summaryrefslogtreecommitdiffstats
path: root/lualib/rspamadm/corpus_test.lua
diff options
context:
space:
mode:
Diffstat (limited to 'lualib/rspamadm/corpus_test.lua')
-rw-r--r--lualib/rspamadm/corpus_test.lua185
1 files changed, 185 insertions, 0 deletions
diff --git a/lualib/rspamadm/corpus_test.lua b/lualib/rspamadm/corpus_test.lua
new file mode 100644
index 0000000..0e63f9f
--- /dev/null
+++ b/lualib/rspamadm/corpus_test.lua
@@ -0,0 +1,185 @@
+local rspamd_logger = require "rspamd_logger"
+local ucl = require "ucl"
+local lua_util = require "lua_util"
+local argparse = require "argparse"
+
+local parser = argparse()
+ :name "rspamadm corpus_test"
+ :description "Create logs files from email corpus"
+ :help_description_margin(32)
+
+parser:option "-H --ham"
+ :description("Ham directory")
+ :argname("<dir>")
+parser:option "-S --spam"
+ :description("Spam directory")
+ :argname("<dir>")
+parser:option "-n --conns"
+ :description("Number of parallel connections")
+ :argname("<N>")
+ :convert(tonumber)
+ :default(10)
+parser:option "-o --output"
+ :description("Output file")
+ :argname("<file>")
+ :default('results.log')
+parser:option "-t --timeout"
+ :description("Timeout for client connections")
+ :argname("<sec>")
+ :convert(tonumber)
+ :default(60)
+parser:option "-c --connect"
+ :description("Connect to specific host")
+ :argname("<host>")
+ :default('localhost:11334')
+parser:option "-r --rspamc"
+ :description("Use specific rspamc path")
+ :argname("<path>")
+ :default('rspamc')
+
+local HAM = "HAM"
+local SPAM = "SPAM"
+local opts
+
+local function scan_email(n_parallel, path, timeout)
+
+ local rspamc_command = string.format("%s --connect %s -j --compact -n %s -t %.3f %s",
+ opts.rspamc, opts.connect, n_parallel, timeout, path)
+ local result = assert(io.popen(rspamc_command))
+ result = result:read("*all")
+ return result
+end
+
+local function write_results(results, file)
+
+ local f = io.open(file, 'w')
+
+ for _, result in pairs(results) do
+ local log_line = string.format("%s %.2f %s",
+ result.type, result.score, result.action)
+
+ for _, sym in pairs(result.symbols) do
+ log_line = log_line .. " " .. sym
+ end
+
+ log_line = log_line .. " " .. result.scan_time .. " " .. file .. ':' .. result.filename
+
+ log_line = log_line .. "\r\n"
+
+ f:write(log_line)
+ end
+
+ f:close()
+end
+
+local function encoded_json_to_log(result)
+ -- Returns table containing score, action, list of symbols
+
+ local filtered_result = {}
+ local ucl_parser = ucl.parser()
+
+ local is_good, err = ucl_parser:parse_string(result)
+
+ if not is_good then
+ rspamd_logger.errx("Parser error: %1", err)
+ return nil
+ end
+
+ result = ucl_parser:get_object()
+
+ filtered_result.score = result.score
+ if not result.action then
+ rspamd_logger.errx("Bad JSON: %1", result)
+ return nil
+ end
+ local action = result.action:gsub("%s+", "_")
+ filtered_result.action = action
+
+ filtered_result.symbols = {}
+
+ for sym, _ in pairs(result.symbols) do
+ table.insert(filtered_result.symbols, sym)
+ end
+
+ filtered_result.filename = result.filename
+ filtered_result.scan_time = result.scan_time
+
+ return filtered_result
+end
+
+local function scan_results_to_logs(results, actual_email_type)
+
+ local logs = {}
+
+ results = lua_util.rspamd_str_split(results, "\n")
+
+ if results[#results] == "" then
+ results[#results] = nil
+ end
+
+ for _, result in pairs(results) do
+ result = encoded_json_to_log(result)
+ if result then
+ result['type'] = actual_email_type
+ table.insert(logs, result)
+ end
+ end
+
+ return logs
+end
+
+local function handler(args)
+ opts = parser:parse(args)
+ local ham_directory = opts['ham']
+ local spam_directory = opts['spam']
+ local connections = opts["conns"]
+ local output = opts["output"]
+
+ local results = {}
+
+ local start_time = os.time()
+ local no_of_ham = 0
+ local no_of_spam = 0
+
+ if ham_directory then
+ rspamd_logger.messagex("Scanning ham corpus...")
+ local ham_results = scan_email(connections, ham_directory, opts["timeout"])
+ ham_results = scan_results_to_logs(ham_results, HAM)
+
+ no_of_ham = #ham_results
+
+ for _, result in pairs(ham_results) do
+ table.insert(results, result)
+ end
+ end
+
+ if spam_directory then
+ rspamd_logger.messagex("Scanning spam corpus...")
+ local spam_results = scan_email(connections, spam_directory, opts.timeout)
+ spam_results = scan_results_to_logs(spam_results, SPAM)
+
+ no_of_spam = #spam_results
+
+ for _, result in pairs(spam_results) do
+ table.insert(results, result)
+ end
+ end
+
+ rspamd_logger.messagex("Writing results to %s", output)
+ write_results(results, output)
+
+ rspamd_logger.messagex("Stats: ")
+ local elapsed_time = os.time() - start_time
+ local total_msgs = no_of_ham + no_of_spam
+ rspamd_logger.messagex("Elapsed time: %ss", elapsed_time)
+ rspamd_logger.messagex("No of ham: %s", no_of_ham)
+ rspamd_logger.messagex("No of spam: %s", no_of_spam)
+ rspamd_logger.messagex("Messages/sec: %s", (total_msgs / elapsed_time))
+end
+
+return {
+ name = 'corpustest',
+ aliases = { 'corpus_test', 'corpus' },
+ handler = handler,
+ description = parser._description
+}