summaryrefslogtreecommitdiffstats
path: root/lualib/rspamadm
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-10 21:30:40 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-10 21:30:40 +0000
commit133a45c109da5310add55824db21af5239951f93 (patch)
treeba6ac4c0a950a0dda56451944315d66409923918 /lualib/rspamadm
parentInitial commit. (diff)
downloadrspamd-upstream.tar.xz
rspamd-upstream.zip
Adding upstream version 3.8.1.upstream/3.8.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'lualib/rspamadm')
-rw-r--r--lualib/rspamadm/clickhouse.lua528
-rw-r--r--lualib/rspamadm/configgraph.lua172
-rw-r--r--lualib/rspamadm/confighelp.lua123
-rw-r--r--lualib/rspamadm/configwizard.lua849
-rw-r--r--lualib/rspamadm/cookie.lua125
-rw-r--r--lualib/rspamadm/corpus_test.lua185
-rw-r--r--lualib/rspamadm/dkim_keygen.lua178
-rw-r--r--lualib/rspamadm/dmarc_report.lua737
-rw-r--r--lualib/rspamadm/dns_tool.lua232
-rw-r--r--lualib/rspamadm/fuzzy_convert.lua208
-rw-r--r--lualib/rspamadm/fuzzy_ping.lua259
-rw-r--r--lualib/rspamadm/fuzzy_stat.lua366
-rw-r--r--lualib/rspamadm/grep.lua174
-rw-r--r--lualib/rspamadm/keypair.lua508
-rw-r--r--lualib/rspamadm/mime.lua1012
-rw-r--r--lualib/rspamadm/neural_test.lua228
-rw-r--r--lualib/rspamadm/publicsuffix.lua82
-rw-r--r--lualib/rspamadm/stat_convert.lua38
-rw-r--r--lualib/rspamadm/statistics_dump.lua544
-rw-r--r--lualib/rspamadm/template.lua131
-rw-r--r--lualib/rspamadm/vault.lua579
21 files changed, 7258 insertions, 0 deletions
diff --git a/lualib/rspamadm/clickhouse.lua b/lualib/rspamadm/clickhouse.lua
new file mode 100644
index 0000000..b22d800
--- /dev/null
+++ b/lualib/rspamadm/clickhouse.lua
@@ -0,0 +1,528 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local argparse = require "argparse"
+local lua_clickhouse = require "lua_clickhouse"
+local lua_util = require "lua_util"
+local rspamd_http = require "rspamd_http"
+local rspamd_upstream_list = require "rspamd_upstream_list"
+local rspamd_logger = require "rspamd_logger"
+local ucl = require "ucl"
+
+local E = {}
+
+-- Define command line options
+local parser = argparse()
+ :name 'rspamadm clickhouse'
+ :description 'Retrieve information from Clickhouse'
+ :help_description_margin(30)
+ :command_target('command')
+ :require_command(true)
+
+parser:option '-c --config'
+ :description 'Path to config file'
+ :argname('config_file')
+ :default(rspamd_paths['CONFDIR'] .. '/rspamd.conf')
+parser:option '-d --database'
+ :description 'Name of Clickhouse database to use'
+ :argname('database')
+ :default('default')
+parser:flag '--no-ssl-verify'
+ :description 'Disable SSL verification'
+ :argname('no_ssl_verify')
+parser:mutex(
+ parser:option '-p --password'
+ :description 'Password to use for Clickhouse'
+ :argname('password'),
+ parser:flag '-a --ask-password'
+ :description 'Ask password from the terminal'
+ :argname('ask_password')
+)
+parser:option '-s --server'
+ :description 'Address[:port] to connect to Clickhouse with'
+ :argname('server')
+parser:option '-u --user'
+ :description 'Username to use for Clickhouse'
+ :argname('user')
+parser:option '--use-gzip'
+ :description 'Use Gzip with Clickhouse'
+ :argname('use_gzip')
+ :default(true)
+parser:flag '--use-https'
+ :description 'Use HTTPS with Clickhouse'
+ :argname('use_https')
+
+local neural_profile = parser:command 'neural_profile'
+ :description 'Generate symbols profile using data from Clickhouse'
+neural_profile:option '-w --where'
+ :description 'WHERE clause for Clickhouse query'
+ :argname('where')
+neural_profile:flag '-j --json'
+ :description 'Write output as JSON'
+ :argname('json')
+neural_profile:option '--days'
+ :description 'Number of days to collect stats for'
+ :argname('days')
+ :default('7')
+neural_profile:option '--limit -l'
+ :description 'Maximum rows to fetch per day'
+ :argname('limit')
+neural_profile:option '--settings-id'
+ :description 'Settings ID to query'
+ :argname('settings_id')
+ :default('')
+
+local neural_train = parser:command 'neural_train'
+ :description 'Train neural using data from Clickhouse'
+neural_train:option '--days'
+ :description 'Number of days to query data for'
+ :argname('days')
+ :default('7')
+neural_train:option '--column-name-digest'
+ :description 'Name of neural profile digest column in Clickhouse'
+ :argname('column_name_digest')
+ :default('NeuralDigest')
+neural_train:option '--column-name-vector'
+ :description 'Name of neural training vector column in Clickhouse'
+ :argname('column_name_vector')
+ :default('NeuralMpack')
+neural_train:option '--limit -l'
+ :description 'Maximum rows to fetch per day'
+ :argname('limit')
+neural_train:option '--profile -p'
+ :description 'Profile to use for training'
+ :argname('profile')
+ :default('default')
+neural_train:option '--rule -r'
+ :description 'Rule to train'
+ :argname('rule')
+ :default('default')
+neural_train:option '--spam -s'
+ :description 'WHERE clause to use for spam'
+ :argname('spam')
+ :default("Action == 'reject'")
+neural_train:option '--ham -h'
+ :description 'WHERE clause to use for ham'
+ :argname('ham')
+ :default('Score < 0')
+neural_train:option '--url -u'
+ :description 'URL to use for training'
+ :argname('url')
+ :default('http://127.0.0.1:11334/plugins/neural/learn')
+
+local http_params = {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+}
+
+local function load_config(config_file)
+ local _r, err = rspamd_config:load_ucl(config_file)
+
+ if not _r then
+ rspamd_logger.errx('cannot load %s: %s', config_file, err)
+ os.exit(1)
+ end
+
+ _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' })
+ if not _r then
+ rspamd_logger.errx('cannot process %s: %s', config_file, err)
+ os.exit(1)
+ end
+
+ if not rspamd_config:init_modules() then
+ rspamd_logger.errx('cannot init modules when parsing %s', config_file)
+ os.exit(1)
+ end
+
+ rspamd_config:init_subsystem('symcache')
+end
+
+local function days_list(days)
+ -- Create list of days to query starting with yesterday
+ local query_days = {}
+ local previous_date = os.time() - 86400
+ local num_days = tonumber(days)
+ for _ = 1, num_days do
+ table.insert(query_days, os.date('%Y-%m-%d', previous_date))
+ previous_date = previous_date - 86400
+ end
+ return query_days
+end
+
+local function get_excluded_symbols(known_symbols, correlations, seen_total)
+ -- Walk results once to collect all symbols & count occurrences
+
+ local remove = {}
+ local known_symbols_list = {}
+ local composites = rspamd_config:get_all_opt('composites')
+ local all_symbols = rspamd_config:get_symbols()
+ local skip_flags = {
+ nostat = true,
+ skip = true,
+ idempotent = true,
+ composite = true,
+ }
+ for k, v in pairs(known_symbols) do
+ local lower_count, higher_count
+ if v.seen_spam > v.seen_ham then
+ lower_count = v.seen_ham
+ higher_count = v.seen_spam
+ else
+ lower_count = v.seen_spam
+ higher_count = v.seen_ham
+ end
+
+ if composites[k] then
+ remove[k] = 'composite symbol'
+ elseif lower_count / higher_count >= 0.95 then
+ remove[k] = 'weak ham/spam correlation'
+ elseif v.seen / seen_total >= 0.9 then
+ remove[k] = 'omnipresent symbol'
+ elseif not all_symbols[k] then
+ remove[k] = 'nonexistent symbol'
+ else
+ for fl, _ in pairs(all_symbols[k].flags or {}) do
+ if skip_flags[fl] then
+ remove[k] = fl .. ' symbol'
+ break
+ end
+ end
+ end
+ known_symbols_list[v.id] = {
+ seen = v.seen,
+ name = k,
+ }
+ end
+
+ -- Walk correlation matrix and check total counts
+ for sym_id, row in pairs(correlations) do
+ for inner_sym_id, count in pairs(row) do
+ local known = known_symbols_list[sym_id]
+ local inner = known_symbols_list[inner_sym_id]
+ if known and count == known.seen and not remove[inner.name] and not remove[known.name] then
+ remove[known.name] = string.format("overlapped by %s",
+ known_symbols_list[inner_sym_id].name)
+ end
+ end
+ end
+
+ return remove
+end
+
+local function handle_neural_profile(args)
+
+ local known_symbols, correlations = {}, {}
+ local symbols_count, seen_total = 0, 0
+
+ local function process_row(r)
+ local is_spam = true
+ if r['Action'] == 'no action' or r['Action'] == 'greylist' then
+ is_spam = false
+ end
+ seen_total = seen_total + 1
+
+ local nsym = #r['Symbols.Names']
+
+ for i = 1, nsym do
+ local sym = r['Symbols.Names'][i]
+ local t = known_symbols[sym]
+ if not t then
+ local spam_count, ham_count = 0, 0
+ if is_spam then
+ spam_count = spam_count + 1
+ else
+ ham_count = ham_count + 1
+ end
+ known_symbols[sym] = {
+ id = symbols_count,
+ seen = 1,
+ seen_ham = ham_count,
+ seen_spam = spam_count,
+ }
+ symbols_count = symbols_count + 1
+ else
+ known_symbols[sym].seen = known_symbols[sym].seen + 1
+ if is_spam then
+ known_symbols[sym].seen_spam = known_symbols[sym].seen_spam + 1
+ else
+ known_symbols[sym].seen_ham = known_symbols[sym].seen_ham + 1
+ end
+ end
+ end
+
+ -- Fill correlations
+ for i = 1, nsym do
+ for j = 1, nsym do
+ if i ~= j then
+ local sym = r['Symbols.Names'][i]
+ local inner_sym_name = r['Symbols.Names'][j]
+ local known_sym = known_symbols[sym]
+ local inner_sym = known_symbols[inner_sym_name]
+ if known_sym and inner_sym then
+ if not correlations[known_sym.id] then
+ correlations[known_sym.id] = {}
+ end
+ local n = correlations[known_sym.id][inner_sym.id] or 0
+ n = n + 1
+ correlations[known_sym.id][inner_sym.id] = n
+ end
+ end
+ end
+ end
+ end
+
+ local query_days = days_list(args.days)
+ local conditions = {}
+ table.insert(conditions, string.format("SettingsId = '%s'", args.settings_id))
+ local limit = ''
+ local num_limit = tonumber(args.limit)
+ if num_limit then
+ limit = string.format(' LIMIT %d', num_limit) -- Contains leading space
+ end
+ if args.where then
+ table.insert(conditions, args.where)
+ end
+
+ local query_fmt = 'SELECT Action, Symbols.Names FROM rspamd WHERE %s%s'
+ for _, query_day in ipairs(query_days) do
+ -- Date should be the last condition
+ table.insert(conditions, string.format("Date = '%s'", query_day))
+ local query = string.format(query_fmt, table.concat(conditions, ' AND '), limit)
+ local upstream = args.upstream:get_upstream_round_robin()
+ local err = lua_clickhouse.select_sync(upstream, args, http_params, query, process_row)
+ if err ~= nil then
+ io.stderr:write(string.format('Error querying Clickhouse: %s\n', err))
+ os.exit(1)
+ end
+ conditions[#conditions] = nil -- remove Date condition
+ end
+
+ local remove = get_excluded_symbols(known_symbols, correlations, seen_total)
+ if not args.json then
+ for k in pairs(known_symbols) do
+ if not remove[k] then
+ io.stdout:write(string.format('%s\n', k))
+ end
+ end
+ os.exit(0)
+ end
+
+ local json_output = {
+ all_symbols = {},
+ removed_symbols = {},
+ used_symbols = {},
+ }
+ for k in pairs(known_symbols) do
+ table.insert(json_output.all_symbols, k)
+ local why_removed = remove[k]
+ if why_removed then
+ json_output.removed_symbols[k] = why_removed
+ else
+ table.insert(json_output.used_symbols, k)
+ end
+ end
+ io.stdout:write(ucl.to_format(json_output, 'json'))
+end
+
+local function post_neural_training(url, rule, spam_rows, ham_rows)
+ -- Prepare JSON payload
+ local payload = ucl.to_format(
+ {
+ ham_vec = ham_rows,
+ rule = rule,
+ spam_vec = spam_rows,
+ }, 'json')
+
+ -- POST the payload
+ local err, response = rspamd_http.request({
+ body = payload,
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ log_obj = rspamd_config,
+ resolver = rspamadm_dns_resolver,
+ session = rspamadm_session,
+ url = url,
+ })
+
+ if err then
+ io.stderr:write(string.format('HTTP error: %s\n', err))
+ os.exit(1)
+ end
+ if response.code ~= 200 then
+ io.stderr:write(string.format('bad HTTP code: %d\n', response.code))
+ os.exit(1)
+ end
+ io.stdout:write(string.format('%s\n', response.content))
+end
+
+local function handle_neural_train(args)
+
+ local this_where -- which class of messages are we collecting data for
+ local ham_rows, spam_rows = {}, {}
+ local want_spam, want_ham = true, true -- keep collecting while true
+
+ -- Try find profile in config
+ local neural_opts = rspamd_config:get_all_opt('neural')
+ local symbols_profile = ((((neural_opts or E).rules or E)[args.rule] or E).profile or E)[args.profile]
+ if not symbols_profile then
+ io.stderr:write(string.format("Couldn't find profile %s in rule %s\n", args.profile, args.rule))
+ os.exit(1)
+ end
+ -- Try find max_trains
+ local max_trains = (neural_opts.rules[args.rule].train or E).max_trains or 1000
+
+ -- Callback used to process rows from Clickhouse
+ local function process_row(r)
+ local destination -- which table to collect this information in
+ if this_where == args.ham then
+ destination = ham_rows
+ if #destination >= max_trains then
+ want_ham = false
+ return
+ end
+ else
+ destination = spam_rows
+ if #destination >= max_trains then
+ want_spam = false
+ return
+ end
+ end
+ local ucl_parser = ucl.parser()
+ local ok, err = ucl_parser:parse_string(r[args.column_name_vector], 'msgpack')
+ if not ok then
+ io.stderr:write(string.format("Couldn't parse [%s]: %s", r[args.column_name_vector], err))
+ os.exit(1)
+ end
+ table.insert(destination, ucl_parser:get_object())
+ end
+
+ -- Generate symbols digest
+ table.sort(symbols_profile)
+ local symbols_digest = lua_util.table_digest(symbols_profile)
+ -- Create list of days to query data for
+ local query_days = days_list(args.days)
+ -- Set value for limit
+ local limit = ''
+ local num_limit = tonumber(args.limit)
+ if num_limit then
+ limit = string.format(' LIMIT %d', num_limit) -- Contains leading space
+ end
+ -- Prepare query elements
+ local conditions = { string.format("%s = '%s'", args.column_name_digest, symbols_digest) }
+ local query_fmt = 'SELECT %s FROM rspamd WHERE %s%s'
+
+ -- Run queries
+ for _, the_where in ipairs({ args.ham, args.spam }) do
+ -- Inform callback which group of vectors we're collecting
+ this_where = the_where
+ table.insert(conditions, the_where) -- should be 2nd from last condition
+ -- Loop over days and try collect data
+ for _, query_day in ipairs(query_days) do
+ -- Break the loop if we have enough data already
+ if this_where == args.ham then
+ if not want_ham then
+ break
+ end
+ else
+ if not want_spam then
+ break
+ end
+ end
+ -- Date should be the last condition
+ table.insert(conditions, string.format("Date = '%s'", query_day))
+ local query = string.format(query_fmt, args.column_name_vector, table.concat(conditions, ' AND '), limit)
+ local upstream = args.upstream:get_upstream_round_robin()
+ local err = lua_clickhouse.select_sync(upstream, args, http_params, query, process_row)
+ if err ~= nil then
+ io.stderr:write(string.format('Error querying Clickhouse: %s\n', err))
+ os.exit(1)
+ end
+ conditions[#conditions] = nil -- remove Date condition
+ end
+ conditions[#conditions] = nil -- remove spam/ham condition
+ end
+
+ -- Make sure we collected enough data for training
+ if #ham_rows < max_trains then
+ io.stderr:write(string.format('Insufficient ham rows: %d/%d\n', #ham_rows, max_trains))
+ os.exit(1)
+ end
+ if #spam_rows < max_trains then
+ io.stderr:write(string.format('Insufficient spam rows: %d/%d\n', #spam_rows, max_trains))
+ os.exit(1)
+ end
+
+ return post_neural_training(args.url, args.rule, spam_rows, ham_rows)
+end
+
+local command_handlers = {
+ neural_profile = handle_neural_profile,
+ neural_train = handle_neural_train,
+}
+
+local function handler(args)
+ local cmd_opts = parser:parse(args)
+
+ load_config(cmd_opts.config_file)
+ local cfg_opts = rspamd_config:get_all_opt('clickhouse')
+
+ if cmd_opts.ask_password then
+ local rspamd_util = require "rspamd_util"
+
+ io.write('Password: ')
+ cmd_opts.password = rspamd_util.readpassphrase()
+ end
+
+ local function override_settings(params)
+ for _, which in ipairs(params) do
+ if cmd_opts[which] == nil then
+ cmd_opts[which] = cfg_opts[which]
+ end
+ end
+ end
+
+ override_settings({
+ 'database', 'no_ssl_verify', 'password', 'server',
+ 'use_gzip', 'use_https', 'user',
+ })
+
+ local servers = cmd_opts['server'] or cmd_opts['servers']
+ if not servers then
+ parser:error("server(s) unspecified & couldn't be fetched from config")
+ end
+
+ cmd_opts.upstream = rspamd_upstream_list.create(rspamd_config, servers, 8123)
+
+ if not cmd_opts.upstream then
+ io.stderr:write(string.format("can't parse clickhouse address: %s\n", servers))
+ os.exit(1)
+ end
+
+ local f = command_handlers[cmd_opts.command]
+ if not f then
+ parser:error(string.format("command isn't implemented: %s",
+ cmd_opts.command))
+ end
+ f(cmd_opts)
+end
+
+return {
+ handler = handler,
+ description = parser._description,
+ name = 'clickhouse'
+}
diff --git a/lualib/rspamadm/configgraph.lua b/lualib/rspamadm/configgraph.lua
new file mode 100644
index 0000000..07f14a9
--- /dev/null
+++ b/lualib/rspamadm/configgraph.lua
@@ -0,0 +1,172 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local rspamd_logger = require "rspamd_logger"
+local rspamd_util = require "rspamd_util"
+local rspamd_regexp = require "rspamd_regexp"
+local argparse = require "argparse"
+
+-- Define command line options
+local parser = argparse()
+ :name "rspamadm configgraph"
+ :description "Produces graph of Rspamd includes"
+ :help_description_margin(30)
+parser:option "-c --config"
+ :description "Path to config file"
+ :argname("<file>")
+ :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
+parser:flag "-a --all"
+ :description('Show all nodes, not just existing ones')
+
+local function process_filename(fname)
+ local cdir = rspamd_paths['CONFDIR'] .. '/'
+ fname = fname:gsub(cdir, '')
+ return fname
+end
+
+local function output_dot(opts, nodes, adjacency)
+ rspamd_logger.messagex("digraph rspamd {")
+ for k, node in pairs(nodes) do
+ local attrs = { "shape=box" }
+ local skip = false
+ if node.exists then
+ if node.priority >= 10 then
+ attrs[#attrs + 1] = "color=red"
+ elseif node.priority > 0 then
+ attrs[#attrs + 1] = "color=blue"
+ end
+ else
+ if opts.all then
+ attrs[#attrs + 1] = "style=dotted"
+ else
+ skip = true
+ end
+ end
+
+ if not skip then
+ rspamd_logger.messagex("\"%s\" [%s];", process_filename(k),
+ table.concat(attrs, ','))
+ end
+ end
+ for _, adj in ipairs(adjacency) do
+ local attrs = {}
+ local skip = false
+
+ if adj.to.exists then
+ if adj.to.merge then
+ attrs[#attrs + 1] = "arrowhead=diamond"
+ attrs[#attrs + 1] = "label=\"+\""
+ elseif adj.to.priority > 1 then
+ attrs[#attrs + 1] = "color=red"
+ end
+ else
+ if opts.all then
+ attrs[#attrs + 1] = "style=dotted"
+ else
+ skip = true
+ end
+ end
+
+ if not skip then
+ rspamd_logger.messagex("\"%s\" -> \"%s\" [%s];", process_filename(adj.from),
+ adj.to.short_path, table.concat(attrs, ','))
+ end
+ end
+ rspamd_logger.messagex("}")
+end
+
+local function load_config_traced(opts)
+ local glob_traces = {}
+ local adjacency = {}
+ local nodes = {}
+
+ local function maybe_match_glob(file)
+ for _, gl in ipairs(glob_traces) do
+ if gl.re:match(file) then
+ return gl
+ end
+ end
+
+ return nil
+ end
+
+ local function add_dep(from, node, args)
+ adjacency[#adjacency + 1] = {
+ from = from,
+ to = node,
+ args = args
+ }
+ end
+
+ local function process_node(fname, args)
+ local node = nodes[fname]
+ if not node then
+ node = {
+ path = fname,
+ short_path = process_filename(fname),
+ exists = rspamd_util.file_exists(fname),
+ merge = args.duplicate and args.duplicate == 'merge',
+ priority = args.priority or 0,
+ glob = args.glob,
+ try = args.try,
+ }
+ nodes[fname] = node
+ end
+
+ return node
+ end
+
+ local function trace_func(cur_file, included_file, args, parent)
+ if args.glob then
+ glob_traces[#glob_traces + 1] = {
+ re = rspamd_regexp.import_glob(included_file, ''),
+ parent = cur_file,
+ args = args,
+ seen = {},
+ }
+ else
+ local node = process_node(included_file, args)
+ if opts.all or node.exists then
+ local gl_parent = maybe_match_glob(included_file)
+ if gl_parent and not gl_parent.seen[cur_file] then
+ add_dep(gl_parent.parent, nodes[cur_file], gl_parent.args)
+ gl_parent.seen[cur_file] = true
+ end
+ add_dep(cur_file, node, args)
+ end
+ end
+ end
+
+ local _r, err = rspamd_config:load_ucl(opts['config'], trace_func)
+ if not _r then
+ rspamd_logger.errx('cannot parse %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
+ output_dot(opts, nodes, adjacency)
+end
+
+local function handler(args)
+ local res = parser:parse(args)
+
+ load_config_traced(res)
+end
+
+return {
+ handler = handler,
+ description = parser._description,
+ name = 'configgraph'
+} \ No newline at end of file
diff --git a/lualib/rspamadm/confighelp.lua b/lualib/rspamadm/confighelp.lua
new file mode 100644
index 0000000..38b26b6
--- /dev/null
+++ b/lualib/rspamadm/confighelp.lua
@@ -0,0 +1,123 @@
+local opts
+local known_attrs = {
+ data = 1,
+ example = 1,
+ type = 1,
+ required = 1,
+ default = 1,
+}
+local argparse = require "argparse"
+local ansicolors = require "ansicolors"
+
+local parser = argparse()
+ :name "rspamadm confighelp"
+ :description "Shows help for the specified configuration options"
+ :help_description_margin(32)
+parser:argument "path":args "*"
+ :description('Optional config paths')
+parser:flag "--no-color"
+ :description "Disable coloured output"
+parser:flag "--short"
+ :description "Show only option names"
+parser:flag "--no-examples"
+ :description "Do not show examples (implied by --short)"
+
+local function maybe_print_color(key)
+ if not opts['no-color'] then
+ return ansicolors.white .. key .. ansicolors.reset
+ else
+ return key
+ end
+end
+
+local function sort_values(tbl)
+ local res = {}
+ for k, v in pairs(tbl) do
+ table.insert(res, { key = k, value = v })
+ end
+
+ -- Sort order
+ local order = {
+ options = 1,
+ dns = 2,
+ upstream = 3,
+ logging = 4,
+ metric = 5,
+ composite = 6,
+ classifier = 7,
+ modules = 8,
+ lua = 9,
+ worker = 10,
+ workers = 11,
+ }
+
+ table.sort(res, function(a, b)
+ local oa = order[a['key']]
+ local ob = order[b['key']]
+
+ if oa and ob then
+ return oa < ob
+ elseif oa then
+ return -1 < 0
+ elseif ob then
+ return 1 < 0
+ else
+ return a['key'] < b['key']
+ end
+
+ end)
+
+ return res
+end
+
+local function print_help(key, value, tabs)
+ print(string.format('%sConfiguration element: %s', tabs, maybe_print_color(key)))
+
+ if not opts['short'] then
+ if value['data'] then
+ local nv = string.match(value['data'], '^#%s*(.*)%s*$') or value.data
+ print(string.format('%s\tDescription: %s', tabs, nv))
+ end
+ if type(value['type']) == 'string' then
+ print(string.format('%s\tType: %s', tabs, value['type']))
+ end
+ if type(value['required']) == 'boolean' then
+ if value['required'] then
+ print(string.format('%s\tRequired: %s', tabs,
+ maybe_print_color(tostring(value['required']))))
+ else
+ print(string.format('%s\tRequired: %s', tabs,
+ tostring(value['required'])))
+ end
+ end
+ if value['default'] then
+ print(string.format('%s\tDefault: %s', tabs, value['default']))
+ end
+ if not opts['no-examples'] and value['example'] then
+ local nv = string.match(value['example'], '^%s*(.*[^%s])%s*$') or value.example
+ print(string.format('%s\tExample:\n%s', tabs, nv))
+ end
+ if value.type and value.type == 'object' then
+ print('')
+ end
+ end
+
+ local sorted = sort_values(value)
+ for _, v in ipairs(sorted) do
+ if not known_attrs[v['key']] then
+ -- We need to go deeper
+ print_help(v['key'], v['value'], tabs .. '\t')
+ end
+ end
+end
+
+return function(args, res)
+ opts = parser:parse(args)
+
+ local sorted = sort_values(res)
+
+ for _, v in ipairs(sorted) do
+ print_help(v['key'], v['value'], '')
+ print('')
+ end
+end
diff --git a/lualib/rspamadm/configwizard.lua b/lualib/rspamadm/configwizard.lua
new file mode 100644
index 0000000..2637036
--- /dev/null
+++ b/lualib/rspamadm/configwizard.lua
@@ -0,0 +1,849 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local ansicolors = require "ansicolors"
+local local_conf = rspamd_paths['CONFDIR']
+local rspamd_util = require "rspamd_util"
+local rspamd_logger = require "rspamd_logger"
+local lua_util = require "lua_util"
+local lua_stat_tools = require "lua_stat"
+local lua_redis = require "lua_redis"
+local ucl = require "ucl"
+local argparse = require "argparse"
+local fun = require "fun"
+
+local plugins_stat = require "plugins_stats"
+
+local rspamd_logo = [[
+ ____ _
+ | _ \ ___ _ __ __ _ _ __ ___ __| |
+ | |_) |/ __|| '_ \ / _` || '_ ` _ \ / _` |
+ | _ < \__ \| |_) || (_| || | | | | || (_| |
+ |_| \_\|___/| .__/ \__,_||_| |_| |_| \__,_|
+ |_|
+]]
+
+local parser = argparse()
+ :name "rspamadm configwizard"
+ :description "Perform guided configuration for Rspamd daemon"
+ :help_description_margin(32)
+parser:option "-c --config"
+ :description "Path to config file"
+ :argname("<file>")
+ :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
+parser:argument "checks"
+ :description "Checks to do (or 'list')"
+ :argname("<checks>")
+ :args "*"
+
+local redis_params
+
+local function printf(fmt, ...)
+ if fmt then
+ io.write(string.format(fmt, ...))
+ end
+ io.write('\n')
+end
+
+local function highlight(str)
+ return ansicolors.white .. str .. ansicolors.reset
+end
+
+local function ask_yes_no(greet, default)
+ local def_str
+ if default then
+ greet = greet .. "[Y/n]: "
+ def_str = "yes"
+ else
+ greet = greet .. "[y/N]: "
+ def_str = "no"
+ end
+
+ local reply = rspamd_util.readline(greet)
+
+ if not reply then
+ os.exit(0)
+ end
+ if #reply == 0 then
+ reply = def_str
+ end
+ reply = reply:lower()
+ if reply == 'y' or reply == 'yes' then
+ return true
+ end
+
+ return false
+end
+
+local function readline_default(greet, def_value)
+ local reply = rspamd_util.readline(greet)
+ if not reply then
+ os.exit(0)
+ end
+
+ if #reply == 0 then
+ return def_value
+ end
+
+ return reply
+end
+
+local function readline_expire()
+ local expire = '100d'
+ repeat
+ expire = readline_default("Expire time for new tokens [" .. expire .. "]: ",
+ expire)
+ expire = lua_util.parse_time_interval(expire)
+
+ if not expire then
+ expire = '100d'
+ elseif expire > 2147483647 then
+ printf("The maximum possible value is 2147483647 (about 68y)")
+ expire = '68y'
+ elseif expire < -1 then
+ printf("The value must be a non-negative integer or -1")
+ expire = -1
+ elseif expire ~= math.floor(expire) then
+ printf("The value must be an integer")
+ expire = math.floor(expire)
+ else
+ return expire
+ end
+ until false
+end
+
+local function print_changes(changes)
+ local function print_change(k, c, where)
+ printf('File: %s, changes list:', highlight(local_conf .. '/'
+ .. where .. '/' .. k))
+
+ for ek, ev in pairs(c) do
+ printf("%s => %s", highlight(ek), rspamd_logger.slog("%s", ev))
+ end
+ end
+ for k, v in pairs(changes.l) do
+ print_change(k, v, 'local.d')
+ if changes.o[k] then
+ v = changes.o[k]
+ print_change(k, v, 'override.d')
+ end
+ print()
+ end
+end
+
+local function apply_changes(changes)
+ local function dirname(fname)
+ if fname:match(".-/.-") then
+ return string.gsub(fname, "(.*/)(.*)", "%1")
+ else
+ return nil
+ end
+ end
+
+ local function apply_change(k, c, where)
+ local fname = local_conf .. '/' .. where .. '/' .. k
+
+ if not rspamd_util.file_exists(fname) then
+ printf("Create file %s", highlight(fname))
+
+ local dname = dirname(fname)
+
+ if dname then
+ local ret, err = rspamd_util.mkdir(dname, true)
+
+ if not ret then
+ printf("Cannot make directory %s: %s", dname, highlight(err))
+ os.exit(1)
+ end
+ end
+ end
+
+ local f = io.open(fname, "a+")
+
+ if not f then
+ printf("Cannot open file %s, aborting", highlight(fname))
+ os.exit(1)
+ end
+
+ f:write(ucl.to_config(c))
+
+ f:close()
+ end
+ for k, v in pairs(changes.l) do
+ apply_change(k, v, 'local.d')
+ if changes.o[k] then
+ v = changes.o[k]
+ apply_change(k, v, 'override.d')
+ end
+ end
+end
+
+local function setup_controller(controller, changes)
+ printf("Setup %s and controller worker:", highlight("WebUI"))
+
+ if not controller.password or controller.password == 'q1' then
+ if ask_yes_no("Controller password is not set, do you want to set one?", true) then
+ local pw_encrypted = rspamadm.pw_encrypt()
+ if pw_encrypted then
+ printf("Set encrypted password to: %s", highlight(pw_encrypted))
+ changes.l['worker-controller.inc'] = {
+ password = pw_encrypted
+ }
+ end
+ end
+ end
+end
+
+local function setup_redis(cfg, changes)
+ local function parse_servers(servers)
+ local ls = lua_util.rspamd_str_split(servers, ",")
+
+ return ls
+ end
+
+ printf("%s servers are not set:", highlight("Redis"))
+ printf("The following modules will be enabled if you add Redis servers:")
+
+ for k, _ in pairs(rspamd_plugins_state.disabled_redis) do
+ printf("\t* %s", highlight(k))
+ end
+
+ if ask_yes_no("Do you wish to set Redis servers?", true) then
+ local read_servers = readline_default("Input read only servers separated by `,` [default: localhost]: ",
+ "localhost")
+
+ local rs = parse_servers(read_servers)
+ if rs and #rs > 0 then
+ changes.l['redis.conf'] = {
+ read_servers = table.concat(rs, ",")
+ }
+ end
+ local write_servers = readline_default("Input write only servers separated by `,` [default: "
+ .. read_servers .. "]: ", read_servers)
+
+ if not write_servers or #write_servers == 0 then
+ printf("Use read servers %s as write servers", highlight(table.concat(rs, ",")))
+ write_servers = read_servers
+ end
+
+ redis_params = {
+ read_servers = rs,
+ }
+
+ local ws = parse_servers(write_servers)
+ if ws and #ws > 0 then
+ changes.l['redis.conf']['write_servers'] = table.concat(ws, ",")
+ redis_params['write_servers'] = ws
+ end
+
+ if ask_yes_no('Do you have any username set for your Redis (ACL SETUSER and Redis 6.0+)') then
+ local username = readline_default("Enter Redis username:", nil)
+
+ if username then
+ changes.l['redis.conf'].username = username
+ redis_params.username = username
+ end
+
+ local passwd = readline_default("Enter Redis password:", nil)
+
+ if passwd then
+ changes.l['redis.conf']['password'] = passwd
+ redis_params['password'] = passwd
+ end
+ elseif ask_yes_no('Do you have any password set for your Redis?') then
+ local passwd = readline_default("Enter Redis password:", nil)
+
+ if passwd then
+ changes.l['redis.conf']['password'] = passwd
+ redis_params['password'] = passwd
+ end
+ end
+
+ if ask_yes_no('Do you have any specific database for your Redis?') then
+ local db = readline_default("Enter Redis database:", nil)
+
+ if db then
+ changes.l['redis.conf']['db'] = db
+ redis_params['db'] = db
+ end
+ end
+ end
+end
+
+local function setup_dkim_signing(cfg, changes)
+ -- Remove the trailing slash of a pathname, if present.
+ local function remove_trailing_slash(path)
+ if string.sub(path, -1) ~= "/" then
+ return path
+ end
+ return string.sub(path, 1, string.len(path) - 1)
+ end
+
+ printf('How would you like to set up DKIM signing?')
+ printf('1. Use domain from %s for sign', highlight('mime from header'))
+ printf('2. Use domain from %s for sign', highlight('SMTP envelope from'))
+ printf('3. Use domain from %s for sign', highlight('authenticated user'))
+ printf('4. Sign all mail from %s', highlight('specific networks'))
+ printf()
+
+ local sign_type = readline_default('Enter your choice (1, 2, 3, 4) [default: 1]: ', '1')
+ local sign_networks
+ local allow_mismatch
+ local sign_authenticated
+ local use_esld
+ local sign_domain = 'pet luacheck'
+
+ local defined_auth_types = { 'header', 'envelope', 'auth', 'recipient' }
+
+ if sign_type == '4' then
+ repeat
+ sign_networks = readline_default('Enter list of networks to perform dkim signing: ',
+ '')
+ until #sign_networks ~= 0
+
+ sign_networks = fun.totable(fun.map(lua_util.rspamd_str_trim,
+ lua_util.str_split(sign_networks, ',; ')))
+ printf('What domain would you like to use for signing?')
+ printf('* %s to use mime from domain', highlight('header'))
+ printf('* %s to use SMTP from domain', highlight('envelope'))
+ printf('* %s to use domain from SMTP auth', highlight('auth'))
+ printf('* %s to use domain from SMTP recipient', highlight('recipient'))
+ printf('* anything else to use as a %s domain (e.g. `example.com`)', highlight('static'))
+ printf()
+
+ sign_domain = readline_default('Enter your choice [default: header]: ', 'header')
+ else
+ if sign_type == '1' then
+ sign_domain = 'header'
+ elseif sign_type == '2' then
+ sign_domain = 'envelope'
+ else
+ sign_domain = 'auth'
+ end
+ end
+
+ if sign_type ~= '3' then
+ sign_authenticated = ask_yes_no(
+ string.format('Do you want to sign mail from %s? ',
+ highlight('authenticated users')), true)
+ else
+ sign_authenticated = true
+ end
+
+ if fun.any(function(s)
+ return s == sign_domain
+ end, defined_auth_types) then
+ -- Allow mismatch
+ allow_mismatch = ask_yes_no(
+ string.format('Allow data %s, e.g. if mime from domain is not equal to authenticated user domain? ',
+ highlight('mismatch')), true)
+ -- ESLD check
+ use_esld = ask_yes_no(
+ string.format('Do you want to use %s domain (e.g. example.com instead of foo.example.com)? ',
+ highlight('effective')), true)
+ else
+ allow_mismatch = true
+ end
+
+ local domains = {}
+ local has_domains = false
+
+ local dkim_keys_dir = rspamd_paths["DBDIR"] .. "/dkim/"
+
+ local prompt = string.format("Enter output directory for the keys [default: %s]: ",
+ highlight(dkim_keys_dir))
+ dkim_keys_dir = remove_trailing_slash(readline_default(prompt, dkim_keys_dir))
+
+ local ret, err = rspamd_util.mkdir(dkim_keys_dir, true)
+
+ if not ret then
+ printf("Cannot make directory %s: %s", dkim_keys_dir, highlight(err))
+ os.exit(1)
+ end
+
+ local function print_domains()
+ printf("Domains configured:")
+ for k, v in pairs(domains) do
+ printf("Domain: %s, selector: %s, privkey: %s", highlight(k),
+ v.selector, v.privkey)
+ end
+ printf("--")
+ end
+ local function print_public_key(pk)
+ local base64_pk = tostring(rspamd_util.encode_base64(pk))
+ printf('v=DKIM1; k=rsa; p=%s\n', base64_pk)
+ end
+ repeat
+ if has_domains then
+ print_domains()
+ end
+
+ local domain
+ repeat
+ domain = rspamd_util.readline("Enter domain to sign: ")
+ if not domain then
+ os.exit(1)
+ end
+ until #domain ~= 0
+
+ local selector = readline_default("Enter selector [default: dkim]: ", 'dkim')
+ if not selector then
+ selector = 'dkim'
+ end
+
+ local privkey_file = string.format("%s/%s.%s.key", dkim_keys_dir, domain,
+ selector)
+ if not rspamd_util.file_exists(privkey_file) then
+ if ask_yes_no("Do you want to create privkey " .. highlight(privkey_file),
+ true) then
+ local rsa = require "rspamd_rsa"
+ local sk, pk = rsa.keypair(2048)
+ sk:save(privkey_file, 'pem')
+ print("You need to chown private key file to rspamd user!!")
+ print("To make dkim signing working, to place the following record in your DNS zone:")
+ print_public_key(tostring(pk))
+ end
+ end
+
+ domains[domain] = {
+ selector = selector,
+ path = privkey_file,
+ }
+ until not ask_yes_no("Do you wish to add another DKIM domain?")
+
+ changes.l['dkim_signing.conf'] = { domain = domains }
+ local res_tbl = changes.l['dkim_signing.conf']
+
+ if sign_networks then
+ res_tbl.sign_networks = sign_networks
+ res_tbl.use_domain_sign_networks = sign_domain
+ else
+ res_tbl.use_domain = sign_domain
+ end
+
+ if allow_mismatch then
+ res_tbl.allow_hdrfrom_mismatch = true
+ res_tbl.allow_hdrfrom_mismatch_sign_networks = true
+ res_tbl.allow_username_mismatch = true
+ end
+
+ res_tbl.use_esld = use_esld
+ res_tbl.sign_authenticated = sign_authenticated
+end
+
+local function check_redis_classifier(cls, changes)
+ local symbol_spam, symbol_ham
+ -- Load symbols from statfiles
+ local statfiles = cls.statfile
+ for _, stf in ipairs(statfiles) do
+ local symbol = stf.symbol or 'undefined'
+
+ local spam
+ if stf.spam then
+ spam = stf.spam
+ else
+ if string.match(symbol:upper(), 'SPAM') then
+ spam = true
+ else
+ spam = false
+ end
+ end
+
+ if spam then
+ symbol_spam = symbol
+ else
+ symbol_ham = symbol
+ end
+ end
+
+ if not symbol_spam or not symbol_ham then
+ printf("Classifier has no symbols defined")
+ return
+ end
+
+ local parsed_redis = lua_redis.try_load_redis_servers(cls, nil)
+
+ if not parsed_redis and redis_params then
+ parsed_redis = lua_redis.try_load_redis_servers(redis_params, nil)
+ if not parsed_redis then
+ printf("Cannot parse Redis params")
+ return
+ end
+ end
+
+ local function try_convert(update_config)
+ if ask_yes_no("Do you wish to convert data to the new schema?", true) then
+ local expire = readline_expire()
+ if not lua_stat_tools.convert_bayes_schema(parsed_redis, symbol_spam,
+ symbol_ham, expire) then
+ printf("Conversion failed")
+ else
+ printf("Conversion succeed")
+ if update_config then
+ changes.l['classifier-bayes.conf'] = {
+ new_schema = true,
+ }
+
+ if expire then
+ changes.l['classifier-bayes.conf'].expire = expire
+ end
+ end
+ end
+ end
+ end
+
+ local function get_version(conn)
+ conn:add_cmd("SMEMBERS", { "RS_keys" })
+
+ local ret, members = conn:exec()
+
+ -- Empty db
+ if not ret or #members == 0 then
+ return false, 0
+ end
+
+ -- We still need to check versions
+ local lua_script = [[
+local ver = 0
+
+local tst = redis.call('GET', KEYS[1]..'_version')
+if tst then
+ ver = tonumber(tst) or 0
+end
+
+return ver
+]]
+ conn:add_cmd('EVAL', { lua_script, '1', 'RS' })
+ local _, ver = conn:exec()
+
+ return true, tonumber(ver)
+ end
+
+ local function check_expire(conn)
+ -- We still need to check versions
+ local lua_script = [[
+local ttl = 0
+
+local sc = redis.call('SCAN', 0, 'MATCH', 'RS*_*', 'COUNT', 1)
+local _,key = sc[1], sc[2]
+
+if key and key[1] then
+ ttl = redis.call('TTL', key[1])
+end
+
+return ttl
+]]
+ conn:add_cmd('EVAL', { lua_script, '0' })
+ local _, ttl = conn:exec()
+
+ return tonumber(ttl)
+ end
+
+ local res, conn = lua_redis.redis_connect_sync(parsed_redis, true)
+ if not res then
+ printf("Cannot connect to Redis server")
+ return false
+ end
+
+ if not cls.new_schema then
+ local r, ver = get_version(conn)
+ if not r then
+ return false
+ end
+ if ver ~= 2 then
+ if not ver then
+ printf('Key "RS_version" has not been found in Redis for %s/%s',
+ symbol_ham, symbol_spam)
+ else
+ printf("You are using an old schema version: %s for %s/%s",
+ ver, symbol_ham, symbol_spam)
+ end
+ try_convert(true)
+ else
+ printf("You have configured an old schema for %s/%s but your data has new layout",
+ symbol_ham, symbol_spam)
+
+ if ask_yes_no("Switch config to the new schema?", true) then
+ changes.l['classifier-bayes.conf'] = {
+ new_schema = true,
+ }
+
+ local expire = check_expire(conn)
+ if expire then
+ changes.l['classifier-bayes.conf'].expire = expire
+ end
+ end
+ end
+ else
+ local r, ver = get_version(conn)
+ if not r then
+ return false
+ end
+ if ver ~= 2 then
+ printf("You have configured new schema for %s/%s but your DB has old version: %s",
+ symbol_spam, symbol_ham, ver)
+ try_convert(false)
+ else
+ printf(
+ 'You have configured new schema for %s/%s and your DB already has new layout (v. %s).' ..
+ ' DB conversion is not needed.',
+ symbol_spam, symbol_ham, ver)
+ end
+ end
+end
+
+local function setup_statistic(cfg, changes)
+ local sqlite_configs = lua_stat_tools.load_sqlite_config(cfg)
+
+ if #sqlite_configs > 0 then
+
+ if not redis_params then
+ printf('You have %d sqlite classifiers, but you have no Redis servers being set',
+ #sqlite_configs)
+ return false
+ end
+
+ local parsed_redis = lua_redis.try_load_redis_servers(redis_params, nil)
+ if parsed_redis then
+ printf('You have %d sqlite classifiers', #sqlite_configs)
+ local expire = readline_expire()
+
+ local reset_previous = ask_yes_no("Reset previous data?")
+ if ask_yes_no('Do you wish to convert them to Redis?', true) then
+
+ for _, cls in ipairs(sqlite_configs) do
+ if rspamd_util.file_exists(cls.db_spam) and rspamd_util.file_exists(cls.db_ham) then
+ if not lua_stat_tools.convert_sqlite_to_redis(parsed_redis, cls.db_spam,
+ cls.db_ham, cls.symbol_spam, cls.symbol_ham, cls.learn_cache, expire,
+ reset_previous) then
+ rspamd_logger.errx('conversion failed')
+
+ return false
+ end
+ else
+ rspamd_logger.messagex('cannot find %s and %s, skip conversion',
+ cls.db_spam, cls.db_ham)
+ end
+
+ rspamd_logger.messagex('Converted classifier to the from sqlite to redis')
+ changes.l['classifier-bayes.conf'] = {
+ backend = 'redis',
+ new_schema = true,
+ }
+
+ if expire then
+ changes.l['classifier-bayes.conf'].expire = expire
+ end
+
+ if cls.learn_cache then
+ changes.l['classifier-bayes.conf'].cache = {
+ backend = 'redis'
+ }
+ end
+ end
+ end
+ end
+ else
+ -- Check sanity for the existing Redis classifiers
+ local classifier = cfg.classifier
+
+ if classifier then
+ if classifier[1] then
+ for _, cls in ipairs(classifier) do
+ if cls.bayes then
+ cls = cls.bayes
+ end
+ if cls.backend and cls.backend == 'redis' then
+ check_redis_classifier(cls, changes)
+ end
+ end
+ else
+ if classifier.bayes then
+
+ classifier = classifier.bayes
+ if classifier[1] then
+ for _, cls in ipairs(classifier) do
+ if cls.backend and cls.backend == 'redis' then
+ check_redis_classifier(cls, changes)
+ end
+ end
+ else
+ if classifier.backend and classifier.backend == 'redis' then
+ check_redis_classifier(classifier, changes)
+ end
+ end
+ end
+ end
+ end
+ end
+end
+
+local function find_worker(cfg, wtype)
+ if cfg.worker then
+ for k, s in pairs(cfg.worker) do
+ if type(k) == 'number' and type(s) == 'table' then
+ if s[wtype] then
+ return s[wtype]
+ end
+ end
+ if type(s) == 'table' and s.type and s.type == wtype then
+ return s
+ end
+ if type(k) == 'string' and k == wtype then
+ return s
+ end
+ end
+ end
+
+ return nil
+end
+
+return {
+ handler = function(cmd_args)
+ local changes = {
+ l = {}, -- local changes
+ o = {}, -- override changes
+ }
+
+ local interactive_start = true
+ local checks = {}
+ local all_checks = {
+ 'controller',
+ 'redis',
+ 'dkim',
+ 'statistic',
+ }
+
+ local opts = parser:parse(cmd_args)
+ local args = opts['checks'] or {}
+
+ local _r, err = rspamd_config:load_ucl(opts['config'])
+
+ if not _r then
+ rspamd_logger.errx('cannot parse %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
+ _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' })
+ if not _r then
+ rspamd_logger.errx('cannot process %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
+ local cfg = rspamd_config:get_ucl()
+
+ if not rspamd_config:init_modules() then
+ rspamd_logger.errx('cannot init modules when parsing %s', opts['config'])
+ os.exit(1)
+ end
+
+ if #args > 0 then
+ interactive_start = false
+
+ for _, arg in ipairs(args) do
+ if arg == 'all' then
+ checks = all_checks
+ elseif arg == 'list' then
+ printf(highlight(rspamd_logo))
+ printf('Available modules')
+ for _, c in ipairs(all_checks) do
+ printf('- %s', c)
+ end
+ return
+ else
+ table.insert(checks, arg)
+ end
+ end
+ else
+ checks = all_checks
+ end
+
+ local function has_check(check)
+ for _, c in ipairs(checks) do
+ if c == check then
+ return true
+ end
+ end
+
+ return false
+ end
+
+ rspamd_util.umask('022')
+ if interactive_start then
+ printf(highlight(rspamd_logo))
+ printf("Welcome to the configuration tool")
+ printf("We use %s configuration file, writing results to %s",
+ highlight(opts['config']), highlight(local_conf))
+ plugins_stat(nil, nil)
+ end
+
+ if not interactive_start or
+ ask_yes_no("Do you wish to continue?", true) then
+
+ if has_check('controller') then
+ local controller = find_worker(cfg, 'controller')
+ if controller then
+ setup_controller(controller, changes)
+ end
+ end
+
+ if has_check('redis') then
+ if not cfg.redis or (not cfg.redis.servers and not cfg.redis.read_servers) then
+ setup_redis(cfg, changes)
+ else
+ redis_params = cfg.redis
+ end
+ else
+ redis_params = cfg.redis
+ end
+
+ if has_check('dkim') then
+ if cfg.dkim_signing and not cfg.dkim_signing.domain then
+ if ask_yes_no('Do you want to setup dkim signing feature?') then
+ setup_dkim_signing(cfg, changes)
+ end
+ end
+ end
+
+ if has_check('statistic') or has_check('statistics') then
+ setup_statistic(cfg, changes)
+ end
+
+ local nchanges = 0
+ for _, _ in pairs(changes.l) do
+ nchanges = nchanges + 1
+ end
+ for _, _ in pairs(changes.o) do
+ nchanges = nchanges + 1
+ end
+
+ if nchanges > 0 then
+ print_changes(changes)
+ if ask_yes_no("Apply changes?", true) then
+ apply_changes(changes)
+ printf("%d changes applied, the wizard is finished now", nchanges)
+ printf("*** Please reload the Rspamd configuration ***")
+ else
+ printf("No changes applied, the wizard is finished now")
+ end
+ else
+ printf("No changes found, the wizard is finished now")
+ end
+ end
+ end,
+ name = 'configwizard',
+ description = parser._description,
+}
diff --git a/lualib/rspamadm/cookie.lua b/lualib/rspamadm/cookie.lua
new file mode 100644
index 0000000..7e0526a
--- /dev/null
+++ b/lualib/rspamadm/cookie.lua
@@ -0,0 +1,125 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local argparse = require "argparse"
+
+
+-- Define command line options
+local parser = argparse()
+ :name "rspamadm cookie"
+ :description "Produces cookies or message ids"
+ :help_description_margin(30)
+
+parser:mutex(
+ parser:option "-k --key"
+ :description('Key to load')
+ :argname "<32hex>",
+ parser:flag "-K --new-key"
+ :description('Generates a new key')
+)
+
+parser:option "-d --domain"
+ :description('Use specified domain and generate full message id')
+ :argname "<domain>"
+parser:flag "-D --decrypt"
+ :description('Decrypt cookie instead of encrypting one')
+parser:flag "-t --timestamp"
+ :description('Show cookie timestamp (valid for decrypting only)')
+parser:argument "cookie":args "?"
+ :description('Use specified cookie')
+
+local function gen_cookie(args, key)
+ local cr = require "rspamd_cryptobox"
+
+ if not args.cookie then
+ return
+ end
+
+ local function encrypt()
+ if #args.cookie > 31 then
+ print('cookie too long (>31 characters), cannot encrypt')
+ os.exit(1)
+ end
+
+ local enc_cookie = cr.encrypt_cookie(key, args.cookie)
+ if args.domain then
+ print(string.format('<%s@%s>', enc_cookie, args.domain))
+ else
+ print(enc_cookie)
+ end
+ end
+
+ local function decrypt()
+ local extracted_cookie = args.cookie:match('^%<?([^@]+)@.*$')
+ if not extracted_cookie then
+ -- Assume full message id as a cookie
+ extracted_cookie = args.cookie
+ end
+
+ local dec_cookie, ts = cr.decrypt_cookie(key, extracted_cookie)
+
+ if dec_cookie then
+ if args.timestamp then
+ print(string.format('%s %s', dec_cookie, ts))
+ else
+ print(dec_cookie)
+ end
+ else
+ print('cannot decrypt cookie')
+ os.exit(1)
+ end
+ end
+
+ if args.decrypt then
+ decrypt()
+ else
+ encrypt()
+ end
+end
+
+local function handler(args)
+ local res = parser:parse(args)
+
+ if not (res.key or res['new_key']) then
+ parser:error('--key or --new-key must be specified')
+ end
+
+ if res.key then
+ local pattern = { '^' }
+ for i = 1, 32 do
+ pattern[i + 1] = '[a-zA-Z0-9]'
+ end
+ pattern[34] = '$'
+
+ if not res.key:match(table.concat(pattern, '')) then
+ parser:error('invalid key: ' .. res.key)
+ end
+
+ gen_cookie(res, res.key)
+ else
+ local util = require "rspamd_util"
+ local key = util.random_hex(32)
+
+ print(key)
+ gen_cookie(res, res.key)
+ end
+end
+
+return {
+ handler = handler,
+ description = parser._description,
+ name = 'cookie'
+} \ No newline at end of file
diff --git a/lualib/rspamadm/corpus_test.lua b/lualib/rspamadm/corpus_test.lua
new file mode 100644
index 0000000..0e63f9f
--- /dev/null
+++ b/lualib/rspamadm/corpus_test.lua
@@ -0,0 +1,185 @@
+local rspamd_logger = require "rspamd_logger"
+local ucl = require "ucl"
+local lua_util = require "lua_util"
+local argparse = require "argparse"
+
+local parser = argparse()
+ :name "rspamadm corpus_test"
+ :description "Create logs files from email corpus"
+ :help_description_margin(32)
+
+parser:option "-H --ham"
+ :description("Ham directory")
+ :argname("<dir>")
+parser:option "-S --spam"
+ :description("Spam directory")
+ :argname("<dir>")
+parser:option "-n --conns"
+ :description("Number of parallel connections")
+ :argname("<N>")
+ :convert(tonumber)
+ :default(10)
+parser:option "-o --output"
+ :description("Output file")
+ :argname("<file>")
+ :default('results.log')
+parser:option "-t --timeout"
+ :description("Timeout for client connections")
+ :argname("<sec>")
+ :convert(tonumber)
+ :default(60)
+parser:option "-c --connect"
+ :description("Connect to specific host")
+ :argname("<host>")
+ :default('localhost:11334')
+parser:option "-r --rspamc"
+ :description("Use specific rspamc path")
+ :argname("<path>")
+ :default('rspamc')
+
+local HAM = "HAM"
+local SPAM = "SPAM"
+local opts
+
+local function scan_email(n_parallel, path, timeout)
+
+ local rspamc_command = string.format("%s --connect %s -j --compact -n %s -t %.3f %s",
+ opts.rspamc, opts.connect, n_parallel, timeout, path)
+ local result = assert(io.popen(rspamc_command))
+ result = result:read("*all")
+ return result
+end
+
+local function write_results(results, file)
+
+ local f = io.open(file, 'w')
+
+ for _, result in pairs(results) do
+ local log_line = string.format("%s %.2f %s",
+ result.type, result.score, result.action)
+
+ for _, sym in pairs(result.symbols) do
+ log_line = log_line .. " " .. sym
+ end
+
+ log_line = log_line .. " " .. result.scan_time .. " " .. file .. ':' .. result.filename
+
+ log_line = log_line .. "\r\n"
+
+ f:write(log_line)
+ end
+
+ f:close()
+end
+
+local function encoded_json_to_log(result)
+ -- Returns table containing score, action, list of symbols
+
+ local filtered_result = {}
+ local ucl_parser = ucl.parser()
+
+ local is_good, err = ucl_parser:parse_string(result)
+
+ if not is_good then
+ rspamd_logger.errx("Parser error: %1", err)
+ return nil
+ end
+
+ result = ucl_parser:get_object()
+
+ filtered_result.score = result.score
+ if not result.action then
+ rspamd_logger.errx("Bad JSON: %1", result)
+ return nil
+ end
+ local action = result.action:gsub("%s+", "_")
+ filtered_result.action = action
+
+ filtered_result.symbols = {}
+
+ for sym, _ in pairs(result.symbols) do
+ table.insert(filtered_result.symbols, sym)
+ end
+
+ filtered_result.filename = result.filename
+ filtered_result.scan_time = result.scan_time
+
+ return filtered_result
+end
+
+local function scan_results_to_logs(results, actual_email_type)
+
+ local logs = {}
+
+ results = lua_util.rspamd_str_split(results, "\n")
+
+ if results[#results] == "" then
+ results[#results] = nil
+ end
+
+ for _, result in pairs(results) do
+ result = encoded_json_to_log(result)
+ if result then
+ result['type'] = actual_email_type
+ table.insert(logs, result)
+ end
+ end
+
+ return logs
+end
+
+local function handler(args)
+ opts = parser:parse(args)
+ local ham_directory = opts['ham']
+ local spam_directory = opts['spam']
+ local connections = opts["conns"]
+ local output = opts["output"]
+
+ local results = {}
+
+ local start_time = os.time()
+ local no_of_ham = 0
+ local no_of_spam = 0
+
+ if ham_directory then
+ rspamd_logger.messagex("Scanning ham corpus...")
+ local ham_results = scan_email(connections, ham_directory, opts["timeout"])
+ ham_results = scan_results_to_logs(ham_results, HAM)
+
+ no_of_ham = #ham_results
+
+ for _, result in pairs(ham_results) do
+ table.insert(results, result)
+ end
+ end
+
+ if spam_directory then
+ rspamd_logger.messagex("Scanning spam corpus...")
+ local spam_results = scan_email(connections, spam_directory, opts.timeout)
+ spam_results = scan_results_to_logs(spam_results, SPAM)
+
+ no_of_spam = #spam_results
+
+ for _, result in pairs(spam_results) do
+ table.insert(results, result)
+ end
+ end
+
+ rspamd_logger.messagex("Writing results to %s", output)
+ write_results(results, output)
+
+ rspamd_logger.messagex("Stats: ")
+ local elapsed_time = os.time() - start_time
+ local total_msgs = no_of_ham + no_of_spam
+ rspamd_logger.messagex("Elapsed time: %ss", elapsed_time)
+ rspamd_logger.messagex("No of ham: %s", no_of_ham)
+ rspamd_logger.messagex("No of spam: %s", no_of_spam)
+ rspamd_logger.messagex("Messages/sec: %s", (total_msgs / elapsed_time))
+end
+
+return {
+ name = 'corpustest',
+ aliases = { 'corpus_test', 'corpus' },
+ handler = handler,
+ description = parser._description
+}
diff --git a/lualib/rspamadm/dkim_keygen.lua b/lualib/rspamadm/dkim_keygen.lua
new file mode 100644
index 0000000..211094d
--- /dev/null
+++ b/lualib/rspamadm/dkim_keygen.lua
@@ -0,0 +1,178 @@
+--[[
+Copyright (c) 2023, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local argparse = require "argparse"
+local rspamd_util = require "rspamd_util"
+local rspamd_cryptobox = require "rspamd_cryptobox"
+
+local parser = argparse()
+ :name 'rspamadm dkim_keygen'
+ :description 'Create key pairs for dkim signing'
+ :help_description_margin(30)
+parser:option '-d --domain'
+ :description 'Create a key for a specific domain'
+ :default "example.com"
+parser:option '-s --selector'
+ :description 'Create a key for a specific DKIM selector'
+ :default "mail"
+parser:option '-k --privkey'
+ :description 'Save private key to file instead of printing it to stdout'
+parser:option '-b --bits'
+ :convert(function(input)
+ local n = tonumber(input)
+ if not n or n < 512 or n > 4096 then
+ return nil
+ end
+ return n
+end)
+ :description 'Generate an RSA key with the specified number of bits (512 to 4096)'
+ :default(1024)
+parser:option '-t --type'
+ :description 'Key type: RSA, ED25519 or ED25119-seed'
+ :convert {
+ ['rsa'] = 'rsa',
+ ['RSA'] = 'rsa',
+ ['ed25519'] = 'ed25519',
+ ['ED25519'] = 'ed25519',
+ ['ed25519-seed'] = 'ed25519-seed',
+ ['ED25519-seed'] = 'ed25519-seed',
+}
+ :default 'rsa'
+parser:option '-o --output'
+ :description 'Output public key in the following format: dns, dnskey or plain'
+ :convert {
+ ['dns'] = 'dns',
+ ['plain'] = 'plain',
+ ['dnskey'] = 'dnskey',
+}
+ :default 'dns'
+parser:option '--priv-output'
+ :description 'Output private key in the following format: PEM or DER (for RSA)'
+ :convert {
+ ['pem'] = 'pem',
+ ['der'] = 'der',
+}
+ :default 'pem'
+parser:flag '-f --force'
+ :description 'Force overwrite of existing files'
+
+local function split_string(input, max_length)
+ max_length = max_length or 253
+ local pieces = {}
+ local index = 1
+
+ while index <= #input do
+ local piece = input:sub(index, index + max_length - 1)
+ table.insert(pieces, piece)
+ index = index + max_length
+ end
+
+ return pieces
+end
+
+local function print_public_key_dns(opts, base64_pk)
+ local key_type = opts.type == 'rsa' and 'rsa' or 'ed25519'
+ if #base64_pk < 255 - 2 then
+ io.write(string.format('%s._domainkey IN TXT ( "v=DKIM1; k=%s;" \n\t"p=%s" ) ;\n',
+ opts.selector, key_type, base64_pk))
+ else
+ -- Split it by parts
+ local parts = split_string(base64_pk)
+ io.write(string.format('%s._domainkey IN TXT ( "v=DKIM1; k=%s; "\n', opts.selector, key_type))
+ for i, part in ipairs(parts) do
+ if i == 1 then
+ io.write(string.format('\t"p=%s"\n', part))
+ else
+ io.write(string.format('\t"%s"\n', part))
+ end
+ end
+ io.write(") ; \n")
+ end
+
+end
+
+local function print_public_key(opts, pk, need_base64)
+ local key_type = opts.type == 'rsa' and 'rsa' or 'ed25519'
+ local base64_pk = need_base64 and tostring(rspamd_util.encode_base64(pk)) or tostring(pk)
+ if opts.output == 'plain' then
+ io.write(base64_pk)
+ io.write("\n")
+ elseif opts.output == 'dns' then
+ print_public_key_dns(opts, base64_pk, false)
+ elseif opts.output == 'dnskey' then
+ io.write(string.format('v=DKIM1; k=%s; p=%s\n', key_type, base64_pk))
+ end
+end
+
+local function gen_rsa_key(opts)
+ local rsa = require "rspamd_rsa"
+
+ local sk, pk = rsa.keypair(opts.bits or 1024)
+ if opts.privkey then
+ if opts.force then
+ os.remove(opts.privkey)
+ end
+ sk:save(opts.privkey, opts.priv_output)
+ else
+ sk:save("-", opts.priv_output)
+ end
+
+ -- We generate key directly via lua_rsa and it returns unencoded raw data
+ print_public_key(opts, tostring(pk), true)
+end
+
+local function gen_eddsa_key(opts)
+ local sk, pk = rspamd_cryptobox.gen_dkim_keypair(opts.type)
+
+ if opts.privkey and opts.force then
+ os.remove(opts.privkey)
+ end
+ if not sk:save_in_file(opts.privkey, tonumber('0600', 8)) then
+ io.stderr:write('cannot save private key to ' .. (opts.privkey or 'stdout') .. '\n')
+ os.exit(1)
+ end
+
+ if not opts.privkey then
+ io.write("\n")
+ io.flush()
+ end
+
+ -- gen_dkim_keypair function returns everything encoded in base64, so no need to do anything
+ print_public_key(opts, tostring(pk), false)
+end
+
+local function handler(args)
+ local opts = parser:parse(args)
+
+ if not opts then
+ os.exit(1)
+ end
+
+ if opts.type == 'rsa' then
+ gen_rsa_key(opts)
+ else
+ gen_eddsa_key(opts)
+ end
+end
+
+return {
+ name = 'dkim_keygen',
+ aliases = { 'dkimkeygen' },
+ handler = handler,
+ description = parser._description
+}
+
+
diff --git a/lualib/rspamadm/dmarc_report.lua b/lualib/rspamadm/dmarc_report.lua
new file mode 100644
index 0000000..42c801e
--- /dev/null
+++ b/lualib/rspamadm/dmarc_report.lua
@@ -0,0 +1,737 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local argparse = require "argparse"
+local lua_util = require "lua_util"
+local logger = require "rspamd_logger"
+local lua_redis = require "lua_redis"
+local dmarc_common = require "plugins/dmarc"
+local lupa = require "lupa"
+local rspamd_mempool = require "rspamd_mempool"
+local rspamd_url = require "rspamd_url"
+local rspamd_text = require "rspamd_text"
+local rspamd_util = require "rspamd_util"
+local rspamd_dns = require "rspamd_dns"
+
+local N = 'dmarc_report'
+
+-- Define command line options
+local parser = argparse()
+ :name "rspamadm dmarc_report"
+ :description "Dmarc reports sending tool"
+ :help_description_margin(30)
+
+parser:option "-c --config"
+ :description "Path to config file"
+ :argname("<cfg>")
+ :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
+
+parser:flag "-v --verbose"
+ :description "Enable dmarc specific logging"
+
+parser:flag "-n --no-opt"
+ :description "Do not reset reporting data/send reports"
+
+parser:argument "date"
+ :description "Date to process (today by default)"
+ :argname "<YYYYMMDD>"
+ :args "*"
+parser:option "-b --batch-size"
+ :description "Send reports in batches up to <batch-size> messages"
+ :argname "<number>"
+ :convert(tonumber)
+ :default "10"
+
+local report_template = [[From: "{= from_name =}" <{= from_addr =}>
+To: {= rcpt =}
+{%+ if is_string(bcc) %}Bcc: {= bcc =}{%- endif %}
+Subject: Report Domain: {= reporting_domain =}
+ Submitter: {= submitter =}
+ Report-ID: {= report_id =}
+Date: {= report_date =}
+MIME-Version: 1.0
+Message-ID: <{= message_id =}>
+Content-Type: multipart/mixed;
+ boundary="----=_NextPart_{= uuid =}"
+
+This is a multipart message in MIME format.
+
+------=_NextPart_{= uuid =}
+Content-Type: text/plain; charset="us-ascii"
+Content-Transfer-Encoding: 7bit
+
+This is an aggregate report from {= submitter =}.
+
+Report domain: {= reporting_domain =}
+Submitter: {= submitter =}
+Report ID: {= report_id =}
+
+------=_NextPart_{= uuid =}
+Content-Type: application/gzip
+Content-Transfer-Encoding: base64
+Content-Disposition: attachment;
+ filename="{= submitter =}!{= reporting_domain =}!{= report_start =}!{= report_end =}.xml.gz"
+
+]]
+local report_footer = [[
+
+------=_NextPart_{= uuid =}--]]
+
+local dmarc_settings = {}
+local redis_params
+local redis_attrs = {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ log_obj = rspamd_config,
+ resolver = rspamadm_dns_resolver,
+}
+local pool
+
+local function load_config(opts)
+ local _r, err = rspamd_config:load_ucl(opts['config'])
+
+ if not _r then
+ logger.errx('cannot parse %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
+ _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' })
+ if not _r then
+ logger.errx('cannot process %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+end
+
+-- Concat elements using redis_keys.join_char
+local function redis_prefix(...)
+ return table.concat({ ... }, dmarc_settings.reporting.redis_keys.join_char)
+end
+
+local function get_rua(rep_key)
+ local parts = lua_util.str_split(rep_key, dmarc_settings.reporting.redis_keys.join_char)
+
+ if #parts >= 3 then
+ return parts[3]
+ end
+
+ return nil
+end
+
+local function get_domain(rep_key)
+ local parts = lua_util.str_split(rep_key, dmarc_settings.reporting.redis_keys.join_char)
+
+ if #parts >= 3 then
+ return parts[2]
+ end
+
+ return nil
+end
+
+local function gen_uuid()
+ local template = 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'
+ return string.gsub(template, '[xy]', function(c)
+ local v = (c == 'x') and math.random(0, 0xf) or math.random(8, 0xb)
+ return string.format('%x', v)
+ end)
+end
+
+local function gen_xml_grammar()
+ local lpeg = require 'lpeg'
+ local lt = lpeg.P('<') / '&lt;'
+ local gt = lpeg.P('>') / '&gt;'
+ local amp = lpeg.P('&') / '&amp;'
+ local quot = lpeg.P('"') / '&quot;'
+ local apos = lpeg.P("'") / '&apos;'
+ local special = lt + gt + amp + quot + apos
+ local grammar = lpeg.Cs((special + 1) ^ 0)
+ return grammar
+end
+
+local xml_grammar = gen_xml_grammar()
+
+local function escape_xml(input)
+ if type(input) == 'string' or type(input) == 'userdata' then
+ return xml_grammar:match(input)
+ else
+ input = tostring(input)
+
+ if input then
+ return xml_grammar:match(input)
+ end
+ end
+
+ return ''
+end
+-- Enable xml escaping in lupa templates
+lupa.filters.escape_xml = escape_xml
+
+-- Creates report XML header
+local function report_header(reporting_domain, report_start, report_end, domain_policy)
+ local report_id = string.format('%s.%d.%d',
+ reporting_domain, report_start, report_end)
+ local xml_template = [[
+<?xml version="1.0" encoding="UTF-8" ?>
+<feedback>
+ <report_metadata>
+ <org_name>{= report_settings.org_name | escape_xml =}</org_name>
+ <email>{= report_settings.email | escape_xml =}</email>
+ <report_id>{= report_id =}</report_id>
+ <date_range>
+ <begin>{= report_start =}</begin>
+ <end>{= report_end =}</end>
+ </date_range>
+ </report_metadata>
+ <policy_published>
+ <domain>{= reporting_domain | escape_xml =}</domain>
+ <adkim>{= domain_policy.adkim | escape_xml =}</adkim>
+ <aspf>{= domain_policy.aspf | escape_xml =}</aspf>
+ <p>{= domain_policy.p | escape_xml =}</p>
+ <sp>{= domain_policy.sp | escape_xml =}</sp>
+ <pct>{= domain_policy.pct | escape_xml =}</pct>
+ </policy_published>
+]]
+ return lua_util.jinja_template(xml_template, {
+ report_settings = dmarc_settings.reporting,
+ report_id = report_id,
+ report_start = report_start,
+ report_end = report_end,
+ domain_policy = domain_policy,
+ reporting_domain = reporting_domain,
+ }, true)
+end
+
+-- Generate xml entry for a preprocessed redis row
+local function entry_to_xml(data)
+ local xml_template = [[<record>
+ <row>
+ <source_ip>{= data.ip =}</source_ip>
+ <count>{= data.count =}</count>
+ <policy_evaluated>
+ <disposition>{= data.disposition =}</disposition>
+ <dkim>{= data.dkim_disposition =}</dkim>
+ <spf>{= data.spf_disposition =}</spf>
+ {% if data.override and data.override ~= '' -%}
+ <reason><type>{= data.override =}</type></reason>
+ {%- endif %}
+ </policy_evaluated>
+ </row>
+ <identifiers>
+ <header_from>{= data.header_from =}</header_from>
+ </identifiers>
+ <auth_results>
+ {% if data.dkim_results[1] -%}
+ {% for d in data.dkim_results -%}
+ <dkim>
+ <domain>{= d.domain =}</domain>
+ <result>{= d.result =}</result>
+ </dkim>
+ {%- endfor %}
+ {%- endif %}
+ <spf>
+ <domain>{= data.spf_domain =}</domain>
+ <result>{= data.spf_result =}</result>
+ </spf>
+ </auth_results>
+</record>
+]]
+ return lua_util.jinja_template(xml_template, { data = data }, true)
+end
+
+-- Process a report entry stored in Redis splitting it to a lua table
+local function process_report_entry(data, score)
+ local split = lua_util.str_split(data, ',')
+ local row = {
+ ip = split[1],
+ spf_disposition = split[2],
+ dkim_disposition = split[3],
+ disposition = split[4],
+ override = split[5],
+ header_from = split[6],
+ dkim_results = {},
+ spf_domain = split[11],
+ spf_result = split[12],
+ count = tonumber(score),
+ }
+ -- Process dkim entries
+ local function dkim_entries_process(dkim_data, result)
+ if dkim_data and dkim_data ~= '' then
+ local dkim_elts = lua_util.str_split(dkim_data, '|')
+ for _, d in ipairs(dkim_elts) do
+ table.insert(row.dkim_results, { domain = d, result = result })
+ end
+ end
+ end
+ dkim_entries_process(split[7], 'pass')
+ dkim_entries_process(split[8], 'fail')
+ dkim_entries_process(split[9], 'temperror')
+ dkim_entries_process(split[9], 'permerror')
+
+ return row
+end
+
+-- Process a single rua entry, validating in DNS if needed
+local function process_rua(dmarc_domain, rua)
+ local parts = lua_util.str_split(rua, ',')
+
+ -- Remove size limitation, as we don't care about them
+ local addrs = {}
+ for _, rua_part in ipairs(parts) do
+ local u = rspamd_url.create(pool, rua_part:gsub('!%d+[kmg]?$', ''))
+ local u2 = rspamd_url.create(pool, dmarc_domain)
+ if u and (u:get_protocol() or '') == 'mailto' and u:get_user() then
+ -- Check each address for sanity
+ if u:get_tld() == u2:get_tld() then
+ -- Same eSLD - always include
+ table.insert(addrs, u)
+ else
+ -- We need to check authority
+ local resolve_str = string.format('%s._report._dmarc.%s',
+ dmarc_domain, u:get_host())
+ local is_ok, results = rspamd_dns.request({
+ config = rspamd_config,
+ session = rspamadm_session,
+ type = 'txt',
+ name = resolve_str,
+ })
+
+ if not is_ok then
+ logger.errx('cannot resolve %s: %s; exclude %s', resolve_str, results, rua_part)
+ else
+ local found = false
+ for _, t in ipairs(results) do
+ if string.match(t, 'v=DMARC1') then
+ found = true
+ break
+ end
+ end
+
+ if not found then
+ logger.errx('%s is not authorized to process reports on %s', dmarc_domain, u:get_host())
+ else
+ -- All good
+ table.insert(addrs, u)
+ end
+ end
+ end
+ else
+ logger.errx('invalid rua url: "%s""', tostring(u or 'null'))
+ end
+ end
+
+ if #addrs > 0 then
+ return addrs
+ end
+
+ return nil
+end
+
+-- Validate reporting domain, extracting rua and checking 3rd party report domains
+-- This function returns a full dmarc record processed + rua as a list of url objects
+local function validate_reporting_domain(reporting_domain)
+ local is_ok, results = rspamd_dns.request({
+ config = rspamd_config,
+ session = rspamadm_session,
+ type = 'txt',
+ name = '_dmarc.' .. reporting_domain,
+ })
+
+ if not is_ok or not results then
+ logger.errx('cannot resolve _dmarc.%s: %s', reporting_domain, results)
+ return nil
+ end
+
+ for _, r in ipairs(results) do
+ local processed, rec = dmarc_common.dmarc_check_record(rspamd_config, r, false)
+ if processed and rec.rua then
+ -- We need to check or alter rua if needed
+ local processed_rua = process_rua(reporting_domain, rec.rua)
+ if processed_rua then
+ rec = rec.raw_elts
+ rec.rua = processed_rua
+
+ -- Fill defaults in a record to avoid nils in a report
+ rec['pct'] = rec['pct'] or 100
+ rec['adkim'] = rec['adkim'] or 'r'
+ rec['aspf'] = rec['aspf'] or 'r'
+ rec['p'] = rec['p'] or 'none'
+ rec['sp'] = rec['sp'] or 'none'
+ return rec
+ end
+ return nil
+ end
+ end
+
+ return nil
+end
+
+-- Returns a list of recipients from a table as a string processing elements if needed
+local function rcpt_list(tbl, func)
+ local res = {}
+ for _, r in ipairs(tbl) do
+ if func then
+ table.insert(res, func(r))
+ else
+ table.insert(res, r)
+ end
+ end
+
+ return table.concat(res, ',')
+end
+
+-- Synchronous smtp send function
+local function send_reports_by_smtp(opts, reports, finish_cb)
+ local lua_smtp = require "lua_smtp"
+ local reports_failed = 0
+ local reports_sent = 0
+ local report_settings = dmarc_settings.reporting
+
+ local function gen_sendmail_cb(report, args)
+ return function(ret, err)
+ -- We modify this from all callbacks
+ args.nreports = args.nreports - 1
+ if not ret then
+ logger.errx("Couldn't send mail for %s: %s", report.reporting_domain, err)
+ reports_failed = reports_failed + 1
+ else
+ reports_sent = reports_sent + 1
+ lua_util.debugm(N, 'successfully sent a report for %s: %s bytes sent',
+ report.reporting_domain, #report.message)
+ end
+
+ -- Tail call to the next batch or to the final function
+ if args.nreports == 0 then
+ if args.next_start > #reports then
+ finish_cb(reports_sent, reports_failed)
+ else
+ args.cont_func(args.next_start)
+ end
+ end
+ end
+ end
+
+ local function send_data_in_batches(cur_batch)
+ local nreports = math.min(#reports - cur_batch + 1, opts.batch_size)
+ local next_start = cur_batch + nreports
+ lua_util.debugm(N, 'send data for %s domains (from %s to %s)',
+ nreports, cur_batch, next_start - 1)
+ -- Shared across all closures
+ local gen_args = {
+ cont_func = send_data_in_batches,
+ nreports = nreports,
+ next_start = next_start
+ }
+ for i = cur_batch, next_start - 1 do
+ local report = reports[i]
+ local send_opts = {
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ config = rspamd_config,
+ host = report_settings.smtp,
+ port = report_settings.smtp_port or 25,
+ resolver = rspamadm_dns_resolver,
+ from = report_settings.email,
+ recipients = report.rcpts,
+ helo = report_settings.helo or 'rspamd.localhost',
+ }
+
+ lua_smtp.sendmail(send_opts,
+ report.message,
+ gen_sendmail_cb(report, gen_args))
+ end
+ end
+
+ send_data_in_batches(1)
+end
+
+local function prepare_report(opts, start_time, end_time, rep_key)
+ local rua = get_rua(rep_key)
+ local reporting_domain = get_domain(rep_key)
+
+ if not rua then
+ logger.errx('report %s has no valid rua, skip it', rep_key)
+ return nil
+ end
+ if not reporting_domain then
+ logger.errx('report %s has no valid reporting_domain, skip it', rep_key)
+ return nil
+ end
+
+ local ret, results = lua_redis.request(redis_params, redis_attrs,
+ { 'EXISTS', rep_key })
+
+ if not ret or not results or results == 0 then
+ return nil
+ end
+
+ -- Rename report key to avoid races
+ if not opts.no_opt then
+ lua_redis.request(redis_params, redis_attrs,
+ { 'RENAME', rep_key, rep_key .. '_processing' })
+ rep_key = rep_key .. '_processing'
+ end
+
+ local dmarc_record = validate_reporting_domain(reporting_domain)
+ lua_util.debugm(N, 'process reporting domain %s: %s', reporting_domain, dmarc_record)
+
+ if not dmarc_record then
+ if not opts.no_opt then
+ lua_redis.request(redis_params, redis_attrs,
+ { 'DEL', rep_key })
+ end
+ logger.messagex('Cannot process reports for domain %s; invalid dmarc record', reporting_domain)
+ return nil
+ end
+
+ -- Get all reports for a domain
+ ret, results = lua_redis.request(redis_params, redis_attrs,
+ { 'ZRANGE', rep_key, '0', '-1', 'WITHSCORES' })
+ local report_entries = {}
+ table.insert(report_entries,
+ report_header(reporting_domain, start_time, end_time, dmarc_record))
+ for i = 1, #results, 2 do
+ local xml_record = entry_to_xml(process_report_entry(results[i], results[i + 1]))
+ table.insert(report_entries, xml_record)
+ end
+ table.insert(report_entries, '</feedback>')
+ local xml_to_compress = rspamd_text.fromtable(report_entries)
+ lua_util.debugm(N, 'got xml: %s', xml_to_compress)
+
+ -- Prepare SMTP message
+ local report_settings = dmarc_settings.reporting
+ local rcpt_string = rcpt_list(dmarc_record.rua, function(rua_elt)
+ return string.format('%s@%s', rua_elt:get_user(), rua_elt:get_host())
+ end)
+ local bcc_string
+ if report_settings.bcc_addrs then
+ bcc_string = rcpt_list(report_settings.bcc_addrs)
+ end
+ local uuid = gen_uuid()
+ local rhead = lua_util.jinja_template(report_template, {
+ from_name = report_settings.from_name,
+ from_addr = report_settings.email,
+ rcpt = rcpt_string,
+ bcc = bcc_string,
+ uuid = uuid,
+ reporting_domain = reporting_domain,
+ submitter = report_settings.domain,
+ report_id = string.format('%s.%d.%d', reporting_domain, start_time,
+ end_time),
+ report_date = rspamd_util.time_to_string(rspamd_util.get_time()),
+ message_id = rspamd_util.random_hex(16) .. '@' .. report_settings.msgid_from,
+ report_start = start_time,
+ report_end = end_time
+ }, true)
+ local rfooter = lua_util.jinja_template(report_footer, {
+ uuid = uuid,
+ }, true)
+ local message = rspamd_text.fromtable {
+ (rhead:gsub("\n", "\r\n")),
+ rspamd_util.encode_base64(rspamd_util.gzip_compress(xml_to_compress), 73),
+ rfooter:gsub("\n", "\r\n"),
+ }
+
+ lua_util.debugm(N, 'got final message: %s', message)
+
+ if not opts.no_opt then
+ lua_redis.request(redis_params, redis_attrs,
+ { 'DEL', rep_key })
+ end
+
+ local report_rcpts = lua_util.str_split(rcpt_string, ',')
+
+ if report_settings.bcc_addrs then
+ for _, b in ipairs(report_settings.bcc_addrs) do
+ table.insert(report_rcpts, b)
+ end
+ end
+
+ return {
+ message = message,
+ rcpts = report_rcpts,
+ reporting_domain = reporting_domain
+ }
+end
+
+local function process_report_date(opts, start_time, end_time, date)
+ local idx_key = redis_prefix(dmarc_settings.reporting.redis_keys.index_prefix, date)
+ local ret, results = lua_redis.request(redis_params, redis_attrs,
+ { 'EXISTS', idx_key })
+
+ if not ret or not results or results == 0 then
+ logger.messagex('No reports for %s', date)
+ return {}
+ end
+
+ -- Rename index key to avoid races
+ if not opts.no_opt then
+ lua_redis.request(redis_params, redis_attrs,
+ { 'RENAME', idx_key, idx_key .. '_processing' })
+ idx_key = idx_key .. '_processing'
+ end
+ ret, results = lua_redis.request(redis_params, redis_attrs,
+ { 'SMEMBERS', idx_key })
+
+ if not ret or not results then
+ -- Remove bad key
+ if not opts.no_opt then
+ lua_redis.request(redis_params, redis_attrs,
+ { 'DEL', idx_key })
+ end
+ logger.messagex('Cannot get reports for %s', date)
+ return {}
+ end
+
+ local reports = {}
+ for _, rep in ipairs(results) do
+ local report = prepare_report(opts, start_time, end_time, rep)
+
+ if report then
+ table.insert(reports, report)
+ end
+ end
+
+ -- Shuffle reports to make sending more fair
+ lua_util.shuffle(reports)
+ -- Remove processed key
+ if not opts.no_opt then
+ lua_redis.request(redis_params, redis_attrs,
+ { 'DEL', idx_key })
+ end
+
+ return reports
+end
+
+
+-- Returns a day before today at 00:00 as unix seconds
+local function yesterday_midnight()
+ local piecewise_time = os.date("*t")
+ piecewise_time.day = piecewise_time.day - 1 -- Lua allows negative values here
+ piecewise_time.hour = 0
+ piecewise_time.sec = 0
+ piecewise_time.min = 0
+ return os.time(piecewise_time)
+end
+
+-- Returns today time at 00:00 as unix seconds
+local function today_midnight()
+ local piecewise_time = os.date("*t")
+ piecewise_time.hour = 0
+ piecewise_time.sec = 0
+ piecewise_time.min = 0
+ return os.time(piecewise_time)
+end
+
+local function handler(args)
+ local start_time
+ -- Preserve start time as report sending might take some time
+ local start_collection = today_midnight()
+
+ local opts = parser:parse(args)
+
+ pool = rspamd_mempool.create()
+ load_config(opts)
+ rspamd_url.init(rspamd_config:get_tld_path())
+
+ if opts.verbose then
+ lua_util.enable_debug_modules('dmarc', N)
+ end
+
+ dmarc_settings = rspamd_config:get_all_opt('dmarc')
+ if not dmarc_settings or not dmarc_settings.reporting or not dmarc_settings.reporting.enabled then
+ logger.errx('dmarc reporting is not enabled, exiting')
+ os.exit(1)
+ end
+
+ dmarc_settings = lua_util.override_defaults(dmarc_common.default_settings, dmarc_settings)
+ redis_params = lua_redis.parse_redis_server('dmarc', dmarc_settings)
+
+ if not redis_params then
+ logger.errx('Redis is not configured, exiting')
+ os.exit(1)
+ end
+
+ for _, e in ipairs({ 'email', 'domain', 'org_name' }) do
+ if not dmarc_settings.reporting[e] then
+ logger.errx('Missing required setting: dmarc.reporting.%s', e)
+ return
+ end
+ end
+
+ local ret, results = lua_redis.request(redis_params, redis_attrs, {
+ 'GET', 'rspamd_dmarc_last_collection'
+ })
+
+ if not ret or not tonumber(results) then
+ start_time = yesterday_midnight()
+ else
+ start_time = tonumber(results)
+ end
+
+ lua_util.debugm(N, 'previous last report date is %s', start_time)
+
+ if not opts.date or #opts.date == 0 then
+ opts.date = {}
+ table.insert(opts.date, os.date('%Y%m%d', yesterday_midnight()))
+ end
+
+ local ndates = 0
+ local nreports = 0
+ local all_reports = {}
+ for _, date in ipairs(opts.date) do
+ lua_util.debugm(N, 'Process date %s', date)
+ local reports_for_date = process_report_date(opts, start_time, start_collection, date)
+ if #reports_for_date > 0 then
+ ndates = ndates + 1
+ nreports = nreports + #reports_for_date
+
+ for _, r in ipairs(reports_for_date) do
+ table.insert(all_reports, r)
+ end
+ end
+ end
+
+ local function finish_cb(nsuccess, nfail)
+ if not opts.no_opt then
+ lua_util.debugm(N, 'set last report date to %s', start_collection)
+ -- Hack to avoid coroutines + async functions mess: we use async redis call here
+ redis_attrs.callback = function()
+ logger.messagex('Reporting collection has finished %s dates processed, %s reports: %s completed, %s failed',
+ ndates, nreports, nsuccess, nfail)
+ end
+ lua_redis.request(redis_params, redis_attrs,
+ { 'SETEX', 'rspamd_dmarc_last_collection', dmarc_settings.reporting.keys_expire * 2,
+ tostring(start_collection) })
+ else
+ logger.messagex('Reporting collection has finished %s dates processed, %s reports: %s completed, %s failed',
+ ndates, nreports, nsuccess, nfail)
+ end
+
+ pool:destroy()
+ end
+ if not opts.no_opt then
+ send_reports_by_smtp(opts, all_reports, finish_cb)
+ else
+ logger.messagex('Skip sending mails due to -n / --no-opt option')
+ end
+end
+
+return {
+ name = 'dmarc_report',
+ aliases = { 'dmarc_reporting' },
+ handler = handler,
+ description = parser._description
+}
diff --git a/lualib/rspamadm/dns_tool.lua b/lualib/rspamadm/dns_tool.lua
new file mode 100644
index 0000000..3eb09a8
--- /dev/null
+++ b/lualib/rspamadm/dns_tool.lua
@@ -0,0 +1,232 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+
+local argparse = require "argparse"
+local rspamd_logger = require "rspamd_logger"
+local ansicolors = require "ansicolors"
+local bit = require "bit"
+
+local parser = argparse()
+ :name "rspamadm dnstool"
+ :description "DNS tools provided by Rspamd"
+ :help_description_margin(30)
+ :command_target("command")
+ :require_command(true)
+
+parser:option "-c --config"
+ :description "Path to config file"
+ :argname("<cfg>")
+ :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
+
+local spf = parser:command "spf"
+ :description "Extracts spf records"
+spf:mutex(
+ spf:option "-d --domain"
+ :description "Domain to use"
+ :argname("<domain>"),
+ spf:option "-f --from"
+ :description "SMTP from to use"
+ :argname("<from>")
+)
+
+spf:option "-i --ip"
+ :description "Source IP address to use"
+ :argname("<ip>")
+spf:flag "-a --all"
+ :description "Print all records"
+
+local function printf(fmt, ...)
+ if fmt then
+ io.write(string.format(fmt, ...))
+ end
+ io.write('\n')
+end
+
+local function highlight(str)
+ return ansicolors.white .. str .. ansicolors.reset
+end
+
+local function green(str)
+ return ansicolors.green .. str .. ansicolors.reset
+end
+
+local function red(str)
+ return ansicolors.red .. str .. ansicolors.reset
+end
+
+local function load_config(opts)
+ local _r, err = rspamd_config:load_ucl(opts['config'])
+
+ if not _r then
+ rspamd_logger.errx('cannot parse %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
+ _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' })
+ if not _r then
+ rspamd_logger.errx('cannot process %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+end
+
+local function spf_handler(opts)
+ local rspamd_spf = require "rspamd_spf"
+ local rspamd_task = require "rspamd_task"
+ local rspamd_ip = require "rspamd_ip"
+
+ local task = rspamd_task:create(rspamd_config, rspamadm_ev_base)
+ task:set_session(rspamadm_session)
+ task:set_resolver(rspamadm_dns_resolver)
+
+ if opts.ip then
+ opts.ip = rspamd_ip.fromstring(opts.ip)
+ task:set_from_ip(opts.ip)
+ else
+ opts.all = true
+ end
+
+ if opts.from then
+ local rspamd_parsers = require "rspamd_parsers"
+ local addr_parsed = rspamd_parsers.parse_mail_address(opts.from)
+ if addr_parsed then
+ task:set_from('smtp', addr_parsed[1])
+ else
+ io.stderr:write('Invalid from addr\n')
+ os.exit(1)
+ end
+ elseif opts.domain then
+ task:set_from('smtp', { user = 'user', domain = opts.domain })
+ else
+ io.stderr:write('Neither domain nor from specified\n')
+ os.exit(1)
+ end
+
+ local function flag_to_str(fl)
+ if bit.band(fl, rspamd_spf.flags.temp_fail) ~= 0 then
+ return "temporary failure"
+ elseif bit.band(fl, rspamd_spf.flags.perm_fail) ~= 0 then
+ return "permanent failure"
+ elseif bit.band(fl, rspamd_spf.flags.na) ~= 0 then
+ return "no spf record"
+ end
+
+ return "unknown flag: " .. tostring(fl)
+ end
+
+ local function display_spf_results(elt, colored)
+ local dec = function(e)
+ return e
+ end
+ local policy_decode = function(e)
+ if e == rspamd_spf.policy.fail then
+ return 'reject'
+ elseif e == rspamd_spf.policy.pass then
+ return 'pass'
+ elseif e == rspamd_spf.policy.soft_fail then
+ return 'soft fail'
+ elseif e == rspamd_spf.policy.neutral then
+ return 'neutral'
+ end
+
+ return 'unknown'
+ end
+
+ if colored then
+ dec = function(e)
+ return highlight(e)
+ end
+
+ if elt.result == rspamd_spf.policy.pass then
+ dec = function(e)
+ return green(e)
+ end
+ elseif elt.result == rspamd_spf.policy.fail then
+ dec = function(e)
+ return red(e)
+ end
+ end
+
+ end
+ printf('%s: %s', highlight('Policy'), dec(policy_decode(elt.result)))
+ printf('%s: %s', highlight('Network'), dec(elt.addr))
+
+ if elt.str then
+ printf('%s: %s', highlight('Original'), elt.str)
+ end
+ end
+
+ local function cb(record, flags, err)
+ if record then
+ local result, flag_or_policy, error_or_addr
+ if opts.ip then
+ result, flag_or_policy, error_or_addr = record:check_ip(opts.ip)
+ elseif opts.all then
+ result = true
+ end
+ if opts.ip and not opts.all then
+ if result then
+ display_spf_results(error_or_addr, true)
+ else
+ printf('Not matched: %s', error_or_addr)
+ end
+
+ os.exit(0)
+ end
+
+ if result then
+ printf('SPF record for %s; digest: %s',
+ highlight(opts.domain or opts.from), highlight(record:get_digest()))
+ for _, elt in ipairs(record:get_elts()) do
+ if result and error_or_addr and elt.str and elt.str == error_or_addr.str then
+ printf("%s", highlight('*** Matched ***'))
+ display_spf_results(elt, true)
+ printf('------')
+ else
+ display_spf_results(elt, false)
+ printf('------')
+ end
+ end
+ else
+ printf('Error getting SPF record: %s (%s flag)', err,
+ flag_to_str(flag_or_policy or flags))
+ end
+ else
+ printf('Cannot get SPF record: %s', err)
+ end
+ end
+ rspamd_spf.resolve(task, cb)
+end
+
+local function handler(args)
+ local opts = parser:parse(args)
+ load_config(opts)
+
+ local command = opts.command
+
+ if command == 'spf' then
+ spf_handler(opts)
+ else
+ parser:error('command %s is not implemented', command)
+ end
+end
+
+return {
+ name = 'dnstool',
+ aliases = { 'dns', 'dns_tool' },
+ handler = handler,
+ description = parser._description
+} \ No newline at end of file
diff --git a/lualib/rspamadm/fuzzy_convert.lua b/lualib/rspamadm/fuzzy_convert.lua
new file mode 100644
index 0000000..fab3995
--- /dev/null
+++ b/lualib/rspamadm/fuzzy_convert.lua
@@ -0,0 +1,208 @@
+local sqlite3 = require "rspamd_sqlite3"
+local redis = require "rspamd_redis"
+local util = require "rspamd_util"
+
+local function connect_redis(server, username, password, db)
+ local ret
+ local conn, err = redis.connect_sync({
+ host = server,
+ })
+
+ if not conn then
+ return nil, 'Cannot connect: ' .. err
+ end
+
+ if username then
+ if password then
+ ret = conn:add_cmd('AUTH', { username, password })
+ if not ret then
+ return nil, 'Cannot queue command'
+ end
+ else
+ return nil, 'Redis requires a password when username is supplied'
+ end
+ elseif password then
+ ret = conn:add_cmd('AUTH', { password })
+ if not ret then
+ return nil, 'Cannot queue command'
+ end
+ end
+ if db then
+ ret = conn:add_cmd('SELECT', { db })
+ if not ret then
+ return nil, 'Cannot queue command'
+ end
+ end
+
+ return conn, nil
+end
+
+local function send_digests(digests, redis_host, redis_username, redis_password, redis_db)
+ local conn, err = connect_redis(redis_host, redis_username, redis_password, redis_db)
+ if err then
+ print(err)
+ return false
+ end
+ local ret
+ for _, v in ipairs(digests) do
+ ret = conn:add_cmd('HMSET', {
+ 'fuzzy' .. v[1],
+ 'F', v[2],
+ 'V', v[3],
+ })
+ if not ret then
+ print('Cannot batch command')
+ return false
+ end
+ ret = conn:add_cmd('EXPIRE', {
+ 'fuzzy' .. v[1],
+ tostring(v[4]),
+ })
+ if not ret then
+ print('Cannot batch command')
+ return false
+ end
+ end
+ ret, err = conn:exec()
+ if not ret then
+ print('Cannot execute batched commands: ' .. err)
+ return false
+ end
+ return true
+end
+
+local function send_shingles(shingles, redis_host, redis_username, redis_password, redis_db)
+ local conn, err = connect_redis(redis_host, redis_username, redis_password, redis_db)
+ if err then
+ print("Redis error: " .. err)
+ return false
+ end
+ local ret
+ for _, v in ipairs(shingles) do
+ ret = conn:add_cmd('SET', {
+ 'fuzzy_' .. v[2] .. '_' .. v[1],
+ v[4],
+ })
+ if not ret then
+ print('Cannot batch SET command: ' .. err)
+ return false
+ end
+ ret = conn:add_cmd('EXPIRE', {
+ 'fuzzy_' .. v[2] .. '_' .. v[1],
+ tostring(v[3]),
+ })
+ if not ret then
+ print('Cannot batch command')
+ return false
+ end
+ end
+ ret, err = conn:exec()
+ if not ret then
+ print('Cannot execute batched commands: ' .. err)
+ return false
+ end
+ return true
+end
+
+local function update_counters(total, redis_host, redis_username, redis_password, redis_db)
+ local conn, err = connect_redis(redis_host, redis_username, redis_password, redis_db)
+ if err then
+ print(err)
+ return false
+ end
+ local ret
+ ret = conn:add_cmd('SET', {
+ 'fuzzylocal',
+ total,
+ })
+ if not ret then
+ print('Cannot batch command')
+ return false
+ end
+ ret = conn:add_cmd('SET', {
+ 'fuzzy_count',
+ total,
+ })
+ if not ret then
+ print('Cannot batch command')
+ return false
+ end
+ ret, err = conn:exec()
+ if not ret then
+ print('Cannot execute batched commands: ' .. err)
+ return false
+ end
+ return true
+end
+
+return function(_, res)
+ local db = sqlite3.open(res['source_db'])
+ local shingles = {}
+ local digests = {}
+ local num_batch_digests = 0
+ local num_batch_shingles = 0
+ local total_digests = 0
+ local total_shingles = 0
+ local lim_batch = 1000 -- Update each 1000 entries
+ local redis_username = res['redis_username']
+ local redis_password = res['redis_password']
+ local redis_db = nil
+
+ if res['redis_db'] then
+ redis_db = tostring(res['redis_db'])
+ end
+
+ if not db then
+ print('Cannot open source db: ' .. res['source_db'])
+ return
+ end
+
+ local now = util.get_time()
+ for row in db:rows('SELECT id, flag, digest, value, time FROM digests') do
+
+ local expire_in = math.floor(now - row.time + res['expiry'])
+ if expire_in >= 1 then
+ table.insert(digests, { row.digest, row.flag, row.value, expire_in })
+ num_batch_digests = num_batch_digests + 1
+ total_digests = total_digests + 1
+ for srow in db:rows('SELECT value, number FROM shingles WHERE digest_id = ' .. row.id) do
+ table.insert(shingles, { srow.value, srow.number, expire_in, row.digest })
+ total_shingles = total_shingles + 1
+ num_batch_shingles = num_batch_shingles + 1
+ end
+ end
+ if num_batch_digests >= lim_batch then
+ if not send_digests(digests, res['redis_host'], redis_username, redis_password, redis_db) then
+ return
+ end
+ num_batch_digests = 0
+ digests = {}
+ end
+ if num_batch_shingles >= lim_batch then
+ if not send_shingles(shingles, res['redis_host'], redis_username, redis_password, redis_db) then
+ return
+ end
+ num_batch_shingles = 0
+ shingles = {}
+ end
+ end
+ if digests[1] then
+ if not send_digests(digests, res['redis_host'], redis_username, redis_password, redis_db) then
+ return
+ end
+ end
+ if shingles[1] then
+ if not send_shingles(shingles, res['redis_host'], redis_username, redis_password, redis_db) then
+ return
+ end
+ end
+
+ local message = string.format(
+ 'Migrated %d digests and %d shingles',
+ total_digests, total_shingles
+ )
+ if not update_counters(total_digests, res['redis_host'], redis_username, redis_password, redis_db) then
+ message = message .. ' but failed to update counters'
+ end
+ print(message)
+end
diff --git a/lualib/rspamadm/fuzzy_ping.lua b/lualib/rspamadm/fuzzy_ping.lua
new file mode 100644
index 0000000..e0345da
--- /dev/null
+++ b/lualib/rspamadm/fuzzy_ping.lua
@@ -0,0 +1,259 @@
+--[[
+Copyright (c) 2023, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local argparse = require "argparse"
+local ansicolors = require "ansicolors"
+local rspamd_logger = require "rspamd_logger"
+local lua_util = require "lua_util"
+
+local E = {}
+
+local parser = argparse()
+ :name 'rspamadm fuzzy_ping'
+ :description 'Pings fuzzy storage'
+ :help_description_margin(30)
+parser:option "-c --config"
+ :description "Path to config file"
+ :argname("<cfg>")
+ :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
+parser:option "-r --rule"
+ :description "Storage to ping (must be configured in Rspamd configuration)"
+ :argname("<name>")
+ :default("rspamd.com")
+parser:flag "-f --flood"
+ :description "Flood mode (no waiting for replies)"
+parser:flag "-S --silent"
+ :description "Silent mode (statistics only)"
+parser:option "-t --timeout"
+ :description "Timeout for requests"
+ :argname("<timeout>")
+ :convert(tonumber)
+ :default(5)
+parser:option "-s --server"
+ :description "Override server to ping"
+ :argname("<name>")
+parser:option "-n --number"
+ :description "Timeout for requests"
+ :argname("<number>")
+ :convert(tonumber)
+ :default(5)
+parser:flag "-l --list"
+ :description "List configured storages"
+
+local function load_config(opts)
+ local _r, err = rspamd_config:load_ucl(opts['config'])
+
+ if not _r then
+ rspamd_logger.errx('cannot parse %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
+ -- Init the real structure excluding logging and workers
+ _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' })
+ if not _r then
+ rspamd_logger.errx('cannot process %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
+ _r, err = rspamd_config:init_modules()
+ if not _r then
+ rspamd_logger.errx('cannot init modules from %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+end
+
+local function highlight(fmt, ...)
+ return ansicolors.white .. string.format(fmt, ...) .. ansicolors.reset
+end
+
+local function highlight_err(fmt, ...)
+ return ansicolors.red .. string.format(fmt, ...) .. ansicolors.reset
+end
+
+local function print_storages(rules)
+ for n, rule in pairs(rules) do
+ print(highlight('Rule: %s', n))
+ print(string.format("\tRead only: %s", rule.read_only))
+ print(string.format("\tServers: %s", table.concat(lua_util.values(rule.servers), ',')))
+ print("\tFlags:")
+
+ for fl, id in pairs(rule.flags or E) do
+ print(string.format("\t\t%s: %s", fl, id))
+ end
+ end
+end
+
+local function std_mean(tbl)
+ local function mean()
+ local sum = 0
+ local count = 0
+
+ for _, v in ipairs(tbl) do
+ sum = sum + v
+ count = count + 1
+ end
+
+ return (sum / count)
+ end
+
+ local m
+ local vm
+ local sum = 0
+ local count = 0
+ local result
+
+ m = mean(tbl)
+
+ for _, v in ipairs(tbl) do
+ vm = v - m
+ sum = sum + (vm * vm)
+ count = count + 1
+ end
+
+ result = math.sqrt(sum / (count - 1))
+
+ return result, m
+end
+
+local function maxmin(tbl)
+ local max = -math.huge
+ local min = math.huge
+
+ for _, v in ipairs(tbl) do
+ max = math.max(max, v)
+ min = math.min(min, v)
+ end
+
+ return max, min
+end
+
+local function print_results(results)
+ local servers = {}
+ local err_servers = {}
+ for _, res in ipairs(results) do
+ if res.success then
+ if servers[res.server] then
+ table.insert(servers[res.server], res.latency)
+ else
+ servers[res.server] = { res.latency }
+ end
+ else
+ if err_servers[res.server] then
+ err_servers[res.server] = err_servers[res.server] + 1
+ else
+ err_servers[res.server] = 1
+ end
+ -- For the case if no successful replies are detected
+ if not servers[res.server] then
+ servers[res.server] = {}
+ end
+ end
+ end
+
+ for s, l in pairs(servers) do
+ local total = #l + (err_servers[s] or 0)
+ print(highlight('Summary for %s: %d packets transmitted, %d packets received, %.1f%% packet loss',
+ s, total, #l, (total - #l) * 100.0 / total))
+ local mean, std = std_mean(l)
+ local max, min = maxmin(l)
+ print(string.format('round-trip min/avg/max/std-dev = %.2f/%.2f/%.2f/%.2f ms',
+ min, mean,
+ max, std))
+ end
+end
+
+local function handler(args)
+ local opts = parser:parse(args)
+
+ load_config(opts)
+
+ if opts.list then
+ print_storages(rspamd_plugins.fuzzy_check.list_storages(rspamd_config))
+ os.exit(0)
+ end
+
+ -- Perform ping using a fake task from async stuff provided by rspamadm
+ local rspamd_task = require "rspamd_task"
+
+ -- TODO: this task is not cleared at the end, do something about it some day
+ local task = rspamd_task.create(rspamd_config, rspamadm_ev_base)
+ task:set_session(rspamadm_session)
+ task:set_resolver(rspamadm_dns_resolver)
+
+ local replied = 0
+ local results = {}
+ local ping_fuzzy
+
+ local function gen_ping_fuzzy_cb(num)
+ return function(success, server, latency_or_err)
+ if not success then
+ if not opts.silent then
+ print(highlight_err('error from %s: %s', server, latency_or_err))
+ end
+ results[num] = {
+ success = false,
+ error = latency_or_err,
+ server = tostring(server),
+ }
+ else
+ if not opts.silent then
+ local adjusted_latency = math.floor(latency_or_err * 1000) * 1.0 / 1000;
+ print(highlight('reply from %s: %s ms', server, adjusted_latency))
+
+ end
+ results[num] = {
+ success = true,
+ latency = latency_or_err,
+ server = tostring(server),
+ }
+ end
+
+ if replied == opts.number - 1 then
+ print_results(results)
+ else
+ replied = replied + 1
+ if not opts.flood then
+ ping_fuzzy(replied + 1)
+ end
+ end
+ end
+ end
+
+ ping_fuzzy = function(num)
+ local ret, err = rspamd_plugins.fuzzy_check.ping_storage(task, gen_ping_fuzzy_cb(num),
+ opts.rule, opts.timeout, opts.server)
+
+ if not ret then
+ print(highlight_err('error from %s: %s', opts.server, err))
+ opts.number = opts.number - 1 -- To avoid issues with waiting for other replies
+ end
+ end
+
+ if opts.flood then
+ for i = 1, opts.number do
+ ping_fuzzy(i)
+ end
+ else
+ ping_fuzzy(1)
+ end
+end
+
+return {
+ name = 'fuzzy_ping',
+ aliases = { 'fuzzyping' },
+ handler = handler,
+ description = parser._description
+} \ No newline at end of file
diff --git a/lualib/rspamadm/fuzzy_stat.lua b/lualib/rspamadm/fuzzy_stat.lua
new file mode 100644
index 0000000..ef8a5de
--- /dev/null
+++ b/lualib/rspamadm/fuzzy_stat.lua
@@ -0,0 +1,366 @@
+local rspamd_util = require "rspamd_util"
+local lua_util = require "lua_util"
+local opts = {}
+
+local argparse = require "argparse"
+local parser = argparse()
+ :name "rspamadm control fuzzystat"
+ :description "Shows help for the specified configuration options"
+ :help_description_margin(32)
+parser:flag "--no-ips"
+ :description "No IPs stats"
+parser:flag "--no-keys"
+ :description "No keys stats"
+parser:flag "--short"
+ :description "Short output mode"
+parser:flag "-n --number"
+ :description "Disable numbers humanization"
+parser:option "--sort"
+ :description "Sort order"
+ :convert {
+ checked = "checked",
+ matched = "matched",
+ errors = "errors",
+ name = "name"
+}
+
+local function add_data(target, src)
+ for k, v in pairs(src) do
+ if type(v) == 'number' then
+ if target[k] then
+ target[k] = target[k] + v
+ else
+ target[k] = v
+ end
+ elseif k == 'ips' then
+ if not target['ips'] then
+ target['ips'] = {}
+ end
+ -- Iterate over IPs
+ for ip, st in pairs(v) do
+ if not target['ips'][ip] then
+ target['ips'][ip] = {}
+ end
+ add_data(target['ips'][ip], st)
+ end
+ elseif k == 'flags' then
+ if not target['flags'] then
+ target['flags'] = {}
+ end
+ -- Iterate over Flags
+ for flag, st in pairs(v) do
+ if not target['flags'][flag] then
+ target['flags'][flag] = {}
+ end
+ add_data(target['flags'][flag], st)
+ end
+ elseif k == 'keypair' then
+ if type(v.extensions) == 'table' then
+ if type(v.extensions.name) == 'string' then
+ target.name = v.extensions.name
+ end
+ end
+ end
+ end
+end
+
+local function print_num(num)
+ if num then
+ if opts['n'] or opts['number'] then
+ return tostring(num)
+ else
+ return rspamd_util.humanize_number(num)
+ end
+ else
+ return 'na'
+ end
+end
+
+local function print_stat(st, tabs)
+ if st['checked'] then
+ if st.checked_per_hour then
+ print(string.format('%sChecked: %s (%s per hour in average)', tabs,
+ print_num(st['checked']), print_num(st['checked_per_hour'])))
+ else
+ print(string.format('%sChecked: %s', tabs, print_num(st['checked'])))
+ end
+ end
+ if st['matched'] then
+ if st.checked and st.checked > 0 and st.checked <= st.matched then
+ local percentage = st.matched / st.checked * 100.0
+ if st.matched_per_hour then
+ print(string.format('%sMatched: %s - %s percent (%s per hour in average)', tabs,
+ print_num(st['matched']), percentage, print_num(st['matched_per_hour'])))
+ else
+ print(string.format('%sMatched: %s - %s percent', tabs, print_num(st['matched']), percentage))
+ end
+ else
+ if st.matched_per_hour then
+ print(string.format('%sMatched: %s (%s per hour in average)', tabs,
+ print_num(st['matched']), print_num(st['matched_per_hour'])))
+ else
+ print(string.format('%sMatched: %s', tabs, print_num(st['matched'])))
+ end
+ end
+ end
+ if st['errors'] then
+ print(string.format('%sErrors: %s', tabs, print_num(st['errors'])))
+ end
+ if st['added'] then
+ print(string.format('%sAdded: %s', tabs, print_num(st['added'])))
+ end
+ if st['deleted'] then
+ print(string.format('%sDeleted: %s', tabs, print_num(st['deleted'])))
+ end
+end
+
+-- Sort by checked
+local function sort_hash_table(tbl, sort_opts, key_key)
+ local res = {}
+ for k, v in pairs(tbl) do
+ table.insert(res, { [key_key] = k, data = v })
+ end
+
+ local function sort_order(elt)
+ local key = 'checked'
+ local sort_res = 0
+
+ if sort_opts['sort'] then
+ if sort_opts['sort'] == 'matched' then
+ key = 'matched'
+ elseif sort_opts['sort'] == 'errors' then
+ key = 'errors'
+ elseif sort_opts['sort'] == 'name' then
+ return elt[key_key]
+ end
+ end
+
+ if elt.data[key] then
+ sort_res = elt.data[key]
+ end
+
+ return sort_res
+ end
+
+ table.sort(res, function(a, b)
+ return sort_order(a) > sort_order(b)
+ end)
+
+ return res
+end
+
+local function add_result(dst, src, k)
+ if type(src) == 'table' then
+ if type(dst) == 'number' then
+ -- Convert dst to table
+ dst = { dst }
+ elseif type(dst) == 'nil' then
+ dst = {}
+ end
+
+ for i, v in ipairs(src) do
+ if dst[i] and k ~= 'fuzzy_stored' then
+ dst[i] = dst[i] + v
+ else
+ dst[i] = v
+ end
+ end
+ else
+ if type(dst) == 'table' then
+ if k ~= 'fuzzy_stored' then
+ dst[1] = dst[1] + src
+ else
+ dst[1] = src
+ end
+ else
+ if dst and k ~= 'fuzzy_stored' then
+ dst = dst + src
+ else
+ dst = src
+ end
+ end
+ end
+
+ return dst
+end
+
+local function print_result(r)
+ local function num_to_epoch(num)
+ if num == 1 then
+ return 'v0.6'
+ elseif num == 2 then
+ return 'v0.8'
+ elseif num == 3 then
+ return 'v0.9'
+ elseif num == 4 then
+ return 'v1.0+'
+ elseif num == 5 then
+ return 'v1.7+'
+ end
+ return '???'
+ end
+ if type(r) == 'table' then
+ local res = {}
+ for i, num in ipairs(r) do
+ res[i] = string.format('(%s: %s)', num_to_epoch(i), print_num(num))
+ end
+
+ return table.concat(res, ', ')
+ end
+
+ return print_num(r)
+end
+
+return function(args, res)
+ local res_ips = {}
+ local res_databases = {}
+ local wrk = res['workers']
+ opts = parser:parse(args)
+
+ if wrk then
+ for _, pr in pairs(wrk) do
+ -- processes cycle
+ if pr['data'] then
+ local id = pr['id']
+
+ if id then
+ local res_db = res_databases[id]
+ if not res_db then
+ res_db = {
+ keys = {}
+ }
+ res_databases[id] = res_db
+ end
+
+ -- General stats
+ for k, v in pairs(pr['data']) do
+ if k ~= 'keys' and k ~= 'errors_ips' then
+ res_db[k] = add_result(res_db[k], v, k)
+ elseif k == 'errors_ips' then
+ -- Errors ips
+ if not res_db['errors_ips'] then
+ res_db['errors_ips'] = {}
+ end
+ for ip, nerrors in pairs(v) do
+ if not res_db['errors_ips'][ip] then
+ res_db['errors_ips'][ip] = nerrors
+ else
+ res_db['errors_ips'][ip] = nerrors + res_db['errors_ips'][ip]
+ end
+ end
+ end
+ end
+
+ if pr['data']['keys'] then
+ local res_keys = res_db['keys']
+ if not res_keys then
+ res_keys = {}
+ res_db['keys'] = res_keys
+ end
+ -- Go through keys in input
+ for k, elts in pairs(pr['data']['keys']) do
+ -- keys cycle
+ if not res_keys[k] then
+ res_keys[k] = {}
+ end
+
+ add_data(res_keys[k], elts)
+
+ if elts['ips'] then
+ for ip, v in pairs(elts['ips']) do
+ if not res_ips[ip] then
+ res_ips[ip] = {}
+ end
+ add_data(res_ips[ip], v)
+ end
+ end
+ end
+ end
+ end
+ end
+ end
+ end
+
+ -- General stats
+ for db, st in pairs(res_databases) do
+ print(string.format('Statistics for storage %s', db))
+
+ for k, v in pairs(st) do
+ if k ~= 'keys' and k ~= 'errors_ips' then
+ print(string.format('%s: %s', k, print_result(v)))
+ end
+ end
+ print('')
+
+ local res_keys = st['keys']
+ if res_keys and not opts['no_keys'] and not opts['short'] then
+ print('Keys statistics:')
+ -- Convert into an array to allow sorting
+ local sorted_keys = sort_hash_table(res_keys, opts, 'key')
+
+ for _, key in ipairs(sorted_keys) do
+ local key_stat = key.data
+ if key_stat.name then
+ print(string.format('Key id: %s, name: %s', key.key, key_stat.name))
+ else
+ print(string.format('Key id: %s', key.key))
+ end
+
+ print_stat(key_stat, '\t')
+
+ if key_stat['ips'] and not opts['no_ips'] then
+ print('')
+ print('\tIPs stat:')
+ local sorted_ips = sort_hash_table(key_stat['ips'], opts, 'ip')
+
+ for _, v in ipairs(sorted_ips) do
+ print(string.format('\t%s', v['ip']))
+ print_stat(v['data'], '\t\t')
+ print('')
+ end
+ end
+
+ if key_stat.flags then
+ print('')
+ print('\tFlags stat:')
+ for flag, v in pairs(key_stat.flags) do
+ print(string.format('\t[%s]:', flag))
+ -- Remove irrelevant fields
+ v.checked = nil
+ print_stat(v, '\t\t')
+ print('')
+ end
+ end
+
+ print('')
+ end
+ end
+ if st['errors_ips'] and not opts['no_ips'] and not opts['short'] then
+ print('')
+ print('Errors IPs statistics:')
+ local ip_stat = st['errors_ips']
+ local ips = lua_util.keys(ip_stat)
+ -- Reverse sort by number of errors
+ table.sort(ips, function(a, b)
+ return ip_stat[a] > ip_stat[b]
+ end)
+ for _, ip in ipairs(ips) do
+ print(string.format('%s: %s', ip, print_result(ip_stat[ip])))
+ end
+ print('')
+ end
+ end
+
+ if not opts['no_ips'] and not opts['short'] then
+ print('')
+ print('IPs statistics:')
+
+ local sorted_ips = sort_hash_table(res_ips, opts, 'ip')
+ for _, v in ipairs(sorted_ips) do
+ print(string.format('%s', v['ip']))
+ print_stat(v['data'], '\t')
+ print('')
+ end
+ end
+end
+
diff --git a/lualib/rspamadm/grep.lua b/lualib/rspamadm/grep.lua
new file mode 100644
index 0000000..6ed0569
--- /dev/null
+++ b/lualib/rspamadm/grep.lua
@@ -0,0 +1,174 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local argparse = require "argparse"
+
+
+-- Define command line options
+local parser = argparse()
+ :name "rspamadm grep"
+ :description "Search for patterns in rspamd logs"
+ :help_description_margin(30)
+parser:mutex(
+ parser:option "-s --string"
+ :description('Plain string to search (case-insensitive)')
+ :argname "<str>",
+ parser:option "-p --pattern"
+ :description('Pattern to search for (regex)')
+ :argname "<re>"
+)
+parser:flag "-l --lua"
+ :description('Use Lua patterns in string search')
+
+parser:argument "input":args "*"
+ :description('Process specified inputs')
+ :default("stdin")
+parser:flag "-S --sensitive"
+ :description('Enable case-sensitivity in string search')
+parser:flag "-o --orphans"
+ :description('Print orphaned logs')
+parser:flag "-P --partial"
+ :description('Print partial logs')
+
+local function handler(args)
+
+ local rspamd_regexp = require 'rspamd_regexp'
+ local res = parser:parse(args)
+
+ if not res['string'] and not res['pattern'] then
+ parser:error('string or pattern options must be specified')
+ end
+
+ if res['string'] and res['pattern'] then
+ parser:error('string and pattern options are mutually exclusive')
+ end
+
+ local buffer = {}
+ local matches = {}
+
+ local pattern = res['pattern']
+ local re
+ if pattern then
+ re = rspamd_regexp.create(pattern)
+ if not re then
+ io.stderr:write("Couldn't compile regex: " .. pattern .. '\n')
+ os.exit(1)
+ end
+ end
+
+ local plainm = true
+ if res['lua'] then
+ plainm = false
+ end
+ local orphans = res['orphans']
+ local search_str = res['string']
+ local sensitive = res['sensitive']
+ local partial = res['partial']
+ if search_str and not sensitive then
+ search_str = string.lower(search_str)
+ end
+ local inputs = res['input'] or { 'stdin' }
+
+ for _, n in ipairs(inputs) do
+ local h, err
+ if string.match(n, '%.xz$') then
+ h, err = io.popen('xzcat ' .. n, 'r')
+ elseif string.match(n, '%.bz2$') then
+ h, err = io.popen('bzcat ' .. n, 'r')
+ elseif string.match(n, '%.gz$') then
+ h, err = io.popen('zcat ' .. n, 'r')
+ elseif string.match(n, '%.zst$') then
+ h, err = io.popen('zstdcat ' .. n, 'r')
+ elseif n == 'stdin' then
+ h = io.input()
+ else
+ h, err = io.open(n, 'r')
+ end
+ if not h then
+ if err then
+ io.stderr:write("Couldn't open file (" .. n .. '): ' .. err .. '\n')
+ else
+ io.stderr:write("Couldn't open file (" .. n .. '): no error\n')
+ end
+ else
+ for line in h:lines() do
+ local hash = string.match(line, '<(%x+)>')
+ local already_matching = false
+ if hash then
+ if matches[hash] then
+ table.insert(matches[hash], line)
+ already_matching = true
+ else
+ if buffer[hash] then
+ table.insert(buffer[hash], line)
+ else
+ buffer[hash] = { line }
+ end
+ end
+ end
+ local ismatch = false
+ if re then
+ ismatch = re:match(line)
+ elseif sensitive and search_str then
+ ismatch = string.find(line, search_str, 1, plainm)
+ elseif search_str then
+ local lwr = string.lower(line)
+ ismatch = string.find(lwr, search_str, 1, plainm)
+ end
+ if ismatch then
+ if not hash then
+ if orphans then
+ print('*** orphaned ***')
+ print(line)
+ print()
+ end
+ elseif not already_matching then
+ matches[hash] = buffer[hash]
+ end
+ end
+ local is_end = string.match(line, '<%x+>; task; rspamd_protocol_http_reply:')
+ if is_end then
+ buffer[hash] = nil
+ if matches[hash] then
+ for _, v in ipairs(matches[hash]) do
+ print(v)
+ end
+ print()
+ matches[hash] = nil
+ end
+ end
+ end
+ if partial then
+ for k, v in pairs(matches) do
+ print('*** partial ***')
+ for _, vv in ipairs(v) do
+ print(vv)
+ end
+ print()
+ matches[k] = nil
+ end
+ else
+ matches = {}
+ end
+ end
+ end
+end
+
+return {
+ handler = handler,
+ description = parser._description,
+ name = 'grep'
+} \ No newline at end of file
diff --git a/lualib/rspamadm/keypair.lua b/lualib/rspamadm/keypair.lua
new file mode 100644
index 0000000..f0716a2
--- /dev/null
+++ b/lualib/rspamadm/keypair.lua
@@ -0,0 +1,508 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local argparse = require "argparse"
+local rspamd_keypair = require "rspamd_cryptobox_keypair"
+local rspamd_pubkey = require "rspamd_cryptobox_pubkey"
+local rspamd_signature = require "rspamd_cryptobox_signature"
+local rspamd_crypto = require "rspamd_cryptobox"
+local rspamd_util = require "rspamd_util"
+local ucl = require "ucl"
+local logger = require "rspamd_logger"
+
+-- Define command line options
+local parser = argparse()
+ :name "rspamadm keypair"
+ :description "Manages keypairs for Rspamd"
+ :help_description_margin(30)
+ :command_target("command")
+ :require_command(false)
+
+-- Generate subcommand
+local generate = parser:command "generate gen g"
+ :description "Creates a new keypair"
+generate:flag "-s --sign"
+ :description "Generates a sign keypair instead of the encryption one"
+generate:flag "-n --nist"
+ :description "Uses nist encryption algorithm"
+generate:option "-o --output"
+ :description "Write keypair to file"
+ :argname "<file>"
+generate:mutex(
+ generate:flag "-j --json"
+ :description "Output JSON instead of UCL",
+ generate:flag "-u --ucl"
+ :description "Output UCL"
+ :default(true)
+)
+generate:option "--name"
+ :description "Adds name extension"
+ :argname "<name>"
+
+-- Sign subcommand
+
+local sign = parser:command "sign sig s"
+ :description "Signs a file using keypair"
+sign:option "-k --keypair"
+ :description "Keypair to use"
+ :argname "<file>"
+sign:option "-s --suffix"
+ :description "Suffix for signature"
+ :argname "<suffix>"
+ :default("sig")
+sign:argument "file"
+ :description "File to sign"
+ :argname "<file>"
+ :args "*"
+
+-- Verify subcommand
+
+local verify = parser:command "verify ver v"
+ :description "Verifies a file using keypair or a public key"
+verify:mutex(
+ verify:option "-p --pubkey"
+ :description "Load pubkey from the specified file"
+ :argname "<file>",
+ verify:option "-P --pubstring"
+ :description "Load pubkey from the base32 encoded string"
+ :argname "<base32>",
+ verify:option "-k --keypair"
+ :description "Get pubkey from the keypair file"
+ :argname "<file>"
+)
+verify:argument "file"
+ :description "File to verify"
+ :argname "<file>"
+ :args "*"
+verify:flag "-n --nist"
+ :description "Uses nistp curves (P256)"
+verify:option "-s --suffix"
+ :description "Suffix for signature"
+ :argname "<suffix>"
+ :default("sig")
+
+-- Encrypt subcommand
+
+local encrypt = parser:command "encrypt crypt enc e"
+ :description "Encrypts a file using keypair (or a pubkey)"
+encrypt:mutex(
+ encrypt:option "-p --pubkey"
+ :description "Load pubkey from the specified file"
+ :argname "<file>",
+ encrypt:option "-P --pubstring"
+ :description "Load pubkey from the base32 encoded string"
+ :argname "<base32>",
+ encrypt:option "-k --keypair"
+ :description "Get pubkey from the keypair file"
+ :argname "<file>"
+)
+encrypt:option "-s --suffix"
+ :description "Suffix for encrypted file"
+ :argname "<suffix>"
+ :default("enc")
+encrypt:argument "file"
+ :description "File to encrypt"
+ :argname "<file>"
+ :args "*"
+encrypt:flag "-r --rm"
+ :description "Remove unencrypted file"
+encrypt:flag "-f --force"
+ :description "Remove destination file if it exists"
+
+-- Decrypt subcommand
+
+local decrypt = parser:command "decrypt dec d"
+ :description "Decrypts a file using keypair"
+decrypt:option "-k --keypair"
+ :description "Get pubkey from the keypair file"
+ :argname "<file>"
+decrypt:flag "-S --keep-suffix"
+ :description "Preserve suffix for decrypted file (overwrite encrypted)"
+decrypt:argument "file"
+ :description "File to encrypt"
+ :argname "<file>"
+ :args "*"
+decrypt:flag "-f --force"
+ :description "Remove destination file if it exists (implied with -S)"
+decrypt:flag "-r --rm"
+ :description "Remove encrypted file"
+
+-- Default command is generate, so duplicate options to be compatible
+
+parser:flag "-s --sign"
+ :description "Generates a sign keypair instead of the encryption one"
+parser:flag "-n --nist"
+ :description "Uses nistp curves (P256)"
+parser:mutex(
+ parser:flag "-j --json"
+ :description "Output JSON instead of UCL",
+ parser:flag "-u --ucl"
+ :description "Output UCL"
+ :default(true)
+)
+parser:option "-o --output"
+ :description "Write keypair to file"
+ :argname "<file>"
+
+local function fatal(...)
+ logger.errx(...)
+ os.exit(1)
+end
+
+local function ask_yes_no(greet, default)
+ local def_str
+ if default then
+ greet = greet .. "[Y/n]: "
+ def_str = "yes"
+ else
+ greet = greet .. "[y/N]: "
+ def_str = "no"
+ end
+
+ local reply = rspamd_util.readline(greet)
+
+ if not reply then
+ os.exit(0)
+ end
+ if #reply == 0 then
+ reply = def_str
+ end
+ reply = reply:lower()
+ if reply == 'y' or reply == 'yes' then
+ return true
+ end
+
+ return false
+end
+
+local function generate_handler(opts)
+ local mode = 'encryption'
+ if opts.sign then
+ mode = 'sign'
+ end
+ local alg = 'curve25519'
+ if opts.nist then
+ alg = 'nist'
+ end
+ -- TODO: probably, do it in a more safe way
+ local kp = rspamd_keypair.create(mode, alg):totable()
+
+ if opts.name then
+ kp.keypair.extensions = {
+ name = opts.name
+ }
+ end
+
+ local format = 'ucl'
+
+ if opts.json then
+ format = 'json'
+ end
+
+ if opts.output then
+ local out = io.open(opts.output, 'w')
+ if not out then
+ fatal('cannot open output to write: ' .. opts.output)
+ end
+ out:write(ucl.to_format(kp, format))
+ out:close()
+ else
+ io.write(ucl.to_format(kp, format))
+ end
+end
+
+local function sign_handler(opts)
+ if opts.file then
+ if type(opts.file) == 'string' then
+ opts.file = { opts.file }
+ end
+ else
+ parser:error('no files to sign')
+ end
+ if not opts.keypair then
+ parser:error("no keypair specified")
+ end
+
+ local ucl_parser = ucl.parser()
+ local res, err = ucl_parser:parse_file(opts.keypair)
+
+ if not res then
+ fatal(string.format('cannot load %s: %s', opts.keypair, err))
+ end
+
+ local kp = rspamd_keypair.load(ucl_parser:get_object())
+
+ if not kp then
+ fatal("cannot load keypair: " .. opts.keypair)
+ end
+
+ for _, fname in ipairs(opts.file) do
+ local sig = rspamd_crypto.sign_file(kp, fname)
+
+ if not sig then
+ fatal(string.format("cannot sign %s\n", fname))
+ end
+
+ local out = string.format('%s.%s', fname, opts.suffix or 'sig')
+ local of = io.open(out, 'w')
+ if not of then
+ fatal('cannot open output to write: ' .. out)
+ end
+ of:write(sig:bin())
+ of:close()
+ io.write(string.format('signed %s -> %s (%s)\n', fname, out, sig:hex()))
+ end
+end
+
+local function verify_handler(opts)
+ if opts.file then
+ if type(opts.file) == 'string' then
+ opts.file = { opts.file }
+ end
+ else
+ parser:error('no files to verify')
+ end
+
+ local pk
+ local alg = 'curve25519'
+
+ if opts.keypair then
+ local ucl_parser = ucl.parser()
+ local res, err = ucl_parser:parse_file(opts.keypair)
+
+ if not res then
+ fatal(string.format('cannot load %s: %s', opts.keypair, err))
+ end
+
+ local kp = rspamd_keypair.load(ucl_parser:get_object())
+
+ if not kp then
+ fatal("cannot load keypair: " .. opts.keypair)
+ end
+
+ pk = kp:pk()
+ alg = kp:alg()
+ elseif opts.pubkey then
+ if opts.nist then
+ alg = 'nist'
+ end
+ pk = rspamd_pubkey.load(opts.pubkey, 'sign', alg)
+ elseif opts.pubstr then
+ if opts.nist then
+ alg = 'nist'
+ end
+ pk = rspamd_pubkey.create(opts.pubstr, 'sign', alg)
+ end
+
+ if not pk then
+ fatal("cannot create pubkey")
+ end
+
+ local valid = true
+
+ for _, fname in ipairs(opts.file) do
+
+ local sig_fname = string.format('%s.%s', fname, opts.suffix or 'sig')
+ local sig = rspamd_signature.load(sig_fname, alg)
+
+ if not sig then
+ fatal(string.format("cannot load signature for %s -> %s",
+ fname, sig_fname))
+ end
+
+ if rspamd_crypto.verify_file(pk, sig, fname, alg) then
+ io.write(string.format('verified %s -> %s (%s)\n', fname, sig_fname, sig:hex()))
+ else
+ valid = false
+ io.write(string.format('FAILED to verify %s -> %s (%s)\n', fname,
+ sig_fname, sig:hex()))
+ end
+ end
+
+ if not valid then
+ os.exit(1)
+ end
+end
+
+local function encrypt_handler(opts)
+ if opts.file then
+ if type(opts.file) == 'string' then
+ opts.file = { opts.file }
+ end
+ else
+ parser:error('no files to sign')
+ end
+
+ local pk
+ local alg = 'curve25519'
+
+ if opts.keypair then
+ local ucl_parser = ucl.parser()
+ local res, err = ucl_parser:parse_file(opts.keypair)
+
+ if not res then
+ fatal(string.format('cannot load %s: %s', opts.keypair, err))
+ end
+
+ local kp = rspamd_keypair.load(ucl_parser:get_object())
+
+ if not kp then
+ fatal("cannot load keypair: " .. opts.keypair)
+ end
+
+ pk = kp:pk()
+ alg = kp:alg()
+ elseif opts.pubkey then
+ if opts.nist then
+ alg = 'nist'
+ end
+ pk = rspamd_pubkey.load(opts.pubkey, 'sign', alg)
+ elseif opts.pubstr then
+ if opts.nist then
+ alg = 'nist'
+ end
+ pk = rspamd_pubkey.create(opts.pubstr, 'sign', alg)
+ end
+
+ if not pk then
+ fatal("cannot load keypair: " .. opts.keypair)
+ end
+
+ for _, fname in ipairs(opts.file) do
+ local enc = rspamd_crypto.encrypt_file(pk, fname, alg)
+
+ if not enc then
+ fatal(string.format("cannot encrypt %s\n", fname))
+ end
+
+ local out
+ if opts.suffix and #opts.suffix > 0 then
+ out = string.format('%s.%s', fname, opts.suffix)
+ else
+ out = string.format('%s', fname)
+ end
+
+ if rspamd_util.file_exists(out) then
+ if opts.force or ask_yes_no(string.format('File %s already exists, overwrite?',
+ out), true) then
+ os.remove(out)
+ else
+ os.exit(1)
+ end
+ end
+
+ enc:save_in_file(out)
+
+ if opts.rm then
+ os.remove(fname)
+ io.write(string.format('encrypted %s (deleted) -> %s\n', fname, out))
+ else
+ io.write(string.format('encrypted %s -> %s\n', fname, out))
+ end
+ end
+end
+
+local function decrypt_handler(opts)
+ if opts.file then
+ if type(opts.file) == 'string' then
+ opts.file = { opts.file }
+ end
+ else
+ parser:error('no files to decrypt')
+ end
+ if not opts.keypair then
+ parser:error("no keypair specified")
+ end
+
+ local ucl_parser = ucl.parser()
+ local res, err = ucl_parser:parse_file(opts.keypair)
+
+ if not res then
+ fatal(string.format('cannot load %s: %s', opts.keypair, err))
+ end
+
+ local kp = rspamd_keypair.load(ucl_parser:get_object())
+
+ if not kp then
+ fatal("cannot load keypair: " .. opts.keypair)
+ end
+
+ for _, fname in ipairs(opts.file) do
+ local decrypted = rspamd_crypto.decrypt_file(kp, fname)
+
+ if not decrypted then
+ fatal(string.format("cannot decrypt %s\n", fname))
+ end
+
+ local out
+ if not opts['keep-suffix'] then
+ -- Strip the last suffix
+ out = fname:match("^(.+)%..+$")
+ else
+ out = fname
+ end
+
+ local removed = false
+
+ if rspamd_util.file_exists(out) then
+ if (opts.force or opts['keep-suffix'])
+ or ask_yes_no(string.format('File %s already exists, overwrite?', out), true) then
+ os.remove(out)
+ removed = true
+ else
+ os.exit(1)
+ end
+ end
+
+ if opts.rm then
+ os.remove(fname)
+ removed = true
+ end
+
+ if removed then
+ io.write(string.format('decrypted %s (removed) -> %s\n', fname, out))
+ else
+ io.write(string.format('decrypted %s -> %s\n', fname, out))
+ end
+ end
+end
+
+local function handler(args)
+ local opts = parser:parse(args)
+
+ local command = opts.command or "generate"
+
+ if command == 'generate' then
+ generate_handler(opts)
+ elseif command == 'sign' then
+ sign_handler(opts)
+ elseif command == 'verify' then
+ verify_handler(opts)
+ elseif command == 'encrypt' then
+ encrypt_handler(opts)
+ elseif command == 'decrypt' then
+ decrypt_handler(opts)
+ else
+ parser:error('command %s is not implemented', command)
+ end
+end
+
+return {
+ name = 'keypair',
+ aliases = { 'kp', 'key' },
+ handler = handler,
+ description = parser._description
+} \ No newline at end of file
diff --git a/lualib/rspamadm/mime.lua b/lualib/rspamadm/mime.lua
new file mode 100644
index 0000000..6a589d6
--- /dev/null
+++ b/lualib/rspamadm/mime.lua
@@ -0,0 +1,1012 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local argparse = require "argparse"
+local ansicolors = require "ansicolors"
+local rspamd_util = require "rspamd_util"
+local rspamd_task = require "rspamd_task"
+local rspamd_text = require "rspamd_text"
+local rspamd_logger = require "rspamd_logger"
+local lua_meta = require "lua_meta"
+local rspamd_url = require "rspamd_url"
+local lua_util = require "lua_util"
+local lua_mime = require "lua_mime"
+local ucl = require "ucl"
+
+-- Define command line options
+local parser = argparse()
+ :name "rspamadm mime"
+ :description "Mime manipulations provided by Rspamd"
+ :help_description_margin(30)
+ :command_target("command")
+ :require_command(true)
+
+parser:option "-c --config"
+ :description "Path to config file"
+ :argname("<cfg>")
+ :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
+parser:mutex(
+ parser:flag "-j --json"
+ :description "JSON output",
+ parser:flag "-U --ucl"
+ :description "UCL output",
+ parser:flag "-M --messagepack"
+ :description "MessagePack output"
+)
+parser:flag "-C --compact"
+ :description "Use compact format"
+parser:flag "--no-file"
+ :description "Do not print filename"
+
+-- Extract subcommand
+local extract = parser:command "extract ex e"
+ :description "Extracts data from MIME messages"
+extract:argument "file"
+ :description "File to process"
+ :argname "<file>"
+ :args "+"
+
+extract:flag "-t --text"
+ :description "Extracts plain text data from a message"
+extract:flag "-H --html"
+ :description "Extracts htm data from a message"
+extract:option "-o --output"
+ :description "Output format ('raw', 'content', 'oneline', 'decoded', 'decoded_utf')"
+ :argname("<type>")
+ :convert {
+ raw = "raw",
+ content = "content",
+ oneline = "content_oneline",
+ decoded = "raw_parsed",
+ decoded_utf = "raw_utf"
+}
+ :default "content"
+extract:flag "-w --words"
+ :description "Extracts words"
+extract:flag "-p --part"
+ :description "Show part info"
+extract:flag "-s --structure"
+ :description "Show structure info (e.g. HTML tags)"
+extract:flag "-i --invisible"
+ :description "Show invisible content for HTML parts"
+extract:option "-F --words-format"
+ :description "Words format ('stem', 'norm', 'raw', 'full')"
+ :argname("<type>")
+ :convert {
+ stem = "stem",
+ norm = "norm",
+ raw = "raw",
+ full = "full",
+}
+ :default "stem"
+
+local stat = parser:command "stat st s"
+ :description "Extracts statistical data from MIME messages"
+stat:argument "file"
+ :description "File to process"
+ :argname "<file>"
+ :args "+"
+stat:mutex(
+ stat:flag "-m --meta"
+ :description "Lua metatokens",
+ stat:flag "-b --bayes"
+ :description "Bayes tokens",
+ stat:flag "-F --fuzzy"
+ :description "Fuzzy hashes"
+)
+stat:flag "-s --shingles"
+ :description "Show shingles for fuzzy hashes"
+
+local urls = parser:command "urls url u"
+ :description "Extracts URLs from MIME messages"
+urls:argument "file"
+ :description "File to process"
+ :argname "<file>"
+ :args "+"
+urls:mutex(
+ urls:flag "-t --tld"
+ :description "Get TLDs only",
+ urls:flag "-H --host"
+ :description "Get hosts only",
+ urls:flag "-f --full"
+ :description "Show piecewise urls as processed by Rspamd"
+)
+
+urls:flag "-u --unique"
+ :description "Print only unique urls"
+urls:flag "-s --sort"
+ :description "Sort output"
+urls:flag "--count"
+ :description "Print count of each printed element"
+urls:flag "-r --reverse"
+ :description "Reverse sort order"
+
+local modify = parser:command "modify mod m"
+ :description "Modifies MIME message"
+modify:argument "file"
+ :description "File to process"
+ :argname "<file>"
+ :args "+"
+
+modify:option "-a --add-header"
+ :description "Adds specific header"
+ :argname "<header=value>"
+ :count "*"
+modify:option "-r --remove-header"
+ :description "Removes specific header (all occurrences)"
+ :argname "<header>"
+ :count "*"
+modify:option "-R --rewrite-header"
+ :description "Rewrites specific header, uses Lua string.format pattern"
+ :argname "<header=pattern>"
+ :count "*"
+modify:option "-t --text-footer"
+ :description "Adds footer to text/plain parts from a specific file"
+ :argname "<file>"
+modify:option "-H --html-footer"
+ :description "Adds footer to text/html parts from a specific file"
+ :argname "<file>"
+
+local sign = parser:command "sign"
+ :description "Performs DKIM signing"
+sign:argument "file"
+ :description "File to process"
+ :argname "<file>"
+ :args "+"
+
+sign:option "-d --domain"
+ :description "Use specific domain"
+ :argname "<domain>"
+ :count "1"
+sign:option "-s --selector"
+ :description "Use specific selector"
+ :argname "<selector>"
+ :count "1"
+sign:option "-k --key"
+ :description "Use specific key of file"
+ :argname "<key>"
+ :count "1"
+sign:option "-t --type"
+ :description "ARC or DKIM signing"
+ :argname("<arc|dkim>")
+ :convert {
+ ['arc'] = 'arc',
+ ['dkim'] = 'dkim',
+}
+ :default 'dkim'
+sign:option "-o --output"
+ :description "Output format"
+ :argname("<message|signature>")
+ :convert {
+ ['message'] = 'message',
+ ['signature'] = 'signature',
+}
+ :default 'message'
+
+local dump = parser:command "dump"
+ :description "Dumps a raw message in different formats"
+dump:argument "file"
+ :description "File to process"
+ :argname "<file>"
+ :args "+"
+-- Duplicate format for convenience
+dump:mutex(
+ parser:flag "-j --json"
+ :description "JSON output",
+ parser:flag "-U --ucl"
+ :description "UCL output",
+ parser:flag "-M --messagepack"
+ :description "MessagePack output"
+)
+dump:flag "-s --split"
+ :description "Split the output file contents such that no content is embedded"
+
+dump:option "-o --outdir"
+ :description "Output directory"
+ :argname("<directory>")
+
+local function load_config(opts)
+ local _r, err = rspamd_config:load_ucl(opts['config'])
+
+ if not _r then
+ rspamd_logger.errx('cannot parse %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
+ _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' })
+ if not _r then
+ rspamd_logger.errx('cannot process %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+end
+
+local function load_task(opts, fname)
+ if not fname then
+ fname = '-'
+ end
+
+ local res, task = rspamd_task.load_from_file(fname, rspamd_config)
+
+ if not res then
+ parser:error(string.format('cannot read message from %s: %s', fname,
+ task))
+ end
+
+ if not task:process_message() then
+ parser:error(string.format('cannot read message from %s: %s', fname,
+ 'failed to parse'))
+ end
+
+ return task
+end
+
+local function highlight(fmt, ...)
+ return ansicolors.white .. string.format(fmt, ...) .. ansicolors.reset
+end
+
+local function maybe_print_fname(opts, fname)
+ if not opts.json and not opts['no-file'] then
+ rspamd_logger.messagex(highlight('File: %s', fname))
+ end
+end
+
+local function output_fmt(opts)
+ local fmt = 'json'
+ if opts.compact then
+ fmt = 'json-compact'
+ end
+ if opts.ucl then
+ fmt = 'ucl'
+ end
+ if opts.messagepack then
+ fmt = 'msgpack'
+ end
+
+ return fmt
+end
+
+-- Print elements in form
+-- filename -> table of elements
+local function print_elts(elts, opts, func)
+ local fun = require "fun"
+
+ if opts.json or opts.ucl then
+ io.write(ucl.to_format(elts, output_fmt(opts)))
+ else
+ fun.each(function(fname, elt)
+
+ if not opts.json and not opts.ucl then
+ if func then
+ elt = fun.map(func, elt)
+ end
+ maybe_print_fname(opts, fname)
+ fun.each(function(e)
+ io.write(e)
+ io.write("\n")
+ end, elt)
+ end
+ end, elts)
+ end
+end
+
+local function extract_handler(opts)
+ local out_elts = {}
+ local tasks = {}
+ local process_func
+
+ if opts.words then
+ -- Enable stemming and urls detection
+ load_config(opts)
+ rspamd_url.init(rspamd_config:get_tld_path())
+ rspamd_config:init_subsystem('langdet')
+ end
+
+ local function maybe_print_text_part_info(part, out)
+ local fun = require "fun"
+ if opts.part then
+ local t = 'plain text'
+ if part:is_html() then
+ t = 'html'
+ end
+
+ if not opts.json and not opts.ucl then
+ table.insert(out,
+ rspamd_logger.slog('Part: %s: %s, language: %s, size: %s (%s raw), words: %s',
+ part:get_mimepart():get_digest():sub(1, 8),
+ t,
+ part:get_language(),
+ part:get_length(), part:get_raw_length(),
+ part:get_words_count()))
+ table.insert(out,
+ rspamd_logger.slog('Stats: %s',
+ fun.foldl(function(acc, k, v)
+ if acc ~= '' then
+ return string.format('%s, %s:%s', acc, k, v)
+ else
+ return string.format('%s:%s', k, v)
+ end
+ end, '', part:get_stats())))
+ end
+ end
+ end
+
+ local function maybe_print_mime_part_info(part, out)
+ if opts.part then
+
+ if not opts.json and not opts.ucl then
+ local mtype, msubtype = part:get_type()
+ local det_mtype, det_msubtype = part:get_detected_type()
+ table.insert(out,
+ rspamd_logger.slog('Mime Part: %s: %s/%s (%s/%s detected), filename: %s (%s detected ext), size: %s',
+ part:get_digest():sub(1, 8),
+ mtype, msubtype,
+ det_mtype, det_msubtype,
+ part:get_filename(),
+ part:get_detected_ext(),
+ part:get_length()))
+ end
+ end
+ end
+
+ local function print_words(words, full)
+ local fun = require "fun"
+
+ if not full then
+ return table.concat(words, ' ')
+ else
+ return table.concat(
+ fun.totable(
+ fun.map(function(w)
+ -- [1] - stemmed word
+ -- [2] - normalised word
+ -- [3] - raw word
+ -- [4] - flags (table of strings)
+ return string.format('%s|%s|%s(%s)',
+ w[3], w[2], w[1], table.concat(w[4], ','))
+ end, words)
+ ),
+ ' '
+ )
+ end
+ end
+
+ for _, fname in ipairs(opts.file) do
+ local task = load_task(opts, fname)
+ out_elts[fname] = {}
+
+ if not opts.text and not opts.html then
+ opts.text = true
+ opts.html = true
+ end
+
+ if opts.words then
+ local how_words = opts['words_format'] or 'stem'
+ table.insert(out_elts[fname], 'meta_words: ' ..
+ print_words(task:get_meta_words(how_words), how_words == 'full'))
+ end
+
+ if opts.text or opts.html then
+ local mp = task:get_parts() or {}
+
+ for _, mime_part in ipairs(mp) do
+ local how = opts.output
+ local part
+ if mime_part:is_text() then
+ part = mime_part:get_text()
+ end
+
+ if part and opts.text and not part:is_html() then
+ maybe_print_text_part_info(part, out_elts[fname])
+ maybe_print_mime_part_info(mime_part, out_elts[fname])
+ if not opts.json and not opts.ucl then
+ table.insert(out_elts[fname], '\n')
+ end
+
+ if opts.words then
+ local how_words = opts['words_format'] or 'stem'
+ table.insert(out_elts[fname], print_words(part:get_words(how_words),
+ how_words == 'full'))
+ else
+ table.insert(out_elts[fname], tostring(part:get_content(how)))
+ end
+ elseif part and opts.html and part:is_html() then
+ maybe_print_text_part_info(part, out_elts[fname])
+ maybe_print_mime_part_info(mime_part, out_elts[fname])
+ if not opts.json and not opts.ucl then
+ table.insert(out_elts[fname], '\n')
+ end
+
+ if opts.words then
+ local how_words = opts['words_format'] or 'stem'
+ table.insert(out_elts[fname], print_words(part:get_words(how_words),
+ how_words == 'full'))
+ else
+ if opts.structure then
+ local hc = part:get_html()
+ local res = {}
+ process_func = function(elt)
+ local fun = require "fun"
+ if type(elt) == 'table' then
+ return table.concat(fun.totable(
+ fun.map(
+ function(t)
+ return rspamd_logger.slog("%s", t)
+ end,
+ elt)), '\n')
+ else
+ return rspamd_logger.slog("%s", elt)
+ end
+ end
+
+ hc:foreach_tag('any', function(tag)
+ local elt = {}
+ local ex = tag:get_extra()
+ elt.tag = tag:get_type()
+ if ex then
+ elt.extra = ex
+ end
+ local content = tag:get_content()
+ if content then
+ elt.content = tostring(content)
+ end
+ local style = tag:get_style()
+ if style then
+ elt.style = style
+ end
+ table.insert(res, elt)
+ end)
+ table.insert(out_elts[fname], res)
+ else
+ -- opts.structure
+ table.insert(out_elts[fname], tostring(part:get_content(how)))
+ end
+ if opts.invisible then
+ local hc = part:get_html()
+ table.insert(out_elts[fname], string.format('invisible content: %s',
+ tostring(hc:get_invisible())))
+ end
+ end
+ end
+
+ if not part then
+ maybe_print_mime_part_info(mime_part, out_elts[fname])
+ end
+ end
+ end
+
+ table.insert(out_elts[fname], "")
+ table.insert(tasks, task)
+ end
+
+ print_elts(out_elts, opts, process_func)
+ -- To avoid use after free we postpone tasks destruction
+ for _, task in ipairs(tasks) do
+ task:destroy()
+ end
+end
+
+local function stat_handler(opts)
+ local fun = require "fun"
+ local out_elts = {}
+
+ load_config(opts)
+ rspamd_url.init(rspamd_config:get_tld_path())
+ rspamd_config:init_subsystem('langdet,stat') -- Needed to gen stat tokens
+
+ local process_func
+
+ for _, fname in ipairs(opts.file) do
+ local task = load_task(opts, fname)
+ out_elts[fname] = {}
+
+ if opts.meta then
+ local mt = lua_meta.gen_metatokens_table(task)
+ out_elts[fname] = mt
+ process_func = function(k, v)
+ return string.format("%s = %s", k, v)
+ end
+ elseif opts.bayes then
+ local bt = task:get_stat_tokens()
+ out_elts[fname] = bt
+ process_func = function(e)
+ return string.format('%s (%d): "%s"+"%s", [%s]', e.data, e.win, e.t1 or "",
+ e.t2 or "", table.concat(fun.totable(
+ fun.map(function(k)
+ return k
+ end, e.flags)), ","))
+ end
+ elseif opts.fuzzy then
+ local parts = task:get_parts() or {}
+ out_elts[fname] = {}
+ process_func = function(e)
+ local ret = string.format('part: %s(%s): %s', e.type, e.file or "", e.digest)
+ if opts.shingles and e.shingles then
+ local sgl = {}
+ for _, s in ipairs(e.shingles) do
+ table.insert(sgl, string.format('%s: %s+%s+%s', s[1], s[2], s[3], s[4]))
+ end
+
+ ret = ret .. '\n' .. table.concat(sgl, '\n')
+ end
+ return ret
+ end
+ for _, part in ipairs(parts) do
+ if not part:is_multipart() then
+ local text = part:get_text()
+
+ if text then
+ local digest, shingles = text:get_fuzzy_hashes(task:get_mempool())
+ table.insert(out_elts[fname], {
+ digest = digest,
+ shingles = shingles,
+ type = string.format('%s/%s',
+ ({ part:get_type() })[1],
+ ({ part:get_type() })[2])
+ })
+ else
+ table.insert(out_elts[fname], {
+ digest = part:get_digest(),
+ file = part:get_filename(),
+ type = string.format('%s/%s',
+ ({ part:get_type() })[1],
+ ({ part:get_type() })[2])
+ })
+ end
+ end
+ end
+ end
+
+ task:destroy() -- No automatic dtor
+ end
+
+ print_elts(out_elts, opts, process_func)
+end
+
+local function urls_handler(opts)
+ load_config(opts)
+ rspamd_url.init(rspamd_config:get_tld_path())
+ local out_elts = {}
+
+ if opts.json then
+ rspamd_logger.messagex('[')
+ end
+
+ for _, fname in ipairs(opts.file) do
+ out_elts[fname] = {}
+ local task = load_task(opts, fname)
+ local elts = {}
+
+ local function process_url(u)
+ local s
+ if opts.tld then
+ s = u:get_tld()
+ elseif opts.host then
+ s = u:get_host()
+ elseif opts.full then
+ s = rspamd_logger.slog('%s: %s', u:get_text(), u:to_table())
+ else
+ s = u:get_text()
+ end
+
+ if opts.unique then
+ if elts[s] then
+ elts[s].count = elts[s].count + 1
+ else
+ elts[s] = {
+ count = 1,
+ url = u:to_table()
+ }
+ end
+ else
+ if opts.json then
+ table.insert(elts, u)
+ else
+ table.insert(elts, s)
+ end
+ end
+ end
+
+ for _, u in ipairs(task:get_urls(true)) do
+ process_url(u)
+ end
+
+ local json_elts = {}
+
+ local function process_elt(s, u)
+ if opts.unique then
+ -- s is string, u is {url = url, count = count }
+ if not opts.json then
+ if opts.count then
+ table.insert(json_elts, string.format('%s : %s', s, u.count))
+ else
+ table.insert(json_elts, s)
+ end
+ else
+ local tb = u.url
+ tb.count = u.count
+ table.insert(json_elts, tb)
+ end
+ else
+ -- s is index, u is url or string
+ if opts.json then
+ table.insert(json_elts, u)
+ else
+ table.insert(json_elts, u)
+ end
+ end
+ end
+
+ if opts.sort then
+ local sfunc
+ if opts.unique then
+ sfunc = function(t, a, b)
+ if t[a].count ~= t[b].count then
+ if opts.reverse then
+ return t[a].count > t[b].count
+ else
+ return t[a].count < t[b].count
+ end
+ else
+ -- Sort lexicography
+ if opts.reverse then
+ return a > b
+ else
+ return a < b
+ end
+ end
+ end
+ else
+ sfunc = function(t, a, b)
+ local va, vb
+ if opts.json then
+ va = t[a]:get_text()
+ vb = t[b]:get_text()
+ else
+ va = t[a]
+ vb = t[b]
+ end
+ if opts.reverse then
+ return va > vb
+ else
+ return va < vb
+ end
+ end
+ end
+
+ for s, u in lua_util.spairs(elts, sfunc) do
+ process_elt(s, u)
+ end
+ else
+ for s, u in pairs(elts) do
+ process_elt(s, u)
+ end
+ end
+
+ out_elts[fname] = json_elts
+
+ task:destroy() -- No automatic dtor
+ end
+
+ print_elts(out_elts, opts)
+end
+
+local function newline(task)
+ local t = task:get_newlines_type()
+
+ if t == 'cr' then
+ return '\r'
+ elseif t == 'lf' then
+ return '\n'
+ end
+
+ return '\r\n'
+end
+
+local function modify_handler(opts)
+ load_config(opts)
+ rspamd_url.init(rspamd_config:get_tld_path())
+
+ local function read_file(file)
+ local f = assert(io.open(file, "rb"))
+ local content = f:read("*all")
+ f:close()
+ return content
+ end
+
+ local text_footer, html_footer
+
+ if opts['text_footer'] then
+ text_footer = read_file(opts['text_footer'])
+ end
+
+ if opts['html_footer'] then
+ html_footer = read_file(opts['html_footer'])
+ end
+
+ for _, fname in ipairs(opts.file) do
+ local task = load_task(opts, fname)
+ local newline_s = newline(task)
+ local seen_cte
+
+ local rewrite = lua_mime.add_text_footer(task, html_footer, text_footer) or {}
+ local out = {} -- Start with headers
+
+ local function process_headers_cb(name, hdr)
+ for _, h in ipairs(opts['remove_header']) do
+ if name:match(h) then
+ return
+ end
+ end
+
+ for _, h in ipairs(opts['rewrite_header']) do
+ local hname, hpattern = h:match('^([^=]+)=(.+)$')
+ if hname == name then
+ local new_value = string.format(hpattern, hdr.decoded)
+ new_value = string.format('%s:%s%s',
+ name, hdr.separator,
+ rspamd_util.fold_header(name,
+ rspamd_util.mime_header_encode(new_value),
+ task:get_newlines_type()))
+ out[#out + 1] = new_value
+ return
+ end
+ end
+
+ if rewrite.need_rewrite_ct then
+ if name:lower() == 'content-type' then
+ local nct = string.format('%s: %s/%s; charset=utf-8',
+ 'Content-Type', rewrite.new_ct.type, rewrite.new_ct.subtype)
+ out[#out + 1] = nct
+ return
+ elseif name:lower() == 'content-transfer-encoding' then
+ out[#out + 1] = string.format('%s: %s',
+ 'Content-Transfer-Encoding', rewrite.new_cte or 'quoted-printable')
+ seen_cte = true
+ return
+ end
+ end
+
+ out[#out + 1] = hdr.raw:gsub('\r?\n?$', '')
+ end
+
+ task:headers_foreach(process_headers_cb, { full = true })
+
+ for _, h in ipairs(opts['add_header']) do
+ local hname, hvalue = h:match('^([^=]+)=(.+)$')
+
+ if hname and hvalue then
+ out[#out + 1] = string.format('%s: %s', hname,
+ rspamd_util.fold_header(hname, hvalue, task:get_newlines_type()))
+ end
+ end
+
+ if not seen_cte and rewrite.need_rewrite_ct then
+ out[#out + 1] = string.format('%s: %s',
+ 'Content-Transfer-Encoding', rewrite.new_cte or 'quoted-printable')
+ end
+
+ -- End of headers
+ out[#out + 1] = ''
+
+ if rewrite.out then
+ for _, o in ipairs(rewrite.out) do
+ out[#out + 1] = o
+ end
+ else
+ out[#out + 1] = { task:get_rawbody(), false }
+ end
+
+ for _, o in ipairs(out) do
+ if type(o) == 'string' then
+ io.write(o)
+ io.write(newline_s)
+ elseif type(o) == 'table' then
+ io.flush()
+ if type(o[1]) == 'string' then
+ io.write(o[1])
+ else
+ o[1]:save_in_file(1)
+ end
+
+ if o[2] then
+ io.write(newline_s)
+ end
+ else
+ o:save_in_file(1)
+ io.write(newline_s)
+ end
+ end
+
+ task:destroy() -- No automatic dtor
+ end
+end
+
+local function sign_handler(opts)
+ load_config(opts)
+ rspamd_url.init(rspamd_config:get_tld_path())
+
+ local lua_dkim = require("lua_ffi").dkim
+
+ if not lua_dkim then
+ io.stderr:write('FFI support is required: please use LuaJIT or install lua-ffi')
+ os.exit(1)
+ end
+
+ local sign_key
+ if rspamd_util.file_exists(opts.key) then
+ sign_key = lua_dkim.load_sign_key(opts.key, 'file')
+ else
+ sign_key = lua_dkim.load_sign_key(opts.key, 'base64')
+ end
+
+ if not sign_key then
+ io.stderr:write('Cannot load key: ' .. opts.key .. '\n')
+ os.exit(1)
+ end
+
+ for _, fname in ipairs(opts.file) do
+ local task = load_task(opts, fname)
+ local ctx = lua_dkim.create_sign_context(task, sign_key, nil, opts.algorithm)
+
+ if not ctx then
+ io.stderr:write('Cannot init signing\n')
+ os.exit(1)
+ end
+
+ local sig = lua_dkim.do_sign(task, ctx, opts.selector, opts.domain)
+
+ if not sig then
+ io.stderr:write('Cannot create signature\n')
+ os.exit(1)
+ end
+
+ if opts.output == 'signature' then
+ io.write(sig)
+ io.write('\n')
+ io.flush()
+ else
+ local dkim_hdr = string.format('%s: %s%s',
+ 'DKIM-Signature',
+ rspamd_util.fold_header('DKIM-Signature',
+ rspamd_util.mime_header_encode(sig),
+ task:get_newlines_type()),
+ newline(task))
+ io.write(dkim_hdr)
+ io.flush()
+ task:get_content():save_in_file(1)
+ end
+
+ task:destroy() -- No automatic dtor
+ end
+end
+
+-- Strips directories and .extensions (if present) from a filepath
+local function filename_only(filepath)
+ local filename = filepath:match(".*%/([^%.]+)")
+ if not filename then
+ filename = filepath:match("([^%.]+)")
+ end
+ return filename
+end
+
+assert(filename_only("very_simple") == "very_simple")
+assert(filename_only("/home/very_simple.eml") == "very_simple")
+assert(filename_only("very_simple.eml") == "very_simple")
+assert(filename_only("very_simple.example.eml") == "very_simple")
+assert(filename_only("/home/very_simple") == "very_simple")
+assert(filename_only("home/very_simple") == "very_simple")
+assert(filename_only("./home/very_simple") == "very_simple")
+assert(filename_only("../home/very_simple.eml") == "very_simple")
+assert(filename_only("/home/dir.with.dots/very_simple.eml") == "very_simple")
+
+--Write the dump content to file or standard out
+local function write_dump_content(dump_content, fname, extension, outdir)
+ if type(dump_content) == "string" then
+ dump_content = rspamd_text.fromstring(dump_content)
+ end
+
+ local wrote_filepath = nil
+ if outdir then
+ if outdir:sub(-1) ~= "/" then
+ outdir = outdir .. "/"
+ end
+
+ local outpath = string.format("%s%s.%s", outdir, filename_only(fname), extension)
+ if rspamd_util.file_exists(outpath) then
+ os.remove(outpath)
+ end
+ if dump_content:save_in_file(outpath) then
+ wrote_filepath = outpath
+ io.write(wrote_filepath .. "\n")
+ else
+ io.stderr:write(string.format("Unable to save dump content to file: %s\n", outpath))
+ end
+ else
+ dump_content:save_in_file(1)
+ end
+ return wrote_filepath
+end
+
+-- Get the formatted ucl (split or unsplit) or the raw task content
+local function get_dump_content(task, opts, fname)
+ if opts.ucl or opts.json or opts.messagepack then
+ local ucl_object = lua_mime.message_to_ucl(task)
+ -- Split out the content field into separate raws and update the ucl
+ if opts.split then
+ for i, part in ipairs(ucl_object.parts) do
+ if part.content then
+ local part_filename = string.format("%s-part%d", filename_only(fname), i)
+ local part_path = write_dump_content(part.content, part_filename, "raw", opts.outdir)
+ if part_path then
+ part.content = ucl.null
+ part.content_path = part_path
+ end
+ end
+ end
+ end
+ local extension = output_fmt(opts)
+ return ucl.to_format(ucl_object, extension), extension
+ end
+ return task:get_content(), "mime"
+end
+
+local function dump_handler(opts)
+ load_config(opts)
+ rspamd_url.init(rspamd_config:get_tld_path())
+
+ for _, fname in ipairs(opts.file) do
+ local task = load_task(opts, fname)
+ local data, extension = get_dump_content(task, opts, fname)
+ write_dump_content(data, fname, extension, opts.outdir)
+
+ task:destroy() -- No automatic dtor
+ end
+end
+
+local function handler(args)
+ local opts = parser:parse(args)
+
+ local command = opts.command
+
+ if type(opts.file) == 'string' then
+ opts.file = { opts.file }
+ elseif type(opts.file) == 'none' then
+ opts.file = {}
+ end
+
+ if command == 'extract' then
+ extract_handler(opts)
+ elseif command == 'stat' then
+ stat_handler(opts)
+ elseif command == 'urls' then
+ urls_handler(opts)
+ elseif command == 'modify' then
+ modify_handler(opts)
+ elseif command == 'sign' then
+ sign_handler(opts)
+ elseif command == 'dump' then
+ dump_handler(opts)
+ else
+ parser:error('command %s is not implemented', command)
+ end
+end
+
+return {
+ name = 'mime',
+ aliases = { 'mime_tool' },
+ handler = handler,
+ description = parser._description
+}
diff --git a/lualib/rspamadm/neural_test.lua b/lualib/rspamadm/neural_test.lua
new file mode 100644
index 0000000..31d21a9
--- /dev/null
+++ b/lualib/rspamadm/neural_test.lua
@@ -0,0 +1,228 @@
+local rspamd_logger = require "rspamd_logger"
+local argparse = require "argparse"
+local lua_util = require "lua_util"
+local ucl = require "ucl"
+
+local parser = argparse()
+ :name "rspamadm neural_test"
+ :description "Test the neural network with labelled dataset"
+ :help_description_margin(32)
+
+parser:option "-c --config"
+ :description "Path to config file"
+ :argname("<cfg>")
+ :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
+parser:option "-H --hamdir"
+ :description("Ham directory")
+ :argname("<dir>")
+parser:option "-S --spamdir"
+ :description("Spam directory")
+ :argname("<dir>")
+parser:option "-t --timeout"
+ :description("Timeout for client connections")
+ :argname("<sec>")
+ :convert(tonumber)
+ :default(60)
+parser:option "-n --conns"
+ :description("Number of parallel connections")
+ :argname("<N>")
+ :convert(tonumber)
+ :default(10)
+parser:option "-c --connect"
+ :description("Connect to specific host")
+ :argname("<host>")
+ :default('localhost:11334')
+parser:option "-r --rspamc"
+ :description("Use specific rspamc path")
+ :argname("<path>")
+ :default('rspamc')
+parser:option '--rule'
+ :description 'Rule to test'
+ :argname('<rule>')
+
+local HAM = "HAM"
+local SPAM = "SPAM"
+
+local function load_config(opts)
+ local _r, err = rspamd_config:load_ucl(opts['config'])
+
+ if not _r then
+ rspamd_logger.errx('cannot parse %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
+ _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' })
+ if not _r then
+ rspamd_logger.errx('cannot process %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+end
+
+local function scan_email(rspamc_path, host, n_parallel, path, timeout)
+
+ local rspamc_command = string.format("%s --connect %s -j --compact -n %s -t %.3f %s",
+ rspamc_path, host, n_parallel, timeout, path)
+ local result = assert(io.popen(rspamc_command))
+ result = result:read("*all")
+ return result
+end
+
+local function encoded_json_to_log(result)
+ -- Returns table containing score, action, list of symbols
+
+ local filtered_result = {}
+ local ucl_parser = ucl.parser()
+
+ local is_good, err = ucl_parser:parse_string(result)
+
+ if not is_good then
+ rspamd_logger.errx("Parser error: %1", err)
+ return nil
+ end
+
+ result = ucl_parser:get_object()
+
+ filtered_result.score = result.score
+ if not result.action then
+ rspamd_logger.errx("Bad JSON: %1", result)
+ return nil
+ end
+ local action = result.action:gsub("%s+", "_")
+ filtered_result.action = action
+
+ filtered_result.symbols = {}
+
+ for sym, _ in pairs(result.symbols) do
+ table.insert(filtered_result.symbols, sym)
+ end
+
+ filtered_result.filename = result.filename
+ filtered_result.scan_time = result.scan_time
+
+ return filtered_result
+end
+
+local function filter_scan_results(results, actual_email_type)
+
+ local logs = {}
+
+ results = lua_util.rspamd_str_split(results, "\n")
+
+ if results[#results] == "" then
+ results[#results] = nil
+ end
+
+ for _, result in pairs(results) do
+ result = encoded_json_to_log(result)
+ if result then
+ result['type'] = actual_email_type
+ table.insert(logs, result)
+ end
+ end
+
+ return logs
+end
+
+local function get_stats_from_scan_results(results, rules)
+
+ local rule_stats = {}
+ for rule, _ in pairs(rules) do
+ rule_stats[rule] = { tp = 0, tn = 0, fp = 0, fn = 0 }
+ end
+
+ for _, result in ipairs(results) do
+ for _, symbol in ipairs(result["symbols"]) do
+ for name, rule in pairs(rules) do
+ if rule.symbol_spam and rule.symbol_spam == symbol then
+ if result.type == HAM then
+ rule_stats[name].fp = rule_stats[name].fp + 1
+ elseif result.type == SPAM then
+ rule_stats[name].tp = rule_stats[name].tp + 1
+ end
+ elseif rule.symbol_ham and rule.symbol_ham == symbol then
+ if result.type == HAM then
+ rule_stats[name].tn = rule_stats[name].tn + 1
+ elseif result.type == SPAM then
+ rule_stats[name].fn = rule_stats[name].fn + 1
+ end
+ end
+ end
+ end
+ end
+
+ for rule, _ in pairs(rules) do
+ rule_stats[rule].fpr = rule_stats[rule].fp / (rule_stats[rule].fp + rule_stats[rule].tn)
+ rule_stats[rule].fnr = rule_stats[rule].fn / (rule_stats[rule].fn + rule_stats[rule].tp)
+ end
+
+ return rule_stats
+end
+
+local function print_neural_stats(neural_stats)
+ for rule, stats in pairs(neural_stats) do
+ rspamd_logger.messagex("\nStats for rule: %s", rule)
+ rspamd_logger.messagex("False positive rate: %s%%", stats.fpr * 100)
+ rspamd_logger.messagex("False negative rate: %s%%", stats.fnr * 100)
+ end
+end
+
+local function handler(args)
+ local opts = parser:parse(args)
+
+ local ham_directory = opts['hamdir']
+ local spam_directory = opts['spamdir']
+ local connections = opts["conns"]
+
+ load_config(opts)
+
+ local neural_opts = rspamd_config:get_all_opt('neural')
+
+ if opts["rule"] then
+ local found = false
+ for rule_name, _ in pairs(neural_opts.rules) do
+ if string.lower(rule_name) == string.lower(opts["rule"]) then
+ found = true
+ else
+ neural_opts.rules[rule_name] = nil
+ end
+ end
+
+ if not found then
+ rspamd_logger.errx("Couldn't find the rule %s", opts["rule"])
+ return
+ end
+ end
+
+ local results = {}
+
+ if ham_directory then
+ rspamd_logger.messagex("Scanning ham corpus...")
+ local ham_results = scan_email(opts.rspamc, opts.connect, connections, ham_directory, opts.timeout)
+ ham_results = filter_scan_results(ham_results, HAM)
+
+ for _, result in pairs(ham_results) do
+ table.insert(results, result)
+ end
+ end
+
+ if spam_directory then
+ rspamd_logger.messagex("Scanning spam corpus...")
+ local spam_results = scan_email(opts.rspamc, opts.connect, connections, spam_directory, opts.timeout)
+ spam_results = filter_scan_results(spam_results, SPAM)
+
+ for _, result in pairs(spam_results) do
+ table.insert(results, result)
+ end
+ end
+
+ local neural_stats = get_stats_from_scan_results(results, neural_opts.rules)
+ print_neural_stats(neural_stats)
+
+end
+
+return {
+ name = "neuraltest",
+ aliases = { "neural_test" },
+ handler = handler,
+ description = parser._description
+} \ No newline at end of file
diff --git a/lualib/rspamadm/publicsuffix.lua b/lualib/rspamadm/publicsuffix.lua
new file mode 100644
index 0000000..96bf069
--- /dev/null
+++ b/lualib/rspamadm/publicsuffix.lua
@@ -0,0 +1,82 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local argparse = require "argparse"
+local rspamd_logger = require "rspamd_logger"
+
+-- Define command line options
+local parser = argparse()
+ :name 'rspamadm publicsuffix'
+ :description 'Do manipulations with the publicsuffix list'
+ :help_description_margin(30)
+ :command_target('command')
+ :require_command(true)
+
+parser:option '-c --config'
+ :description 'Path to config file'
+ :argname('config_file')
+ :default(rspamd_paths['CONFDIR'] .. '/rspamd.conf')
+
+parser:command 'compile'
+ :description 'Compile publicsuffix list if needed'
+
+local function load_config(config_file)
+ local _r, err = rspamd_config:load_ucl(config_file)
+
+ if not _r then
+ rspamd_logger.errx('cannot load %s: %s', config_file, err)
+ os.exit(1)
+ end
+
+ _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' })
+ if not _r then
+ rspamd_logger.errx('cannot process %s: %s', config_file, err)
+ os.exit(1)
+ end
+end
+
+local function compile_handler(_)
+ local rspamd_url = require "rspamd_url"
+ local tld_file = rspamd_config:get_tld_path()
+
+ if not tld_file then
+ rspamd_logger.errx('missing `url_tld` option, cannot continue')
+ os.exit(1)
+ end
+
+ rspamd_logger.messagex('loading public suffix file from %s', tld_file)
+ rspamd_url.init(tld_file)
+ rspamd_logger.messagex('public suffix file has been loaded')
+end
+
+local function handler(args)
+ local cmd_opts = parser:parse(args)
+
+ load_config(cmd_opts.config_file)
+
+ if cmd_opts.command == 'compile' then
+ compile_handler(cmd_opts)
+ else
+ rspamd_logger.errx('unknown command: %s', cmd_opts.command)
+ os.exit(1)
+ end
+end
+
+return {
+ handler = handler,
+ description = parser._description,
+ name = 'publicsuffix'
+} \ No newline at end of file
diff --git a/lualib/rspamadm/stat_convert.lua b/lualib/rspamadm/stat_convert.lua
new file mode 100644
index 0000000..62a19a2
--- /dev/null
+++ b/lualib/rspamadm/stat_convert.lua
@@ -0,0 +1,38 @@
+local lua_redis = require "lua_redis"
+local stat_tools = require "lua_stat"
+local ucl = require "ucl"
+local logger = require "rspamd_logger"
+local lua_util = require "lua_util"
+
+return function(_, res)
+ local redis_params = lua_redis.try_load_redis_servers(res.redis, nil)
+ if res.expire then
+ res.expire = lua_util.parse_time_interval(res.expire)
+ end
+ if not redis_params then
+ logger.errx('cannot load redis server definition')
+
+ return false
+ end
+
+ local sqlite_params = stat_tools.load_sqlite_config(res)
+
+ if #sqlite_params == 0 then
+ logger.errx('cannot load sqlite classifiers')
+ return false
+ end
+
+ for _, cls in ipairs(sqlite_params) do
+ if not stat_tools.convert_sqlite_to_redis(redis_params, cls.db_spam,
+ cls.db_ham, cls.symbol_spam, cls.symbol_ham, cls.learn_cache, res.expire,
+ res.reset_previous) then
+ logger.errx('conversion failed')
+
+ return false
+ end
+ logger.messagex('Converted classifier to the from sqlite to redis')
+ logger.messagex('Suggested configuration:')
+ logger.messagex(ucl.to_format(stat_tools.redis_classifier_from_sqlite(cls, res.expire),
+ 'config'))
+ end
+end
diff --git a/lualib/rspamadm/statistics_dump.lua b/lualib/rspamadm/statistics_dump.lua
new file mode 100644
index 0000000..6bc0458
--- /dev/null
+++ b/lualib/rspamadm/statistics_dump.lua
@@ -0,0 +1,544 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local lua_redis = require "lua_redis"
+local rspamd_logger = require "rspamd_logger"
+local argparse = require "argparse"
+local rspamd_zstd = require "rspamd_zstd"
+local rspamd_text = require "rspamd_text"
+local rspamd_util = require "rspamd_util"
+local rspamd_cdb = require "rspamd_cdb"
+local lua_util = require "lua_util"
+local rspamd_i64 = require "rspamd_int64"
+local ucl = require "ucl"
+
+local N = "statistics_dump"
+local E = {}
+local classifiers = {}
+
+-- Define command line options
+local parser = argparse()
+ :name "rspamadm statistics_dump"
+ :description "Dump/restore Rspamd statistics"
+ :help_description_margin(30)
+ :command_target("command")
+ :require_command(false)
+
+parser:option "-c --config"
+ :description "Path to config file"
+ :argname("<cfg>")
+ :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
+
+-- Extract subcommand
+local dump = parser:command "dump d"
+ :description "Dump bayes statistics"
+dump:mutex(
+ dump:flag "-j --json"
+ :description "Json output",
+ dump:flag "-C --cdb"
+ :description "CDB output"
+)
+dump:flag "-c --compress"
+ :description "Compress output"
+dump:option "-b --batch-size"
+ :description "Number of entires to process at once"
+ :argname("<elts>")
+ :convert(tonumber)
+ :default(1000)
+
+
+-- Restore
+local restore = parser:command "restore r"
+ :description "Restore bayes statistics"
+restore:argument "file"
+ :description "Input file to process"
+ :argname "<file>"
+ :args "*"
+restore:option "-b --batch-size"
+ :description "Number of entires to process at once"
+ :argname("<elts>")
+ :convert(tonumber)
+ :default(1000)
+restore:option "-m --mode"
+ :description "Number of entires to process at once"
+ :argname("<append|subtract|replace>")
+ :convert {
+ ['append'] = 'append',
+ ['subtract'] = 'subtract',
+ ['replace'] = 'replace',
+}
+ :default 'append'
+restore:flag "-n --no-operation"
+ :description "Only show redis commands to be issued"
+
+local function load_config(opts)
+ local _r, err = rspamd_config:load_ucl(opts['config'])
+
+ if not _r then
+ rspamd_logger.errx('cannot parse %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
+ _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' })
+ if not _r then
+ rspamd_logger.errx('cannot process %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+end
+
+local function check_redis_classifier(cls, cfg)
+ -- Skip old classifiers
+ if cls.new_schema then
+ local symbol_spam, symbol_ham
+ -- Load symbols from statfiles
+
+ local function check_statfile_table(tbl, def_sym)
+ local symbol = tbl.symbol or def_sym
+
+ local spam
+ if tbl.spam then
+ spam = tbl.spam
+ else
+ if string.match(symbol:upper(), 'SPAM') then
+ spam = true
+ else
+ spam = false
+ end
+ end
+
+ if spam then
+ symbol_spam = symbol
+ else
+ symbol_ham = symbol
+ end
+ end
+
+ local statfiles = cls.statfile
+ if statfiles[1] then
+ for _, stf in ipairs(statfiles) do
+ if not stf.symbol then
+ for k, v in pairs(stf) do
+ check_statfile_table(v, k)
+ end
+ else
+ check_statfile_table(stf, 'undefined')
+ end
+ end
+ else
+ for stn, stf in pairs(statfiles) do
+ check_statfile_table(stf, stn)
+ end
+ end
+
+ local redis_params
+ redis_params = lua_redis.try_load_redis_servers(cls,
+ rspamd_config, false, 'bayes')
+ if not redis_params then
+ redis_params = lua_redis.try_load_redis_servers(cfg[N] or E,
+ rspamd_config, false, 'bayes')
+ if not redis_params then
+ redis_params = lua_redis.try_load_redis_servers(cfg[N] or E,
+ rspamd_config, true)
+ if not redis_params then
+ return false
+ end
+ end
+ end
+
+ table.insert(classifiers, {
+ symbol_spam = symbol_spam,
+ symbol_ham = symbol_ham,
+ redis_params = redis_params,
+ })
+ end
+end
+
+local function redis_map_zip(ar)
+ local data = {}
+ for j = 1, #ar, 2 do
+ data[ar[j]] = ar[j + 1]
+ end
+
+ return data
+end
+
+-- Used to clear tables
+local clear_fcn = table.clear or function(tbl)
+ local keys = lua_util.keys(tbl)
+ for _, k in ipairs(keys) do
+ tbl[k] = nil
+ end
+end
+
+local compress_ctx
+
+local function dump_out(out, opts, last)
+ if opts.compress and not compress_ctx then
+ compress_ctx = rspamd_zstd.compress_ctx()
+ end
+
+ if compress_ctx then
+ if last then
+ compress_ctx:stream(rspamd_text.fromtable(out), 'end'):write()
+ else
+ compress_ctx:stream(rspamd_text.fromtable(out), 'flush'):write()
+ end
+ else
+ for _, o in ipairs(out) do
+ io.write(o)
+ end
+ end
+end
+
+local function dump_cdb(out, opts, last, pattern)
+ local results = out[pattern]
+
+ if not out.cdb_builder then
+ -- First invocation
+ out.cdb_builder = rspamd_cdb.build(string.format('%s.cdb', pattern))
+ out.cdb_builder:add('_lrnspam', rspamd_i64.fromstring(results.learns_spam or '0'))
+ out.cdb_builder:add('_lrnham_', rspamd_i64.fromstring(results.learns_ham or '0'))
+ end
+
+ for _, o in ipairs(results.elts) do
+ out.cdb_builder:add(o.key, o.value)
+ end
+
+ if last then
+ out.cdb_builder:finalize()
+ out.cdb_builder = nil
+ end
+end
+
+local function dump_pattern(conn, pattern, opts, out, key)
+ local cursor = 0
+
+ repeat
+ conn:add_cmd('SCAN', { tostring(cursor),
+ 'MATCH', pattern,
+ 'COUNT', tostring(opts.batch_size) })
+ local ret, results = conn:exec()
+
+ if not ret then
+ rspamd_logger.errx("cannot connect execute scan command: %s", results)
+ os.exit(1)
+ end
+
+ cursor = tonumber(results[1])
+
+ local elts = results[2]
+ local tokens = {}
+
+ for _, e in ipairs(elts) do
+ conn:add_cmd('HGETALL', { e })
+ end
+ -- This function returns many results, each for each command
+ -- So if we have batch 1000, then we would have 1000 tables in form
+ -- [result, {hash_content}]
+ local all_results = { conn:exec() }
+
+ for i = 1, #all_results, 2 do
+ local r, hash_content = all_results[i], all_results[i + 1]
+
+ if r then
+ -- List to a hash map
+ local data = redis_map_zip(hash_content)
+ tokens[#tokens + 1] = { key = elts[(i + 1) / 2], data = data }
+ end
+ end
+
+ -- Output keeping track of the commas
+ for i, d in ipairs(tokens) do
+ if cursor == 0 and i == #tokens or not opts.json then
+ if opts.cdb then
+ table.insert(out[key].elts, {
+ key = rspamd_i64.fromstring(string.match(d.key, '%d+')),
+ value = rspamd_util.pack('ff', tonumber(d.data["S"] or '0') or 0,
+ tonumber(d.data["H"] or '0'))
+ })
+ else
+ out[#out + 1] = rspamd_logger.slog('"%s": %s\n', d.key,
+ ucl.to_format(d.data, "json-compact"))
+ end
+ else
+ out[#out + 1] = rspamd_logger.slog('"%s": %s,\n', d.key,
+ ucl.to_format(d.data, "json-compact"))
+ end
+
+ end
+
+ if opts.json and cursor == 0 then
+ out[#out + 1] = '}}\n'
+ end
+
+ -- Do not write the last chunk of out as it will be processed afterwards
+ if cursor ~= 0 then
+ if opts.cdb then
+ dump_out(out, opts, false)
+ clear_fcn(out)
+ else
+ dump_cdb(out, opts, false, key)
+ out[key].elts = {}
+ end
+ elseif opts.cdb then
+ dump_cdb(out, opts, true, key)
+ end
+
+ until cursor == 0
+end
+
+local function dump_handler(opts)
+ local patterns_seen = {}
+ for _, cls in ipairs(classifiers) do
+ local res, conn = lua_redis.redis_connect_sync(cls.redis_params, false)
+
+ if not res then
+ rspamd_logger.errx("cannot connect to redis server: %s", cls.redis_params)
+ os.exit(1)
+ end
+
+ local out = {}
+ local function check_keys(sym)
+ local sym_keys_pattern = string.format("%s_keys", sym)
+ conn:add_cmd('SMEMBERS', { sym_keys_pattern })
+ local ret, keys = conn:exec()
+
+ if not ret then
+ rspamd_logger.errx("cannot execute command to get keys: %s", keys)
+ os.exit(1)
+ end
+
+ if not opts.json then
+ out[#out + 1] = string.format('"%s": %s\n', sym_keys_pattern,
+ ucl.to_format(keys, 'json-compact'))
+ end
+ for _, k in ipairs(keys) do
+ local pat = string.format('%s*_*', k)
+ if not patterns_seen[pat] then
+ conn:add_cmd('HGETALL', { k })
+ local _ret, additional_keys = conn:exec()
+
+ if _ret then
+ if opts.json then
+ out[#out + 1] = string.format('{"pattern": "%s", "meta": %s, "elts": {\n',
+ k, ucl.to_format(redis_map_zip(additional_keys), 'json-compact'))
+ elseif opts.cdb then
+ out[k] = redis_map_zip(additional_keys)
+ out[k].elts = {}
+ else
+ out[#out + 1] = string.format('"%s": %s\n', k,
+ ucl.to_format(redis_map_zip(additional_keys), 'json-compact'))
+ end
+ dump_pattern(conn, pat, opts, out, k)
+ patterns_seen[pat] = true
+ end
+ end
+ end
+ end
+
+ check_keys(cls.symbol_spam)
+ check_keys(cls.symbol_ham)
+
+ if #out > 0 then
+ dump_out(out, opts, true)
+ end
+ end
+end
+
+local function obj_to_redis_arguments(obj, opts, cmd_pipe)
+ local key, value = next(obj)
+
+ if type(key) == 'string' then
+ if type(value) == 'table' then
+ if not value[1] then
+ if opts.mode == 'replace' then
+ local cmd = 'HMSET'
+ local params = { key }
+ for k, v in pairs(value) do
+ table.insert(params, k)
+ table.insert(params, v)
+ end
+ table.insert(cmd_pipe, { cmd, params })
+ else
+ local cmd = 'HINCRBYFLOAT'
+ local mult = 1.0
+ if opts.mode == 'subtract' then
+ mult = (-mult)
+ end
+
+ for k, v in pairs(value) do
+ if tonumber(v) then
+ v = tonumber(v)
+ table.insert(cmd_pipe, { cmd, { key, k, tostring(v * mult) } })
+ else
+ table.insert(cmd_pipe, { 'HSET', { key, k, v } })
+ end
+ end
+ end
+ else
+ -- Numeric table of elements (e.g. _keys) - it is actually a set in Redis
+ for _, elt in ipairs(value) do
+ table.insert(cmd_pipe, { 'SADD', { key, elt } })
+ end
+ end
+ end
+ end
+
+ return cmd_pipe
+end
+
+local function execute_batch(batch, conns, opts)
+ local cmd_pipe = {}
+
+ for _, cmd in ipairs(batch) do
+ obj_to_redis_arguments(cmd, opts, cmd_pipe)
+ end
+
+ if opts.no_operation then
+ for _, cmd in ipairs(cmd_pipe) do
+ rspamd_logger.messagex('%s %s', cmd[1], table.concat(cmd[2], ' '))
+ end
+ else
+ for _, conn in ipairs(conns) do
+ for _, cmd in ipairs(cmd_pipe) do
+ local is_ok, err = conn:add_cmd(cmd[1], cmd[2])
+
+ if not is_ok then
+ rspamd_logger.errx("cannot add command: %s with args: %s: %s", cmd[1], cmd[2], err)
+ end
+ end
+
+ conn:exec()
+ end
+ end
+end
+
+local function restore_handler(opts)
+ local files = opts.file or { '-' }
+ local conns = {}
+
+ for _, cls in ipairs(classifiers) do
+ local res, conn = lua_redis.redis_connect_sync(cls.redis_params, true)
+
+ if not res then
+ rspamd_logger.errx("cannot connect to redis server: %s", cls.redis_params)
+ os.exit(1)
+ end
+
+ table.insert(conns, conn)
+ end
+
+ local batch = {}
+
+ for _, f in ipairs(files) do
+ local fd
+ if f ~= '-' then
+ fd = io.open(f, 'r')
+ io.input(fd)
+ end
+
+ local cur_line = 1
+ for line in io.lines() do
+ local ucl_parser = ucl.parser()
+ local res, err
+ res, err = ucl_parser:parse_string(line)
+
+ if not res then
+ rspamd_logger.errx("%s: cannot read line %s: %s", f, cur_line, err)
+ os.exit(1)
+ end
+
+ table.insert(batch, ucl_parser:get_object())
+ cur_line = cur_line + 1
+
+ if cur_line % opts.batch_size == 0 then
+ execute_batch(batch, conns, opts)
+ batch = {}
+ end
+ end
+
+ if fd then
+ fd:close()
+ end
+ end
+
+ if #batch > 0 then
+ execute_batch(batch, conns, opts)
+ end
+end
+
+local function handler(args)
+ local opts = parser:parse(args)
+
+ local command = opts.command or 'dump'
+
+ load_config(opts)
+ rspamd_config:init_subsystem('stat')
+
+ local obj = rspamd_config:get_ucl()
+
+ local classifier = obj.classifier
+
+ if classifier then
+ if classifier[1] then
+ for _, cls in ipairs(classifier) do
+ if cls.bayes then
+ cls = cls.bayes
+ end
+ if cls.backend and cls.backend == 'redis' then
+ check_redis_classifier(cls, obj)
+ end
+ end
+ else
+ if classifier.bayes then
+
+ classifier = classifier.bayes
+ if classifier[1] then
+ for _, cls in ipairs(classifier) do
+ if cls.backend and cls.backend == 'redis' then
+ check_redis_classifier(cls, obj)
+ end
+ end
+ else
+ if classifier.backend and classifier.backend == 'redis' then
+ check_redis_classifier(classifier, obj)
+ end
+ end
+ end
+ end
+ end
+
+ if type(opts.file) == 'string' then
+ opts.file = { opts.file }
+ elseif type(opts.file) == 'none' then
+ opts.file = {}
+ end
+
+ if command == 'dump' then
+ dump_handler(opts)
+ elseif command == 'restore' then
+ restore_handler(opts)
+ else
+ parser:error('command %s is not implemented', command)
+ end
+end
+
+return {
+ name = 'statistics_dump',
+ aliases = { 'stat_dump', 'bayes_dump' },
+ handler = handler,
+ description = parser._description
+} \ No newline at end of file
diff --git a/lualib/rspamadm/template.lua b/lualib/rspamadm/template.lua
new file mode 100644
index 0000000..ca1779a
--- /dev/null
+++ b/lualib/rspamadm/template.lua
@@ -0,0 +1,131 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local argparse = require "argparse"
+
+
+-- Define command line options
+local parser = argparse()
+ :name "rspamadm template"
+ :description "Apply jinja templates for strings/files"
+ :help_description_margin(30)
+parser:argument "file"
+ :description "File to process"
+ :argname "<file>"
+ :args "*"
+
+parser:flag "-n --no-vars"
+ :description "Don't add Rspamd internal variables"
+parser:option "-e --env"
+ :description "Load additional environment vars from specific file (name=value)"
+ :argname "<filename>"
+ :count "*"
+parser:option "-l --lua-env"
+ :description "Load additional environment vars from specific file (lua source)"
+ :argname "<filename>"
+ :count "*"
+parser:mutex(
+ parser:option "-s --suffix"
+ :description "Store files with the new suffix"
+ :argname "<suffix>",
+ parser:flag "-i --inplace"
+ :description "Replace input file(s)"
+)
+
+local lua_util = require "lua_util"
+
+local function set_env(opts, env)
+ if opts.env then
+ for _, fname in ipairs(opts.env) do
+ for kv in assert(io.open(fname)):lines() do
+ if not kv:match('%s*#.*') then
+ local k, v = kv:match('([^=%s]+)%s*=%s*(.+)')
+
+ if k and v then
+ env[k] = v
+ else
+ io.write(string.format('invalid env line in %s: %s\n', fname, kv))
+ end
+ end
+ end
+ end
+ end
+
+ if opts.lua_env then
+ for _, fname in ipairs(opts.env) do
+ local ret, res_or_err = pcall(loadfile(fname))
+
+ if not ret then
+ io.write(string.format('cannot load %s: %s\n', fname, res_or_err))
+ else
+ if type(res_or_err) == 'table' then
+ for k, v in pairs(res_or_err) do
+ env[k] = lua_util.deepcopy(v)
+ end
+ else
+ io.write(string.format('cannot load %s: not a table\n', fname))
+ end
+ end
+ end
+ end
+end
+
+local function read_file(file)
+ local content
+ if file == '-' then
+ content = io.read("*all")
+ else
+ local f = assert(io.open(file, "rb"))
+ content = f:read("*all")
+ f:close()
+ end
+ return content
+end
+
+local function handler(args)
+ local opts = parser:parse(args)
+ local env = {}
+ set_env(opts, env)
+
+ if not opts.file or #opts.file == 0 then
+ opts.file = { '-' }
+ end
+ for _, fname in ipairs(opts.file) do
+ local content = read_file(fname)
+ local res = lua_util.jinja_template(content, env, opts.no_vars)
+
+ if opts.inplace then
+ local nfile = string.format('%s.new', fname)
+ local out = assert(io.open(nfile, 'w'))
+ out:write(content)
+ out:close()
+ os.rename(nfile, fname)
+ elseif opts.suffix then
+ local nfile = string.format('%s.%s', opts.suffix)
+ local out = assert(io.open(nfile, 'w'))
+ out:write(content)
+ out:close()
+ else
+ io.write(res)
+ end
+ end
+end
+
+return {
+ handler = handler,
+ description = parser._description,
+ name = 'template'
+} \ No newline at end of file
diff --git a/lualib/rspamadm/vault.lua b/lualib/rspamadm/vault.lua
new file mode 100644
index 0000000..840e504
--- /dev/null
+++ b/lualib/rspamadm/vault.lua
@@ -0,0 +1,579 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+
+local rspamd_logger = require "rspamd_logger"
+local ansicolors = require "ansicolors"
+local ucl = require "ucl"
+local argparse = require "argparse"
+local fun = require "fun"
+local rspamd_http = require "rspamd_http"
+local cr = require "rspamd_cryptobox"
+
+local parser = argparse()
+ :name "rspamadm vault"
+ :description "Perform Hashicorp Vault management"
+ :help_description_margin(32)
+ :command_target("command")
+ :require_command(true)
+
+parser:flag "-s --silent"
+ :description "Do not output extra information"
+parser:option "-a --addr"
+ :description "Vault address (if not defined in VAULT_ADDR env)"
+parser:option "-t --token"
+ :description "Vault token (not recommended, better define VAULT_TOKEN env)"
+parser:option "-p --path"
+ :description "Path to work with in the vault"
+ :default "dkim"
+parser:option "-o --output"
+ :description "Output format ('ucl', 'json', 'json-compact', 'yaml')"
+ :argname("<type>")
+ :convert {
+ ucl = "ucl",
+ json = "json",
+ ['json-compact'] = "json-compact",
+ yaml = "yaml",
+}
+ :default "ucl"
+
+parser:command "list ls l"
+ :description "List elements in the vault"
+
+local show = parser:command "show get"
+ :description "Extract element from the vault"
+show:argument "domain"
+ :description "Domain to create key for"
+ :args "+"
+
+local delete = parser:command "delete del rm remove"
+ :description "Delete element from the vault"
+delete:argument "domain"
+ :description "Domain to create delete key(s) for"
+ :args "+"
+
+local newkey = parser:command "newkey new create"
+ :description "Add new key to the vault"
+newkey:argument "domain"
+ :description "Domain to create key for"
+ :args "+"
+newkey:option "-s --selector"
+ :description "Selector to use"
+ :count "?"
+newkey:option "-A --algorithm"
+ :argname("<type>")
+ :convert {
+ rsa = "rsa",
+ ed25519 = "ed25519",
+ eddsa = "ed25519",
+}
+ :default "rsa"
+newkey:option "-b --bits"
+ :argname("<nbits>")
+ :convert(tonumber)
+ :default "1024"
+newkey:option "-x --expire"
+ :argname("<days>")
+ :convert(tonumber)
+newkey:flag "-r --rewrite"
+
+local roll = parser:command "roll rollover"
+ :description "Perform keys rollover"
+roll:argument "domain"
+ :description "Domain to roll key(s) for"
+ :args "+"
+roll:option "-T --ttl"
+ :description "Validity period for old keys (days)"
+ :convert(tonumber)
+ :default "1"
+roll:flag "-r --remove-expired"
+ :description "Remove expired keys"
+roll:option "-x --expire"
+ :argname("<days>")
+ :convert(tonumber)
+
+local function printf(fmt, ...)
+ if fmt then
+ io.write(rspamd_logger.slog(fmt, ...))
+ end
+ io.write('\n')
+end
+
+local function maybe_printf(opts, fmt, ...)
+ if not opts.silent then
+ printf(fmt, ...)
+ end
+end
+
+local function highlight(str, color)
+ return ansicolors[color or 'white'] .. str .. ansicolors.reset
+end
+
+local function vault_url(opts, path)
+ if path then
+ return string.format('%s/v1/%s/%s', opts.addr, opts.path, path)
+ end
+
+ return string.format('%s/v1/%s', opts.addr, opts.path)
+end
+
+local function is_http_error(err, data)
+ return err or (math.floor(data.code / 100) ~= 2)
+end
+
+local function parse_vault_reply(data)
+ local p = ucl.parser()
+ local res, parser_err = p:parse_string(data)
+
+ if not res then
+ return nil, parser_err
+ else
+ return p:get_object(), nil
+ end
+end
+
+local function maybe_print_vault_data(opts, data, func)
+ if data then
+ local res, parser_err = parse_vault_reply(data)
+
+ if not res then
+ printf('vault reply for cannot be parsed: %s', parser_err)
+ else
+ if func then
+ printf(ucl.to_format(func(res), opts.output))
+ else
+ printf(ucl.to_format(res, opts.output))
+ end
+ end
+ else
+ printf('no data received')
+ end
+end
+
+local function print_dkim_txt_record(b64, selector, alg)
+ local labels = {}
+ local prefix = string.format("v=DKIM1; k=%s; p=", alg)
+ b64 = prefix .. b64
+ if #b64 < 255 then
+ labels = { '"' .. b64 .. '"' }
+ else
+ for sl = 1, #b64, 256 do
+ table.insert(labels, '"' .. b64:sub(sl, sl + 255) .. '"')
+ end
+ end
+
+ printf("%s._domainkey IN TXT ( %s )", selector,
+ table.concat(labels, "\n\t"))
+end
+
+local function show_handler(opts, domain)
+ local uri = vault_url(opts, domain)
+ local err, data = rspamd_http.request {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+ url = uri,
+ headers = {
+ ['X-Vault-Token'] = opts.token
+ }
+ }
+
+ if is_http_error(err, data) then
+ printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code)
+ maybe_print_vault_data(opts, err)
+ os.exit(1)
+ else
+ maybe_print_vault_data(opts, data.content, function(obj)
+ return obj.data.selectors
+ end)
+ end
+end
+
+local function delete_handler(opts, domain)
+ local uri = vault_url(opts, domain)
+ local err, data = rspamd_http.request {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+ url = uri,
+ method = 'delete',
+ headers = {
+ ['X-Vault-Token'] = opts.token
+ }
+ }
+
+ if is_http_error(err, data) then
+ printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code)
+ maybe_print_vault_data(opts, err)
+ os.exit(1)
+ else
+ printf('deleted key(s) for %s', domain)
+ end
+end
+
+local function list_handler(opts)
+ local uri = vault_url(opts)
+ local err, data = rspamd_http.request {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+ url = uri .. '?list=true',
+ headers = {
+ ['X-Vault-Token'] = opts.token
+ }
+ }
+
+ if is_http_error(err, data) then
+ printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code)
+ maybe_print_vault_data(opts, err)
+ os.exit(1)
+ else
+ maybe_print_vault_data(opts, data.content, function(obj)
+ return obj.data.keys
+ end)
+ end
+end
+
+-- Returns pair privkey+pubkey
+local function genkey(opts)
+ return cr.gen_dkim_keypair(opts.algorithm, opts.bits)
+end
+
+local function create_and_push_key(opts, domain, existing)
+ local uri = vault_url(opts, domain)
+ local sk, pk = genkey(opts)
+
+ local res = {
+ selectors = {
+ [1] = {
+ selector = opts.selector,
+ domain = domain,
+ key = tostring(sk),
+ pubkey = tostring(pk),
+ alg = opts.algorithm,
+ bits = opts.bits or 0,
+ valid_start = os.time(),
+ }
+ }
+ }
+
+ for _, sel in ipairs(existing) do
+ res.selectors[#res.selectors + 1] = sel
+ end
+
+ if opts.expire then
+ res.selectors[1].valid_end = os.time() + opts.expire * 3600 * 24
+ end
+
+ local err, data = rspamd_http.request {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+ url = uri,
+ method = 'put',
+ headers = {
+ ['Content-Type'] = 'application/json',
+ ['X-Vault-Token'] = opts.token
+ },
+ body = {
+ ucl.to_format(res, 'json-compact')
+ },
+ }
+
+ if is_http_error(err, data) then
+ printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code)
+ maybe_print_vault_data(opts, data.content)
+ os.exit(1)
+ else
+ maybe_printf(opts, 'stored key for: %s, selector: %s', domain, opts.selector)
+ maybe_printf(opts, 'please place the corresponding public key as following:')
+
+ if opts.silent then
+ printf('%s', pk)
+ else
+ print_dkim_txt_record(tostring(pk), opts.selector, opts.algorithm)
+ end
+ end
+end
+
+local function newkey_handler(opts, domain)
+ local uri = vault_url(opts, domain)
+
+ if not opts.selector then
+ opts.selector = string.format('%s-%s', opts.algorithm,
+ os.date("!%Y%m%d"))
+ end
+
+ local err, data = rspamd_http.request {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+ url = uri,
+ method = 'get',
+ headers = {
+ ['X-Vault-Token'] = opts.token
+ }
+ }
+
+ if is_http_error(err, data) or not data.content then
+ create_and_push_key(opts, domain, {})
+ else
+ -- Key exists
+ local rep = parse_vault_reply(data.content)
+
+ if not rep or not rep.data then
+ printf('cannot parse reply for %s: %s', uri, data.content)
+ os.exit(1)
+ end
+
+ local elts = rep.data.selectors
+
+ if not elts then
+ create_and_push_key(opts, domain, {})
+ os.exit(0)
+ end
+
+ for _, sel in ipairs(elts) do
+ if sel.alg == opts.algorithm then
+ printf('key with the specific algorithm %s is already presented at %s selector for %s domain',
+ opts.algorithm, sel.selector, domain)
+ os.exit(1)
+ else
+ create_and_push_key(opts, domain, elts)
+ end
+ end
+ end
+end
+
+local function roll_handler(opts, domain)
+ local uri = vault_url(opts, domain)
+ local res = {
+ selectors = {}
+ }
+
+ local err, data = rspamd_http.request {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+ url = uri,
+ method = 'get',
+ headers = {
+ ['X-Vault-Token'] = opts.token
+ }
+ }
+
+ if is_http_error(err, data) or not data.content then
+ printf("No keys to roll for domain %s", domain)
+ os.exit(1)
+ else
+ local rep = parse_vault_reply(data.content)
+
+ if not rep or not rep.data then
+ printf('cannot parse reply for %s: %s', uri, data.content)
+ os.exit(1)
+ end
+
+ local elts = rep.data.selectors
+
+ if not elts then
+ printf("No keys to roll for domain %s", domain)
+ os.exit(1)
+ end
+
+ local nkeys = {} -- indexed by algorithm
+
+ local function insert_key(sel, add_expire)
+ if not nkeys[sel.alg] then
+ nkeys[sel.alg] = {}
+ end
+
+ if add_expire then
+ sel.valid_end = os.time() + opts.ttl * 3600 * 24
+ end
+
+ table.insert(nkeys[sel.alg], sel)
+ end
+
+ for _, sel in ipairs(elts) do
+ if sel.valid_end and sel.valid_end < os.time() then
+ if not opts.remove_expired then
+ insert_key(sel, false)
+ else
+ maybe_printf(opts, 'removed expired key for %s (selector %s, expire "%s"',
+ domain, sel.selector, os.date('%c', sel.valid_end))
+ end
+ else
+ insert_key(sel, true)
+ end
+ end
+
+ -- Now we need to ensure that all but one selectors have either expired or just a single key
+ for alg, keys in pairs(nkeys) do
+ table.sort(keys, function(k1, k2)
+ if k1.valid_end and k2.valid_end then
+ return k1.valid_end > k2.valid_end
+ elseif k1.valid_end then
+ return true
+ elseif k2.valid_end then
+ return false
+ end
+ return false
+ end)
+ -- Exclude the key with the highest expiration date and examine the rest
+ if not (#keys == 1 or fun.all(function(k)
+ return k.valid_end and k.valid_end < os.time()
+ end, fun.tail(keys))) then
+ printf('bad keys list for %s and %s algorithm', domain, alg)
+ fun.each(function(k)
+ if not k.valid_end then
+ printf('selector %s, algorithm %s has a key with no expire',
+ k.selector, k.alg)
+ elseif k.valid_end >= os.time() then
+ printf('selector %s, algorithm %s has a key that not yet expired: %s',
+ k.selector, k.alg, os.date('%c', k.valid_end))
+ end
+ end, fun.tail(keys))
+ os.exit(1)
+ end
+ -- Do not create new keys, if we only want to remove expired keys
+ if not opts.remove_expired then
+ -- OK to process
+ -- Insert keys for each algorithm in pairs <old_key(s)>, <new_key>
+ local sk, pk = genkey({ algorithm = alg, bits = keys[1].bits })
+ local selector = string.format('%s-%s', alg,
+ os.date("!%Y%m%d"))
+
+ if selector == keys[1].selector then
+ selector = selector .. '-1'
+ end
+ local nelt = {
+ selector = selector,
+ domain = domain,
+ key = tostring(sk),
+ pubkey = tostring(pk),
+ alg = alg,
+ bits = keys[1].bits,
+ valid_start = os.time(),
+ }
+
+ if opts.expire then
+ nelt.valid_end = os.time() + opts.expire * 3600 * 24
+ end
+
+ table.insert(res.selectors, nelt)
+ end
+ for _, k in ipairs(keys) do
+ table.insert(res.selectors, k)
+ end
+ end
+ end
+
+ -- We can now store res in the vault
+ err, data = rspamd_http.request {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+ url = uri,
+ method = 'put',
+ headers = {
+ ['Content-Type'] = 'application/json',
+ ['X-Vault-Token'] = opts.token
+ },
+ body = {
+ ucl.to_format(res, 'json-compact')
+ },
+ }
+
+ if is_http_error(err, data) then
+ printf('cannot put request to the vault (%s), HTTP error code %s', uri, data.code)
+ maybe_print_vault_data(opts, data.content)
+ os.exit(1)
+ else
+ for _, key in ipairs(res.selectors) do
+ if not key.valid_end or key.valid_end > os.time() + opts.ttl * 3600 * 24 then
+ maybe_printf(opts, 'rolled key for: %s, new selector: %s', domain, key.selector)
+ maybe_printf(opts, 'please place the corresponding public key as following:')
+
+ if opts.silent then
+ printf('%s', key.pubkey)
+ else
+ print_dkim_txt_record(key.pubkey, key.selector, key.alg)
+ end
+
+ end
+ end
+
+ maybe_printf(opts, 'your old keys will be valid until %s',
+ os.date('%c', os.time() + opts.ttl * 3600 * 24))
+ end
+end
+
+local function handler(args)
+ local opts = parser:parse(args)
+
+ if not opts.addr then
+ opts.addr = os.getenv('VAULT_ADDR')
+ end
+
+ if not opts.token then
+ opts.token = os.getenv('VAULT_TOKEN')
+ else
+ maybe_printf(opts, 'defining token via command line is insecure, define it via environment variable %s',
+ highlight('VAULT_TOKEN', 'red'))
+ end
+
+ if not opts.token or not opts.addr then
+ printf('no token or/and vault addr has been specified, exiting')
+ os.exit(1)
+ end
+
+ local command = opts.command
+
+ if command == 'list' then
+ list_handler(opts)
+ elseif command == 'show' then
+ fun.each(function(d)
+ show_handler(opts, d)
+ end, opts.domain)
+ elseif command == 'newkey' then
+ fun.each(function(d)
+ newkey_handler(opts, d)
+ end, opts.domain)
+ elseif command == 'roll' then
+ fun.each(function(d)
+ roll_handler(opts, d)
+ end, opts.domain)
+ elseif command == 'delete' then
+ fun.each(function(d)
+ delete_handler(opts, d)
+ end, opts.domain)
+ else
+ parser:error(string.format('command %s is not implemented', command))
+ end
+end
+
+return {
+ handler = handler,
+ description = parser._description,
+ name = 'vault'
+}