diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-10 21:30:40 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-10 21:30:40 +0000 |
commit | 133a45c109da5310add55824db21af5239951f93 (patch) | |
tree | ba6ac4c0a950a0dda56451944315d66409923918 /lualib/rspamadm | |
parent | Initial commit. (diff) | |
download | rspamd-upstream.tar.xz rspamd-upstream.zip |
Adding upstream version 3.8.1.upstream/3.8.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'lualib/rspamadm')
-rw-r--r-- | lualib/rspamadm/clickhouse.lua | 528 | ||||
-rw-r--r-- | lualib/rspamadm/configgraph.lua | 172 | ||||
-rw-r--r-- | lualib/rspamadm/confighelp.lua | 123 | ||||
-rw-r--r-- | lualib/rspamadm/configwizard.lua | 849 | ||||
-rw-r--r-- | lualib/rspamadm/cookie.lua | 125 | ||||
-rw-r--r-- | lualib/rspamadm/corpus_test.lua | 185 | ||||
-rw-r--r-- | lualib/rspamadm/dkim_keygen.lua | 178 | ||||
-rw-r--r-- | lualib/rspamadm/dmarc_report.lua | 737 | ||||
-rw-r--r-- | lualib/rspamadm/dns_tool.lua | 232 | ||||
-rw-r--r-- | lualib/rspamadm/fuzzy_convert.lua | 208 | ||||
-rw-r--r-- | lualib/rspamadm/fuzzy_ping.lua | 259 | ||||
-rw-r--r-- | lualib/rspamadm/fuzzy_stat.lua | 366 | ||||
-rw-r--r-- | lualib/rspamadm/grep.lua | 174 | ||||
-rw-r--r-- | lualib/rspamadm/keypair.lua | 508 | ||||
-rw-r--r-- | lualib/rspamadm/mime.lua | 1012 | ||||
-rw-r--r-- | lualib/rspamadm/neural_test.lua | 228 | ||||
-rw-r--r-- | lualib/rspamadm/publicsuffix.lua | 82 | ||||
-rw-r--r-- | lualib/rspamadm/stat_convert.lua | 38 | ||||
-rw-r--r-- | lualib/rspamadm/statistics_dump.lua | 544 | ||||
-rw-r--r-- | lualib/rspamadm/template.lua | 131 | ||||
-rw-r--r-- | lualib/rspamadm/vault.lua | 579 |
21 files changed, 7258 insertions, 0 deletions
diff --git a/lualib/rspamadm/clickhouse.lua b/lualib/rspamadm/clickhouse.lua new file mode 100644 index 0000000..b22d800 --- /dev/null +++ b/lualib/rspamadm/clickhouse.lua @@ -0,0 +1,528 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local argparse = require "argparse" +local lua_clickhouse = require "lua_clickhouse" +local lua_util = require "lua_util" +local rspamd_http = require "rspamd_http" +local rspamd_upstream_list = require "rspamd_upstream_list" +local rspamd_logger = require "rspamd_logger" +local ucl = require "ucl" + +local E = {} + +-- Define command line options +local parser = argparse() + :name 'rspamadm clickhouse' + :description 'Retrieve information from Clickhouse' + :help_description_margin(30) + :command_target('command') + :require_command(true) + +parser:option '-c --config' + :description 'Path to config file' + :argname('config_file') + :default(rspamd_paths['CONFDIR'] .. '/rspamd.conf') +parser:option '-d --database' + :description 'Name of Clickhouse database to use' + :argname('database') + :default('default') +parser:flag '--no-ssl-verify' + :description 'Disable SSL verification' + :argname('no_ssl_verify') +parser:mutex( + parser:option '-p --password' + :description 'Password to use for Clickhouse' + :argname('password'), + parser:flag '-a --ask-password' + :description 'Ask password from the terminal' + :argname('ask_password') +) +parser:option '-s --server' + :description 'Address[:port] to connect to Clickhouse with' + :argname('server') +parser:option '-u --user' + :description 'Username to use for Clickhouse' + :argname('user') +parser:option '--use-gzip' + :description 'Use Gzip with Clickhouse' + :argname('use_gzip') + :default(true) +parser:flag '--use-https' + :description 'Use HTTPS with Clickhouse' + :argname('use_https') + +local neural_profile = parser:command 'neural_profile' + :description 'Generate symbols profile using data from Clickhouse' +neural_profile:option '-w --where' + :description 'WHERE clause for Clickhouse query' + :argname('where') +neural_profile:flag '-j --json' + :description 'Write output as JSON' + :argname('json') +neural_profile:option '--days' + :description 'Number of days to collect stats for' + :argname('days') + :default('7') +neural_profile:option '--limit -l' + :description 'Maximum rows to fetch per day' + :argname('limit') +neural_profile:option '--settings-id' + :description 'Settings ID to query' + :argname('settings_id') + :default('') + +local neural_train = parser:command 'neural_train' + :description 'Train neural using data from Clickhouse' +neural_train:option '--days' + :description 'Number of days to query data for' + :argname('days') + :default('7') +neural_train:option '--column-name-digest' + :description 'Name of neural profile digest column in Clickhouse' + :argname('column_name_digest') + :default('NeuralDigest') +neural_train:option '--column-name-vector' + :description 'Name of neural training vector column in Clickhouse' + :argname('column_name_vector') + :default('NeuralMpack') +neural_train:option '--limit -l' + :description 'Maximum rows to fetch per day' + :argname('limit') +neural_train:option '--profile -p' + :description 'Profile to use for training' + :argname('profile') + :default('default') +neural_train:option '--rule -r' + :description 'Rule to train' + :argname('rule') + :default('default') +neural_train:option '--spam -s' + :description 'WHERE clause to use for spam' + :argname('spam') + :default("Action == 'reject'") +neural_train:option '--ham -h' + :description 'WHERE clause to use for ham' + :argname('ham') + :default('Score < 0') +neural_train:option '--url -u' + :description 'URL to use for training' + :argname('url') + :default('http://127.0.0.1:11334/plugins/neural/learn') + +local http_params = { + config = rspamd_config, + ev_base = rspamadm_ev_base, + session = rspamadm_session, + resolver = rspamadm_dns_resolver, +} + +local function load_config(config_file) + local _r, err = rspamd_config:load_ucl(config_file) + + if not _r then + rspamd_logger.errx('cannot load %s: %s', config_file, err) + os.exit(1) + end + + _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' }) + if not _r then + rspamd_logger.errx('cannot process %s: %s', config_file, err) + os.exit(1) + end + + if not rspamd_config:init_modules() then + rspamd_logger.errx('cannot init modules when parsing %s', config_file) + os.exit(1) + end + + rspamd_config:init_subsystem('symcache') +end + +local function days_list(days) + -- Create list of days to query starting with yesterday + local query_days = {} + local previous_date = os.time() - 86400 + local num_days = tonumber(days) + for _ = 1, num_days do + table.insert(query_days, os.date('%Y-%m-%d', previous_date)) + previous_date = previous_date - 86400 + end + return query_days +end + +local function get_excluded_symbols(known_symbols, correlations, seen_total) + -- Walk results once to collect all symbols & count occurrences + + local remove = {} + local known_symbols_list = {} + local composites = rspamd_config:get_all_opt('composites') + local all_symbols = rspamd_config:get_symbols() + local skip_flags = { + nostat = true, + skip = true, + idempotent = true, + composite = true, + } + for k, v in pairs(known_symbols) do + local lower_count, higher_count + if v.seen_spam > v.seen_ham then + lower_count = v.seen_ham + higher_count = v.seen_spam + else + lower_count = v.seen_spam + higher_count = v.seen_ham + end + + if composites[k] then + remove[k] = 'composite symbol' + elseif lower_count / higher_count >= 0.95 then + remove[k] = 'weak ham/spam correlation' + elseif v.seen / seen_total >= 0.9 then + remove[k] = 'omnipresent symbol' + elseif not all_symbols[k] then + remove[k] = 'nonexistent symbol' + else + for fl, _ in pairs(all_symbols[k].flags or {}) do + if skip_flags[fl] then + remove[k] = fl .. ' symbol' + break + end + end + end + known_symbols_list[v.id] = { + seen = v.seen, + name = k, + } + end + + -- Walk correlation matrix and check total counts + for sym_id, row in pairs(correlations) do + for inner_sym_id, count in pairs(row) do + local known = known_symbols_list[sym_id] + local inner = known_symbols_list[inner_sym_id] + if known and count == known.seen and not remove[inner.name] and not remove[known.name] then + remove[known.name] = string.format("overlapped by %s", + known_symbols_list[inner_sym_id].name) + end + end + end + + return remove +end + +local function handle_neural_profile(args) + + local known_symbols, correlations = {}, {} + local symbols_count, seen_total = 0, 0 + + local function process_row(r) + local is_spam = true + if r['Action'] == 'no action' or r['Action'] == 'greylist' then + is_spam = false + end + seen_total = seen_total + 1 + + local nsym = #r['Symbols.Names'] + + for i = 1, nsym do + local sym = r['Symbols.Names'][i] + local t = known_symbols[sym] + if not t then + local spam_count, ham_count = 0, 0 + if is_spam then + spam_count = spam_count + 1 + else + ham_count = ham_count + 1 + end + known_symbols[sym] = { + id = symbols_count, + seen = 1, + seen_ham = ham_count, + seen_spam = spam_count, + } + symbols_count = symbols_count + 1 + else + known_symbols[sym].seen = known_symbols[sym].seen + 1 + if is_spam then + known_symbols[sym].seen_spam = known_symbols[sym].seen_spam + 1 + else + known_symbols[sym].seen_ham = known_symbols[sym].seen_ham + 1 + end + end + end + + -- Fill correlations + for i = 1, nsym do + for j = 1, nsym do + if i ~= j then + local sym = r['Symbols.Names'][i] + local inner_sym_name = r['Symbols.Names'][j] + local known_sym = known_symbols[sym] + local inner_sym = known_symbols[inner_sym_name] + if known_sym and inner_sym then + if not correlations[known_sym.id] then + correlations[known_sym.id] = {} + end + local n = correlations[known_sym.id][inner_sym.id] or 0 + n = n + 1 + correlations[known_sym.id][inner_sym.id] = n + end + end + end + end + end + + local query_days = days_list(args.days) + local conditions = {} + table.insert(conditions, string.format("SettingsId = '%s'", args.settings_id)) + local limit = '' + local num_limit = tonumber(args.limit) + if num_limit then + limit = string.format(' LIMIT %d', num_limit) -- Contains leading space + end + if args.where then + table.insert(conditions, args.where) + end + + local query_fmt = 'SELECT Action, Symbols.Names FROM rspamd WHERE %s%s' + for _, query_day in ipairs(query_days) do + -- Date should be the last condition + table.insert(conditions, string.format("Date = '%s'", query_day)) + local query = string.format(query_fmt, table.concat(conditions, ' AND '), limit) + local upstream = args.upstream:get_upstream_round_robin() + local err = lua_clickhouse.select_sync(upstream, args, http_params, query, process_row) + if err ~= nil then + io.stderr:write(string.format('Error querying Clickhouse: %s\n', err)) + os.exit(1) + end + conditions[#conditions] = nil -- remove Date condition + end + + local remove = get_excluded_symbols(known_symbols, correlations, seen_total) + if not args.json then + for k in pairs(known_symbols) do + if not remove[k] then + io.stdout:write(string.format('%s\n', k)) + end + end + os.exit(0) + end + + local json_output = { + all_symbols = {}, + removed_symbols = {}, + used_symbols = {}, + } + for k in pairs(known_symbols) do + table.insert(json_output.all_symbols, k) + local why_removed = remove[k] + if why_removed then + json_output.removed_symbols[k] = why_removed + else + table.insert(json_output.used_symbols, k) + end + end + io.stdout:write(ucl.to_format(json_output, 'json')) +end + +local function post_neural_training(url, rule, spam_rows, ham_rows) + -- Prepare JSON payload + local payload = ucl.to_format( + { + ham_vec = ham_rows, + rule = rule, + spam_vec = spam_rows, + }, 'json') + + -- POST the payload + local err, response = rspamd_http.request({ + body = payload, + config = rspamd_config, + ev_base = rspamadm_ev_base, + log_obj = rspamd_config, + resolver = rspamadm_dns_resolver, + session = rspamadm_session, + url = url, + }) + + if err then + io.stderr:write(string.format('HTTP error: %s\n', err)) + os.exit(1) + end + if response.code ~= 200 then + io.stderr:write(string.format('bad HTTP code: %d\n', response.code)) + os.exit(1) + end + io.stdout:write(string.format('%s\n', response.content)) +end + +local function handle_neural_train(args) + + local this_where -- which class of messages are we collecting data for + local ham_rows, spam_rows = {}, {} + local want_spam, want_ham = true, true -- keep collecting while true + + -- Try find profile in config + local neural_opts = rspamd_config:get_all_opt('neural') + local symbols_profile = ((((neural_opts or E).rules or E)[args.rule] or E).profile or E)[args.profile] + if not symbols_profile then + io.stderr:write(string.format("Couldn't find profile %s in rule %s\n", args.profile, args.rule)) + os.exit(1) + end + -- Try find max_trains + local max_trains = (neural_opts.rules[args.rule].train or E).max_trains or 1000 + + -- Callback used to process rows from Clickhouse + local function process_row(r) + local destination -- which table to collect this information in + if this_where == args.ham then + destination = ham_rows + if #destination >= max_trains then + want_ham = false + return + end + else + destination = spam_rows + if #destination >= max_trains then + want_spam = false + return + end + end + local ucl_parser = ucl.parser() + local ok, err = ucl_parser:parse_string(r[args.column_name_vector], 'msgpack') + if not ok then + io.stderr:write(string.format("Couldn't parse [%s]: %s", r[args.column_name_vector], err)) + os.exit(1) + end + table.insert(destination, ucl_parser:get_object()) + end + + -- Generate symbols digest + table.sort(symbols_profile) + local symbols_digest = lua_util.table_digest(symbols_profile) + -- Create list of days to query data for + local query_days = days_list(args.days) + -- Set value for limit + local limit = '' + local num_limit = tonumber(args.limit) + if num_limit then + limit = string.format(' LIMIT %d', num_limit) -- Contains leading space + end + -- Prepare query elements + local conditions = { string.format("%s = '%s'", args.column_name_digest, symbols_digest) } + local query_fmt = 'SELECT %s FROM rspamd WHERE %s%s' + + -- Run queries + for _, the_where in ipairs({ args.ham, args.spam }) do + -- Inform callback which group of vectors we're collecting + this_where = the_where + table.insert(conditions, the_where) -- should be 2nd from last condition + -- Loop over days and try collect data + for _, query_day in ipairs(query_days) do + -- Break the loop if we have enough data already + if this_where == args.ham then + if not want_ham then + break + end + else + if not want_spam then + break + end + end + -- Date should be the last condition + table.insert(conditions, string.format("Date = '%s'", query_day)) + local query = string.format(query_fmt, args.column_name_vector, table.concat(conditions, ' AND '), limit) + local upstream = args.upstream:get_upstream_round_robin() + local err = lua_clickhouse.select_sync(upstream, args, http_params, query, process_row) + if err ~= nil then + io.stderr:write(string.format('Error querying Clickhouse: %s\n', err)) + os.exit(1) + end + conditions[#conditions] = nil -- remove Date condition + end + conditions[#conditions] = nil -- remove spam/ham condition + end + + -- Make sure we collected enough data for training + if #ham_rows < max_trains then + io.stderr:write(string.format('Insufficient ham rows: %d/%d\n', #ham_rows, max_trains)) + os.exit(1) + end + if #spam_rows < max_trains then + io.stderr:write(string.format('Insufficient spam rows: %d/%d\n', #spam_rows, max_trains)) + os.exit(1) + end + + return post_neural_training(args.url, args.rule, spam_rows, ham_rows) +end + +local command_handlers = { + neural_profile = handle_neural_profile, + neural_train = handle_neural_train, +} + +local function handler(args) + local cmd_opts = parser:parse(args) + + load_config(cmd_opts.config_file) + local cfg_opts = rspamd_config:get_all_opt('clickhouse') + + if cmd_opts.ask_password then + local rspamd_util = require "rspamd_util" + + io.write('Password: ') + cmd_opts.password = rspamd_util.readpassphrase() + end + + local function override_settings(params) + for _, which in ipairs(params) do + if cmd_opts[which] == nil then + cmd_opts[which] = cfg_opts[which] + end + end + end + + override_settings({ + 'database', 'no_ssl_verify', 'password', 'server', + 'use_gzip', 'use_https', 'user', + }) + + local servers = cmd_opts['server'] or cmd_opts['servers'] + if not servers then + parser:error("server(s) unspecified & couldn't be fetched from config") + end + + cmd_opts.upstream = rspamd_upstream_list.create(rspamd_config, servers, 8123) + + if not cmd_opts.upstream then + io.stderr:write(string.format("can't parse clickhouse address: %s\n", servers)) + os.exit(1) + end + + local f = command_handlers[cmd_opts.command] + if not f then + parser:error(string.format("command isn't implemented: %s", + cmd_opts.command)) + end + f(cmd_opts) +end + +return { + handler = handler, + description = parser._description, + name = 'clickhouse' +} diff --git a/lualib/rspamadm/configgraph.lua b/lualib/rspamadm/configgraph.lua new file mode 100644 index 0000000..07f14a9 --- /dev/null +++ b/lualib/rspamadm/configgraph.lua @@ -0,0 +1,172 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local rspamd_logger = require "rspamd_logger" +local rspamd_util = require "rspamd_util" +local rspamd_regexp = require "rspamd_regexp" +local argparse = require "argparse" + +-- Define command line options +local parser = argparse() + :name "rspamadm configgraph" + :description "Produces graph of Rspamd includes" + :help_description_margin(30) +parser:option "-c --config" + :description "Path to config file" + :argname("<file>") + :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf") +parser:flag "-a --all" + :description('Show all nodes, not just existing ones') + +local function process_filename(fname) + local cdir = rspamd_paths['CONFDIR'] .. '/' + fname = fname:gsub(cdir, '') + return fname +end + +local function output_dot(opts, nodes, adjacency) + rspamd_logger.messagex("digraph rspamd {") + for k, node in pairs(nodes) do + local attrs = { "shape=box" } + local skip = false + if node.exists then + if node.priority >= 10 then + attrs[#attrs + 1] = "color=red" + elseif node.priority > 0 then + attrs[#attrs + 1] = "color=blue" + end + else + if opts.all then + attrs[#attrs + 1] = "style=dotted" + else + skip = true + end + end + + if not skip then + rspamd_logger.messagex("\"%s\" [%s];", process_filename(k), + table.concat(attrs, ',')) + end + end + for _, adj in ipairs(adjacency) do + local attrs = {} + local skip = false + + if adj.to.exists then + if adj.to.merge then + attrs[#attrs + 1] = "arrowhead=diamond" + attrs[#attrs + 1] = "label=\"+\"" + elseif adj.to.priority > 1 then + attrs[#attrs + 1] = "color=red" + end + else + if opts.all then + attrs[#attrs + 1] = "style=dotted" + else + skip = true + end + end + + if not skip then + rspamd_logger.messagex("\"%s\" -> \"%s\" [%s];", process_filename(adj.from), + adj.to.short_path, table.concat(attrs, ',')) + end + end + rspamd_logger.messagex("}") +end + +local function load_config_traced(opts) + local glob_traces = {} + local adjacency = {} + local nodes = {} + + local function maybe_match_glob(file) + for _, gl in ipairs(glob_traces) do + if gl.re:match(file) then + return gl + end + end + + return nil + end + + local function add_dep(from, node, args) + adjacency[#adjacency + 1] = { + from = from, + to = node, + args = args + } + end + + local function process_node(fname, args) + local node = nodes[fname] + if not node then + node = { + path = fname, + short_path = process_filename(fname), + exists = rspamd_util.file_exists(fname), + merge = args.duplicate and args.duplicate == 'merge', + priority = args.priority or 0, + glob = args.glob, + try = args.try, + } + nodes[fname] = node + end + + return node + end + + local function trace_func(cur_file, included_file, args, parent) + if args.glob then + glob_traces[#glob_traces + 1] = { + re = rspamd_regexp.import_glob(included_file, ''), + parent = cur_file, + args = args, + seen = {}, + } + else + local node = process_node(included_file, args) + if opts.all or node.exists then + local gl_parent = maybe_match_glob(included_file) + if gl_parent and not gl_parent.seen[cur_file] then + add_dep(gl_parent.parent, nodes[cur_file], gl_parent.args) + gl_parent.seen[cur_file] = true + end + add_dep(cur_file, node, args) + end + end + end + + local _r, err = rspamd_config:load_ucl(opts['config'], trace_func) + if not _r then + rspamd_logger.errx('cannot parse %s: %s', opts['config'], err) + os.exit(1) + end + + output_dot(opts, nodes, adjacency) +end + +local function handler(args) + local res = parser:parse(args) + + load_config_traced(res) +end + +return { + handler = handler, + description = parser._description, + name = 'configgraph' +}
\ No newline at end of file diff --git a/lualib/rspamadm/confighelp.lua b/lualib/rspamadm/confighelp.lua new file mode 100644 index 0000000..38b26b6 --- /dev/null +++ b/lualib/rspamadm/confighelp.lua @@ -0,0 +1,123 @@ +local opts +local known_attrs = { + data = 1, + example = 1, + type = 1, + required = 1, + default = 1, +} +local argparse = require "argparse" +local ansicolors = require "ansicolors" + +local parser = argparse() + :name "rspamadm confighelp" + :description "Shows help for the specified configuration options" + :help_description_margin(32) +parser:argument "path":args "*" + :description('Optional config paths') +parser:flag "--no-color" + :description "Disable coloured output" +parser:flag "--short" + :description "Show only option names" +parser:flag "--no-examples" + :description "Do not show examples (implied by --short)" + +local function maybe_print_color(key) + if not opts['no-color'] then + return ansicolors.white .. key .. ansicolors.reset + else + return key + end +end + +local function sort_values(tbl) + local res = {} + for k, v in pairs(tbl) do + table.insert(res, { key = k, value = v }) + end + + -- Sort order + local order = { + options = 1, + dns = 2, + upstream = 3, + logging = 4, + metric = 5, + composite = 6, + classifier = 7, + modules = 8, + lua = 9, + worker = 10, + workers = 11, + } + + table.sort(res, function(a, b) + local oa = order[a['key']] + local ob = order[b['key']] + + if oa and ob then + return oa < ob + elseif oa then + return -1 < 0 + elseif ob then + return 1 < 0 + else + return a['key'] < b['key'] + end + + end) + + return res +end + +local function print_help(key, value, tabs) + print(string.format('%sConfiguration element: %s', tabs, maybe_print_color(key))) + + if not opts['short'] then + if value['data'] then + local nv = string.match(value['data'], '^#%s*(.*)%s*$') or value.data + print(string.format('%s\tDescription: %s', tabs, nv)) + end + if type(value['type']) == 'string' then + print(string.format('%s\tType: %s', tabs, value['type'])) + end + if type(value['required']) == 'boolean' then + if value['required'] then + print(string.format('%s\tRequired: %s', tabs, + maybe_print_color(tostring(value['required'])))) + else + print(string.format('%s\tRequired: %s', tabs, + tostring(value['required']))) + end + end + if value['default'] then + print(string.format('%s\tDefault: %s', tabs, value['default'])) + end + if not opts['no-examples'] and value['example'] then + local nv = string.match(value['example'], '^%s*(.*[^%s])%s*$') or value.example + print(string.format('%s\tExample:\n%s', tabs, nv)) + end + if value.type and value.type == 'object' then + print('') + end + end + + local sorted = sort_values(value) + for _, v in ipairs(sorted) do + if not known_attrs[v['key']] then + -- We need to go deeper + print_help(v['key'], v['value'], tabs .. '\t') + end + end +end + +return function(args, res) + opts = parser:parse(args) + + local sorted = sort_values(res) + + for _, v in ipairs(sorted) do + print_help(v['key'], v['value'], '') + print('') + end +end diff --git a/lualib/rspamadm/configwizard.lua b/lualib/rspamadm/configwizard.lua new file mode 100644 index 0000000..2637036 --- /dev/null +++ b/lualib/rspamadm/configwizard.lua @@ -0,0 +1,849 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local ansicolors = require "ansicolors" +local local_conf = rspamd_paths['CONFDIR'] +local rspamd_util = require "rspamd_util" +local rspamd_logger = require "rspamd_logger" +local lua_util = require "lua_util" +local lua_stat_tools = require "lua_stat" +local lua_redis = require "lua_redis" +local ucl = require "ucl" +local argparse = require "argparse" +local fun = require "fun" + +local plugins_stat = require "plugins_stats" + +local rspamd_logo = [[ + ____ _ + | _ \ ___ _ __ __ _ _ __ ___ __| | + | |_) |/ __|| '_ \ / _` || '_ ` _ \ / _` | + | _ < \__ \| |_) || (_| || | | | | || (_| | + |_| \_\|___/| .__/ \__,_||_| |_| |_| \__,_| + |_| +]] + +local parser = argparse() + :name "rspamadm configwizard" + :description "Perform guided configuration for Rspamd daemon" + :help_description_margin(32) +parser:option "-c --config" + :description "Path to config file" + :argname("<file>") + :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf") +parser:argument "checks" + :description "Checks to do (or 'list')" + :argname("<checks>") + :args "*" + +local redis_params + +local function printf(fmt, ...) + if fmt then + io.write(string.format(fmt, ...)) + end + io.write('\n') +end + +local function highlight(str) + return ansicolors.white .. str .. ansicolors.reset +end + +local function ask_yes_no(greet, default) + local def_str + if default then + greet = greet .. "[Y/n]: " + def_str = "yes" + else + greet = greet .. "[y/N]: " + def_str = "no" + end + + local reply = rspamd_util.readline(greet) + + if not reply then + os.exit(0) + end + if #reply == 0 then + reply = def_str + end + reply = reply:lower() + if reply == 'y' or reply == 'yes' then + return true + end + + return false +end + +local function readline_default(greet, def_value) + local reply = rspamd_util.readline(greet) + if not reply then + os.exit(0) + end + + if #reply == 0 then + return def_value + end + + return reply +end + +local function readline_expire() + local expire = '100d' + repeat + expire = readline_default("Expire time for new tokens [" .. expire .. "]: ", + expire) + expire = lua_util.parse_time_interval(expire) + + if not expire then + expire = '100d' + elseif expire > 2147483647 then + printf("The maximum possible value is 2147483647 (about 68y)") + expire = '68y' + elseif expire < -1 then + printf("The value must be a non-negative integer or -1") + expire = -1 + elseif expire ~= math.floor(expire) then + printf("The value must be an integer") + expire = math.floor(expire) + else + return expire + end + until false +end + +local function print_changes(changes) + local function print_change(k, c, where) + printf('File: %s, changes list:', highlight(local_conf .. '/' + .. where .. '/' .. k)) + + for ek, ev in pairs(c) do + printf("%s => %s", highlight(ek), rspamd_logger.slog("%s", ev)) + end + end + for k, v in pairs(changes.l) do + print_change(k, v, 'local.d') + if changes.o[k] then + v = changes.o[k] + print_change(k, v, 'override.d') + end + print() + end +end + +local function apply_changes(changes) + local function dirname(fname) + if fname:match(".-/.-") then + return string.gsub(fname, "(.*/)(.*)", "%1") + else + return nil + end + end + + local function apply_change(k, c, where) + local fname = local_conf .. '/' .. where .. '/' .. k + + if not rspamd_util.file_exists(fname) then + printf("Create file %s", highlight(fname)) + + local dname = dirname(fname) + + if dname then + local ret, err = rspamd_util.mkdir(dname, true) + + if not ret then + printf("Cannot make directory %s: %s", dname, highlight(err)) + os.exit(1) + end + end + end + + local f = io.open(fname, "a+") + + if not f then + printf("Cannot open file %s, aborting", highlight(fname)) + os.exit(1) + end + + f:write(ucl.to_config(c)) + + f:close() + end + for k, v in pairs(changes.l) do + apply_change(k, v, 'local.d') + if changes.o[k] then + v = changes.o[k] + apply_change(k, v, 'override.d') + end + end +end + +local function setup_controller(controller, changes) + printf("Setup %s and controller worker:", highlight("WebUI")) + + if not controller.password or controller.password == 'q1' then + if ask_yes_no("Controller password is not set, do you want to set one?", true) then + local pw_encrypted = rspamadm.pw_encrypt() + if pw_encrypted then + printf("Set encrypted password to: %s", highlight(pw_encrypted)) + changes.l['worker-controller.inc'] = { + password = pw_encrypted + } + end + end + end +end + +local function setup_redis(cfg, changes) + local function parse_servers(servers) + local ls = lua_util.rspamd_str_split(servers, ",") + + return ls + end + + printf("%s servers are not set:", highlight("Redis")) + printf("The following modules will be enabled if you add Redis servers:") + + for k, _ in pairs(rspamd_plugins_state.disabled_redis) do + printf("\t* %s", highlight(k)) + end + + if ask_yes_no("Do you wish to set Redis servers?", true) then + local read_servers = readline_default("Input read only servers separated by `,` [default: localhost]: ", + "localhost") + + local rs = parse_servers(read_servers) + if rs and #rs > 0 then + changes.l['redis.conf'] = { + read_servers = table.concat(rs, ",") + } + end + local write_servers = readline_default("Input write only servers separated by `,` [default: " + .. read_servers .. "]: ", read_servers) + + if not write_servers or #write_servers == 0 then + printf("Use read servers %s as write servers", highlight(table.concat(rs, ","))) + write_servers = read_servers + end + + redis_params = { + read_servers = rs, + } + + local ws = parse_servers(write_servers) + if ws and #ws > 0 then + changes.l['redis.conf']['write_servers'] = table.concat(ws, ",") + redis_params['write_servers'] = ws + end + + if ask_yes_no('Do you have any username set for your Redis (ACL SETUSER and Redis 6.0+)') then + local username = readline_default("Enter Redis username:", nil) + + if username then + changes.l['redis.conf'].username = username + redis_params.username = username + end + + local passwd = readline_default("Enter Redis password:", nil) + + if passwd then + changes.l['redis.conf']['password'] = passwd + redis_params['password'] = passwd + end + elseif ask_yes_no('Do you have any password set for your Redis?') then + local passwd = readline_default("Enter Redis password:", nil) + + if passwd then + changes.l['redis.conf']['password'] = passwd + redis_params['password'] = passwd + end + end + + if ask_yes_no('Do you have any specific database for your Redis?') then + local db = readline_default("Enter Redis database:", nil) + + if db then + changes.l['redis.conf']['db'] = db + redis_params['db'] = db + end + end + end +end + +local function setup_dkim_signing(cfg, changes) + -- Remove the trailing slash of a pathname, if present. + local function remove_trailing_slash(path) + if string.sub(path, -1) ~= "/" then + return path + end + return string.sub(path, 1, string.len(path) - 1) + end + + printf('How would you like to set up DKIM signing?') + printf('1. Use domain from %s for sign', highlight('mime from header')) + printf('2. Use domain from %s for sign', highlight('SMTP envelope from')) + printf('3. Use domain from %s for sign', highlight('authenticated user')) + printf('4. Sign all mail from %s', highlight('specific networks')) + printf() + + local sign_type = readline_default('Enter your choice (1, 2, 3, 4) [default: 1]: ', '1') + local sign_networks + local allow_mismatch + local sign_authenticated + local use_esld + local sign_domain = 'pet luacheck' + + local defined_auth_types = { 'header', 'envelope', 'auth', 'recipient' } + + if sign_type == '4' then + repeat + sign_networks = readline_default('Enter list of networks to perform dkim signing: ', + '') + until #sign_networks ~= 0 + + sign_networks = fun.totable(fun.map(lua_util.rspamd_str_trim, + lua_util.str_split(sign_networks, ',; '))) + printf('What domain would you like to use for signing?') + printf('* %s to use mime from domain', highlight('header')) + printf('* %s to use SMTP from domain', highlight('envelope')) + printf('* %s to use domain from SMTP auth', highlight('auth')) + printf('* %s to use domain from SMTP recipient', highlight('recipient')) + printf('* anything else to use as a %s domain (e.g. `example.com`)', highlight('static')) + printf() + + sign_domain = readline_default('Enter your choice [default: header]: ', 'header') + else + if sign_type == '1' then + sign_domain = 'header' + elseif sign_type == '2' then + sign_domain = 'envelope' + else + sign_domain = 'auth' + end + end + + if sign_type ~= '3' then + sign_authenticated = ask_yes_no( + string.format('Do you want to sign mail from %s? ', + highlight('authenticated users')), true) + else + sign_authenticated = true + end + + if fun.any(function(s) + return s == sign_domain + end, defined_auth_types) then + -- Allow mismatch + allow_mismatch = ask_yes_no( + string.format('Allow data %s, e.g. if mime from domain is not equal to authenticated user domain? ', + highlight('mismatch')), true) + -- ESLD check + use_esld = ask_yes_no( + string.format('Do you want to use %s domain (e.g. example.com instead of foo.example.com)? ', + highlight('effective')), true) + else + allow_mismatch = true + end + + local domains = {} + local has_domains = false + + local dkim_keys_dir = rspamd_paths["DBDIR"] .. "/dkim/" + + local prompt = string.format("Enter output directory for the keys [default: %s]: ", + highlight(dkim_keys_dir)) + dkim_keys_dir = remove_trailing_slash(readline_default(prompt, dkim_keys_dir)) + + local ret, err = rspamd_util.mkdir(dkim_keys_dir, true) + + if not ret then + printf("Cannot make directory %s: %s", dkim_keys_dir, highlight(err)) + os.exit(1) + end + + local function print_domains() + printf("Domains configured:") + for k, v in pairs(domains) do + printf("Domain: %s, selector: %s, privkey: %s", highlight(k), + v.selector, v.privkey) + end + printf("--") + end + local function print_public_key(pk) + local base64_pk = tostring(rspamd_util.encode_base64(pk)) + printf('v=DKIM1; k=rsa; p=%s\n', base64_pk) + end + repeat + if has_domains then + print_domains() + end + + local domain + repeat + domain = rspamd_util.readline("Enter domain to sign: ") + if not domain then + os.exit(1) + end + until #domain ~= 0 + + local selector = readline_default("Enter selector [default: dkim]: ", 'dkim') + if not selector then + selector = 'dkim' + end + + local privkey_file = string.format("%s/%s.%s.key", dkim_keys_dir, domain, + selector) + if not rspamd_util.file_exists(privkey_file) then + if ask_yes_no("Do you want to create privkey " .. highlight(privkey_file), + true) then + local rsa = require "rspamd_rsa" + local sk, pk = rsa.keypair(2048) + sk:save(privkey_file, 'pem') + print("You need to chown private key file to rspamd user!!") + print("To make dkim signing working, to place the following record in your DNS zone:") + print_public_key(tostring(pk)) + end + end + + domains[domain] = { + selector = selector, + path = privkey_file, + } + until not ask_yes_no("Do you wish to add another DKIM domain?") + + changes.l['dkim_signing.conf'] = { domain = domains } + local res_tbl = changes.l['dkim_signing.conf'] + + if sign_networks then + res_tbl.sign_networks = sign_networks + res_tbl.use_domain_sign_networks = sign_domain + else + res_tbl.use_domain = sign_domain + end + + if allow_mismatch then + res_tbl.allow_hdrfrom_mismatch = true + res_tbl.allow_hdrfrom_mismatch_sign_networks = true + res_tbl.allow_username_mismatch = true + end + + res_tbl.use_esld = use_esld + res_tbl.sign_authenticated = sign_authenticated +end + +local function check_redis_classifier(cls, changes) + local symbol_spam, symbol_ham + -- Load symbols from statfiles + local statfiles = cls.statfile + for _, stf in ipairs(statfiles) do + local symbol = stf.symbol or 'undefined' + + local spam + if stf.spam then + spam = stf.spam + else + if string.match(symbol:upper(), 'SPAM') then + spam = true + else + spam = false + end + end + + if spam then + symbol_spam = symbol + else + symbol_ham = symbol + end + end + + if not symbol_spam or not symbol_ham then + printf("Classifier has no symbols defined") + return + end + + local parsed_redis = lua_redis.try_load_redis_servers(cls, nil) + + if not parsed_redis and redis_params then + parsed_redis = lua_redis.try_load_redis_servers(redis_params, nil) + if not parsed_redis then + printf("Cannot parse Redis params") + return + end + end + + local function try_convert(update_config) + if ask_yes_no("Do you wish to convert data to the new schema?", true) then + local expire = readline_expire() + if not lua_stat_tools.convert_bayes_schema(parsed_redis, symbol_spam, + symbol_ham, expire) then + printf("Conversion failed") + else + printf("Conversion succeed") + if update_config then + changes.l['classifier-bayes.conf'] = { + new_schema = true, + } + + if expire then + changes.l['classifier-bayes.conf'].expire = expire + end + end + end + end + end + + local function get_version(conn) + conn:add_cmd("SMEMBERS", { "RS_keys" }) + + local ret, members = conn:exec() + + -- Empty db + if not ret or #members == 0 then + return false, 0 + end + + -- We still need to check versions + local lua_script = [[ +local ver = 0 + +local tst = redis.call('GET', KEYS[1]..'_version') +if tst then + ver = tonumber(tst) or 0 +end + +return ver +]] + conn:add_cmd('EVAL', { lua_script, '1', 'RS' }) + local _, ver = conn:exec() + + return true, tonumber(ver) + end + + local function check_expire(conn) + -- We still need to check versions + local lua_script = [[ +local ttl = 0 + +local sc = redis.call('SCAN', 0, 'MATCH', 'RS*_*', 'COUNT', 1) +local _,key = sc[1], sc[2] + +if key and key[1] then + ttl = redis.call('TTL', key[1]) +end + +return ttl +]] + conn:add_cmd('EVAL', { lua_script, '0' }) + local _, ttl = conn:exec() + + return tonumber(ttl) + end + + local res, conn = lua_redis.redis_connect_sync(parsed_redis, true) + if not res then + printf("Cannot connect to Redis server") + return false + end + + if not cls.new_schema then + local r, ver = get_version(conn) + if not r then + return false + end + if ver ~= 2 then + if not ver then + printf('Key "RS_version" has not been found in Redis for %s/%s', + symbol_ham, symbol_spam) + else + printf("You are using an old schema version: %s for %s/%s", + ver, symbol_ham, symbol_spam) + end + try_convert(true) + else + printf("You have configured an old schema for %s/%s but your data has new layout", + symbol_ham, symbol_spam) + + if ask_yes_no("Switch config to the new schema?", true) then + changes.l['classifier-bayes.conf'] = { + new_schema = true, + } + + local expire = check_expire(conn) + if expire then + changes.l['classifier-bayes.conf'].expire = expire + end + end + end + else + local r, ver = get_version(conn) + if not r then + return false + end + if ver ~= 2 then + printf("You have configured new schema for %s/%s but your DB has old version: %s", + symbol_spam, symbol_ham, ver) + try_convert(false) + else + printf( + 'You have configured new schema for %s/%s and your DB already has new layout (v. %s).' .. + ' DB conversion is not needed.', + symbol_spam, symbol_ham, ver) + end + end +end + +local function setup_statistic(cfg, changes) + local sqlite_configs = lua_stat_tools.load_sqlite_config(cfg) + + if #sqlite_configs > 0 then + + if not redis_params then + printf('You have %d sqlite classifiers, but you have no Redis servers being set', + #sqlite_configs) + return false + end + + local parsed_redis = lua_redis.try_load_redis_servers(redis_params, nil) + if parsed_redis then + printf('You have %d sqlite classifiers', #sqlite_configs) + local expire = readline_expire() + + local reset_previous = ask_yes_no("Reset previous data?") + if ask_yes_no('Do you wish to convert them to Redis?', true) then + + for _, cls in ipairs(sqlite_configs) do + if rspamd_util.file_exists(cls.db_spam) and rspamd_util.file_exists(cls.db_ham) then + if not lua_stat_tools.convert_sqlite_to_redis(parsed_redis, cls.db_spam, + cls.db_ham, cls.symbol_spam, cls.symbol_ham, cls.learn_cache, expire, + reset_previous) then + rspamd_logger.errx('conversion failed') + + return false + end + else + rspamd_logger.messagex('cannot find %s and %s, skip conversion', + cls.db_spam, cls.db_ham) + end + + rspamd_logger.messagex('Converted classifier to the from sqlite to redis') + changes.l['classifier-bayes.conf'] = { + backend = 'redis', + new_schema = true, + } + + if expire then + changes.l['classifier-bayes.conf'].expire = expire + end + + if cls.learn_cache then + changes.l['classifier-bayes.conf'].cache = { + backend = 'redis' + } + end + end + end + end + else + -- Check sanity for the existing Redis classifiers + local classifier = cfg.classifier + + if classifier then + if classifier[1] then + for _, cls in ipairs(classifier) do + if cls.bayes then + cls = cls.bayes + end + if cls.backend and cls.backend == 'redis' then + check_redis_classifier(cls, changes) + end + end + else + if classifier.bayes then + + classifier = classifier.bayes + if classifier[1] then + for _, cls in ipairs(classifier) do + if cls.backend and cls.backend == 'redis' then + check_redis_classifier(cls, changes) + end + end + else + if classifier.backend and classifier.backend == 'redis' then + check_redis_classifier(classifier, changes) + end + end + end + end + end + end +end + +local function find_worker(cfg, wtype) + if cfg.worker then + for k, s in pairs(cfg.worker) do + if type(k) == 'number' and type(s) == 'table' then + if s[wtype] then + return s[wtype] + end + end + if type(s) == 'table' and s.type and s.type == wtype then + return s + end + if type(k) == 'string' and k == wtype then + return s + end + end + end + + return nil +end + +return { + handler = function(cmd_args) + local changes = { + l = {}, -- local changes + o = {}, -- override changes + } + + local interactive_start = true + local checks = {} + local all_checks = { + 'controller', + 'redis', + 'dkim', + 'statistic', + } + + local opts = parser:parse(cmd_args) + local args = opts['checks'] or {} + + local _r, err = rspamd_config:load_ucl(opts['config']) + + if not _r then + rspamd_logger.errx('cannot parse %s: %s', opts['config'], err) + os.exit(1) + end + + _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' }) + if not _r then + rspamd_logger.errx('cannot process %s: %s', opts['config'], err) + os.exit(1) + end + + local cfg = rspamd_config:get_ucl() + + if not rspamd_config:init_modules() then + rspamd_logger.errx('cannot init modules when parsing %s', opts['config']) + os.exit(1) + end + + if #args > 0 then + interactive_start = false + + for _, arg in ipairs(args) do + if arg == 'all' then + checks = all_checks + elseif arg == 'list' then + printf(highlight(rspamd_logo)) + printf('Available modules') + for _, c in ipairs(all_checks) do + printf('- %s', c) + end + return + else + table.insert(checks, arg) + end + end + else + checks = all_checks + end + + local function has_check(check) + for _, c in ipairs(checks) do + if c == check then + return true + end + end + + return false + end + + rspamd_util.umask('022') + if interactive_start then + printf(highlight(rspamd_logo)) + printf("Welcome to the configuration tool") + printf("We use %s configuration file, writing results to %s", + highlight(opts['config']), highlight(local_conf)) + plugins_stat(nil, nil) + end + + if not interactive_start or + ask_yes_no("Do you wish to continue?", true) then + + if has_check('controller') then + local controller = find_worker(cfg, 'controller') + if controller then + setup_controller(controller, changes) + end + end + + if has_check('redis') then + if not cfg.redis or (not cfg.redis.servers and not cfg.redis.read_servers) then + setup_redis(cfg, changes) + else + redis_params = cfg.redis + end + else + redis_params = cfg.redis + end + + if has_check('dkim') then + if cfg.dkim_signing and not cfg.dkim_signing.domain then + if ask_yes_no('Do you want to setup dkim signing feature?') then + setup_dkim_signing(cfg, changes) + end + end + end + + if has_check('statistic') or has_check('statistics') then + setup_statistic(cfg, changes) + end + + local nchanges = 0 + for _, _ in pairs(changes.l) do + nchanges = nchanges + 1 + end + for _, _ in pairs(changes.o) do + nchanges = nchanges + 1 + end + + if nchanges > 0 then + print_changes(changes) + if ask_yes_no("Apply changes?", true) then + apply_changes(changes) + printf("%d changes applied, the wizard is finished now", nchanges) + printf("*** Please reload the Rspamd configuration ***") + else + printf("No changes applied, the wizard is finished now") + end + else + printf("No changes found, the wizard is finished now") + end + end + end, + name = 'configwizard', + description = parser._description, +} diff --git a/lualib/rspamadm/cookie.lua b/lualib/rspamadm/cookie.lua new file mode 100644 index 0000000..7e0526a --- /dev/null +++ b/lualib/rspamadm/cookie.lua @@ -0,0 +1,125 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local argparse = require "argparse" + + +-- Define command line options +local parser = argparse() + :name "rspamadm cookie" + :description "Produces cookies or message ids" + :help_description_margin(30) + +parser:mutex( + parser:option "-k --key" + :description('Key to load') + :argname "<32hex>", + parser:flag "-K --new-key" + :description('Generates a new key') +) + +parser:option "-d --domain" + :description('Use specified domain and generate full message id') + :argname "<domain>" +parser:flag "-D --decrypt" + :description('Decrypt cookie instead of encrypting one') +parser:flag "-t --timestamp" + :description('Show cookie timestamp (valid for decrypting only)') +parser:argument "cookie":args "?" + :description('Use specified cookie') + +local function gen_cookie(args, key) + local cr = require "rspamd_cryptobox" + + if not args.cookie then + return + end + + local function encrypt() + if #args.cookie > 31 then + print('cookie too long (>31 characters), cannot encrypt') + os.exit(1) + end + + local enc_cookie = cr.encrypt_cookie(key, args.cookie) + if args.domain then + print(string.format('<%s@%s>', enc_cookie, args.domain)) + else + print(enc_cookie) + end + end + + local function decrypt() + local extracted_cookie = args.cookie:match('^%<?([^@]+)@.*$') + if not extracted_cookie then + -- Assume full message id as a cookie + extracted_cookie = args.cookie + end + + local dec_cookie, ts = cr.decrypt_cookie(key, extracted_cookie) + + if dec_cookie then + if args.timestamp then + print(string.format('%s %s', dec_cookie, ts)) + else + print(dec_cookie) + end + else + print('cannot decrypt cookie') + os.exit(1) + end + end + + if args.decrypt then + decrypt() + else + encrypt() + end +end + +local function handler(args) + local res = parser:parse(args) + + if not (res.key or res['new_key']) then + parser:error('--key or --new-key must be specified') + end + + if res.key then + local pattern = { '^' } + for i = 1, 32 do + pattern[i + 1] = '[a-zA-Z0-9]' + end + pattern[34] = '$' + + if not res.key:match(table.concat(pattern, '')) then + parser:error('invalid key: ' .. res.key) + end + + gen_cookie(res, res.key) + else + local util = require "rspamd_util" + local key = util.random_hex(32) + + print(key) + gen_cookie(res, res.key) + end +end + +return { + handler = handler, + description = parser._description, + name = 'cookie' +}
\ No newline at end of file diff --git a/lualib/rspamadm/corpus_test.lua b/lualib/rspamadm/corpus_test.lua new file mode 100644 index 0000000..0e63f9f --- /dev/null +++ b/lualib/rspamadm/corpus_test.lua @@ -0,0 +1,185 @@ +local rspamd_logger = require "rspamd_logger" +local ucl = require "ucl" +local lua_util = require "lua_util" +local argparse = require "argparse" + +local parser = argparse() + :name "rspamadm corpus_test" + :description "Create logs files from email corpus" + :help_description_margin(32) + +parser:option "-H --ham" + :description("Ham directory") + :argname("<dir>") +parser:option "-S --spam" + :description("Spam directory") + :argname("<dir>") +parser:option "-n --conns" + :description("Number of parallel connections") + :argname("<N>") + :convert(tonumber) + :default(10) +parser:option "-o --output" + :description("Output file") + :argname("<file>") + :default('results.log') +parser:option "-t --timeout" + :description("Timeout for client connections") + :argname("<sec>") + :convert(tonumber) + :default(60) +parser:option "-c --connect" + :description("Connect to specific host") + :argname("<host>") + :default('localhost:11334') +parser:option "-r --rspamc" + :description("Use specific rspamc path") + :argname("<path>") + :default('rspamc') + +local HAM = "HAM" +local SPAM = "SPAM" +local opts + +local function scan_email(n_parallel, path, timeout) + + local rspamc_command = string.format("%s --connect %s -j --compact -n %s -t %.3f %s", + opts.rspamc, opts.connect, n_parallel, timeout, path) + local result = assert(io.popen(rspamc_command)) + result = result:read("*all") + return result +end + +local function write_results(results, file) + + local f = io.open(file, 'w') + + for _, result in pairs(results) do + local log_line = string.format("%s %.2f %s", + result.type, result.score, result.action) + + for _, sym in pairs(result.symbols) do + log_line = log_line .. " " .. sym + end + + log_line = log_line .. " " .. result.scan_time .. " " .. file .. ':' .. result.filename + + log_line = log_line .. "\r\n" + + f:write(log_line) + end + + f:close() +end + +local function encoded_json_to_log(result) + -- Returns table containing score, action, list of symbols + + local filtered_result = {} + local ucl_parser = ucl.parser() + + local is_good, err = ucl_parser:parse_string(result) + + if not is_good then + rspamd_logger.errx("Parser error: %1", err) + return nil + end + + result = ucl_parser:get_object() + + filtered_result.score = result.score + if not result.action then + rspamd_logger.errx("Bad JSON: %1", result) + return nil + end + local action = result.action:gsub("%s+", "_") + filtered_result.action = action + + filtered_result.symbols = {} + + for sym, _ in pairs(result.symbols) do + table.insert(filtered_result.symbols, sym) + end + + filtered_result.filename = result.filename + filtered_result.scan_time = result.scan_time + + return filtered_result +end + +local function scan_results_to_logs(results, actual_email_type) + + local logs = {} + + results = lua_util.rspamd_str_split(results, "\n") + + if results[#results] == "" then + results[#results] = nil + end + + for _, result in pairs(results) do + result = encoded_json_to_log(result) + if result then + result['type'] = actual_email_type + table.insert(logs, result) + end + end + + return logs +end + +local function handler(args) + opts = parser:parse(args) + local ham_directory = opts['ham'] + local spam_directory = opts['spam'] + local connections = opts["conns"] + local output = opts["output"] + + local results = {} + + local start_time = os.time() + local no_of_ham = 0 + local no_of_spam = 0 + + if ham_directory then + rspamd_logger.messagex("Scanning ham corpus...") + local ham_results = scan_email(connections, ham_directory, opts["timeout"]) + ham_results = scan_results_to_logs(ham_results, HAM) + + no_of_ham = #ham_results + + for _, result in pairs(ham_results) do + table.insert(results, result) + end + end + + if spam_directory then + rspamd_logger.messagex("Scanning spam corpus...") + local spam_results = scan_email(connections, spam_directory, opts.timeout) + spam_results = scan_results_to_logs(spam_results, SPAM) + + no_of_spam = #spam_results + + for _, result in pairs(spam_results) do + table.insert(results, result) + end + end + + rspamd_logger.messagex("Writing results to %s", output) + write_results(results, output) + + rspamd_logger.messagex("Stats: ") + local elapsed_time = os.time() - start_time + local total_msgs = no_of_ham + no_of_spam + rspamd_logger.messagex("Elapsed time: %ss", elapsed_time) + rspamd_logger.messagex("No of ham: %s", no_of_ham) + rspamd_logger.messagex("No of spam: %s", no_of_spam) + rspamd_logger.messagex("Messages/sec: %s", (total_msgs / elapsed_time)) +end + +return { + name = 'corpustest', + aliases = { 'corpus_test', 'corpus' }, + handler = handler, + description = parser._description +} diff --git a/lualib/rspamadm/dkim_keygen.lua b/lualib/rspamadm/dkim_keygen.lua new file mode 100644 index 0000000..211094d --- /dev/null +++ b/lualib/rspamadm/dkim_keygen.lua @@ -0,0 +1,178 @@ +--[[ +Copyright (c) 2023, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local argparse = require "argparse" +local rspamd_util = require "rspamd_util" +local rspamd_cryptobox = require "rspamd_cryptobox" + +local parser = argparse() + :name 'rspamadm dkim_keygen' + :description 'Create key pairs for dkim signing' + :help_description_margin(30) +parser:option '-d --domain' + :description 'Create a key for a specific domain' + :default "example.com" +parser:option '-s --selector' + :description 'Create a key for a specific DKIM selector' + :default "mail" +parser:option '-k --privkey' + :description 'Save private key to file instead of printing it to stdout' +parser:option '-b --bits' + :convert(function(input) + local n = tonumber(input) + if not n or n < 512 or n > 4096 then + return nil + end + return n +end) + :description 'Generate an RSA key with the specified number of bits (512 to 4096)' + :default(1024) +parser:option '-t --type' + :description 'Key type: RSA, ED25519 or ED25119-seed' + :convert { + ['rsa'] = 'rsa', + ['RSA'] = 'rsa', + ['ed25519'] = 'ed25519', + ['ED25519'] = 'ed25519', + ['ed25519-seed'] = 'ed25519-seed', + ['ED25519-seed'] = 'ed25519-seed', +} + :default 'rsa' +parser:option '-o --output' + :description 'Output public key in the following format: dns, dnskey or plain' + :convert { + ['dns'] = 'dns', + ['plain'] = 'plain', + ['dnskey'] = 'dnskey', +} + :default 'dns' +parser:option '--priv-output' + :description 'Output private key in the following format: PEM or DER (for RSA)' + :convert { + ['pem'] = 'pem', + ['der'] = 'der', +} + :default 'pem' +parser:flag '-f --force' + :description 'Force overwrite of existing files' + +local function split_string(input, max_length) + max_length = max_length or 253 + local pieces = {} + local index = 1 + + while index <= #input do + local piece = input:sub(index, index + max_length - 1) + table.insert(pieces, piece) + index = index + max_length + end + + return pieces +end + +local function print_public_key_dns(opts, base64_pk) + local key_type = opts.type == 'rsa' and 'rsa' or 'ed25519' + if #base64_pk < 255 - 2 then + io.write(string.format('%s._domainkey IN TXT ( "v=DKIM1; k=%s;" \n\t"p=%s" ) ;\n', + opts.selector, key_type, base64_pk)) + else + -- Split it by parts + local parts = split_string(base64_pk) + io.write(string.format('%s._domainkey IN TXT ( "v=DKIM1; k=%s; "\n', opts.selector, key_type)) + for i, part in ipairs(parts) do + if i == 1 then + io.write(string.format('\t"p=%s"\n', part)) + else + io.write(string.format('\t"%s"\n', part)) + end + end + io.write(") ; \n") + end + +end + +local function print_public_key(opts, pk, need_base64) + local key_type = opts.type == 'rsa' and 'rsa' or 'ed25519' + local base64_pk = need_base64 and tostring(rspamd_util.encode_base64(pk)) or tostring(pk) + if opts.output == 'plain' then + io.write(base64_pk) + io.write("\n") + elseif opts.output == 'dns' then + print_public_key_dns(opts, base64_pk, false) + elseif opts.output == 'dnskey' then + io.write(string.format('v=DKIM1; k=%s; p=%s\n', key_type, base64_pk)) + end +end + +local function gen_rsa_key(opts) + local rsa = require "rspamd_rsa" + + local sk, pk = rsa.keypair(opts.bits or 1024) + if opts.privkey then + if opts.force then + os.remove(opts.privkey) + end + sk:save(opts.privkey, opts.priv_output) + else + sk:save("-", opts.priv_output) + end + + -- We generate key directly via lua_rsa and it returns unencoded raw data + print_public_key(opts, tostring(pk), true) +end + +local function gen_eddsa_key(opts) + local sk, pk = rspamd_cryptobox.gen_dkim_keypair(opts.type) + + if opts.privkey and opts.force then + os.remove(opts.privkey) + end + if not sk:save_in_file(opts.privkey, tonumber('0600', 8)) then + io.stderr:write('cannot save private key to ' .. (opts.privkey or 'stdout') .. '\n') + os.exit(1) + end + + if not opts.privkey then + io.write("\n") + io.flush() + end + + -- gen_dkim_keypair function returns everything encoded in base64, so no need to do anything + print_public_key(opts, tostring(pk), false) +end + +local function handler(args) + local opts = parser:parse(args) + + if not opts then + os.exit(1) + end + + if opts.type == 'rsa' then + gen_rsa_key(opts) + else + gen_eddsa_key(opts) + end +end + +return { + name = 'dkim_keygen', + aliases = { 'dkimkeygen' }, + handler = handler, + description = parser._description +} + + diff --git a/lualib/rspamadm/dmarc_report.lua b/lualib/rspamadm/dmarc_report.lua new file mode 100644 index 0000000..42c801e --- /dev/null +++ b/lualib/rspamadm/dmarc_report.lua @@ -0,0 +1,737 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local argparse = require "argparse" +local lua_util = require "lua_util" +local logger = require "rspamd_logger" +local lua_redis = require "lua_redis" +local dmarc_common = require "plugins/dmarc" +local lupa = require "lupa" +local rspamd_mempool = require "rspamd_mempool" +local rspamd_url = require "rspamd_url" +local rspamd_text = require "rspamd_text" +local rspamd_util = require "rspamd_util" +local rspamd_dns = require "rspamd_dns" + +local N = 'dmarc_report' + +-- Define command line options +local parser = argparse() + :name "rspamadm dmarc_report" + :description "Dmarc reports sending tool" + :help_description_margin(30) + +parser:option "-c --config" + :description "Path to config file" + :argname("<cfg>") + :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf") + +parser:flag "-v --verbose" + :description "Enable dmarc specific logging" + +parser:flag "-n --no-opt" + :description "Do not reset reporting data/send reports" + +parser:argument "date" + :description "Date to process (today by default)" + :argname "<YYYYMMDD>" + :args "*" +parser:option "-b --batch-size" + :description "Send reports in batches up to <batch-size> messages" + :argname "<number>" + :convert(tonumber) + :default "10" + +local report_template = [[From: "{= from_name =}" <{= from_addr =}> +To: {= rcpt =} +{%+ if is_string(bcc) %}Bcc: {= bcc =}{%- endif %} +Subject: Report Domain: {= reporting_domain =} + Submitter: {= submitter =} + Report-ID: {= report_id =} +Date: {= report_date =} +MIME-Version: 1.0 +Message-ID: <{= message_id =}> +Content-Type: multipart/mixed; + boundary="----=_NextPart_{= uuid =}" + +This is a multipart message in MIME format. + +------=_NextPart_{= uuid =} +Content-Type: text/plain; charset="us-ascii" +Content-Transfer-Encoding: 7bit + +This is an aggregate report from {= submitter =}. + +Report domain: {= reporting_domain =} +Submitter: {= submitter =} +Report ID: {= report_id =} + +------=_NextPart_{= uuid =} +Content-Type: application/gzip +Content-Transfer-Encoding: base64 +Content-Disposition: attachment; + filename="{= submitter =}!{= reporting_domain =}!{= report_start =}!{= report_end =}.xml.gz" + +]] +local report_footer = [[ + +------=_NextPart_{= uuid =}--]] + +local dmarc_settings = {} +local redis_params +local redis_attrs = { + config = rspamd_config, + ev_base = rspamadm_ev_base, + session = rspamadm_session, + log_obj = rspamd_config, + resolver = rspamadm_dns_resolver, +} +local pool + +local function load_config(opts) + local _r, err = rspamd_config:load_ucl(opts['config']) + + if not _r then + logger.errx('cannot parse %s: %s', opts['config'], err) + os.exit(1) + end + + _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' }) + if not _r then + logger.errx('cannot process %s: %s', opts['config'], err) + os.exit(1) + end +end + +-- Concat elements using redis_keys.join_char +local function redis_prefix(...) + return table.concat({ ... }, dmarc_settings.reporting.redis_keys.join_char) +end + +local function get_rua(rep_key) + local parts = lua_util.str_split(rep_key, dmarc_settings.reporting.redis_keys.join_char) + + if #parts >= 3 then + return parts[3] + end + + return nil +end + +local function get_domain(rep_key) + local parts = lua_util.str_split(rep_key, dmarc_settings.reporting.redis_keys.join_char) + + if #parts >= 3 then + return parts[2] + end + + return nil +end + +local function gen_uuid() + local template = 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx' + return string.gsub(template, '[xy]', function(c) + local v = (c == 'x') and math.random(0, 0xf) or math.random(8, 0xb) + return string.format('%x', v) + end) +end + +local function gen_xml_grammar() + local lpeg = require 'lpeg' + local lt = lpeg.P('<') / '<' + local gt = lpeg.P('>') / '>' + local amp = lpeg.P('&') / '&' + local quot = lpeg.P('"') / '"' + local apos = lpeg.P("'") / ''' + local special = lt + gt + amp + quot + apos + local grammar = lpeg.Cs((special + 1) ^ 0) + return grammar +end + +local xml_grammar = gen_xml_grammar() + +local function escape_xml(input) + if type(input) == 'string' or type(input) == 'userdata' then + return xml_grammar:match(input) + else + input = tostring(input) + + if input then + return xml_grammar:match(input) + end + end + + return '' +end +-- Enable xml escaping in lupa templates +lupa.filters.escape_xml = escape_xml + +-- Creates report XML header +local function report_header(reporting_domain, report_start, report_end, domain_policy) + local report_id = string.format('%s.%d.%d', + reporting_domain, report_start, report_end) + local xml_template = [[ +<?xml version="1.0" encoding="UTF-8" ?> +<feedback> + <report_metadata> + <org_name>{= report_settings.org_name | escape_xml =}</org_name> + <email>{= report_settings.email | escape_xml =}</email> + <report_id>{= report_id =}</report_id> + <date_range> + <begin>{= report_start =}</begin> + <end>{= report_end =}</end> + </date_range> + </report_metadata> + <policy_published> + <domain>{= reporting_domain | escape_xml =}</domain> + <adkim>{= domain_policy.adkim | escape_xml =}</adkim> + <aspf>{= domain_policy.aspf | escape_xml =}</aspf> + <p>{= domain_policy.p | escape_xml =}</p> + <sp>{= domain_policy.sp | escape_xml =}</sp> + <pct>{= domain_policy.pct | escape_xml =}</pct> + </policy_published> +]] + return lua_util.jinja_template(xml_template, { + report_settings = dmarc_settings.reporting, + report_id = report_id, + report_start = report_start, + report_end = report_end, + domain_policy = domain_policy, + reporting_domain = reporting_domain, + }, true) +end + +-- Generate xml entry for a preprocessed redis row +local function entry_to_xml(data) + local xml_template = [[<record> + <row> + <source_ip>{= data.ip =}</source_ip> + <count>{= data.count =}</count> + <policy_evaluated> + <disposition>{= data.disposition =}</disposition> + <dkim>{= data.dkim_disposition =}</dkim> + <spf>{= data.spf_disposition =}</spf> + {% if data.override and data.override ~= '' -%} + <reason><type>{= data.override =}</type></reason> + {%- endif %} + </policy_evaluated> + </row> + <identifiers> + <header_from>{= data.header_from =}</header_from> + </identifiers> + <auth_results> + {% if data.dkim_results[1] -%} + {% for d in data.dkim_results -%} + <dkim> + <domain>{= d.domain =}</domain> + <result>{= d.result =}</result> + </dkim> + {%- endfor %} + {%- endif %} + <spf> + <domain>{= data.spf_domain =}</domain> + <result>{= data.spf_result =}</result> + </spf> + </auth_results> +</record> +]] + return lua_util.jinja_template(xml_template, { data = data }, true) +end + +-- Process a report entry stored in Redis splitting it to a lua table +local function process_report_entry(data, score) + local split = lua_util.str_split(data, ',') + local row = { + ip = split[1], + spf_disposition = split[2], + dkim_disposition = split[3], + disposition = split[4], + override = split[5], + header_from = split[6], + dkim_results = {}, + spf_domain = split[11], + spf_result = split[12], + count = tonumber(score), + } + -- Process dkim entries + local function dkim_entries_process(dkim_data, result) + if dkim_data and dkim_data ~= '' then + local dkim_elts = lua_util.str_split(dkim_data, '|') + for _, d in ipairs(dkim_elts) do + table.insert(row.dkim_results, { domain = d, result = result }) + end + end + end + dkim_entries_process(split[7], 'pass') + dkim_entries_process(split[8], 'fail') + dkim_entries_process(split[9], 'temperror') + dkim_entries_process(split[9], 'permerror') + + return row +end + +-- Process a single rua entry, validating in DNS if needed +local function process_rua(dmarc_domain, rua) + local parts = lua_util.str_split(rua, ',') + + -- Remove size limitation, as we don't care about them + local addrs = {} + for _, rua_part in ipairs(parts) do + local u = rspamd_url.create(pool, rua_part:gsub('!%d+[kmg]?$', '')) + local u2 = rspamd_url.create(pool, dmarc_domain) + if u and (u:get_protocol() or '') == 'mailto' and u:get_user() then + -- Check each address for sanity + if u:get_tld() == u2:get_tld() then + -- Same eSLD - always include + table.insert(addrs, u) + else + -- We need to check authority + local resolve_str = string.format('%s._report._dmarc.%s', + dmarc_domain, u:get_host()) + local is_ok, results = rspamd_dns.request({ + config = rspamd_config, + session = rspamadm_session, + type = 'txt', + name = resolve_str, + }) + + if not is_ok then + logger.errx('cannot resolve %s: %s; exclude %s', resolve_str, results, rua_part) + else + local found = false + for _, t in ipairs(results) do + if string.match(t, 'v=DMARC1') then + found = true + break + end + end + + if not found then + logger.errx('%s is not authorized to process reports on %s', dmarc_domain, u:get_host()) + else + -- All good + table.insert(addrs, u) + end + end + end + else + logger.errx('invalid rua url: "%s""', tostring(u or 'null')) + end + end + + if #addrs > 0 then + return addrs + end + + return nil +end + +-- Validate reporting domain, extracting rua and checking 3rd party report domains +-- This function returns a full dmarc record processed + rua as a list of url objects +local function validate_reporting_domain(reporting_domain) + local is_ok, results = rspamd_dns.request({ + config = rspamd_config, + session = rspamadm_session, + type = 'txt', + name = '_dmarc.' .. reporting_domain, + }) + + if not is_ok or not results then + logger.errx('cannot resolve _dmarc.%s: %s', reporting_domain, results) + return nil + end + + for _, r in ipairs(results) do + local processed, rec = dmarc_common.dmarc_check_record(rspamd_config, r, false) + if processed and rec.rua then + -- We need to check or alter rua if needed + local processed_rua = process_rua(reporting_domain, rec.rua) + if processed_rua then + rec = rec.raw_elts + rec.rua = processed_rua + + -- Fill defaults in a record to avoid nils in a report + rec['pct'] = rec['pct'] or 100 + rec['adkim'] = rec['adkim'] or 'r' + rec['aspf'] = rec['aspf'] or 'r' + rec['p'] = rec['p'] or 'none' + rec['sp'] = rec['sp'] or 'none' + return rec + end + return nil + end + end + + return nil +end + +-- Returns a list of recipients from a table as a string processing elements if needed +local function rcpt_list(tbl, func) + local res = {} + for _, r in ipairs(tbl) do + if func then + table.insert(res, func(r)) + else + table.insert(res, r) + end + end + + return table.concat(res, ',') +end + +-- Synchronous smtp send function +local function send_reports_by_smtp(opts, reports, finish_cb) + local lua_smtp = require "lua_smtp" + local reports_failed = 0 + local reports_sent = 0 + local report_settings = dmarc_settings.reporting + + local function gen_sendmail_cb(report, args) + return function(ret, err) + -- We modify this from all callbacks + args.nreports = args.nreports - 1 + if not ret then + logger.errx("Couldn't send mail for %s: %s", report.reporting_domain, err) + reports_failed = reports_failed + 1 + else + reports_sent = reports_sent + 1 + lua_util.debugm(N, 'successfully sent a report for %s: %s bytes sent', + report.reporting_domain, #report.message) + end + + -- Tail call to the next batch or to the final function + if args.nreports == 0 then + if args.next_start > #reports then + finish_cb(reports_sent, reports_failed) + else + args.cont_func(args.next_start) + end + end + end + end + + local function send_data_in_batches(cur_batch) + local nreports = math.min(#reports - cur_batch + 1, opts.batch_size) + local next_start = cur_batch + nreports + lua_util.debugm(N, 'send data for %s domains (from %s to %s)', + nreports, cur_batch, next_start - 1) + -- Shared across all closures + local gen_args = { + cont_func = send_data_in_batches, + nreports = nreports, + next_start = next_start + } + for i = cur_batch, next_start - 1 do + local report = reports[i] + local send_opts = { + ev_base = rspamadm_ev_base, + session = rspamadm_session, + config = rspamd_config, + host = report_settings.smtp, + port = report_settings.smtp_port or 25, + resolver = rspamadm_dns_resolver, + from = report_settings.email, + recipients = report.rcpts, + helo = report_settings.helo or 'rspamd.localhost', + } + + lua_smtp.sendmail(send_opts, + report.message, + gen_sendmail_cb(report, gen_args)) + end + end + + send_data_in_batches(1) +end + +local function prepare_report(opts, start_time, end_time, rep_key) + local rua = get_rua(rep_key) + local reporting_domain = get_domain(rep_key) + + if not rua then + logger.errx('report %s has no valid rua, skip it', rep_key) + return nil + end + if not reporting_domain then + logger.errx('report %s has no valid reporting_domain, skip it', rep_key) + return nil + end + + local ret, results = lua_redis.request(redis_params, redis_attrs, + { 'EXISTS', rep_key }) + + if not ret or not results or results == 0 then + return nil + end + + -- Rename report key to avoid races + if not opts.no_opt then + lua_redis.request(redis_params, redis_attrs, + { 'RENAME', rep_key, rep_key .. '_processing' }) + rep_key = rep_key .. '_processing' + end + + local dmarc_record = validate_reporting_domain(reporting_domain) + lua_util.debugm(N, 'process reporting domain %s: %s', reporting_domain, dmarc_record) + + if not dmarc_record then + if not opts.no_opt then + lua_redis.request(redis_params, redis_attrs, + { 'DEL', rep_key }) + end + logger.messagex('Cannot process reports for domain %s; invalid dmarc record', reporting_domain) + return nil + end + + -- Get all reports for a domain + ret, results = lua_redis.request(redis_params, redis_attrs, + { 'ZRANGE', rep_key, '0', '-1', 'WITHSCORES' }) + local report_entries = {} + table.insert(report_entries, + report_header(reporting_domain, start_time, end_time, dmarc_record)) + for i = 1, #results, 2 do + local xml_record = entry_to_xml(process_report_entry(results[i], results[i + 1])) + table.insert(report_entries, xml_record) + end + table.insert(report_entries, '</feedback>') + local xml_to_compress = rspamd_text.fromtable(report_entries) + lua_util.debugm(N, 'got xml: %s', xml_to_compress) + + -- Prepare SMTP message + local report_settings = dmarc_settings.reporting + local rcpt_string = rcpt_list(dmarc_record.rua, function(rua_elt) + return string.format('%s@%s', rua_elt:get_user(), rua_elt:get_host()) + end) + local bcc_string + if report_settings.bcc_addrs then + bcc_string = rcpt_list(report_settings.bcc_addrs) + end + local uuid = gen_uuid() + local rhead = lua_util.jinja_template(report_template, { + from_name = report_settings.from_name, + from_addr = report_settings.email, + rcpt = rcpt_string, + bcc = bcc_string, + uuid = uuid, + reporting_domain = reporting_domain, + submitter = report_settings.domain, + report_id = string.format('%s.%d.%d', reporting_domain, start_time, + end_time), + report_date = rspamd_util.time_to_string(rspamd_util.get_time()), + message_id = rspamd_util.random_hex(16) .. '@' .. report_settings.msgid_from, + report_start = start_time, + report_end = end_time + }, true) + local rfooter = lua_util.jinja_template(report_footer, { + uuid = uuid, + }, true) + local message = rspamd_text.fromtable { + (rhead:gsub("\n", "\r\n")), + rspamd_util.encode_base64(rspamd_util.gzip_compress(xml_to_compress), 73), + rfooter:gsub("\n", "\r\n"), + } + + lua_util.debugm(N, 'got final message: %s', message) + + if not opts.no_opt then + lua_redis.request(redis_params, redis_attrs, + { 'DEL', rep_key }) + end + + local report_rcpts = lua_util.str_split(rcpt_string, ',') + + if report_settings.bcc_addrs then + for _, b in ipairs(report_settings.bcc_addrs) do + table.insert(report_rcpts, b) + end + end + + return { + message = message, + rcpts = report_rcpts, + reporting_domain = reporting_domain + } +end + +local function process_report_date(opts, start_time, end_time, date) + local idx_key = redis_prefix(dmarc_settings.reporting.redis_keys.index_prefix, date) + local ret, results = lua_redis.request(redis_params, redis_attrs, + { 'EXISTS', idx_key }) + + if not ret or not results or results == 0 then + logger.messagex('No reports for %s', date) + return {} + end + + -- Rename index key to avoid races + if not opts.no_opt then + lua_redis.request(redis_params, redis_attrs, + { 'RENAME', idx_key, idx_key .. '_processing' }) + idx_key = idx_key .. '_processing' + end + ret, results = lua_redis.request(redis_params, redis_attrs, + { 'SMEMBERS', idx_key }) + + if not ret or not results then + -- Remove bad key + if not opts.no_opt then + lua_redis.request(redis_params, redis_attrs, + { 'DEL', idx_key }) + end + logger.messagex('Cannot get reports for %s', date) + return {} + end + + local reports = {} + for _, rep in ipairs(results) do + local report = prepare_report(opts, start_time, end_time, rep) + + if report then + table.insert(reports, report) + end + end + + -- Shuffle reports to make sending more fair + lua_util.shuffle(reports) + -- Remove processed key + if not opts.no_opt then + lua_redis.request(redis_params, redis_attrs, + { 'DEL', idx_key }) + end + + return reports +end + + +-- Returns a day before today at 00:00 as unix seconds +local function yesterday_midnight() + local piecewise_time = os.date("*t") + piecewise_time.day = piecewise_time.day - 1 -- Lua allows negative values here + piecewise_time.hour = 0 + piecewise_time.sec = 0 + piecewise_time.min = 0 + return os.time(piecewise_time) +end + +-- Returns today time at 00:00 as unix seconds +local function today_midnight() + local piecewise_time = os.date("*t") + piecewise_time.hour = 0 + piecewise_time.sec = 0 + piecewise_time.min = 0 + return os.time(piecewise_time) +end + +local function handler(args) + local start_time + -- Preserve start time as report sending might take some time + local start_collection = today_midnight() + + local opts = parser:parse(args) + + pool = rspamd_mempool.create() + load_config(opts) + rspamd_url.init(rspamd_config:get_tld_path()) + + if opts.verbose then + lua_util.enable_debug_modules('dmarc', N) + end + + dmarc_settings = rspamd_config:get_all_opt('dmarc') + if not dmarc_settings or not dmarc_settings.reporting or not dmarc_settings.reporting.enabled then + logger.errx('dmarc reporting is not enabled, exiting') + os.exit(1) + end + + dmarc_settings = lua_util.override_defaults(dmarc_common.default_settings, dmarc_settings) + redis_params = lua_redis.parse_redis_server('dmarc', dmarc_settings) + + if not redis_params then + logger.errx('Redis is not configured, exiting') + os.exit(1) + end + + for _, e in ipairs({ 'email', 'domain', 'org_name' }) do + if not dmarc_settings.reporting[e] then + logger.errx('Missing required setting: dmarc.reporting.%s', e) + return + end + end + + local ret, results = lua_redis.request(redis_params, redis_attrs, { + 'GET', 'rspamd_dmarc_last_collection' + }) + + if not ret or not tonumber(results) then + start_time = yesterday_midnight() + else + start_time = tonumber(results) + end + + lua_util.debugm(N, 'previous last report date is %s', start_time) + + if not opts.date or #opts.date == 0 then + opts.date = {} + table.insert(opts.date, os.date('%Y%m%d', yesterday_midnight())) + end + + local ndates = 0 + local nreports = 0 + local all_reports = {} + for _, date in ipairs(opts.date) do + lua_util.debugm(N, 'Process date %s', date) + local reports_for_date = process_report_date(opts, start_time, start_collection, date) + if #reports_for_date > 0 then + ndates = ndates + 1 + nreports = nreports + #reports_for_date + + for _, r in ipairs(reports_for_date) do + table.insert(all_reports, r) + end + end + end + + local function finish_cb(nsuccess, nfail) + if not opts.no_opt then + lua_util.debugm(N, 'set last report date to %s', start_collection) + -- Hack to avoid coroutines + async functions mess: we use async redis call here + redis_attrs.callback = function() + logger.messagex('Reporting collection has finished %s dates processed, %s reports: %s completed, %s failed', + ndates, nreports, nsuccess, nfail) + end + lua_redis.request(redis_params, redis_attrs, + { 'SETEX', 'rspamd_dmarc_last_collection', dmarc_settings.reporting.keys_expire * 2, + tostring(start_collection) }) + else + logger.messagex('Reporting collection has finished %s dates processed, %s reports: %s completed, %s failed', + ndates, nreports, nsuccess, nfail) + end + + pool:destroy() + end + if not opts.no_opt then + send_reports_by_smtp(opts, all_reports, finish_cb) + else + logger.messagex('Skip sending mails due to -n / --no-opt option') + end +end + +return { + name = 'dmarc_report', + aliases = { 'dmarc_reporting' }, + handler = handler, + description = parser._description +} diff --git a/lualib/rspamadm/dns_tool.lua b/lualib/rspamadm/dns_tool.lua new file mode 100644 index 0000000..3eb09a8 --- /dev/null +++ b/lualib/rspamadm/dns_tool.lua @@ -0,0 +1,232 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + + +local argparse = require "argparse" +local rspamd_logger = require "rspamd_logger" +local ansicolors = require "ansicolors" +local bit = require "bit" + +local parser = argparse() + :name "rspamadm dnstool" + :description "DNS tools provided by Rspamd" + :help_description_margin(30) + :command_target("command") + :require_command(true) + +parser:option "-c --config" + :description "Path to config file" + :argname("<cfg>") + :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf") + +local spf = parser:command "spf" + :description "Extracts spf records" +spf:mutex( + spf:option "-d --domain" + :description "Domain to use" + :argname("<domain>"), + spf:option "-f --from" + :description "SMTP from to use" + :argname("<from>") +) + +spf:option "-i --ip" + :description "Source IP address to use" + :argname("<ip>") +spf:flag "-a --all" + :description "Print all records" + +local function printf(fmt, ...) + if fmt then + io.write(string.format(fmt, ...)) + end + io.write('\n') +end + +local function highlight(str) + return ansicolors.white .. str .. ansicolors.reset +end + +local function green(str) + return ansicolors.green .. str .. ansicolors.reset +end + +local function red(str) + return ansicolors.red .. str .. ansicolors.reset +end + +local function load_config(opts) + local _r, err = rspamd_config:load_ucl(opts['config']) + + if not _r then + rspamd_logger.errx('cannot parse %s: %s', opts['config'], err) + os.exit(1) + end + + _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' }) + if not _r then + rspamd_logger.errx('cannot process %s: %s', opts['config'], err) + os.exit(1) + end +end + +local function spf_handler(opts) + local rspamd_spf = require "rspamd_spf" + local rspamd_task = require "rspamd_task" + local rspamd_ip = require "rspamd_ip" + + local task = rspamd_task:create(rspamd_config, rspamadm_ev_base) + task:set_session(rspamadm_session) + task:set_resolver(rspamadm_dns_resolver) + + if opts.ip then + opts.ip = rspamd_ip.fromstring(opts.ip) + task:set_from_ip(opts.ip) + else + opts.all = true + end + + if opts.from then + local rspamd_parsers = require "rspamd_parsers" + local addr_parsed = rspamd_parsers.parse_mail_address(opts.from) + if addr_parsed then + task:set_from('smtp', addr_parsed[1]) + else + io.stderr:write('Invalid from addr\n') + os.exit(1) + end + elseif opts.domain then + task:set_from('smtp', { user = 'user', domain = opts.domain }) + else + io.stderr:write('Neither domain nor from specified\n') + os.exit(1) + end + + local function flag_to_str(fl) + if bit.band(fl, rspamd_spf.flags.temp_fail) ~= 0 then + return "temporary failure" + elseif bit.band(fl, rspamd_spf.flags.perm_fail) ~= 0 then + return "permanent failure" + elseif bit.band(fl, rspamd_spf.flags.na) ~= 0 then + return "no spf record" + end + + return "unknown flag: " .. tostring(fl) + end + + local function display_spf_results(elt, colored) + local dec = function(e) + return e + end + local policy_decode = function(e) + if e == rspamd_spf.policy.fail then + return 'reject' + elseif e == rspamd_spf.policy.pass then + return 'pass' + elseif e == rspamd_spf.policy.soft_fail then + return 'soft fail' + elseif e == rspamd_spf.policy.neutral then + return 'neutral' + end + + return 'unknown' + end + + if colored then + dec = function(e) + return highlight(e) + end + + if elt.result == rspamd_spf.policy.pass then + dec = function(e) + return green(e) + end + elseif elt.result == rspamd_spf.policy.fail then + dec = function(e) + return red(e) + end + end + + end + printf('%s: %s', highlight('Policy'), dec(policy_decode(elt.result))) + printf('%s: %s', highlight('Network'), dec(elt.addr)) + + if elt.str then + printf('%s: %s', highlight('Original'), elt.str) + end + end + + local function cb(record, flags, err) + if record then + local result, flag_or_policy, error_or_addr + if opts.ip then + result, flag_or_policy, error_or_addr = record:check_ip(opts.ip) + elseif opts.all then + result = true + end + if opts.ip and not opts.all then + if result then + display_spf_results(error_or_addr, true) + else + printf('Not matched: %s', error_or_addr) + end + + os.exit(0) + end + + if result then + printf('SPF record for %s; digest: %s', + highlight(opts.domain or opts.from), highlight(record:get_digest())) + for _, elt in ipairs(record:get_elts()) do + if result and error_or_addr and elt.str and elt.str == error_or_addr.str then + printf("%s", highlight('*** Matched ***')) + display_spf_results(elt, true) + printf('------') + else + display_spf_results(elt, false) + printf('------') + end + end + else + printf('Error getting SPF record: %s (%s flag)', err, + flag_to_str(flag_or_policy or flags)) + end + else + printf('Cannot get SPF record: %s', err) + end + end + rspamd_spf.resolve(task, cb) +end + +local function handler(args) + local opts = parser:parse(args) + load_config(opts) + + local command = opts.command + + if command == 'spf' then + spf_handler(opts) + else + parser:error('command %s is not implemented', command) + end +end + +return { + name = 'dnstool', + aliases = { 'dns', 'dns_tool' }, + handler = handler, + description = parser._description +}
\ No newline at end of file diff --git a/lualib/rspamadm/fuzzy_convert.lua b/lualib/rspamadm/fuzzy_convert.lua new file mode 100644 index 0000000..fab3995 --- /dev/null +++ b/lualib/rspamadm/fuzzy_convert.lua @@ -0,0 +1,208 @@ +local sqlite3 = require "rspamd_sqlite3" +local redis = require "rspamd_redis" +local util = require "rspamd_util" + +local function connect_redis(server, username, password, db) + local ret + local conn, err = redis.connect_sync({ + host = server, + }) + + if not conn then + return nil, 'Cannot connect: ' .. err + end + + if username then + if password then + ret = conn:add_cmd('AUTH', { username, password }) + if not ret then + return nil, 'Cannot queue command' + end + else + return nil, 'Redis requires a password when username is supplied' + end + elseif password then + ret = conn:add_cmd('AUTH', { password }) + if not ret then + return nil, 'Cannot queue command' + end + end + if db then + ret = conn:add_cmd('SELECT', { db }) + if not ret then + return nil, 'Cannot queue command' + end + end + + return conn, nil +end + +local function send_digests(digests, redis_host, redis_username, redis_password, redis_db) + local conn, err = connect_redis(redis_host, redis_username, redis_password, redis_db) + if err then + print(err) + return false + end + local ret + for _, v in ipairs(digests) do + ret = conn:add_cmd('HMSET', { + 'fuzzy' .. v[1], + 'F', v[2], + 'V', v[3], + }) + if not ret then + print('Cannot batch command') + return false + end + ret = conn:add_cmd('EXPIRE', { + 'fuzzy' .. v[1], + tostring(v[4]), + }) + if not ret then + print('Cannot batch command') + return false + end + end + ret, err = conn:exec() + if not ret then + print('Cannot execute batched commands: ' .. err) + return false + end + return true +end + +local function send_shingles(shingles, redis_host, redis_username, redis_password, redis_db) + local conn, err = connect_redis(redis_host, redis_username, redis_password, redis_db) + if err then + print("Redis error: " .. err) + return false + end + local ret + for _, v in ipairs(shingles) do + ret = conn:add_cmd('SET', { + 'fuzzy_' .. v[2] .. '_' .. v[1], + v[4], + }) + if not ret then + print('Cannot batch SET command: ' .. err) + return false + end + ret = conn:add_cmd('EXPIRE', { + 'fuzzy_' .. v[2] .. '_' .. v[1], + tostring(v[3]), + }) + if not ret then + print('Cannot batch command') + return false + end + end + ret, err = conn:exec() + if not ret then + print('Cannot execute batched commands: ' .. err) + return false + end + return true +end + +local function update_counters(total, redis_host, redis_username, redis_password, redis_db) + local conn, err = connect_redis(redis_host, redis_username, redis_password, redis_db) + if err then + print(err) + return false + end + local ret + ret = conn:add_cmd('SET', { + 'fuzzylocal', + total, + }) + if not ret then + print('Cannot batch command') + return false + end + ret = conn:add_cmd('SET', { + 'fuzzy_count', + total, + }) + if not ret then + print('Cannot batch command') + return false + end + ret, err = conn:exec() + if not ret then + print('Cannot execute batched commands: ' .. err) + return false + end + return true +end + +return function(_, res) + local db = sqlite3.open(res['source_db']) + local shingles = {} + local digests = {} + local num_batch_digests = 0 + local num_batch_shingles = 0 + local total_digests = 0 + local total_shingles = 0 + local lim_batch = 1000 -- Update each 1000 entries + local redis_username = res['redis_username'] + local redis_password = res['redis_password'] + local redis_db = nil + + if res['redis_db'] then + redis_db = tostring(res['redis_db']) + end + + if not db then + print('Cannot open source db: ' .. res['source_db']) + return + end + + local now = util.get_time() + for row in db:rows('SELECT id, flag, digest, value, time FROM digests') do + + local expire_in = math.floor(now - row.time + res['expiry']) + if expire_in >= 1 then + table.insert(digests, { row.digest, row.flag, row.value, expire_in }) + num_batch_digests = num_batch_digests + 1 + total_digests = total_digests + 1 + for srow in db:rows('SELECT value, number FROM shingles WHERE digest_id = ' .. row.id) do + table.insert(shingles, { srow.value, srow.number, expire_in, row.digest }) + total_shingles = total_shingles + 1 + num_batch_shingles = num_batch_shingles + 1 + end + end + if num_batch_digests >= lim_batch then + if not send_digests(digests, res['redis_host'], redis_username, redis_password, redis_db) then + return + end + num_batch_digests = 0 + digests = {} + end + if num_batch_shingles >= lim_batch then + if not send_shingles(shingles, res['redis_host'], redis_username, redis_password, redis_db) then + return + end + num_batch_shingles = 0 + shingles = {} + end + end + if digests[1] then + if not send_digests(digests, res['redis_host'], redis_username, redis_password, redis_db) then + return + end + end + if shingles[1] then + if not send_shingles(shingles, res['redis_host'], redis_username, redis_password, redis_db) then + return + end + end + + local message = string.format( + 'Migrated %d digests and %d shingles', + total_digests, total_shingles + ) + if not update_counters(total_digests, res['redis_host'], redis_username, redis_password, redis_db) then + message = message .. ' but failed to update counters' + end + print(message) +end diff --git a/lualib/rspamadm/fuzzy_ping.lua b/lualib/rspamadm/fuzzy_ping.lua new file mode 100644 index 0000000..e0345da --- /dev/null +++ b/lualib/rspamadm/fuzzy_ping.lua @@ -0,0 +1,259 @@ +--[[ +Copyright (c) 2023, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local argparse = require "argparse" +local ansicolors = require "ansicolors" +local rspamd_logger = require "rspamd_logger" +local lua_util = require "lua_util" + +local E = {} + +local parser = argparse() + :name 'rspamadm fuzzy_ping' + :description 'Pings fuzzy storage' + :help_description_margin(30) +parser:option "-c --config" + :description "Path to config file" + :argname("<cfg>") + :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf") +parser:option "-r --rule" + :description "Storage to ping (must be configured in Rspamd configuration)" + :argname("<name>") + :default("rspamd.com") +parser:flag "-f --flood" + :description "Flood mode (no waiting for replies)" +parser:flag "-S --silent" + :description "Silent mode (statistics only)" +parser:option "-t --timeout" + :description "Timeout for requests" + :argname("<timeout>") + :convert(tonumber) + :default(5) +parser:option "-s --server" + :description "Override server to ping" + :argname("<name>") +parser:option "-n --number" + :description "Timeout for requests" + :argname("<number>") + :convert(tonumber) + :default(5) +parser:flag "-l --list" + :description "List configured storages" + +local function load_config(opts) + local _r, err = rspamd_config:load_ucl(opts['config']) + + if not _r then + rspamd_logger.errx('cannot parse %s: %s', opts['config'], err) + os.exit(1) + end + + -- Init the real structure excluding logging and workers + _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' }) + if not _r then + rspamd_logger.errx('cannot process %s: %s', opts['config'], err) + os.exit(1) + end + + _r, err = rspamd_config:init_modules() + if not _r then + rspamd_logger.errx('cannot init modules from %s: %s', opts['config'], err) + os.exit(1) + end +end + +local function highlight(fmt, ...) + return ansicolors.white .. string.format(fmt, ...) .. ansicolors.reset +end + +local function highlight_err(fmt, ...) + return ansicolors.red .. string.format(fmt, ...) .. ansicolors.reset +end + +local function print_storages(rules) + for n, rule in pairs(rules) do + print(highlight('Rule: %s', n)) + print(string.format("\tRead only: %s", rule.read_only)) + print(string.format("\tServers: %s", table.concat(lua_util.values(rule.servers), ','))) + print("\tFlags:") + + for fl, id in pairs(rule.flags or E) do + print(string.format("\t\t%s: %s", fl, id)) + end + end +end + +local function std_mean(tbl) + local function mean() + local sum = 0 + local count = 0 + + for _, v in ipairs(tbl) do + sum = sum + v + count = count + 1 + end + + return (sum / count) + end + + local m + local vm + local sum = 0 + local count = 0 + local result + + m = mean(tbl) + + for _, v in ipairs(tbl) do + vm = v - m + sum = sum + (vm * vm) + count = count + 1 + end + + result = math.sqrt(sum / (count - 1)) + + return result, m +end + +local function maxmin(tbl) + local max = -math.huge + local min = math.huge + + for _, v in ipairs(tbl) do + max = math.max(max, v) + min = math.min(min, v) + end + + return max, min +end + +local function print_results(results) + local servers = {} + local err_servers = {} + for _, res in ipairs(results) do + if res.success then + if servers[res.server] then + table.insert(servers[res.server], res.latency) + else + servers[res.server] = { res.latency } + end + else + if err_servers[res.server] then + err_servers[res.server] = err_servers[res.server] + 1 + else + err_servers[res.server] = 1 + end + -- For the case if no successful replies are detected + if not servers[res.server] then + servers[res.server] = {} + end + end + end + + for s, l in pairs(servers) do + local total = #l + (err_servers[s] or 0) + print(highlight('Summary for %s: %d packets transmitted, %d packets received, %.1f%% packet loss', + s, total, #l, (total - #l) * 100.0 / total)) + local mean, std = std_mean(l) + local max, min = maxmin(l) + print(string.format('round-trip min/avg/max/std-dev = %.2f/%.2f/%.2f/%.2f ms', + min, mean, + max, std)) + end +end + +local function handler(args) + local opts = parser:parse(args) + + load_config(opts) + + if opts.list then + print_storages(rspamd_plugins.fuzzy_check.list_storages(rspamd_config)) + os.exit(0) + end + + -- Perform ping using a fake task from async stuff provided by rspamadm + local rspamd_task = require "rspamd_task" + + -- TODO: this task is not cleared at the end, do something about it some day + local task = rspamd_task.create(rspamd_config, rspamadm_ev_base) + task:set_session(rspamadm_session) + task:set_resolver(rspamadm_dns_resolver) + + local replied = 0 + local results = {} + local ping_fuzzy + + local function gen_ping_fuzzy_cb(num) + return function(success, server, latency_or_err) + if not success then + if not opts.silent then + print(highlight_err('error from %s: %s', server, latency_or_err)) + end + results[num] = { + success = false, + error = latency_or_err, + server = tostring(server), + } + else + if not opts.silent then + local adjusted_latency = math.floor(latency_or_err * 1000) * 1.0 / 1000; + print(highlight('reply from %s: %s ms', server, adjusted_latency)) + + end + results[num] = { + success = true, + latency = latency_or_err, + server = tostring(server), + } + end + + if replied == opts.number - 1 then + print_results(results) + else + replied = replied + 1 + if not opts.flood then + ping_fuzzy(replied + 1) + end + end + end + end + + ping_fuzzy = function(num) + local ret, err = rspamd_plugins.fuzzy_check.ping_storage(task, gen_ping_fuzzy_cb(num), + opts.rule, opts.timeout, opts.server) + + if not ret then + print(highlight_err('error from %s: %s', opts.server, err)) + opts.number = opts.number - 1 -- To avoid issues with waiting for other replies + end + end + + if opts.flood then + for i = 1, opts.number do + ping_fuzzy(i) + end + else + ping_fuzzy(1) + end +end + +return { + name = 'fuzzy_ping', + aliases = { 'fuzzyping' }, + handler = handler, + description = parser._description +}
\ No newline at end of file diff --git a/lualib/rspamadm/fuzzy_stat.lua b/lualib/rspamadm/fuzzy_stat.lua new file mode 100644 index 0000000..ef8a5de --- /dev/null +++ b/lualib/rspamadm/fuzzy_stat.lua @@ -0,0 +1,366 @@ +local rspamd_util = require "rspamd_util" +local lua_util = require "lua_util" +local opts = {} + +local argparse = require "argparse" +local parser = argparse() + :name "rspamadm control fuzzystat" + :description "Shows help for the specified configuration options" + :help_description_margin(32) +parser:flag "--no-ips" + :description "No IPs stats" +parser:flag "--no-keys" + :description "No keys stats" +parser:flag "--short" + :description "Short output mode" +parser:flag "-n --number" + :description "Disable numbers humanization" +parser:option "--sort" + :description "Sort order" + :convert { + checked = "checked", + matched = "matched", + errors = "errors", + name = "name" +} + +local function add_data(target, src) + for k, v in pairs(src) do + if type(v) == 'number' then + if target[k] then + target[k] = target[k] + v + else + target[k] = v + end + elseif k == 'ips' then + if not target['ips'] then + target['ips'] = {} + end + -- Iterate over IPs + for ip, st in pairs(v) do + if not target['ips'][ip] then + target['ips'][ip] = {} + end + add_data(target['ips'][ip], st) + end + elseif k == 'flags' then + if not target['flags'] then + target['flags'] = {} + end + -- Iterate over Flags + for flag, st in pairs(v) do + if not target['flags'][flag] then + target['flags'][flag] = {} + end + add_data(target['flags'][flag], st) + end + elseif k == 'keypair' then + if type(v.extensions) == 'table' then + if type(v.extensions.name) == 'string' then + target.name = v.extensions.name + end + end + end + end +end + +local function print_num(num) + if num then + if opts['n'] or opts['number'] then + return tostring(num) + else + return rspamd_util.humanize_number(num) + end + else + return 'na' + end +end + +local function print_stat(st, tabs) + if st['checked'] then + if st.checked_per_hour then + print(string.format('%sChecked: %s (%s per hour in average)', tabs, + print_num(st['checked']), print_num(st['checked_per_hour']))) + else + print(string.format('%sChecked: %s', tabs, print_num(st['checked']))) + end + end + if st['matched'] then + if st.checked and st.checked > 0 and st.checked <= st.matched then + local percentage = st.matched / st.checked * 100.0 + if st.matched_per_hour then + print(string.format('%sMatched: %s - %s percent (%s per hour in average)', tabs, + print_num(st['matched']), percentage, print_num(st['matched_per_hour']))) + else + print(string.format('%sMatched: %s - %s percent', tabs, print_num(st['matched']), percentage)) + end + else + if st.matched_per_hour then + print(string.format('%sMatched: %s (%s per hour in average)', tabs, + print_num(st['matched']), print_num(st['matched_per_hour']))) + else + print(string.format('%sMatched: %s', tabs, print_num(st['matched']))) + end + end + end + if st['errors'] then + print(string.format('%sErrors: %s', tabs, print_num(st['errors']))) + end + if st['added'] then + print(string.format('%sAdded: %s', tabs, print_num(st['added']))) + end + if st['deleted'] then + print(string.format('%sDeleted: %s', tabs, print_num(st['deleted']))) + end +end + +-- Sort by checked +local function sort_hash_table(tbl, sort_opts, key_key) + local res = {} + for k, v in pairs(tbl) do + table.insert(res, { [key_key] = k, data = v }) + end + + local function sort_order(elt) + local key = 'checked' + local sort_res = 0 + + if sort_opts['sort'] then + if sort_opts['sort'] == 'matched' then + key = 'matched' + elseif sort_opts['sort'] == 'errors' then + key = 'errors' + elseif sort_opts['sort'] == 'name' then + return elt[key_key] + end + end + + if elt.data[key] then + sort_res = elt.data[key] + end + + return sort_res + end + + table.sort(res, function(a, b) + return sort_order(a) > sort_order(b) + end) + + return res +end + +local function add_result(dst, src, k) + if type(src) == 'table' then + if type(dst) == 'number' then + -- Convert dst to table + dst = { dst } + elseif type(dst) == 'nil' then + dst = {} + end + + for i, v in ipairs(src) do + if dst[i] and k ~= 'fuzzy_stored' then + dst[i] = dst[i] + v + else + dst[i] = v + end + end + else + if type(dst) == 'table' then + if k ~= 'fuzzy_stored' then + dst[1] = dst[1] + src + else + dst[1] = src + end + else + if dst and k ~= 'fuzzy_stored' then + dst = dst + src + else + dst = src + end + end + end + + return dst +end + +local function print_result(r) + local function num_to_epoch(num) + if num == 1 then + return 'v0.6' + elseif num == 2 then + return 'v0.8' + elseif num == 3 then + return 'v0.9' + elseif num == 4 then + return 'v1.0+' + elseif num == 5 then + return 'v1.7+' + end + return '???' + end + if type(r) == 'table' then + local res = {} + for i, num in ipairs(r) do + res[i] = string.format('(%s: %s)', num_to_epoch(i), print_num(num)) + end + + return table.concat(res, ', ') + end + + return print_num(r) +end + +return function(args, res) + local res_ips = {} + local res_databases = {} + local wrk = res['workers'] + opts = parser:parse(args) + + if wrk then + for _, pr in pairs(wrk) do + -- processes cycle + if pr['data'] then + local id = pr['id'] + + if id then + local res_db = res_databases[id] + if not res_db then + res_db = { + keys = {} + } + res_databases[id] = res_db + end + + -- General stats + for k, v in pairs(pr['data']) do + if k ~= 'keys' and k ~= 'errors_ips' then + res_db[k] = add_result(res_db[k], v, k) + elseif k == 'errors_ips' then + -- Errors ips + if not res_db['errors_ips'] then + res_db['errors_ips'] = {} + end + for ip, nerrors in pairs(v) do + if not res_db['errors_ips'][ip] then + res_db['errors_ips'][ip] = nerrors + else + res_db['errors_ips'][ip] = nerrors + res_db['errors_ips'][ip] + end + end + end + end + + if pr['data']['keys'] then + local res_keys = res_db['keys'] + if not res_keys then + res_keys = {} + res_db['keys'] = res_keys + end + -- Go through keys in input + for k, elts in pairs(pr['data']['keys']) do + -- keys cycle + if not res_keys[k] then + res_keys[k] = {} + end + + add_data(res_keys[k], elts) + + if elts['ips'] then + for ip, v in pairs(elts['ips']) do + if not res_ips[ip] then + res_ips[ip] = {} + end + add_data(res_ips[ip], v) + end + end + end + end + end + end + end + end + + -- General stats + for db, st in pairs(res_databases) do + print(string.format('Statistics for storage %s', db)) + + for k, v in pairs(st) do + if k ~= 'keys' and k ~= 'errors_ips' then + print(string.format('%s: %s', k, print_result(v))) + end + end + print('') + + local res_keys = st['keys'] + if res_keys and not opts['no_keys'] and not opts['short'] then + print('Keys statistics:') + -- Convert into an array to allow sorting + local sorted_keys = sort_hash_table(res_keys, opts, 'key') + + for _, key in ipairs(sorted_keys) do + local key_stat = key.data + if key_stat.name then + print(string.format('Key id: %s, name: %s', key.key, key_stat.name)) + else + print(string.format('Key id: %s', key.key)) + end + + print_stat(key_stat, '\t') + + if key_stat['ips'] and not opts['no_ips'] then + print('') + print('\tIPs stat:') + local sorted_ips = sort_hash_table(key_stat['ips'], opts, 'ip') + + for _, v in ipairs(sorted_ips) do + print(string.format('\t%s', v['ip'])) + print_stat(v['data'], '\t\t') + print('') + end + end + + if key_stat.flags then + print('') + print('\tFlags stat:') + for flag, v in pairs(key_stat.flags) do + print(string.format('\t[%s]:', flag)) + -- Remove irrelevant fields + v.checked = nil + print_stat(v, '\t\t') + print('') + end + end + + print('') + end + end + if st['errors_ips'] and not opts['no_ips'] and not opts['short'] then + print('') + print('Errors IPs statistics:') + local ip_stat = st['errors_ips'] + local ips = lua_util.keys(ip_stat) + -- Reverse sort by number of errors + table.sort(ips, function(a, b) + return ip_stat[a] > ip_stat[b] + end) + for _, ip in ipairs(ips) do + print(string.format('%s: %s', ip, print_result(ip_stat[ip]))) + end + print('') + end + end + + if not opts['no_ips'] and not opts['short'] then + print('') + print('IPs statistics:') + + local sorted_ips = sort_hash_table(res_ips, opts, 'ip') + for _, v in ipairs(sorted_ips) do + print(string.format('%s', v['ip'])) + print_stat(v['data'], '\t') + print('') + end + end +end + diff --git a/lualib/rspamadm/grep.lua b/lualib/rspamadm/grep.lua new file mode 100644 index 0000000..6ed0569 --- /dev/null +++ b/lualib/rspamadm/grep.lua @@ -0,0 +1,174 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local argparse = require "argparse" + + +-- Define command line options +local parser = argparse() + :name "rspamadm grep" + :description "Search for patterns in rspamd logs" + :help_description_margin(30) +parser:mutex( + parser:option "-s --string" + :description('Plain string to search (case-insensitive)') + :argname "<str>", + parser:option "-p --pattern" + :description('Pattern to search for (regex)') + :argname "<re>" +) +parser:flag "-l --lua" + :description('Use Lua patterns in string search') + +parser:argument "input":args "*" + :description('Process specified inputs') + :default("stdin") +parser:flag "-S --sensitive" + :description('Enable case-sensitivity in string search') +parser:flag "-o --orphans" + :description('Print orphaned logs') +parser:flag "-P --partial" + :description('Print partial logs') + +local function handler(args) + + local rspamd_regexp = require 'rspamd_regexp' + local res = parser:parse(args) + + if not res['string'] and not res['pattern'] then + parser:error('string or pattern options must be specified') + end + + if res['string'] and res['pattern'] then + parser:error('string and pattern options are mutually exclusive') + end + + local buffer = {} + local matches = {} + + local pattern = res['pattern'] + local re + if pattern then + re = rspamd_regexp.create(pattern) + if not re then + io.stderr:write("Couldn't compile regex: " .. pattern .. '\n') + os.exit(1) + end + end + + local plainm = true + if res['lua'] then + plainm = false + end + local orphans = res['orphans'] + local search_str = res['string'] + local sensitive = res['sensitive'] + local partial = res['partial'] + if search_str and not sensitive then + search_str = string.lower(search_str) + end + local inputs = res['input'] or { 'stdin' } + + for _, n in ipairs(inputs) do + local h, err + if string.match(n, '%.xz$') then + h, err = io.popen('xzcat ' .. n, 'r') + elseif string.match(n, '%.bz2$') then + h, err = io.popen('bzcat ' .. n, 'r') + elseif string.match(n, '%.gz$') then + h, err = io.popen('zcat ' .. n, 'r') + elseif string.match(n, '%.zst$') then + h, err = io.popen('zstdcat ' .. n, 'r') + elseif n == 'stdin' then + h = io.input() + else + h, err = io.open(n, 'r') + end + if not h then + if err then + io.stderr:write("Couldn't open file (" .. n .. '): ' .. err .. '\n') + else + io.stderr:write("Couldn't open file (" .. n .. '): no error\n') + end + else + for line in h:lines() do + local hash = string.match(line, '<(%x+)>') + local already_matching = false + if hash then + if matches[hash] then + table.insert(matches[hash], line) + already_matching = true + else + if buffer[hash] then + table.insert(buffer[hash], line) + else + buffer[hash] = { line } + end + end + end + local ismatch = false + if re then + ismatch = re:match(line) + elseif sensitive and search_str then + ismatch = string.find(line, search_str, 1, plainm) + elseif search_str then + local lwr = string.lower(line) + ismatch = string.find(lwr, search_str, 1, plainm) + end + if ismatch then + if not hash then + if orphans then + print('*** orphaned ***') + print(line) + print() + end + elseif not already_matching then + matches[hash] = buffer[hash] + end + end + local is_end = string.match(line, '<%x+>; task; rspamd_protocol_http_reply:') + if is_end then + buffer[hash] = nil + if matches[hash] then + for _, v in ipairs(matches[hash]) do + print(v) + end + print() + matches[hash] = nil + end + end + end + if partial then + for k, v in pairs(matches) do + print('*** partial ***') + for _, vv in ipairs(v) do + print(vv) + end + print() + matches[k] = nil + end + else + matches = {} + end + end + end +end + +return { + handler = handler, + description = parser._description, + name = 'grep' +}
\ No newline at end of file diff --git a/lualib/rspamadm/keypair.lua b/lualib/rspamadm/keypair.lua new file mode 100644 index 0000000..f0716a2 --- /dev/null +++ b/lualib/rspamadm/keypair.lua @@ -0,0 +1,508 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local argparse = require "argparse" +local rspamd_keypair = require "rspamd_cryptobox_keypair" +local rspamd_pubkey = require "rspamd_cryptobox_pubkey" +local rspamd_signature = require "rspamd_cryptobox_signature" +local rspamd_crypto = require "rspamd_cryptobox" +local rspamd_util = require "rspamd_util" +local ucl = require "ucl" +local logger = require "rspamd_logger" + +-- Define command line options +local parser = argparse() + :name "rspamadm keypair" + :description "Manages keypairs for Rspamd" + :help_description_margin(30) + :command_target("command") + :require_command(false) + +-- Generate subcommand +local generate = parser:command "generate gen g" + :description "Creates a new keypair" +generate:flag "-s --sign" + :description "Generates a sign keypair instead of the encryption one" +generate:flag "-n --nist" + :description "Uses nist encryption algorithm" +generate:option "-o --output" + :description "Write keypair to file" + :argname "<file>" +generate:mutex( + generate:flag "-j --json" + :description "Output JSON instead of UCL", + generate:flag "-u --ucl" + :description "Output UCL" + :default(true) +) +generate:option "--name" + :description "Adds name extension" + :argname "<name>" + +-- Sign subcommand + +local sign = parser:command "sign sig s" + :description "Signs a file using keypair" +sign:option "-k --keypair" + :description "Keypair to use" + :argname "<file>" +sign:option "-s --suffix" + :description "Suffix for signature" + :argname "<suffix>" + :default("sig") +sign:argument "file" + :description "File to sign" + :argname "<file>" + :args "*" + +-- Verify subcommand + +local verify = parser:command "verify ver v" + :description "Verifies a file using keypair or a public key" +verify:mutex( + verify:option "-p --pubkey" + :description "Load pubkey from the specified file" + :argname "<file>", + verify:option "-P --pubstring" + :description "Load pubkey from the base32 encoded string" + :argname "<base32>", + verify:option "-k --keypair" + :description "Get pubkey from the keypair file" + :argname "<file>" +) +verify:argument "file" + :description "File to verify" + :argname "<file>" + :args "*" +verify:flag "-n --nist" + :description "Uses nistp curves (P256)" +verify:option "-s --suffix" + :description "Suffix for signature" + :argname "<suffix>" + :default("sig") + +-- Encrypt subcommand + +local encrypt = parser:command "encrypt crypt enc e" + :description "Encrypts a file using keypair (or a pubkey)" +encrypt:mutex( + encrypt:option "-p --pubkey" + :description "Load pubkey from the specified file" + :argname "<file>", + encrypt:option "-P --pubstring" + :description "Load pubkey from the base32 encoded string" + :argname "<base32>", + encrypt:option "-k --keypair" + :description "Get pubkey from the keypair file" + :argname "<file>" +) +encrypt:option "-s --suffix" + :description "Suffix for encrypted file" + :argname "<suffix>" + :default("enc") +encrypt:argument "file" + :description "File to encrypt" + :argname "<file>" + :args "*" +encrypt:flag "-r --rm" + :description "Remove unencrypted file" +encrypt:flag "-f --force" + :description "Remove destination file if it exists" + +-- Decrypt subcommand + +local decrypt = parser:command "decrypt dec d" + :description "Decrypts a file using keypair" +decrypt:option "-k --keypair" + :description "Get pubkey from the keypair file" + :argname "<file>" +decrypt:flag "-S --keep-suffix" + :description "Preserve suffix for decrypted file (overwrite encrypted)" +decrypt:argument "file" + :description "File to encrypt" + :argname "<file>" + :args "*" +decrypt:flag "-f --force" + :description "Remove destination file if it exists (implied with -S)" +decrypt:flag "-r --rm" + :description "Remove encrypted file" + +-- Default command is generate, so duplicate options to be compatible + +parser:flag "-s --sign" + :description "Generates a sign keypair instead of the encryption one" +parser:flag "-n --nist" + :description "Uses nistp curves (P256)" +parser:mutex( + parser:flag "-j --json" + :description "Output JSON instead of UCL", + parser:flag "-u --ucl" + :description "Output UCL" + :default(true) +) +parser:option "-o --output" + :description "Write keypair to file" + :argname "<file>" + +local function fatal(...) + logger.errx(...) + os.exit(1) +end + +local function ask_yes_no(greet, default) + local def_str + if default then + greet = greet .. "[Y/n]: " + def_str = "yes" + else + greet = greet .. "[y/N]: " + def_str = "no" + end + + local reply = rspamd_util.readline(greet) + + if not reply then + os.exit(0) + end + if #reply == 0 then + reply = def_str + end + reply = reply:lower() + if reply == 'y' or reply == 'yes' then + return true + end + + return false +end + +local function generate_handler(opts) + local mode = 'encryption' + if opts.sign then + mode = 'sign' + end + local alg = 'curve25519' + if opts.nist then + alg = 'nist' + end + -- TODO: probably, do it in a more safe way + local kp = rspamd_keypair.create(mode, alg):totable() + + if opts.name then + kp.keypair.extensions = { + name = opts.name + } + end + + local format = 'ucl' + + if opts.json then + format = 'json' + end + + if opts.output then + local out = io.open(opts.output, 'w') + if not out then + fatal('cannot open output to write: ' .. opts.output) + end + out:write(ucl.to_format(kp, format)) + out:close() + else + io.write(ucl.to_format(kp, format)) + end +end + +local function sign_handler(opts) + if opts.file then + if type(opts.file) == 'string' then + opts.file = { opts.file } + end + else + parser:error('no files to sign') + end + if not opts.keypair then + parser:error("no keypair specified") + end + + local ucl_parser = ucl.parser() + local res, err = ucl_parser:parse_file(opts.keypair) + + if not res then + fatal(string.format('cannot load %s: %s', opts.keypair, err)) + end + + local kp = rspamd_keypair.load(ucl_parser:get_object()) + + if not kp then + fatal("cannot load keypair: " .. opts.keypair) + end + + for _, fname in ipairs(opts.file) do + local sig = rspamd_crypto.sign_file(kp, fname) + + if not sig then + fatal(string.format("cannot sign %s\n", fname)) + end + + local out = string.format('%s.%s', fname, opts.suffix or 'sig') + local of = io.open(out, 'w') + if not of then + fatal('cannot open output to write: ' .. out) + end + of:write(sig:bin()) + of:close() + io.write(string.format('signed %s -> %s (%s)\n', fname, out, sig:hex())) + end +end + +local function verify_handler(opts) + if opts.file then + if type(opts.file) == 'string' then + opts.file = { opts.file } + end + else + parser:error('no files to verify') + end + + local pk + local alg = 'curve25519' + + if opts.keypair then + local ucl_parser = ucl.parser() + local res, err = ucl_parser:parse_file(opts.keypair) + + if not res then + fatal(string.format('cannot load %s: %s', opts.keypair, err)) + end + + local kp = rspamd_keypair.load(ucl_parser:get_object()) + + if not kp then + fatal("cannot load keypair: " .. opts.keypair) + end + + pk = kp:pk() + alg = kp:alg() + elseif opts.pubkey then + if opts.nist then + alg = 'nist' + end + pk = rspamd_pubkey.load(opts.pubkey, 'sign', alg) + elseif opts.pubstr then + if opts.nist then + alg = 'nist' + end + pk = rspamd_pubkey.create(opts.pubstr, 'sign', alg) + end + + if not pk then + fatal("cannot create pubkey") + end + + local valid = true + + for _, fname in ipairs(opts.file) do + + local sig_fname = string.format('%s.%s', fname, opts.suffix or 'sig') + local sig = rspamd_signature.load(sig_fname, alg) + + if not sig then + fatal(string.format("cannot load signature for %s -> %s", + fname, sig_fname)) + end + + if rspamd_crypto.verify_file(pk, sig, fname, alg) then + io.write(string.format('verified %s -> %s (%s)\n', fname, sig_fname, sig:hex())) + else + valid = false + io.write(string.format('FAILED to verify %s -> %s (%s)\n', fname, + sig_fname, sig:hex())) + end + end + + if not valid then + os.exit(1) + end +end + +local function encrypt_handler(opts) + if opts.file then + if type(opts.file) == 'string' then + opts.file = { opts.file } + end + else + parser:error('no files to sign') + end + + local pk + local alg = 'curve25519' + + if opts.keypair then + local ucl_parser = ucl.parser() + local res, err = ucl_parser:parse_file(opts.keypair) + + if not res then + fatal(string.format('cannot load %s: %s', opts.keypair, err)) + end + + local kp = rspamd_keypair.load(ucl_parser:get_object()) + + if not kp then + fatal("cannot load keypair: " .. opts.keypair) + end + + pk = kp:pk() + alg = kp:alg() + elseif opts.pubkey then + if opts.nist then + alg = 'nist' + end + pk = rspamd_pubkey.load(opts.pubkey, 'sign', alg) + elseif opts.pubstr then + if opts.nist then + alg = 'nist' + end + pk = rspamd_pubkey.create(opts.pubstr, 'sign', alg) + end + + if not pk then + fatal("cannot load keypair: " .. opts.keypair) + end + + for _, fname in ipairs(opts.file) do + local enc = rspamd_crypto.encrypt_file(pk, fname, alg) + + if not enc then + fatal(string.format("cannot encrypt %s\n", fname)) + end + + local out + if opts.suffix and #opts.suffix > 0 then + out = string.format('%s.%s', fname, opts.suffix) + else + out = string.format('%s', fname) + end + + if rspamd_util.file_exists(out) then + if opts.force or ask_yes_no(string.format('File %s already exists, overwrite?', + out), true) then + os.remove(out) + else + os.exit(1) + end + end + + enc:save_in_file(out) + + if opts.rm then + os.remove(fname) + io.write(string.format('encrypted %s (deleted) -> %s\n', fname, out)) + else + io.write(string.format('encrypted %s -> %s\n', fname, out)) + end + end +end + +local function decrypt_handler(opts) + if opts.file then + if type(opts.file) == 'string' then + opts.file = { opts.file } + end + else + parser:error('no files to decrypt') + end + if not opts.keypair then + parser:error("no keypair specified") + end + + local ucl_parser = ucl.parser() + local res, err = ucl_parser:parse_file(opts.keypair) + + if not res then + fatal(string.format('cannot load %s: %s', opts.keypair, err)) + end + + local kp = rspamd_keypair.load(ucl_parser:get_object()) + + if not kp then + fatal("cannot load keypair: " .. opts.keypair) + end + + for _, fname in ipairs(opts.file) do + local decrypted = rspamd_crypto.decrypt_file(kp, fname) + + if not decrypted then + fatal(string.format("cannot decrypt %s\n", fname)) + end + + local out + if not opts['keep-suffix'] then + -- Strip the last suffix + out = fname:match("^(.+)%..+$") + else + out = fname + end + + local removed = false + + if rspamd_util.file_exists(out) then + if (opts.force or opts['keep-suffix']) + or ask_yes_no(string.format('File %s already exists, overwrite?', out), true) then + os.remove(out) + removed = true + else + os.exit(1) + end + end + + if opts.rm then + os.remove(fname) + removed = true + end + + if removed then + io.write(string.format('decrypted %s (removed) -> %s\n', fname, out)) + else + io.write(string.format('decrypted %s -> %s\n', fname, out)) + end + end +end + +local function handler(args) + local opts = parser:parse(args) + + local command = opts.command or "generate" + + if command == 'generate' then + generate_handler(opts) + elseif command == 'sign' then + sign_handler(opts) + elseif command == 'verify' then + verify_handler(opts) + elseif command == 'encrypt' then + encrypt_handler(opts) + elseif command == 'decrypt' then + decrypt_handler(opts) + else + parser:error('command %s is not implemented', command) + end +end + +return { + name = 'keypair', + aliases = { 'kp', 'key' }, + handler = handler, + description = parser._description +}
\ No newline at end of file diff --git a/lualib/rspamadm/mime.lua b/lualib/rspamadm/mime.lua new file mode 100644 index 0000000..6a589d6 --- /dev/null +++ b/lualib/rspamadm/mime.lua @@ -0,0 +1,1012 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local argparse = require "argparse" +local ansicolors = require "ansicolors" +local rspamd_util = require "rspamd_util" +local rspamd_task = require "rspamd_task" +local rspamd_text = require "rspamd_text" +local rspamd_logger = require "rspamd_logger" +local lua_meta = require "lua_meta" +local rspamd_url = require "rspamd_url" +local lua_util = require "lua_util" +local lua_mime = require "lua_mime" +local ucl = require "ucl" + +-- Define command line options +local parser = argparse() + :name "rspamadm mime" + :description "Mime manipulations provided by Rspamd" + :help_description_margin(30) + :command_target("command") + :require_command(true) + +parser:option "-c --config" + :description "Path to config file" + :argname("<cfg>") + :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf") +parser:mutex( + parser:flag "-j --json" + :description "JSON output", + parser:flag "-U --ucl" + :description "UCL output", + parser:flag "-M --messagepack" + :description "MessagePack output" +) +parser:flag "-C --compact" + :description "Use compact format" +parser:flag "--no-file" + :description "Do not print filename" + +-- Extract subcommand +local extract = parser:command "extract ex e" + :description "Extracts data from MIME messages" +extract:argument "file" + :description "File to process" + :argname "<file>" + :args "+" + +extract:flag "-t --text" + :description "Extracts plain text data from a message" +extract:flag "-H --html" + :description "Extracts htm data from a message" +extract:option "-o --output" + :description "Output format ('raw', 'content', 'oneline', 'decoded', 'decoded_utf')" + :argname("<type>") + :convert { + raw = "raw", + content = "content", + oneline = "content_oneline", + decoded = "raw_parsed", + decoded_utf = "raw_utf" +} + :default "content" +extract:flag "-w --words" + :description "Extracts words" +extract:flag "-p --part" + :description "Show part info" +extract:flag "-s --structure" + :description "Show structure info (e.g. HTML tags)" +extract:flag "-i --invisible" + :description "Show invisible content for HTML parts" +extract:option "-F --words-format" + :description "Words format ('stem', 'norm', 'raw', 'full')" + :argname("<type>") + :convert { + stem = "stem", + norm = "norm", + raw = "raw", + full = "full", +} + :default "stem" + +local stat = parser:command "stat st s" + :description "Extracts statistical data from MIME messages" +stat:argument "file" + :description "File to process" + :argname "<file>" + :args "+" +stat:mutex( + stat:flag "-m --meta" + :description "Lua metatokens", + stat:flag "-b --bayes" + :description "Bayes tokens", + stat:flag "-F --fuzzy" + :description "Fuzzy hashes" +) +stat:flag "-s --shingles" + :description "Show shingles for fuzzy hashes" + +local urls = parser:command "urls url u" + :description "Extracts URLs from MIME messages" +urls:argument "file" + :description "File to process" + :argname "<file>" + :args "+" +urls:mutex( + urls:flag "-t --tld" + :description "Get TLDs only", + urls:flag "-H --host" + :description "Get hosts only", + urls:flag "-f --full" + :description "Show piecewise urls as processed by Rspamd" +) + +urls:flag "-u --unique" + :description "Print only unique urls" +urls:flag "-s --sort" + :description "Sort output" +urls:flag "--count" + :description "Print count of each printed element" +urls:flag "-r --reverse" + :description "Reverse sort order" + +local modify = parser:command "modify mod m" + :description "Modifies MIME message" +modify:argument "file" + :description "File to process" + :argname "<file>" + :args "+" + +modify:option "-a --add-header" + :description "Adds specific header" + :argname "<header=value>" + :count "*" +modify:option "-r --remove-header" + :description "Removes specific header (all occurrences)" + :argname "<header>" + :count "*" +modify:option "-R --rewrite-header" + :description "Rewrites specific header, uses Lua string.format pattern" + :argname "<header=pattern>" + :count "*" +modify:option "-t --text-footer" + :description "Adds footer to text/plain parts from a specific file" + :argname "<file>" +modify:option "-H --html-footer" + :description "Adds footer to text/html parts from a specific file" + :argname "<file>" + +local sign = parser:command "sign" + :description "Performs DKIM signing" +sign:argument "file" + :description "File to process" + :argname "<file>" + :args "+" + +sign:option "-d --domain" + :description "Use specific domain" + :argname "<domain>" + :count "1" +sign:option "-s --selector" + :description "Use specific selector" + :argname "<selector>" + :count "1" +sign:option "-k --key" + :description "Use specific key of file" + :argname "<key>" + :count "1" +sign:option "-t --type" + :description "ARC or DKIM signing" + :argname("<arc|dkim>") + :convert { + ['arc'] = 'arc', + ['dkim'] = 'dkim', +} + :default 'dkim' +sign:option "-o --output" + :description "Output format" + :argname("<message|signature>") + :convert { + ['message'] = 'message', + ['signature'] = 'signature', +} + :default 'message' + +local dump = parser:command "dump" + :description "Dumps a raw message in different formats" +dump:argument "file" + :description "File to process" + :argname "<file>" + :args "+" +-- Duplicate format for convenience +dump:mutex( + parser:flag "-j --json" + :description "JSON output", + parser:flag "-U --ucl" + :description "UCL output", + parser:flag "-M --messagepack" + :description "MessagePack output" +) +dump:flag "-s --split" + :description "Split the output file contents such that no content is embedded" + +dump:option "-o --outdir" + :description "Output directory" + :argname("<directory>") + +local function load_config(opts) + local _r, err = rspamd_config:load_ucl(opts['config']) + + if not _r then + rspamd_logger.errx('cannot parse %s: %s', opts['config'], err) + os.exit(1) + end + + _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' }) + if not _r then + rspamd_logger.errx('cannot process %s: %s', opts['config'], err) + os.exit(1) + end +end + +local function load_task(opts, fname) + if not fname then + fname = '-' + end + + local res, task = rspamd_task.load_from_file(fname, rspamd_config) + + if not res then + parser:error(string.format('cannot read message from %s: %s', fname, + task)) + end + + if not task:process_message() then + parser:error(string.format('cannot read message from %s: %s', fname, + 'failed to parse')) + end + + return task +end + +local function highlight(fmt, ...) + return ansicolors.white .. string.format(fmt, ...) .. ansicolors.reset +end + +local function maybe_print_fname(opts, fname) + if not opts.json and not opts['no-file'] then + rspamd_logger.messagex(highlight('File: %s', fname)) + end +end + +local function output_fmt(opts) + local fmt = 'json' + if opts.compact then + fmt = 'json-compact' + end + if opts.ucl then + fmt = 'ucl' + end + if opts.messagepack then + fmt = 'msgpack' + end + + return fmt +end + +-- Print elements in form +-- filename -> table of elements +local function print_elts(elts, opts, func) + local fun = require "fun" + + if opts.json or opts.ucl then + io.write(ucl.to_format(elts, output_fmt(opts))) + else + fun.each(function(fname, elt) + + if not opts.json and not opts.ucl then + if func then + elt = fun.map(func, elt) + end + maybe_print_fname(opts, fname) + fun.each(function(e) + io.write(e) + io.write("\n") + end, elt) + end + end, elts) + end +end + +local function extract_handler(opts) + local out_elts = {} + local tasks = {} + local process_func + + if opts.words then + -- Enable stemming and urls detection + load_config(opts) + rspamd_url.init(rspamd_config:get_tld_path()) + rspamd_config:init_subsystem('langdet') + end + + local function maybe_print_text_part_info(part, out) + local fun = require "fun" + if opts.part then + local t = 'plain text' + if part:is_html() then + t = 'html' + end + + if not opts.json and not opts.ucl then + table.insert(out, + rspamd_logger.slog('Part: %s: %s, language: %s, size: %s (%s raw), words: %s', + part:get_mimepart():get_digest():sub(1, 8), + t, + part:get_language(), + part:get_length(), part:get_raw_length(), + part:get_words_count())) + table.insert(out, + rspamd_logger.slog('Stats: %s', + fun.foldl(function(acc, k, v) + if acc ~= '' then + return string.format('%s, %s:%s', acc, k, v) + else + return string.format('%s:%s', k, v) + end + end, '', part:get_stats()))) + end + end + end + + local function maybe_print_mime_part_info(part, out) + if opts.part then + + if not opts.json and not opts.ucl then + local mtype, msubtype = part:get_type() + local det_mtype, det_msubtype = part:get_detected_type() + table.insert(out, + rspamd_logger.slog('Mime Part: %s: %s/%s (%s/%s detected), filename: %s (%s detected ext), size: %s', + part:get_digest():sub(1, 8), + mtype, msubtype, + det_mtype, det_msubtype, + part:get_filename(), + part:get_detected_ext(), + part:get_length())) + end + end + end + + local function print_words(words, full) + local fun = require "fun" + + if not full then + return table.concat(words, ' ') + else + return table.concat( + fun.totable( + fun.map(function(w) + -- [1] - stemmed word + -- [2] - normalised word + -- [3] - raw word + -- [4] - flags (table of strings) + return string.format('%s|%s|%s(%s)', + w[3], w[2], w[1], table.concat(w[4], ',')) + end, words) + ), + ' ' + ) + end + end + + for _, fname in ipairs(opts.file) do + local task = load_task(opts, fname) + out_elts[fname] = {} + + if not opts.text and not opts.html then + opts.text = true + opts.html = true + end + + if opts.words then + local how_words = opts['words_format'] or 'stem' + table.insert(out_elts[fname], 'meta_words: ' .. + print_words(task:get_meta_words(how_words), how_words == 'full')) + end + + if opts.text or opts.html then + local mp = task:get_parts() or {} + + for _, mime_part in ipairs(mp) do + local how = opts.output + local part + if mime_part:is_text() then + part = mime_part:get_text() + end + + if part and opts.text and not part:is_html() then + maybe_print_text_part_info(part, out_elts[fname]) + maybe_print_mime_part_info(mime_part, out_elts[fname]) + if not opts.json and not opts.ucl then + table.insert(out_elts[fname], '\n') + end + + if opts.words then + local how_words = opts['words_format'] or 'stem' + table.insert(out_elts[fname], print_words(part:get_words(how_words), + how_words == 'full')) + else + table.insert(out_elts[fname], tostring(part:get_content(how))) + end + elseif part and opts.html and part:is_html() then + maybe_print_text_part_info(part, out_elts[fname]) + maybe_print_mime_part_info(mime_part, out_elts[fname]) + if not opts.json and not opts.ucl then + table.insert(out_elts[fname], '\n') + end + + if opts.words then + local how_words = opts['words_format'] or 'stem' + table.insert(out_elts[fname], print_words(part:get_words(how_words), + how_words == 'full')) + else + if opts.structure then + local hc = part:get_html() + local res = {} + process_func = function(elt) + local fun = require "fun" + if type(elt) == 'table' then + return table.concat(fun.totable( + fun.map( + function(t) + return rspamd_logger.slog("%s", t) + end, + elt)), '\n') + else + return rspamd_logger.slog("%s", elt) + end + end + + hc:foreach_tag('any', function(tag) + local elt = {} + local ex = tag:get_extra() + elt.tag = tag:get_type() + if ex then + elt.extra = ex + end + local content = tag:get_content() + if content then + elt.content = tostring(content) + end + local style = tag:get_style() + if style then + elt.style = style + end + table.insert(res, elt) + end) + table.insert(out_elts[fname], res) + else + -- opts.structure + table.insert(out_elts[fname], tostring(part:get_content(how))) + end + if opts.invisible then + local hc = part:get_html() + table.insert(out_elts[fname], string.format('invisible content: %s', + tostring(hc:get_invisible()))) + end + end + end + + if not part then + maybe_print_mime_part_info(mime_part, out_elts[fname]) + end + end + end + + table.insert(out_elts[fname], "") + table.insert(tasks, task) + end + + print_elts(out_elts, opts, process_func) + -- To avoid use after free we postpone tasks destruction + for _, task in ipairs(tasks) do + task:destroy() + end +end + +local function stat_handler(opts) + local fun = require "fun" + local out_elts = {} + + load_config(opts) + rspamd_url.init(rspamd_config:get_tld_path()) + rspamd_config:init_subsystem('langdet,stat') -- Needed to gen stat tokens + + local process_func + + for _, fname in ipairs(opts.file) do + local task = load_task(opts, fname) + out_elts[fname] = {} + + if opts.meta then + local mt = lua_meta.gen_metatokens_table(task) + out_elts[fname] = mt + process_func = function(k, v) + return string.format("%s = %s", k, v) + end + elseif opts.bayes then + local bt = task:get_stat_tokens() + out_elts[fname] = bt + process_func = function(e) + return string.format('%s (%d): "%s"+"%s", [%s]', e.data, e.win, e.t1 or "", + e.t2 or "", table.concat(fun.totable( + fun.map(function(k) + return k + end, e.flags)), ",")) + end + elseif opts.fuzzy then + local parts = task:get_parts() or {} + out_elts[fname] = {} + process_func = function(e) + local ret = string.format('part: %s(%s): %s', e.type, e.file or "", e.digest) + if opts.shingles and e.shingles then + local sgl = {} + for _, s in ipairs(e.shingles) do + table.insert(sgl, string.format('%s: %s+%s+%s', s[1], s[2], s[3], s[4])) + end + + ret = ret .. '\n' .. table.concat(sgl, '\n') + end + return ret + end + for _, part in ipairs(parts) do + if not part:is_multipart() then + local text = part:get_text() + + if text then + local digest, shingles = text:get_fuzzy_hashes(task:get_mempool()) + table.insert(out_elts[fname], { + digest = digest, + shingles = shingles, + type = string.format('%s/%s', + ({ part:get_type() })[1], + ({ part:get_type() })[2]) + }) + else + table.insert(out_elts[fname], { + digest = part:get_digest(), + file = part:get_filename(), + type = string.format('%s/%s', + ({ part:get_type() })[1], + ({ part:get_type() })[2]) + }) + end + end + end + end + + task:destroy() -- No automatic dtor + end + + print_elts(out_elts, opts, process_func) +end + +local function urls_handler(opts) + load_config(opts) + rspamd_url.init(rspamd_config:get_tld_path()) + local out_elts = {} + + if opts.json then + rspamd_logger.messagex('[') + end + + for _, fname in ipairs(opts.file) do + out_elts[fname] = {} + local task = load_task(opts, fname) + local elts = {} + + local function process_url(u) + local s + if opts.tld then + s = u:get_tld() + elseif opts.host then + s = u:get_host() + elseif opts.full then + s = rspamd_logger.slog('%s: %s', u:get_text(), u:to_table()) + else + s = u:get_text() + end + + if opts.unique then + if elts[s] then + elts[s].count = elts[s].count + 1 + else + elts[s] = { + count = 1, + url = u:to_table() + } + end + else + if opts.json then + table.insert(elts, u) + else + table.insert(elts, s) + end + end + end + + for _, u in ipairs(task:get_urls(true)) do + process_url(u) + end + + local json_elts = {} + + local function process_elt(s, u) + if opts.unique then + -- s is string, u is {url = url, count = count } + if not opts.json then + if opts.count then + table.insert(json_elts, string.format('%s : %s', s, u.count)) + else + table.insert(json_elts, s) + end + else + local tb = u.url + tb.count = u.count + table.insert(json_elts, tb) + end + else + -- s is index, u is url or string + if opts.json then + table.insert(json_elts, u) + else + table.insert(json_elts, u) + end + end + end + + if opts.sort then + local sfunc + if opts.unique then + sfunc = function(t, a, b) + if t[a].count ~= t[b].count then + if opts.reverse then + return t[a].count > t[b].count + else + return t[a].count < t[b].count + end + else + -- Sort lexicography + if opts.reverse then + return a > b + else + return a < b + end + end + end + else + sfunc = function(t, a, b) + local va, vb + if opts.json then + va = t[a]:get_text() + vb = t[b]:get_text() + else + va = t[a] + vb = t[b] + end + if opts.reverse then + return va > vb + else + return va < vb + end + end + end + + for s, u in lua_util.spairs(elts, sfunc) do + process_elt(s, u) + end + else + for s, u in pairs(elts) do + process_elt(s, u) + end + end + + out_elts[fname] = json_elts + + task:destroy() -- No automatic dtor + end + + print_elts(out_elts, opts) +end + +local function newline(task) + local t = task:get_newlines_type() + + if t == 'cr' then + return '\r' + elseif t == 'lf' then + return '\n' + end + + return '\r\n' +end + +local function modify_handler(opts) + load_config(opts) + rspamd_url.init(rspamd_config:get_tld_path()) + + local function read_file(file) + local f = assert(io.open(file, "rb")) + local content = f:read("*all") + f:close() + return content + end + + local text_footer, html_footer + + if opts['text_footer'] then + text_footer = read_file(opts['text_footer']) + end + + if opts['html_footer'] then + html_footer = read_file(opts['html_footer']) + end + + for _, fname in ipairs(opts.file) do + local task = load_task(opts, fname) + local newline_s = newline(task) + local seen_cte + + local rewrite = lua_mime.add_text_footer(task, html_footer, text_footer) or {} + local out = {} -- Start with headers + + local function process_headers_cb(name, hdr) + for _, h in ipairs(opts['remove_header']) do + if name:match(h) then + return + end + end + + for _, h in ipairs(opts['rewrite_header']) do + local hname, hpattern = h:match('^([^=]+)=(.+)$') + if hname == name then + local new_value = string.format(hpattern, hdr.decoded) + new_value = string.format('%s:%s%s', + name, hdr.separator, + rspamd_util.fold_header(name, + rspamd_util.mime_header_encode(new_value), + task:get_newlines_type())) + out[#out + 1] = new_value + return + end + end + + if rewrite.need_rewrite_ct then + if name:lower() == 'content-type' then + local nct = string.format('%s: %s/%s; charset=utf-8', + 'Content-Type', rewrite.new_ct.type, rewrite.new_ct.subtype) + out[#out + 1] = nct + return + elseif name:lower() == 'content-transfer-encoding' then + out[#out + 1] = string.format('%s: %s', + 'Content-Transfer-Encoding', rewrite.new_cte or 'quoted-printable') + seen_cte = true + return + end + end + + out[#out + 1] = hdr.raw:gsub('\r?\n?$', '') + end + + task:headers_foreach(process_headers_cb, { full = true }) + + for _, h in ipairs(opts['add_header']) do + local hname, hvalue = h:match('^([^=]+)=(.+)$') + + if hname and hvalue then + out[#out + 1] = string.format('%s: %s', hname, + rspamd_util.fold_header(hname, hvalue, task:get_newlines_type())) + end + end + + if not seen_cte and rewrite.need_rewrite_ct then + out[#out + 1] = string.format('%s: %s', + 'Content-Transfer-Encoding', rewrite.new_cte or 'quoted-printable') + end + + -- End of headers + out[#out + 1] = '' + + if rewrite.out then + for _, o in ipairs(rewrite.out) do + out[#out + 1] = o + end + else + out[#out + 1] = { task:get_rawbody(), false } + end + + for _, o in ipairs(out) do + if type(o) == 'string' then + io.write(o) + io.write(newline_s) + elseif type(o) == 'table' then + io.flush() + if type(o[1]) == 'string' then + io.write(o[1]) + else + o[1]:save_in_file(1) + end + + if o[2] then + io.write(newline_s) + end + else + o:save_in_file(1) + io.write(newline_s) + end + end + + task:destroy() -- No automatic dtor + end +end + +local function sign_handler(opts) + load_config(opts) + rspamd_url.init(rspamd_config:get_tld_path()) + + local lua_dkim = require("lua_ffi").dkim + + if not lua_dkim then + io.stderr:write('FFI support is required: please use LuaJIT or install lua-ffi') + os.exit(1) + end + + local sign_key + if rspamd_util.file_exists(opts.key) then + sign_key = lua_dkim.load_sign_key(opts.key, 'file') + else + sign_key = lua_dkim.load_sign_key(opts.key, 'base64') + end + + if not sign_key then + io.stderr:write('Cannot load key: ' .. opts.key .. '\n') + os.exit(1) + end + + for _, fname in ipairs(opts.file) do + local task = load_task(opts, fname) + local ctx = lua_dkim.create_sign_context(task, sign_key, nil, opts.algorithm) + + if not ctx then + io.stderr:write('Cannot init signing\n') + os.exit(1) + end + + local sig = lua_dkim.do_sign(task, ctx, opts.selector, opts.domain) + + if not sig then + io.stderr:write('Cannot create signature\n') + os.exit(1) + end + + if opts.output == 'signature' then + io.write(sig) + io.write('\n') + io.flush() + else + local dkim_hdr = string.format('%s: %s%s', + 'DKIM-Signature', + rspamd_util.fold_header('DKIM-Signature', + rspamd_util.mime_header_encode(sig), + task:get_newlines_type()), + newline(task)) + io.write(dkim_hdr) + io.flush() + task:get_content():save_in_file(1) + end + + task:destroy() -- No automatic dtor + end +end + +-- Strips directories and .extensions (if present) from a filepath +local function filename_only(filepath) + local filename = filepath:match(".*%/([^%.]+)") + if not filename then + filename = filepath:match("([^%.]+)") + end + return filename +end + +assert(filename_only("very_simple") == "very_simple") +assert(filename_only("/home/very_simple.eml") == "very_simple") +assert(filename_only("very_simple.eml") == "very_simple") +assert(filename_only("very_simple.example.eml") == "very_simple") +assert(filename_only("/home/very_simple") == "very_simple") +assert(filename_only("home/very_simple") == "very_simple") +assert(filename_only("./home/very_simple") == "very_simple") +assert(filename_only("../home/very_simple.eml") == "very_simple") +assert(filename_only("/home/dir.with.dots/very_simple.eml") == "very_simple") + +--Write the dump content to file or standard out +local function write_dump_content(dump_content, fname, extension, outdir) + if type(dump_content) == "string" then + dump_content = rspamd_text.fromstring(dump_content) + end + + local wrote_filepath = nil + if outdir then + if outdir:sub(-1) ~= "/" then + outdir = outdir .. "/" + end + + local outpath = string.format("%s%s.%s", outdir, filename_only(fname), extension) + if rspamd_util.file_exists(outpath) then + os.remove(outpath) + end + if dump_content:save_in_file(outpath) then + wrote_filepath = outpath + io.write(wrote_filepath .. "\n") + else + io.stderr:write(string.format("Unable to save dump content to file: %s\n", outpath)) + end + else + dump_content:save_in_file(1) + end + return wrote_filepath +end + +-- Get the formatted ucl (split or unsplit) or the raw task content +local function get_dump_content(task, opts, fname) + if opts.ucl or opts.json or opts.messagepack then + local ucl_object = lua_mime.message_to_ucl(task) + -- Split out the content field into separate raws and update the ucl + if opts.split then + for i, part in ipairs(ucl_object.parts) do + if part.content then + local part_filename = string.format("%s-part%d", filename_only(fname), i) + local part_path = write_dump_content(part.content, part_filename, "raw", opts.outdir) + if part_path then + part.content = ucl.null + part.content_path = part_path + end + end + end + end + local extension = output_fmt(opts) + return ucl.to_format(ucl_object, extension), extension + end + return task:get_content(), "mime" +end + +local function dump_handler(opts) + load_config(opts) + rspamd_url.init(rspamd_config:get_tld_path()) + + for _, fname in ipairs(opts.file) do + local task = load_task(opts, fname) + local data, extension = get_dump_content(task, opts, fname) + write_dump_content(data, fname, extension, opts.outdir) + + task:destroy() -- No automatic dtor + end +end + +local function handler(args) + local opts = parser:parse(args) + + local command = opts.command + + if type(opts.file) == 'string' then + opts.file = { opts.file } + elseif type(opts.file) == 'none' then + opts.file = {} + end + + if command == 'extract' then + extract_handler(opts) + elseif command == 'stat' then + stat_handler(opts) + elseif command == 'urls' then + urls_handler(opts) + elseif command == 'modify' then + modify_handler(opts) + elseif command == 'sign' then + sign_handler(opts) + elseif command == 'dump' then + dump_handler(opts) + else + parser:error('command %s is not implemented', command) + end +end + +return { + name = 'mime', + aliases = { 'mime_tool' }, + handler = handler, + description = parser._description +} diff --git a/lualib/rspamadm/neural_test.lua b/lualib/rspamadm/neural_test.lua new file mode 100644 index 0000000..31d21a9 --- /dev/null +++ b/lualib/rspamadm/neural_test.lua @@ -0,0 +1,228 @@ +local rspamd_logger = require "rspamd_logger" +local argparse = require "argparse" +local lua_util = require "lua_util" +local ucl = require "ucl" + +local parser = argparse() + :name "rspamadm neural_test" + :description "Test the neural network with labelled dataset" + :help_description_margin(32) + +parser:option "-c --config" + :description "Path to config file" + :argname("<cfg>") + :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf") +parser:option "-H --hamdir" + :description("Ham directory") + :argname("<dir>") +parser:option "-S --spamdir" + :description("Spam directory") + :argname("<dir>") +parser:option "-t --timeout" + :description("Timeout for client connections") + :argname("<sec>") + :convert(tonumber) + :default(60) +parser:option "-n --conns" + :description("Number of parallel connections") + :argname("<N>") + :convert(tonumber) + :default(10) +parser:option "-c --connect" + :description("Connect to specific host") + :argname("<host>") + :default('localhost:11334') +parser:option "-r --rspamc" + :description("Use specific rspamc path") + :argname("<path>") + :default('rspamc') +parser:option '--rule' + :description 'Rule to test' + :argname('<rule>') + +local HAM = "HAM" +local SPAM = "SPAM" + +local function load_config(opts) + local _r, err = rspamd_config:load_ucl(opts['config']) + + if not _r then + rspamd_logger.errx('cannot parse %s: %s', opts['config'], err) + os.exit(1) + end + + _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' }) + if not _r then + rspamd_logger.errx('cannot process %s: %s', opts['config'], err) + os.exit(1) + end +end + +local function scan_email(rspamc_path, host, n_parallel, path, timeout) + + local rspamc_command = string.format("%s --connect %s -j --compact -n %s -t %.3f %s", + rspamc_path, host, n_parallel, timeout, path) + local result = assert(io.popen(rspamc_command)) + result = result:read("*all") + return result +end + +local function encoded_json_to_log(result) + -- Returns table containing score, action, list of symbols + + local filtered_result = {} + local ucl_parser = ucl.parser() + + local is_good, err = ucl_parser:parse_string(result) + + if not is_good then + rspamd_logger.errx("Parser error: %1", err) + return nil + end + + result = ucl_parser:get_object() + + filtered_result.score = result.score + if not result.action then + rspamd_logger.errx("Bad JSON: %1", result) + return nil + end + local action = result.action:gsub("%s+", "_") + filtered_result.action = action + + filtered_result.symbols = {} + + for sym, _ in pairs(result.symbols) do + table.insert(filtered_result.symbols, sym) + end + + filtered_result.filename = result.filename + filtered_result.scan_time = result.scan_time + + return filtered_result +end + +local function filter_scan_results(results, actual_email_type) + + local logs = {} + + results = lua_util.rspamd_str_split(results, "\n") + + if results[#results] == "" then + results[#results] = nil + end + + for _, result in pairs(results) do + result = encoded_json_to_log(result) + if result then + result['type'] = actual_email_type + table.insert(logs, result) + end + end + + return logs +end + +local function get_stats_from_scan_results(results, rules) + + local rule_stats = {} + for rule, _ in pairs(rules) do + rule_stats[rule] = { tp = 0, tn = 0, fp = 0, fn = 0 } + end + + for _, result in ipairs(results) do + for _, symbol in ipairs(result["symbols"]) do + for name, rule in pairs(rules) do + if rule.symbol_spam and rule.symbol_spam == symbol then + if result.type == HAM then + rule_stats[name].fp = rule_stats[name].fp + 1 + elseif result.type == SPAM then + rule_stats[name].tp = rule_stats[name].tp + 1 + end + elseif rule.symbol_ham and rule.symbol_ham == symbol then + if result.type == HAM then + rule_stats[name].tn = rule_stats[name].tn + 1 + elseif result.type == SPAM then + rule_stats[name].fn = rule_stats[name].fn + 1 + end + end + end + end + end + + for rule, _ in pairs(rules) do + rule_stats[rule].fpr = rule_stats[rule].fp / (rule_stats[rule].fp + rule_stats[rule].tn) + rule_stats[rule].fnr = rule_stats[rule].fn / (rule_stats[rule].fn + rule_stats[rule].tp) + end + + return rule_stats +end + +local function print_neural_stats(neural_stats) + for rule, stats in pairs(neural_stats) do + rspamd_logger.messagex("\nStats for rule: %s", rule) + rspamd_logger.messagex("False positive rate: %s%%", stats.fpr * 100) + rspamd_logger.messagex("False negative rate: %s%%", stats.fnr * 100) + end +end + +local function handler(args) + local opts = parser:parse(args) + + local ham_directory = opts['hamdir'] + local spam_directory = opts['spamdir'] + local connections = opts["conns"] + + load_config(opts) + + local neural_opts = rspamd_config:get_all_opt('neural') + + if opts["rule"] then + local found = false + for rule_name, _ in pairs(neural_opts.rules) do + if string.lower(rule_name) == string.lower(opts["rule"]) then + found = true + else + neural_opts.rules[rule_name] = nil + end + end + + if not found then + rspamd_logger.errx("Couldn't find the rule %s", opts["rule"]) + return + end + end + + local results = {} + + if ham_directory then + rspamd_logger.messagex("Scanning ham corpus...") + local ham_results = scan_email(opts.rspamc, opts.connect, connections, ham_directory, opts.timeout) + ham_results = filter_scan_results(ham_results, HAM) + + for _, result in pairs(ham_results) do + table.insert(results, result) + end + end + + if spam_directory then + rspamd_logger.messagex("Scanning spam corpus...") + local spam_results = scan_email(opts.rspamc, opts.connect, connections, spam_directory, opts.timeout) + spam_results = filter_scan_results(spam_results, SPAM) + + for _, result in pairs(spam_results) do + table.insert(results, result) + end + end + + local neural_stats = get_stats_from_scan_results(results, neural_opts.rules) + print_neural_stats(neural_stats) + +end + +return { + name = "neuraltest", + aliases = { "neural_test" }, + handler = handler, + description = parser._description +}
\ No newline at end of file diff --git a/lualib/rspamadm/publicsuffix.lua b/lualib/rspamadm/publicsuffix.lua new file mode 100644 index 0000000..96bf069 --- /dev/null +++ b/lualib/rspamadm/publicsuffix.lua @@ -0,0 +1,82 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local argparse = require "argparse" +local rspamd_logger = require "rspamd_logger" + +-- Define command line options +local parser = argparse() + :name 'rspamadm publicsuffix' + :description 'Do manipulations with the publicsuffix list' + :help_description_margin(30) + :command_target('command') + :require_command(true) + +parser:option '-c --config' + :description 'Path to config file' + :argname('config_file') + :default(rspamd_paths['CONFDIR'] .. '/rspamd.conf') + +parser:command 'compile' + :description 'Compile publicsuffix list if needed' + +local function load_config(config_file) + local _r, err = rspamd_config:load_ucl(config_file) + + if not _r then + rspamd_logger.errx('cannot load %s: %s', config_file, err) + os.exit(1) + end + + _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' }) + if not _r then + rspamd_logger.errx('cannot process %s: %s', config_file, err) + os.exit(1) + end +end + +local function compile_handler(_) + local rspamd_url = require "rspamd_url" + local tld_file = rspamd_config:get_tld_path() + + if not tld_file then + rspamd_logger.errx('missing `url_tld` option, cannot continue') + os.exit(1) + end + + rspamd_logger.messagex('loading public suffix file from %s', tld_file) + rspamd_url.init(tld_file) + rspamd_logger.messagex('public suffix file has been loaded') +end + +local function handler(args) + local cmd_opts = parser:parse(args) + + load_config(cmd_opts.config_file) + + if cmd_opts.command == 'compile' then + compile_handler(cmd_opts) + else + rspamd_logger.errx('unknown command: %s', cmd_opts.command) + os.exit(1) + end +end + +return { + handler = handler, + description = parser._description, + name = 'publicsuffix' +}
\ No newline at end of file diff --git a/lualib/rspamadm/stat_convert.lua b/lualib/rspamadm/stat_convert.lua new file mode 100644 index 0000000..62a19a2 --- /dev/null +++ b/lualib/rspamadm/stat_convert.lua @@ -0,0 +1,38 @@ +local lua_redis = require "lua_redis" +local stat_tools = require "lua_stat" +local ucl = require "ucl" +local logger = require "rspamd_logger" +local lua_util = require "lua_util" + +return function(_, res) + local redis_params = lua_redis.try_load_redis_servers(res.redis, nil) + if res.expire then + res.expire = lua_util.parse_time_interval(res.expire) + end + if not redis_params then + logger.errx('cannot load redis server definition') + + return false + end + + local sqlite_params = stat_tools.load_sqlite_config(res) + + if #sqlite_params == 0 then + logger.errx('cannot load sqlite classifiers') + return false + end + + for _, cls in ipairs(sqlite_params) do + if not stat_tools.convert_sqlite_to_redis(redis_params, cls.db_spam, + cls.db_ham, cls.symbol_spam, cls.symbol_ham, cls.learn_cache, res.expire, + res.reset_previous) then + logger.errx('conversion failed') + + return false + end + logger.messagex('Converted classifier to the from sqlite to redis') + logger.messagex('Suggested configuration:') + logger.messagex(ucl.to_format(stat_tools.redis_classifier_from_sqlite(cls, res.expire), + 'config')) + end +end diff --git a/lualib/rspamadm/statistics_dump.lua b/lualib/rspamadm/statistics_dump.lua new file mode 100644 index 0000000..6bc0458 --- /dev/null +++ b/lualib/rspamadm/statistics_dump.lua @@ -0,0 +1,544 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local lua_redis = require "lua_redis" +local rspamd_logger = require "rspamd_logger" +local argparse = require "argparse" +local rspamd_zstd = require "rspamd_zstd" +local rspamd_text = require "rspamd_text" +local rspamd_util = require "rspamd_util" +local rspamd_cdb = require "rspamd_cdb" +local lua_util = require "lua_util" +local rspamd_i64 = require "rspamd_int64" +local ucl = require "ucl" + +local N = "statistics_dump" +local E = {} +local classifiers = {} + +-- Define command line options +local parser = argparse() + :name "rspamadm statistics_dump" + :description "Dump/restore Rspamd statistics" + :help_description_margin(30) + :command_target("command") + :require_command(false) + +parser:option "-c --config" + :description "Path to config file" + :argname("<cfg>") + :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf") + +-- Extract subcommand +local dump = parser:command "dump d" + :description "Dump bayes statistics" +dump:mutex( + dump:flag "-j --json" + :description "Json output", + dump:flag "-C --cdb" + :description "CDB output" +) +dump:flag "-c --compress" + :description "Compress output" +dump:option "-b --batch-size" + :description "Number of entires to process at once" + :argname("<elts>") + :convert(tonumber) + :default(1000) + + +-- Restore +local restore = parser:command "restore r" + :description "Restore bayes statistics" +restore:argument "file" + :description "Input file to process" + :argname "<file>" + :args "*" +restore:option "-b --batch-size" + :description "Number of entires to process at once" + :argname("<elts>") + :convert(tonumber) + :default(1000) +restore:option "-m --mode" + :description "Number of entires to process at once" + :argname("<append|subtract|replace>") + :convert { + ['append'] = 'append', + ['subtract'] = 'subtract', + ['replace'] = 'replace', +} + :default 'append' +restore:flag "-n --no-operation" + :description "Only show redis commands to be issued" + +local function load_config(opts) + local _r, err = rspamd_config:load_ucl(opts['config']) + + if not _r then + rspamd_logger.errx('cannot parse %s: %s', opts['config'], err) + os.exit(1) + end + + _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' }) + if not _r then + rspamd_logger.errx('cannot process %s: %s', opts['config'], err) + os.exit(1) + end +end + +local function check_redis_classifier(cls, cfg) + -- Skip old classifiers + if cls.new_schema then + local symbol_spam, symbol_ham + -- Load symbols from statfiles + + local function check_statfile_table(tbl, def_sym) + local symbol = tbl.symbol or def_sym + + local spam + if tbl.spam then + spam = tbl.spam + else + if string.match(symbol:upper(), 'SPAM') then + spam = true + else + spam = false + end + end + + if spam then + symbol_spam = symbol + else + symbol_ham = symbol + end + end + + local statfiles = cls.statfile + if statfiles[1] then + for _, stf in ipairs(statfiles) do + if not stf.symbol then + for k, v in pairs(stf) do + check_statfile_table(v, k) + end + else + check_statfile_table(stf, 'undefined') + end + end + else + for stn, stf in pairs(statfiles) do + check_statfile_table(stf, stn) + end + end + + local redis_params + redis_params = lua_redis.try_load_redis_servers(cls, + rspamd_config, false, 'bayes') + if not redis_params then + redis_params = lua_redis.try_load_redis_servers(cfg[N] or E, + rspamd_config, false, 'bayes') + if not redis_params then + redis_params = lua_redis.try_load_redis_servers(cfg[N] or E, + rspamd_config, true) + if not redis_params then + return false + end + end + end + + table.insert(classifiers, { + symbol_spam = symbol_spam, + symbol_ham = symbol_ham, + redis_params = redis_params, + }) + end +end + +local function redis_map_zip(ar) + local data = {} + for j = 1, #ar, 2 do + data[ar[j]] = ar[j + 1] + end + + return data +end + +-- Used to clear tables +local clear_fcn = table.clear or function(tbl) + local keys = lua_util.keys(tbl) + for _, k in ipairs(keys) do + tbl[k] = nil + end +end + +local compress_ctx + +local function dump_out(out, opts, last) + if opts.compress and not compress_ctx then + compress_ctx = rspamd_zstd.compress_ctx() + end + + if compress_ctx then + if last then + compress_ctx:stream(rspamd_text.fromtable(out), 'end'):write() + else + compress_ctx:stream(rspamd_text.fromtable(out), 'flush'):write() + end + else + for _, o in ipairs(out) do + io.write(o) + end + end +end + +local function dump_cdb(out, opts, last, pattern) + local results = out[pattern] + + if not out.cdb_builder then + -- First invocation + out.cdb_builder = rspamd_cdb.build(string.format('%s.cdb', pattern)) + out.cdb_builder:add('_lrnspam', rspamd_i64.fromstring(results.learns_spam or '0')) + out.cdb_builder:add('_lrnham_', rspamd_i64.fromstring(results.learns_ham or '0')) + end + + for _, o in ipairs(results.elts) do + out.cdb_builder:add(o.key, o.value) + end + + if last then + out.cdb_builder:finalize() + out.cdb_builder = nil + end +end + +local function dump_pattern(conn, pattern, opts, out, key) + local cursor = 0 + + repeat + conn:add_cmd('SCAN', { tostring(cursor), + 'MATCH', pattern, + 'COUNT', tostring(opts.batch_size) }) + local ret, results = conn:exec() + + if not ret then + rspamd_logger.errx("cannot connect execute scan command: %s", results) + os.exit(1) + end + + cursor = tonumber(results[1]) + + local elts = results[2] + local tokens = {} + + for _, e in ipairs(elts) do + conn:add_cmd('HGETALL', { e }) + end + -- This function returns many results, each for each command + -- So if we have batch 1000, then we would have 1000 tables in form + -- [result, {hash_content}] + local all_results = { conn:exec() } + + for i = 1, #all_results, 2 do + local r, hash_content = all_results[i], all_results[i + 1] + + if r then + -- List to a hash map + local data = redis_map_zip(hash_content) + tokens[#tokens + 1] = { key = elts[(i + 1) / 2], data = data } + end + end + + -- Output keeping track of the commas + for i, d in ipairs(tokens) do + if cursor == 0 and i == #tokens or not opts.json then + if opts.cdb then + table.insert(out[key].elts, { + key = rspamd_i64.fromstring(string.match(d.key, '%d+')), + value = rspamd_util.pack('ff', tonumber(d.data["S"] or '0') or 0, + tonumber(d.data["H"] or '0')) + }) + else + out[#out + 1] = rspamd_logger.slog('"%s": %s\n', d.key, + ucl.to_format(d.data, "json-compact")) + end + else + out[#out + 1] = rspamd_logger.slog('"%s": %s,\n', d.key, + ucl.to_format(d.data, "json-compact")) + end + + end + + if opts.json and cursor == 0 then + out[#out + 1] = '}}\n' + end + + -- Do not write the last chunk of out as it will be processed afterwards + if cursor ~= 0 then + if opts.cdb then + dump_out(out, opts, false) + clear_fcn(out) + else + dump_cdb(out, opts, false, key) + out[key].elts = {} + end + elseif opts.cdb then + dump_cdb(out, opts, true, key) + end + + until cursor == 0 +end + +local function dump_handler(opts) + local patterns_seen = {} + for _, cls in ipairs(classifiers) do + local res, conn = lua_redis.redis_connect_sync(cls.redis_params, false) + + if not res then + rspamd_logger.errx("cannot connect to redis server: %s", cls.redis_params) + os.exit(1) + end + + local out = {} + local function check_keys(sym) + local sym_keys_pattern = string.format("%s_keys", sym) + conn:add_cmd('SMEMBERS', { sym_keys_pattern }) + local ret, keys = conn:exec() + + if not ret then + rspamd_logger.errx("cannot execute command to get keys: %s", keys) + os.exit(1) + end + + if not opts.json then + out[#out + 1] = string.format('"%s": %s\n', sym_keys_pattern, + ucl.to_format(keys, 'json-compact')) + end + for _, k in ipairs(keys) do + local pat = string.format('%s*_*', k) + if not patterns_seen[pat] then + conn:add_cmd('HGETALL', { k }) + local _ret, additional_keys = conn:exec() + + if _ret then + if opts.json then + out[#out + 1] = string.format('{"pattern": "%s", "meta": %s, "elts": {\n', + k, ucl.to_format(redis_map_zip(additional_keys), 'json-compact')) + elseif opts.cdb then + out[k] = redis_map_zip(additional_keys) + out[k].elts = {} + else + out[#out + 1] = string.format('"%s": %s\n', k, + ucl.to_format(redis_map_zip(additional_keys), 'json-compact')) + end + dump_pattern(conn, pat, opts, out, k) + patterns_seen[pat] = true + end + end + end + end + + check_keys(cls.symbol_spam) + check_keys(cls.symbol_ham) + + if #out > 0 then + dump_out(out, opts, true) + end + end +end + +local function obj_to_redis_arguments(obj, opts, cmd_pipe) + local key, value = next(obj) + + if type(key) == 'string' then + if type(value) == 'table' then + if not value[1] then + if opts.mode == 'replace' then + local cmd = 'HMSET' + local params = { key } + for k, v in pairs(value) do + table.insert(params, k) + table.insert(params, v) + end + table.insert(cmd_pipe, { cmd, params }) + else + local cmd = 'HINCRBYFLOAT' + local mult = 1.0 + if opts.mode == 'subtract' then + mult = (-mult) + end + + for k, v in pairs(value) do + if tonumber(v) then + v = tonumber(v) + table.insert(cmd_pipe, { cmd, { key, k, tostring(v * mult) } }) + else + table.insert(cmd_pipe, { 'HSET', { key, k, v } }) + end + end + end + else + -- Numeric table of elements (e.g. _keys) - it is actually a set in Redis + for _, elt in ipairs(value) do + table.insert(cmd_pipe, { 'SADD', { key, elt } }) + end + end + end + end + + return cmd_pipe +end + +local function execute_batch(batch, conns, opts) + local cmd_pipe = {} + + for _, cmd in ipairs(batch) do + obj_to_redis_arguments(cmd, opts, cmd_pipe) + end + + if opts.no_operation then + for _, cmd in ipairs(cmd_pipe) do + rspamd_logger.messagex('%s %s', cmd[1], table.concat(cmd[2], ' ')) + end + else + for _, conn in ipairs(conns) do + for _, cmd in ipairs(cmd_pipe) do + local is_ok, err = conn:add_cmd(cmd[1], cmd[2]) + + if not is_ok then + rspamd_logger.errx("cannot add command: %s with args: %s: %s", cmd[1], cmd[2], err) + end + end + + conn:exec() + end + end +end + +local function restore_handler(opts) + local files = opts.file or { '-' } + local conns = {} + + for _, cls in ipairs(classifiers) do + local res, conn = lua_redis.redis_connect_sync(cls.redis_params, true) + + if not res then + rspamd_logger.errx("cannot connect to redis server: %s", cls.redis_params) + os.exit(1) + end + + table.insert(conns, conn) + end + + local batch = {} + + for _, f in ipairs(files) do + local fd + if f ~= '-' then + fd = io.open(f, 'r') + io.input(fd) + end + + local cur_line = 1 + for line in io.lines() do + local ucl_parser = ucl.parser() + local res, err + res, err = ucl_parser:parse_string(line) + + if not res then + rspamd_logger.errx("%s: cannot read line %s: %s", f, cur_line, err) + os.exit(1) + end + + table.insert(batch, ucl_parser:get_object()) + cur_line = cur_line + 1 + + if cur_line % opts.batch_size == 0 then + execute_batch(batch, conns, opts) + batch = {} + end + end + + if fd then + fd:close() + end + end + + if #batch > 0 then + execute_batch(batch, conns, opts) + end +end + +local function handler(args) + local opts = parser:parse(args) + + local command = opts.command or 'dump' + + load_config(opts) + rspamd_config:init_subsystem('stat') + + local obj = rspamd_config:get_ucl() + + local classifier = obj.classifier + + if classifier then + if classifier[1] then + for _, cls in ipairs(classifier) do + if cls.bayes then + cls = cls.bayes + end + if cls.backend and cls.backend == 'redis' then + check_redis_classifier(cls, obj) + end + end + else + if classifier.bayes then + + classifier = classifier.bayes + if classifier[1] then + for _, cls in ipairs(classifier) do + if cls.backend and cls.backend == 'redis' then + check_redis_classifier(cls, obj) + end + end + else + if classifier.backend and classifier.backend == 'redis' then + check_redis_classifier(classifier, obj) + end + end + end + end + end + + if type(opts.file) == 'string' then + opts.file = { opts.file } + elseif type(opts.file) == 'none' then + opts.file = {} + end + + if command == 'dump' then + dump_handler(opts) + elseif command == 'restore' then + restore_handler(opts) + else + parser:error('command %s is not implemented', command) + end +end + +return { + name = 'statistics_dump', + aliases = { 'stat_dump', 'bayes_dump' }, + handler = handler, + description = parser._description +}
\ No newline at end of file diff --git a/lualib/rspamadm/template.lua b/lualib/rspamadm/template.lua new file mode 100644 index 0000000..ca1779a --- /dev/null +++ b/lualib/rspamadm/template.lua @@ -0,0 +1,131 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local argparse = require "argparse" + + +-- Define command line options +local parser = argparse() + :name "rspamadm template" + :description "Apply jinja templates for strings/files" + :help_description_margin(30) +parser:argument "file" + :description "File to process" + :argname "<file>" + :args "*" + +parser:flag "-n --no-vars" + :description "Don't add Rspamd internal variables" +parser:option "-e --env" + :description "Load additional environment vars from specific file (name=value)" + :argname "<filename>" + :count "*" +parser:option "-l --lua-env" + :description "Load additional environment vars from specific file (lua source)" + :argname "<filename>" + :count "*" +parser:mutex( + parser:option "-s --suffix" + :description "Store files with the new suffix" + :argname "<suffix>", + parser:flag "-i --inplace" + :description "Replace input file(s)" +) + +local lua_util = require "lua_util" + +local function set_env(opts, env) + if opts.env then + for _, fname in ipairs(opts.env) do + for kv in assert(io.open(fname)):lines() do + if not kv:match('%s*#.*') then + local k, v = kv:match('([^=%s]+)%s*=%s*(.+)') + + if k and v then + env[k] = v + else + io.write(string.format('invalid env line in %s: %s\n', fname, kv)) + end + end + end + end + end + + if opts.lua_env then + for _, fname in ipairs(opts.env) do + local ret, res_or_err = pcall(loadfile(fname)) + + if not ret then + io.write(string.format('cannot load %s: %s\n', fname, res_or_err)) + else + if type(res_or_err) == 'table' then + for k, v in pairs(res_or_err) do + env[k] = lua_util.deepcopy(v) + end + else + io.write(string.format('cannot load %s: not a table\n', fname)) + end + end + end + end +end + +local function read_file(file) + local content + if file == '-' then + content = io.read("*all") + else + local f = assert(io.open(file, "rb")) + content = f:read("*all") + f:close() + end + return content +end + +local function handler(args) + local opts = parser:parse(args) + local env = {} + set_env(opts, env) + + if not opts.file or #opts.file == 0 then + opts.file = { '-' } + end + for _, fname in ipairs(opts.file) do + local content = read_file(fname) + local res = lua_util.jinja_template(content, env, opts.no_vars) + + if opts.inplace then + local nfile = string.format('%s.new', fname) + local out = assert(io.open(nfile, 'w')) + out:write(content) + out:close() + os.rename(nfile, fname) + elseif opts.suffix then + local nfile = string.format('%s.%s', opts.suffix) + local out = assert(io.open(nfile, 'w')) + out:write(content) + out:close() + else + io.write(res) + end + end +end + +return { + handler = handler, + description = parser._description, + name = 'template' +}
\ No newline at end of file diff --git a/lualib/rspamadm/vault.lua b/lualib/rspamadm/vault.lua new file mode 100644 index 0000000..840e504 --- /dev/null +++ b/lualib/rspamadm/vault.lua @@ -0,0 +1,579 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + + +local rspamd_logger = require "rspamd_logger" +local ansicolors = require "ansicolors" +local ucl = require "ucl" +local argparse = require "argparse" +local fun = require "fun" +local rspamd_http = require "rspamd_http" +local cr = require "rspamd_cryptobox" + +local parser = argparse() + :name "rspamadm vault" + :description "Perform Hashicorp Vault management" + :help_description_margin(32) + :command_target("command") + :require_command(true) + +parser:flag "-s --silent" + :description "Do not output extra information" +parser:option "-a --addr" + :description "Vault address (if not defined in VAULT_ADDR env)" +parser:option "-t --token" + :description "Vault token (not recommended, better define VAULT_TOKEN env)" +parser:option "-p --path" + :description "Path to work with in the vault" + :default "dkim" +parser:option "-o --output" + :description "Output format ('ucl', 'json', 'json-compact', 'yaml')" + :argname("<type>") + :convert { + ucl = "ucl", + json = "json", + ['json-compact'] = "json-compact", + yaml = "yaml", +} + :default "ucl" + +parser:command "list ls l" + :description "List elements in the vault" + +local show = parser:command "show get" + :description "Extract element from the vault" +show:argument "domain" + :description "Domain to create key for" + :args "+" + +local delete = parser:command "delete del rm remove" + :description "Delete element from the vault" +delete:argument "domain" + :description "Domain to create delete key(s) for" + :args "+" + +local newkey = parser:command "newkey new create" + :description "Add new key to the vault" +newkey:argument "domain" + :description "Domain to create key for" + :args "+" +newkey:option "-s --selector" + :description "Selector to use" + :count "?" +newkey:option "-A --algorithm" + :argname("<type>") + :convert { + rsa = "rsa", + ed25519 = "ed25519", + eddsa = "ed25519", +} + :default "rsa" +newkey:option "-b --bits" + :argname("<nbits>") + :convert(tonumber) + :default "1024" +newkey:option "-x --expire" + :argname("<days>") + :convert(tonumber) +newkey:flag "-r --rewrite" + +local roll = parser:command "roll rollover" + :description "Perform keys rollover" +roll:argument "domain" + :description "Domain to roll key(s) for" + :args "+" +roll:option "-T --ttl" + :description "Validity period for old keys (days)" + :convert(tonumber) + :default "1" +roll:flag "-r --remove-expired" + :description "Remove expired keys" +roll:option "-x --expire" + :argname("<days>") + :convert(tonumber) + +local function printf(fmt, ...) + if fmt then + io.write(rspamd_logger.slog(fmt, ...)) + end + io.write('\n') +end + +local function maybe_printf(opts, fmt, ...) + if not opts.silent then + printf(fmt, ...) + end +end + +local function highlight(str, color) + return ansicolors[color or 'white'] .. str .. ansicolors.reset +end + +local function vault_url(opts, path) + if path then + return string.format('%s/v1/%s/%s', opts.addr, opts.path, path) + end + + return string.format('%s/v1/%s', opts.addr, opts.path) +end + +local function is_http_error(err, data) + return err or (math.floor(data.code / 100) ~= 2) +end + +local function parse_vault_reply(data) + local p = ucl.parser() + local res, parser_err = p:parse_string(data) + + if not res then + return nil, parser_err + else + return p:get_object(), nil + end +end + +local function maybe_print_vault_data(opts, data, func) + if data then + local res, parser_err = parse_vault_reply(data) + + if not res then + printf('vault reply for cannot be parsed: %s', parser_err) + else + if func then + printf(ucl.to_format(func(res), opts.output)) + else + printf(ucl.to_format(res, opts.output)) + end + end + else + printf('no data received') + end +end + +local function print_dkim_txt_record(b64, selector, alg) + local labels = {} + local prefix = string.format("v=DKIM1; k=%s; p=", alg) + b64 = prefix .. b64 + if #b64 < 255 then + labels = { '"' .. b64 .. '"' } + else + for sl = 1, #b64, 256 do + table.insert(labels, '"' .. b64:sub(sl, sl + 255) .. '"') + end + end + + printf("%s._domainkey IN TXT ( %s )", selector, + table.concat(labels, "\n\t")) +end + +local function show_handler(opts, domain) + local uri = vault_url(opts, domain) + local err, data = rspamd_http.request { + config = rspamd_config, + ev_base = rspamadm_ev_base, + session = rspamadm_session, + resolver = rspamadm_dns_resolver, + url = uri, + headers = { + ['X-Vault-Token'] = opts.token + } + } + + if is_http_error(err, data) then + printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code) + maybe_print_vault_data(opts, err) + os.exit(1) + else + maybe_print_vault_data(opts, data.content, function(obj) + return obj.data.selectors + end) + end +end + +local function delete_handler(opts, domain) + local uri = vault_url(opts, domain) + local err, data = rspamd_http.request { + config = rspamd_config, + ev_base = rspamadm_ev_base, + session = rspamadm_session, + resolver = rspamadm_dns_resolver, + url = uri, + method = 'delete', + headers = { + ['X-Vault-Token'] = opts.token + } + } + + if is_http_error(err, data) then + printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code) + maybe_print_vault_data(opts, err) + os.exit(1) + else + printf('deleted key(s) for %s', domain) + end +end + +local function list_handler(opts) + local uri = vault_url(opts) + local err, data = rspamd_http.request { + config = rspamd_config, + ev_base = rspamadm_ev_base, + session = rspamadm_session, + resolver = rspamadm_dns_resolver, + url = uri .. '?list=true', + headers = { + ['X-Vault-Token'] = opts.token + } + } + + if is_http_error(err, data) then + printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code) + maybe_print_vault_data(opts, err) + os.exit(1) + else + maybe_print_vault_data(opts, data.content, function(obj) + return obj.data.keys + end) + end +end + +-- Returns pair privkey+pubkey +local function genkey(opts) + return cr.gen_dkim_keypair(opts.algorithm, opts.bits) +end + +local function create_and_push_key(opts, domain, existing) + local uri = vault_url(opts, domain) + local sk, pk = genkey(opts) + + local res = { + selectors = { + [1] = { + selector = opts.selector, + domain = domain, + key = tostring(sk), + pubkey = tostring(pk), + alg = opts.algorithm, + bits = opts.bits or 0, + valid_start = os.time(), + } + } + } + + for _, sel in ipairs(existing) do + res.selectors[#res.selectors + 1] = sel + end + + if opts.expire then + res.selectors[1].valid_end = os.time() + opts.expire * 3600 * 24 + end + + local err, data = rspamd_http.request { + config = rspamd_config, + ev_base = rspamadm_ev_base, + session = rspamadm_session, + resolver = rspamadm_dns_resolver, + url = uri, + method = 'put', + headers = { + ['Content-Type'] = 'application/json', + ['X-Vault-Token'] = opts.token + }, + body = { + ucl.to_format(res, 'json-compact') + }, + } + + if is_http_error(err, data) then + printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code) + maybe_print_vault_data(opts, data.content) + os.exit(1) + else + maybe_printf(opts, 'stored key for: %s, selector: %s', domain, opts.selector) + maybe_printf(opts, 'please place the corresponding public key as following:') + + if opts.silent then + printf('%s', pk) + else + print_dkim_txt_record(tostring(pk), opts.selector, opts.algorithm) + end + end +end + +local function newkey_handler(opts, domain) + local uri = vault_url(opts, domain) + + if not opts.selector then + opts.selector = string.format('%s-%s', opts.algorithm, + os.date("!%Y%m%d")) + end + + local err, data = rspamd_http.request { + config = rspamd_config, + ev_base = rspamadm_ev_base, + session = rspamadm_session, + resolver = rspamadm_dns_resolver, + url = uri, + method = 'get', + headers = { + ['X-Vault-Token'] = opts.token + } + } + + if is_http_error(err, data) or not data.content then + create_and_push_key(opts, domain, {}) + else + -- Key exists + local rep = parse_vault_reply(data.content) + + if not rep or not rep.data then + printf('cannot parse reply for %s: %s', uri, data.content) + os.exit(1) + end + + local elts = rep.data.selectors + + if not elts then + create_and_push_key(opts, domain, {}) + os.exit(0) + end + + for _, sel in ipairs(elts) do + if sel.alg == opts.algorithm then + printf('key with the specific algorithm %s is already presented at %s selector for %s domain', + opts.algorithm, sel.selector, domain) + os.exit(1) + else + create_and_push_key(opts, domain, elts) + end + end + end +end + +local function roll_handler(opts, domain) + local uri = vault_url(opts, domain) + local res = { + selectors = {} + } + + local err, data = rspamd_http.request { + config = rspamd_config, + ev_base = rspamadm_ev_base, + session = rspamadm_session, + resolver = rspamadm_dns_resolver, + url = uri, + method = 'get', + headers = { + ['X-Vault-Token'] = opts.token + } + } + + if is_http_error(err, data) or not data.content then + printf("No keys to roll for domain %s", domain) + os.exit(1) + else + local rep = parse_vault_reply(data.content) + + if not rep or not rep.data then + printf('cannot parse reply for %s: %s', uri, data.content) + os.exit(1) + end + + local elts = rep.data.selectors + + if not elts then + printf("No keys to roll for domain %s", domain) + os.exit(1) + end + + local nkeys = {} -- indexed by algorithm + + local function insert_key(sel, add_expire) + if not nkeys[sel.alg] then + nkeys[sel.alg] = {} + end + + if add_expire then + sel.valid_end = os.time() + opts.ttl * 3600 * 24 + end + + table.insert(nkeys[sel.alg], sel) + end + + for _, sel in ipairs(elts) do + if sel.valid_end and sel.valid_end < os.time() then + if not opts.remove_expired then + insert_key(sel, false) + else + maybe_printf(opts, 'removed expired key for %s (selector %s, expire "%s"', + domain, sel.selector, os.date('%c', sel.valid_end)) + end + else + insert_key(sel, true) + end + end + + -- Now we need to ensure that all but one selectors have either expired or just a single key + for alg, keys in pairs(nkeys) do + table.sort(keys, function(k1, k2) + if k1.valid_end and k2.valid_end then + return k1.valid_end > k2.valid_end + elseif k1.valid_end then + return true + elseif k2.valid_end then + return false + end + return false + end) + -- Exclude the key with the highest expiration date and examine the rest + if not (#keys == 1 or fun.all(function(k) + return k.valid_end and k.valid_end < os.time() + end, fun.tail(keys))) then + printf('bad keys list for %s and %s algorithm', domain, alg) + fun.each(function(k) + if not k.valid_end then + printf('selector %s, algorithm %s has a key with no expire', + k.selector, k.alg) + elseif k.valid_end >= os.time() then + printf('selector %s, algorithm %s has a key that not yet expired: %s', + k.selector, k.alg, os.date('%c', k.valid_end)) + end + end, fun.tail(keys)) + os.exit(1) + end + -- Do not create new keys, if we only want to remove expired keys + if not opts.remove_expired then + -- OK to process + -- Insert keys for each algorithm in pairs <old_key(s)>, <new_key> + local sk, pk = genkey({ algorithm = alg, bits = keys[1].bits }) + local selector = string.format('%s-%s', alg, + os.date("!%Y%m%d")) + + if selector == keys[1].selector then + selector = selector .. '-1' + end + local nelt = { + selector = selector, + domain = domain, + key = tostring(sk), + pubkey = tostring(pk), + alg = alg, + bits = keys[1].bits, + valid_start = os.time(), + } + + if opts.expire then + nelt.valid_end = os.time() + opts.expire * 3600 * 24 + end + + table.insert(res.selectors, nelt) + end + for _, k in ipairs(keys) do + table.insert(res.selectors, k) + end + end + end + + -- We can now store res in the vault + err, data = rspamd_http.request { + config = rspamd_config, + ev_base = rspamadm_ev_base, + session = rspamadm_session, + resolver = rspamadm_dns_resolver, + url = uri, + method = 'put', + headers = { + ['Content-Type'] = 'application/json', + ['X-Vault-Token'] = opts.token + }, + body = { + ucl.to_format(res, 'json-compact') + }, + } + + if is_http_error(err, data) then + printf('cannot put request to the vault (%s), HTTP error code %s', uri, data.code) + maybe_print_vault_data(opts, data.content) + os.exit(1) + else + for _, key in ipairs(res.selectors) do + if not key.valid_end or key.valid_end > os.time() + opts.ttl * 3600 * 24 then + maybe_printf(opts, 'rolled key for: %s, new selector: %s', domain, key.selector) + maybe_printf(opts, 'please place the corresponding public key as following:') + + if opts.silent then + printf('%s', key.pubkey) + else + print_dkim_txt_record(key.pubkey, key.selector, key.alg) + end + + end + end + + maybe_printf(opts, 'your old keys will be valid until %s', + os.date('%c', os.time() + opts.ttl * 3600 * 24)) + end +end + +local function handler(args) + local opts = parser:parse(args) + + if not opts.addr then + opts.addr = os.getenv('VAULT_ADDR') + end + + if not opts.token then + opts.token = os.getenv('VAULT_TOKEN') + else + maybe_printf(opts, 'defining token via command line is insecure, define it via environment variable %s', + highlight('VAULT_TOKEN', 'red')) + end + + if not opts.token or not opts.addr then + printf('no token or/and vault addr has been specified, exiting') + os.exit(1) + end + + local command = opts.command + + if command == 'list' then + list_handler(opts) + elseif command == 'show' then + fun.each(function(d) + show_handler(opts, d) + end, opts.domain) + elseif command == 'newkey' then + fun.each(function(d) + newkey_handler(opts, d) + end, opts.domain) + elseif command == 'roll' then + fun.each(function(d) + roll_handler(opts, d) + end, opts.domain) + elseif command == 'delete' then + fun.each(function(d) + delete_handler(opts, d) + end, opts.domain) + else + parser:error(string.format('command %s is not implemented', command)) + end +end + +return { + handler = handler, + description = parser._description, + name = 'vault' +} |