summaryrefslogtreecommitdiffstats
path: root/lualib
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-10 21:30:40 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-10 21:30:40 +0000
commit133a45c109da5310add55824db21af5239951f93 (patch)
treeba6ac4c0a950a0dda56451944315d66409923918 /lualib
parentInitial commit. (diff)
downloadrspamd-133a45c109da5310add55824db21af5239951f93.tar.xz
rspamd-133a45c109da5310add55824db21af5239951f93.zip
Adding upstream version 3.8.1.upstream/3.8.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'lualib')
-rw-r--r--lualib/ansicolors.lua68
-rw-r--r--lualib/global_functions.lua56
-rw-r--r--lualib/lua_auth_results.lua301
-rw-r--r--lualib/lua_aws.lua300
-rw-r--r--lualib/lua_bayes_learn.lua151
-rw-r--r--lualib/lua_bayes_redis.lua244
-rw-r--r--lualib/lua_cfg_transform.lua634
-rw-r--r--lualib/lua_cfg_utils.lua84
-rw-r--r--lualib/lua_clickhouse.lua547
-rw-r--r--lualib/lua_content/ical.lua105
-rw-r--r--lualib/lua_content/init.lua109
-rw-r--r--lualib/lua_content/pdf.lua1424
-rw-r--r--lualib/lua_content/vcard.lua84
-rw-r--r--lualib/lua_dkim_tools.lua742
-rw-r--r--lualib/lua_ffi/common.lua45
-rw-r--r--lualib/lua_ffi/dkim.lua144
-rw-r--r--lualib/lua_ffi/init.lua59
-rw-r--r--lualib/lua_ffi/linalg.lua87
-rw-r--r--lualib/lua_ffi/spf.lua143
-rw-r--r--lualib/lua_fuzzy.lua355
-rw-r--r--lualib/lua_lexer.lua163
-rw-r--r--lualib/lua_magic/heuristics.lua605
-rw-r--r--lualib/lua_magic/init.lua388
-rw-r--r--lualib/lua_magic/patterns.lua471
-rw-r--r--lualib/lua_magic/types.lua327
-rw-r--r--lualib/lua_maps.lua612
-rw-r--r--lualib/lua_maps_expressions.lua219
-rw-r--r--lualib/lua_meta.lua549
-rw-r--r--lualib/lua_mime.lua760
-rw-r--r--lualib/lua_mime_types.lua745
-rw-r--r--lualib/lua_redis.lua1817
-rw-r--r--lualib/lua_scanners/avast.lua304
-rw-r--r--lualib/lua_scanners/clamav.lua193
-rw-r--r--lualib/lua_scanners/cloudmark.lua372
-rw-r--r--lualib/lua_scanners/common.lua539
-rw-r--r--lualib/lua_scanners/dcc.lua313
-rw-r--r--lualib/lua_scanners/fprot.lua181
-rw-r--r--lualib/lua_scanners/icap.lua713
-rw-r--r--lualib/lua_scanners/init.lua75
-rw-r--r--lualib/lua_scanners/kaspersky_av.lua197
-rw-r--r--lualib/lua_scanners/kaspersky_se.lua287
-rw-r--r--lualib/lua_scanners/oletools.lua369
-rw-r--r--lualib/lua_scanners/p0f.lua227
-rw-r--r--lualib/lua_scanners/pyzor.lua206
-rw-r--r--lualib/lua_scanners/razor.lua181
-rw-r--r--lualib/lua_scanners/savapi.lua261
-rw-r--r--lualib/lua_scanners/sophos.lua192
-rw-r--r--lualib/lua_scanners/spamassassin.lua213
-rw-r--r--lualib/lua_scanners/vadesecure.lua351
-rw-r--r--lualib/lua_scanners/virustotal.lua214
-rw-r--r--lualib/lua_selectors/common.lua95
-rw-r--r--lualib/lua_selectors/extractors.lua565
-rw-r--r--lualib/lua_selectors/init.lua668
-rw-r--r--lualib/lua_selectors/maps.lua19
-rw-r--r--lualib/lua_selectors/transforms.lua571
-rw-r--r--lualib/lua_settings.lua309
-rw-r--r--lualib/lua_smtp.lua201
-rw-r--r--lualib/lua_stat.lua869
-rw-r--r--lualib/lua_tcp_sync.lua213
-rw-r--r--lualib/lua_urls_compose.lua286
-rw-r--r--lualib/lua_util.lua1639
-rw-r--r--lualib/lua_verdict.lua208
-rw-r--r--lualib/plugins/dmarc.lua359
-rw-r--r--lualib/plugins/neural.lua892
-rw-r--r--lualib/plugins/rbl.lua232
-rw-r--r--lualib/plugins_stats.lua48
-rw-r--r--lualib/redis_scripts/bayes_cache_check.lua20
-rw-r--r--lualib/redis_scripts/bayes_cache_learn.lua61
-rw-r--r--lualib/redis_scripts/bayes_classify.lua37
-rw-r--r--lualib/redis_scripts/bayes_learn.lua44
-rw-r--r--lualib/redis_scripts/bayes_stat.lua19
-rw-r--r--lualib/redis_scripts/neural_maybe_invalidate.lua25
-rw-r--r--lualib/redis_scripts/neural_maybe_lock.lua19
-rw-r--r--lualib/redis_scripts/neural_save_unlock.lua24
-rw-r--r--lualib/redis_scripts/neural_train_size.lua24
-rw-r--r--lualib/redis_scripts/ratelimit_check.lua85
-rw-r--r--lualib/redis_scripts/ratelimit_cleanup_pending.lua33
-rw-r--r--lualib/redis_scripts/ratelimit_update.lua93
-rw-r--r--lualib/rspamadm/clickhouse.lua528
-rw-r--r--lualib/rspamadm/configgraph.lua172
-rw-r--r--lualib/rspamadm/confighelp.lua123
-rw-r--r--lualib/rspamadm/configwizard.lua849
-rw-r--r--lualib/rspamadm/cookie.lua125
-rw-r--r--lualib/rspamadm/corpus_test.lua185
-rw-r--r--lualib/rspamadm/dkim_keygen.lua178
-rw-r--r--lualib/rspamadm/dmarc_report.lua737
-rw-r--r--lualib/rspamadm/dns_tool.lua232
-rw-r--r--lualib/rspamadm/fuzzy_convert.lua208
-rw-r--r--lualib/rspamadm/fuzzy_ping.lua259
-rw-r--r--lualib/rspamadm/fuzzy_stat.lua366
-rw-r--r--lualib/rspamadm/grep.lua174
-rw-r--r--lualib/rspamadm/keypair.lua508
-rw-r--r--lualib/rspamadm/mime.lua1012
-rw-r--r--lualib/rspamadm/neural_test.lua228
-rw-r--r--lualib/rspamadm/publicsuffix.lua82
-rw-r--r--lualib/rspamadm/stat_convert.lua38
-rw-r--r--lualib/rspamadm/statistics_dump.lua544
-rw-r--r--lualib/rspamadm/template.lua131
-rw-r--r--lualib/rspamadm/vault.lua579
99 files changed, 32642 insertions, 0 deletions
diff --git a/lualib/ansicolors.lua b/lualib/ansicolors.lua
new file mode 100644
index 0000000..81783f6
--- /dev/null
+++ b/lualib/ansicolors.lua
@@ -0,0 +1,68 @@
+local colormt = {}
+local ansicolors = {}
+
+local rspamd_util = require "rspamd_util"
+local isatty = rspamd_util.isatty()
+
+function colormt:__tostring()
+ return self.value
+end
+
+function colormt:__concat(other)
+ return tostring(self) .. tostring(other)
+end
+
+function colormt:__call(s)
+ return self .. s .. ansicolors.reset
+end
+
+colormt.__metatable = {}
+local function makecolor(value)
+ if isatty then
+ return setmetatable({
+ value = string.char(27) .. '[' .. tostring(value) .. 'm'
+ }, colormt)
+ else
+ return setmetatable({
+ value = ''
+ }, colormt)
+ end
+end
+
+local colors = {
+ -- attributes
+ reset = 0,
+ clear = 0,
+ bright = 1,
+ dim = 2,
+ underscore = 4,
+ blink = 5,
+ reverse = 7,
+ hidden = 8,
+
+ -- foreground
+ black = 30,
+ red = 31,
+ green = 32,
+ yellow = 33,
+ blue = 34,
+ magenta = 35,
+ cyan = 36,
+ white = 37,
+
+ -- background
+ onblack = 40,
+ onred = 41,
+ ongreen = 42,
+ onyellow = 43,
+ onblue = 44,
+ onmagenta = 45,
+ oncyan = 46,
+ onwhite = 47,
+}
+
+for c, v in pairs(colors) do
+ ansicolors[c] = makecolor(v)
+end
+
+return ansicolors \ No newline at end of file
diff --git a/lualib/global_functions.lua b/lualib/global_functions.lua
new file mode 100644
index 0000000..45d4e84
--- /dev/null
+++ b/lualib/global_functions.lua
@@ -0,0 +1,56 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local logger = require "rspamd_logger"
+local lua_util = require "lua_util"
+local lua_redis = require "lua_redis"
+local meta_functions = require "lua_meta"
+local maps = require "lua_maps"
+
+local exports = {}
+
+exports.rspamd_parse_redis_server = lua_redis.rspamd_parse_redis_server
+exports.parse_redis_server = lua_redis.rspamd_parse_redis_server
+exports.rspamd_redis_make_request = lua_redis.rspamd_redis_make_request
+exports.redis_make_request = lua_redis.rspamd_redis_make_request
+
+exports.rspamd_gen_metatokens = meta_functions.rspamd_gen_metatokens
+exports.rspamd_count_metatokens = meta_functions.rspamd_count_metatokens
+
+exports.rspamd_map_add = maps.rspamd_map_add
+
+exports.rspamd_str_split = lua_util.rspamd_str_split
+
+-- a special syntax sugar to export all functions to the global table
+setmetatable(exports, {
+ __call = function(t, override)
+ for k, v in pairs(t) do
+ if _G[k] ~= nil then
+ local msg = 'function ' .. k .. ' already exists in global scope.'
+ if override then
+ _G[k] = v
+ logger.errx('WARNING: ' .. msg .. ' Overwritten.')
+ else
+ logger.errx('NOTICE: ' .. msg .. ' Skipped.')
+ end
+ else
+ _G[k] = v
+ end
+ end
+ end,
+})
+
+return exports
diff --git a/lualib/lua_auth_results.lua b/lualib/lua_auth_results.lua
new file mode 100644
index 0000000..8c907d9
--- /dev/null
+++ b/lualib/lua_auth_results.lua
@@ -0,0 +1,301 @@
+--[[
+Copyright (c) 2016, Andrew Lewis <nerf@judo.za.org>
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local rspamd_util = require "rspamd_util"
+local lua_util = require "lua_util"
+
+local default_settings = {
+ spf_symbols = {
+ pass = 'R_SPF_ALLOW',
+ fail = 'R_SPF_FAIL',
+ softfail = 'R_SPF_SOFTFAIL',
+ neutral = 'R_SPF_NEUTRAL',
+ temperror = 'R_SPF_DNSFAIL',
+ none = 'R_SPF_NA',
+ permerror = 'R_SPF_PERMFAIL',
+ },
+ dmarc_symbols = {
+ pass = 'DMARC_POLICY_ALLOW',
+ permerror = 'DMARC_BAD_POLICY',
+ temperror = 'DMARC_DNSFAIL',
+ none = 'DMARC_NA',
+ reject = 'DMARC_POLICY_REJECT',
+ softfail = 'DMARC_POLICY_SOFTFAIL',
+ quarantine = 'DMARC_POLICY_QUARANTINE',
+ },
+ arc_symbols = {
+ pass = 'ARC_ALLOW',
+ permerror = 'ARC_INVALID',
+ temperror = 'ARC_DNSFAIL',
+ none = 'ARC_NA',
+ reject = 'ARC_REJECT',
+ },
+ dkim_symbols = {
+ none = 'R_DKIM_NA',
+ },
+ add_smtp_user = true,
+}
+
+local exports = {
+ default_settings = default_settings
+}
+
+local local_hostname = rspamd_util.get_hostname()
+
+local function gen_auth_results(task, settings)
+ local auth_results, hdr_parts = {}, {}
+
+ if not settings then
+ settings = default_settings
+ end
+
+ local auth_types = {
+ dkim = settings.dkim_symbols,
+ dmarc = settings.dmarc_symbols,
+ spf = settings.spf_symbols,
+ arc = settings.arc_symbols,
+ }
+
+ local common = {
+ symbols = {}
+ }
+
+ local mta_hostname = task:get_request_header('MTA-Name') or
+ task:get_request_header('MTA-Tag')
+ if mta_hostname then
+ mta_hostname = tostring(mta_hostname)
+ else
+ mta_hostname = local_hostname
+ end
+
+ table.insert(hdr_parts, mta_hostname)
+
+ for auth_type, symbols in pairs(auth_types) do
+ for key, sym in pairs(symbols) do
+ if not common.symbols.sym then
+ local s = task:get_symbol(sym)
+ if not s then
+ common.symbols[sym] = false
+ else
+ common.symbols[sym] = s
+ if not auth_results[auth_type] then
+ auth_results[auth_type] = { key }
+ else
+ table.insert(auth_results[auth_type], key)
+ end
+
+ if auth_type ~= 'dkim' then
+ break
+ end
+ end
+ end
+ end
+ end
+
+ local dkim_results = task:get_dkim_results()
+ -- For each signature we set authentication results
+ -- dkim=neutral (body hash did not verify) header.d=example.com header.s=sel header.b=fA8VVvJ8;
+ -- dkim=neutral (body hash did not verify) header.d=example.com header.s=sel header.b=f8pM8o90;
+
+ for _, dres in ipairs(dkim_results) do
+ local ar_string = 'none'
+
+ if dres.result == 'reject' then
+ ar_string = 'fail' -- imply failure, not neutral
+ elseif dres.result == 'allow' then
+ ar_string = 'pass'
+ elseif dres.result == 'bad record' or dres.result == 'permerror' then
+ ar_string = 'permerror'
+ elseif dres.result == 'tempfail' then
+ ar_string = 'temperror'
+ end
+ local hdr = {}
+
+ hdr[1] = string.format('dkim=%s', ar_string)
+
+ if dres.fail_reason then
+ hdr[#hdr + 1] = string.format('(%s)', lua_util.maybe_smtp_quote_value(dres.fail_reason))
+ end
+
+ if dres.domain then
+ hdr[#hdr + 1] = string.format('header.d=%s', lua_util.maybe_smtp_quote_value(dres.domain))
+ end
+
+ if dres.selector then
+ hdr[#hdr + 1] = string.format('header.s=%s', lua_util.maybe_smtp_quote_value(dres.selector))
+ end
+
+ if dres.bhash then
+ hdr[#hdr + 1] = string.format('header.b=%s', lua_util.maybe_smtp_quote_value(dres.bhash))
+ end
+
+ table.insert(hdr_parts, table.concat(hdr, ' '))
+ end
+
+ if #dkim_results == 0 then
+ -- We have no dkim results, so check for DKIM_NA symbol
+ if common.symbols[settings.dkim_symbols.none] then
+ table.insert(hdr_parts, 'dkim=none')
+ end
+ end
+
+ for auth_type, keys in pairs(auth_results) do
+ for _, key in ipairs(keys) do
+ local hdr = ''
+ if auth_type == 'dmarc' then
+ local opts = common.symbols[auth_types['dmarc'][key]][1]['options'] or {}
+ hdr = hdr .. 'dmarc='
+ if key == 'reject' or key == 'quarantine' or key == 'softfail' then
+ hdr = hdr .. 'fail'
+ else
+ hdr = hdr .. lua_util.maybe_smtp_quote_value(key)
+ end
+ if key == 'pass' then
+ hdr = hdr .. ' (policy=' .. lua_util.maybe_smtp_quote_value(opts[2]) .. ')'
+ hdr = hdr .. ' header.from=' .. lua_util.maybe_smtp_quote_value(opts[1])
+ elseif key ~= 'none' then
+ local t = { opts[1]:match('^([^%s]+) : (.*)$') }
+ if #t > 0 then
+ local dom = t[1]
+ local rsn = t[2]
+ if rsn then
+ hdr = string.format('%s reason=%s', hdr, lua_util.maybe_smtp_quote_value(rsn))
+ end
+ hdr = string.format('%s header.from=%s', hdr, lua_util.maybe_smtp_quote_value(dom))
+ end
+ if key == 'softfail' then
+ hdr = hdr .. ' (policy=none)'
+ else
+ hdr = hdr .. ' (policy=' .. lua_util.maybe_smtp_quote_value(key) .. ')'
+ end
+ end
+ table.insert(hdr_parts, hdr)
+ elseif auth_type == 'arc' then
+ if common.symbols[auth_types['arc'][key]][1] then
+ local opts = common.symbols[auth_types['arc'][key]][1]['options'] or {}
+ for _, v in ipairs(opts) do
+ hdr = string.format('%s%s=%s (%s)', hdr, auth_type,
+ lua_util.maybe_smtp_quote_value(key), lua_util.maybe_smtp_quote_value(v))
+ table.insert(hdr_parts, hdr)
+ end
+ end
+ elseif auth_type == 'spf' then
+ -- Main type
+ local sender
+ local sender_type
+ local smtp_from = task:get_from({ 'smtp', 'orig' })
+
+ if smtp_from and
+ smtp_from[1] and
+ smtp_from[1]['addr'] ~= '' and
+ smtp_from[1]['addr'] ~= nil then
+ sender = lua_util.maybe_smtp_quote_value(smtp_from[1]['addr'])
+ sender_type = 'smtp.mailfrom'
+ else
+ local helo = task:get_helo()
+ if helo then
+ sender = lua_util.maybe_smtp_quote_value(helo)
+ sender_type = 'smtp.helo'
+ end
+ end
+
+ if sender and sender_type then
+ -- Comment line
+ local comment = ''
+ if key == 'pass' then
+ comment = string.format('%s: domain of %s designates %s as permitted sender',
+ mta_hostname, sender, tostring(task:get_from_ip() or 'unknown'))
+ elseif key == 'fail' then
+ comment = string.format('%s: domain of %s does not designate %s as permitted sender',
+ mta_hostname, sender, tostring(task:get_from_ip() or 'unknown'))
+ elseif key == 'neutral' or key == 'softfail' then
+ comment = string.format('%s: %s is neither permitted nor denied by domain of %s',
+ mta_hostname, tostring(task:get_from_ip() or 'unknown'), sender)
+ elseif key == 'permerror' then
+ comment = string.format('%s: domain of %s uses mechanism not recognized by this client',
+ mta_hostname, sender)
+ elseif key == 'temperror' then
+ comment = string.format('%s: error in processing during lookup of %s: DNS error',
+ mta_hostname, sender)
+ elseif key == 'none' then
+ comment = string.format('%s: domain of %s has no SPF policy when checking %s',
+ mta_hostname, sender, tostring(task:get_from_ip() or 'unknown'))
+ end
+ hdr = string.format('%s=%s (%s) %s=%s', auth_type, key,
+ comment, sender_type, sender)
+ else
+ hdr = string.format('%s=%s', auth_type, key)
+ end
+
+ table.insert(hdr_parts, hdr)
+ end
+ end
+ end
+
+ local u = task:get_user()
+ local smtp_from = task:get_from({ 'smtp', 'orig' })
+
+ if u and smtp_from then
+ local hdr = { [1] = 'auth=pass' }
+
+ if settings['add_smtp_user'] then
+ table.insert(hdr, 'smtp.auth=' .. lua_util.maybe_smtp_quote_value(u))
+ end
+ if smtp_from[1]['addr'] then
+ table.insert(hdr, 'smtp.mailfrom=' .. lua_util.maybe_smtp_quote_value(smtp_from[1]['addr']))
+ end
+
+ table.insert(hdr_parts, table.concat(hdr, ' '))
+ end
+
+ if #hdr_parts > 0 then
+ if #hdr_parts == 1 then
+ hdr_parts[2] = 'none'
+ end
+ return table.concat(hdr_parts, '; ')
+ end
+
+ return nil
+end
+
+exports.gen_auth_results = gen_auth_results
+
+local aar_elt_grammar
+-- This function parses an ar element to a table of kv pairs that represents different
+-- elements
+local function parse_ar_element(elt)
+
+ if not aar_elt_grammar then
+ -- Generate grammar
+ local lpeg = require "lpeg"
+ local P = lpeg.P
+ local S = lpeg.S
+ local V = lpeg.V
+ local C = lpeg.C
+ local space = S(" ") ^ 0
+ local doublequoted = space * P '"' * ((1 - S '"\r\n\f\\') + (P '\\' * 1)) ^ 0 * '"' * space
+ local comment = space * P { "(" * ((1 - S "()") + V(1)) ^ 0 * ")" } * space
+ local name = C((1 - S('=(" ')) ^ 1) * space
+ local pair = lpeg.Cg(name * "=" * space * name) * space
+ aar_elt_grammar = lpeg.Cf(lpeg.Ct("") * (pair + comment + doublequoted) ^ 1, rawset)
+ end
+
+ return aar_elt_grammar:match(elt)
+end
+exports.parse_ar_element = parse_ar_element
+
+return exports
diff --git a/lualib/lua_aws.lua b/lualib/lua_aws.lua
new file mode 100644
index 0000000..e6c4b29
--- /dev/null
+++ b/lualib/lua_aws.lua
@@ -0,0 +1,300 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+
+--[[[
+-- @module lua_aws
+-- This module contains Amazon AWS utility functions
+--]]
+
+local N = "aws"
+--local rspamd_logger = require "rspamd_logger"
+local ts = (require "tableshape").types
+local lua_util = require "lua_util"
+local fun = require "fun"
+local rspamd_crypto_hash = require "rspamd_cryptobox_hash"
+
+local exports = {}
+
+-- Returns a canonical representation of today date
+local function today_canonical()
+ return os.date('!%Y%m%d')
+end
+
+--[[[
+-- @function lua_aws.aws_date([date_str])
+-- Returns an aws date header corresponding to the specific date
+--]]
+local function aws_date(date_str)
+ if not date_str then
+ date_str = today_canonical()
+ end
+
+ return date_str .. os.date('!T%H%M%SZ')
+end
+
+exports.aws_date = aws_date
+
+
+-- Local cache of the keys to save resources
+local cached_keys = {}
+
+local function maybe_get_cached_key(date_str, secret_key, region, service, req_type)
+ local bucket = cached_keys[tonumber(date_str)]
+
+ if not bucket then
+ return nil
+ end
+
+ local elt = bucket[string.format('%s.%s.%s.%s', secret_key, region, service, req_type)]
+ if elt then
+ return elt
+ end
+end
+
+local function save_cached_key(date_str, secret_key, region, service, req_type, key)
+ local numdate = tonumber(date_str)
+ -- expire old buckets
+ for k, _ in pairs(cached_keys) do
+ if k < numdate then
+ cached_keys[k] = nil
+ end
+ end
+
+ local bucket = cached_keys[tonumber(date_str)]
+ local idx = string.format('%s.%s.%s.%s', secret_key, region, service, req_type)
+
+ if not bucket then
+ cached_keys[tonumber(date_str)] = {
+ idx = key
+ }
+ else
+ bucket[idx] = key
+ end
+end
+--[[[
+-- @function lua_aws.aws_signing_key([date_str], secret_key, region, [service='s3'], [req_type='aws4_request'])
+-- Returns a signing key for the specific parameters
+--]]
+local function aws_signing_key(date_str, secret_key, region, service, req_type)
+ if not date_str then
+ date_str = today_canonical()
+ end
+
+ if not service then
+ service = 's3'
+ end
+
+ if not req_type then
+ req_type = 'aws4_request'
+ end
+
+ assert(type(secret_key) == 'string')
+ assert(type(region) == 'string')
+
+ local maybe_cached = maybe_get_cached_key(date_str, secret_key, region, service, req_type)
+
+ if maybe_cached then
+ return maybe_cached
+ end
+
+ local hmac1 = rspamd_crypto_hash.create_specific_keyed("AWS4" .. secret_key, "sha256", date_str):bin()
+ local hmac2 = rspamd_crypto_hash.create_specific_keyed(hmac1, "sha256", region):bin()
+ local hmac3 = rspamd_crypto_hash.create_specific_keyed(hmac2, "sha256", service):bin()
+ local final_key = rspamd_crypto_hash.create_specific_keyed(hmac3, "sha256", req_type):bin()
+
+ save_cached_key(date_str, secret_key, region, service, req_type, final_key)
+
+ return final_key
+end
+
+exports.aws_signing_key = aws_signing_key
+
+--[[[
+-- @function lua_aws.aws_canon_request_hash(method, path, headers_to_sign, hex_hash)
+-- Returns a hash + list of headers as required to produce signature afterwards
+--]]
+local function aws_canon_request_hash(method, uri, headers_to_sign, hex_hash)
+ assert(type(method) == 'string')
+ assert(type(uri) == 'string')
+ assert(type(headers_to_sign) == 'table')
+
+ if not hex_hash then
+ hex_hash = headers_to_sign['x-amz-content-sha256']
+ end
+
+ assert(type(hex_hash) == 'string')
+
+ local sha_ctx = rspamd_crypto_hash.create_specific('sha256')
+
+ lua_util.debugm(N, 'update signature with the method %s',
+ method)
+ sha_ctx:update(method .. '\n')
+ lua_util.debugm(N, 'update signature with the uri %s',
+ uri)
+ sha_ctx:update(uri .. '\n')
+ -- XXX add query string canonicalisation
+ sha_ctx:update('\n')
+ -- Sort auth headers and canonicalise them as requested
+ local hdr_canon = fun.tomap(fun.map(function(k, v)
+ return k:lower(), lua_util.str_trim(v)
+ end, headers_to_sign))
+ local header_names = lua_util.keys(hdr_canon)
+ table.sort(header_names)
+ for _, hn in ipairs(header_names) do
+ local v = hdr_canon[hn]
+ lua_util.debugm(N, 'update signature with the header %s, %s',
+ hn, v)
+ sha_ctx:update(string.format('%s:%s\n', hn, v))
+ end
+ local hdrs_list = table.concat(header_names, ';')
+ lua_util.debugm(N, 'headers list to sign: %s', hdrs_list)
+ sha_ctx:update(string.format('\n%s\n%s', hdrs_list, hex_hash))
+
+ return sha_ctx:hex(), hdrs_list
+end
+
+exports.aws_canon_request_hash = aws_canon_request_hash
+
+local aws_authorization_hdr_args_schema = ts.shape {
+ date = ts.string + ts['nil'] / today_canonical,
+ secret_key = ts.string,
+ method = ts.string + ts['nil'] / function()
+ return 'GET'
+ end,
+ uri = ts.string,
+ region = ts.string,
+ service = ts.string + ts['nil'] / function()
+ return 's3'
+ end,
+ req_type = ts.string + ts['nil'] / function()
+ return 'aws4_request'
+ end,
+ headers = ts.map_of(ts.string, ts.string),
+ key_id = ts.string,
+}
+--[[[
+-- @function lua_aws.aws_authorization_hdr(params)
+-- Produces an authorization header as required by AWS
+-- Parameters schema is the following:
+ts.shape{
+ date = ts.string + ts['nil'] / today_canonical,
+ secret_key = ts.string,
+ method = ts.string + ts['nil'] / function() return 'GET' end,
+ uri = ts.string,
+ region = ts.string,
+ service = ts.string + ts['nil'] / function() return 's3' end,
+ req_type = ts.string + ts['nil'] / function() return 'aws4_request' end,
+ headers = ts.map_of(ts.string, ts.string),
+ key_id = ts.string,
+}
+--
+--]]
+local function aws_authorization_hdr(tbl, transformed)
+ local res, err
+ if not transformed then
+ res, err = aws_authorization_hdr_args_schema:transform(tbl)
+ assert(res, err)
+ else
+ res = tbl
+ end
+
+ local signing_key = aws_signing_key(res.date, res.secret_key, res.region, res.service,
+ res.req_type)
+ assert(signing_key ~= nil)
+ local signed_sha, signed_hdrs = aws_canon_request_hash(res.method, res.uri,
+ res.headers)
+
+ if not signed_sha then
+ return nil
+ end
+
+ local string_to_sign = string.format('AWS4-HMAC-SHA256\n%s\n%s/%s/%s/%s\n%s',
+ res.headers['x-amz-date'] or aws_date(),
+ res.date, res.region, res.service, res.req_type,
+ signed_sha)
+ lua_util.debugm(N, "string to sign: %s", string_to_sign)
+ local hmac = rspamd_crypto_hash.create_specific_keyed(signing_key, 'sha256', string_to_sign):hex()
+ lua_util.debugm(N, "hmac: %s", hmac)
+ local auth_hdr = string.format('AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/%s,' ..
+ 'SignedHeaders=%s,Signature=%s',
+ res.key_id, res.date, res.region, res.service, res.req_type,
+ signed_hdrs, hmac)
+
+ return auth_hdr
+end
+
+exports.aws_authorization_hdr = aws_authorization_hdr
+
+
+
+--[[[
+-- @function lua_aws.aws_request_enrich(params, content)
+-- Produces an authorization header as required by AWS
+-- Parameters schema is the following:
+ts.shape{
+ date = ts.string + ts['nil'] / today_canonical,
+ secret_key = ts.string,
+ method = ts.string + ts['nil'] / function() return 'GET' end,
+ uri = ts.string,
+ region = ts.string,
+ service = ts.string + ts['nil'] / function() return 's3' end,
+ req_type = ts.string + ts['nil'] / function() return 'aws4_request' end,
+ headers = ts.map_of(ts.string, ts.string),
+ key_id = ts.string,
+}
+This method returns new/modified in place table of the headers
+--
+--]]
+local function aws_request_enrich(tbl, content)
+ local res, err = aws_authorization_hdr_args_schema:transform(tbl)
+ assert(res, err)
+ local content_sha256 = rspamd_crypto_hash.create_specific('sha256', content):hex()
+ local hdrs = res.headers
+ hdrs['x-amz-content-sha256'] = content_sha256
+ if not hdrs['x-amz-date'] then
+ hdrs['x-amz-date'] = aws_date(res.date)
+ end
+ hdrs['Authorization'] = aws_authorization_hdr(res, true)
+
+ return hdrs
+end
+
+exports.aws_request_enrich = aws_request_enrich
+
+-- A simple tests according to AWS docs to check sanity
+local test_request_hdrs = {
+ ['Host'] = 'examplebucket.s3.amazonaws.com',
+ ['x-amz-date'] = '20130524T000000Z',
+ ['Range'] = 'bytes=0-9',
+ ['x-amz-content-sha256'] = 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855',
+}
+
+assert(aws_canon_request_hash('GET', '/test.txt', test_request_hdrs) ==
+ '7344ae5b7ee6c3e7e6b0fe0640412a37625d1fbfff95c48bbb2dc43964946972')
+
+assert(aws_authorization_hdr {
+ date = '20130524',
+ region = 'us-east-1',
+ headers = test_request_hdrs,
+ uri = '/test.txt',
+ key_id = 'AKIAIOSFODNN7EXAMPLE',
+ secret_key = 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY',
+} == 'AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20130524/us-east-1/s3/aws4_request,' ..
+ 'SignedHeaders=host;range;x-amz-content-sha256;x-amz-date,' ..
+ 'Signature=f0e8bdb87c964420e857bd35b5d6ed310bd44f0170aba48dd91039c6036bdb41')
+
+return exports
diff --git a/lualib/lua_bayes_learn.lua b/lualib/lua_bayes_learn.lua
new file mode 100644
index 0000000..ea97db6
--- /dev/null
+++ b/lualib/lua_bayes_learn.lua
@@ -0,0 +1,151 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+-- This file contains functions to simplify bayes classifier auto-learning
+
+local lua_util = require "lua_util"
+local lua_verdict = require "lua_verdict"
+local N = "lua_bayes"
+
+local exports = {}
+
+exports.can_learn = function(task, is_spam, is_unlearn)
+ local learn_type = task:get_request_header('Learn-Type')
+
+ if not (learn_type and tostring(learn_type) == 'bulk') then
+ local prob = task:get_mempool():get_variable('bayes_prob', 'double')
+
+ if prob then
+ local in_class = false
+ local cl
+ if is_spam then
+ cl = 'spam'
+ in_class = prob >= 0.95
+ else
+ cl = 'ham'
+ in_class = prob <= 0.05
+ end
+
+ if in_class then
+ return false, string.format(
+ 'already in class %s; probability %.2f%%',
+ cl, math.abs((prob - 0.5) * 200.0))
+ end
+ end
+ end
+
+ return true
+end
+
+exports.autolearn = function(task, conf)
+ local function log_can_autolearn(verdict, score, threshold)
+ local from = task:get_from('smtp')
+ local mime_rcpts = 'undef'
+ local mr = task:get_recipients('mime')
+ if mr then
+ for _, r in ipairs(mr) do
+ if mime_rcpts == 'undef' then
+ mime_rcpts = r.addr
+ else
+ mime_rcpts = mime_rcpts .. ',' .. r.addr
+ end
+ end
+ end
+
+ lua_util.debugm(N, task, 'id: %s, from: <%s>: can autolearn %s: score %s %s %s, mime_rcpts: <%s>',
+ task:get_header('Message-Id') or '<undef>',
+ from and from[1].addr or 'undef',
+ verdict,
+ string.format("%.2f", score),
+ verdict == 'ham' and '<=' or verdict == 'spam' and '>=' or '/',
+ threshold,
+ mime_rcpts)
+ end
+
+ -- We have autolearn config so let's figure out what is requested
+ local verdict, score = lua_verdict.get_specific_verdict("bayes", task)
+ local learn_spam, learn_ham = false, false
+
+ if verdict == 'passthrough' then
+ -- No need to autolearn
+ lua_util.debugm(N, task, 'no need to autolearn - verdict: %s',
+ verdict)
+ return
+ end
+
+ if conf.spam_threshold and conf.ham_threshold then
+ if verdict == 'spam' then
+ if conf.spam_threshold and score >= conf.spam_threshold then
+ log_can_autolearn(verdict, score, conf.spam_threshold)
+ learn_spam = true
+ end
+ elseif verdict == 'junk' then
+ if conf.junk_threshold and score >= conf.junk_threshold then
+ log_can_autolearn(verdict, score, conf.junk_threshold)
+ learn_spam = true
+ end
+ elseif verdict == 'ham' then
+ if conf.ham_threshold and score <= conf.ham_threshold then
+ log_can_autolearn(verdict, score, conf.ham_threshold)
+ learn_ham = true
+ end
+ end
+ elseif conf.learn_verdict then
+ if verdict == 'spam' or verdict == 'junk' then
+ learn_spam = true
+ elseif verdict == 'ham' then
+ learn_ham = true
+ end
+ end
+
+ if conf.check_balance then
+ -- Check balance of learns
+ local spam_learns = task:get_mempool():get_variable('spam_learns', 'int64') or 0
+ local ham_learns = task:get_mempool():get_variable('ham_learns', 'int64') or 0
+
+ local min_balance = 0.9
+ if conf.min_balance then
+ min_balance = conf.min_balance
+ end
+
+ if spam_learns > 0 or ham_learns > 0 then
+ local max_ratio = 1.0 / min_balance
+ local spam_learns_ratio = spam_learns / (ham_learns + 1)
+ if spam_learns_ratio > max_ratio and learn_spam then
+ lua_util.debugm(N, task,
+ 'skip learning spam, balance is not satisfied: %s < %s; %s spam learns; %s ham learns',
+ spam_learns_ratio, min_balance, spam_learns, ham_learns)
+ learn_spam = false
+ end
+
+ local ham_learns_ratio = ham_learns / (spam_learns + 1)
+ if ham_learns_ratio > max_ratio and learn_ham then
+ lua_util.debugm(N, task,
+ 'skip learning ham, balance is not satisfied: %s < %s; %s spam learns; %s ham learns',
+ ham_learns_ratio, min_balance, spam_learns, ham_learns)
+ learn_ham = false
+ end
+ end
+ end
+
+ if learn_spam then
+ return 'spam'
+ elseif learn_ham then
+ return 'ham'
+ end
+end
+
+return exports \ No newline at end of file
diff --git a/lualib/lua_bayes_redis.lua b/lualib/lua_bayes_redis.lua
new file mode 100644
index 0000000..7533997
--- /dev/null
+++ b/lualib/lua_bayes_redis.lua
@@ -0,0 +1,244 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]
+
+-- This file contains functions to support Bayes statistics in Redis
+
+local exports = {}
+local lua_redis = require "lua_redis"
+local logger = require "rspamd_logger"
+local lua_util = require "lua_util"
+local ucl = require "ucl"
+
+local N = "bayes"
+
+local function gen_classify_functor(redis_params, classify_script_id)
+ return function(task, expanded_key, id, is_spam, stat_tokens, callback)
+
+ local function classify_redis_cb(err, data)
+ lua_util.debugm(N, task, 'classify redis cb: %s, %s', err, data)
+ if err then
+ callback(task, false, err)
+ else
+ callback(task, true, data[1], data[2], data[3], data[4])
+ end
+ end
+
+ lua_redis.exec_redis_script(classify_script_id,
+ { task = task, is_write = false, key = expanded_key },
+ classify_redis_cb, { expanded_key, stat_tokens })
+ end
+end
+
+local function gen_learn_functor(redis_params, learn_script_id)
+ return function(task, expanded_key, id, is_spam, symbol, is_unlearn, stat_tokens, callback, maybe_text_tokens)
+ local function learn_redis_cb(err, data)
+ lua_util.debugm(N, task, 'learn redis cb: %s, %s', err, data)
+ if err then
+ callback(task, false, err)
+ else
+ callback(task, true)
+ end
+ end
+
+ if maybe_text_tokens then
+ lua_redis.exec_redis_script(learn_script_id,
+ { task = task, is_write = true, key = expanded_key },
+ learn_redis_cb,
+ { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens })
+ else
+ lua_redis.exec_redis_script(learn_script_id,
+ { task = task, is_write = true, key = expanded_key },
+ learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens })
+ end
+
+ end
+end
+
+local function load_redis_params(classifier_ucl, statfile_ucl)
+ local redis_params
+
+ -- Try load from statfile options
+ if statfile_ucl.redis then
+ redis_params = lua_redis.try_load_redis_servers(statfile_ucl.redis, rspamd_config, true)
+ end
+
+ if not redis_params then
+ if statfile_ucl then
+ redis_params = lua_redis.try_load_redis_servers(statfile_ucl, rspamd_config, true)
+ end
+ end
+
+ -- Try load from classifier config
+ if not redis_params and classifier_ucl.backend then
+ redis_params = lua_redis.try_load_redis_servers(classifier_ucl.backend, rspamd_config, true)
+ end
+
+ if not redis_params and classifier_ucl.redis then
+ redis_params = lua_redis.try_load_redis_servers(classifier_ucl.redis, rspamd_config, true)
+ end
+
+ if not redis_params then
+ redis_params = lua_redis.try_load_redis_servers(classifier_ucl, rspamd_config, true)
+ end
+
+ -- Try load global options
+ if not redis_params then
+ redis_params = lua_redis.try_load_redis_servers(rspamd_config:get_all_opt('redis'), rspamd_config, true)
+ end
+
+ if not redis_params then
+ logger.err(rspamd_config, "cannot load Redis parameters for the classifier")
+ return nil
+ end
+
+ return redis_params
+end
+
+---
+--- Init bayes classifier
+--- @param classifier_ucl ucl of the classifier config
+--- @param statfile_ucl ucl of the statfile config
+--- @return a pair of (classify_functor, learn_functor) or `nil` in case of error
+exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, is_spam, ev_base, stat_periodic_cb)
+
+ local redis_params = load_redis_params(classifier_ucl, statfile_ucl)
+
+ if not redis_params then
+ return nil
+ end
+
+ local classify_script_id = lua_redis.load_redis_script_from_file("bayes_classify.lua", redis_params)
+ local learn_script_id = lua_redis.load_redis_script_from_file("bayes_learn.lua", redis_params)
+ local stat_script_id = lua_redis.load_redis_script_from_file("bayes_stat.lua", redis_params)
+ local max_users = classifier_ucl.max_users or 1000
+
+ local current_data = {
+ users = 0,
+ revision = 0,
+ }
+ local final_data = {
+ users = 0,
+ revision = 0, -- number of learns
+ }
+ local cursor = 0
+ rspamd_config:add_periodic(ev_base, 0.0, function(cfg, _)
+
+ local function stat_redis_cb(err, data)
+ lua_util.debugm(N, cfg, 'stat redis cb: %s, %s', err, data)
+
+ if err then
+ logger.warn(cfg, 'cannot get bayes statistics for %s: %s', symbol, err)
+ else
+ local new_cursor = data[1]
+ current_data.users = current_data.users + data[2]
+ current_data.revision = current_data.revision + data[3]
+ if new_cursor == 0 then
+ -- Done iteration
+ final_data = lua_util.shallowcopy(current_data)
+ current_data = {
+ users = 0,
+ revision = 0,
+ }
+ lua_util.debugm(N, cfg, 'final data: %s', final_data)
+ stat_periodic_cb(cfg, final_data)
+ end
+
+ cursor = new_cursor
+ end
+ end
+
+ lua_redis.exec_redis_script(stat_script_id,
+ { ev_base = ev_base, cfg = cfg, is_write = false },
+ stat_redis_cb, { tostring(cursor),
+ symbol,
+ is_spam and "learns_spam" or "learns_ham",
+ tostring(max_users) })
+ return statfile_ucl.monitor_timeout or classifier_ucl.monitor_timeout or 30.0
+ end)
+
+ return gen_classify_functor(redis_params, classify_script_id), gen_learn_functor(redis_params, learn_script_id)
+end
+
+local function gen_cache_check_functor(redis_params, check_script_id, conf)
+ local packed_conf = ucl.to_format(conf, 'msgpack')
+ return function(task, cache_id, callback)
+
+ local function classify_redis_cb(err, data)
+ lua_util.debugm(N, task, 'check cache redis cb: %s, %s (%s)', err, data, type(data))
+ if err then
+ callback(task, false, err)
+ else
+ if type(data) == 'number' then
+ callback(task, true, data)
+ else
+ callback(task, false, 'not found')
+ end
+ end
+ end
+
+ lua_util.debugm(N, task, 'checking cache: %s', cache_id)
+ lua_redis.exec_redis_script(check_script_id,
+ { task = task, is_write = false, key = cache_id },
+ classify_redis_cb, { cache_id, packed_conf })
+ end
+end
+
+local function gen_cache_learn_functor(redis_params, learn_script_id, conf)
+ local packed_conf = ucl.to_format(conf, 'msgpack')
+ return function(task, cache_id, is_spam)
+ local function learn_redis_cb(err, data)
+ lua_util.debugm(N, task, 'learn_cache redis cb: %s, %s', err, data)
+ end
+
+ lua_util.debugm(N, task, 'try to learn cache: %s', cache_id)
+ lua_redis.exec_redis_script(learn_script_id,
+ { task = task, is_write = true, key = cache_id },
+ learn_redis_cb,
+ { cache_id, is_spam and "1" or "0", packed_conf })
+
+ end
+end
+
+exports.lua_bayes_init_cache = function(classifier_ucl, statfile_ucl)
+ local redis_params = load_redis_params(classifier_ucl, statfile_ucl)
+
+ if not redis_params then
+ return nil
+ end
+
+ local default_conf = {
+ cache_prefix = "learned_ids",
+ cache_max_elt = 10000, -- Maximum number of elements in the cache key
+ cache_max_keys = 5, -- Maximum number of keys in the cache
+ cache_elt_len = 32, -- Length of the element in the cache (will trim id to that value)
+ }
+
+ local conf = lua_util.override_defaults(default_conf, classifier_ucl)
+ -- Clean all not known configurations
+ for k, _ in pairs(conf) do
+ if default_conf[k] == nil then
+ conf[k] = nil
+ end
+ end
+
+ local check_script_id = lua_redis.load_redis_script_from_file("bayes_cache_check.lua", redis_params)
+ local learn_script_id = lua_redis.load_redis_script_from_file("bayes_cache_learn.lua", redis_params)
+
+ return gen_cache_check_functor(redis_params, check_script_id, conf), gen_cache_learn_functor(redis_params,
+ learn_script_id, conf)
+end
+
+return exports
diff --git a/lualib/lua_cfg_transform.lua b/lualib/lua_cfg_transform.lua
new file mode 100644
index 0000000..d6243ad
--- /dev/null
+++ b/lualib/lua_cfg_transform.lua
@@ -0,0 +1,634 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local logger = require "rspamd_logger"
+local lua_util = require "lua_util"
+local rspamd_util = require "rspamd_util"
+local fun = require "fun"
+
+local function is_implicit(t)
+ local mt = getmetatable(t)
+
+ return mt and mt.class and mt.class == 'ucl.type.impl_array'
+end
+
+local function metric_pairs(t)
+ -- collect the keys
+ local keys = {}
+ local implicit_array = is_implicit(t)
+
+ local function gen_keys(tbl)
+ if implicit_array then
+ for _, v in ipairs(tbl) do
+ if v.name then
+ table.insert(keys, { v.name, v })
+ v.name = nil
+ else
+ -- Very tricky to distinguish:
+ -- group {name = "foo" ... } + group "blah" { ... }
+ for gr_name, gr in pairs(v) do
+ if type(gr_name) ~= 'number' then
+ -- We can also have implicit arrays here
+ local gr_implicit = is_implicit(gr)
+
+ if gr_implicit then
+ for _, gr_elt in ipairs(gr) do
+ table.insert(keys, { gr_name, gr_elt })
+ end
+ else
+ table.insert(keys, { gr_name, gr })
+ end
+ end
+ end
+ end
+ end
+ else
+ if tbl.name then
+ table.insert(keys, { tbl.name, tbl })
+ tbl.name = nil
+ else
+ for k, v in pairs(tbl) do
+ if type(k) ~= 'number' then
+ -- We can also have implicit arrays here
+ local sym_implicit = is_implicit(v)
+
+ if sym_implicit then
+ for _, elt in ipairs(v) do
+ table.insert(keys, { k, elt })
+ end
+ else
+ table.insert(keys, { k, v })
+ end
+ end
+ end
+ end
+ end
+ end
+
+ gen_keys(t)
+
+ -- return the iterator function
+ local i = 0
+ return function()
+ i = i + 1
+ if keys[i] then
+ return keys[i][1], keys[i][2]
+ end
+ end
+end
+
+local function group_transform(cfg, k, v)
+ if v.name then
+ k = v.name
+ end
+
+ local new_group = {
+ symbols = {}
+ }
+
+ if v.enabled then
+ new_group.enabled = v.enabled
+ end
+ if v.disabled then
+ new_group.disabled = v.disabled
+ end
+ if v.max_score then
+ new_group.max_score = v.max_score
+ end
+
+ if v.symbol then
+ for sk, sv in metric_pairs(v.symbol) do
+ if sv.name then
+ sk = sv.name
+ sv.name = nil -- Remove field
+ end
+
+ new_group.symbols[sk] = sv
+ end
+ end
+
+ if not cfg.group then
+ cfg.group = {}
+ end
+
+ if cfg.group[k] then
+ cfg.group[k] = lua_util.override_defaults(cfg.group[k], new_group)
+ else
+ cfg.group[k] = new_group
+ end
+
+ logger.infox("overriding group %s from the legacy metric settings", k)
+end
+
+local function symbol_transform(cfg, k, v)
+ -- first try to find any group where there is a definition of this symbol
+ for gr_n, gr in pairs(cfg.group) do
+ if gr.symbols and gr.symbols[k] then
+ -- We override group symbol with ungrouped symbol
+ logger.infox("overriding group symbol %s in the group %s", k, gr_n)
+ gr.symbols[k] = lua_util.override_defaults(gr.symbols[k], v)
+ return
+ end
+ end
+ -- Now check what Rspamd knows about this symbol
+ local sym = rspamd_config:get_symbol(k)
+
+ if not sym or not sym.group then
+ -- Otherwise we just use group 'ungrouped'
+ if not cfg.group.ungrouped then
+ cfg.group.ungrouped = {
+ symbols = {}
+ }
+ end
+
+ cfg.group.ungrouped.symbols[k] = v
+ logger.debugx("adding symbol %s to the group 'ungrouped'", k)
+ end
+end
+
+local function test_groups(groups)
+ for gr_name, gr in pairs(groups) do
+ if not gr.symbols then
+ local cnt = 0
+ for _, _ in pairs(gr) do
+ cnt = cnt + 1
+ end
+
+ if cnt == 0 then
+ logger.debugx('group %s is empty', gr_name)
+ else
+ logger.infox('group %s has no symbols', gr_name)
+ end
+ end
+ end
+end
+
+local function convert_metric(cfg, metric)
+ if metric.actions then
+ cfg.actions = lua_util.override_defaults(cfg.actions, metric.actions)
+ logger.infox("overriding actions from the legacy metric settings")
+ end
+ if metric.unknown_weight then
+ cfg.actions.unknown_weight = metric.unknown_weight
+ end
+
+ if metric.subject then
+ logger.infox("overriding subject from the legacy metric settings")
+ cfg.actions.subject = metric.subject
+ end
+
+ if metric.group then
+ for k, v in metric_pairs(metric.group) do
+ group_transform(cfg, k, v)
+ end
+ else
+ if not cfg.group then
+ cfg.group = {
+ ungrouped = {
+ symbols = {}
+ }
+ }
+ end
+ end
+
+ if metric.symbol then
+ for k, v in metric_pairs(metric.symbol) do
+ symbol_transform(cfg, k, v)
+ end
+ end
+
+ return cfg
+end
+
+-- Converts a table of groups indexed by number (implicit array) to a
+-- merged group definition
+local function merge_groups(groups)
+ local ret = {}
+ for k, gr in pairs(groups) do
+ if type(k) == 'number' then
+ for key, sec in pairs(gr) do
+ ret[key] = sec
+ end
+ else
+ ret[k] = gr
+ end
+ end
+
+ return ret
+end
+
+-- Checks configuration files for statistics
+local function check_statistics_sanity()
+ local local_conf = rspamd_paths['LOCAL_CONFDIR']
+ local local_stat = string.format('%s/local.d/%s', local_conf,
+ 'statistic.conf')
+ local local_bayes = string.format('%s/local.d/%s', local_conf,
+ 'classifier-bayes.conf')
+
+ if rspamd_util.file_exists(local_stat) and
+ rspamd_util.file_exists(local_bayes) then
+ logger.warnx(rspamd_config, 'conflicting files %s and %s are found: ' ..
+ 'Rspamd classifier configuration might be broken!', local_stat, local_bayes)
+ end
+end
+
+-- Converts surbl module config to rbl module
+local function surbl_section_convert(cfg, section)
+ local rbl_section = cfg.rbl.rbls
+ local wl = section.whitelist
+ for name, value in pairs(section.rules or {}) do
+ if rbl_section[name] then
+ logger.warnx(rspamd_config, 'conflicting names in surbl and rbl rules: %s, prefer surbl rule!',
+ name)
+ end
+ local converted = {
+ urls = true,
+ ignore_defaults = true,
+ }
+
+ if wl then
+ converted.whitelist = wl
+ end
+
+ for k, v in pairs(value) do
+ local skip = false
+ -- Rename
+ if k == 'suffix' then
+ k = 'rbl'
+ end
+ if k == 'ips' then
+ k = 'returncodes'
+ end
+ if k == 'bits' then
+ k = 'returnbits'
+ end
+ if k == 'noip' then
+ k = 'no_ip'
+ end
+ -- Crappy legacy
+ if k == 'options' then
+ if v == 'noip' or v == 'no_ip' then
+ converted.no_ip = true
+ skip = true
+ end
+ end
+ if k:match('check_') then
+ local n = k:match('check_(.*)')
+ k = n
+ end
+
+ if k == 'dkim' and v then
+ converted.dkim_domainonly = false
+ converted.dkim_match_from = true
+ end
+
+ if k == 'emails' and v then
+ -- To match surbl behaviour
+ converted.emails_domainonly = true
+ end
+
+ if not skip then
+ converted[k] = lua_util.deepcopy(v)
+ end
+ end
+ rbl_section[name] = lua_util.override_defaults(rbl_section[name], converted)
+ end
+end
+
+-- Converts surbl module config to rbl module
+local function emails_section_convert(cfg, section)
+ local rbl_section = cfg.rbl.rbls
+ local wl = section.whitelist
+ for name, value in pairs(section.rules or {}) do
+ if rbl_section[name] then
+ logger.warnx(rspamd_config, 'conflicting names in emails and rbl rules: %s, prefer emails rule!',
+ name)
+ end
+ local converted = {
+ emails = true,
+ ignore_defaults = true,
+ }
+
+ if wl then
+ converted.whitelist = wl
+ end
+
+ for k, v in pairs(value) do
+ local skip = false
+ -- Rename
+ if k == 'dnsbl' then
+ k = 'rbl'
+ end
+ if k == 'check_replyto' then
+ k = 'replyto'
+ end
+ if k == 'hashlen' then
+ k = 'hash_len'
+ end
+ if k == 'encoding' then
+ k = 'hash_format'
+ end
+ if k == 'domain_only' then
+ k = 'emails_domainonly'
+ end
+ if k == 'delimiter' then
+ k = 'emails_delimiter'
+ end
+ if k == 'skip_body' then
+ skip = true
+ if v then
+ -- Hack
+ converted.emails = false
+ converted.replyto = true
+ else
+ converted.emails = true
+ end
+ end
+ if k == 'expect_ip' then
+ -- Another stupid hack
+ if not converted.return_codes then
+ converted.returncodes = {}
+ end
+ local symbol = value.symbol or name
+ converted.returncodes[symbol] = { v }
+ skip = true
+ end
+
+ if not skip then
+ converted[k] = lua_util.deepcopy(v)
+ end
+ end
+ rbl_section[name] = lua_util.override_defaults(rbl_section[name], converted)
+ end
+end
+
+return function(cfg)
+ local ret = false
+
+ if cfg['metric'] then
+ for _, v in metric_pairs(cfg.metric) do
+ cfg = convert_metric(cfg, v)
+ end
+ ret = true
+ end
+
+ if cfg.symbols then
+ for k, v in metric_pairs(cfg.symbols) do
+ symbol_transform(cfg, k, v)
+ end
+ end
+
+ check_statistics_sanity()
+
+ if not cfg.actions then
+ logger.errx('no actions defined')
+ else
+ -- Perform sanity check for actions
+ local actions_defs = { 'no action', 'no_action', -- In case if that's added
+ 'greylist', 'add header', 'add_header',
+ 'rewrite subject', 'rewrite_subject', 'quarantine',
+ 'reject', 'discard' }
+
+ if not cfg.actions['no action'] and not cfg.actions['no_action'] and
+ not cfg.actions['accept'] then
+ for _, d in ipairs(actions_defs) do
+ if cfg.actions[d] then
+
+ local action_score = nil
+ if type(cfg.actions[d]) == 'number' then
+ action_score = cfg.actions[d]
+ elseif type(cfg.actions[d]) == 'table' and cfg.actions[d]['score'] then
+ action_score = cfg.actions[d]['score']
+ end
+
+ if type(cfg.actions[d]) ~= 'table' and not action_score then
+ cfg.actions[d] = nil
+ elseif type(action_score) == 'number' and action_score < 0 then
+ cfg.actions['no_action'] = cfg.actions[d] - 0.001
+ logger.infox(rspamd_config, 'set no_action score to: %s, as action %s has negative score',
+ cfg.actions['no_action'], d)
+ break
+ end
+ end
+ end
+ end
+
+ local actions_set = lua_util.list_to_hash(actions_defs)
+
+ -- Now check actions section for garbage
+ actions_set['unknown_weight'] = true
+ actions_set['grow_factor'] = true
+ actions_set['subject'] = true
+
+ for k, _ in pairs(cfg.actions) do
+ if not actions_set[k] then
+ logger.warnx(rspamd_config, 'unknown element in actions section: %s', k)
+ end
+ end
+
+ -- Performs thresholds sanity
+ -- We exclude greylist here as it can be set to whatever threshold in practice
+ local actions_order = {
+ 'no_action',
+ 'add_header',
+ 'rewrite_subject',
+ 'quarantine',
+ 'reject',
+ 'discard'
+ }
+ for i = 1, (#actions_order - 1) do
+ local act = actions_order[i]
+
+ if cfg.actions[act] and type(cfg.actions[act]) == 'number' then
+ local score = cfg.actions[act]
+
+ for j = i + 1, #actions_order do
+ local next_act = actions_order[j]
+ if cfg.actions[next_act] and type(cfg.actions[next_act]) == 'number' then
+ local next_score = cfg.actions[next_act]
+ if next_score <= score then
+ logger.errx(rspamd_config, 'invalid actions thresholds order: action %s (%s) must have lower ' ..
+ 'score than action %s (%s)', act, score, next_act, next_score)
+ ret = false
+ end
+ end
+ end
+ end
+ end
+ end
+
+ if not cfg.group then
+ logger.errx('no symbol groups defined')
+ else
+ if cfg.group[1] then
+ -- We need to merge groups
+ cfg.group = merge_groups(cfg.group)
+ ret = true
+ end
+ test_groups(cfg.group)
+ end
+
+ -- Deal with dkim settings
+ if not cfg.dkim then
+ cfg.dkim = {}
+ else
+ if cfg.dkim.sign_condition then
+ -- We have an obsoleted sign condition, so we need to either add dkim_signing and move it
+ -- there or just move sign condition there...
+ if not cfg.dkim_signing then
+ logger.warnx('obsoleted DKIM signing method used, converting it to "dkim_signing" module')
+ cfg.dkim_signing = {
+ sign_condition = cfg.dkim.sign_condition
+ }
+ else
+ if not cfg.dkim_signing.sign_condition then
+ logger.warnx('obsoleted DKIM signing method used, move it to "dkim_signing" module')
+ cfg.dkim_signing.sign_condition = cfg.dkim.sign_condition
+ else
+ logger.warnx('obsoleted DKIM signing method used, ignore it as "dkim_signing" also defines condition!')
+ end
+ end
+ end
+ end
+
+ -- Again: legacy stuff :(
+ if not cfg.dkim.sign_headers then
+ local sec = cfg.dkim_signing
+ if sec and sec[1] then
+ sec = cfg.dkim_signing[1]
+ end
+
+ if sec and sec.sign_headers then
+ cfg.dkim.sign_headers = sec.sign_headers
+ end
+ end
+
+ -- DKIM signing/ARC legacy
+ for _, mod in ipairs({ 'dkim_signing', 'arc' }) do
+ if cfg[mod] then
+ if cfg[mod].auth_only ~= nil then
+ if cfg[mod].sign_authenticated ~= nil then
+ logger.warnx(rspamd_config,
+ 'both auth_only (%s) and sign_authenticated (%s) for %s are specified, prefer auth_only',
+ cfg[mod].auth_only, cfg[mod].sign_authenticated, mod)
+ end
+ cfg[mod].sign_authenticated = cfg[mod].auth_only
+ end
+ end
+ end
+
+ if cfg.dkim and cfg.dkim.sign_headers and type(cfg.dkim.sign_headers) == 'table' then
+ -- Flatten
+ cfg.dkim.sign_headers = table.concat(cfg.dkim.sign_headers, ':')
+ end
+
+ -- Try to find some obvious issues with configuration
+ for k, v in pairs(cfg) do
+ if type(v) == 'table' and v[k] and type(v[k]) == 'table' then
+ logger.errx('nested section: %s { %s { ... } }, it is likely a configuration error',
+ k, k)
+ end
+ end
+
+ -- If neural network is enabled we MUST have `check_all_filters` flag
+ if cfg.neural then
+ if not cfg.options then
+ cfg.options = {}
+ end
+
+ if not cfg.options.check_all_filters then
+ logger.infox(rspamd_config, 'enable `options.check_all_filters` for neural network')
+ cfg.options.check_all_filters = true
+ end
+ end
+
+ -- Deal with IP_SCORE
+ if cfg.ip_score and (cfg.ip_score.servers or cfg.redis.servers) then
+ logger.warnx(rspamd_config, 'ip_score module is deprecated in honor of reputation module!')
+
+ if not cfg.reputation then
+ cfg.reputation = {
+ rules = {}
+ }
+ end
+
+ if not cfg.reputation.rules then
+ cfg.reputation.rules = {}
+ end
+
+ if not fun.any(function(_, v)
+ return v.selector and v.selector.ip
+ end,
+ cfg.reputation.rules) then
+ logger.infox(rspamd_config, 'attach ip reputation element to use it')
+
+ cfg.reputation.rules.ip_score = {
+ selector = {
+ ip = {},
+ },
+ backend = {
+ redis = {},
+ }
+ }
+
+ if cfg.ip_score.servers then
+ cfg.reputation.rules.ip_score.backend.redis.servers = cfg.ip_score.servers
+ end
+
+ if cfg.symbols and cfg.symbols['IP_SCORE'] then
+ local t = cfg.symbols['IP_SCORE']
+
+ if not cfg.symbols['SENDER_REP_SPAM'] then
+ cfg.symbols['SENDER_REP_SPAM'] = t
+ cfg.symbols['SENDER_REP_HAM'] = t
+ cfg.symbols['SENDER_REP_HAM'].weight = -(t.weight or 0)
+ end
+ end
+ else
+ logger.infox(rspamd_config, 'ip reputation already exists, do not do any IP_SCORE transforms')
+ end
+ end
+
+ if cfg.surbl then
+ if not cfg.rbl then
+ cfg.rbl = {
+ rbls = {}
+ }
+ end
+ if not cfg.rbl.rbls then
+ cfg.rbl.rbls = {}
+ end
+ surbl_section_convert(cfg, cfg.surbl)
+ logger.infox(rspamd_config, 'converted surbl rules to rbl rules')
+ cfg.surbl = {}
+ end
+
+ if cfg.emails then
+ if not cfg.rbl then
+ cfg.rbl = {
+ rbls = {}
+ }
+ end
+ if not cfg.rbl.rbls then
+ cfg.rbl.rbls = {}
+ end
+ emails_section_convert(cfg, cfg.emails)
+ logger.infox(rspamd_config, 'converted emails rules to rbl rules')
+ cfg.emails = {}
+ end
+
+ return ret, cfg
+end
diff --git a/lualib/lua_cfg_utils.lua b/lualib/lua_cfg_utils.lua
new file mode 100644
index 0000000..e07a3ae
--- /dev/null
+++ b/lualib/lua_cfg_utils.lua
@@ -0,0 +1,84 @@
+--[[
+Copyright (c) 2023, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_cfg_utils
+-- This module contains utility functions for configuration of Rspamd modules
+--]]
+
+local rspamd_logger = require "rspamd_logger"
+local exports = {}
+
+--[[[
+-- @function lua_util.disable_module(modname, how[, reason])
+-- Disables a plugin
+-- @param {string} modname name of plugin to disable
+-- @param {string} how 'redis' to disable redis, 'config' to disable startup
+-- @param {string} reason optional reason for failure
+--]]
+exports.disable_module = function(modname, how, reason)
+ if rspamd_plugins_state.enabled[modname] then
+ rspamd_plugins_state.enabled[modname] = nil
+ end
+
+ if how == 'redis' then
+ rspamd_plugins_state.disabled_redis[modname] = {}
+ elseif how == 'config' then
+ rspamd_plugins_state.disabled_unconfigured[modname] = {}
+ elseif how == 'experimental' then
+ rspamd_plugins_state.disabled_experimental[modname] = {}
+ elseif how == 'failed' then
+ rspamd_plugins_state.disabled_failed[modname] = { reason = reason }
+ else
+ rspamd_plugins_state.disabled_unknown[modname] = {}
+ end
+end
+
+--[[[
+-- @function lua_util.push_config_error(module, err)
+-- Pushes a configuration error to the state
+-- @param {string} module name of module
+-- @param {string} err error string
+--]]
+exports.push_config_error = function(module, err)
+ if not rspamd_plugins_state.config_errors then
+ rspamd_plugins_state.config_errors = {}
+ end
+
+ if not rspamd_plugins_state.config_errors[module] then
+ rspamd_plugins_state.config_errors[module] = {}
+ end
+
+ table.insert(rspamd_plugins_state.config_errors[module], err)
+end
+
+exports.check_configuration_errors = function()
+ local ret = true
+
+ if type(rspamd_plugins_state.config_errors) == 'table' then
+ -- We have some errors found during the configuration, so we need to show them
+ for m, errs in pairs(rspamd_plugins_state.config_errors) do
+ for _, err in ipairs(errs) do
+ rspamd_logger.errx(rspamd_config, 'configuration error: module %s: %s', m, err)
+ ret = false
+ end
+ end
+ end
+
+ return ret
+end
+
+return exports \ No newline at end of file
diff --git a/lualib/lua_clickhouse.lua b/lualib/lua_clickhouse.lua
new file mode 100644
index 0000000..28366d2
--- /dev/null
+++ b/lualib/lua_clickhouse.lua
@@ -0,0 +1,547 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+Copyright (c) 2018, Mikhail Galanin <mgalanin@mimecast.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_clickhouse
+-- This module contains Clickhouse access functions
+--]]
+
+local rspamd_logger = require "rspamd_logger"
+local rspamd_http = require "rspamd_http"
+local lua_util = require "lua_util"
+local rspamd_text = require "rspamd_text"
+
+local exports = {}
+local N = 'clickhouse'
+
+local default_timeout = 10.0
+
+local function escape_spaces(query)
+ return query:gsub('%s', '%%20')
+end
+
+local function ch_number(a)
+ if (a + 2 ^ 52) - 2 ^ 52 == a then
+ -- Integer
+ return tostring(math.floor(a))
+ end
+
+ return tostring(a)
+end
+
+local function clickhouse_quote(str)
+ if str then
+ return str:gsub('[\'\\\n\t\r]', {
+ ['\''] = [[\']],
+ ['\\'] = [[\\]],
+ ['\n'] = [[\n]],
+ ['\t'] = [[\t]],
+ ['\r'] = [[\r]],
+ })
+ end
+
+ return ''
+end
+
+-- Converts an array to a string suitable for clickhouse
+local function array_to_string(ar)
+ for i, elt in ipairs(ar) do
+ local t = type(elt)
+ if t == 'string' then
+ ar[i] = string.format('\'%s\'', clickhouse_quote(elt))
+ elseif t == 'userdata' then
+ ar[i] = string.format('\'%s\'', clickhouse_quote(tostring(elt)))
+ elseif t == 'number' then
+ ar[i] = ch_number(elt)
+ end
+ end
+
+ return table.concat(ar, ',')
+end
+
+-- Converts a row into TSV, taking extra care about arrays
+local function row_to_tsv(row)
+
+ for i, elt in ipairs(row) do
+ local t = type(elt)
+ if t == 'table' then
+ row[i] = '[' .. array_to_string(elt) .. ']'
+ elseif t == 'number' then
+ row[i] = ch_number(elt)
+ elseif t == 'userdata' then
+ row[i] = clickhouse_quote(tostring(elt))
+ else
+ row[i] = clickhouse_quote(elt)
+ end
+ end
+
+ return rspamd_text.fromtable(row, '\t')
+end
+
+exports.row_to_tsv = row_to_tsv
+
+-- Parses JSONEachRow reply from CH
+local function parse_clickhouse_response_json_eachrow(params, data, row_cb)
+ local ucl = require "ucl"
+
+ if data == nil then
+ -- clickhouse returned no data (i.e. empty result set): exiting
+ return {}
+ end
+
+ local function parse_string(s)
+ local parser = ucl.parser()
+ local res, err
+ if type(s) == 'string' then
+ res, err = parser:parse_string(s)
+ else
+ res, err = parser:parse_text(s)
+ end
+
+ if not res then
+ rspamd_logger.errx(params.log_obj, 'Parser error: %s', err)
+ return nil
+ end
+ return parser:get_object()
+ end
+
+ -- iterate over rows and parse
+ local parsed_rows = {}
+ for plain_row in data:lines() do
+ if plain_row and #plain_row > 1 then
+ local parsed_row = parse_string(plain_row)
+ if parsed_row then
+ if row_cb then
+ row_cb(parsed_row)
+ else
+ table.insert(parsed_rows, parsed_row)
+ end
+ end
+ end
+ end
+
+ return parsed_rows
+end
+
+-- Parses JSON reply from CH
+local function parse_clickhouse_response_json(params, data)
+ local ucl = require "ucl"
+
+ if data == nil then
+ -- clickhouse returned no data (i.e. empty result set) considered valid!
+ return nil, {}
+ end
+
+ local function parse_string(s)
+ local parser = ucl.parser()
+ local res, err
+
+ if type(s) == 'string' then
+ res, err = parser:parse_string(s)
+ else
+ res, err = parser:parse_text(s)
+ end
+
+ if not res then
+ rspamd_logger.errx(params.log_obj, 'Parser error: %s', err)
+ return nil
+ end
+ return parser:get_object()
+ end
+
+ local json = parse_string(data)
+
+ if not json then
+ return 'bad json', {}
+ end
+
+ return nil, json
+end
+
+-- Helper to generate HTTP closure
+local function mk_http_select_cb(upstream, params, ok_cb, fail_cb, row_cb)
+ local function http_cb(err_message, code, data, _)
+ if code ~= 200 or err_message then
+ if not err_message then
+ err_message = data
+ end
+ local ip_addr = upstream:get_addr():to_string(true)
+
+ if fail_cb then
+ fail_cb(params, err_message, data)
+ else
+ rspamd_logger.errx(params.log_obj,
+ "request failed on clickhouse server %s: %s",
+ ip_addr, err_message)
+ end
+ upstream:fail()
+ else
+ upstream:ok()
+ local rows = parse_clickhouse_response_json_eachrow(params, data, row_cb)
+
+ if rows then
+ if ok_cb then
+ ok_cb(params, rows)
+ else
+ lua_util.debugm(N, params.log_obj,
+ "http_select_cb ok: %s, %s, %s, %s", err_message, code,
+ data:gsub('[\n%s]+', ' '), _)
+ end
+ else
+ if fail_cb then
+ fail_cb(params, 'failed to parse reply', data)
+ else
+ local ip_addr = upstream:get_addr():to_string(true)
+ rspamd_logger.errx(params.log_obj,
+ "request failed on clickhouse server %s: %s",
+ ip_addr, 'failed to parse reply')
+ end
+ end
+ end
+ end
+
+ return http_cb
+end
+
+-- Helper to generate HTTP closure
+local function mk_http_insert_cb(upstream, params, ok_cb, fail_cb)
+ local function http_cb(err_message, code, data, _)
+ if code ~= 200 or err_message then
+ if not err_message then
+ err_message = data
+ end
+ local ip_addr = upstream:get_addr():to_string(true)
+
+ if fail_cb then
+ fail_cb(params, err_message, data)
+ else
+ rspamd_logger.errx(params.log_obj,
+ "request failed on clickhouse server %s: %s",
+ ip_addr, err_message)
+ end
+ upstream:fail()
+ else
+ upstream:ok()
+
+ if ok_cb then
+ local err, parsed = parse_clickhouse_response_json(data)
+
+ if err then
+ fail_cb(params, err, data)
+ else
+ ok_cb(params, parsed)
+ end
+
+ else
+ lua_util.debugm(N, params.log_obj,
+ "http_insert_cb ok: %s, %s, %s, %s", err_message, code,
+ data:gsub('[\n%s]+', ' '), _)
+ end
+ end
+ end
+
+ return http_cb
+end
+
+--[[[
+-- @function lua_clickhouse.select(upstream, settings, params, query,
+ ok_cb, fail_cb)
+-- Make select request to clickhouse
+-- @param {upstream} upstream clickhouse server upstream
+-- @param {table} settings global settings table:
+-- * use_gsip: use gzip compression
+-- * timeout: request timeout
+-- * no_ssl_verify: skip SSL verification
+-- * user: HTTP user
+-- * password: HTTP password
+-- @param {params} HTTP request params
+-- @param {string} query select query (passed in HTTP body)
+-- @param {function} ok_cb callback to be called in case of success
+-- @param {function} fail_cb callback to be called in case of some error
+-- @param {function} row_cb optional callback to be called on each parsed data row (instead of table insertion)
+-- @return {boolean} whether a connection was successful
+-- @example
+--
+--]]
+exports.select = function(upstream, settings, params, query, ok_cb, fail_cb, row_cb)
+ local http_params = {}
+
+ for k, v in pairs(params) do
+ http_params[k] = v
+ end
+
+ http_params.callback = mk_http_select_cb(upstream, http_params, ok_cb, fail_cb, row_cb)
+ http_params.gzip = settings.use_gzip
+ http_params.mime_type = 'text/plain'
+ http_params.timeout = settings.timeout or default_timeout
+ http_params.no_ssl_verify = settings.no_ssl_verify
+ http_params.user = settings.user
+ http_params.password = settings.password
+ http_params.body = query
+ http_params.log_obj = params.task or params.config
+ http_params.opaque_body = true
+
+ lua_util.debugm(N, http_params.log_obj, "clickhouse select request: %s", http_params.body)
+
+ if not http_params.url then
+ local connect_prefix = "http://"
+ if settings.use_https then
+ connect_prefix = 'https://'
+ end
+ local ip_addr = upstream:get_addr():to_string(true)
+ local database = settings.database or 'default'
+ http_params.url = string.format('%s%s/?database=%s&default_format=JSONEachRow',
+ connect_prefix, ip_addr, escape_spaces(database))
+ end
+
+ return rspamd_http.request(http_params)
+end
+
+--[[[
+-- @function lua_clickhouse.select_sync(upstream, settings, params, query,
+ ok_cb, fail_cb, row_cb)
+-- Make select request to clickhouse
+-- @param {upstream} upstream clickhouse server upstream
+-- @param {table} settings global settings table:
+-- * use_gsip: use gzip compression
+-- * timeout: request timeout
+-- * no_ssl_verify: skip SSL verification
+-- * user: HTTP user
+-- * password: HTTP password
+-- @param {params} HTTP request params
+-- @param {string} query select query (passed in HTTP body)
+-- @param {function} ok_cb callback to be called in case of success
+-- @param {function} fail_cb callback to be called in case of some error
+-- @param {function} row_cb optional callback to be called on each parsed data row (instead of table insertion)
+-- @return
+-- {string} error message if exists
+-- nil | {rows} | {http_response}
+-- @example
+--
+--]]
+exports.select_sync = function(upstream, settings, params, query, row_cb)
+ local http_params = {}
+
+ for k, v in pairs(params) do
+ http_params[k] = v
+ end
+
+ http_params.gzip = settings.use_gzip
+ http_params.mime_type = 'text/plain'
+ http_params.timeout = settings.timeout or default_timeout
+ http_params.no_ssl_verify = settings.no_ssl_verify
+ http_params.user = settings.user
+ http_params.password = settings.password
+ http_params.body = query
+ http_params.log_obj = params.task or params.config
+ http_params.opaque_body = true
+
+ lua_util.debugm(N, http_params.log_obj, "clickhouse select request: %s", http_params.body)
+
+ if not http_params.url then
+ local connect_prefix = "http://"
+ if settings.use_https then
+ connect_prefix = 'https://'
+ end
+ local ip_addr = upstream:get_addr():to_string(true)
+ local database = settings.database or 'default'
+ http_params.url = string.format('%s%s/?database=%s&default_format=JSONEachRow',
+ connect_prefix, ip_addr, escape_spaces(database))
+ end
+
+ local err, response = rspamd_http.request(http_params)
+
+ if err then
+ return err, nil
+ elseif response.code ~= 200 then
+ return response.content, response
+ else
+ lua_util.debugm(N, http_params.log_obj, "clickhouse select response: %1", response)
+ local rows = parse_clickhouse_response_json_eachrow(params, response.content, row_cb)
+ return nil, rows
+ end
+end
+
+--[[[
+-- @function lua_clickhouse.insert(upstream, settings, params, query, rows,
+ ok_cb, fail_cb)
+-- Insert data rows to clickhouse
+-- @param {upstream} upstream clickhouse server upstream
+-- @param {table} settings global settings table:
+-- * use_gsip: use gzip compression
+-- * timeout: request timeout
+-- * no_ssl_verify: skip SSL verification
+-- * user: HTTP user
+-- * password: HTTP password
+-- @param {params} HTTP request params
+-- @param {string} query select query (passed in `query` request element with spaces escaped)
+-- @param {table|mixed} rows mix of strings, numbers or tables (for arrays)
+-- @param {function} ok_cb callback to be called in case of success
+-- @param {function} fail_cb callback to be called in case of some error
+-- @return {boolean} whether a connection was successful
+-- @example
+--
+--]]
+exports.insert = function(upstream, settings, params, query, rows,
+ ok_cb, fail_cb)
+ local http_params = {}
+
+ for k, v in pairs(params) do
+ http_params[k] = v
+ end
+
+ http_params.callback = mk_http_insert_cb(upstream, http_params, ok_cb, fail_cb)
+ http_params.gzip = settings.use_gzip
+ http_params.mime_type = 'text/plain'
+ http_params.timeout = settings.timeout or default_timeout
+ http_params.no_ssl_verify = settings.no_ssl_verify
+ http_params.user = settings.user
+ http_params.password = settings.password
+ http_params.method = 'POST'
+ http_params.body = { rspamd_text.fromtable(rows, '\n'), '\n' }
+ http_params.log_obj = params.task or params.config
+
+ if not http_params.url then
+ local connect_prefix = "http://"
+ if settings.use_https then
+ connect_prefix = 'https://'
+ end
+ local ip_addr = upstream:get_addr():to_string(true)
+ local database = settings.database or 'default'
+ http_params.url = string.format('%s%s/?database=%s&query=%s%%20FORMAT%%20TabSeparated',
+ connect_prefix,
+ ip_addr,
+ escape_spaces(database),
+ escape_spaces(query))
+ end
+
+ return rspamd_http.request(http_params)
+end
+
+--[[[
+-- @function lua_clickhouse.generic(upstream, settings, params, query,
+ ok_cb, fail_cb)
+-- Make a generic request to Clickhouse (e.g. alter)
+-- @param {upstream} upstream clickhouse server upstream
+-- @param {table} settings global settings table:
+-- * use_gsip: use gzip compression
+-- * timeout: request timeout
+-- * no_ssl_verify: skip SSL verification
+-- * user: HTTP user
+-- * password: HTTP password
+-- @param {params} HTTP request params
+-- @param {string} query Clickhouse query (passed in `query` request element with spaces escaped)
+-- @param {function} ok_cb callback to be called in case of success
+-- @param {function} fail_cb callback to be called in case of some error
+-- @return {boolean} whether a connection was successful
+-- @example
+--
+--]]
+exports.generic = function(upstream, settings, params, query,
+ ok_cb, fail_cb)
+ local http_params = {}
+
+ for k, v in pairs(params) do
+ http_params[k] = v
+ end
+
+ http_params.callback = mk_http_insert_cb(upstream, http_params, ok_cb, fail_cb)
+ http_params.gzip = settings.use_gzip
+ http_params.mime_type = 'text/plain'
+ http_params.timeout = settings.timeout or default_timeout
+ http_params.no_ssl_verify = settings.no_ssl_verify
+ http_params.user = settings.user
+ http_params.password = settings.password
+ http_params.log_obj = params.task or params.config
+ http_params.body = query
+
+ if not http_params.url then
+ local connect_prefix = "http://"
+ if settings.use_https then
+ connect_prefix = 'https://'
+ end
+ local ip_addr = upstream:get_addr():to_string(true)
+ local database = settings.database or 'default'
+ http_params.url = string.format('%s%s/?database=%s&default_format=JSONEachRow',
+ connect_prefix, ip_addr, escape_spaces(database))
+ end
+
+ return rspamd_http.request(http_params)
+end
+
+--[[[
+-- @function lua_clickhouse.generic_sync(upstream, settings, params, query,
+ ok_cb, fail_cb)
+-- Make a generic request to Clickhouse (e.g. alter)
+-- @param {upstream} upstream clickhouse server upstream
+-- @param {table} settings global settings table:
+-- * use_gsip: use gzip compression
+-- * timeout: request timeout
+-- * no_ssl_verify: skip SSL verification
+-- * user: HTTP user
+-- * password: HTTP password
+-- @param {params} HTTP request params
+-- @param {string} query Clickhouse query (passed in `query` request element with spaces escaped)
+-- @return {boolean} whether a connection was successful
+-- @example
+--
+--]]
+exports.generic_sync = function(upstream, settings, params, query)
+ local http_params = {}
+
+ for k, v in pairs(params) do
+ http_params[k] = v
+ end
+
+ http_params.gzip = settings.use_gzip
+ http_params.mime_type = 'text/plain'
+ http_params.timeout = settings.timeout or default_timeout
+ http_params.no_ssl_verify = settings.no_ssl_verify
+ http_params.user = settings.user
+ http_params.password = settings.password
+ http_params.log_obj = params.task or params.config
+ http_params.body = query
+
+ if not http_params.url then
+ local connect_prefix = "http://"
+ if settings.use_https then
+ connect_prefix = 'https://'
+ end
+ local ip_addr = upstream:get_addr():to_string(true)
+ local database = settings.database or 'default'
+ http_params.url = string.format('%s%s/?database=%s&default_format=JSON',
+ connect_prefix, ip_addr, escape_spaces(database))
+ end
+
+ local err, response = rspamd_http.request(http_params)
+
+ if err then
+ return err, nil
+ elseif response.code ~= 200 then
+ return response.content, response
+ else
+ lua_util.debugm(N, http_params.log_obj, "clickhouse generic response: %1", response)
+ local e, obj = parse_clickhouse_response_json(params, response.content)
+
+ if e then
+ return e, nil
+ end
+ return nil, obj
+ end
+end
+
+return exports
diff --git a/lualib/lua_content/ical.lua b/lualib/lua_content/ical.lua
new file mode 100644
index 0000000..d018a85
--- /dev/null
+++ b/lualib/lua_content/ical.lua
@@ -0,0 +1,105 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local l = require 'lpeg'
+local lua_util = require "lua_util"
+local N = "lua_content"
+
+local ical_grammar
+
+local function gen_grammar()
+ if not ical_grammar then
+ local wsp = l.S(" \t\v\f")
+ local crlf = (l.P "\r" ^ -1 * l.P "\n") + l.P "\r"
+ local eol = (crlf * #crlf) + (crlf - (crlf ^ -1 * wsp))
+ local name = l.C((l.P(1) - (l.P ":")) ^ 1) / function(v)
+ return (v:gsub("[\n\r]+%s", ""))
+ end
+ local value = l.C((l.P(1) - eol) ^ 0) / function(v)
+ return (v:gsub("[\n\r]+%s", ""))
+ end
+ ical_grammar = name * ":" * wsp ^ 0 * value * eol ^ -1
+ end
+
+ return ical_grammar
+end
+
+local exports = {}
+
+local function extract_text_data(specific)
+ local fun = require "fun"
+
+ local tbl = fun.totable(fun.map(function(e)
+ return e[2]:lower()
+ end, specific.elts))
+ return table.concat(tbl, '\n')
+end
+
+
+-- Keys that can have visible urls
+local url_keys = lua_util.list_to_hash {
+ 'description',
+ 'location',
+ 'summary',
+ 'organizer',
+ 'organiser',
+ 'attendee',
+ 'url'
+}
+
+local function process_ical(input, mpart, task)
+ local control = { n = '\n', r = '' }
+ local rspamd_url = require "rspamd_url"
+ local escaper = l.Ct((gen_grammar() / function(key, value)
+ value = value:gsub("\\(.)", control)
+ key = key:lower():match('^([^;]+)')
+
+ if key and url_keys[key] then
+ local local_urls = rspamd_url.all(task:get_mempool(), value)
+
+ if local_urls and #local_urls > 0 then
+ for _, u in ipairs(local_urls) do
+ lua_util.debugm(N, task, 'ical: found URL in ical key "%s": %s',
+ key, tostring(u))
+ task:inject_url(u, mpart)
+ end
+ end
+ end
+ lua_util.debugm(N, task, 'ical: ical key %s = "%s"',
+ key, value)
+ return { key, value }
+ end) ^ 1)
+
+ local elts = escaper:match(input)
+
+ if not elts then
+ return nil
+ end
+
+ return {
+ tag = 'ical',
+ extract_text = extract_text_data,
+ elts = elts
+ }
+end
+
+--[[[
+-- @function lua_ical.process(input)
+-- Returns all values from ical as a plain text. Names are completely ignored.
+--]]
+exports.process = process_ical
+
+return exports \ No newline at end of file
diff --git a/lualib/lua_content/init.lua b/lualib/lua_content/init.lua
new file mode 100644
index 0000000..701d223
--- /dev/null
+++ b/lualib/lua_content/init.lua
@@ -0,0 +1,109 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_content
+-- This module contains content processing logic
+--]]
+
+
+local exports = {}
+local N = "lua_content"
+local lua_util = require "lua_util"
+
+local content_modules = {
+ ical = {
+ mime_type = { "text/calendar", "application/calendar" },
+ module = require "lua_content/ical",
+ extensions = { 'ics' },
+ output = "text"
+ },
+ vcf = {
+ mime_type = { "text/vcard", "application/vcard" },
+ module = require "lua_content/vcard",
+ extensions = { 'vcf' },
+ output = "text"
+ },
+ pdf = {
+ mime_type = "application/pdf",
+ module = require "lua_content/pdf",
+ extensions = { 'pdf' },
+ output = "table"
+ },
+}
+
+local modules_by_mime_type
+local modules_by_extension
+
+local function init()
+ modules_by_mime_type = {}
+ modules_by_extension = {}
+ for k, v in pairs(content_modules) do
+ if v.mime_type then
+ if type(v.mime_type) == 'table' then
+ for _, mt in ipairs(v.mime_type) do
+ modules_by_mime_type[mt] = { k, v }
+ end
+ else
+ modules_by_mime_type[v.mime_type] = { k, v }
+ end
+
+ end
+ if v.extensions then
+ for _, ext in ipairs(v.extensions) do
+ modules_by_extension[ext] = { k, v }
+ end
+ end
+ end
+end
+
+exports.maybe_process_mime_part = function(part, task)
+ if not modules_by_mime_type then
+ init()
+ end
+
+ local ctype, csubtype = part:get_type()
+ local mt = string.format("%s/%s", ctype or 'application',
+ csubtype or 'octet-stream')
+ local pair = modules_by_mime_type[mt]
+
+ if not pair then
+ local ext = part:get_detected_ext()
+
+ if ext then
+ pair = modules_by_extension[ext]
+ end
+ end
+
+ if pair then
+ lua_util.debugm(N, task, "found known content of type %s: %s",
+ mt, pair[1])
+
+ local data = pair[2].module.process(part:get_content(), part, task)
+
+ if data then
+ lua_util.debugm(N, task, "extracted content from %s: %s type",
+ pair[1], type(data))
+ part:set_specific(data)
+ else
+ lua_util.debugm(N, task, "failed to extract anything from %s",
+ pair[1])
+ end
+ end
+
+end
+
+return exports \ No newline at end of file
diff --git a/lualib/lua_content/pdf.lua b/lualib/lua_content/pdf.lua
new file mode 100644
index 0000000..f6d5c0b
--- /dev/null
+++ b/lualib/lua_content/pdf.lua
@@ -0,0 +1,1424 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_content/pdf
+-- This module contains some heuristics for PDF files
+--]]
+
+local rspamd_trie = require "rspamd_trie"
+local rspamd_util = require "rspamd_util"
+local rspamd_text = require "rspamd_text"
+local rspamd_url = require "rspamd_url"
+local bit = require "bit"
+local N = "lua_content"
+local lua_util = require "lua_util"
+local rspamd_regexp = require "rspamd_regexp"
+local lpeg = require "lpeg"
+local pdf_patterns = {
+ trailer = {
+ patterns = {
+ [[\ntrailer\r?\n]]
+ }
+ },
+ suspicious = {
+ patterns = {
+ [[netsh\s]],
+ [[echo\s]],
+ [=[\/[A-Za-z]*#\d\d[#A-Za-z<>/\s]]=], -- Hex encode obfuscation
+ }
+ },
+ start_object = {
+ patterns = {
+ [=[[\r\n\0]\s*\d+\s+\d+\s+obj[\s<]]=]
+ }
+ },
+ end_object = {
+ patterns = {
+ [=[endobj[\r\n]]=]
+ }
+ },
+ start_stream = {
+ patterns = {
+ [=[>\s*stream[\r\n]]=],
+ }
+ },
+ end_stream = {
+ patterns = {
+ [=[endstream[\r\n]]=]
+ }
+ }
+}
+
+local pdf_text_patterns = {
+ start = {
+ patterns = {
+ [[\sBT\s]]
+ }
+ },
+ stop = {
+ patterns = {
+ [[\sET\b]]
+ }
+ }
+}
+
+local pdf_cmap_patterns = {
+ start = {
+ patterns = {
+ [[\d\s+beginbfchar\s]],
+ [[\d\s+beginbfrange\s]]
+ }
+ },
+ stop = {
+ patterns = {
+ [[\sendbfrange\b]],
+ [[\sendbchar\b]]
+ }
+ }
+}
+
+-- index[n] ->
+-- t[1] - pattern,
+-- t[2] - key in patterns table,
+-- t[3] - value in patterns table
+-- t[4] - local pattern index
+local pdf_indexes = {}
+local pdf_text_indexes = {}
+local pdf_cmap_indexes = {}
+
+local pdf_trie
+local pdf_text_trie
+local pdf_cmap_trie
+
+local exports = {}
+
+local config = {
+ max_extraction_size = 512 * 1024,
+ max_processing_size = 32 * 1024,
+ text_extraction = false, -- NYI feature
+ url_extraction = true,
+ enabled = true,
+ js_fuzzy = true, -- Generate fuzzy hashes from PDF javascripts
+ min_js_fuzzy = 256, -- Minimum size of js to be considered as a fuzzy
+ openaction_fuzzy_only = false, -- Generate fuzzy from all scripts
+ max_pdf_objects = 10000, -- Maximum number of objects to be considered
+ max_pdf_trailer = 10 * 1024 * 1024, -- Maximum trailer size (to avoid abuse)
+ max_pdf_trailer_lines = 100, -- Maximum number of lines in pdf trailer
+ pdf_process_timeout = 1.0, -- Timeout in seconds for processing
+}
+
+-- Used to process patterns found in PDF
+-- positions for functional processors should be a iter/table from trie matcher in form
+---- [{n1, pat_idx1}, ... {nn, pat_idxn}] where
+---- pat_idxn is pattern index and n1 ... nn are match positions
+local processors = {}
+-- PDF objects outer grammar in LPEG style (performing table captures)
+local pdf_outer_grammar
+local pdf_text_grammar
+
+-- Used to match objects
+local object_re = rspamd_regexp.create_cached([=[/(\d+)\s+(\d+)\s+obj\s*/]=])
+
+local function config_module()
+ local opts = rspamd_config:get_all_opt('lua_content')
+
+ if opts and opts.pdf then
+ config = lua_util.override_defaults(config, opts.pdf)
+ end
+end
+
+local function compile_tries()
+ local default_compile_flags = bit.bor(rspamd_trie.flags.re,
+ rspamd_trie.flags.dot_all,
+ rspamd_trie.flags.no_start)
+ local function compile_pats(patterns, indexes, compile_flags)
+ local strs = {}
+ for what, data in pairs(patterns) do
+ for i, pat in ipairs(data.patterns) do
+ strs[#strs + 1] = pat
+ indexes[#indexes + 1] = { what, data, pat, i }
+ end
+ end
+
+ return rspamd_trie.create(strs, compile_flags or default_compile_flags)
+ end
+
+ if not pdf_trie then
+ pdf_trie = compile_pats(pdf_patterns, pdf_indexes)
+ end
+ if not pdf_text_trie then
+ pdf_text_trie = compile_pats(pdf_text_patterns, pdf_text_indexes)
+ end
+ if not pdf_cmap_trie then
+ pdf_cmap_trie = compile_pats(pdf_cmap_patterns, pdf_cmap_indexes)
+ end
+end
+
+-- Returns a table with generic grammar elements for PDF
+local function generic_grammar_elts()
+ local P = lpeg.P
+ local R = lpeg.R
+ local S = lpeg.S
+ local V = lpeg.V
+ local C = lpeg.C
+ local D = R '09' -- Digits
+
+ local grammar_elts = {}
+
+ -- Helper functions
+ local function pdf_hexstring_unescape(s)
+ if #s % 2 == 0 then
+ -- Sane hex string
+ return lua_util.unhex(s)
+ end
+
+ -- WTF hex string
+ -- Append '0' to it and unescape...
+ return lua_util.unhex(s:sub(1, #s - 1)) .. lua_util.unhex((s:sub(#s) .. '0'))
+ end
+
+ local function pdf_string_unescape(s)
+ local function ue_single(cc)
+ if cc == '\\r' then
+ return '\r'
+ elseif cc == '\\n' then
+ return '\n'
+ else
+ return cc:gsub(2, 2)
+ end
+ end
+ -- simple unescape \char
+ s = s:gsub('\\[^%d]', ue_single)
+ -- unescape octal
+ local function ue_octal(cc)
+ -- Replace unknown stuff with '?'
+ return string.char(tonumber(cc:sub(2), 8) or 63)
+ end
+ s = s:gsub('\\%d%d?%d?', ue_octal)
+
+ return s
+ end
+
+ local function pdf_id_unescape(s)
+ return (s:gsub('#%d%d', function(cc)
+ return string.char(tonumber(cc:sub(2), 16))
+ end))
+ end
+
+ local delim = S '()<>[]{}/%'
+ grammar_elts.ws = S '\0 \r\n\t\f'
+ local hex = R 'af' + R 'AF' + D
+ -- Comments.
+ local eol = P '\r\n' + '\n'
+ local line = (1 - S '\r\n\f') ^ 0 * eol ^ -1
+ grammar_elts.comment = P '%' * line
+
+ -- Numbers.
+ local sign = S '+-' ^ -1
+ local decimal = D ^ 1
+ local float = D ^ 1 * P '.' * D ^ 0 + P '.' * D ^ 1
+ grammar_elts.number = C(sign * (float + decimal)) / tonumber
+
+ -- String
+ grammar_elts.str = P { "(" * C(((1 - S "()\\") + (P '\\' * 1) + V(1)) ^ 0) / pdf_string_unescape * ")" }
+ grammar_elts.hexstr = P { "<" * C(hex ^ 0) / pdf_hexstring_unescape * ">" }
+
+ -- Identifier
+ grammar_elts.id = P { '/' * C((1 - (delim + grammar_elts.ws)) ^ 1) / pdf_id_unescape }
+
+ -- Booleans (who care about them?)
+ grammar_elts.boolean = C(P("true") + P("false"))
+
+ -- Stupid references
+ grammar_elts.ref = lpeg.Ct { lpeg.Cc("%REF%") * C(D ^ 1) * " " * C(D ^ 1) * " " * "R" }
+
+ return grammar_elts
+end
+
+
+-- Generates a grammar to parse outer elements (external objects in PDF notation)
+local function gen_outer_grammar()
+ local V = lpeg.V
+ local gen = generic_grammar_elts()
+
+ return lpeg.P {
+ "EXPR";
+ EXPR = gen.ws ^ 0 * V("ELT") ^ 0 * gen.ws ^ 0,
+ ELT = V("ARRAY") + V("DICT") + V("ATOM"),
+ ATOM = gen.ws ^ 0 * (gen.comment + gen.boolean + gen.ref +
+ gen.number + V("STRING") + gen.id) * gen.ws ^ 0,
+ DICT = "<<" * gen.ws ^ 0 * lpeg.Cf(lpeg.Ct("") * V("KV_PAIR") ^ 0, rawset) * gen.ws ^ 0 * ">>",
+ KV_PAIR = lpeg.Cg(gen.id * gen.ws ^ 0 * V("ELT") * gen.ws ^ 0),
+ ARRAY = "[" * gen.ws ^ 0 * lpeg.Ct(V("ELT") ^ 0) * gen.ws ^ 0 * "]",
+ STRING = lpeg.P { gen.str + gen.hexstr },
+ }
+end
+
+-- Graphic state in PDF
+local function gen_graphics_unary()
+ local P = lpeg.P
+ local S = lpeg.S
+
+ return P("q") + P("Q") + P("h")
+ + S("WSsFfBb") * P("*") ^ 0 + P("n")
+
+end
+local function gen_graphics_binary()
+ local P = lpeg.P
+ local S = lpeg.S
+
+ return S("gGwJjMi") +
+ P("M") + P("ri") + P("gs") +
+ P("CS") + P("cs") + P("sh")
+end
+local function gen_graphics_ternary()
+ local P = lpeg.P
+ local S = lpeg.S
+
+ return P("d") + P("m") + S("lm")
+end
+local function gen_graphics_nary()
+ local P = lpeg.P
+ local S = lpeg.S
+
+ return P("SC") + P("sc") + P("SCN") + P("scn") + P("k") + P("K") + P("re") + S("cvy") +
+ P("RG") + P("rg")
+end
+
+-- Generates a grammar to parse text blocks (between BT and ET)
+local function gen_text_grammar()
+ local V = lpeg.V
+ local P = lpeg.P
+ local C = lpeg.C
+ local gen = generic_grammar_elts()
+
+ local empty = ""
+ local unary_ops = C("T*") / "\n" +
+ C(gen_graphics_unary()) / empty
+ local binary_ops = P("Tc") + P("Tw") + P("Tz") + P("TL") + P("Tr") + P("Ts") +
+ gen_graphics_binary()
+ local ternary_ops = P("TD") + P("Td") + gen_graphics_ternary()
+ local nary_op = P("Tm") + gen_graphics_nary()
+ local text_binary_op = P("Tj") + P("TJ") + P("'")
+ local text_quote_op = P('"')
+ local font_op = P("Tf")
+
+ return lpeg.P {
+ "EXPR";
+ EXPR = gen.ws ^ 0 * lpeg.Ct(V("COMMAND") ^ 0),
+ COMMAND = (V("UNARY") + V("BINARY") + V("TERNARY") + V("NARY") + V("TEXT") +
+ V("FONT") + gen.comment) * gen.ws ^ 0,
+ UNARY = unary_ops,
+ BINARY = V("ARG") / empty * gen.ws ^ 1 * binary_ops,
+ TERNARY = V("ARG") / empty * gen.ws ^ 1 * V("ARG") / empty * gen.ws ^ 1 * ternary_ops,
+ NARY = (gen.number / 0 * gen.ws ^ 1) ^ 1 * (gen.id / empty * gen.ws ^ 0) ^ -1 * nary_op,
+ ARG = V("ARRAY") + V("DICT") + V("ATOM"),
+ ATOM = (gen.comment + gen.boolean + gen.ref +
+ gen.number + V("STRING") + gen.id),
+ DICT = "<<" * gen.ws ^ 0 * lpeg.Cf(lpeg.Ct("") * V("KV_PAIR") ^ 0, rawset) * gen.ws ^ 0 * ">>",
+ KV_PAIR = lpeg.Cg(gen.id * gen.ws ^ 0 * V("ARG") * gen.ws ^ 0),
+ ARRAY = "[" * gen.ws ^ 0 * lpeg.Ct(V("ARG") ^ 0) * gen.ws ^ 0 * "]",
+ STRING = lpeg.P { gen.str + gen.hexstr },
+ TEXT = (V("TEXT_ARG") * gen.ws ^ 1 * text_binary_op) +
+ (V("ARG") / 0 * gen.ws ^ 1 * V("ARG") / 0 * gen.ws ^ 1 * V("TEXT_ARG") * gen.ws ^ 1 * text_quote_op),
+ FONT = (V("FONT_ARG") * gen.ws ^ 1 * (gen.number / 0) * gen.ws ^ 1 * font_op),
+ FONT_ARG = lpeg.Ct(lpeg.Cc("%font%") * gen.id),
+ TEXT_ARG = lpeg.Ct(V("STRING")) + V("TEXT_ARRAY"),
+ TEXT_ARRAY = "[" *
+ lpeg.Ct(((gen.ws ^ 0 * (gen.ws ^ 0 * (gen.number / 0) ^ 0 * gen.ws ^ 0 * (gen.str + gen.hexstr))) ^ 1)) * gen.ws ^ 0 * "]",
+ }
+end
+
+
+-- Call immediately on require
+compile_tries()
+config_module()
+pdf_outer_grammar = gen_outer_grammar()
+pdf_text_grammar = gen_text_grammar()
+
+local function extract_text_data(specific)
+ return nil -- NYI
+end
+
+-- Generates index for major/minor pair
+local function obj_ref(major, minor)
+ return major * 10.0 + 1.0 / (minor + 1.0)
+end
+
+-- Return indirect object reference (if needed)
+local function maybe_dereference_object(elt, pdf, task)
+ if type(elt) == 'table' and elt[1] == '%REF%' then
+ local ref = obj_ref(elt[2], elt[3])
+
+ if pdf.ref[ref] then
+ -- No recursion!
+ return pdf.ref[ref]
+ else
+ lua_util.debugm(N, task, 'cannot dereference %s:%s -> %s, no object',
+ elt[2], elt[3], obj_ref(elt[2], elt[3]))
+ return nil
+ end
+ end
+
+ return elt
+end
+
+-- Apply PDF stream filter
+local function apply_pdf_filter(input, filt)
+ if filt == 'FlateDecode' then
+ return rspamd_util.inflate(input, config.max_extraction_size)
+ end
+
+ return nil
+end
+
+-- Conditionally apply a pipeline of stream filters and return uncompressed data
+local function maybe_apply_filter(dict, data, pdf, task)
+ local uncompressed = data
+
+ if dict.Filter then
+ local filt = dict.Filter
+ if type(filt) == 'string' then
+ filt = { filt }
+ end
+
+ if dict.DecodeParms then
+ local decode_params = maybe_dereference_object(dict.DecodeParms, pdf, task)
+
+ if type(decode_params) == 'table' then
+ if decode_params.Predictor then
+ return nil, 'predictor exists'
+ end
+ end
+ end
+
+ for _, f in ipairs(filt) do
+ uncompressed = apply_pdf_filter(uncompressed, f)
+
+ if not uncompressed then
+ break
+ end
+ end
+ end
+
+ return uncompressed, nil
+end
+
+-- Conditionally extract stream data from object and attach it as obj.uncompressed
+local function maybe_extract_object_stream(obj, pdf, task)
+ if pdf.encrypted then
+ -- TODO add decryption some day
+ return nil
+ end
+ local dict = obj.dict
+ if dict.Length and type(obj.stream) == 'table' then
+ local len = math.min(obj.stream.len,
+ tonumber(maybe_dereference_object(dict.Length, pdf, task)) or 0)
+ if len > 0 then
+ local real_stream = obj.stream.data:span(1, len)
+
+ local uncompressed, filter_err = maybe_apply_filter(dict, real_stream, pdf, task)
+
+ if uncompressed then
+ obj.uncompressed = uncompressed
+ lua_util.debugm(N, task, 'extracted object %s:%s: (%s -> %s)',
+ obj.major, obj.minor, len, uncompressed:len())
+ return obj.uncompressed
+ else
+ lua_util.debugm(N, task, 'cannot extract object %s:%s; len = %s; filter = %s: %s',
+ obj.major, obj.minor, len, dict.Filter, filter_err)
+ end
+ else
+ lua_util.debugm(N, task, 'cannot extract object %s:%s; len = %s',
+ obj.major, obj.minor, len)
+ end
+ end
+end
+
+local function parse_object_grammar(obj, task, pdf)
+ -- Parse grammar
+ local obj_dict_span
+ if obj.stream then
+ obj_dict_span = obj.data:span(1, obj.stream.start - obj.start)
+ else
+ obj_dict_span = obj.data
+ end
+
+ if obj_dict_span:len() < config.max_processing_size then
+ local ret, obj_or_err = pcall(pdf_outer_grammar.match, pdf_outer_grammar, obj_dict_span)
+
+ if ret then
+ if obj.stream then
+ if type(obj_or_err) == 'table' then
+ obj.dict = obj_or_err
+ else
+ obj.dict = {}
+ end
+
+ lua_util.debugm(N, task, 'stream object %s:%s is parsed to: %s',
+ obj.major, obj.minor, obj_or_err)
+ else
+ -- Direct object
+ if type(obj_or_err) == 'table' then
+ obj.dict = obj_or_err
+ obj.uncompressed = obj_or_err
+ lua_util.debugm(N, task, 'direct object %s:%s is parsed to: %s',
+ obj.major, obj.minor, obj_or_err)
+ pdf.ref[obj_ref(obj.major, obj.minor)] = obj
+ else
+ lua_util.debugm(N, task, 'direct object %s:%s is parsed to raw data: %s',
+ obj.major, obj.minor, obj_or_err)
+ pdf.ref[obj_ref(obj.major, obj.minor)] = obj_or_err
+ obj.dict = {}
+ obj.uncompressed = obj_or_err
+ end
+ end
+ else
+ lua_util.debugm(N, task, 'object %s:%s cannot be parsed: %s',
+ obj.major, obj.minor, obj_or_err)
+ end
+ else
+ lua_util.debugm(N, task, 'object %s:%s cannot be parsed: too large %s',
+ obj.major, obj.minor, obj_dict_span:len())
+ end
+end
+
+-- Extracts font data and process /ToUnicode mappings
+-- NYI in fact as cmap is ridiculously stupid and complicated
+--[[
+local function process_font(task, pdf, font, fname)
+ local dict = font
+ if font.dict then
+ dict = font.dict
+ end
+
+ if type(dict) == 'table' and dict.ToUnicode then
+ local cmap = maybe_dereference_object(dict.ToUnicode, pdf, task)
+
+ if cmap and cmap.dict then
+ maybe_extract_object_stream(cmap, pdf, task)
+ lua_util.debugm(N, task, 'found cmap for font %s: %s',
+ fname, cmap.uncompressed)
+ end
+ end
+end
+--]]
+
+-- Forward declaration
+local process_dict
+
+-- This function processes javascript string and returns JS hash and JS rspamd_text
+local function process_javascript(task, pdf, js, obj)
+ local rspamd_cryptobox_hash = require "rspamd_cryptobox_hash"
+ if type(js) == 'string' then
+ js = rspamd_text.fromstring(js):oneline()
+ elseif type(js) == 'userdata' then
+ js = js:oneline()
+ else
+ return nil
+ end
+
+ local hash = rspamd_cryptobox_hash.create(js)
+ local bin_hash = hash:bin()
+
+ if not pdf.scripts then
+ pdf.scripts = {}
+ end
+
+ if pdf.scripts[bin_hash] then
+ -- Duplicate
+ return pdf.scripts[bin_hash]
+ end
+
+ local njs = {
+ data = js,
+ hash = hash:hex(),
+ bin_hash = bin_hash,
+ object = obj,
+ }
+ pdf.scripts[bin_hash] = njs
+ return njs
+end
+
+-- Extract interesting stuff from /Action, e.g. javascript
+local function process_action(task, pdf, obj)
+ if not (obj.js or obj.launch) and (obj.dict and obj.dict.JS) then
+ local js = maybe_dereference_object(obj.dict.JS, pdf, task)
+
+ if js then
+ if type(js) == 'table' then
+ local extracted_js = maybe_extract_object_stream(js, pdf, task)
+
+ if not extracted_js then
+ lua_util.debugm(N, task, 'invalid type for JavaScript from %s:%s: %s',
+ obj.major, obj.minor, js)
+ else
+ js = extracted_js
+ end
+ end
+
+ js = process_javascript(task, pdf, js, obj)
+ if js then
+ obj.js = js
+ lua_util.debugm(N, task, 'extracted javascript from %s:%s: %s',
+ obj.major, obj.minor, obj.js.data)
+ else
+ lua_util.debugm(N, task, 'invalid type for JavaScript from %s:%s: %s',
+ obj.major, obj.minor, js)
+ end
+ elseif obj.dict.F then
+ local launch = maybe_dereference_object(obj.dict.F, pdf, task)
+
+ if launch then
+ if type(launch) == 'string' then
+ obj.launch = rspamd_text.fromstring(launch):exclude_chars('%n%c')
+ lua_util.debugm(N, task, 'extracted launch from %s:%s: %s',
+ obj.major, obj.minor, obj.launch)
+ elseif type(launch) == 'userdata' then
+ obj.launch = launch:exclude_chars('%n%c')
+ lua_util.debugm(N, task, 'extracted launch from %s:%s: %s',
+ obj.major, obj.minor, obj.launch)
+ else
+ lua_util.debugm(N, task, 'invalid type for launch from %s:%s: %s',
+ obj.major, obj.minor, launch)
+ end
+ end
+ else
+
+ lua_util.debugm(N, task, 'no JS attribute in action %s:%s',
+ obj.major, obj.minor)
+ end
+ end
+end
+
+-- Extract interesting stuff from /Catalog, e.g. javascript in /OpenAction
+local function process_catalog(task, pdf, obj)
+ if obj.dict then
+ if obj.dict.OpenAction then
+ local action = maybe_dereference_object(obj.dict.OpenAction, pdf, task)
+
+ if action and type(action) == 'table' then
+ -- This also processes action js (if not already processed)
+ process_dict(task, pdf, action, action.dict)
+ if action.js then
+ lua_util.debugm(N, task, 'found openaction JS in %s:%s: %s',
+ obj.major, obj.minor, action.js)
+ pdf.openaction = action.js
+ action.js.object = obj
+ elseif action.launch then
+ lua_util.debugm(N, task, 'found openaction launch in %s:%s: %s',
+ obj.major, obj.minor, action.launch)
+ pdf.launch = action.launch
+ else
+ lua_util.debugm(N, task, 'no JS in openaction %s:%s: %s',
+ obj.major, obj.minor, action)
+ end
+ else
+ lua_util.debugm(N, task, 'cannot find openaction %s:%s: %s -> %s',
+ obj.major, obj.minor, obj.dict.OpenAction, action)
+ end
+ else
+ lua_util.debugm(N, task, 'no openaction in catalog %s:%s',
+ obj.major, obj.minor)
+ end
+ end
+end
+
+local function process_xref(task, pdf, obj)
+ if obj.dict then
+ if obj.dict.Encrypt then
+ local encrypt = maybe_dereference_object(obj.dict.Encrypt, pdf, task)
+ lua_util.debugm(N, task, 'found encrypt: %s in xref object %s:%s',
+ encrypt, obj.major, obj.minor)
+ pdf.encrypted = true
+ end
+ end
+end
+
+process_dict = function(task, pdf, obj, dict)
+ if not obj.type and type(dict) == 'table' then
+ if dict.Type and type(dict.Type) == 'string' then
+ -- Common stuff
+ obj.type = dict.Type
+ end
+
+ if not obj.type then
+
+ if obj.dict.S and obj.dict.JS then
+ obj.type = 'Javascript'
+ lua_util.debugm(N, task, 'implicit type for JavaScript object %s:%s',
+ obj.major, obj.minor)
+ else
+ lua_util.debugm(N, task, 'no type for %s:%s',
+ obj.major, obj.minor)
+ return
+ end
+ end
+
+ lua_util.debugm(N, task, 'processed stream dictionary for object %s:%s -> %s',
+ obj.major, obj.minor, obj.type)
+ local contents = dict.Contents
+ if contents and type(contents) == 'table' then
+ if contents[1] == '%REF%' then
+ -- Single reference
+ contents = { contents }
+ end
+ obj.contents = {}
+
+ for _, c in ipairs(contents) do
+ local cobj = maybe_dereference_object(c, pdf, task)
+ if cobj and type(cobj) == 'table' then
+ obj.contents[#obj.contents + 1] = cobj
+ cobj.parent = obj
+ cobj.type = 'content'
+ end
+ end
+
+ lua_util.debugm(N, task, 'found content objects for %s:%s -> %s',
+ obj.major, obj.minor, #obj.contents)
+ end
+
+ local resources = dict.Resources
+ if resources and type(resources) == 'table' then
+ local res_ref = maybe_dereference_object(resources, pdf, task)
+
+ if type(res_ref) ~= 'table' then
+ lua_util.debugm(N, task, 'cannot parse resources from pdf: %s',
+ resources)
+ obj.resources = {}
+ elseif res_ref.dict then
+ obj.resources = res_ref.dict
+ else
+ obj.resources = {}
+ end
+ else
+ -- Fucking pdf: we need to inherit from parent
+ resources = {}
+ if dict.Parent then
+ local parent = maybe_dereference_object(dict.Parent, pdf, task)
+
+ if parent and type(parent) == 'table' and parent.dict then
+ if parent.resources then
+ lua_util.debugm(N, task, 'propagated resources from %s:%s to %s:%s',
+ parent.major, parent.minor, obj.major, obj.minor)
+ resources = parent.resources
+ end
+ end
+ end
+
+ obj.resources = resources
+ end
+
+
+
+ --[[Disabled fonts extraction
+ local fonts = obj.resources.Font
+ if fonts and type(fonts) == 'table' then
+ obj.fonts = {}
+ for k,v in pairs(fonts) do
+ obj.fonts[k] = maybe_dereference_object(v, pdf, task)
+
+ if obj.fonts[k] then
+ local font = obj.fonts[k]
+
+ if config.text_extraction then
+ process_font(task, pdf, font, k)
+ lua_util.debugm(N, task, 'found font "%s" for object %s:%s -> %s',
+ k, obj.major, obj.minor, font)
+ end
+ end
+ end
+ end
+ ]]
+
+ lua_util.debugm(N, task, 'found resources for object %s:%s (%s): %s',
+ obj.major, obj.minor, obj.type, obj.resources)
+
+ if obj.type == 'Action' then
+ process_action(task, pdf, obj)
+ elseif obj.type == 'Catalog' then
+ process_catalog(task, pdf, obj)
+ elseif obj.type == 'XRef' then
+ -- XRef stream instead of trailer from PDF 1.5 (thanks Adobe)
+ process_xref(task, pdf, obj)
+ elseif obj.type == 'Javascript' then
+ local js = maybe_dereference_object(obj.dict.JS, pdf, task)
+
+ if js then
+ if type(js) == 'table' then
+ local extracted_js = maybe_extract_object_stream(js, pdf, task)
+
+ if not extracted_js then
+ lua_util.debugm(N, task, 'invalid type for JavaScript from %s:%s: %s',
+ obj.major, obj.minor, js)
+ else
+ js = extracted_js
+ end
+ end
+
+ js = process_javascript(task, pdf, js, obj)
+ if js then
+ obj.js = js
+ lua_util.debugm(N, task, 'extracted javascript from %s:%s: %s',
+ obj.major, obj.minor, obj.js.data)
+ else
+ lua_util.debugm(N, task, 'invalid type for JavaScript from %s:%s: %s',
+ obj.major, obj.minor, js)
+ end
+ end
+ end
+ end -- Already processed dict (obj.type is not empty)
+end
+
+-- This function is intended to unpack objects from ObjStm crappy structure
+local compound_obj_grammar
+local function compound_obj_grammar_gen()
+ if not compound_obj_grammar then
+ local gen = generic_grammar_elts()
+ compound_obj_grammar = gen.ws ^ 0 * (gen.comment * gen.ws ^ 1) ^ 0 *
+ lpeg.Ct(lpeg.Ct(gen.number * gen.ws ^ 1 * gen.number * gen.ws ^ 0) ^ 1)
+ end
+
+ return compound_obj_grammar
+end
+local function pdf_compound_object_unpack(_, uncompressed, pdf, task, first)
+ -- First, we need to parse data line by line likely to find a line
+ -- that consists of pairs of numbers
+ compound_obj_grammar_gen()
+ local elts = compound_obj_grammar:match(uncompressed)
+ if elts and #elts > 0 then
+ lua_util.debugm(N, task, 'compound elts (chunk length %s): %s',
+ #uncompressed, elts)
+
+ for i, pair in ipairs(elts) do
+ local obj_number, offset = pair[1], pair[2]
+
+ offset = offset + first
+ if offset < #uncompressed then
+ local span_len
+ if i == #elts then
+ span_len = #uncompressed - offset
+ else
+ span_len = (elts[i + 1][2] + first) - offset
+ end
+
+ if span_len > 0 and offset + span_len <= #uncompressed then
+ local obj = {
+ major = obj_number,
+ minor = 0, -- Implicit
+ data = uncompressed:span(offset + 1, span_len),
+ ref = obj_ref(obj_number, 0)
+ }
+ parse_object_grammar(obj, task, pdf)
+
+ if obj.dict then
+ pdf.objects[#pdf.objects + 1] = obj
+ end
+ else
+ lua_util.debugm(N, task, 'invalid span_len for compound object %s:%s; offset = %s, len = %s',
+ pair[1], pair[2], offset + span_len, #uncompressed)
+ end
+ end
+ end
+ end
+end
+
+-- PDF 1.5 ObjStmt
+local function extract_pdf_compound_objects(task, pdf)
+ for i, obj in ipairs(pdf.objects or {}) do
+ if i > 0 and i % 100 == 0 then
+ local now = rspamd_util.get_ticks()
+
+ if now >= pdf.end_timestamp then
+ pdf.timeout_processing = now - pdf.start_timestamp
+
+ lua_util.debugm(N, task, 'pdf: timeout processing compound objects after spending %s seconds, ' ..
+ '%s elements processed',
+ pdf.timeout_processing, i)
+ break
+ end
+ end
+ if obj.stream and obj.dict and type(obj.dict) == 'table' then
+ local t = obj.dict.Type
+ if t and t == 'ObjStm' then
+ -- We are in troubles sir...
+ local nobjs = tonumber(maybe_dereference_object(obj.dict.N, pdf, task))
+ local first = tonumber(maybe_dereference_object(obj.dict.First, pdf, task))
+
+ if nobjs and first then
+ --local extend = maybe_dereference_object(obj.dict.Extends, pdf, task)
+ lua_util.debugm(N, task, 'extract ObjStm with %s objects (%s first) %s extend',
+ nobjs, first, obj.dict.Extends)
+
+ local uncompressed = maybe_extract_object_stream(obj, pdf, task)
+
+ if uncompressed then
+ pdf_compound_object_unpack(obj, uncompressed, pdf, task, first)
+ end
+ else
+ lua_util.debugm(N, task, 'ObjStm object %s:%s has bad dict: %s',
+ obj.major, obj.minor, obj.dict)
+ end
+ end
+ end
+ end
+end
+
+-- This function arranges starts and ends of all objects and process them into initial
+-- set of objects
+local function extract_outer_objects(task, input, pdf)
+ local start_pos, end_pos = 1, 1
+ local max_start_pos, max_end_pos
+ local obj_count = 0
+
+ max_start_pos = math.min(config.max_pdf_objects, #pdf.start_objects)
+ max_end_pos = math.min(config.max_pdf_objects, #pdf.end_objects)
+ lua_util.debugm(N, task, "pdf: extract objects from %s start positions and %s end positions",
+ max_start_pos, max_end_pos)
+
+ while start_pos <= max_start_pos and end_pos <= max_end_pos do
+ local first = pdf.start_objects[start_pos]
+ local last = pdf.end_objects[end_pos]
+
+ -- 7 is length of `endobj\n`
+ if first + 6 < last then
+ local len = last - first - 6
+
+ -- Also get the starting span and try to match it versus obj re to get numbers
+ local obj_line_potential = first - 32
+ if obj_line_potential < 1 then
+ obj_line_potential = 1
+ end
+ local prev_obj_end = pdf.end_objects[end_pos - 1]
+ if end_pos > 1 and prev_obj_end >= obj_line_potential and prev_obj_end < first then
+ obj_line_potential = prev_obj_end + 1
+ end
+
+ local obj_line_span = input:span(obj_line_potential, first - obj_line_potential + 1)
+ local matches = object_re:search(obj_line_span, true, true)
+
+ if matches and matches[1] then
+ local nobj = {
+ start = first,
+ len = len,
+ data = input:span(first, len),
+ major = tonumber(matches[1][2]),
+ minor = tonumber(matches[1][3]),
+ }
+ pdf.objects[obj_count + 1] = nobj
+ if nobj.major and nobj.minor then
+ -- Add reference
+ local ref = obj_ref(nobj.major, nobj.minor)
+ nobj.ref = ref -- Our internal reference
+ pdf.ref[ref] = nobj
+ end
+ end
+
+ obj_count = obj_count + 1
+ start_pos = start_pos + 1
+ end_pos = end_pos + 1
+ elseif first > last then
+ end_pos = end_pos + 1
+ else
+ start_pos = start_pos + 1
+ end_pos = end_pos + 1
+ end
+ end
+end
+
+-- This function attaches streams to objects and processes outer pdf grammar
+local function attach_pdf_streams(task, input, pdf)
+ if pdf.start_streams and pdf.end_streams then
+ local start_pos, end_pos = 1, 1
+ local max_start_pos, max_end_pos
+
+ max_start_pos = math.min(config.max_pdf_objects, #pdf.start_streams)
+ max_end_pos = math.min(config.max_pdf_objects, #pdf.end_streams)
+
+ for _, obj in ipairs(pdf.objects) do
+ while start_pos <= max_start_pos and end_pos <= max_end_pos do
+ local first = pdf.start_streams[start_pos]
+ local last = pdf.end_streams[end_pos]
+ last = last - 10 -- Exclude endstream\n pattern
+ lua_util.debugm(N, task, "start: %s, end: %s; obj: %s-%s",
+ first, last, obj.start, obj.start + obj.len)
+ if first > obj.start and last < obj.start + obj.len and last > first then
+ -- In case if we have fake endstream :(
+ while pdf.end_streams[end_pos + 1] and pdf.end_streams[end_pos + 1] < obj.start + obj.len do
+ end_pos = end_pos + 1
+ last = pdf.end_streams[end_pos]
+ end
+ -- Strip the first \n
+ while first < last do
+ local chr = input:byte(first)
+ if chr ~= 13 and chr ~= 10 then
+ break
+ end
+ first = first + 1
+ end
+ local len = last - first
+ obj.stream = {
+ start = first,
+ len = len,
+ data = input:span(first, len)
+ }
+ start_pos = start_pos + 1
+ end_pos = end_pos + 1
+ break
+ elseif first < obj.start then
+ start_pos = start_pos + 1
+ elseif last > obj.start + obj.len then
+ -- Not this object
+ break
+ else
+ start_pos = start_pos + 1
+ end_pos = end_pos + 1
+ end
+ end
+ if obj.stream then
+ lua_util.debugm(N, task, 'found object %s:%s %s start %s len, %s stream start, %s stream length',
+ obj.major, obj.minor, obj.start, obj.len, obj.stream.start, obj.stream.len)
+ else
+ lua_util.debugm(N, task, 'found object %s:%s %s start %s len, no stream',
+ obj.major, obj.minor, obj.start, obj.len)
+ end
+ end
+ end
+end
+
+-- Processes PDF objects: extracts streams, object numbers, process outer grammar,
+-- augment object types
+local function postprocess_pdf_objects(task, input, pdf)
+ pdf.objects = {} -- objects table
+ pdf.ref = {} -- references table
+ extract_outer_objects(task, input, pdf)
+
+ -- Now we have objects and we need to attach streams that are in bounds
+ attach_pdf_streams(task, input, pdf)
+ -- Parse grammar for outer objects
+ for i, obj in ipairs(pdf.objects) do
+ if i > 0 and i % 100 == 0 then
+ local now = rspamd_util.get_ticks()
+
+ if now >= pdf.end_timestamp then
+ pdf.timeout_processing = now - pdf.start_timestamp
+
+ lua_util.debugm(N, task, 'pdf: timeout processing grammars after spending %s seconds, ' ..
+ '%s elements processed',
+ pdf.timeout_processing, i)
+ break
+ end
+ end
+ if obj.ref then
+ parse_object_grammar(obj, task, pdf)
+
+ -- Special early handling
+ if obj.dict and obj.dict.Type and obj.dict.Type == 'XRef' then
+ process_xref(task, pdf, obj)
+ end
+ end
+ end
+
+ if not pdf.timeout_processing then
+ extract_pdf_compound_objects(task, pdf)
+ else
+ -- ENOTIME
+ return
+ end
+
+ -- Now we might probably have all objects being processed
+ for i, obj in ipairs(pdf.objects) do
+ if obj.dict then
+ -- Types processing
+ if i > 0 and i % 100 == 0 then
+ local now = rspamd_util.get_ticks()
+
+ if now >= pdf.end_timestamp then
+ pdf.timeout_processing = now - pdf.start_timestamp
+
+ lua_util.debugm(N, task, 'pdf: timeout processing dicts after spending %s seconds, ' ..
+ '%s elements processed',
+ pdf.timeout_processing, i)
+ break
+ end
+ end
+ process_dict(task, pdf, obj, obj.dict)
+ end
+ end
+end
+
+local function offsets_to_blocks(starts, ends, out)
+ local start_pos, end_pos = 1, 1
+
+ while start_pos <= #starts and end_pos <= #ends do
+ local first = starts[start_pos]
+ local last = ends[end_pos]
+
+ if first < last then
+ local len = last - first
+ out[#out + 1] = {
+ start = first,
+ len = len,
+ }
+ start_pos = start_pos + 1
+ end_pos = end_pos + 1
+ elseif first > last then
+ end_pos = end_pos + 1
+ else
+ -- Not ordered properly!
+ break
+ end
+ end
+end
+
+local function search_text(task, pdf)
+ for _, obj in ipairs(pdf.objects) do
+ if obj.type == 'Page' and obj.contents then
+ local text = {}
+ for _, tobj in ipairs(obj.contents) do
+ maybe_extract_object_stream(tobj, pdf, task)
+ local matches = pdf_text_trie:match(tobj.uncompressed or '')
+ if matches then
+ local text_blocks = {}
+ local starts = {}
+ local ends = {}
+
+ for npat, matched_positions in pairs(matches) do
+ if npat == 1 then
+ for _, pos in ipairs(matched_positions) do
+ starts[#starts + 1] = pos
+ end
+ else
+ for _, pos in ipairs(matched_positions) do
+ ends[#ends + 1] = pos
+ end
+ end
+ end
+
+ offsets_to_blocks(starts, ends, text_blocks)
+ for _, bl in ipairs(text_blocks) do
+ if bl.len > 2 then
+ -- To remove \s+ET\b pattern (it can leave trailing space or not but it doesn't matter)
+ bl.len = bl.len - 2
+ end
+
+ bl.data = tobj.uncompressed:span(bl.start, bl.len)
+ --lua_util.debugm(N, task, 'extracted text from object %s:%s: %s',
+ -- tobj.major, tobj.minor, bl.data)
+
+ if bl.len < config.max_processing_size then
+ local ret, obj_or_err = pcall(pdf_text_grammar.match, pdf_text_grammar,
+ bl.data)
+
+ if ret then
+ text[#text + 1] = obj_or_err
+ lua_util.debugm(N, task, 'attached %s from content object %s:%s to %s:%s',
+ obj_or_err, tobj.major, tobj.minor, obj.major, obj.minor)
+ else
+ lua_util.debugm(N, task, 'object %s:%s cannot be parsed: %s',
+ obj.major, obj.minor, obj_or_err)
+ end
+
+ end
+ end
+ end
+ end
+
+ -- Join all text data together
+ if #text > 0 then
+ obj.text = rspamd_text.fromtable(text)
+ lua_util.debugm(N, task, 'object %s:%s is parsed to: %s',
+ obj.major, obj.minor, obj.text)
+ end
+ end
+ end
+end
+
+-- This function searches objects for `/URI` key and parses it's content
+local function search_urls(task, pdf, mpart)
+ local function recursive_object_traverse(obj, dict, rec)
+ if rec > 10 then
+ lua_util.debugm(N, task, 'object %s:%s recurses too much',
+ obj.major, obj.minor)
+ return
+ end
+
+ for k, v in pairs(dict) do
+ if type(v) == 'table' then
+ recursive_object_traverse(obj, v, rec + 1)
+ elseif k == 'URI' then
+ v = maybe_dereference_object(v, pdf, task)
+ if type(v) == 'string' then
+ local url = rspamd_url.create(task:get_mempool(), v, { 'content' })
+
+ if url then
+ lua_util.debugm(N, task, 'found url %s in object %s:%s',
+ v, obj.major, obj.minor)
+ task:inject_url(url, mpart)
+ end
+ end
+ end
+ end
+ end
+
+ for _, obj in ipairs(pdf.objects) do
+ if obj.dict and type(obj.dict) == 'table' then
+ recursive_object_traverse(obj, obj.dict, 0)
+ end
+ end
+end
+
+local function process_pdf(input, mpart, task)
+
+ if not config.enabled then
+ -- Skip processing
+ return {}
+ end
+
+ local matches = pdf_trie:match(input)
+
+ if matches then
+ local start_ts = rspamd_util.get_ticks()
+ -- Temp object used to share data between pdf extraction methods
+ local pdf_object = {
+ tag = 'pdf',
+ extract_text = extract_text_data,
+ start_timestamp = start_ts,
+ end_timestamp = start_ts + config.pdf_process_timeout,
+ }
+ -- Output object that excludes all internal stuff
+ local pdf_output = lua_util.shallowcopy(pdf_object)
+ local grouped_processors = {}
+ for npat, matched_positions in pairs(matches) do
+ local index = pdf_indexes[npat]
+
+ local proc_key, loc_npat = index[1], index[4]
+
+ if not grouped_processors[proc_key] then
+ grouped_processors[proc_key] = {
+ processor_func = processors[proc_key],
+ offsets = {},
+ }
+ end
+ local proc = grouped_processors[proc_key]
+ -- Fill offsets
+ for _, pos in ipairs(matched_positions) do
+ proc.offsets[#proc.offsets + 1] = { pos, loc_npat }
+ end
+ end
+
+ for name, processor in pairs(grouped_processors) do
+ -- Sort by offset
+ lua_util.debugm(N, task, "pdf: process group %s with %s matches",
+ name, #processor.offsets)
+ table.sort(processor.offsets, function(e1, e2)
+ return e1[1] < e2[1]
+ end)
+ processor.processor_func(input, task, processor.offsets, pdf_object, pdf_output)
+ end
+
+ pdf_output.flags = {}
+
+ if pdf_object.start_objects and pdf_object.end_objects then
+ if #pdf_object.start_objects > config.max_pdf_objects then
+ pdf_output.many_objects = #pdf_object.start_objects
+ -- Trim
+ end
+
+ -- Postprocess objects
+ postprocess_pdf_objects(task, input, pdf_object)
+ if config.text_extraction then
+ search_text(task, pdf_object, pdf_output)
+ end
+ if config.url_extraction then
+ search_urls(task, pdf_object, mpart, pdf_output)
+ end
+
+ if config.js_fuzzy and pdf_object.scripts then
+ pdf_output.fuzzy_hashes = {}
+ if config.openaction_fuzzy_only then
+ -- OpenAction only
+ if pdf_object.openaction and pdf_object.openaction.bin_hash then
+ if config.min_js_fuzzy and #pdf_object.openaction.data >= config.min_js_fuzzy then
+ lua_util.debugm(N, task, "pdf: add fuzzy hash from openaction: %s; size = %s; object: %s:%s",
+ pdf_object.openaction.hash,
+ #pdf_object.openaction.data,
+ pdf_object.openaction.object.major, pdf_object.openaction.object.minor)
+ table.insert(pdf_output.fuzzy_hashes, pdf_object.openaction.bin_hash)
+ else
+ lua_util.debugm(N, task, "pdf: skip fuzzy hash from JavaScript: %s, too short: %s",
+ pdf_object.openaction.hash, #pdf_object.openaction.data)
+ end
+ end
+ else
+ -- All hashes
+ for h, sc in pairs(pdf_object.scripts) do
+ if config.min_js_fuzzy and #sc.data >= config.min_js_fuzzy then
+ lua_util.debugm(N, task, "pdf: add fuzzy hash from JavaScript: %s; size = %s; object: %s:%s",
+ sc.hash,
+ #sc.data,
+ sc.object.major, sc.object.minor)
+ table.insert(pdf_output.fuzzy_hashes, h)
+ else
+ lua_util.debugm(N, task, "pdf: skip fuzzy hash from JavaScript: %s, too short: %s",
+ sc.hash, #sc.data)
+ end
+ end
+
+ end
+ end
+ else
+ pdf_output.flags.no_objects = true
+ end
+
+ -- Propagate from object to output
+ if pdf_object.encrypted then
+ pdf_output.encrypted = true
+ end
+ if pdf_object.scripts then
+ pdf_output.scripts = true
+ end
+
+ return pdf_output
+ end
+end
+
+-- Processes the PDF trailer
+processors.trailer = function(input, task, positions, pdf_object, pdf_output)
+ local last_pos = positions[#positions]
+
+ lua_util.debugm(N, task, 'pdf: process trailer at position %s (%s total length)',
+ last_pos, #input)
+
+ if last_pos[1] > config.max_pdf_trailer then
+ pdf_output.long_trailer = #input - last_pos[1]
+ return
+ end
+
+ local last_span = input:span(last_pos[1])
+ local lines_checked = 0
+ for line in last_span:lines(true) do
+ if line:find('/Encrypt ') then
+ lua_util.debugm(N, task, "pdf: found encrypted line in trailer: %s",
+ line)
+ pdf_output.encrypted = true
+ pdf_object.encrypted = true
+ break
+ end
+ lines_checked = lines_checked + 1
+
+ if lines_checked > config.max_pdf_trailer_lines then
+ lua_util.debugm(N, task, "pdf: trailer has too many lines, stop checking")
+ pdf_output.long_trailer = #input - last_pos[1]
+ break
+ end
+ end
+end
+
+processors.suspicious = function(input, task, positions, pdf_object, pdf_output)
+ local suspicious_factor = 0.0
+ local nexec = 0
+ local nencoded = 0
+ local close_encoded = 0
+ local last_encoded
+ for _, match in ipairs(positions) do
+ if match[2] == 1 then
+ -- netsh
+ suspicious_factor = suspicious_factor + 0.5
+ elseif match[2] == 2 then
+ nexec = nexec + 1
+ elseif match[2] == 3 then
+ local enc_data = input:sub(match[1] - 2, match[1] - 1)
+ local legal_escape = false
+
+ if enc_data then
+ enc_data = enc_data:strtoul()
+
+ if enc_data then
+ -- Legit encode cases are non printable characters (e.g. spaces)
+ if enc_data < 0x21 or enc_data >= 0x7f then
+ legal_escape = true
+ end
+ end
+ end
+
+ if not legal_escape then
+ nencoded = nencoded + 1
+
+ if last_encoded then
+ if match[1] - last_encoded < 8 then
+ -- likely consecutive encoded chars, increase factor
+ close_encoded = close_encoded + 1
+ end
+ end
+ last_encoded = match[1]
+
+ end
+ end
+ end
+
+ if nencoded > 10 then
+ suspicious_factor = suspicious_factor + nencoded / 10
+ end
+ if nexec > 1 then
+ suspicious_factor = suspicious_factor + nexec / 2.0
+ end
+ if close_encoded > 4 and nencoded - close_encoded < 5 then
+ -- Too many close encoded comparing to the total number of encoded characters
+ suspicious_factor = suspicious_factor + 0.5
+ end
+
+ lua_util.debugm(N, task, 'pdf: found a suspicious patterns: %s exec, %s encoded (%s close), ' ..
+ '%s final factor',
+ nexec, nencoded, close_encoded, suspicious_factor)
+
+ if suspicious_factor > 1.0 then
+ suspicious_factor = 1.0
+ end
+
+ pdf_output.suspicious = suspicious_factor
+end
+
+local function generic_table_inserter(positions, pdf_object, output_key)
+ if not pdf_object[output_key] then
+ pdf_object[output_key] = {}
+ end
+ local shift = #pdf_object[output_key]
+ for i, pos in ipairs(positions) do
+ pdf_object[output_key][i + shift] = pos[1]
+ end
+end
+
+processors.start_object = function(_, task, positions, pdf_object)
+ generic_table_inserter(positions, pdf_object, 'start_objects')
+end
+
+processors.end_object = function(_, task, positions, pdf_object)
+ generic_table_inserter(positions, pdf_object, 'end_objects')
+end
+
+processors.start_stream = function(_, task, positions, pdf_object)
+ generic_table_inserter(positions, pdf_object, 'start_streams')
+end
+
+processors.end_stream = function(_, task, positions, pdf_object)
+ generic_table_inserter(positions, pdf_object, 'end_streams')
+end
+
+exports.process = process_pdf
+
+return exports \ No newline at end of file
diff --git a/lualib/lua_content/vcard.lua b/lualib/lua_content/vcard.lua
new file mode 100644
index 0000000..ed14412
--- /dev/null
+++ b/lualib/lua_content/vcard.lua
@@ -0,0 +1,84 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local l = require 'lpeg'
+local lua_util = require "lua_util"
+local N = "lua_content"
+
+local vcard_grammar
+
+-- XXX: Currently it is a copy of ical grammar
+local function gen_grammar()
+ if not vcard_grammar then
+ local wsp = l.S(" \t\v\f")
+ local crlf = (l.P "\r" ^ -1 * l.P "\n") + l.P "\r"
+ local eol = (crlf * #crlf) + (crlf - (crlf ^ -1 * wsp))
+ local name = l.C((l.P(1) - (l.P ":")) ^ 1) / function(v)
+ return (v:gsub("[\n\r]+%s", ""))
+ end
+ local value = l.C((l.P(1) - eol) ^ 0) / function(v)
+ return (v:gsub("[\n\r]+%s", ""))
+ end
+ vcard_grammar = name * ":" * wsp ^ 0 * value * eol ^ -1
+ end
+
+ return vcard_grammar
+end
+
+local exports = {}
+
+local function process_vcard(input, mpart, task)
+ local control = { n = '\n', r = '' }
+ local rspamd_url = require "rspamd_url"
+ local escaper = l.Ct((gen_grammar() / function(key, value)
+ value = value:gsub("\\(.)", control)
+ key = key:lower()
+ local local_urls = rspamd_url.all(task:get_mempool(), value)
+
+ if local_urls and #local_urls > 0 then
+ for _, u in ipairs(local_urls) do
+ lua_util.debugm(N, task, 'vcard: found URL in vcard %s',
+ tostring(u))
+ task:inject_url(u, mpart)
+ end
+ end
+ lua_util.debugm(N, task, 'vcard: vcard key %s = "%s"',
+ key, value)
+ return { key, value }
+ end) ^ 1)
+
+ local elts = escaper:match(input)
+
+ if not elts then
+ return nil
+ end
+
+ return {
+ tag = 'vcard',
+ extract_text = function()
+ return nil
+ end, -- NYI
+ elts = elts
+ }
+end
+
+--[[[
+-- @function vcard.process(input)
+-- Returns all values from vcard as a plain text. Names are completely ignored.
+--]]
+exports.process = process_vcard
+
+return exports \ No newline at end of file
diff --git a/lualib/lua_dkim_tools.lua b/lualib/lua_dkim_tools.lua
new file mode 100644
index 0000000..165ea8f
--- /dev/null
+++ b/lualib/lua_dkim_tools.lua
@@ -0,0 +1,742 @@
+--[[
+Copyright (c) 2016, Andrew Lewis <nerf@judo.za.org>
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local exports = {}
+
+local E = {}
+local lua_util = require "lua_util"
+local rspamd_util = require "rspamd_util"
+local logger = require "rspamd_logger"
+local fun = require "fun"
+
+local function check_violation(N, task, domain)
+ -- Check for DKIM_REJECT
+ local sym_check = 'R_DKIM_REJECT'
+
+ if N == 'arc' then
+ sym_check = 'ARC_REJECT'
+ end
+ if task:has_symbol(sym_check) then
+ local sym = task:get_symbol(sym_check)[1]
+ logger.infox(task, 'skip signing for %s: violation %s found: %s',
+ domain, sym_check, sym.options)
+ return false
+ end
+
+ return true
+end
+
+local function insert_or_update_prop(N, task, p, prop, origin, data)
+ if #p == 0 then
+ local k = {}
+ k[prop] = data
+ table.insert(p, k)
+ lua_util.debugm(N, task, 'add %s "%s" using %s', prop, data, origin)
+ else
+ for _, k in ipairs(p) do
+ if not k[prop] then
+ k[prop] = data
+ lua_util.debugm(N, task, 'set %s to "%s" using %s', prop, data, origin)
+ end
+ end
+ end
+end
+
+local function get_mempool_selectors(N, task)
+ local p = {}
+ local key_var = "dkim_key"
+ local selector_var = "dkim_selector"
+ if N == "arc" then
+ key_var = "arc_key"
+ selector_var = "arc_selector"
+ end
+
+ p.key = task:get_mempool():get_variable(key_var)
+ p.selector = task:get_mempool():get_variable(selector_var)
+
+ if (not p.key or not p.selector) then
+ return false, {}
+ end
+
+ lua_util.debugm(N, task, 'override selector and key to %s:%s', p.key, p.selector)
+ return true, p
+end
+
+local function parse_dkim_http_headers(N, task, settings)
+ -- Configure headers
+ local headers = {
+ sign_header = settings.http_sign_header or "PerformDkimSign",
+ sign_on_reject_header = settings.http_sign_on_reject_header_header or 'SignOnAuthFailed',
+ domain_header = settings.http_domain_header or 'DkimDomain',
+ selector_header = settings.http_selector_header or 'DkimSelector',
+ key_header = settings.http_key_header or 'DkimPrivateKey'
+ }
+
+ if task:get_request_header(headers.sign_header) then
+ local domain = task:get_request_header(headers.domain_header)
+ local selector = task:get_request_header(headers.selector_header)
+ local key = task:get_request_header(headers.key_header)
+
+ if not (domain and selector and key) then
+
+ logger.errx(task, 'missing required headers to sign email')
+ return false, {}
+ end
+
+ -- Now check if we need to check the existing auth
+ local hdr = task:get_request_header(headers.sign_on_reject_header)
+ if not hdr or tostring(hdr) == '0' or tostring(hdr) == 'false' then
+ if not check_violation(N, task, domain, selector) then
+ return false, {}
+ end
+ end
+
+ local p = {}
+ local k = {
+ domain = tostring(domain),
+ rawkey = tostring(key),
+ selector = tostring(selector),
+ }
+ table.insert(p, k)
+ return true, p
+ end
+
+ lua_util.debugm(N, task, 'no sign header %s', headers.sign_header)
+ return false, {}
+end
+
+local function prepare_dkim_signing(N, task, settings)
+ local is_local, is_sign_networks, is_authed
+
+ if settings.use_http_headers then
+ local res, tbl = parse_dkim_http_headers(N, task, settings)
+
+ if not res then
+ if not settings.allow_headers_fallback then
+ return res, {}
+ else
+ lua_util.debugm(N, task, 'failed to read http headers, fallback to normal schema')
+ end
+ else
+ return res, tbl
+ end
+ end
+
+ if settings.sign_condition and type(settings.sign_condition) == 'function' then
+ -- Use sign condition only
+ local ret = settings.sign_condition(task)
+
+ if not ret then
+ return false, {}
+ end
+
+ if ret[1] then
+ return true, ret
+ else
+ return true, { ret }
+ end
+ end
+
+ local auser = task:get_user()
+ local ip = task:get_from_ip()
+
+ if ip and ip:is_local() then
+ is_local = true
+ end
+
+ local has_pre_result = task:has_pre_result()
+ if has_pre_result then
+ local metric_action = task:get_metric_action()
+
+ if metric_action == 'reject' or metric_action == 'drop' then
+ -- No need to sign what we are already rejecting/dropping
+ lua_util.debugm(N, task, 'task result is already %s, no need to sign', metric_action)
+ return false, {}
+ end
+
+ if metric_action == 'soft reject' then
+ -- Same here, we are going to delay an email, signing is just a waste of time
+ lua_util.debugm(N, task, 'task result is %s, skip signing', metric_action)
+ return false, {}
+ end
+
+ -- For spam actions, there is no clear distinction
+ if metric_action ~= 'no action' and type(settings.skip_spam_sign) == 'boolean' and settings.skip_spam_sign then
+ lua_util.debugm(N, task, 'task result is %s, no need to sign', metric_action)
+ return false, {}
+ end
+ end
+
+ if settings.sign_authenticated and auser then
+ lua_util.debugm(N, task, 'user is authenticated')
+ is_authed = true
+ elseif (settings.sign_networks and settings.sign_networks:get_key(ip)) then
+ is_sign_networks = true
+ lua_util.debugm(N, task, 'mail is from address in sign_networks')
+ elseif settings.sign_local and is_local then
+ lua_util.debugm(N, task, 'mail is from local address')
+ elseif settings.sign_inbound and not is_local and not auser then
+ lua_util.debugm(N, task, 'mail was sent to us')
+ else
+ lua_util.debugm(N, task, 'mail is ineligible for signing')
+ return false, {}
+ end
+
+ local efrom = task:get_from('smtp')
+ local empty_envelope = false
+ if #(((efrom or E)[1] or E).addr or '') == 0 then
+ if not settings.allow_envfrom_empty then
+ lua_util.debugm(N, task, 'empty envelope from not allowed')
+ return false, {}
+ else
+ empty_envelope = true
+ end
+ end
+
+ local hfrom = task:get_from('mime')
+ if not settings.allow_hdrfrom_multiple and (hfrom or E)[2] then
+ lua_util.debugm(N, task, 'multiple header from not allowed')
+ return false, {}
+ end
+
+ local eto = task:get_recipients(0)
+
+ local dkim_domain
+ local hdom = ((hfrom or E)[1] or E).domain
+ local edom = ((efrom or E)[1] or E).domain
+ local tdom = ((eto or E)[1] or E).domain
+ local udom = string.match(auser or '', '.*@(.*)')
+
+ local function get_dkim_domain(dtype)
+ if settings[dtype] == 'header' then
+ return hdom
+ elseif settings[dtype] == 'envelope' then
+ return edom
+ elseif settings[dtype] == 'auth' then
+ return udom
+ elseif settings[dtype] == 'recipient' then
+ return tdom
+ else
+ return settings[dtype]:lower()
+ end
+ end
+
+ local function is_skip_sign()
+ return not (settings.sign_networks and is_sign_networks) and
+ not (settings.sign_authenticated and is_authed) and
+ not (settings.sign_local and is_local)
+ end
+
+ if hdom then
+ hdom = hdom:lower()
+ end
+ if edom then
+ edom = edom:lower()
+ end
+ if udom then
+ udom = udom:lower()
+ end
+ if tdom then
+ tdom = tdom:lower()
+ end
+
+ if settings.signing_table and (settings.key_table or settings.use_vault) then
+ -- OpenDKIM style
+ if is_skip_sign() then
+ lua_util.debugm(N, task,
+ 'skip signing: is_sign_network: %s, is_authed: %s, is_local: %s',
+ is_sign_networks, is_authed, is_local)
+ return false, {}
+ end
+
+ if not hfrom or not hfrom[1] or not hfrom[1].addr then
+ lua_util.debugm(N, task,
+ 'signing_table: cannot get data when no header from is presented')
+ return false, {}
+ end
+ local sign_entry = settings.signing_table:get_key(hfrom[1].addr:lower())
+
+ if sign_entry then
+ -- Check opendkim style entries
+ lua_util.debugm(N, task,
+ 'signing_table: found entry for %s: %s', hfrom[1].addr, sign_entry)
+ if sign_entry == '%' then
+ sign_entry = hdom
+ end
+
+ if settings.key_table then
+ -- Now search in key table
+ local key_entry = settings.key_table:get_key(sign_entry)
+
+ if key_entry then
+ local parts = lua_util.str_split(key_entry, ':')
+
+ if #parts == 2 then
+ -- domain + key
+ local selector = settings.selector
+
+ if not selector then
+ logger.errx(task, 'no selector defined for sign_entry %s, key_entry %s',
+ sign_entry, key_entry)
+ return false, {}
+ end
+
+ local res = {
+ selector = selector,
+ domain = parts[1]:gsub('%%', hdom)
+ }
+
+ local st = parts[2]:sub(1, 2)
+
+ if st:sub(1, 1) == '/' or st == './' or st == '..' then
+ res.key = parts[2]:gsub('%%', hdom)
+ lua_util.debugm(N, task, 'perform dkim signing for %s, selector=%s, domain=%s, key file=%s',
+ hdom, selector, res.domain, res.key)
+ else
+ res.rawkey = parts[2] -- No sanity check here
+ lua_util.debugm(N, task, 'perform dkim signing for %s, selector=%s, domain=%s, raw key used',
+ hdom, selector, res.domain)
+ end
+
+ return true, { res }
+ elseif #parts == 3 then
+ -- domain, selector, key
+ local selector = parts[2]
+
+ local res = {
+ selector = selector,
+ domain = parts[1]:gsub('%%', hdom)
+ }
+
+ local st = parts[3]:sub(1, 2)
+
+ if st:sub(1, 1) == '/' or st == './' or st == '..' then
+ res.key = parts[3]:gsub('%%', hdom)
+ lua_util.debugm(N, task, 'perform dkim signing for %s, selector=%s, domain=%s, key file=%s',
+ hdom, selector, res.domain, res.key)
+ else
+ res.rawkey = parts[3] -- No sanity check here
+ lua_util.debugm(N, task, 'perform dkim signing for %s, selector=%s, domain=%s, raw key used',
+ hdom, selector, res.domain)
+ end
+
+ return true, { res }
+ else
+ logger.errx(task, 'invalid key entry for sign entry %s: %s; when signing %s domain',
+ sign_entry, key_entry, hdom)
+ return false, {}
+ end
+ elseif settings.use_vault then
+ -- Sign table is presented, the rest is covered by vault
+ lua_util.debugm(N, task, 'check vault for %s, by sign entry %s, key entry is missing',
+ hdom, sign_entry)
+ return true, {
+ domain = sign_entry,
+ vault = true
+ }
+ else
+ logger.errx(task, 'missing key entry for sign entry %s; when signing %s domain',
+ sign_entry, hdom)
+ return false, {}
+ end
+ else
+ logger.errx(task, 'cannot get key entry for signing entry %s, when signing %s domain',
+ sign_entry, hdom)
+ return false, {}
+ end
+ else
+ lua_util.debugm(N, task,
+ 'signing_table: no entry for %s', hfrom[1].addr)
+ return false, {}
+ end
+ else
+ if settings.use_domain_sign_networks and is_sign_networks then
+ dkim_domain = get_dkim_domain('use_domain_sign_networks')
+ lua_util.debugm(N, task,
+ 'sign_networks: use domain(%s) for signature: %s',
+ settings.use_domain_sign_networks, dkim_domain)
+ elseif settings.use_domain_sign_local and is_local then
+ dkim_domain = get_dkim_domain('use_domain_sign_local')
+ lua_util.debugm(N, task, 'local: use domain(%s) for signature: %s',
+ settings.use_domain_sign_local, dkim_domain)
+ elseif settings.use_domain_sign_inbound and not is_local and not auser then
+ dkim_domain = get_dkim_domain('use_domain_sign_inbound')
+ lua_util.debugm(N, task, 'inbound: use domain(%s) for signature: %s',
+ settings.use_domain_sign_inbound, dkim_domain)
+ elseif settings.use_domain_custom then
+ if type(settings.use_domain_custom) == 'string' then
+ -- Load custom function
+ local loadstring = loadstring or load
+ local ret, res_or_err = pcall(loadstring(settings.use_domain_custom))
+ if ret then
+ if type(res_or_err) == 'function' then
+ settings.use_domain_custom = res_or_err
+ dkim_domain = settings.use_domain_custom(task)
+ lua_util.debugm(N, task, 'use custom domain for signing: %s',
+ dkim_domain)
+ else
+ logger.errx(task, 'cannot load dkim domain custom script: invalid type: %s, expected function',
+ type(res_or_err))
+ settings.use_domain_custom = nil
+ end
+ else
+ logger.errx(task, 'cannot load dkim domain custom script: %s', res_or_err)
+ settings.use_domain_custom = nil
+ end
+ else
+ dkim_domain = settings.use_domain_custom(task)
+ lua_util.debugm(N, task, 'use custom domain for signing: %s',
+ dkim_domain)
+ end
+ else
+ dkim_domain = get_dkim_domain('use_domain')
+ lua_util.debugm(N, task, 'use domain(%s) for signature: %s',
+ settings.use_domain, dkim_domain)
+ end
+ end
+
+ if not dkim_domain then
+ lua_util.debugm(N, task, 'could not extract dkim domain')
+ return false, {}
+ end
+
+ if settings.use_esld then
+ dkim_domain = rspamd_util.get_tld(dkim_domain)
+ if hdom then
+ hdom = rspamd_util.get_tld(hdom)
+ end
+ if edom then
+ edom = rspamd_util.get_tld(edom)
+ end
+ end
+
+ lua_util.debugm(N, task, 'final DKIM domain: %s', dkim_domain)
+
+ -- Sanity checks
+ if edom and hdom and not settings.allow_hdrfrom_mismatch and hdom ~= edom then
+ if settings.allow_hdrfrom_mismatch_local and is_local then
+ lua_util.debugm(N, task, 'domain mismatch allowed for local IP: %1 != %2', hdom, edom)
+ elseif settings.allow_hdrfrom_mismatch_sign_networks and is_sign_networks then
+ lua_util.debugm(N, task, 'domain mismatch allowed for sign_networks: %1 != %2', hdom, edom)
+ else
+ if empty_envelope and hdom then
+ lua_util.debugm(N, task, 'domain mismatch allowed for empty envelope: %1 != %2', hdom, edom)
+ else
+ lua_util.debugm(N, task, 'domain mismatch not allowed: %1 != %2', hdom, edom)
+ return false, {}
+ end
+ end
+ end
+
+ if auser and not settings.allow_username_mismatch then
+ if not udom then
+ lua_util.debugm(N, task, 'couldnt find domain in username')
+ return false, {}
+ end
+ if settings.use_esld then
+ udom = rspamd_util.get_tld(udom)
+ end
+ if udom ~= dkim_domain then
+ lua_util.debugm(N, task, 'user domain mismatch')
+ return false, {}
+ end
+ end
+
+ local p = {}
+
+ if settings.use_vault then
+ if settings.vault_domains then
+ if settings.vault_domains:get_key(dkim_domain) then
+ return true, {
+ domain = dkim_domain,
+ vault = true,
+ }
+ else
+ lua_util.debugm(N, task, 'domain %s is not designated for vault',
+ dkim_domain)
+ return false, {}
+ end
+ else
+ -- TODO: try every domain in the vault
+ return true, {
+ domain = dkim_domain,
+ vault = true,
+ }
+ end
+ end
+
+ if settings.domain[dkim_domain] then
+ -- support old style selector/paths
+ if settings.domain[dkim_domain].selector or
+ settings.domain[dkim_domain].path then
+ local k = {}
+ k.selector = settings.domain[dkim_domain].selector
+ k.key = settings.domain[dkim_domain].path
+ table.insert(p, k)
+ end
+ for _, s in ipairs((settings.domain[dkim_domain].selectors or {})) do
+ lua_util.debugm(N, task, 'adding selector: %1', s)
+ local k = {}
+ k.selector = s.selector
+ k.key = s.path
+ table.insert(p, k)
+ end
+ end
+
+ if #p == 0 then
+ local ret, k = get_mempool_selectors(N, task)
+ if ret then
+ table.insert(p, k)
+ lua_util.debugm(N, task, 'using mempool selector %s with key %s',
+ k.selector, k.key)
+ end
+ end
+
+ if settings.selector_map then
+ local data = settings.selector_map:get_key(dkim_domain)
+ if data then
+ insert_or_update_prop(N, task, p, 'selector', 'selector_map', data)
+ else
+ lua_util.debugm(N, task, 'no selector in map for %s', dkim_domain)
+ end
+ end
+
+ if settings.path_map then
+ local data = settings.path_map:get_key(dkim_domain)
+ if data then
+ insert_or_update_prop(N, task, p, 'key', 'path_map', data)
+ else
+ lua_util.debugm(N, task, 'no key in map for %s', dkim_domain)
+ end
+ end
+
+ if #p == 0 and not settings.try_fallback then
+ lua_util.debugm(N, task, 'dkim unconfigured and fallback disabled')
+ return false, {}
+ end
+
+ if not settings.use_redis then
+ insert_or_update_prop(N, task, p, 'key',
+ 'default path', settings.path)
+ end
+
+ insert_or_update_prop(N, task, p, 'selector',
+ 'default selector', settings.selector)
+
+ if settings.check_violation then
+ if not check_violation(N, task, p.domain) then
+ return false, {}
+ end
+ end
+
+ insert_or_update_prop(N, task, p, 'domain', 'dkim_domain',
+ dkim_domain)
+
+ return true, p
+end
+
+exports.prepare_dkim_signing = prepare_dkim_signing
+
+exports.sign_using_redis = function(N, task, settings, selectors, sign_func, err_func)
+ local lua_redis = require "lua_redis"
+
+ local function try_redis_key(selector, p)
+ p.key = nil
+ p.selector = selector
+ local rk = string.format('%s.%s', p.selector, p.domain)
+ local function redis_key_cb(err, data)
+ if err then
+ err_func(string.format("cannot make request to load DKIM key for %s: %s",
+ rk, err))
+ elseif type(data) ~= 'string' then
+ lua_util.debugm(N, task, "missing DKIM key for %s", rk)
+ else
+ p.rawkey = data
+ lua_util.debugm(N, task, 'found and parsed key for %s:%s in Redis',
+ p.domain, p.selector)
+ sign_func(task, p)
+ end
+ end
+ local rret = lua_redis.redis_make_request(task,
+ settings.redis_params, -- connect params
+ rk, -- hash key
+ false, -- is write
+ redis_key_cb, --callback
+ 'HGET', -- command
+ { settings.key_prefix, rk } -- arguments
+ )
+ if not rret then
+ err_func(task,
+ string.format("cannot make request to load DKIM key for %s", rk))
+ end
+ end
+
+ for _, p in ipairs(selectors) do
+ if settings.selector_prefix then
+ logger.infox(task, "using selector prefix '%s' for domain '%s'",
+ settings.selector_prefix, p.domain);
+ local function redis_selector_cb(err, data)
+ if err or type(data) ~= 'string' then
+ err_func(task, string.format("cannot make request to load DKIM selector for domain %s: %s",
+ p.domain, err))
+ else
+ try_redis_key(data, p)
+ end
+ end
+ local rret = lua_redis.redis_make_request(task,
+ settings.redis_params, -- connect params
+ p.domain, -- hash key
+ false, -- is write
+ redis_selector_cb, --callback
+ 'HGET', -- command
+ { settings.selector_prefix, p.domain } -- arguments
+ )
+ if not rret then
+ err_func(task, string.format("cannot make Redis request to load DKIM selector for domain %s",
+ p.domain))
+ end
+ else
+ try_redis_key(p.selector, p)
+ end
+ end
+end
+
+exports.sign_using_vault = function(N, task, settings, selectors, sign_func, err_func)
+ local http = require "rspamd_http"
+ local ucl = require "ucl"
+
+ local full_url = string.format('%s/v1/%s/%s',
+ settings.vault_url, settings.vault_path or 'dkim', selectors.domain)
+ local upstream_list = lua_util.http_upstreams_by_url(rspamd_config:get_mempool(), settings.vault_url)
+
+ local function vault_callback(err, code, body, _)
+ if code ~= 200 then
+ err_func(task, string.format('cannot request data from the vault url: %s; %s (%s)',
+ full_url, err, body))
+ else
+ local parser = ucl.parser()
+ local res, parser_err = parser:parse_string(body)
+ if not res then
+ err_func(task, string.format('vault reply for %s (data=%s) cannot be parsed: %s',
+ full_url, body, parser_err))
+ else
+ local obj = parser:get_object()
+
+ if not obj or not obj.data then
+ err_func(task, string.format('vault reply for %s (data=%s) is invalid, no data',
+ full_url, body))
+ else
+ local elts = obj.data.selectors or {}
+
+ -- Filter selectors by time/sanity
+ local function is_selector_valid(p)
+ if not p.key or not p.selector then
+ return false
+ end
+
+ if p.valid_start then
+ -- Check start time
+ if rspamd_util.get_time() < tonumber(p.valid_start) then
+ return false
+ end
+ end
+
+ if p.valid_end then
+ if rspamd_util.get_time() >= tonumber(p.valid_end) then
+ return false
+ end
+ end
+
+ return true
+ end
+ fun.each(function(p)
+ local dkim_sign_data = {
+ rawkey = p.key,
+ selector = p.selector,
+ domain = p.domain or selectors.domain,
+ alg = p.alg,
+ }
+ lua_util.debugm(N, task, 'found and parsed key for %s:%s in Vault',
+ dkim_sign_data.domain, dkim_sign_data.selector)
+ sign_func(task, dkim_sign_data)
+ end, fun.filter(is_selector_valid, elts))
+ end
+ end
+ end
+ end
+
+ local ret = http.request {
+ task = task,
+ url = full_url,
+ callback = vault_callback,
+ timeout = settings.http_timeout or 5.0,
+ no_ssl_verify = settings.no_ssl_verify,
+ keepalive = true,
+ upstream = upstream_list and upstream_list:get_upstream_round_robin() or nil,
+ headers = {
+ ['X-Vault-Token'] = settings.vault_token,
+ },
+ }
+
+ if not ret then
+ err_func(task, string.format("cannot make HTTP request to load DKIM data domain %s",
+ selectors.domain))
+ end
+end
+
+exports.validate_signing_settings = function(settings)
+ return settings.use_redis or
+ settings.path or
+ settings.domain or
+ settings.path_map or
+ settings.selector_map or
+ settings.use_http_headers or
+ (settings.signing_table and settings.key_table) or
+ (settings.use_vault and settings.vault_url and settings.vault_token) or
+ settings.sign_condition
+end
+
+exports.process_signing_settings = function(N, settings, opts)
+ local lua_maps = require "lua_maps"
+ -- Used to convert plain options to the maps
+ local maps_opts = {
+ sign_networks = { 'radix', 'DKIM signing networks' },
+ path_map = { 'map', 'Paths to DKIM signing keys' },
+ selector_map = { 'map', 'DKIM selectors' },
+ signing_table = { 'glob', 'DKIM signing table' },
+ key_table = { 'glob', 'DKIM keys table' },
+ vault_domains = { 'glob', 'DKIM signing domains in vault' },
+ whitelisted_signers_map = { 'set', 'ARC trusted signers domains' }
+ }
+ for k, v in pairs(opts) do
+ local maybe_map = maps_opts[k]
+ if maybe_map then
+ settings[k] = lua_maps.map_add_from_ucl(v, maybe_map[1], maybe_map[2])
+ elseif k == 'sign_condition' then
+ local ret, f = lua_util.callback_from_string(v)
+ if ret then
+ settings[k] = f
+ else
+ logger.errx(rspamd_config, 'cannot load sign condition %s: %s', v, f)
+ end
+ else
+ settings[k] = v
+ end
+ end
+end
+
+return exports
diff --git a/lualib/lua_ffi/common.lua b/lualib/lua_ffi/common.lua
new file mode 100644
index 0000000..4076cfa
--- /dev/null
+++ b/lualib/lua_ffi/common.lua
@@ -0,0 +1,45 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_ffi/common
+-- Common ffi definitions
+--]]
+
+local ffi = require 'ffi'
+
+ffi.cdef [[
+struct GString {
+ char *str;
+ size_t len;
+ size_t allocated_len;
+};
+struct GArray {
+ char *data;
+ unsigned len;
+};
+typedef void (*ref_dtor_cb_t)(void *data);
+struct ref_entry_s {
+ unsigned int refcount;
+ ref_dtor_cb_t dtor;
+};
+
+void g_string_free (struct GString *st, int free_data);
+void g_free (void *p);
+long rspamd_snprintf (char *buf, long max, const char *fmt, ...);
+]]
+
+return {} \ No newline at end of file
diff --git a/lualib/lua_ffi/dkim.lua b/lualib/lua_ffi/dkim.lua
new file mode 100644
index 0000000..e4592c2
--- /dev/null
+++ b/lualib/lua_ffi/dkim.lua
@@ -0,0 +1,144 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_ffi/dkim
+-- This module contains ffi interfaces to DKIM
+--]]
+
+local ffi = require 'ffi'
+
+ffi.cdef [[
+struct rspamd_dkim_sign_context_s;
+struct rspamd_dkim_key_s;
+struct rspamd_task;
+enum rspamd_dkim_key_format {
+ RSPAMD_DKIM_KEY_FILE = 0,
+ RSPAMD_DKIM_KEY_PEM,
+ RSPAMD_DKIM_KEY_BASE64,
+ RSPAMD_DKIM_KEY_RAW,
+};
+enum rspamd_dkim_type {
+ RSPAMD_DKIM_NORMAL,
+ RSPAMD_DKIM_ARC_SIG,
+ RSPAMD_DKIM_ARC_SEAL
+};
+struct rspamd_dkim_sign_context_s*
+rspamd_create_dkim_sign_context (struct rspamd_task *task,
+ struct rspamd_dkim_key_s *priv_key,
+ int headers_canon,
+ int body_canon,
+ const char *dkim_headers,
+ enum rspamd_dkim_type type,
+ void *unused);
+struct rspamd_dkim_key_s* rspamd_dkim_sign_key_load (const char *what, size_t len,
+ enum rspamd_dkim_key_format,
+ void *err);
+void rspamd_dkim_key_unref (struct rspamd_dkim_key_s *k);
+struct GString *rspamd_dkim_sign (struct rspamd_task *task,
+ const char *selector,
+ const char *domain,
+ unsigned long expire,
+ size_t len,
+ unsigned int idx,
+ const char *arc_cv,
+ struct rspamd_dkim_sign_context_s *ctx);
+]]
+
+local function load_sign_key(what, format)
+ if not format then
+ format = ffi.C.RSPAMD_DKIM_KEY_PEM
+ else
+ if format == 'file' then
+ format = ffi.C.RSPAMD_DKIM_KEY_FILE
+ elseif format == 'base64' then
+ format = ffi.C.RSPAMD_DKIM_KEY_BASE64
+ elseif format == 'raw' then
+ format = ffi.C.RSPAMD_DKIM_KEY_RAW
+ else
+ return nil, 'unknown key format'
+ end
+ end
+
+ return ffi.C.rspamd_dkim_sign_key_load(what, #what, format, nil)
+end
+
+local default_dkim_headers = "(o)from:(o)sender:(o)reply-to:(o)subject:(o)date:(o)message-id:" ..
+ "(o)to:(o)cc:(o)mime-version:(o)content-type:(o)content-transfer-encoding:" ..
+ "resent-to:resent-cc:resent-from:resent-sender:resent-message-id:" ..
+ "(o)in-reply-to:(o)references:list-id:list-owner:list-unsubscribe:" ..
+ "list-subscribe:list-post:(o)openpgp:(o)autocrypt"
+
+local function create_sign_context(task, privkey, dkim_headers, sign_type)
+ if not task or not privkey then
+ return nil, 'invalid arguments'
+ end
+
+ if not dkim_headers then
+ dkim_headers = default_dkim_headers
+ end
+
+ if not sign_type then
+ sign_type = 'dkim'
+ end
+
+ if sign_type == 'dkim' then
+ sign_type = ffi.C.RSPAMD_DKIM_NORMAL
+ elseif sign_type == 'arc-sig' then
+ sign_type = ffi.C.RSPAMD_DKIM_ARC_SIG
+ elseif sign_type == 'arc-seal' then
+ sign_type = ffi.C.RSPAMD_DKIM_ARC_SEAL
+ else
+ return nil, 'invalid sign type'
+ end
+
+ return ffi.C.rspamd_create_dkim_sign_context(task:topointer(), privkey,
+ 1, 1, dkim_headers, sign_type, nil)
+end
+
+local function do_sign(task, sign_context, selector, domain,
+ expire, len, arc_idx)
+ if not task or not sign_context or not selector or not domain then
+ return nil, 'invalid arguments'
+ end
+
+ if not expire then
+ expire = 0
+ end
+ if not len then
+ len = 0
+ end
+ if not arc_idx then
+ arc_idx = 0
+ end
+
+ local gstring = ffi.C.rspamd_dkim_sign(task:topointer(), selector, domain, expire, len, arc_idx, nil, sign_context)
+
+ if not gstring then
+ return nil, 'cannot sign'
+ end
+
+ local ret = ffi.string(gstring.str, gstring.len)
+ ffi.C.g_string_free(gstring, true)
+
+ return ret
+end
+
+return {
+ load_sign_key = load_sign_key,
+ create_sign_context = create_sign_context,
+ do_sign = do_sign
+}
diff --git a/lualib/lua_ffi/init.lua b/lualib/lua_ffi/init.lua
new file mode 100644
index 0000000..efbbc7a
--- /dev/null
+++ b/lualib/lua_ffi/init.lua
@@ -0,0 +1,59 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_ffi
+-- This module contains ffi interfaces (requires luajit or lua-ffi)
+--]]
+
+local ffi
+
+local exports = {}
+
+if type(jit) == 'table' then
+ ffi = require "ffi"
+ local NULL = ffi.new 'void*'
+
+ exports.is_null = function(o)
+ return o ~= NULL
+ end
+else
+ local ret, result_or_err = pcall(require, 'ffi')
+
+ if not ret then
+ return {}
+ end
+
+ ffi = result_or_err
+ -- Lua ffi
+ local NULL = ffi.NULL or ffi.C.NULL
+ exports.is_null = function(o)
+ return o ~= NULL
+ end
+end
+
+pcall(ffi.load, "rspamd-server", true)
+exports.common = require "lua_ffi/common"
+exports.dkim = require "lua_ffi/dkim"
+exports.spf = require "lua_ffi/spf"
+exports.linalg = require "lua_ffi/linalg"
+
+for k, v in pairs(ffi) do
+ -- Preserve all stuff to use lua_ffi as ffi itself
+ exports[k] = v
+end
+
+return exports \ No newline at end of file
diff --git a/lualib/lua_ffi/linalg.lua b/lualib/lua_ffi/linalg.lua
new file mode 100644
index 0000000..2df488a
--- /dev/null
+++ b/lualib/lua_ffi/linalg.lua
@@ -0,0 +1,87 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_ffi/linalg
+-- This module contains ffi interfaces to linear algebra routines
+--]]
+
+local ffi = require 'ffi'
+
+local exports = {}
+
+ffi.cdef [[
+ void kad_sgemm_simple(int trans_A, int trans_B, int M, int N, int K, const float *A, const float *B, float *C);
+ bool kad_ssyev_simple (int N, float *A, float *output);
+]]
+
+local function table_to_ffi(a, m, n)
+ local a_conv = ffi.new("float[?]", m * n)
+ for i = 1, m or #a do
+ for j = 1, n or #a[1] do
+ a_conv[(i - 1) * n + (j - 1)] = a[i][j]
+ end
+ end
+ return a_conv
+end
+
+local function ffi_to_table(a, m, n)
+ local res = {}
+
+ for i = 0, m - 1 do
+ res[i + 1] = {}
+ for j = 0, n - 1 do
+ res[i + 1][j + 1] = a[i * n + j]
+ end
+ end
+
+ return res
+end
+
+exports.sgemm = function(a, m, b, n, k, trans_a, trans_b)
+ if type(a) == 'table' then
+ -- Need to convert, slow!
+ a = table_to_ffi(a, m, k)
+ end
+ if type(b) == 'table' then
+ b = table_to_ffi(b, k, n)
+ end
+ local res = ffi.new("float[?]", m * n)
+ ffi.C.kad_sgemm_simple(trans_a or 0, trans_b or 0, m, n, k, ffi.cast('const float*', a),
+ ffi.cast('const float*', b), ffi.cast('float*', res))
+ return res
+end
+
+exports.eigen = function(a, n)
+ if type(a) == 'table' then
+ -- Need to convert, slow!
+ n = n or #a
+ a = table_to_ffi(a, n, n)
+ end
+
+ local res = ffi.new("float[?]", n)
+
+ if ffi.C.kad_ssyev_simple(n, ffi.cast('float*', a), res) then
+ return res, a
+ end
+
+ return nil
+end
+
+exports.ffi_to_table = ffi_to_table
+exports.table_to_ffi = table_to_ffi
+
+return exports \ No newline at end of file
diff --git a/lualib/lua_ffi/spf.lua b/lualib/lua_ffi/spf.lua
new file mode 100644
index 0000000..0f982f2
--- /dev/null
+++ b/lualib/lua_ffi/spf.lua
@@ -0,0 +1,143 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_ffi/spf
+-- This module contains ffi interfaces to SPF
+--]]
+
+local ffi = require 'ffi'
+
+ffi.cdef [[
+enum spf_mech_e {
+ SPF_FAIL,
+ SPF_SOFT_FAIL,
+ SPF_PASS,
+ SPF_NEUTRAL
+};
+static const unsigned RSPAMD_SPF_FLAG_IPV6 = (1 << 0);
+static const unsigned RSPAMD_SPF_FLAG_IPV4 = (1 << 1);
+static const unsigned RSPAMD_SPF_FLAG_ANY = (1 << 3);
+struct spf_addr {
+ unsigned char addr6[16];
+ unsigned char addr4[4];
+ union {
+ struct {
+ uint16_t mask_v4;
+ uint16_t mask_v6;
+ } dual;
+ uint32_t idx;
+ } m;
+ unsigned flags;
+ enum spf_mech_e mech;
+ char *spf_string;
+ struct spf_addr *prev, *next;
+};
+
+struct spf_resolved {
+ char *domain;
+ unsigned ttl;
+ int temp_failed;
+ int na;
+ int perm_failed;
+ uint64_t digest;
+ struct GArray *elts;
+ struct ref_entry_s ref;
+};
+
+typedef void (*spf_cb_t)(struct spf_resolved *record,
+ struct rspamd_task *task, void *data);
+struct rspamd_task;
+int rspamd_spf_resolve(struct rspamd_task *task, spf_cb_t callback,
+ void *cbdata);
+const char * rspamd_spf_get_domain (struct rspamd_task *task);
+struct spf_resolved * spf_record_ref (struct spf_resolved *rec);
+void spf_record_unref (struct spf_resolved *rec);
+char * spf_addr_mask_to_string (struct spf_addr *addr);
+struct spf_addr * spf_addr_match_task (struct rspamd_task *task, struct spf_resolved *rec);
+]]
+
+local function convert_mech(mech)
+ if mech == ffi.C.SPF_FAIL then
+ return 'fail'
+ elseif mech == ffi.C.SPF_SOFT_FAIL then
+ return 'softfail'
+ elseif mech == ffi.C.SPF_PASS then
+ return 'pass'
+ elseif mech == ffi.C.SPF_NEUTRAL then
+ return 'neutral'
+ end
+end
+
+local NULL = ffi.new 'void*'
+
+local function spf_addr_tolua(ffi_spf_addr)
+ local ipstr = ffi.C.spf_addr_mask_to_string(ffi_spf_addr)
+ local ret = {
+ res = convert_mech(ffi_spf_addr.mech),
+ ipnet = ffi.string(ipstr),
+ }
+
+ if ffi_spf_addr.spf_string ~= NULL then
+ ret.spf_str = ffi.string(ffi_spf_addr.spf_string)
+ end
+
+ ffi.C.g_free(ipstr)
+ return ret
+end
+
+local function spf_resolve(task, cb)
+ local function spf_cb(rec, _, _)
+ if not rec then
+ cb(false, 'record is empty')
+ else
+ local nelts = rec.elts.len
+ local elts = ffi.cast("struct spf_addr *", rec.elts.data)
+ local res = {
+ addrs = {}
+ }
+ local digstr = ffi.new("char[64]")
+ ffi.C.rspamd_snprintf(digstr, 64, "0x%xuL", rec.digest)
+ res.digest = ffi.string(digstr)
+ for i = 1, nelts do
+ res.addrs[i] = spf_addr_tolua(elts[i - 1])
+ end
+
+ local matched = ffi.C.spf_addr_match_task(task:topointer(), rec)
+
+ if matched ~= NULL then
+ cb(true, res, spf_addr_tolua(matched))
+ else
+ cb(true, res, nil)
+ end
+ end
+ end
+
+ local ret = ffi.C.rspamd_spf_resolve(task:topointer(), spf_cb, nil)
+
+ if not ret then
+ cb(false, 'cannot perform resolving')
+ end
+end
+
+local function spf_unref(rec)
+ ffi.C.spf_record_unref(rec)
+end
+
+return {
+ spf_resolve = spf_resolve,
+ spf_unref = spf_unref
+} \ No newline at end of file
diff --git a/lualib/lua_fuzzy.lua b/lualib/lua_fuzzy.lua
new file mode 100644
index 0000000..986d1a0
--- /dev/null
+++ b/lualib/lua_fuzzy.lua
@@ -0,0 +1,355 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_fuzzy
+-- This module contains helper functions for supporting fuzzy check module
+--]]
+
+
+local N = "lua_fuzzy"
+local lua_util = require "lua_util"
+local rspamd_regexp = require "rspamd_regexp"
+local fun = require "fun"
+local rspamd_logger = require "rspamd_logger"
+local ts = require("tableshape").types
+
+-- Filled by C code, indexed by number in this table
+local rules = {}
+
+-- Pre-defined rules options
+local policies = {
+ recommended = {
+ min_bytes = 1024,
+ min_height = 500,
+ min_width = 500,
+ min_length = 64,
+ text_multiplier = 4.0, -- divide min_bytes by 4 for texts
+ mime_types = { "application/*" },
+ scan_archives = true,
+ short_text_direct_hash = true,
+ text_shingles = true,
+ skip_images = false,
+ }
+}
+
+local default_policy = policies.recommended
+
+local schema_fields = {
+ min_bytes = ts.number + ts.string / tonumber,
+ min_height = ts.number + ts.string / tonumber,
+ min_width = ts.number + ts.string / tonumber,
+ min_length = ts.number + ts.string / tonumber,
+ text_multiplier = ts.number,
+ mime_types = ts.array_of(ts.string),
+ scan_archives = ts.boolean,
+ short_text_direct_hash = ts.boolean,
+ text_shingles = ts.boolean,
+ skip_images = ts.boolean,
+}
+local policy_schema = ts.shape(schema_fields)
+
+local policy_schema_open = ts.shape(schema_fields, {
+ open = true,
+})
+
+local exports = {}
+
+
+--[[[
+-- @function lua_fuzzy.register_policy(name, policy)
+-- Adds a new policy with name `name`. Must be valid, checked using policy_schema
+--]]
+exports.register_policy = function(name, policy)
+ if policies[name] then
+ rspamd_logger.warnx(rspamd_config, "overriding policy %s", name)
+ end
+
+ local parsed_policy, err = policy_schema:transform(policy)
+
+ if not parsed_policy then
+ rspamd_logger.errx(rspamd_config, 'invalid fuzzy rule policy %s: %s',
+ name, err)
+
+ return
+ else
+ policies.name = parsed_policy
+ end
+end
+
+--[[[
+-- @function lua_fuzzy.process_rule(rule)
+-- Processes fuzzy rule (applying policies or defaults if needed). Returns policy id
+--]]
+exports.process_rule = function(rule)
+ local processed_rule = lua_util.shallowcopy(rule)
+ local policy = default_policy
+
+ if processed_rule.policy then
+ policy = policies[processed_rule.policy]
+ end
+
+ if policy then
+ processed_rule = lua_util.override_defaults(policy, processed_rule)
+
+ local parsed_policy, err = policy_schema_open:transform(processed_rule)
+
+ if not parsed_policy then
+ rspamd_logger.errx(rspamd_config, 'invalid fuzzy rule default fields: %s', err)
+ else
+ processed_rule = parsed_policy
+ end
+ else
+ rspamd_logger.warnx(rspamd_config, "unknown policy %s", processed_rule.policy)
+ end
+
+ if processed_rule.mime_types then
+ processed_rule.mime_types = fun.totable(fun.map(function(gl)
+ return rspamd_regexp.import_glob(gl, 'i')
+ end, processed_rule.mime_types))
+ end
+
+ table.insert(rules, processed_rule)
+ return #rules
+end
+
+local function check_length(task, part, rule)
+ local bytes = part:get_length()
+ local length_ok = bytes > 0
+
+ local id = part:get_id()
+ lua_util.debugm(N, task, 'check size of part %s', id)
+
+ if length_ok and rule.min_bytes > 0 then
+
+ local adjusted_bytes = bytes
+
+ if part:is_text() then
+ -- Fuzzy plugin uses stripped utf content to get an exact hash, that
+ -- corresponds to `get_content_oneline()`
+ -- However, in the case of empty parts this method returns `nil`, so extra
+ -- sanity check is required.
+ bytes = #(part:get_text():get_content_oneline() or '')
+
+ -- Short hashing algorithm also use subject unless explicitly denied
+ if not rule.no_subject then
+ local subject = task:get_subject() or ''
+ bytes = bytes + #subject
+ end
+
+ if rule.text_multiplier then
+ adjusted_bytes = bytes * rule.text_multiplier
+ end
+ end
+
+ if rule.min_bytes > adjusted_bytes then
+ lua_util.debugm(N, task, 'skip part of length %s (%s adjusted) ' ..
+ 'as it has less than %s bytes',
+ bytes, adjusted_bytes, rule.min_bytes)
+ length_ok = false
+ else
+ lua_util.debugm(N, task, 'allow part of length %s (%s adjusted)',
+ bytes, adjusted_bytes, rule.min_bytes)
+ end
+ else
+ lua_util.debugm(N, task, 'allow part %s, no length limits', id)
+ end
+
+ return length_ok
+end
+
+local function check_text_part(task, part, rule, text)
+ local allow_direct, allow_shingles = false, false
+
+ local id = part:get_id()
+ lua_util.debugm(N, task, 'check text part %s', id)
+ local wcnt = text:get_words_count()
+
+ if rule.text_shingles then
+ -- Check number of words
+ local min_words = rule.min_length or 0
+ if min_words < 32 then
+ min_words = 32 -- Minimum for shingles
+ end
+ if wcnt < min_words then
+ lua_util.debugm(N, task, 'text has less than %s words: %s; disable shingles',
+ rule.min_length, wcnt)
+ allow_shingles = false
+ else
+ lua_util.debugm(N, task, 'allow shingles in text %s, %s words',
+ id, wcnt)
+ allow_shingles = true
+ end
+
+ if not rule.short_text_direct_hash and not allow_shingles then
+ allow_direct = false
+ else
+ if not allow_shingles then
+ lua_util.debugm(N, task,
+ 'allow direct hash for short text %s, %s words',
+ id, wcnt)
+ allow_direct = check_length(task, part, rule)
+ else
+ allow_direct = wcnt > 0
+ end
+ end
+ else
+ lua_util.debugm(N, task,
+ 'disable shingles in text %s', id)
+ allow_direct = check_length(task, part, rule)
+ end
+
+ return allow_direct, allow_shingles
+end
+
+--local function has_sane_text_parts(task)
+-- local text_parts = task:get_text_parts() or {}
+-- return fun.any(function(tp) return tp:get_words_count() > 32 end, text_parts)
+--end
+
+local function check_image_part(task, part, rule, image)
+ if rule.skip_images then
+ lua_util.debugm(N, task, 'skip image part as images are disabled')
+ return false, false
+ end
+
+ local id = part:get_id()
+ lua_util.debugm(N, task, 'check image part %s', id)
+
+ if rule.min_width > 0 or rule.min_height > 0 then
+ -- Check dimensions
+ local min_width = rule.min_width or rule.min_height
+ local min_height = rule.min_height or rule.min_width
+ local height = image:get_height()
+ local width = image:get_width()
+
+ if height and width then
+ if height < min_height or width < min_width then
+ lua_util.debugm(N, task, 'skip image part %s as it does not meet minimum sizes: %sx%s < %sx%s',
+ id, width, height, min_width, min_height)
+ return false, false
+ else
+ lua_util.debugm(N, task, 'allow image part %s: %sx%s',
+ id, width, height)
+ end
+ end
+ end
+
+ return check_length(task, part, rule), false
+end
+
+local function mime_types_check(task, part, rule)
+ local t, st = part:get_type()
+
+ if not t then
+ return false, false
+ end
+
+ local ct = string.format('%s/%s', t, st)
+
+ local detected_ct
+ t, st = part:get_detected_type()
+ if t then
+ detected_ct = string.format('%s/%s', t, st)
+ else
+ detected_ct = ct
+ end
+
+ local id = part:get_id()
+ lua_util.debugm(N, task, 'check binary part %s: %s', id, ct)
+
+ -- For bad mime parts we implicitly enable fuzzy check
+ local mime_trace = (task:get_symbol('MIME_TRACE') or {})[1]
+ local opts = {}
+
+ if mime_trace then
+ opts = mime_trace.options or opts
+ end
+ opts = fun.tomap(fun.map(function(opt)
+ local elts = lua_util.str_split(opt, ':')
+ return elts[1], elts[2]
+ end, opts))
+
+ if opts[id] and opts[id] == '-' then
+ lua_util.debugm(N, task, 'explicitly check binary part %s: bad mime type %s', id, ct)
+ return check_length(task, part, rule), false
+ end
+
+ if rule.mime_types then
+
+ if fun.any(function(gl_re)
+ if gl_re:match(ct) or (detected_ct and gl_re:match(detected_ct)) then
+ return true
+ else
+ return false
+ end
+ end, rule.mime_types) then
+ lua_util.debugm(N, task, 'found mime type match for part %s: %s (%s detected)',
+ id, ct, detected_ct)
+ return check_length(task, part, rule), false
+ end
+
+ return false, false
+ end
+
+ return false, false
+end
+
+exports.check_mime_part = function(task, part, rule_id)
+ local rule = rules[rule_id]
+
+ if not rule then
+ rspamd_logger.errx(task, 'cannot find rule with id %s', rule_id)
+
+ return false, false
+ end
+
+ if part:is_text() then
+ return check_text_part(task, part, rule, part:get_text())
+ end
+
+ if part:is_image() then
+ return check_image_part(task, part, rule, part:get_image())
+ end
+
+ if part:is_archive() and rule.scan_archives then
+ -- Always send archives
+ lua_util.debugm(N, task, 'check archive part %s', part:get_id())
+
+ return true, false
+ end
+
+ if part:is_specific() then
+ local sp = part:get_specific()
+
+ if type(sp) == 'table' and sp.fuzzy_hashes then
+ lua_util.debugm(N, task, 'check specific part %s', part:get_id())
+ return true, false
+ end
+ end
+
+ if part:is_attachment() then
+ return mime_types_check(task, part, rule)
+ end
+
+ return false, false
+end
+
+exports.cleanup_rules = function()
+ rules = {}
+end
+
+return exports
diff --git a/lualib/lua_lexer.lua b/lualib/lua_lexer.lua
new file mode 100644
index 0000000..54bbd7c
--- /dev/null
+++ b/lualib/lua_lexer.lua
@@ -0,0 +1,163 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[ Lua LPEG grammar based on https://github.com/xolox/lua-lxsh/ ]]
+
+
+local lpeg = require "lpeg"
+
+local P = lpeg.P
+local R = lpeg.R
+local S = lpeg.S
+local D = R '09' -- Digits
+local I = R('AZ', 'az', '\127\255') + '_' -- Identifiers
+local B = -(I + D) -- Word boundary
+local EOS = -lpeg.P(1) -- end of string
+
+-- Pattern for long strings and long comments.
+local longstring = #(P '[[' + (P '[' * P '=' ^ 0 * '[')) * P(function(input, index)
+ local level = input:match('^%[(=*)%[', index)
+ if level then
+ local _, last = input:find(']' .. level .. ']', index, true)
+ if last then
+ return last + 1
+ end
+ end
+end)
+
+-- String literals.
+local singlequoted = P "'" * ((1 - S "'\r\n\f\\") + (P '\\' * 1)) ^ 0 * "'"
+local doublequoted = P '"' * ((1 - S '"\r\n\f\\') + (P '\\' * 1)) ^ 0 * '"'
+
+-- Comments.
+local eol = P '\r\n' + '\n'
+local line = (1 - S '\r\n\f') ^ 0 * eol ^ -1
+local singleline = P '--' * line
+local multiline = P '--' * longstring
+
+-- Numbers.
+local sign = S '+-' ^ -1
+local decimal = D ^ 1
+local hexadecimal = P '0' * S 'xX' * R('09', 'AF', 'af') ^ 1
+local float = D ^ 1 * P '.' * D ^ 0 + P '.' * D ^ 1
+local maybeexp = (float + decimal) * (S 'eE' * sign * D ^ 1) ^ -1
+
+local function compile_keywords(keywords)
+ local list = {}
+ for word in keywords:gmatch('%S+') do
+ list[#list + 1] = word
+ end
+ -- Sort by length
+ table.sort(list, function(a, b)
+ return #a > #b
+ end)
+
+ local pattern
+ for _, word in ipairs(list) do
+ local p = lpeg.P(word)
+ pattern = pattern and (pattern + p) or p
+ end
+
+ local AB = B + EOS -- ending boundary
+ return pattern * AB
+end
+
+-- Identifiers
+local ident = I * (I + D) ^ 0
+local expr = ('.' * ident) ^ 0
+
+local patterns = {
+ { 'whitespace', S '\r\n\f\t\v ' ^ 1 },
+ { 'constant', (P 'true' + 'false' + 'nil') * B },
+ { 'string', singlequoted + doublequoted + longstring },
+ { 'comment', multiline + singleline },
+ { 'number', hexadecimal + maybeexp },
+ { 'operator', P 'not' + '...' + 'and' + '..' + '~=' + '==' + '>=' + '<='
+ + 'or' + S ']{=>^[<;)*(%}+-:,/.#' },
+ { 'keyword', compile_keywords([[
+ break do else elseif end for function if in local repeat return then until while
+ ]]) },
+ { 'identifier', lpeg.Cmt(ident,
+ function(input, index)
+ return expr:match(input, index)
+ end)
+ },
+ { 'error', 1 },
+}
+
+local compiled
+
+local function compile_patterns()
+ if not compiled then
+ local function process(elt)
+ local n, grammar = elt[1], elt[2]
+ return lpeg.Cc(n) * lpeg.P(grammar) * lpeg.Cp()
+ end
+ local any = process(patterns[1])
+ for i = 2, #patterns do
+ any = any + process(patterns[i])
+ end
+ compiled = any
+ end
+
+ return compiled
+end
+
+local function sync(token, lnum, cnum)
+ local lastidx
+ lnum, cnum = lnum or 1, cnum or 1
+ if token:find '\n' then
+ for i in token:gmatch '()\n' do
+ lnum = lnum + 1
+ lastidx = i
+ end
+ cnum = #token - lastidx + 1
+ else
+ cnum = cnum + #token
+ end
+ return lnum, cnum
+end
+
+local exports = {}
+
+exports.gmatch = function(input)
+ local parser = compile_patterns()
+ local index, lnum, cnum = 1, 1, 1
+
+ return function()
+ local kind, after = parser:match(input, index)
+ if kind and after then
+ local text = input:sub(index, after - 1)
+ local oldlnum, oldcnum = lnum, cnum
+ index = after
+ lnum, cnum = sync(text, lnum, cnum)
+ return kind, text, oldlnum, oldcnum
+ end
+ end
+end
+
+exports.lex_to_table = function(input)
+ local out = {}
+
+ for kind, text, lnum, cnum in exports.gmatch(input) do
+ out[#out + 1] = { kind, text, lnum, cnum }
+ end
+
+ return out
+end
+
+return exports
+
diff --git a/lualib/lua_magic/heuristics.lua b/lualib/lua_magic/heuristics.lua
new file mode 100644
index 0000000..b8a1b41
--- /dev/null
+++ b/lualib/lua_magic/heuristics.lua
@@ -0,0 +1,605 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_magic/heuristics
+-- This module contains heuristics for some specific cases
+--]]
+
+local rspamd_trie = require "rspamd_trie"
+local rspamd_util = require "rspamd_util"
+local lua_util = require "lua_util"
+local bit = require "bit"
+local fun = require "fun"
+
+local N = "lua_magic"
+local msoffice_trie
+local msoffice_patterns = {
+ doc = { [[WordDocument]] },
+ xls = { [[Workbook]], [[Book]] },
+ ppt = { [[PowerPoint Document]], [[Current User]] },
+ vsd = { [[VisioDocument]] },
+}
+local msoffice_trie_clsid
+local msoffice_clsids = {
+ doc = { [[0609020000000000c000000000000046]] },
+ xls = { [[1008020000000000c000000000000046]], [[2008020000000000c000000000000046]] },
+ ppt = { [[108d81649b4fcf1186ea00aa00b929e8]] },
+ msg = { [[46f0060000000000c000000000000046]], [[0b0d020000000000c000000000000046]] },
+ msi = { [[84100c0000000000c000000000000046]] },
+}
+local zip_trie
+local zip_patterns = {
+ -- https://lists.oasis-open.org/archives/office/200505/msg00006.html
+ odt = {
+ [[mimetypeapplication/vnd\.oasis\.opendocument\.text]],
+ [[mimetypeapplication/vnd\.oasis\.opendocument\.image]],
+ [[mimetypeapplication/vnd\.oasis\.opendocument\.graphic]]
+ },
+ ods = {
+ [[mimetypeapplication/vnd\.oasis\.opendocument\.spreadsheet]],
+ [[mimetypeapplication/vnd\.oasis\.opendocument\.formula]],
+ [[mimetypeapplication/vnd\.oasis\.opendocument\.chart]]
+ },
+ odp = { [[mimetypeapplication/vnd\.oasis\.opendocument\.presentation]] },
+ epub = { [[epub\+zip]] },
+ asice = { [[mimetypeapplication/vnd\.etsi\.asic-e\+zipPK]] },
+ asics = { [[mimetypeapplication/vnd\.etsi\.asic-s\+zipPK]] },
+}
+
+local txt_trie
+local txt_patterns = {
+ html = {
+ { [=[(?i)<html[\s>]]=], 32 },
+ { [[(?i)<script\b]], 20 }, -- Commonly used by spammers
+ { [[<script\s+type="text\/javascript">]], 31 }, -- Another spammy pattern
+ { [[(?i)<\!DOCTYPE HTML\b]], 33 },
+ { [[(?i)<body\b]], 20 },
+ { [[(?i)<table\b]], 20 },
+ { [[(?i)<a\s]], 10 },
+ { [[(?i)<p\b]], 10 },
+ { [[(?i)<div\b]], 10 },
+ { [[(?i)<span\b]], 10 },
+ },
+ csv = {
+ { [[(?:[-a-zA-Z0-9_]+\s*,){2,}(?:[-a-zA-Z0-9_]+,?[ ]*[\r\n])]], 20 }
+ },
+ ics = {
+ { [[^BEGIN:VCALENDAR\r?\n]], 40 },
+ },
+ vcf = {
+ { [[^BEGIN:VCARD\r?\n]], 40 },
+ },
+ xml = {
+ { [[<\?xml\b.+\?>]], 31 },
+ }
+}
+
+-- Used to match pattern index and extension
+local msoffice_clsid_indexes = {}
+local msoffice_patterns_indexes = {}
+local zip_patterns_indexes = {}
+local txt_patterns_indexes = {}
+
+local exports = {}
+
+local function compile_tries()
+ local default_compile_flags = bit.bor(rspamd_trie.flags.re,
+ rspamd_trie.flags.dot_all,
+ rspamd_trie.flags.single_match,
+ rspamd_trie.flags.no_start)
+ local function compile_pats(patterns, indexes, transform_func, compile_flags)
+ local strs = {}
+ for ext, pats in pairs(patterns) do
+ for _, pat in ipairs(pats) do
+ -- These are utf16 strings in fact...
+ strs[#strs + 1] = transform_func(pat)
+ indexes[#indexes + 1] = { ext, pat }
+ end
+ end
+
+ return rspamd_trie.create(strs, compile_flags or default_compile_flags)
+ end
+
+ if not msoffice_trie then
+ -- Directory names
+ local function msoffice_pattern_transform(pat)
+ return '^' ..
+ table.concat(
+ fun.totable(
+ fun.map(function(c)
+ return c .. [[\x{00}]]
+ end,
+ fun.iter(pat))))
+ end
+ local function msoffice_clsid_transform(pat)
+ local hex_table = {}
+ for i = 1, #pat, 2 do
+ local subc = pat:sub(i, i + 1)
+ hex_table[#hex_table + 1] = string.format('\\x{%s}', subc)
+ end
+
+ return '^' .. table.concat(hex_table) .. '$'
+ end
+ -- Directory entries
+ msoffice_trie = compile_pats(msoffice_patterns, msoffice_patterns_indexes,
+ msoffice_pattern_transform)
+ -- Clsids
+ msoffice_trie_clsid = compile_pats(msoffice_clsids, msoffice_clsid_indexes,
+ msoffice_clsid_transform)
+ -- Misc zip patterns at the initial fragment
+ zip_trie = compile_pats(zip_patterns, zip_patterns_indexes,
+ function(pat)
+ return pat
+ end)
+ -- Text patterns at the initial fragment
+ txt_trie = compile_pats(txt_patterns, txt_patterns_indexes,
+ function(pat_tbl)
+ return pat_tbl[1]
+ end,
+ bit.bor(rspamd_trie.flags.re,
+ rspamd_trie.flags.dot_all,
+ rspamd_trie.flags.no_start))
+ end
+end
+
+-- Call immediately on require
+compile_tries()
+
+local function detect_ole_format(input, log_obj, _, part)
+ local inplen = #input
+ if inplen < 0x31 + 4 then
+ lua_util.debugm(N, log_obj, "short length: %s", inplen)
+ return nil
+ end
+
+ local bom, sec_size = rspamd_util.unpack('<I2<I2', input:span(29, 4))
+ if bom == 0xFFFE then
+ bom = '<'
+ else
+ lua_util.debugm(N, log_obj, "bom file!: %s", bom)
+ bom = '>';
+ sec_size = bit.bswap(sec_size)
+ end
+
+ if sec_size < 7 or sec_size > 31 then
+ lua_util.debugm(N, log_obj, "bad sec_size: %s", sec_size)
+ return nil
+ end
+
+ sec_size = 2 ^ sec_size
+
+ -- SecID of first sector of the directory stream
+ local directory_offset = (rspamd_util.unpack(bom .. 'I4', input:span(0x31, 4)))
+ * sec_size + 512 + 1
+ lua_util.debugm(N, log_obj, "directory: %s", directory_offset)
+
+ if inplen < directory_offset then
+ lua_util.debugm(N, log_obj, "short length: %s", inplen)
+ return nil
+ end
+
+ local function process_dir_entry(offset)
+ local dtype = input:byte(offset + 66)
+ lua_util.debugm(N, log_obj, "dtype: %s, offset: %s", dtype, offset)
+
+ if dtype then
+ if dtype == 5 then
+ -- Extract clsid
+ local matches = msoffice_trie_clsid:match(input:span(offset + 80, 16))
+ if matches then
+ for n, _ in pairs(matches) do
+ if msoffice_clsid_indexes[n] then
+ lua_util.debugm(N, log_obj, "found valid clsid for %s",
+ msoffice_clsid_indexes[n][1])
+ return true, msoffice_clsid_indexes[n][1]
+ end
+ end
+ end
+ return true, nil
+ elseif dtype == 2 then
+ local matches = msoffice_trie:match(input:span(offset, 64))
+ if matches then
+ for n, _ in pairs(matches) do
+ if msoffice_patterns_indexes[n] then
+ return true, msoffice_patterns_indexes[n][1]
+ end
+ end
+ end
+ return true, nil
+ elseif dtype >= 0 and dtype < 5 then
+ -- Bad type
+ return true, nil
+ end
+ end
+
+ return false, nil
+ end
+
+ repeat
+ local res, ext = process_dir_entry(directory_offset)
+
+ if res and ext then
+ return ext, 60
+ end
+
+ if not res then
+ break
+ end
+
+ directory_offset = directory_offset + 128
+ until directory_offset >= inplen
+end
+
+exports.ole_format_heuristic = detect_ole_format
+
+local function process_top_detected(res)
+ local extensions = lua_util.keys(res)
+
+ if #extensions > 0 then
+ table.sort(extensions, function(ex1, ex2)
+ return res[ex1] > res[ex2]
+ end)
+
+ return extensions[1], res[extensions[1]]
+ end
+
+ return nil
+end
+
+local function detect_archive_flaw(part, arch, log_obj, _)
+ local arch_type = arch:get_type()
+ local res = {
+ docx = 0,
+ xlsx = 0,
+ pptx = 0,
+ jar = 0,
+ odt = 0,
+ odp = 0,
+ ods = 0,
+ apk = 0,
+ } -- ext + confidence pairs
+
+ -- General msoffice patterns
+ local function add_msoffice_confidence(incr)
+ res.docx = res.docx + incr
+ res.xlsx = res.xlsx + incr
+ res.pptx = res.pptx + incr
+ end
+
+ if arch_type == 'zip' then
+ -- Find specific files/folders in zip file
+ local files = arch:get_files(100) or {}
+ for _, file in ipairs(files) do
+ if file == '[Content_Types].xml' then
+ add_msoffice_confidence(10)
+ elseif file:sub(1, 3) == 'xl/' then
+ res.xlsx = res.xlsx + 30
+ elseif file:sub(1, 5) == 'word/' then
+ res.docx = res.docx + 30
+ elseif file:sub(1, 4) == 'ppt/' then
+ res.pptx = res.pptx + 30
+ elseif file == 'META-INF/MANIFEST.MF' then
+ res.jar = res.jar + 40
+ elseif file == 'AndroidManifest.xml' then
+ res.apk = res.apk + 60
+ end
+ end
+
+ local ext, weight = process_top_detected(res)
+
+ if weight >= 40 then
+ return ext, weight
+ end
+
+ -- Apply misc Zip detection logic
+ local content = part:get_content()
+
+ if #content > 128 then
+ local start_span = content:span(1, 128)
+
+ local matches = zip_trie:match(start_span)
+ if matches then
+ for n, _ in pairs(matches) do
+ if zip_patterns_indexes[n] then
+ lua_util.debugm(N, log_obj, "found zip pattern for %s",
+ zip_patterns_indexes[n][1])
+ return zip_patterns_indexes[n][1], 40
+ end
+ end
+ end
+ end
+ end
+
+ return arch_type:lower(), 40
+end
+
+local csv_grammar
+-- Returns a grammar that will count commas
+local function get_csv_grammar()
+ if not csv_grammar then
+ local lpeg = require 'lpeg'
+
+ local field = '"' * lpeg.Cs(((lpeg.P(1) - '"') + lpeg.P '""' / '"') ^ 0) * '"' +
+ lpeg.C((1 - lpeg.S ',\n"') ^ 0)
+
+ csv_grammar = lpeg.Cf(lpeg.Cc(0) * field * lpeg.P((lpeg.P(',') +
+ lpeg.P('\t')) * field) ^ 1 * (lpeg.S '\r\n' + -1),
+ function(acc)
+ return acc + 1
+ end)
+ end
+
+ return csv_grammar
+end
+local function validate_csv(part, content, log_obj)
+ local max_chunk = 32768
+ local chunk = content:sub(1, max_chunk)
+
+ local expected_commas
+ local matched_lines = 0
+ local max_matched_lines = 10
+
+ lua_util.debugm(N, log_obj, "check for csv pattern")
+
+ for s in chunk:lines() do
+ local ncommas = get_csv_grammar():match(s)
+
+ if not ncommas then
+ lua_util.debugm(N, log_obj, "not a csv line at line number %s",
+ matched_lines)
+ return false
+ end
+
+ if expected_commas and ncommas ~= expected_commas then
+ -- Mismatched commas
+ lua_util.debugm(N, log_obj, "missmatched commas on line %s: %s != %s",
+ matched_lines, ncommas, expected_commas)
+ return false
+ elseif not expected_commas then
+ if ncommas == 0 then
+ lua_util.debugm(N, log_obj, "no commas in the first line")
+ return false
+ end
+ expected_commas = ncommas
+ end
+
+ matched_lines = matched_lines + 1
+
+ if matched_lines > max_matched_lines then
+ break
+ end
+ end
+
+ lua_util.debugm(N, log_obj, "csv content is sane: %s fields; %s lines checked",
+ expected_commas, matched_lines)
+
+ return true
+end
+
+exports.mime_part_heuristic = function(part, log_obj, _)
+ if part:is_archive() then
+ local arch = part:get_archive()
+ return detect_archive_flaw(part, arch, log_obj)
+ end
+
+ return nil
+end
+
+exports.text_part_heuristic = function(part, log_obj, _)
+ -- We get some span of data and check it
+ local function is_span_text(span)
+ -- We examine 8 bit content, and we assume it might be localized text
+ -- if it has more than 3 subsequent 8 bit characters
+ local function rough_8bit_check(bytes, idx, remain, len)
+ local b = bytes[idx]
+ local n8bit = 0
+
+ while b >= 127 and idx < len do
+ -- utf8 part
+ if bit.band(b, 0xe0) == 0xc0 and remain > 1 and
+ bit.band(bytes[idx + 1], 0xc0) == 0x80 then
+ return true, 1
+ elseif bit.band(b, 0xf0) == 0xe0 and remain > 2 and
+ bit.band(bytes[idx + 1], 0xc0) == 0x80 and
+ bit.band(bytes[idx + 2], 0xc0) == 0x80 then
+ return true, 2
+ elseif bit.band(b, 0xf8) == 0xf0 and remain > 3 and
+ bit.band(bytes[idx + 1], 0xc0) == 0x80 and
+ bit.band(bytes[idx + 2], 0xc0) == 0x80 and
+ bit.band(bytes[idx + 3], 0xc0) == 0x80 then
+ return true, 3
+ end
+
+ n8bit = n8bit + 1
+ idx = idx + 1
+ b = bytes[idx]
+ remain = remain - 1
+ end
+
+ if n8bit >= 3 then
+ return true, n8bit
+ end
+
+ return false, 0
+ end
+
+ -- Convert to string as LuaJIT can optimise string.sub (and fun.iter) but not C calls
+ local tlen = #span
+ local non_printable = 0
+ local bytes = span:bytes()
+ local i = 1
+ repeat
+ local b = bytes[i]
+
+ if (b < 0x20) and not (b == 0x0d or b == 0x0a or b == 0x09) then
+ non_printable = non_printable + 1
+ elseif b >= 127 then
+ local c, nskip = rough_8bit_check(bytes, i, tlen - i, tlen)
+
+ if not c then
+ non_printable = non_printable + 1
+ else
+ i = i + nskip
+ end
+ end
+ i = i + 1
+ until i > tlen
+
+ lua_util.debugm(N, log_obj, "text part check: %s printable, %s non-printable, %s total",
+ tlen - non_printable, non_printable, tlen)
+ if non_printable / tlen > 0.0078125 then
+ return false
+ end
+
+ return true
+ end
+
+ local parent = part:get_parent()
+
+ if parent then
+ local parent_type, parent_subtype = parent:get_type()
+
+ if parent_type == 'multipart' and parent_subtype == 'encrypted' then
+ -- Skip text heuristics for encrypted parts
+ lua_util.debugm(N, log_obj, "text part check: parent is encrypted, not a text part")
+
+ return false
+ end
+ end
+
+ local content = part:get_content()
+ local mtype, msubtype = part:get_type()
+ local clen = #content
+ local is_text
+
+ if clen > 0 then
+ if clen > 80 * 3 then
+ -- Use chunks
+ is_text = is_span_text(content:span(1, 160)) and is_span_text(content:span(clen - 80, 80))
+ else
+ is_text = is_span_text(content)
+ end
+
+ if is_text and mtype ~= 'message' then
+ -- Try patterns
+ local span_len = math.min(4096, clen)
+ local start_span = content:span(1, span_len)
+ local matches = txt_trie:match(start_span)
+ local res = {}
+ local fname = part:get_filename()
+
+ if matches then
+ -- Require at least 2 occurrences of those patterns
+ for n, positions in pairs(matches) do
+ local ext, weight = txt_patterns_indexes[n][1], txt_patterns_indexes[n][2][2]
+ if ext then
+ res[ext] = (res[ext] or 0) + weight * #positions
+ lua_util.debugm(N, log_obj, "found txt pattern for %s: %s, total: %s; %s/%s announced",
+ ext, weight * #positions, res[ext], mtype, msubtype)
+ end
+ end
+
+ if res.html and res.html >= 40 then
+ -- HTML has priority over something like js...
+ return 'html', res.html
+ end
+
+ local ext, weight = process_top_detected(res)
+
+ if weight then
+ if weight >= 40 then
+ -- Extra validation for csv extension
+ if ext ~= 'csv' or validate_csv(part, content, log_obj) then
+ return ext, weight
+ end
+ elseif fname and weight >= 20 then
+ return ext, weight
+ end
+ end
+ end
+
+ -- Content type stuff
+ if (mtype == 'text' or mtype == 'application') and
+ (msubtype == 'html' or msubtype == 'xhtml+xml') then
+ return 'html', 21
+ end
+
+ if msubtype:lower() == 'csv' then
+ if validate_csv(part, content, log_obj) then
+ return 'csv', 40
+ end
+ end
+
+ -- Extension stuff
+ local function has_extension(file, ext)
+ local ext_len = ext:len()
+ return file:len() > ext_len + 1
+ and file:sub(-ext_len):lower() == ext
+ and file:sub(-ext_len - 1, -ext_len - 1) == '.'
+ end
+
+ if fname and (has_extension(fname, 'htm') or has_extension(fname, 'html')) then
+ return 'html', 21
+ end
+
+ if mtype ~= 'text' then
+ -- Do not treat non text patterns as text
+ return nil
+ end
+
+ return 'txt', 40
+ end
+ end
+end
+
+exports.pdf_format_heuristic = function(input, log_obj, pos, part)
+ local weight = 10
+ local ext = string.match(part:get_filename() or '', '%.([^.]+)$')
+ -- If we found a pattern at the beginning
+ if pos <= 10 then
+ weight = weight + 30
+ end
+ -- If the announced extension is `pdf`
+ if ext and ext:lower() == 'pdf' then
+ weight = weight + 30
+ end
+
+ return 'pdf', weight
+end
+
+exports.pe_part_heuristic = function(input, log_obj, pos, part)
+ if not input then
+ return
+ end
+
+ -- pe header should start at the offset that is placed in msdos header at position 60..64
+ local pe_ptr_bin = input:sub(60, 64)
+ if #pe_ptr_bin ~= 4 then
+ return
+ end
+
+ -- it is an LE 32 bit integer
+ local pe_ptr = rspamd_util.unpack("<I4", pe_ptr_bin)
+ -- if pe header magic matches the offset, it is definitely a PE file
+ if pe_ptr ~= pos then
+ return
+ end
+
+ return 'exe', 30
+end
+
+return exports
diff --git a/lualib/lua_magic/init.lua b/lualib/lua_magic/init.lua
new file mode 100644
index 0000000..38bfddb
--- /dev/null
+++ b/lualib/lua_magic/init.lua
@@ -0,0 +1,388 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_magic
+-- This module contains file types detection logic
+--]]
+
+local patterns = require "lua_magic/patterns"
+local types = require "lua_magic/types"
+local heuristics = require "lua_magic/heuristics"
+local fun = require "fun"
+local lua_util = require "lua_util"
+
+local rspamd_text = require "rspamd_text"
+local rspamd_trie = require "rspamd_trie"
+
+local N = "lua_magic"
+local exports = {}
+-- trie objects
+local compiled_patterns
+local compiled_short_patterns
+local compiled_tail_patterns
+-- {<str>, <match_object>, <pattern_object>} indexed by pattern number
+local processed_patterns = {}
+local short_patterns = {}
+local tail_patterns = {}
+
+local short_match_limit = 128
+local max_short_offset = -1
+local min_tail_offset = math.huge
+
+local function process_patterns(log_obj)
+ -- Add pattern to either short patterns or to normal patterns
+ local function add_processed(str, match, pattern)
+ if match.position and type(match.position) == 'number' then
+ if match.tail then
+ -- Tail pattern
+ tail_patterns[#tail_patterns + 1] = {
+ str, match, pattern
+ }
+ if min_tail_offset > match.tail then
+ min_tail_offset = match.tail
+ end
+
+ lua_util.debugm(N, log_obj, 'add tail pattern %s for ext %s',
+ str, pattern.ext)
+ elseif match.position < short_match_limit then
+ short_patterns[#short_patterns + 1] = {
+ str, match, pattern
+ }
+ if str:sub(1, 1) == '^' then
+ lua_util.debugm(N, log_obj, 'add head pattern %s for ext %s',
+ str, pattern.ext)
+ else
+ lua_util.debugm(N, log_obj, 'add short pattern %s for ext %s',
+ str, pattern.ext)
+ end
+
+ if max_short_offset < match.position then
+ max_short_offset = match.position
+ end
+ else
+ processed_patterns[#processed_patterns + 1] = {
+ str, match, pattern
+ }
+
+ lua_util.debugm(N, log_obj, 'add long pattern %s for ext %s',
+ str, pattern.ext)
+ end
+ else
+ processed_patterns[#processed_patterns + 1] = {
+ str, match, pattern
+ }
+
+ lua_util.debugm(N, log_obj, 'add long pattern %s for ext %s',
+ str, pattern.ext)
+ end
+ end
+
+ if not compiled_patterns then
+ for ext, pattern in pairs(patterns) do
+ assert(types[ext], 'not found type: ' .. ext)
+ pattern.ext = ext
+ for _, match in ipairs(pattern.matches) do
+ if match.string then
+ if match.relative_position and not match.position then
+ match.position = match.relative_position + #match.string
+
+ if match.relative_position == 0 then
+ if match.string:sub(1, 1) ~= '^' then
+ match.string = '^' .. match.string
+ end
+ end
+ end
+ add_processed(match.string, match, pattern)
+ elseif match.hex then
+ local hex_table = {}
+
+ for i = 1, #match.hex, 2 do
+ local subc = match.hex:sub(i, i + 1)
+ hex_table[#hex_table + 1] = string.format('\\x{%s}', subc)
+ end
+
+ if match.relative_position and not match.position then
+ match.position = match.relative_position + #match.hex / 2
+ end
+ if match.relative_position == 0 then
+ table.insert(hex_table, 1, '^')
+ end
+ add_processed(table.concat(hex_table), match, pattern)
+ end
+ end
+ end
+ local bit = require "bit"
+ local compile_flags = bit.bor(rspamd_trie.flags.re, rspamd_trie.flags.dot_all)
+ compile_flags = bit.bor(compile_flags, rspamd_trie.flags.single_match)
+ compile_flags = bit.bor(compile_flags, rspamd_trie.flags.no_start)
+ compiled_patterns = rspamd_trie.create(fun.totable(
+ fun.map(function(t)
+ return t[1]
+ end, processed_patterns)),
+ compile_flags
+ )
+ compiled_short_patterns = rspamd_trie.create(fun.totable(
+ fun.map(function(t)
+ return t[1]
+ end, short_patterns)),
+ compile_flags
+ )
+ compiled_tail_patterns = rspamd_trie.create(fun.totable(
+ fun.map(function(t)
+ return t[1]
+ end, tail_patterns)),
+ compile_flags
+ )
+
+ lua_util.debugm(N, log_obj,
+ 'compiled %s (%s short; %s long; %s tail) patterns',
+ #processed_patterns + #short_patterns + #tail_patterns,
+ #short_patterns, #processed_patterns, #tail_patterns)
+ end
+end
+
+process_patterns(rspamd_config)
+
+local function match_chunk(chunk, input, tlen, offset, trie, processed_tbl, log_obj, res, part)
+ local matches = trie:match(chunk)
+
+ local last = tlen
+
+ local function add_result(weight, ext)
+ if not res[ext] then
+ res[ext] = 0
+ end
+ if weight then
+ res[ext] = res[ext] + weight
+ else
+ res[ext] = res[ext] + 1
+ end
+
+ lua_util.debugm(N, log_obj, 'add pattern for %s, weight %s, total weight %s',
+ ext, weight, res[ext])
+ end
+
+ local function match_position(pos, expected)
+ local cmp = function(a, b)
+ return a == b
+ end
+ if type(expected) == 'table' then
+ -- Something like {'>', 0}
+ if expected[1] == '>' then
+ cmp = function(a, b)
+ return a > b
+ end
+ elseif expected[1] == '>=' then
+ cmp = function(a, b)
+ return a >= b
+ end
+ elseif expected[1] == '<' then
+ cmp = function(a, b)
+ return a < b
+ end
+ elseif expected[1] == '<=' then
+ cmp = function(a, b)
+ return a <= b
+ end
+ elseif expected[1] == '!=' then
+ cmp = function(a, b)
+ return a ~= b
+ end
+ end
+ expected = expected[2]
+ end
+
+ -- Tail match
+ if expected < 0 then
+ expected = last + expected + 1
+ end
+ return cmp(pos, expected)
+ end
+
+ for npat, matched_positions in pairs(matches) do
+ local pat_data = processed_tbl[npat]
+ local pattern = pat_data[3]
+ local match = pat_data[2]
+
+ -- Single position
+ if match.position then
+ local position = match.position
+
+ for _, pos in ipairs(matched_positions) do
+ lua_util.debugm(N, log_obj, 'found match %s at offset %s(from %s)',
+ pattern.ext, pos, offset)
+ if match_position(pos + offset, position) then
+ if match.heuristic then
+ local ext, weight = match.heuristic(input, log_obj, pos + offset, part)
+
+ if ext then
+ add_result(weight, ext)
+ break
+ end
+ else
+ add_result(match.weight, pattern.ext)
+ break
+ end
+ end
+ end
+ elseif match.positions then
+ -- Match all positions
+ local all_right = true
+ local matched_pos = 0
+ for _, position in ipairs(match.positions) do
+ local matched = false
+ for _, pos in ipairs(matched_positions) do
+ lua_util.debugm(N, log_obj, 'found match %s at offset %s(from %s)',
+ pattern.ext, pos, offset)
+ if not match_position(pos + offset, position) then
+ matched = true
+ matched_pos = pos
+ break
+ end
+ end
+ if not matched then
+ all_right = false
+ break
+ end
+ end
+
+ if all_right then
+ if match.heuristic then
+ local ext, weight = match.heuristic(input, log_obj, matched_pos + offset, part)
+
+ if ext then
+ add_result(weight, ext)
+ break
+ end
+ else
+ add_result(match.weight, pattern.ext)
+ break
+ end
+ end
+ end
+ end
+
+end
+
+local function process_detected(res)
+ local extensions = lua_util.keys(res)
+
+ if #extensions > 0 then
+ table.sort(extensions, function(ex1, ex2)
+ return res[ex1] > res[ex2]
+ end)
+
+ return extensions, res[extensions[1]]
+ end
+
+ return nil
+end
+
+exports.detect = function(part, log_obj)
+ if not log_obj then
+ log_obj = rspamd_config
+ end
+ local input = part:get_content()
+
+ local res = {}
+
+ if type(input) == 'string' then
+ -- Convert to rspamd_text
+ input = rspamd_text.fromstring(input)
+ end
+
+ if type(input) == 'userdata' then
+ local inplen = #input
+
+ -- Check tail matches
+ if inplen > min_tail_offset then
+ local tail = input:span(inplen - min_tail_offset, min_tail_offset)
+ match_chunk(tail, input, inplen, inplen - min_tail_offset,
+ compiled_tail_patterns, tail_patterns, log_obj, res, part)
+ end
+
+ -- Try short match
+ local head = input:span(1, math.min(max_short_offset, inplen))
+ match_chunk(head, input, inplen, 0,
+ compiled_short_patterns, short_patterns, log_obj, res, part)
+
+ -- Check if we have enough data or go to long patterns
+ local extensions, confidence = process_detected(res)
+
+ if extensions and #extensions > 0 and confidence > 30 then
+ -- We are done on short patterns
+ return extensions[1], types[extensions[1]]
+ end
+
+ -- No way, let's check data in chunks or just the whole input if it is small enough
+ if #input > exports.chunk_size * 3 then
+ -- Chunked version as input is too long
+ local chunk1, chunk2 = input:span(1, exports.chunk_size * 2),
+ input:span(inplen - exports.chunk_size, exports.chunk_size)
+ local offset1, offset2 = 0, inplen - exports.chunk_size
+
+ match_chunk(chunk1, input, inplen,
+ offset1, compiled_patterns, processed_patterns, log_obj, res, part)
+ match_chunk(chunk2, input, inplen,
+ offset2, compiled_patterns, processed_patterns, log_obj, res, part)
+ else
+ -- Input is short enough to match it at all
+ match_chunk(input, input, inplen, 0,
+ compiled_patterns, processed_patterns, log_obj, res, part)
+ end
+ else
+ -- Table input is NYI
+ assert(0, 'table input for match')
+ end
+
+ local extensions = process_detected(res)
+
+ if extensions and #extensions > 0 then
+ return extensions[1], types[extensions[1]]
+ end
+
+ -- Nothing found
+ return nil
+end
+
+exports.detect_mime_part = function(part, log_obj)
+ local ext, weight = heuristics.mime_part_heuristic(part, log_obj)
+
+ if ext and weight and weight > 20 then
+ return ext, types[ext]
+ end
+
+ ext = exports.detect(part, log_obj)
+
+ if ext then
+ return ext, types[ext]
+ end
+
+ -- Text/html and other parts
+ ext, weight = heuristics.text_part_heuristic(part, log_obj)
+ if ext and weight and weight > 20 then
+ return ext, types[ext]
+ end
+end
+
+-- This parameter specifies how many bytes are checked in the input
+-- Rspamd checks 2 chunks at start and 1 chunk at the end
+exports.chunk_size = 32768
+
+exports.types = types
+
+return exports \ No newline at end of file
diff --git a/lualib/lua_magic/patterns.lua b/lualib/lua_magic/patterns.lua
new file mode 100644
index 0000000..971ddd9
--- /dev/null
+++ b/lualib/lua_magic/patterns.lua
@@ -0,0 +1,471 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_magic/patterns
+-- This module contains most common patterns
+--]]
+
+local heuristics = require "lua_magic/heuristics"
+
+local patterns = {
+ pdf = {
+ -- These are alternatives
+ matches = {
+ {
+ string = [[%PDF-[12]\.\d]],
+ position = { '<=', 1024 },
+ weight = 60,
+ heuristic = heuristics.pdf_format_heuristic
+ },
+ {
+ string = [[%FDF-[12]\.\d]],
+ position = { '<=', 1024 },
+ weight = 60,
+ heuristic = heuristics.pdf_format_heuristic
+ },
+ },
+ },
+ ps = {
+ matches = {
+ {
+ string = [[%!PS-Adobe]],
+ relative_position = 0,
+ weight = 60,
+ },
+ },
+ },
+ -- RTF document
+ rtf = {
+ matches = {
+ {
+ string = [[^{\\rt]],
+ position = 4,
+ weight = 60,
+ }
+ }
+ },
+ chm = {
+ matches = {
+ {
+ string = [[ITSF]],
+ relative_position = 0,
+ weight = 60,
+ }
+ }
+ },
+ djvu = {
+ matches = {
+ {
+ string = [[AT&TFORM]],
+ relative_position = 0,
+ weight = 60,
+ },
+ {
+ string = [[DJVM]],
+ relative_position = 0x0c,
+ weight = 60,
+ }
+ }
+ },
+ -- MS Office format, needs heuristic
+ ole = {
+ matches = {
+ {
+ hex = [[d0cf11e0a1b11ae1]],
+ relative_position = 0,
+ weight = 60,
+ heuristic = heuristics.ole_format_heuristic
+ }
+ }
+ },
+ -- MS Exe file
+ exe = {
+ matches = {
+ {
+ string = [[MZ]],
+ relative_position = 0,
+ weight = 15,
+ },
+ -- PE part
+ {
+ string = [[PE\x{00}\x{00}]],
+ position = { '>=', 0x3c + 4 },
+ weight = 15,
+ heuristic = heuristics.pe_part_heuristic,
+ }
+ }
+ },
+ elf = {
+ matches = {
+ {
+ hex = [[7f454c46]],
+ relative_position = 0,
+ weight = 60,
+ },
+ }
+ },
+ lnk = {
+ matches = {
+ {
+ hex = [[4C0000000114020000000000C000000000000046]],
+ relative_position = 0,
+ weight = 60,
+ },
+ }
+ },
+ bat = {
+ matches = {
+ {
+ string = [[(?i)@\s*ECHO\s+OFF]],
+ position = { '>=', 0 },
+ weight = 60,
+ },
+ }
+ },
+ class = {
+ -- Technically, this also matches MachO files, but I don't care about
+ -- Apple and their mental health problems here: just consider Java files,
+ -- Mach object files and all other cafe babes as bad and block them!
+ matches = {
+ {
+ hex = [[cafebabe]],
+ relative_position = 0,
+ weight = 60,
+ },
+ }
+ },
+ ics = {
+ matches = {
+ {
+ string = [[BEGIN:VCALENDAR]],
+ weight = 60,
+ relative_position = 0,
+ }
+ }
+ },
+ vcf = {
+ matches = {
+ {
+ string = [[BEGIN:VCARD]],
+ weight = 60,
+ relative_position = 0,
+ }
+ }
+ },
+ -- Archives
+ arj = {
+ matches = {
+ {
+ hex = '60EA',
+ relative_position = 0,
+ weight = 60,
+ },
+ }
+ },
+ ace = {
+ matches = {
+ {
+ string = [[\*\*ACE\*\*]],
+ position = 14,
+ weight = 60,
+ },
+ }
+ },
+ cab = {
+ matches = {
+ {
+ hex = [[4d53434600000000]], -- Can be anywhere for SFX :(
+ position = { '>=', 8 },
+ weight = 60,
+ },
+ }
+ },
+ tar = {
+ matches = {
+ {
+ string = [[ustar]],
+ relative_position = 257,
+ weight = 60,
+ },
+ }
+ },
+ bz2 = {
+ matches = {
+ {
+ string = "^BZ[h0]",
+ position = 3,
+ weight = 60,
+ },
+ }
+ },
+ lz4 = {
+ matches = {
+ {
+ hex = "04224d18",
+ relative_position = 0,
+ weight = 60,
+ },
+ {
+ hex = "03214c18",
+ relative_position = 0,
+ weight = 60,
+ },
+ {
+ hex = "02214c18",
+ relative_position = 0,
+ weight = 60,
+ },
+ {
+ -- MozLZ4
+ hex = '6d6f7a4c7a343000',
+ relative_position = 0,
+ weight = 60,
+ }
+ }
+ },
+ zst = {
+ matches = {
+ {
+ string = [[^[\x{22}-\x{40}]\x{B5}\x{2F}\x{FD}]],
+ position = 4,
+ weight = 60,
+ },
+ }
+ },
+ zoo = {
+ matches = {
+ {
+ hex = [[dca7c4fd]],
+ relative_position = 20,
+ weight = 60,
+ },
+ }
+ },
+ xar = {
+ matches = {
+ {
+ string = [[xar!]],
+ relative_position = 0,
+ weight = 60,
+ },
+ }
+ },
+ iso = {
+ matches = {
+ {
+ string = [[\x{01}CD001\x{01}]],
+ position = { '>=', 0x8000 + 7 }, -- first 32k is unused
+ weight = 60,
+ },
+ }
+ },
+ egg = {
+ -- ALZip egg
+ matches = {
+ {
+ string = [[EGGA]],
+ weight = 60,
+ relative_position = 0,
+ },
+ }
+ },
+ alz = {
+ -- ALZip alz
+ matches = {
+ {
+ string = [[ALZ\x{01}]],
+ weight = 60,
+ relative_position = 0,
+ },
+ }
+ },
+ -- Apple is a 'special' child: this needs to be matched at the data tail...
+ dmg = {
+ matches = {
+ {
+ string = [[koly\x{00}\x{00}\x{00}\x{04}]],
+ position = -512 + 8,
+ weight = 61,
+ tail = 512,
+ },
+ }
+ },
+ szdd = {
+ matches = {
+ {
+ hex = [[535a4444]],
+ relative_position = 0,
+ weight = 60,
+ },
+ }
+ },
+ xz = {
+ matches = {
+ {
+ hex = [[FD377A585A00]],
+ relative_position = 0,
+ weight = 60,
+ },
+ }
+ },
+ -- Images
+ psd = {
+ matches = {
+ {
+ string = [[8BPS]],
+ relative_position = 0,
+ weight = 60,
+ },
+ }
+ },
+ ico = {
+ matches = {
+ {
+ hex = [[00000100]],
+ relative_position = 0,
+ weight = 60,
+ },
+ }
+ },
+ pcx = {
+ matches = {
+ {
+ hex = [[0A050108]],
+ relative_position = 0,
+ weight = 60,
+ },
+ }
+ },
+ pic = {
+ matches = {
+ {
+ hex = [[FF80C9C71A00]],
+ relative_position = 0,
+ weight = 60,
+ },
+ }
+ },
+ swf = {
+ matches = {
+ {
+ hex = [[5a5753]], -- LZMA
+ relative_position = 0,
+ weight = 60,
+ },
+ {
+ hex = [[435753]], -- Zlib
+ relative_position = 0,
+ weight = 60,
+ },
+ {
+ hex = [[465753]], -- Uncompressed
+ relative_position = 0,
+ weight = 60,
+ },
+ }
+ },
+ tiff = {
+ matches = {
+ {
+ hex = [[49492a00]], -- LE encoded
+ relative_position = 0,
+ weight = 60,
+ },
+ {
+ hex = [[4d4d]], -- BE tiff
+ relative_position = 0,
+ weight = 60,
+ },
+ }
+ },
+ -- Other
+ pgp = {
+ matches = {
+ {
+ hex = [[A803504750]],
+ relative_position = 0,
+ weight = 60,
+ },
+ {
+ hex = [[2D424547494E20504750204D4553534147452D]],
+ relative_position = 0,
+ weight = 60,
+ },
+ }
+ },
+ uue = {
+ matches = {
+ {
+ hex = [[626567696e20]],
+ relative_position = 0,
+ weight = 60,
+ },
+ }
+ },
+ dwg = {
+ matches = {
+ {
+ string = '^AC10[12][2-9]',
+ position = 6,
+ weight = 60,
+ }
+ }
+ },
+ jpg = {
+ matches = {
+ { -- JPEG2000
+ hex = [[0000000c6a5020200d0a870a]],
+ relative_position = 0,
+ weight = 60,
+ },
+ {
+ string = [[^\x{ff}\x{d8}\x{ff}]],
+ weight = 60,
+ position = 3,
+ },
+ },
+ },
+ png = {
+ matches = {
+ {
+ string = [[^\x{89}PNG\x{0d}\x{0a}\x{1a}\x{0a}]],
+ position = 8,
+ weight = 60,
+ },
+ }
+ },
+ gif = {
+ matches = {
+ {
+ string = [[^GIF8\d]],
+ position = 5,
+ weight = 60,
+ },
+ }
+ },
+ bmp = {
+ matches = {
+ {
+ string = [[^BM...\x{00}\x{00}\x{00}\x{00}]],
+ position = 9,
+ weight = 60,
+ },
+ }
+ },
+}
+
+return patterns
diff --git a/lualib/lua_magic/types.lua b/lualib/lua_magic/types.lua
new file mode 100644
index 0000000..3dce2e1
--- /dev/null
+++ b/lualib/lua_magic/types.lua
@@ -0,0 +1,327 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_magic/patterns
+-- This module contains types definitions
+--]]
+
+-- This table is indexed by msdos extension for convenience
+
+local types = {
+ -- exe
+ exe = {
+ ct = 'application/x-ms-application',
+ type = 'executable',
+ },
+ elf = {
+ ct = 'application/x-elf-executable',
+ type = 'executable',
+ },
+ lnk = {
+ ct = 'application/x-ms-application',
+ type = 'executable',
+ },
+ class = {
+ ct = 'application/x-java-applet',
+ type = 'executable',
+ },
+ jar = {
+ ct = 'application/java-archive',
+ type = 'archive',
+ },
+ apk = {
+ ct = 'application/vnd.android.package-archive',
+ type = 'archive',
+ },
+ bat = {
+ ct = 'application/x-bat',
+ type = 'executable',
+ },
+ -- text
+ rtf = {
+ ct = "application/rtf",
+ type = 'binary',
+ },
+ pdf = {
+ ct = 'application/pdf',
+ type = 'binary',
+ },
+ ps = {
+ ct = 'application/postscript',
+ type = 'binary',
+ },
+ chm = {
+ ct = 'application/x-chm',
+ type = 'binary',
+ },
+ djvu = {
+ ct = 'application/x-djvu',
+ type = 'binary',
+ },
+ -- archives
+ arj = {
+ ct = 'application/x-arj',
+ type = 'archive',
+ },
+ cab = {
+ ct = 'application/x-cab',
+ type = 'archive',
+ },
+ ace = {
+ ct = 'application/x-ace',
+ type = 'archive',
+ },
+ tar = {
+ ct = 'application/x-tar',
+ type = 'archive',
+ },
+ bz2 = {
+ ct = 'application/x-bzip',
+ type = 'archive',
+ },
+ xz = {
+ ct = 'application/x-xz',
+ type = 'archive',
+ },
+ lz4 = {
+ ct = 'application/x-lz4',
+ type = 'archive',
+ },
+ zst = {
+ ct = 'application/x-zstandard',
+ type = 'archive',
+ },
+ dmg = {
+ ct = 'application/x-dmg',
+ type = 'archive',
+ },
+ iso = {
+ ct = 'application/x-iso',
+ type = 'archive',
+ },
+ zoo = {
+ ct = 'application/x-zoo',
+ type = 'archive',
+ },
+ egg = {
+ ct = 'application/x-egg',
+ type = 'archive',
+ },
+ alz = {
+ ct = 'application/x-alz',
+ type = 'archive',
+ },
+ xar = {
+ ct = 'application/x-xar',
+ type = 'archive',
+ },
+ epub = {
+ ct = 'application/x-epub',
+ type = 'archive'
+ },
+ szdd = { -- in fact, their MSDOS extension is like FOO.TX_ or FOO.TX$
+ ct = 'application/x-compressed',
+ type = 'archive',
+ },
+ -- images
+ psd = {
+ ct = 'image/psd',
+ type = 'image',
+ av_check = false,
+ },
+ pcx = {
+ ct = 'image/pcx',
+ type = 'image',
+ av_check = false,
+ },
+ pic = {
+ ct = 'image/pic',
+ type = 'image',
+ av_check = false,
+ },
+ tiff = {
+ ct = 'image/tiff',
+ type = 'image',
+ av_check = false,
+ },
+ ico = {
+ ct = 'image/ico',
+ type = 'image',
+ av_check = false,
+ },
+ swf = {
+ ct = 'application/x-shockwave-flash',
+ type = 'image',
+ },
+ -- Ole files
+ ole = {
+ ct = 'application/octet-stream',
+ type = 'office'
+ },
+ doc = {
+ ct = 'application/msword',
+ type = 'office'
+ },
+ xls = {
+ ct = 'application/vnd.ms-excel',
+ type = 'office'
+ },
+ ppt = {
+ ct = 'application/vnd.ms-powerpoint',
+ type = 'office'
+ },
+ vsd = {
+ ct = 'application/vnd.visio',
+ type = 'office'
+ },
+ msi = {
+ ct = 'application/x-msi',
+ type = 'executable'
+ },
+ msg = {
+ ct = 'application/vnd.ms-outlook',
+ type = 'office'
+ },
+ -- newer office (2007+)
+ docx = {
+ ct = 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
+ type = 'office'
+ },
+ xlsx = {
+ ct = 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
+ type = 'office'
+ },
+ pptx = {
+ ct = 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
+ type = 'office'
+ },
+ -- OpenOffice formats
+ odt = {
+ ct = 'application/vnd.oasis.opendocument.text',
+ type = 'office'
+ },
+ ods = {
+ ct = 'application/vnd.oasis.opendocument.spreadsheet',
+ type = 'office'
+ },
+ odp = {
+ ct = 'application/vnd.oasis.opendocument.presentation',
+ type = 'office'
+ },
+ -- https://en.wikipedia.org/wiki/Associated_Signature_Containers
+ asice = {
+ ct = 'application/vnd.etsi.asic-e+zip',
+ type = 'office'
+ },
+ asics = {
+ ct = 'application/vnd.etsi.asic-s+zip',
+ type = 'office'
+ },
+ -- other
+ pgp = {
+ ct = 'application/encrypted',
+ type = 'encrypted'
+ },
+ uue = {
+ ct = 'application/x-uuencoded',
+ type = 'binary',
+ },
+ -- Types that are detected by Rspamd itself
+ -- Archives
+ zip = {
+ ct = 'application/zip',
+ type = 'archive',
+ },
+ rar = {
+ ct = 'application/x-rar',
+ type = 'archive',
+ },
+ ['7z'] = {
+ ct = 'application/x-7z-compressed',
+ type = 'archive',
+ },
+ gz = {
+ ct = 'application/gzip',
+ type = 'archive',
+ },
+ -- Images
+ png = {
+ ct = 'image/png',
+ type = 'image',
+ av_check = false,
+ },
+ gif = {
+ ct = 'image/gif',
+ type = 'image',
+ av_check = false,
+ },
+ jpg = {
+ ct = 'image/jpeg',
+ type = 'image',
+ av_check = false,
+ },
+ bmp = {
+ type = 'image',
+ ct = 'image/bmp',
+ av_check = false,
+ },
+ dwg = {
+ type = 'image',
+ ct = 'image/vnd.dwg',
+ },
+ -- Text
+ xml = {
+ ct = 'application/xml',
+ type = 'text',
+ no_text = true,
+ },
+ txt = {
+ type = 'text',
+ ct = 'text/plain',
+ av_check = false,
+ },
+ html = {
+ type = 'text',
+ ct = 'text/html',
+ av_check = false,
+ },
+ csv = {
+ type = 'text',
+ ct = 'text/csv',
+ av_check = false,
+ no_text = true,
+ },
+ ics = {
+ type = 'text',
+ ct = 'text/calendar',
+ av_check = false,
+ no_text = true,
+ },
+ vcf = {
+ type = 'text',
+ ct = 'text/vcard',
+ av_check = false,
+ no_text = true,
+ },
+ eml = {
+ type = 'message',
+ ct = 'message/rfc822',
+ av_check = false,
+ },
+}
+
+return types \ No newline at end of file
diff --git a/lualib/lua_maps.lua b/lualib/lua_maps.lua
new file mode 100644
index 0000000..d357310
--- /dev/null
+++ b/lualib/lua_maps.lua
@@ -0,0 +1,612 @@
+--[[[
+-- @module lua_maps
+-- This module contains helper functions for managing rspamd maps
+--]]
+
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local rspamd_logger = require "rspamd_logger"
+local ts = require("tableshape").types
+local lua_util = require "lua_util"
+
+local exports = {}
+
+local maps_cache = {}
+
+local function map_hash_key(data, mtype)
+ local hash = require "rspamd_cryptobox_hash"
+ local st = hash.create_specific('xxh64')
+ st:update(data)
+ st:update(mtype)
+
+ return st:hex()
+end
+
+local function starts(where, st)
+ return string.sub(where, 1, string.len(st)) == st
+end
+
+local function cut_prefix(where, st)
+ return string.sub(where, #st + 1)
+end
+
+local function maybe_adjust_type(data, mtype)
+ local function check_prefix(prefix, t)
+ if starts(data, prefix) then
+ data = cut_prefix(data, prefix)
+ mtype = t
+
+ return true
+ end
+
+ return false
+ end
+
+ local known_types = {
+ { 'regexp;', 'regexp' },
+ { 're;', 'regexp' },
+ { 'regexp_multi;', 'regexp_multi' },
+ { 're_multi;', 'regexp_multi' },
+ { 'glob;', 'glob' },
+ { 'glob_multi;', 'glob_multi' },
+ { 'radix;', 'radix' },
+ { 'ipnet;', 'radix' },
+ { 'set;', 'set' },
+ { 'hash;', 'hash' },
+ { 'plain;', 'hash' },
+ { 'cdb;', 'cdb' },
+ { 'cdb:/', 'cdb' },
+ }
+
+ if mtype == 'callback' then
+ return mtype
+ end
+
+ for _, t in ipairs(known_types) do
+ if check_prefix(t[1], t[2]) then
+ return data, mtype
+ end
+ end
+
+ -- No change
+ return data, mtype
+end
+
+local external_map_schema = ts.shape {
+ external = ts.equivalent(true), -- must be true
+ backend = ts.string, -- where to get data, required
+ method = ts.one_of { "body", "header", "query" }, -- how to pass input
+ encode = ts.one_of { "json", "messagepack" }:is_optional(), -- how to encode input (if relevant)
+ timeout = (ts.number + ts.string / lua_util.parse_time_interval):is_optional(),
+}
+
+local rspamd_http = require "rspamd_http"
+local ucl = require "ucl"
+
+local function url_encode_string(str)
+ str = string.gsub(str, "([^%w _%%%-%.~])",
+ function(c)
+ return string.format("%%%02X", string.byte(c))
+ end)
+ str = string.gsub(str, " ", "+")
+ return str
+end
+
+assert(url_encode_string('上海+中國') == '%E4%B8%8A%E6%B5%B7%2B%E4%B8%AD%E5%9C%8B')
+assert(url_encode_string('? and the Mysterians') == '%3F+and+the+Mysterians')
+
+local function query_external_map(map_config, upstreams, key, callback, task)
+ local http_method = (map_config.method == 'body' or map_config.method == 'form') and 'POST' or 'GET'
+ local upstream = upstreams:get_upstream_round_robin()
+ local http_headers = {
+ ['Accept'] = '*/*'
+ }
+ local http_body = nil
+ local url = map_config.backend
+
+ if type(key) == 'string' or type(key) == 'userdata' then
+ if map_config.method == 'body' then
+ http_body = key
+ http_headers['Content-Type'] = 'text/plain'
+ elseif map_config.method == 'header' then
+ http_headers = {
+ key = key
+ }
+ elseif map_config.method == 'query' then
+ url = string.format('%s?key=%s', url, url_encode_string(tostring(key)))
+ end
+ elseif type(key) == 'table' then
+ if map_config.method == 'body' then
+ if map_config.encode == 'json' then
+ http_body = ucl.to_format(key, 'json-compact', true)
+ http_headers['Content-Type'] = 'application/json'
+ elseif map_config.encode == 'messagepack' then
+ http_body = ucl.to_format(key, 'messagepack', true)
+ http_headers['Content-Type'] = 'application/msgpack'
+ else
+ local caller = debug.getinfo(2) or {}
+ rspamd_logger.errx(task,
+ "requested external map key with a wrong combination body method and missing encode; caller: %s:%s",
+ caller.short_src, caller.currentline)
+ callback(false, 'invalid map usage', 500, task)
+ end
+ else
+ -- query/header and no encode
+ if map_config.method == 'query' then
+ local params_table = {}
+ for k, v in pairs(key) do
+ if type(v) == 'string' then
+ table.insert(params_table, string.format('%s=%s', url_encode_string(k), url_encode_string(v)))
+ end
+ end
+ url = string.format('%s?%s', url, table.concat(params_table, '&'))
+ elseif map_config.method == 'header' then
+ http_headers = key
+ else
+ local caller = debug.getinfo(2) or {}
+ rspamd_logger.errx(task,
+ "requested external map key with a wrong combination of encode and input; caller: %s:%s",
+ caller.short_src, caller.currentline)
+ callback(false, 'invalid map usage', 500, task)
+ return
+ end
+ end
+ end
+
+ local function map_callback(err, code, body, _)
+ if err then
+ callback(false, err, code, task)
+ elseif code == 200 then
+ callback(true, body, 200, task)
+ else
+ callback(false, err, code, task)
+ end
+ end
+
+ local ret = rspamd_http.request {
+ task = task,
+ url = url,
+ callback = map_callback,
+ timeout = map_config.timeout or 1.0,
+ keepalive = true,
+ upstream = upstream,
+ method = http_method,
+ headers = http_headers,
+ body = http_body,
+ }
+
+ if not ret then
+ callback(false, 'http request error', 500, task)
+ end
+end
+
+--[[[
+-- @function lua_maps.map_add_from_ucl(opt, mtype, description)
+-- Creates a map from static data
+-- Returns true if map was added or nil
+-- @param {string or table} opt data for map (or URL)
+-- @param {string} mtype type of map (`set`, `map`, `radix`, `regexp`)
+-- @param {string} description human-readable description of map
+-- @param {function} callback optional callback that will be called on map match (required for external maps)
+-- @return {bool} true on success, or `nil`
+--]]
+local function rspamd_map_add_from_ucl(opt, mtype, description, callback)
+ local ret = {
+ get_key = function(t, k, key_callback, task)
+ if t.__data then
+ local cb = key_callback or callback
+ if t.__external then
+ if not cb or not task then
+ local caller = debug.getinfo(2) or {}
+ rspamd_logger.errx(rspamd_config, "requested external map key without callback or task; caller: %s:%s",
+ caller.short_src, caller.currentline)
+ return nil
+ end
+ query_external_map(t.__data, t.__upstreams, k, cb, task)
+ else
+ local result = t.__data:get_key(k)
+ if cb then
+ if result then
+ cb(true, result, 200, task)
+ else
+ cb(false, 'not found', 404, task)
+ end
+ else
+ return result
+ end
+ end
+ end
+
+ return nil
+ end,
+ foreach = function(t, cb)
+ return t.__data:foreach(cb)
+ end,
+ on_load = function(t, cb)
+ t.__data:on_load(cb)
+ end
+ }
+ local ret_mt = {
+ __index = function(t, k, key_callback, task)
+ if t.__data then
+ return t.get_key(k, key_callback, task)
+ end
+
+ return nil
+ end
+ }
+
+ if not opt then
+ return nil
+ end
+
+ local function maybe_register_selector()
+ if opt.selector_alias then
+ local lua_selectors = require "lua_selectors"
+ lua_selectors.add_map(opt.selector_alias, ret)
+ end
+ end
+
+ if type(opt) == 'string' then
+ opt, mtype = maybe_adjust_type(opt, mtype)
+ local cache_key = map_hash_key(opt, mtype)
+ if not callback and maps_cache[cache_key] then
+ rspamd_logger.infox(rspamd_config, 'reuse url for %s(%s)',
+ opt, mtype)
+
+ return maps_cache[cache_key]
+ end
+ -- We have a single string, so we treat it as a map
+ local map = rspamd_config:add_map {
+ type = mtype,
+ description = description,
+ url = opt,
+ }
+
+ if map then
+ ret.__data = map
+ ret.hash = cache_key
+ setmetatable(ret, ret_mt)
+ maps_cache[cache_key] = ret
+ return ret
+ end
+ elseif type(opt) == 'table' then
+ local cache_key = lua_util.table_digest(opt)
+ if not callback and maps_cache[cache_key] then
+ rspamd_logger.infox(rspamd_config, 'reuse url for complex map definition %s: %s',
+ cache_key:sub(1, 8), description)
+
+ return maps_cache[cache_key]
+ end
+
+ if opt[1] then
+ -- Adjust each element if needed
+ local adjusted
+ for i, source in ipairs(opt) do
+ local nsrc, ntype = maybe_adjust_type(source, mtype)
+
+ if mtype ~= ntype then
+ if not adjusted then
+ mtype = ntype
+ end
+ adjusted = true
+ end
+ opt[i] = nsrc
+ end
+
+ if mtype == 'radix' then
+
+ if string.find(opt[1], '^%d') then
+ local map = rspamd_config:radix_from_ucl(opt)
+
+ if map then
+ ret.__data = map
+ setmetatable(ret, ret_mt)
+ maps_cache[cache_key] = ret
+ maybe_register_selector()
+
+ return ret
+ end
+ else
+ -- Plain table
+ local map = rspamd_config:add_map {
+ type = mtype,
+ description = description,
+ url = opt,
+ }
+ if map then
+ ret.__data = map
+ setmetatable(ret, ret_mt)
+ maps_cache[cache_key] = ret
+ maybe_register_selector()
+
+ return ret
+ end
+ end
+ elseif mtype == 'regexp' or mtype == 'glob' then
+ if string.find(opt[1], '^/%a') or string.find(opt[1], '^http') then
+ -- Plain table
+ local map = rspamd_config:add_map {
+ type = mtype,
+ description = description,
+ url = opt,
+ }
+ if map then
+ ret.__data = map
+ setmetatable(ret, ret_mt)
+ maps_cache[cache_key] = ret
+ maybe_register_selector()
+
+ return ret
+ end
+ else
+ local map = rspamd_config:add_map {
+ type = mtype,
+ description = description,
+ url = {
+ url = 'static',
+ data = opt,
+ }
+ }
+ if map then
+ ret.__data = map
+ setmetatable(ret, ret_mt)
+ maps_cache[cache_key] = ret
+ maybe_register_selector()
+
+ return ret
+ end
+ end
+ else
+ if string.find(opt[1], '^/%a') or string.find(opt[1], '^http') then
+ -- Plain table
+ local map = rspamd_config:add_map {
+ type = mtype,
+ description = description,
+ url = opt,
+ }
+ if map then
+ ret.__data = map
+ setmetatable(ret, ret_mt)
+ maps_cache[cache_key] = ret
+ maybe_register_selector()
+
+ return ret
+ end
+ else
+ local data = {}
+ local nelts = 0
+ -- Plain array of keys, count merely numeric elts
+ for _, elt in ipairs(opt) do
+ if type(elt) == 'string' then
+ -- Numeric table
+ if mtype == 'hash' then
+ -- Treat as KV pair
+ local pieces = lua_util.str_split(elt, ' ')
+ if #pieces > 1 then
+ local key = table.remove(pieces, 1)
+ data[key] = table.concat(pieces, ' ')
+ else
+ data[elt] = true
+ end
+ else
+ data[elt] = true
+ end
+
+ nelts = nelts + 1
+ end
+ end
+
+ if nelts > 0 then
+ -- Plain Lua table that is used as a map
+ ret.__data = data
+ ret.get_key = function(t, k)
+ if k ~= '__data' then
+ return t.__data[k]
+ end
+
+ return nil
+ end
+ ret.foreach = function(_, func)
+ for k, v in pairs(ret.__data) do
+ if not func(k, v) then
+ return false
+ end
+ end
+
+ return true
+ end
+ ret.on_load = function(_, cb)
+ rspamd_config:add_on_load(function(_, _, _)
+ cb()
+ end)
+ end
+
+ maps_cache[cache_key] = ret
+ maybe_register_selector()
+
+ return ret
+ else
+ -- Empty map, huh?
+ rspamd_logger.errx(rspamd_config, 'invalid map element: %s',
+ opt)
+ end
+ end
+ end
+ else
+ if opt.external then
+ -- External map definition, missing fields are handled by schema
+ local parse_res, parse_err = external_map_schema(opt)
+
+ if parse_res then
+ ret.__upstreams = lua_util.http_upstreams_by_url(rspamd_config:get_mempool(), opt.backend)
+ if ret.__upstreams then
+ ret.__data = opt
+ ret.__external = true
+ setmetatable(ret, ret_mt)
+ maybe_register_selector()
+
+ return ret
+ else
+ rspamd_logger.errx(rspamd_config, 'cannot parse external map upstreams: %s',
+ opt.backend)
+ end
+ else
+ rspamd_logger.errx(rspamd_config, 'cannot parse external map: %s',
+ parse_err)
+ end
+ else
+ -- Adjust lua specific augmentations in a trivial case
+ if type(opt.url) == 'string' then
+ local nsrc, ntype = maybe_adjust_type(opt.url, mtype)
+ if nsrc and ntype then
+ opt.url = nsrc
+ mtype = ntype
+ end
+ end
+ -- We have some non-trivial object so let C code to deal with it somehow...
+ local map = rspamd_config:add_map {
+ type = mtype,
+ description = description,
+ url = opt,
+ }
+ if map then
+ ret.__data = map
+ setmetatable(ret, ret_mt)
+ maps_cache[cache_key] = ret
+ maybe_register_selector()
+
+ return ret
+ end
+ end
+ end -- opt[1]
+ end
+
+ return nil
+end
+
+--[[[
+-- @function lua_maps.map_add(mname, optname, mtype, description)
+-- Creates a map from configuration elements (static data or URL)
+-- Returns true if map was added or nil
+-- @param {string} mname config section to use
+-- @param {string} optname option name to use
+-- @param {string} mtype type of map ('set', 'hash', 'radix', 'regexp', 'glob')
+-- @param {string} description human-readable description of map
+-- @param {function} callback optional callback that will be called on map match (required for external maps)
+-- @return {bool} true on success, or `nil`
+--]]
+
+local function rspamd_map_add(mname, optname, mtype, description, callback)
+ local opt = rspamd_config:get_module_opt(mname, optname)
+
+ return rspamd_map_add_from_ucl(opt, mtype, description, callback)
+end
+
+exports.rspamd_map_add = rspamd_map_add
+exports.map_add = rspamd_map_add
+exports.rspamd_map_add_from_ucl = rspamd_map_add_from_ucl
+exports.map_add_from_ucl = rspamd_map_add_from_ucl
+
+-- Check `what` for being lua_map name, otherwise just compares key with what
+local function rspamd_maybe_check_map(key, what)
+ local fun = require "fun"
+
+ if type(what) == "table" then
+ return fun.any(function(elt)
+ return rspamd_maybe_check_map(key, elt)
+ end, what)
+ end
+ if type(rspamd_maps) == "table" then
+ local mn
+ if starts(key, "map:") then
+ mn = string.sub(key, 5)
+ elseif starts(key, "map://") then
+ mn = string.sub(key, 7)
+ end
+
+ if mn and rspamd_maps[mn] then
+ return rspamd_maps[mn]:get_key(what)
+ end
+ end
+
+ return what:lower() == key
+end
+
+exports.rspamd_maybe_check_map = rspamd_maybe_check_map
+
+--[[[
+-- @function lua_maps.fill_config_maps(mname, options, defs)
+-- Fill maps that could be defined in defs, from the config in the options
+-- Defs is a table indexed by a map's parameter name and defining it's config,
+-- @example
+-- defs = {
+-- my_map = {
+-- type = 'map',
+-- description = 'my cool map',
+-- optional = true,
+-- }
+-- }
+-- -- Then this function will look for opts.my_map parameter and try to replace it with
+-- -- a map with the specific type, description but not failing if it was empty.
+-- -- It will also set options.my_map_orig to the original value defined in the map.
+--]]
+exports.fill_config_maps = function(mname, opts, map_defs)
+ assert(type(opts) == 'table')
+ assert(type(map_defs) == 'table')
+ for k, v in pairs(map_defs) do
+ if opts[k] then
+ local map = rspamd_map_add_from_ucl(opts[k], v.type or 'map', v.description)
+ if not map then
+ rspamd_logger.errx(rspamd_config, 'map add error %s for module %s', k, mname)
+ return false
+ end
+ opts[k .. '_orig'] = opts[k]
+ opts[k] = map
+ elseif not v.optional then
+ rspamd_logger.errx(rspamd_config, 'cannot find non optional map %s for module %s', k, mname)
+ return false
+ end
+ end
+
+ return true
+end
+
+local direct_map_schema = ts.shape { -- complex object
+ name = ts.string:is_optional(),
+ description = ts.string:is_optional(),
+ selector_alias = ts.string:is_optional(), -- an optional alias for the selectos framework
+ timeout = ts.number,
+ data = ts.array_of(ts.string):is_optional(),
+ -- Tableshape has no options support for something like key1 or key2?
+ upstreams = ts.one_of {
+ ts.string,
+ ts.array_of(ts.string),
+ } :is_optional(),
+ url = ts.one_of {
+ ts.string,
+ ts.array_of(ts.string),
+ } :is_optional(),
+}
+
+exports.map_schema = ts.one_of {
+ ts.string, -- 'http://some_map'
+ ts.array_of(ts.string), -- ['foo', 'bar']
+ ts.one_of { direct_map_schema, external_map_schema }
+}
+
+return exports
diff --git a/lualib/lua_maps_expressions.lua b/lualib/lua_maps_expressions.lua
new file mode 100644
index 0000000..996de99
--- /dev/null
+++ b/lualib/lua_maps_expressions.lua
@@ -0,0 +1,219 @@
+--[[[
+-- @module lua_maps_expressions
+-- This module contains routines to combine maps, selectors and expressions
+-- in a generic framework
+@example
+whitelist_ip_from = {
+ rules {
+ ip {
+ selector = "ip";
+ map = "/path/to/whitelist_ip.map";
+ }
+ from {
+ selector = "from(smtp)";
+ map = "/path/to/whitelist_from.map";
+ }
+ }
+ expression = "ip & from";
+}
+--]]
+
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local lua_selectors = require "lua_selectors"
+local lua_maps = require "lua_maps"
+local rspamd_expression = require "rspamd_expression"
+local rspamd_logger = require "rspamd_logger"
+local fun = require "fun"
+local ts = require("tableshape").types
+
+local exports = {}
+
+local function process_func(elt, task)
+ local matched = {}
+ local function process_atom(atom)
+ local rule = elt.rules[atom]
+ local res = 0
+
+ local function match_rule(val)
+ local map_match = rule.map:get_key(val)
+ if map_match then
+ res = 1.0
+ matched[rule.name] = {
+ matched = val,
+ value = map_match
+ }
+ end
+ end
+
+ local values = rule.selector(task)
+
+ if values then
+ if type(values) == 'table' then
+ for _, val in ipairs(values) do
+ if res == 0 then
+ match_rule(val)
+ end
+ end
+ else
+ match_rule(values)
+ end
+ end
+
+ return res
+ end
+
+ local res = elt.expr:process(process_atom)
+
+ if res > 0 then
+ return res, matched
+ end
+
+ return nil
+end
+
+exports.schema = ts.shape {
+ expression = ts.string,
+ rules = ts.array_of(
+ ts.shape {
+ selector = ts.string,
+ map = lua_maps.map_schema,
+ }
+ )
+}
+
+--[[[
+-- @function lua_maps_expression.create(config, object, module_name)
+-- Creates a new maps combination from `object` for `module_name`.
+-- The input should be table with the following fields:
+--
+-- * `rules` - kv map of rules where each rule has `map` and `selector` mandatory attribute, also `type` for map type, e.g. `regexp`
+-- * `expression` - Rspamd expression where elements are names from `rules` field, e.g. `ip & from`
+--
+-- This function returns an object with public method `process(task)` that checks
+-- a task for the conditions defined in `expression` and `rules` and returns 2 values:
+--
+-- 1. value returned by an expression (e.g. 1 or 0)
+-- 2. an map (rule_name -> table) of matches, where each element has the following fields:
+-- * `matched` - selector's value
+-- * `value` - map's result
+--
+-- In case if `expression` is false a `nil` value is returned.
+-- @param {rspamd_config} cfg rspamd config
+-- @param {table} obj configuration table
+--
+--]]
+local function create(cfg, obj, module_name)
+ if not module_name then
+ module_name = 'lua_maps_expressions'
+ end
+
+ if not obj or not obj.rules or not obj.expression then
+ rspamd_logger.errx(cfg, 'cannot add maps combination for module %s: required elements are missing',
+ module_name)
+ return nil
+ end
+
+ local ret = {
+ process = process_func,
+ rules = {},
+ module_name = module_name
+ }
+
+ for name, rule in pairs(obj.rules) do
+ local sel = lua_selectors.create_selector_closure(cfg, rule.selector)
+
+ if not sel then
+ rspamd_logger.errx(cfg, 'cannot add selector for element %s in module %s',
+ name, module_name)
+ end
+
+ if not rule.type then
+ -- Guess type
+ if name:find('ip') or name:find('ipnet') then
+ rule.type = 'radix'
+ elseif name:find('regexp') or name:find('re_') then
+ rule.type = 'regexp'
+ elseif name:find('glob') then
+ rule.type = 'regexp'
+ else
+ rule.type = 'set'
+ end
+ end
+ local map = lua_maps.map_add_from_ucl(rule.map, rule.type,
+ obj.description or module_name)
+ if not map then
+ rspamd_logger.errx(cfg, 'cannot add map for element %s in module %s',
+ name, module_name)
+ end
+
+ if sel and map then
+ ret.rules[name] = {
+ selector = sel,
+ map = map,
+ name = name,
+ }
+ else
+ return nil
+ end
+ end
+
+ -- Now process and parse expression
+ local function parse_atom(str)
+ local atom = table.concat(fun.totable(fun.take_while(function(c)
+ if string.find(', \t()><+!|&\n', c, 1, true) then
+ return false
+ end
+ return true
+ end, fun.iter(str))), '')
+
+ if ret.rules[atom] then
+ return atom
+ end
+
+ rspamd_logger.errx(cfg, 'use of undefined element "%s" when parsing maps expression for %s',
+ atom, module_name)
+
+ return nil
+ end
+ local expr = rspamd_expression.create(obj.expression, parse_atom,
+ rspamd_config:get_mempool())
+
+ if not expr then
+ rspamd_logger.errx(cfg, 'cannot add map expression for module %s',
+ module_name)
+ return nil
+ end
+
+ ret.expr = expr
+
+ if obj.symbol then
+ rspamd_config:register_symbol {
+ type = 'virtual,ghost',
+ name = obj.symbol,
+ score = 0.0,
+ }
+ end
+
+ ret.symbol = obj.symbol
+
+ return ret
+end
+
+exports.create = create
+
+return exports \ No newline at end of file
diff --git a/lualib/lua_meta.lua b/lualib/lua_meta.lua
new file mode 100644
index 0000000..340d89e
--- /dev/null
+++ b/lualib/lua_meta.lua
@@ -0,0 +1,549 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local exports = {}
+
+local N = "metatokens"
+local ts = require("tableshape").types
+local logger = require "rspamd_logger"
+
+-- Metafunctions
+local function meta_size_function(task)
+ local sizes = {
+ 100,
+ 200,
+ 500,
+ 1000,
+ 2000,
+ 4000,
+ 10000,
+ 20000,
+ 30000,
+ 100000,
+ 200000,
+ 400000,
+ 800000,
+ 1000000,
+ 2000000,
+ 8000000,
+ }
+
+ local size = task:get_size()
+ for i = 1, #sizes do
+ if sizes[i] >= size then
+ return { (1.0 * i) / #sizes }
+ end
+ end
+
+ return { 0 }
+end
+
+local function meta_images_function(task)
+ local images = task:get_images()
+ local ntotal = 0
+ local njpg = 0
+ local npng = 0
+ local nlarge = 0
+ local nsmall = 0
+
+ if images then
+ for _, img in ipairs(images) do
+ if img:get_type() == 'png' then
+ npng = npng + 1
+ elseif img:get_type() == 'jpeg' then
+ njpg = njpg + 1
+ end
+
+ local w = img:get_width()
+ local h = img:get_height()
+
+ if w > 0 and h > 0 then
+ if w + h > 256 then
+ nlarge = nlarge + 1
+ else
+ nsmall = nsmall + 1
+ end
+ end
+
+ ntotal = ntotal + 1
+ end
+ end
+ if ntotal > 0 then
+ njpg = 1.0 * njpg / ntotal
+ npng = 1.0 * npng / ntotal
+ nlarge = 1.0 * nlarge / ntotal
+ nsmall = 1.0 * nsmall / ntotal
+ end
+ return { ntotal, njpg, npng, nlarge, nsmall }
+end
+
+local function meta_nparts_function(task)
+ local nattachments = 0
+ local ntextparts = 0
+ local totalparts = 1
+
+ local tp = task:get_text_parts()
+ if tp then
+ ntextparts = #tp
+ end
+
+ local parts = task:get_parts()
+
+ if parts then
+ for _, p in ipairs(parts) do
+ if p:is_attachment() then
+ nattachments = nattachments + 1
+ end
+ totalparts = totalparts + 1
+ end
+ end
+
+ return { (1.0 * ntextparts) / totalparts, (1.0 * nattachments) / totalparts }
+end
+
+local function meta_encoding_function(task)
+ local nutf = 0
+ local nother = 0
+
+ local tp = task:get_text_parts()
+ if tp and #tp > 0 then
+ for _, p in ipairs(tp) do
+ if p:is_utf() then
+ nutf = nutf + 1
+ else
+ nother = nother + 1
+ end
+ end
+
+ return { nutf / #tp, nother / #tp }
+ end
+
+ return { 0, 0 }
+end
+
+local function meta_recipients_function(task)
+ local nmime = 0
+ local nsmtp = 0
+
+ if task:has_recipients('mime') then
+ nmime = #(task:get_recipients('mime'))
+ end
+ if task:has_recipients('smtp') then
+ nsmtp = #(task:get_recipients('smtp'))
+ end
+
+ if nmime > 0 then
+ nmime = 1.0 / nmime
+ end
+ if nsmtp > 0 then
+ nsmtp = 1.0 / nsmtp
+ end
+
+ return { nmime, nsmtp }
+end
+
+local function meta_received_function(task)
+ local count_factor = 0
+ local invalid_factor = 0
+ local rh = task:get_received_headers()
+ local time_factor = 0
+ local secure_factor = 0
+ local fun = require "fun"
+
+ if rh and #rh > 0 then
+
+ local ntotal = 0.0
+ local init_time = 0
+
+ fun.each(function(rc)
+ ntotal = ntotal + 1.0
+
+ if not rc.by_hostname then
+ invalid_factor = invalid_factor + 1.0
+ end
+ if init_time == 0 and rc.timestamp then
+ init_time = rc.timestamp
+ elseif rc.timestamp then
+ time_factor = time_factor + math.abs(init_time - rc.timestamp)
+ init_time = rc.timestamp
+ end
+ if rc.flags and (rc.flags['ssl'] or rc.flags['authenticated']) then
+ secure_factor = secure_factor + 1.0
+ end
+ end,
+ fun.filter(function(rc)
+ return not rc.flags or not rc.flags['artificial']
+ end, rh))
+
+ if ntotal > 0 then
+ invalid_factor = invalid_factor / ntotal
+ secure_factor = secure_factor / ntotal
+ count_factor = 1.0 / ntotal
+ end
+
+ if time_factor ~= 0 then
+ time_factor = 1.0 / time_factor
+ end
+ end
+
+ return { count_factor, invalid_factor, time_factor, secure_factor }
+end
+
+local function meta_urls_function(task)
+ local has_urls, nurls = task:has_urls()
+ if has_urls and nurls > 0 then
+ return { 1.0 / nurls }
+ end
+
+ return { 0 }
+end
+
+local function meta_words_function(task)
+ local avg_len = task:get_mempool():get_variable("avg_words_len", "double") or 0.0
+ local short_words = task:get_mempool():get_variable("short_words_cnt", "double") or 0.0
+ local ret_len = 0
+
+ local lens = {
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9,
+ 10,
+ 15,
+ 20,
+ }
+
+ for i = 1, #lens do
+ if lens[i] >= avg_len then
+ ret_len = (1.0 * i) / #lens
+ break
+ end
+ end
+
+ local tp = task:get_text_parts()
+ local wres = {
+ 0, -- spaces rate
+ 0, -- double spaces rate
+ 0, -- non spaces rate
+ 0, -- ascii characters rate
+ 0, -- non-ascii characters rate
+ 0, -- capital characters rate
+ 0, -- numeric characters
+ }
+ for _, p in ipairs(tp) do
+ local stats = p:get_stats()
+ local len = p:get_length()
+
+ if len > 0 then
+ wres[1] = wres[1] + stats['spaces'] / len
+ wres[2] = wres[2] + stats['double_spaces'] / len
+ wres[3] = wres[3] + stats['non_spaces'] / len
+ wres[4] = wres[4] + stats['ascii_characters'] / len
+ wres[5] = wres[5] + stats['non_ascii_characters'] / len
+ wres[6] = wres[6] + stats['capital_letters'] / len
+ wres[7] = wres[7] + stats['numeric_characters'] / len
+ end
+ end
+
+ local ret = {
+ short_words,
+ ret_len,
+ }
+
+ local divisor = 1.0
+ if #tp > 0 then
+ divisor = #tp
+ end
+
+ for _, wr in ipairs(wres) do
+ table.insert(ret, wr / divisor)
+ end
+
+ return ret
+end
+
+local metafunctions = {
+ {
+ cb = meta_size_function,
+ ninputs = 1,
+ names = {
+ "size"
+ },
+ description = 'Describes size of the message',
+ },
+ {
+ cb = meta_images_function,
+ ninputs = 5,
+ -- 1 - number of images,
+ -- 2 - number of png images,
+ -- 3 - number of jpeg images
+ -- 4 - number of large images (> 128 x 128)
+ -- 5 - number of small images (< 128 x 128)
+ names = {
+ 'nimages',
+ 'npng_images',
+ 'njpeg_images',
+ 'nlarge_images',
+ 'nsmall_images'
+ },
+ description = [[Functions for images matching:
+ - number of images,
+ - number of png images,
+ - number of jpeg images
+ - number of large images (> 128 x 128)
+ - number of small images (< 128 x 128)
+]]
+ },
+ {
+ cb = meta_nparts_function,
+ ninputs = 2,
+ -- 1 - number of text parts
+ -- 2 - number of attachments
+ names = {
+ 'ntext_parts',
+ 'nattachments'
+ },
+ description = [[Functions for images matching:
+ - number of text parts
+ - number of attachments
+]]
+ },
+ {
+ cb = meta_encoding_function,
+ ninputs = 2,
+ -- 1 - number of utf parts
+ -- 2 - number of non-utf parts
+ names = {
+ 'nutf_parts',
+ 'nascii_parts'
+ },
+ description = [[Functions for encoding matching:
+ - number of utf parts
+ - number of non-utf parts
+]]
+ },
+ {
+ cb = meta_recipients_function,
+ ninputs = 2,
+ -- 1 - number of mime rcpt
+ -- 2 - number of smtp rcpt
+ names = {
+ 'nmime_rcpt',
+ 'nsmtp_rcpt'
+ },
+ description = [[Functions for recipients data matching:
+ - number of mime rcpt
+ - number of smtp rcpt
+]]
+ },
+ {
+ cb = meta_received_function,
+ ninputs = 4,
+ names = {
+ 'nreceived',
+ 'nreceived_invalid',
+ 'nreceived_bad_time',
+ 'nreceived_secure'
+ },
+ description = [[Functions for received headers data matching:
+ - number of received headers
+ - number of bad received headers
+ - number of skewed time received headers
+ - number of received via secured relays
+]]
+ },
+ {
+ cb = meta_urls_function,
+ ninputs = 1,
+ names = {
+ 'nurls'
+ },
+ description = [[Functions for urls data matching:
+ - number of urls
+]]
+ },
+ {
+ cb = meta_words_function,
+ ninputs = 9,
+ names = {
+ 'avg_words_len',
+ 'nshort_words',
+ 'spaces_rate',
+ 'double_spaces_rate',
+ 'non_spaces_rate',
+ 'ascii_characters_rate',
+ 'non_ascii_characters_rate',
+ 'capital_characters_rate',
+ 'numeric_characters'
+ },
+ description = [[Functions for words data matching:
+ - average length of the words
+ - number of short words
+ - rate of spaces in the text
+ - rate of multiple spaces
+ - rate of non space characters
+ - rate of ascii characters
+ - rate of non-ascii characters
+ - rate of capital letters
+ - rate of numbers
+]]
+ },
+}
+
+local meta_schema = ts.shape {
+ cb = ts.func,
+ ninputs = ts.number,
+ names = ts.array_of(ts.string),
+ description = ts.string:is_optional()
+}
+
+local metatokens_by_name = {}
+
+local function fill_metatokens_by_name()
+ metatokens_by_name = {}
+
+ for _, mt in ipairs(metafunctions) do
+ for i = 1, mt.ninputs do
+ local name = mt.names[i]
+
+ metatokens_by_name[name] = function(task)
+ local results = mt.cb(task)
+ return results[i]
+ end
+ end
+ end
+end
+
+local function calculate_digest()
+ local cr = require "rspamd_cryptobox_hash"
+
+ local h = cr.create()
+ for _, mt in ipairs(metafunctions) do
+ for i = 1, mt.ninputs do
+ local name = mt.names[i]
+ h:update(name)
+ end
+ end
+
+ exports.digest = h:hex()
+end
+
+local function rspamd_gen_metatokens(task, names)
+ local lua_util = require "lua_util"
+ local ipairs = ipairs
+ local metatokens = {}
+
+ if not names then
+ local cached = task:cache_get('metatokens')
+
+ if cached then
+ return cached
+ else
+ for _, mt in ipairs(metafunctions) do
+ local ct = mt.cb(task)
+ for i, tok in ipairs(ct) do
+ lua_util.debugm(N, task, "metatoken: %s = %s",
+ mt.names[i], tok)
+ if tok ~= tok or tok == math.huge then
+ logger.errx(task, 'metatoken %s returned %s; replace it with 0 for sanity',
+ mt.names[i], tok)
+ tok = 0.0
+ end
+ table.insert(metatokens, tok)
+ end
+ end
+
+ task:cache_set('metatokens', metatokens)
+ end
+
+ else
+ for _, n in ipairs(names) do
+ if metatokens_by_name[n] then
+ local tok = metatokens_by_name[n](task)
+ if tok ~= tok or tok == math.huge then
+ logger.errx(task, 'metatoken %s returned %s; replace it with 0 for sanity',
+ n, tok)
+ tok = 0.0
+ end
+ table.insert(metatokens, tok)
+ else
+ logger.errx(task, 'unknown metatoken: %s', n)
+ end
+ end
+ end
+
+ return metatokens
+end
+
+exports.rspamd_gen_metatokens = rspamd_gen_metatokens
+exports.gen_metatokens = rspamd_gen_metatokens
+
+local function rspamd_gen_metatokens_table(task)
+ local metatokens = {}
+
+ for _, mt in ipairs(metafunctions) do
+ local ct = mt.cb(task)
+ for i, tok in ipairs(ct) do
+ if tok ~= tok or tok == math.huge then
+ logger.errx(task, 'metatoken %s returned %s; replace it with 0 for sanity',
+ mt.names[i], tok)
+ tok = 0.0
+ end
+
+ metatokens[mt.names[i]] = tok
+ end
+ end
+
+ return metatokens
+end
+
+exports.rspamd_gen_metatokens_table = rspamd_gen_metatokens_table
+exports.gen_metatokens_table = rspamd_gen_metatokens_table
+
+local function rspamd_count_metatokens()
+ local ipairs = ipairs
+ local total = 0
+ for _, mt in ipairs(metafunctions) do
+ total = total + mt.ninputs
+ end
+
+ return total
+end
+
+exports.rspamd_count_metatokens = rspamd_count_metatokens
+exports.count_metatokens = rspamd_count_metatokens
+exports.version = 1 -- MUST be increased on each change of metatokens
+
+exports.add_metafunction = function(tbl)
+ local ret, err = meta_schema(tbl)
+
+ if not ret then
+ logger.errx('cannot add metafunction: %s', err)
+ else
+ table.insert(metafunctions, tbl)
+ fill_metatokens_by_name()
+ calculate_digest()
+ end
+end
+
+fill_metatokens_by_name()
+calculate_digest()
+
+return exports
diff --git a/lualib/lua_mime.lua b/lualib/lua_mime.lua
new file mode 100644
index 0000000..0f5aa75
--- /dev/null
+++ b/lualib/lua_mime.lua
@@ -0,0 +1,760 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_mime
+-- This module contains helper functions to modify mime parts
+--]]
+
+local logger = require "rspamd_logger"
+local rspamd_util = require "rspamd_util"
+local rspamd_text = require "rspamd_text"
+local ucl = require "ucl"
+
+local exports = {}
+
+local function newline(task)
+ local t = task:get_newlines_type()
+
+ if t == 'cr' then
+ return '\r'
+ elseif t == 'lf' then
+ return '\n'
+ end
+
+ return '\r\n'
+end
+
+local function do_append_footer(task, part, footer, is_multipart, out, state)
+ local tp = part:get_text()
+ local ct = 'text/plain'
+ local cte = 'quoted-printable'
+ local newline_s = state.newline_s
+
+ if tp:is_html() then
+ ct = 'text/html'
+ end
+
+ local encode_func = function(input)
+ return rspamd_util.encode_qp(input, 80, task:get_newlines_type())
+ end
+
+ if part:get_cte() == '7bit' then
+ cte = '7bit'
+ encode_func = function(input)
+ if type(input) == 'userdata' then
+ return input
+ else
+ return rspamd_text.fromstring(input)
+ end
+ end
+ end
+
+ if is_multipart then
+ out[#out + 1] = string.format('Content-Type: %s; charset=utf-8%s' ..
+ 'Content-Transfer-Encoding: %s',
+ ct, newline_s, cte)
+ out[#out + 1] = ''
+ else
+ state.new_cte = cte
+ end
+
+ local content = tp:get_content('raw_utf') or ''
+ local double_nline = newline_s .. newline_s
+ local nlen = #double_nline
+ -- Hack, if part ends with 2 newline, then we append it after footer
+ if content:sub(-(nlen), nlen + 1) == double_nline then
+ -- content without last newline
+ content = content:sub(-(#newline_s), #newline_s + 1) .. footer
+ out[#out + 1] = { encode_func(content), true }
+ out[#out + 1] = ''
+ else
+ content = content .. footer
+ out[#out + 1] = { encode_func(content), true }
+ out[#out + 1] = ''
+ end
+
+end
+
+--[[[
+-- @function lua_mime.add_text_footer(task, html_footer, text_footer)
+-- Adds a footer to all text parts in a message. It returns a table with the following
+-- fields:
+-- * out: new content (body only)
+-- * need_rewrite_ct: boolean field that means if we must rewrite content type
+-- * new_ct: new content type (type => string, subtype => string)
+-- * new_cte: new content-transfer encoding (string)
+--]]
+exports.add_text_footer = function(task, html_footer, text_footer)
+ local newline_s = newline(task)
+ local state = {
+ newline_s = newline_s
+ }
+ local out = {}
+ local text_parts = task:get_text_parts()
+
+ if not (html_footer or text_footer) or not (text_parts and #text_parts > 0) then
+ return false
+ end
+
+ if html_footer or text_footer then
+ -- We need to take extra care about content-type and cte
+ local ct = task:get_header('Content-Type')
+ if ct then
+ ct = rspamd_util.parse_content_type(ct, task:get_mempool())
+ end
+
+ if ct then
+ if ct.type and ct.type == 'text' then
+ if ct.subtype then
+ if html_footer and (ct.subtype == 'html' or ct.subtype == 'htm') then
+ state.need_rewrite_ct = true
+ elseif text_footer and ct.subtype == 'plain' then
+ state.need_rewrite_ct = true
+ end
+ else
+ if text_footer then
+ state.need_rewrite_ct = true
+ end
+ end
+
+ state.new_ct = ct
+ end
+ else
+
+ if text_parts then
+
+ if #text_parts == 1 then
+ state.need_rewrite_ct = true
+ state.new_ct = {
+ type = 'text',
+ subtype = 'plain'
+ }
+ elseif #text_parts > 1 then
+ -- XXX: in fact, it cannot be
+ state.new_ct = {
+ type = 'multipart',
+ subtype = 'mixed'
+ }
+ end
+ end
+ end
+ end
+
+ local boundaries = {}
+ local cur_boundary
+ for _, part in ipairs(task:get_parts()) do
+ local boundary = part:get_boundary()
+ if part:is_multipart() then
+ if cur_boundary then
+ out[#out + 1] = string.format('--%s',
+ boundaries[#boundaries])
+ end
+
+ boundaries[#boundaries + 1] = boundary or '--XXX'
+ cur_boundary = boundary
+
+ local rh = part:get_raw_headers()
+ if #rh > 0 then
+ out[#out + 1] = { rh, true }
+ end
+ elseif part:is_message() then
+ if boundary then
+ if cur_boundary and boundary ~= cur_boundary then
+ -- Need to close boundary
+ out[#out + 1] = string.format('--%s--%s',
+ boundaries[#boundaries], newline_s)
+ table.remove(boundaries)
+ cur_boundary = nil
+ end
+ out[#out + 1] = string.format('--%s',
+ boundary)
+ end
+
+ out[#out + 1] = { part:get_raw_headers(), true }
+ else
+ local append_footer = false
+ local skip_footer = part:is_attachment()
+
+ local parent = part:get_parent()
+ if parent then
+ local t, st = parent:get_type()
+
+ if t == 'multipart' and st == 'signed' then
+ -- Do not modify signed parts
+ skip_footer = true
+ end
+ end
+ if text_footer and part:is_text() then
+ local tp = part:get_text()
+
+ if not tp:is_html() then
+ append_footer = text_footer
+ end
+ end
+
+ if html_footer and part:is_text() then
+ local tp = part:get_text()
+
+ if tp:is_html() then
+ append_footer = html_footer
+ end
+ end
+
+ if boundary then
+ if cur_boundary and boundary ~= cur_boundary then
+ -- Need to close boundary
+ out[#out + 1] = string.format('--%s--%s',
+ boundaries[#boundaries], newline_s)
+ table.remove(boundaries)
+ cur_boundary = boundary
+ end
+ out[#out + 1] = string.format('--%s',
+ boundary)
+ end
+
+ if append_footer and not skip_footer then
+ do_append_footer(task, part, append_footer,
+ parent and parent:is_multipart(), out, state)
+ else
+ out[#out + 1] = { part:get_raw_headers(), true }
+ out[#out + 1] = { part:get_raw_content(), false }
+ end
+ end
+ end
+
+ -- Close remaining
+ local b = table.remove(boundaries)
+ while b do
+ out[#out + 1] = string.format('--%s--', b)
+ if #boundaries > 0 then
+ out[#out + 1] = ''
+ end
+ b = table.remove(boundaries)
+ end
+
+ state.out = out
+
+ return state
+end
+
+local function do_replacement (task, part, mp, replacements,
+ is_multipart, out, state)
+
+ local tp = part:get_text()
+ local ct = 'text/plain'
+ local cte = 'quoted-printable'
+ local newline_s = state.newline_s
+
+ if tp:is_html() then
+ ct = 'text/html'
+ end
+
+ local encode_func = function(input)
+ return rspamd_util.encode_qp(input, 80, task:get_newlines_type())
+ end
+
+ if part:get_cte() == '7bit' then
+ cte = '7bit'
+ encode_func = function(input)
+ if type(input) == 'userdata' then
+ return input
+ else
+ return rspamd_text.fromstring(input)
+ end
+ end
+ end
+
+ local content = tp:get_content('raw_utf') or rspamd_text.fromstring('')
+ local match_pos = mp:match(content, true)
+
+ if match_pos then
+ -- sort matches and form the table:
+ -- start .. end for inclusion position
+ local matches_flattened = {}
+ for npat, matches in pairs(match_pos) do
+ for _, m in ipairs(matches) do
+ table.insert(matches_flattened, { m, npat })
+ end
+ end
+
+ -- Handle the case of empty match
+ if #matches_flattened == 0 then
+ out[#out + 1] = { part:get_raw_headers(), true }
+ out[#out + 1] = { part:get_raw_content(), false }
+
+ return
+ end
+
+ if is_multipart then
+ out[#out + 1] = { string.format('Content-Type: %s; charset="utf-8"%s' ..
+ 'Content-Transfer-Encoding: %s',
+ ct, newline_s, cte), true }
+ out[#out + 1] = { '', true }
+ else
+ state.new_cte = cte
+ end
+
+ state.has_matches = true
+ -- now sort flattened by start of match and eliminate all overlaps
+ table.sort(matches_flattened, function(m1, m2)
+ return m1[1][1] < m2[1][1]
+ end)
+
+ for i = 1, #matches_flattened - 1 do
+ local st = matches_flattened[i][1][1] -- current start of match
+ local e = matches_flattened[i][1][2] -- current end of match
+ local max_npat = matches_flattened[i][2]
+ for j = i + 1, #matches_flattened do
+ if matches_flattened[j][1][1] == st then
+ -- overlap
+ if matches_flattened[j][1][2] > e then
+ -- larger exclusion and switch replacement
+ e = matches_flattened[j][1][2]
+ max_npat = matches_flattened[j][2]
+ end
+ else
+ break
+ end
+ end
+ -- Maximum overlap for all matches
+ for j = i, #matches_flattened do
+ if matches_flattened[j][1][1] == st then
+ if e > matches_flattened[j][1][2] then
+ matches_flattened[j][1][2] = e
+ matches_flattened[j][2] = max_npat
+ end
+ else
+ break
+ end
+ end
+ end
+ -- Off-by one: match returns 0 based positions while we use 1 based in Lua
+ for _, m in ipairs(matches_flattened) do
+ m[1][1] = m[1][1] + 1
+ m[1][2] = m[1][2] + 1
+ end
+
+ -- Now flattened match table is sorted by start pos and has the maximum overlapped pattern
+ -- Matches with the same start and end are covering the same replacement
+ -- e.g. we had something like [1 .. 2] -> replacement 1 and [1 .. 4] -> replacement 2
+ -- after flattening we should have [1 .. 4] -> 2 and [1 .. 4] -> 2
+ -- we can safely ignore those duplicates in the following code
+
+ local cur_start = 1
+ local fragments = {}
+ for _, m in ipairs(matches_flattened) do
+ if m[1][1] >= cur_start then
+ fragments[#fragments + 1] = content:sub(cur_start, m[1][1] - 1)
+ fragments[#fragments + 1] = replacements[m[2]]
+ cur_start = m[1][2] -- end of match
+ end
+ end
+
+ -- last part
+ if cur_start < #content then
+ fragments[#fragments + 1] = content:span(cur_start)
+ end
+
+ -- Final stuff
+ out[#out + 1] = { encode_func(rspamd_text.fromtable(fragments)), false }
+ else
+ -- No matches
+ out[#out + 1] = { part:get_raw_headers(), true }
+ out[#out + 1] = { part:get_raw_content(), false }
+ end
+end
+
+--[[[
+-- @function lua_mime.multipattern_text_replace(task, mp, replacements)
+-- Replaces text according to multipattern matches. It returns a table with the following
+-- fields:
+-- * out: new content (body only)
+-- * need_rewrite_ct: boolean field that means if we must rewrite content type
+-- * new_ct: new content type (type => string, subtype => string)
+-- * new_cte: new content-transfer encoding (string)
+--]]
+exports.multipattern_text_replace = function(task, mp, replacements)
+ local newline_s = newline(task)
+ local state = {
+ newline_s = newline_s
+ }
+ local out = {}
+ local text_parts = task:get_text_parts()
+
+ if not mp or not (text_parts and #text_parts > 0) then
+ return false
+ end
+
+ -- We need to take extra care about content-type and cte
+ local ct = task:get_header('Content-Type')
+ if ct then
+ ct = rspamd_util.parse_content_type(ct, task:get_mempool())
+ end
+
+ if ct then
+ if ct.type and ct.type == 'text' then
+ state.need_rewrite_ct = true
+ state.new_ct = ct
+ end
+ else
+ -- No explicit CT, need to guess
+ if text_parts then
+ if #text_parts == 1 then
+ state.need_rewrite_ct = true
+ state.new_ct = {
+ type = 'text',
+ subtype = 'plain'
+ }
+ elseif #text_parts > 1 then
+ -- XXX: in fact, it cannot be
+ state.new_ct = {
+ type = 'multipart',
+ subtype = 'mixed'
+ }
+ end
+ end
+ end
+
+ local boundaries = {}
+ local cur_boundary
+ for _, part in ipairs(task:get_parts()) do
+ local boundary = part:get_boundary()
+ if part:is_multipart() then
+ if cur_boundary then
+ out[#out + 1] = { string.format('--%s',
+ boundaries[#boundaries]), true }
+ end
+
+ boundaries[#boundaries + 1] = boundary or '--XXX'
+ cur_boundary = boundary
+
+ local rh = part:get_raw_headers()
+ if #rh > 0 then
+ out[#out + 1] = { rh, true }
+ end
+ elseif part:is_message() then
+ if boundary then
+ if cur_boundary and boundary ~= cur_boundary then
+ -- Need to close boundary
+ out[#out + 1] = { string.format('--%s--',
+ boundaries[#boundaries]), true }
+ table.remove(boundaries)
+ cur_boundary = nil
+ end
+ out[#out + 1] = { string.format('--%s',
+ boundary), true }
+ end
+
+ out[#out + 1] = { part:get_raw_headers(), true }
+ else
+ local skip_replacement = part:is_attachment()
+
+ local parent = part:get_parent()
+ if parent then
+ local t, st = parent:get_type()
+
+ if t == 'multipart' and st == 'signed' then
+ -- Do not modify signed parts
+ skip_replacement = true
+ end
+ end
+ if not part:is_text() then
+ skip_replacement = true
+ end
+
+ if boundary then
+ if cur_boundary and boundary ~= cur_boundary then
+ -- Need to close boundary
+ out[#out + 1] = { string.format('--%s--',
+ boundaries[#boundaries]), true }
+ table.remove(boundaries)
+ cur_boundary = boundary
+ end
+ out[#out + 1] = { string.format('--%s',
+ boundary), true }
+ end
+
+ if not skip_replacement then
+ do_replacement(task, part, mp, replacements,
+ parent and parent:is_multipart(), out, state)
+ else
+ -- Append as is
+ out[#out + 1] = { part:get_raw_headers(), true }
+ out[#out + 1] = { part:get_raw_content(), false }
+ end
+ end
+ end
+
+ -- Close remaining
+ local b = table.remove(boundaries)
+ while b do
+ out[#out + 1] = { string.format('--%s--', b), true }
+ if #boundaries > 0 then
+ out[#out + 1] = { '', true }
+ end
+ b = table.remove(boundaries)
+ end
+
+ state.out = out
+
+ return state
+end
+
+--[[[
+-- @function lua_mime.modify_headers(task, {add = {hname = {value = 'value', order = 1}}, remove = {hname = {1,2}}})
+-- Adds/removes headers both internal and in the milter reply
+-- Mode defines to be compatible with Rspamd <=3.2 and is the default (equal to 'compat')
+--]]
+exports.modify_headers = function(task, hdr_alterations, mode)
+ -- Assume default mode compatibility
+ if not mode then
+ mode = 'compat'
+ end
+ local add = hdr_alterations.add or {}
+ local remove = hdr_alterations.remove or {}
+
+ local add_headers = {} -- For Milter reply
+ local hdr_flattened = {} -- For C API
+
+ local function flatten_add_header(hname, hdr)
+ if not add_headers[hname] then
+ add_headers[hname] = {}
+ end
+ if not hdr_flattened[hname] then
+ hdr_flattened[hname] = { add = {} }
+ end
+ local add_tbl = hdr_flattened[hname].add
+ if hdr.value then
+ table.insert(add_headers[hname], {
+ order = (tonumber(hdr.order) or -1),
+ value = hdr.value,
+ })
+ table.insert(add_tbl, { tonumber(hdr.order) or -1, hdr.value })
+ elseif type(hdr) == 'table' then
+ for _, v in ipairs(hdr) do
+ flatten_add_header(hname, v)
+ end
+ elseif type(hdr) == 'string' then
+ table.insert(add_headers[hname], {
+ order = -1,
+ value = hdr,
+ })
+ table.insert(add_tbl, { -1, hdr })
+ else
+ logger.errx(task, 'invalid modification of header: %s', hdr)
+ end
+
+ if mode == 'compat' and #add_headers[hname] == 1 then
+ -- Switch to the compatibility mode
+ add_headers[hname] = add_headers[hname][1]
+ end
+ end
+ if hdr_alterations.order then
+ -- Get headers alterations ordered
+ for _, hname in ipairs(hdr_alterations.order) do
+ flatten_add_header(hname, add[hname])
+ end
+ else
+ for hname, hdr in pairs(add) do
+ flatten_add_header(hname, hdr)
+ end
+ end
+
+ for hname, hdr in pairs(remove) do
+ if not hdr_flattened[hname] then
+ hdr_flattened[hname] = { remove = {} }
+ end
+ if not hdr_flattened[hname].remove then
+ hdr_flattened[hname].remove = {}
+ end
+ local remove_tbl = hdr_flattened[hname].remove
+ if type(hdr) == 'number' then
+ table.insert(remove_tbl, hdr)
+ else
+ for _, num in ipairs(hdr) do
+ table.insert(remove_tbl, num)
+ end
+ end
+ end
+
+ if mode == 'compat' then
+ -- Clear empty alterations in the compat mode
+ if add_headers and not next(add_headers) then
+ add_headers = nil
+ end
+ if hdr_alterations.remove and not next(hdr_alterations.remove) then
+ hdr_alterations.remove = nil
+ end
+ end
+ task:set_milter_reply({
+ add_headers = add_headers,
+ remove_headers = hdr_alterations.remove
+ })
+
+ for hname, flat_rules in pairs(hdr_flattened) do
+ task:modify_header(hname, flat_rules)
+ end
+end
+
+--[[[
+-- @function lua_mime.message_to_ucl(task, [stringify_content])
+-- Exports a message to an ucl object
+--]]
+exports.message_to_ucl = function(task, stringify_content)
+ local E = {}
+
+ local maybe_stringify_f = stringify_content and
+ tostring or function(t)
+ return t
+ end
+ local result = {
+ size = task:get_size(),
+ digest = task:get_digest(),
+ newlines = task:get_newlines_type(),
+ headers = task:get_headers(true)
+ }
+
+ -- Utility to convert ip addr to a string or nil if invalid/absent
+ local function maybe_stringify_ip(addr)
+ if addr and addr:is_valid() then
+ return addr:to_string()
+ end
+
+ return nil
+ end
+
+ -- Envelope (smtp) information from email (nil if empty)
+ result.envelope = {
+ from_smtp = (task:get_from('smtp') or E)[1],
+ recipients_smtp = task:get_recipients('smtp'),
+ helo = task:get_helo(),
+ hostname = task:get_hostname(),
+ client_ip = maybe_stringify_ip(task:get_client_ip()),
+ from_ip = maybe_stringify_ip(task:get_from_ip()),
+ }
+ if not next(result.envelope) then
+ result.envelope = ucl.null
+ end
+
+ local parts = task:get_parts() or E
+ result.parts = {}
+ for _, part in ipairs(parts) do
+ if not part:is_multipart() and not part:is_message() then
+ local p = {
+ size = part:get_length(),
+ type = string.format('%s/%s', part:get_type()),
+ detected_type = string.format('%s/%s', part:get_detected_type()),
+ filename = part:get_filename(),
+ content = maybe_stringify_f(part:get_content()),
+ headers = part:get_headers(true) or E,
+ boundary = part:get_enclosing_boundary(),
+ }
+ table.insert(result.parts, p)
+ else
+ -- Service part: multipart container or message/rfc822
+ local p = {
+ type = string.format('%s/%s', part:get_type()),
+ headers = part:get_headers(true) or E,
+ boundary = part:get_enclosing_boundary(),
+ size = 0,
+ }
+
+ if part:is_multipart() then
+ p.multipart_boundary = part:get_boundary()
+ end
+
+ table.insert(result.parts, p)
+ end
+ end
+
+ return result
+end
+
+--[[[
+-- @function lua_mime.message_to_ucl_schema()
+-- Returns schema for a message to verify result/document fields
+--]]
+exports.message_to_ucl_schema = function()
+ local ts = require("tableshape").types
+
+ local function headers_schema()
+ return ts.shape {
+ order = ts.integer:describe('Header order in a message'),
+ raw = ts.string:describe('Raw header value'):is_optional(),
+ empty_separator = ts.boolean:describe('Whether header has an empty separator'),
+ separator = ts.string:describe('Separator between a header and a value'),
+ decoded = ts.string:describe('Decoded value'):is_optional(),
+ value = ts.string:describe('Decoded value'):is_optional(),
+ name = ts.string:describe('Header name'),
+ tab_separated = ts.boolean:describe('Whether header has tab as a separator')
+ }
+ end
+
+ local function part_schema()
+ return ts.shape {
+ content = ts.string:describe('Decoded content'):is_optional(),
+ multipart_boundary = ts.string:describe('Multipart service boundary'):is_optional(),
+ size = ts.integer:describe('Size of the part'),
+ type = ts.string:describe('Announced type'):is_optional(),
+ detected_type = ts.string:describe('Detected type'):is_optional(),
+ boundary = ts.string:describe('Eclosing boundary'):is_optional(),
+ filename = ts.string:describe('File name for attachments'):is_optional(),
+ headers = ts.array_of(headers_schema()):describe('Part headers'),
+ }
+ end
+
+ local function email_addr_schema()
+ return ts.shape {
+ addr = ts.string:describe('Parsed address'):is_optional(),
+ raw = ts.string:describe('Raw address'),
+ flags = ts.shape {
+ valid = ts.boolean:describe('Valid address'):is_optional(),
+ ip = ts.boolean:describe('IP like address'):is_optional(),
+ braced = ts.boolean:describe('Have braces around address'):is_optional(),
+ quoted = ts.boolean:describe('Have quotes around address'):is_optional(),
+ empty = ts.boolean:describe('Empty address'):is_optional(),
+ backslash = ts.boolean:describe('Backslash in address'):is_optional(),
+ ['8bit'] = ts.boolean:describe('8 bit characters in address'):is_optional(),
+ },
+ user = ts.string:describe('Parsed user part'):is_optional(),
+ name = ts.string:describe('Displayed name'):is_optional(),
+ domain = ts.string:describe('Parsed domain part'):is_optional(),
+ }
+ end
+ local function envelope_schema()
+ return ts.shape {
+ from_smtp = email_addr_schema():describe('SMTP from'):is_optional(),
+ recipients_smtp = ts.array_of(email_addr_schema()):describe('SMTP recipients'):is_optional(),
+ helo = ts.string:describe('SMTP Helo'):is_optional(),
+ hostname = ts.string:describe('Sender hostname'):is_optional(),
+ client_ip = ts.string:describe('Client ip'):is_optional(),
+ from_ip = ts.string:describe('Sender ip'):is_optional(),
+ }
+ end
+
+ return ts.shape {
+ headers = ts.array_of(headers_schema()),
+ parts = ts.array_of(part_schema()),
+ digest = ts.pattern(string.format('^%s$', string.rep('%x', 32)))
+ :describe('Message digest'),
+ newlines = ts.one_of({ "cr", "lf", "crlf" }):describe('Newlines type'),
+ size = ts.integer:describe('Size of the message in bytes'),
+ envelope = envelope_schema()
+ }
+end
+
+return exports
diff --git a/lualib/lua_mime_types.lua b/lualib/lua_mime_types.lua
new file mode 100644
index 0000000..ba55f97
--- /dev/null
+++ b/lualib/lua_mime_types.lua
@@ -0,0 +1,745 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_mime_types
+-- This module contains mime types list
+--]]
+
+local exports = {}
+
+-- All mime extensions with corresponding content types
+exports.full_extensions_map = {
+ { "323", "text/h323" },
+ { "3g2", "video/3gpp2" },
+ { "3gp", "video/3gpp" },
+ { "3gp2", "video/3gpp2" },
+ { "3gpp", "video/3gpp" },
+ { "7z", { "application/x-7z-compressed", "application/7z" } },
+ { "aa", "audio/audible" },
+ { "AAC", "audio/aac" },
+ { "aaf", "application/octet-stream" },
+ { "aax", "audio/vnd.audible.aax" },
+ { "ac3", "audio/ac3" },
+ { "aca", "application/octet-stream" },
+ { "accda", "application/msaccess.addin" },
+ { "accdb", "application/msaccess" },
+ { "accdc", "application/msaccess.cab" },
+ { "accde", "application/msaccess" },
+ { "accdr", "application/msaccess.runtime" },
+ { "accdt", "application/msaccess" },
+ { "accdw", "application/msaccess.webapplication" },
+ { "accft", "application/msaccess.ftemplate" },
+ { "acx", "application/internet-property-stream" },
+ { "AddIn", "text/xml" },
+ { "ade", "application/msaccess" },
+ { "adobebridge", "application/x-bridge-url" },
+ { "adp", "application/msaccess" },
+ { "ADT", "audio/vnd.dlna.adts" },
+ { "ADTS", "audio/aac" },
+ { "afm", "application/octet-stream" },
+ { "ai", "application/postscript" },
+ { "aif", "audio/aiff" },
+ { "aifc", "audio/aiff" },
+ { "aiff", "audio/aiff" },
+ { "air", "application/vnd.adobe.air-application-installer-package+zip" },
+ { "amc", "application/mpeg" },
+ { "anx", "application/annodex" },
+ { "apk", "application/vnd.android.package-archive" },
+ { "application", "application/x-ms-application" },
+ { "art", "image/x-jg" },
+ { "asa", "application/xml" },
+ { "asax", "application/xml" },
+ { "ascx", "application/xml" },
+ { "asd", "application/octet-stream" },
+ { "asf", "video/x-ms-asf" },
+ { "ashx", "application/xml" },
+ { "asi", "application/octet-stream" },
+ { "asm", "text/plain" },
+ { "asmx", "application/xml" },
+ { "aspx", "application/xml" },
+ { "asr", "video/x-ms-asf" },
+ { "asx", "video/x-ms-asf" },
+ { "atom", "application/atom+xml" },
+ { "au", "audio/basic" },
+ { "avi", "video/x-msvideo" },
+ { "axa", "audio/annodex" },
+ { "axs", "application/olescript" },
+ { "axv", "video/annodex" },
+ { "bas", "text/plain" },
+ { "bcpio", "application/x-bcpio" },
+ { "bin", "application/octet-stream" },
+ { "bmp", { "image/bmp", "image/x-ms-bmp" } },
+ { "c", "text/plain" },
+ { "cab", "application/octet-stream" },
+ { "caf", "audio/x-caf" },
+ { "calx", "application/vnd.ms-office.calx" },
+ { "cat", "application/vnd.ms-pki.seccat" },
+ { "cc", "text/plain" },
+ { "cd", "text/plain" },
+ { "cdda", "audio/aiff" },
+ { "cdf", "application/x-cdf" },
+ { "cer", "application/x-x509-ca-cert" },
+ { "cfg", "text/plain" },
+ { "chm", "application/octet-stream" },
+ { "class", "application/x-java-applet" },
+ { "clp", "application/x-msclip" },
+ { "cmd", "text/plain" },
+ { "cmx", "image/x-cmx" },
+ { "cnf", "text/plain" },
+ { "cod", "image/cis-cod" },
+ { "config", "application/xml" },
+ { "contact", "text/x-ms-contact" },
+ { "coverage", "application/xml" },
+ { "cpio", "application/x-cpio" },
+ { "cpp", "text/plain" },
+ { "crd", "application/x-mscardfile" },
+ { "crl", "application/pkix-crl" },
+ { "crt", "application/x-x509-ca-cert" },
+ { "cs", "text/plain" },
+ { "csdproj", "text/plain" },
+ { "csh", "application/x-csh" },
+ { "csproj", "text/plain" },
+ { "css", "text/css" },
+ { "csv", { "application/vnd.ms-excel", "text/csv", "text/plain" } },
+ { "cur", "application/octet-stream" },
+ { "cxx", "text/plain" },
+ { "dat", { "application/octet-stream", "application/ms-tnef" } },
+ { "datasource", "application/xml" },
+ { "dbproj", "text/plain" },
+ { "dcr", "application/x-director" },
+ { "def", "text/plain" },
+ { "deploy", "application/octet-stream" },
+ { "der", "application/x-x509-ca-cert" },
+ { "dgml", "application/xml" },
+ { "dib", "image/bmp" },
+ { "dif", "video/x-dv" },
+ { "dir", "application/x-director" },
+ { "disco", "text/xml" },
+ { "divx", "video/divx" },
+ { "dll", "application/x-msdownload" },
+ { "dll.config", "text/xml" },
+ { "dlm", "text/dlm" },
+ { "doc", "application/msword" },
+ { "docm", "application/vnd.ms-word.document.macroEnabled.12" },
+ { "docx", {
+ "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
+ "application/msword",
+ "application/vnd.ms-word.document.12",
+ "application/octet-stream",
+ } },
+ { "dot", "application/msword" },
+ { "dotm", "application/vnd.ms-word.template.macroEnabled.12" },
+ { "dotx", "application/vnd.openxmlformats-officedocument.wordprocessingml.template" },
+ { "dsp", "application/octet-stream" },
+ { "dsw", "text/plain" },
+ { "dtd", "text/xml" },
+ { "dtsConfig", "text/xml" },
+ { "dv", "video/x-dv" },
+ { "dvi", "application/x-dvi" },
+ { "dwf", "drawing/x-dwf" },
+ { "dwg", { "application/acad", "image/vnd.dwg" } },
+ { "dwp", "application/octet-stream" },
+ { "dxf", "application/x-dxf" },
+ { "dxr", "application/x-director" },
+ { "eml", "message/rfc822" },
+ { "emz", "application/octet-stream" },
+ { "eot", "application/vnd.ms-fontobject" },
+ { "eps", "application/postscript" },
+ { "etl", "application/etl" },
+ { "etx", "text/x-setext" },
+ { "evy", "application/envoy" },
+ { "exe", {
+ "application/x-dosexec",
+ "application/x-msdownload",
+ "application/x-executable",
+ } },
+ { "exe.config", "text/xml" },
+ { "fdf", "application/vnd.fdf" },
+ { "fif", "application/fractals" },
+ { "filters", "application/xml" },
+ { "fla", "application/octet-stream" },
+ { "flac", "audio/flac" },
+ { "flr", "x-world/x-vrml" },
+ { "flv", "video/x-flv" },
+ { "fsscript", "application/fsharp-script" },
+ { "fsx", "application/fsharp-script" },
+ { "generictest", "application/xml" },
+ { "gif", "image/gif" },
+ { "gpx", "application/gpx+xml" },
+ { "group", "text/x-ms-group" },
+ { "gsm", "audio/x-gsm" },
+ { "gtar", "application/x-gtar" },
+ { "gz", { "application/gzip", "application/x-gzip", "application/tlsrpt+gzip" } },
+ { "h", "text/plain" },
+ { "hdf", "application/x-hdf" },
+ { "hdml", "text/x-hdml" },
+ { "hhc", "application/x-oleobject" },
+ { "hhk", "application/octet-stream" },
+ { "hhp", "application/octet-stream" },
+ { "hlp", "application/winhlp" },
+ { "hpp", "text/plain" },
+ { "hqx", "application/mac-binhex40" },
+ { "hta", "application/hta" },
+ { "htc", "text/x-component" },
+ { "htm", "text/html" },
+ { "html", "text/html" },
+ { "htt", "text/webviewhtml" },
+ { "hxa", "application/xml" },
+ { "hxc", "application/xml" },
+ { "hxd", "application/octet-stream" },
+ { "hxe", "application/xml" },
+ { "hxf", "application/xml" },
+ { "hxh", "application/octet-stream" },
+ { "hxi", "application/octet-stream" },
+ { "hxk", "application/xml" },
+ { "hxq", "application/octet-stream" },
+ { "hxr", "application/octet-stream" },
+ { "hxs", "application/octet-stream" },
+ { "hxt", "text/html" },
+ { "hxv", "application/xml" },
+ { "hxw", "application/octet-stream" },
+ { "hxx", "text/plain" },
+ { "i", "text/plain" },
+ { "ico", "image/x-icon" },
+ { "ics", { "text/calendar", "application/ics", "application/octet-stream" } },
+ { "idl", "text/plain" },
+ { "ief", "image/ief" },
+ { "iii", "application/x-iphone" },
+ { "inc", "text/plain" },
+ { "inf", "application/octet-stream" },
+ { "ini", "text/plain" },
+ { "inl", "text/plain" },
+ { "ins", "application/x-internet-signup" },
+ { "ipa", "application/x-itunes-ipa" },
+ { "ipg", "application/x-itunes-ipg" },
+ { "ipproj", "text/plain" },
+ { "ipsw", "application/x-itunes-ipsw" },
+ { "iqy", "text/x-ms-iqy" },
+ { "isp", "application/x-internet-signup" },
+ { "ite", "application/x-itunes-ite" },
+ { "itlp", "application/x-itunes-itlp" },
+ { "itms", "application/x-itunes-itms" },
+ { "itpc", "application/x-itunes-itpc" },
+ { "IVF", "video/x-ivf" },
+ { "jar", "application/java-archive" },
+ { "java", "application/octet-stream" },
+ { "jck", "application/liquidmotion" },
+ { "jcz", "application/liquidmotion" },
+ { "jfif", { "image/jpeg", "image/pjpeg" } },
+ { "jnlp", "application/x-java-jnlp-file" },
+ { "jpb", "application/octet-stream" },
+ { "jpe", { "image/jpeg", "image/pjpeg" } },
+ { "jpeg", { "image/jpeg", "image/pjpeg" } },
+ { "jpg", { "image/jpeg", "image/pjpeg" } },
+ { "js", "application/javascript" },
+ { "json", "application/json" },
+ { "jsx", "text/jscript" },
+ { "jsxbin", "text/plain" },
+ { "latex", "application/x-latex" },
+ { "library-ms", "application/windows-library+xml" },
+ { "lit", "application/x-ms-reader" },
+ { "loadtest", "application/xml" },
+ { "lpk", "application/octet-stream" },
+ { "lsf", "video/x-la-asf" },
+ { "lst", "text/plain" },
+ { "lsx", "video/x-la-asf" },
+ { "lzh", "application/octet-stream" },
+ { "m13", "application/x-msmediaview" },
+ { "m14", "application/x-msmediaview" },
+ { "m1v", "video/mpeg" },
+ { "m2t", "video/vnd.dlna.mpeg-tts" },
+ { "m2ts", "video/vnd.dlna.mpeg-tts" },
+ { "m2v", "video/mpeg" },
+ { "m3u", "audio/x-mpegurl" },
+ { "m3u8", "audio/x-mpegurl" },
+ { "m4a", { "audio/m4a", "audio/x-m4a" } },
+ { "m4b", "audio/m4b" },
+ { "m4p", "audio/m4p" },
+ { "m4r", "audio/x-m4r" },
+ { "m4v", "video/x-m4v" },
+ { "mac", "image/x-macpaint" },
+ { "mak", "text/plain" },
+ { "man", "application/x-troff-man" },
+ { "manifest", "application/x-ms-manifest" },
+ { "map", "text/plain" },
+ { "master", "application/xml" },
+ { "mbox", "application/mbox" },
+ { "mda", "application/msaccess" },
+ { "mdb", "application/x-msaccess" },
+ { "mde", "application/msaccess" },
+ { "mdp", "application/octet-stream" },
+ { "me", "application/x-troff-me" },
+ { "mfp", "application/x-shockwave-flash" },
+ { "mht", "message/rfc822" },
+ { "mhtml", "message/rfc822" },
+ { "mid", "audio/mid" },
+ { "midi", "audio/mid" },
+ { "mix", "application/octet-stream" },
+ { "mk", "text/plain" },
+ { "mmf", "application/x-smaf" },
+ { "mno", "text/xml" },
+ { "mny", "application/x-msmoney" },
+ { "mod", "video/mpeg" },
+ { "mov", "video/quicktime" },
+ { "movie", "video/x-sgi-movie" },
+ { "mp2", "video/mpeg" },
+ { "mp2v", "video/mpeg" },
+ { "mp3", { "audio/mpeg", "audio/mpeg3", "audio/mp3", "audio/x-mpeg-3" } },
+ { "mp4", "video/mp4" },
+ { "mp4v", "video/mp4" },
+ { "mpa", "video/mpeg" },
+ { "mpe", "video/mpeg" },
+ { "mpeg", "video/mpeg" },
+ { "mpf", "application/vnd.ms-mediapackage" },
+ { "mpg", "video/mpeg" },
+ { "mpp", "application/vnd.ms-project" },
+ { "mpv2", "video/mpeg" },
+ { "mqv", "video/quicktime" },
+ { "ms", "application/x-troff-ms" },
+ { "msg", "application/vnd.ms-outlook" },
+ { "msi", { "application/x-msi", "application/octet-stream" } },
+ { "mso", "application/octet-stream" },
+ { "mts", "video/vnd.dlna.mpeg-tts" },
+ { "mtx", "application/xml" },
+ { "mvb", "application/x-msmediaview" },
+ { "mvc", "application/x-miva-compiled" },
+ { "mxp", "application/x-mmxp" },
+ { "nc", "application/x-netcdf" },
+ { "nsc", "video/x-ms-asf" },
+ { "nws", "message/rfc822" },
+ { "ocx", "application/octet-stream" },
+ { "oda", "application/oda" },
+ { "odb", "application/vnd.oasis.opendocument.database" },
+ { "odc", "application/vnd.oasis.opendocument.chart" },
+ { "odf", "application/vnd.oasis.opendocument.formula" },
+ { "odg", "application/vnd.oasis.opendocument.graphics" },
+ { "odh", "text/plain" },
+ { "odi", "application/vnd.oasis.opendocument.image" },
+ { "odl", "text/plain" },
+ { "odm", "application/vnd.oasis.opendocument.text-master" },
+ { "odp", "application/vnd.oasis.opendocument.presentation" },
+ { "ods", "application/vnd.oasis.opendocument.spreadsheet" },
+ { "odt", "application/vnd.oasis.opendocument.text" },
+ { "oga", "audio/ogg" },
+ { "ogg", "audio/ogg" },
+ { "ogv", "video/ogg" },
+ { "ogx", "application/ogg" },
+ { "one", "application/onenote" },
+ { "onea", "application/onenote" },
+ { "onepkg", "application/onenote" },
+ { "onetmp", "application/onenote" },
+ { "onetoc", "application/onenote" },
+ { "onetoc2", "application/onenote" },
+ { "opus", "audio/ogg" },
+ { "orderedtest", "application/xml" },
+ { "osdx", "application/opensearchdescription+xml" },
+ { "otf", "application/font-sfnt" },
+ { "otg", "application/vnd.oasis.opendocument.graphics-template" },
+ { "oth", "application/vnd.oasis.opendocument.text-web" },
+ { "otp", "application/vnd.oasis.opendocument.presentation-template" },
+ { "ots", "application/vnd.oasis.opendocument.spreadsheet-template" },
+ { "ott", "application/vnd.oasis.opendocument.text-template" },
+ { "oxt", "application/vnd.openofficeorg.extension" },
+ { "p10", "application/pkcs10" },
+ { "p12", "application/x-pkcs12" },
+ { "p7b", "application/x-pkcs7-certificates" },
+ { "p7c", "application/pkcs7-mime" },
+ { "p7m", "application/pkcs7-mime", "application/x-pkcs7-mime" },
+ { "p7r", "application/x-pkcs7-certreqresp" },
+ { "p7s", { "application/pkcs7-signature", "application/x-pkcs7-signature", "text/plain" } },
+ { "pbm", "image/x-portable-bitmap" },
+ { "pcast", "application/x-podcast" },
+ { "pct", "image/pict" },
+ { "pcx", "application/octet-stream" },
+ { "pcz", "application/octet-stream" },
+ { "pdf", "application/pdf" },
+ { "pfb", "application/octet-stream" },
+ { "pfm", "application/octet-stream" },
+ { "pfx", "application/x-pkcs12" },
+ { "pgm", "image/x-portable-graymap" },
+ { "pic", "image/pict" },
+ { "pict", "image/pict" },
+ { "pkgdef", "text/plain" },
+ { "pkgundef", "text/plain" },
+ { "pko", "application/vnd.ms-pki.pko" },
+ { "pls", "audio/scpls" },
+ { "pma", "application/x-perfmon" },
+ { "pmc", "application/x-perfmon" },
+ { "pml", "application/x-perfmon" },
+ { "pmr", "application/x-perfmon" },
+ { "pmw", "application/x-perfmon" },
+ { "png", "image/png" },
+ { "pnm", "image/x-portable-anymap" },
+ { "pnt", "image/x-macpaint" },
+ { "pntg", "image/x-macpaint" },
+ { "pnz", "image/png" },
+ { "pot", "application/vnd.ms-powerpoint" },
+ { "potm", "application/vnd.ms-powerpoint.template.macroEnabled.12" },
+ { "potx", "application/vnd.openxmlformats-officedocument.presentationml.template" },
+ { "ppa", "application/vnd.ms-powerpoint" },
+ { "ppam", "application/vnd.ms-powerpoint.addin.macroEnabled.12" },
+ { "ppm", "image/x-portable-pixmap" },
+ { "pps", "application/vnd.ms-powerpoint" },
+ { "ppsm", "application/vnd.ms-powerpoint.slideshow.macroEnabled.12" },
+ { "ppsx", "application/vnd.openxmlformats-officedocument.presentationml.slideshow" },
+ { "ppt", "application/vnd.ms-powerpoint" },
+ { "pptm", "application/vnd.ms-powerpoint.presentation.macroEnabled.12" },
+ { "pptx", "application/vnd.openxmlformats-officedocument.presentationml.presentation" },
+ { "prf", "application/pics-rules" },
+ { "prm", "application/octet-stream" },
+ { "prx", "application/octet-stream" },
+ { "ps", "application/postscript" },
+ { "psc1", "application/PowerShell" },
+ { "psd", "application/octet-stream" },
+ { "psess", "application/xml" },
+ { "psm", "application/octet-stream" },
+ { "psp", "application/octet-stream" },
+ { "pst", "application/vnd.ms-outlook" },
+ { "pub", "application/x-mspublisher" },
+ { "pwz", "application/vnd.ms-powerpoint" },
+ { "qht", "text/x-html-insertion" },
+ { "qhtm", "text/x-html-insertion" },
+ { "qt", "video/quicktime" },
+ { "qti", "image/x-quicktime" },
+ { "qtif", "image/x-quicktime" },
+ { "qtl", "application/x-quicktimeplayer" },
+ { "qxd", "application/octet-stream" },
+ { "ra", "audio/x-pn-realaudio" },
+ { "ram", "audio/x-pn-realaudio" },
+ { "rar", { "application/x-rar-compressed", "application/x-rar", "application/rar", "application/octet-stream" } },
+ { "ras", "image/x-cmu-raster" },
+ { "rat", "application/rat-file" },
+ { "rc", "text/plain" },
+ { "rc2", "text/plain" },
+ { "rct", "text/plain" },
+ { "rdlc", "application/xml" },
+ { "reg", "text/plain" },
+ { "resx", "application/xml" },
+ { "rf", "image/vnd.rn-realflash" },
+ { "rgb", "image/x-rgb" },
+ { "rgs", "text/plain" },
+ { "rm", "application/vnd.rn-realmedia" },
+ { "rmi", "audio/mid" },
+ { "rmp", "application/vnd.rn-rn_music_package" },
+ { "roff", "application/x-troff" },
+ { "rpm", "audio/x-pn-realaudio-plugin" },
+ { "rqy", "text/x-ms-rqy" },
+ { "rtf", { "application/rtf", "application/msword", "text/richtext", "text/rtf" } },
+ { "rtx", "text/richtext" },
+ { "rvt", "application/octet-stream" },
+ { "ruleset", "application/xml" },
+ { "s", "text/plain" },
+ { "safariextz", "application/x-safari-safariextz" },
+ { "scd", "application/x-msschedule" },
+ { "scr", "text/plain" },
+ { "sct", "text/scriptlet" },
+ { "sd2", "audio/x-sd2" },
+ { "sdp", "application/sdp" },
+ { "sea", "application/octet-stream" },
+ { "searchConnector-ms", "application/windows-search-connector+xml" },
+ { "setpay", "application/set-payment-initiation" },
+ { "setreg", "application/set-registration-initiation" },
+ { "settings", "application/xml" },
+ { "sgimb", "application/x-sgimb" },
+ { "sgml", "text/sgml" },
+ { "sh", "application/x-sh" },
+ { "shar", "application/x-shar" },
+ { "shtml", "text/html" },
+ { "sit", "application/x-stuffit" },
+ { "sitemap", "application/xml" },
+ { "skin", "application/xml" },
+ { "skp", "application/x-koan" },
+ { "sldm", "application/vnd.ms-powerpoint.slide.macroEnabled.12" },
+ { "sldx", "application/vnd.openxmlformats-officedocument.presentationml.slide" },
+ { "slk", "application/vnd.ms-excel" },
+ { "sln", "text/plain" },
+ { "slupkg-ms", "application/x-ms-license" },
+ { "smd", "audio/x-smd" },
+ { "smi", "application/octet-stream" },
+ { "smx", "audio/x-smd" },
+ { "smz", "audio/x-smd" },
+ { "snd", "audio/basic" },
+ { "snippet", "application/xml" },
+ { "snp", "application/octet-stream" },
+ { "sol", "text/plain" },
+ { "sor", "text/plain" },
+ { "spc", "application/x-pkcs7-certificates" },
+ { "spl", "application/futuresplash" },
+ { "spx", "audio/ogg" },
+ { "src", "application/x-wais-source" },
+ { "srf", "text/plain" },
+ { "SSISDeploymentManifest", "text/xml" },
+ { "ssm", "application/streamingmedia" },
+ { "sst", "application/vnd.ms-pki.certstore" },
+ { "stl", "application/vnd.ms-pki.stl" },
+ { "sv4cpio", "application/x-sv4cpio" },
+ { "sv4crc", "application/x-sv4crc" },
+ { "svc", "application/xml" },
+ { "svg", "image/svg+xml" },
+ { "swf", "application/x-shockwave-flash" },
+ { "step", "application/step" },
+ { "stp", "application/step" },
+ { "t", "application/x-troff" },
+ { "tar", "application/x-tar" },
+ { "tcl", "application/x-tcl" },
+ { "testrunconfig", "application/xml" },
+ { "testsettings", "application/xml" },
+ { "tex", "application/x-tex" },
+ { "texi", "application/x-texinfo" },
+ { "texinfo", "application/x-texinfo" },
+ { "tgz", "application/x-compressed" },
+ { "thmx", "application/vnd.ms-officetheme" },
+ { "thn", "application/octet-stream" },
+ { "tif", { "image/tiff", "application/octet-stream" } },
+ { "tiff", "image/tiff" },
+ { "tlh", "text/plain" },
+ { "tli", "text/plain" },
+ { "toc", "application/octet-stream" },
+ { "tr", "application/x-troff" },
+ { "trm", "application/x-msterminal" },
+ { "trx", "application/xml" },
+ { "ts", "video/vnd.dlna.mpeg-tts" },
+ { "tsv", "text/tab-separated-values" },
+ { "ttf", "application/font-sfnt" },
+ { "tts", "video/vnd.dlna.mpeg-tts" },
+ { "txt", "text/plain" },
+ { "u32", "application/octet-stream" },
+ { "uls", "text/iuls" },
+ { "user", "text/plain" },
+ { "ustar", "application/x-ustar" },
+ { "vb", "text/plain" },
+ { "vbdproj", "text/plain" },
+ { "vbk", "video/mpeg" },
+ { "vbproj", "text/plain" },
+ { "vbs", "text/vbscript" },
+ { "vcf", { "text/x-vcard", "text/vcard" } },
+ { "vcproj", "application/xml" },
+ { "vcs", "text/plain" },
+ { "vcxproj", "application/xml" },
+ { "vddproj", "text/plain" },
+ { "vdp", "text/plain" },
+ { "vdproj", "text/plain" },
+ { "vdx", "application/vnd.ms-visio.viewer" },
+ { "vml", "text/xml" },
+ { "vscontent", "application/xml" },
+ { "vsct", "text/xml" },
+ { "vsd", "application/vnd.visio" },
+ { "vsi", "application/ms-vsi" },
+ { "vsix", "application/vsix" },
+ { "vsixlangpack", "text/xml" },
+ { "vsixmanifest", "text/xml" },
+ { "vsmdi", "application/xml" },
+ { "vspscc", "text/plain" },
+ { "vss", "application/vnd.visio" },
+ { "vsscc", "text/plain" },
+ { "vssettings", "text/xml" },
+ { "vssscc", "text/plain" },
+ { "vst", "application/vnd.visio" },
+ { "vstemplate", "text/xml" },
+ { "vsto", "application/x-ms-vsto" },
+ { "vsw", "application/vnd.visio" },
+ { "vsx", "application/vnd.visio" },
+ { "vtx", "application/vnd.visio" },
+ { "wav", { "audio/wav", "audio/vnd.wave", "audio/x-wav" } },
+ { "wave", "audio/wav" },
+ { "wax", "audio/x-ms-wax" },
+ { "wbk", "application/msword" },
+ { "wbmp", "image/vnd.wap.wbmp" },
+ { "wcm", "application/vnd.ms-works" },
+ { "wdb", "application/vnd.ms-works" },
+ { "wdp", "image/vnd.ms-photo" },
+ { "webarchive", "application/x-safari-webarchive" },
+ { "webm", "video/webm" },
+ { "webp", "image/webp" },
+ { "webtest", "application/xml" },
+ { "wiq", "application/xml" },
+ { "wiz", "application/msword" },
+ { "wks", "application/vnd.ms-works" },
+ { "WLMP", "application/wlmoviemaker" },
+ { "wlpginstall", "application/x-wlpg-detect" },
+ { "wlpginstall3", "application/x-wlpg3-detect" },
+ { "wm", "video/x-ms-wm" },
+ { "wma", "audio/x-ms-wma" },
+ { "wmd", "application/x-ms-wmd" },
+ { "wmf", { "application/x-msmetafile", "image/wmf", "image/x-wmf" } },
+ { "wml", "text/vnd.wap.wml" },
+ { "wmlc", "application/vnd.wap.wmlc" },
+ { "wmls", "text/vnd.wap.wmlscript" },
+ { "wmlsc", "application/vnd.wap.wmlscriptc" },
+ { "wmp", "video/x-ms-wmp" },
+ { "wmv", "video/x-ms-wmv" },
+ { "wmx", "video/x-ms-wmx" },
+ { "wmz", "application/x-ms-wmz" },
+ { "woff", "application/font-woff" },
+ { "wpl", "application/vnd.ms-wpl" },
+ { "wps", "application/vnd.ms-works" },
+ { "wri", "application/x-mswrite" },
+ { "wrl", "x-world/x-vrml" },
+ { "wrz", "x-world/x-vrml" },
+ { "wsc", "text/scriptlet" },
+ { "wsdl", "text/xml" },
+ { "wvx", "video/x-ms-wvx" },
+ { "x", "application/directx" },
+ { "xaf", "x-world/x-vrml" },
+ { "xaml", "application/xaml+xml" },
+ { "xap", "application/x-silverlight-app" },
+ { "xbap", "application/x-ms-xbap" },
+ { "xbm", "image/x-xbitmap" },
+ { "xdr", "text/plain" },
+ { "xht", "application/xhtml+xml" },
+ { "xhtml", "application/xhtml+xml" },
+ { "xla", "application/vnd.ms-excel" },
+ { "xlam", "application/vnd.ms-excel.addin.macroEnabled.12" },
+ { "xlc", "application/vnd.ms-excel" },
+ { "xld", "application/vnd.ms-excel" },
+ { "xlk", "application/vnd.ms-excel" },
+ { "xll", "application/vnd.ms-excel" },
+ { "xlm", "application/vnd.ms-excel" },
+ { "xls", {
+ "application/excel",
+ "application/vnd.ms-excel",
+ "application/vnd.ms-office",
+ "application/x-excel",
+ "application/octet-stream"
+ } },
+ { "xlsb", "application/vnd.ms-excel.sheet.binary.macroEnabled.12" },
+ { "xlsm", "application/vnd.ms-excel.sheet.macroEnabled.12" },
+ { "xlsx", {
+ "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
+ "application/vnd.ms-excel.12",
+ "application/octet-stream"
+ } },
+ { "xlt", "application/vnd.ms-excel" },
+ { "xltm", "application/vnd.ms-excel.template.macroEnabled.12" },
+ { "xltx", "application/vnd.openxmlformats-officedocument.spreadsheetml.template" },
+ { "xlw", "application/vnd.ms-excel" },
+ { "xml", { "application/xml", "text/xml", "application/octet-stream" } },
+ { "xmp", "application/octet-stream" },
+ { "xmta", "application/xml" },
+ { "xof", "x-world/x-vrml" },
+ { "XOML", "text/plain" },
+ { "xpm", "image/x-xpixmap" },
+ { "xps", "application/vnd.ms-xpsdocument" },
+ { "xrm-ms", "text/xml" },
+ { "xsc", "application/xml" },
+ { "xsd", "text/xml" },
+ { "xsf", "text/xml" },
+ { "xsl", "text/xml" },
+ { "xslt", "text/xml" },
+ { "xsn", "application/octet-stream" },
+ { "xss", "application/xml" },
+ { "xspf", "application/xspf+xml" },
+ { "xtp", "application/octet-stream" },
+ { "xwd", "image/x-xwindowdump" },
+ { "z", "application/x-compress" },
+ { "zip", {
+ "application/zip",
+ "application/x-zip-compressed",
+ "application/octet-stream"
+ } },
+ { "zlib", "application/zlib" },
+}
+
+-- Used to match extension by content type
+exports.reversed_extensions_map = {
+ ["text/html"] = "html",
+ ["text/css"] = "css",
+ ["text/xml"] = "xml",
+ ["image/gif"] = "gif",
+ ["image/jpeg"] = "jpeg",
+ ["application/javascript"] = "js",
+ ["application/atom+xml"] = "atom",
+ ["application/rss+xml"] = "rss",
+ ["application/csv"] = "csv",
+ ["text/mathml"] = "mml",
+ ["text/plain"] = "txt",
+ ["text/vnd.sun.j2me.app-descriptor"] = "jad",
+ ["text/vnd.wap.wml"] = "wml",
+ ["text/x-component"] = "htc",
+ ["image/png"] = "png",
+ ["image/svg+xml"] = "svg",
+ ["image/tiff"] = "tiff",
+ ["image/vnd.wap.wbmp"] = "wbmp",
+ ["image/webp"] = "webp",
+ ["image/x-icon"] = "ico",
+ ["image/x-jng"] = "jng",
+ ["image/x-ms-bmp"] = "bmp",
+ ["font/woff"] = "woff",
+ ["font/woff2"] = "woff2",
+ ["application/java-archive"] = "jar",
+ ["application/json"] = "json",
+ ["application/mac-binhex40"] = "hqx",
+ ["application/msword"] = "doc",
+ ["application/pdf"] = "pdf",
+ ["application/postscript"] = "ps",
+ ["application/rtf"] = "rtf",
+ ["application/vnd.apple.mpegurl"] = "m3u8",
+ ["application/vnd.google-earth.kml+xml"] = "kml",
+ ["application/vnd.google-earth.kmz"] = "kmz",
+ ["application/vnd.ms-excel"] = "xls",
+ ["application/vnd.ms-fontobject"] = "eot",
+ ["application/vnd.ms-powerpoint"] = "ppt",
+ ["application/vnd.oasis.opendocument.graphics"] = "odg",
+ ["application/vnd.oasis.opendocument.presentation"] = "odp",
+ ["application/vnd.oasis.opendocument.spreadsheet"] = "ods",
+ ["application/vnd.oasis.opendocument.text"] = "odt",
+ ["application/vnd.openxmlformats-officedocument.presentationml.presentation"] = "pptx",
+ ["application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"] = "xlsx",
+ ["application/vnd.openxmlformats-officedocument.wordprocessingml.document"] = "docx",
+ ["application/x-7z-compressed"] = "7z",
+ ["application/x-cocoa"] = "cco",
+ ["application/x-java-archive-diff"] = "jardiff",
+ ["application/x-java-jnlp-file"] = "jnlp",
+ ["application/x-makeself"] = "run",
+ ["application/x-perl"] = "pl",
+ ["application/x-pilot"] = "pdb",
+ ["application/x-rar-compressed"] = "rar",
+ ["application/x-redhat-package-manager"] = "rpm",
+ ["application/x-sea"] = "sea",
+ ["application/x-shockwave-flash"] = "swf",
+ ["application/x-stuffit"] = "sit",
+ ["application/x-tcl"] = "tcl",
+ ["application/x-x509-ca-cert"] = "crt",
+ ["application/x-xpinstall"] = "xpi",
+ ["application/xhtml+xml"] = "xhtml",
+ ["application/xspf+xml"] = "xspf",
+ ["application/zip"] = "zip",
+ ["application/x-dosexec"] = "exe",
+ ["application/x-msdownload"] = "exe",
+ ["application/x-executable"] = "exe",
+ ["text/x-msdos-batch"] = "bat",
+
+ ["audio/midi"] = "mid",
+ ["audio/mpeg"] = "mp3",
+ ["audio/ogg"] = "ogg",
+ ["audio/x-m4a"] = "m4a",
+ ["audio/x-realaudio"] = "ra",
+ ["video/3gpp"] = "3gpp",
+ ["video/mp2t"] = "ts",
+ ["video/mp4"] = "mp4",
+ ["video/mpeg"] = "mpeg",
+ ["video/quicktime"] = "mov",
+ ["video/webm"] = "webm",
+ ["video/x-flv"] = "flv",
+ ["video/x-m4v"] = "m4v",
+ ["video/x-mng"] = "mng",
+ ["video/x-ms-asf"] = "asx",
+ ["video/x-ms-wmv"] = "wmv",
+ ["video/x-msvideo"] = "avi",
+}
+
+return exports
diff --git a/lualib/lua_redis.lua b/lualib/lua_redis.lua
new file mode 100644
index 0000000..818d955
--- /dev/null
+++ b/lualib/lua_redis.lua
@@ -0,0 +1,1817 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local logger = require "rspamd_logger"
+local lutil = require "lua_util"
+local rspamd_util = require "rspamd_util"
+local ts = require("tableshape").types
+
+local exports = {}
+
+local E = {}
+local N = "lua_redis"
+
+local common_schema = {
+ timeout = (ts.number + ts.string / lutil.parse_time_interval):is_optional():describe("Connection timeout"),
+ db = ts.string:is_optional():describe("Database number"),
+ database = ts.string:is_optional():describe("Database number"),
+ dbname = ts.string:is_optional():describe("Database number"),
+ prefix = ts.string:is_optional():describe("Key prefix"),
+ username = ts.string:is_optional():describe("Username"),
+ password = ts.string:is_optional():describe("Password"),
+ expand_keys = ts.boolean:is_optional():describe("Expand keys"),
+ sentinels = (ts.string + ts.array_of(ts.string)):is_optional():describe("Sentinel servers"),
+ sentinel_watch_time = (ts.number + ts.string / lutil.parse_time_interval):is_optional():describe("Sentinel watch time"),
+ sentinel_masters_pattern = ts.string:is_optional():describe("Sentinel masters pattern"),
+ sentinel_master_maxerrors = (ts.number + ts.string / tonumber):is_optional():describe("Sentinel master max errors"),
+ sentinel_username = ts.string:is_optional():describe("Sentinel username"),
+ sentinel_password = ts.string:is_optional():describe("Sentinel password"),
+}
+
+local read_schema = lutil.table_merge({
+ read_servers = ts.string + ts.array_of(ts.string),
+}, common_schema)
+
+local write_schema = lutil.table_merge({
+ write_servers = ts.string + ts.array_of(ts.string),
+}, common_schema)
+
+local rw_schema = lutil.table_merge({
+ read_servers = ts.string + ts.array_of(ts.string),
+ write_servers = ts.string + ts.array_of(ts.string),
+}, common_schema)
+
+local servers_schema = lutil.table_merge({
+ servers = ts.string + ts.array_of(ts.string),
+}, common_schema)
+
+local server_schema = lutil.table_merge({
+ server = ts.string + ts.array_of(ts.string),
+}, common_schema)
+
+local enrich_schema = function(external)
+ return ts.one_of {
+ ts.shape(external), -- no specific redis parameters
+ ts.shape(lutil.table_merge(read_schema, external)), -- read_servers specified
+ ts.shape(lutil.table_merge(write_schema, external)), -- write_servers specified
+ ts.shape(lutil.table_merge(rw_schema, external)), -- both read and write servers defined
+ ts.shape(lutil.table_merge(servers_schema, external)), -- just servers for both ops
+ ts.shape(lutil.table_merge(server_schema, external)), -- legacy `server` attribute
+ }
+end
+
+exports.enrich_schema = enrich_schema
+
+local function redis_query_sentinel(ev_base, params, initialised)
+ local function flatten_redis_table(tbl)
+ local res = {}
+ for i = 1, #tbl, 2 do
+ res[tbl[i]] = tbl[i + 1]
+ end
+
+ return res
+ end
+ -- Coroutines syntax
+ local rspamd_redis = require "rspamd_redis"
+ local sentinels = params.sentinels
+ local addr = sentinels:get_upstream_round_robin()
+
+ local host = addr:get_addr()
+ local masters = {}
+ local process_masters -- Function that is called to process masters data
+
+ local function masters_cb(err, result)
+ if not err and result and type(result) == 'table' then
+
+ local pending_subrequests = 0
+
+ for _, m in ipairs(result) do
+ local master = flatten_redis_table(m)
+
+ -- Wrap IPv6-addresses in brackets
+ if (master.ip:match(":")) then
+ master.ip = "[" .. master.ip .. "]"
+ end
+
+ if params.sentinel_masters_pattern then
+ if master.name:match(params.sentinel_masters_pattern) then
+ lutil.debugm(N, 'found master %s with ip %s and port %s',
+ master.name, master.ip, master.port)
+ masters[master.name] = master
+ else
+ lutil.debugm(N, 'skip master %s with ip %s and port %s, pattern %s',
+ master.name, master.ip, master.port, params.sentinel_masters_pattern)
+ end
+ else
+ lutil.debugm(N, 'found master %s with ip %s and port %s',
+ master.name, master.ip, master.port)
+ masters[master.name] = master
+ end
+ end
+
+ -- For each master we need to get a list of slaves
+ for k, v in pairs(masters) do
+ v.slaves = {}
+ local function slaves_cb(slave_err, slave_result)
+ if not slave_err and type(slave_result) == 'table' then
+ for _, s in ipairs(slave_result) do
+ local slave = flatten_redis_table(s)
+ lutil.debugm(N, rspamd_config,
+ 'found slave for master %s with ip %s and port %s',
+ v.name, slave.ip, slave.port)
+ -- Wrap IPv6-addresses in brackets
+ if (slave.ip:match(":")) then
+ slave.ip = "[" .. slave.ip .. "]"
+ end
+ v.slaves[#v.slaves + 1] = slave
+ end
+ else
+ logger.errx('cannot get slaves data from Redis Sentinel %s: %s',
+ host:to_string(true), slave_err)
+ addr:fail()
+ end
+
+ pending_subrequests = pending_subrequests - 1
+
+ if pending_subrequests == 0 then
+ -- Finalize masters and slaves
+ process_masters()
+ end
+ end
+
+ local ret = rspamd_redis.make_request {
+ host = addr:get_addr(),
+ timeout = params.timeout,
+ username = params.sentinel_username,
+ password = params.sentinel_password,
+ config = rspamd_config,
+ ev_base = ev_base,
+ cmd = 'SENTINEL',
+ args = { 'slaves', k },
+ no_pool = true,
+ callback = slaves_cb
+ }
+
+ if not ret then
+ logger.errx(rspamd_config, 'cannot connect sentinel when query slaves at address: %s',
+ host:to_string(true))
+ addr:fail()
+ else
+ pending_subrequests = pending_subrequests + 1
+ end
+ end
+
+ addr:ok()
+ else
+ logger.errx('cannot get masters data from Redis Sentinel %s: %s',
+ host:to_string(true), err)
+ addr:fail()
+ end
+ end
+
+ local ret = rspamd_redis.make_request {
+ host = addr:get_addr(),
+ timeout = params.timeout,
+ config = rspamd_config,
+ ev_base = ev_base,
+ username = params.sentinel_username,
+ password = params.sentinel_password,
+ cmd = 'SENTINEL',
+ args = { 'masters' },
+ no_pool = true,
+ callback = masters_cb,
+ }
+
+ if not ret then
+ logger.errx(rspamd_config, 'cannot connect sentinel at address: %s',
+ host:to_string(true))
+ addr:fail()
+ end
+
+ process_masters = function()
+ -- We now form new strings for masters and slaves
+ local read_servers_tbl, write_servers_tbl = {}, {}
+
+ for _, master in pairs(masters) do
+ write_servers_tbl[#write_servers_tbl + 1] = string.format(
+ '%s:%s', master.ip, master.port
+ )
+ read_servers_tbl[#read_servers_tbl + 1] = string.format(
+ '%s:%s', master.ip, master.port
+ )
+
+ for _, slave in ipairs(master.slaves) do
+ if slave['master-link-status'] == 'ok' then
+ read_servers_tbl[#read_servers_tbl + 1] = string.format(
+ '%s:%s', slave.ip, slave.port
+ )
+ end
+ end
+ end
+
+ table.sort(read_servers_tbl)
+ table.sort(write_servers_tbl)
+
+ local read_servers_str = table.concat(read_servers_tbl, ',')
+ local write_servers_str = table.concat(write_servers_tbl, ',')
+
+ lutil.debugm(N, rspamd_config,
+ 'new servers list: %s read; %s write',
+ read_servers_str,
+ write_servers_str)
+
+ if read_servers_str ~= params.read_servers_str then
+ local upstream_list = require "rspamd_upstream_list"
+
+ local read_upstreams = upstream_list.create(rspamd_config,
+ read_servers_str, 6379)
+
+ if read_upstreams then
+ logger.infox(rspamd_config, 'sentinel %s: replace read servers with new list: %s',
+ host:to_string(true), read_servers_str)
+ params.read_servers = read_upstreams
+ params.read_servers_str = read_servers_str
+ end
+ end
+
+ if write_servers_str ~= params.write_servers_str then
+ local upstream_list = require "rspamd_upstream_list"
+
+ local write_upstreams = upstream_list.create(rspamd_config,
+ write_servers_str, 6379)
+
+ if write_upstreams then
+ logger.infox(rspamd_config, 'sentinel %s: replace write servers with new list: %s',
+ host:to_string(true), write_servers_str)
+ params.write_servers = write_upstreams
+ params.write_servers_str = write_servers_str
+
+ local queried = false
+
+ local function monitor_failures(up, _, count)
+ if count > params.sentinel_master_maxerrors and not queried then
+ logger.infox(rspamd_config, 'sentinel: master with address %s, caused %s failures, try to query sentinel',
+ host:to_string(true), count)
+ queried = true -- Avoid multiple checks caused by this monitor
+ redis_query_sentinel(ev_base, params, true)
+ end
+ end
+
+ write_upstreams:add_watcher('failure', monitor_failures)
+ end
+ end
+ end
+
+end
+
+local function add_redis_sentinels(params)
+ local upstream_list = require "rspamd_upstream_list"
+
+ local upstreams_sentinels = upstream_list.create(rspamd_config,
+ params.sentinels, 5000)
+
+ if not upstreams_sentinels then
+ logger.errx(rspamd_config, 'cannot load redis sentinels string: %s',
+ params.sentinels)
+
+ return
+ end
+
+ params.sentinels = upstreams_sentinels
+
+ if not params.sentinel_watch_time then
+ params.sentinel_watch_time = 60 -- Each minute
+ end
+
+ if not params.sentinel_master_maxerrors then
+ params.sentinel_master_maxerrors = 2 -- Maximum number of errors before rechecking
+ end
+
+ rspamd_config:add_on_load(function(_, ev_base, worker)
+ local initialised = false
+ if worker:is_scanner() or worker:get_type() == 'fuzzy' then
+ rspamd_config:add_periodic(ev_base, 0.0, function()
+ redis_query_sentinel(ev_base, params, initialised)
+ initialised = true
+
+ return params.sentinel_watch_time
+ end, false)
+ end
+ end)
+end
+
+local cached_results = {}
+
+local function calculate_redis_hash(params)
+ local cr = require "rspamd_cryptobox_hash"
+
+ local h = cr.create()
+
+ local function rec_hash(k, v)
+ if type(v) == 'string' then
+ h:update(k)
+ h:update(v)
+ elseif type(v) == 'number' then
+ h:update(k)
+ h:update(tostring(v))
+ elseif type(v) == 'table' then
+ for kk, vv in pairs(v) do
+ rec_hash(kk, vv)
+ end
+ end
+ end
+
+ rec_hash('top', params)
+
+ return h:base32()
+end
+
+local function process_redis_opts(options, redis_params)
+ local default_timeout = 1.0
+ local default_expand_keys = false
+
+ if not redis_params['timeout'] or redis_params['timeout'] == default_timeout then
+ if options['timeout'] then
+ redis_params['timeout'] = tonumber(options['timeout'])
+ else
+ redis_params['timeout'] = default_timeout
+ end
+ end
+
+ if options['prefix'] and not redis_params['prefix'] then
+ redis_params['prefix'] = options['prefix']
+ end
+
+ if type(options['expand_keys']) == 'boolean' then
+ redis_params['expand_keys'] = options['expand_keys']
+ else
+ redis_params['expand_keys'] = default_expand_keys
+ end
+
+ if not redis_params['db'] then
+ if options['db'] then
+ redis_params['db'] = tostring(options['db'])
+ elseif options['dbname'] then
+ redis_params['db'] = tostring(options['dbname'])
+ elseif options['database'] then
+ redis_params['db'] = tostring(options['database'])
+ end
+ end
+ if options['username'] and not redis_params['username'] then
+ redis_params['username'] = options['username']
+ end
+ if options['password'] and not redis_params['password'] then
+ redis_params['password'] = options['password']
+ end
+
+ if not redis_params.sentinels and options.sentinels then
+ redis_params.sentinels = options.sentinels
+ end
+
+ if options['sentinel_masters_pattern'] and not redis_params['sentinel_masters_pattern'] then
+ redis_params['sentinel_masters_pattern'] = options['sentinel_masters_pattern']
+ end
+
+end
+
+local function enrich_defaults(rspamd_config, module, redis_params)
+ if rspamd_config then
+ local opts = rspamd_config:get_all_opt('redis')
+
+ if opts then
+ if module then
+ if opts[module] then
+ process_redis_opts(opts[module], redis_params)
+ end
+ end
+
+ process_redis_opts(opts, redis_params)
+ end
+ end
+end
+
+local function maybe_return_cached(redis_params)
+ local h = calculate_redis_hash(redis_params)
+
+ if cached_results[h] then
+ lutil.debugm(N, 'reused redis server: %s', redis_params)
+ return cached_results[h]
+ end
+
+ redis_params.hash = h
+ cached_results[h] = redis_params
+
+ if not redis_params.read_only and redis_params.sentinels then
+ add_redis_sentinels(redis_params)
+ end
+
+ lutil.debugm(N, 'loaded new redis server: %s', redis_params)
+ return redis_params
+end
+
+--[[[
+-- @module lua_redis
+-- This module contains helper functions for working with Redis
+--]]
+local function process_redis_options(options, rspamd_config, result)
+ local default_port = 6379
+ local upstream_list = require "rspamd_upstream_list"
+ local read_only = true
+
+ -- Try to get read servers:
+ local upstreams_read, upstreams_write
+
+ if options['read_servers'] then
+ if rspamd_config then
+ upstreams_read = upstream_list.create(rspamd_config,
+ options['read_servers'], default_port)
+ else
+ upstreams_read = upstream_list.create(options['read_servers'],
+ default_port)
+ end
+
+ result.read_servers_str = options['read_servers']
+ elseif options['servers'] then
+ if rspamd_config then
+ upstreams_read = upstream_list.create(rspamd_config,
+ options['servers'], default_port)
+ else
+ upstreams_read = upstream_list.create(options['servers'], default_port)
+ end
+
+ result.read_servers_str = options['servers']
+ read_only = false
+ elseif options['server'] then
+ if rspamd_config then
+ upstreams_read = upstream_list.create(rspamd_config,
+ options['server'], default_port)
+ else
+ upstreams_read = upstream_list.create(options['server'], default_port)
+ end
+
+ result.read_servers_str = options['server']
+ read_only = false
+ end
+
+ if upstreams_read then
+ if options['write_servers'] then
+ if rspamd_config then
+ upstreams_write = upstream_list.create(rspamd_config,
+ options['write_servers'], default_port)
+ else
+ upstreams_write = upstream_list.create(options['write_servers'],
+ default_port)
+ end
+ result.write_servers_str = options['write_servers']
+ read_only = false
+ elseif not read_only then
+ upstreams_write = upstreams_read
+ result.write_servers_str = result.read_servers_str
+ end
+ end
+
+ -- Store options
+ process_redis_opts(options, result)
+
+ if read_only and not upstreams_write then
+ result.read_only = true
+ elseif upstreams_write then
+ result.read_only = false
+ end
+
+ if upstreams_read then
+ result.read_servers = upstreams_read
+
+ if upstreams_write then
+ result.write_servers = upstreams_write
+ end
+
+ return true
+ end
+
+ lutil.debugm(N, rspamd_config,
+ 'cannot load redis server from obj: %s, processed to %s',
+ options, result)
+
+ return false
+end
+
+--[[[
+@function try_load_redis_servers(options, rspamd_config, no_fallback)
+Tries to load redis servers from the specified `options` object.
+Returns `redis_params` table or nil in case of failure
+
+--]]
+exports.try_load_redis_servers = function(options, rspamd_config, no_fallback, module_name)
+ local result = {}
+
+ if process_redis_options(options, rspamd_config, result) then
+ if not no_fallback then
+ enrich_defaults(rspamd_config, module_name, result)
+ end
+ return maybe_return_cached(result)
+ end
+end
+
+-- This function parses redis server definition using either
+-- specific server string for this module or global
+-- redis section
+local function rspamd_parse_redis_server(module_name, module_opts, no_fallback)
+ local result = {}
+
+ -- Try local options
+ local opts
+ lutil.debugm(N, rspamd_config, 'try load redis config for: %s', module_name)
+ if not module_opts then
+ opts = rspamd_config:get_all_opt(module_name)
+ else
+ opts = module_opts
+ end
+
+ if opts then
+ local ret
+
+ if opts.redis then
+ ret = process_redis_options(opts.redis, rspamd_config, result)
+
+ if ret then
+ if not no_fallback then
+ enrich_defaults(rspamd_config, module_name, result)
+ end
+ return maybe_return_cached(result)
+ end
+ end
+
+ ret = process_redis_options(opts, rspamd_config, result)
+
+ if ret then
+ if not no_fallback then
+ enrich_defaults(rspamd_config, module_name, result)
+ end
+ return maybe_return_cached(result)
+ end
+ end
+
+ if no_fallback then
+ logger.infox(rspamd_config, "cannot find Redis definitions for %s and fallback is disabled",
+ module_name)
+
+ return nil
+ end
+
+ -- Try global options
+ opts = rspamd_config:get_all_opt('redis')
+
+ if opts then
+ local ret
+
+ if opts[module_name] then
+ ret = process_redis_options(opts[module_name], rspamd_config, result)
+
+ if ret then
+ return maybe_return_cached(result)
+ end
+ else
+ ret = process_redis_options(opts, rspamd_config, result)
+
+ -- Exclude disabled
+ if opts['disabled_modules'] then
+ for _, v in ipairs(opts['disabled_modules']) do
+ if v == module_name then
+ logger.infox(rspamd_config, "NOT using default redis server for module %s: it is disabled",
+ module_name)
+
+ return nil
+ end
+ end
+ end
+
+ if ret then
+ logger.infox(rspamd_config, "use default Redis settings for %s",
+ module_name)
+ return maybe_return_cached(result)
+ end
+ end
+ end
+
+ if result.read_servers then
+ return maybe_return_cached(result)
+ end
+
+ return nil
+end
+
+--[[[
+-- @function lua_redis.parse_redis_server(module_name, module_opts, no_fallback)
+-- Extracts Redis server settings from configuration
+-- @param {string} module_name name of module to get settings for
+-- @param {table} module_opts settings for module or `nil` to fetch them from configuration
+-- @param {boolean} no_fallback should be `true` if global settings must not be used
+-- @return {table} redis server settings
+-- @example
+-- local rconfig = lua_redis.parse_redis_server('my_module')
+-- -- rconfig contains upstream_list objects in ['write_servers'] and ['read_servers']
+-- -- ['timeout'] contains timeout in seconds
+-- -- ['expand_keys'] if true tells that redis key expansion is enabled
+--]]
+
+exports.rspamd_parse_redis_server = rspamd_parse_redis_server
+exports.parse_redis_server = rspamd_parse_redis_server
+
+local process_cmd = {
+ bitop = function(args)
+ local idx_l = {}
+ for i = 2, #args do
+ table.insert(idx_l, i)
+ end
+ return idx_l
+ end,
+ blpop = function(args)
+ local idx_l = {}
+ for i = 1, #args - 1 do
+ table.insert(idx_l, i)
+ end
+ return idx_l
+ end,
+ eval = function(args)
+ local idx_l = {}
+ local numkeys = args[2]
+ if numkeys and tonumber(numkeys) >= 1 then
+ for i = 3, numkeys + 2 do
+ table.insert(idx_l, i)
+ end
+ end
+ return idx_l
+ end,
+ set = function(args)
+ return { 1 }
+ end,
+ mget = function(args)
+ local idx_l = {}
+ for i = 1, #args do
+ table.insert(idx_l, i)
+ end
+ return idx_l
+ end,
+ mset = function(args)
+ local idx_l = {}
+ for i = 1, #args, 2 do
+ table.insert(idx_l, i)
+ end
+ return idx_l
+ end,
+ sdiffstore = function(args)
+ local idx_l = {}
+ for i = 2, #args do
+ table.insert(idx_l, i)
+ end
+ return idx_l
+ end,
+ smove = function(args)
+ return { 1, 2 }
+ end,
+ script = function()
+ end
+}
+process_cmd.append = process_cmd.set
+process_cmd.auth = process_cmd.script
+process_cmd.bgrewriteaof = process_cmd.script
+process_cmd.bgsave = process_cmd.script
+process_cmd.bitcount = process_cmd.set
+process_cmd.bitfield = process_cmd.set
+process_cmd.bitpos = process_cmd.set
+process_cmd.brpop = process_cmd.blpop
+process_cmd.brpoplpush = process_cmd.blpop
+process_cmd.client = process_cmd.script
+process_cmd.cluster = process_cmd.script
+process_cmd.command = process_cmd.script
+process_cmd.config = process_cmd.script
+process_cmd.dbsize = process_cmd.script
+process_cmd.debug = process_cmd.script
+process_cmd.decr = process_cmd.set
+process_cmd.decrby = process_cmd.set
+process_cmd.del = process_cmd.mget
+process_cmd.discard = process_cmd.script
+process_cmd.dump = process_cmd.set
+process_cmd.echo = process_cmd.script
+process_cmd.evalsha = process_cmd.eval
+process_cmd.exec = process_cmd.script
+process_cmd.exists = process_cmd.mget
+process_cmd.expire = process_cmd.set
+process_cmd.expireat = process_cmd.set
+process_cmd.flushall = process_cmd.script
+process_cmd.flushdb = process_cmd.script
+process_cmd.geoadd = process_cmd.set
+process_cmd.geohash = process_cmd.set
+process_cmd.geopos = process_cmd.set
+process_cmd.geodist = process_cmd.set
+process_cmd.georadius = process_cmd.set
+process_cmd.georadiusbymember = process_cmd.set
+process_cmd.get = process_cmd.set
+process_cmd.getbit = process_cmd.set
+process_cmd.getrange = process_cmd.set
+process_cmd.getset = process_cmd.set
+process_cmd.hdel = process_cmd.set
+process_cmd.hexists = process_cmd.set
+process_cmd.hget = process_cmd.set
+process_cmd.hgetall = process_cmd.set
+process_cmd.hincrby = process_cmd.set
+process_cmd.hincrbyfloat = process_cmd.set
+process_cmd.hkeys = process_cmd.set
+process_cmd.hlen = process_cmd.set
+process_cmd.hmget = process_cmd.set
+process_cmd.hmset = process_cmd.set
+process_cmd.hscan = process_cmd.set
+process_cmd.hset = process_cmd.set
+process_cmd.hsetnx = process_cmd.set
+process_cmd.hstrlen = process_cmd.set
+process_cmd.hvals = process_cmd.set
+process_cmd.incr = process_cmd.set
+process_cmd.incrby = process_cmd.set
+process_cmd.incrbyfloat = process_cmd.set
+process_cmd.info = process_cmd.script
+process_cmd.keys = process_cmd.script
+process_cmd.lastsave = process_cmd.script
+process_cmd.lindex = process_cmd.set
+process_cmd.linsert = process_cmd.set
+process_cmd.llen = process_cmd.set
+process_cmd.lpop = process_cmd.set
+process_cmd.lpush = process_cmd.set
+process_cmd.lpushx = process_cmd.set
+process_cmd.lrange = process_cmd.set
+process_cmd.lrem = process_cmd.set
+process_cmd.lset = process_cmd.set
+process_cmd.ltrim = process_cmd.set
+process_cmd.migrate = process_cmd.script
+process_cmd.monitor = process_cmd.script
+process_cmd.move = process_cmd.set
+process_cmd.msetnx = process_cmd.mset
+process_cmd.multi = process_cmd.script
+process_cmd.object = process_cmd.script
+process_cmd.persist = process_cmd.set
+process_cmd.pexpire = process_cmd.set
+process_cmd.pexpireat = process_cmd.set
+process_cmd.pfadd = process_cmd.set
+process_cmd.pfcount = process_cmd.set
+process_cmd.pfmerge = process_cmd.mget
+process_cmd.ping = process_cmd.script
+process_cmd.psetex = process_cmd.set
+process_cmd.psubscribe = process_cmd.script
+process_cmd.pubsub = process_cmd.script
+process_cmd.pttl = process_cmd.set
+process_cmd.publish = process_cmd.script
+process_cmd.punsubscribe = process_cmd.script
+process_cmd.quit = process_cmd.script
+process_cmd.randomkey = process_cmd.script
+process_cmd.readonly = process_cmd.script
+process_cmd.readwrite = process_cmd.script
+process_cmd.rename = process_cmd.mget
+process_cmd.renamenx = process_cmd.mget
+process_cmd.restore = process_cmd.set
+process_cmd.role = process_cmd.script
+process_cmd.rpop = process_cmd.set
+process_cmd.rpoplpush = process_cmd.mget
+process_cmd.rpush = process_cmd.set
+process_cmd.rpushx = process_cmd.set
+process_cmd.sadd = process_cmd.set
+process_cmd.save = process_cmd.script
+process_cmd.scard = process_cmd.set
+process_cmd.sdiff = process_cmd.mget
+process_cmd.select = process_cmd.script
+process_cmd.setbit = process_cmd.set
+process_cmd.setex = process_cmd.set
+process_cmd.setnx = process_cmd.set
+process_cmd.sinterstore = process_cmd.sdiff
+process_cmd.sismember = process_cmd.set
+process_cmd.slaveof = process_cmd.script
+process_cmd.slowlog = process_cmd.script
+process_cmd.smembers = process_cmd.script
+process_cmd.sort = process_cmd.set
+process_cmd.spop = process_cmd.set
+process_cmd.srandmember = process_cmd.set
+process_cmd.srem = process_cmd.set
+process_cmd.strlen = process_cmd.set
+process_cmd.subscribe = process_cmd.script
+process_cmd.sunion = process_cmd.mget
+process_cmd.sunionstore = process_cmd.mget
+process_cmd.swapdb = process_cmd.script
+process_cmd.sync = process_cmd.script
+process_cmd.time = process_cmd.script
+process_cmd.touch = process_cmd.mget
+process_cmd.ttl = process_cmd.set
+process_cmd.type = process_cmd.set
+process_cmd.unsubscribe = process_cmd.script
+process_cmd.unlink = process_cmd.mget
+process_cmd.unwatch = process_cmd.script
+process_cmd.wait = process_cmd.script
+process_cmd.watch = process_cmd.mget
+process_cmd.zadd = process_cmd.set
+process_cmd.zcard = process_cmd.set
+process_cmd.zcount = process_cmd.set
+process_cmd.zincrby = process_cmd.set
+process_cmd.zinterstore = process_cmd.eval
+process_cmd.zlexcount = process_cmd.set
+process_cmd.zrange = process_cmd.set
+process_cmd.zrangebylex = process_cmd.set
+process_cmd.zrank = process_cmd.set
+process_cmd.zrem = process_cmd.set
+process_cmd.zrembylex = process_cmd.set
+process_cmd.zrembyrank = process_cmd.set
+process_cmd.zrembyscore = process_cmd.set
+process_cmd.zrevrange = process_cmd.set
+process_cmd.zrevrangebyscore = process_cmd.set
+process_cmd.zrevrank = process_cmd.set
+process_cmd.zscore = process_cmd.set
+process_cmd.zunionstore = process_cmd.eval
+process_cmd.scan = process_cmd.script
+process_cmd.sscan = process_cmd.set
+process_cmd.hscan = process_cmd.set
+process_cmd.zscan = process_cmd.set
+
+local function get_key_indexes(cmd, args)
+ local idx_l = {}
+ cmd = string.lower(cmd)
+ if process_cmd[cmd] then
+ idx_l = process_cmd[cmd](args)
+ else
+ logger.warnx(rspamd_config, "Don't know how to extract keys for %s Redis command", cmd)
+ end
+ return idx_l
+end
+
+local gen_meta = {
+ principal_recipient = function(task)
+ return task:get_principal_recipient()
+ end,
+ principal_recipient_domain = function(task)
+ local p = task:get_principal_recipient()
+ if not p then
+ return
+ end
+ return string.match(p, '.*@(.*)')
+ end,
+ ip = function(task)
+ local i = task:get_ip()
+ if i and i:is_valid() then
+ return i:to_string()
+ end
+ end,
+ from = function(task)
+ return ((task:get_from('smtp') or E)[1] or E)['addr']
+ end,
+ from_domain = function(task)
+ return ((task:get_from('smtp') or E)[1] or E)['domain']
+ end,
+ from_domain_or_helo_domain = function(task)
+ local d = ((task:get_from('smtp') or E)[1] or E)['domain']
+ if d and #d > 0 then
+ return d
+ end
+ return task:get_helo()
+ end,
+ mime_from = function(task)
+ return ((task:get_from('mime') or E)[1] or E)['addr']
+ end,
+ mime_from_domain = function(task)
+ return ((task:get_from('mime') or E)[1] or E)['domain']
+ end,
+}
+
+local function gen_get_esld(f)
+ return function(task)
+ local d = f(task)
+ if not d then
+ return
+ end
+ return rspamd_util.get_tld(d)
+ end
+end
+
+gen_meta.smtp_from = gen_meta.from
+gen_meta.smtp_from_domain = gen_meta.from_domain
+gen_meta.smtp_from_domain_or_helo_domain = gen_meta.from_domain_or_helo_domain
+gen_meta.esld_principal_recipient_domain = gen_get_esld(gen_meta.principal_recipient_domain)
+gen_meta.esld_from_domain = gen_get_esld(gen_meta.from_domain)
+gen_meta.esld_smtp_from_domain = gen_meta.esld_from_domain
+gen_meta.esld_mime_from_domain = gen_get_esld(gen_meta.mime_from_domain)
+gen_meta.esld_from_domain_or_helo_domain = gen_get_esld(gen_meta.from_domain_or_helo_domain)
+gen_meta.esld_smtp_from_domain_or_helo_domain = gen_meta.esld_from_domain_or_helo_domain
+
+local function get_key_expansion_metadata(task)
+
+ local md_mt = {
+ __index = function(self, k)
+ k = string.lower(k)
+ local v = rawget(self, k)
+ if v then
+ return v
+ end
+ if gen_meta[k] then
+ v = gen_meta[k](task)
+ rawset(self, k, v)
+ end
+ return v
+ end,
+ }
+
+ local lazy_meta = {}
+ setmetatable(lazy_meta, md_mt)
+ return lazy_meta
+
+end
+
+-- Performs async call to redis hiding all complexity inside function
+-- task - rspamd_task
+-- redis_params - valid params returned by rspamd_parse_redis_server
+-- key - key to select upstream or nil to select round-robin/master-slave
+-- is_write - true if need to write to redis server
+-- callback - function to be called upon request is completed
+-- command - redis command
+-- args - table of arguments
+-- extra_opts - table of optional request arguments
+local function rspamd_redis_make_request(task, redis_params, key, is_write,
+ callback, command, args, extra_opts)
+ local addr
+ local function rspamd_redis_make_request_cb(err, data)
+ if err then
+ addr:fail()
+ else
+ addr:ok()
+ end
+ if callback then
+ callback(err, data, addr)
+ end
+ end
+ if not task or not redis_params or not command then
+ return false, nil, nil
+ end
+
+ local rspamd_redis = require "rspamd_redis"
+
+ if key then
+ if is_write then
+ addr = redis_params['write_servers']:get_upstream_by_hash(key)
+ else
+ addr = redis_params['read_servers']:get_upstream_by_hash(key)
+ end
+ else
+ if is_write then
+ addr = redis_params['write_servers']:get_upstream_master_slave(key)
+ else
+ addr = redis_params['read_servers']:get_upstream_round_robin(key)
+ end
+ end
+
+ if not addr then
+ logger.errx(task, 'cannot select server to make redis request')
+ end
+
+ if redis_params['expand_keys'] then
+ local m = get_key_expansion_metadata(task)
+ local indexes = get_key_indexes(command, args)
+ for _, i in ipairs(indexes) do
+ args[i] = lutil.template(args[i], m)
+ end
+ end
+
+ local ip_addr = addr:get_addr()
+ local options = {
+ task = task,
+ callback = rspamd_redis_make_request_cb,
+ host = ip_addr,
+ timeout = redis_params['timeout'],
+ cmd = command,
+ args = args
+ }
+
+ if extra_opts then
+ for k, v in pairs(extra_opts) do
+ options[k] = v
+ end
+ end
+
+ if redis_params['username'] then
+ options['username'] = redis_params['username']
+ end
+
+ if redis_params['password'] then
+ options['password'] = redis_params['password']
+ end
+
+ if redis_params['db'] then
+ options['dbname'] = redis_params['db']
+ end
+
+ lutil.debugm(N, task, 'perform request to redis server' ..
+ ' (host=%s, timeout=%s): cmd: %s', ip_addr,
+ options.timeout, options.cmd)
+
+ local ret, conn = rspamd_redis.make_request(options)
+
+ if not ret then
+ addr:fail()
+ logger.warnx(task, "cannot make redis request to: %s", tostring(ip_addr))
+ end
+
+ return ret, conn, addr
+end
+
+--[[[
+-- @function lua_redis.redis_make_request(task, redis_params, key, is_write, callback, command, args)
+-- Sends a request to Redis
+-- @param {rspamd_task} task task object
+-- @param {table} redis_params redis configuration in format returned by lua_redis.parse_redis_server()
+-- @param {string} key key to use for sharding
+-- @param {boolean} is_write should be `true` if we are performing a write operating
+-- @param {function} callback callback function (first parameter is error if applicable, second is a 2D array (table))
+-- @param {string} command Redis command to run
+-- @param {table} args Numerically indexed table containing arguments for command
+--]]
+
+exports.rspamd_redis_make_request = rspamd_redis_make_request
+exports.redis_make_request = rspamd_redis_make_request
+
+local function redis_make_request_taskless(ev_base, cfg, redis_params, key,
+ is_write, callback, command, args, extra_opts)
+ if not ev_base or not redis_params or not command then
+ return false, nil, nil
+ end
+
+ local addr
+ local function rspamd_redis_make_request_cb(err, data)
+ if err then
+ addr:fail()
+ else
+ addr:ok()
+ end
+ if callback then
+ callback(err, data, addr)
+ end
+ end
+
+ local rspamd_redis = require "rspamd_redis"
+
+ if key then
+ if is_write then
+ addr = redis_params['write_servers']:get_upstream_by_hash(key)
+ else
+ addr = redis_params['read_servers']:get_upstream_by_hash(key)
+ end
+ else
+ if is_write then
+ addr = redis_params['write_servers']:get_upstream_master_slave(key)
+ else
+ addr = redis_params['read_servers']:get_upstream_round_robin(key)
+ end
+ end
+
+ if not addr then
+ logger.errx(cfg, 'cannot select server to make redis request')
+ end
+
+ local options = {
+ ev_base = ev_base,
+ config = cfg,
+ callback = rspamd_redis_make_request_cb,
+ host = addr:get_addr(),
+ timeout = redis_params['timeout'],
+ cmd = command,
+ args = args
+ }
+ if extra_opts then
+ for k, v in pairs(extra_opts) do
+ options[k] = v
+ end
+ end
+
+ if redis_params['username'] then
+ options['username'] = redis_params['username']
+ end
+
+ if redis_params['password'] then
+ options['password'] = redis_params['password']
+ end
+
+ if redis_params['db'] then
+ options['dbname'] = redis_params['db']
+ end
+
+ lutil.debugm(N, cfg, 'perform taskless request to redis server' ..
+ ' (host=%s, timeout=%s): cmd: %s', options.host:tostring(true),
+ options.timeout, options.cmd)
+ local ret, conn = rspamd_redis.make_request(options)
+ if not ret then
+ logger.errx('cannot execute redis request')
+ addr:fail()
+ end
+
+ return ret, conn, addr
+end
+
+--[[[
+-- @function lua_redis.redis_make_request_taskless(ev_base, redis_params, key, is_write, callback, command, args)
+-- Sends a request to Redis in context where `task` is not available for some specific use-cases
+-- Identical to redis_make_request() except in that first parameter is an `event base` object
+--]]
+
+exports.rspamd_redis_make_request_taskless = redis_make_request_taskless
+exports.redis_make_request_taskless = redis_make_request_taskless
+
+local redis_scripts = {
+}
+
+local function script_set_loaded(script)
+ if script.sha then
+ script.loaded = true
+ end
+
+ local wait_table = {}
+ for _, s in ipairs(script.waitq) do
+ table.insert(wait_table, s)
+ end
+
+ script.waitq = {}
+
+ for _, s in ipairs(wait_table) do
+ s(script.loaded)
+ end
+end
+
+local function prepare_redis_call(script)
+ local servers = {}
+ local options = {}
+
+ if script.redis_params.read_servers then
+ servers = lutil.table_merge(servers, script.redis_params.read_servers:all_upstreams())
+ end
+ if script.redis_params.write_servers then
+ servers = lutil.table_merge(servers, script.redis_params.write_servers:all_upstreams())
+ end
+
+ -- Call load script on each server, set loaded flag
+ script.in_flight = #servers
+ for _, s in ipairs(servers) do
+ local cur_opts = {
+ host = s:get_addr(),
+ timeout = script.redis_params['timeout'],
+ cmd = 'SCRIPT',
+ args = { 'LOAD', script.script },
+ upstream = s
+ }
+
+ if script.redis_params['username'] then
+ cur_opts['username'] = script.redis_params['username']
+ end
+
+ if script.redis_params['password'] then
+ cur_opts['password'] = script.redis_params['password']
+ end
+
+ if script.redis_params['db'] then
+ cur_opts['dbname'] = script.redis_params['db']
+ end
+
+ table.insert(options, cur_opts)
+ end
+
+ return options
+end
+
+local function load_script_task(script, task, is_write)
+ local rspamd_redis = require "rspamd_redis"
+ local opts = prepare_redis_call(script)
+
+ for _, opt in ipairs(opts) do
+ opt.task = task
+ opt.is_write = is_write
+ opt.callback = function(err, data)
+ if err then
+ logger.errx(task, 'cannot upload script to %s: %s; registered from: %s:%s',
+ opt.upstream:get_addr():to_string(true),
+ err, script.caller.short_src, script.caller.currentline)
+ opt.upstream:fail()
+ script.fatal_error = err
+ else
+ opt.upstream:ok()
+ logger.infox(task,
+ "uploaded redis script to %s %s %s, sha: %s",
+ opt.upstream:get_addr():to_string(true),
+ script.filename and "from file" or "with id", script.filename or script.id, data)
+ script.sha = data -- We assume that sha is the same on all servers
+ end
+ script.in_flight = script.in_flight - 1
+
+ if script.in_flight == 0 then
+ script_set_loaded(script)
+ end
+ end
+
+ local ret = rspamd_redis.make_request(opt)
+
+ if not ret then
+ logger.errx('cannot execute redis request to load script on %s',
+ opt.upstream:get_addr())
+ script.in_flight = script.in_flight - 1
+ opt.upstream:fail()
+ end
+
+ if script.in_flight == 0 then
+ script_set_loaded(script)
+ end
+ end
+end
+
+local function load_script_taskless(script, cfg, ev_base, is_write)
+ local rspamd_redis = require "rspamd_redis"
+ local opts = prepare_redis_call(script)
+
+ for _, opt in ipairs(opts) do
+ opt.config = cfg
+ opt.ev_base = ev_base
+ opt.is_write = is_write
+ opt.callback = function(err, data)
+ if err then
+ logger.errx(cfg, 'cannot upload script to %s: %s; registered from: %s:%s, filename: %s',
+ opt.upstream:get_addr():to_string(true),
+ err, script.caller.short_src, script.caller.currentline, script.filename)
+ opt.upstream:fail()
+ script.fatal_error = err
+ else
+ opt.upstream:ok()
+ logger.infox(cfg,
+ "uploaded redis script to %s %s %s, sha: %s",
+ opt.upstream:get_addr():to_string(true),
+ script.filename and "from file" or "with id", script.filename or script.id,
+ data)
+ script.sha = data -- We assume that sha is the same on all servers
+ script.fatal_error = nil
+ end
+ script.in_flight = script.in_flight - 1
+
+ if script.in_flight == 0 then
+ script_set_loaded(script)
+ end
+ end
+ local ret = rspamd_redis.make_request(opt)
+
+ if not ret then
+ logger.errx('cannot execute redis request to load script on %s',
+ opt.upstream:get_addr())
+ script.in_flight = script.in_flight - 1
+ opt.upstream:fail()
+ end
+
+ if script.in_flight == 0 then
+ script_set_loaded(script)
+ end
+ end
+end
+
+local function load_redis_script(script, cfg, ev_base, _)
+ if script.redis_params then
+ load_script_taskless(script, cfg, ev_base)
+ end
+end
+
+local function add_redis_script(script, redis_params, caller_level, maybe_filename)
+ if not caller_level then
+ caller_level = 2
+ end
+ local caller = debug.getinfo(caller_level) or debug.getinfo(caller_level - 1) or E
+
+ local new_script = {
+ caller = caller,
+ loaded = false,
+ redis_params = redis_params,
+ script = script,
+ waitq = {}, -- callbacks pending for script being loaded
+ id = #redis_scripts + 1,
+ filename = maybe_filename,
+ }
+
+ -- Register on load function
+ rspamd_config:add_on_load(function(cfg, ev_base, worker)
+ local mult = 0.0
+ rspamd_config:add_periodic(ev_base, 0.0, function()
+ if not new_script.sha then
+ load_redis_script(new_script, cfg, ev_base, worker)
+ mult = mult + 1
+ return 1.0 * mult -- Check one more time in one second
+ end
+
+ return false
+ end, false)
+ end)
+
+ table.insert(redis_scripts, new_script)
+
+ return #redis_scripts
+end
+exports.add_redis_script = add_redis_script
+
+-- Loads a Redis script from a file, strips comments, and passes the content to
+-- `add_redis_script` function.
+--
+-- @param filename The name of the file containing the Redis script.
+-- @param redis_params The Redis parameters to use for this script.
+-- @return The ID of the newly added Redis script.
+--
+local function load_redis_script_from_file(filename, redis_params, dir)
+ local lua_util = require "lua_util"
+ local rspamd_logger = require "rspamd_logger"
+
+ if not dir then
+ dir = rspamd_paths.LUALIBDIR
+ end
+ local path = filename
+ if filename:sub(1, 1) ~= package.config:sub(1, 1) then
+ -- Relative path
+ path = lua_util.join_path(dir, "redis_scripts", filename)
+ end
+ -- Read file contents
+ local file = io.open(path, "r")
+ if not file then
+ rspamd_logger.errx("failed to open Redis script file: %s", path)
+ return nil
+ end
+ local script = file:read("*all")
+ if not script then
+ rspamd_logger.errx("failed to load Redis script file: %s", path)
+ return nil
+ end
+ file:close()
+ script = lua_util.strip_lua_comments(script)
+
+ return add_redis_script(script, redis_params, 3, filename)
+end
+
+exports.load_redis_script_from_file = load_redis_script_from_file
+
+local function exec_redis_script(id, params, callback, keys, args)
+ local redis_args = {}
+
+ if not redis_scripts[id] then
+ logger.errx("cannot find registered script with id %s", id)
+ return false
+ end
+
+ local script = redis_scripts[id]
+
+ if script.fatal_error then
+ callback(script.fatal_error, nil)
+ return true
+ end
+
+ if not script.redis_params then
+ callback('no redis servers defined', nil)
+ return true
+ end
+
+ local function do_call(can_reload)
+ local function redis_cb(err, data)
+ if not err then
+ callback(err, data)
+ elseif string.match(err, 'NOSCRIPT') then
+ -- Schedule restart
+ script.sha = nil
+ if can_reload then
+ table.insert(script.waitq, do_call)
+ if script.in_flight == 0 then
+ -- Reload scripts if this has not been initiated yet
+ if params.task then
+ load_script_task(script, params.task)
+ else
+ load_script_taskless(script, rspamd_config, params.ev_base)
+ end
+ end
+ else
+ callback(err, data)
+ end
+ else
+ callback(err, data)
+ end
+ end
+
+ if #redis_args == 0 then
+ table.insert(redis_args, script.sha)
+ table.insert(redis_args, tostring(#keys))
+ for _, k in ipairs(keys) do
+ table.insert(redis_args, k)
+ end
+
+ if type(args) == 'table' then
+ for _, a in ipairs(args) do
+ table.insert(redis_args, a)
+ end
+ end
+ end
+
+ if params.task then
+ if not rspamd_redis_make_request(params.task, script.redis_params,
+ params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
+ callback('Cannot make redis request', nil)
+ end
+ else
+ if not redis_make_request_taskless(params.ev_base, rspamd_config,
+ script.redis_params,
+ params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then
+ callback('Cannot make redis request', nil)
+ end
+ end
+ end
+
+ if script.loaded then
+ do_call(true)
+ else
+ -- Delayed until scripts are loaded
+ if not params.task then
+ table.insert(script.waitq, do_call)
+ else
+ -- TODO: fix taskfull requests
+ table.insert(script.waitq, function()
+ if script.loaded then
+ do_call(false)
+ else
+ callback('NOSCRIPT', nil)
+ end
+ end)
+ load_script_task(script, params.task, params.is_write)
+ end
+ end
+
+ return true
+end
+
+exports.exec_redis_script = exec_redis_script
+
+local function redis_connect_sync(redis_params, is_write, key, cfg, ev_base)
+ if not redis_params then
+ return false, nil
+ end
+
+ local rspamd_redis = require "rspamd_redis"
+ local addr
+
+ if key then
+ if is_write then
+ addr = redis_params['write_servers']:get_upstream_by_hash(key)
+ else
+ addr = redis_params['read_servers']:get_upstream_by_hash(key)
+ end
+ else
+ if is_write then
+ addr = redis_params['write_servers']:get_upstream_master_slave(key)
+ else
+ addr = redis_params['read_servers']:get_upstream_round_robin(key)
+ end
+ end
+
+ if not addr then
+ logger.errx(cfg, 'cannot select server to make redis request')
+ end
+
+ local options = {
+ host = addr:get_addr(),
+ timeout = redis_params['timeout'],
+ config = cfg or rspamd_config,
+ ev_base = ev_base or rspamadm_ev_base,
+ session = redis_params.session or rspamadm_session
+ }
+
+ for k, v in pairs(redis_params) do
+ options[k] = v
+ end
+
+ if not options.config then
+ logger.errx('config is not set')
+ return false, nil, addr
+ end
+
+ if not options.ev_base then
+ logger.errx('ev_base is not set')
+ return false, nil, addr
+ end
+
+ if not options.session then
+ logger.errx('session is not set')
+ return false, nil, addr
+ end
+
+ local ret, conn = rspamd_redis.connect_sync(options)
+ if not ret then
+ logger.errx('cannot create redis connection: %s', conn)
+ addr:fail()
+
+ return false, nil, addr
+ end
+
+ if conn then
+ local need_exec = false
+ if redis_params['username'] then
+ if redis_params['password'] then
+ conn:add_cmd('AUTH', { redis_params['username'], redis_params['password'] })
+ need_exec = true
+ else
+ logger.warnx('Redis requires a password when username is supplied')
+ return false, nil, addr
+ end
+ elseif redis_params['password'] then
+ conn:add_cmd('AUTH', { redis_params['password'] })
+ need_exec = true
+ end
+
+ if redis_params['db'] then
+ conn:add_cmd('SELECT', { tostring(redis_params['db']) })
+ need_exec = true
+ elseif redis_params['dbname'] then
+ conn:add_cmd('SELECT', { tostring(redis_params['dbname']) })
+ need_exec = true
+ end
+
+ if need_exec then
+ local exec_ret, res = conn:exec()
+
+ if not exec_ret then
+ logger.errx('cannot prepare redis connection (authentication or db selection failure): %s',
+ res)
+ addr:fail()
+ return false, nil, addr
+ end
+ end
+ end
+
+ return ret, conn, addr
+end
+
+exports.redis_connect_sync = redis_connect_sync
+
+--[[[
+-- @function lua_redis.request(redis_params, attrs, req)
+-- Sends a request to Redis synchronously with coroutines or asynchronously using
+-- a callback (modern API)
+-- @param redis_params a table of redis server parameters
+-- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
+-- @param req a table of request: a command + command options
+-- @return {result,data/connection,address} boolean result, connection object in case of async request and results if using coroutines, redis server address
+--]]
+
+exports.request = function(redis_params, attrs, req)
+ local lua_util = require "lua_util"
+
+ if not attrs or not redis_params or not req then
+ logger.errx('invalid arguments for redis request')
+ return false, nil, nil
+ end
+
+ if not (attrs.task or (attrs.config and attrs.ev_base)) then
+ logger.errx('invalid attributes for redis request')
+ return false, nil, nil
+ end
+
+ local opts = lua_util.shallowcopy(attrs)
+
+ local log_obj = opts.task or opts.config
+
+ local addr
+
+ if opts.callback then
+ -- Wrap callback
+ local callback = opts.callback
+ local function rspamd_redis_make_request_cb(err, data)
+ if err then
+ addr:fail()
+ else
+ addr:ok()
+ end
+ callback(err, data, addr)
+ end
+ opts.callback = rspamd_redis_make_request_cb
+ end
+
+ local rspamd_redis = require "rspamd_redis"
+ local is_write = opts.is_write
+
+ if opts.key then
+ if is_write then
+ addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
+ else
+ addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
+ end
+ else
+ if is_write then
+ addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
+ else
+ addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
+ end
+ end
+
+ if not addr then
+ logger.errx(log_obj, 'cannot select server to make redis request')
+ end
+
+ opts.host = addr:get_addr()
+ opts.timeout = redis_params.timeout
+
+ if type(req) == 'string' then
+ opts.cmd = req
+ else
+ -- XXX: modifies the input table
+ opts.cmd = table.remove(req, 1);
+ opts.args = req
+ end
+
+ if redis_params.username then
+ opts.username = redis_params.username
+ end
+
+ if redis_params.password then
+ opts.password = redis_params.password
+ end
+
+ if redis_params.db then
+ opts.dbname = redis_params.db
+ end
+
+ lutil.debugm(N, 'perform generic request to redis server' ..
+ ' (host=%s, timeout=%s): cmd: %s, arguments: %s', addr,
+ opts.timeout, opts.cmd, opts.args)
+
+ if opts.callback then
+ local ret, conn = rspamd_redis.make_request(opts)
+ if not ret then
+ logger.errx(log_obj, 'cannot execute redis request')
+ addr:fail()
+ end
+
+ return ret, conn, addr
+ else
+ -- Coroutines version
+ local ret, conn = rspamd_redis.connect_sync(opts)
+ if not ret then
+ logger.errx(log_obj, 'cannot execute redis request')
+ addr:fail()
+ else
+ conn:add_cmd(opts.cmd, opts.args)
+ return conn:exec()
+ end
+ return false, nil, addr
+ end
+end
+
+--[[[
+-- @function lua_redis.connect(redis_params, attrs)
+-- Connects to Redis synchronously with coroutines or asynchronously using a callback (modern API)
+-- @param redis_params a table of redis server parameters
+-- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session)
+-- @return {result,connection,address} boolean result, connection object, redis server address
+--]]
+
+exports.connect = function(redis_params, attrs)
+ local lua_util = require "lua_util"
+
+ if not attrs or not redis_params then
+ logger.errx('invalid arguments for redis connect')
+ return false, nil, nil
+ end
+
+ if not (attrs.task or (attrs.config and attrs.ev_base)) then
+ logger.errx('invalid attributes for redis connect')
+ return false, nil, nil
+ end
+
+ local opts = lua_util.shallowcopy(attrs)
+
+ local log_obj = opts.task or opts.config
+
+ local addr
+
+ if opts.callback then
+ -- Wrap callback
+ local callback = opts.callback
+ local function rspamd_redis_make_request_cb(err, data)
+ if err then
+ addr:fail()
+ else
+ addr:ok()
+ end
+ callback(err, data, addr)
+ end
+ opts.callback = rspamd_redis_make_request_cb
+ end
+
+ local rspamd_redis = require "rspamd_redis"
+ local is_write = opts.is_write
+
+ if opts.key then
+ if is_write then
+ addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key)
+ else
+ addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key)
+ end
+ else
+ if is_write then
+ addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key)
+ else
+ addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key)
+ end
+ end
+
+ if not addr then
+ logger.errx(log_obj, 'cannot select server to make redis connect')
+ end
+
+ opts.host = addr:get_addr()
+ opts.timeout = redis_params.timeout
+
+ if redis_params.username then
+ opts.username = redis_params.username
+ end
+
+ if redis_params.password then
+ opts.password = redis_params.password
+ end
+
+ if redis_params.db then
+ opts.dbname = redis_params.db
+ end
+
+ if opts.callback then
+ local ret, conn = rspamd_redis.connect(opts)
+ if not ret then
+ logger.errx(log_obj, 'cannot execute redis connect')
+ addr:fail()
+ end
+
+ return ret, conn, addr
+ else
+ -- Coroutines version
+ local ret, conn = rspamd_redis.connect_sync(opts)
+ if not ret then
+ logger.errx(log_obj, 'cannot execute redis connect')
+ addr:fail()
+ else
+ return true, conn, addr
+ end
+
+ return false, nil, addr
+ end
+end
+
+local redis_prefixes = {}
+
+--[[[
+-- @function lua_redis.register_prefix(prefix, module, description[, optional])
+-- Register new redis prefix for documentation purposes
+-- @param {string} prefix string prefix
+-- @param {string} module module name
+-- @param {string} description prefix description
+-- @param {table} optional optional kv pairs (e.g. pattern)
+--]]
+local function register_prefix(prefix, module, description, optional)
+ local pr = {
+ module = module,
+ description = description
+ }
+
+ if optional and type(optional) == 'table' then
+ for k, v in pairs(optional) do
+ pr[k] = v
+ end
+ end
+
+ redis_prefixes[prefix] = pr
+end
+
+exports.register_prefix = register_prefix
+
+--[[[
+-- @function lua_redis.prefixes([mname])
+-- Returns prefixes for specific module (or all prefixes). Returns a table prefix -> table
+--]]
+exports.prefixes = function(mname)
+ if not mname then
+ return redis_prefixes
+ else
+ local fun = require "fun"
+
+ return fun.totable(fun.filter(function(_, data)
+ return data.module == mname
+ end,
+ redis_prefixes))
+ end
+end
+
+return exports
diff --git a/lualib/lua_scanners/avast.lua b/lualib/lua_scanners/avast.lua
new file mode 100644
index 0000000..7e77897
--- /dev/null
+++ b/lualib/lua_scanners/avast.lua
@@ -0,0 +1,304 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module avast
+-- This module contains avast av access functions
+--]]
+
+local lua_util = require "lua_util"
+local rspamd_util = require "rspamd_util"
+local tcp = require "rspamd_tcp"
+local upstream_list = require "rspamd_upstream_list"
+local rspamd_regexp = require "rspamd_regexp"
+local rspamd_logger = require "rspamd_logger"
+local common = require "lua_scanners/common"
+
+local N = "avast"
+
+local default_message = '${SCANNER}: virus found: "${VIRUS}"'
+
+local function avast_config(opts)
+ local avast_conf = {
+ name = N,
+ scan_mime_parts = true,
+ scan_text_mime = false,
+ scan_image_mime = false,
+ timeout = 4.0, -- FIXME: this will break task_timeout!
+ log_clean = false,
+ detection_category = "virus",
+ retransmits = 1,
+ servers = nil, -- e.g. /var/run/avast/scan.sock
+ cache_expire = 3600, -- expire redis in one hour
+ message = default_message,
+ tmpdir = '/tmp',
+ }
+
+ avast_conf = lua_util.override_defaults(avast_conf, opts)
+
+ if not avast_conf.prefix then
+ avast_conf.prefix = 'rs_' .. avast_conf.name .. '_'
+ end
+
+ if not avast_conf.log_prefix then
+ if avast_conf.name:lower() == avast_conf.type:lower() then
+ avast_conf.log_prefix = avast_conf.name
+ else
+ avast_conf.log_prefix = avast_conf.name .. ' (' .. avast_conf.type .. ')'
+ end
+ end
+
+ if not avast_conf['servers'] then
+ rspamd_logger.errx(rspamd_config, 'no servers/unix socket defined')
+
+ return nil
+ end
+
+ avast_conf['upstreams'] = upstream_list.create(rspamd_config,
+ avast_conf['servers'],
+ 0)
+
+ if avast_conf['upstreams'] then
+ lua_util.add_debug_alias('antivirus', avast_conf.name)
+ return avast_conf
+ end
+
+ rspamd_logger.errx(rspamd_config, 'cannot parse servers %s',
+ avast_conf['servers'])
+ return nil
+end
+
+local function avast_check(task, content, digest, rule, maybe_part)
+ local function avast_check_uncached ()
+ local upstream = rule.upstreams:get_upstream_round_robin()
+ local addr = upstream:get_addr()
+ local retransmits = rule.retransmits
+ local CRLF = '\r\n'
+
+ -- Common tcp options
+ local tcp_opts = {
+ stop_pattern = CRLF,
+ host = addr:to_string(),
+ port = addr:get_port(),
+ upstream = upstream,
+ timeout = rule.timeout,
+ task = task
+ }
+
+ -- Regexps to process reply from avast
+ local clean_re = rspamd_regexp.create_cached(
+ [=[(?!\\)\t\[\+\]]=]
+ )
+ local virus_re = rspamd_regexp.create_cached(
+ [[(?!\\)\t\[L\]\d\.\d\t\d\s(.*)]]
+ )
+ local error_re = rspamd_regexp.create_cached(
+ [[(?!\\)\t\[E\]\d+\.0\tError\s\d+\s(.*)]]
+ )
+
+ -- Used to make a dialog
+ local tcp_conn
+
+ -- Save content in file as avast can work with files only
+ local fname = string.format('%s/%s.avtmp',
+ rule.tmpdir, rspamd_util.random_hex(32))
+ local message_fd = rspamd_util.create_file(fname)
+
+ if not message_fd then
+ rspamd_logger.errx('cannot store file for avast scan: %s', fname)
+ return
+ end
+
+ if type(content) == 'string' then
+ -- Create rspamd_text
+ local rspamd_text = require "rspamd_text"
+ content = rspamd_text.fromstring(content)
+ end
+ content:save_in_file(message_fd)
+
+ -- Ensure file cleanup on task processed
+ task:get_mempool():add_destructor(function()
+ os.remove(fname)
+ rspamd_util.close_file(message_fd)
+ end)
+
+ -- Dialog stages closures
+ local avast_helo_cb
+ local avast_scan_cb
+ local avast_scan_done_cb
+
+ -- Utility closures
+ local function maybe_retransmit()
+ if retransmits > 0 then
+ retransmits = retransmits - 1
+ else
+ rspamd_logger.errx(task,
+ '%s [%s]: failed to scan, maximum retransmits exceed',
+ rule['symbol'], rule['type'])
+ common.yield_result(task, rule, 'failed to scan and retransmits exceed',
+ 0.0, 'fail', maybe_part)
+
+ return
+ end
+
+ upstream = rule.upstreams:get_upstream_round_robin()
+ addr = upstream:get_addr()
+ tcp_opts.upstream = upstream
+ tcp_opts.callback = avast_helo_cb
+
+ local is_succ, err = tcp.request(tcp_opts)
+
+ if not is_succ then
+ rspamd_logger.infox(task, 'cannot create connection to avast server: %s (%s)',
+ addr:to_string(true), err)
+ else
+ lua_util.debugm(rule.log_prefix, task, 'established connection to %s; retransmits=%s',
+ addr:to_string(true), retransmits)
+ end
+ end
+
+ local function no_connection_error(err)
+ if err then
+ if tcp_conn then
+ tcp_conn:close()
+ tcp_conn = nil
+
+ rspamd_logger.infox(task, 'failed to request to avast (%s): %s',
+ addr:to_string(true), err)
+ maybe_retransmit()
+ end
+
+ return false
+ end
+
+ return true
+ end
+
+
+ -- Define callbacks
+ avast_helo_cb = function(merr, mdata, conn)
+ -- Called when we have established a connection but not read anything
+ tcp_conn = conn
+
+ if no_connection_error(merr) then
+ -- Check mdata to ensure that it starts with 220
+ if #mdata > 3 and tostring(mdata:span(1, 3)) == '220' then
+ tcp_conn:add_write(avast_scan_cb, string.format(
+ 'SCAN %s%s', fname, CRLF))
+ else
+ rspamd_logger.errx(task, 'Unhandled response: %s', mdata)
+ end
+ end
+ end
+
+ avast_scan_cb = function(merr)
+ -- Called when we have send request to avast and are waiting for reply
+ if no_connection_error(merr) then
+ tcp_conn:add_read(avast_scan_done_cb, CRLF)
+ end
+ end
+
+ avast_scan_done_cb = function(merr, mdata)
+ if no_connection_error(merr) then
+ lua_util.debugm(rule.log_prefix, task, 'got reply from avast: %s',
+ mdata)
+ if #mdata > 4 then
+ local beg = tostring(mdata:span(1, 3))
+
+ if beg == '210' then
+ -- Ignore 210, fire another read
+ if tcp_conn then
+ tcp_conn:add_read(avast_scan_done_cb, CRLF)
+ end
+ elseif beg == '200' then
+ -- Final line
+ if tcp_conn then
+ tcp_conn:close()
+ tcp_conn = nil
+ end
+ else
+ -- Check line using regular expressions
+ local cached
+ local ret = clean_re:search(mdata, false, true)
+
+ if ret then
+ cached = 'OK'
+ if rule.log_clean then
+ rspamd_logger.infox(task,
+ '%s [%s]: message or mime_part is clean',
+ rule.symbol, rule.type)
+ end
+ end
+
+ if not cached then
+ ret = virus_re:search(mdata, false, true)
+
+ if ret then
+ local vname = ret[1][2]
+
+ if vname then
+ vname = vname:gsub('\\ ', ' '):gsub('\\\\', '\\')
+ common.yield_result(task, rule, vname, 1.0, nil, maybe_part)
+ cached = vname
+ end
+ end
+ end
+
+ if not cached then
+ ret = error_re:search(mdata, false, true)
+
+ if ret then
+ rspamd_logger.errx(task, '%s: error: %s', rule.log_prefix, ret[1][2])
+ common.yield_result(task, rule, 'error:' .. ret[1][2],
+ 0.0, 'fail', maybe_part)
+ end
+ end
+
+ if cached then
+ common.save_cache(task, digest, rule, cached, 1.0, maybe_part)
+ else
+ -- Unexpected reply
+ rspamd_logger.errx(task, '%s: unexpected reply: %s', rule.log_prefix, mdata)
+ end
+ -- Read more
+ if tcp_conn then
+ tcp_conn:add_read(avast_scan_done_cb, CRLF)
+ end
+ end
+ end
+ end
+ end
+
+ -- Send the real request
+ maybe_retransmit()
+ end
+
+ if common.condition_check_and_continue(task, content, rule, digest,
+ avast_check_uncached, maybe_part) then
+ return
+ else
+ avast_check_uncached()
+ end
+
+end
+
+return {
+ type = 'antivirus',
+ description = 'Avast antivirus',
+ configure = avast_config,
+ check = avast_check,
+ name = N
+}
diff --git a/lualib/lua_scanners/clamav.lua b/lualib/lua_scanners/clamav.lua
new file mode 100644
index 0000000..fc99ab0
--- /dev/null
+++ b/lualib/lua_scanners/clamav.lua
@@ -0,0 +1,193 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module clamav
+-- This module contains clamav access functions
+--]]
+
+local lua_util = require "lua_util"
+local tcp = require "rspamd_tcp"
+local upstream_list = require "rspamd_upstream_list"
+local rspamd_util = require "rspamd_util"
+local rspamd_logger = require "rspamd_logger"
+local common = require "lua_scanners/common"
+
+local N = "clamav"
+
+local default_message = '${SCANNER}: virus found: "${VIRUS}"'
+
+local function clamav_config(opts)
+ local clamav_conf = {
+ name = N,
+ scan_mime_parts = true,
+ scan_text_mime = false,
+ scan_image_mime = false,
+ default_port = 3310,
+ log_clean = false,
+ timeout = 5.0, -- FIXME: this will break task_timeout!
+ detection_category = "virus",
+ retransmits = 2,
+ cache_expire = 3600, -- expire redis in one hour
+ message = default_message,
+ }
+
+ clamav_conf = lua_util.override_defaults(clamav_conf, opts)
+
+ if not clamav_conf.prefix then
+ clamav_conf.prefix = 'rs_' .. clamav_conf.name .. '_'
+ end
+
+ if not clamav_conf.log_prefix then
+ if clamav_conf.name:lower() == clamav_conf.type:lower() then
+ clamav_conf.log_prefix = clamav_conf.name
+ else
+ clamav_conf.log_prefix = clamav_conf.name .. ' (' .. clamav_conf.type .. ')'
+ end
+ end
+
+ if not clamav_conf['servers'] then
+ rspamd_logger.errx(rspamd_config, 'no servers defined')
+
+ return nil
+ end
+
+ clamav_conf['upstreams'] = upstream_list.create(rspamd_config,
+ clamav_conf['servers'],
+ clamav_conf.default_port)
+
+ if clamav_conf['upstreams'] then
+ lua_util.add_debug_alias('antivirus', clamav_conf.name)
+ return clamav_conf
+ end
+
+ rspamd_logger.errx(rspamd_config, 'cannot parse servers %s',
+ clamav_conf['servers'])
+ return nil
+end
+
+local function clamav_check(task, content, digest, rule, maybe_part)
+ local function clamav_check_uncached ()
+ local upstream = rule.upstreams:get_upstream_round_robin()
+ local addr = upstream:get_addr()
+ local retransmits = rule.retransmits
+ local header = rspamd_util.pack("c9 c1 >I4", "zINSTREAM", "\0",
+ #content)
+ local footer = rspamd_util.pack(">I4", 0)
+
+ local function clamav_callback(err, data)
+ if err then
+
+ -- retry with another upstream until retransmits exceeds
+ if retransmits > 0 then
+
+ retransmits = retransmits - 1
+
+ -- Select a different upstream!
+ upstream = rule.upstreams:get_upstream_round_robin()
+ addr = upstream:get_addr()
+
+ lua_util.debugm(rule.name, task, '%s: error: %s; retry IP: %s; retries left: %s',
+ rule.log_prefix, err, addr, retransmits)
+
+ tcp.request({
+ task = task,
+ host = addr:to_string(),
+ port = addr:get_port(),
+ upstream = upstream,
+ timeout = rule['timeout'],
+ callback = clamav_callback,
+ data = { header, content, footer },
+ stop_pattern = '\0'
+ })
+ else
+ rspamd_logger.errx(task, '%s: failed to scan, maximum retransmits exceed', rule.log_prefix)
+ common.yield_result(task, rule,
+ 'failed to scan and retransmits exceed', 0.0, 'fail',
+ maybe_part)
+ end
+
+ else
+ data = tostring(data)
+ local cached
+ lua_util.debugm(rule.name, task, '%s: got reply: %s',
+ rule.log_prefix, data)
+ if data == 'stream: OK' then
+ cached = 'OK'
+ if rule['log_clean'] then
+ rspamd_logger.infox(task, '%s: message or mime_part is clean',
+ rule.log_prefix)
+ else
+ lua_util.debugm(rule.name, task, '%s: message or mime_part is clean', rule.log_prefix)
+ end
+ else
+ local vname = string.match(data, 'stream: (.+) FOUND')
+ if string.find(vname, '^Heuristics%.Encrypted') then
+ rspamd_logger.errx(task, '%s: File is encrypted', rule.log_prefix)
+ common.yield_result(task, rule, 'File is encrypted: ' .. vname,
+ 0.0, 'encrypted', maybe_part)
+ cached = 'ENCRYPTED'
+ elseif string.find(vname, '^Heuristics%.OLE2%.ContainsMacros') then
+ rspamd_logger.errx(task, '%s: ClamAV Found an OLE2 Office Macro', rule.log_prefix)
+ common.yield_result(task, rule, vname, 0.0, 'macro', maybe_part)
+ cached = 'MACRO'
+ elseif string.find(vname, '^Heuristics%.Limits%.Exceeded') then
+ rspamd_logger.errx(task, '%s: ClamAV Limits Exceeded', rule.log_prefix)
+ common.yield_result(task, rule, 'Limits Exceeded: ' .. vname, 0.0,
+ 'fail', maybe_part)
+ elseif vname then
+ common.yield_result(task, rule, vname, 1.0, nil, maybe_part)
+ cached = vname
+ else
+ rspamd_logger.errx(task, '%s: unhandled response: %s', rule.log_prefix, data)
+ common.yield_result(task, rule, 'unhandled response:' .. vname, 0.0,
+ 'fail', maybe_part)
+ end
+ end
+ if cached then
+ common.save_cache(task, digest, rule, cached, 1.0, maybe_part)
+ end
+ end
+ end
+
+ tcp.request({
+ task = task,
+ host = addr:to_string(),
+ port = addr:get_port(),
+ timeout = rule['timeout'],
+ callback = clamav_callback,
+ upstream = upstream,
+ data = { header, content, footer },
+ stop_pattern = '\0'
+ })
+ end
+
+ if common.condition_check_and_continue(task, content, rule, digest,
+ clamav_check_uncached, maybe_part) then
+ return
+ else
+ clamav_check_uncached()
+ end
+
+end
+
+return {
+ type = 'antivirus',
+ description = 'clamav antivirus',
+ configure = clamav_config,
+ check = clamav_check,
+ name = N
+}
diff --git a/lualib/lua_scanners/cloudmark.lua b/lualib/lua_scanners/cloudmark.lua
new file mode 100644
index 0000000..b07f238
--- /dev/null
+++ b/lualib/lua_scanners/cloudmark.lua
@@ -0,0 +1,372 @@
+--[[
+Copyright (c) 2021, Alexander Moisseev <moiseev@mezonplus.ru>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module cloudmark
+-- This module contains Cloudmark v2 interface
+--]]
+
+local lua_util = require "lua_util"
+local http = require "rspamd_http"
+local upstream_list = require "rspamd_upstream_list"
+local rspamd_logger = require "rspamd_logger"
+local ucl = require "ucl"
+local rspamd_util = require "rspamd_util"
+local common = require "lua_scanners/common"
+local fun = require "fun"
+local lua_mime = require "lua_mime"
+
+local N = 'cloudmark'
+-- Boundary for multipart transfers, generated on module init
+local static_boundary = rspamd_util.random_hex(32)
+
+local function cloudmark_url(rule, addr, maybe_url)
+ local url
+ local port = addr:get_port()
+
+ maybe_url = maybe_url or rule.url
+ if port == 0 then
+ port = rule.default_port
+ end
+ if rule.use_https then
+ url = string.format('https://%s:%d%s', tostring(addr),
+ port, maybe_url)
+ else
+ url = string.format('http://%s:%d%s', tostring(addr),
+ port, maybe_url)
+ end
+
+ return url
+end
+
+-- Detect cloudmark max size
+local function cloudmark_preload(rule, cfg, ev_base, _)
+ local upstream = rule.upstreams:get_upstream_round_robin()
+ local addr = upstream:get_addr()
+ local function max_message_size_cb(http_err, code, body, _)
+ if http_err then
+ rspamd_logger.errx(ev_base, 'HTTP error when getting max message size: %s',
+ http_err)
+ return
+ end
+ if code ~= 200 then
+ rspamd_logger.errx(ev_base, 'bad HTTP code when getting max message size: %s', code)
+ end
+ local parser = ucl.parser()
+ local ret, err = parser:parse_string(body)
+ if not ret then
+ rspamd_logger.errx(ev_base, 'could not parse response body [%s]: %s', body, err)
+ return
+ end
+ local obj = parser:get_object()
+ local ms = obj.maxMessageSize
+ if not ms then
+ rspamd_logger.errx(ev_base, 'missing maxMessageSize in the response body (JSON): %s', obj)
+ return
+ end
+
+ rule.max_size = ms
+ lua_util.debugm(N, cfg, 'set maximum message size set to %s bytes', ms)
+ end
+ http.request({
+ ev_base = ev_base,
+ config = cfg,
+ url = cloudmark_url(rule, addr, '/score/v2/max-message-size'),
+ callback = max_message_size_cb,
+ })
+end
+
+local function cloudmark_config(opts)
+
+ local cloudmark_conf = {
+ name = N,
+ default_port = 2713,
+ url = '/score/v2/message',
+ use_https = false,
+ timeout = 5.0,
+ log_clean = false,
+ retransmits = 1,
+ score_threshold = 90, -- minimum score to considerate reply
+ message = '${SCANNER}: spam message found: "${VIRUS}"',
+ max_message = 0,
+ detection_category = "hash",
+ default_score = 1,
+ action = false,
+ log_spamcause = true,
+ symbol_fail = 'CLOUDMARK_FAIL',
+ symbol = 'CLOUDMARK_CHECK',
+ symbol_spam = 'CLOUDMARK_SPAM',
+ add_headers = false, -- allow addition of the headers from Cloudmark
+ }
+
+ cloudmark_conf = lua_util.override_defaults(cloudmark_conf, opts)
+
+ if not cloudmark_conf.prefix then
+ cloudmark_conf.prefix = 'rs_' .. cloudmark_conf.name .. '_'
+ end
+
+ if not cloudmark_conf.log_prefix then
+ if cloudmark_conf.name:lower() == cloudmark_conf.type:lower() then
+ cloudmark_conf.log_prefix = cloudmark_conf.name
+ else
+ cloudmark_conf.log_prefix = cloudmark_conf.name .. ' (' .. cloudmark_conf.type .. ')'
+ end
+ end
+
+ if not cloudmark_conf.servers and cloudmark_conf.socket then
+ cloudmark_conf.servers = cloudmark_conf.socket
+ end
+
+ if not cloudmark_conf.servers then
+ rspamd_logger.errx(rspamd_config, 'no servers defined')
+
+ return nil
+ end
+
+ cloudmark_conf.upstreams = upstream_list.create(rspamd_config,
+ cloudmark_conf.servers,
+ cloudmark_conf.default_port)
+
+ if cloudmark_conf.upstreams then
+
+ cloudmark_conf.symbols = { { symbol = cloudmark_conf.symbol_spam, score = 5.0 } }
+ cloudmark_conf.preloads = { cloudmark_preload }
+ lua_util.add_debug_alias('external_services', cloudmark_conf.name)
+ return cloudmark_conf
+ end
+
+ rspamd_logger.errx(rspamd_config, 'cannot parse servers %s',
+ cloudmark_conf['servers'])
+ return nil
+end
+
+-- Converts a key-value map to the table representing multipart body, with the following values:
+-- `data`: data of the part
+-- `filename`: optional filename
+-- `content-type`: content type of the element (optional)
+-- `content-transfer-encoding`: optional CTE header
+local function table_to_multipart_body(tbl, boundary)
+ local seen_data = false
+ local out = {}
+
+ for k, v in pairs(tbl) do
+ if v.data then
+ seen_data = true
+ table.insert(out, string.format('--%s\r\n', boundary))
+ if v.filename then
+ table.insert(out,
+ string.format('Content-Disposition: form-data; name="%s"; filename="%s"\r\n',
+ k, v.filename))
+ else
+ table.insert(out,
+ string.format('Content-Disposition: form-data; name="%s"\r\n', k))
+ end
+ if v['content-type'] then
+ table.insert(out,
+ string.format('Content-Type: %s\r\n', v['content-type']))
+ else
+ table.insert(out, 'Content-Type: text/plain\r\n')
+ end
+ if v['content-transfer-encoding'] then
+ table.insert(out,
+ string.format('Content-Transfer-Encoding: %s\r\n',
+ v['content-transfer-encoding']))
+ else
+ table.insert(out, 'Content-Transfer-Encoding: binary\r\n')
+ end
+ table.insert(out, '\r\n')
+ table.insert(out, v.data)
+ table.insert(out, '\r\n')
+ end
+ end
+
+ if seen_data then
+ table.insert(out, string.format('--%s--\r\n', boundary))
+ end
+
+ return out
+end
+
+local function parse_cloudmark_reply(task, rule, body)
+ local parser = ucl.parser()
+ local ret, err = parser:parse_string(body)
+ if not ret then
+ rspamd_logger.errx(task, '%s: bad response body (raw): %s', N, body)
+ task:insert_result(rule.symbol_fail, 1.0, 'Parser error: ' .. err)
+ return
+ end
+ local obj = parser:get_object()
+ lua_util.debugm(N, task, 'cloudmark reply is: %s', obj)
+
+ if not obj.score then
+ rspamd_logger.errx(task, '%s: bad response body (raw): %s', N, body)
+ task:insert_result(rule.symbol_fail, 1.0, 'Parser error: no score')
+ return
+ end
+
+ if obj.analysis then
+ -- Report analysis string
+ rspamd_logger.infox(task, 'cloudmark report string: %s', obj.analysis)
+ end
+
+ local score = tonumber(obj.score) or 0
+ if score >= rule.score_threshold then
+ task:insert_result(rule.symbol_spam, 1.0, tostring(score))
+ end
+
+ if rule.add_headers and type(obj.appendHeaders) == 'table' then
+ local headers_add = fun.tomap(fun.map(function(h)
+ return h.headerField, {
+ order = 1, value = h.body
+ }
+ end, obj.appendHeaders))
+ lua_mime.modify_headers(task, {
+ add = headers_add
+ })
+ end
+
+end
+
+local function cloudmark_check(task, content, digest, rule, maybe_part)
+ local function cloudmark_check_uncached()
+ local upstream = rule.upstreams:get_upstream_round_robin()
+ local addr = upstream:get_addr()
+ local retransmits = rule.retransmits
+
+ local url = cloudmark_url(rule, addr)
+ local message_data = task:get_content()
+ if rule.max_message and rule.max_message > 0 and #message_data > rule.max_message then
+ task:insert_result(rule['symbol_fail'], 0.0, 'Message too large: ' .. #message_data)
+ return
+ end
+ local request = {
+ rfc822 = {
+ ['Content-Type'] = 'message/rfc822',
+ data = message_data,
+ }
+ }
+
+ local helo = task:get_helo()
+ if helo then
+ request['heloDomain'] = {
+ data = helo,
+ }
+ end
+ local mail_from = task:get_from('smtp') or {}
+ if mail_from[1] and #mail_from[1].addr > 1 then
+ request['mailFrom'] = {
+ data = mail_from[1].addr
+ }
+ end
+
+ local rcpt_to = task:get_recipients('smtp')
+ if rcpt_to then
+ request['rcptTo'] = {
+ data = table.concat(fun.totable(fun.map(function(r)
+ return r.addr
+ end, rcpt_to)), ',')
+ }
+ end
+
+ local fip = task:get_from_ip()
+ if fip and fip:is_valid() then
+ request['connIp'] = tostring(fip)
+ end
+
+ local hostname = task:get_hostname()
+ if hostname then
+ request['fromHost'] = hostname
+ end
+
+ local request_data = {
+ task = task,
+ url = url,
+ body = table_to_multipart_body(request, static_boundary),
+ headers = {
+ ['Content-Type'] = string.format('multipart/form-data; boundary="%s"', static_boundary)
+ },
+ timeout = rule.timeout,
+ }
+
+ local function cloudmark_callback(http_err, code, body, headers)
+
+ local function cloudmark_requery()
+ -- set current upstream to fail because an error occurred
+ upstream:fail()
+
+ -- retry with another upstream until retransmits exceeds
+ if retransmits > 0 then
+
+ retransmits = retransmits - 1
+
+ lua_util.debugm(rule.name, task,
+ '%s: request Error: %s - retries left: %s',
+ rule.log_prefix, http_err, retransmits)
+
+ -- Select a different upstream!
+ upstream = rule.upstreams:get_upstream_round_robin()
+ addr = upstream:get_addr()
+ url = cloudmark_url(rule, addr)
+
+ lua_util.debugm(rule.name, task, '%s: retry IP: %s:%s',
+ rule.log_prefix, addr, addr:get_port())
+ request_data.url = url
+
+ http.request(request_data)
+ else
+ rspamd_logger.errx(task, '%s: failed to scan, maximum retransmits ' ..
+ 'exceed', rule.log_prefix)
+ task:insert_result(rule['symbol_fail'], 0.0, 'failed to scan and ' ..
+ 'retransmits exceed')
+ upstream:fail()
+ end
+ end
+
+ if http_err then
+ cloudmark_requery()
+ else
+ -- Parse the response
+ if upstream then
+ upstream:ok()
+ end
+ if code ~= 200 then
+ rspamd_logger.errx(task, 'invalid HTTP code: %s, body: %s, headers: %s', code, body, headers)
+ task:insert_result(rule.symbol_fail, 1.0, 'Bad HTTP code: ' .. code)
+ return
+ end
+ parse_cloudmark_reply(task, rule, body)
+ end
+ end
+
+ request_data.callback = cloudmark_callback
+ http.request(request_data)
+ end
+
+ if common.condition_check_and_continue(task, content, rule, digest,
+ cloudmark_check_uncached, maybe_part) then
+ return
+ else
+ cloudmark_check_uncached()
+ end
+end
+
+return {
+ type = { 'cloudmark', 'scanner' },
+ description = 'Cloudmark cartridge interface',
+ configure = cloudmark_config,
+ check = cloudmark_check,
+ name = N,
+}
diff --git a/lualib/lua_scanners/common.lua b/lualib/lua_scanners/common.lua
new file mode 100644
index 0000000..11f5e1f
--- /dev/null
+++ b/lualib/lua_scanners/common.lua
@@ -0,0 +1,539 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+Copyright (c) 2019, Carsten Rosenberg <c.rosenberg@heinlein-support.de>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_scanners_common
+-- This module contains common external scanners functions
+--]]
+
+local rspamd_logger = require "rspamd_logger"
+local rspamd_regexp = require "rspamd_regexp"
+local lua_util = require "lua_util"
+local lua_redis = require "lua_redis"
+local lua_magic_types = require "lua_magic/types"
+local fun = require "fun"
+
+local exports = {}
+
+local function log_clean(task, rule, msg)
+
+ msg = msg or 'message or mime_part is clean'
+
+ if rule.log_clean then
+ rspamd_logger.infox(task, '%s: %s', rule.log_prefix, msg)
+ else
+ lua_util.debugm(rule.name, task, '%s: %s', rule.log_prefix, msg)
+ end
+
+end
+
+local function match_patterns(default_sym, found, patterns, dyn_weight)
+ if type(patterns) ~= 'table' then
+ return default_sym, dyn_weight
+ end
+ if not patterns[1] then
+ for sym, pat in pairs(patterns) do
+ if pat:match(found) then
+ return sym, '1'
+ end
+ end
+ return default_sym, dyn_weight
+ else
+ for _, p in ipairs(patterns) do
+ for sym, pat in pairs(p) do
+ if pat:match(found) then
+ return sym, '1'
+ end
+ end
+ end
+ return default_sym, dyn_weight
+ end
+end
+
+local function yield_result(task, rule, vname, dyn_weight, is_fail, maybe_part)
+ local all_whitelisted = true
+ local patterns
+ local symbol
+ local threat_table
+ local threat_info
+ local flags
+
+ if type(vname) == 'string' then
+ threat_table = { vname }
+ elseif type(vname) == 'table' then
+ threat_table = vname
+ end
+
+
+ -- This should be more generic
+ if not is_fail then
+ patterns = rule.patterns
+ symbol = rule.symbol
+ threat_info = rule.detection_category .. 'found'
+ if not dyn_weight then
+ dyn_weight = 1.0
+ end
+ elseif is_fail == 'fail' then
+ patterns = rule.patterns_fail
+ symbol = rule.symbol_fail
+ threat_info = "FAILED with error"
+ dyn_weight = 0.0
+ elseif is_fail == 'encrypted' then
+ patterns = rule.patterns
+ symbol = rule.symbol_encrypted
+ threat_info = "Scan has returned that input was encrypted"
+ dyn_weight = 1.0
+ elseif is_fail == 'macro' then
+ patterns = rule.patterns
+ symbol = rule.symbol_macro
+ threat_info = "Scan has returned that input contains macros"
+ dyn_weight = 1.0
+ end
+
+ for _, tm in ipairs(threat_table) do
+ local symname, symscore = match_patterns(symbol, tm, patterns, dyn_weight)
+ if rule.whitelist and rule.whitelist:get_key(tm) then
+ rspamd_logger.infox(task, '%s: "%s" is in whitelist', rule.log_prefix, tm)
+ else
+ all_whitelisted = false
+ rspamd_logger.infox(task, '%s: result - %s: "%s - score: %s"',
+ rule.log_prefix, threat_info, tm, symscore)
+
+ if maybe_part and rule.show_attachments and maybe_part:get_filename() then
+ local fname = maybe_part:get_filename()
+ task:insert_result(symname, symscore, string.format("%s|%s",
+ tm, fname))
+ else
+ task:insert_result(symname, symscore, tm)
+ end
+
+ end
+ end
+
+ if rule.action and is_fail ~= 'fail' and not all_whitelisted then
+ threat_table = table.concat(threat_table, '; ')
+ if rule.action ~= 'reject' then
+ flags = 'least'
+ end
+ task:set_pre_result(rule.action,
+ lua_util.template(rule.message or 'Rejected', {
+ SCANNER = rule.name,
+ VIRUS = threat_table,
+ }), rule.name, nil, nil, flags)
+ end
+end
+
+local function message_not_too_large(task, content, rule)
+ local max_size = tonumber(rule.max_size)
+ if not max_size then
+ return true
+ end
+ if #content > max_size then
+ rspamd_logger.infox(task, "skip %s check as it is too large: %s (%s is allowed)",
+ rule.log_prefix, #content, max_size)
+ return false
+ end
+ return true
+end
+
+local function message_not_too_small(task, content, rule)
+ local min_size = tonumber(rule.min_size)
+ if not min_size then
+ return true
+ end
+ if #content < min_size then
+ rspamd_logger.infox(task, "skip %s check as it is too small: %s (%s is allowed)",
+ rule.log_prefix, #content, min_size)
+ return false
+ end
+ return true
+end
+
+local function message_min_words(task, rule)
+ if rule.text_part_min_words and tonumber(rule.text_part_min_words) > 0 then
+ local text_part_above_limit = false
+ local text_parts = task:get_text_parts()
+
+ local filter_func = function(p)
+ return p:get_words_count() >= tonumber(rule.text_part_min_words)
+ end
+
+ fun.each(function(p)
+ text_part_above_limit = true
+ end, fun.filter(filter_func, text_parts))
+
+ if not text_part_above_limit then
+ rspamd_logger.infox(task, '%s: #words in all text parts is below text_part_min_words limit: %s',
+ rule.log_prefix, rule.text_part_min_words)
+ end
+
+ return text_part_above_limit
+ else
+ return true
+ end
+end
+
+local function dynamic_scan(task, rule)
+ if rule.dynamic_scan then
+ if rule.action ~= 'reject' then
+ local metric_result = task:get_metric_score()
+ local metric_action = task:get_metric_action()
+ local has_pre_result = task:has_pre_result()
+ -- ToDo: needed?
+ -- Sometimes leads to FPs
+ --if rule.symbol_type == 'postfilter' and metric_action == 'reject' then
+ -- rspamd_logger.infox(task, '%s: aborting: %s', rule.log_prefix, "result is already reject")
+ -- return false
+ --elseif metric_result[1] > metric_result[2]*2 then
+ if metric_result[1] > metric_result[2] * 2 then
+ rspamd_logger.infox(task, '%s: aborting: %s', rule.log_prefix, 'score > 2 * reject_level: ' .. metric_result[1])
+ return false
+ elseif has_pre_result and metric_action == 'reject' then
+ rspamd_logger.infox(task, '%s: aborting: %s', rule.log_prefix, 'pre_result reject is set')
+ return false
+ else
+ return true, 'undecided'
+ end
+ else
+ return true, 'dynamic_scan is not possible with config `action=reject;`'
+ end
+ else
+ return true
+ end
+end
+
+local function need_check(task, content, rule, digest, fn, maybe_part)
+
+ local uncached = true
+ local key = digest
+
+ local function redis_av_cb(err, data)
+ if data and type(data) == 'string' then
+ -- Cached
+ data = lua_util.str_split(data, '\t')
+ local threat_string = lua_util.str_split(data[1], '\v')
+ local score = data[2] or rule.default_score
+
+ if threat_string[1] ~= 'OK' then
+ if threat_string[1] == 'MACRO' then
+ yield_result(task, rule, 'File contains macros',
+ 0.0, 'macro', maybe_part)
+ elseif threat_string[1] == 'ENCRYPTED' then
+ yield_result(task, rule, 'File is encrypted',
+ 0.0, 'encrypted', maybe_part)
+ else
+ lua_util.debugm(rule.name, task, '%s: got cached threat result for %s: %s - score: %s',
+ rule.log_prefix, key, threat_string[1], score)
+ yield_result(task, rule, threat_string, score, false, maybe_part)
+ end
+
+ else
+ lua_util.debugm(rule.name, task, '%s: got cached negative result for %s: %s',
+ rule.log_prefix, key, threat_string[1])
+ end
+ uncached = false
+ else
+ if err then
+ rspamd_logger.errx(task, 'got error checking cache: %s', err)
+ end
+ end
+
+ local f_message_not_too_large = message_not_too_large(task, content, rule)
+ local f_message_not_too_small = message_not_too_small(task, content, rule)
+ local f_message_min_words = message_min_words(task, rule)
+ local f_dynamic_scan = dynamic_scan(task, rule)
+
+ if uncached and
+ f_message_not_too_large and
+ f_message_not_too_small and
+ f_message_min_words and
+ f_dynamic_scan then
+
+ fn()
+
+ end
+
+ end
+
+ if rule.redis_params and not rule.no_cache then
+
+ key = rule.prefix .. key
+
+ if lua_redis.redis_make_request(task,
+ rule.redis_params, -- connect params
+ key, -- hash key
+ false, -- is write
+ redis_av_cb, --callback
+ 'GET', -- command
+ { key } -- arguments)
+ ) then
+ return true
+ end
+ end
+
+ return false
+
+end
+
+local function save_cache(task, digest, rule, to_save, dyn_weight, maybe_part)
+ local key = digest
+ if not dyn_weight then
+ dyn_weight = 1.0
+ end
+
+ local function redis_set_cb(err)
+ -- Do nothing
+ if err then
+ rspamd_logger.errx(task, 'failed to save %s cache for %s -> "%s": %s',
+ rule.detection_category, to_save, key, err)
+ else
+ lua_util.debugm(rule.name, task, '%s: saved cached result for %s: %s - score %s - ttl %s',
+ rule.log_prefix, key, to_save, dyn_weight, rule.cache_expire)
+ end
+ end
+
+ if type(to_save) == 'table' then
+ to_save = table.concat(to_save, '\v')
+ end
+
+ local value_tbl = { to_save, dyn_weight }
+ if maybe_part and rule.show_attachments and maybe_part:get_filename() then
+ local fname = maybe_part:get_filename()
+ table.insert(value_tbl, fname)
+ end
+ local value = table.concat(value_tbl, '\t')
+
+ if rule.redis_params and rule.prefix then
+ key = rule.prefix .. key
+
+ lua_redis.redis_make_request(task,
+ rule.redis_params, -- connect params
+ key, -- hash key
+ true, -- is write
+ redis_set_cb, --callback
+ 'SETEX', -- command
+ { key, rule.cache_expire or 0, value }
+ )
+ end
+
+ return false
+end
+
+local function create_regex_table(patterns)
+ local regex_table = {}
+ if patterns[1] then
+ for i, p in ipairs(patterns) do
+ if type(p) == 'table' then
+ local new_set = {}
+ for k, v in pairs(p) do
+ new_set[k] = rspamd_regexp.create_cached(v)
+ end
+ regex_table[i] = new_set
+ else
+ regex_table[i] = {}
+ end
+ end
+ else
+ for k, v in pairs(patterns) do
+ regex_table[k] = rspamd_regexp.create_cached(v)
+ end
+ end
+ return regex_table
+end
+
+local function match_filter(task, rule, found, patterns, pat_type)
+ if type(patterns) ~= 'table' or not found then
+ return false
+ end
+ if not patterns[1] then
+ for _, pat in pairs(patterns) do
+ if pat_type == 'ext' and tostring(pat) == tostring(found) then
+ return true
+ elseif pat_type == 'regex' and pat:match(found) then
+ return true
+ end
+ end
+ return false
+ else
+ for _, p in ipairs(patterns) do
+ for _, pat in ipairs(p) do
+ if pat_type == 'ext' and tostring(pat) == tostring(found) then
+ return true
+ elseif pat_type == 'regex' and pat:match(found) then
+ return true
+ end
+ end
+ end
+ return false
+ end
+end
+
+-- borrowed from mime_types.lua
+-- ext is the last extension, LOWERCASED
+-- ext2 is the one before last extension LOWERCASED
+local function gen_extension(fname)
+ local filename_parts = lua_util.str_split(fname, '.')
+
+ local ext = {}
+ for n = 1, 2 do
+ ext[n] = #filename_parts > n and string.lower(filename_parts[#filename_parts + 1 - n]) or nil
+ end
+ return ext[1], ext[2], filename_parts
+end
+
+local function check_parts_match(task, rule)
+
+ local filter_func = function(p)
+ local mtype, msubtype = p:get_type()
+ local detected_ext = p:get_detected_ext()
+ local fname = p:get_filename()
+ local ext, ext2
+
+ if rule.scan_all_mime_parts == false then
+ -- check file extension and filename regex matching
+ --lua_util.debugm(rule.name, task, '%s: filename: |%s|%s|', rule.log_prefix, fname)
+ if fname ~= nil then
+ ext, ext2 = gen_extension(fname)
+ --lua_util.debugm(rule.name, task, '%s: extension, fname: |%s|%s|%s|', rule.log_prefix, ext, ext2, fname)
+ if match_filter(task, rule, ext, rule.mime_parts_filter_ext, 'ext')
+ or match_filter(task, rule, ext2, rule.mime_parts_filter_ext, 'ext') then
+ lua_util.debugm(rule.name, task, '%s: extension matched: |%s|%s|', rule.log_prefix, ext, ext2)
+ return true
+ elseif match_filter(task, rule, fname, rule.mime_parts_filter_regex, 'regex') then
+ lua_util.debugm(rule.name, task, '%s: filename regex matched', rule.log_prefix)
+ return true
+ end
+ end
+ -- check content type string regex matching
+ if mtype ~= nil and msubtype ~= nil then
+ local ct = string.format('%s/%s', mtype, msubtype):lower()
+ if match_filter(task, rule, ct, rule.mime_parts_filter_regex, 'regex') then
+ lua_util.debugm(rule.name, task, '%s: regex content-type: %s', rule.log_prefix, ct)
+ return true
+ end
+ end
+ -- check detected content type (libmagic) regex matching
+ if detected_ext then
+ local magic = lua_magic_types[detected_ext] or {}
+ if match_filter(task, rule, detected_ext, rule.mime_parts_filter_ext, 'ext') then
+ lua_util.debugm(rule.name, task, '%s: detected extension matched: |%s|', rule.log_prefix, detected_ext)
+ return true
+ elseif magic.ct and match_filter(task, rule, magic.ct, rule.mime_parts_filter_regex, 'regex') then
+ lua_util.debugm(rule.name, task, '%s: regex detected libmagic content-type: %s',
+ rule.log_prefix, magic.ct)
+ return true
+ end
+ end
+ -- check filenames in archives
+ if p:is_archive() then
+ local arch = p:get_archive()
+ local filelist = arch:get_files_full(1000)
+ for _, f in ipairs(filelist) do
+ ext, ext2 = gen_extension(f.name)
+ if match_filter(task, rule, ext, rule.mime_parts_filter_ext, 'ext')
+ or match_filter(task, rule, ext2, rule.mime_parts_filter_ext, 'ext') then
+ lua_util.debugm(rule.name, task, '%s: extension matched in archive: |%s|%s|', rule.log_prefix, ext, ext2)
+ --lua_util.debugm(rule.name, task, '%s: extension matched in archive: %s', rule.log_prefix, ext)
+ return true
+ elseif match_filter(task, rule, f.name, rule.mime_parts_filter_regex, 'regex') then
+ lua_util.debugm(rule.name, task, '%s: filename regex matched in archive', rule.log_prefix)
+ return true
+ end
+ end
+ end
+ end
+
+ -- check text_part has more words than text_part_min_words_check
+ if rule.scan_text_mime and rule.text_part_min_words and p:is_text() and
+ p:get_words_count() >= tonumber(rule.text_part_min_words) then
+ return true
+ end
+
+ if rule.scan_image_mime and p:is_image() then
+ return true
+ end
+
+ if rule.scan_all_mime_parts ~= false then
+ local is_part_checkable = (p:is_attachment() and (not p:is_image() or rule.scan_image_mime))
+ if detected_ext then
+ -- We know what to scan!
+ local magic = lua_magic_types[detected_ext] or {}
+
+ if magic.av_check ~= false or is_part_checkable then
+ return true
+ end
+ elseif is_part_checkable then
+ -- Just rely on attachment property
+ return true
+ end
+ end
+
+ return false
+ end
+
+ return fun.filter(filter_func, task:get_parts())
+end
+
+local function check_metric_results(task, rule)
+
+ if rule.action ~= 'reject' then
+ local metric_result = task:get_metric_score()
+ local metric_action = task:get_metric_action()
+ local has_pre_result = task:has_pre_result()
+
+ if rule.symbol_type == 'postfilter' and metric_action == 'reject' then
+ return true, 'result is already reject'
+ elseif metric_result[1] > metric_result[2] * 2 then
+ return true, 'score > 2 * reject_level: ' .. metric_result[1]
+ elseif has_pre_result and metric_action == 'reject' then
+ return true, 'pre_result reject is set'
+ else
+ return false, 'undecided'
+ end
+ else
+ return false, 'dynamic_scan is not possible with config `action=reject;`'
+ end
+end
+
+exports.log_clean = log_clean
+exports.yield_result = yield_result
+exports.match_patterns = match_patterns
+exports.condition_check_and_continue = need_check
+exports.save_cache = save_cache
+exports.create_regex_table = create_regex_table
+exports.check_parts_match = check_parts_match
+exports.check_metric_results = check_metric_results
+
+setmetatable(exports, {
+ __call = function(t, override)
+ for k, v in pairs(t) do
+ if _G[k] ~= nil then
+ local msg = 'function ' .. k .. ' already exists in global scope.'
+ if override then
+ _G[k] = v
+ print('WARNING: ' .. msg .. ' Overwritten.')
+ else
+ print('NOTICE: ' .. msg .. ' Skipped.')
+ end
+ else
+ _G[k] = v
+ end
+ end
+ end,
+})
+
+return exports
diff --git a/lualib/lua_scanners/dcc.lua b/lualib/lua_scanners/dcc.lua
new file mode 100644
index 0000000..8d5e9e1
--- /dev/null
+++ b/lualib/lua_scanners/dcc.lua
@@ -0,0 +1,313 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+Copyright (c) 2018, Carsten Rosenberg <c.rosenberg@heinlein-support.de>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module dcc
+-- This module contains dcc access functions
+--]]
+
+local lua_util = require "lua_util"
+local tcp = require "rspamd_tcp"
+local upstream_list = require "rspamd_upstream_list"
+local rspamd_logger = require "rspamd_logger"
+local common = require "lua_scanners/common"
+local fun = require "fun"
+
+local N = 'dcc'
+
+local function dcc_config(opts)
+
+ local dcc_conf = {
+ name = N,
+ default_port = 10045,
+ timeout = 5.0,
+ log_clean = false,
+ retransmits = 2,
+ cache_expire = 7200, -- expire redis in 2h
+ message = '${SCANNER}: bulk message found: "${VIRUS}"',
+ detection_category = "hash",
+ default_score = 1,
+ action = false,
+ client = '0.0.0.0',
+ symbol_fail = 'DCC_FAIL',
+ symbol = 'DCC_REJECT',
+ symbol_bulk = 'DCC_BULK',
+ body_max = 999999,
+ fuz1_max = 999999,
+ fuz2_max = 999999,
+ }
+
+ dcc_conf = lua_util.override_defaults(dcc_conf, opts)
+
+ if not dcc_conf.prefix then
+ dcc_conf.prefix = 'rs_' .. dcc_conf.name .. '_'
+ end
+
+ if not dcc_conf.log_prefix then
+ dcc_conf.log_prefix = dcc_conf.name
+ end
+
+ if not dcc_conf.servers and dcc_conf.socket then
+ dcc_conf.servers = dcc_conf.socket
+ end
+
+ if not dcc_conf.servers then
+ rspamd_logger.errx(rspamd_config, 'no servers defined')
+
+ return nil
+ end
+
+ dcc_conf.upstreams = upstream_list.create(rspamd_config,
+ dcc_conf.servers,
+ dcc_conf.default_port)
+
+ if dcc_conf.upstreams then
+ lua_util.add_debug_alias('external_services', dcc_conf.name)
+ return dcc_conf
+ end
+
+ rspamd_logger.errx(rspamd_config, 'cannot parse servers %s',
+ dcc_conf['servers'])
+ return nil
+end
+
+local function dcc_check(task, content, digest, rule)
+ local function dcc_check_uncached ()
+ local upstream = rule.upstreams:get_upstream_round_robin()
+ local addr = upstream:get_addr()
+ local retransmits = rule.retransmits
+ local client = rule.client
+
+ local client_ip = task:get_from_ip()
+ if client_ip and client_ip:is_valid() then
+ client = client_ip:to_string()
+ end
+ local client_host = task:get_hostname()
+ if client_host then
+ client = client .. "\r" .. client_host
+ end
+
+ -- HELO
+ local helo = task:get_helo() or ''
+
+ -- Envelope From
+ local ef = task:get_from()
+ local envfrom = 'test@example.com'
+ if ef and ef[1] then
+ envfrom = ef[1]['addr']
+ end
+
+ -- Envelope To
+ local envrcpt = 'test@example.com'
+ local rcpts = task:get_recipients();
+ if rcpts then
+ local dcc_recipients = table.concat(fun.totable(fun.map(function(rcpt)
+ return rcpt['addr']
+ end,
+ rcpts)), '\n')
+ if dcc_recipients then
+ envrcpt = dcc_recipients
+ end
+ end
+
+ -- Build the DCC query
+ -- https://www.dcc-servers.net/dcc/dcc-tree/dccifd.html#Protocol
+ local request_data = {
+ "header\n",
+ client .. "\n",
+ helo .. "\n",
+ envfrom .. "\n",
+ envrcpt .. "\n",
+ "\n",
+ content
+ }
+
+ local function dcc_callback(err, data, conn)
+
+ local function dcc_requery()
+ -- retry with another upstream until retransmits exceeds
+ if retransmits > 0 then
+
+ retransmits = retransmits - 1
+
+ -- Select a different upstream!
+ upstream = rule.upstreams:get_upstream_round_robin()
+ addr = upstream:get_addr()
+
+ lua_util.debugm(rule.name, task, '%s: error: %s; retry IP: %s; retries left: %s',
+ rule.log_prefix, err, addr, retransmits)
+
+ tcp.request({
+ task = task,
+ host = addr:to_string(),
+ port = addr:get_port(),
+ timeout = rule.timeout or 2.0,
+ upstream = upstream,
+ shutdown = true,
+ data = request_data,
+ callback = dcc_callback,
+ body_max = 999999,
+ fuz1_max = 999999,
+ fuz2_max = 999999,
+ })
+ else
+ rspamd_logger.errx(task, '%s: failed to scan, maximum retransmits ' ..
+ 'exceed', rule.log_prefix)
+ common.yield_result(task, rule, 'failed to scan and retransmits exceed', 0.0, 'fail')
+ end
+ end
+
+ if err then
+
+ dcc_requery()
+
+ else
+ -- Parse the response
+ local _, _, result, disposition, header = tostring(data):find("(.-)\n(.-)\n(.-)$")
+ lua_util.debugm(rule.name, task, 'DCC result=%1 disposition=%2 header="%3"',
+ result, disposition, header)
+
+ if header then
+ -- Unfold header
+ header = header:gsub('\r?\n%s*', ' ')
+ local _, _, info = header:find("; (.-)$")
+ if (result == 'R') then
+ -- Reject
+ common.yield_result(task, rule, info, rule.default_score)
+ common.save_cache(task, digest, rule, info, rule.default_score)
+ elseif (result == 'T') then
+ -- Temporary failure
+ rspamd_logger.warnx(task, 'DCC returned a temporary failure result: %s', result)
+ dcc_requery()
+ elseif result == 'A' then
+
+ local opts = {}
+ local score = 0.0
+ info = info:lower()
+ local rep = info:match('rep=([^=%s]+)')
+
+ -- Adjust reputation if available
+ if rep then
+ rep = tonumber(rep)
+ end
+ if not rep then
+ rep = 1.0
+ end
+
+ local function check_threshold(what, num, lim)
+ local rnum
+ if num == 'many' then
+ rnum = lim
+ else
+ rnum = tonumber(num)
+ end
+
+ if rnum and rnum >= lim then
+ opts[#opts + 1] = string.format('%s=%s', what, num)
+ score = score + rep / 3.0
+ end
+ end
+
+ info = info:lower()
+ local body = info:match('body=([^=%s]+)')
+
+ if body then
+ check_threshold('body', body, rule.body_max)
+ end
+
+ local fuz1 = info:match('fuz1=([^=%s]+)')
+
+ if fuz1 then
+ check_threshold('fuz1', fuz1, rule.fuz1_max)
+ end
+
+ local fuz2 = info:match('fuz2=([^=%s]+)')
+
+ if fuz2 then
+ check_threshold('fuz2', fuz2, rule.fuz2_max)
+ end
+
+ if #opts > 0 and score > 0 then
+ task:insert_result(rule.symbol_bulk,
+ score,
+ opts)
+ common.save_cache(task, digest, rule, opts, score)
+ else
+ common.save_cache(task, digest, rule, 'OK')
+ if rule.log_clean then
+ rspamd_logger.infox(task, '%s: clean, returned result A - info: %s',
+ rule.log_prefix, info)
+ else
+ lua_util.debugm(rule.name, task, '%s: returned result A - info: %s',
+ rule.log_prefix, info)
+ end
+ end
+ elseif result == 'G' then
+ -- do nothing
+ common.save_cache(task, digest, rule, 'OK')
+ if rule.log_clean then
+ rspamd_logger.infox(task, '%s: clean, returned result G - info: %s', rule.log_prefix, info)
+ else
+ lua_util.debugm(rule.name, task, '%s: returned result G - info: %s', rule.log_prefix, info)
+ end
+ elseif result == 'S' then
+ -- do nothing
+ common.save_cache(task, digest, rule, 'OK')
+ if rule.log_clean then
+ rspamd_logger.infox(task, '%s: clean, returned result S - info: %s', rule.log_prefix, info)
+ else
+ lua_util.debugm(rule.name, task, '%s: returned result S - info: %s', rule.log_prefix, info)
+ end
+ else
+ -- Unknown result
+ rspamd_logger.warnx(task, '%s: result error: %1', rule.log_prefix, result);
+ common.yield_result(task, rule, 'error: ' .. result, 0.0, 'fail')
+ end
+ end
+ end
+ end
+
+ tcp.request({
+ task = task,
+ host = addr:to_string(),
+ port = addr:get_port(),
+ timeout = rule.timeout or 2.0,
+ shutdown = true,
+ upstream = upstream,
+ data = request_data,
+ callback = dcc_callback,
+ body_max = 999999,
+ fuz1_max = 999999,
+ fuz2_max = 999999,
+ })
+ end
+
+ if common.condition_check_and_continue(task, content, rule, digest, dcc_check_uncached) then
+ return
+ else
+ dcc_check_uncached()
+ end
+
+end
+
+return {
+ type = { 'dcc', 'bulk', 'hash', 'scanner' },
+ description = 'dcc bulk scanner',
+ configure = dcc_config,
+ check = dcc_check,
+ name = N
+}
diff --git a/lualib/lua_scanners/fprot.lua b/lualib/lua_scanners/fprot.lua
new file mode 100644
index 0000000..5a469c3
--- /dev/null
+++ b/lualib/lua_scanners/fprot.lua
@@ -0,0 +1,181 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module fprot
+-- This module contains fprot access functions
+--]]
+
+local lua_util = require "lua_util"
+local tcp = require "rspamd_tcp"
+local upstream_list = require "rspamd_upstream_list"
+local rspamd_logger = require "rspamd_logger"
+local common = require "lua_scanners/common"
+
+local N = "fprot"
+
+local default_message = '${SCANNER}: virus found: "${VIRUS}"'
+
+local function fprot_config(opts)
+ local fprot_conf = {
+ name = N,
+ scan_mime_parts = true,
+ scan_text_mime = false,
+ scan_image_mime = false,
+ default_port = 10200,
+ timeout = 5.0, -- FIXME: this will break task_timeout!
+ log_clean = false,
+ detection_category = "virus",
+ retransmits = 2,
+ cache_expire = 3600, -- expire redis in one hour
+ message = default_message,
+ }
+
+ fprot_conf = lua_util.override_defaults(fprot_conf, opts)
+
+ if not fprot_conf.prefix then
+ fprot_conf.prefix = 'rs_' .. fprot_conf.name .. '_'
+ end
+
+ if not fprot_conf.log_prefix then
+ if fprot_conf.name:lower() == fprot_conf.type:lower() then
+ fprot_conf.log_prefix = fprot_conf.name
+ else
+ fprot_conf.log_prefix = fprot_conf.name .. ' (' .. fprot_conf.type .. ')'
+ end
+ end
+
+ if not fprot_conf['servers'] then
+ rspamd_logger.errx(rspamd_config, 'no servers defined')
+
+ return nil
+ end
+
+ fprot_conf['upstreams'] = upstream_list.create(rspamd_config,
+ fprot_conf['servers'],
+ fprot_conf.default_port)
+
+ if fprot_conf['upstreams'] then
+ lua_util.add_debug_alias('antivirus', fprot_conf.name)
+ return fprot_conf
+ end
+
+ rspamd_logger.errx(rspamd_config, 'cannot parse servers %s',
+ fprot_conf['servers'])
+ return nil
+end
+
+local function fprot_check(task, content, digest, rule, maybe_part)
+ local function fprot_check_uncached ()
+ local upstream = rule.upstreams:get_upstream_round_robin()
+ local addr = upstream:get_addr()
+ local retransmits = rule.retransmits
+ local scan_id = task:get_queue_id()
+ if not scan_id then
+ scan_id = task:get_uid()
+ end
+ local header = string.format('SCAN STREAM %s SIZE %d\n', scan_id,
+ #content)
+ local footer = '\n'
+
+ local function fprot_callback(err, data)
+ if err then
+ -- retry with another upstream until retransmits exceeds
+ if retransmits > 0 then
+
+ retransmits = retransmits - 1
+
+ -- Select a different upstream!
+ upstream = rule.upstreams:get_upstream_round_robin()
+ addr = upstream:get_addr()
+
+ lua_util.debugm(rule.name, task, '%s: error: %s; retry IP: %s; retries left: %s',
+ rule.log_prefix, err, addr, retransmits)
+
+ tcp.request({
+ task = task,
+ host = addr:to_string(),
+ port = addr:get_port(),
+ upstream = upstream,
+ timeout = rule['timeout'],
+ callback = fprot_callback,
+ data = { header, content, footer },
+ stop_pattern = '\n'
+ })
+ else
+ rspamd_logger.errx(task,
+ '%s [%s]: failed to scan, maximum retransmits exceed',
+ rule['symbol'], rule['type'])
+ common.yield_result(task, rule, 'failed to scan and retransmits exceed',
+ 0.0, 'fail', maybe_part)
+ end
+ else
+ upstream:ok()
+ data = tostring(data)
+ local cached
+ local clean = string.match(data, '^0 <clean>')
+ if clean then
+ cached = 'OK'
+ if rule['log_clean'] then
+ rspamd_logger.infox(task,
+ '%s [%s]: message or mime_part is clean',
+ rule['symbol'], rule['type'])
+ end
+ else
+ -- returncodes: 1: infected, 2: suspicious, 3: both, 4-255: some error occurred
+ -- see http://www.f-prot.com/support/helpfiles/unix/appendix_c.html for more detail
+ local vname = string.match(data, '^[1-3] <[%w%s]-: (.-)>')
+ if not vname then
+ rspamd_logger.errx(task, 'Unhandled response: %s', data)
+ else
+ common.yield_result(task, rule, vname, 1.0, nil, maybe_part)
+ cached = vname
+ end
+ end
+ if cached then
+ common.save_cache(task, digest, rule, cached, 1.0, maybe_part)
+ end
+ end
+ end
+
+ tcp.request({
+ task = task,
+ host = addr:to_string(),
+ port = addr:get_port(),
+ upstream = upstream,
+ timeout = rule['timeout'],
+ callback = fprot_callback,
+ data = { header, content, footer },
+ stop_pattern = '\n'
+ })
+ end
+
+ if common.condition_check_and_continue(task, content, rule, digest,
+ fprot_check_uncached, maybe_part) then
+ return
+ else
+ fprot_check_uncached()
+ end
+
+end
+
+return {
+ type = 'antivirus',
+ description = 'fprot antivirus',
+ configure = fprot_config,
+ check = fprot_check,
+ name = N
+}
diff --git a/lualib/lua_scanners/icap.lua b/lualib/lua_scanners/icap.lua
new file mode 100644
index 0000000..682562d
--- /dev/null
+++ b/lualib/lua_scanners/icap.lua
@@ -0,0 +1,713 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+Copyright (c) 2019, Carsten Rosenberg <c.rosenberg@heinlein-support.de>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[
+@module icap
+This module contains icap access functions.
+Currently tested with
+ - C-ICAP Squidclamav / echo
+ - Checkpoint Sandblast
+ - F-Secure Internet Gatekeeper
+ - Kaspersky Web Traffic Security
+ - Kaspersky Scan Engine 2.0
+ - McAfee Web Gateway 9/10/11
+ - Sophos Savdi
+ - Symantec (Rspamd <3.2, >=3.2 untested)
+ - Trend Micro IWSVA 6.0
+ - Trend Micro Web Gateway
+
+@TODO
+ - Preview / Continue
+ - Reqmod URL's
+ - Content-Type / Filename
+]] --
+
+--[[
+Configuration Notes:
+
+C-ICAP Squidclamav
+ scheme = "squidclamav";
+
+Checkpoint Sandblast example:
+ scheme = "sandblast";
+
+ESET Gateway Security / Antivirus for Linux example:
+ scheme = "scan";
+
+F-Secure Internet Gatekeeper example:
+ scheme = "respmod";
+ x_client_header = true;
+ x_rcpt_header = true;
+ x_from_header = true;
+
+Kaspersky Web Traffic Security example:
+ scheme = "av/respmod";
+ x_client_header = true;
+
+Kaspersky Web Traffic Security (as configured in kavicapd.xml):
+ scheme = "resp";
+ x_client_header = true;
+
+McAfee Web Gateway 10/11 (Headers must be activated with personal extra Rules)
+ scheme = "respmod";
+ x_client_header = true;
+
+Sophos SAVDI example:
+ # scheme as configured in savdi.conf (name option in service section)
+ scheme = "respmod";
+
+Symantec example:
+ scheme = "avscan";
+
+Trend Micro IWSVA example (X-Virus-ID/X-Infection-Found headers must be activated):
+ scheme = "avscan";
+ x_client_header = true;
+
+Trend Micro Web Gateway example (X-Virus-ID/X-Infection-Found headers must be activated):
+ scheme = "interscan";
+ x_client_header = true;
+]] --
+
+
+local lua_util = require "lua_util"
+local tcp = require "rspamd_tcp"
+local upstream_list = require "rspamd_upstream_list"
+local rspamd_logger = require "rspamd_logger"
+local common = require "lua_scanners/common"
+local rspamd_util = require "rspamd_util"
+local rspamd_version = rspamd_version
+
+local N = 'icap'
+
+local function icap_config(opts)
+
+ local icap_conf = {
+ name = N,
+ scan_mime_parts = true,
+ scan_all_mime_parts = true,
+ scan_text_mime = false,
+ scan_image_mime = false,
+ scheme = "scan",
+ default_port = 1344,
+ ssl = false,
+ no_ssl_verify = false,
+ timeout = 10.0,
+ log_clean = false,
+ retransmits = 2,
+ cache_expire = 7200, -- expire redis in one hour
+ message = '${SCANNER}: threat found with icap scanner: "${VIRUS}"',
+ detection_category = "virus",
+ default_score = 1,
+ action = false,
+ dynamic_scan = false,
+ user_agent = "Rspamd",
+ x_client_header = false,
+ x_rcpt_header = false,
+ x_from_header = false,
+ req_headers_enabled = true,
+ req_fake_url = "http://127.0.0.1/mail",
+ http_headers_enabled = true,
+ use_http_result_header = true,
+ use_http_3xx_as_threat = false,
+ use_specific_content_type = false, -- Use content type from a part where possible
+ }
+
+ icap_conf = lua_util.override_defaults(icap_conf, opts)
+
+ if not icap_conf.prefix then
+ icap_conf.prefix = 'rs_' .. icap_conf.name .. '_'
+ end
+
+ if not icap_conf.log_prefix then
+ icap_conf.log_prefix = icap_conf.name .. ' (' .. icap_conf.type .. ')'
+ end
+
+ if not icap_conf.log_prefix then
+ if icap_conf.name:lower() == icap_conf.type:lower() then
+ icap_conf.log_prefix = icap_conf.name
+ else
+ icap_conf.log_prefix = icap_conf.name .. ' (' .. icap_conf.type .. ')'
+ end
+ end
+
+ if not icap_conf.servers then
+ rspamd_logger.errx(rspamd_config, 'no servers defined')
+
+ return nil
+ end
+
+ icap_conf.upstreams = upstream_list.create(rspamd_config,
+ icap_conf.servers,
+ icap_conf.default_port)
+
+ if icap_conf.upstreams then
+ lua_util.add_debug_alias('external_services', icap_conf.name)
+ return icap_conf
+ end
+
+ rspamd_logger.errx(rspamd_config, 'cannot parse servers %s',
+ icap_conf.servers)
+ return nil
+end
+
+local function icap_check(task, content, digest, rule, maybe_part)
+ local function icap_check_uncached ()
+ local upstream = rule.upstreams:get_upstream_round_robin()
+ local addr = upstream:get_addr()
+ local retransmits = rule.retransmits
+ local http_headers = {}
+ local req_headers = {}
+ local tcp_options = {}
+ local threat_table = {}
+
+ -- Build extended User Agent
+ if rule.user_agent == "extended" then
+ rule.user_agent = string.format("Rspamd/%s-%s (%s/%s)",
+ rspamd_version('main'),
+ rspamd_version('id'),
+ rspamd_util.get_hostname(),
+ string.sub(task:get_uid(), 1, 6))
+ end
+
+ -- Build the icap queries
+ local options_request = {
+ string.format("OPTIONS icap://%s/%s ICAP/1.0\r\n", addr:to_string(), rule.scheme),
+ string.format('Host: %s\r\n', addr:to_string()),
+ string.format("User-Agent: %s\r\n", rule.user_agent),
+ "Connection: keep-alive\r\n",
+ "Encapsulated: null-body=0\r\n\r\n",
+ }
+ if rule.user_agent == "none" then
+ table.remove(options_request, 3)
+ end
+
+ local respond_headers = {
+ -- Add main RESPMOD header before any other
+ string.format('RESPMOD icap://%s/%s ICAP/1.0\r\n', addr:to_string(), rule.scheme),
+ string.format('Host: %s\r\n', addr:to_string()),
+ }
+
+ local size = tonumber(#content)
+ local chunked_size = string.format("%x", size)
+
+ local function icap_callback(err, conn)
+
+ local function icap_requery(err_m, info)
+ -- retry with another upstream until retransmits exceeds
+ if retransmits > 0 then
+
+ retransmits = retransmits - 1
+
+ lua_util.debugm(rule.name, task,
+ '%s: %s Request Error: %s - retries left: %s',
+ rule.log_prefix, info, err_m, retransmits)
+
+ -- Select a different upstream!
+ upstream = rule.upstreams:get_upstream_round_robin()
+ addr = upstream:get_addr()
+
+ lua_util.debugm(rule.name, task, '%s: retry IP: %s:%s',
+ rule.log_prefix, addr, addr:get_port())
+
+ tcp_options.host = addr:to_string()
+ tcp_options.port = addr:get_port()
+ tcp_options.callback = icap_callback
+ tcp_options.data = options_request
+ tcp_options.upstream = upstream
+
+ tcp.request(tcp_options)
+
+ else
+ rspamd_logger.errx(task, '%s: failed to scan, maximum retransmits ' ..
+ 'exceed - error: %s', rule.log_prefix, err_m or '')
+ common.yield_result(task, rule, string.format('failed - error: %s', err_m),
+ 0.0, 'fail', maybe_part)
+ end
+ end
+
+ local function get_req_headers()
+
+ local in_client_ip = task:get_from_ip()
+ local req_hlen = 2
+ if maybe_part then
+ table.insert(req_headers,
+ string.format('GET http://%s/%s HTTP/1.0\r\n', in_client_ip, maybe_part:get_filename()))
+ if rule.use_specific_content_type then
+ table.insert(http_headers, string.format('Content-Type: %s/%s\r\n', maybe_part:get_detected_type()))
+ --else
+ -- To test: what content type is better for icap servers?
+ --table.insert(http_headers, 'Content-Type: text/plain\r\n')
+ end
+ else
+ table.insert(req_headers, string.format('GET %s HTTP/1.0\r\n', rule.req_fake_url))
+ table.insert(http_headers, string.format('Content-Type: application/octet-stream\r\n'))
+ end
+ table.insert(req_headers, string.format('Date: %s\r\n', rspamd_util.time_to_string(rspamd_util.get_time())))
+ if rule.user_agent ~= "none" then
+ table.insert(req_headers, string.format("User-Agent: %s\r\n", rule.user_agent))
+ end
+
+ for _, h in ipairs(req_headers) do
+ req_hlen = req_hlen + tonumber(#h)
+ end
+
+ return req_hlen, req_headers
+
+ end
+
+ local function get_http_headers()
+ local http_hlen = 2
+ table.insert(http_headers, 'HTTP/1.0 200 OK\r\n')
+ table.insert(http_headers, string.format('Date: %s\r\n', rspamd_util.time_to_string(rspamd_util.get_time())))
+ table.insert(http_headers, string.format('Server: %s\r\n', 'Apache/2.4'))
+ if rule.user_agent ~= "none" then
+ table.insert(http_headers, string.format("User-Agent: %s\r\n", rule.user_agent))
+ end
+ --table.insert(http_headers, string.format('Content-Type: %s\r\n', 'text/html'))
+ table.insert(http_headers, string.format('Content-Length: %s\r\n', size))
+
+ for _, h in ipairs(http_headers) do
+ http_hlen = http_hlen + tonumber(#h)
+ end
+
+ return http_hlen, http_headers
+
+ end
+
+ local function get_respond_query()
+ local req_hlen = 0
+ local resp_req_headers
+ local http_hlen = 0
+ local resp_http_headers
+
+ -- Append all extra headers
+ if rule.user_agent ~= "none" then
+ table.insert(respond_headers,
+ string.format("User-Agent: %s\r\n", rule.user_agent))
+ end
+
+ if rule.req_headers_enabled then
+ req_hlen, resp_req_headers = get_req_headers()
+ end
+ if rule.http_headers_enabled then
+ http_hlen, resp_http_headers = get_http_headers()
+ end
+
+ if rule.req_headers_enabled and rule.http_headers_enabled then
+ local res_body_hlen = req_hlen + http_hlen
+ table.insert(respond_headers,
+ string.format('Encapsulated: req-hdr=0, res-hdr=%s, res-body=%s\r\n',
+ req_hlen, res_body_hlen))
+ elseif rule.http_headers_enabled then
+ table.insert(respond_headers,
+ string.format('Encapsulated: res-hdr=0, res-body=%s\r\n',
+ http_hlen))
+ else
+ table.insert(respond_headers, 'Encapsulated: res-body=0\r\n')
+ end
+
+ table.insert(respond_headers, '\r\n')
+ for _, h in ipairs(resp_req_headers) do
+ table.insert(respond_headers, h)
+ end
+ table.insert(respond_headers, '\r\n')
+ for _, h in ipairs(resp_http_headers) do
+ table.insert(respond_headers, h)
+ end
+ table.insert(respond_headers, '\r\n')
+ table.insert(respond_headers, chunked_size .. '\r\n')
+ table.insert(respond_headers, content)
+ table.insert(respond_headers, '\r\n0\r\n\r\n')
+ return respond_headers
+ end
+
+ local function add_respond_header(name, value)
+ if name and value then
+ table.insert(respond_headers, string.format('%s: %s\r\n', name, value))
+ end
+ end
+
+ local function result_header_table(result)
+ local icap_headers = {}
+ for s in result:gmatch("[^\r\n]+") do
+ if string.find(s, '^ICAP') then
+ icap_headers['icap'] = tostring(s)
+ elseif string.find(s, '^HTTP') then
+ icap_headers['http'] = tostring(s)
+ elseif string.find(s, '[%a%d-+]-:') then
+ local _, _, key, value = tostring(s):find("([%a%d-+]-):%s?(.+)")
+ if key ~= nil then
+ icap_headers[key:lower()] = tostring(value)
+ end
+ end
+ end
+ lua_util.debugm(rule.name, task, '%s: icap_headers: %s',
+ rule.log_prefix, icap_headers)
+ return icap_headers
+ end
+
+ local function threat_table_add(icap_threat, maybe_split)
+
+ if maybe_split and string.find(icap_threat, ',') then
+ local threats = lua_util.str_split(string.gsub(icap_threat, "%s", ""), ',') or {}
+
+ for _, v in ipairs(threats) do
+ table.insert(threat_table, v)
+ end
+ else
+ table.insert(threat_table, icap_threat)
+ end
+ return true
+ end
+
+ local function icap_parse_result(headers)
+
+ --[[
+ @ToDo: handle type in response
+
+ Generic Strings:
+ icap: X-Infection-Found: Type=0; Resolution=2; Threat=Troj/DocDl-OYC;
+ icap: X-Infection-Found: Type=0; Resolution=2; Threat=W97M.Downloader;
+
+ Symantec String:
+ icap: X-Infection-Found: Type=2; Resolution=2; Threat=Container size violation
+ icap: X-Infection-Found: Type=2; Resolution=2; Threat=Encrypted container violation;
+
+ Sophos Strings:
+ icap: X-Virus-ID: Troj/DocDl-OYC
+ http: X-Blocked: Virus found during virus scan
+ http: X-Blocked-By: Sophos Anti-Virus
+
+ Kaspersky Web Traffic Security Strings:
+ icap: X-Virus-ID: HEUR:Backdoor.Java.QRat.gen
+ icap: X-Response-Info: blocked
+ icap: X-Virus-ID: no threats
+ icap: X-Response-Info: blocked
+ icap: X-Response-Info: passed
+ http: HTTP/1.1 403 Forbidden
+
+ Kaspersky Scan Engine 2.0 (ICAP mode)
+ icap: X-Virus-ID: EICAR-Test-File
+ http: HTTP/1.0 403 Forbidden
+
+ Trend Micro Strings:
+ icap: X-Virus-ID: Trojan.W97M.POWLOAD.SMTHF1
+ icap: X-Infection-Found: Type=0; Resolution=2; Threat=Trojan.W97M.POWLOAD.SMTHF1;
+ http: HTTP/1.1 403 Forbidden (TMWS Blocked)
+ http: HTTP/1.1 403 Forbidden
+
+ F-Secure Internet Gatekeeper Strings:
+ icap: X-FSecure-Scan-Result: infected
+ icap: X-FSecure-Infection-Name: "Malware.W97M/Agent.32584203"
+ icap: X-FSecure-Infected-Filename: "virus.doc"
+
+ ESET File Security for Linux 7.0
+ icap: X-Infection-Found: Type=0; Resolution=0; Threat=VBA/TrojanDownloader.Agent.JOA;
+ icap: X-Virus-ID: Trojaner
+ icap: X-Response-Info: Blocked
+
+ McAfee Web Gateway 10/11 (Headers must be activated with personal extra Rules)
+ icap: X-Virus-ID: EICAR test file
+ icap: X-Media-Type: text/plain
+ icap: X-Block-Result: 80
+ icap: X-Block-Reason: Malware found
+ icap: X-Block-Reason: Archive not supported
+ icap: X-Block-Reason: Media Type (Block List)
+ http: HTTP/1.0 403 VirusFound
+
+ C-ICAP Squidclamav
+ icap/http: X-Infection-Found: Type=0; Resolution=2; Threat={HEX}EICAR.TEST.3.UNOFFICIAL;
+ icap/http: X-Virus-ID: {HEX}EICAR.TEST.3.UNOFFICIAL
+ http: HTTP/1.0 307 Temporary Redirect
+ ]] --
+
+ -- Generic ICAP Headers
+ if headers['x-infection-found'] then
+ local _, _, icap_type, _, icap_threat = headers['x-infection-found']:find("Type=(.-); Resolution=(.-); Threat=(.-);$")
+
+ -- Type=2 is typical for scan error returns
+ if icap_type and icap_type == '2' then
+ lua_util.debugm(rule.name, task,
+ '%s: icap error X-Infection-Found: %s', rule.log_prefix, icap_threat)
+ common.yield_result(task, rule, icap_threat, 0,
+ 'fail', maybe_part)
+ return true
+ elseif icap_threat ~= nil then
+ lua_util.debugm(rule.name, task,
+ '%s: icap X-Infection-Found: %s', rule.log_prefix, icap_threat)
+ threat_table_add(icap_threat, false)
+ -- stupid workaround for unuseable x-infection-found header
+ -- but also x-virus-name set (McAfee Web Gateway 9)
+ elseif not icap_threat and headers['x-virus-name'] then
+ threat_table_add(headers['x-virus-name'], true)
+ else
+ threat_table_add(headers['x-infection-found'], true)
+ end
+ elseif headers['x-virus-name'] and headers['x-virus-name'] ~= "no threats" then
+ lua_util.debugm(rule.name, task,
+ '%s: icap X-Virus-Name: %s', rule.log_prefix, headers['x-virus-name'])
+ threat_table_add(headers['x-virus-name'], true)
+ elseif headers['x-virus-id'] and headers['x-virus-id'] ~= "no threats" then
+ lua_util.debugm(rule.name, task,
+ '%s: icap X-Virus-ID: %s', rule.log_prefix, headers['x-virus-id'])
+ threat_table_add(headers['x-virus-id'], true)
+ -- FSecure X-Headers
+ elseif headers['x-fsecure-scan-result'] and headers['x-fsecure-scan-result'] ~= "clean" then
+
+ local infected_filename = ""
+ local infection_name = "-unknown-"
+
+ if headers['x-fsecure-infected-filename'] then
+ infected_filename = string.gsub(headers['x-fsecure-infected-filename'], '[%s"]', '')
+ end
+ if headers['x-fsecure-infection-name'] then
+ infection_name = string.gsub(headers['x-fsecure-infection-name'], '[%s"]', '')
+ end
+
+ lua_util.debugm(rule.name, task,
+ '%s: icap X-FSecure-Infection-Name (X-FSecure-Infected-Filename): %s (%s)',
+ rule.log_prefix, infection_name, infected_filename)
+
+ threat_table_add(infection_name, true)
+ -- McAfee Web Gateway manual extra headers
+ elseif headers['x-mwg-block-reason'] and headers['x-mwg-block-reason'] ~= "" then
+ threat_table_add(headers['x-mwg-block-reason'], false)
+ -- Sophos SAVDI special http headers
+ elseif headers['x-blocked'] and headers['x-blocked'] ~= "" then
+ threat_table_add(headers['x-blocked'], false)
+ elseif headers['x-block-reason'] and headers['x-block-reason'] ~= "" then
+ threat_table_add(headers['x-block-reason'], false)
+ -- last try HTTP [4]xx return
+ elseif headers.http and string.find(headers.http, '^HTTP%/[12]%.. [4]%d%d') then
+ threat_table_add(
+ string.format("pseudo-virus (blocked): %s", string.gsub(headers.http, 'HTTP%/[12]%.. ', '')), false)
+ elseif rule.use_http_3xx_as_threat and
+ headers.http and
+ string.find(headers.http, '^HTTP%/[12]%.. [3]%d%d')
+ then
+ threat_table_add(
+ string.format("pseudo-virus (redirect): %s",
+ string.gsub(headers.http, 'HTTP%/[12]%.. ', '')), false)
+ end
+
+ if #threat_table > 0 then
+ common.yield_result(task, rule, threat_table, rule.default_score, nil, maybe_part)
+ common.save_cache(task, digest, rule, threat_table, rule.default_score, maybe_part)
+ return true
+ else
+ return false
+ end
+ end
+
+ local function icap_r_respond_http_cb(err_m, data, connection)
+ if err_m or connection == nil then
+ icap_requery(err_m, "icap_r_respond_http_cb")
+ else
+ local result = tostring(data)
+
+ local icap_http_headers = result_header_table(result) or {}
+ -- Find HTTP/[12].x [234]xx response
+ if icap_http_headers.http and string.find(icap_http_headers.http, 'HTTP%/[12]%.. [234]%d%d') then
+ local icap_http_header_result = icap_parse_result(icap_http_headers)
+ if icap_http_header_result then
+ -- Threat found - close connection
+ connection:close()
+ else
+ common.save_cache(task, digest, rule, 'OK', 0, maybe_part)
+ common.log_clean(task, rule)
+ end
+ else
+ rspamd_logger.errx(task, '%s: unhandled response |%s|',
+ rule.log_prefix, string.gsub(result, "\r\n", ", "))
+ common.yield_result(task, rule, string.format('unhandled icap response: %s', icap_http_headers.icap),
+ 0.0, 'fail', maybe_part)
+ end
+ end
+ end
+
+ local function icap_r_respond_cb(err_m, data, connection)
+ if err_m or connection == nil then
+ icap_requery(err_m, "icap_r_respond_cb")
+ else
+ local result = tostring(data)
+
+ local icap_headers = result_header_table(result) or {}
+ -- Find ICAP/1.x 2xx response
+ if icap_headers.icap and string.find(icap_headers.icap, 'ICAP%/1%.. 2%d%d') then
+ local icap_header_result = icap_parse_result(icap_headers)
+ if icap_header_result then
+ -- Threat found - close connection
+ connection:close()
+ elseif not icap_header_result
+ and rule.use_http_result_header
+ and icap_headers.encapsulated
+ and not string.find(icap_headers.encapsulated, 'null%-body=0')
+ then
+ -- Try to read encapsulated HTTP Headers
+ lua_util.debugm(rule.name, task, '%s: no ICAP virus header found - try HTTP headers',
+ rule.log_prefix)
+ connection:add_read(icap_r_respond_http_cb, '\r\n\r\n')
+ else
+ connection:close()
+ common.save_cache(task, digest, rule, 'OK', 0, maybe_part)
+ common.log_clean(task, rule)
+ end
+ elseif icap_headers.icap and string.find(icap_headers.icap, 'ICAP%/1%.. [45]%d%d') then
+ -- Find ICAP/1.x 5/4xx response
+ --[[
+ Symantec String:
+ ICAP/1.0 539 Aborted - No AV scanning license
+ SquidClamAV/C-ICAP:
+ ICAP/1.0 500 Server error
+ Eset:
+ ICAP/1.0 405 Forbidden
+ TrendMicro:
+ ICAP/1.0 400 Bad request
+ McAfee:
+ ICAP/1.0 418 Bad composition
+ ]]--
+ rspamd_logger.errx(task, '%s: ICAP ERROR: %s', rule.log_prefix, icap_headers.icap)
+ common.yield_result(task, rule, icap_headers.icap, 0.0,
+ 'fail', maybe_part)
+ return false
+ else
+ rspamd_logger.errx(task, '%s: unhandled response |%s|',
+ rule.log_prefix, string.gsub(result, "\r\n", ", "))
+ common.yield_result(task, rule, string.format('unhandled icap response: %s', icap_headers.icap),
+ 0.0, 'fail', maybe_part)
+ end
+ end
+ end
+
+ local function icap_w_respond_cb(err_m, connection)
+ if err_m or connection == nil then
+ icap_requery(err_m, "icap_w_respond_cb")
+ else
+ connection:add_read(icap_r_respond_cb, '\r\n\r\n')
+ end
+ end
+
+ local function icap_r_options_cb(err_m, data, connection)
+ if err_m or connection == nil then
+ icap_requery(err_m, "icap_r_options_cb")
+ else
+ local icap_headers = result_header_table(tostring(data))
+
+ if icap_headers.icap and string.find(icap_headers.icap, 'ICAP%/1%.. 2%d%d') then
+ if icap_headers['methods'] and string.find(icap_headers['methods'], 'RESPMOD') then
+ -- Allow "204 No Content" responses
+ -- https://datatracker.ietf.org/doc/html/rfc3507#section-4.6
+ if icap_headers['allow'] and string.find(icap_headers['allow'], '204') then
+ add_respond_header('Allow', '204')
+ end
+
+ if rule.x_client_header then
+ local client = task:get_from_ip()
+ if client then
+ add_respond_header('X-Client-IP', client:to_string())
+ end
+ end
+
+ -- F-Secure extra headers
+ if icap_headers['server'] and string.find(icap_headers['server'], 'f-secure icap server') then
+
+ if rule.x_rcpt_header then
+ local rcpt_to = task:get_principal_recipient()
+ if rcpt_to then
+ add_respond_header('X-Rcpt-To', rcpt_to)
+ end
+ end
+
+ if rule.x_from_header then
+ local mail_from = task:get_principal_recipient()
+ if mail_from and mail_from[1] then
+ add_respond_header('X-Rcpt-To', mail_from[1].addr)
+ end
+ end
+
+ end
+
+ if icap_headers.connection and icap_headers.connection:lower() == 'close' then
+ lua_util.debugm(rule.name, task, '%s: OPTIONS request Connection: %s - using new connection',
+ rule.log_prefix, icap_headers.connection)
+ connection:close()
+ tcp_options.callback = icap_w_respond_cb
+ tcp_options.data = get_respond_query()
+ tcp.request(tcp_options)
+ else
+ connection:add_write(icap_w_respond_cb, get_respond_query())
+ end
+
+ else
+ rspamd_logger.errx(task, '%s: RESPMOD method not advertised: Methods: %s',
+ rule.log_prefix, icap_headers['methods'])
+ common.yield_result(task, rule, 'NO RESPMOD', 0.0,
+ 'fail', maybe_part)
+ end
+ else
+ rspamd_logger.errx(task, '%s: OPTIONS query failed: %s',
+ rule.log_prefix, icap_headers.icap or "-")
+ common.yield_result(task, rule, 'OPTIONS query failed', 0.0,
+ 'fail', maybe_part)
+ end
+ end
+ end
+
+ if err or conn == nil then
+ icap_requery(err, "options_request")
+ else
+ conn:add_read(icap_r_options_cb, '\r\n\r\n')
+ end
+ end
+
+ tcp_options.task = task
+ tcp_options.stop_pattern = '\r\n'
+ tcp_options.read = false
+ tcp_options.timeout = rule.timeout
+ tcp_options.callback = icap_callback
+ tcp_options.data = options_request
+
+ if rule.ssl then
+ tcp_options.ssl = true
+ if rule.no_ssl_verify then
+ tcp_options.no_ssl_verify = true
+ end
+ end
+
+ tcp_options.host = addr:to_string()
+ tcp_options.port = addr:get_port()
+ tcp_options.upstream = upstream
+
+ tcp.request(tcp_options)
+ end
+
+ if common.condition_check_and_continue(task, content, rule, digest,
+ icap_check_uncached, maybe_part) then
+ return
+ else
+ icap_check_uncached()
+ end
+
+end
+
+return {
+ type = { N, 'virus', 'virus', 'scanner' },
+ description = 'generic icap antivirus',
+ configure = icap_config,
+ check = icap_check,
+ name = N
+}
diff --git a/lualib/lua_scanners/init.lua b/lualib/lua_scanners/init.lua
new file mode 100644
index 0000000..e47cebe
--- /dev/null
+++ b/lualib/lua_scanners/init.lua
@@ -0,0 +1,75 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_scanners
+-- This module contains external scanners functions
+--]]
+
+local fun = require "fun"
+
+local exports = {
+}
+
+local function require_scanner(name)
+ local sc = require("lua_scanners/" .. name)
+
+ exports[sc.name or name] = sc
+end
+
+-- Antiviruses
+require_scanner('clamav')
+require_scanner('fprot')
+require_scanner('kaspersky_av')
+require_scanner('kaspersky_se')
+require_scanner('savapi')
+require_scanner('sophos')
+require_scanner('virustotal')
+require_scanner('avast')
+
+-- Other scanners
+require_scanner('dcc')
+require_scanner('oletools')
+require_scanner('icap')
+require_scanner('vadesecure')
+require_scanner('spamassassin')
+require_scanner('p0f')
+require_scanner('razor')
+require_scanner('pyzor')
+require_scanner('cloudmark')
+
+exports.add_scanner = function(name, t, conf_func, check_func)
+ assert(type(conf_func) == 'function' and type(check_func) == 'function',
+ 'bad arguments')
+ exports[name] = {
+ type = t,
+ configure = conf_func,
+ check = check_func,
+ }
+end
+
+exports.filter = function(t)
+ return fun.tomap(fun.filter(function(_, elt)
+ return type(elt) == 'table' and elt.type and (
+ (type(elt.type) == 'string' and elt.type == t) or
+ (type(elt.type) == 'table' and fun.any(function(tt)
+ return tt == t
+ end, elt.type))
+ )
+ end, exports))
+end
+
+return exports
diff --git a/lualib/lua_scanners/kaspersky_av.lua b/lualib/lua_scanners/kaspersky_av.lua
new file mode 100644
index 0000000..d52cef0
--- /dev/null
+++ b/lualib/lua_scanners/kaspersky_av.lua
@@ -0,0 +1,197 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module kaspersky
+-- This module contains kaspersky antivirus access functions
+--]]
+
+local lua_util = require "lua_util"
+local tcp = require "rspamd_tcp"
+local upstream_list = require "rspamd_upstream_list"
+local rspamd_util = require "rspamd_util"
+local rspamd_logger = require "rspamd_logger"
+local common = require "lua_scanners/common"
+
+local N = "kaspersky"
+
+local default_message = '${SCANNER}: virus found: "${VIRUS}"'
+
+local function kaspersky_config(opts)
+ local kaspersky_conf = {
+ name = N,
+ scan_mime_parts = true,
+ scan_text_mime = false,
+ scan_image_mime = false,
+ product_id = 0,
+ log_clean = false,
+ timeout = 5.0,
+ retransmits = 1, -- use local files, retransmits are useless
+ cache_expire = 3600, -- expire redis in one hour
+ message = default_message,
+ detection_category = "virus",
+ tmpdir = '/tmp',
+ }
+
+ kaspersky_conf = lua_util.override_defaults(kaspersky_conf, opts)
+
+ if not kaspersky_conf.prefix then
+ kaspersky_conf.prefix = 'rs_' .. kaspersky_conf.name .. '_'
+ end
+
+ if not kaspersky_conf.log_prefix then
+ if kaspersky_conf.name:lower() == kaspersky_conf.type:lower() then
+ kaspersky_conf.log_prefix = kaspersky_conf.name
+ else
+ kaspersky_conf.log_prefix = kaspersky_conf.name .. ' (' .. kaspersky_conf.type .. ')'
+ end
+ end
+
+ if not kaspersky_conf['servers'] then
+ rspamd_logger.errx(rspamd_config, 'no servers defined')
+
+ return nil
+ end
+
+ kaspersky_conf['upstreams'] = upstream_list.create(rspamd_config,
+ kaspersky_conf['servers'], 0)
+
+ if kaspersky_conf['upstreams'] then
+ lua_util.add_debug_alias('antivirus', kaspersky_conf.name)
+ return kaspersky_conf
+ end
+
+ rspamd_logger.errx(rspamd_config, 'cannot parse servers %s',
+ kaspersky_conf['servers'])
+ return nil
+end
+
+local function kaspersky_check(task, content, digest, rule, maybe_part)
+ local function kaspersky_check_uncached ()
+ local upstream = rule.upstreams:get_upstream_round_robin()
+ local addr = upstream:get_addr()
+ local retransmits = rule.retransmits
+ local fname = string.format('%s/%s.tmp',
+ rule.tmpdir, rspamd_util.random_hex(32))
+ local message_fd = rspamd_util.create_file(fname)
+ local clamav_compat_cmd = string.format("nSCAN %s\n", fname)
+
+ if not message_fd then
+ rspamd_logger.errx('cannot store file for kaspersky scan: %s', fname)
+ return
+ end
+
+ if type(content) == 'string' then
+ -- Create rspamd_text
+ local rspamd_text = require "rspamd_text"
+ content = rspamd_text.fromstring(content)
+ end
+ content:save_in_file(message_fd)
+
+ -- Ensure file cleanup
+ task:get_mempool():add_destructor(function()
+ os.remove(fname)
+ rspamd_util.close_file(message_fd)
+ end)
+
+ local function kaspersky_callback(err, data)
+ if err then
+
+ -- retry with another upstream until retransmits exceeds
+ if retransmits > 0 then
+
+ retransmits = retransmits - 1
+
+ -- Select a different upstream!
+ upstream = rule.upstreams:get_upstream_round_robin()
+ addr = upstream:get_addr()
+
+ lua_util.debugm(rule.name, task, '%s: error: %s; retry IP: %s; retries left: %s',
+ rule.log_prefix, err, addr, retransmits)
+
+ tcp.request({
+ task = task,
+ host = addr:to_string(),
+ port = addr:get_port(),
+ upstream = upstream,
+ timeout = rule['timeout'],
+ callback = kaspersky_callback,
+ data = { clamav_compat_cmd },
+ stop_pattern = '\n'
+ })
+ else
+ rspamd_logger.errx(task,
+ '%s [%s]: failed to scan, maximum retransmits exceed',
+ rule['symbol'], rule['type'])
+ common.yield_result(task, rule,
+ 'failed to scan and retransmits exceed', 0.0, 'fail',
+ maybe_part)
+ end
+
+ else
+ data = tostring(data)
+ local cached
+ lua_util.debugm(rule.name, task,
+ '%s [%s]: got reply: %s',
+ rule['symbol'], rule['type'], data)
+ if data == 'stream: OK' or data == fname .. ': OK' then
+ cached = 'OK'
+ common.log_clean(task, rule)
+ else
+ local vname = string.match(data, ': (.+) FOUND')
+ if vname then
+ common.yield_result(task, rule, vname, 1.0, nil, maybe_part)
+ cached = vname
+ else
+ rspamd_logger.errx(task, 'unhandled response: %s', data)
+ common.yield_result(task, rule, 'unhandled response',
+ 0.0, 'fail', maybe_part)
+ end
+ end
+ if cached then
+ common.save_cache(task, digest, rule, cached, 1.0, maybe_part)
+ end
+ end
+ end
+
+ tcp.request({
+ task = task,
+ host = addr:to_string(),
+ port = addr:get_port(),
+ upstream = upstream,
+ timeout = rule['timeout'],
+ callback = kaspersky_callback,
+ data = { clamav_compat_cmd },
+ stop_pattern = '\n'
+ })
+ end
+
+ if common.condition_check_and_continue(task, content, rule, digest,
+ kaspersky_check_uncached, maybe_part) then
+ return
+ else
+ kaspersky_check_uncached()
+ end
+
+end
+
+return {
+ type = 'antivirus',
+ description = 'kaspersky antivirus',
+ configure = kaspersky_config,
+ check = kaspersky_check,
+ name = N
+}
diff --git a/lualib/lua_scanners/kaspersky_se.lua b/lualib/lua_scanners/kaspersky_se.lua
new file mode 100644
index 0000000..5e0f2ea
--- /dev/null
+++ b/lualib/lua_scanners/kaspersky_se.lua
@@ -0,0 +1,287 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module kaspersky_se
+-- This module contains Kaspersky Scan Engine integration support
+-- https://www.kaspersky.com/scan-engine
+--]]
+
+local lua_util = require "lua_util"
+local rspamd_util = require "rspamd_util"
+local http = require "rspamd_http"
+local upstream_list = require "rspamd_upstream_list"
+local rspamd_logger = require "rspamd_logger"
+local common = require "lua_scanners/common"
+
+local N = 'kaspersky_se'
+
+local function kaspersky_se_config(opts)
+
+ local default_conf = {
+ name = N,
+ default_port = 9999,
+ use_https = false,
+ use_files = false,
+ timeout = 5.0,
+ log_clean = false,
+ tmpdir = '/tmp',
+ retransmits = 1,
+ cache_expire = 7200, -- expire redis in 2h
+ message = '${SCANNER}: spam message found: "${VIRUS}"',
+ detection_category = "virus",
+ default_score = 1,
+ action = false,
+ scan_mime_parts = true,
+ scan_text_mime = false,
+ scan_image_mime = false,
+ }
+
+ default_conf = lua_util.override_defaults(default_conf, opts)
+
+ if not default_conf.prefix then
+ default_conf.prefix = 'rs_' .. default_conf.name .. '_'
+ end
+
+ if not default_conf.log_prefix then
+ if default_conf.name:lower() == default_conf.type:lower() then
+ default_conf.log_prefix = default_conf.name
+ else
+ default_conf.log_prefix = default_conf.name .. ' (' .. default_conf.type .. ')'
+ end
+ end
+
+ if not default_conf.servers and default_conf.socket then
+ default_conf.servers = default_conf.socket
+ end
+
+ if not default_conf.servers then
+ rspamd_logger.errx(rspamd_config, 'no servers defined')
+
+ return nil
+ end
+
+ default_conf.upstreams = upstream_list.create(rspamd_config,
+ default_conf.servers,
+ default_conf.default_port)
+
+ if default_conf.upstreams then
+ lua_util.add_debug_alias('external_services', default_conf.name)
+ return default_conf
+ end
+
+ rspamd_logger.errx(rspamd_config, 'cannot parse servers %s',
+ default_conf['servers'])
+ return nil
+end
+
+local function kaspersky_se_check(task, content, digest, rule, maybe_part)
+ local function kaspersky_se_check_uncached()
+ local function make_url(addr)
+ local url
+ local suffix = '/scanmemory'
+
+ if rule.use_files then
+ suffix = '/scanfile'
+ end
+ if rule.use_https then
+ url = string.format('https://%s:%d%s', tostring(addr),
+ addr:get_port(), suffix)
+ else
+ url = string.format('http://%s:%d%s', tostring(addr),
+ addr:get_port(), suffix)
+ end
+
+ return url
+ end
+
+ local upstream = rule.upstreams:get_upstream_round_robin()
+ local addr = upstream:get_addr()
+ local retransmits = rule.retransmits
+
+ local url = make_url(addr)
+ local hdrs = {
+ ['X-KAV-ProtocolVersion'] = '1',
+ ['X-KAV-Timeout'] = tostring(rule.timeout * 1000),
+ }
+
+ if task:has_from() then
+ hdrs['X-KAV-ObjectURL'] = string.format('[from:%s]', task:get_from()[1].addr)
+ end
+
+ local req_body
+
+ if rule.use_files then
+ local fname = string.format('%s/%s.tmp',
+ rule.tmpdir, rspamd_util.random_hex(32))
+ local message_fd = rspamd_util.create_file(fname)
+
+ if not message_fd then
+ rspamd_logger.errx('cannot store file for kaspersky_se scan: %s', fname)
+ return
+ end
+
+ if type(content) == 'string' then
+ -- Create rspamd_text
+ local rspamd_text = require "rspamd_text"
+ content = rspamd_text.fromstring(content)
+ end
+ content:save_in_file(message_fd)
+
+ -- Ensure cleanup
+ task:get_mempool():add_destructor(function()
+ os.remove(fname)
+ rspamd_util.close_file(message_fd)
+ end)
+
+ req_body = fname
+ else
+ req_body = content
+ end
+
+ local request_data = {
+ task = task,
+ url = url,
+ body = req_body,
+ headers = hdrs,
+ timeout = rule.timeout,
+ }
+
+ local function kas_callback(http_err, code, body, headers)
+
+ local function requery()
+ -- set current upstream to fail because an error occurred
+ upstream:fail()
+
+ -- retry with another upstream until retransmits exceeds
+ if retransmits > 0 then
+
+ retransmits = retransmits - 1
+
+ lua_util.debugm(rule.name, task,
+ '%s: Request Error: %s - retries left: %s',
+ rule.log_prefix, http_err, retransmits)
+
+ -- Select a different upstream!
+ upstream = rule.upstreams:get_upstream_round_robin()
+ addr = upstream:get_addr()
+ url = make_url(addr)
+
+ lua_util.debugm(rule.name, task, '%s: retry IP: %s:%s',
+ rule.log_prefix, addr, addr:get_port())
+ request_data.url = url
+ request_data.upstream = upstream
+
+ http.request(request_data)
+ else
+ rspamd_logger.errx(task, '%s: failed to scan, maximum retransmits ' ..
+ 'exceed', rule.log_prefix)
+ task:insert_result(rule['symbol_fail'], 0.0, 'failed to scan and ' ..
+ 'retransmits exceed')
+ end
+ end
+
+ if http_err then
+ requery()
+ else
+ -- Parse the response
+ if upstream then
+ upstream:ok()
+ end
+ if code ~= 200 then
+ rspamd_logger.errx(task, 'invalid HTTP code: %s, body: %s, headers: %s', code, body, headers)
+ task:insert_result(rule.symbol_fail, 1.0, 'Bad HTTP code: ' .. code)
+ return
+ end
+ local data = string.gsub(tostring(body), '[\r\n%s]$', '')
+ local cached
+ lua_util.debugm(rule.name, task, '%s: got reply data: "%s"',
+ rule.log_prefix, data)
+
+ if data:find('^CLEAN') then
+ -- Handle CLEAN replies
+ if data == 'CLEAN' then
+ cached = 'OK'
+ if rule['log_clean'] then
+ rspamd_logger.infox(task, '%s: message or mime_part is clean',
+ rule.log_prefix)
+ else
+ lua_util.debugm(rule.name, task, '%s: message or mime_part is clean',
+ rule.log_prefix)
+ end
+ elseif data == 'CLEAN AND CONTAINS OFFICE MACRO' then
+ common.yield_result(task, rule, 'File contains macros',
+ 0.0, 'macro', maybe_part)
+ cached = 'MACRO'
+ else
+ rspamd_logger.errx(task, '%s: unhandled clean response: %s', rule.log_prefix, data)
+ common.yield_result(task, rule, 'unhandled response:' .. data,
+ 0.0, 'fail', maybe_part)
+ end
+ elseif data == 'SERVER_ERROR' then
+ rspamd_logger.errx(task, '%s: error: %s', rule.log_prefix, data)
+ common.yield_result(task, rule, 'error:' .. data,
+ 0.0, 'fail', maybe_part)
+ elseif string.match(data, 'DETECT (.+)') then
+ local vname = string.match(data, 'DETECT (.+)')
+ common.yield_result(task, rule, vname, 1.0, nil, maybe_part)
+ cached = vname
+ elseif string.match(data, 'NON_SCANNED %((.+)%)') then
+ local why = string.match(data, 'NON_SCANNED %((.+)%)')
+
+ if why == 'PASSWORD PROTECTED' then
+ rspamd_logger.errx(task, '%s: File is encrypted', rule.log_prefix)
+ common.yield_result(task, rule, 'File is encrypted: ' .. why,
+ 0.0, 'encrypted', maybe_part)
+ cached = 'ENCRYPTED'
+ else
+ common.yield_result(task, rule, 'unhandled response:' .. data,
+ 0.0, 'fail', maybe_part)
+ end
+ else
+ rspamd_logger.errx(task, '%s: unhandled response: %s', rule.log_prefix, data)
+ common.yield_result(task, rule, 'unhandled response:' .. data,
+ 0.0, 'fail', maybe_part)
+ end
+
+ if cached then
+ common.save_cache(task, digest, rule, cached, 1.0, maybe_part)
+ end
+
+ end
+ end
+
+ request_data.callback = kas_callback
+ http.request(request_data)
+ end
+
+ if common.condition_check_and_continue(task, content, rule, digest,
+ kaspersky_se_check_uncached, maybe_part) then
+ return
+ else
+
+ kaspersky_se_check_uncached()
+ end
+
+end
+
+return {
+ type = 'antivirus',
+ description = 'Kaspersky Scan Engine interface',
+ configure = kaspersky_se_config,
+ check = kaspersky_se_check,
+ name = N
+}
diff --git a/lualib/lua_scanners/oletools.lua b/lualib/lua_scanners/oletools.lua
new file mode 100644
index 0000000..378e094
--- /dev/null
+++ b/lualib/lua_scanners/oletools.lua
@@ -0,0 +1,369 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+Copyright (c) 2018, Carsten Rosenberg <c.rosenberg@heinlein-support.de>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module oletools
+-- This module contains oletools access functions.
+-- Olefy is needed: https://github.com/HeinleinSupport/olefy
+--]]
+
+local lua_util = require "lua_util"
+local tcp = require "rspamd_tcp"
+local upstream_list = require "rspamd_upstream_list"
+local rspamd_logger = require "rspamd_logger"
+local ucl = require "ucl"
+local common = require "lua_scanners/common"
+
+local N = 'oletools'
+
+local function oletools_config(opts)
+
+ local oletools_conf = {
+ name = N,
+ scan_mime_parts = true,
+ scan_text_mime = false,
+ scan_image_mime = false,
+ default_port = 10050,
+ timeout = 15.0,
+ log_clean = false,
+ retransmits = 2,
+ cache_expire = 86400, -- expire redis in 1d
+ min_size = 500,
+ symbol = "OLETOOLS",
+ message = '${SCANNER}: Oletools threat message found: "${VIRUS}"',
+ detection_category = "office macro",
+ default_score = 1,
+ action = false,
+ extended = false,
+ symbol_type = 'postfilter',
+ dynamic_scan = true,
+ }
+
+ oletools_conf = lua_util.override_defaults(oletools_conf, opts)
+
+ if not oletools_conf.prefix then
+ oletools_conf.prefix = 'rs_' .. oletools_conf.name .. '_'
+ end
+
+ if not oletools_conf.log_prefix then
+ if oletools_conf.name:lower() == oletools_conf.type:lower() then
+ oletools_conf.log_prefix = oletools_conf.name
+ else
+ oletools_conf.log_prefix = oletools_conf.name .. ' (' .. oletools_conf.type .. ')'
+ end
+ end
+
+ if not oletools_conf.servers then
+ rspamd_logger.errx(rspamd_config, 'no servers defined')
+
+ return nil
+ end
+
+ oletools_conf.upstreams = upstream_list.create(rspamd_config,
+ oletools_conf.servers,
+ oletools_conf.default_port)
+
+ if oletools_conf.upstreams then
+ lua_util.add_debug_alias('external_services', oletools_conf.name)
+ return oletools_conf
+ end
+
+ rspamd_logger.errx(rspamd_config, 'cannot parse servers %s',
+ oletools_conf.servers)
+ return nil
+end
+
+local function oletools_check(task, content, digest, rule, maybe_part)
+ local function oletools_check_uncached ()
+ local upstream = rule.upstreams:get_upstream_round_robin()
+ local addr = upstream:get_addr()
+ local retransmits = rule.retransmits
+ local protocol = 'OLEFY/1.0\nMethod: oletools\nRspamd-ID: ' .. task:get_uid() .. '\n\n'
+ local json_response = ""
+
+ local function oletools_callback(err, data, conn)
+
+ local function oletools_requery(error)
+
+ -- retry with another upstream until retransmits exceeds
+ if retransmits > 0 then
+
+ retransmits = retransmits - 1
+
+ -- Select a different upstream!
+ upstream = rule.upstreams:get_upstream_round_robin()
+ addr = upstream:get_addr()
+
+ lua_util.debugm(rule.name, task, '%s: error: %s; retry IP: %s; retries left: %s',
+ rule.log_prefix, err, addr, retransmits)
+
+ tcp.request({
+ task = task,
+ host = addr:to_string(),
+ port = addr:get_port(),
+ upstream = upstream,
+ timeout = rule.timeout,
+ shutdown = true,
+ data = { protocol, content },
+ callback = oletools_callback,
+ })
+ else
+ rspamd_logger.errx(task, '%s: failed to scan, maximum retransmits ' ..
+ 'exceed - err: %s', rule.log_prefix, error)
+ common.yield_result(task, rule,
+ 'failed to scan, maximum retransmits exceed - err: ' .. error,
+ 0.0, 'fail', maybe_part)
+ end
+ end
+
+ if err then
+
+ oletools_requery(err)
+
+ else
+ json_response = json_response .. tostring(data)
+
+ if not string.find(json_response, '\t\n\n\t') and #data == 8192 then
+ lua_util.debugm(rule.name, task, '%s: no stop word: add_read - #json: %s / current packet: %s',
+ rule.log_prefix, #json_response, #data)
+ conn:add_read(oletools_callback)
+
+ else
+ local ucl_parser = ucl.parser()
+ local ok, ucl_err = ucl_parser:parse_string(tostring(json_response))
+ if not ok then
+ rspamd_logger.errx(task, "%s: error parsing json response, retry: %s",
+ rule.log_prefix, ucl_err)
+ oletools_requery(ucl_err)
+ return
+ end
+
+ local result = ucl_parser:get_object()
+
+ local oletools_rc = {
+ [0] = 'RETURN_OK',
+ [1] = 'RETURN_WARNINGS',
+ [2] = 'RETURN_WRONG_ARGS',
+ [3] = 'RETURN_FILE_NOT_FOUND',
+ [4] = 'RETURN_XGLOB_ERR',
+ [5] = 'RETURN_OPEN_ERROR',
+ [6] = 'RETURN_PARSE_ERROR',
+ [7] = 'RETURN_SEVERAL_ERRS',
+ [8] = 'RETURN_UNEXPECTED',
+ [9] = 'RETURN_ENCRYPTED',
+ }
+
+ -- M=Macros, A=Auto-executable, S=Suspicious keywords, I=IOCs,
+ -- H=Hex strings, B=Base64 strings, D=Dridex strings, V=VBA strings
+ -- Keep sorted to avoid dragons
+ local analysis_cat_table = {
+ autoexec = '-',
+ base64 = '-',
+ dridex = '-',
+ hex = '-',
+ iocs = '-',
+ macro_exist = '-',
+ suspicious = '-',
+ vba = '-'
+ }
+ local analysis_keyword_table = {}
+
+ for _, v in ipairs(result) do
+
+ if v.error ~= nil and v.type ~= 'error' then
+ -- olefy, not oletools error
+ rspamd_logger.errx(task, '%s: ERROR found: %s', rule.log_prefix,
+ v.error)
+ if v.error == 'File too small' then
+ common.save_cache(task, digest, rule, 'OK', 1.0, maybe_part)
+ common.log_clean(task, rule, 'File too small to be scanned for macros')
+ return
+ else
+ oletools_requery(v.error)
+ end
+
+ elseif tostring(v.type) == "MetaInformation" and v.version ~= nil then
+ -- if MetaInformation section - check and print script and version
+
+ lua_util.debugm(N, task, '%s: version: %s %s', rule.log_prefix,
+ tostring(v.script_name), tostring(v.version))
+
+ elseif tostring(v.type) == "MetaInformation" and v.return_code ~= nil then
+ -- if MetaInformation section - check return_code
+
+ local oletools_rc_code = tonumber(v.return_code)
+ if oletools_rc_code == 9 then
+ rspamd_logger.warnx(task, '%s: File is encrypted.', rule.log_prefix)
+ common.yield_result(task, rule,
+ 'failed - err: ' .. oletools_rc[oletools_rc_code],
+ 0.0, 'encrypted', maybe_part)
+ common.save_cache(task, digest, rule, 'encrypted', 1.0, maybe_part)
+ return
+ elseif oletools_rc_code == 5 then
+ rspamd_logger.warnx(task, '%s: olefy could not open the file - error: %s', rule.log_prefix,
+ result[2]['message'])
+ common.yield_result(task, rule,
+ 'failed - err: ' .. oletools_rc[oletools_rc_code],
+ 0.0, 'fail', maybe_part)
+ return
+ elseif oletools_rc_code > 6 then
+ rspamd_logger.errx(task, '%s: MetaInfo section error code: %s',
+ rule.log_prefix, oletools_rc[oletools_rc_code])
+ rspamd_logger.errx(task, '%s: MetaInfo section message: %s',
+ rule.log_prefix, result[2]['message'])
+ common.yield_result(task, rule,
+ 'failed - err: ' .. oletools_rc[oletools_rc_code],
+ 0.0, 'fail', maybe_part)
+ return
+ elseif oletools_rc_code > 1 then
+ rspamd_logger.errx(task, '%s: Error message: %s',
+ rule.log_prefix, result[2]['message'])
+ oletools_requery(oletools_rc[oletools_rc_code])
+ end
+
+ elseif tostring(v.type) == "error" then
+ -- error section found - check message
+ rspamd_logger.errx(task, '%s: Error section error code: %s',
+ rule.log_prefix, v.error)
+ rspamd_logger.errx(task, '%s: Error section message: %s',
+ rule.log_prefix, v.message)
+ --common.yield_result(task, rule, 'failed - err: ' .. v.error, 0.0, 'fail')
+
+ elseif type(v.analysis) == 'table' and type(v.macros) == 'table' then
+ -- analysis + macro found - evaluate response
+
+ if type(v.analysis) == 'table' and #v.analysis == 0 and #v.macros == 0 then
+ rspamd_logger.warnx(task, '%s: maybe unhandled python or oletools error', rule.log_prefix)
+ oletools_requery('oletools unhandled error')
+
+ elseif #v.macros > 0 then
+
+ analysis_cat_table.macro_exist = 'M'
+
+ lua_util.debugm(rule.name, task,
+ '%s: filename: %s', rule.log_prefix, result[2]['file'])
+ lua_util.debugm(rule.name, task,
+ '%s: type: %s', rule.log_prefix, result[2]['type'])
+
+ for _, m in ipairs(v.macros) do
+ lua_util.debugm(rule.name, task, '%s: macros found - code: %s, ole_stream: %s, ' ..
+ 'vba_filename: %s', rule.log_prefix, m.code, m.ole_stream, m.vba_filename)
+ end
+
+ for _, a in ipairs(v.analysis) do
+ lua_util.debugm(rule.name, task, '%s: threat found - type: %s, keyword: %s, ' ..
+ 'description: %s', rule.log_prefix, a.type, a.keyword, a.description)
+ if a.type == 'AutoExec' then
+ analysis_cat_table.autoexec = 'A'
+ table.insert(analysis_keyword_table, a.keyword)
+ elseif a.type == 'Suspicious' then
+ if rule.extended == true or
+ (a.keyword ~= 'Base64 Strings' and a.keyword ~= 'Hex Strings')
+ then
+ analysis_cat_table.suspicious = 'S'
+ table.insert(analysis_keyword_table, a.keyword)
+ end
+ elseif a.type == 'IOC' then
+ analysis_cat_table.iocs = 'I'
+ elseif a.type == 'Hex strings' then
+ analysis_cat_table.hex = 'H'
+ elseif a.type == 'Base64 strings' then
+ analysis_cat_table.base64 = 'B'
+ elseif a.type == 'Dridex strings' then
+ analysis_cat_table.dridex = 'D'
+ elseif a.type == 'VBA strings' then
+ analysis_cat_table.vba = 'V'
+ end
+ end
+ end
+ end
+ end
+
+ lua_util.debugm(N, task, '%s: analysis_keyword_table: %s', rule.log_prefix, analysis_keyword_table)
+ lua_util.debugm(N, task, '%s: analysis_cat_table: %s', rule.log_prefix, analysis_cat_table)
+
+ if rule.extended == false and analysis_cat_table.autoexec == 'A' and analysis_cat_table.suspicious == 'S' then
+ -- use single string as virus name
+ local threat = 'AutoExec + Suspicious (' .. table.concat(analysis_keyword_table, ',') .. ')'
+ lua_util.debugm(rule.name, task, '%s: threat result: %s', rule.log_prefix, threat)
+ common.yield_result(task, rule, threat, rule.default_score, nil, maybe_part)
+ common.save_cache(task, digest, rule, threat, rule.default_score, maybe_part)
+
+ elseif rule.extended == true and #analysis_keyword_table > 0 then
+ -- report any flags (types) and any most keywords as individual virus name
+ local analysis_cat_table_values_sorted = {}
+
+ -- see https://github.com/rspamd/rspamd/commit/6bd3e2b9f49d1de3ab882aeca9c30bc7d526ac9d#commitcomment-40130493
+ -- for details
+ local analysis_cat_table_keys_sorted = lua_util.keys(analysis_cat_table)
+ table.sort(analysis_cat_table_keys_sorted)
+
+ for _, v in ipairs(analysis_cat_table_keys_sorted) do
+ table.insert(analysis_cat_table_values_sorted, analysis_cat_table[v])
+ end
+
+ table.insert(analysis_keyword_table, 1, table.concat(analysis_cat_table_values_sorted))
+
+ lua_util.debugm(rule.name, task, '%s: extended threat result: %s',
+ rule.log_prefix, table.concat(analysis_keyword_table, ','))
+
+ common.yield_result(task, rule, analysis_keyword_table,
+ rule.default_score, nil, maybe_part)
+ common.save_cache(task, digest, rule, analysis_keyword_table,
+ rule.default_score, maybe_part)
+
+ elseif analysis_cat_table.macro_exist == '-' and #analysis_keyword_table == 0 then
+ common.save_cache(task, digest, rule, 'OK', 1.0, maybe_part)
+ common.log_clean(task, rule, 'No macro found')
+
+ else
+ common.save_cache(task, digest, rule, 'OK', 1.0, maybe_part)
+ common.log_clean(task, rule, 'Scanned Macro is OK')
+ end
+ end
+ end
+ end
+
+ tcp.request({
+ task = task,
+ host = addr:to_string(),
+ port = addr:get_port(),
+ upstream = upstream,
+ timeout = rule.timeout,
+ shutdown = true,
+ data = { protocol, content },
+ callback = oletools_callback,
+ })
+
+ end
+
+ if common.condition_check_and_continue(task, content, rule, digest,
+ oletools_check_uncached, maybe_part) then
+ return
+ else
+ oletools_check_uncached()
+ end
+
+end
+
+return {
+ type = { N, 'attachment scanner', 'hash', 'scanner' },
+ description = 'oletools office macro scanner',
+ configure = oletools_config,
+ check = oletools_check,
+ name = N
+}
diff --git a/lualib/lua_scanners/p0f.lua b/lualib/lua_scanners/p0f.lua
new file mode 100644
index 0000000..7785f83
--- /dev/null
+++ b/lualib/lua_scanners/p0f.lua
@@ -0,0 +1,227 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+Copyright (c) 2019, Denis Paavilainen <denpa@denpa.pro>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module p0f
+-- This module contains p0f access functions
+--]]
+
+local tcp = require "rspamd_tcp"
+local rspamd_util = require "rspamd_util"
+local rspamd_logger = require "rspamd_logger"
+local lua_redis = require "lua_redis"
+local lua_util = require "lua_util"
+local common = require "lua_scanners/common"
+
+-- SEE: https://github.com/p0f/p0f/blob/v3.06b/docs/README#L317
+local S = {
+ BAD_QUERY = 0x0,
+ OK = 0x10,
+ NO_MATCH = 0x20
+}
+
+local N = 'p0f'
+
+local function p0f_check(task, ip, rule)
+
+ local function ip2bin(addr)
+ addr = addr:to_table()
+
+ for k, v in ipairs(addr) do
+ addr[k] = rspamd_util.pack('B', v)
+ end
+
+ return table.concat(addr)
+ end
+
+ local function trim(...)
+ local vars = { ... }
+
+ for k, v in ipairs(vars) do
+ -- skip numbers, trim only strings
+ if tonumber(vars[k]) == nil then
+ vars[k] = string.gsub(v, '[^%w-_\\.\\(\\) ]', '')
+ end
+ end
+
+ return lua_util.unpack(vars)
+ end
+
+ local function parse_p0f_response(data)
+ --[[
+ p0f_api_response[232]: magic, status, first_seen, last_seen, total_conn,
+ uptime_min, up_mod_days, last_nat, last_chg, distance, bad_sw, os_match_q,
+ os_name, os_flavor, http_name, http_flavor, link_type, language
+ ]]--
+
+ data = tostring(data)
+
+ -- API response must be 232 bytes long
+ if #data ~= 232 then
+ rspamd_logger.errx(task, 'malformed response from p0f on %s, %s bytes',
+ rule.socket, #data)
+
+ common.yield_result(task, rule, 'Malformed Response: ' .. rule.socket,
+ 0.0, 'fail')
+ return
+ end
+
+ local _, status, _, _, _, uptime_min, _, _, _, distance, _, _, os_name,
+ os_flavor, _, _, link_type, _ = trim(rspamd_util.unpack(
+ 'I4I4I4I4I4I4I4I4I4hbbc32c32c32c32c32c32', data))
+
+ if status ~= S.OK then
+ if status == S.BAD_QUERY then
+ rspamd_logger.errx(task, 'malformed p0f query on %s', rule.socket)
+ common.yield_result(task, rule, 'Malformed Query: ' .. rule.socket,
+ 0.0, 'fail')
+ end
+
+ return
+ end
+
+ local os_string = #os_name == 0 and 'unknown' or os_name .. ' ' .. os_flavor
+
+ task:get_mempool():set_variable('os_fingerprint', os_string, link_type,
+ uptime_min, distance)
+
+ if link_type and #link_type > 0 then
+ common.yield_result(task, rule, {
+ os_string,
+ 'link=' .. link_type,
+ 'distance=' .. distance },
+ 0.0)
+ else
+ common.yield_result(task, rule, {
+ os_string,
+ 'link=unknown',
+ 'distance=' .. distance },
+ 0.0)
+ end
+
+ return data
+ end
+
+ local function make_p0f_request()
+
+ local function check_p0f_cb(err, data)
+
+ local function redis_set_cb(redis_set_err)
+ if redis_set_err then
+ rspamd_logger.errx(task, 'redis received an error: %s', redis_set_err)
+ end
+ end
+
+ if err then
+ rspamd_logger.errx(task, 'p0f received an error: %s', err)
+ common.yield_result(task, rule, 'Error getting result: ' .. err,
+ 0.0, 'fail')
+ return
+ end
+
+ data = parse_p0f_response(data)
+
+ if rule.redis_params and data then
+ local key = rule.prefix .. ip:to_string()
+ local ret = lua_redis.redis_make_request(task,
+ rule.redis_params,
+ key,
+ true,
+ redis_set_cb,
+ 'SETEX',
+ { key, tostring(rule.expire), data }
+ )
+
+ if not ret then
+ rspamd_logger.warnx(task, 'error connecting to redis')
+ end
+ end
+ end
+
+ local query = rspamd_util.pack('I4 I1 c16', 0x50304601,
+ ip:get_version(), ip2bin(ip))
+
+ tcp.request({
+ host = rule.socket,
+ callback = check_p0f_cb,
+ data = { query },
+ task = task,
+ timeout = rule.timeout
+ })
+ end
+
+ local function redis_get_cb(err, data)
+ if err or type(data) ~= 'string' then
+ make_p0f_request()
+ else
+ parse_p0f_response(data)
+ end
+ end
+
+ local ret = nil
+ if rule.redis_params then
+ local key = rule.prefix .. ip:to_string()
+ ret = lua_redis.redis_make_request(task,
+ rule.redis_params,
+ key,
+ false,
+ redis_get_cb,
+ 'GET',
+ { key }
+ )
+ end
+
+ if not ret then
+ make_p0f_request() -- fallback to directly querying p0f
+ end
+end
+
+local function p0f_config(opts)
+ local p0f_conf = {
+ name = N,
+ timeout = 5,
+ symbol = 'P0F',
+ symbol_fail = 'P0F_FAIL',
+ patterns = {},
+ expire = 7200,
+ prefix = 'p0f',
+ detection_category = 'fingerprint',
+ message = '${SCANNER}: fingerprint matched: "${VIRUS}"'
+ }
+
+ p0f_conf = lua_util.override_defaults(p0f_conf, opts)
+ p0f_conf.patterns = common.create_regex_table(p0f_conf.patterns)
+
+ if not p0f_conf.log_prefix then
+ p0f_conf.log_prefix = p0f_conf.name
+ end
+
+ if not p0f_conf.socket then
+ rspamd_logger.errx(rspamd_config, 'no servers defined')
+ return nil
+ end
+
+ return p0f_conf
+end
+
+return {
+ type = { N, 'fingerprint', 'scanner' },
+ description = 'passive OS fingerprinter',
+ configure = p0f_config,
+ check = p0f_check,
+ name = N
+}
diff --git a/lualib/lua_scanners/pyzor.lua b/lualib/lua_scanners/pyzor.lua
new file mode 100644
index 0000000..75c1b4a
--- /dev/null
+++ b/lualib/lua_scanners/pyzor.lua
@@ -0,0 +1,206 @@
+--[[
+Copyright (c) 2021, defkev <defkev@gmail.com>
+Copyright (c) 2018, Carsten Rosenberg <c.rosenberg@heinlein-support.de>
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module pyzor
+-- This module contains pyzor access functions
+--]]
+
+local lua_util = require "lua_util"
+local tcp = require "rspamd_tcp"
+local upstream_list = require "rspamd_upstream_list"
+local rspamd_logger = require "rspamd_logger"
+local common = require "lua_scanners/common"
+
+local N = 'pyzor'
+local categories = { 'pyzor', 'bulk', 'hash', 'scanner' }
+
+local function pyzor_config(opts)
+
+ local pyzor_conf = {
+ text_part_min_words = 2,
+ default_port = 5953,
+ timeout = 15.0,
+ log_clean = false,
+ retransmits = 2,
+ detection_category = "hash",
+ cache_expire = 7200, -- expire redis in one hour
+ message = '${SCANNER}: Pyzor bulk message found: "${VIRUS}"',
+ default_score = 1.5,
+ action = false,
+ }
+
+ pyzor_conf = lua_util.override_defaults(pyzor_conf, opts)
+
+ if not pyzor_conf.prefix then
+ pyzor_conf.prefix = 'rext_' .. N .. '_'
+ end
+
+ if not pyzor_conf.log_prefix then
+ pyzor_conf.log_prefix = N .. ' (' .. pyzor_conf.detection_category .. ')'
+ end
+
+ if not pyzor_conf['servers'] then
+ rspamd_logger.errx(rspamd_config, 'no servers defined')
+
+ return nil
+ end
+
+ pyzor_conf['upstreams'] = upstream_list.create(rspamd_config,
+ pyzor_conf['servers'],
+ pyzor_conf.default_port)
+
+ if pyzor_conf['upstreams'] then
+ lua_util.add_debug_alias('external_services', N)
+ return pyzor_conf
+ end
+
+ rspamd_logger.errx(rspamd_config, 'cannot parse servers %s',
+ pyzor_conf['servers'])
+ return nil
+end
+
+local function pyzor_check(task, content, digest, rule)
+ local function pyzor_check_uncached ()
+ local upstream = rule.upstreams:get_upstream_round_robin()
+ local addr = upstream:get_addr()
+ local retransmits = rule.retransmits
+
+ local function pyzor_callback(err, data, conn)
+
+ if err then
+
+ -- retry with another upstream until retransmits exceeds
+ if retransmits > 0 then
+
+ retransmits = retransmits - 1
+
+ -- Select a different upstream!
+ upstream = rule.upstreams:get_upstream_round_robin()
+ addr = upstream:get_addr()
+
+ lua_util.debugm(N, task, '%s: retry IP: %s:%s err: %s',
+ rule.log_prefix, addr, addr:get_port(), err)
+
+ tcp.request({
+ task = task,
+ host = addr:to_string(),
+ port = addr:get_port(),
+ upstream = upstream,
+ timeout = rule['timeout'],
+ shutdown = true,
+ data = content,
+ callback = pyzor_callback,
+ })
+ else
+ rspamd_logger.errx(task, '%s: failed to scan, maximum retransmits exceed',
+ rule['symbol'], rule['type'])
+ task:insert_result(rule['symbol_fail'], 0.0,
+ 'failed to scan and retransmits exceed')
+ end
+ else
+ -- pyzor output is unicode (\x09 -> tab, \0a -> newline)
+ -- public.pyzor.org:24441 (200, 'OK') 21285091 206759
+ -- server:port Code Diag Count WL-Count
+ local str_data = tostring(data)
+ lua_util.debugm(N, task, '%s: returned data: %s',
+ rule.log_prefix, str_data)
+ -- If pyzor would return JSON this wouldn't be necessary
+ local resp = {}
+ for v in string.gmatch(str_data, '[^\t]+') do
+ table.insert(resp, v)
+ end
+ -- rspamd_logger.infox(task, 'resp: %s', resp)
+ if resp[2] ~= [[(200, 'OK')]] then
+ rspamd_logger.errx(task, "error parsing response: %s", str_data)
+ return
+ end
+
+ local whitelisted = tonumber(resp[4])
+ local reported = tonumber(resp[3])
+
+ --rspamd_logger.infox(task, "%s - count=%s wl=%s", addr:to_string(), reported, whitelisted)
+
+ --[[
+ Weight is Count - WL-Count of rule.default_score in percent, e.g.
+ SPAM:
+ Count: 100 (100%)
+ WL-Count: 1 (1%)
+ rule.default_score: 1
+ Weight: 0.99
+ HAM:
+ Count: 10 (100%)
+ WL-Count: 10 (100%)
+ rule.default_score: 1
+ Weight: 0
+ ]]
+ local weight = tonumber(string.format("%.2f",
+ rule.default_score * (reported - whitelisted) / (reported + whitelisted)))
+ local info = string.format("count=%d wl=%d", reported, whitelisted)
+ local threat_string = string.format("bl_%d_wl_%d",
+ reported, whitelisted)
+
+ if weight > 0 then
+ lua_util.debugm(N, task, '%s: returned result is spam - info: %s',
+ rule.log_prefix, info)
+ common.yield_result(task, rule, threat_string, weight)
+ common.save_cache(task, digest, rule, threat_string, weight)
+ else
+ if rule.log_clean then
+ rspamd_logger.infox(task, '%s: clean, returned result is ham - info: %s',
+ rule.log_prefix, info)
+ else
+ lua_util.debugm(N, task, '%s: returned result is ham - info: %s',
+ rule.log_prefix, info)
+ end
+ common.save_cache(task, digest, rule, 'OK', weight)
+ end
+
+ end
+ end
+
+ if digest == 'da39a3ee5e6b4b0d3255bfef95601890afd80709' then
+ rspamd_logger.infox(task, '%s: not checking default digest', rule.log_prefix)
+ return
+ end
+
+ tcp.request({
+ task = task,
+ host = addr:to_string(),
+ port = addr:get_port(),
+ upstream = upstream,
+ timeout = rule.timeout,
+ shutdown = true,
+ data = content,
+ callback = pyzor_callback,
+ })
+ end
+ if common.condition_check_and_continue(task, content, rule, digest, pyzor_check_uncached) then
+ return
+ else
+ pyzor_check_uncached()
+ end
+end
+
+return {
+ type = categories,
+ description = 'pyzor bulk scanner',
+ configure = pyzor_config,
+ check = pyzor_check,
+ name = N
+}
diff --git a/lualib/lua_scanners/razor.lua b/lualib/lua_scanners/razor.lua
new file mode 100644
index 0000000..fcc0a8e
--- /dev/null
+++ b/lualib/lua_scanners/razor.lua
@@ -0,0 +1,181 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+Copyright (c) 2018, Carsten Rosenberg <c.rosenberg@heinlein-support.de>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module razor
+-- This module contains razor access functions
+--]]
+
+local lua_util = require "lua_util"
+local tcp = require "rspamd_tcp"
+local upstream_list = require "rspamd_upstream_list"
+local rspamd_logger = require "rspamd_logger"
+local common = require "lua_scanners/common"
+
+local N = 'razor'
+
+local function razor_config(opts)
+
+ local razor_conf = {
+ name = N,
+ default_port = 11342,
+ timeout = 5.0,
+ log_clean = false,
+ retransmits = 2,
+ cache_expire = 7200, -- expire redis in 2h
+ message = '${SCANNER}: spam message found: "${VIRUS}"',
+ detection_category = "hash",
+ default_score = 1,
+ action = false,
+ dynamic_scan = false,
+ symbol_fail = 'RAZOR_FAIL',
+ symbol = 'RAZOR',
+ }
+
+ razor_conf = lua_util.override_defaults(razor_conf, opts)
+
+ if not razor_conf.prefix then
+ razor_conf.prefix = 'rs_' .. razor_conf.name .. '_'
+ end
+
+ if not razor_conf.log_prefix then
+ razor_conf.log_prefix = razor_conf.name
+ end
+
+ if not razor_conf.servers and razor_conf.socket then
+ razor_conf.servers = razor_conf.socket
+ end
+
+ if not razor_conf.servers then
+ rspamd_logger.errx(rspamd_config, 'no servers defined')
+
+ return nil
+ end
+
+ razor_conf.upstreams = upstream_list.create(rspamd_config,
+ razor_conf.servers,
+ razor_conf.default_port)
+
+ if razor_conf.upstreams then
+ lua_util.add_debug_alias('external_services', razor_conf.name)
+ return razor_conf
+ end
+
+ rspamd_logger.errx(rspamd_config, 'cannot parse servers %s',
+ razor_conf['servers'])
+ return nil
+end
+
+local function razor_check(task, content, digest, rule)
+ local function razor_check_uncached ()
+ local upstream = rule.upstreams:get_upstream_round_robin()
+ local addr = upstream:get_addr()
+ local retransmits = rule.retransmits
+
+ local function razor_callback(err, data, conn)
+
+ local function razor_requery()
+ -- retry with another upstream until retransmits exceeds
+ if retransmits > 0 then
+
+ retransmits = retransmits - 1
+
+ lua_util.debugm(rule.name, task, '%s: Request Error: %s - retries left: %s',
+ rule.log_prefix, err, retransmits)
+
+ -- Select a different upstream!
+ upstream = rule.upstreams:get_upstream_round_robin()
+ addr = upstream:get_addr()
+
+ lua_util.debugm(rule.name, task, '%s: retry IP: %s:%s',
+ rule.log_prefix, addr, addr:get_port())
+
+ tcp.request({
+ task = task,
+ host = addr:to_string(),
+ port = addr:get_port(),
+ upstream = upstream,
+ timeout = rule.timeout or 2.0,
+ shutdown = true,
+ data = content,
+ callback = razor_callback,
+ })
+ else
+ rspamd_logger.errx(task, '%s: failed to scan, maximum retransmits ' ..
+ 'exceed', rule.log_prefix)
+ common.yield_result(task, rule, 'failed to scan and retransmits exceed', 0.0, 'fail')
+ end
+ end
+
+ if err then
+
+ razor_requery()
+
+ else
+ --[[
+ @todo: Razorsocket currently only returns ham or spam. When the wrapper is fixed we should add dynamic scores here.
+ Maybe check spamassassin implementation.
+
+ This implementation is based on https://github.com/cgt/rspamd-plugins
+ Thanks @cgt!
+ ]] --
+
+ local threat_string = tostring(data)
+ if threat_string == "spam" then
+ lua_util.debugm(N, task, '%s: returned result is spam', rule['symbol'], rule['type'])
+ common.yield_result(task, rule, threat_string, rule.default_score)
+ common.save_cache(task, digest, rule, threat_string, rule.default_score)
+ elseif threat_string == "ham" then
+ if rule.log_clean then
+ rspamd_logger.infox(task, '%s: returned result is ham', rule['symbol'], rule['type'])
+ else
+ lua_util.debugm(N, task, '%s: returned result is ham', rule['symbol'], rule['type'])
+ end
+ common.save_cache(task, digest, rule, 'OK', rule.default_score)
+ else
+ rspamd_logger.errx(task, "%s - unknown response from razorfy: %s", addr:to_string(), threat_string)
+ end
+
+ end
+ end
+
+ tcp.request({
+ task = task,
+ host = addr:to_string(),
+ port = addr:get_port(),
+ upstream = upstream,
+ timeout = rule.timeout or 2.0,
+ shutdown = true,
+ data = content,
+ callback = razor_callback,
+ })
+ end
+
+ if common.condition_check_and_continue(task, content, rule, digest, razor_check_uncached) then
+ return
+ else
+ razor_check_uncached()
+ end
+end
+
+return {
+ type = { 'razor', 'spam', 'hash', 'scanner' },
+ description = 'razor bulk scanner',
+ configure = razor_config,
+ check = razor_check,
+ name = N
+}
diff --git a/lualib/lua_scanners/savapi.lua b/lualib/lua_scanners/savapi.lua
new file mode 100644
index 0000000..08f7b66
--- /dev/null
+++ b/lualib/lua_scanners/savapi.lua
@@ -0,0 +1,261 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module savapi
+-- This module contains avira savapi antivirus access functions
+--]]
+
+local lua_util = require "lua_util"
+local tcp = require "rspamd_tcp"
+local upstream_list = require "rspamd_upstream_list"
+local rspamd_util = require "rspamd_util"
+local rspamd_logger = require "rspamd_logger"
+local common = require "lua_scanners/common"
+
+local N = "savapi"
+
+local default_message = '${SCANNER}: virus found: "${VIRUS}"'
+
+local function savapi_config(opts)
+ local savapi_conf = {
+ name = N,
+ scan_mime_parts = true,
+ scan_text_mime = false,
+ scan_image_mime = false,
+ default_port = 4444, -- note: You must set ListenAddress in savapi.conf
+ product_id = 0,
+ log_clean = false,
+ timeout = 15.0, -- FIXME: this will break task_timeout!
+ retransmits = 1, -- FIXME: useless, for local files
+ cache_expire = 3600, -- expire redis in one hour
+ message = default_message,
+ detection_category = "virus",
+ tmpdir = '/tmp',
+ }
+
+ savapi_conf = lua_util.override_defaults(savapi_conf, opts)
+
+ if not savapi_conf.prefix then
+ savapi_conf.prefix = 'rs_' .. savapi_conf.name .. '_'
+ end
+
+ if not savapi_conf.log_prefix then
+ if savapi_conf.name:lower() == savapi_conf.type:lower() then
+ savapi_conf.log_prefix = savapi_conf.name
+ else
+ savapi_conf.log_prefix = savapi_conf.name .. ' (' .. savapi_conf.type .. ')'
+ end
+ end
+
+ if not savapi_conf['servers'] then
+ rspamd_logger.errx(rspamd_config, 'no servers defined')
+
+ return nil
+ end
+
+ savapi_conf['upstreams'] = upstream_list.create(rspamd_config,
+ savapi_conf['servers'],
+ savapi_conf.default_port)
+
+ if savapi_conf['upstreams'] then
+ lua_util.add_debug_alias('antivirus', savapi_conf.name)
+ return savapi_conf
+ end
+
+ rspamd_logger.errx(rspamd_config, 'cannot parse servers %s',
+ savapi_conf['servers'])
+ return nil
+end
+
+local function savapi_check(task, content, digest, rule)
+ local function savapi_check_uncached ()
+ local upstream = rule.upstreams:get_upstream_round_robin()
+ local addr = upstream:get_addr()
+ local retransmits = rule.retransmits
+ local fname = string.format('%s/%s.tmp',
+ rule.tmpdir, rspamd_util.random_hex(32))
+ local message_fd = rspamd_util.create_file(fname)
+
+ if not message_fd then
+ rspamd_logger.errx('cannot store file for savapi scan: %s', fname)
+ return
+ end
+
+ if type(content) == 'string' then
+ -- Create rspamd_text
+ local rspamd_text = require "rspamd_text"
+ content = rspamd_text.fromstring(content)
+ end
+ content:save_in_file(message_fd)
+
+ -- Ensure cleanup
+ task:get_mempool():add_destructor(function()
+ os.remove(fname)
+ rspamd_util.close_file(message_fd)
+ end)
+
+ local vnames = {}
+
+ -- Forward declaration for recursive calls
+ local savapi_scan1_cb
+
+ local function savapi_fin_cb(err, conn)
+ local vnames_reordered = {}
+ -- Swap table
+ for virus, _ in pairs(vnames) do
+ table.insert(vnames_reordered, virus)
+ end
+ lua_util.debugm(rule.name, task, "%s: number of virus names found %s", rule['type'], #vnames_reordered)
+ if #vnames_reordered > 0 then
+ local vname = {}
+ for _, virus in ipairs(vnames_reordered) do
+ table.insert(vname, virus)
+ end
+
+ common.yield_result(task, rule, vname)
+ common.save_cache(task, digest, rule, vname)
+ end
+ if conn then
+ conn:close()
+ end
+ end
+
+ local function savapi_scan2_cb(err, data, conn)
+ local result = tostring(data)
+ lua_util.debugm(rule.name, task, "%s: got reply: %s",
+ rule.type, result)
+
+ -- Terminal response - clean
+ if string.find(result, '200') or string.find(result, '210') then
+ if rule['log_clean'] then
+ rspamd_logger.infox(task, '%s: message or mime_part is clean', rule['type'])
+ end
+ common.save_cache(task, digest, rule, 'OK')
+ conn:add_write(savapi_fin_cb, 'QUIT\n')
+
+ -- Terminal response - infected
+ elseif string.find(result, '319') then
+ conn:add_write(savapi_fin_cb, 'QUIT\n')
+
+ -- Non-terminal response
+ elseif string.find(result, '310') then
+ local virus
+ virus = result:match "310.*<<<%s(.*)%s+;.*;.*"
+ if not virus then
+ virus = result:match "310%s(.*)%s+;.*;.*"
+ if not virus then
+ rspamd_logger.errx(task, "%s: virus result unparseable: %s",
+ rule['type'], result)
+ common.yield_result(task, rule, 'virus result unparseable: ' .. result, 0.0, 'fail')
+ return
+ end
+ end
+ -- Store unique virus names
+ vnames[virus] = 1
+ -- More content is expected
+ conn:add_write(savapi_scan1_cb, '\n')
+ end
+ end
+
+ savapi_scan1_cb = function(err, conn)
+ conn:add_read(savapi_scan2_cb, '\n')
+ end
+
+ -- 100 PRODUCT:xyz
+ local function savapi_greet2_cb(err, data, conn)
+ local result = tostring(data)
+ if string.find(result, '100 PRODUCT') then
+ lua_util.debugm(rule.name, task, "%s: scanning file: %s",
+ rule['type'], fname)
+ conn:add_write(savapi_scan1_cb, { string.format('SCAN %s\n',
+ fname) })
+ else
+ rspamd_logger.errx(task, '%s: invalid product id %s', rule['type'],
+ rule['product_id'])
+ common.yield_result(task, rule, 'invalid product id: ' .. result, 0.0, 'fail')
+ conn:add_write(savapi_fin_cb, 'QUIT\n')
+ end
+ end
+
+ local function savapi_greet1_cb(err, conn)
+ conn:add_read(savapi_greet2_cb, '\n')
+ end
+
+ local function savapi_callback_init(err, data, conn)
+ if err then
+
+ -- retry with another upstream until retransmits exceeds
+ if retransmits > 0 then
+
+ retransmits = retransmits - 1
+
+ -- Select a different upstream!
+ upstream = rule.upstreams:get_upstream_round_robin()
+ addr = upstream:get_addr()
+
+ lua_util.debugm(rule.name, task, '%s: error: %s; retry IP: %s; retries left: %s',
+ rule.log_prefix, err, addr, retransmits)
+
+ tcp.request({
+ task = task,
+ host = addr:to_string(),
+ port = addr:get_port(),
+ upstream = upstream,
+ timeout = rule['timeout'],
+ callback = savapi_callback_init,
+ stop_pattern = { '\n' },
+ })
+ else
+ rspamd_logger.errx(task, '%s [%s]: failed to scan, maximum retransmits exceed', rule['symbol'], rule['type'])
+ common.yield_result(task, rule, 'failed to scan and retransmits exceed', 0.0, 'fail')
+ end
+ else
+ local result = tostring(data)
+
+ -- 100 SAVAPI:4.0 greeting
+ if string.find(result, '100') then
+ conn:add_write(savapi_greet1_cb, { string.format('SET PRODUCT %s\n', rule['product_id']) })
+ end
+ end
+ end
+
+ tcp.request({
+ task = task,
+ host = addr:to_string(),
+ port = addr:get_port(),
+ upstream = upstream,
+ timeout = rule['timeout'],
+ callback = savapi_callback_init,
+ stop_pattern = { '\n' },
+ })
+ end
+
+ if common.condition_check_and_continue(task, content, rule, digest, savapi_check_uncached) then
+ return
+ else
+ savapi_check_uncached()
+ end
+
+end
+
+return {
+ type = 'antivirus',
+ description = 'savapi avira antivirus',
+ configure = savapi_config,
+ check = savapi_check,
+ name = N
+}
diff --git a/lualib/lua_scanners/sophos.lua b/lualib/lua_scanners/sophos.lua
new file mode 100644
index 0000000..d9b64f1
--- /dev/null
+++ b/lualib/lua_scanners/sophos.lua
@@ -0,0 +1,192 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module savapi
+-- This module contains avira savapi antivirus access functions
+--]]
+
+local lua_util = require "lua_util"
+local tcp = require "rspamd_tcp"
+local upstream_list = require "rspamd_upstream_list"
+local rspamd_logger = require "rspamd_logger"
+local common = require "lua_scanners/common"
+
+local N = "sophos"
+
+local default_message = '${SCANNER}: virus found: "${VIRUS}"'
+
+local function sophos_config(opts)
+ local sophos_conf = {
+ name = N,
+ scan_mime_parts = true,
+ scan_text_mime = false,
+ scan_image_mime = false,
+ default_port = 4010,
+ timeout = 15.0,
+ log_clean = false,
+ retransmits = 2,
+ cache_expire = 3600, -- expire redis in one hour
+ message = default_message,
+ detection_category = "virus",
+ }
+
+ sophos_conf = lua_util.override_defaults(sophos_conf, opts)
+
+ if not sophos_conf.prefix then
+ sophos_conf.prefix = 'rs_' .. sophos_conf.name .. '_'
+ end
+
+ if not sophos_conf.log_prefix then
+ if sophos_conf.name:lower() == sophos_conf.type:lower() then
+ sophos_conf.log_prefix = sophos_conf.name
+ else
+ sophos_conf.log_prefix = sophos_conf.name .. ' (' .. sophos_conf.type .. ')'
+ end
+ end
+
+ if not sophos_conf['servers'] then
+ rspamd_logger.errx(rspamd_config, 'no servers defined')
+
+ return nil
+ end
+
+ sophos_conf['upstreams'] = upstream_list.create(rspamd_config,
+ sophos_conf['servers'],
+ sophos_conf.default_port)
+
+ if sophos_conf['upstreams'] then
+ lua_util.add_debug_alias('antivirus', sophos_conf.name)
+ return sophos_conf
+ end
+
+ rspamd_logger.errx(rspamd_config, 'cannot parse servers %s',
+ sophos_conf['servers'])
+ return nil
+end
+
+local function sophos_check(task, content, digest, rule, maybe_part)
+ local function sophos_check_uncached ()
+ local upstream = rule.upstreams:get_upstream_round_robin()
+ local addr = upstream:get_addr()
+ local retransmits = rule.retransmits
+ local protocol = 'SSSP/1.0\n'
+ local streamsize = string.format('SCANDATA %d\n', #content)
+ local bye = 'BYE\n'
+
+ local function sophos_callback(err, data, conn)
+
+ if err then
+ -- retry with another upstream until retransmits exceeds
+ if retransmits > 0 then
+
+ retransmits = retransmits - 1
+
+ -- Select a different upstream!
+ upstream = rule.upstreams:get_upstream_round_robin()
+ addr = upstream:get_addr()
+
+ lua_util.debugm(rule.name, task, '%s: error: %s; retry IP: %s; retries left: %s',
+ rule.log_prefix, err, addr, retransmits)
+
+ tcp.request({
+ task = task,
+ host = addr:to_string(),
+ port = addr:get_port(),
+ upstream = upstream,
+ timeout = rule['timeout'],
+ callback = sophos_callback,
+ data = { protocol, streamsize, content, bye }
+ })
+ else
+ rspamd_logger.errx(task, '%s [%s]: failed to scan, maximum retransmits exceed', rule['symbol'], rule['type'])
+ common.yield_result(task, rule, 'failed to scan and retransmits exceed',
+ 0.0, 'fail', maybe_part)
+ end
+ else
+ data = tostring(data)
+ lua_util.debugm(rule.name, task,
+ '%s [%s]: got reply: %s', rule['symbol'], rule['type'], data)
+ local vname = string.match(data, 'VIRUS (%S+) ')
+ local cached
+ if vname then
+ common.yield_result(task, rule, vname, 1.0, nil, maybe_part)
+ common.save_cache(task, digest, rule, vname, 1.0, maybe_part)
+ else
+ if string.find(data, 'DONE OK') then
+ if rule['log_clean'] then
+ rspamd_logger.infox(task, '%s: message or mime_part is clean', rule.log_prefix)
+ else
+ lua_util.debugm(rule.name, task,
+ '%s: message or mime_part is clean', rule.log_prefix)
+ end
+ cached = 'OK'
+ -- not finished - continue
+ elseif string.find(data, 'ACC') or string.find(data, 'OK SSSP') then
+ conn:add_read(sophos_callback)
+ elseif string.find(data, 'FAIL 0212') then
+ rspamd_logger.warnx(task, 'Message is encrypted (FAIL 0212): %s', data)
+ common.yield_result(task, rule, 'SAVDI: Message is encrypted (FAIL 0212)',
+ 0.0, 'encrypted', maybe_part)
+ cached = 'ENCRYPTED'
+ elseif string.find(data, 'REJ 4') then
+ rspamd_logger.warnx(task, 'Message is oversized (REJ 4): %s', data)
+ common.yield_result(task, rule, 'SAVDI: Message oversized (REJ 4)',
+ 0.0, 'fail', maybe_part)
+ -- explicitly set REJ1 message when SAVDIreports a protocol error
+ elseif string.find(data, 'REJ 1') then
+ rspamd_logger.errx(task, 'SAVDI (Protocol error (REJ 1)): %s', data)
+ common.yield_result(task, rule, 'SAVDI: Protocol error (REJ 1)',
+ 0.0, 'fail', maybe_part)
+ else
+ rspamd_logger.errx(task, 'unhandled response: %s', data)
+ common.yield_result(task, rule, 'unhandled response: ' .. data,
+ 0.0, 'fail', maybe_part)
+ end
+ if cached then
+ common.save_cache(task, digest, rule, cached, 1.0, maybe_part)
+ end
+ end
+ end
+ end
+
+ tcp.request({
+ task = task,
+ host = addr:to_string(),
+ port = addr:get_port(),
+ upstream = upstream,
+ timeout = rule['timeout'],
+ callback = sophos_callback,
+ data = { protocol, streamsize, content, bye }
+ })
+ end
+
+ if common.condition_check_and_continue(task, content, rule, digest,
+ sophos_check_uncached, maybe_part) then
+ return
+ else
+ sophos_check_uncached()
+ end
+
+end
+
+return {
+ type = 'antivirus',
+ description = 'sophos antivirus',
+ configure = sophos_config,
+ check = sophos_check,
+ name = N
+}
diff --git a/lualib/lua_scanners/spamassassin.lua b/lualib/lua_scanners/spamassassin.lua
new file mode 100644
index 0000000..f425924
--- /dev/null
+++ b/lualib/lua_scanners/spamassassin.lua
@@ -0,0 +1,213 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+Copyright (c) 2019, Carsten Rosenberg <c.rosenberg@heinlein-support.de>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module spamassassin
+-- This module contains spamd access functions.
+--]]
+
+local lua_util = require "lua_util"
+local tcp = require "rspamd_tcp"
+local upstream_list = require "rspamd_upstream_list"
+local rspamd_logger = require "rspamd_logger"
+local common = require "lua_scanners/common"
+
+local N = 'spamassassin'
+
+local function spamassassin_config(opts)
+
+ local spamassassin_conf = {
+ N = N,
+ scan_mime_parts = false,
+ scan_text_mime = false,
+ scan_image_mime = false,
+ default_port = 783,
+ timeout = 15.0,
+ log_clean = false,
+ retransmits = 2,
+ cache_expire = 3600, -- expire redis in one hour
+ symbol = "SPAMD",
+ message = '${SCANNER}: Spamassassin bulk message found: "${VIRUS}"',
+ detection_category = "spam",
+ default_score = 1,
+ action = false,
+ extended = false,
+ symbol_type = 'postfilter',
+ dynamic_scan = true,
+ }
+
+ spamassassin_conf = lua_util.override_defaults(spamassassin_conf, opts)
+
+ if not spamassassin_conf.prefix then
+ spamassassin_conf.prefix = 'rs_' .. spamassassin_conf.name .. '_'
+ end
+
+ if not spamassassin_conf.log_prefix then
+ if spamassassin_conf.name:lower() == spamassassin_conf.type:lower() then
+ spamassassin_conf.log_prefix = spamassassin_conf.name
+ else
+ spamassassin_conf.log_prefix = spamassassin_conf.name .. ' (' .. spamassassin_conf.type .. ')'
+ end
+ end
+
+ if not spamassassin_conf.servers then
+ rspamd_logger.errx(rspamd_config, 'no servers defined')
+
+ return nil
+ end
+
+ spamassassin_conf.upstreams = upstream_list.create(rspamd_config,
+ spamassassin_conf.servers,
+ spamassassin_conf.default_port)
+
+ if spamassassin_conf.upstreams then
+ lua_util.add_debug_alias('external_services', spamassassin_conf.N)
+ return spamassassin_conf
+ end
+
+ rspamd_logger.errx(rspamd_config, 'cannot parse servers %s',
+ spamassassin_conf.servers)
+ return nil
+end
+
+local function spamassassin_check(task, content, digest, rule)
+ local function spamassassin_check_uncached ()
+ local upstream = rule.upstreams:get_upstream_round_robin()
+ local addr = upstream:get_addr()
+ local retransmits = rule.retransmits
+
+ -- Build the spamd query
+ -- https://svn.apache.org/repos/asf/spamassassin/trunk/spamd/PROTOCOL
+ local request_data = {
+ "HEADERS SPAMC/1.5\r\n",
+ "User: root\r\n",
+ "Content-length: " .. #content .. "\r\n",
+ "\r\n",
+ content,
+ }
+
+ local function spamassassin_callback(err, data)
+
+ local function spamassassin_requery(error)
+
+ -- retry with another upstream until retransmits exceeds
+ if retransmits > 0 then
+
+ retransmits = retransmits - 1
+
+ lua_util.debugm(rule.N, task, '%s: Request Error: %s - retries left: %s',
+ rule.log_prefix, error, retransmits)
+
+ -- Select a different upstream!
+ upstream = rule.upstreams:get_upstream_round_robin()
+ addr = upstream:get_addr()
+
+ lua_util.debugm(rule.N, task, '%s: retry IP: %s:%s',
+ rule.log_prefix, addr, addr:get_port())
+
+ tcp.request({
+ task = task,
+ host = addr:to_string(),
+ port = addr:get_port(),
+ upstream = upstream,
+ timeout = rule['timeout'],
+ data = request_data,
+ callback = spamassassin_callback,
+ })
+ else
+ rspamd_logger.errx(task, '%s: failed to scan, maximum retransmits ' ..
+ 'exceed - err: %s', rule.log_prefix, error)
+ common.yield_result(task, rule, 'failed to scan and retransmits exceed: ' .. error, 0.0, 'fail')
+ end
+ end
+
+ if err then
+
+ spamassassin_requery(err)
+
+ else
+ --lua_util.debugm(rule.N, task, '%s: returned result: %s', rule.log_prefix, data)
+
+ --[[
+ patterns tested against Spamassassin 3.4.6
+
+ X-Spam-Status: No, score=1.1 required=5.0 tests=HTML_MESSAGE,MIME_HTML_ONLY,
+ TVD_RCVD_SPACE_BRACKET,UNPARSEABLE_RELAY autolearn=no
+ autolearn_force=no version=3.4.6
+ ]] --
+ local header = string.gsub(tostring(data), "[\r\n]+[\t ]", " ")
+ --lua_util.debugm(rule.N, task, '%s: returned header: %s', rule.log_prefix, header)
+
+ local symbols = ""
+ local spam_score = 0
+ for s in header:gmatch("[^\r\n]+") do
+ if string.find(s, 'X%-Spam%-Status: %S+, score') then
+ local pattern_symbols = "X%-Spam%-Status: %S+, score%=([%-%d%.]+)%s.*tests%=(.*,?)(%s*%S+)%sautolearn.*"
+ spam_score = string.gsub(s, pattern_symbols, "%1")
+ symbols = string.gsub(s, pattern_symbols, "%2%3")
+ symbols = string.gsub(symbols, "%s", "")
+ end
+ end
+
+ lua_util.debugm(rule.N, task, '%s: spam_score: %s, symbols: %s, int spam_score: |%s|, type spam_score: |%s|',
+ rule.log_prefix, spam_score, symbols, tonumber(spam_score), type(spam_score))
+
+ if tonumber(spam_score) > 0 and #symbols > 0 and symbols ~= "none" then
+
+ if rule.extended == false then
+ common.yield_result(task, rule, symbols, spam_score)
+ common.save_cache(task, digest, rule, symbols, spam_score)
+ else
+ local symbols_table = lua_util.str_split(symbols, ",")
+ lua_util.debugm(rule.N, task, '%s: returned symbols as table: %s', rule.log_prefix, symbols_table)
+
+ common.yield_result(task, rule, symbols_table, spam_score)
+ common.save_cache(task, digest, rule, symbols_table, spam_score)
+ end
+ else
+ common.save_cache(task, digest, rule, 'OK')
+ common.log_clean(task, rule, 'no spam detected - spam score: ' .. spam_score .. ', symbols: ' .. symbols)
+ end
+ end
+ end
+
+ tcp.request({
+ task = task,
+ host = addr:to_string(),
+ port = addr:get_port(),
+ upstream = upstream,
+ timeout = rule['timeout'],
+ data = request_data,
+ callback = spamassassin_callback,
+ })
+ end
+
+ if common.condition_check_and_continue(task, content, rule, digest, spamassassin_check_uncached) then
+ return
+ else
+ spamassassin_check_uncached()
+ end
+
+end
+
+return {
+ type = { N, 'spam', 'scanner' },
+ description = 'spamassassin spam scanner',
+ configure = spamassassin_config,
+ check = spamassassin_check,
+ name = N
+}
diff --git a/lualib/lua_scanners/vadesecure.lua b/lualib/lua_scanners/vadesecure.lua
new file mode 100644
index 0000000..826573a
--- /dev/null
+++ b/lualib/lua_scanners/vadesecure.lua
@@ -0,0 +1,351 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module vadesecure
+-- This module contains Vadesecure Filterd interface
+--]]
+
+local lua_util = require "lua_util"
+local http = require "rspamd_http"
+local upstream_list = require "rspamd_upstream_list"
+local rspamd_logger = require "rspamd_logger"
+local ucl = require "ucl"
+local common = require "lua_scanners/common"
+
+local N = 'vadesecure'
+
+local function vade_config(opts)
+
+ local vade_conf = {
+ name = N,
+ default_port = 23808,
+ url = '/api/v1/scan',
+ use_https = false,
+ timeout = 5.0,
+ log_clean = false,
+ retransmits = 1,
+ cache_expire = 7200, -- expire redis in 2h
+ message = '${SCANNER}: spam message found: "${VIRUS}"',
+ detection_category = "hash",
+ default_score = 1,
+ action = false,
+ log_spamcause = true,
+ symbol_fail = 'VADE_FAIL',
+ symbol = 'VADE_CHECK',
+ settings_outbound = nil, -- Set when there is a settings id for outbound messages
+ symbols = {
+ clean = {
+ symbol = 'VADE_CLEAN',
+ score = -0.5,
+ description = 'VadeSecure decided message to be clean'
+ },
+ spam = {
+ high = {
+ symbol = 'VADE_SPAM_HIGH',
+ score = 8.0,
+ description = 'VadeSecure decided message to be clearly spam'
+ },
+ medium = {
+ symbol = 'VADE_SPAM_MEDIUM',
+ score = 5.0,
+ description = 'VadeSecure decided message to be highly likely spam'
+ },
+ low = {
+ symbol = 'VADE_SPAM_LOW',
+ score = 2.0,
+ description = 'VadeSecure decided message to be likely spam'
+ },
+ },
+ malware = {
+ symbol = 'VADE_MALWARE',
+ score = 8.0,
+ description = 'VadeSecure decided message to be malware'
+ },
+ scam = {
+ symbol = 'VADE_SCAM',
+ score = 7.0,
+ description = 'VadeSecure decided message to be scam'
+ },
+ phishing = {
+ symbol = 'VADE_PHISHING',
+ score = 8.0,
+ description = 'VadeSecure decided message to be phishing'
+ },
+ commercial = {
+ symbol = 'VADE_COMMERCIAL',
+ score = 0.0,
+ description = 'VadeSecure decided message to be commercial message'
+ },
+ community = {
+ symbol = 'VADE_COMMUNITY',
+ score = 0.0,
+ description = 'VadeSecure decided message to be community message'
+ },
+ transactional = {
+ symbol = 'VADE_TRANSACTIONAL',
+ score = 0.0,
+ description = 'VadeSecure decided message to be transactional message'
+ },
+ suspect = {
+ symbol = 'VADE_SUSPECT',
+ score = 3.0,
+ description = 'VadeSecure decided message to be suspicious message'
+ },
+ bounce = {
+ symbol = 'VADE_BOUNCE',
+ score = 0.0,
+ description = 'VadeSecure decided message to be bounce message'
+ },
+ other = 'VADE_OTHER',
+ }
+ }
+
+ vade_conf = lua_util.override_defaults(vade_conf, opts)
+
+ if not vade_conf.prefix then
+ vade_conf.prefix = 'rs_' .. vade_conf.name .. '_'
+ end
+
+ if not vade_conf.log_prefix then
+ if vade_conf.name:lower() == vade_conf.type:lower() then
+ vade_conf.log_prefix = vade_conf.name
+ else
+ vade_conf.log_prefix = vade_conf.name .. ' (' .. vade_conf.type .. ')'
+ end
+ end
+
+ if not vade_conf.servers and vade_conf.socket then
+ vade_conf.servers = vade_conf.socket
+ end
+
+ if not vade_conf.servers then
+ rspamd_logger.errx(rspamd_config, 'no servers defined')
+
+ return nil
+ end
+
+ vade_conf.upstreams = upstream_list.create(rspamd_config,
+ vade_conf.servers,
+ vade_conf.default_port)
+
+ if vade_conf.upstreams then
+ lua_util.add_debug_alias('external_services', vade_conf.name)
+ return vade_conf
+ end
+
+ rspamd_logger.errx(rspamd_config, 'cannot parse servers %s',
+ vade_conf['servers'])
+ return nil
+end
+
+local function vade_check(task, content, digest, rule, maybe_part)
+ local function vade_check_uncached()
+ local function vade_url(addr)
+ local url
+ if rule.use_https then
+ url = string.format('https://%s:%d%s', tostring(addr),
+ rule.default_port, rule.url)
+ else
+ url = string.format('http://%s:%d%s', tostring(addr),
+ rule.default_port, rule.url)
+ end
+
+ return url
+ end
+
+ local upstream = rule.upstreams:get_upstream_round_robin()
+ local addr = upstream:get_addr()
+ local retransmits = rule.retransmits
+
+ local url = vade_url(addr)
+ local hdrs = {}
+
+ local helo = task:get_helo()
+ if helo then
+ hdrs['X-Helo'] = helo
+ end
+ local mail_from = task:get_from('smtp') or {}
+ if mail_from[1] and #mail_from[1].addr > 1 then
+ hdrs['X-Mailfrom'] = mail_from[1].addr
+ end
+
+ local rcpt_to = task:get_recipients('smtp')
+ if rcpt_to then
+ hdrs['X-Rcptto'] = {}
+ for _, r in ipairs(rcpt_to) do
+ table.insert(hdrs['X-Rcptto'], r.addr)
+ end
+ end
+
+ local fip = task:get_from_ip()
+ if fip and fip:is_valid() then
+ hdrs['X-Inet'] = tostring(fip)
+ end
+
+ if rule.settings_outbound then
+ local settings_id = task:get_settings_id()
+
+ if settings_id then
+ local lua_settings = require "lua_settings"
+ -- Convert to string
+ settings_id = lua_settings.settings_by_id(settings_id)
+
+ if settings_id then
+ settings_id = settings_id.name or ''
+
+ if settings_id == rule.settings_outbound then
+ lua_util.debugm(rule.name, task, '%s settings has matched outbound',
+ settings_id)
+ hdrs['X-Params'] = 'mode=smtpout'
+ end
+ end
+ end
+ end
+
+ local request_data = {
+ task = task,
+ url = url,
+ body = task:get_content(),
+ headers = hdrs,
+ timeout = rule.timeout,
+ }
+
+ local function vade_callback(http_err, code, body, headers)
+
+ local function vade_requery()
+ -- set current upstream to fail because an error occurred
+ upstream:fail()
+
+ -- retry with another upstream until retransmits exceeds
+ if retransmits > 0 then
+
+ retransmits = retransmits - 1
+
+ lua_util.debugm(rule.name, task,
+ '%s: Request Error: %s - retries left: %s',
+ rule.log_prefix, http_err, retransmits)
+
+ -- Select a different upstream!
+ upstream = rule.upstreams:get_upstream_round_robin()
+ addr = upstream:get_addr()
+ url = vade_url(addr)
+
+ lua_util.debugm(rule.name, task, '%s: retry IP: %s:%s',
+ rule.log_prefix, addr, addr:get_port())
+ request_data.url = url
+
+ http.request(request_data)
+ else
+ rspamd_logger.errx(task, '%s: failed to scan, maximum retransmits ' ..
+ 'exceed', rule.log_prefix)
+ task:insert_result(rule['symbol_fail'], 0.0, 'failed to scan and ' ..
+ 'retransmits exceed')
+ end
+ end
+
+ if http_err then
+ vade_requery()
+ else
+ -- Parse the response
+ if upstream then
+ upstream:ok()
+ end
+ if code ~= 200 then
+ rspamd_logger.errx(task, 'invalid HTTP code: %s, body: %s, headers: %s', code, body, headers)
+ task:insert_result(rule.symbol_fail, 1.0, 'Bad HTTP code: ' .. code)
+ return
+ end
+ local parser = ucl.parser()
+ local ret, err = parser:parse_string(body)
+ if not ret then
+ rspamd_logger.errx(task, 'vade: bad response body (raw): %s', body)
+ task:insert_result(rule.symbol_fail, 1.0, 'Parser error: ' .. err)
+ return
+ end
+ local obj = parser:get_object()
+ local verdict = obj.verdict
+ if not verdict then
+ rspamd_logger.errx(task, 'vade: bad response JSON (no verdict): %s', obj)
+ task:insert_result(rule.symbol_fail, 1.0, 'No verdict/unknown verdict')
+ return
+ end
+ local vparts = lua_util.str_split(verdict, ":")
+ verdict = table.remove(vparts, 1) or verdict
+
+ local sym = rule.symbols[verdict]
+ if not sym then
+ sym = rule.symbols.other
+ end
+
+ if not sym.symbol then
+ -- Subcategory match
+ local lvl = 'low'
+ if vparts and vparts[1] then
+ lvl = vparts[1]
+ end
+
+ if sym[lvl] then
+ sym = sym[lvl]
+ else
+ sym = rule.symbols.other
+ end
+ end
+
+ local opts = {}
+ if obj.score then
+ table.insert(opts, 'score=' .. obj.score)
+ end
+ if obj.elapsed then
+ table.insert(opts, 'elapsed=' .. obj.elapsed)
+ end
+
+ if rule.log_spamcause and obj.spamcause then
+ rspamd_logger.infox(task, 'vadesecure verdict="%s", score=%s, spamcause="%s", message-id="%s"',
+ verdict, obj.score, obj.spamcause, task:get_message_id())
+ else
+ lua_util.debugm(rule.name, task, 'vadesecure returned verdict="%s", score=%s, spamcause="%s"',
+ verdict, obj.score, obj.spamcause)
+ end
+
+ if #vparts > 0 then
+ table.insert(opts, 'verdict=' .. verdict .. ';' .. table.concat(vparts, ':'))
+ end
+
+ task:insert_result(sym.symbol, 1.0, opts)
+ end
+ end
+
+ request_data.callback = vade_callback
+ http.request(request_data)
+ end
+
+ if common.condition_check_and_continue(task, content, rule, digest,
+ vade_check_uncached, maybe_part) then
+ return
+ else
+ vade_check_uncached()
+ end
+
+end
+
+return {
+ type = { 'vadesecure', 'scanner' },
+ description = 'VadeSecure Filterd interface',
+ configure = vade_config,
+ check = vade_check,
+ name = N
+}
diff --git a/lualib/lua_scanners/virustotal.lua b/lualib/lua_scanners/virustotal.lua
new file mode 100644
index 0000000..d937c41
--- /dev/null
+++ b/lualib/lua_scanners/virustotal.lua
@@ -0,0 +1,214 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module virustotal
+-- This module contains Virustotal integration support
+-- https://www.virustotal.com/
+--]]
+
+local lua_util = require "lua_util"
+local http = require "rspamd_http"
+local rspamd_cryptobox_hash = require "rspamd_cryptobox_hash"
+local rspamd_logger = require "rspamd_logger"
+local common = require "lua_scanners/common"
+
+local N = 'virustotal'
+
+local function virustotal_config(opts)
+
+ local default_conf = {
+ name = N,
+ url = 'https://www.virustotal.com/vtapi/v2/file',
+ timeout = 5.0,
+ log_clean = false,
+ retransmits = 1,
+ cache_expire = 7200, -- expire redis in 2h
+ message = '${SCANNER}: spam message found: "${VIRUS}"',
+ detection_category = "virus",
+ default_score = 1,
+ action = false,
+ scan_mime_parts = true,
+ scan_text_mime = false,
+ scan_image_mime = false,
+ apikey = nil, -- Required to set by user
+ -- Specific for virustotal
+ minimum_engines = 3, -- Minimum required to get scored
+ full_score_engines = 7, -- After this number we set max score
+ }
+
+ default_conf = lua_util.override_defaults(default_conf, opts)
+
+ if not default_conf.prefix then
+ default_conf.prefix = 'rs_' .. default_conf.name .. '_'
+ end
+
+ if not default_conf.log_prefix then
+ if default_conf.name:lower() == default_conf.type:lower() then
+ default_conf.log_prefix = default_conf.name
+ else
+ default_conf.log_prefix = default_conf.name .. ' (' .. default_conf.type .. ')'
+ end
+ end
+
+ if not default_conf.apikey then
+ rspamd_logger.errx(rspamd_config, 'no apikey defined for virustotal, disable checks')
+
+ return nil
+ end
+
+ lua_util.add_debug_alias('external_services', default_conf.name)
+ return default_conf
+end
+
+local function virustotal_check(task, content, digest, rule, maybe_part)
+ local function virustotal_check_uncached()
+ local function make_url(hash)
+ return string.format('%s/report?apikey=%s&resource=%s',
+ rule.url, rule.apikey, hash)
+ end
+
+ local hash = rspamd_cryptobox_hash.create_specific('md5')
+ hash:update(content)
+ hash = hash:hex()
+
+ local url = make_url(hash)
+ lua_util.debugm(N, task, "send request %s", url)
+ local request_data = {
+ task = task,
+ url = url,
+ timeout = rule.timeout,
+ }
+
+ local function vt_http_callback(http_err, code, body, headers)
+ if http_err then
+ rspamd_logger.errx(task, 'HTTP error: %s, body: %s, headers: %s', http_err, body, headers)
+ else
+ local cached
+ local dyn_score
+ -- Parse the response
+ if code ~= 200 then
+ if code == 404 then
+ cached = 'OK'
+ if rule['log_clean'] then
+ rspamd_logger.infox(task, '%s: hash %s clean (not found)',
+ rule.log_prefix, hash)
+ else
+ lua_util.debugm(rule.name, task, '%s: hash %s clean (not found)',
+ rule.log_prefix, hash)
+ end
+ elseif code == 204 then
+ -- Request rate limit exceeded
+ rspamd_logger.infox(task, 'virustotal request rate limit exceeded')
+ task:insert_result(rule.symbol_fail, 1.0, 'rate limit exceeded')
+ return
+ else
+ rspamd_logger.errx(task, 'invalid HTTP code: %s, body: %s, headers: %s', code, body, headers)
+ task:insert_result(rule.symbol_fail, 1.0, 'Bad HTTP code: ' .. code)
+ return
+ end
+ else
+ local ucl = require "ucl"
+ local parser = ucl.parser()
+ local res, json_err = parser:parse_string(body)
+
+ lua_util.debugm(rule.name, task, '%s: got reply data: "%s"',
+ rule.log_prefix, body)
+
+ if res then
+ local obj = parser:get_object()
+ if not obj.positives or type(obj.positives) ~= 'number' then
+ if obj.response_code then
+ if obj.response_code == 0 then
+ cached = 'OK'
+ if rule['log_clean'] then
+ rspamd_logger.infox(task, '%s: hash %s clean (not found)',
+ rule.log_prefix, hash)
+ else
+ lua_util.debugm(rule.name, task, '%s: hash %s clean (not found)',
+ rule.log_prefix, hash)
+ end
+ else
+ rspamd_logger.errx(task, 'invalid JSON reply: %s, body: %s, headers: %s',
+ 'bad response code: ' .. tostring(obj.response_code), body, headers)
+ task:insert_result(rule.symbol_fail, 1.0, 'Bad JSON reply: no `positives` element')
+ return
+ end
+ else
+ rspamd_logger.errx(task, 'invalid JSON reply: %s, body: %s, headers: %s',
+ 'no response_code', body, headers)
+ task:insert_result(rule.symbol_fail, 1.0, 'Bad JSON reply: no `positives` element')
+ return
+ end
+ else
+ if obj.positives < rule.minimum_engines then
+ lua_util.debugm(rule.name, task, '%s: hash %s has not enough hits: %s where %s is min',
+ rule.log_prefix, obj.positives, rule.minimum_engines)
+ -- TODO: add proper hashing!
+ cached = 'OK'
+ else
+ if obj.positives > rule.full_score_engines then
+ dyn_score = 1.0
+ else
+ local norm_pos = obj.positives - rule.minimum_engines
+ dyn_score = norm_pos / (rule.full_score_engines - rule.minimum_engines)
+ end
+
+ if dyn_score < 0 or dyn_score > 1 then
+ dyn_score = 1.0
+ end
+ local sopt = string.format("%s:%s/%s",
+ hash, obj.positives, obj.total)
+ common.yield_result(task, rule, sopt, dyn_score, nil, maybe_part)
+ cached = sopt
+ end
+ end
+ else
+ -- not res
+ rspamd_logger.errx(task, 'invalid JSON reply: %s, body: %s, headers: %s',
+ json_err, body, headers)
+ task:insert_result(rule.symbol_fail, 1.0, 'Bad JSON reply: ' .. json_err)
+ return
+ end
+ end
+
+ if cached then
+ common.save_cache(task, digest, rule, cached, dyn_score, maybe_part)
+ end
+ end
+ end
+
+ request_data.callback = vt_http_callback
+ http.request(request_data)
+ end
+
+ if common.condition_check_and_continue(task, content, rule, digest,
+ virustotal_check_uncached) then
+ return
+ else
+
+ virustotal_check_uncached()
+ end
+
+end
+
+return {
+ type = 'antivirus',
+ description = 'Virustotal integration',
+ configure = virustotal_config,
+ check = virustotal_check,
+ name = N
+}
diff --git a/lualib/lua_selectors/common.lua b/lualib/lua_selectors/common.lua
new file mode 100644
index 0000000..7b2372d
--- /dev/null
+++ b/lualib/lua_selectors/common.lua
@@ -0,0 +1,95 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local ts = require("tableshape").types
+local exports = {}
+local cr_hash = require 'rspamd_cryptobox_hash'
+
+local blake2b_key = cr_hash.create_specific('blake2'):update('rspamd'):bin()
+
+local function digest_schema()
+ return { ts.one_of { 'hex', 'base32', 'bleach32', 'rbase32', 'base64' }:is_optional(),
+ ts.one_of { 'blake2', 'sha256', 'sha1', 'sha512', 'md5' }:is_optional() }
+end
+
+exports.digest_schema = digest_schema
+
+local function create_raw_digest(data, args)
+ local ht = args[2] or 'blake2'
+
+ local h
+
+ if ht == 'blake2' then
+ -- Hack to be compatible with various 'get_digest' methods
+ h = cr_hash.create_keyed(blake2b_key):update(data)
+ else
+ h = cr_hash.create_specific(ht):update(data)
+ end
+
+ return h
+end
+
+local function encode_digest(h, args)
+ local encoding = args[1] or 'hex'
+
+ local s
+ if encoding == 'hex' then
+ s = h:hex()
+ elseif encoding == 'base32' then
+ s = h:base32()
+ elseif encoding == 'bleach32' then
+ s = h:base32('bleach')
+ elseif encoding == 'rbase32' then
+ s = h:base32('rfc')
+ elseif encoding == 'base64' then
+ s = h:base64()
+ end
+
+ return s
+end
+
+local function create_digest(data, args)
+ local h = create_raw_digest(data, args)
+ return encode_digest(h, args)
+end
+
+local function get_cached_or_raw_digest(task, idx, mime_part, args)
+ if #args == 0 then
+ -- Optimise as we already have this hash in the API
+ return mime_part:get_digest()
+ end
+
+ local ht = args[2] or 'blake2'
+ local cache_key = 'mp_digest_' .. ht .. tostring(idx)
+
+ local cached = task:cache_get(cache_key)
+
+ if cached then
+ return encode_digest(cached, args)
+ end
+
+ local h = create_raw_digest(mime_part:get_content('raw_parsed'), args)
+ task:cache_set(cache_key, h)
+
+ return encode_digest(h, args)
+end
+
+exports.create_digest = create_digest
+exports.create_raw_digest = create_raw_digest
+exports.get_cached_or_raw_digest = get_cached_or_raw_digest
+exports.encode_digest = encode_digest
+
+return exports \ No newline at end of file
diff --git a/lualib/lua_selectors/extractors.lua b/lualib/lua_selectors/extractors.lua
new file mode 100644
index 0000000..81dfa9d
--- /dev/null
+++ b/lualib/lua_selectors/extractors.lua
@@ -0,0 +1,565 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local fun = require 'fun'
+local meta_functions = require "lua_meta"
+local lua_util = require "lua_util"
+local rspamd_url = require "rspamd_url"
+local common = require "lua_selectors/common"
+local ts = require("tableshape").types
+local maps = require "lua_selectors/maps"
+local E = {}
+local M = "selectors"
+
+local url_flags_ts = ts.array_of(ts.one_of(lua_util.keys(rspamd_url.flags))):is_optional()
+
+local function gen_exclude_flags_filter(exclude_flags)
+ return function(u)
+ local got_flags = u:get_flags()
+ for _, flag in ipairs(exclude_flags) do
+ if got_flags[flag] then
+ return false
+ end
+ end
+ return true
+ end
+end
+
+local extractors = {
+ -- Plain id function
+ ['id'] = {
+ ['get_value'] = function(_, args)
+ if args[1] then
+ return args[1], 'string'
+ end
+
+ return '', 'string'
+ end,
+ ['description'] = [[Return value from function's argument or an empty string,
+For example, `id('Something')` returns a string 'Something']],
+ ['args_schema'] = { ts.string:is_optional() }
+ },
+ -- Similar but for making lists
+ ['list'] = {
+ ['get_value'] = function(_, args)
+ if args[1] then
+ return fun.map(tostring, args), 'string_list'
+ end
+
+ return {}, 'string_list'
+ end,
+ ['description'] = [[Return a list from function's arguments or an empty list,
+For example, `list('foo', 'bar')` returns a list {'foo', 'bar'}]],
+ },
+ -- Get source IP address
+ ['ip'] = {
+ ['get_value'] = function(task)
+ local ip = task:get_ip()
+ if ip and ip:is_valid() then
+ return ip, 'userdata'
+ end
+ return nil
+ end,
+ ['description'] = [[Get source IP address]],
+ },
+ -- Get MIME from
+ ['from'] = {
+ ['get_value'] = function(task, args)
+ local from
+ if type(args) == 'table' then
+ from = task:get_from(args)
+ else
+ from = task:get_from(0)
+ end
+ if ((from or E)[1] or E).addr then
+ return from[1], 'table'
+ end
+ return nil
+ end,
+ ['description'] = [[Get MIME or SMTP from (e.g. `from('smtp')` or `from('mime')`,
+uses any type by default)]],
+ },
+ ['rcpts'] = {
+ ['get_value'] = function(task, args)
+ local rcpts
+ if type(args) == 'table' then
+ rcpts = task:get_recipients(args)
+ else
+ rcpts = task:get_recipients(0)
+ end
+ if ((rcpts or E)[1] or E).addr then
+ return rcpts, 'table_list'
+ end
+ return nil
+ end,
+ ['description'] = [[Get MIME or SMTP rcpts (e.g. `rcpts('smtp')` or `rcpts('mime')`,
+uses any type by default)]],
+ },
+ -- Get country (ASN module must be executed first)
+ ['country'] = {
+ ['get_value'] = function(task)
+ local country = task:get_mempool():get_variable('country')
+ if not country then
+ return nil
+ else
+ return country, 'string'
+ end
+ end,
+ ['description'] = [[Get country (ASN module must be executed first)]],
+ },
+ -- Get ASN number
+ ['asn'] = {
+ ['type'] = 'string',
+ ['get_value'] = function(task)
+ local asn = task:get_mempool():get_variable('asn')
+ if not asn then
+ return nil
+ else
+ return asn, 'string'
+ end
+ end,
+ ['description'] = [[Get AS number (ASN module must be executed first)]],
+ },
+ -- Get authenticated username
+ ['user'] = {
+ ['get_value'] = function(task)
+ local auser = task:get_user()
+ if not auser then
+ return nil
+ else
+ return auser, 'string'
+ end
+ end,
+ ['description'] = 'Get authenticated user name',
+ },
+ -- Get principal recipient
+ ['to'] = {
+ ['get_value'] = function(task)
+ return task:get_principal_recipient(), 'string'
+ end,
+ ['description'] = 'Get principal recipient',
+ },
+ -- Get content digest
+ ['digest'] = {
+ ['get_value'] = function(task)
+ return task:get_digest(), 'string'
+ end,
+ ['description'] = 'Get content digest',
+ },
+ -- Get list of all attachments digests
+ ['attachments'] = {
+ ['get_value'] = function(task, args)
+ local parts = task:get_parts() or E
+ local digests = {}
+ for i, p in ipairs(parts) do
+ if p:is_attachment() then
+ table.insert(digests, common.get_cached_or_raw_digest(task, i, p, args))
+ end
+ end
+
+ if #digests > 0 then
+ return digests, 'string_list'
+ end
+
+ return nil
+ end,
+ ['description'] = [[Get list of all attachments digests.
+The first optional argument is encoding (`hex`, `base32` (and forms `bleach32`, `rbase32`), `base64`),
+the second optional argument is optional hash type (`blake2`, `sha256`, `sha1`, `sha512`, `md5`)]],
+ ['args_schema'] = common.digest_schema()
+
+ },
+ -- Get all attachments files
+ ['files'] = {
+ ['get_value'] = function(task)
+ local parts = task:get_parts() or E
+ local files = {}
+
+ for _, p in ipairs(parts) do
+ local fname = p:get_filename()
+ if fname then
+ table.insert(files, fname)
+ end
+ end
+
+ if #files > 0 then
+ return files, 'string_list'
+ end
+
+ return nil
+ end,
+ ['description'] = 'Get all attachments files',
+ },
+ -- Get languages for text parts
+ ['languages'] = {
+ ['get_value'] = function(task)
+ local text_parts = task:get_text_parts() or E
+ local languages = {}
+
+ for _, p in ipairs(text_parts) do
+ local lang = p:get_language()
+ if lang then
+ table.insert(languages, lang)
+ end
+ end
+
+ if #languages > 0 then
+ return languages, 'string_list'
+ end
+
+ return nil
+ end,
+ ['description'] = 'Get languages for text parts',
+ },
+ -- Get helo value
+ ['helo'] = {
+ ['get_value'] = function(task)
+ return task:get_helo(), 'string'
+ end,
+ ['description'] = 'Get helo value',
+ },
+ -- Get header with the name that is expected as an argument. Returns list of
+ -- headers with this name
+ ['header'] = {
+ ['get_value'] = function(task, args)
+ local strong = false
+ if args[2] then
+ if args[2]:match('strong') then
+ strong = true
+ end
+
+ if args[2]:match('full') then
+ return task:get_header_full(args[1], strong), 'table_list'
+ end
+
+ return task:get_header(args[1], strong), 'string'
+ else
+ return task:get_header(args[1]), 'string'
+ end
+ end,
+ ['description'] = [[Get header with the name that is expected as an argument.
+The optional second argument accepts list of flags:
+ - `full`: returns all headers with this name with all data (like task:get_header_full())
+ - `strong`: use case sensitive match when matching header's name]],
+ ['args_schema'] = { ts.string,
+ (ts.pattern("strong") + ts.pattern("full")):is_optional() }
+ },
+ -- Get list of received headers (returns list of tables)
+ ['received'] = {
+ ['get_value'] = function(task, args)
+ local rh = task:get_received_headers()
+ if not rh[1] then
+ return nil
+ end
+ if args[1] then
+ return fun.map(function(r)
+ return r[args[1]]
+ end, rh), 'string_list'
+ end
+
+ return rh, 'table_list'
+ end,
+ ['description'] = [[Get list of received headers.
+If no arguments specified, returns list of tables. Otherwise, selects a specific element,
+e.g. `by_hostname`]],
+ },
+ -- Get all urls
+ ['urls'] = {
+ ['get_value'] = function(task, args)
+ local urls = task:get_urls()
+ if not urls[1] then
+ return nil
+ end
+ if args[1] then
+ return fun.map(function(r)
+ return r[args[1]](r)
+ end, urls), 'string_list'
+ end
+ return urls, 'userdata_list'
+ end,
+ ['description'] = [[Get list of all urls.
+If no arguments specified, returns list of url objects. Otherwise, calls a specific method,
+e.g. `get_tld`]],
+ },
+ -- Get specific urls
+ ['specific_urls'] = {
+ ['get_value'] = function(task, args)
+ local params = args[1] or {}
+ params.task = task
+ params.no_cache = true
+ if params.exclude_flags then
+ params.filter = gen_exclude_flags_filter(params.exclude_flags)
+ end
+ local urls = lua_util.extract_specific_urls(params)
+ if not urls[1] then
+ return nil
+ end
+ return urls, 'userdata_list'
+ end,
+ ['description'] = [[Get most specific urls. Arguments are equal to the Lua API function]],
+ ['args_schema'] = { ts.shape {
+ limit = ts.number + ts.string / tonumber,
+ esld_limit = (ts.number + ts.string / tonumber):is_optional(),
+ exclude_flags = url_flags_ts,
+ flags = url_flags_ts,
+ flags_mode = ts.one_of { 'explicit' }:is_optional(),
+ prefix = ts.string:is_optional(),
+ need_content = (ts.boolean + ts.string / lua_util.toboolean):is_optional(),
+ need_emails = (ts.boolean + ts.string / lua_util.toboolean):is_optional(),
+ need_images = (ts.boolean + ts.string / lua_util.toboolean):is_optional(),
+ ignore_redirected = (ts.boolean + ts.string / lua_util.toboolean):is_optional(),
+ } }
+ },
+ ['specific_urls_filter_map'] = {
+ ['get_value'] = function(task, args)
+ local map = maps[args[1]]
+ if not map then
+ lua_util.debugm(M, "invalid/unknown map: %s", args[1])
+ end
+ local params = args[2] or {}
+ params.task = task
+ params.no_cache = true
+ if params.exclude_flags then
+ params.filter = gen_exclude_flags_filter(params.exclude_flags)
+ end
+ local urls = lua_util.extract_specific_urls(params)
+ if not urls[1] then
+ return nil
+ end
+ return fun.filter(function(u)
+ return map:get_key(tostring(u))
+ end, urls), 'userdata_list'
+ end,
+ ['description'] = [[Get most specific urls, filtered by some map. Arguments are equal to the Lua API function]],
+ ['args_schema'] = { ts.string, ts.shape {
+ limit = ts.number + ts.string / tonumber,
+ esld_limit = (ts.number + ts.string / tonumber):is_optional(),
+ exclude_flags = url_flags_ts,
+ flags = url_flags_ts,
+ flags_mode = ts.one_of { 'explicit' }:is_optional(),
+ prefix = ts.string:is_optional(),
+ need_content = (ts.boolean + ts.string / lua_util.toboolean):is_optional(),
+ need_emails = (ts.boolean + ts.string / lua_util.toboolean):is_optional(),
+ need_images = (ts.boolean + ts.string / lua_util.toboolean):is_optional(),
+ ignore_redirected = (ts.boolean + ts.string / lua_util.toboolean):is_optional(),
+ } }
+ },
+ -- URLs filtered by flags
+ ['urls_filtered'] = {
+ ['get_value'] = function(task, args)
+ local urls = task:get_urls_filtered(args[1], args[2])
+ if not urls[1] then
+ return nil
+ end
+ return urls, 'userdata_list'
+ end,
+ ['description'] = [[Get list of all urls filtered by flags_include/exclude
+(see rspamd_task:get_urls_filtered for description)]],
+ ['args_schema'] = { ts.array_of {
+ url_flags_ts:is_optional(), url_flags_ts:is_optional()
+ } }
+ },
+ -- Get all emails
+ ['emails'] = {
+ ['get_value'] = function(task, args)
+ local urls = task:get_emails()
+ if not urls[1] then
+ return nil
+ end
+ if args[1] then
+ return fun.map(function(r)
+ return r[args[1]](r)
+ end, urls), 'string_list'
+ end
+ return urls, 'userdata_list'
+ end,
+ ['description'] = [[Get list of all emails.
+If no arguments specified, returns list of url objects. Otherwise, calls a specific method,
+e.g. `get_user`]],
+ },
+ -- Get specific pool var. The first argument must be variable name,
+ -- the second argument is optional and defines the type (string by default)
+ ['pool_var'] = {
+ ['get_value'] = function(task, args)
+ local type = args[2] or 'string'
+ return task:get_mempool():get_variable(args[1], type), (type)
+ end,
+ ['description'] = [[Get specific pool var. The first argument must be variable name,
+the second argument is optional and defines the type (string by default)]],
+ ['args_schema'] = { ts.string, ts.string:is_optional() }
+ },
+ -- Get value of specific key from task cache
+ ['task_cache'] = {
+ ['get_value'] = function(task, args)
+ local val = task:cache_get(args[1])
+ if not val then
+ return
+ end
+ if type(val) == 'table' then
+ if not val[1] then
+ return
+ end
+ return val, 'string_list'
+ end
+ return val, 'string'
+ end,
+ ['description'] = [[Get value of specific key from task cache. The first argument must be
+the key name]],
+ ['args_schema'] = { ts.string }
+ },
+ -- Get specific HTTP request header. The first argument must be header name.
+ ['request_header'] = {
+ ['get_value'] = function(task, args)
+ local hdr = task:get_request_header(args[1])
+ if hdr then
+ return hdr, 'string'
+ end
+
+ return nil
+ end,
+ ['description'] = [[Get specific HTTP request header.
+The first argument must be header name.]],
+ ['args_schema'] = { ts.string }
+ },
+ -- Get task date, optionally formatted
+ ['time'] = {
+ ['get_value'] = function(task, args)
+ local what = args[1] or 'message'
+ local dt = task:get_date { format = what, gmt = true }
+
+ if dt then
+ if args[2] then
+ -- Should be in format !xxx, as dt is in GMT
+ return os.date(args[2], dt), 'string'
+ end
+
+ return tostring(dt), 'string'
+ end
+
+ return nil
+ end,
+ ['description'] = [[Get task timestamp. The first argument is type:
+ - `connect`: connection timestamp (default)
+ - `message`: timestamp as defined by `Date` header
+
+ The second argument is optional time format, see [os.date](http://pgl.yoyo.org/luai/i/os.date) description]],
+ ['args_schema'] = { ts.one_of { 'connect', 'message' }:is_optional(),
+ ts.string:is_optional() }
+ },
+ -- Get text words from a message
+ ['words'] = {
+ ['get_value'] = function(task, args)
+ local how = args[1] or 'stem'
+ local tp = task:get_text_parts()
+
+ if tp then
+ local rtype = 'string_list'
+ if how == 'full' then
+ rtype = 'table_list'
+ end
+
+ return lua_util.flatten(
+ fun.map(function(p)
+ return p:get_words(how)
+ end, tp)), rtype
+ end
+
+ return nil
+ end,
+ ['description'] = [[Get words from text parts
+ - `stem`: stemmed words (default)
+ - `raw`: raw words
+ - `norm`: normalised words (lowercased)
+ - `full`: list of tables
+ ]],
+ ['args_schema'] = { ts.one_of { 'stem', 'raw', 'norm', 'full' }:is_optional() },
+ },
+ -- Get queue ID
+ ['queueid'] = {
+ ['get_value'] = function(task)
+ local queueid = task:get_queue_id()
+ if queueid then
+ return queueid, 'string'
+ end
+ return nil
+ end,
+ ['description'] = [[Get queue ID]],
+ },
+ -- Get ID of the task being processed
+ ['uid'] = {
+ ['get_value'] = function(task)
+ local uid = task:get_uid()
+ if uid then
+ return uid, 'string'
+ end
+ return nil
+ end,
+ ['description'] = [[Get ID of the task being processed]],
+ },
+ -- Get message ID of the task being processed
+ ['messageid'] = {
+ ['get_value'] = function(task)
+ local mid = task:get_message_id()
+ if mid then
+ return mid, 'string'
+ end
+ return nil
+ end,
+ ['description'] = [[Get message ID]],
+ },
+ -- Get specific symbol
+ ['symbol'] = {
+ ['get_value'] = function(task, args)
+ local symbol = task:get_symbol(args[1], args[2])
+ if symbol then
+ return symbol[1], 'table'
+ end
+ end,
+ ['description'] = 'Get specific symbol. The first argument must be the symbol name. ' ..
+ 'The second argument is an optional shadow result name. ' ..
+ 'Returns the symbol table. See task:get_symbol()',
+ ['args_schema'] = { ts.string, ts.string:is_optional() }
+ },
+ -- Get full scan result
+ ['scan_result'] = {
+ ['get_value'] = function(task, args)
+ local res = task:get_metric_result(args[1])
+ if res then
+ return res, 'table'
+ end
+ end,
+ ['description'] = 'Get full scan result (either default or shadow if shadow result name is specified)' ..
+ 'Returns the result table. See task:get_metric_result()',
+ ['args_schema'] = { ts.string:is_optional() }
+ },
+ -- Get list of metatokens as strings
+ ['metatokens'] = {
+ ['get_value'] = function(task)
+ local tokens = meta_functions.gen_metatokens(task)
+ if not tokens[1] then
+ return nil
+ end
+ local res = {}
+ for _, t in ipairs(tokens) do
+ table.insert(res, tostring(t))
+ end
+ return res, 'string_list'
+ end,
+ ['description'] = 'Get metatokens for a message as strings',
+ },
+}
+
+return extractors
diff --git a/lualib/lua_selectors/init.lua b/lualib/lua_selectors/init.lua
new file mode 100644
index 0000000..5fcdb38
--- /dev/null
+++ b/lualib/lua_selectors/init.lua
@@ -0,0 +1,668 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+-- This module contains 'selectors' implementation: code to extract data
+-- from Rspamd tasks and compose those together
+--
+-- Read more at https://rspamd.com/doc/configuration/selectors.html
+
+--[[[
+-- @module lua_selectors
+-- This module contains 'selectors' implementation: code to extract data
+-- from Rspamd tasks and compose those together.
+-- Typical selector looks like this: header(User).lower.substring(1, 2):ip
+--]]
+
+local exports = {
+ maps = require "lua_selectors/maps"
+}
+
+local logger = require 'rspamd_logger'
+local fun = require 'fun'
+local lua_util = require "lua_util"
+local M = "selectors"
+local rspamd_text = require "rspamd_text"
+local unpack_function = table.unpack or unpack
+local E = {}
+
+local extractors = require "lua_selectors/extractors"
+local transform_function = require "lua_selectors/transforms"
+
+local text_cookie = rspamd_text.cookie
+
+local function pure_type(ltype)
+ return ltype:match('^(.*)_list$')
+end
+
+local function implicit_tostring(t, ud_or_table)
+ if t == 'table' then
+ -- Table (very special)
+ if ud_or_table.value then
+ return ud_or_table.value, 'string'
+ elseif ud_or_table.addr then
+ return ud_or_table.addr, 'string'
+ end
+
+ return logger.slog("%s", ud_or_table), 'string'
+ elseif (t == 'string' or t == 'text') and type(ud_or_table) == 'userdata' then
+ if ud_or_table.cookie and ud_or_table.cookie == text_cookie then
+ -- Preserve opaque
+ return ud_or_table, 'string'
+ else
+ return tostring(ud_or_table), 'string'
+ end
+ elseif t ~= 'nil' then
+ return tostring(ud_or_table), 'string'
+ end
+
+ return nil
+end
+
+local function process_selector(task, sel)
+ local function allowed_type(t)
+ if t == 'string' or t == 'string_list' then
+ return true
+ end
+
+ return false
+ end
+
+ local function list_type(t)
+ return pure_type(t)
+ end
+
+ local input, etype = sel.selector.get_value(task, sel.selector.args)
+
+ if not input then
+ lua_util.debugm(M, task, 'no value extracted for %s', sel.selector.name)
+ return nil
+ end
+
+ lua_util.debugm(M, task, 'extracted %s, type %s',
+ sel.selector.name, etype)
+
+ local pipe = sel.processor_pipe or E
+ local first_elt = pipe[1]
+
+ if first_elt and (first_elt.method or
+ fun.any(function(t)
+ return t == 'userdata' or t == 'table'
+ end, first_elt.types)) then
+ -- Explicit conversion
+ local meth = first_elt
+
+ if meth.types[etype] then
+ lua_util.debugm(M, task, 'apply method `%s` to %s',
+ meth.name, etype)
+ input, etype = meth.process(input, etype, meth.args)
+ else
+ local pt = pure_type(etype)
+
+ if meth.types[pt] then
+ lua_util.debugm(M, task, 'map method `%s` to list of %s',
+ meth.name, pt)
+ -- Map method to a list of inputs, excluding empty elements
+ -- We need to fold it down here to get a proper type resolution
+ input = fun.totable(fun.filter(function(map_elt, _)
+ return map_elt
+ end,
+ fun.map(function(list_elt)
+ local ret, ty = meth.process(list_elt, pt, meth.args)
+ if ret then
+ etype = ty
+ end
+ return ret
+ end, input)))
+ if input and etype then
+ etype = etype .. "_list"
+ else
+ input = nil
+ end
+ end
+ end
+ -- Remove method from the pipeline
+ pipe = fun.drop_n(1, pipe)
+ elseif etype:match('^userdata') or etype:match('^table') then
+ -- Implicit conversion
+ local pt = pure_type(etype)
+
+ if not pt then
+ lua_util.debugm(M, task, 'apply implicit conversion %s->string', etype)
+ input = implicit_tostring(etype, input)
+ etype = 'string'
+ else
+ lua_util.debugm(M, task, 'apply implicit map %s->string', pt)
+ input = fun.filter(function(map_elt)
+ return map_elt
+ end,
+ fun.map(function(list_elt)
+ local ret = implicit_tostring(pt, list_elt)
+ return ret
+ end, input))
+ etype = 'string_list'
+ end
+ else
+ lua_util.debugm(M, task, 'avoid implicit conversion as the transformer accepts complex input')
+ end
+
+ -- Now we fold elements using left fold
+ local function fold_function(acc, x)
+ if acc == nil or acc[1] == nil then
+ lua_util.debugm(M, task, 'do not apply %s, accumulator is nil', x.name)
+ return nil
+ end
+
+ local value = acc[1]
+ local t = acc[2]
+
+ if not x.types[t] then
+ local pt = pure_type(t)
+
+ if pt and x.types['list'] then
+ -- Generic list processor
+ lua_util.debugm(M, task, 'apply list function `%s` to %s', x.name, t)
+ return { x.process(value, t, x.args) }
+ elseif pt and x.map_type and x.types[pt] then
+ local map_type = x.map_type .. '_list'
+ lua_util.debugm(M, task, 'map `%s` to list of %s resulting %s',
+ x.name, pt, map_type)
+ -- Apply map, filtering empty values
+ return {
+ fun.filter(function(map_elt)
+ return map_elt
+ end,
+ fun.map(function(list_elt)
+ if not list_elt then
+ return nil
+ end
+ local ret, _ = x.process(list_elt, pt, x.args)
+ return ret
+ end, value)),
+ map_type -- Returned type
+ }
+ end
+ logger.errx(task, 'cannot apply transform %s for type %s', x.name, t)
+ return nil
+ end
+
+ lua_util.debugm(M, task, 'apply %s to %s', x.name, t)
+ return { x.process(value, t, x.args) }
+ end
+
+ local res = fun.foldl(fold_function,
+ { input, etype },
+ pipe)
+
+ if not res or not res[1] then
+ return nil
+ end -- Pipeline failed
+
+ if not allowed_type(res[2]) then
+ -- Search for implicit conversion
+ local pt = pure_type(res[2])
+
+ if pt then
+ lua_util.debugm(M, task, 'apply implicit map %s->string_list', pt)
+ res[1] = fun.map(function(e)
+ return implicit_tostring(pt, e)
+ end, res[1])
+ res[2] = 'string_list'
+ else
+ res[1] = implicit_tostring(res[2], res[1])
+ res[2] = 'string'
+ end
+ end
+
+ if list_type(res[2]) then
+ -- Convert to table as it might have a functional form
+ res[1] = fun.totable(res[1])
+ end
+
+ lua_util.debugm(M, task, 'final selector type: %s, value: %s', res[2], res[1])
+
+ return res[1]
+end
+
+local function make_grammar()
+ local l = require "lpeg"
+ local spc = l.S(" \t\n") ^ 0
+ local cont = l.R("\128\191") -- continuation byte
+ local utf8_high = l.R("\194\223") * cont
+ + l.R("\224\239") * cont * cont
+ + l.R("\240\244") * cont * cont * cont
+ local atom_start = (l.R("az") + l.R("AZ") + l.R("09") + utf8_high + l.S "-") ^ 1
+ local atom_end = (l.R("az") + l.R("AZ") + l.R("09") + l.S "-_" + utf8_high) ^ 1
+ local atom_mid = (1 - l.S("'\r\n\f\\,)(}{= " .. '"')) ^ 1
+ local atom_argument = l.C(atom_start * atom_mid ^ 0 * atom_end ^ 0) -- We allow more characters for the arguments
+ local atom = l.C(atom_start * atom_end ^ 0) -- We are more strict about selector names itself
+ local singlequoted_string = l.P "'" * l.C(((1 - l.S "'\r\n\f\\") + (l.P '\\' * 1)) ^ 0) * "'"
+ local doublequoted_string = l.P '"' * l.C(((1 - l.S '"\r\n\f\\') + (l.P '\\' * 1)) ^ 0) * '"'
+ local argument = atom_argument + singlequoted_string + doublequoted_string
+ local dot = l.P(".")
+ local semicolon = l.P(":")
+ local obrace = "(" * spc
+ local tbl_obrace = "{" * spc
+ local eqsign = spc * "=" * spc
+ local tbl_ebrace = spc * "}"
+ local ebrace = spc * ")"
+ local comma = spc * "," * spc
+ local sel_separator = spc * l.S ";*" * spc
+
+ return l.P {
+ "LIST";
+ LIST = l.Ct(l.V("EXPR")) * (sel_separator * l.Ct(l.V("EXPR"))) ^ 0,
+ EXPR = l.V("FUNCTION") * (semicolon * l.V("METHOD")) ^ -1 * (dot * l.V("PROCESSOR")) ^ 0,
+ PROCESSOR = l.Ct(atom * spc * (obrace * l.V("ARG_LIST") * ebrace) ^ 0),
+ FUNCTION = l.Ct(atom * spc * (obrace * l.V("ARG_LIST") * ebrace) ^ 0),
+ METHOD = l.Ct(atom / function(e)
+ return '__' .. e
+ end * spc * (obrace * l.V("ARG_LIST") * ebrace) ^ 0),
+ ARG_LIST = l.Ct((l.V("ARG") * comma ^ 0) ^ 0),
+ ARG = l.Cf(tbl_obrace * l.V("NAMED_ARG") * tbl_ebrace, rawset) + argument + l.V("LIST_ARGS"),
+ NAMED_ARG = (l.Ct("") * l.Cg(argument * eqsign * (argument + l.V("LIST_ARGS")) * comma ^ 0) ^ 0),
+ LIST_ARGS = l.Ct(tbl_obrace * l.V("LIST_ARG") * tbl_ebrace),
+ LIST_ARG = l.Cg(argument * comma ^ 0) ^ 0,
+ }
+end
+
+local parser = make_grammar()
+
+--[[[
+-- @function lua_selectors.parse_selector(cfg, str)
+--]]
+exports.parse_selector = function(cfg, str)
+ local parsed = { parser:match(str) }
+ local output = {}
+
+ if not parsed or not parsed[1] then
+ return nil
+ end
+
+ local function check_args(name, schema, args)
+ if schema then
+ if getmetatable(schema) then
+ -- Schema covers all arguments
+ local res, err = schema:transform(args)
+ if not res then
+ logger.errx(rspamd_config, 'invalid arguments for %s: %s', name, err)
+ return false
+ else
+ for i, elt in ipairs(res) do
+ args[i] = elt
+ end
+ end
+ else
+ for i, selt in ipairs(schema) do
+ local res, err = selt:transform(args[i])
+
+ if err then
+ logger.errx(rspamd_config, 'invalid arguments for %s: argument number: %s, error: %s', name, i, err)
+ return false
+ else
+ args[i] = res
+ end
+ end
+ end
+ end
+
+ return true
+ end
+
+ -- Output AST format is the following:
+ -- table of individual selectors
+ -- each selector: list of functions
+ -- each function: function name + optional list of arguments
+ for _, sel in ipairs(parsed) do
+ local res = {
+ selector = {},
+ processor_pipe = {},
+ }
+
+ local selector_tbl = sel[1]
+ if not selector_tbl then
+ logger.errx(cfg, 'no selector represented')
+ return nil
+ end
+ if not extractors[selector_tbl[1]] then
+ logger.errx(cfg, 'selector %s is unknown', selector_tbl[1])
+ return nil
+ end
+
+ res.selector = lua_util.shallowcopy(extractors[selector_tbl[1]])
+ res.selector.name = selector_tbl[1]
+ res.selector.args = selector_tbl[2] or E
+
+ if not check_args(res.selector.name,
+ res.selector.args_schema,
+ res.selector.args) then
+ return nil
+ end
+
+ lua_util.debugm(M, cfg, 'processed selector %s, args: %s',
+ res.selector.name, res.selector.args)
+
+ local pipeline_error = false
+ -- Now process processors pipe
+ fun.each(function(proc_tbl)
+ local proc_name = proc_tbl[1]
+
+ if proc_name:match('^__') then
+ -- Special case - method
+ local method_name = proc_name:match('^__(.*)$')
+ -- Check array indexing...
+ if tonumber(method_name) then
+ method_name = tonumber(method_name)
+ end
+ local processor = {
+ name = tostring(method_name),
+ method = true,
+ args = proc_tbl[2] or E,
+ types = {
+ userdata = true,
+ table = true,
+ string = true,
+ },
+ map_type = 'string',
+ process = function(inp, t, args)
+ local ret
+ if t == 'table' then
+ -- Plain table field
+ ret = inp[method_name]
+ else
+ -- We call method unpacking arguments and dropping all but the first result returned
+ ret = (inp[method_name](inp, unpack_function(args or E)))
+ end
+
+ local ret_type = type(ret)
+
+ if ret_type == 'nil' then
+ return nil
+ end
+ -- Now apply types heuristic
+ if ret_type == 'string' then
+ return ret, 'string'
+ elseif ret_type == 'table' then
+ -- TODO: we need to ensure that 1) table is numeric 2) table has merely strings
+ return ret, 'string_list'
+ else
+ return implicit_tostring(ret_type, ret)
+ end
+ end,
+ }
+ lua_util.debugm(M, cfg, 'attached method %s to selector %s, args: %s',
+ proc_name, res.selector.name, processor.args)
+ table.insert(res.processor_pipe, processor)
+ else
+
+ if not transform_function[proc_name] then
+ logger.errx(cfg, 'processor %s is unknown', proc_name)
+ pipeline_error = proc_name
+ return nil
+ end
+ local processor = lua_util.shallowcopy(transform_function[proc_name])
+ processor.name = proc_name
+ processor.args = proc_tbl[2] or E
+
+ if not check_args(processor.name, processor.args_schema, processor.args) then
+ pipeline_error = 'args schema for ' .. proc_name
+ return nil
+ end
+
+ lua_util.debugm(M, cfg, 'attached processor %s to selector %s, args: %s',
+ proc_name, res.selector.name, processor.args)
+ table.insert(res.processor_pipe, processor)
+ end
+ end, fun.tail(sel))
+
+ if pipeline_error then
+ logger.errx(cfg, 'unknown or invalid processor used: "%s", exiting', pipeline_error)
+ return nil
+ end
+
+ table.insert(output, res)
+ end
+
+ return output
+end
+
+--[[[
+-- @function lua_selectors.register_extractor(cfg, name, selector)
+--]]
+exports.register_extractor = function(cfg, name, selector)
+ if selector.get_value then
+ if extractors[name] then
+ logger.warnx(cfg, 'redefining selector %s', name)
+ end
+ extractors[name] = selector
+
+ return true
+ end
+
+ logger.errx(cfg, 'bad selector %s', name)
+ return false
+end
+
+--[[[
+-- @function lua_selectors.register_transform(cfg, name, transform)
+--]]
+exports.register_transform = function(cfg, name, transform)
+ if transform.process and transform.types then
+ if transform_function[name] then
+ logger.warnx(cfg, 'redefining transform function %s', name)
+ end
+ transform_function[name] = transform
+
+ return true
+ end
+
+ logger.errx(cfg, 'bad transform function %s', name)
+ return false
+end
+
+--[[[
+-- @function lua_selectors.process_selectors(task, selectors_pipe)
+--]]
+exports.process_selectors = function(task, selectors_pipe)
+ local ret = {}
+
+ for _, sel in ipairs(selectors_pipe) do
+ local r = process_selector(task, sel)
+
+ -- If any element is nil, then the whole selector is nil
+ if not r then
+ return nil
+ end
+ table.insert(ret, r)
+ end
+
+ return ret
+end
+
+--[[[
+-- @function lua_selectors.combine_selectors(task, selectors, delimiter)
+--]]
+exports.combine_selectors = function(_, selectors, delimiter)
+ if not delimiter then
+ delimiter = ''
+ end
+
+ if not selectors then
+ return nil
+ end
+
+ local have_tables, have_userdata
+
+ for _, s in ipairs(selectors) do
+ if type(s) == 'table' then
+ have_tables = true
+ elseif type(s) == 'userdata' then
+ have_userdata = true
+ end
+ end
+
+ if not have_tables then
+ if not have_userdata then
+ return table.concat(selectors, delimiter)
+ else
+ return rspamd_text.fromtable(selectors, delimiter)
+ end
+ else
+ -- We need to do a spill on each table selector and make a cortesian product
+ -- e.g. s:tbl:s -> s:telt1:s + s:telt2:s ...
+ local tbl = {}
+ local res = {}
+
+ for i, s in ipairs(selectors) do
+ if type(s) == 'string' then
+ rawset(tbl, i, fun.duplicate(s))
+ elseif type(s) == 'userdata' then
+ rawset(tbl, i, fun.duplicate(tostring(s)))
+ else
+ -- Raw table
+ rawset(tbl, i, fun.map(tostring, s))
+ end
+ end
+
+ fun.each(function(...)
+ table.insert(res, table.concat({ ... }, delimiter))
+ end, fun.zip(lua_util.unpack(tbl)))
+
+ return res
+ end
+end
+
+--[[[
+-- @function lua_selectors.flatten_selectors(selectors)
+-- Convert selectors to a flat table of elements
+--]]
+exports.flatten_selectors = function(_, selectors, _)
+ local res = {}
+
+ local function fill(tbl)
+ for _, s in ipairs(tbl) do
+ if type(s) == 'string' then
+ rawset(res, #res + 1, s)
+ elseif type(s) == 'userdata' then
+ rawset(res, #res + 1, tostring(s))
+ else
+ fill(s)
+ end
+ end
+ end
+
+ fill(selectors)
+
+ return res
+end
+
+--[[[
+-- @function lua_selectors.kv_table_from_pairs(selectors)
+-- Convert selectors to a table where the odd elements are keys and even are elements
+-- Similarly to make a map from (k, v) pairs list
+-- To specify the concrete constant keys, one can use the `id` extractor
+--]]
+exports.kv_table_from_pairs = function(log_obj, selectors, _)
+ local res = {}
+ local rspamd_logger = require "rspamd_logger"
+
+ local function fill(tbl)
+ local tbl_len = #tbl
+ if tbl_len % 2 ~= 0 or tbl_len == 0 then
+ rspamd_logger.errx(log_obj, "invalid invocation of the `kv_table_from_pairs`: table length is invalid %s",
+ tbl_len)
+ return
+ end
+ for i = 1, tbl_len, 2 do
+ local k = tostring(tbl[i])
+ local v = tbl[i + 1]
+ if type(v) == 'string' then
+ res[k] = v
+ elseif type(v) == 'userdata' then
+ res[k] = tostring(v)
+ else
+ res[k] = fun.totable(fun.map(function(elt)
+ return tostring(elt)
+ end, v))
+ end
+ end
+ end
+
+ fill(selectors)
+
+ return res
+end
+
+
+--[[[
+-- @function lua_selectors.create_closure(log_obj, cfg, selector_str, delimiter, fn)
+-- Creates a closure from a string selector, using the specific combinator function
+--]]
+exports.create_selector_closure_fn = function(log_obj, cfg, selector_str, delimiter, fn)
+ local selector = exports.parse_selector(cfg, selector_str)
+
+ if not selector then
+ return nil
+ end
+
+ return function(task)
+ local res = exports.process_selectors(task, selector)
+
+ if res then
+ return fn(log_obj, res, delimiter)
+ end
+
+ return nil
+ end
+end
+
+--[[[
+-- @function lua_selectors.create_closure(cfg, selector_str, delimiter='', flatten=false)
+-- Creates a closure from a string selector
+--]]
+exports.create_selector_closure = function(cfg, selector_str, delimiter, flatten)
+ local combinator_fn = flatten and exports.flatten_selectors or exports.combine_selectors
+
+ return exports.create_selector_closure_fn(nil, cfg, selector_str, delimiter, combinator_fn)
+end
+
+local function display_selectors(tbl)
+ return fun.tomap(fun.map(function(k, v)
+ return k, fun.tomap(fun.filter(function(kk, vv)
+ return type(vv) ~= 'function'
+ end, v))
+ end, tbl))
+end
+
+exports.list_extractors = function()
+ return display_selectors(extractors)
+end
+
+exports.list_transforms = function()
+ return display_selectors(transform_function)
+end
+
+exports.add_map = function(name, map)
+ if not exports.maps[name] then
+ exports.maps[name] = map
+ else
+ logger.errx(rspamd_config, "duplicate map redefinition for the selectors: %s", name)
+ end
+end
+
+-- Publish log target
+exports.M = M
+
+return exports
diff --git a/lualib/lua_selectors/maps.lua b/lualib/lua_selectors/maps.lua
new file mode 100644
index 0000000..85b54a6
--- /dev/null
+++ b/lualib/lua_selectors/maps.lua
@@ -0,0 +1,19 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local maps = {} -- Shared within selectors, indexed by name
+
+return maps \ No newline at end of file
diff --git a/lualib/lua_selectors/transforms.lua b/lualib/lua_selectors/transforms.lua
new file mode 100644
index 0000000..6c6bc71
--- /dev/null
+++ b/lualib/lua_selectors/transforms.lua
@@ -0,0 +1,571 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local fun = require 'fun'
+local lua_util = require "lua_util"
+local rspamd_util = require "rspamd_util"
+local ts = require("tableshape").types
+local logger = require 'rspamd_logger'
+local common = require "lua_selectors/common"
+local M = "selectors"
+
+local maps = require "lua_selectors/maps"
+
+local function pure_type(ltype)
+ return ltype:match('^(.*)_list$')
+end
+
+local transform_function = {
+ -- Returns the lowercased string
+ ['lower'] = {
+ ['types'] = {
+ ['string'] = true,
+ },
+ ['map_type'] = 'string',
+ ['process'] = function(inp, _)
+ return inp:lower(), 'string'
+ end,
+ ['description'] = 'Returns the lowercased string',
+ },
+ -- Returns the lowercased utf8 string
+ ['lower_utf8'] = {
+ ['types'] = {
+ ['string'] = true,
+ },
+ ['map_type'] = 'string',
+ ['process'] = function(inp, t)
+ return rspamd_util.lower_utf8(inp), t
+ end,
+ ['description'] = 'Returns the lowercased utf8 string',
+ },
+ -- Returns the first element
+ ['first'] = {
+ ['types'] = {
+ ['list'] = true,
+ },
+ ['process'] = function(inp, t)
+ return fun.head(inp), pure_type(t)
+ end,
+ ['description'] = 'Returns the first element',
+ },
+ -- Returns the last element
+ ['last'] = {
+ ['types'] = {
+ ['list'] = true,
+ },
+ ['process'] = function(inp, t)
+ return fun.nth(fun.length(inp), inp), pure_type(t)
+ end,
+ ['description'] = 'Returns the last element',
+ },
+ -- Returns the nth element
+ ['nth'] = {
+ ['types'] = {
+ ['list'] = true,
+ },
+ ['process'] = function(inp, t, args)
+ return fun.nth(args[1] or 1, inp), pure_type(t)
+ end,
+ ['description'] = 'Returns the nth element',
+ ['args_schema'] = { ts.number + ts.string / tonumber }
+ },
+ ['take_n'] = {
+ ['types'] = {
+ ['list'] = true,
+ },
+ ['process'] = function(inp, t, args)
+ return fun.take_n(args[1] or 1, inp), t
+ end,
+ ['description'] = 'Returns the n first elements',
+ ['args_schema'] = { ts.number + ts.string / tonumber }
+ },
+ ['drop_n'] = {
+ ['types'] = {
+ ['list'] = true,
+ },
+ ['process'] = function(inp, t, args)
+ return fun.drop_n(args[1] or 1, inp), t
+ end,
+ ['description'] = 'Returns list without the first n elements',
+ ['args_schema'] = { ts.number + ts.string / tonumber }
+ },
+ -- Joins strings into a single string using separator in the argument
+ ['join'] = {
+ ['types'] = {
+ ['string_list'] = true
+ },
+ ['process'] = function(inp, _, args)
+ return table.concat(fun.totable(inp), args[1] or ''), 'string'
+ end,
+ ['description'] = 'Joins strings into a single string using separator in the argument',
+ ['args_schema'] = { ts.string:is_optional() }
+ },
+ -- Joins strings into a set of strings using N elements and a separator in the argument
+ ['join_nth'] = {
+ ['types'] = {
+ ['string_list'] = true
+ },
+ ['process'] = function(inp, _, args)
+ local step = args[1]
+ local sep = args[2] or ''
+ local inp_t = fun.totable(inp)
+ local res = {}
+
+ for i = 1, #inp_t, step do
+ table.insert(res, table.concat(inp_t, sep, i, i + step))
+ end
+ return res, 'string_list'
+ end,
+ ['description'] = 'Joins strings into a set of strings using N elements and a separator in the argument',
+ ['args_schema'] = { ts.number + ts.string / tonumber, ts.string:is_optional() }
+ },
+ -- Joins tables into a table of strings
+ ['join_tables'] = {
+ ['types'] = {
+ ['list'] = true
+ },
+ ['process'] = function(inp, _, args)
+ local sep = args[1] or ''
+ return fun.map(function(t)
+ return table.concat(t, sep)
+ end, inp), 'string_list'
+ end,
+ ['description'] = 'Joins tables into a table of strings',
+ ['args_schema'] = { ts.string:is_optional() }
+ },
+ -- Sort strings
+ ['sort'] = {
+ ['types'] = {
+ ['list'] = true
+ },
+ ['process'] = function(inp, t, _)
+ table.sort(inp)
+ return inp, t
+ end,
+ ['description'] = 'Sort strings lexicographically',
+ },
+ -- Return unique elements based on hashing (can work without sorting)
+ ['uniq'] = {
+ ['types'] = {
+ ['list'] = true
+ },
+ ['process'] = function(inp, t, _)
+ local tmp = {}
+ fun.each(function(val)
+ tmp[val] = true
+ end, inp)
+
+ return fun.map(function(k, _)
+ return k
+ end, tmp), t
+ end,
+ ['description'] = 'Returns a list of unique elements (using a hash table)',
+ },
+ -- Create a digest from string or a list of strings
+ ['digest'] = {
+ ['types'] = {
+ ['string'] = true
+ },
+ ['map_type'] = 'string',
+ ['process'] = function(inp, _, args)
+ return common.create_digest(inp, args), 'string'
+ end,
+ ['description'] = [[Create a digest from a string.
+The first argument is encoding (`hex`, `base32` (and forms `bleach32`, `rbase32`), `base64`),
+the second argument is optional hash type (`blake2`, `sha256`, `sha1`, `sha512`, `md5`)]],
+ ['args_schema'] = common.digest_schema()
+ },
+ -- Extracts substring
+ ['substring'] = {
+ ['types'] = {
+ ['string'] = true
+ },
+ ['map_type'] = 'string',
+ ['process'] = function(inp, _, args)
+ local start_pos = args[1] or 1
+ local end_pos = args[2] or -1
+
+ return inp:sub(start_pos, end_pos), 'string'
+ end,
+ ['description'] = 'Extracts substring; the first argument is start, the second is the last (like in Lua)',
+ ['args_schema'] = { (ts.number + ts.string / tonumber):is_optional(),
+ (ts.number + ts.string / tonumber):is_optional() }
+ },
+ -- Prepends a string or a strings list
+ ['prepend'] = {
+ ['types'] = {
+ ['string'] = true
+ },
+ ['map_type'] = 'string',
+ ['process'] = function(inp, _, args)
+ local prepend = table.concat(args, '')
+
+ return prepend .. inp, 'string'
+ end,
+ ['description'] = 'Prepends a string or a strings list',
+ },
+ -- Appends a string or a strings list
+ ['append'] = {
+ ['types'] = {
+ ['string'] = true
+ },
+ ['map_type'] = 'string',
+ ['process'] = function(inp, _, args)
+ local append = table.concat(args, '')
+
+ return inp .. append, 'string'
+ end,
+ ['description'] = 'Appends a string or a strings list',
+ },
+ -- Regexp matching
+ ['regexp'] = {
+ ['types'] = {
+ ['string'] = true
+ },
+ ['map_type'] = 'string',
+ ['process'] = function(inp, _, args)
+ local rspamd_regexp = require "rspamd_regexp"
+
+ local re = rspamd_regexp.create_cached(args[1])
+
+ if not re then
+ logger.errx('invalid regexp: %s', args[1])
+ return nil
+ end
+
+ local res = re:search(inp, false, true)
+
+ if res then
+ -- Map all results in a single list
+ local flattened_table = {}
+ local function flatten_table(tbl)
+ for _, v in ipairs(tbl) do
+ if type(v) == 'table' then
+ flatten_table(v)
+ else
+ table.insert(flattened_table, v)
+ end
+ end
+ end
+ flatten_table(res)
+ return flattened_table, 'string_list'
+ end
+
+ return nil
+ end,
+ ['description'] = 'Regexp matching, returns all matches flattened in a single list',
+ ['args_schema'] = { ts.string }
+ },
+ -- Returns a value if it exists in some map (or acts like a `filter` function)
+ ['filter_map'] = {
+ ['types'] = {
+ ['string'] = true
+ },
+ ['map_type'] = 'string',
+ ['process'] = function(inp, t, args)
+ local map = maps[args[1]]
+
+ if not map then
+ logger.errx('invalid map name: %s', args[1])
+ return nil
+ end
+
+ local res = map:get_key(inp)
+
+ if res then
+ return inp, t
+ end
+
+ return nil
+ end,
+ ['description'] = 'Returns a value if it exists in some map (or acts like a `filter` function)',
+ ['args_schema'] = { ts.string }
+ },
+ -- Returns a value if it exists in some map (or acts like a `filter` function)
+ ['except_map'] = {
+ ['types'] = {
+ ['string'] = true
+ },
+ ['map_type'] = 'string',
+ ['process'] = function(inp, t, args)
+ local map = maps[args[1]]
+
+ if not map then
+ logger.errx('invalid map name: %s', args[1])
+ return nil
+ end
+
+ local res = map:get_key(inp)
+
+ if not res then
+ return inp, t
+ end
+
+ return nil
+ end,
+ ['description'] = 'Returns a value if it does not exists in some map (or acts like a `except` function)',
+ ['args_schema'] = { ts.string }
+ },
+ -- Returns a value from some map corresponding to some key (or acts like a `map` function)
+ ['apply_map'] = {
+ ['types'] = {
+ ['string'] = true
+ },
+ ['map_type'] = 'string',
+ ['process'] = function(inp, t, args)
+ local map = maps[args[1]]
+
+ if not map then
+ logger.errx('invalid map name: %s', args[1])
+ return nil
+ end
+
+ local res = map:get_key(inp)
+
+ if res then
+ return res, t
+ end
+
+ return nil
+ end,
+ ['description'] = 'Returns a value from some map corresponding to some key (or acts like a `map` function)',
+ ['args_schema'] = { ts.string }
+ },
+ -- Drops input value and return values from function's arguments or an empty string
+ ['id'] = {
+ ['types'] = {
+ ['string'] = true,
+ ['list'] = true,
+ },
+ ['map_type'] = 'string',
+ ['process'] = function(_, _, args)
+ if args[1] and args[2] then
+ return fun.map(tostring, args), 'string_list'
+ elseif args[1] then
+ return args[1], 'string'
+ end
+
+ return '', 'string'
+ end,
+ ['description'] = 'Drops input value and return values from function\'s arguments or an empty string',
+ ['args_schema'] = (ts.string + ts.array_of(ts.string)):is_optional()
+ },
+ ['equal'] = {
+ ['types'] = {
+ ['string'] = true,
+ },
+ ['map_type'] = 'string',
+ ['process'] = function(inp, _, args)
+ if inp == args[1] then
+ return inp, 'string'
+ end
+
+ return nil
+ end,
+ ['description'] = [[Boolean function equal.
+Returns either nil or its argument if input is equal to argument]],
+ ['args_schema'] = { ts.string }
+ },
+ -- Boolean function in, returns either nil or its input if input is in args list
+ ['in'] = {
+ ['types'] = {
+ ['string'] = true,
+ },
+ ['map_type'] = 'string',
+ ['process'] = function(inp, t, args)
+ for _, a in ipairs(args) do
+ if a == inp then
+ return inp, t
+ end
+ end
+ return nil
+ end,
+ ['description'] = [[Boolean function in.
+Returns either nil or its input if input is in args list]],
+ ['args_schema'] = ts.array_of(ts.string)
+ },
+ ['not_in'] = {
+ ['types'] = {
+ ['string'] = true,
+ },
+ ['map_type'] = 'string',
+ ['process'] = function(inp, t, args)
+ for _, a in ipairs(args) do
+ if a == inp then
+ return nil
+ end
+ end
+ return inp, t
+ end,
+ ['description'] = [[Boolean function not in.
+Returns either nil or its input if input is not in args list]],
+ ['args_schema'] = ts.array_of(ts.string)
+ },
+ ['inverse'] = {
+ ['types'] = {
+ ['string'] = true,
+ },
+ ['map_type'] = 'string',
+ ['process'] = function(inp, _, args)
+ if inp then
+ return nil
+ else
+ return (args[1] or 'true'), 'string'
+ end
+ end,
+ ['description'] = [[Inverses input.
+Empty string comes the first argument or 'true', non-empty string comes nil]],
+ ['args_schema'] = { ts.string:is_optional() }
+ },
+ ['ipmask'] = {
+ ['types'] = {
+ ['string'] = true,
+ },
+ ['map_type'] = 'string',
+ ['process'] = function(inp, _, args)
+ local rspamd_ip = require "rspamd_ip"
+ -- Non optimal: convert string to an IP address
+ local ip = rspamd_ip.from_string(inp)
+
+ if not ip or not ip:is_valid() then
+ lua_util.debugm(M, "cannot convert %s to IP", inp)
+ return nil
+ end
+
+ if ip:get_version() == 4 then
+ local mask = tonumber(args[1])
+
+ return ip:apply_mask(mask):to_string(), 'string'
+ else
+ -- IPv6 takes the second argument or the first one...
+ local mask_str = args[2] or args[1]
+ local mask = tonumber(mask_str)
+
+ return ip:apply_mask(mask):to_string(), 'string'
+ end
+ end,
+ ['description'] = 'Applies mask to IP address.' ..
+ ' The first argument is the mask for IPv4 addresses, the second is the mask for IPv6 addresses.',
+ ['args_schema'] = { (ts.number + ts.string / tonumber),
+ (ts.number + ts.string / tonumber):is_optional() }
+ },
+ -- Returns the string(s) with all non ascii chars replaced
+ ['to_ascii'] = {
+ ['types'] = {
+ ['string'] = true,
+ ['list'] = true,
+ },
+ ['map_type'] = 'string',
+ ['process'] = function(inp, _, args)
+ if type(inp) == 'table' then
+ return fun.map(
+ function(s)
+ return string.gsub(tostring(s), '[\128-\255]', args[1] or '?')
+ end, inp), 'string_list'
+ else
+ return string.gsub(tostring(inp), '[\128-\255]', '?'), 'string'
+ end
+ end,
+ ['description'] = 'Returns the string with all non-ascii bytes replaced with the character ' ..
+ 'given as second argument or `?`',
+ ['args_schema'] = { ts.string:is_optional() }
+ },
+ -- Extracts tld from a hostname
+ ['get_tld'] = {
+ ['types'] = {
+ ['string'] = true
+ },
+ ['map_type'] = 'string',
+ ['process'] = function(inp, _, _)
+ return rspamd_util.get_tld(inp), 'string'
+ end,
+ ['description'] = 'Extracts tld from a hostname represented as a string',
+ ['args_schema'] = {}
+ },
+ -- Converts list of strings to numbers and returns a packed string
+ ['pack_numbers'] = {
+ ['types'] = {
+ ['string_list'] = true
+ },
+ ['map_type'] = 'string',
+ ['process'] = function(inp, _, args)
+ local fmt = args[1] or 'f'
+ local res = {}
+ for _, s in ipairs(inp) do
+ table.insert(res, tonumber(s))
+ end
+ return rspamd_util.pack(string.rep(fmt, #res), lua_util.unpack(res)), 'string'
+ end,
+ ['description'] = 'Converts a list of strings to numbers & returns a packed string',
+ ['args_schema'] = { ts.string:is_optional() }
+ },
+ -- Filter nils from a list
+ ['filter_string_nils'] = {
+ ['types'] = {
+ ['string_list'] = true
+ },
+ ['process'] = function(inp, _, _)
+ return fun.filter(function(val)
+ return type(val) == 'string' and val ~= 'nil'
+ end, inp), 'string_list'
+ end,
+ ['description'] = 'Removes all nils from a list of strings (when converted implicitly)',
+ ['args_schema'] = {}
+ },
+ -- Call a set of methods on a userdata object
+ ['apply_methods'] = {
+ ['types'] = {
+ ['userdata'] = true,
+ },
+ ['process'] = function(inp, _, args)
+ local res = {}
+ for _, arg in ipairs(args) do
+ local meth = inp[arg]
+ local ret = meth(inp)
+ if ret then
+ table.insert(res, tostring(ret))
+ end
+ end
+ return res, 'string_list'
+ end,
+ ['description'] = 'Apply a list of method calls to the userdata object',
+ },
+ -- Apply method to list of userdata and use it as a filter, excluding elements for which method returns false/nil
+ ['filter_method'] = {
+ ['types'] = {
+ ['userdata_list'] = true
+ },
+ ['process'] = function(inp, t, args)
+ local meth = args[1]
+
+ if not meth then
+ logger.errx('invalid method name: %s', args[1])
+ return nil
+ end
+
+ return fun.filter(function(val)
+ return val[meth](val)
+ end, inp), 'userdata_list'
+ end,
+ ['description'] = 'Apply method to list of userdata and use it as a filter,' ..
+ ' excluding elements for which method returns false/nil',
+ ['args_schema'] = { ts.string }
+ },
+}
+
+transform_function.match = transform_function.regexp
+
+return transform_function
diff --git a/lualib/lua_settings.lua b/lualib/lua_settings.lua
new file mode 100644
index 0000000..d6d24d6
--- /dev/null
+++ b/lualib/lua_settings.lua
@@ -0,0 +1,309 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_settings
+-- This module contains internal helpers for the settings infrastructure in Rspamd
+-- More details at https://rspamd.com/doc/configuration/settings.html
+--]]
+
+local exports = {}
+local known_ids = {}
+local post_init_added = false
+local post_init_performed = false
+local all_symbols
+local default_symbols
+
+local fun = require "fun"
+local lua_util = require "lua_util"
+local rspamd_logger = require "rspamd_logger"
+
+local function register_settings_cb(from_postload)
+ if not post_init_performed then
+ all_symbols = rspamd_config:get_symbols()
+
+ default_symbols = fun.totable(fun.filter(function(_, v)
+ return not v.allowed_ids or #v.allowed_ids == 0 or v.flags.explicit_disable
+ end, all_symbols))
+
+ local explicit_symbols = lua_util.keys(fun.filter(function(k, v)
+ return v.flags.explicit_disable
+ end, all_symbols))
+
+ local symnames = lua_util.list_to_hash(lua_util.keys(all_symbols))
+
+ for _, set in pairs(known_ids) do
+ local s = set.settings.apply or {}
+ set.symbols = lua_util.shallowcopy(symnames)
+ local enabled_symbols = {}
+ local seen_enabled = false
+ local disabled_symbols = {}
+ local seen_disabled = false
+
+ -- Enabled map
+ if s.symbols_enabled then
+ -- Remove all symbols from set.symbols aside of explicit_disable symbols
+ set.symbols = lua_util.list_to_hash(explicit_symbols)
+ seen_enabled = true
+ for _, sym in ipairs(s.symbols_enabled) do
+ enabled_symbols[sym] = true
+ set.symbols[sym] = true
+ end
+ end
+ if s.groups_enabled then
+ seen_enabled = true
+ for _, gr in ipairs(s.groups_enabled) do
+ local syms = rspamd_config:get_group_symbols(gr)
+
+ if syms then
+ for _, sym in ipairs(syms) do
+ enabled_symbols[sym] = true
+ set.symbols[sym] = true
+ end
+ end
+ end
+ end
+
+ -- Disabled map
+ if s.symbols_disabled then
+ seen_disabled = true
+ for _, sym in ipairs(s.symbols_disabled) do
+ disabled_symbols[sym] = true
+ set.symbols[sym] = false
+ end
+ end
+ if s.groups_disabled then
+ seen_disabled = true
+ for _, gr in ipairs(s.groups_disabled) do
+ local syms = rspamd_config:get_group_symbols(gr)
+
+ if syms then
+ for _, sym in ipairs(syms) do
+ disabled_symbols[sym] = true
+ set.symbols[sym] = false
+ end
+ end
+ end
+ end
+
+ -- Deal with complexity to avoid mess in C
+ if not seen_enabled then
+ enabled_symbols = nil
+ end
+ if not seen_disabled then
+ disabled_symbols = nil
+ end
+
+ if enabled_symbols or disabled_symbols then
+ -- Specify what symbols are really enabled for this settings id
+ set.has_specific_symbols = true
+ end
+
+ rspamd_config:register_settings_id(set.name, enabled_symbols, disabled_symbols)
+
+ -- Remove to avoid clash
+ s.symbols_disabled = nil
+ s.symbols_enabled = nil
+ s.groups_enabled = nil
+ s.groups_disabled = nil
+ end
+
+ -- We now iterate over all symbols and check for allowed_ids/forbidden_ids
+ for k, v in pairs(all_symbols) do
+ if v.allowed_ids and not v.flags.explicit_disable then
+ for _, id in ipairs(v.allowed_ids) do
+ if known_ids[id] then
+ local set = known_ids[id]
+ if not set.has_specific_symbols then
+ set.has_specific_symbols = true
+ end
+ set.symbols[k] = true
+ else
+ rspamd_logger.errx(rspamd_config, 'symbol %s is allowed at unknown settings id %s',
+ k, id)
+ end
+ end
+ end
+ if v.forbidden_ids then
+ for _, id in ipairs(v.forbidden_ids) do
+ if known_ids[id] then
+ local set = known_ids[id]
+ if not set.has_specific_symbols then
+ set.has_specific_symbols = true
+ end
+ set.symbols[k] = false
+ else
+ rspamd_logger.errx(rspamd_config, 'symbol %s is denied at unknown settings id %s',
+ k, id)
+ end
+ end
+ end
+ end
+
+ -- Now we create lists of symbols for each settings and digest
+ for _, set in pairs(known_ids) do
+ set.symbols = lua_util.keys(fun.filter(function(_, v)
+ return v
+ end, set.symbols))
+ table.sort(set.symbols)
+ set.digest = lua_util.table_digest(set.symbols)
+ end
+
+ post_init_performed = true
+ end
+end
+
+-- Returns numeric representation of the settings id
+local function numeric_settings_id(str)
+ local cr = require "rspamd_cryptobox_hash"
+ local util = require "rspamd_util"
+ local ret = util.unpack("I4",
+ cr.create_specific('xxh64'):update(str):bin())
+
+ return ret
+end
+
+exports.numeric_settings_id = numeric_settings_id
+
+-- Used to do the following:
+-- If there is a group of symbols_allowed, it checks if that is an array
+-- If that is a hash table then we transform it to a normal list, probably adding symbols to adjust scores
+local function transform_settings_maybe(settings, name)
+ if settings.apply then
+ local apply = settings.apply
+
+ if apply.symbols_enabled then
+ local senabled = apply.symbols_enabled
+
+ if not senabled[1] then
+ -- Transform map to a list
+ local nlist = {}
+ if not apply.scores then
+ apply.scores = {}
+ end
+ for k, v in pairs(senabled) do
+ if tonumber(v) then
+ -- Move to symbols as well
+ apply.scores[k] = tonumber(v)
+ lua_util.debugm('settings', rspamd_config,
+ 'set symbol %s -> %s for settings %s', k, v, name)
+ end
+ nlist[#nlist + 1] = k
+ end
+ -- Convert
+ apply.symbols_enabled = nlist
+ end
+
+ local symhash = lua_util.list_to_hash(apply.symbols_enabled)
+
+ if apply.symbols then
+ -- Check if added symbols are enabled
+ for k, v in pairs(apply.symbols) do
+ local s
+ -- Check if we have ["sym1", "sym2" ...] or {"sym1": xx, "sym2": yy}
+ if type(k) == 'string' then
+ s = k
+ else
+ s = v
+ end
+ if not symhash[s] then
+ lua_util.debugm('settings', rspamd_config,
+ 'added symbol %s to symbols_enabled for %s', s, name)
+ apply.symbols_enabled[#apply.symbols_enabled + 1] = s
+ end
+ end
+ end
+ end
+ end
+
+ return settings
+end
+
+local function register_settings_id(str, settings, from_postload)
+ local numeric_id = numeric_settings_id(str)
+
+ if known_ids[numeric_id] then
+ -- Might be either rewrite or a collision
+ if known_ids[numeric_id].name ~= str then
+ local logger = require "rspamd_logger"
+
+ logger.errx(rspamd_config, 'settings ID clash! id %s maps to %s and conflicts with %s',
+ numeric_id, known_ids[numeric_id].name, str)
+
+ return nil
+ end
+ else
+ known_ids[numeric_id] = {
+ name = str,
+ id = numeric_id,
+ settings = transform_settings_maybe(settings, str),
+ symbols = {}
+ }
+ end
+
+ if not from_postload and not post_init_added then
+ -- Use high priority to ensure that settings are initialised early but not before all
+ -- plugins are loaded
+ rspamd_config:add_post_init(function()
+ register_settings_cb(true)
+ end, 150)
+ rspamd_config:add_config_unload(function()
+ if post_init_added then
+ known_ids = {}
+ post_init_added = false
+ end
+ post_init_performed = false
+ end)
+
+ post_init_added = true
+ end
+
+ return numeric_id
+end
+
+exports.register_settings_id = register_settings_id
+
+local function settings_by_id(id)
+ if not post_init_performed then
+ register_settings_cb(false)
+ end
+ return known_ids[id]
+end
+
+exports.settings_by_id = settings_by_id
+exports.all_settings = function()
+ if not post_init_performed then
+ register_settings_cb(false)
+ end
+ return known_ids
+end
+exports.all_symbols = function()
+ if not post_init_performed then
+ register_settings_cb(false)
+ end
+ return all_symbols
+end
+-- What is enabled when no settings are there
+exports.default_symbols = function()
+ if not post_init_performed then
+ register_settings_cb(false)
+ end
+ return default_symbols
+end
+
+exports.load_all_settings = register_settings_cb
+
+return exports \ No newline at end of file
diff --git a/lualib/lua_smtp.lua b/lualib/lua_smtp.lua
new file mode 100644
index 0000000..3c40349
--- /dev/null
+++ b/lualib/lua_smtp.lua
@@ -0,0 +1,201 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local rspamd_tcp = require "rspamd_tcp"
+local lua_util = require "lua_util"
+
+local exports = {}
+
+local CRLF = '\r\n'
+local default_timeout = 10.0
+
+--[[[
+-- @function lua_smtp.sendmail(task, message, opts, callback)
+--]]
+local function sendmail(opts, message, callback)
+ local stage = 'connect'
+
+ local function mail_cb(err, data, conn)
+ local function no_error_write(merr)
+ if merr then
+ callback(false, string.format('error on stage %s: %s',
+ stage, merr))
+ if conn then
+ conn:close()
+ end
+
+ return false
+ end
+
+ return true
+ end
+
+ local function no_error_read(merr, mdata, wantcode)
+ wantcode = wantcode or '2'
+ if merr then
+ callback(false, string.format('error on stage %s: %s',
+ stage, merr))
+ if conn then
+ conn:close()
+ end
+
+ return false
+ end
+ if mdata then
+ if type(mdata) ~= 'string' then
+ mdata = tostring(mdata)
+ end
+ if string.sub(mdata, 1, 1) ~= wantcode then
+ callback(false, string.format('bad smtp response on stage %s: "%s" when "%s" expected',
+ stage, mdata, wantcode))
+ if conn then
+ conn:close()
+ end
+ return false
+ end
+ else
+ callback(false, string.format('no data on stage %s',
+ stage))
+ if conn then
+ conn:close()
+ end
+ return false
+ end
+ return true
+ end
+
+ -- After quit
+ local function all_done_cb(merr, mdata)
+ if conn then
+ conn:close()
+ end
+
+ callback(true, nil)
+
+ return true
+ end
+
+ -- QUIT stage
+ local function quit_done_cb(_, _)
+ conn:add_read(all_done_cb, CRLF)
+ end
+ local function quit_cb(merr, mdata)
+ if no_error_read(merr, mdata) then
+ conn:add_write(quit_done_cb, 'QUIT' .. CRLF)
+ end
+ end
+ local function pre_quit_cb(merr, _)
+ if no_error_write(merr) then
+ stage = 'quit'
+ conn:add_read(quit_cb, CRLF)
+ end
+ end
+
+ -- DATA stage
+ local function data_done_cb(merr, mdata)
+ if no_error_read(merr, mdata, '3') then
+ if type(message) == 'string' or type(message) == 'userdata' then
+ conn:add_write(pre_quit_cb, { message, CRLF .. '.' .. CRLF })
+ else
+ table.insert(message, CRLF .. '.' .. CRLF)
+ conn:add_write(pre_quit_cb, message)
+ end
+ end
+ end
+ local function data_cb(merr, _)
+ if no_error_write(merr) then
+ conn:add_read(data_done_cb, CRLF)
+ end
+ end
+
+ -- RCPT phase
+ local next_recipient
+ local function rcpt_done_cb_gen(i)
+ return function(merr, mdata)
+ if no_error_read(merr, mdata) then
+ if i == #opts.recipients then
+ conn:add_write(data_cb, 'DATA' .. CRLF)
+ else
+ next_recipient(i + 1)
+ end
+ end
+ end
+ end
+
+ local function rcpt_cb_gen(i)
+ return function(merr, _)
+ if no_error_write(merr, '2') then
+ conn:add_read(rcpt_done_cb_gen(i), CRLF)
+ end
+ end
+ end
+
+ next_recipient = function(i)
+ conn:add_write(rcpt_cb_gen(i),
+ string.format('RCPT TO: <%s>%s', opts.recipients[i], CRLF))
+ end
+
+ -- FROM stage
+ local function from_done_cb(merr, mdata)
+ -- We need to iterate over recipients sequentially
+ if no_error_read(merr, mdata, '2') then
+ stage = 'rcpt'
+ next_recipient(1)
+ end
+ end
+ local function from_cb(merr, _)
+ if no_error_write(merr) then
+ conn:add_read(from_done_cb, CRLF)
+ end
+ end
+ local function hello_done_cb(merr, mdata)
+ if no_error_read(merr, mdata) then
+ stage = 'from'
+ conn:add_write(from_cb, string.format(
+ 'MAIL FROM: <%s>%s', opts.from, CRLF))
+ end
+ end
+
+ -- HELO stage
+ local function hello_cb(merr)
+ if no_error_write(merr) then
+ conn:add_read(hello_done_cb, CRLF)
+ end
+ end
+ if no_error_read(err, data) then
+ stage = 'helo'
+ conn:add_write(hello_cb, string.format('HELO %s%s',
+ opts.helo, CRLF))
+ end
+ end
+
+ if type(opts.recipients) == 'string' then
+ opts.recipients = { opts.recipients }
+ end
+
+ local tcp_opts = lua_util.shallowcopy(opts)
+ tcp_opts.stop_pattern = CRLF
+ tcp_opts.timeout = opts.timeout or default_timeout
+ tcp_opts.callback = mail_cb
+
+ if not rspamd_tcp.request(tcp_opts) then
+ callback(false, 'cannot make a TCP connection')
+ end
+end
+
+exports.sendmail = sendmail
+
+return exports \ No newline at end of file
diff --git a/lualib/lua_stat.lua b/lualib/lua_stat.lua
new file mode 100644
index 0000000..a0f3303
--- /dev/null
+++ b/lualib/lua_stat.lua
@@ -0,0 +1,869 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_stat
+-- This module contains helper functions for supporting statistics
+--]]
+
+local logger = require "rspamd_logger"
+local sqlite3 = require "rspamd_sqlite3"
+local util = require "rspamd_util"
+local lua_redis = require "lua_redis"
+local lua_util = require "lua_util"
+local exports = {}
+
+local N = "stat_tools" -- luacheck: ignore (maybe unused)
+
+-- Performs synchronous conversion of redis schema
+local function convert_bayes_schema(redis_params, symbol_spam, symbol_ham, expire)
+
+ -- Old schema is the following one:
+ -- Keys are named <symbol>[<user>]
+ -- Elements are placed within hash:
+ -- BAYES_SPAM -> {<id1>: <num_hits>, <id2>: <num_hits> ...}
+ -- In new schema it is changed to a more extensible schema:
+ -- Keys are named RS[<user>]_<id> -> {'H': <ham_hits>, 'S': <spam_hits>}
+ -- So we can expire individual records, measure most popular elements by zranges,
+ -- add new fields, such as tokens etc
+
+ local res, conn = lua_redis.redis_connect_sync(redis_params, true)
+
+ if not res then
+ logger.errx("cannot connect to redis server")
+ return false
+ end
+
+ -- KEYS[1]: key to check (e.g. 'BAYES_SPAM')
+ -- KEYS[2]: hash key ('S' or 'H')
+ -- KEYS[3]: expire
+ local lua_script = [[
+redis.replicate_commands()
+local keys = redis.call('SMEMBERS', KEYS[1]..'_keys')
+local nconverted = 0
+for _,k in ipairs(keys) do
+ local cursor = redis.call('HSCAN', k, 0)
+ local neutral_prefix = string.gsub(k, KEYS[1], 'RS')
+ local elts
+ while cursor[1] ~= "0" do
+ elts = cursor[2]
+ cursor = redis.call('HSCAN', k, cursor[1])
+ local real_key
+ for i,v in ipairs(elts) do
+ if i % 2 ~= 0 then
+ real_key = v
+ else
+ local nkey = string.format('%s_%s', neutral_prefix, real_key)
+ redis.call('HSET', nkey, KEYS[2], v)
+ if KEYS[3] and tonumber(KEYS[3]) > 0 then
+ redis.call('EXPIRE', nkey, KEYS[3])
+ end
+ nconverted = nconverted + 1
+ end
+ end
+ end
+end
+return nconverted
+]]
+
+ conn:add_cmd('EVAL', { lua_script, '3', symbol_spam, 'S', tostring(expire) })
+ local ret
+ ret, res = conn:exec()
+
+ if not ret then
+ logger.errx('error converting symbol %s: %s', symbol_spam, res)
+ return false
+ else
+ logger.messagex('converted %s elements from symbol %s', res, symbol_spam)
+ end
+
+ conn:add_cmd('EVAL', { lua_script, '3', symbol_ham, 'H', tostring(expire) })
+ ret, res = conn:exec()
+
+ if not ret then
+ logger.errx('error converting symbol %s: %s', symbol_ham, res)
+ return false
+ else
+ logger.messagex('converted %s elements from symbol %s', res, symbol_ham)
+ end
+
+ -- We can now convert metadata: set + learned + version
+ -- KEYS[1]: key to check (e.g. 'BAYES_SPAM')
+ -- KEYS[2]: learn key (e.g. 'learns_spam' or 'learns_ham')
+ lua_script = [[
+local keys = redis.call('SMEMBERS', KEYS[1]..'_keys')
+
+for _,k in ipairs(keys) do
+ local learns = redis.call('HGET', k, 'learns') or 0
+ local neutral_prefix = string.gsub(k, KEYS[1], 'RS')
+
+ redis.call('HSET', neutral_prefix, KEYS[2], learns)
+ redis.call('SADD', KEYS[1]..'_keys', neutral_prefix)
+ redis.call('SREM', KEYS[1]..'_keys', k)
+ redis.call('DEL', KEYS[1])
+ redis.call('SET', k ..'_version', '2')
+end
+]]
+
+ conn:add_cmd('EVAL', { lua_script, '2', symbol_spam, 'learns_spam' })
+ ret, res = conn:exec()
+
+ if not ret then
+ logger.errx('error converting metadata for symbol %s: %s', symbol_spam, res)
+ return false
+ end
+
+ conn:add_cmd('EVAL', { lua_script, '2', symbol_ham, 'learns_ham' })
+ ret, res = conn:exec()
+
+ if not ret then
+ logger.errx('error converting metadata for symbol %s', symbol_ham, res)
+ return false
+ end
+
+ return true
+end
+
+exports.convert_bayes_schema = convert_bayes_schema
+
+-- It now accepts both ham and spam databases
+-- parameters:
+-- redis_params - how do we connect to a redis server
+-- sqlite_db_spam - name for sqlite database with spam tokens
+-- sqlite_db_ham - name for sqlite database with ham tokens
+-- symbol_ham - name for symbol representing spam, e.g. BAYES_SPAM
+-- symbol_spam - name for symbol representing ham, e.g. BAYES_HAM
+-- learn_cache_spam - name for sqlite database with spam learn cache
+-- learn_cache_ham - name for sqlite database with ham learn cache
+-- reset_previous - if true, then the old database is flushed (slow)
+local function convert_sqlite_to_redis(redis_params,
+ sqlite_db_spam, sqlite_db_ham, symbol_spam, symbol_ham,
+ learn_cache_db, expire, reset_previous)
+ local nusers = 0
+ local lim = 1000 -- Update each 1000 tokens
+ local users_map = {}
+ local converted = 0
+
+ local db_spam = sqlite3.open(sqlite_db_spam)
+ if not db_spam then
+ logger.errx('Cannot open source db: %s', sqlite_db_spam)
+ return false
+ end
+ local db_ham = sqlite3.open(sqlite_db_ham)
+ if not db_ham then
+ logger.errx('Cannot open source db: %s', sqlite_db_ham)
+ return false
+ end
+
+ local res, conn = lua_redis.redis_connect_sync(redis_params, true)
+
+ if not res then
+ logger.errx("cannot connect to redis server")
+ return false
+ end
+
+ if reset_previous then
+ -- Do a more complicated cleanup
+ -- execute a lua script that cleans up data
+ local script = [[
+local members = redis.call('SMEMBERS', KEYS[1]..'_keys')
+
+for _,prefix in ipairs(members) do
+ local keys = redis.call('KEYS', prefix..'*')
+ redis.call('DEL', keys)
+end
+]]
+ -- Common keys
+ for _, sym in ipairs({ symbol_spam, symbol_ham }) do
+ logger.messagex('Cleaning up old data for %s', sym)
+ conn:add_cmd('EVAL', { script, '1', sym })
+ conn:exec()
+ conn:add_cmd('DEL', { sym .. "_version" })
+ conn:add_cmd('DEL', { sym .. "_keys" })
+ conn:exec()
+ end
+
+ if learn_cache_db then
+ -- Cleanup learned_cache
+ logger.messagex('Cleaning up old data learned cache')
+ conn:add_cmd('DEL', { "learned_ids" })
+ conn:exec()
+ end
+ end
+
+ local function convert_db(db, is_spam)
+ -- Map users and languages
+ local what = 'ham'
+ if is_spam then
+ what = 'spam'
+ end
+
+ local learns = {}
+ db:sql('BEGIN;')
+ -- Fill users mapping
+ for row in db:rows('SELECT * FROM users;') do
+ if row.id == '0' then
+ users_map[row.id] = ''
+ else
+ users_map[row.id] = row.name
+ end
+ learns[row.id] = row.learns
+ nusers = nusers + 1
+ end
+
+ -- Workaround for old databases
+ for row in db:rows('SELECT * FROM languages') do
+ if learns['0'] then
+ learns['0'] = learns['0'] + row.learns
+ else
+ learns['0'] = row.learns
+ end
+ end
+
+ local function send_batch(tokens, prefix)
+ -- We use the new schema: RS[user]_token -> H=ham count
+ -- S=spam count
+ local hash_key = 'H'
+ if is_spam then
+ hash_key = 'S'
+ end
+ for _, tok in ipairs(tokens) do
+ -- tok schema:
+ -- tok[1] = token_id (uint64 represented as a string)
+ -- tok[2] = token value (number)
+ -- tok[3] = user_map[user_id] or ''
+ local rkey = string.format('%s%s_%s', prefix, tok[3], tok[1])
+ conn:add_cmd('HINCRBYFLOAT', { rkey, hash_key, tostring(tok[2]) })
+
+ if expire and expire ~= 0 then
+ conn:add_cmd('EXPIRE', { rkey, tostring(expire) })
+ end
+ end
+
+ return conn:exec()
+ end
+ -- Fill tokens, sending data to redis each `lim` records
+
+ local ntokens = db:query('SELECT count(*) as c FROM tokens')['c']
+ local tokens = {}
+ local num = 0
+ local total = 0
+
+ for row in db:rows('SELECT token,value,user FROM tokens;') do
+ local user = ''
+ if row.user ~= 0 and users_map[row.user] then
+ user = users_map[row.user]
+ end
+
+ table.insert(tokens, { row.token, row.value, user })
+ num = num + 1
+ total = total + 1
+ if num > lim then
+ -- TODO: we use the default 'RS' prefix, it can be false in case of
+ -- classifiers with labels
+ local ret, err_str = send_batch(tokens, 'RS')
+ if not ret then
+ logger.errx('Cannot send tokens to the redis server: ' .. err_str)
+ db:sql('COMMIT;')
+ return false
+ end
+
+ num = 0
+ tokens = {}
+ end
+
+ io.write(string.format('Processed batch %s: %s/%s\r', what, total, ntokens))
+ end
+ -- Last batch
+ if #tokens > 0 then
+ local ret, err_str = send_batch(tokens, 'RS')
+ if not ret then
+ logger.errx('Cannot send tokens to the redis server: ' .. err_str)
+ db:sql('COMMIT;')
+ return false
+ end
+
+ io.write(string.format('Processed batch %s: %s/%s\r', what, total, ntokens))
+ end
+ io.write('\n')
+
+ converted = converted + total
+
+ -- Close DB
+ db:sql('COMMIT;')
+ local symbol = symbol_ham
+ local learns_elt = "learns_ham"
+
+ if is_spam then
+ symbol = symbol_spam
+ learns_elt = "learns_spam"
+ end
+
+ for id, learned in pairs(learns) do
+ local user = users_map[id]
+ if not conn:add_cmd('HSET', { 'RS' .. user, learns_elt, learned }) then
+ logger.errx('Cannot update learns for user: ' .. user)
+ return false
+ end
+ if not conn:add_cmd('SADD', { symbol .. '_keys', 'RS' .. user }) then
+ logger.errx('Cannot update learns for user: ' .. user)
+ return false
+ end
+ end
+ -- Set version
+ conn:add_cmd('SET', { symbol .. '_version', '2' })
+ return conn:exec()
+ end
+
+ logger.messagex('Convert spam tokens')
+ if not convert_db(db_spam, true) then
+ return false
+ end
+
+ logger.messagex('Convert ham tokens')
+ if not convert_db(db_ham, false) then
+ return false
+ end
+
+ if learn_cache_db then
+ logger.messagex('Convert learned ids from %s', learn_cache_db)
+ local db = sqlite3.open(learn_cache_db)
+ local ret = true
+ local total = 0
+
+ if not db then
+ logger.errx('Cannot open cache database: ' .. learn_cache_db)
+ return false
+ end
+
+ db:sql('BEGIN;')
+
+ for row in db:rows('SELECT * FROM learns;') do
+ local is_spam
+ local digest = tostring(util.encode_base32(row.digest))
+
+ if row.flag == '0' then
+ is_spam = '-1'
+ else
+ is_spam = '1'
+ end
+
+ if not conn:add_cmd('HSET', { 'learned_ids', digest, is_spam }) then
+ logger.errx('Cannot add hash: ' .. digest)
+ ret = false
+ else
+ total = total + 1
+ end
+ end
+ db:sql('COMMIT;')
+
+ if ret then
+ conn:exec()
+ end
+
+ if ret then
+ logger.messagex('Converted %s cached items from sqlite3 learned cache to redis',
+ total)
+ else
+ logger.errx('Error occurred during sending data to redis')
+ end
+ end
+
+ logger.messagex('Migrated %s tokens for %s users for symbols (%s, %s)',
+ converted, nusers, symbol_spam, symbol_ham)
+ return true
+end
+
+exports.convert_sqlite_to_redis = convert_sqlite_to_redis
+
+-- Loads sqlite3 based classifiers and output data in form of array of objects:
+-- [
+-- {
+-- symbol_spam = XXX
+-- symbol_ham = YYY
+-- db_spam = XXX.sqlite
+-- db_ham = YYY.sqlite
+-- learn_cache = ZZZ.sqlite
+-- per_user = true/false
+-- label = str
+-- }
+-- ]
+local function load_sqlite_config(cfg)
+ local result = {}
+
+ local function parse_classifier(cls)
+ local tbl = {}
+ if cls.cache then
+ local cache = cls.cache
+ if cache.type == 'sqlite3' and (cache.file or cache.path) then
+ tbl.learn_cache = (cache.file or cache.path)
+ end
+ end
+
+ if cls.per_user then
+ tbl.per_user = cls.per_user
+ end
+
+ if cls.label then
+ tbl.label = cls.label
+ end
+
+ local statfiles = cls.statfile
+ for _, stf in ipairs(statfiles) do
+ local path = (stf.file or stf.path or stf.db or stf.dbname)
+ local symbol = stf.symbol or 'undefined'
+
+ if not path then
+ logger.errx('no path defined for statfile %s', symbol)
+ else
+
+ local spam
+ if stf.spam then
+ spam = stf.spam
+ else
+ if string.match(symbol:upper(), 'SPAM') then
+ spam = true
+ else
+ spam = false
+ end
+ end
+
+ if spam then
+ tbl.symbol_spam = symbol
+ tbl.db_spam = path
+ else
+ tbl.symbol_ham = symbol
+ tbl.db_ham = path
+ end
+ end
+ end
+
+ if tbl.symbol_spam and tbl.symbol_ham and tbl.db_ham and tbl.db_spam then
+ table.insert(result, tbl)
+ end
+ end
+
+ local classifier = cfg.classifier
+
+ if classifier then
+ if classifier[1] then
+ for _, cls in ipairs(classifier) do
+ if cls.bayes then
+ cls = cls.bayes
+ end
+ if cls.backend and cls.backend == 'sqlite3' then
+ parse_classifier(cls)
+ end
+ end
+ else
+ if classifier.bayes then
+ classifier = classifier.bayes
+ if classifier[1] then
+ for _, cls in ipairs(classifier) do
+ if cls.backend and cls.backend == 'sqlite3' then
+ parse_classifier(cls)
+ end
+ end
+ else
+ if classifier.backend and classifier.backend == 'sqlite3' then
+ parse_classifier(classifier)
+ end
+ end
+ end
+ end
+ end
+
+ return result
+end
+
+exports.load_sqlite_config = load_sqlite_config
+
+-- A helper method that suggests a user how to configure Redis based
+-- classifier based on the existing sqlite classifier
+local function redis_classifier_from_sqlite(sqlite_classifier, expire)
+ local result = {
+ new_schema = true,
+ backend = 'redis',
+ cache = {
+ backend = 'redis'
+ },
+ statfile = {
+ [sqlite_classifier.symbol_spam] = {
+ spam = true
+ },
+ [sqlite_classifier.symbol_ham] = {
+ spam = false
+ }
+ }
+ }
+
+ if expire then
+ result.expire = expire
+ end
+
+ return { classifier = { bayes = result } }
+end
+
+exports.redis_classifier_from_sqlite = redis_classifier_from_sqlite
+
+-- Reads statistics config and return preprocessed table
+local function process_stat_config(cfg)
+ local opts_section = cfg:get_all_opt('options') or {}
+
+ -- Check if we have a dedicated section for statistics
+ if opts_section.statistics then
+ opts_section = opts_section.statistics
+ end
+
+ -- Default
+ local res_config = {
+ classify_headers = {
+ "User-Agent",
+ "X-Mailer",
+ "Content-Type",
+ "X-MimeOLE",
+ "Organization",
+ "Organisation"
+ },
+ classify_images = true,
+ classify_mime_info = true,
+ classify_urls = true,
+ classify_meta = true,
+ classify_max_tlds = 10,
+ }
+
+ res_config = lua_util.override_defaults(res_config, opts_section)
+
+ -- Postprocess classify_headers
+ local classify_headers_parsed = {}
+
+ for _, v in ipairs(res_config.classify_headers) do
+ local s1, s2 = v:match("^([A-Z])[^%-]+%-([A-Z]).*$")
+
+ local hname
+ if s1 and s2 then
+ hname = string.format('%s-%s', s1, s2)
+ else
+ s1 = v:match("^X%-([A-Z].*)$")
+
+ if s1 then
+ hname = string.format('x%s', s1:sub(1, 3):lower())
+ else
+ hname = string.format('%s', v:sub(1, 3):lower())
+ end
+ end
+
+ if classify_headers_parsed[hname] then
+ table.insert(classify_headers_parsed[hname], v)
+ else
+ classify_headers_parsed[hname] = { v }
+ end
+ end
+
+ res_config.classify_headers_parsed = classify_headers_parsed
+
+ return res_config
+end
+
+local function get_mime_stat_tokens(task, res, i)
+ local parts = task:get_parts() or {}
+ local seen_multipart = false
+ local seen_plain = false
+ local seen_html = false
+ local empty_plain = false
+ local empty_html = false
+ local online_text = false
+
+ for _, part in ipairs(parts) do
+ local fname = part:get_filename()
+
+ local sz = part:get_length()
+
+ if sz > 0 then
+ rawset(res, i, string.format("#ps:%d",
+ math.floor(math.log(sz))))
+ lua_util.debugm("bayes", task, "part size: %s",
+ res[i])
+ i = i + 1
+ end
+
+ if fname then
+ rawset(res, i, "#f:" .. fname)
+ i = i + 1
+
+ lua_util.debugm("bayes", task, "added attachment: #f:%s",
+ fname)
+ end
+
+ if part:is_text() then
+ local tp = part:get_text()
+
+ if tp:is_html() then
+ seen_html = true
+
+ if tp:get_length() == 0 then
+ empty_html = true
+ end
+ else
+ seen_plain = true
+
+ if tp:get_length() == 0 then
+ empty_plain = true
+ end
+ end
+
+ if tp:get_lines_count() < 2 then
+ online_text = true
+ end
+
+ rawset(res, i, "#lang:" .. (tp:get_language() or 'unk'))
+ lua_util.debugm("bayes", task, "added language: %s",
+ res[i])
+ i = i + 1
+
+ rawset(res, i, "#cs:" .. (tp:get_charset() or 'unk'))
+ lua_util.debugm("bayes", task, "added charset: %s",
+ res[i])
+ i = i + 1
+
+ elseif part:is_multipart() then
+ seen_multipart = true;
+ end
+ end
+
+ -- Create a special token depending on parts structure
+ local st_tok = "#unk"
+ if seen_multipart and seen_html and seen_plain then
+ st_tok = '#mpth'
+ end
+
+ if seen_html and not seen_plain then
+ st_tok = "#ho"
+ end
+
+ if seen_plain and not seen_html then
+ st_tok = "#to"
+ end
+
+ local spec_tok = ""
+ if online_text then
+ spec_tok = "#ot"
+ end
+
+ if empty_plain then
+ spec_tok = spec_tok .. "#ep"
+ end
+
+ if empty_html then
+ spec_tok = spec_tok .. "#eh"
+ end
+
+ rawset(res, i, string.format("#m:%s%s", st_tok, spec_tok))
+ lua_util.debugm("bayes", task, "added mime token: %s",
+ res[i])
+ i = i + 1
+
+ return i
+end
+
+local function get_headers_stat_tokens(task, cf, res, i)
+ --[[
+ -- As discussed with Alexander Moisseev, this feature can skew statistics
+ -- especially when learning is separated from scanning, so learning
+ -- has a different set of tokens where this token can have too high weight
+ local hdrs_cksum = task:get_mempool():get_variable("headers_hash")
+
+ if hdrs_cksum then
+ rawset(res, i, string.format("#hh:%s", hdrs_cksum:sub(1, 7)))
+ lua_util.debugm("bayes", task, "added hdrs hash token: %s",
+ res[i])
+ i = i + 1
+ end
+ ]]--
+
+ for k, hdrs in pairs(cf.classify_headers_parsed) do
+ for _, hname in ipairs(hdrs) do
+ local value = task:get_header(hname)
+
+ if value then
+ rawset(res, i, string.format("#h:%s:%s", k, value))
+ lua_util.debugm("bayes", task, "added hdrs token: %s",
+ res[i])
+ i = i + 1
+ end
+ end
+ end
+
+ local from = (task:get_from('mime') or {})[1]
+
+ if from and from.name then
+ rawset(res, i, string.format("#F:%s", from.name))
+ lua_util.debugm("bayes", task, "added from name token: %s",
+ res[i])
+ i = i + 1
+ end
+
+ return i
+end
+
+local function get_meta_stat_tokens(task, res, i)
+ local day_and_hour = os.date('%u:%H',
+ task:get_date { format = 'message', gmt = true })
+ rawset(res, i, string.format("#dt:%s", day_and_hour))
+ lua_util.debugm("bayes", task, "added day_of_week token: %s",
+ res[i])
+ i = i + 1
+
+ local pol = {}
+
+ -- Authentication results
+ if task:has_symbol('DKIM_TRACE') then
+ -- Autolearn or scan
+ if task:has_symbol('R_SPF_ALLOW') then
+ table.insert(pol, 's=pass')
+ end
+
+ local trace = task:get_symbol('DKIM_TRACE')
+ local dkim_opts = trace[1]['options']
+ if dkim_opts then
+ for _, o in ipairs(dkim_opts) do
+ local check_res = string.sub(o, -1)
+ local domain = string.sub(o, 1, -3)
+
+ if check_res == '+' then
+ table.insert(pol, string.format('d=%s:%s', "pass", domain))
+ end
+ end
+ end
+ else
+ -- Offline learn
+ local aur = task:get_header('Authentication-Results')
+
+ if aur then
+ local spf = aur:match('spf=([a-z]+)')
+ local dkim, dkim_domain = aur:match('dkim=([a-z]+) header.d=([a-z.%-]+)')
+
+ if spf then
+ table.insert(pol, 's=' .. spf)
+ end
+ if dkim and dkim_domain then
+ table.insert(pol, string.format('d=%s:%s', dkim, dkim_domain))
+ end
+ end
+ end
+
+ if #pol > 0 then
+ rawset(res, i, string.format("#aur:%s", table.concat(pol, ',')))
+ lua_util.debugm("bayes", task, "added policies token: %s",
+ res[i])
+ i = i + 1
+ end
+
+ --[[
+ -- Disabled.
+ -- 1. Depending on the source the message has a different set of Received
+ -- headers as the receiving MTA adds another Received header.
+ -- 2. The usefulness of the Received tokens is questionable.
+ local rh = task:get_received_headers()
+
+ if rh and #rh > 0 then
+ local lim = math.min(5, #rh)
+ for j =1,lim do
+ local rcvd = rh[j]
+ local ip = rcvd.real_ip
+ if ip and ip:is_valid() and ip:get_version() == 4 then
+ local masked = ip:apply_mask(24)
+
+ rawset(res, i, string.format("#rcv:%s:%s", tostring(masked),
+ rcvd.proto))
+ lua_util.debugm("bayes", task, "added received token: %s",
+ res[i])
+ i = i + 1
+ end
+ end
+ end
+ ]]--
+
+ return i
+end
+
+local function get_stat_tokens(task, cf)
+ local res = {}
+ local E = {}
+ local i = 1
+
+ if cf.classify_images then
+ local images = task:get_images() or E
+
+ for _, img in ipairs(images) do
+ rawset(res, i, "image")
+ i = i + 1
+ rawset(res, i, tostring(img:get_height()))
+ i = i + 1
+ rawset(res, i, tostring(img:get_width()))
+ i = i + 1
+ rawset(res, i, tostring(img:get_type()))
+ i = i + 1
+
+ local fname = img:get_filename()
+
+ if fname then
+ rawset(res, i, tostring(img:get_filename()))
+ i = i + 1
+ end
+
+ lua_util.debugm("bayes", task, "added image: %s",
+ fname)
+ end
+ end
+
+ if cf.classify_mime_info then
+ i = get_mime_stat_tokens(task, res, i)
+ end
+
+ if cf.classify_headers and #cf.classify_headers > 0 then
+ i = get_headers_stat_tokens(task, cf, res, i)
+ end
+
+ if cf.classify_urls then
+ local urls = lua_util.extract_specific_urls { task = task, limit = 5, esld_limit = 1 }
+
+ if urls then
+ for _, u in ipairs(urls) do
+ rawset(res, i, string.format("#u:%s", u:get_tld()))
+ lua_util.debugm("bayes", task, "added url token: %s",
+ res[i])
+ i = i + 1
+ end
+ end
+ end
+
+ if cf.classify_meta then
+ i = get_meta_stat_tokens(task, res, i)
+ end
+
+ return res
+end
+
+exports.gen_stat_tokens = function(cfg)
+ local stat_config = process_stat_config(cfg)
+
+ return function(task)
+ return get_stat_tokens(task, stat_config)
+ end
+end
+
+return exports
diff --git a/lualib/lua_tcp_sync.lua b/lualib/lua_tcp_sync.lua
new file mode 100644
index 0000000..f8e6044
--- /dev/null
+++ b/lualib/lua_tcp_sync.lua
@@ -0,0 +1,213 @@
+local rspamd_tcp = require "rspamd_tcp"
+local lua_util = require "lua_util"
+
+local exports = {}
+local N = 'tcp_sync'
+
+local tcp_sync = { _conn = nil, _data = '', _eof = false, _addr = '' }
+local metatable = {
+ __tostring = function(self)
+ return "class {tcp_sync connect to: " .. self._addr .. "}"
+ end
+}
+
+function tcp_sync.new(connection)
+ local self = {}
+
+ for name, method in pairs(tcp_sync) do
+ if name ~= 'new' then
+ self[name] = method
+ end
+ end
+
+ self._conn = connection
+
+ setmetatable(self, metatable)
+
+ return self
+end
+
+--[[[
+-- @method tcp_sync.read_once()
+--
+-- Acts exactly like low-level tcp_sync.read_once()
+-- the only exception is that if there is some pending data,
+-- it's returned immediately and no underlying call is performed
+--
+-- @return
+-- true, {data} if everything is fine
+-- false, {error message} otherwise
+--
+--]]
+function tcp_sync:read_once()
+ local is_ok, data
+ if self._data:len() > 0 then
+ data = self._data
+ self._data = nil
+ return true, data
+ end
+
+ is_ok, data = self._conn:read_once()
+
+ return is_ok, data
+end
+
+--[[[
+-- @method tcp_sync.read_until(pattern)
+--
+-- Reads data from the connection until pattern is found
+-- returns all bytes before the pattern
+--
+-- @param {pattern} Read data until pattern is found
+-- @return
+-- true, {data} if everything is fine
+-- false, {error message} otherwise
+-- @example
+--
+--]]
+function tcp_sync:read_until(pattern)
+ repeat
+ local pos_start, pos_end = self._data:find(pattern, 1, true)
+ if pos_start then
+ local data = self._data:sub(1, pos_start - 1)
+ self._data = self._data:sub(pos_end + 1)
+ return true, data
+ end
+
+ local is_ok, more_data = self._conn:read_once()
+ if not is_ok then
+ return is_ok, more_data
+ end
+
+ self._data = self._data .. more_data
+ until false
+end
+
+--[[[
+-- @method tcp_sync.read_bytes(n)
+--
+-- Reads {n} bytes from the stream
+--
+-- @param {n} Number of bytes to read
+-- @return
+-- true, {data} if everything is fine
+-- false, {error message} otherwise
+--
+--]]
+function tcp_sync:read_bytes(n)
+ repeat
+ if self._data:len() >= n then
+ local data = self._data:sub(1, n)
+ self._data = self._data:sub(n + 1)
+ return true, data
+ end
+
+ local is_ok, more_data = self._conn:read_once()
+ if not is_ok then
+ return is_ok, more_data
+ end
+
+ self._data = self._data .. more_data
+ until false
+end
+
+--[[[
+-- @method tcp_sync.read_until_eof(n)
+--
+-- Reads stream until EOF is reached
+--
+-- @return
+-- true, {data} if everything is fine
+-- false, {error message} otherwise
+--
+--]]
+function tcp_sync:read_until_eof()
+ while not self:eof() do
+ local is_ok, more_data = self._conn:read_once()
+ if not is_ok then
+ if self:eof() then
+ -- this error is EOF (connection terminated)
+ -- exactly what we were waiting for
+ break
+ end
+ return is_ok, more_data
+ end
+ self._data = self._data .. more_data
+ end
+
+ local data = self._data
+ self._data = ''
+ return true, data
+end
+
+--[[[
+-- @method tcp_sync.write(n)
+--
+-- Writes data into the stream.
+--
+-- @return
+-- true if everything is fine
+-- false, {error message} otherwise
+--
+--]]
+function tcp_sync:write(data)
+ return self._conn:write(data)
+end
+
+--[[[
+-- @method tcp_sync.close()
+--
+-- Closes the connection. If the connection was created with task,
+-- this method is called automatically as soon as the task is done
+-- Calling this method helps to prevent connections leak.
+-- The object is finally destroyed by garbage collector.
+--
+-- @return
+--
+--]]
+function tcp_sync:close()
+ return self._conn:close()
+end
+
+--[[[
+-- @method tcp_sync.eof()
+--
+-- @return
+-- true if last "read" operation ended with EOF
+-- false otherwise
+--
+--]]
+function tcp_sync:eof()
+ if not self._eof and self._conn:eof() then
+ self._eof = true
+ end
+ return self._eof
+end
+
+--[[[
+-- @function tcp_sync.shutdown(n)
+--
+-- half-close socket
+--
+-- @return
+--
+--]]
+function tcp_sync:shutdown()
+ return self._conn:shutdown()
+end
+
+exports.connect = function(args)
+ local is_ok, connection = rspamd_tcp.connect_sync(args)
+ if not is_ok then
+ return is_ok, connection
+ end
+
+ local instance = tcp_sync.new(connection)
+ instance._addr = string.format("%s:%s", tostring(args.host), tostring(args.port))
+
+ lua_util.debugm(N, args.task, 'Connected to %s', instance._addr)
+
+ return true, instance
+end
+
+return exports \ No newline at end of file
diff --git a/lualib/lua_urls_compose.lua b/lualib/lua_urls_compose.lua
new file mode 100644
index 0000000..1113421
--- /dev/null
+++ b/lualib/lua_urls_compose.lua
@@ -0,0 +1,286 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_urls_compose
+-- This module contains functions to compose urls queries from hostname
+-- to TLD part
+--]]
+
+local N = "lua_urls_compose"
+local lua_util = require "lua_util"
+local rspamd_util = require "rspamd_util"
+local bit = require "bit"
+local rspamd_trie = require "rspamd_trie"
+local fun = require "fun"
+local rspamd_regexp = require "rspamd_regexp"
+
+local maps_cache = {}
+
+local exports = {}
+
+local function process_url(self, log_obj, url_tld, url_host)
+ local tld_elt = self.tlds[url_tld]
+
+ if tld_elt then
+ lua_util.debugm(N, log_obj, 'found compose tld for %s (host = %s)',
+ url_tld, url_host)
+
+ for _, excl in ipairs(tld_elt.except_rules) do
+ local matched, ret = excl[2](url_tld, url_host)
+ if matched then
+ lua_util.debugm(N, log_obj, 'found compose exclusion for %s (%s) -> %s',
+ url_host, excl[1], ret)
+
+ return ret
+ end
+ end
+
+ if tld_elt.multipattern_compose_rules then
+ local matches = tld_elt.multipattern_compose_rules:match(url_host)
+
+ if matches then
+ local lua_pat_idx = math.huge
+
+ for m, _ in pairs(matches) do
+ if m < lua_pat_idx then
+ lua_pat_idx = m
+ end
+ end
+
+ if #tld_elt.compose_rules >= lua_pat_idx then
+ local lua_pat = tld_elt.compose_rules[lua_pat_idx]
+ local matched, ret = lua_pat[2](url_tld, url_host)
+
+ if not matched then
+ lua_util.debugm(N, log_obj, 'NOT found compose inclusion for %s (%s) -> %s',
+ url_host, lua_pat[1], url_tld)
+
+ return url_tld
+ else
+ lua_util.debugm(N, log_obj, 'found compose inclusion for %s (%s) -> %s',
+ url_host, lua_pat[1], ret)
+
+ return ret
+ end
+ else
+ lua_util.debugm(N, log_obj, 'NOT found compose inclusion for %s (%s) -> %s',
+ url_host, lua_pat_idx, url_tld)
+
+ return url_tld
+ end
+ end
+ else
+ -- Match one by one
+ for _, lua_pat in ipairs(tld_elt.compose_rules) do
+ local matched, ret = lua_pat[2](url_tld, url_host)
+ if matched then
+ lua_util.debugm(N, log_obj, 'found compose inclusion for %s (%s) -> %s',
+ url_host, lua_pat[1], ret)
+
+ return ret
+ end
+ end
+ end
+
+ lua_util.debugm(N, log_obj, 'not found compose inclusion for %s in %s -> %s',
+ url_host, url_tld, url_tld)
+ else
+ lua_util.debugm(N, log_obj, 'not found compose tld for %s in %s -> %s',
+ url_host, url_tld, url_tld)
+ end
+
+ return url_tld
+end
+
+local function tld_pattern_transform(tld_pat)
+ -- Convert tld like pattern to a lua match pattern
+ -- blah -> %.blah
+ -- *.blah -> .*%.blah
+ local ret
+ if tld_pat:sub(1, 2) == '*.' then
+ ret = string.format('^((?:[^.]+\\.)*%s)$', tld_pat:sub(3))
+ else
+ ret = string.format('(?:^|\\.)((?:[^.]+\\.)?%s)$', tld_pat)
+ end
+
+ lua_util.debugm(N, nil, 'added pattern %s -> %s',
+ tld_pat, ret)
+
+ return ret
+end
+
+local function include_elt_gen(pat)
+ pat = rspamd_regexp.create(tld_pattern_transform(pat), 'i')
+ return function(_, host)
+ local matches = pat:search(host, false, true)
+ if matches then
+ return true, matches[1][2]
+ end
+
+ return false
+ end
+end
+
+local function exclude_elt_gen(pat)
+ pat = rspamd_regexp.create(tld_pattern_transform(pat))
+ return function(tld, host)
+ if pat:search(host) then
+ return true, tld
+ end
+
+ return false
+ end
+end
+
+local function compose_map_cb(self, map_text)
+ local lpeg = require "lpeg"
+
+ local singleline_comment = lpeg.P '#' * (1 - lpeg.S '\r\n\f') ^ 0
+ local comments_strip_grammar = lpeg.C((1 - lpeg.P '#') ^ 1) * lpeg.S(' \t') ^ 0 * singleline_comment ^ 0
+
+ local function process_tld_rule(tld_elt, l)
+ if l:sub(1, 1) == '!' then
+ -- Exclusion elt
+ table.insert(tld_elt.except_rules, { l, exclude_elt_gen(l:sub(2)) })
+ else
+ table.insert(tld_elt.compose_rules, { l, include_elt_gen(l) })
+ end
+ end
+
+ local function process_map_line(l)
+ -- Skip empty lines and comments
+ if #l == 0 then
+ return
+ end
+ l = comments_strip_grammar:match(l)
+ if not l or #l == 0 then
+ return
+ end
+
+ -- Get TLD
+ local tld = rspamd_util.get_tld(l)
+
+ if tld then
+ local tld_elt = self.tlds[tld]
+
+ if not tld_elt then
+ tld_elt = {
+ compose_rules = {},
+ except_rules = {},
+ multipattern_compose_rules = nil
+ }
+
+ lua_util.debugm(N, rspamd_config, 'processed new tld rule for %s', tld)
+ self.tlds[tld] = tld_elt
+ end
+
+ process_tld_rule(tld_elt, l)
+ else
+ lua_util.debugm(N, rspamd_config, 'cannot read tld from compose map line: %s', l)
+ end
+ end
+
+ for line in map_text:lines() do
+ process_map_line(line)
+ end
+
+ local multipattern_threshold = 1
+ for tld, tld_elt in pairs(self.tlds) do
+ -- Sort patterns to have longest labels before shortest ones,
+ -- so we can ensure that they match before
+ table.sort(tld_elt.compose_rules, function(e1, e2)
+ local _, ndots1 = string.gsub(e1[1], '(%.)', '')
+ local _, ndots2 = string.gsub(e2[1], '(%.)', '')
+
+ return ndots1 > ndots2
+ end)
+ if rspamd_trie.has_hyperscan() and #tld_elt.compose_rules >= multipattern_threshold then
+ lua_util.debugm(N, rspamd_config, 'tld %s has %s rules, apply multipattern',
+ tld, #tld_elt.compose_rules)
+ local flags = bit.bor(rspamd_trie.flags.re,
+ rspamd_trie.flags.dot_all,
+ rspamd_trie.flags.no_start,
+ rspamd_trie.flags.icase)
+
+
+ -- We now convert our internal patterns to multipattern patterns
+ local mp_table = fun.totable(fun.map(function(pat_elt)
+ return tld_pattern_transform(pat_elt[1])
+ end, tld_elt.compose_rules))
+ tld_elt.multipattern_compose_rules = rspamd_trie.create(mp_table, flags)
+ end
+ end
+end
+
+exports.add_composition_map = function(cfg, map_obj)
+ local hash_key = map_obj
+ if type(map_obj) == 'table' then
+ hash_key = lua_util.table_digest(map_obj)
+ end
+
+ local map = maps_cache[hash_key]
+
+ if not map then
+ local ret = {
+ process_url = process_url,
+ hash = hash_key,
+ tlds = {},
+ }
+
+ map = cfg:add_map {
+ type = 'callback',
+ description = 'URL compose map',
+ url = map_obj,
+ callback = function(input)
+ compose_map_cb(ret, input)
+ end,
+ opaque_data = true,
+ }
+
+ ret.map = map
+ maps_cache[hash_key] = ret
+ map = ret
+ end
+
+ return map
+end
+
+exports.inject_composition_rules = function(cfg, rules)
+ local hash_key = rules
+ local rspamd_text = require "rspamd_text"
+ if type(rules) == 'table' then
+ hash_key = lua_util.table_digest(rules)
+ end
+
+ local map = maps_cache[hash_key]
+
+ if not map then
+ local ret = {
+ process_url = process_url,
+ hash = hash_key,
+ tlds = {},
+ }
+
+ compose_map_cb(ret, rspamd_text.fromtable(rules, '\n'))
+ maps_cache[hash_key] = ret
+ map = ret
+ end
+
+ return map
+end
+
+return exports \ No newline at end of file
diff --git a/lualib/lua_util.lua b/lualib/lua_util.lua
new file mode 100644
index 0000000..6964b0f
--- /dev/null
+++ b/lualib/lua_util.lua
@@ -0,0 +1,1639 @@
+--[[
+Copyright (c) 2023, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_util
+-- This module contains utility functions for working with Lua and/or Rspamd
+--]]
+
+local exports = {}
+local lpeg = require 'lpeg'
+local rspamd_util = require "rspamd_util"
+local fun = require "fun"
+local lupa = require "lupa"
+
+local split_grammar = {}
+local spaces_split_grammar
+local space = lpeg.S ' \t\n\v\f\r'
+local nospace = 1 - space
+local ptrim = space ^ 0 * lpeg.C((space ^ 0 * nospace ^ 1) ^ 0)
+local match = lpeg.match
+
+lupa.configure('{%', '%}', '{=', '=}', '{#', '#}', {
+ keep_trailing_newline = true,
+ autoescape = false,
+})
+
+lupa.filters.pbkdf = function(s)
+ local cr = require "rspamd_cryptobox"
+ return cr.pbkdf(s)
+end
+
+local function rspamd_str_split(s, sep)
+ local gr
+ if not sep then
+ if not spaces_split_grammar then
+ local _sep = space
+ local elem = lpeg.C((1 - _sep) ^ 0)
+ local p = lpeg.Ct(elem * (_sep * elem) ^ 0)
+ spaces_split_grammar = p
+ end
+
+ gr = spaces_split_grammar
+ else
+ gr = split_grammar[sep]
+
+ if not gr then
+ local _sep
+ if type(sep) == 'string' then
+ _sep = lpeg.S(sep) -- Assume set
+ else
+ _sep = sep -- Assume lpeg object
+ end
+ local elem = lpeg.C((1 - _sep) ^ 0)
+ local p = lpeg.Ct(elem * (_sep * elem) ^ 0)
+ gr = p
+ split_grammar[sep] = gr
+ end
+ end
+
+ return gr:match(s)
+end
+
+--[[[
+-- @function lua_util.str_split(text, delimiter)
+-- Splits text into a numeric table by delimiter
+-- @param {string} text delimited text
+-- @param {string} delimiter the delimiter
+-- @return {table} numeric table containing string parts
+--]]
+
+exports.rspamd_str_split = rspamd_str_split
+exports.str_split = rspamd_str_split
+
+local function rspamd_str_trim(s)
+ return match(ptrim, s)
+end
+exports.rspamd_str_trim = rspamd_str_trim
+--[[[
+-- @function lua_util.str_trim(text)
+-- Returns a string with no trailing and leading spaces
+-- @param {string} text input text
+-- @return {string} string with no trailing and leading spaces
+--]]
+exports.str_trim = rspamd_str_trim
+
+--[[[
+-- @function lua_util.str_startswith(text, prefix)
+-- @param {string} text
+-- @param {string} prefix
+-- @return {boolean} true if text starts with the specified prefix, false otherwise
+--]]
+exports.str_startswith = function(s, prefix)
+ return s:sub(1, prefix:len()) == prefix
+end
+
+--[[[
+-- @function lua_util.str_endswith(text, suffix)
+-- @param {string} text
+-- @param {string} suffix
+-- @return {boolean} true if text ends with the specified suffix, false otherwise
+--]]
+exports.str_endswith = function(s, suffix)
+ return s:find(suffix, -suffix:len(), true) ~= nil
+end
+
+--[[[
+-- @function lua_util.round(number, decimalPlaces)
+-- Round number to fixed number of decimal points
+-- @param {number} number number to round
+-- @param {number} decimalPlaces number of decimal points
+-- @return {number} rounded number
+--]]
+
+-- modified version from Robert Jay Gould http://lua-users.org/wiki/SimpleRound
+exports.round = function(num, numDecimalPlaces)
+ local mult = 10 ^ (numDecimalPlaces or 0)
+ if num >= 0 then
+ return math.floor(num * mult + 0.5) / mult
+ else
+ return math.ceil(num * mult - 0.5) / mult
+ end
+end
+
+--[[[
+-- @function lua_util.template(text, replacements)
+-- Replaces values in a text template
+-- Variable names can contain letters, numbers and underscores, are prefixed with `$` and may or not use curly braces.
+-- @param {string} text text containing variables
+-- @param {table} replacements key/value pairs for replacements
+-- @return {string} string containing replaced values
+-- @example
+-- local goop = lua_util.template("HELLO $FOO ${BAR}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'})
+-- -- goop contains "HELLO LUA WORLD!"
+--]]
+
+exports.template = function(tmpl, keys)
+ local var_lit = lpeg.P { lpeg.R("az") + lpeg.R("AZ") + lpeg.R("09") + "_" }
+ local var = lpeg.P { (lpeg.P("$") / "") * ((var_lit ^ 1) / keys) }
+ local var_braced = lpeg.P { (lpeg.P("${") / "") * ((var_lit ^ 1) / keys) * (lpeg.P("}") / "") }
+
+ local template_grammar = lpeg.Cs((var + var_braced + 1) ^ 0)
+
+ return lpeg.match(template_grammar, tmpl)
+end
+
+local function enrich_template_with_globals(env)
+ local newenv = exports.shallowcopy(env)
+ newenv.paths = rspamd_paths
+ newenv.env = rspamd_env
+
+ return newenv
+end
+--[[[
+-- @function lua_util.jinja_template(text, env[, skip_global_env])
+-- Replaces values in a text template according to jinja2 syntax
+-- @param {string} text text containing variables
+-- @param {table} replacements key/value pairs for replacements
+-- @param {boolean} skip_global_env don't export Rspamd superglobals
+-- @return {string} string containing replaced values
+-- @example
+-- lua_util.jinja_template("HELLO {{FOO}} {{BAR}}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'})
+-- "HELLO LUA WORLD!"
+--]]
+exports.jinja_template = function(text, env, skip_global_env)
+ if not skip_global_env then
+ env = enrich_template_with_globals(env)
+ end
+
+ return lupa.expand(text, env)
+end
+
+--[[[
+-- @function lua_util.jinja_file(filename, env[, skip_global_env])
+-- Replaces values in a text template according to jinja2 syntax
+-- @param {string} filename name of file to expand
+-- @param {table} replacements key/value pairs for replacements
+-- @param {boolean} skip_global_env don't export Rspamd superglobals
+-- @return {string} string containing replaced values
+-- @example
+-- lua_util.jinja_template("HELLO {{FOO}} {{BAR}}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'})
+-- "HELLO LUA WORLD!"
+--]]
+exports.jinja_template_file = function(filename, env, skip_global_env)
+ if not skip_global_env then
+ env = enrich_template_with_globals(env)
+ end
+
+ return lupa.expand_file(filename, env)
+end
+
+exports.remove_email_aliases = function(email_addr)
+ local function check_gmail_user(addr)
+ -- Remove all points
+ local no_dots_user = string.gsub(addr.user, '%.', '')
+ local cap, pluses = string.match(no_dots_user, '^([^%+][^%+]*)(%+.*)$')
+ if cap then
+ return cap, rspamd_str_split(pluses, '+'), nil
+ elseif no_dots_user ~= addr.user then
+ return no_dots_user, {}, nil
+ end
+
+ return nil
+ end
+
+ local function check_address(addr)
+ if addr.user then
+ local cap, pluses = string.match(addr.user, '^([^%+][^%+]*)(%+.*)$')
+ if cap then
+ return cap, rspamd_str_split(pluses, '+'), nil
+ end
+ end
+
+ return nil
+ end
+
+ local function set_addr(addr, new_user, new_domain)
+ if new_user then
+ addr.user = new_user
+ end
+ if new_domain then
+ addr.domain = new_domain
+ end
+
+ if addr.domain then
+ addr.addr = string.format('%s@%s', addr.user, addr.domain)
+ else
+ addr.addr = string.format('%s@', addr.user)
+ end
+
+ if addr.name and #addr.name > 0 then
+ addr.raw = string.format('"%s" <%s>', addr.name, addr.addr)
+ else
+ addr.raw = string.format('<%s>', addr.addr)
+ end
+ end
+
+ local function check_gmail(addr)
+ local nu, tags, nd = check_gmail_user(addr)
+
+ if nu then
+ return nu, tags, nd
+ end
+
+ return nil
+ end
+
+ local function check_googlemail(addr)
+ local nd = 'gmail.com'
+ local nu, tags = check_gmail_user(addr)
+
+ if nu then
+ return nu, tags, nd
+ end
+
+ return nil, nil, nd
+ end
+
+ local specific_domains = {
+ ['gmail.com'] = check_gmail,
+ ['googlemail.com'] = check_googlemail,
+ }
+
+ if email_addr then
+ if email_addr.domain and specific_domains[email_addr.domain] then
+ local nu, tags, nd = specific_domains[email_addr.domain](email_addr)
+ if nu or nd then
+ set_addr(email_addr, nu, nd)
+
+ return nu, tags
+ end
+ else
+ local nu, tags, nd = check_address(email_addr)
+ if nu or nd then
+ set_addr(email_addr, nu, nd)
+
+ return nu, tags
+ end
+ end
+
+ return nil
+ end
+end
+
+exports.is_rspamc_or_controller = function(task)
+ local ua = task:get_request_header('User-Agent') or ''
+ local pwd = task:get_request_header('Password')
+ local is_rspamc = false
+ if tostring(ua) == 'rspamc' or pwd then
+ is_rspamc = true
+ end
+
+ return is_rspamc
+end
+
+--[[[
+-- @function lua_util.unpack(table)
+-- Converts numeric table to varargs
+-- This is `unpack` on Lua 5.1/5.2/LuaJIT and `table.unpack` on Lua 5.3
+-- @param {table} table numerically indexed table to unpack
+-- @return {varargs} unpacked table elements
+--]]
+
+local unpack_function = table.unpack or unpack
+exports.unpack = function(t)
+ return unpack_function(t)
+end
+
+--[[[
+-- @function lua_util.flatten(table)
+-- Flatten underlying tables in a single table
+-- @param {table} table table of tables
+-- @return {table} flattened table
+--]]
+exports.flatten = function(t)
+ local res = {}
+ for _, e in fun.iter(t) do
+ for _, v in fun.iter(e) do
+ res[#res + 1] = v
+ end
+ end
+
+ return res
+end
+
+--[[[
+-- @function lua_util.spairs(table)
+-- Like `pairs` but keys are sorted lexicographically
+-- @param {table} table table containing key/value pairs
+-- @return {function} generator function returning key/value pairs
+--]]
+
+-- Sorted iteration:
+-- for k,v in spairs(t) do ... end
+--
+-- or with custom comparison:
+-- for k, v in spairs(t, function(t, a, b) return t[a] < t[b] end)
+--
+-- optional limit is also available (e.g. return top X elements)
+local function spairs(t, order, lim)
+ -- collect the keys
+ local keys = {}
+ for k in pairs(t) do
+ keys[#keys + 1] = k
+ end
+
+ -- if order function given, sort by it by passing the table and keys a, b,
+ -- otherwise just sort the keys
+ if order then
+ table.sort(keys, function(a, b)
+ return order(t, a, b)
+ end)
+ else
+ table.sort(keys)
+ end
+
+ -- return the iterator function
+ local i = 0
+ return function()
+ i = i + 1
+ if not lim or i <= lim then
+ if keys[i] then
+ return keys[i], t[keys[i]]
+ end
+ end
+ end
+end
+
+exports.spairs = spairs
+
+local lua_cfg_utils = require "lua_cfg_utils"
+
+exports.config_utils = lua_cfg_utils
+exports.disable_module = lua_cfg_utils.disable_module
+
+--[[[
+-- @function lua_util.disable_module(modname)
+-- Checks experimental plugins state and disable if needed
+-- @param {string} modname name of plugin to check
+-- @return {boolean} true if plugin should be enabled, false otherwise
+--]]
+local function check_experimental(modname)
+ if rspamd_config:experimental_enabled() then
+ return true
+ else
+ lua_cfg_utils.disable_module(modname, 'experimental')
+ end
+
+ return false
+end
+
+exports.check_experimental = check_experimental
+
+--[[[
+-- @function lua_util.list_to_hash(list)
+-- Converts numerically-indexed table to table indexed by values
+-- @param {table} list numerically-indexed table or string, which is treated as a one-element list
+-- @return {table} table indexed by values
+-- @example
+-- local h = lua_util.list_to_hash({"a", "b"})
+-- -- h contains {a = true, b = true}
+--]]
+local function list_to_hash(list)
+ if type(list) == 'table' then
+ if list[1] then
+ local h = {}
+ for _, e in ipairs(list) do
+ h[e] = true
+ end
+ return h
+ else
+ return list
+ end
+ elseif type(list) == 'string' then
+ local h = {}
+ h[list] = true
+ return h
+ end
+end
+
+exports.list_to_hash = list_to_hash
+
+--[[[
+-- @function lua_util.nkeys(table|gen, param, state)
+-- Returns number of keys in a table (i.e. from both the array and hash parts combined)
+-- @param {table} list numerically-indexed table or string, which is treated as a one-element list
+-- @return {number} number of keys
+-- @example
+-- print(lua_util.nkeys({})) -- 0
+-- print(lua_util.nkeys({ "a", nil, "b" })) -- 2
+-- print(lua_util.nkeys({ dog = 3, cat = 4, bird = nil })) -- 2
+-- print(lua_util.nkeys({ "a", dog = 3, cat = 4 })) -- 3
+--
+--]]
+local function nkeys(gen, param, state)
+ local n = 0
+ if not param then
+ for _, _ in pairs(gen) do
+ n = n + 1
+ end
+ else
+ for _, _ in fun.iter(gen, param, state) do
+ n = n + 1
+ end
+ end
+ return n
+end
+
+exports.nkeys = nkeys
+
+--[[[
+-- @function lua_util.parse_time_interval(str)
+-- Parses human readable time interval
+-- Accepts 's' for seconds, 'm' for minutes, 'h' for hours, 'd' for days,
+-- 'w' for weeks, 'y' for years
+-- @param {string} str input string
+-- @return {number|nil} parsed interval as seconds (might be fractional)
+--]]
+local function parse_time_interval(str)
+ local function parse_time_suffix(s)
+ if s == 's' then
+ return 1
+ elseif s == 'm' then
+ return 60
+ elseif s == 'h' then
+ return 3600
+ elseif s == 'd' then
+ return 86400
+ elseif s == 'w' then
+ return 86400 * 7
+ elseif s == 'y' then
+ return 365 * 86400;
+ end
+ end
+
+ local digit = lpeg.R("09")
+ local parser = {}
+ parser.integer = (lpeg.S("+-") ^ -1) *
+ (digit ^ 1)
+ parser.fractional = (lpeg.P(".")) *
+ (digit ^ 1)
+ parser.number = (parser.integer *
+ (parser.fractional ^ -1)) +
+ (lpeg.S("+-") * parser.fractional)
+ parser.time = lpeg.Cf(lpeg.Cc(1) *
+ (parser.number / tonumber) *
+ ((lpeg.S("smhdwy") / parse_time_suffix) ^ -1),
+ function(acc, val)
+ return acc * val
+ end)
+
+ local t = lpeg.match(parser.time, str)
+
+ return t
+end
+
+exports.parse_time_interval = parse_time_interval
+
+--[[[
+-- @function lua_util.dehumanize_number(str)
+-- Parses human readable number
+-- Accepts 'k' for thousands, 'm' for millions, 'g' for billions, 'b' suffix for 1024 multiplier,
+-- e.g. `10mb` equal to `10 * 1024 * 1024`
+-- @param {string} str input string
+-- @return {number|nil} parsed number
+--]]
+local function dehumanize_number(str)
+ local function parse_suffix(s)
+ if s == 'k' then
+ return 1000
+ elseif s == 'm' then
+ return 1000000
+ elseif s == 'g' then
+ return 1e9
+ elseif s == 'kb' then
+ return 1024
+ elseif s == 'mb' then
+ return 1024 * 1024
+ elseif s == 'gb' then
+ return 1024 * 1024;
+ end
+ end
+
+ local digit = lpeg.R("09")
+ local parser = {}
+ parser.integer = (lpeg.S("+-") ^ -1) *
+ (digit ^ 1)
+ parser.fractional = (lpeg.P(".")) *
+ (digit ^ 1)
+ parser.number = (parser.integer *
+ (parser.fractional ^ -1)) +
+ (lpeg.S("+-") * parser.fractional)
+ parser.humanized_number = lpeg.Cf(lpeg.Cc(1) *
+ (parser.number / tonumber) *
+ (((lpeg.S("kmg") * (lpeg.P("b") ^ -1)) / parse_suffix) ^ -1),
+ function(acc, val)
+ return acc * val
+ end)
+
+ local t = lpeg.match(parser.humanized_number, str)
+
+ return t
+end
+
+exports.dehumanize_number = dehumanize_number
+
+--[[[
+-- @function lua_util.table_cmp(t1, t2)
+-- Compare two tables deeply
+--]]
+local function table_cmp(table1, table2)
+ local avoid_loops = {}
+ local function recurse(t1, t2)
+ if type(t1) ~= type(t2) then
+ return false
+ end
+ if type(t1) ~= "table" then
+ return t1 == t2
+ end
+
+ if avoid_loops[t1] then
+ return avoid_loops[t1] == t2
+ end
+ avoid_loops[t1] = t2
+ -- Copy keys from t2
+ local t2keys = {}
+ local t2tablekeys = {}
+ for k, _ in pairs(t2) do
+ if type(k) == "table" then
+ table.insert(t2tablekeys, k)
+ end
+ t2keys[k] = true
+ end
+ -- Let's iterate keys from t1
+ for k1, v1 in pairs(t1) do
+ local v2 = t2[k1]
+ if type(k1) == "table" then
+ -- if key is a table, we need to find an equivalent one.
+ local ok = false
+ for i, tk in ipairs(t2tablekeys) do
+ if table_cmp(k1, tk) and recurse(v1, t2[tk]) then
+ table.remove(t2tablekeys, i)
+ t2keys[tk] = nil
+ ok = true
+ break
+ end
+ end
+ if not ok then
+ return false
+ end
+ else
+ -- t1 has a key which t2 doesn't have, fail.
+ if v2 == nil then
+ return false
+ end
+ t2keys[k1] = nil
+ if not recurse(v1, v2) then
+ return false
+ end
+ end
+ end
+ -- if t2 has a key which t1 doesn't have, fail.
+ if next(t2keys) then
+ return false
+ end
+ return true
+ end
+ return recurse(table1, table2)
+end
+
+exports.table_cmp = table_cmp
+
+--[[[
+-- @function lua_util.table_merge(t1, t2)
+-- Merge two tables
+--]]
+local function table_merge(t1, t2)
+ local res = {}
+ local nidx = 1 -- for numeric indicies
+ local it_func = function(k, v)
+ if type(k) == 'number' then
+ res[nidx] = v
+ nidx = nidx + 1
+ else
+ res[k] = v
+ end
+ end
+ for k, v in pairs(t1) do
+ it_func(k, v)
+ end
+ for k, v in pairs(t2) do
+ it_func(k, v)
+ end
+ return res
+end
+
+exports.table_merge = table_merge
+
+--[[[
+-- @function lua_util.table_cmp(task, name, value, stop_chars)
+-- Performs header folding
+--]]
+exports.fold_header = function(task, name, value, stop_chars)
+
+ local how
+
+ if task:has_flag("milter") then
+ how = "lf"
+ else
+ how = task:get_newlines_type()
+ end
+
+ return rspamd_util.fold_header(name, value, how, stop_chars)
+end
+
+--[[[
+-- @function lua_util.override_defaults(defaults, override)
+-- Overrides values from defaults with override
+--]]
+local function override_defaults(def, override)
+ -- Corner cases
+ if not override or type(override) ~= 'table' then
+ return def
+ end
+ if not def or type(def) ~= 'table' then
+ return override
+ end
+
+ local res = {}
+
+ for k, v in pairs(override) do
+ if type(v) == 'table' then
+ if def[k] and type(def[k]) == 'table' then
+ -- Recursively override elements
+ res[k] = override_defaults(def[k], v)
+ else
+ res[k] = v
+ end
+ else
+ res[k] = v
+ end
+ end
+
+ for k, v in pairs(def) do
+ if type(res[k]) == 'nil' then
+ res[k] = v
+ end
+ end
+
+ return res
+end
+
+exports.override_defaults = override_defaults
+
+--[[[
+-- @function lua_util.filter_specific_urls(urls, params)
+-- params: {
+- - task - if needed to save in the cache
+- - limit <int> (default = 9999)
+- - esld_limit <int> (default = 9999) n domains per eSLD (effective second level domain)
+ works only if number of unique eSLD less than `limit`
+- - need_emails <bool> (default = false)
+- - filter <callback> (default = nil)
+- - prefix <string> cache prefix (default = nil)
+-- }
+-- Apply heuristic in extracting of urls from `urls` table, this function
+-- tries its best to extract specific number of urls from a task based on
+-- their characteristics
+--]]
+exports.filter_specific_urls = function(urls, params)
+ local cache_key
+
+ if params.task and not params.no_cache then
+ if params.prefix then
+ cache_key = params.prefix
+ else
+ cache_key = string.format('sp_urls_%d%s%s%s', params.limit,
+ tostring(params.need_emails or false),
+ tostring(params.need_images or false),
+ tostring(params.need_content or false))
+ end
+ local cached = params.task:cache_get(cache_key)
+
+ if cached then
+ return cached
+ end
+ end
+
+ if not urls then
+ return {}
+ end
+
+ if params.filter then
+ urls = fun.totable(fun.filter(params.filter, urls))
+ end
+
+ -- Filter by tld:
+ local tlds = {}
+ local eslds = {}
+ local ntlds, neslds = 0, 0
+
+ local res = {}
+ local nres = 0
+
+ local function insert_url(str, u)
+ if not res[str] then
+ res[str] = u
+ nres = nres + 1
+
+ return true
+ end
+
+ return false
+ end
+
+ local function process_single_url(u, default_priority)
+ local priority = default_priority or 1 -- Normal priority
+ local flags = u:get_flags()
+ if params.ignore_ip and flags.numeric then
+ return
+ end
+
+ if flags.redirected then
+ local redir = u:get_redirected() -- get the real url
+
+ if params.ignore_redirected then
+ -- Replace `u` with redir
+ u = redir
+ priority = 2
+ else
+ -- Process both redirected url and the original one
+ process_single_url(redir, 2)
+ end
+ end
+
+ if flags.image then
+ if not params.need_images then
+ -- Ignore url
+ return
+ else
+ -- Penalise images in urls
+ priority = 0
+ end
+ end
+
+ local esld = u:get_tld()
+ local str_hash = tostring(u)
+
+ if esld then
+ -- Special cases
+ if (u:get_protocol() ~= 'mailto') and (not flags.html_displayed) then
+ if flags.obscured then
+ priority = 3
+ else
+ if (flags.has_user or flags.has_port) then
+ priority = 2
+ elseif (flags.subject or flags.phished) then
+ priority = 2
+ end
+ end
+ elseif flags.html_displayed then
+ priority = 0
+ end
+
+ if not eslds[esld] then
+ eslds[esld] = { { str_hash, u, priority } }
+ neslds = neslds + 1
+ else
+ if #eslds[esld] < params.esld_limit then
+ table.insert(eslds[esld], { str_hash, u, priority })
+ end
+ end
+
+
+ -- eSLD - 1 part => tld
+ local parts = rspamd_str_split(esld, '.')
+ local tld = table.concat(fun.totable(fun.tail(parts)), '.')
+
+ if not tlds[tld] then
+ tlds[tld] = { { str_hash, u, priority } }
+ ntlds = ntlds + 1
+ else
+ table.insert(tlds[tld], { str_hash, u, priority })
+ end
+ end
+ end
+
+ for _, u in ipairs(urls) do
+ process_single_url(u)
+ end
+
+ local limit = params.limit
+ limit = limit - nres
+ if limit < 0 then
+ limit = 0
+ end
+
+ if limit == 0 then
+ res = exports.values(res)
+ if params.task and not params.no_cache then
+ params.task:cache_set(cache_key, res)
+ end
+ return res
+ end
+
+ -- Sort eSLDs and tlds
+ local function sort_stuff(tbl)
+ -- Sort according to max priority
+ table.sort(tbl, function(e1, e2)
+ -- Sort by priority so max priority is at the end
+ table.sort(e1, function(tr1, tr2)
+ return tr1[3] < tr2[3]
+ end)
+ table.sort(e2, function(tr1, tr2)
+ return tr1[3] < tr2[3]
+ end)
+
+ if e1[#e1][3] ~= e2[#e2][3] then
+ -- Sort by priority so max priority is at the beginning
+ return e1[#e1][3] > e2[#e2][3]
+ else
+ -- Prefer less urls to more urls per esld
+ return #e1 < #e2
+ end
+
+ end)
+
+ return tbl
+ end
+
+ eslds = sort_stuff(exports.values(eslds))
+ neslds = #eslds
+
+ if neslds <= limit then
+ -- Number of eslds < limit
+ repeat
+ local item_found = false
+
+ for _, lurls in ipairs(eslds) do
+ if #lurls > 0 then
+ local last = table.remove(lurls)
+ insert_url(last[1], last[2])
+ limit = limit - 1
+ item_found = true
+ end
+ end
+
+ until limit <= 0 or not item_found
+
+ res = exports.values(res)
+ if params.task and not params.no_cache then
+ params.task:cache_set(cache_key, res)
+ end
+ return res
+ end
+
+ tlds = sort_stuff(exports.values(tlds))
+ ntlds = #tlds
+
+ -- Number of tlds < limit
+ while limit > 0 do
+ for _, lurls in ipairs(tlds) do
+ if #lurls > 0 then
+ local last = table.remove(lurls)
+ insert_url(last[1], last[2])
+ limit = limit - 1
+ end
+ if limit == 0 then
+ break
+ end
+ end
+ end
+
+ res = exports.values(res)
+ if params.task and not params.no_cache then
+ params.task:cache_set(cache_key, res)
+ end
+ return res
+end
+
+--[[[
+-- @function lua_util.extract_specific_urls(params)
+-- params: {
+- - task
+- - limit <int> (default = 9999)
+- - esld_limit <int> (default = 9999) n domains per eSLD (effective second level domain)
+ works only if number of unique eSLD less than `limit`
+- - need_emails <bool> (default = false)
+- - filter <callback> (default = nil)
+- - prefix <string> cache prefix (default = nil)
+- - ignore_redirected <bool> (default = false)
+- - need_images <bool> (default = false)
+- - need_content <bool> (default = false)
+-- }
+-- Apply heuristic in extracting of urls from task, this function
+-- tries its best to extract specific number of urls from a task based on
+-- their characteristics
+--]]
+-- exports.extract_specific_urls = function(params_or_task, limit, need_emails, filter, prefix)
+exports.extract_specific_urls = function(params_or_task, lim, need_emails, filter, prefix)
+ local default_params = {
+ limit = 9999,
+ esld_limit = 9999,
+ need_emails = false,
+ need_images = false,
+ need_content = false,
+ filter = nil,
+ prefix = nil,
+ ignore_ip = false,
+ ignore_redirected = false,
+ no_cache = false,
+ }
+
+ local params
+ if type(params_or_task) == 'table' and type(lim) == 'nil' then
+ params = params_or_task
+ else
+ -- Deprecated call
+ params = {
+ task = params_or_task,
+ limit = lim,
+ need_emails = need_emails,
+ filter = filter,
+ prefix = prefix
+ }
+ end
+ for k, v in pairs(default_params) do
+ if type(params[k]) == 'nil' and v ~= nil then
+ params[k] = v
+ end
+ end
+ local url_params = {
+ emails = params.need_emails,
+ images = params.need_images,
+ content = params.need_content,
+ flags = params.flags, -- maybe nil
+ flags_mode = params.flags_mode, -- maybe nil
+ }
+
+ -- Shortcut for cached stuff
+ if params.task and not params.no_cache then
+ local cache_key
+ if params.prefix then
+ cache_key = params.prefix
+ else
+ local cache_key_suffix
+ if params.flags then
+ cache_key_suffix = table.concat(params.flags) .. (params.flags_mode or '')
+ else
+ cache_key_suffix = string.format('%s%s%s',
+ tostring(params.need_emails or false),
+ tostring(params.need_images or false),
+ tostring(params.need_content or false))
+ end
+ cache_key = string.format('sp_urls_%d%s', params.limit, cache_key_suffix)
+ end
+ local cached = params.task:cache_get(cache_key)
+
+ if cached then
+ return cached
+ end
+ end
+
+ -- No cache version
+ local urls = params.task:get_urls(url_params)
+
+ return exports.filter_specific_urls(urls, params)
+end
+
+--[[[
+-- @function lua_util.deepcopy(table)
+-- params: {
+- - table
+-- }
+-- Performs deep copy of the table. Including metatables
+--]]
+local function deepcopy(orig)
+ local orig_type = type(orig)
+ local copy
+ if orig_type == 'table' then
+ copy = {}
+ for orig_key, orig_value in next, orig, nil do
+ copy[deepcopy(orig_key)] = deepcopy(orig_value)
+ end
+ if getmetatable(orig) then
+ setmetatable(copy, deepcopy(getmetatable(orig)))
+ end
+ else
+ -- number, string, boolean, etc
+ copy = orig
+ end
+ return copy
+end
+
+exports.deepcopy = deepcopy
+
+--[[[
+-- @function lua_util.deepsort(table)
+-- params: {
+- - table
+-- }
+-- Performs recursive in-place sort of a table
+--]]
+local function default_sort_cmp(e1, e2)
+ if type(e1) == type(e2) then
+ return e1 < e2
+ else
+ return type(e1) < type(e2)
+ end
+end
+
+local function deepsort(tbl, sort_func)
+ local orig_type = type(tbl)
+ if orig_type == 'table' then
+ table.sort(tbl, sort_func or default_sort_cmp)
+ for _, orig_value in next, tbl, nil do
+ deepsort(orig_value)
+ end
+ end
+end
+
+exports.deepsort = deepsort
+
+--[[[
+-- @function lua_util.shallowcopy(tbl)
+-- Performs shallow (and fast) copy of a table or another Lua type
+--]]
+exports.shallowcopy = function(orig)
+ local orig_type = type(orig)
+ local copy
+ if orig_type == 'table' then
+ copy = {}
+ for orig_key, orig_value in pairs(orig) do
+ copy[orig_key] = orig_value
+ end
+ else
+ copy = orig
+ end
+ return copy
+end
+
+-- Debugging support
+local logger = require "rspamd_logger"
+local unconditional_debug = logger.log_level() == 'debug'
+local debug_modules = {}
+local debug_aliases = {}
+local log_level = 384 -- debug + forced (1 << 7 | 1 << 8)
+
+
+exports.init_debug_logging = function(config)
+ -- Fill debug modules from the config
+ if not unconditional_debug then
+ local log_config = config:get_all_opt('logging')
+ if log_config then
+ local log_level_str = log_config.level
+ if log_level_str then
+ if log_level_str == 'debug' then
+ unconditional_debug = true
+ end
+ end
+ if log_config.debug_modules then
+ for _, m in ipairs(log_config.debug_modules) do
+ debug_modules[m] = true
+ logger.infox(config, 'enable debug for Lua module %s', m)
+ end
+ end
+
+ if #debug_aliases > 0 then
+ for alias, mod in pairs(debug_aliases) do
+ if debug_modules[mod] then
+ debug_modules[alias] = true
+ logger.infox(config, 'enable debug for Lua module %s (%s aliased)',
+ alias, mod)
+ end
+ end
+ end
+ end
+ end
+end
+
+exports.enable_debug_logging = function()
+ unconditional_debug = true
+end
+
+exports.enable_debug_modules = function(...)
+ for _, m in ipairs({ ... }) do
+ debug_modules[m] = true
+ end
+end
+
+exports.disable_debug_logging = function()
+ unconditional_debug = false
+end
+
+--[[[
+-- @function lua_util.debugm(module, [log_object], format, ...)
+-- Performs fast debug log for a specific module
+--]]
+exports.debugm = function(mod, obj_or_fmt, fmt_or_something, ...)
+ if unconditional_debug or debug_modules[mod] then
+ if type(obj_or_fmt) == 'string' then
+ logger.logx(log_level, mod, '', 2, obj_or_fmt, fmt_or_something, ...)
+ else
+ logger.logx(log_level, mod, obj_or_fmt, 2, fmt_or_something, ...)
+ end
+ end
+end
+
+--[[[
+-- @function lua_util.add_debug_alias(mod, alias)
+-- Add debugging alias so logging to `alias` will be treated as logging to `mod`
+--]]
+exports.add_debug_alias = function(mod, alias)
+ debug_aliases[alias] = mod
+
+ if debug_modules[mod] then
+ debug_modules[alias] = true
+ logger.infox(rspamd_config, 'enable debug for Lua module %s (%s aliased)',
+ alias, mod)
+ end
+end
+---[[[
+-- @function lua_util.get_task_verdict(task)
+-- Returns verdict for a task + score if certain, must be called from idempotent filters only
+-- Returns string:
+-- * `spam`: if message have over reject threshold and has more than one positive rule
+-- * `junk`: if a message has between score between [add_header/rewrite subject] to reject thresholds and has more than two positive rules
+-- * `passthrough`: if a message has been passed through some short-circuit rule
+-- * `ham`: if a message has overall score below junk level **and** more than three negative rule, or negative total score
+-- * `uncertain`: all other cases
+--]]
+exports.get_task_verdict = function(task)
+ local lua_verdict = require "lua_verdict"
+
+ return lua_verdict.get_default_verdict(task)
+end
+
+---[[[
+-- @function lua_util.maybe_obfuscate_string(subject, settings, prefix)
+-- Obfuscate string if enabled in settings. Also checks utf8 validity - if
+-- string is not valid utf8 then '???' is returned. Empty string returned as is.
+-- Supported settings:
+-- * <prefix>_privacy = false - subject privacy is off
+-- * <prefix>_privacy_alg = 'blake2' - default hash-algorithm to obfuscate subject
+-- * <prefix>_privacy_prefix = 'obf' - prefix to show it's obfuscated
+-- * <prefix>_privacy_length = 16 - cut the length of the hash; if 0 or fasle full hash is returned
+-- @return obfuscated or validated subject
+--]]
+
+exports.maybe_obfuscate_string = function(subject, settings, prefix)
+ local hash = require 'rspamd_cryptobox_hash'
+ if not subject or subject == '' then
+ return subject
+ elseif not rspamd_util.is_valid_utf8(subject) then
+ subject = '???'
+ elseif settings[prefix .. '_privacy'] then
+ local hash_alg = settings[prefix .. '_privacy_alg'] or 'blake2'
+ local subject_hash = hash.create_specific(hash_alg, subject)
+
+ local strip_len = settings[prefix .. '_privacy_length']
+ if strip_len and strip_len > 0 then
+ subject = subject_hash:hex():sub(1, strip_len)
+ else
+ subject = subject_hash:hex()
+ end
+
+ local privacy_prefix = settings[prefix .. '_privacy_prefix']
+ if privacy_prefix and #privacy_prefix > 0 then
+ subject = privacy_prefix .. ':' .. subject
+ end
+ end
+
+ return subject
+end
+
+---[[[
+-- @function lua_util.callback_from_string(str)
+-- Converts a string like `return function(...) end` to lua function and return true and this function
+-- or returns false + error message
+-- @return status code and function object or an error message
+--]]]
+exports.callback_from_string = function(s)
+ local loadstring = loadstring or load
+
+ if not s or #s == 0 then
+ return false, 'invalid or empty string'
+ end
+
+ s = exports.rspamd_str_trim(s)
+ local inp
+
+ if s:match('^return%s*function') then
+ -- 'return function', can be evaluated directly
+ inp = s
+ elseif s:match('^function%s*%(') then
+ inp = 'return ' .. s
+ else
+ -- Just a plain sequence
+ inp = 'return function(...)\n' .. s .. '; end'
+ end
+
+ local ret, res_or_err = pcall(loadstring(inp))
+
+ if not ret or type(res_or_err) ~= 'function' then
+ return false, res_or_err
+ end
+
+ return ret, res_or_err
+end
+
+---[[[
+-- @function lua_util.keys(t)
+-- Returns all keys from a specific table
+-- @param {table} t input table (or iterator triplet)
+-- @return array of keys
+--]]]
+exports.keys = function(gen, param, state)
+ local keys = {}
+ local i = 1
+
+ if param then
+ for k, _ in fun.iter(gen, param, state) do
+ rawset(keys, i, k)
+ i = i + 1
+ end
+ else
+ for k, _ in pairs(gen) do
+ rawset(keys, i, k)
+ i = i + 1
+ end
+ end
+
+ return keys
+end
+
+---[[[
+-- @function lua_util.values(t)
+-- Returns all values from a specific table
+-- @param {table} t input table
+-- @return array of values
+--]]]
+exports.values = function(gen, param, state)
+ local values = {}
+ local i = 1
+
+ if param then
+ for _, v in fun.iter(gen, param, state) do
+ rawset(values, i, v)
+ i = i + 1
+ end
+ else
+ for _, v in pairs(gen) do
+ rawset(values, i, v)
+ i = i + 1
+ end
+ end
+
+ return values
+end
+
+---[[[
+-- @function lua_util.distance_sorted(t1, t2)
+-- Returns distance between two sorted tables t1 and t2
+-- @param {table} t1 input table
+-- @param {table} t2 input table
+-- @return distance between `t1` and `t2`
+--]]]
+exports.distance_sorted = function(t1, t2)
+ local ncomp = #t1
+ local ndiff = 0
+ local i, j = 1, 1
+
+ if ncomp < #t2 then
+ ncomp = #t2
+ end
+
+ for _ = 1, ncomp do
+ if j > #t2 then
+ ndiff = ndiff + ncomp - #t2
+ if i > j then
+ ndiff = ndiff - (i - j)
+ end
+ break
+ elseif i > #t1 then
+ ndiff = ndiff + ncomp - #t1
+ if j > i then
+ ndiff = ndiff - (j - i)
+ end
+ break
+ end
+
+ if t1[i] == t2[j] then
+ i = i + 1
+ j = j + 1
+ elseif t1[i] < t2[j] then
+ i = i + 1
+ ndiff = ndiff + 1
+ else
+ j = j + 1
+ ndiff = ndiff + 1
+ end
+ end
+
+ return ndiff
+end
+
+---[[[
+-- @function lua_util.table_digest(t)
+-- Returns hash of all values if t[1] is string or all keys/values otherwise
+-- @param {table} t input array or map
+-- @return {string} base32 representation of blake2b hash of all strings
+--]]]
+local function table_digest(t)
+ local cr = require "rspamd_cryptobox_hash"
+ local h = cr.create()
+
+ if t[1] then
+ for _, e in ipairs(t) do
+ if type(e) == 'table' then
+ h:update(table_digest(e))
+ else
+ h:update(tostring(e))
+ end
+ end
+ else
+ for k, v in pairs(t) do
+ h:update(tostring(k))
+
+ if type(v) == 'string' then
+ h:update(v)
+ elseif type(v) == 'table' then
+ h:update(table_digest(v))
+ end
+ end
+ end
+ return h:base32()
+end
+
+exports.table_digest = table_digest
+
+---[[[
+-- @function lua_util.toboolean(v)
+-- Converts a string or a number to boolean
+-- @param {string|number} v
+-- @return {boolean} v converted to boolean
+--]]]
+exports.toboolean = function(v)
+ local true_t = {
+ ['1'] = true,
+ ['true'] = true,
+ ['TRUE'] = true,
+ ['True'] = true,
+ };
+ local false_t = {
+ ['0'] = false,
+ ['false'] = false,
+ ['FALSE'] = false,
+ ['False'] = false,
+ };
+
+ if type(v) == 'string' then
+ if true_t[v] == true then
+ return true;
+ elseif false_t[v] == false then
+ return false;
+ else
+ return false, string.format('cannot convert %q to boolean', v);
+ end
+ elseif type(v) == 'number' then
+ return v ~= 0
+ else
+ return false, string.format('cannot convert %q to boolean', v);
+ end
+end
+
+---[[[
+-- @function lua_util.config_check_local_or_authed(config, modname)
+-- Reads check_local and check_authed from the config as this is used in many modules
+-- @param {rspamd_config} config `rspamd_config` global
+-- @param {name} module name
+-- @return {boolean} v converted to boolean
+--]]]
+exports.config_check_local_or_authed = function(rspamd_config, modname, def_local, def_authed)
+ local check_local = def_local or false
+ local check_authed = def_authed or false
+
+ local function try_section(where)
+ local ret = false
+ local opts = rspamd_config:get_all_opt(where)
+ if type(opts) == 'table' then
+ if type(opts['check_local']) == 'boolean' then
+ check_local = opts['check_local']
+ ret = true
+ end
+ if type(opts['check_authed']) == 'boolean' then
+ check_authed = opts['check_authed']
+ ret = true
+ end
+ end
+
+ return ret
+ end
+
+ if not try_section(modname) then
+ try_section('options')
+ end
+
+ return { check_local, check_authed }
+end
+
+---[[[
+-- @function lua_util.is_skip_local_or_authed(task, conf[, ip])
+-- Returns `true` if local or authenticated task should be skipped for this module
+-- @param {rspamd_task} task
+-- @param {table} conf table returned from `config_check_local_or_authed`
+-- @param {rspamd_ip} ip optional ip address (can be obtained from a task)
+-- @return {boolean} true if check should be skipped
+--]]]
+exports.is_skip_local_or_authed = function(task, conf, ip)
+ if not ip then
+ ip = task:get_from_ip()
+ end
+ if not conf then
+ conf = { false, false }
+ end
+ if ((not conf[2] and task:get_user()) or
+ (not conf[1] and type(ip) == 'userdata' and ip:is_local())) then
+ return true
+ end
+
+ return false
+end
+
+---[[[
+-- @function lua_util.maybe_smtp_quote_value(str)
+-- Checks string for the forbidden elements (tspecials in RFC and quote string if needed)
+-- @param {string} str input string
+-- @return {string} original or quoted string
+--]]]
+local tspecial = lpeg.S "()<>,;:\\\"/[]?= \t\v"
+local special_match = lpeg.P((1 - tspecial) ^ 0 * tspecial ^ 1)
+exports.maybe_smtp_quote_value = function(str)
+ if special_match:match(str) then
+ return string.format('"%s"', str:gsub('"', '\\"'))
+ end
+
+ return str
+end
+
+---[[[
+-- @function lua_util.shuffle(table)
+-- Performs in-place shuffling of a table
+-- @param {table} tbl table to shuffle
+-- @return {table} same table
+--]]]
+exports.shuffle = function(tbl)
+ local size = #tbl
+ for i = size, 1, -1 do
+ local rand = math.random(size)
+ tbl[i], tbl[rand] = tbl[rand], tbl[i]
+ end
+ return tbl
+end
+
+--
+local hex_table = {}
+for idx = 0, 255 do
+ hex_table[("%02X"):format(idx)] = string.char(idx)
+ hex_table[("%02x"):format(idx)] = string.char(idx)
+end
+
+---[[[
+-- @function lua_util.unhex(str)
+-- Decode hex encoded string
+-- @param {string} str string to decode
+-- @return {string} hex decoded string (valid hex pairs are decoded, everything else is printed as is)
+--]]]
+exports.unhex = function(str)
+ return str:gsub('(..)', hex_table)
+end
+
+local http_upstream_lists = {}
+local function http_upstreams_by_url(pool, url)
+ local rspamd_url = require "rspamd_url"
+
+ local cached = http_upstream_lists[url]
+ if cached then
+ return cached
+ end
+
+ local real_url = rspamd_url.create(pool, url)
+
+ if not real_url then
+ return nil
+ end
+
+ local host = real_url:get_host()
+ local proto = real_url:get_protocol() or 'http'
+ local port = real_url:get_port() or (proto == 'https' and 443 or 80)
+ local upstream_list = require "rspamd_upstream_list"
+ local upstreams = upstream_list.create(host, port)
+
+ if upstreams then
+ http_upstream_lists[url] = upstreams
+ return upstreams
+ end
+
+ return nil
+end
+---[[[
+-- @function lua_util.http_upstreams_by_url(pool, url)
+-- Returns a cached or new upstreams list that corresponds to the specific url
+-- @param {mempool} pool memory pool to use (typically static pool from rspamd_config)
+-- @param {string} url full url
+-- @return {upstreams_list} object to get upstream from an url
+--]]]
+exports.http_upstreams_by_url = http_upstreams_by_url
+
+---[[[
+-- @function lua_util.dns_timeout_augmentation(cfg)
+-- Returns an augmentation suitable to define DNS timeout for a module
+-- @return {string} a string in format 'timeout=x' where `x` is a number of seconds for DNS timeout
+--]]]
+local function dns_timeout_augmentation(cfg)
+ return string.format('timeout=%f', cfg:get_dns_timeout() or 0.0)
+end
+
+exports.dns_timeout_augmentation = dns_timeout_augmentation
+
+---[[[
+--- @function lua_util.strip_lua_comments(lua_code)
+-- Strips single-line and multi-line comments from a given Lua code string and removes
+-- any extra spaces or newlines.
+--
+-- @param lua_code The Lua code string to strip comments from.
+-- @return The resulting Lua code string with comments and extra spaces removed.
+--
+---]]]
+local function strip_lua_comments(lua_code)
+ -- Remove single-line comments
+ lua_code = lua_code:gsub("%-%-[^\r\n]*", "")
+
+ -- Remove multi-line comments
+ lua_code = lua_code:gsub("%-%-%[%[.-%]%]", "")
+
+ -- Remove extra spaces and newlines
+ lua_code = lua_code:gsub("%s+", " ")
+
+ return lua_code
+end
+
+exports.strip_lua_comments = strip_lua_comments
+
+---[[[
+-- @function lua_util.join_path(...)
+-- Joins path components into a single path string using the appropriate separator
+-- for the current operating system.
+--
+-- @param ... Any number of path components to join together.
+-- @return A single path string, with components separated by the appropriate separator.
+--
+---]]]
+local path_sep = package.config:sub(1, 1) or '/'
+local function join_path(...)
+ local components = { ... }
+
+ -- Join components using separator
+ return table.concat(components, path_sep)
+end
+exports.join_path = join_path
+
+-- Short unit test for sanity
+if path_sep == '/' then
+ assert(join_path('/path', 'to', 'file') == '/path/to/file')
+else
+ assert(join_path('C:', 'path', 'to', 'file') == 'C:\\path\\to\\file')
+end
+
+-- Defines symbols priorities for common usage in prefilters/postfilters
+exports.symbols_priorities = {
+ top = 10, -- Symbols must be executed first (or last), such as settings
+ high = 9, -- Example: asn
+ medium = 5, -- Everything should use this as default
+ low = 0,
+}
+
+return exports
diff --git a/lualib/lua_verdict.lua b/lualib/lua_verdict.lua
new file mode 100644
index 0000000..6ce99e6
--- /dev/null
+++ b/lualib/lua_verdict.lua
@@ -0,0 +1,208 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local exports = {}
+
+---[[[
+-- @function lua_verdict.get_default_verdict(task)
+-- Returns verdict for a task + score if certain, must be called from idempotent filters only
+-- Returns string:
+-- * `spam`: if message have over reject threshold and has more than one positive rule
+-- * `junk`: if a message has between score between [add_header/rewrite subject] to reject thresholds and has more than two positive rules
+-- * `passthrough`: if a message has been passed through some short-circuit rule
+-- * `ham`: if a message has overall score below junk level **and** more than three negative rule, or negative total score
+-- * `uncertain`: all other cases
+--]]
+local function default_verdict_function(task)
+ local result = task:get_metric_result()
+
+ if result then
+
+ if result.passthrough then
+ return 'passthrough', nil
+ end
+
+ local score = result.score
+
+ local action = result.action
+
+ if action == 'reject' and result.npositive > 1 then
+ return 'spam', score
+ elseif action == 'no action' then
+ if score < 0 or result.nnegative > 3 then
+ return 'ham', score
+ end
+ else
+ -- All colors of junk
+ if action == 'add header' or action == 'rewrite subject' then
+ if result.npositive > 2 then
+ return 'junk', score
+ end
+ end
+ end
+
+ return 'uncertain', score
+ end
+end
+
+local default_possible_verdicts = {
+ passthrough = {
+ can_learn = false,
+ description = 'message has passthrough result',
+ },
+ spam = {
+ can_learn = 'spam',
+ description = 'message is likely spam',
+ },
+ junk = {
+ can_learn = 'spam',
+ description = 'message is likely possible spam',
+ },
+ ham = {
+ can_learn = 'ham',
+ description = 'message is likely ham',
+ },
+ uncertain = {
+ can_learn = false,
+ description = 'not certainty in verdict'
+ }
+}
+
+-- Verdict functions specific for modules
+local specific_verdicts = {
+ default = {
+ callback = default_verdict_function,
+ possible_verdicts = default_possible_verdicts
+ }
+}
+
+local default_verdict = specific_verdicts.default
+
+exports.get_default_verdict = default_verdict.callback
+exports.set_verdict_function = function(func, what)
+ assert(type(func) == 'function')
+ if not what then
+ -- Default verdict
+ local existing = specific_verdicts.default.callback
+ specific_verdicts.default.callback = func
+ exports.get_default_verdict = func
+
+ return existing
+ else
+ local existing = specific_verdicts[what]
+
+ if not existing then
+ specific_verdicts[what] = {
+ callback = func,
+ possible_verdicts = default_possible_verdicts
+ }
+ else
+ existing = existing.callback
+ end
+
+ specific_verdicts[what].callback = func
+ return existing
+ end
+end
+
+exports.set_verdict_table = function(verdict_tbl, what)
+ assert(type(verdict_tbl) == 'table' and
+ type(verdict_tbl.callback) == 'function' and
+ type(verdict_tbl.possible_verdicts) == 'table')
+
+ if not what then
+ -- Default verdict
+ local existing = specific_verdicts.default
+ specific_verdicts.default = verdict_tbl
+ exports.get_default_verdict = specific_verdicts.default.callback
+
+ return existing
+ else
+ local existing = specific_verdicts[what]
+ specific_verdicts[what] = verdict_tbl
+ return existing
+ end
+end
+
+exports.get_specific_verdict = function(what, task)
+ if specific_verdicts[what] then
+ return specific_verdicts[what].callback(task)
+ end
+
+ return exports.get_default_verdict(task)
+end
+
+exports.get_possible_verdicts = function(what)
+ local lua_util = require "lua_util"
+ if what then
+ if specific_verdicts[what] then
+ return lua_util.keys(specific_verdicts[what].possible_verdicts)
+ end
+ else
+ return lua_util.keys(specific_verdicts.default.possible_verdicts)
+ end
+
+ return nil
+end
+
+exports.can_learn = function(verdict, what)
+ if what then
+ if specific_verdicts[what] and specific_verdicts[what].possible_verdicts[verdict] then
+ return specific_verdicts[what].possible_verdicts[verdict].can_learn
+ end
+ else
+ if specific_verdicts.default.possible_verdicts[verdict] then
+ return specific_verdicts.default.possible_verdicts[verdict].can_learn
+ end
+ end
+
+ return nil -- To distinguish from `false` that could happen in can_learn
+end
+
+exports.describe = function(verdict, what)
+ if what then
+ if specific_verdicts[what] and specific_verdicts[what].possible_verdicts[verdict] then
+ return specific_verdicts[what].possible_verdicts[verdict].description
+ end
+ else
+ if specific_verdicts.default.possible_verdicts[verdict] then
+ return specific_verdicts.default.possible_verdicts[verdict].description
+ end
+ end
+
+ return nil
+end
+
+---[[[
+-- @function lua_verdict.adjust_passthrough_action(task)
+-- If an action is `soft reject` then this function extracts a module that has set this action
+-- and returns an adjusted action (e.g. 'greylist' or 'ratelimit').
+-- Otherwise an action is returned as is.
+--]]
+exports.adjust_passthrough_action = function(task)
+ local action = task:get_metric_action()
+ if action == 'soft reject' then
+ local has_pr, _, _, module = task:has_pre_result()
+
+ if has_pr and module then
+ action = module
+ end
+ end
+
+ return action
+end
+
+return exports \ No newline at end of file
diff --git a/lualib/plugins/dmarc.lua b/lualib/plugins/dmarc.lua
new file mode 100644
index 0000000..7791f4e
--- /dev/null
+++ b/lualib/plugins/dmarc.lua
@@ -0,0 +1,359 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+Copyright (c) 2015-2016, Andrew Lewis <nerf@judo.za.org>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+-- Common dmarc stuff
+local rspamd_logger = require "rspamd_logger"
+local lua_util = require "lua_util"
+local N = "dmarc"
+
+local exports = {}
+
+exports.default_settings = {
+ auth_and_local_conf = false,
+ symbols = {
+ spf_allow_symbol = 'R_SPF_ALLOW',
+ spf_deny_symbol = 'R_SPF_FAIL',
+ spf_softfail_symbol = 'R_SPF_SOFTFAIL',
+ spf_neutral_symbol = 'R_SPF_NEUTRAL',
+ spf_tempfail_symbol = 'R_SPF_DNSFAIL',
+ spf_permfail_symbol = 'R_SPF_PERMFAIL',
+ spf_na_symbol = 'R_SPF_NA',
+
+ dkim_allow_symbol = 'R_DKIM_ALLOW',
+ dkim_deny_symbol = 'R_DKIM_REJECT',
+ dkim_tempfail_symbol = 'R_DKIM_TEMPFAIL',
+ dkim_na_symbol = 'R_DKIM_NA',
+ dkim_permfail_symbol = 'R_DKIM_PERMFAIL',
+
+ -- DMARC symbols
+ allow = 'DMARC_POLICY_ALLOW',
+ badpolicy = 'DMARC_BAD_POLICY',
+ dnsfail = 'DMARC_DNSFAIL',
+ na = 'DMARC_NA',
+ reject = 'DMARC_POLICY_REJECT',
+ softfail = 'DMARC_POLICY_SOFTFAIL',
+ quarantine = 'DMARC_POLICY_QUARANTINE',
+ },
+ no_sampling_domains = nil,
+ no_reporting_domains = nil,
+ reporting = {
+ report_local_controller = false, -- Store reports for local/controller scans (for testing only)
+ redis_keys = {
+ index_prefix = 'dmarc_idx',
+ report_prefix = 'dmarc_rpt',
+ join_char = ';',
+ },
+ helo = 'rspamd.localhost',
+ smtp = '127.0.0.1',
+ smtp_port = 25,
+ retries = 2,
+ from_name = 'Rspamd',
+ msgid_from = 'rspamd',
+ enabled = false,
+ max_entries = 1000,
+ keys_expire = 172800,
+ only_domains = nil,
+ },
+ actions = {},
+}
+
+
+-- Returns a key used to be inserted into dmarc report sample
+exports.dmarc_report = function(task, settings, data)
+ local rspamd_lua_utils = require "lua_util"
+ local E = {}
+
+ local ip = task:get_from_ip()
+ if not ip or not ip:is_valid() then
+ rspamd_logger.infox(task, 'cannot store dmarc report for %s: no valid source IP',
+ data.domain)
+ return nil
+ end
+
+ ip = ip:to_string()
+
+ if rspamd_lua_utils.is_rspamc_or_controller(task) and not settings.reporting.report_local_controller then
+ rspamd_logger.infox(task, 'cannot store dmarc report for %s from IP %s: has come from controller/rspamc',
+ data.domain, ip)
+ return
+ end
+
+ local dkim_pass = table.concat(data.dkim_results.pass or E, '|')
+ local dkim_fail = table.concat(data.dkim_results.fail or E, '|')
+ local dkim_temperror = table.concat(data.dkim_results.temperror or E, '|')
+ local dkim_permerror = table.concat(data.dkim_results.permerror or E, '|')
+ local disposition_to_return = data.disposition
+ local res = table.concat({
+ ip, data.spf_ok, data.dkim_ok,
+ disposition_to_return, (data.sampled_out and 'sampled_out' or ''), data.domain,
+ dkim_pass, dkim_fail, dkim_temperror, dkim_permerror, data.spf_domain, data.spf_result }, ',')
+
+ return res
+end
+
+exports.gen_munging_callback = function(munging_opts, settings)
+ local rspamd_util = require "rspamd_util"
+ local lua_mime = require "lua_mime"
+ return function(task)
+ if munging_opts.mitigate_allow_only then
+ if not task:has_symbol(settings.symbols.allow) then
+ lua_util.debugm(N, task, 'skip munging, no %s symbol',
+ settings.symbols.allow)
+ -- Excepted
+ return
+ end
+ else
+ local has_dmarc = task:has_symbol(settings.symbols.allow) or
+ task:has_symbol(settings.symbols.quarantine) or
+ task:has_symbol(settings.symbols.reject) or
+ task:has_symbol(settings.symbols.softfail)
+
+ if not has_dmarc then
+ lua_util.debugm(N, task, 'skip munging, no %s symbol',
+ settings.symbols.allow)
+ -- Excepted
+ return
+ end
+ end
+ if munging_opts.mitigate_strict_only then
+ local s = task:get_symbol(settings.symbols.allow) or { [1] = {} }
+ local sopts = s[1].options or {}
+
+ local seen_strict
+ for _, o in ipairs(sopts) do
+ if o == 'reject' or o == 'quarantine' then
+ seen_strict = true
+ break
+ end
+ end
+
+ if not seen_strict then
+ lua_util.debugm(N, task, 'skip munging, no strict policy found in %s',
+ settings.symbols.allow)
+ -- Excepted
+ return
+ end
+ end
+ if munging_opts.munge_map_condition then
+ local accepted, trace = munging_opts.munge_map_condition:process(task)
+ if not accepted then
+ lua_util.debugm(N, task, 'skip munging, maps condition not satisfied: (%s)',
+ trace)
+ -- Excepted
+ return
+ end
+ end
+ -- Now, look for domain for munging
+ local mr = task:get_recipients({ 'mime', 'orig' })
+ local rcpt_found
+ if mr then
+ for _, r in ipairs(mr) do
+ if r.domain and munging_opts.list_map:get_key(r.addr) then
+ rcpt_found = r
+ break
+ end
+ end
+ end
+
+ if not rcpt_found then
+ lua_util.debugm(N, task, 'skip munging, recipients are not in list_map')
+ -- Excepted
+ return
+ end
+
+ local from = task:get_from({ 'mime', 'orig' })
+
+ if not from or not from[1] then
+ lua_util.debugm(N, task, 'skip munging, from is bad')
+ -- Excepted
+ return
+ end
+
+ from = from[1]
+ local via_user = rcpt_found.user
+ local via_addr = rcpt_found.addr
+ local via_name
+
+ if from.name == "" then
+ via_name = string.format('%s via %s', from.user or 'unknown', via_user)
+ else
+ via_name = string.format('%s via %s', from.name, via_user)
+ end
+
+ local hdr_encoded = rspamd_util.fold_header('From',
+ rspamd_util.mime_header_encode(string.format('%s <%s>',
+ via_name, via_addr)), task:get_newlines_type())
+ local orig_from_encoded = rspamd_util.fold_header('X-Original-From',
+ rspamd_util.mime_header_encode(string.format('%s <%s>',
+ from.name or '', from.addr)), task:get_newlines_type())
+ local add_hdrs = {
+ ['From'] = { order = 1, value = hdr_encoded },
+ }
+ local remove_hdrs = { ['From'] = 0 }
+
+ local nreply = from.addr
+ if munging_opts.reply_goes_to_list then
+ -- Reply-to goes to the list
+ nreply = via_addr
+ end
+
+ if task:has_header('Reply-To') then
+ -- If we have reply-to header, then we need to insert an additional
+ -- address there
+ local orig_reply = task:get_header_full('Reply-To')[1]
+ if orig_reply.value then
+ nreply = string.format('%s, %s', orig_reply.value, nreply)
+ end
+ remove_hdrs['Reply-To'] = 1
+ end
+
+ add_hdrs['Reply-To'] = { order = 0, value = nreply }
+
+ add_hdrs['X-Original-From'] = { order = 0, value = orig_from_encoded }
+ lua_mime.modify_headers(task, {
+ remove = remove_hdrs,
+ add = add_hdrs
+ })
+ lua_util.debugm(N, task, 'munged DMARC header for %s: %s -> %s',
+ from.domain, hdr_encoded, from.addr)
+ rspamd_logger.infox(task, 'munged DMARC header for %s', from.addr)
+ task:insert_result('DMARC_MUNGED', 1.0, from.addr)
+ end
+end
+
+local function gen_dmarc_grammar()
+ local lpeg = require "lpeg"
+ lpeg.locale(lpeg)
+ local space = lpeg.space ^ 0
+ local name = lpeg.C(lpeg.alpha ^ 1) * space
+ local sep = space * (lpeg.S("\\;") * space) + (lpeg.space ^ 1)
+ local value = lpeg.C(lpeg.P(lpeg.graph - sep) ^ 1)
+ local pair = lpeg.Cg(name * "=" * space * value) * sep ^ -1
+ local list = lpeg.Cf(lpeg.Ct("") * pair ^ 0, rawset)
+ local version = lpeg.P("v") * space * lpeg.P("=") * space * lpeg.P("DMARC1")
+ local record = version * sep * list
+
+ return record
+end
+
+local dmarc_grammar = gen_dmarc_grammar()
+
+local function dmarc_key_value_case(elts)
+ if type(elts) ~= "table" then
+ return elts
+ end
+ local result = {}
+ for k, v in pairs(elts) do
+ k = k:lower()
+ if k ~= "v" then
+ v = v:lower()
+ end
+
+ result[k] = v
+ end
+
+ return result
+end
+
+--[[
+-- Used to check dmarc record, check elements and produce dmarc policy processed
+-- result.
+-- Returns:
+-- false,false - record is garbage
+-- false,error_message - record is invalid
+-- true,policy_table - record is valid and parsed
+]]
+local function dmarc_check_record(log_obj, record, is_tld)
+ local failed_policy
+ local result = {
+ dmarc_policy = 'none'
+ }
+
+ local elts = dmarc_grammar:match(record)
+ lua_util.debugm(N, log_obj, "got DMARC record: %s, tld_flag=%s, processed=%s",
+ record, is_tld, elts)
+
+ if elts then
+ elts = dmarc_key_value_case(elts)
+
+ local dkim_pol = elts['adkim']
+ if dkim_pol then
+ if dkim_pol == 's' then
+ result.strict_dkim = true
+ elseif dkim_pol ~= 'r' then
+ failed_policy = 'adkim tag has invalid value: ' .. dkim_pol
+ return false, failed_policy
+ end
+ end
+
+ local spf_pol = elts['aspf']
+ if spf_pol then
+ if spf_pol == 's' then
+ result.strict_spf = true
+ elseif spf_pol ~= 'r' then
+ failed_policy = 'aspf tag has invalid value: ' .. spf_pol
+ return false, failed_policy
+ end
+ end
+
+ local policy = elts['p']
+ if policy then
+ if (policy == 'reject') then
+ result.dmarc_policy = 'reject'
+ elseif (policy == 'quarantine') then
+ result.dmarc_policy = 'quarantine'
+ elseif (policy ~= 'none') then
+ failed_policy = 'p tag has invalid value: ' .. policy
+ return false, failed_policy
+ end
+ end
+
+ -- Adjust policy if we are in tld mode
+ local subdomain_policy = elts['sp']
+ if elts['sp'] and is_tld then
+ result.subdomain_policy = elts['sp']
+
+ if (subdomain_policy == 'reject') then
+ result.dmarc_policy = 'reject'
+ elseif (subdomain_policy == 'quarantine') then
+ result.dmarc_policy = 'quarantine'
+ elseif (subdomain_policy == 'none') then
+ result.dmarc_policy = 'none'
+ elseif (subdomain_policy ~= 'none') then
+ failed_policy = 'sp tag has invalid value: ' .. subdomain_policy
+ return false, failed_policy
+ end
+ end
+ result.pct = elts['pct']
+ if result.pct then
+ result.pct = tonumber(result.pct)
+ end
+
+ if elts.rua then
+ result.rua = elts['rua']
+ end
+ result.raw_elts = elts
+ else
+ return false, false -- Ignore garbage
+ end
+
+ return true, result
+end
+
+exports.dmarc_check_record = dmarc_check_record
+
+return exports
diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua
new file mode 100644
index 0000000..6e88ef2
--- /dev/null
+++ b/lualib/plugins/neural.lua
@@ -0,0 +1,892 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local fun = require "fun"
+local lua_redis = require "lua_redis"
+local lua_settings = require "lua_settings"
+local lua_util = require "lua_util"
+local meta_functions = require "lua_meta"
+local rspamd_kann = require "rspamd_kann"
+local rspamd_logger = require "rspamd_logger"
+local rspamd_tensor = require "rspamd_tensor"
+local rspamd_util = require "rspamd_util"
+local ucl = require "ucl"
+
+local N = 'neural'
+
+-- Used in prefix to avoid wrong ANN to be loaded
+local plugin_ver = '2'
+
+-- Module vars
+local default_options = {
+ train = {
+ max_trains = 1000,
+ max_epoch = 1000,
+ max_usages = 10,
+ max_iterations = 25, -- Torch style
+ mse = 0.001,
+ autotrain = true,
+ train_prob = 1.0,
+ learn_threads = 1,
+ learn_mode = 'balanced', -- Possible values: balanced, proportional
+ learning_rate = 0.01,
+ classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias)
+ spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1)
+ ham_skip_prob = 0.0, -- proportional mode: ham skip probability
+ store_pool_only = false, -- store tokens in cache only (disables autotrain);
+ -- neural_vec_mpack stores vector of training data in messagepack neural_profile_digest stores profile digest
+ },
+ watch_interval = 60.0,
+ lock_expire = 600,
+ learning_spawned = false,
+ ann_expire = 60 * 60 * 24 * 2, -- 2 days
+ hidden_layer_mult = 1.5, -- number of neurons in the hidden layer
+ roc_enabled = false, -- Use ROC to find the best possible thresholds for ham and spam. If spam_score_threshold or ham_score_threshold is defined, it takes precedence over ROC thresholds.
+ roc_misclassification_cost = 0.5, -- Cost of misclassifying a spam message (must be 0..1).
+ spam_score_threshold = nil, -- neural score threshold for spam (must be 0..1 or nil to disable)
+ ham_score_threshold = nil, -- neural score threshold for ham (must be 0..1 or nil to disable)
+ flat_threshold_curve = false, -- use binary classification 0/1 when threshold is reached
+ symbol_spam = 'NEURAL_SPAM',
+ symbol_ham = 'NEURAL_HAM',
+ max_inputs = nil, -- when PCA is used
+ blacklisted_symbols = {}, -- list of symbols skipped in neural processing
+}
+
+-- Rule structure:
+-- * static config fields (see `default_options`)
+-- * prefix - name or defined prefix
+-- * settings - table of settings indexed by settings id, -1 is used when no settings defined
+
+-- Rule settings element defines elements for specific settings id:
+-- * symbols - static symbols profile (defined by config or extracted from symcache)
+-- * name - name of settings id
+-- * digest - digest of all symbols
+-- * ann - dynamic ANN configuration loaded from Redis
+-- * train - train data for ANN (e.g. the currently trained ANN)
+
+-- Settings ANN table is loaded from Redis and represents dynamic profile for ANN
+-- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically
+-- * version - version of ANN loaded from redis
+-- * redis_key - name of ANN key in Redis
+-- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols)
+-- * distance - distance between set.symbols and set.ann.symbols
+-- * ann - kann object
+
+local settings = {
+ rules = {},
+ prefix = 'rn', -- Neural network default prefix
+ max_profiles = 3, -- Maximum number of NN profiles stored
+}
+
+-- Get module & Redis configuration
+local module_config = rspamd_config:get_all_opt(N)
+settings = lua_util.override_defaults(settings, module_config)
+local redis_params = lua_redis.parse_redis_server('neural')
+
+local redis_lua_script_vectors_len = "neural_train_size.lua"
+local redis_lua_script_maybe_invalidate = "neural_maybe_invalidate.lua"
+local redis_lua_script_maybe_lock = "neural_maybe_lock.lua"
+local redis_lua_script_save_unlock = "neural_save_unlock.lua"
+
+local redis_script_id = {}
+
+local function load_scripts()
+ redis_script_id.vectors_len = lua_redis.load_redis_script_from_file(redis_lua_script_vectors_len,
+ redis_params)
+ redis_script_id.maybe_invalidate = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_invalidate,
+ redis_params)
+ redis_script_id.maybe_lock = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_lock,
+ redis_params)
+ redis_script_id.save_unlock = lua_redis.load_redis_script_from_file(redis_lua_script_save_unlock,
+ redis_params)
+end
+
+local function create_ann(n, nlayers, rule)
+ -- We ignore number of layers so far when using kann
+ local nhidden = math.floor(n * (rule.hidden_layer_mult or 1.0) + 1.0)
+ local t = rspamd_kann.layer.input(n)
+ t = rspamd_kann.transform.relu(t)
+ t = rspamd_kann.layer.dense(t, nhidden);
+ t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.ceb_neg)
+ return rspamd_kann.new.kann(t)
+end
+
+-- Fills ANN data for a specific settings element
+local function fill_set_ann(set, ann_key)
+ if not set.ann then
+ set.ann = {
+ symbols = set.symbols,
+ distance = 0,
+ digest = set.digest,
+ redis_key = ann_key,
+ version = 0,
+ }
+ end
+end
+
+-- This function takes all inputs, applies PCA transformation and returns the final
+-- PCA matrix as rspamd_tensor
+local function learn_pca(inputs, max_inputs)
+ local scatter_matrix = rspamd_tensor.scatter_matrix(rspamd_tensor.fromtable(inputs))
+ local eigenvals = scatter_matrix:eigen()
+ -- scatter matrix is not filled with eigenvectors
+ lua_util.debugm(N, 'eigenvalues: %s', eigenvals)
+ local w = rspamd_tensor.new(2, max_inputs, #scatter_matrix[1])
+ for i = 1, max_inputs do
+ w[i] = scatter_matrix[#scatter_matrix - i + 1]
+ end
+
+ lua_util.debugm(N, 'pca matrix: %s', w)
+
+ return w
+end
+
+-- This function computes optimal threshold using ROC for the given set of inputs.
+-- Returns a threshold that minimizes:
+-- alpha * (false_positive_rate) + beta * (false_negative_rate)
+-- Where alpha is cost of false positive result
+-- beta is cost of false negative result
+local function get_roc_thresholds(ann, inputs, outputs, alpha, beta)
+
+ -- Sorts list x and list y based on the values in list x.
+ local sort_relative = function(x, y)
+
+ local r = {}
+
+ assert(#x == #y)
+ local n = #x
+
+ local a = {}
+ local b = {}
+ for i = 1, n do
+ r[i] = i
+ end
+
+ local cmp = function(p, q)
+ return p < q
+ end
+
+ table.sort(r, function(p, q)
+ return cmp(x[p], x[q])
+ end)
+
+ for i = 1, n do
+ a[i] = x[r[i]]
+ b[i] = y[r[i]]
+ end
+
+ return a, b
+ end
+
+ local function get_scores(nn, input_vectors)
+ local scores = {}
+ for i = 1, #inputs do
+ local score = nn:apply1(input_vectors[i], nn.pca)[1]
+ scores[#scores + 1] = score
+ end
+
+ return scores
+ end
+
+ local fpr = {}
+ local fnr = {}
+ local scores = get_scores(ann, inputs)
+
+ scores, outputs = sort_relative(scores, outputs)
+
+ local n_samples = #outputs
+ local n_spam = 0
+ local n_ham = 0
+ local ham_count_ahead = {}
+ local spam_count_ahead = {}
+ local ham_count_behind = {}
+ local spam_count_behind = {}
+
+ ham_count_ahead[n_samples + 1] = 0
+ spam_count_ahead[n_samples + 1] = 0
+
+ for i = n_samples, 1, -1 do
+
+ if outputs[i][1] == 0 then
+ n_ham = n_ham + 1
+ ham_count_ahead[i] = 1
+ spam_count_ahead[i] = 0
+ else
+ n_spam = n_spam + 1
+ ham_count_ahead[i] = 0
+ spam_count_ahead[i] = 1
+ end
+
+ ham_count_ahead[i] = ham_count_ahead[i] + ham_count_ahead[i + 1]
+ spam_count_ahead[i] = spam_count_ahead[i] + spam_count_ahead[i + 1]
+ end
+
+ for i = 1, n_samples do
+ if outputs[i][1] == 0 then
+ ham_count_behind[i] = 1
+ spam_count_behind[i] = 0
+ else
+ ham_count_behind[i] = 0
+ spam_count_behind[i] = 1
+ end
+
+ if i ~= 1 then
+ ham_count_behind[i] = ham_count_behind[i] + ham_count_behind[i - 1]
+ spam_count_behind[i] = spam_count_behind[i] + spam_count_behind[i - 1]
+ end
+ end
+
+ for i = 1, n_samples do
+ fpr[i] = 0
+ fnr[i] = 0
+
+ if (ham_count_ahead[i + 1] + ham_count_behind[i]) ~= 0 then
+ fpr[i] = ham_count_ahead[i + 1] / (ham_count_ahead[i + 1] + ham_count_behind[i])
+ end
+
+ if (spam_count_behind[i] + spam_count_ahead[i + 1]) ~= 0 then
+ fnr[i] = spam_count_behind[i] / (spam_count_behind[i] + spam_count_ahead[i + 1])
+ end
+ end
+
+ local p = n_spam / (n_spam + n_ham)
+
+ local cost = {}
+ local min_cost_idx = 0
+ local min_cost = math.huge
+ for i = 1, n_samples do
+ cost[i] = ((1 - p) * alpha * fpr[i]) + (p * beta * fnr[i])
+ if min_cost >= cost[i] then
+ min_cost = cost[i]
+ min_cost_idx = i
+ end
+ end
+
+ return scores[min_cost_idx]
+end
+
+-- This function is intended to extend lock for ANN during training
+-- It registers periodic that increases locked key each 30 seconds unless
+-- `set.learning_spawned` is set to `true`
+local function register_lock_extender(rule, set, ev_base, ann_key)
+ rspamd_config:add_periodic(ev_base, 30.0,
+ function()
+ local function redis_lock_extend_cb(err, _)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
+ ann_key, err)
+ else
+ rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',
+ ann_key)
+ end
+ end
+
+ if set.learning_spawned then
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ redis_lock_extend_cb, --callback
+ 'HINCRBY', -- command
+ { ann_key, 'lock', '30' }
+ )
+ else
+ lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false")
+ return false -- do not plan any more updates
+ end
+
+ return true
+ end
+ )
+end
+
+local function can_push_train_vector(rule, task, learn_type, nspam, nham)
+ local train_opts = rule.train
+ local coin = math.random()
+
+ if train_opts.train_prob and coin < 1.0 - train_opts.train_prob then
+ rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
+ return false
+ end
+
+ if train_opts.learn_mode == 'balanced' then
+ -- Keep balanced training set based on number of spam and ham samples
+ if learn_type == 'spam' then
+ if nspam <= train_opts.max_trains then
+ if nspam > nham then
+ -- Apply sampling
+ local skip_rate = 1.0 - nham / (nspam + 1)
+ if coin < skip_rate - train_opts.classes_bias then
+ rspamd_logger.infox(task,
+ 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
+ learn_type,
+ skip_rate - train_opts.classes_bias,
+ nspam, nham)
+ return false
+ end
+ end
+ return true
+ else
+ -- Enough learns
+ rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s',
+ learn_type,
+ nspam)
+ end
+ else
+ if nham <= train_opts.max_trains then
+ if nham > nspam then
+ -- Apply sampling
+ local skip_rate = 1.0 - nspam / (nham + 1)
+ if coin < skip_rate - train_opts.classes_bias then
+ rspamd_logger.infox(task,
+ 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
+ learn_type,
+ skip_rate - train_opts.classes_bias,
+ nspam, nham)
+ return false
+ end
+ end
+ return true
+ else
+ rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many ham samples: %s', learn_type,
+ nham)
+ end
+ end
+ else
+ -- Probabilistic learn mode, we just skip learn if we already have enough samples or
+ -- if our coin drop is less than desired probability
+ if learn_type == 'spam' then
+ if nspam <= train_opts.max_trains then
+ if train_opts.spam_skip_prob then
+ if coin <= train_opts.spam_skip_prob then
+ rspamd_logger.infox(task, 'skip %s sample probabilistically; probability %s (%s skip chance)', learn_type,
+ coin, train_opts.spam_skip_prob)
+ return false
+ end
+
+ return true
+ end
+ else
+ rspamd_logger.infox(task, 'skip %s sample; too many spam samples: %s (%s limit)', learn_type,
+ nspam, train_opts.max_trains)
+ end
+ else
+ if nham <= train_opts.max_trains then
+ if train_opts.ham_skip_prob then
+ if coin <= train_opts.ham_skip_prob then
+ rspamd_logger.infox(task, 'skip %s sample probabilistically; probability %s (%s skip chance)', learn_type,
+ coin, train_opts.ham_skip_prob)
+ return false
+ end
+
+ return true
+ end
+ else
+ rspamd_logger.infox(task, 'skip %s sample; too many ham samples: %s (%s limit)', learn_type,
+ nham, train_opts.max_trains)
+ end
+ end
+ end
+
+ return false
+end
+
+-- Closure generator for unlock function
+local function gen_unlock_cb(rule, set, ann_key)
+ return function(err)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s',
+ rule.prefix, set.name, ann_key, err)
+ else
+ lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s',
+ rule.prefix, set.name, ann_key)
+ end
+ end
+end
+
+-- Used to generate new ANN key for specific profile
+local function new_ann_key(rule, set, version)
+ local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix,
+ rule.prefix, set.name, set.digest:sub(1, 8), tostring(version))
+
+ return ann_key
+end
+
+local function redis_ann_prefix(rule, settings_name)
+ -- We also need to count metatokens:
+ local n = meta_functions.version
+ return string.format('%s%d_%s_%d_%s',
+ settings.prefix, plugin_ver, rule.prefix, n, settings_name)
+end
+
+-- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
+local function spawn_train(params)
+ -- Check training data sanity
+ -- Now we need to join inputs and create the appropriate test vectors
+ local n = #params.set.symbols +
+ meta_functions.rspamd_count_metatokens()
+
+ -- Now we can train ann
+ local train_ann = create_ann(params.rule.max_inputs or n, 3, params.rule)
+
+ if #params.ham_vec + #params.spam_vec < params.rule.train.max_trains / 2 then
+ -- Invalidate ANN as it is definitely invalid
+ -- TODO: add invalidation
+ assert(false)
+ else
+ local inputs, outputs = {}, {}
+
+ -- Used to show parsed vectors in a convenient format (for debugging only)
+ local function debug_vec(t)
+ local ret = {}
+ for i, v in ipairs(t) do
+ if v ~= 0 then
+ ret[#ret + 1] = string.format('%d=%.2f', i, v)
+ end
+ end
+
+ return ret
+ end
+
+ -- Make training set by joining vectors
+ -- KANN automatically shuffles those samples
+ -- 1.0 is used for spam and -1.0 is used for ham
+ -- It implies that output layer can express that (e.g. tanh output)
+ for _, e in ipairs(params.spam_vec) do
+ inputs[#inputs + 1] = e
+ outputs[#outputs + 1] = { 1.0 }
+ --rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e))
+ end
+ for _, e in ipairs(params.ham_vec) do
+ inputs[#inputs + 1] = e
+ outputs[#outputs + 1] = { -1.0 }
+ --rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e))
+ end
+
+ -- Called in child process
+ local function train()
+ local log_thresh = params.rule.train.max_iterations / 10
+ local seen_nan = false
+
+ local function train_cb(iter, train_cost, value_cost)
+ if (iter * (params.rule.train.max_iterations / log_thresh)) % (params.rule.train.max_iterations) == 0 then
+ if train_cost ~= train_cost and not seen_nan then
+ -- We have nan :( try to log lot's of stuff to dig into a problem
+ seen_nan = true
+ rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s',
+ params.rule.prefix, params.set.name,
+ value_cost)
+ for i, e in ipairs(inputs) do
+ lua_util.debugm(N, rspamd_config, 'train vector %s -> %s',
+ debug_vec(e), outputs[i][1])
+ end
+ end
+
+ rspamd_logger.infox(rspamd_config,
+ "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s",
+ params.rule.prefix, params.set.name,
+ params.ann_key,
+ iter,
+ train_cost,
+ value_cost)
+ end
+ end
+
+ lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started",
+ params.rule.prefix, params.set.name)
+
+ local pca
+ if params.rule.max_inputs then
+ -- Train PCA in the main process, presumably it is not that long
+ lua_util.debugm(N, rspamd_config, "start PCA train for ANN %s:%s",
+ params.rule.prefix, params.set.name)
+ pca = learn_pca(inputs, params.rule.max_inputs)
+ end
+
+ lua_util.debugm(N, rspamd_config, "start neural train for ANN %s:%s",
+ params.rule.prefix, params.set.name)
+ local ret, err = pcall(train_ann.train1, train_ann,
+ inputs, outputs, {
+ lr = params.rule.train.learning_rate,
+ max_epoch = params.rule.train.max_iterations,
+ cb = train_cb,
+ pca = pca
+ })
+
+ if not ret then
+ rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s",
+ params.rule.prefix, params.set.name, err)
+
+ return nil
+ else
+ lua_util.debugm(N, rspamd_config, "finished neural train for ANN %s:%s",
+ params.rule.prefix, params.set.name)
+ end
+
+ local roc_thresholds = {}
+ if params.rule.roc_enabled then
+ local spam_threshold = get_roc_thresholds(train_ann,
+ inputs,
+ outputs,
+ 1 - params.rule.roc_misclassification_cost,
+ params.rule.roc_misclassification_cost)
+ local ham_threshold = get_roc_thresholds(train_ann,
+ inputs,
+ outputs,
+ params.rule.roc_misclassification_cost,
+ 1 - params.rule.roc_misclassification_cost)
+ roc_thresholds = { spam_threshold, ham_threshold }
+
+ rspamd_logger.messagex("ROC thresholds: (spam_threshold: %s, ham_threshold: %s)",
+ roc_thresholds[1], roc_thresholds[2])
+ end
+
+ if not seen_nan then
+ -- Convert to strings as ucl cannot rspamd_text properly
+ local pca_data
+ if pca then
+ pca_data = tostring(pca:save())
+ end
+ local out = {
+ ann_data = tostring(train_ann:save()),
+ pca_data = pca_data,
+ roc_thresholds = roc_thresholds,
+ }
+
+ local final_data = ucl.to_format(out, 'msgpack')
+ lua_util.debugm(N, rspamd_config, "subprocess for ANN %s:%s returned %s bytes",
+ params.rule.prefix, params.set.name, #final_data)
+ return final_data
+ else
+ return nil
+ end
+ end
+
+ params.set.learning_spawned = true
+
+ local function redis_save_cb(err)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s',
+ params.rule.prefix, params.set.name, params.ann_key, err)
+ lua_redis.redis_make_request_taskless(params.ev_base,
+ rspamd_config,
+ params.rule.redis,
+ nil,
+ false, -- is write
+ gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
+ 'HDEL', -- command
+ { params.ann_key, 'lock' }
+ )
+ else
+ rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s',
+ params.rule.prefix, params.set.name, params.set.ann.redis_key)
+ end
+ end
+
+ local function ann_trained(err, data)
+ params.set.learning_spawned = false
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s',
+ params.rule.prefix, params.set.name, err)
+ lua_redis.redis_make_request_taskless(params.ev_base,
+ rspamd_config,
+ params.rule.redis,
+ nil,
+ true, -- is write
+ gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
+ 'HDEL', -- command
+ { params.ann_key, 'lock' }
+ )
+ else
+ local parser = ucl.parser()
+ local ok, parse_err = parser:parse_text(data, 'msgpack')
+ assert(ok, parse_err)
+ local parsed = parser:get_object()
+ local ann_data = rspamd_util.zstd_compress(parsed.ann_data)
+ local pca_data = parsed.pca_data
+ local roc_thresholds = parsed.roc_thresholds
+
+ fill_set_ann(params.set, params.ann_key)
+ if pca_data then
+ params.set.ann.pca = rspamd_tensor.load(pca_data)
+ pca_data = rspamd_util.zstd_compress(pca_data)
+ end
+
+ if roc_thresholds then
+ params.set.ann.roc_thresholds = roc_thresholds
+ end
+
+
+ -- Deserialise ANN from the child process
+ ann_trained = rspamd_kann.load(parsed.ann_data)
+ local version = (params.set.ann.version or 0) + 1
+ params.set.ann.version = version
+ params.set.ann.ann = ann_trained
+ params.set.ann.symbols = params.set.symbols
+ params.set.ann.redis_key = new_ann_key(params.rule, params.set, version)
+
+ local profile = {
+ symbols = params.set.symbols,
+ digest = params.set.digest,
+ redis_key = params.set.ann.redis_key,
+ version = version
+ }
+
+ local profile_serialized = ucl.to_format(profile, 'json-compact', true)
+ local roc_thresholds_serialized = ucl.to_format(roc_thresholds, 'json-compact', true)
+
+ rspamd_logger.infox(rspamd_config,
+ 'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)',
+ params.rule.prefix, params.set.name,
+ #data, #ann_data,
+ #(params.set.ann.pca or {}), #(pca_data or {}),
+ params.set.ann.redis_key, params.ann_key)
+
+ lua_redis.exec_redis_script(redis_script_id.save_unlock,
+ { ev_base = params.ev_base, is_write = true },
+ redis_save_cb,
+ { profile.redis_key,
+ redis_ann_prefix(params.rule, params.set.name),
+ ann_data,
+ profile_serialized,
+ tostring(params.rule.ann_expire),
+ tostring(os.time()),
+ params.ann_key, -- old key to unlock...
+ roc_thresholds_serialized,
+ pca_data,
+ })
+ end
+ end
+
+ if params.rule.max_inputs then
+ fill_set_ann(params.set, params.ann_key)
+ end
+
+ params.worker:spawn_process {
+ func = train,
+ on_complete = ann_trained,
+ proctitle = string.format("ANN train for %s/%s", params.rule.prefix, params.set.name),
+ }
+ -- Spawn learn and register lock extension
+ params.set.learning_spawned = true
+ register_lock_extender(params.rule, params.set, params.ev_base, params.ann_key)
+ return
+
+ end
+end
+
+-- This function is used to adjust profiles and allowed setting ids for each rule
+-- It must be called when all settings are already registered (e.g. at post-init for config)
+local function process_rules_settings()
+ local function process_settings_elt(rule, selt)
+ local profile = rule.profile[selt.name]
+ if profile then
+ -- Use static user defined profile
+ -- Ensure that we have an array...
+ lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s",
+ rule.prefix, selt.name, profile)
+ if not profile[1] then
+ profile = lua_util.keys(profile)
+ end
+ selt.symbols = profile
+ else
+ lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)",
+ rule.prefix, selt.name)
+ end
+
+ local function filter_symbols_predicate(sname)
+ if settings.blacklisted_symbols and settings.blacklisted_symbols[sname] then
+ return false
+ end
+ local fl = rspamd_config:get_symbol_flags(sname)
+ if fl then
+ fl = lua_util.list_to_hash(fl)
+
+ return not (fl.nostat or fl.idempotent or fl.skip or fl.composite)
+ end
+
+ return false
+ end
+
+ -- Generic stuff
+ if not profile then
+ -- Do filtering merely if we are using a dynamic profile
+ selt.symbols = fun.totable(fun.filter(filter_symbols_predicate, selt.symbols))
+ end
+
+ table.sort(selt.symbols)
+
+ selt.digest = lua_util.table_digest(selt.symbols)
+ selt.prefix = redis_ann_prefix(rule, selt.name)
+
+ rspamd_logger.messagex(rspamd_config,
+ 'use NN prefix for rule %s; settings id "%s"; symbols digest: "%s"',
+ selt.prefix, selt.name, selt.digest)
+
+ lua_redis.register_prefix(selt.prefix, N,
+ string.format('NN prefix for rule "%s"; settings id "%s"',
+ selt.prefix, selt.name), {
+ persistent = true,
+ type = 'zlist',
+ })
+ -- Versions
+ lua_redis.register_prefix(selt.prefix .. '_\\d+', N,
+ string.format('NN storage for rule "%s"; settings id "%s"',
+ selt.prefix, selt.name), {
+ persistent = true,
+ type = 'hash',
+ })
+ lua_redis.register_prefix(selt.prefix .. '_\\d+_spam_set', N,
+ string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
+ selt.prefix, selt.name), {
+ persistent = true,
+ type = 'set',
+ })
+ lua_redis.register_prefix(selt.prefix .. '_\\d+_ham_set', N,
+ string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
+ rule.prefix, selt.name), {
+ persistent = true,
+ type = 'set',
+ })
+ end
+
+ for k, rule in pairs(settings.rules) do
+ if not rule.allowed_settings then
+ rule.allowed_settings = {}
+ elseif rule.allowed_settings == 'all' then
+ -- Extract all settings ids
+ rule.allowed_settings = lua_util.keys(lua_settings.all_settings())
+ end
+
+ -- Convert to a map <setting_id> -> true
+ rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings)
+
+ -- Check if we can work without settings
+ if k == 'default' or type(rule.default) ~= 'boolean' then
+ rule.default = true
+ end
+
+ rule.settings = {}
+
+ if rule.default then
+ local default_settings = {
+ symbols = lua_settings.default_symbols(),
+ name = 'default'
+ }
+
+ process_settings_elt(rule, default_settings)
+ rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32
+ end
+
+ -- Now, for each allowed settings, we store sorted symbols + digest
+ -- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest }
+ for s, _ in pairs(rule.allowed_settings) do
+ -- Here, we have a name, set of symbols and
+ local settings_id = s
+ if type(settings_id) ~= 'number' then
+ settings_id = lua_settings.numeric_settings_id(s)
+ end
+ local selt = lua_settings.settings_by_id(settings_id)
+
+ local nelt = {
+ symbols = selt.symbols, -- Already sorted
+ name = selt.name
+ }
+
+ process_settings_elt(rule, nelt)
+ for id, ex in pairs(rule.settings) do
+ if type(ex) == 'table' then
+ if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then
+ -- Equal symbols, add reference
+ lua_util.debugm(N, rspamd_config,
+ 'added reference from settings id %s to %s; same symbols',
+ nelt.name, ex.name)
+ rule.settings[settings_id] = id
+ nelt = nil
+ end
+ end
+ end
+
+ if nelt then
+ rule.settings[settings_id] = nelt
+ lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s',
+ nelt.name, settings_id, rule.prefix)
+ end
+ end
+ end
+end
+
+-- Extract settings element for a specific settings id
+local function get_rule_settings(task, rule)
+ local sid = task:get_settings_id() or -1
+ local set = rule.settings[sid]
+
+ if not set then
+ return nil
+ end
+
+ while type(set) == 'number' do
+ -- Reference to another settings!
+ set = rule.settings[set]
+ end
+
+ return set
+end
+
+local function result_to_vector(task, profile)
+ if not profile.zeros then
+ -- Fill zeros vector
+ local zeros = {}
+ for i = 1, meta_functions.count_metatokens() do
+ zeros[i] = 0.0
+ end
+ for _, _ in ipairs(profile.symbols) do
+ zeros[#zeros + 1] = 0.0
+ end
+ profile.zeros = zeros
+ end
+
+ local vec = lua_util.shallowcopy(profile.zeros)
+ local mt = meta_functions.rspamd_gen_metatokens(task)
+
+ for i, v in ipairs(mt) do
+ vec[i] = v
+ end
+
+ task:process_ann_tokens(profile.symbols, vec, #mt, 0.1)
+
+ return vec
+end
+
+return {
+ can_push_train_vector = can_push_train_vector,
+ create_ann = create_ann,
+ default_options = default_options,
+ gen_unlock_cb = gen_unlock_cb,
+ get_rule_settings = get_rule_settings,
+ load_scripts = load_scripts,
+ module_config = module_config,
+ new_ann_key = new_ann_key,
+ plugin_ver = plugin_ver,
+ process_rules_settings = process_rules_settings,
+ redis_ann_prefix = redis_ann_prefix,
+ redis_params = redis_params,
+ redis_script_id = redis_script_id,
+ result_to_vector = result_to_vector,
+ settings = settings,
+ spawn_train = spawn_train,
+}
diff --git a/lualib/plugins/rbl.lua b/lualib/plugins/rbl.lua
new file mode 100644
index 0000000..af5d6bd
--- /dev/null
+++ b/lualib/plugins/rbl.lua
@@ -0,0 +1,232 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local ts = require("tableshape").types
+local lua_maps = require "lua_maps"
+local lua_util = require "lua_util"
+
+-- Common RBL plugin definitions
+
+local check_types = {
+ from = {
+ connfilter = true,
+ },
+ received = {},
+ helo = {
+ connfilter = true,
+ },
+ urls = {},
+ content_urls = {},
+ numeric_urls = {},
+ emails = {},
+ replyto = {},
+ dkim = {},
+ rdns = {
+ connfilter = true,
+ },
+ selector = {
+ require_argument = true,
+ },
+}
+
+local default_options = {
+ ['default_enabled'] = true,
+ ['default_ipv4'] = true,
+ ['default_ipv6'] = true,
+ ['default_unknown'] = false,
+ ['default_dkim_domainonly'] = true,
+ ['default_emails_domainonly'] = false,
+ ['default_exclude_users'] = false,
+ ['default_exclude_local'] = true,
+ ['default_no_ip'] = false,
+ ['default_dkim_match_from'] = false,
+ ['default_selector_flatten'] = true,
+}
+
+local return_codes_schema = ts.map_of(
+ ts.string / string.upper, -- Symbol name
+ (
+ ts.array_of(ts.string) +
+ (ts.string / function(s)
+ return { s }
+ end) -- List of IP patterns
+ )
+)
+local return_bits_schema = ts.map_of(
+ ts.string / string.upper, -- Symbol name
+ (
+ ts.array_of(ts.number + ts.string / tonumber) +
+ (ts.string / function(s)
+ return { tonumber(s) }
+ end) +
+ (ts.number / function(s)
+ return { s }
+ end)
+ )
+)
+
+local rule_schema_tbl = {
+ content_urls = ts.boolean:is_optional(),
+ disable_monitoring = ts.boolean:is_optional(),
+ disabled = ts.boolean:is_optional(),
+ dkim = ts.boolean:is_optional(),
+ dkim_domainonly = ts.boolean:is_optional(),
+ dkim_match_from = ts.boolean:is_optional(),
+ emails = ts.boolean:is_optional(),
+ emails_delimiter = ts.string:is_optional(),
+ emails_domainonly = ts.boolean:is_optional(),
+ enabled = ts.boolean:is_optional(),
+ exclude_local = ts.boolean:is_optional(),
+ exclude_users = ts.boolean:is_optional(),
+ from = ts.boolean:is_optional(),
+ hash = ts.one_of { "sha1", "sha256", "sha384", "sha512", "md5", "blake2" }:is_optional(),
+ hash_format = ts.one_of { "hex", "base32", "base64" }:is_optional(),
+ hash_len = (ts.integer + ts.string / tonumber):is_optional(),
+ helo = ts.boolean:is_optional(),
+ ignore_default = ts.boolean:is_optional(), -- alias
+ ignore_defaults = ts.boolean:is_optional(),
+ ignore_url_whitelist = ts.boolean:is_optional(),
+ ignore_whitelist = ts.boolean:is_optional(),
+ ignore_whitelists = ts.boolean:is_optional(), -- alias
+ images = ts.boolean:is_optional(),
+ ipv4 = ts.boolean:is_optional(),
+ ipv6 = ts.boolean:is_optional(),
+ is_whitelist = ts.boolean:is_optional(),
+ local_exclude_ip_map = ts.string:is_optional(),
+ monitored_address = ts.string:is_optional(),
+ no_ip = ts.boolean:is_optional(),
+ process_script = ts.string:is_optional(),
+ random_monitored = ts.boolean:is_optional(),
+ rbl = ts.string,
+ rdns = ts.boolean:is_optional(),
+ received = ts.boolean:is_optional(),
+ received_flags = ts.array_of(ts.string):is_optional(),
+ received_max_pos = ts.number:is_optional(),
+ received_min_pos = ts.number:is_optional(),
+ received_nflags = ts.array_of(ts.string):is_optional(),
+ replyto = ts.boolean:is_optional(),
+ requests_limit = (ts.integer + ts.string / tonumber):is_optional(),
+ require_symbols = (
+ ts.array_of(ts.string) + (ts.string / function(s)
+ return { s }
+ end)
+ ):is_optional(),
+ resolve_ip = ts.boolean:is_optional(),
+ return_bits = return_bits_schema:is_optional(),
+ return_codes = return_codes_schema:is_optional(),
+ returnbits = return_bits_schema:is_optional(),
+ returncodes = return_codes_schema:is_optional(),
+ returncodes_matcher = ts.one_of { "equality", "glob", "luapattern", "radix", "regexp" }:is_optional(),
+ selector = ts.one_of { ts.string, ts.table }:is_optional(),
+ selector_flatten = ts.boolean:is_optional(),
+ symbol = ts.string:is_optional(),
+ symbols_prefixes = ts.map_of(ts.string, ts.string):is_optional(),
+ unknown = ts.boolean:is_optional(),
+ url_compose_map = lua_maps.map_schema:is_optional(),
+ url_full_hostname = ts.boolean:is_optional(),
+ url_whitelist = lua_maps.map_schema:is_optional(),
+ urls = ts.boolean:is_optional(),
+ whitelist = lua_maps.map_schema:is_optional(),
+ whitelist_exception = (
+ ts.array_of(ts.string) + (ts.string / function(s)
+ return { s }
+ end)
+ ):is_optional(),
+ checks = ts.array_of(ts.one_of(lua_util.keys(check_types))):is_optional(),
+ exclude_checks = ts.array_of(ts.one_of(lua_util.keys(check_types))):is_optional(),
+}
+
+local function convert_checks(rule, name)
+ local rspamd_logger = require "rspamd_logger"
+ if rule.checks then
+ local all_connfilter = true
+ local exclude_checks = lua_util.list_to_hash(rule.exclude_checks or {})
+ for _, check in ipairs(rule.checks) do
+ if not exclude_checks[check] then
+ local check_type = check_types[check]
+ if check_type.require_argument then
+ if not rule[check] then
+ rspamd_logger.errx(rspamd_config, 'rbl rule %s has check %s which requires an argument',
+ name, check)
+ return nil
+ end
+ end
+
+ rule[check] = check_type
+
+ if not check_type.connfilter then
+ all_connfilter = false
+ end
+
+ if not check_type then
+ rspamd_logger.errx(rspamd_config, 'rbl rule %s has invalid check type: %s',
+ name, check)
+ return nil
+ end
+ else
+ rspamd_logger.infox(rspamd_config, 'disable check %s in %s: excluded explicitly',
+ check, name)
+ end
+ end
+ rule.connfilter = all_connfilter
+ end
+
+ -- Now check if we have any check enabled at all
+ local check_found = false
+ for k, _ in pairs(check_types) do
+ if type(rule[k]) ~= 'nil' then
+ check_found = true
+ break
+ end
+ end
+
+ if not check_found then
+ -- Enable implicit `from` check to allow upgrade
+ rspamd_logger.warnx(rspamd_config, 'rbl rule %s has no check enabled, enable default `from` check',
+ name)
+ rule.from = true
+ end
+
+ if rule.returncodes and not rule.returncodes_matcher then
+ for _, v in pairs(rule.returncodes) do
+ for _, e in ipairs(v) do
+ if e:find('[%%%[]') then
+ rspamd_logger.warn(rspamd_config, 'implicitly enabling luapattern returncodes_matcher for rule %s', name)
+ rule.returncodes_matcher = 'luapattern'
+ break
+ end
+ end
+ if rule.returncodes_matcher then
+ break
+ end
+ end
+ end
+
+ return rule
+end
+
+
+-- Add default boolean flags to the schema
+for def_k, _ in pairs(default_options) do
+ rule_schema_tbl[def_k:sub(#('default_') + 1)] = ts.boolean:is_optional()
+end
+
+return {
+ check_types = check_types,
+ rule_schema = ts.shape(rule_schema_tbl),
+ default_options = default_options,
+ convert_checks = convert_checks,
+}
diff --git a/lualib/plugins_stats.lua b/lualib/plugins_stats.lua
new file mode 100644
index 0000000..2497fb9
--- /dev/null
+++ b/lualib/plugins_stats.lua
@@ -0,0 +1,48 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local ansicolors = require "ansicolors"
+
+local function printf(fmt, ...)
+ print(string.format(fmt, ...))
+end
+
+local function highlight(str)
+ return ansicolors.white .. str .. ansicolors.reset
+end
+
+local function print_plugins_table(tbl, what)
+ local mods = {}
+ for k, _ in pairs(tbl) do
+ table.insert(mods, k)
+ end
+
+ printf("Modules %s: %s", highlight(what), table.concat(mods, ", "))
+end
+
+return function(args, _)
+ print_plugins_table(rspamd_plugins_state.enabled, "enabled")
+ print_plugins_table(rspamd_plugins_state.disabled_explicitly,
+ "disabled (explicitly)")
+ print_plugins_table(rspamd_plugins_state.disabled_unconfigured,
+ "disabled (unconfigured)")
+ print_plugins_table(rspamd_plugins_state.disabled_redis,
+ "disabled (no Redis)")
+ print_plugins_table(rspamd_plugins_state.disabled_experimental,
+ "disabled (experimental)")
+ print_plugins_table(rspamd_plugins_state.disabled_failed,
+ "disabled (failed)")
+end \ No newline at end of file
diff --git a/lualib/redis_scripts/bayes_cache_check.lua b/lualib/redis_scripts/bayes_cache_check.lua
new file mode 100644
index 0000000..f1ffc2b
--- /dev/null
+++ b/lualib/redis_scripts/bayes_cache_check.lua
@@ -0,0 +1,20 @@
+-- Lua script to perform cache checking for bayes classification
+-- This script accepts the following parameters:
+-- key1 - cache id
+-- key2 - configuration table in message pack
+
+local cache_id = KEYS[1]
+local conf = cmsgpack.unpack(KEYS[2])
+cache_id = string.sub(cache_id, 1, conf.cache_elt_len)
+
+-- Try each prefix that is in Redis
+for i = 0, conf.cache_max_keys do
+ local prefix = conf.cache_prefix .. string.rep("X", i)
+ local have = redis.call('HGET', prefix, cache_id)
+
+ if have then
+ return tonumber(have)
+ end
+end
+
+return nil
diff --git a/lualib/redis_scripts/bayes_cache_learn.lua b/lualib/redis_scripts/bayes_cache_learn.lua
new file mode 100644
index 0000000..8811f3c
--- /dev/null
+++ b/lualib/redis_scripts/bayes_cache_learn.lua
@@ -0,0 +1,61 @@
+-- Lua script to perform cache checking for bayes classification
+-- This script accepts the following parameters:
+-- key1 - cache id
+-- key3 - is spam (1 or 0)
+-- key3 - configuration table in message pack
+
+local cache_id = KEYS[1]
+local is_spam = KEYS[2]
+local conf = cmsgpack.unpack(KEYS[3])
+cache_id = string.sub(cache_id, 1, conf.cache_elt_len)
+
+-- Try each prefix that is in Redis (as some other instance might have set it)
+for i = 0, conf.cache_max_keys do
+ local prefix = conf.cache_prefix .. string.rep("X", i)
+ local have = redis.call('HGET', prefix, cache_id)
+
+ if have then
+ -- Already in cache
+ return false
+ end
+end
+
+local added = false
+local lim = conf.cache_max_elt
+for i = 0, conf.cache_max_keys do
+ if not added then
+ local prefix = conf.cache_prefix .. string.rep("X", i)
+ local count = redis.call('HLEN', prefix)
+
+ if count < lim then
+ -- We can add it to this prefix
+ redis.call('HSET', prefix, cache_id, is_spam)
+ added = true
+ end
+ end
+end
+
+if not added then
+ -- Need to expire some keys
+ local expired = false
+ for i = 0, conf.cache_max_keys do
+ local prefix = conf.cache_prefix .. string.rep("X", i)
+ local exists = redis.call('EXISTS', prefix)
+
+ if exists then
+ if expired then
+ redis.call('DEL', prefix)
+ redis.call('HSET', prefix, cache_id, is_spam)
+
+ -- Do not expire anything else
+ expired = true
+ elseif i > 0 then
+ -- Move key to a shorter prefix, so we will rotate them eventually from lower to upper
+ local new_prefix = conf.cache_prefix .. string.rep("X", i - 1)
+ redis.call('RENAME', prefix, new_prefix)
+ end
+ end
+ end
+end
+
+return true \ No newline at end of file
diff --git a/lualib/redis_scripts/bayes_classify.lua b/lualib/redis_scripts/bayes_classify.lua
new file mode 100644
index 0000000..e94f645
--- /dev/null
+++ b/lualib/redis_scripts/bayes_classify.lua
@@ -0,0 +1,37 @@
+-- Lua script to perform bayes classification
+-- This script accepts the following parameters:
+-- key1 - prefix for bayes tokens (e.g. for per-user classification)
+-- key2 - set of tokens encoded in messagepack array of strings
+
+local prefix = KEYS[1]
+local output_spam = {}
+local output_ham = {}
+
+local learned_ham = tonumber(redis.call('HGET', prefix, 'learns_ham')) or 0
+local learned_spam = tonumber(redis.call('HGET', prefix, 'learns_spam')) or 0
+
+-- Output is a set of pairs (token_index, token_count), tokens that are not
+-- found are not filled.
+-- This optimisation will save a lot of space for sparse tokens, and in Bayes that assumption is normally held
+
+if learned_ham > 0 and learned_spam > 0 then
+ local input_tokens = cmsgpack.unpack(KEYS[2])
+ for i, token in ipairs(input_tokens) do
+ local token_data = redis.call('HMGET', token, 'H', 'S')
+
+ if token_data then
+ local ham_count = token_data[1]
+ local spam_count = token_data[2]
+
+ if ham_count then
+ table.insert(output_ham, { i, tonumber(ham_count) })
+ end
+
+ if spam_count then
+ table.insert(output_spam, { i, tonumber(spam_count) })
+ end
+ end
+ end
+end
+
+return { learned_ham, learned_spam, output_ham, output_spam } \ No newline at end of file
diff --git a/lualib/redis_scripts/bayes_learn.lua b/lualib/redis_scripts/bayes_learn.lua
new file mode 100644
index 0000000..80d86d8
--- /dev/null
+++ b/lualib/redis_scripts/bayes_learn.lua
@@ -0,0 +1,44 @@
+-- Lua script to perform bayes learning
+-- This script accepts the following parameters:
+-- key1 - prefix for bayes tokens (e.g. for per-user classification)
+-- key2 - boolean is_spam
+-- key3 - string symbol
+-- key4 - boolean is_unlearn
+-- key5 - set of tokens encoded in messagepack array of strings
+-- key6 - set of text tokens (if any) encoded in messagepack array of strings (size must be twice of `KEYS[5]`)
+
+local prefix = KEYS[1]
+local is_spam = KEYS[2] == 'true' and true or false
+local symbol = KEYS[3]
+local is_unlearn = KEYS[4] == 'true' and true or false
+local input_tokens = cmsgpack.unpack(KEYS[5])
+local text_tokens
+
+if KEYS[6] then
+ text_tokens = cmsgpack.unpack(KEYS[6])
+end
+
+local hash_key = is_spam and 'S' or 'H'
+local learned_key = is_spam and 'learns_spam' or 'learns_ham'
+
+redis.call('SADD', symbol .. '_keys', prefix)
+redis.call('HSET', prefix, 'version', '2') -- new schema
+redis.call('HINCRBY', prefix, learned_key, is_unlearn and -1 or 1) -- increase or decrease learned count
+
+for i, token in ipairs(input_tokens) do
+ redis.call('HINCRBY', token, hash_key, 1)
+ if text_tokens then
+ local tok1 = text_tokens[i * 2 - 1]
+ local tok2 = text_tokens[i * 2]
+
+ if tok1 then
+ if tok2 then
+ redis.call('HSET', token, 'tokens', string.format('%s:%s', tok1, tok2))
+ else
+ redis.call('HSET', token, 'tokens', tok1)
+ end
+
+ redis.call('ZINCRBY', prefix .. '_z', is_unlearn and -1 or 1, token)
+ end
+ end
+end \ No newline at end of file
diff --git a/lualib/redis_scripts/bayes_stat.lua b/lualib/redis_scripts/bayes_stat.lua
new file mode 100644
index 0000000..31e5128
--- /dev/null
+++ b/lualib/redis_scripts/bayes_stat.lua
@@ -0,0 +1,19 @@
+-- Lua script to perform bayes stats
+-- This script accepts the following parameters:
+-- key1 - current cursor
+-- key2 - symbol to examine
+-- key3 - learn key (e.g. learns_ham or learns_spam)
+-- key4 - max users
+
+local cursor = tonumber(KEYS[1])
+
+local ret = redis.call('SSCAN', KEYS[2] .. '_keys', cursor, 'COUNT', tonumber(KEYS[4]))
+
+local new_cursor = tonumber(ret[1])
+local nkeys = #ret[2]
+local learns = 0
+for _, key in ipairs(ret[2]) do
+ learns = learns + (tonumber(redis.call('HGET', key, KEYS[3])) or 0)
+end
+
+return { new_cursor, nkeys, learns } \ No newline at end of file
diff --git a/lualib/redis_scripts/neural_maybe_invalidate.lua b/lualib/redis_scripts/neural_maybe_invalidate.lua
new file mode 100644
index 0000000..517fa01
--- /dev/null
+++ b/lualib/redis_scripts/neural_maybe_invalidate.lua
@@ -0,0 +1,25 @@
+-- Lua script to invalidate ANNs by rank
+-- Uses the following keys
+-- key1 - prefix for keys
+-- key2 - number of elements to leave
+
+local card = redis.call('ZCARD', KEYS[1])
+local lim = tonumber(KEYS[2])
+if card > lim then
+ local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1)
+ if to_delete then
+ for _, k in ipairs(to_delete) do
+ local tb = cjson.decode(k)
+ if type(tb) == 'table' and type(tb.redis_key) == 'string' then
+ redis.call('DEL', tb.redis_key)
+ -- Also train vectors
+ redis.call('DEL', tb.redis_key .. '_spam_set')
+ redis.call('DEL', tb.redis_key .. '_ham_set')
+ end
+ end
+ end
+ redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1)
+ return to_delete
+else
+ return {}
+end \ No newline at end of file
diff --git a/lualib/redis_scripts/neural_maybe_lock.lua b/lualib/redis_scripts/neural_maybe_lock.lua
new file mode 100644
index 0000000..f705115
--- /dev/null
+++ b/lualib/redis_scripts/neural_maybe_lock.lua
@@ -0,0 +1,19 @@
+-- Lua script lock ANN for learning
+-- Uses the following keys
+-- key1 - prefix for keys
+-- key2 - current time
+-- key3 - key expire
+-- key4 - hostname
+
+local locked = redis.call('HGET', KEYS[1], 'lock')
+local now = tonumber(KEYS[2])
+if locked then
+ locked = tonumber(locked)
+ local expire = tonumber(KEYS[3])
+ if now > locked and (now - locked) < expire then
+ return { tostring(locked), redis.call('HGET', KEYS[1], 'hostname') or 'unknown' }
+ end
+end
+redis.call('HSET', KEYS[1], 'lock', tostring(now))
+redis.call('HSET', KEYS[1], 'hostname', KEYS[4])
+return 1 \ No newline at end of file
diff --git a/lualib/redis_scripts/neural_save_unlock.lua b/lualib/redis_scripts/neural_save_unlock.lua
new file mode 100644
index 0000000..5af1ddc
--- /dev/null
+++ b/lualib/redis_scripts/neural_save_unlock.lua
@@ -0,0 +1,24 @@
+-- Lua script to save and unlock ANN in redis
+-- Uses the following keys
+-- key1 - prefix for ANN
+-- key2 - prefix for profile
+-- key3 - compressed ANN
+-- key4 - profile as JSON
+-- key5 - expire in seconds
+-- key6 - current time
+-- key7 - old key
+-- key8 - ROC Thresholds
+-- key9 - optional PCA
+local now = tonumber(KEYS[6])
+redis.call('ZADD', KEYS[2], now, KEYS[4])
+redis.call('HSET', KEYS[1], 'ann', KEYS[3])
+redis.call('DEL', KEYS[1] .. '_spam_set')
+redis.call('DEL', KEYS[1] .. '_ham_set')
+redis.call('HDEL', KEYS[1], 'lock')
+redis.call('HDEL', KEYS[7], 'lock')
+redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))
+redis.call('HSET', KEYS[1], 'roc_thresholds', KEYS[8])
+if KEYS[9] then
+ redis.call('HSET', KEYS[1], 'pca', KEYS[9])
+end
+return 1 \ No newline at end of file
diff --git a/lualib/redis_scripts/neural_train_size.lua b/lualib/redis_scripts/neural_train_size.lua
new file mode 100644
index 0000000..45ad6a9
--- /dev/null
+++ b/lualib/redis_scripts/neural_train_size.lua
@@ -0,0 +1,24 @@
+-- Lua script that checks if we can store a new training vector
+-- Uses the following keys:
+-- key1 - ann key
+-- returns nspam,nham (or nil if locked)
+
+local prefix = KEYS[1]
+local locked = redis.call('HGET', prefix, 'lock')
+if locked then
+ local host = redis.call('HGET', prefix, 'hostname') or 'unknown'
+ return string.format('%s:%s', host, locked)
+end
+local nspam = 0
+local nham = 0
+
+local ret = redis.call('SCARD', prefix .. '_spam_set')
+if ret then
+ nspam = tonumber(ret)
+end
+ret = redis.call('SCARD', prefix .. '_ham_set')
+if ret then
+ nham = tonumber(ret)
+end
+
+return { nspam, nham } \ No newline at end of file
diff --git a/lualib/redis_scripts/ratelimit_check.lua b/lualib/redis_scripts/ratelimit_check.lua
new file mode 100644
index 0000000..d39cdf1
--- /dev/null
+++ b/lualib/redis_scripts/ratelimit_check.lua
@@ -0,0 +1,85 @@
+-- This Lua script is a rate limiter for Redis using the token bucket algorithm.
+-- The script checks if a message should be rate-limited and updates the bucket status accordingly.
+-- Input keys:
+-- KEYS[1]: A prefix for the Redis keys, e.g., RL_<triplet>_<seconds>
+-- KEYS[2]: The current time in milliseconds
+-- KEYS[3]: The bucket leak rate (messages per millisecond)
+-- KEYS[4]: The maximum allowed burst
+-- KEYS[5]: The expiration time for a bucket
+-- KEYS[6]: The number of recipients for the message
+
+-- Redis keys used:
+-- l: Last hit (time in milliseconds)
+-- b: Current burst (number of tokens in the bucket)
+-- p: Pending messages (number of messages in processing)
+-- dr: Current dynamic rate multiplier (*10000)
+-- db: Current dynamic burst multiplier (*10000)
+
+-- Returns:
+-- An array containing:
+-- 1. if the message should be rate-limited or 0 if not
+-- 2. The current burst value after processing the message
+-- 3. The dynamic rate multiplier
+-- 4. The dynamic burst multiplier
+-- 5. The number of tokens leaked during processing
+
+local last = redis.call('HGET', KEYS[1], 'l')
+local now = tonumber(KEYS[2])
+local nrcpt = tonumber(KEYS[6])
+local leak_rate = tonumber(KEYS[3])
+local max_burst = tonumber(KEYS[4])
+local prefix = KEYS[1]
+local dynr, dynb, leaked = 0, 0, 0
+if not last then
+ -- New bucket
+ redis.call('HMSET', prefix, 'l', tostring(now), 'b', '0', 'dr', '10000', 'db', '10000', 'p', tostring(nrcpt))
+ redis.call('EXPIRE', prefix, KEYS[5])
+ return { 0, '0', '1', '1', '0' }
+end
+last = tonumber(last)
+
+local burst, pending = unpack(redis.call('HMGET', prefix, 'b', 'p'))
+burst, pending = tonumber(burst or '0'), tonumber(pending or '0')
+-- Sanity to avoid races
+if burst < 0 then
+ burst = 0
+end
+if pending < 0 then
+ pending = 0
+end
+pending = pending + nrcpt -- this message
+-- Perform leak
+if burst + pending > 0 then
+ -- If we have any time passed
+ if burst > 0 and last < now then
+ dynr = tonumber(redis.call('HGET', prefix, 'dr')) / 10000.0
+ if dynr == 0 then
+ dynr = 0.0001
+ end
+ leak_rate = leak_rate * dynr
+ leaked = ((now - last) * leak_rate)
+ if leaked > burst then
+ leaked = burst
+ end
+ burst = burst - leaked
+ redis.call('HINCRBYFLOAT', prefix, 'b', -(leaked))
+ redis.call('HSET', prefix, 'l', tostring(now))
+ end
+
+ dynb = tonumber(redis.call('HGET', prefix, 'db')) / 10000.0
+ if dynb == 0 then
+ dynb = 0.0001
+ end
+
+ burst = burst + pending
+ if burst > 0 and burst > max_burst * dynb then
+ return { 1, tostring(burst - pending), tostring(dynr), tostring(dynb), tostring(leaked) }
+ end
+ -- Increase pending if we allow ratelimit
+ redis.call('HINCRBY', prefix, 'p', nrcpt)
+else
+ burst = 0
+ redis.call('HMSET', prefix, 'b', '0', 'p', tostring(nrcpt))
+end
+
+return { 0, tostring(burst), tostring(dynr), tostring(dynb), tostring(leaked) } \ No newline at end of file
diff --git a/lualib/redis_scripts/ratelimit_cleanup_pending.lua b/lualib/redis_scripts/ratelimit_cleanup_pending.lua
new file mode 100644
index 0000000..698a3ec
--- /dev/null
+++ b/lualib/redis_scripts/ratelimit_cleanup_pending.lua
@@ -0,0 +1,33 @@
+-- This script cleans up the pending requests in Redis.
+
+-- KEYS: Input parameters
+-- KEYS[1] - prefix: The Redis key prefix used to store the bucket information.
+-- KEYS[2] - now: The current time in milliseconds.
+-- KEYS[3] - expire: The expiration time for the Redis key storing the bucket information, in seconds.
+-- KEYS[4] - number_of_recipients: The number of requests to be allowed (or the increase rate).
+
+-- 1. Retrieve the last hit time and initialize variables
+local prefix = KEYS[1]
+local last = redis.call('HGET', prefix, 'l')
+local nrcpt = tonumber(KEYS[4])
+if not last then
+ -- No bucket, no cleanup
+ return 0
+end
+
+
+-- 2. Update the pending values based on the number of recipients (requests)
+local pending = redis.call('HGET', prefix, 'p')
+pending = tonumber(pending or '0')
+if pending < nrcpt then
+ pending = 0
+else
+ pending = pending - nrcpt
+end
+
+-- 3. Set the updated values back to Redis and update the expiration time for the bucket
+redis.call('HMSET', prefix, 'p', tostring(pending), 'l', KEYS[2])
+redis.call('EXPIRE', prefix, KEYS[3])
+
+-- 4. Return the updated pending value
+return pending \ No newline at end of file
diff --git a/lualib/redis_scripts/ratelimit_update.lua b/lualib/redis_scripts/ratelimit_update.lua
new file mode 100644
index 0000000..caee8fb
--- /dev/null
+++ b/lualib/redis_scripts/ratelimit_update.lua
@@ -0,0 +1,93 @@
+-- This script updates a token bucket rate limiter with dynamic rate and burst multipliers in Redis.
+
+-- KEYS: Input parameters
+-- KEYS[1] - prefix: The Redis key prefix used to store the bucket information.
+-- KEYS[2] - now: The current time in milliseconds.
+-- KEYS[3] - dynamic_rate_multiplier: A multiplier to adjust the rate limit dynamically.
+-- KEYS[4] - dynamic_burst_multiplier: A multiplier to adjust the burst limit dynamically.
+-- KEYS[5] - max_dyn_rate: The maximum allowed value for the dynamic rate multiplier.
+-- KEYS[6] - max_burst_rate: The maximum allowed value for the dynamic burst multiplier.
+-- KEYS[7] - expire: The expiration time for the Redis key storing the bucket information, in seconds.
+-- KEYS[8] - number_of_recipients: The number of requests to be allowed (or the increase rate).
+
+-- 1. Retrieve the last hit time and initialize variables
+local prefix = KEYS[1]
+local last = redis.call('HGET', prefix, 'l')
+local now = tonumber(KEYS[2])
+local nrcpt = tonumber(KEYS[8])
+if not last then
+ -- 2. Initialize a new bucket if the last hit time is not found (must not happen)
+ redis.call('HMSET', prefix, 'l', tostring(now), 'b', tostring(nrcpt), 'dr', '10000', 'db', '10000', 'p', '0')
+ redis.call('EXPIRE', prefix, KEYS[7])
+ return { 1, 1, 1 }
+end
+
+-- 3. Update the dynamic rate multiplier based on input parameters
+local dr, db = 1.0, 1.0
+
+local max_dr = tonumber(KEYS[5])
+
+if max_dr > 1 then
+ local rate_mult = tonumber(KEYS[3])
+ dr = tonumber(redis.call('HGET', prefix, 'dr')) / 10000
+
+ if rate_mult > 1.0 and dr < max_dr then
+ dr = dr * rate_mult
+ if dr > 0.0001 then
+ redis.call('HSET', prefix, 'dr', tostring(math.floor(dr * 10000)))
+ else
+ redis.call('HSET', prefix, 'dr', '1')
+ end
+ elseif rate_mult < 1.0 and dr > (1.0 / max_dr) then
+ dr = dr * rate_mult
+ if dr > 0.0001 then
+ redis.call('HSET', prefix, 'dr', tostring(math.floor(dr * 10000)))
+ else
+ redis.call('HSET', prefix, 'dr', '1')
+ end
+ end
+end
+
+-- 4. Update the dynamic burst multiplier based on input parameters
+local max_db = tonumber(KEYS[6])
+if max_db > 1 then
+ local rate_mult = tonumber(KEYS[4])
+ db = tonumber(redis.call('HGET', prefix, 'db')) / 10000
+
+ if rate_mult > 1.0 and db < max_db then
+ db = db * rate_mult
+ if db > 0.0001 then
+ redis.call('HSET', prefix, 'db', tostring(math.floor(db * 10000)))
+ else
+ redis.call('HSET', prefix, 'db', '1')
+ end
+ elseif rate_mult < 1.0 and db > (1.0 / max_db) then
+ db = db * rate_mult
+ if db > 0.0001 then
+ redis.call('HSET', prefix, 'db', tostring(math.floor(db * 10000)))
+ else
+ redis.call('HSET', prefix, 'db', '1')
+ end
+ end
+end
+
+-- 5. Update the burst and pending values based on the number of recipients (requests)
+local burst, pending = unpack(redis.call('HMGET', prefix, 'b', 'p'))
+burst, pending = tonumber(burst or '0'), tonumber(pending or '0')
+if burst < 0 then
+ burst = nrcpt
+else
+ burst = burst + nrcpt
+end
+if pending < nrcpt then
+ pending = 0
+else
+ pending = pending - nrcpt
+end
+
+-- 6. Set the updated values back to Redis and update the expiration time for the bucket
+redis.call('HMSET', prefix, 'b', tostring(burst), 'p', tostring(pending), 'l', KEYS[2])
+redis.call('EXPIRE', prefix, KEYS[7])
+
+-- 7. Return the updated burst value, dynamic rate multiplier, and dynamic burst multiplier
+return { tostring(burst), tostring(dr), tostring(db) } \ No newline at end of file
diff --git a/lualib/rspamadm/clickhouse.lua b/lualib/rspamadm/clickhouse.lua
new file mode 100644
index 0000000..b22d800
--- /dev/null
+++ b/lualib/rspamadm/clickhouse.lua
@@ -0,0 +1,528 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local argparse = require "argparse"
+local lua_clickhouse = require "lua_clickhouse"
+local lua_util = require "lua_util"
+local rspamd_http = require "rspamd_http"
+local rspamd_upstream_list = require "rspamd_upstream_list"
+local rspamd_logger = require "rspamd_logger"
+local ucl = require "ucl"
+
+local E = {}
+
+-- Define command line options
+local parser = argparse()
+ :name 'rspamadm clickhouse'
+ :description 'Retrieve information from Clickhouse'
+ :help_description_margin(30)
+ :command_target('command')
+ :require_command(true)
+
+parser:option '-c --config'
+ :description 'Path to config file'
+ :argname('config_file')
+ :default(rspamd_paths['CONFDIR'] .. '/rspamd.conf')
+parser:option '-d --database'
+ :description 'Name of Clickhouse database to use'
+ :argname('database')
+ :default('default')
+parser:flag '--no-ssl-verify'
+ :description 'Disable SSL verification'
+ :argname('no_ssl_verify')
+parser:mutex(
+ parser:option '-p --password'
+ :description 'Password to use for Clickhouse'
+ :argname('password'),
+ parser:flag '-a --ask-password'
+ :description 'Ask password from the terminal'
+ :argname('ask_password')
+)
+parser:option '-s --server'
+ :description 'Address[:port] to connect to Clickhouse with'
+ :argname('server')
+parser:option '-u --user'
+ :description 'Username to use for Clickhouse'
+ :argname('user')
+parser:option '--use-gzip'
+ :description 'Use Gzip with Clickhouse'
+ :argname('use_gzip')
+ :default(true)
+parser:flag '--use-https'
+ :description 'Use HTTPS with Clickhouse'
+ :argname('use_https')
+
+local neural_profile = parser:command 'neural_profile'
+ :description 'Generate symbols profile using data from Clickhouse'
+neural_profile:option '-w --where'
+ :description 'WHERE clause for Clickhouse query'
+ :argname('where')
+neural_profile:flag '-j --json'
+ :description 'Write output as JSON'
+ :argname('json')
+neural_profile:option '--days'
+ :description 'Number of days to collect stats for'
+ :argname('days')
+ :default('7')
+neural_profile:option '--limit -l'
+ :description 'Maximum rows to fetch per day'
+ :argname('limit')
+neural_profile:option '--settings-id'
+ :description 'Settings ID to query'
+ :argname('settings_id')
+ :default('')
+
+local neural_train = parser:command 'neural_train'
+ :description 'Train neural using data from Clickhouse'
+neural_train:option '--days'
+ :description 'Number of days to query data for'
+ :argname('days')
+ :default('7')
+neural_train:option '--column-name-digest'
+ :description 'Name of neural profile digest column in Clickhouse'
+ :argname('column_name_digest')
+ :default('NeuralDigest')
+neural_train:option '--column-name-vector'
+ :description 'Name of neural training vector column in Clickhouse'
+ :argname('column_name_vector')
+ :default('NeuralMpack')
+neural_train:option '--limit -l'
+ :description 'Maximum rows to fetch per day'
+ :argname('limit')
+neural_train:option '--profile -p'
+ :description 'Profile to use for training'
+ :argname('profile')
+ :default('default')
+neural_train:option '--rule -r'
+ :description 'Rule to train'
+ :argname('rule')
+ :default('default')
+neural_train:option '--spam -s'
+ :description 'WHERE clause to use for spam'
+ :argname('spam')
+ :default("Action == 'reject'")
+neural_train:option '--ham -h'
+ :description 'WHERE clause to use for ham'
+ :argname('ham')
+ :default('Score < 0')
+neural_train:option '--url -u'
+ :description 'URL to use for training'
+ :argname('url')
+ :default('http://127.0.0.1:11334/plugins/neural/learn')
+
+local http_params = {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+}
+
+local function load_config(config_file)
+ local _r, err = rspamd_config:load_ucl(config_file)
+
+ if not _r then
+ rspamd_logger.errx('cannot load %s: %s', config_file, err)
+ os.exit(1)
+ end
+
+ _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' })
+ if not _r then
+ rspamd_logger.errx('cannot process %s: %s', config_file, err)
+ os.exit(1)
+ end
+
+ if not rspamd_config:init_modules() then
+ rspamd_logger.errx('cannot init modules when parsing %s', config_file)
+ os.exit(1)
+ end
+
+ rspamd_config:init_subsystem('symcache')
+end
+
+local function days_list(days)
+ -- Create list of days to query starting with yesterday
+ local query_days = {}
+ local previous_date = os.time() - 86400
+ local num_days = tonumber(days)
+ for _ = 1, num_days do
+ table.insert(query_days, os.date('%Y-%m-%d', previous_date))
+ previous_date = previous_date - 86400
+ end
+ return query_days
+end
+
+local function get_excluded_symbols(known_symbols, correlations, seen_total)
+ -- Walk results once to collect all symbols & count occurrences
+
+ local remove = {}
+ local known_symbols_list = {}
+ local composites = rspamd_config:get_all_opt('composites')
+ local all_symbols = rspamd_config:get_symbols()
+ local skip_flags = {
+ nostat = true,
+ skip = true,
+ idempotent = true,
+ composite = true,
+ }
+ for k, v in pairs(known_symbols) do
+ local lower_count, higher_count
+ if v.seen_spam > v.seen_ham then
+ lower_count = v.seen_ham
+ higher_count = v.seen_spam
+ else
+ lower_count = v.seen_spam
+ higher_count = v.seen_ham
+ end
+
+ if composites[k] then
+ remove[k] = 'composite symbol'
+ elseif lower_count / higher_count >= 0.95 then
+ remove[k] = 'weak ham/spam correlation'
+ elseif v.seen / seen_total >= 0.9 then
+ remove[k] = 'omnipresent symbol'
+ elseif not all_symbols[k] then
+ remove[k] = 'nonexistent symbol'
+ else
+ for fl, _ in pairs(all_symbols[k].flags or {}) do
+ if skip_flags[fl] then
+ remove[k] = fl .. ' symbol'
+ break
+ end
+ end
+ end
+ known_symbols_list[v.id] = {
+ seen = v.seen,
+ name = k,
+ }
+ end
+
+ -- Walk correlation matrix and check total counts
+ for sym_id, row in pairs(correlations) do
+ for inner_sym_id, count in pairs(row) do
+ local known = known_symbols_list[sym_id]
+ local inner = known_symbols_list[inner_sym_id]
+ if known and count == known.seen and not remove[inner.name] and not remove[known.name] then
+ remove[known.name] = string.format("overlapped by %s",
+ known_symbols_list[inner_sym_id].name)
+ end
+ end
+ end
+
+ return remove
+end
+
+local function handle_neural_profile(args)
+
+ local known_symbols, correlations = {}, {}
+ local symbols_count, seen_total = 0, 0
+
+ local function process_row(r)
+ local is_spam = true
+ if r['Action'] == 'no action' or r['Action'] == 'greylist' then
+ is_spam = false
+ end
+ seen_total = seen_total + 1
+
+ local nsym = #r['Symbols.Names']
+
+ for i = 1, nsym do
+ local sym = r['Symbols.Names'][i]
+ local t = known_symbols[sym]
+ if not t then
+ local spam_count, ham_count = 0, 0
+ if is_spam then
+ spam_count = spam_count + 1
+ else
+ ham_count = ham_count + 1
+ end
+ known_symbols[sym] = {
+ id = symbols_count,
+ seen = 1,
+ seen_ham = ham_count,
+ seen_spam = spam_count,
+ }
+ symbols_count = symbols_count + 1
+ else
+ known_symbols[sym].seen = known_symbols[sym].seen + 1
+ if is_spam then
+ known_symbols[sym].seen_spam = known_symbols[sym].seen_spam + 1
+ else
+ known_symbols[sym].seen_ham = known_symbols[sym].seen_ham + 1
+ end
+ end
+ end
+
+ -- Fill correlations
+ for i = 1, nsym do
+ for j = 1, nsym do
+ if i ~= j then
+ local sym = r['Symbols.Names'][i]
+ local inner_sym_name = r['Symbols.Names'][j]
+ local known_sym = known_symbols[sym]
+ local inner_sym = known_symbols[inner_sym_name]
+ if known_sym and inner_sym then
+ if not correlations[known_sym.id] then
+ correlations[known_sym.id] = {}
+ end
+ local n = correlations[known_sym.id][inner_sym.id] or 0
+ n = n + 1
+ correlations[known_sym.id][inner_sym.id] = n
+ end
+ end
+ end
+ end
+ end
+
+ local query_days = days_list(args.days)
+ local conditions = {}
+ table.insert(conditions, string.format("SettingsId = '%s'", args.settings_id))
+ local limit = ''
+ local num_limit = tonumber(args.limit)
+ if num_limit then
+ limit = string.format(' LIMIT %d', num_limit) -- Contains leading space
+ end
+ if args.where then
+ table.insert(conditions, args.where)
+ end
+
+ local query_fmt = 'SELECT Action, Symbols.Names FROM rspamd WHERE %s%s'
+ for _, query_day in ipairs(query_days) do
+ -- Date should be the last condition
+ table.insert(conditions, string.format("Date = '%s'", query_day))
+ local query = string.format(query_fmt, table.concat(conditions, ' AND '), limit)
+ local upstream = args.upstream:get_upstream_round_robin()
+ local err = lua_clickhouse.select_sync(upstream, args, http_params, query, process_row)
+ if err ~= nil then
+ io.stderr:write(string.format('Error querying Clickhouse: %s\n', err))
+ os.exit(1)
+ end
+ conditions[#conditions] = nil -- remove Date condition
+ end
+
+ local remove = get_excluded_symbols(known_symbols, correlations, seen_total)
+ if not args.json then
+ for k in pairs(known_symbols) do
+ if not remove[k] then
+ io.stdout:write(string.format('%s\n', k))
+ end
+ end
+ os.exit(0)
+ end
+
+ local json_output = {
+ all_symbols = {},
+ removed_symbols = {},
+ used_symbols = {},
+ }
+ for k in pairs(known_symbols) do
+ table.insert(json_output.all_symbols, k)
+ local why_removed = remove[k]
+ if why_removed then
+ json_output.removed_symbols[k] = why_removed
+ else
+ table.insert(json_output.used_symbols, k)
+ end
+ end
+ io.stdout:write(ucl.to_format(json_output, 'json'))
+end
+
+local function post_neural_training(url, rule, spam_rows, ham_rows)
+ -- Prepare JSON payload
+ local payload = ucl.to_format(
+ {
+ ham_vec = ham_rows,
+ rule = rule,
+ spam_vec = spam_rows,
+ }, 'json')
+
+ -- POST the payload
+ local err, response = rspamd_http.request({
+ body = payload,
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ log_obj = rspamd_config,
+ resolver = rspamadm_dns_resolver,
+ session = rspamadm_session,
+ url = url,
+ })
+
+ if err then
+ io.stderr:write(string.format('HTTP error: %s\n', err))
+ os.exit(1)
+ end
+ if response.code ~= 200 then
+ io.stderr:write(string.format('bad HTTP code: %d\n', response.code))
+ os.exit(1)
+ end
+ io.stdout:write(string.format('%s\n', response.content))
+end
+
+local function handle_neural_train(args)
+
+ local this_where -- which class of messages are we collecting data for
+ local ham_rows, spam_rows = {}, {}
+ local want_spam, want_ham = true, true -- keep collecting while true
+
+ -- Try find profile in config
+ local neural_opts = rspamd_config:get_all_opt('neural')
+ local symbols_profile = ((((neural_opts or E).rules or E)[args.rule] or E).profile or E)[args.profile]
+ if not symbols_profile then
+ io.stderr:write(string.format("Couldn't find profile %s in rule %s\n", args.profile, args.rule))
+ os.exit(1)
+ end
+ -- Try find max_trains
+ local max_trains = (neural_opts.rules[args.rule].train or E).max_trains or 1000
+
+ -- Callback used to process rows from Clickhouse
+ local function process_row(r)
+ local destination -- which table to collect this information in
+ if this_where == args.ham then
+ destination = ham_rows
+ if #destination >= max_trains then
+ want_ham = false
+ return
+ end
+ else
+ destination = spam_rows
+ if #destination >= max_trains then
+ want_spam = false
+ return
+ end
+ end
+ local ucl_parser = ucl.parser()
+ local ok, err = ucl_parser:parse_string(r[args.column_name_vector], 'msgpack')
+ if not ok then
+ io.stderr:write(string.format("Couldn't parse [%s]: %s", r[args.column_name_vector], err))
+ os.exit(1)
+ end
+ table.insert(destination, ucl_parser:get_object())
+ end
+
+ -- Generate symbols digest
+ table.sort(symbols_profile)
+ local symbols_digest = lua_util.table_digest(symbols_profile)
+ -- Create list of days to query data for
+ local query_days = days_list(args.days)
+ -- Set value for limit
+ local limit = ''
+ local num_limit = tonumber(args.limit)
+ if num_limit then
+ limit = string.format(' LIMIT %d', num_limit) -- Contains leading space
+ end
+ -- Prepare query elements
+ local conditions = { string.format("%s = '%s'", args.column_name_digest, symbols_digest) }
+ local query_fmt = 'SELECT %s FROM rspamd WHERE %s%s'
+
+ -- Run queries
+ for _, the_where in ipairs({ args.ham, args.spam }) do
+ -- Inform callback which group of vectors we're collecting
+ this_where = the_where
+ table.insert(conditions, the_where) -- should be 2nd from last condition
+ -- Loop over days and try collect data
+ for _, query_day in ipairs(query_days) do
+ -- Break the loop if we have enough data already
+ if this_where == args.ham then
+ if not want_ham then
+ break
+ end
+ else
+ if not want_spam then
+ break
+ end
+ end
+ -- Date should be the last condition
+ table.insert(conditions, string.format("Date = '%s'", query_day))
+ local query = string.format(query_fmt, args.column_name_vector, table.concat(conditions, ' AND '), limit)
+ local upstream = args.upstream:get_upstream_round_robin()
+ local err = lua_clickhouse.select_sync(upstream, args, http_params, query, process_row)
+ if err ~= nil then
+ io.stderr:write(string.format('Error querying Clickhouse: %s\n', err))
+ os.exit(1)
+ end
+ conditions[#conditions] = nil -- remove Date condition
+ end
+ conditions[#conditions] = nil -- remove spam/ham condition
+ end
+
+ -- Make sure we collected enough data for training
+ if #ham_rows < max_trains then
+ io.stderr:write(string.format('Insufficient ham rows: %d/%d\n', #ham_rows, max_trains))
+ os.exit(1)
+ end
+ if #spam_rows < max_trains then
+ io.stderr:write(string.format('Insufficient spam rows: %d/%d\n', #spam_rows, max_trains))
+ os.exit(1)
+ end
+
+ return post_neural_training(args.url, args.rule, spam_rows, ham_rows)
+end
+
+local command_handlers = {
+ neural_profile = handle_neural_profile,
+ neural_train = handle_neural_train,
+}
+
+local function handler(args)
+ local cmd_opts = parser:parse(args)
+
+ load_config(cmd_opts.config_file)
+ local cfg_opts = rspamd_config:get_all_opt('clickhouse')
+
+ if cmd_opts.ask_password then
+ local rspamd_util = require "rspamd_util"
+
+ io.write('Password: ')
+ cmd_opts.password = rspamd_util.readpassphrase()
+ end
+
+ local function override_settings(params)
+ for _, which in ipairs(params) do
+ if cmd_opts[which] == nil then
+ cmd_opts[which] = cfg_opts[which]
+ end
+ end
+ end
+
+ override_settings({
+ 'database', 'no_ssl_verify', 'password', 'server',
+ 'use_gzip', 'use_https', 'user',
+ })
+
+ local servers = cmd_opts['server'] or cmd_opts['servers']
+ if not servers then
+ parser:error("server(s) unspecified & couldn't be fetched from config")
+ end
+
+ cmd_opts.upstream = rspamd_upstream_list.create(rspamd_config, servers, 8123)
+
+ if not cmd_opts.upstream then
+ io.stderr:write(string.format("can't parse clickhouse address: %s\n", servers))
+ os.exit(1)
+ end
+
+ local f = command_handlers[cmd_opts.command]
+ if not f then
+ parser:error(string.format("command isn't implemented: %s",
+ cmd_opts.command))
+ end
+ f(cmd_opts)
+end
+
+return {
+ handler = handler,
+ description = parser._description,
+ name = 'clickhouse'
+}
diff --git a/lualib/rspamadm/configgraph.lua b/lualib/rspamadm/configgraph.lua
new file mode 100644
index 0000000..07f14a9
--- /dev/null
+++ b/lualib/rspamadm/configgraph.lua
@@ -0,0 +1,172 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local rspamd_logger = require "rspamd_logger"
+local rspamd_util = require "rspamd_util"
+local rspamd_regexp = require "rspamd_regexp"
+local argparse = require "argparse"
+
+-- Define command line options
+local parser = argparse()
+ :name "rspamadm configgraph"
+ :description "Produces graph of Rspamd includes"
+ :help_description_margin(30)
+parser:option "-c --config"
+ :description "Path to config file"
+ :argname("<file>")
+ :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
+parser:flag "-a --all"
+ :description('Show all nodes, not just existing ones')
+
+local function process_filename(fname)
+ local cdir = rspamd_paths['CONFDIR'] .. '/'
+ fname = fname:gsub(cdir, '')
+ return fname
+end
+
+local function output_dot(opts, nodes, adjacency)
+ rspamd_logger.messagex("digraph rspamd {")
+ for k, node in pairs(nodes) do
+ local attrs = { "shape=box" }
+ local skip = false
+ if node.exists then
+ if node.priority >= 10 then
+ attrs[#attrs + 1] = "color=red"
+ elseif node.priority > 0 then
+ attrs[#attrs + 1] = "color=blue"
+ end
+ else
+ if opts.all then
+ attrs[#attrs + 1] = "style=dotted"
+ else
+ skip = true
+ end
+ end
+
+ if not skip then
+ rspamd_logger.messagex("\"%s\" [%s];", process_filename(k),
+ table.concat(attrs, ','))
+ end
+ end
+ for _, adj in ipairs(adjacency) do
+ local attrs = {}
+ local skip = false
+
+ if adj.to.exists then
+ if adj.to.merge then
+ attrs[#attrs + 1] = "arrowhead=diamond"
+ attrs[#attrs + 1] = "label=\"+\""
+ elseif adj.to.priority > 1 then
+ attrs[#attrs + 1] = "color=red"
+ end
+ else
+ if opts.all then
+ attrs[#attrs + 1] = "style=dotted"
+ else
+ skip = true
+ end
+ end
+
+ if not skip then
+ rspamd_logger.messagex("\"%s\" -> \"%s\" [%s];", process_filename(adj.from),
+ adj.to.short_path, table.concat(attrs, ','))
+ end
+ end
+ rspamd_logger.messagex("}")
+end
+
+local function load_config_traced(opts)
+ local glob_traces = {}
+ local adjacency = {}
+ local nodes = {}
+
+ local function maybe_match_glob(file)
+ for _, gl in ipairs(glob_traces) do
+ if gl.re:match(file) then
+ return gl
+ end
+ end
+
+ return nil
+ end
+
+ local function add_dep(from, node, args)
+ adjacency[#adjacency + 1] = {
+ from = from,
+ to = node,
+ args = args
+ }
+ end
+
+ local function process_node(fname, args)
+ local node = nodes[fname]
+ if not node then
+ node = {
+ path = fname,
+ short_path = process_filename(fname),
+ exists = rspamd_util.file_exists(fname),
+ merge = args.duplicate and args.duplicate == 'merge',
+ priority = args.priority or 0,
+ glob = args.glob,
+ try = args.try,
+ }
+ nodes[fname] = node
+ end
+
+ return node
+ end
+
+ local function trace_func(cur_file, included_file, args, parent)
+ if args.glob then
+ glob_traces[#glob_traces + 1] = {
+ re = rspamd_regexp.import_glob(included_file, ''),
+ parent = cur_file,
+ args = args,
+ seen = {},
+ }
+ else
+ local node = process_node(included_file, args)
+ if opts.all or node.exists then
+ local gl_parent = maybe_match_glob(included_file)
+ if gl_parent and not gl_parent.seen[cur_file] then
+ add_dep(gl_parent.parent, nodes[cur_file], gl_parent.args)
+ gl_parent.seen[cur_file] = true
+ end
+ add_dep(cur_file, node, args)
+ end
+ end
+ end
+
+ local _r, err = rspamd_config:load_ucl(opts['config'], trace_func)
+ if not _r then
+ rspamd_logger.errx('cannot parse %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
+ output_dot(opts, nodes, adjacency)
+end
+
+local function handler(args)
+ local res = parser:parse(args)
+
+ load_config_traced(res)
+end
+
+return {
+ handler = handler,
+ description = parser._description,
+ name = 'configgraph'
+} \ No newline at end of file
diff --git a/lualib/rspamadm/confighelp.lua b/lualib/rspamadm/confighelp.lua
new file mode 100644
index 0000000..38b26b6
--- /dev/null
+++ b/lualib/rspamadm/confighelp.lua
@@ -0,0 +1,123 @@
+local opts
+local known_attrs = {
+ data = 1,
+ example = 1,
+ type = 1,
+ required = 1,
+ default = 1,
+}
+local argparse = require "argparse"
+local ansicolors = require "ansicolors"
+
+local parser = argparse()
+ :name "rspamadm confighelp"
+ :description "Shows help for the specified configuration options"
+ :help_description_margin(32)
+parser:argument "path":args "*"
+ :description('Optional config paths')
+parser:flag "--no-color"
+ :description "Disable coloured output"
+parser:flag "--short"
+ :description "Show only option names"
+parser:flag "--no-examples"
+ :description "Do not show examples (implied by --short)"
+
+local function maybe_print_color(key)
+ if not opts['no-color'] then
+ return ansicolors.white .. key .. ansicolors.reset
+ else
+ return key
+ end
+end
+
+local function sort_values(tbl)
+ local res = {}
+ for k, v in pairs(tbl) do
+ table.insert(res, { key = k, value = v })
+ end
+
+ -- Sort order
+ local order = {
+ options = 1,
+ dns = 2,
+ upstream = 3,
+ logging = 4,
+ metric = 5,
+ composite = 6,
+ classifier = 7,
+ modules = 8,
+ lua = 9,
+ worker = 10,
+ workers = 11,
+ }
+
+ table.sort(res, function(a, b)
+ local oa = order[a['key']]
+ local ob = order[b['key']]
+
+ if oa and ob then
+ return oa < ob
+ elseif oa then
+ return -1 < 0
+ elseif ob then
+ return 1 < 0
+ else
+ return a['key'] < b['key']
+ end
+
+ end)
+
+ return res
+end
+
+local function print_help(key, value, tabs)
+ print(string.format('%sConfiguration element: %s', tabs, maybe_print_color(key)))
+
+ if not opts['short'] then
+ if value['data'] then
+ local nv = string.match(value['data'], '^#%s*(.*)%s*$') or value.data
+ print(string.format('%s\tDescription: %s', tabs, nv))
+ end
+ if type(value['type']) == 'string' then
+ print(string.format('%s\tType: %s', tabs, value['type']))
+ end
+ if type(value['required']) == 'boolean' then
+ if value['required'] then
+ print(string.format('%s\tRequired: %s', tabs,
+ maybe_print_color(tostring(value['required']))))
+ else
+ print(string.format('%s\tRequired: %s', tabs,
+ tostring(value['required'])))
+ end
+ end
+ if value['default'] then
+ print(string.format('%s\tDefault: %s', tabs, value['default']))
+ end
+ if not opts['no-examples'] and value['example'] then
+ local nv = string.match(value['example'], '^%s*(.*[^%s])%s*$') or value.example
+ print(string.format('%s\tExample:\n%s', tabs, nv))
+ end
+ if value.type and value.type == 'object' then
+ print('')
+ end
+ end
+
+ local sorted = sort_values(value)
+ for _, v in ipairs(sorted) do
+ if not known_attrs[v['key']] then
+ -- We need to go deeper
+ print_help(v['key'], v['value'], tabs .. '\t')
+ end
+ end
+end
+
+return function(args, res)
+ opts = parser:parse(args)
+
+ local sorted = sort_values(res)
+
+ for _, v in ipairs(sorted) do
+ print_help(v['key'], v['value'], '')
+ print('')
+ end
+end
diff --git a/lualib/rspamadm/configwizard.lua b/lualib/rspamadm/configwizard.lua
new file mode 100644
index 0000000..2637036
--- /dev/null
+++ b/lualib/rspamadm/configwizard.lua
@@ -0,0 +1,849 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local ansicolors = require "ansicolors"
+local local_conf = rspamd_paths['CONFDIR']
+local rspamd_util = require "rspamd_util"
+local rspamd_logger = require "rspamd_logger"
+local lua_util = require "lua_util"
+local lua_stat_tools = require "lua_stat"
+local lua_redis = require "lua_redis"
+local ucl = require "ucl"
+local argparse = require "argparse"
+local fun = require "fun"
+
+local plugins_stat = require "plugins_stats"
+
+local rspamd_logo = [[
+ ____ _
+ | _ \ ___ _ __ __ _ _ __ ___ __| |
+ | |_) |/ __|| '_ \ / _` || '_ ` _ \ / _` |
+ | _ < \__ \| |_) || (_| || | | | | || (_| |
+ |_| \_\|___/| .__/ \__,_||_| |_| |_| \__,_|
+ |_|
+]]
+
+local parser = argparse()
+ :name "rspamadm configwizard"
+ :description "Perform guided configuration for Rspamd daemon"
+ :help_description_margin(32)
+parser:option "-c --config"
+ :description "Path to config file"
+ :argname("<file>")
+ :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
+parser:argument "checks"
+ :description "Checks to do (or 'list')"
+ :argname("<checks>")
+ :args "*"
+
+local redis_params
+
+local function printf(fmt, ...)
+ if fmt then
+ io.write(string.format(fmt, ...))
+ end
+ io.write('\n')
+end
+
+local function highlight(str)
+ return ansicolors.white .. str .. ansicolors.reset
+end
+
+local function ask_yes_no(greet, default)
+ local def_str
+ if default then
+ greet = greet .. "[Y/n]: "
+ def_str = "yes"
+ else
+ greet = greet .. "[y/N]: "
+ def_str = "no"
+ end
+
+ local reply = rspamd_util.readline(greet)
+
+ if not reply then
+ os.exit(0)
+ end
+ if #reply == 0 then
+ reply = def_str
+ end
+ reply = reply:lower()
+ if reply == 'y' or reply == 'yes' then
+ return true
+ end
+
+ return false
+end
+
+local function readline_default(greet, def_value)
+ local reply = rspamd_util.readline(greet)
+ if not reply then
+ os.exit(0)
+ end
+
+ if #reply == 0 then
+ return def_value
+ end
+
+ return reply
+end
+
+local function readline_expire()
+ local expire = '100d'
+ repeat
+ expire = readline_default("Expire time for new tokens [" .. expire .. "]: ",
+ expire)
+ expire = lua_util.parse_time_interval(expire)
+
+ if not expire then
+ expire = '100d'
+ elseif expire > 2147483647 then
+ printf("The maximum possible value is 2147483647 (about 68y)")
+ expire = '68y'
+ elseif expire < -1 then
+ printf("The value must be a non-negative integer or -1")
+ expire = -1
+ elseif expire ~= math.floor(expire) then
+ printf("The value must be an integer")
+ expire = math.floor(expire)
+ else
+ return expire
+ end
+ until false
+end
+
+local function print_changes(changes)
+ local function print_change(k, c, where)
+ printf('File: %s, changes list:', highlight(local_conf .. '/'
+ .. where .. '/' .. k))
+
+ for ek, ev in pairs(c) do
+ printf("%s => %s", highlight(ek), rspamd_logger.slog("%s", ev))
+ end
+ end
+ for k, v in pairs(changes.l) do
+ print_change(k, v, 'local.d')
+ if changes.o[k] then
+ v = changes.o[k]
+ print_change(k, v, 'override.d')
+ end
+ print()
+ end
+end
+
+local function apply_changes(changes)
+ local function dirname(fname)
+ if fname:match(".-/.-") then
+ return string.gsub(fname, "(.*/)(.*)", "%1")
+ else
+ return nil
+ end
+ end
+
+ local function apply_change(k, c, where)
+ local fname = local_conf .. '/' .. where .. '/' .. k
+
+ if not rspamd_util.file_exists(fname) then
+ printf("Create file %s", highlight(fname))
+
+ local dname = dirname(fname)
+
+ if dname then
+ local ret, err = rspamd_util.mkdir(dname, true)
+
+ if not ret then
+ printf("Cannot make directory %s: %s", dname, highlight(err))
+ os.exit(1)
+ end
+ end
+ end
+
+ local f = io.open(fname, "a+")
+
+ if not f then
+ printf("Cannot open file %s, aborting", highlight(fname))
+ os.exit(1)
+ end
+
+ f:write(ucl.to_config(c))
+
+ f:close()
+ end
+ for k, v in pairs(changes.l) do
+ apply_change(k, v, 'local.d')
+ if changes.o[k] then
+ v = changes.o[k]
+ apply_change(k, v, 'override.d')
+ end
+ end
+end
+
+local function setup_controller(controller, changes)
+ printf("Setup %s and controller worker:", highlight("WebUI"))
+
+ if not controller.password or controller.password == 'q1' then
+ if ask_yes_no("Controller password is not set, do you want to set one?", true) then
+ local pw_encrypted = rspamadm.pw_encrypt()
+ if pw_encrypted then
+ printf("Set encrypted password to: %s", highlight(pw_encrypted))
+ changes.l['worker-controller.inc'] = {
+ password = pw_encrypted
+ }
+ end
+ end
+ end
+end
+
+local function setup_redis(cfg, changes)
+ local function parse_servers(servers)
+ local ls = lua_util.rspamd_str_split(servers, ",")
+
+ return ls
+ end
+
+ printf("%s servers are not set:", highlight("Redis"))
+ printf("The following modules will be enabled if you add Redis servers:")
+
+ for k, _ in pairs(rspamd_plugins_state.disabled_redis) do
+ printf("\t* %s", highlight(k))
+ end
+
+ if ask_yes_no("Do you wish to set Redis servers?", true) then
+ local read_servers = readline_default("Input read only servers separated by `,` [default: localhost]: ",
+ "localhost")
+
+ local rs = parse_servers(read_servers)
+ if rs and #rs > 0 then
+ changes.l['redis.conf'] = {
+ read_servers = table.concat(rs, ",")
+ }
+ end
+ local write_servers = readline_default("Input write only servers separated by `,` [default: "
+ .. read_servers .. "]: ", read_servers)
+
+ if not write_servers or #write_servers == 0 then
+ printf("Use read servers %s as write servers", highlight(table.concat(rs, ",")))
+ write_servers = read_servers
+ end
+
+ redis_params = {
+ read_servers = rs,
+ }
+
+ local ws = parse_servers(write_servers)
+ if ws and #ws > 0 then
+ changes.l['redis.conf']['write_servers'] = table.concat(ws, ",")
+ redis_params['write_servers'] = ws
+ end
+
+ if ask_yes_no('Do you have any username set for your Redis (ACL SETUSER and Redis 6.0+)') then
+ local username = readline_default("Enter Redis username:", nil)
+
+ if username then
+ changes.l['redis.conf'].username = username
+ redis_params.username = username
+ end
+
+ local passwd = readline_default("Enter Redis password:", nil)
+
+ if passwd then
+ changes.l['redis.conf']['password'] = passwd
+ redis_params['password'] = passwd
+ end
+ elseif ask_yes_no('Do you have any password set for your Redis?') then
+ local passwd = readline_default("Enter Redis password:", nil)
+
+ if passwd then
+ changes.l['redis.conf']['password'] = passwd
+ redis_params['password'] = passwd
+ end
+ end
+
+ if ask_yes_no('Do you have any specific database for your Redis?') then
+ local db = readline_default("Enter Redis database:", nil)
+
+ if db then
+ changes.l['redis.conf']['db'] = db
+ redis_params['db'] = db
+ end
+ end
+ end
+end
+
+local function setup_dkim_signing(cfg, changes)
+ -- Remove the trailing slash of a pathname, if present.
+ local function remove_trailing_slash(path)
+ if string.sub(path, -1) ~= "/" then
+ return path
+ end
+ return string.sub(path, 1, string.len(path) - 1)
+ end
+
+ printf('How would you like to set up DKIM signing?')
+ printf('1. Use domain from %s for sign', highlight('mime from header'))
+ printf('2. Use domain from %s for sign', highlight('SMTP envelope from'))
+ printf('3. Use domain from %s for sign', highlight('authenticated user'))
+ printf('4. Sign all mail from %s', highlight('specific networks'))
+ printf()
+
+ local sign_type = readline_default('Enter your choice (1, 2, 3, 4) [default: 1]: ', '1')
+ local sign_networks
+ local allow_mismatch
+ local sign_authenticated
+ local use_esld
+ local sign_domain = 'pet luacheck'
+
+ local defined_auth_types = { 'header', 'envelope', 'auth', 'recipient' }
+
+ if sign_type == '4' then
+ repeat
+ sign_networks = readline_default('Enter list of networks to perform dkim signing: ',
+ '')
+ until #sign_networks ~= 0
+
+ sign_networks = fun.totable(fun.map(lua_util.rspamd_str_trim,
+ lua_util.str_split(sign_networks, ',; ')))
+ printf('What domain would you like to use for signing?')
+ printf('* %s to use mime from domain', highlight('header'))
+ printf('* %s to use SMTP from domain', highlight('envelope'))
+ printf('* %s to use domain from SMTP auth', highlight('auth'))
+ printf('* %s to use domain from SMTP recipient', highlight('recipient'))
+ printf('* anything else to use as a %s domain (e.g. `example.com`)', highlight('static'))
+ printf()
+
+ sign_domain = readline_default('Enter your choice [default: header]: ', 'header')
+ else
+ if sign_type == '1' then
+ sign_domain = 'header'
+ elseif sign_type == '2' then
+ sign_domain = 'envelope'
+ else
+ sign_domain = 'auth'
+ end
+ end
+
+ if sign_type ~= '3' then
+ sign_authenticated = ask_yes_no(
+ string.format('Do you want to sign mail from %s? ',
+ highlight('authenticated users')), true)
+ else
+ sign_authenticated = true
+ end
+
+ if fun.any(function(s)
+ return s == sign_domain
+ end, defined_auth_types) then
+ -- Allow mismatch
+ allow_mismatch = ask_yes_no(
+ string.format('Allow data %s, e.g. if mime from domain is not equal to authenticated user domain? ',
+ highlight('mismatch')), true)
+ -- ESLD check
+ use_esld = ask_yes_no(
+ string.format('Do you want to use %s domain (e.g. example.com instead of foo.example.com)? ',
+ highlight('effective')), true)
+ else
+ allow_mismatch = true
+ end
+
+ local domains = {}
+ local has_domains = false
+
+ local dkim_keys_dir = rspamd_paths["DBDIR"] .. "/dkim/"
+
+ local prompt = string.format("Enter output directory for the keys [default: %s]: ",
+ highlight(dkim_keys_dir))
+ dkim_keys_dir = remove_trailing_slash(readline_default(prompt, dkim_keys_dir))
+
+ local ret, err = rspamd_util.mkdir(dkim_keys_dir, true)
+
+ if not ret then
+ printf("Cannot make directory %s: %s", dkim_keys_dir, highlight(err))
+ os.exit(1)
+ end
+
+ local function print_domains()
+ printf("Domains configured:")
+ for k, v in pairs(domains) do
+ printf("Domain: %s, selector: %s, privkey: %s", highlight(k),
+ v.selector, v.privkey)
+ end
+ printf("--")
+ end
+ local function print_public_key(pk)
+ local base64_pk = tostring(rspamd_util.encode_base64(pk))
+ printf('v=DKIM1; k=rsa; p=%s\n', base64_pk)
+ end
+ repeat
+ if has_domains then
+ print_domains()
+ end
+
+ local domain
+ repeat
+ domain = rspamd_util.readline("Enter domain to sign: ")
+ if not domain then
+ os.exit(1)
+ end
+ until #domain ~= 0
+
+ local selector = readline_default("Enter selector [default: dkim]: ", 'dkim')
+ if not selector then
+ selector = 'dkim'
+ end
+
+ local privkey_file = string.format("%s/%s.%s.key", dkim_keys_dir, domain,
+ selector)
+ if not rspamd_util.file_exists(privkey_file) then
+ if ask_yes_no("Do you want to create privkey " .. highlight(privkey_file),
+ true) then
+ local rsa = require "rspamd_rsa"
+ local sk, pk = rsa.keypair(2048)
+ sk:save(privkey_file, 'pem')
+ print("You need to chown private key file to rspamd user!!")
+ print("To make dkim signing working, to place the following record in your DNS zone:")
+ print_public_key(tostring(pk))
+ end
+ end
+
+ domains[domain] = {
+ selector = selector,
+ path = privkey_file,
+ }
+ until not ask_yes_no("Do you wish to add another DKIM domain?")
+
+ changes.l['dkim_signing.conf'] = { domain = domains }
+ local res_tbl = changes.l['dkim_signing.conf']
+
+ if sign_networks then
+ res_tbl.sign_networks = sign_networks
+ res_tbl.use_domain_sign_networks = sign_domain
+ else
+ res_tbl.use_domain = sign_domain
+ end
+
+ if allow_mismatch then
+ res_tbl.allow_hdrfrom_mismatch = true
+ res_tbl.allow_hdrfrom_mismatch_sign_networks = true
+ res_tbl.allow_username_mismatch = true
+ end
+
+ res_tbl.use_esld = use_esld
+ res_tbl.sign_authenticated = sign_authenticated
+end
+
+local function check_redis_classifier(cls, changes)
+ local symbol_spam, symbol_ham
+ -- Load symbols from statfiles
+ local statfiles = cls.statfile
+ for _, stf in ipairs(statfiles) do
+ local symbol = stf.symbol or 'undefined'
+
+ local spam
+ if stf.spam then
+ spam = stf.spam
+ else
+ if string.match(symbol:upper(), 'SPAM') then
+ spam = true
+ else
+ spam = false
+ end
+ end
+
+ if spam then
+ symbol_spam = symbol
+ else
+ symbol_ham = symbol
+ end
+ end
+
+ if not symbol_spam or not symbol_ham then
+ printf("Classifier has no symbols defined")
+ return
+ end
+
+ local parsed_redis = lua_redis.try_load_redis_servers(cls, nil)
+
+ if not parsed_redis and redis_params then
+ parsed_redis = lua_redis.try_load_redis_servers(redis_params, nil)
+ if not parsed_redis then
+ printf("Cannot parse Redis params")
+ return
+ end
+ end
+
+ local function try_convert(update_config)
+ if ask_yes_no("Do you wish to convert data to the new schema?", true) then
+ local expire = readline_expire()
+ if not lua_stat_tools.convert_bayes_schema(parsed_redis, symbol_spam,
+ symbol_ham, expire) then
+ printf("Conversion failed")
+ else
+ printf("Conversion succeed")
+ if update_config then
+ changes.l['classifier-bayes.conf'] = {
+ new_schema = true,
+ }
+
+ if expire then
+ changes.l['classifier-bayes.conf'].expire = expire
+ end
+ end
+ end
+ end
+ end
+
+ local function get_version(conn)
+ conn:add_cmd("SMEMBERS", { "RS_keys" })
+
+ local ret, members = conn:exec()
+
+ -- Empty db
+ if not ret or #members == 0 then
+ return false, 0
+ end
+
+ -- We still need to check versions
+ local lua_script = [[
+local ver = 0
+
+local tst = redis.call('GET', KEYS[1]..'_version')
+if tst then
+ ver = tonumber(tst) or 0
+end
+
+return ver
+]]
+ conn:add_cmd('EVAL', { lua_script, '1', 'RS' })
+ local _, ver = conn:exec()
+
+ return true, tonumber(ver)
+ end
+
+ local function check_expire(conn)
+ -- We still need to check versions
+ local lua_script = [[
+local ttl = 0
+
+local sc = redis.call('SCAN', 0, 'MATCH', 'RS*_*', 'COUNT', 1)
+local _,key = sc[1], sc[2]
+
+if key and key[1] then
+ ttl = redis.call('TTL', key[1])
+end
+
+return ttl
+]]
+ conn:add_cmd('EVAL', { lua_script, '0' })
+ local _, ttl = conn:exec()
+
+ return tonumber(ttl)
+ end
+
+ local res, conn = lua_redis.redis_connect_sync(parsed_redis, true)
+ if not res then
+ printf("Cannot connect to Redis server")
+ return false
+ end
+
+ if not cls.new_schema then
+ local r, ver = get_version(conn)
+ if not r then
+ return false
+ end
+ if ver ~= 2 then
+ if not ver then
+ printf('Key "RS_version" has not been found in Redis for %s/%s',
+ symbol_ham, symbol_spam)
+ else
+ printf("You are using an old schema version: %s for %s/%s",
+ ver, symbol_ham, symbol_spam)
+ end
+ try_convert(true)
+ else
+ printf("You have configured an old schema for %s/%s but your data has new layout",
+ symbol_ham, symbol_spam)
+
+ if ask_yes_no("Switch config to the new schema?", true) then
+ changes.l['classifier-bayes.conf'] = {
+ new_schema = true,
+ }
+
+ local expire = check_expire(conn)
+ if expire then
+ changes.l['classifier-bayes.conf'].expire = expire
+ end
+ end
+ end
+ else
+ local r, ver = get_version(conn)
+ if not r then
+ return false
+ end
+ if ver ~= 2 then
+ printf("You have configured new schema for %s/%s but your DB has old version: %s",
+ symbol_spam, symbol_ham, ver)
+ try_convert(false)
+ else
+ printf(
+ 'You have configured new schema for %s/%s and your DB already has new layout (v. %s).' ..
+ ' DB conversion is not needed.',
+ symbol_spam, symbol_ham, ver)
+ end
+ end
+end
+
+local function setup_statistic(cfg, changes)
+ local sqlite_configs = lua_stat_tools.load_sqlite_config(cfg)
+
+ if #sqlite_configs > 0 then
+
+ if not redis_params then
+ printf('You have %d sqlite classifiers, but you have no Redis servers being set',
+ #sqlite_configs)
+ return false
+ end
+
+ local parsed_redis = lua_redis.try_load_redis_servers(redis_params, nil)
+ if parsed_redis then
+ printf('You have %d sqlite classifiers', #sqlite_configs)
+ local expire = readline_expire()
+
+ local reset_previous = ask_yes_no("Reset previous data?")
+ if ask_yes_no('Do you wish to convert them to Redis?', true) then
+
+ for _, cls in ipairs(sqlite_configs) do
+ if rspamd_util.file_exists(cls.db_spam) and rspamd_util.file_exists(cls.db_ham) then
+ if not lua_stat_tools.convert_sqlite_to_redis(parsed_redis, cls.db_spam,
+ cls.db_ham, cls.symbol_spam, cls.symbol_ham, cls.learn_cache, expire,
+ reset_previous) then
+ rspamd_logger.errx('conversion failed')
+
+ return false
+ end
+ else
+ rspamd_logger.messagex('cannot find %s and %s, skip conversion',
+ cls.db_spam, cls.db_ham)
+ end
+
+ rspamd_logger.messagex('Converted classifier to the from sqlite to redis')
+ changes.l['classifier-bayes.conf'] = {
+ backend = 'redis',
+ new_schema = true,
+ }
+
+ if expire then
+ changes.l['classifier-bayes.conf'].expire = expire
+ end
+
+ if cls.learn_cache then
+ changes.l['classifier-bayes.conf'].cache = {
+ backend = 'redis'
+ }
+ end
+ end
+ end
+ end
+ else
+ -- Check sanity for the existing Redis classifiers
+ local classifier = cfg.classifier
+
+ if classifier then
+ if classifier[1] then
+ for _, cls in ipairs(classifier) do
+ if cls.bayes then
+ cls = cls.bayes
+ end
+ if cls.backend and cls.backend == 'redis' then
+ check_redis_classifier(cls, changes)
+ end
+ end
+ else
+ if classifier.bayes then
+
+ classifier = classifier.bayes
+ if classifier[1] then
+ for _, cls in ipairs(classifier) do
+ if cls.backend and cls.backend == 'redis' then
+ check_redis_classifier(cls, changes)
+ end
+ end
+ else
+ if classifier.backend and classifier.backend == 'redis' then
+ check_redis_classifier(classifier, changes)
+ end
+ end
+ end
+ end
+ end
+ end
+end
+
+local function find_worker(cfg, wtype)
+ if cfg.worker then
+ for k, s in pairs(cfg.worker) do
+ if type(k) == 'number' and type(s) == 'table' then
+ if s[wtype] then
+ return s[wtype]
+ end
+ end
+ if type(s) == 'table' and s.type and s.type == wtype then
+ return s
+ end
+ if type(k) == 'string' and k == wtype then
+ return s
+ end
+ end
+ end
+
+ return nil
+end
+
+return {
+ handler = function(cmd_args)
+ local changes = {
+ l = {}, -- local changes
+ o = {}, -- override changes
+ }
+
+ local interactive_start = true
+ local checks = {}
+ local all_checks = {
+ 'controller',
+ 'redis',
+ 'dkim',
+ 'statistic',
+ }
+
+ local opts = parser:parse(cmd_args)
+ local args = opts['checks'] or {}
+
+ local _r, err = rspamd_config:load_ucl(opts['config'])
+
+ if not _r then
+ rspamd_logger.errx('cannot parse %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
+ _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' })
+ if not _r then
+ rspamd_logger.errx('cannot process %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
+ local cfg = rspamd_config:get_ucl()
+
+ if not rspamd_config:init_modules() then
+ rspamd_logger.errx('cannot init modules when parsing %s', opts['config'])
+ os.exit(1)
+ end
+
+ if #args > 0 then
+ interactive_start = false
+
+ for _, arg in ipairs(args) do
+ if arg == 'all' then
+ checks = all_checks
+ elseif arg == 'list' then
+ printf(highlight(rspamd_logo))
+ printf('Available modules')
+ for _, c in ipairs(all_checks) do
+ printf('- %s', c)
+ end
+ return
+ else
+ table.insert(checks, arg)
+ end
+ end
+ else
+ checks = all_checks
+ end
+
+ local function has_check(check)
+ for _, c in ipairs(checks) do
+ if c == check then
+ return true
+ end
+ end
+
+ return false
+ end
+
+ rspamd_util.umask('022')
+ if interactive_start then
+ printf(highlight(rspamd_logo))
+ printf("Welcome to the configuration tool")
+ printf("We use %s configuration file, writing results to %s",
+ highlight(opts['config']), highlight(local_conf))
+ plugins_stat(nil, nil)
+ end
+
+ if not interactive_start or
+ ask_yes_no("Do you wish to continue?", true) then
+
+ if has_check('controller') then
+ local controller = find_worker(cfg, 'controller')
+ if controller then
+ setup_controller(controller, changes)
+ end
+ end
+
+ if has_check('redis') then
+ if not cfg.redis or (not cfg.redis.servers and not cfg.redis.read_servers) then
+ setup_redis(cfg, changes)
+ else
+ redis_params = cfg.redis
+ end
+ else
+ redis_params = cfg.redis
+ end
+
+ if has_check('dkim') then
+ if cfg.dkim_signing and not cfg.dkim_signing.domain then
+ if ask_yes_no('Do you want to setup dkim signing feature?') then
+ setup_dkim_signing(cfg, changes)
+ end
+ end
+ end
+
+ if has_check('statistic') or has_check('statistics') then
+ setup_statistic(cfg, changes)
+ end
+
+ local nchanges = 0
+ for _, _ in pairs(changes.l) do
+ nchanges = nchanges + 1
+ end
+ for _, _ in pairs(changes.o) do
+ nchanges = nchanges + 1
+ end
+
+ if nchanges > 0 then
+ print_changes(changes)
+ if ask_yes_no("Apply changes?", true) then
+ apply_changes(changes)
+ printf("%d changes applied, the wizard is finished now", nchanges)
+ printf("*** Please reload the Rspamd configuration ***")
+ else
+ printf("No changes applied, the wizard is finished now")
+ end
+ else
+ printf("No changes found, the wizard is finished now")
+ end
+ end
+ end,
+ name = 'configwizard',
+ description = parser._description,
+}
diff --git a/lualib/rspamadm/cookie.lua b/lualib/rspamadm/cookie.lua
new file mode 100644
index 0000000..7e0526a
--- /dev/null
+++ b/lualib/rspamadm/cookie.lua
@@ -0,0 +1,125 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local argparse = require "argparse"
+
+
+-- Define command line options
+local parser = argparse()
+ :name "rspamadm cookie"
+ :description "Produces cookies or message ids"
+ :help_description_margin(30)
+
+parser:mutex(
+ parser:option "-k --key"
+ :description('Key to load')
+ :argname "<32hex>",
+ parser:flag "-K --new-key"
+ :description('Generates a new key')
+)
+
+parser:option "-d --domain"
+ :description('Use specified domain and generate full message id')
+ :argname "<domain>"
+parser:flag "-D --decrypt"
+ :description('Decrypt cookie instead of encrypting one')
+parser:flag "-t --timestamp"
+ :description('Show cookie timestamp (valid for decrypting only)')
+parser:argument "cookie":args "?"
+ :description('Use specified cookie')
+
+local function gen_cookie(args, key)
+ local cr = require "rspamd_cryptobox"
+
+ if not args.cookie then
+ return
+ end
+
+ local function encrypt()
+ if #args.cookie > 31 then
+ print('cookie too long (>31 characters), cannot encrypt')
+ os.exit(1)
+ end
+
+ local enc_cookie = cr.encrypt_cookie(key, args.cookie)
+ if args.domain then
+ print(string.format('<%s@%s>', enc_cookie, args.domain))
+ else
+ print(enc_cookie)
+ end
+ end
+
+ local function decrypt()
+ local extracted_cookie = args.cookie:match('^%<?([^@]+)@.*$')
+ if not extracted_cookie then
+ -- Assume full message id as a cookie
+ extracted_cookie = args.cookie
+ end
+
+ local dec_cookie, ts = cr.decrypt_cookie(key, extracted_cookie)
+
+ if dec_cookie then
+ if args.timestamp then
+ print(string.format('%s %s', dec_cookie, ts))
+ else
+ print(dec_cookie)
+ end
+ else
+ print('cannot decrypt cookie')
+ os.exit(1)
+ end
+ end
+
+ if args.decrypt then
+ decrypt()
+ else
+ encrypt()
+ end
+end
+
+local function handler(args)
+ local res = parser:parse(args)
+
+ if not (res.key or res['new_key']) then
+ parser:error('--key or --new-key must be specified')
+ end
+
+ if res.key then
+ local pattern = { '^' }
+ for i = 1, 32 do
+ pattern[i + 1] = '[a-zA-Z0-9]'
+ end
+ pattern[34] = '$'
+
+ if not res.key:match(table.concat(pattern, '')) then
+ parser:error('invalid key: ' .. res.key)
+ end
+
+ gen_cookie(res, res.key)
+ else
+ local util = require "rspamd_util"
+ local key = util.random_hex(32)
+
+ print(key)
+ gen_cookie(res, res.key)
+ end
+end
+
+return {
+ handler = handler,
+ description = parser._description,
+ name = 'cookie'
+} \ No newline at end of file
diff --git a/lualib/rspamadm/corpus_test.lua b/lualib/rspamadm/corpus_test.lua
new file mode 100644
index 0000000..0e63f9f
--- /dev/null
+++ b/lualib/rspamadm/corpus_test.lua
@@ -0,0 +1,185 @@
+local rspamd_logger = require "rspamd_logger"
+local ucl = require "ucl"
+local lua_util = require "lua_util"
+local argparse = require "argparse"
+
+local parser = argparse()
+ :name "rspamadm corpus_test"
+ :description "Create logs files from email corpus"
+ :help_description_margin(32)
+
+parser:option "-H --ham"
+ :description("Ham directory")
+ :argname("<dir>")
+parser:option "-S --spam"
+ :description("Spam directory")
+ :argname("<dir>")
+parser:option "-n --conns"
+ :description("Number of parallel connections")
+ :argname("<N>")
+ :convert(tonumber)
+ :default(10)
+parser:option "-o --output"
+ :description("Output file")
+ :argname("<file>")
+ :default('results.log')
+parser:option "-t --timeout"
+ :description("Timeout for client connections")
+ :argname("<sec>")
+ :convert(tonumber)
+ :default(60)
+parser:option "-c --connect"
+ :description("Connect to specific host")
+ :argname("<host>")
+ :default('localhost:11334')
+parser:option "-r --rspamc"
+ :description("Use specific rspamc path")
+ :argname("<path>")
+ :default('rspamc')
+
+local HAM = "HAM"
+local SPAM = "SPAM"
+local opts
+
+local function scan_email(n_parallel, path, timeout)
+
+ local rspamc_command = string.format("%s --connect %s -j --compact -n %s -t %.3f %s",
+ opts.rspamc, opts.connect, n_parallel, timeout, path)
+ local result = assert(io.popen(rspamc_command))
+ result = result:read("*all")
+ return result
+end
+
+local function write_results(results, file)
+
+ local f = io.open(file, 'w')
+
+ for _, result in pairs(results) do
+ local log_line = string.format("%s %.2f %s",
+ result.type, result.score, result.action)
+
+ for _, sym in pairs(result.symbols) do
+ log_line = log_line .. " " .. sym
+ end
+
+ log_line = log_line .. " " .. result.scan_time .. " " .. file .. ':' .. result.filename
+
+ log_line = log_line .. "\r\n"
+
+ f:write(log_line)
+ end
+
+ f:close()
+end
+
+local function encoded_json_to_log(result)
+ -- Returns table containing score, action, list of symbols
+
+ local filtered_result = {}
+ local ucl_parser = ucl.parser()
+
+ local is_good, err = ucl_parser:parse_string(result)
+
+ if not is_good then
+ rspamd_logger.errx("Parser error: %1", err)
+ return nil
+ end
+
+ result = ucl_parser:get_object()
+
+ filtered_result.score = result.score
+ if not result.action then
+ rspamd_logger.errx("Bad JSON: %1", result)
+ return nil
+ end
+ local action = result.action:gsub("%s+", "_")
+ filtered_result.action = action
+
+ filtered_result.symbols = {}
+
+ for sym, _ in pairs(result.symbols) do
+ table.insert(filtered_result.symbols, sym)
+ end
+
+ filtered_result.filename = result.filename
+ filtered_result.scan_time = result.scan_time
+
+ return filtered_result
+end
+
+local function scan_results_to_logs(results, actual_email_type)
+
+ local logs = {}
+
+ results = lua_util.rspamd_str_split(results, "\n")
+
+ if results[#results] == "" then
+ results[#results] = nil
+ end
+
+ for _, result in pairs(results) do
+ result = encoded_json_to_log(result)
+ if result then
+ result['type'] = actual_email_type
+ table.insert(logs, result)
+ end
+ end
+
+ return logs
+end
+
+local function handler(args)
+ opts = parser:parse(args)
+ local ham_directory = opts['ham']
+ local spam_directory = opts['spam']
+ local connections = opts["conns"]
+ local output = opts["output"]
+
+ local results = {}
+
+ local start_time = os.time()
+ local no_of_ham = 0
+ local no_of_spam = 0
+
+ if ham_directory then
+ rspamd_logger.messagex("Scanning ham corpus...")
+ local ham_results = scan_email(connections, ham_directory, opts["timeout"])
+ ham_results = scan_results_to_logs(ham_results, HAM)
+
+ no_of_ham = #ham_results
+
+ for _, result in pairs(ham_results) do
+ table.insert(results, result)
+ end
+ end
+
+ if spam_directory then
+ rspamd_logger.messagex("Scanning spam corpus...")
+ local spam_results = scan_email(connections, spam_directory, opts.timeout)
+ spam_results = scan_results_to_logs(spam_results, SPAM)
+
+ no_of_spam = #spam_results
+
+ for _, result in pairs(spam_results) do
+ table.insert(results, result)
+ end
+ end
+
+ rspamd_logger.messagex("Writing results to %s", output)
+ write_results(results, output)
+
+ rspamd_logger.messagex("Stats: ")
+ local elapsed_time = os.time() - start_time
+ local total_msgs = no_of_ham + no_of_spam
+ rspamd_logger.messagex("Elapsed time: %ss", elapsed_time)
+ rspamd_logger.messagex("No of ham: %s", no_of_ham)
+ rspamd_logger.messagex("No of spam: %s", no_of_spam)
+ rspamd_logger.messagex("Messages/sec: %s", (total_msgs / elapsed_time))
+end
+
+return {
+ name = 'corpustest',
+ aliases = { 'corpus_test', 'corpus' },
+ handler = handler,
+ description = parser._description
+}
diff --git a/lualib/rspamadm/dkim_keygen.lua b/lualib/rspamadm/dkim_keygen.lua
new file mode 100644
index 0000000..211094d
--- /dev/null
+++ b/lualib/rspamadm/dkim_keygen.lua
@@ -0,0 +1,178 @@
+--[[
+Copyright (c) 2023, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local argparse = require "argparse"
+local rspamd_util = require "rspamd_util"
+local rspamd_cryptobox = require "rspamd_cryptobox"
+
+local parser = argparse()
+ :name 'rspamadm dkim_keygen'
+ :description 'Create key pairs for dkim signing'
+ :help_description_margin(30)
+parser:option '-d --domain'
+ :description 'Create a key for a specific domain'
+ :default "example.com"
+parser:option '-s --selector'
+ :description 'Create a key for a specific DKIM selector'
+ :default "mail"
+parser:option '-k --privkey'
+ :description 'Save private key to file instead of printing it to stdout'
+parser:option '-b --bits'
+ :convert(function(input)
+ local n = tonumber(input)
+ if not n or n < 512 or n > 4096 then
+ return nil
+ end
+ return n
+end)
+ :description 'Generate an RSA key with the specified number of bits (512 to 4096)'
+ :default(1024)
+parser:option '-t --type'
+ :description 'Key type: RSA, ED25519 or ED25119-seed'
+ :convert {
+ ['rsa'] = 'rsa',
+ ['RSA'] = 'rsa',
+ ['ed25519'] = 'ed25519',
+ ['ED25519'] = 'ed25519',
+ ['ed25519-seed'] = 'ed25519-seed',
+ ['ED25519-seed'] = 'ed25519-seed',
+}
+ :default 'rsa'
+parser:option '-o --output'
+ :description 'Output public key in the following format: dns, dnskey or plain'
+ :convert {
+ ['dns'] = 'dns',
+ ['plain'] = 'plain',
+ ['dnskey'] = 'dnskey',
+}
+ :default 'dns'
+parser:option '--priv-output'
+ :description 'Output private key in the following format: PEM or DER (for RSA)'
+ :convert {
+ ['pem'] = 'pem',
+ ['der'] = 'der',
+}
+ :default 'pem'
+parser:flag '-f --force'
+ :description 'Force overwrite of existing files'
+
+local function split_string(input, max_length)
+ max_length = max_length or 253
+ local pieces = {}
+ local index = 1
+
+ while index <= #input do
+ local piece = input:sub(index, index + max_length - 1)
+ table.insert(pieces, piece)
+ index = index + max_length
+ end
+
+ return pieces
+end
+
+local function print_public_key_dns(opts, base64_pk)
+ local key_type = opts.type == 'rsa' and 'rsa' or 'ed25519'
+ if #base64_pk < 255 - 2 then
+ io.write(string.format('%s._domainkey IN TXT ( "v=DKIM1; k=%s;" \n\t"p=%s" ) ;\n',
+ opts.selector, key_type, base64_pk))
+ else
+ -- Split it by parts
+ local parts = split_string(base64_pk)
+ io.write(string.format('%s._domainkey IN TXT ( "v=DKIM1; k=%s; "\n', opts.selector, key_type))
+ for i, part in ipairs(parts) do
+ if i == 1 then
+ io.write(string.format('\t"p=%s"\n', part))
+ else
+ io.write(string.format('\t"%s"\n', part))
+ end
+ end
+ io.write(") ; \n")
+ end
+
+end
+
+local function print_public_key(opts, pk, need_base64)
+ local key_type = opts.type == 'rsa' and 'rsa' or 'ed25519'
+ local base64_pk = need_base64 and tostring(rspamd_util.encode_base64(pk)) or tostring(pk)
+ if opts.output == 'plain' then
+ io.write(base64_pk)
+ io.write("\n")
+ elseif opts.output == 'dns' then
+ print_public_key_dns(opts, base64_pk, false)
+ elseif opts.output == 'dnskey' then
+ io.write(string.format('v=DKIM1; k=%s; p=%s\n', key_type, base64_pk))
+ end
+end
+
+local function gen_rsa_key(opts)
+ local rsa = require "rspamd_rsa"
+
+ local sk, pk = rsa.keypair(opts.bits or 1024)
+ if opts.privkey then
+ if opts.force then
+ os.remove(opts.privkey)
+ end
+ sk:save(opts.privkey, opts.priv_output)
+ else
+ sk:save("-", opts.priv_output)
+ end
+
+ -- We generate key directly via lua_rsa and it returns unencoded raw data
+ print_public_key(opts, tostring(pk), true)
+end
+
+local function gen_eddsa_key(opts)
+ local sk, pk = rspamd_cryptobox.gen_dkim_keypair(opts.type)
+
+ if opts.privkey and opts.force then
+ os.remove(opts.privkey)
+ end
+ if not sk:save_in_file(opts.privkey, tonumber('0600', 8)) then
+ io.stderr:write('cannot save private key to ' .. (opts.privkey or 'stdout') .. '\n')
+ os.exit(1)
+ end
+
+ if not opts.privkey then
+ io.write("\n")
+ io.flush()
+ end
+
+ -- gen_dkim_keypair function returns everything encoded in base64, so no need to do anything
+ print_public_key(opts, tostring(pk), false)
+end
+
+local function handler(args)
+ local opts = parser:parse(args)
+
+ if not opts then
+ os.exit(1)
+ end
+
+ if opts.type == 'rsa' then
+ gen_rsa_key(opts)
+ else
+ gen_eddsa_key(opts)
+ end
+end
+
+return {
+ name = 'dkim_keygen',
+ aliases = { 'dkimkeygen' },
+ handler = handler,
+ description = parser._description
+}
+
+
diff --git a/lualib/rspamadm/dmarc_report.lua b/lualib/rspamadm/dmarc_report.lua
new file mode 100644
index 0000000..42c801e
--- /dev/null
+++ b/lualib/rspamadm/dmarc_report.lua
@@ -0,0 +1,737 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local argparse = require "argparse"
+local lua_util = require "lua_util"
+local logger = require "rspamd_logger"
+local lua_redis = require "lua_redis"
+local dmarc_common = require "plugins/dmarc"
+local lupa = require "lupa"
+local rspamd_mempool = require "rspamd_mempool"
+local rspamd_url = require "rspamd_url"
+local rspamd_text = require "rspamd_text"
+local rspamd_util = require "rspamd_util"
+local rspamd_dns = require "rspamd_dns"
+
+local N = 'dmarc_report'
+
+-- Define command line options
+local parser = argparse()
+ :name "rspamadm dmarc_report"
+ :description "Dmarc reports sending tool"
+ :help_description_margin(30)
+
+parser:option "-c --config"
+ :description "Path to config file"
+ :argname("<cfg>")
+ :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
+
+parser:flag "-v --verbose"
+ :description "Enable dmarc specific logging"
+
+parser:flag "-n --no-opt"
+ :description "Do not reset reporting data/send reports"
+
+parser:argument "date"
+ :description "Date to process (today by default)"
+ :argname "<YYYYMMDD>"
+ :args "*"
+parser:option "-b --batch-size"
+ :description "Send reports in batches up to <batch-size> messages"
+ :argname "<number>"
+ :convert(tonumber)
+ :default "10"
+
+local report_template = [[From: "{= from_name =}" <{= from_addr =}>
+To: {= rcpt =}
+{%+ if is_string(bcc) %}Bcc: {= bcc =}{%- endif %}
+Subject: Report Domain: {= reporting_domain =}
+ Submitter: {= submitter =}
+ Report-ID: {= report_id =}
+Date: {= report_date =}
+MIME-Version: 1.0
+Message-ID: <{= message_id =}>
+Content-Type: multipart/mixed;
+ boundary="----=_NextPart_{= uuid =}"
+
+This is a multipart message in MIME format.
+
+------=_NextPart_{= uuid =}
+Content-Type: text/plain; charset="us-ascii"
+Content-Transfer-Encoding: 7bit
+
+This is an aggregate report from {= submitter =}.
+
+Report domain: {= reporting_domain =}
+Submitter: {= submitter =}
+Report ID: {= report_id =}
+
+------=_NextPart_{= uuid =}
+Content-Type: application/gzip
+Content-Transfer-Encoding: base64
+Content-Disposition: attachment;
+ filename="{= submitter =}!{= reporting_domain =}!{= report_start =}!{= report_end =}.xml.gz"
+
+]]
+local report_footer = [[
+
+------=_NextPart_{= uuid =}--]]
+
+local dmarc_settings = {}
+local redis_params
+local redis_attrs = {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ log_obj = rspamd_config,
+ resolver = rspamadm_dns_resolver,
+}
+local pool
+
+local function load_config(opts)
+ local _r, err = rspamd_config:load_ucl(opts['config'])
+
+ if not _r then
+ logger.errx('cannot parse %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
+ _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' })
+ if not _r then
+ logger.errx('cannot process %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+end
+
+-- Concat elements using redis_keys.join_char
+local function redis_prefix(...)
+ return table.concat({ ... }, dmarc_settings.reporting.redis_keys.join_char)
+end
+
+local function get_rua(rep_key)
+ local parts = lua_util.str_split(rep_key, dmarc_settings.reporting.redis_keys.join_char)
+
+ if #parts >= 3 then
+ return parts[3]
+ end
+
+ return nil
+end
+
+local function get_domain(rep_key)
+ local parts = lua_util.str_split(rep_key, dmarc_settings.reporting.redis_keys.join_char)
+
+ if #parts >= 3 then
+ return parts[2]
+ end
+
+ return nil
+end
+
+local function gen_uuid()
+ local template = 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'
+ return string.gsub(template, '[xy]', function(c)
+ local v = (c == 'x') and math.random(0, 0xf) or math.random(8, 0xb)
+ return string.format('%x', v)
+ end)
+end
+
+local function gen_xml_grammar()
+ local lpeg = require 'lpeg'
+ local lt = lpeg.P('<') / '&lt;'
+ local gt = lpeg.P('>') / '&gt;'
+ local amp = lpeg.P('&') / '&amp;'
+ local quot = lpeg.P('"') / '&quot;'
+ local apos = lpeg.P("'") / '&apos;'
+ local special = lt + gt + amp + quot + apos
+ local grammar = lpeg.Cs((special + 1) ^ 0)
+ return grammar
+end
+
+local xml_grammar = gen_xml_grammar()
+
+local function escape_xml(input)
+ if type(input) == 'string' or type(input) == 'userdata' then
+ return xml_grammar:match(input)
+ else
+ input = tostring(input)
+
+ if input then
+ return xml_grammar:match(input)
+ end
+ end
+
+ return ''
+end
+-- Enable xml escaping in lupa templates
+lupa.filters.escape_xml = escape_xml
+
+-- Creates report XML header
+local function report_header(reporting_domain, report_start, report_end, domain_policy)
+ local report_id = string.format('%s.%d.%d',
+ reporting_domain, report_start, report_end)
+ local xml_template = [[
+<?xml version="1.0" encoding="UTF-8" ?>
+<feedback>
+ <report_metadata>
+ <org_name>{= report_settings.org_name | escape_xml =}</org_name>
+ <email>{= report_settings.email | escape_xml =}</email>
+ <report_id>{= report_id =}</report_id>
+ <date_range>
+ <begin>{= report_start =}</begin>
+ <end>{= report_end =}</end>
+ </date_range>
+ </report_metadata>
+ <policy_published>
+ <domain>{= reporting_domain | escape_xml =}</domain>
+ <adkim>{= domain_policy.adkim | escape_xml =}</adkim>
+ <aspf>{= domain_policy.aspf | escape_xml =}</aspf>
+ <p>{= domain_policy.p | escape_xml =}</p>
+ <sp>{= domain_policy.sp | escape_xml =}</sp>
+ <pct>{= domain_policy.pct | escape_xml =}</pct>
+ </policy_published>
+]]
+ return lua_util.jinja_template(xml_template, {
+ report_settings = dmarc_settings.reporting,
+ report_id = report_id,
+ report_start = report_start,
+ report_end = report_end,
+ domain_policy = domain_policy,
+ reporting_domain = reporting_domain,
+ }, true)
+end
+
+-- Generate xml entry for a preprocessed redis row
+local function entry_to_xml(data)
+ local xml_template = [[<record>
+ <row>
+ <source_ip>{= data.ip =}</source_ip>
+ <count>{= data.count =}</count>
+ <policy_evaluated>
+ <disposition>{= data.disposition =}</disposition>
+ <dkim>{= data.dkim_disposition =}</dkim>
+ <spf>{= data.spf_disposition =}</spf>
+ {% if data.override and data.override ~= '' -%}
+ <reason><type>{= data.override =}</type></reason>
+ {%- endif %}
+ </policy_evaluated>
+ </row>
+ <identifiers>
+ <header_from>{= data.header_from =}</header_from>
+ </identifiers>
+ <auth_results>
+ {% if data.dkim_results[1] -%}
+ {% for d in data.dkim_results -%}
+ <dkim>
+ <domain>{= d.domain =}</domain>
+ <result>{= d.result =}</result>
+ </dkim>
+ {%- endfor %}
+ {%- endif %}
+ <spf>
+ <domain>{= data.spf_domain =}</domain>
+ <result>{= data.spf_result =}</result>
+ </spf>
+ </auth_results>
+</record>
+]]
+ return lua_util.jinja_template(xml_template, { data = data }, true)
+end
+
+-- Process a report entry stored in Redis splitting it to a lua table
+local function process_report_entry(data, score)
+ local split = lua_util.str_split(data, ',')
+ local row = {
+ ip = split[1],
+ spf_disposition = split[2],
+ dkim_disposition = split[3],
+ disposition = split[4],
+ override = split[5],
+ header_from = split[6],
+ dkim_results = {},
+ spf_domain = split[11],
+ spf_result = split[12],
+ count = tonumber(score),
+ }
+ -- Process dkim entries
+ local function dkim_entries_process(dkim_data, result)
+ if dkim_data and dkim_data ~= '' then
+ local dkim_elts = lua_util.str_split(dkim_data, '|')
+ for _, d in ipairs(dkim_elts) do
+ table.insert(row.dkim_results, { domain = d, result = result })
+ end
+ end
+ end
+ dkim_entries_process(split[7], 'pass')
+ dkim_entries_process(split[8], 'fail')
+ dkim_entries_process(split[9], 'temperror')
+ dkim_entries_process(split[9], 'permerror')
+
+ return row
+end
+
+-- Process a single rua entry, validating in DNS if needed
+local function process_rua(dmarc_domain, rua)
+ local parts = lua_util.str_split(rua, ',')
+
+ -- Remove size limitation, as we don't care about them
+ local addrs = {}
+ for _, rua_part in ipairs(parts) do
+ local u = rspamd_url.create(pool, rua_part:gsub('!%d+[kmg]?$', ''))
+ local u2 = rspamd_url.create(pool, dmarc_domain)
+ if u and (u:get_protocol() or '') == 'mailto' and u:get_user() then
+ -- Check each address for sanity
+ if u:get_tld() == u2:get_tld() then
+ -- Same eSLD - always include
+ table.insert(addrs, u)
+ else
+ -- We need to check authority
+ local resolve_str = string.format('%s._report._dmarc.%s',
+ dmarc_domain, u:get_host())
+ local is_ok, results = rspamd_dns.request({
+ config = rspamd_config,
+ session = rspamadm_session,
+ type = 'txt',
+ name = resolve_str,
+ })
+
+ if not is_ok then
+ logger.errx('cannot resolve %s: %s; exclude %s', resolve_str, results, rua_part)
+ else
+ local found = false
+ for _, t in ipairs(results) do
+ if string.match(t, 'v=DMARC1') then
+ found = true
+ break
+ end
+ end
+
+ if not found then
+ logger.errx('%s is not authorized to process reports on %s', dmarc_domain, u:get_host())
+ else
+ -- All good
+ table.insert(addrs, u)
+ end
+ end
+ end
+ else
+ logger.errx('invalid rua url: "%s""', tostring(u or 'null'))
+ end
+ end
+
+ if #addrs > 0 then
+ return addrs
+ end
+
+ return nil
+end
+
+-- Validate reporting domain, extracting rua and checking 3rd party report domains
+-- This function returns a full dmarc record processed + rua as a list of url objects
+local function validate_reporting_domain(reporting_domain)
+ local is_ok, results = rspamd_dns.request({
+ config = rspamd_config,
+ session = rspamadm_session,
+ type = 'txt',
+ name = '_dmarc.' .. reporting_domain,
+ })
+
+ if not is_ok or not results then
+ logger.errx('cannot resolve _dmarc.%s: %s', reporting_domain, results)
+ return nil
+ end
+
+ for _, r in ipairs(results) do
+ local processed, rec = dmarc_common.dmarc_check_record(rspamd_config, r, false)
+ if processed and rec.rua then
+ -- We need to check or alter rua if needed
+ local processed_rua = process_rua(reporting_domain, rec.rua)
+ if processed_rua then
+ rec = rec.raw_elts
+ rec.rua = processed_rua
+
+ -- Fill defaults in a record to avoid nils in a report
+ rec['pct'] = rec['pct'] or 100
+ rec['adkim'] = rec['adkim'] or 'r'
+ rec['aspf'] = rec['aspf'] or 'r'
+ rec['p'] = rec['p'] or 'none'
+ rec['sp'] = rec['sp'] or 'none'
+ return rec
+ end
+ return nil
+ end
+ end
+
+ return nil
+end
+
+-- Returns a list of recipients from a table as a string processing elements if needed
+local function rcpt_list(tbl, func)
+ local res = {}
+ for _, r in ipairs(tbl) do
+ if func then
+ table.insert(res, func(r))
+ else
+ table.insert(res, r)
+ end
+ end
+
+ return table.concat(res, ',')
+end
+
+-- Synchronous smtp send function
+local function send_reports_by_smtp(opts, reports, finish_cb)
+ local lua_smtp = require "lua_smtp"
+ local reports_failed = 0
+ local reports_sent = 0
+ local report_settings = dmarc_settings.reporting
+
+ local function gen_sendmail_cb(report, args)
+ return function(ret, err)
+ -- We modify this from all callbacks
+ args.nreports = args.nreports - 1
+ if not ret then
+ logger.errx("Couldn't send mail for %s: %s", report.reporting_domain, err)
+ reports_failed = reports_failed + 1
+ else
+ reports_sent = reports_sent + 1
+ lua_util.debugm(N, 'successfully sent a report for %s: %s bytes sent',
+ report.reporting_domain, #report.message)
+ end
+
+ -- Tail call to the next batch or to the final function
+ if args.nreports == 0 then
+ if args.next_start > #reports then
+ finish_cb(reports_sent, reports_failed)
+ else
+ args.cont_func(args.next_start)
+ end
+ end
+ end
+ end
+
+ local function send_data_in_batches(cur_batch)
+ local nreports = math.min(#reports - cur_batch + 1, opts.batch_size)
+ local next_start = cur_batch + nreports
+ lua_util.debugm(N, 'send data for %s domains (from %s to %s)',
+ nreports, cur_batch, next_start - 1)
+ -- Shared across all closures
+ local gen_args = {
+ cont_func = send_data_in_batches,
+ nreports = nreports,
+ next_start = next_start
+ }
+ for i = cur_batch, next_start - 1 do
+ local report = reports[i]
+ local send_opts = {
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ config = rspamd_config,
+ host = report_settings.smtp,
+ port = report_settings.smtp_port or 25,
+ resolver = rspamadm_dns_resolver,
+ from = report_settings.email,
+ recipients = report.rcpts,
+ helo = report_settings.helo or 'rspamd.localhost',
+ }
+
+ lua_smtp.sendmail(send_opts,
+ report.message,
+ gen_sendmail_cb(report, gen_args))
+ end
+ end
+
+ send_data_in_batches(1)
+end
+
+local function prepare_report(opts, start_time, end_time, rep_key)
+ local rua = get_rua(rep_key)
+ local reporting_domain = get_domain(rep_key)
+
+ if not rua then
+ logger.errx('report %s has no valid rua, skip it', rep_key)
+ return nil
+ end
+ if not reporting_domain then
+ logger.errx('report %s has no valid reporting_domain, skip it', rep_key)
+ return nil
+ end
+
+ local ret, results = lua_redis.request(redis_params, redis_attrs,
+ { 'EXISTS', rep_key })
+
+ if not ret or not results or results == 0 then
+ return nil
+ end
+
+ -- Rename report key to avoid races
+ if not opts.no_opt then
+ lua_redis.request(redis_params, redis_attrs,
+ { 'RENAME', rep_key, rep_key .. '_processing' })
+ rep_key = rep_key .. '_processing'
+ end
+
+ local dmarc_record = validate_reporting_domain(reporting_domain)
+ lua_util.debugm(N, 'process reporting domain %s: %s', reporting_domain, dmarc_record)
+
+ if not dmarc_record then
+ if not opts.no_opt then
+ lua_redis.request(redis_params, redis_attrs,
+ { 'DEL', rep_key })
+ end
+ logger.messagex('Cannot process reports for domain %s; invalid dmarc record', reporting_domain)
+ return nil
+ end
+
+ -- Get all reports for a domain
+ ret, results = lua_redis.request(redis_params, redis_attrs,
+ { 'ZRANGE', rep_key, '0', '-1', 'WITHSCORES' })
+ local report_entries = {}
+ table.insert(report_entries,
+ report_header(reporting_domain, start_time, end_time, dmarc_record))
+ for i = 1, #results, 2 do
+ local xml_record = entry_to_xml(process_report_entry(results[i], results[i + 1]))
+ table.insert(report_entries, xml_record)
+ end
+ table.insert(report_entries, '</feedback>')
+ local xml_to_compress = rspamd_text.fromtable(report_entries)
+ lua_util.debugm(N, 'got xml: %s', xml_to_compress)
+
+ -- Prepare SMTP message
+ local report_settings = dmarc_settings.reporting
+ local rcpt_string = rcpt_list(dmarc_record.rua, function(rua_elt)
+ return string.format('%s@%s', rua_elt:get_user(), rua_elt:get_host())
+ end)
+ local bcc_string
+ if report_settings.bcc_addrs then
+ bcc_string = rcpt_list(report_settings.bcc_addrs)
+ end
+ local uuid = gen_uuid()
+ local rhead = lua_util.jinja_template(report_template, {
+ from_name = report_settings.from_name,
+ from_addr = report_settings.email,
+ rcpt = rcpt_string,
+ bcc = bcc_string,
+ uuid = uuid,
+ reporting_domain = reporting_domain,
+ submitter = report_settings.domain,
+ report_id = string.format('%s.%d.%d', reporting_domain, start_time,
+ end_time),
+ report_date = rspamd_util.time_to_string(rspamd_util.get_time()),
+ message_id = rspamd_util.random_hex(16) .. '@' .. report_settings.msgid_from,
+ report_start = start_time,
+ report_end = end_time
+ }, true)
+ local rfooter = lua_util.jinja_template(report_footer, {
+ uuid = uuid,
+ }, true)
+ local message = rspamd_text.fromtable {
+ (rhead:gsub("\n", "\r\n")),
+ rspamd_util.encode_base64(rspamd_util.gzip_compress(xml_to_compress), 73),
+ rfooter:gsub("\n", "\r\n"),
+ }
+
+ lua_util.debugm(N, 'got final message: %s', message)
+
+ if not opts.no_opt then
+ lua_redis.request(redis_params, redis_attrs,
+ { 'DEL', rep_key })
+ end
+
+ local report_rcpts = lua_util.str_split(rcpt_string, ',')
+
+ if report_settings.bcc_addrs then
+ for _, b in ipairs(report_settings.bcc_addrs) do
+ table.insert(report_rcpts, b)
+ end
+ end
+
+ return {
+ message = message,
+ rcpts = report_rcpts,
+ reporting_domain = reporting_domain
+ }
+end
+
+local function process_report_date(opts, start_time, end_time, date)
+ local idx_key = redis_prefix(dmarc_settings.reporting.redis_keys.index_prefix, date)
+ local ret, results = lua_redis.request(redis_params, redis_attrs,
+ { 'EXISTS', idx_key })
+
+ if not ret or not results or results == 0 then
+ logger.messagex('No reports for %s', date)
+ return {}
+ end
+
+ -- Rename index key to avoid races
+ if not opts.no_opt then
+ lua_redis.request(redis_params, redis_attrs,
+ { 'RENAME', idx_key, idx_key .. '_processing' })
+ idx_key = idx_key .. '_processing'
+ end
+ ret, results = lua_redis.request(redis_params, redis_attrs,
+ { 'SMEMBERS', idx_key })
+
+ if not ret or not results then
+ -- Remove bad key
+ if not opts.no_opt then
+ lua_redis.request(redis_params, redis_attrs,
+ { 'DEL', idx_key })
+ end
+ logger.messagex('Cannot get reports for %s', date)
+ return {}
+ end
+
+ local reports = {}
+ for _, rep in ipairs(results) do
+ local report = prepare_report(opts, start_time, end_time, rep)
+
+ if report then
+ table.insert(reports, report)
+ end
+ end
+
+ -- Shuffle reports to make sending more fair
+ lua_util.shuffle(reports)
+ -- Remove processed key
+ if not opts.no_opt then
+ lua_redis.request(redis_params, redis_attrs,
+ { 'DEL', idx_key })
+ end
+
+ return reports
+end
+
+
+-- Returns a day before today at 00:00 as unix seconds
+local function yesterday_midnight()
+ local piecewise_time = os.date("*t")
+ piecewise_time.day = piecewise_time.day - 1 -- Lua allows negative values here
+ piecewise_time.hour = 0
+ piecewise_time.sec = 0
+ piecewise_time.min = 0
+ return os.time(piecewise_time)
+end
+
+-- Returns today time at 00:00 as unix seconds
+local function today_midnight()
+ local piecewise_time = os.date("*t")
+ piecewise_time.hour = 0
+ piecewise_time.sec = 0
+ piecewise_time.min = 0
+ return os.time(piecewise_time)
+end
+
+local function handler(args)
+ local start_time
+ -- Preserve start time as report sending might take some time
+ local start_collection = today_midnight()
+
+ local opts = parser:parse(args)
+
+ pool = rspamd_mempool.create()
+ load_config(opts)
+ rspamd_url.init(rspamd_config:get_tld_path())
+
+ if opts.verbose then
+ lua_util.enable_debug_modules('dmarc', N)
+ end
+
+ dmarc_settings = rspamd_config:get_all_opt('dmarc')
+ if not dmarc_settings or not dmarc_settings.reporting or not dmarc_settings.reporting.enabled then
+ logger.errx('dmarc reporting is not enabled, exiting')
+ os.exit(1)
+ end
+
+ dmarc_settings = lua_util.override_defaults(dmarc_common.default_settings, dmarc_settings)
+ redis_params = lua_redis.parse_redis_server('dmarc', dmarc_settings)
+
+ if not redis_params then
+ logger.errx('Redis is not configured, exiting')
+ os.exit(1)
+ end
+
+ for _, e in ipairs({ 'email', 'domain', 'org_name' }) do
+ if not dmarc_settings.reporting[e] then
+ logger.errx('Missing required setting: dmarc.reporting.%s', e)
+ return
+ end
+ end
+
+ local ret, results = lua_redis.request(redis_params, redis_attrs, {
+ 'GET', 'rspamd_dmarc_last_collection'
+ })
+
+ if not ret or not tonumber(results) then
+ start_time = yesterday_midnight()
+ else
+ start_time = tonumber(results)
+ end
+
+ lua_util.debugm(N, 'previous last report date is %s', start_time)
+
+ if not opts.date or #opts.date == 0 then
+ opts.date = {}
+ table.insert(opts.date, os.date('%Y%m%d', yesterday_midnight()))
+ end
+
+ local ndates = 0
+ local nreports = 0
+ local all_reports = {}
+ for _, date in ipairs(opts.date) do
+ lua_util.debugm(N, 'Process date %s', date)
+ local reports_for_date = process_report_date(opts, start_time, start_collection, date)
+ if #reports_for_date > 0 then
+ ndates = ndates + 1
+ nreports = nreports + #reports_for_date
+
+ for _, r in ipairs(reports_for_date) do
+ table.insert(all_reports, r)
+ end
+ end
+ end
+
+ local function finish_cb(nsuccess, nfail)
+ if not opts.no_opt then
+ lua_util.debugm(N, 'set last report date to %s', start_collection)
+ -- Hack to avoid coroutines + async functions mess: we use async redis call here
+ redis_attrs.callback = function()
+ logger.messagex('Reporting collection has finished %s dates processed, %s reports: %s completed, %s failed',
+ ndates, nreports, nsuccess, nfail)
+ end
+ lua_redis.request(redis_params, redis_attrs,
+ { 'SETEX', 'rspamd_dmarc_last_collection', dmarc_settings.reporting.keys_expire * 2,
+ tostring(start_collection) })
+ else
+ logger.messagex('Reporting collection has finished %s dates processed, %s reports: %s completed, %s failed',
+ ndates, nreports, nsuccess, nfail)
+ end
+
+ pool:destroy()
+ end
+ if not opts.no_opt then
+ send_reports_by_smtp(opts, all_reports, finish_cb)
+ else
+ logger.messagex('Skip sending mails due to -n / --no-opt option')
+ end
+end
+
+return {
+ name = 'dmarc_report',
+ aliases = { 'dmarc_reporting' },
+ handler = handler,
+ description = parser._description
+}
diff --git a/lualib/rspamadm/dns_tool.lua b/lualib/rspamadm/dns_tool.lua
new file mode 100644
index 0000000..3eb09a8
--- /dev/null
+++ b/lualib/rspamadm/dns_tool.lua
@@ -0,0 +1,232 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+
+local argparse = require "argparse"
+local rspamd_logger = require "rspamd_logger"
+local ansicolors = require "ansicolors"
+local bit = require "bit"
+
+local parser = argparse()
+ :name "rspamadm dnstool"
+ :description "DNS tools provided by Rspamd"
+ :help_description_margin(30)
+ :command_target("command")
+ :require_command(true)
+
+parser:option "-c --config"
+ :description "Path to config file"
+ :argname("<cfg>")
+ :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
+
+local spf = parser:command "spf"
+ :description "Extracts spf records"
+spf:mutex(
+ spf:option "-d --domain"
+ :description "Domain to use"
+ :argname("<domain>"),
+ spf:option "-f --from"
+ :description "SMTP from to use"
+ :argname("<from>")
+)
+
+spf:option "-i --ip"
+ :description "Source IP address to use"
+ :argname("<ip>")
+spf:flag "-a --all"
+ :description "Print all records"
+
+local function printf(fmt, ...)
+ if fmt then
+ io.write(string.format(fmt, ...))
+ end
+ io.write('\n')
+end
+
+local function highlight(str)
+ return ansicolors.white .. str .. ansicolors.reset
+end
+
+local function green(str)
+ return ansicolors.green .. str .. ansicolors.reset
+end
+
+local function red(str)
+ return ansicolors.red .. str .. ansicolors.reset
+end
+
+local function load_config(opts)
+ local _r, err = rspamd_config:load_ucl(opts['config'])
+
+ if not _r then
+ rspamd_logger.errx('cannot parse %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
+ _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' })
+ if not _r then
+ rspamd_logger.errx('cannot process %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+end
+
+local function spf_handler(opts)
+ local rspamd_spf = require "rspamd_spf"
+ local rspamd_task = require "rspamd_task"
+ local rspamd_ip = require "rspamd_ip"
+
+ local task = rspamd_task:create(rspamd_config, rspamadm_ev_base)
+ task:set_session(rspamadm_session)
+ task:set_resolver(rspamadm_dns_resolver)
+
+ if opts.ip then
+ opts.ip = rspamd_ip.fromstring(opts.ip)
+ task:set_from_ip(opts.ip)
+ else
+ opts.all = true
+ end
+
+ if opts.from then
+ local rspamd_parsers = require "rspamd_parsers"
+ local addr_parsed = rspamd_parsers.parse_mail_address(opts.from)
+ if addr_parsed then
+ task:set_from('smtp', addr_parsed[1])
+ else
+ io.stderr:write('Invalid from addr\n')
+ os.exit(1)
+ end
+ elseif opts.domain then
+ task:set_from('smtp', { user = 'user', domain = opts.domain })
+ else
+ io.stderr:write('Neither domain nor from specified\n')
+ os.exit(1)
+ end
+
+ local function flag_to_str(fl)
+ if bit.band(fl, rspamd_spf.flags.temp_fail) ~= 0 then
+ return "temporary failure"
+ elseif bit.band(fl, rspamd_spf.flags.perm_fail) ~= 0 then
+ return "permanent failure"
+ elseif bit.band(fl, rspamd_spf.flags.na) ~= 0 then
+ return "no spf record"
+ end
+
+ return "unknown flag: " .. tostring(fl)
+ end
+
+ local function display_spf_results(elt, colored)
+ local dec = function(e)
+ return e
+ end
+ local policy_decode = function(e)
+ if e == rspamd_spf.policy.fail then
+ return 'reject'
+ elseif e == rspamd_spf.policy.pass then
+ return 'pass'
+ elseif e == rspamd_spf.policy.soft_fail then
+ return 'soft fail'
+ elseif e == rspamd_spf.policy.neutral then
+ return 'neutral'
+ end
+
+ return 'unknown'
+ end
+
+ if colored then
+ dec = function(e)
+ return highlight(e)
+ end
+
+ if elt.result == rspamd_spf.policy.pass then
+ dec = function(e)
+ return green(e)
+ end
+ elseif elt.result == rspamd_spf.policy.fail then
+ dec = function(e)
+ return red(e)
+ end
+ end
+
+ end
+ printf('%s: %s', highlight('Policy'), dec(policy_decode(elt.result)))
+ printf('%s: %s', highlight('Network'), dec(elt.addr))
+
+ if elt.str then
+ printf('%s: %s', highlight('Original'), elt.str)
+ end
+ end
+
+ local function cb(record, flags, err)
+ if record then
+ local result, flag_or_policy, error_or_addr
+ if opts.ip then
+ result, flag_or_policy, error_or_addr = record:check_ip(opts.ip)
+ elseif opts.all then
+ result = true
+ end
+ if opts.ip and not opts.all then
+ if result then
+ display_spf_results(error_or_addr, true)
+ else
+ printf('Not matched: %s', error_or_addr)
+ end
+
+ os.exit(0)
+ end
+
+ if result then
+ printf('SPF record for %s; digest: %s',
+ highlight(opts.domain or opts.from), highlight(record:get_digest()))
+ for _, elt in ipairs(record:get_elts()) do
+ if result and error_or_addr and elt.str and elt.str == error_or_addr.str then
+ printf("%s", highlight('*** Matched ***'))
+ display_spf_results(elt, true)
+ printf('------')
+ else
+ display_spf_results(elt, false)
+ printf('------')
+ end
+ end
+ else
+ printf('Error getting SPF record: %s (%s flag)', err,
+ flag_to_str(flag_or_policy or flags))
+ end
+ else
+ printf('Cannot get SPF record: %s', err)
+ end
+ end
+ rspamd_spf.resolve(task, cb)
+end
+
+local function handler(args)
+ local opts = parser:parse(args)
+ load_config(opts)
+
+ local command = opts.command
+
+ if command == 'spf' then
+ spf_handler(opts)
+ else
+ parser:error('command %s is not implemented', command)
+ end
+end
+
+return {
+ name = 'dnstool',
+ aliases = { 'dns', 'dns_tool' },
+ handler = handler,
+ description = parser._description
+} \ No newline at end of file
diff --git a/lualib/rspamadm/fuzzy_convert.lua b/lualib/rspamadm/fuzzy_convert.lua
new file mode 100644
index 0000000..fab3995
--- /dev/null
+++ b/lualib/rspamadm/fuzzy_convert.lua
@@ -0,0 +1,208 @@
+local sqlite3 = require "rspamd_sqlite3"
+local redis = require "rspamd_redis"
+local util = require "rspamd_util"
+
+local function connect_redis(server, username, password, db)
+ local ret
+ local conn, err = redis.connect_sync({
+ host = server,
+ })
+
+ if not conn then
+ return nil, 'Cannot connect: ' .. err
+ end
+
+ if username then
+ if password then
+ ret = conn:add_cmd('AUTH', { username, password })
+ if not ret then
+ return nil, 'Cannot queue command'
+ end
+ else
+ return nil, 'Redis requires a password when username is supplied'
+ end
+ elseif password then
+ ret = conn:add_cmd('AUTH', { password })
+ if not ret then
+ return nil, 'Cannot queue command'
+ end
+ end
+ if db then
+ ret = conn:add_cmd('SELECT', { db })
+ if not ret then
+ return nil, 'Cannot queue command'
+ end
+ end
+
+ return conn, nil
+end
+
+local function send_digests(digests, redis_host, redis_username, redis_password, redis_db)
+ local conn, err = connect_redis(redis_host, redis_username, redis_password, redis_db)
+ if err then
+ print(err)
+ return false
+ end
+ local ret
+ for _, v in ipairs(digests) do
+ ret = conn:add_cmd('HMSET', {
+ 'fuzzy' .. v[1],
+ 'F', v[2],
+ 'V', v[3],
+ })
+ if not ret then
+ print('Cannot batch command')
+ return false
+ end
+ ret = conn:add_cmd('EXPIRE', {
+ 'fuzzy' .. v[1],
+ tostring(v[4]),
+ })
+ if not ret then
+ print('Cannot batch command')
+ return false
+ end
+ end
+ ret, err = conn:exec()
+ if not ret then
+ print('Cannot execute batched commands: ' .. err)
+ return false
+ end
+ return true
+end
+
+local function send_shingles(shingles, redis_host, redis_username, redis_password, redis_db)
+ local conn, err = connect_redis(redis_host, redis_username, redis_password, redis_db)
+ if err then
+ print("Redis error: " .. err)
+ return false
+ end
+ local ret
+ for _, v in ipairs(shingles) do
+ ret = conn:add_cmd('SET', {
+ 'fuzzy_' .. v[2] .. '_' .. v[1],
+ v[4],
+ })
+ if not ret then
+ print('Cannot batch SET command: ' .. err)
+ return false
+ end
+ ret = conn:add_cmd('EXPIRE', {
+ 'fuzzy_' .. v[2] .. '_' .. v[1],
+ tostring(v[3]),
+ })
+ if not ret then
+ print('Cannot batch command')
+ return false
+ end
+ end
+ ret, err = conn:exec()
+ if not ret then
+ print('Cannot execute batched commands: ' .. err)
+ return false
+ end
+ return true
+end
+
+local function update_counters(total, redis_host, redis_username, redis_password, redis_db)
+ local conn, err = connect_redis(redis_host, redis_username, redis_password, redis_db)
+ if err then
+ print(err)
+ return false
+ end
+ local ret
+ ret = conn:add_cmd('SET', {
+ 'fuzzylocal',
+ total,
+ })
+ if not ret then
+ print('Cannot batch command')
+ return false
+ end
+ ret = conn:add_cmd('SET', {
+ 'fuzzy_count',
+ total,
+ })
+ if not ret then
+ print('Cannot batch command')
+ return false
+ end
+ ret, err = conn:exec()
+ if not ret then
+ print('Cannot execute batched commands: ' .. err)
+ return false
+ end
+ return true
+end
+
+return function(_, res)
+ local db = sqlite3.open(res['source_db'])
+ local shingles = {}
+ local digests = {}
+ local num_batch_digests = 0
+ local num_batch_shingles = 0
+ local total_digests = 0
+ local total_shingles = 0
+ local lim_batch = 1000 -- Update each 1000 entries
+ local redis_username = res['redis_username']
+ local redis_password = res['redis_password']
+ local redis_db = nil
+
+ if res['redis_db'] then
+ redis_db = tostring(res['redis_db'])
+ end
+
+ if not db then
+ print('Cannot open source db: ' .. res['source_db'])
+ return
+ end
+
+ local now = util.get_time()
+ for row in db:rows('SELECT id, flag, digest, value, time FROM digests') do
+
+ local expire_in = math.floor(now - row.time + res['expiry'])
+ if expire_in >= 1 then
+ table.insert(digests, { row.digest, row.flag, row.value, expire_in })
+ num_batch_digests = num_batch_digests + 1
+ total_digests = total_digests + 1
+ for srow in db:rows('SELECT value, number FROM shingles WHERE digest_id = ' .. row.id) do
+ table.insert(shingles, { srow.value, srow.number, expire_in, row.digest })
+ total_shingles = total_shingles + 1
+ num_batch_shingles = num_batch_shingles + 1
+ end
+ end
+ if num_batch_digests >= lim_batch then
+ if not send_digests(digests, res['redis_host'], redis_username, redis_password, redis_db) then
+ return
+ end
+ num_batch_digests = 0
+ digests = {}
+ end
+ if num_batch_shingles >= lim_batch then
+ if not send_shingles(shingles, res['redis_host'], redis_username, redis_password, redis_db) then
+ return
+ end
+ num_batch_shingles = 0
+ shingles = {}
+ end
+ end
+ if digests[1] then
+ if not send_digests(digests, res['redis_host'], redis_username, redis_password, redis_db) then
+ return
+ end
+ end
+ if shingles[1] then
+ if not send_shingles(shingles, res['redis_host'], redis_username, redis_password, redis_db) then
+ return
+ end
+ end
+
+ local message = string.format(
+ 'Migrated %d digests and %d shingles',
+ total_digests, total_shingles
+ )
+ if not update_counters(total_digests, res['redis_host'], redis_username, redis_password, redis_db) then
+ message = message .. ' but failed to update counters'
+ end
+ print(message)
+end
diff --git a/lualib/rspamadm/fuzzy_ping.lua b/lualib/rspamadm/fuzzy_ping.lua
new file mode 100644
index 0000000..e0345da
--- /dev/null
+++ b/lualib/rspamadm/fuzzy_ping.lua
@@ -0,0 +1,259 @@
+--[[
+Copyright (c) 2023, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local argparse = require "argparse"
+local ansicolors = require "ansicolors"
+local rspamd_logger = require "rspamd_logger"
+local lua_util = require "lua_util"
+
+local E = {}
+
+local parser = argparse()
+ :name 'rspamadm fuzzy_ping'
+ :description 'Pings fuzzy storage'
+ :help_description_margin(30)
+parser:option "-c --config"
+ :description "Path to config file"
+ :argname("<cfg>")
+ :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
+parser:option "-r --rule"
+ :description "Storage to ping (must be configured in Rspamd configuration)"
+ :argname("<name>")
+ :default("rspamd.com")
+parser:flag "-f --flood"
+ :description "Flood mode (no waiting for replies)"
+parser:flag "-S --silent"
+ :description "Silent mode (statistics only)"
+parser:option "-t --timeout"
+ :description "Timeout for requests"
+ :argname("<timeout>")
+ :convert(tonumber)
+ :default(5)
+parser:option "-s --server"
+ :description "Override server to ping"
+ :argname("<name>")
+parser:option "-n --number"
+ :description "Timeout for requests"
+ :argname("<number>")
+ :convert(tonumber)
+ :default(5)
+parser:flag "-l --list"
+ :description "List configured storages"
+
+local function load_config(opts)
+ local _r, err = rspamd_config:load_ucl(opts['config'])
+
+ if not _r then
+ rspamd_logger.errx('cannot parse %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
+ -- Init the real structure excluding logging and workers
+ _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' })
+ if not _r then
+ rspamd_logger.errx('cannot process %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
+ _r, err = rspamd_config:init_modules()
+ if not _r then
+ rspamd_logger.errx('cannot init modules from %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+end
+
+local function highlight(fmt, ...)
+ return ansicolors.white .. string.format(fmt, ...) .. ansicolors.reset
+end
+
+local function highlight_err(fmt, ...)
+ return ansicolors.red .. string.format(fmt, ...) .. ansicolors.reset
+end
+
+local function print_storages(rules)
+ for n, rule in pairs(rules) do
+ print(highlight('Rule: %s', n))
+ print(string.format("\tRead only: %s", rule.read_only))
+ print(string.format("\tServers: %s", table.concat(lua_util.values(rule.servers), ',')))
+ print("\tFlags:")
+
+ for fl, id in pairs(rule.flags or E) do
+ print(string.format("\t\t%s: %s", fl, id))
+ end
+ end
+end
+
+local function std_mean(tbl)
+ local function mean()
+ local sum = 0
+ local count = 0
+
+ for _, v in ipairs(tbl) do
+ sum = sum + v
+ count = count + 1
+ end
+
+ return (sum / count)
+ end
+
+ local m
+ local vm
+ local sum = 0
+ local count = 0
+ local result
+
+ m = mean(tbl)
+
+ for _, v in ipairs(tbl) do
+ vm = v - m
+ sum = sum + (vm * vm)
+ count = count + 1
+ end
+
+ result = math.sqrt(sum / (count - 1))
+
+ return result, m
+end
+
+local function maxmin(tbl)
+ local max = -math.huge
+ local min = math.huge
+
+ for _, v in ipairs(tbl) do
+ max = math.max(max, v)
+ min = math.min(min, v)
+ end
+
+ return max, min
+end
+
+local function print_results(results)
+ local servers = {}
+ local err_servers = {}
+ for _, res in ipairs(results) do
+ if res.success then
+ if servers[res.server] then
+ table.insert(servers[res.server], res.latency)
+ else
+ servers[res.server] = { res.latency }
+ end
+ else
+ if err_servers[res.server] then
+ err_servers[res.server] = err_servers[res.server] + 1
+ else
+ err_servers[res.server] = 1
+ end
+ -- For the case if no successful replies are detected
+ if not servers[res.server] then
+ servers[res.server] = {}
+ end
+ end
+ end
+
+ for s, l in pairs(servers) do
+ local total = #l + (err_servers[s] or 0)
+ print(highlight('Summary for %s: %d packets transmitted, %d packets received, %.1f%% packet loss',
+ s, total, #l, (total - #l) * 100.0 / total))
+ local mean, std = std_mean(l)
+ local max, min = maxmin(l)
+ print(string.format('round-trip min/avg/max/std-dev = %.2f/%.2f/%.2f/%.2f ms',
+ min, mean,
+ max, std))
+ end
+end
+
+local function handler(args)
+ local opts = parser:parse(args)
+
+ load_config(opts)
+
+ if opts.list then
+ print_storages(rspamd_plugins.fuzzy_check.list_storages(rspamd_config))
+ os.exit(0)
+ end
+
+ -- Perform ping using a fake task from async stuff provided by rspamadm
+ local rspamd_task = require "rspamd_task"
+
+ -- TODO: this task is not cleared at the end, do something about it some day
+ local task = rspamd_task.create(rspamd_config, rspamadm_ev_base)
+ task:set_session(rspamadm_session)
+ task:set_resolver(rspamadm_dns_resolver)
+
+ local replied = 0
+ local results = {}
+ local ping_fuzzy
+
+ local function gen_ping_fuzzy_cb(num)
+ return function(success, server, latency_or_err)
+ if not success then
+ if not opts.silent then
+ print(highlight_err('error from %s: %s', server, latency_or_err))
+ end
+ results[num] = {
+ success = false,
+ error = latency_or_err,
+ server = tostring(server),
+ }
+ else
+ if not opts.silent then
+ local adjusted_latency = math.floor(latency_or_err * 1000) * 1.0 / 1000;
+ print(highlight('reply from %s: %s ms', server, adjusted_latency))
+
+ end
+ results[num] = {
+ success = true,
+ latency = latency_or_err,
+ server = tostring(server),
+ }
+ end
+
+ if replied == opts.number - 1 then
+ print_results(results)
+ else
+ replied = replied + 1
+ if not opts.flood then
+ ping_fuzzy(replied + 1)
+ end
+ end
+ end
+ end
+
+ ping_fuzzy = function(num)
+ local ret, err = rspamd_plugins.fuzzy_check.ping_storage(task, gen_ping_fuzzy_cb(num),
+ opts.rule, opts.timeout, opts.server)
+
+ if not ret then
+ print(highlight_err('error from %s: %s', opts.server, err))
+ opts.number = opts.number - 1 -- To avoid issues with waiting for other replies
+ end
+ end
+
+ if opts.flood then
+ for i = 1, opts.number do
+ ping_fuzzy(i)
+ end
+ else
+ ping_fuzzy(1)
+ end
+end
+
+return {
+ name = 'fuzzy_ping',
+ aliases = { 'fuzzyping' },
+ handler = handler,
+ description = parser._description
+} \ No newline at end of file
diff --git a/lualib/rspamadm/fuzzy_stat.lua b/lualib/rspamadm/fuzzy_stat.lua
new file mode 100644
index 0000000..ef8a5de
--- /dev/null
+++ b/lualib/rspamadm/fuzzy_stat.lua
@@ -0,0 +1,366 @@
+local rspamd_util = require "rspamd_util"
+local lua_util = require "lua_util"
+local opts = {}
+
+local argparse = require "argparse"
+local parser = argparse()
+ :name "rspamadm control fuzzystat"
+ :description "Shows help for the specified configuration options"
+ :help_description_margin(32)
+parser:flag "--no-ips"
+ :description "No IPs stats"
+parser:flag "--no-keys"
+ :description "No keys stats"
+parser:flag "--short"
+ :description "Short output mode"
+parser:flag "-n --number"
+ :description "Disable numbers humanization"
+parser:option "--sort"
+ :description "Sort order"
+ :convert {
+ checked = "checked",
+ matched = "matched",
+ errors = "errors",
+ name = "name"
+}
+
+local function add_data(target, src)
+ for k, v in pairs(src) do
+ if type(v) == 'number' then
+ if target[k] then
+ target[k] = target[k] + v
+ else
+ target[k] = v
+ end
+ elseif k == 'ips' then
+ if not target['ips'] then
+ target['ips'] = {}
+ end
+ -- Iterate over IPs
+ for ip, st in pairs(v) do
+ if not target['ips'][ip] then
+ target['ips'][ip] = {}
+ end
+ add_data(target['ips'][ip], st)
+ end
+ elseif k == 'flags' then
+ if not target['flags'] then
+ target['flags'] = {}
+ end
+ -- Iterate over Flags
+ for flag, st in pairs(v) do
+ if not target['flags'][flag] then
+ target['flags'][flag] = {}
+ end
+ add_data(target['flags'][flag], st)
+ end
+ elseif k == 'keypair' then
+ if type(v.extensions) == 'table' then
+ if type(v.extensions.name) == 'string' then
+ target.name = v.extensions.name
+ end
+ end
+ end
+ end
+end
+
+local function print_num(num)
+ if num then
+ if opts['n'] or opts['number'] then
+ return tostring(num)
+ else
+ return rspamd_util.humanize_number(num)
+ end
+ else
+ return 'na'
+ end
+end
+
+local function print_stat(st, tabs)
+ if st['checked'] then
+ if st.checked_per_hour then
+ print(string.format('%sChecked: %s (%s per hour in average)', tabs,
+ print_num(st['checked']), print_num(st['checked_per_hour'])))
+ else
+ print(string.format('%sChecked: %s', tabs, print_num(st['checked'])))
+ end
+ end
+ if st['matched'] then
+ if st.checked and st.checked > 0 and st.checked <= st.matched then
+ local percentage = st.matched / st.checked * 100.0
+ if st.matched_per_hour then
+ print(string.format('%sMatched: %s - %s percent (%s per hour in average)', tabs,
+ print_num(st['matched']), percentage, print_num(st['matched_per_hour'])))
+ else
+ print(string.format('%sMatched: %s - %s percent', tabs, print_num(st['matched']), percentage))
+ end
+ else
+ if st.matched_per_hour then
+ print(string.format('%sMatched: %s (%s per hour in average)', tabs,
+ print_num(st['matched']), print_num(st['matched_per_hour'])))
+ else
+ print(string.format('%sMatched: %s', tabs, print_num(st['matched'])))
+ end
+ end
+ end
+ if st['errors'] then
+ print(string.format('%sErrors: %s', tabs, print_num(st['errors'])))
+ end
+ if st['added'] then
+ print(string.format('%sAdded: %s', tabs, print_num(st['added'])))
+ end
+ if st['deleted'] then
+ print(string.format('%sDeleted: %s', tabs, print_num(st['deleted'])))
+ end
+end
+
+-- Sort by checked
+local function sort_hash_table(tbl, sort_opts, key_key)
+ local res = {}
+ for k, v in pairs(tbl) do
+ table.insert(res, { [key_key] = k, data = v })
+ end
+
+ local function sort_order(elt)
+ local key = 'checked'
+ local sort_res = 0
+
+ if sort_opts['sort'] then
+ if sort_opts['sort'] == 'matched' then
+ key = 'matched'
+ elseif sort_opts['sort'] == 'errors' then
+ key = 'errors'
+ elseif sort_opts['sort'] == 'name' then
+ return elt[key_key]
+ end
+ end
+
+ if elt.data[key] then
+ sort_res = elt.data[key]
+ end
+
+ return sort_res
+ end
+
+ table.sort(res, function(a, b)
+ return sort_order(a) > sort_order(b)
+ end)
+
+ return res
+end
+
+local function add_result(dst, src, k)
+ if type(src) == 'table' then
+ if type(dst) == 'number' then
+ -- Convert dst to table
+ dst = { dst }
+ elseif type(dst) == 'nil' then
+ dst = {}
+ end
+
+ for i, v in ipairs(src) do
+ if dst[i] and k ~= 'fuzzy_stored' then
+ dst[i] = dst[i] + v
+ else
+ dst[i] = v
+ end
+ end
+ else
+ if type(dst) == 'table' then
+ if k ~= 'fuzzy_stored' then
+ dst[1] = dst[1] + src
+ else
+ dst[1] = src
+ end
+ else
+ if dst and k ~= 'fuzzy_stored' then
+ dst = dst + src
+ else
+ dst = src
+ end
+ end
+ end
+
+ return dst
+end
+
+local function print_result(r)
+ local function num_to_epoch(num)
+ if num == 1 then
+ return 'v0.6'
+ elseif num == 2 then
+ return 'v0.8'
+ elseif num == 3 then
+ return 'v0.9'
+ elseif num == 4 then
+ return 'v1.0+'
+ elseif num == 5 then
+ return 'v1.7+'
+ end
+ return '???'
+ end
+ if type(r) == 'table' then
+ local res = {}
+ for i, num in ipairs(r) do
+ res[i] = string.format('(%s: %s)', num_to_epoch(i), print_num(num))
+ end
+
+ return table.concat(res, ', ')
+ end
+
+ return print_num(r)
+end
+
+return function(args, res)
+ local res_ips = {}
+ local res_databases = {}
+ local wrk = res['workers']
+ opts = parser:parse(args)
+
+ if wrk then
+ for _, pr in pairs(wrk) do
+ -- processes cycle
+ if pr['data'] then
+ local id = pr['id']
+
+ if id then
+ local res_db = res_databases[id]
+ if not res_db then
+ res_db = {
+ keys = {}
+ }
+ res_databases[id] = res_db
+ end
+
+ -- General stats
+ for k, v in pairs(pr['data']) do
+ if k ~= 'keys' and k ~= 'errors_ips' then
+ res_db[k] = add_result(res_db[k], v, k)
+ elseif k == 'errors_ips' then
+ -- Errors ips
+ if not res_db['errors_ips'] then
+ res_db['errors_ips'] = {}
+ end
+ for ip, nerrors in pairs(v) do
+ if not res_db['errors_ips'][ip] then
+ res_db['errors_ips'][ip] = nerrors
+ else
+ res_db['errors_ips'][ip] = nerrors + res_db['errors_ips'][ip]
+ end
+ end
+ end
+ end
+
+ if pr['data']['keys'] then
+ local res_keys = res_db['keys']
+ if not res_keys then
+ res_keys = {}
+ res_db['keys'] = res_keys
+ end
+ -- Go through keys in input
+ for k, elts in pairs(pr['data']['keys']) do
+ -- keys cycle
+ if not res_keys[k] then
+ res_keys[k] = {}
+ end
+
+ add_data(res_keys[k], elts)
+
+ if elts['ips'] then
+ for ip, v in pairs(elts['ips']) do
+ if not res_ips[ip] then
+ res_ips[ip] = {}
+ end
+ add_data(res_ips[ip], v)
+ end
+ end
+ end
+ end
+ end
+ end
+ end
+ end
+
+ -- General stats
+ for db, st in pairs(res_databases) do
+ print(string.format('Statistics for storage %s', db))
+
+ for k, v in pairs(st) do
+ if k ~= 'keys' and k ~= 'errors_ips' then
+ print(string.format('%s: %s', k, print_result(v)))
+ end
+ end
+ print('')
+
+ local res_keys = st['keys']
+ if res_keys and not opts['no_keys'] and not opts['short'] then
+ print('Keys statistics:')
+ -- Convert into an array to allow sorting
+ local sorted_keys = sort_hash_table(res_keys, opts, 'key')
+
+ for _, key in ipairs(sorted_keys) do
+ local key_stat = key.data
+ if key_stat.name then
+ print(string.format('Key id: %s, name: %s', key.key, key_stat.name))
+ else
+ print(string.format('Key id: %s', key.key))
+ end
+
+ print_stat(key_stat, '\t')
+
+ if key_stat['ips'] and not opts['no_ips'] then
+ print('')
+ print('\tIPs stat:')
+ local sorted_ips = sort_hash_table(key_stat['ips'], opts, 'ip')
+
+ for _, v in ipairs(sorted_ips) do
+ print(string.format('\t%s', v['ip']))
+ print_stat(v['data'], '\t\t')
+ print('')
+ end
+ end
+
+ if key_stat.flags then
+ print('')
+ print('\tFlags stat:')
+ for flag, v in pairs(key_stat.flags) do
+ print(string.format('\t[%s]:', flag))
+ -- Remove irrelevant fields
+ v.checked = nil
+ print_stat(v, '\t\t')
+ print('')
+ end
+ end
+
+ print('')
+ end
+ end
+ if st['errors_ips'] and not opts['no_ips'] and not opts['short'] then
+ print('')
+ print('Errors IPs statistics:')
+ local ip_stat = st['errors_ips']
+ local ips = lua_util.keys(ip_stat)
+ -- Reverse sort by number of errors
+ table.sort(ips, function(a, b)
+ return ip_stat[a] > ip_stat[b]
+ end)
+ for _, ip in ipairs(ips) do
+ print(string.format('%s: %s', ip, print_result(ip_stat[ip])))
+ end
+ print('')
+ end
+ end
+
+ if not opts['no_ips'] and not opts['short'] then
+ print('')
+ print('IPs statistics:')
+
+ local sorted_ips = sort_hash_table(res_ips, opts, 'ip')
+ for _, v in ipairs(sorted_ips) do
+ print(string.format('%s', v['ip']))
+ print_stat(v['data'], '\t')
+ print('')
+ end
+ end
+end
+
diff --git a/lualib/rspamadm/grep.lua b/lualib/rspamadm/grep.lua
new file mode 100644
index 0000000..6ed0569
--- /dev/null
+++ b/lualib/rspamadm/grep.lua
@@ -0,0 +1,174 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local argparse = require "argparse"
+
+
+-- Define command line options
+local parser = argparse()
+ :name "rspamadm grep"
+ :description "Search for patterns in rspamd logs"
+ :help_description_margin(30)
+parser:mutex(
+ parser:option "-s --string"
+ :description('Plain string to search (case-insensitive)')
+ :argname "<str>",
+ parser:option "-p --pattern"
+ :description('Pattern to search for (regex)')
+ :argname "<re>"
+)
+parser:flag "-l --lua"
+ :description('Use Lua patterns in string search')
+
+parser:argument "input":args "*"
+ :description('Process specified inputs')
+ :default("stdin")
+parser:flag "-S --sensitive"
+ :description('Enable case-sensitivity in string search')
+parser:flag "-o --orphans"
+ :description('Print orphaned logs')
+parser:flag "-P --partial"
+ :description('Print partial logs')
+
+local function handler(args)
+
+ local rspamd_regexp = require 'rspamd_regexp'
+ local res = parser:parse(args)
+
+ if not res['string'] and not res['pattern'] then
+ parser:error('string or pattern options must be specified')
+ end
+
+ if res['string'] and res['pattern'] then
+ parser:error('string and pattern options are mutually exclusive')
+ end
+
+ local buffer = {}
+ local matches = {}
+
+ local pattern = res['pattern']
+ local re
+ if pattern then
+ re = rspamd_regexp.create(pattern)
+ if not re then
+ io.stderr:write("Couldn't compile regex: " .. pattern .. '\n')
+ os.exit(1)
+ end
+ end
+
+ local plainm = true
+ if res['lua'] then
+ plainm = false
+ end
+ local orphans = res['orphans']
+ local search_str = res['string']
+ local sensitive = res['sensitive']
+ local partial = res['partial']
+ if search_str and not sensitive then
+ search_str = string.lower(search_str)
+ end
+ local inputs = res['input'] or { 'stdin' }
+
+ for _, n in ipairs(inputs) do
+ local h, err
+ if string.match(n, '%.xz$') then
+ h, err = io.popen('xzcat ' .. n, 'r')
+ elseif string.match(n, '%.bz2$') then
+ h, err = io.popen('bzcat ' .. n, 'r')
+ elseif string.match(n, '%.gz$') then
+ h, err = io.popen('zcat ' .. n, 'r')
+ elseif string.match(n, '%.zst$') then
+ h, err = io.popen('zstdcat ' .. n, 'r')
+ elseif n == 'stdin' then
+ h = io.input()
+ else
+ h, err = io.open(n, 'r')
+ end
+ if not h then
+ if err then
+ io.stderr:write("Couldn't open file (" .. n .. '): ' .. err .. '\n')
+ else
+ io.stderr:write("Couldn't open file (" .. n .. '): no error\n')
+ end
+ else
+ for line in h:lines() do
+ local hash = string.match(line, '<(%x+)>')
+ local already_matching = false
+ if hash then
+ if matches[hash] then
+ table.insert(matches[hash], line)
+ already_matching = true
+ else
+ if buffer[hash] then
+ table.insert(buffer[hash], line)
+ else
+ buffer[hash] = { line }
+ end
+ end
+ end
+ local ismatch = false
+ if re then
+ ismatch = re:match(line)
+ elseif sensitive and search_str then
+ ismatch = string.find(line, search_str, 1, plainm)
+ elseif search_str then
+ local lwr = string.lower(line)
+ ismatch = string.find(lwr, search_str, 1, plainm)
+ end
+ if ismatch then
+ if not hash then
+ if orphans then
+ print('*** orphaned ***')
+ print(line)
+ print()
+ end
+ elseif not already_matching then
+ matches[hash] = buffer[hash]
+ end
+ end
+ local is_end = string.match(line, '<%x+>; task; rspamd_protocol_http_reply:')
+ if is_end then
+ buffer[hash] = nil
+ if matches[hash] then
+ for _, v in ipairs(matches[hash]) do
+ print(v)
+ end
+ print()
+ matches[hash] = nil
+ end
+ end
+ end
+ if partial then
+ for k, v in pairs(matches) do
+ print('*** partial ***')
+ for _, vv in ipairs(v) do
+ print(vv)
+ end
+ print()
+ matches[k] = nil
+ end
+ else
+ matches = {}
+ end
+ end
+ end
+end
+
+return {
+ handler = handler,
+ description = parser._description,
+ name = 'grep'
+} \ No newline at end of file
diff --git a/lualib/rspamadm/keypair.lua b/lualib/rspamadm/keypair.lua
new file mode 100644
index 0000000..f0716a2
--- /dev/null
+++ b/lualib/rspamadm/keypair.lua
@@ -0,0 +1,508 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local argparse = require "argparse"
+local rspamd_keypair = require "rspamd_cryptobox_keypair"
+local rspamd_pubkey = require "rspamd_cryptobox_pubkey"
+local rspamd_signature = require "rspamd_cryptobox_signature"
+local rspamd_crypto = require "rspamd_cryptobox"
+local rspamd_util = require "rspamd_util"
+local ucl = require "ucl"
+local logger = require "rspamd_logger"
+
+-- Define command line options
+local parser = argparse()
+ :name "rspamadm keypair"
+ :description "Manages keypairs for Rspamd"
+ :help_description_margin(30)
+ :command_target("command")
+ :require_command(false)
+
+-- Generate subcommand
+local generate = parser:command "generate gen g"
+ :description "Creates a new keypair"
+generate:flag "-s --sign"
+ :description "Generates a sign keypair instead of the encryption one"
+generate:flag "-n --nist"
+ :description "Uses nist encryption algorithm"
+generate:option "-o --output"
+ :description "Write keypair to file"
+ :argname "<file>"
+generate:mutex(
+ generate:flag "-j --json"
+ :description "Output JSON instead of UCL",
+ generate:flag "-u --ucl"
+ :description "Output UCL"
+ :default(true)
+)
+generate:option "--name"
+ :description "Adds name extension"
+ :argname "<name>"
+
+-- Sign subcommand
+
+local sign = parser:command "sign sig s"
+ :description "Signs a file using keypair"
+sign:option "-k --keypair"
+ :description "Keypair to use"
+ :argname "<file>"
+sign:option "-s --suffix"
+ :description "Suffix for signature"
+ :argname "<suffix>"
+ :default("sig")
+sign:argument "file"
+ :description "File to sign"
+ :argname "<file>"
+ :args "*"
+
+-- Verify subcommand
+
+local verify = parser:command "verify ver v"
+ :description "Verifies a file using keypair or a public key"
+verify:mutex(
+ verify:option "-p --pubkey"
+ :description "Load pubkey from the specified file"
+ :argname "<file>",
+ verify:option "-P --pubstring"
+ :description "Load pubkey from the base32 encoded string"
+ :argname "<base32>",
+ verify:option "-k --keypair"
+ :description "Get pubkey from the keypair file"
+ :argname "<file>"
+)
+verify:argument "file"
+ :description "File to verify"
+ :argname "<file>"
+ :args "*"
+verify:flag "-n --nist"
+ :description "Uses nistp curves (P256)"
+verify:option "-s --suffix"
+ :description "Suffix for signature"
+ :argname "<suffix>"
+ :default("sig")
+
+-- Encrypt subcommand
+
+local encrypt = parser:command "encrypt crypt enc e"
+ :description "Encrypts a file using keypair (or a pubkey)"
+encrypt:mutex(
+ encrypt:option "-p --pubkey"
+ :description "Load pubkey from the specified file"
+ :argname "<file>",
+ encrypt:option "-P --pubstring"
+ :description "Load pubkey from the base32 encoded string"
+ :argname "<base32>",
+ encrypt:option "-k --keypair"
+ :description "Get pubkey from the keypair file"
+ :argname "<file>"
+)
+encrypt:option "-s --suffix"
+ :description "Suffix for encrypted file"
+ :argname "<suffix>"
+ :default("enc")
+encrypt:argument "file"
+ :description "File to encrypt"
+ :argname "<file>"
+ :args "*"
+encrypt:flag "-r --rm"
+ :description "Remove unencrypted file"
+encrypt:flag "-f --force"
+ :description "Remove destination file if it exists"
+
+-- Decrypt subcommand
+
+local decrypt = parser:command "decrypt dec d"
+ :description "Decrypts a file using keypair"
+decrypt:option "-k --keypair"
+ :description "Get pubkey from the keypair file"
+ :argname "<file>"
+decrypt:flag "-S --keep-suffix"
+ :description "Preserve suffix for decrypted file (overwrite encrypted)"
+decrypt:argument "file"
+ :description "File to encrypt"
+ :argname "<file>"
+ :args "*"
+decrypt:flag "-f --force"
+ :description "Remove destination file if it exists (implied with -S)"
+decrypt:flag "-r --rm"
+ :description "Remove encrypted file"
+
+-- Default command is generate, so duplicate options to be compatible
+
+parser:flag "-s --sign"
+ :description "Generates a sign keypair instead of the encryption one"
+parser:flag "-n --nist"
+ :description "Uses nistp curves (P256)"
+parser:mutex(
+ parser:flag "-j --json"
+ :description "Output JSON instead of UCL",
+ parser:flag "-u --ucl"
+ :description "Output UCL"
+ :default(true)
+)
+parser:option "-o --output"
+ :description "Write keypair to file"
+ :argname "<file>"
+
+local function fatal(...)
+ logger.errx(...)
+ os.exit(1)
+end
+
+local function ask_yes_no(greet, default)
+ local def_str
+ if default then
+ greet = greet .. "[Y/n]: "
+ def_str = "yes"
+ else
+ greet = greet .. "[y/N]: "
+ def_str = "no"
+ end
+
+ local reply = rspamd_util.readline(greet)
+
+ if not reply then
+ os.exit(0)
+ end
+ if #reply == 0 then
+ reply = def_str
+ end
+ reply = reply:lower()
+ if reply == 'y' or reply == 'yes' then
+ return true
+ end
+
+ return false
+end
+
+local function generate_handler(opts)
+ local mode = 'encryption'
+ if opts.sign then
+ mode = 'sign'
+ end
+ local alg = 'curve25519'
+ if opts.nist then
+ alg = 'nist'
+ end
+ -- TODO: probably, do it in a more safe way
+ local kp = rspamd_keypair.create(mode, alg):totable()
+
+ if opts.name then
+ kp.keypair.extensions = {
+ name = opts.name
+ }
+ end
+
+ local format = 'ucl'
+
+ if opts.json then
+ format = 'json'
+ end
+
+ if opts.output then
+ local out = io.open(opts.output, 'w')
+ if not out then
+ fatal('cannot open output to write: ' .. opts.output)
+ end
+ out:write(ucl.to_format(kp, format))
+ out:close()
+ else
+ io.write(ucl.to_format(kp, format))
+ end
+end
+
+local function sign_handler(opts)
+ if opts.file then
+ if type(opts.file) == 'string' then
+ opts.file = { opts.file }
+ end
+ else
+ parser:error('no files to sign')
+ end
+ if not opts.keypair then
+ parser:error("no keypair specified")
+ end
+
+ local ucl_parser = ucl.parser()
+ local res, err = ucl_parser:parse_file(opts.keypair)
+
+ if not res then
+ fatal(string.format('cannot load %s: %s', opts.keypair, err))
+ end
+
+ local kp = rspamd_keypair.load(ucl_parser:get_object())
+
+ if not kp then
+ fatal("cannot load keypair: " .. opts.keypair)
+ end
+
+ for _, fname in ipairs(opts.file) do
+ local sig = rspamd_crypto.sign_file(kp, fname)
+
+ if not sig then
+ fatal(string.format("cannot sign %s\n", fname))
+ end
+
+ local out = string.format('%s.%s', fname, opts.suffix or 'sig')
+ local of = io.open(out, 'w')
+ if not of then
+ fatal('cannot open output to write: ' .. out)
+ end
+ of:write(sig:bin())
+ of:close()
+ io.write(string.format('signed %s -> %s (%s)\n', fname, out, sig:hex()))
+ end
+end
+
+local function verify_handler(opts)
+ if opts.file then
+ if type(opts.file) == 'string' then
+ opts.file = { opts.file }
+ end
+ else
+ parser:error('no files to verify')
+ end
+
+ local pk
+ local alg = 'curve25519'
+
+ if opts.keypair then
+ local ucl_parser = ucl.parser()
+ local res, err = ucl_parser:parse_file(opts.keypair)
+
+ if not res then
+ fatal(string.format('cannot load %s: %s', opts.keypair, err))
+ end
+
+ local kp = rspamd_keypair.load(ucl_parser:get_object())
+
+ if not kp then
+ fatal("cannot load keypair: " .. opts.keypair)
+ end
+
+ pk = kp:pk()
+ alg = kp:alg()
+ elseif opts.pubkey then
+ if opts.nist then
+ alg = 'nist'
+ end
+ pk = rspamd_pubkey.load(opts.pubkey, 'sign', alg)
+ elseif opts.pubstr then
+ if opts.nist then
+ alg = 'nist'
+ end
+ pk = rspamd_pubkey.create(opts.pubstr, 'sign', alg)
+ end
+
+ if not pk then
+ fatal("cannot create pubkey")
+ end
+
+ local valid = true
+
+ for _, fname in ipairs(opts.file) do
+
+ local sig_fname = string.format('%s.%s', fname, opts.suffix or 'sig')
+ local sig = rspamd_signature.load(sig_fname, alg)
+
+ if not sig then
+ fatal(string.format("cannot load signature for %s -> %s",
+ fname, sig_fname))
+ end
+
+ if rspamd_crypto.verify_file(pk, sig, fname, alg) then
+ io.write(string.format('verified %s -> %s (%s)\n', fname, sig_fname, sig:hex()))
+ else
+ valid = false
+ io.write(string.format('FAILED to verify %s -> %s (%s)\n', fname,
+ sig_fname, sig:hex()))
+ end
+ end
+
+ if not valid then
+ os.exit(1)
+ end
+end
+
+local function encrypt_handler(opts)
+ if opts.file then
+ if type(opts.file) == 'string' then
+ opts.file = { opts.file }
+ end
+ else
+ parser:error('no files to sign')
+ end
+
+ local pk
+ local alg = 'curve25519'
+
+ if opts.keypair then
+ local ucl_parser = ucl.parser()
+ local res, err = ucl_parser:parse_file(opts.keypair)
+
+ if not res then
+ fatal(string.format('cannot load %s: %s', opts.keypair, err))
+ end
+
+ local kp = rspamd_keypair.load(ucl_parser:get_object())
+
+ if not kp then
+ fatal("cannot load keypair: " .. opts.keypair)
+ end
+
+ pk = kp:pk()
+ alg = kp:alg()
+ elseif opts.pubkey then
+ if opts.nist then
+ alg = 'nist'
+ end
+ pk = rspamd_pubkey.load(opts.pubkey, 'sign', alg)
+ elseif opts.pubstr then
+ if opts.nist then
+ alg = 'nist'
+ end
+ pk = rspamd_pubkey.create(opts.pubstr, 'sign', alg)
+ end
+
+ if not pk then
+ fatal("cannot load keypair: " .. opts.keypair)
+ end
+
+ for _, fname in ipairs(opts.file) do
+ local enc = rspamd_crypto.encrypt_file(pk, fname, alg)
+
+ if not enc then
+ fatal(string.format("cannot encrypt %s\n", fname))
+ end
+
+ local out
+ if opts.suffix and #opts.suffix > 0 then
+ out = string.format('%s.%s', fname, opts.suffix)
+ else
+ out = string.format('%s', fname)
+ end
+
+ if rspamd_util.file_exists(out) then
+ if opts.force or ask_yes_no(string.format('File %s already exists, overwrite?',
+ out), true) then
+ os.remove(out)
+ else
+ os.exit(1)
+ end
+ end
+
+ enc:save_in_file(out)
+
+ if opts.rm then
+ os.remove(fname)
+ io.write(string.format('encrypted %s (deleted) -> %s\n', fname, out))
+ else
+ io.write(string.format('encrypted %s -> %s\n', fname, out))
+ end
+ end
+end
+
+local function decrypt_handler(opts)
+ if opts.file then
+ if type(opts.file) == 'string' then
+ opts.file = { opts.file }
+ end
+ else
+ parser:error('no files to decrypt')
+ end
+ if not opts.keypair then
+ parser:error("no keypair specified")
+ end
+
+ local ucl_parser = ucl.parser()
+ local res, err = ucl_parser:parse_file(opts.keypair)
+
+ if not res then
+ fatal(string.format('cannot load %s: %s', opts.keypair, err))
+ end
+
+ local kp = rspamd_keypair.load(ucl_parser:get_object())
+
+ if not kp then
+ fatal("cannot load keypair: " .. opts.keypair)
+ end
+
+ for _, fname in ipairs(opts.file) do
+ local decrypted = rspamd_crypto.decrypt_file(kp, fname)
+
+ if not decrypted then
+ fatal(string.format("cannot decrypt %s\n", fname))
+ end
+
+ local out
+ if not opts['keep-suffix'] then
+ -- Strip the last suffix
+ out = fname:match("^(.+)%..+$")
+ else
+ out = fname
+ end
+
+ local removed = false
+
+ if rspamd_util.file_exists(out) then
+ if (opts.force or opts['keep-suffix'])
+ or ask_yes_no(string.format('File %s already exists, overwrite?', out), true) then
+ os.remove(out)
+ removed = true
+ else
+ os.exit(1)
+ end
+ end
+
+ if opts.rm then
+ os.remove(fname)
+ removed = true
+ end
+
+ if removed then
+ io.write(string.format('decrypted %s (removed) -> %s\n', fname, out))
+ else
+ io.write(string.format('decrypted %s -> %s\n', fname, out))
+ end
+ end
+end
+
+local function handler(args)
+ local opts = parser:parse(args)
+
+ local command = opts.command or "generate"
+
+ if command == 'generate' then
+ generate_handler(opts)
+ elseif command == 'sign' then
+ sign_handler(opts)
+ elseif command == 'verify' then
+ verify_handler(opts)
+ elseif command == 'encrypt' then
+ encrypt_handler(opts)
+ elseif command == 'decrypt' then
+ decrypt_handler(opts)
+ else
+ parser:error('command %s is not implemented', command)
+ end
+end
+
+return {
+ name = 'keypair',
+ aliases = { 'kp', 'key' },
+ handler = handler,
+ description = parser._description
+} \ No newline at end of file
diff --git a/lualib/rspamadm/mime.lua b/lualib/rspamadm/mime.lua
new file mode 100644
index 0000000..6a589d6
--- /dev/null
+++ b/lualib/rspamadm/mime.lua
@@ -0,0 +1,1012 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local argparse = require "argparse"
+local ansicolors = require "ansicolors"
+local rspamd_util = require "rspamd_util"
+local rspamd_task = require "rspamd_task"
+local rspamd_text = require "rspamd_text"
+local rspamd_logger = require "rspamd_logger"
+local lua_meta = require "lua_meta"
+local rspamd_url = require "rspamd_url"
+local lua_util = require "lua_util"
+local lua_mime = require "lua_mime"
+local ucl = require "ucl"
+
+-- Define command line options
+local parser = argparse()
+ :name "rspamadm mime"
+ :description "Mime manipulations provided by Rspamd"
+ :help_description_margin(30)
+ :command_target("command")
+ :require_command(true)
+
+parser:option "-c --config"
+ :description "Path to config file"
+ :argname("<cfg>")
+ :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
+parser:mutex(
+ parser:flag "-j --json"
+ :description "JSON output",
+ parser:flag "-U --ucl"
+ :description "UCL output",
+ parser:flag "-M --messagepack"
+ :description "MessagePack output"
+)
+parser:flag "-C --compact"
+ :description "Use compact format"
+parser:flag "--no-file"
+ :description "Do not print filename"
+
+-- Extract subcommand
+local extract = parser:command "extract ex e"
+ :description "Extracts data from MIME messages"
+extract:argument "file"
+ :description "File to process"
+ :argname "<file>"
+ :args "+"
+
+extract:flag "-t --text"
+ :description "Extracts plain text data from a message"
+extract:flag "-H --html"
+ :description "Extracts htm data from a message"
+extract:option "-o --output"
+ :description "Output format ('raw', 'content', 'oneline', 'decoded', 'decoded_utf')"
+ :argname("<type>")
+ :convert {
+ raw = "raw",
+ content = "content",
+ oneline = "content_oneline",
+ decoded = "raw_parsed",
+ decoded_utf = "raw_utf"
+}
+ :default "content"
+extract:flag "-w --words"
+ :description "Extracts words"
+extract:flag "-p --part"
+ :description "Show part info"
+extract:flag "-s --structure"
+ :description "Show structure info (e.g. HTML tags)"
+extract:flag "-i --invisible"
+ :description "Show invisible content for HTML parts"
+extract:option "-F --words-format"
+ :description "Words format ('stem', 'norm', 'raw', 'full')"
+ :argname("<type>")
+ :convert {
+ stem = "stem",
+ norm = "norm",
+ raw = "raw",
+ full = "full",
+}
+ :default "stem"
+
+local stat = parser:command "stat st s"
+ :description "Extracts statistical data from MIME messages"
+stat:argument "file"
+ :description "File to process"
+ :argname "<file>"
+ :args "+"
+stat:mutex(
+ stat:flag "-m --meta"
+ :description "Lua metatokens",
+ stat:flag "-b --bayes"
+ :description "Bayes tokens",
+ stat:flag "-F --fuzzy"
+ :description "Fuzzy hashes"
+)
+stat:flag "-s --shingles"
+ :description "Show shingles for fuzzy hashes"
+
+local urls = parser:command "urls url u"
+ :description "Extracts URLs from MIME messages"
+urls:argument "file"
+ :description "File to process"
+ :argname "<file>"
+ :args "+"
+urls:mutex(
+ urls:flag "-t --tld"
+ :description "Get TLDs only",
+ urls:flag "-H --host"
+ :description "Get hosts only",
+ urls:flag "-f --full"
+ :description "Show piecewise urls as processed by Rspamd"
+)
+
+urls:flag "-u --unique"
+ :description "Print only unique urls"
+urls:flag "-s --sort"
+ :description "Sort output"
+urls:flag "--count"
+ :description "Print count of each printed element"
+urls:flag "-r --reverse"
+ :description "Reverse sort order"
+
+local modify = parser:command "modify mod m"
+ :description "Modifies MIME message"
+modify:argument "file"
+ :description "File to process"
+ :argname "<file>"
+ :args "+"
+
+modify:option "-a --add-header"
+ :description "Adds specific header"
+ :argname "<header=value>"
+ :count "*"
+modify:option "-r --remove-header"
+ :description "Removes specific header (all occurrences)"
+ :argname "<header>"
+ :count "*"
+modify:option "-R --rewrite-header"
+ :description "Rewrites specific header, uses Lua string.format pattern"
+ :argname "<header=pattern>"
+ :count "*"
+modify:option "-t --text-footer"
+ :description "Adds footer to text/plain parts from a specific file"
+ :argname "<file>"
+modify:option "-H --html-footer"
+ :description "Adds footer to text/html parts from a specific file"
+ :argname "<file>"
+
+local sign = parser:command "sign"
+ :description "Performs DKIM signing"
+sign:argument "file"
+ :description "File to process"
+ :argname "<file>"
+ :args "+"
+
+sign:option "-d --domain"
+ :description "Use specific domain"
+ :argname "<domain>"
+ :count "1"
+sign:option "-s --selector"
+ :description "Use specific selector"
+ :argname "<selector>"
+ :count "1"
+sign:option "-k --key"
+ :description "Use specific key of file"
+ :argname "<key>"
+ :count "1"
+sign:option "-t --type"
+ :description "ARC or DKIM signing"
+ :argname("<arc|dkim>")
+ :convert {
+ ['arc'] = 'arc',
+ ['dkim'] = 'dkim',
+}
+ :default 'dkim'
+sign:option "-o --output"
+ :description "Output format"
+ :argname("<message|signature>")
+ :convert {
+ ['message'] = 'message',
+ ['signature'] = 'signature',
+}
+ :default 'message'
+
+local dump = parser:command "dump"
+ :description "Dumps a raw message in different formats"
+dump:argument "file"
+ :description "File to process"
+ :argname "<file>"
+ :args "+"
+-- Duplicate format for convenience
+dump:mutex(
+ parser:flag "-j --json"
+ :description "JSON output",
+ parser:flag "-U --ucl"
+ :description "UCL output",
+ parser:flag "-M --messagepack"
+ :description "MessagePack output"
+)
+dump:flag "-s --split"
+ :description "Split the output file contents such that no content is embedded"
+
+dump:option "-o --outdir"
+ :description "Output directory"
+ :argname("<directory>")
+
+local function load_config(opts)
+ local _r, err = rspamd_config:load_ucl(opts['config'])
+
+ if not _r then
+ rspamd_logger.errx('cannot parse %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
+ _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' })
+ if not _r then
+ rspamd_logger.errx('cannot process %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+end
+
+local function load_task(opts, fname)
+ if not fname then
+ fname = '-'
+ end
+
+ local res, task = rspamd_task.load_from_file(fname, rspamd_config)
+
+ if not res then
+ parser:error(string.format('cannot read message from %s: %s', fname,
+ task))
+ end
+
+ if not task:process_message() then
+ parser:error(string.format('cannot read message from %s: %s', fname,
+ 'failed to parse'))
+ end
+
+ return task
+end
+
+local function highlight(fmt, ...)
+ return ansicolors.white .. string.format(fmt, ...) .. ansicolors.reset
+end
+
+local function maybe_print_fname(opts, fname)
+ if not opts.json and not opts['no-file'] then
+ rspamd_logger.messagex(highlight('File: %s', fname))
+ end
+end
+
+local function output_fmt(opts)
+ local fmt = 'json'
+ if opts.compact then
+ fmt = 'json-compact'
+ end
+ if opts.ucl then
+ fmt = 'ucl'
+ end
+ if opts.messagepack then
+ fmt = 'msgpack'
+ end
+
+ return fmt
+end
+
+-- Print elements in form
+-- filename -> table of elements
+local function print_elts(elts, opts, func)
+ local fun = require "fun"
+
+ if opts.json or opts.ucl then
+ io.write(ucl.to_format(elts, output_fmt(opts)))
+ else
+ fun.each(function(fname, elt)
+
+ if not opts.json and not opts.ucl then
+ if func then
+ elt = fun.map(func, elt)
+ end
+ maybe_print_fname(opts, fname)
+ fun.each(function(e)
+ io.write(e)
+ io.write("\n")
+ end, elt)
+ end
+ end, elts)
+ end
+end
+
+local function extract_handler(opts)
+ local out_elts = {}
+ local tasks = {}
+ local process_func
+
+ if opts.words then
+ -- Enable stemming and urls detection
+ load_config(opts)
+ rspamd_url.init(rspamd_config:get_tld_path())
+ rspamd_config:init_subsystem('langdet')
+ end
+
+ local function maybe_print_text_part_info(part, out)
+ local fun = require "fun"
+ if opts.part then
+ local t = 'plain text'
+ if part:is_html() then
+ t = 'html'
+ end
+
+ if not opts.json and not opts.ucl then
+ table.insert(out,
+ rspamd_logger.slog('Part: %s: %s, language: %s, size: %s (%s raw), words: %s',
+ part:get_mimepart():get_digest():sub(1, 8),
+ t,
+ part:get_language(),
+ part:get_length(), part:get_raw_length(),
+ part:get_words_count()))
+ table.insert(out,
+ rspamd_logger.slog('Stats: %s',
+ fun.foldl(function(acc, k, v)
+ if acc ~= '' then
+ return string.format('%s, %s:%s', acc, k, v)
+ else
+ return string.format('%s:%s', k, v)
+ end
+ end, '', part:get_stats())))
+ end
+ end
+ end
+
+ local function maybe_print_mime_part_info(part, out)
+ if opts.part then
+
+ if not opts.json and not opts.ucl then
+ local mtype, msubtype = part:get_type()
+ local det_mtype, det_msubtype = part:get_detected_type()
+ table.insert(out,
+ rspamd_logger.slog('Mime Part: %s: %s/%s (%s/%s detected), filename: %s (%s detected ext), size: %s',
+ part:get_digest():sub(1, 8),
+ mtype, msubtype,
+ det_mtype, det_msubtype,
+ part:get_filename(),
+ part:get_detected_ext(),
+ part:get_length()))
+ end
+ end
+ end
+
+ local function print_words(words, full)
+ local fun = require "fun"
+
+ if not full then
+ return table.concat(words, ' ')
+ else
+ return table.concat(
+ fun.totable(
+ fun.map(function(w)
+ -- [1] - stemmed word
+ -- [2] - normalised word
+ -- [3] - raw word
+ -- [4] - flags (table of strings)
+ return string.format('%s|%s|%s(%s)',
+ w[3], w[2], w[1], table.concat(w[4], ','))
+ end, words)
+ ),
+ ' '
+ )
+ end
+ end
+
+ for _, fname in ipairs(opts.file) do
+ local task = load_task(opts, fname)
+ out_elts[fname] = {}
+
+ if not opts.text and not opts.html then
+ opts.text = true
+ opts.html = true
+ end
+
+ if opts.words then
+ local how_words = opts['words_format'] or 'stem'
+ table.insert(out_elts[fname], 'meta_words: ' ..
+ print_words(task:get_meta_words(how_words), how_words == 'full'))
+ end
+
+ if opts.text or opts.html then
+ local mp = task:get_parts() or {}
+
+ for _, mime_part in ipairs(mp) do
+ local how = opts.output
+ local part
+ if mime_part:is_text() then
+ part = mime_part:get_text()
+ end
+
+ if part and opts.text and not part:is_html() then
+ maybe_print_text_part_info(part, out_elts[fname])
+ maybe_print_mime_part_info(mime_part, out_elts[fname])
+ if not opts.json and not opts.ucl then
+ table.insert(out_elts[fname], '\n')
+ end
+
+ if opts.words then
+ local how_words = opts['words_format'] or 'stem'
+ table.insert(out_elts[fname], print_words(part:get_words(how_words),
+ how_words == 'full'))
+ else
+ table.insert(out_elts[fname], tostring(part:get_content(how)))
+ end
+ elseif part and opts.html and part:is_html() then
+ maybe_print_text_part_info(part, out_elts[fname])
+ maybe_print_mime_part_info(mime_part, out_elts[fname])
+ if not opts.json and not opts.ucl then
+ table.insert(out_elts[fname], '\n')
+ end
+
+ if opts.words then
+ local how_words = opts['words_format'] or 'stem'
+ table.insert(out_elts[fname], print_words(part:get_words(how_words),
+ how_words == 'full'))
+ else
+ if opts.structure then
+ local hc = part:get_html()
+ local res = {}
+ process_func = function(elt)
+ local fun = require "fun"
+ if type(elt) == 'table' then
+ return table.concat(fun.totable(
+ fun.map(
+ function(t)
+ return rspamd_logger.slog("%s", t)
+ end,
+ elt)), '\n')
+ else
+ return rspamd_logger.slog("%s", elt)
+ end
+ end
+
+ hc:foreach_tag('any', function(tag)
+ local elt = {}
+ local ex = tag:get_extra()
+ elt.tag = tag:get_type()
+ if ex then
+ elt.extra = ex
+ end
+ local content = tag:get_content()
+ if content then
+ elt.content = tostring(content)
+ end
+ local style = tag:get_style()
+ if style then
+ elt.style = style
+ end
+ table.insert(res, elt)
+ end)
+ table.insert(out_elts[fname], res)
+ else
+ -- opts.structure
+ table.insert(out_elts[fname], tostring(part:get_content(how)))
+ end
+ if opts.invisible then
+ local hc = part:get_html()
+ table.insert(out_elts[fname], string.format('invisible content: %s',
+ tostring(hc:get_invisible())))
+ end
+ end
+ end
+
+ if not part then
+ maybe_print_mime_part_info(mime_part, out_elts[fname])
+ end
+ end
+ end
+
+ table.insert(out_elts[fname], "")
+ table.insert(tasks, task)
+ end
+
+ print_elts(out_elts, opts, process_func)
+ -- To avoid use after free we postpone tasks destruction
+ for _, task in ipairs(tasks) do
+ task:destroy()
+ end
+end
+
+local function stat_handler(opts)
+ local fun = require "fun"
+ local out_elts = {}
+
+ load_config(opts)
+ rspamd_url.init(rspamd_config:get_tld_path())
+ rspamd_config:init_subsystem('langdet,stat') -- Needed to gen stat tokens
+
+ local process_func
+
+ for _, fname in ipairs(opts.file) do
+ local task = load_task(opts, fname)
+ out_elts[fname] = {}
+
+ if opts.meta then
+ local mt = lua_meta.gen_metatokens_table(task)
+ out_elts[fname] = mt
+ process_func = function(k, v)
+ return string.format("%s = %s", k, v)
+ end
+ elseif opts.bayes then
+ local bt = task:get_stat_tokens()
+ out_elts[fname] = bt
+ process_func = function(e)
+ return string.format('%s (%d): "%s"+"%s", [%s]', e.data, e.win, e.t1 or "",
+ e.t2 or "", table.concat(fun.totable(
+ fun.map(function(k)
+ return k
+ end, e.flags)), ","))
+ end
+ elseif opts.fuzzy then
+ local parts = task:get_parts() or {}
+ out_elts[fname] = {}
+ process_func = function(e)
+ local ret = string.format('part: %s(%s): %s', e.type, e.file or "", e.digest)
+ if opts.shingles and e.shingles then
+ local sgl = {}
+ for _, s in ipairs(e.shingles) do
+ table.insert(sgl, string.format('%s: %s+%s+%s', s[1], s[2], s[3], s[4]))
+ end
+
+ ret = ret .. '\n' .. table.concat(sgl, '\n')
+ end
+ return ret
+ end
+ for _, part in ipairs(parts) do
+ if not part:is_multipart() then
+ local text = part:get_text()
+
+ if text then
+ local digest, shingles = text:get_fuzzy_hashes(task:get_mempool())
+ table.insert(out_elts[fname], {
+ digest = digest,
+ shingles = shingles,
+ type = string.format('%s/%s',
+ ({ part:get_type() })[1],
+ ({ part:get_type() })[2])
+ })
+ else
+ table.insert(out_elts[fname], {
+ digest = part:get_digest(),
+ file = part:get_filename(),
+ type = string.format('%s/%s',
+ ({ part:get_type() })[1],
+ ({ part:get_type() })[2])
+ })
+ end
+ end
+ end
+ end
+
+ task:destroy() -- No automatic dtor
+ end
+
+ print_elts(out_elts, opts, process_func)
+end
+
+local function urls_handler(opts)
+ load_config(opts)
+ rspamd_url.init(rspamd_config:get_tld_path())
+ local out_elts = {}
+
+ if opts.json then
+ rspamd_logger.messagex('[')
+ end
+
+ for _, fname in ipairs(opts.file) do
+ out_elts[fname] = {}
+ local task = load_task(opts, fname)
+ local elts = {}
+
+ local function process_url(u)
+ local s
+ if opts.tld then
+ s = u:get_tld()
+ elseif opts.host then
+ s = u:get_host()
+ elseif opts.full then
+ s = rspamd_logger.slog('%s: %s', u:get_text(), u:to_table())
+ else
+ s = u:get_text()
+ end
+
+ if opts.unique then
+ if elts[s] then
+ elts[s].count = elts[s].count + 1
+ else
+ elts[s] = {
+ count = 1,
+ url = u:to_table()
+ }
+ end
+ else
+ if opts.json then
+ table.insert(elts, u)
+ else
+ table.insert(elts, s)
+ end
+ end
+ end
+
+ for _, u in ipairs(task:get_urls(true)) do
+ process_url(u)
+ end
+
+ local json_elts = {}
+
+ local function process_elt(s, u)
+ if opts.unique then
+ -- s is string, u is {url = url, count = count }
+ if not opts.json then
+ if opts.count then
+ table.insert(json_elts, string.format('%s : %s', s, u.count))
+ else
+ table.insert(json_elts, s)
+ end
+ else
+ local tb = u.url
+ tb.count = u.count
+ table.insert(json_elts, tb)
+ end
+ else
+ -- s is index, u is url or string
+ if opts.json then
+ table.insert(json_elts, u)
+ else
+ table.insert(json_elts, u)
+ end
+ end
+ end
+
+ if opts.sort then
+ local sfunc
+ if opts.unique then
+ sfunc = function(t, a, b)
+ if t[a].count ~= t[b].count then
+ if opts.reverse then
+ return t[a].count > t[b].count
+ else
+ return t[a].count < t[b].count
+ end
+ else
+ -- Sort lexicography
+ if opts.reverse then
+ return a > b
+ else
+ return a < b
+ end
+ end
+ end
+ else
+ sfunc = function(t, a, b)
+ local va, vb
+ if opts.json then
+ va = t[a]:get_text()
+ vb = t[b]:get_text()
+ else
+ va = t[a]
+ vb = t[b]
+ end
+ if opts.reverse then
+ return va > vb
+ else
+ return va < vb
+ end
+ end
+ end
+
+ for s, u in lua_util.spairs(elts, sfunc) do
+ process_elt(s, u)
+ end
+ else
+ for s, u in pairs(elts) do
+ process_elt(s, u)
+ end
+ end
+
+ out_elts[fname] = json_elts
+
+ task:destroy() -- No automatic dtor
+ end
+
+ print_elts(out_elts, opts)
+end
+
+local function newline(task)
+ local t = task:get_newlines_type()
+
+ if t == 'cr' then
+ return '\r'
+ elseif t == 'lf' then
+ return '\n'
+ end
+
+ return '\r\n'
+end
+
+local function modify_handler(opts)
+ load_config(opts)
+ rspamd_url.init(rspamd_config:get_tld_path())
+
+ local function read_file(file)
+ local f = assert(io.open(file, "rb"))
+ local content = f:read("*all")
+ f:close()
+ return content
+ end
+
+ local text_footer, html_footer
+
+ if opts['text_footer'] then
+ text_footer = read_file(opts['text_footer'])
+ end
+
+ if opts['html_footer'] then
+ html_footer = read_file(opts['html_footer'])
+ end
+
+ for _, fname in ipairs(opts.file) do
+ local task = load_task(opts, fname)
+ local newline_s = newline(task)
+ local seen_cte
+
+ local rewrite = lua_mime.add_text_footer(task, html_footer, text_footer) or {}
+ local out = {} -- Start with headers
+
+ local function process_headers_cb(name, hdr)
+ for _, h in ipairs(opts['remove_header']) do
+ if name:match(h) then
+ return
+ end
+ end
+
+ for _, h in ipairs(opts['rewrite_header']) do
+ local hname, hpattern = h:match('^([^=]+)=(.+)$')
+ if hname == name then
+ local new_value = string.format(hpattern, hdr.decoded)
+ new_value = string.format('%s:%s%s',
+ name, hdr.separator,
+ rspamd_util.fold_header(name,
+ rspamd_util.mime_header_encode(new_value),
+ task:get_newlines_type()))
+ out[#out + 1] = new_value
+ return
+ end
+ end
+
+ if rewrite.need_rewrite_ct then
+ if name:lower() == 'content-type' then
+ local nct = string.format('%s: %s/%s; charset=utf-8',
+ 'Content-Type', rewrite.new_ct.type, rewrite.new_ct.subtype)
+ out[#out + 1] = nct
+ return
+ elseif name:lower() == 'content-transfer-encoding' then
+ out[#out + 1] = string.format('%s: %s',
+ 'Content-Transfer-Encoding', rewrite.new_cte or 'quoted-printable')
+ seen_cte = true
+ return
+ end
+ end
+
+ out[#out + 1] = hdr.raw:gsub('\r?\n?$', '')
+ end
+
+ task:headers_foreach(process_headers_cb, { full = true })
+
+ for _, h in ipairs(opts['add_header']) do
+ local hname, hvalue = h:match('^([^=]+)=(.+)$')
+
+ if hname and hvalue then
+ out[#out + 1] = string.format('%s: %s', hname,
+ rspamd_util.fold_header(hname, hvalue, task:get_newlines_type()))
+ end
+ end
+
+ if not seen_cte and rewrite.need_rewrite_ct then
+ out[#out + 1] = string.format('%s: %s',
+ 'Content-Transfer-Encoding', rewrite.new_cte or 'quoted-printable')
+ end
+
+ -- End of headers
+ out[#out + 1] = ''
+
+ if rewrite.out then
+ for _, o in ipairs(rewrite.out) do
+ out[#out + 1] = o
+ end
+ else
+ out[#out + 1] = { task:get_rawbody(), false }
+ end
+
+ for _, o in ipairs(out) do
+ if type(o) == 'string' then
+ io.write(o)
+ io.write(newline_s)
+ elseif type(o) == 'table' then
+ io.flush()
+ if type(o[1]) == 'string' then
+ io.write(o[1])
+ else
+ o[1]:save_in_file(1)
+ end
+
+ if o[2] then
+ io.write(newline_s)
+ end
+ else
+ o:save_in_file(1)
+ io.write(newline_s)
+ end
+ end
+
+ task:destroy() -- No automatic dtor
+ end
+end
+
+local function sign_handler(opts)
+ load_config(opts)
+ rspamd_url.init(rspamd_config:get_tld_path())
+
+ local lua_dkim = require("lua_ffi").dkim
+
+ if not lua_dkim then
+ io.stderr:write('FFI support is required: please use LuaJIT or install lua-ffi')
+ os.exit(1)
+ end
+
+ local sign_key
+ if rspamd_util.file_exists(opts.key) then
+ sign_key = lua_dkim.load_sign_key(opts.key, 'file')
+ else
+ sign_key = lua_dkim.load_sign_key(opts.key, 'base64')
+ end
+
+ if not sign_key then
+ io.stderr:write('Cannot load key: ' .. opts.key .. '\n')
+ os.exit(1)
+ end
+
+ for _, fname in ipairs(opts.file) do
+ local task = load_task(opts, fname)
+ local ctx = lua_dkim.create_sign_context(task, sign_key, nil, opts.algorithm)
+
+ if not ctx then
+ io.stderr:write('Cannot init signing\n')
+ os.exit(1)
+ end
+
+ local sig = lua_dkim.do_sign(task, ctx, opts.selector, opts.domain)
+
+ if not sig then
+ io.stderr:write('Cannot create signature\n')
+ os.exit(1)
+ end
+
+ if opts.output == 'signature' then
+ io.write(sig)
+ io.write('\n')
+ io.flush()
+ else
+ local dkim_hdr = string.format('%s: %s%s',
+ 'DKIM-Signature',
+ rspamd_util.fold_header('DKIM-Signature',
+ rspamd_util.mime_header_encode(sig),
+ task:get_newlines_type()),
+ newline(task))
+ io.write(dkim_hdr)
+ io.flush()
+ task:get_content():save_in_file(1)
+ end
+
+ task:destroy() -- No automatic dtor
+ end
+end
+
+-- Strips directories and .extensions (if present) from a filepath
+local function filename_only(filepath)
+ local filename = filepath:match(".*%/([^%.]+)")
+ if not filename then
+ filename = filepath:match("([^%.]+)")
+ end
+ return filename
+end
+
+assert(filename_only("very_simple") == "very_simple")
+assert(filename_only("/home/very_simple.eml") == "very_simple")
+assert(filename_only("very_simple.eml") == "very_simple")
+assert(filename_only("very_simple.example.eml") == "very_simple")
+assert(filename_only("/home/very_simple") == "very_simple")
+assert(filename_only("home/very_simple") == "very_simple")
+assert(filename_only("./home/very_simple") == "very_simple")
+assert(filename_only("../home/very_simple.eml") == "very_simple")
+assert(filename_only("/home/dir.with.dots/very_simple.eml") == "very_simple")
+
+--Write the dump content to file or standard out
+local function write_dump_content(dump_content, fname, extension, outdir)
+ if type(dump_content) == "string" then
+ dump_content = rspamd_text.fromstring(dump_content)
+ end
+
+ local wrote_filepath = nil
+ if outdir then
+ if outdir:sub(-1) ~= "/" then
+ outdir = outdir .. "/"
+ end
+
+ local outpath = string.format("%s%s.%s", outdir, filename_only(fname), extension)
+ if rspamd_util.file_exists(outpath) then
+ os.remove(outpath)
+ end
+ if dump_content:save_in_file(outpath) then
+ wrote_filepath = outpath
+ io.write(wrote_filepath .. "\n")
+ else
+ io.stderr:write(string.format("Unable to save dump content to file: %s\n", outpath))
+ end
+ else
+ dump_content:save_in_file(1)
+ end
+ return wrote_filepath
+end
+
+-- Get the formatted ucl (split or unsplit) or the raw task content
+local function get_dump_content(task, opts, fname)
+ if opts.ucl or opts.json or opts.messagepack then
+ local ucl_object = lua_mime.message_to_ucl(task)
+ -- Split out the content field into separate raws and update the ucl
+ if opts.split then
+ for i, part in ipairs(ucl_object.parts) do
+ if part.content then
+ local part_filename = string.format("%s-part%d", filename_only(fname), i)
+ local part_path = write_dump_content(part.content, part_filename, "raw", opts.outdir)
+ if part_path then
+ part.content = ucl.null
+ part.content_path = part_path
+ end
+ end
+ end
+ end
+ local extension = output_fmt(opts)
+ return ucl.to_format(ucl_object, extension), extension
+ end
+ return task:get_content(), "mime"
+end
+
+local function dump_handler(opts)
+ load_config(opts)
+ rspamd_url.init(rspamd_config:get_tld_path())
+
+ for _, fname in ipairs(opts.file) do
+ local task = load_task(opts, fname)
+ local data, extension = get_dump_content(task, opts, fname)
+ write_dump_content(data, fname, extension, opts.outdir)
+
+ task:destroy() -- No automatic dtor
+ end
+end
+
+local function handler(args)
+ local opts = parser:parse(args)
+
+ local command = opts.command
+
+ if type(opts.file) == 'string' then
+ opts.file = { opts.file }
+ elseif type(opts.file) == 'none' then
+ opts.file = {}
+ end
+
+ if command == 'extract' then
+ extract_handler(opts)
+ elseif command == 'stat' then
+ stat_handler(opts)
+ elseif command == 'urls' then
+ urls_handler(opts)
+ elseif command == 'modify' then
+ modify_handler(opts)
+ elseif command == 'sign' then
+ sign_handler(opts)
+ elseif command == 'dump' then
+ dump_handler(opts)
+ else
+ parser:error('command %s is not implemented', command)
+ end
+end
+
+return {
+ name = 'mime',
+ aliases = { 'mime_tool' },
+ handler = handler,
+ description = parser._description
+}
diff --git a/lualib/rspamadm/neural_test.lua b/lualib/rspamadm/neural_test.lua
new file mode 100644
index 0000000..31d21a9
--- /dev/null
+++ b/lualib/rspamadm/neural_test.lua
@@ -0,0 +1,228 @@
+local rspamd_logger = require "rspamd_logger"
+local argparse = require "argparse"
+local lua_util = require "lua_util"
+local ucl = require "ucl"
+
+local parser = argparse()
+ :name "rspamadm neural_test"
+ :description "Test the neural network with labelled dataset"
+ :help_description_margin(32)
+
+parser:option "-c --config"
+ :description "Path to config file"
+ :argname("<cfg>")
+ :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
+parser:option "-H --hamdir"
+ :description("Ham directory")
+ :argname("<dir>")
+parser:option "-S --spamdir"
+ :description("Spam directory")
+ :argname("<dir>")
+parser:option "-t --timeout"
+ :description("Timeout for client connections")
+ :argname("<sec>")
+ :convert(tonumber)
+ :default(60)
+parser:option "-n --conns"
+ :description("Number of parallel connections")
+ :argname("<N>")
+ :convert(tonumber)
+ :default(10)
+parser:option "-c --connect"
+ :description("Connect to specific host")
+ :argname("<host>")
+ :default('localhost:11334')
+parser:option "-r --rspamc"
+ :description("Use specific rspamc path")
+ :argname("<path>")
+ :default('rspamc')
+parser:option '--rule'
+ :description 'Rule to test'
+ :argname('<rule>')
+
+local HAM = "HAM"
+local SPAM = "SPAM"
+
+local function load_config(opts)
+ local _r, err = rspamd_config:load_ucl(opts['config'])
+
+ if not _r then
+ rspamd_logger.errx('cannot parse %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
+ _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' })
+ if not _r then
+ rspamd_logger.errx('cannot process %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+end
+
+local function scan_email(rspamc_path, host, n_parallel, path, timeout)
+
+ local rspamc_command = string.format("%s --connect %s -j --compact -n %s -t %.3f %s",
+ rspamc_path, host, n_parallel, timeout, path)
+ local result = assert(io.popen(rspamc_command))
+ result = result:read("*all")
+ return result
+end
+
+local function encoded_json_to_log(result)
+ -- Returns table containing score, action, list of symbols
+
+ local filtered_result = {}
+ local ucl_parser = ucl.parser()
+
+ local is_good, err = ucl_parser:parse_string(result)
+
+ if not is_good then
+ rspamd_logger.errx("Parser error: %1", err)
+ return nil
+ end
+
+ result = ucl_parser:get_object()
+
+ filtered_result.score = result.score
+ if not result.action then
+ rspamd_logger.errx("Bad JSON: %1", result)
+ return nil
+ end
+ local action = result.action:gsub("%s+", "_")
+ filtered_result.action = action
+
+ filtered_result.symbols = {}
+
+ for sym, _ in pairs(result.symbols) do
+ table.insert(filtered_result.symbols, sym)
+ end
+
+ filtered_result.filename = result.filename
+ filtered_result.scan_time = result.scan_time
+
+ return filtered_result
+end
+
+local function filter_scan_results(results, actual_email_type)
+
+ local logs = {}
+
+ results = lua_util.rspamd_str_split(results, "\n")
+
+ if results[#results] == "" then
+ results[#results] = nil
+ end
+
+ for _, result in pairs(results) do
+ result = encoded_json_to_log(result)
+ if result then
+ result['type'] = actual_email_type
+ table.insert(logs, result)
+ end
+ end
+
+ return logs
+end
+
+local function get_stats_from_scan_results(results, rules)
+
+ local rule_stats = {}
+ for rule, _ in pairs(rules) do
+ rule_stats[rule] = { tp = 0, tn = 0, fp = 0, fn = 0 }
+ end
+
+ for _, result in ipairs(results) do
+ for _, symbol in ipairs(result["symbols"]) do
+ for name, rule in pairs(rules) do
+ if rule.symbol_spam and rule.symbol_spam == symbol then
+ if result.type == HAM then
+ rule_stats[name].fp = rule_stats[name].fp + 1
+ elseif result.type == SPAM then
+ rule_stats[name].tp = rule_stats[name].tp + 1
+ end
+ elseif rule.symbol_ham and rule.symbol_ham == symbol then
+ if result.type == HAM then
+ rule_stats[name].tn = rule_stats[name].tn + 1
+ elseif result.type == SPAM then
+ rule_stats[name].fn = rule_stats[name].fn + 1
+ end
+ end
+ end
+ end
+ end
+
+ for rule, _ in pairs(rules) do
+ rule_stats[rule].fpr = rule_stats[rule].fp / (rule_stats[rule].fp + rule_stats[rule].tn)
+ rule_stats[rule].fnr = rule_stats[rule].fn / (rule_stats[rule].fn + rule_stats[rule].tp)
+ end
+
+ return rule_stats
+end
+
+local function print_neural_stats(neural_stats)
+ for rule, stats in pairs(neural_stats) do
+ rspamd_logger.messagex("\nStats for rule: %s", rule)
+ rspamd_logger.messagex("False positive rate: %s%%", stats.fpr * 100)
+ rspamd_logger.messagex("False negative rate: %s%%", stats.fnr * 100)
+ end
+end
+
+local function handler(args)
+ local opts = parser:parse(args)
+
+ local ham_directory = opts['hamdir']
+ local spam_directory = opts['spamdir']
+ local connections = opts["conns"]
+
+ load_config(opts)
+
+ local neural_opts = rspamd_config:get_all_opt('neural')
+
+ if opts["rule"] then
+ local found = false
+ for rule_name, _ in pairs(neural_opts.rules) do
+ if string.lower(rule_name) == string.lower(opts["rule"]) then
+ found = true
+ else
+ neural_opts.rules[rule_name] = nil
+ end
+ end
+
+ if not found then
+ rspamd_logger.errx("Couldn't find the rule %s", opts["rule"])
+ return
+ end
+ end
+
+ local results = {}
+
+ if ham_directory then
+ rspamd_logger.messagex("Scanning ham corpus...")
+ local ham_results = scan_email(opts.rspamc, opts.connect, connections, ham_directory, opts.timeout)
+ ham_results = filter_scan_results(ham_results, HAM)
+
+ for _, result in pairs(ham_results) do
+ table.insert(results, result)
+ end
+ end
+
+ if spam_directory then
+ rspamd_logger.messagex("Scanning spam corpus...")
+ local spam_results = scan_email(opts.rspamc, opts.connect, connections, spam_directory, opts.timeout)
+ spam_results = filter_scan_results(spam_results, SPAM)
+
+ for _, result in pairs(spam_results) do
+ table.insert(results, result)
+ end
+ end
+
+ local neural_stats = get_stats_from_scan_results(results, neural_opts.rules)
+ print_neural_stats(neural_stats)
+
+end
+
+return {
+ name = "neuraltest",
+ aliases = { "neural_test" },
+ handler = handler,
+ description = parser._description
+} \ No newline at end of file
diff --git a/lualib/rspamadm/publicsuffix.lua b/lualib/rspamadm/publicsuffix.lua
new file mode 100644
index 0000000..96bf069
--- /dev/null
+++ b/lualib/rspamadm/publicsuffix.lua
@@ -0,0 +1,82 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local argparse = require "argparse"
+local rspamd_logger = require "rspamd_logger"
+
+-- Define command line options
+local parser = argparse()
+ :name 'rspamadm publicsuffix'
+ :description 'Do manipulations with the publicsuffix list'
+ :help_description_margin(30)
+ :command_target('command')
+ :require_command(true)
+
+parser:option '-c --config'
+ :description 'Path to config file'
+ :argname('config_file')
+ :default(rspamd_paths['CONFDIR'] .. '/rspamd.conf')
+
+parser:command 'compile'
+ :description 'Compile publicsuffix list if needed'
+
+local function load_config(config_file)
+ local _r, err = rspamd_config:load_ucl(config_file)
+
+ if not _r then
+ rspamd_logger.errx('cannot load %s: %s', config_file, err)
+ os.exit(1)
+ end
+
+ _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' })
+ if not _r then
+ rspamd_logger.errx('cannot process %s: %s', config_file, err)
+ os.exit(1)
+ end
+end
+
+local function compile_handler(_)
+ local rspamd_url = require "rspamd_url"
+ local tld_file = rspamd_config:get_tld_path()
+
+ if not tld_file then
+ rspamd_logger.errx('missing `url_tld` option, cannot continue')
+ os.exit(1)
+ end
+
+ rspamd_logger.messagex('loading public suffix file from %s', tld_file)
+ rspamd_url.init(tld_file)
+ rspamd_logger.messagex('public suffix file has been loaded')
+end
+
+local function handler(args)
+ local cmd_opts = parser:parse(args)
+
+ load_config(cmd_opts.config_file)
+
+ if cmd_opts.command == 'compile' then
+ compile_handler(cmd_opts)
+ else
+ rspamd_logger.errx('unknown command: %s', cmd_opts.command)
+ os.exit(1)
+ end
+end
+
+return {
+ handler = handler,
+ description = parser._description,
+ name = 'publicsuffix'
+} \ No newline at end of file
diff --git a/lualib/rspamadm/stat_convert.lua b/lualib/rspamadm/stat_convert.lua
new file mode 100644
index 0000000..62a19a2
--- /dev/null
+++ b/lualib/rspamadm/stat_convert.lua
@@ -0,0 +1,38 @@
+local lua_redis = require "lua_redis"
+local stat_tools = require "lua_stat"
+local ucl = require "ucl"
+local logger = require "rspamd_logger"
+local lua_util = require "lua_util"
+
+return function(_, res)
+ local redis_params = lua_redis.try_load_redis_servers(res.redis, nil)
+ if res.expire then
+ res.expire = lua_util.parse_time_interval(res.expire)
+ end
+ if not redis_params then
+ logger.errx('cannot load redis server definition')
+
+ return false
+ end
+
+ local sqlite_params = stat_tools.load_sqlite_config(res)
+
+ if #sqlite_params == 0 then
+ logger.errx('cannot load sqlite classifiers')
+ return false
+ end
+
+ for _, cls in ipairs(sqlite_params) do
+ if not stat_tools.convert_sqlite_to_redis(redis_params, cls.db_spam,
+ cls.db_ham, cls.symbol_spam, cls.symbol_ham, cls.learn_cache, res.expire,
+ res.reset_previous) then
+ logger.errx('conversion failed')
+
+ return false
+ end
+ logger.messagex('Converted classifier to the from sqlite to redis')
+ logger.messagex('Suggested configuration:')
+ logger.messagex(ucl.to_format(stat_tools.redis_classifier_from_sqlite(cls, res.expire),
+ 'config'))
+ end
+end
diff --git a/lualib/rspamadm/statistics_dump.lua b/lualib/rspamadm/statistics_dump.lua
new file mode 100644
index 0000000..6bc0458
--- /dev/null
+++ b/lualib/rspamadm/statistics_dump.lua
@@ -0,0 +1,544 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local lua_redis = require "lua_redis"
+local rspamd_logger = require "rspamd_logger"
+local argparse = require "argparse"
+local rspamd_zstd = require "rspamd_zstd"
+local rspamd_text = require "rspamd_text"
+local rspamd_util = require "rspamd_util"
+local rspamd_cdb = require "rspamd_cdb"
+local lua_util = require "lua_util"
+local rspamd_i64 = require "rspamd_int64"
+local ucl = require "ucl"
+
+local N = "statistics_dump"
+local E = {}
+local classifiers = {}
+
+-- Define command line options
+local parser = argparse()
+ :name "rspamadm statistics_dump"
+ :description "Dump/restore Rspamd statistics"
+ :help_description_margin(30)
+ :command_target("command")
+ :require_command(false)
+
+parser:option "-c --config"
+ :description "Path to config file"
+ :argname("<cfg>")
+ :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf")
+
+-- Extract subcommand
+local dump = parser:command "dump d"
+ :description "Dump bayes statistics"
+dump:mutex(
+ dump:flag "-j --json"
+ :description "Json output",
+ dump:flag "-C --cdb"
+ :description "CDB output"
+)
+dump:flag "-c --compress"
+ :description "Compress output"
+dump:option "-b --batch-size"
+ :description "Number of entires to process at once"
+ :argname("<elts>")
+ :convert(tonumber)
+ :default(1000)
+
+
+-- Restore
+local restore = parser:command "restore r"
+ :description "Restore bayes statistics"
+restore:argument "file"
+ :description "Input file to process"
+ :argname "<file>"
+ :args "*"
+restore:option "-b --batch-size"
+ :description "Number of entires to process at once"
+ :argname("<elts>")
+ :convert(tonumber)
+ :default(1000)
+restore:option "-m --mode"
+ :description "Number of entires to process at once"
+ :argname("<append|subtract|replace>")
+ :convert {
+ ['append'] = 'append',
+ ['subtract'] = 'subtract',
+ ['replace'] = 'replace',
+}
+ :default 'append'
+restore:flag "-n --no-operation"
+ :description "Only show redis commands to be issued"
+
+local function load_config(opts)
+ local _r, err = rspamd_config:load_ucl(opts['config'])
+
+ if not _r then
+ rspamd_logger.errx('cannot parse %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+
+ _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' })
+ if not _r then
+ rspamd_logger.errx('cannot process %s: %s', opts['config'], err)
+ os.exit(1)
+ end
+end
+
+local function check_redis_classifier(cls, cfg)
+ -- Skip old classifiers
+ if cls.new_schema then
+ local symbol_spam, symbol_ham
+ -- Load symbols from statfiles
+
+ local function check_statfile_table(tbl, def_sym)
+ local symbol = tbl.symbol or def_sym
+
+ local spam
+ if tbl.spam then
+ spam = tbl.spam
+ else
+ if string.match(symbol:upper(), 'SPAM') then
+ spam = true
+ else
+ spam = false
+ end
+ end
+
+ if spam then
+ symbol_spam = symbol
+ else
+ symbol_ham = symbol
+ end
+ end
+
+ local statfiles = cls.statfile
+ if statfiles[1] then
+ for _, stf in ipairs(statfiles) do
+ if not stf.symbol then
+ for k, v in pairs(stf) do
+ check_statfile_table(v, k)
+ end
+ else
+ check_statfile_table(stf, 'undefined')
+ end
+ end
+ else
+ for stn, stf in pairs(statfiles) do
+ check_statfile_table(stf, stn)
+ end
+ end
+
+ local redis_params
+ redis_params = lua_redis.try_load_redis_servers(cls,
+ rspamd_config, false, 'bayes')
+ if not redis_params then
+ redis_params = lua_redis.try_load_redis_servers(cfg[N] or E,
+ rspamd_config, false, 'bayes')
+ if not redis_params then
+ redis_params = lua_redis.try_load_redis_servers(cfg[N] or E,
+ rspamd_config, true)
+ if not redis_params then
+ return false
+ end
+ end
+ end
+
+ table.insert(classifiers, {
+ symbol_spam = symbol_spam,
+ symbol_ham = symbol_ham,
+ redis_params = redis_params,
+ })
+ end
+end
+
+local function redis_map_zip(ar)
+ local data = {}
+ for j = 1, #ar, 2 do
+ data[ar[j]] = ar[j + 1]
+ end
+
+ return data
+end
+
+-- Used to clear tables
+local clear_fcn = table.clear or function(tbl)
+ local keys = lua_util.keys(tbl)
+ for _, k in ipairs(keys) do
+ tbl[k] = nil
+ end
+end
+
+local compress_ctx
+
+local function dump_out(out, opts, last)
+ if opts.compress and not compress_ctx then
+ compress_ctx = rspamd_zstd.compress_ctx()
+ end
+
+ if compress_ctx then
+ if last then
+ compress_ctx:stream(rspamd_text.fromtable(out), 'end'):write()
+ else
+ compress_ctx:stream(rspamd_text.fromtable(out), 'flush'):write()
+ end
+ else
+ for _, o in ipairs(out) do
+ io.write(o)
+ end
+ end
+end
+
+local function dump_cdb(out, opts, last, pattern)
+ local results = out[pattern]
+
+ if not out.cdb_builder then
+ -- First invocation
+ out.cdb_builder = rspamd_cdb.build(string.format('%s.cdb', pattern))
+ out.cdb_builder:add('_lrnspam', rspamd_i64.fromstring(results.learns_spam or '0'))
+ out.cdb_builder:add('_lrnham_', rspamd_i64.fromstring(results.learns_ham or '0'))
+ end
+
+ for _, o in ipairs(results.elts) do
+ out.cdb_builder:add(o.key, o.value)
+ end
+
+ if last then
+ out.cdb_builder:finalize()
+ out.cdb_builder = nil
+ end
+end
+
+local function dump_pattern(conn, pattern, opts, out, key)
+ local cursor = 0
+
+ repeat
+ conn:add_cmd('SCAN', { tostring(cursor),
+ 'MATCH', pattern,
+ 'COUNT', tostring(opts.batch_size) })
+ local ret, results = conn:exec()
+
+ if not ret then
+ rspamd_logger.errx("cannot connect execute scan command: %s", results)
+ os.exit(1)
+ end
+
+ cursor = tonumber(results[1])
+
+ local elts = results[2]
+ local tokens = {}
+
+ for _, e in ipairs(elts) do
+ conn:add_cmd('HGETALL', { e })
+ end
+ -- This function returns many results, each for each command
+ -- So if we have batch 1000, then we would have 1000 tables in form
+ -- [result, {hash_content}]
+ local all_results = { conn:exec() }
+
+ for i = 1, #all_results, 2 do
+ local r, hash_content = all_results[i], all_results[i + 1]
+
+ if r then
+ -- List to a hash map
+ local data = redis_map_zip(hash_content)
+ tokens[#tokens + 1] = { key = elts[(i + 1) / 2], data = data }
+ end
+ end
+
+ -- Output keeping track of the commas
+ for i, d in ipairs(tokens) do
+ if cursor == 0 and i == #tokens or not opts.json then
+ if opts.cdb then
+ table.insert(out[key].elts, {
+ key = rspamd_i64.fromstring(string.match(d.key, '%d+')),
+ value = rspamd_util.pack('ff', tonumber(d.data["S"] or '0') or 0,
+ tonumber(d.data["H"] or '0'))
+ })
+ else
+ out[#out + 1] = rspamd_logger.slog('"%s": %s\n', d.key,
+ ucl.to_format(d.data, "json-compact"))
+ end
+ else
+ out[#out + 1] = rspamd_logger.slog('"%s": %s,\n', d.key,
+ ucl.to_format(d.data, "json-compact"))
+ end
+
+ end
+
+ if opts.json and cursor == 0 then
+ out[#out + 1] = '}}\n'
+ end
+
+ -- Do not write the last chunk of out as it will be processed afterwards
+ if cursor ~= 0 then
+ if opts.cdb then
+ dump_out(out, opts, false)
+ clear_fcn(out)
+ else
+ dump_cdb(out, opts, false, key)
+ out[key].elts = {}
+ end
+ elseif opts.cdb then
+ dump_cdb(out, opts, true, key)
+ end
+
+ until cursor == 0
+end
+
+local function dump_handler(opts)
+ local patterns_seen = {}
+ for _, cls in ipairs(classifiers) do
+ local res, conn = lua_redis.redis_connect_sync(cls.redis_params, false)
+
+ if not res then
+ rspamd_logger.errx("cannot connect to redis server: %s", cls.redis_params)
+ os.exit(1)
+ end
+
+ local out = {}
+ local function check_keys(sym)
+ local sym_keys_pattern = string.format("%s_keys", sym)
+ conn:add_cmd('SMEMBERS', { sym_keys_pattern })
+ local ret, keys = conn:exec()
+
+ if not ret then
+ rspamd_logger.errx("cannot execute command to get keys: %s", keys)
+ os.exit(1)
+ end
+
+ if not opts.json then
+ out[#out + 1] = string.format('"%s": %s\n', sym_keys_pattern,
+ ucl.to_format(keys, 'json-compact'))
+ end
+ for _, k in ipairs(keys) do
+ local pat = string.format('%s*_*', k)
+ if not patterns_seen[pat] then
+ conn:add_cmd('HGETALL', { k })
+ local _ret, additional_keys = conn:exec()
+
+ if _ret then
+ if opts.json then
+ out[#out + 1] = string.format('{"pattern": "%s", "meta": %s, "elts": {\n',
+ k, ucl.to_format(redis_map_zip(additional_keys), 'json-compact'))
+ elseif opts.cdb then
+ out[k] = redis_map_zip(additional_keys)
+ out[k].elts = {}
+ else
+ out[#out + 1] = string.format('"%s": %s\n', k,
+ ucl.to_format(redis_map_zip(additional_keys), 'json-compact'))
+ end
+ dump_pattern(conn, pat, opts, out, k)
+ patterns_seen[pat] = true
+ end
+ end
+ end
+ end
+
+ check_keys(cls.symbol_spam)
+ check_keys(cls.symbol_ham)
+
+ if #out > 0 then
+ dump_out(out, opts, true)
+ end
+ end
+end
+
+local function obj_to_redis_arguments(obj, opts, cmd_pipe)
+ local key, value = next(obj)
+
+ if type(key) == 'string' then
+ if type(value) == 'table' then
+ if not value[1] then
+ if opts.mode == 'replace' then
+ local cmd = 'HMSET'
+ local params = { key }
+ for k, v in pairs(value) do
+ table.insert(params, k)
+ table.insert(params, v)
+ end
+ table.insert(cmd_pipe, { cmd, params })
+ else
+ local cmd = 'HINCRBYFLOAT'
+ local mult = 1.0
+ if opts.mode == 'subtract' then
+ mult = (-mult)
+ end
+
+ for k, v in pairs(value) do
+ if tonumber(v) then
+ v = tonumber(v)
+ table.insert(cmd_pipe, { cmd, { key, k, tostring(v * mult) } })
+ else
+ table.insert(cmd_pipe, { 'HSET', { key, k, v } })
+ end
+ end
+ end
+ else
+ -- Numeric table of elements (e.g. _keys) - it is actually a set in Redis
+ for _, elt in ipairs(value) do
+ table.insert(cmd_pipe, { 'SADD', { key, elt } })
+ end
+ end
+ end
+ end
+
+ return cmd_pipe
+end
+
+local function execute_batch(batch, conns, opts)
+ local cmd_pipe = {}
+
+ for _, cmd in ipairs(batch) do
+ obj_to_redis_arguments(cmd, opts, cmd_pipe)
+ end
+
+ if opts.no_operation then
+ for _, cmd in ipairs(cmd_pipe) do
+ rspamd_logger.messagex('%s %s', cmd[1], table.concat(cmd[2], ' '))
+ end
+ else
+ for _, conn in ipairs(conns) do
+ for _, cmd in ipairs(cmd_pipe) do
+ local is_ok, err = conn:add_cmd(cmd[1], cmd[2])
+
+ if not is_ok then
+ rspamd_logger.errx("cannot add command: %s with args: %s: %s", cmd[1], cmd[2], err)
+ end
+ end
+
+ conn:exec()
+ end
+ end
+end
+
+local function restore_handler(opts)
+ local files = opts.file or { '-' }
+ local conns = {}
+
+ for _, cls in ipairs(classifiers) do
+ local res, conn = lua_redis.redis_connect_sync(cls.redis_params, true)
+
+ if not res then
+ rspamd_logger.errx("cannot connect to redis server: %s", cls.redis_params)
+ os.exit(1)
+ end
+
+ table.insert(conns, conn)
+ end
+
+ local batch = {}
+
+ for _, f in ipairs(files) do
+ local fd
+ if f ~= '-' then
+ fd = io.open(f, 'r')
+ io.input(fd)
+ end
+
+ local cur_line = 1
+ for line in io.lines() do
+ local ucl_parser = ucl.parser()
+ local res, err
+ res, err = ucl_parser:parse_string(line)
+
+ if not res then
+ rspamd_logger.errx("%s: cannot read line %s: %s", f, cur_line, err)
+ os.exit(1)
+ end
+
+ table.insert(batch, ucl_parser:get_object())
+ cur_line = cur_line + 1
+
+ if cur_line % opts.batch_size == 0 then
+ execute_batch(batch, conns, opts)
+ batch = {}
+ end
+ end
+
+ if fd then
+ fd:close()
+ end
+ end
+
+ if #batch > 0 then
+ execute_batch(batch, conns, opts)
+ end
+end
+
+local function handler(args)
+ local opts = parser:parse(args)
+
+ local command = opts.command or 'dump'
+
+ load_config(opts)
+ rspamd_config:init_subsystem('stat')
+
+ local obj = rspamd_config:get_ucl()
+
+ local classifier = obj.classifier
+
+ if classifier then
+ if classifier[1] then
+ for _, cls in ipairs(classifier) do
+ if cls.bayes then
+ cls = cls.bayes
+ end
+ if cls.backend and cls.backend == 'redis' then
+ check_redis_classifier(cls, obj)
+ end
+ end
+ else
+ if classifier.bayes then
+
+ classifier = classifier.bayes
+ if classifier[1] then
+ for _, cls in ipairs(classifier) do
+ if cls.backend and cls.backend == 'redis' then
+ check_redis_classifier(cls, obj)
+ end
+ end
+ else
+ if classifier.backend and classifier.backend == 'redis' then
+ check_redis_classifier(classifier, obj)
+ end
+ end
+ end
+ end
+ end
+
+ if type(opts.file) == 'string' then
+ opts.file = { opts.file }
+ elseif type(opts.file) == 'none' then
+ opts.file = {}
+ end
+
+ if command == 'dump' then
+ dump_handler(opts)
+ elseif command == 'restore' then
+ restore_handler(opts)
+ else
+ parser:error('command %s is not implemented', command)
+ end
+end
+
+return {
+ name = 'statistics_dump',
+ aliases = { 'stat_dump', 'bayes_dump' },
+ handler = handler,
+ description = parser._description
+} \ No newline at end of file
diff --git a/lualib/rspamadm/template.lua b/lualib/rspamadm/template.lua
new file mode 100644
index 0000000..ca1779a
--- /dev/null
+++ b/lualib/rspamadm/template.lua
@@ -0,0 +1,131 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local argparse = require "argparse"
+
+
+-- Define command line options
+local parser = argparse()
+ :name "rspamadm template"
+ :description "Apply jinja templates for strings/files"
+ :help_description_margin(30)
+parser:argument "file"
+ :description "File to process"
+ :argname "<file>"
+ :args "*"
+
+parser:flag "-n --no-vars"
+ :description "Don't add Rspamd internal variables"
+parser:option "-e --env"
+ :description "Load additional environment vars from specific file (name=value)"
+ :argname "<filename>"
+ :count "*"
+parser:option "-l --lua-env"
+ :description "Load additional environment vars from specific file (lua source)"
+ :argname "<filename>"
+ :count "*"
+parser:mutex(
+ parser:option "-s --suffix"
+ :description "Store files with the new suffix"
+ :argname "<suffix>",
+ parser:flag "-i --inplace"
+ :description "Replace input file(s)"
+)
+
+local lua_util = require "lua_util"
+
+local function set_env(opts, env)
+ if opts.env then
+ for _, fname in ipairs(opts.env) do
+ for kv in assert(io.open(fname)):lines() do
+ if not kv:match('%s*#.*') then
+ local k, v = kv:match('([^=%s]+)%s*=%s*(.+)')
+
+ if k and v then
+ env[k] = v
+ else
+ io.write(string.format('invalid env line in %s: %s\n', fname, kv))
+ end
+ end
+ end
+ end
+ end
+
+ if opts.lua_env then
+ for _, fname in ipairs(opts.env) do
+ local ret, res_or_err = pcall(loadfile(fname))
+
+ if not ret then
+ io.write(string.format('cannot load %s: %s\n', fname, res_or_err))
+ else
+ if type(res_or_err) == 'table' then
+ for k, v in pairs(res_or_err) do
+ env[k] = lua_util.deepcopy(v)
+ end
+ else
+ io.write(string.format('cannot load %s: not a table\n', fname))
+ end
+ end
+ end
+ end
+end
+
+local function read_file(file)
+ local content
+ if file == '-' then
+ content = io.read("*all")
+ else
+ local f = assert(io.open(file, "rb"))
+ content = f:read("*all")
+ f:close()
+ end
+ return content
+end
+
+local function handler(args)
+ local opts = parser:parse(args)
+ local env = {}
+ set_env(opts, env)
+
+ if not opts.file or #opts.file == 0 then
+ opts.file = { '-' }
+ end
+ for _, fname in ipairs(opts.file) do
+ local content = read_file(fname)
+ local res = lua_util.jinja_template(content, env, opts.no_vars)
+
+ if opts.inplace then
+ local nfile = string.format('%s.new', fname)
+ local out = assert(io.open(nfile, 'w'))
+ out:write(content)
+ out:close()
+ os.rename(nfile, fname)
+ elseif opts.suffix then
+ local nfile = string.format('%s.%s', opts.suffix)
+ local out = assert(io.open(nfile, 'w'))
+ out:write(content)
+ out:close()
+ else
+ io.write(res)
+ end
+ end
+end
+
+return {
+ handler = handler,
+ description = parser._description,
+ name = 'template'
+} \ No newline at end of file
diff --git a/lualib/rspamadm/vault.lua b/lualib/rspamadm/vault.lua
new file mode 100644
index 0000000..840e504
--- /dev/null
+++ b/lualib/rspamadm/vault.lua
@@ -0,0 +1,579 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+
+local rspamd_logger = require "rspamd_logger"
+local ansicolors = require "ansicolors"
+local ucl = require "ucl"
+local argparse = require "argparse"
+local fun = require "fun"
+local rspamd_http = require "rspamd_http"
+local cr = require "rspamd_cryptobox"
+
+local parser = argparse()
+ :name "rspamadm vault"
+ :description "Perform Hashicorp Vault management"
+ :help_description_margin(32)
+ :command_target("command")
+ :require_command(true)
+
+parser:flag "-s --silent"
+ :description "Do not output extra information"
+parser:option "-a --addr"
+ :description "Vault address (if not defined in VAULT_ADDR env)"
+parser:option "-t --token"
+ :description "Vault token (not recommended, better define VAULT_TOKEN env)"
+parser:option "-p --path"
+ :description "Path to work with in the vault"
+ :default "dkim"
+parser:option "-o --output"
+ :description "Output format ('ucl', 'json', 'json-compact', 'yaml')"
+ :argname("<type>")
+ :convert {
+ ucl = "ucl",
+ json = "json",
+ ['json-compact'] = "json-compact",
+ yaml = "yaml",
+}
+ :default "ucl"
+
+parser:command "list ls l"
+ :description "List elements in the vault"
+
+local show = parser:command "show get"
+ :description "Extract element from the vault"
+show:argument "domain"
+ :description "Domain to create key for"
+ :args "+"
+
+local delete = parser:command "delete del rm remove"
+ :description "Delete element from the vault"
+delete:argument "domain"
+ :description "Domain to create delete key(s) for"
+ :args "+"
+
+local newkey = parser:command "newkey new create"
+ :description "Add new key to the vault"
+newkey:argument "domain"
+ :description "Domain to create key for"
+ :args "+"
+newkey:option "-s --selector"
+ :description "Selector to use"
+ :count "?"
+newkey:option "-A --algorithm"
+ :argname("<type>")
+ :convert {
+ rsa = "rsa",
+ ed25519 = "ed25519",
+ eddsa = "ed25519",
+}
+ :default "rsa"
+newkey:option "-b --bits"
+ :argname("<nbits>")
+ :convert(tonumber)
+ :default "1024"
+newkey:option "-x --expire"
+ :argname("<days>")
+ :convert(tonumber)
+newkey:flag "-r --rewrite"
+
+local roll = parser:command "roll rollover"
+ :description "Perform keys rollover"
+roll:argument "domain"
+ :description "Domain to roll key(s) for"
+ :args "+"
+roll:option "-T --ttl"
+ :description "Validity period for old keys (days)"
+ :convert(tonumber)
+ :default "1"
+roll:flag "-r --remove-expired"
+ :description "Remove expired keys"
+roll:option "-x --expire"
+ :argname("<days>")
+ :convert(tonumber)
+
+local function printf(fmt, ...)
+ if fmt then
+ io.write(rspamd_logger.slog(fmt, ...))
+ end
+ io.write('\n')
+end
+
+local function maybe_printf(opts, fmt, ...)
+ if not opts.silent then
+ printf(fmt, ...)
+ end
+end
+
+local function highlight(str, color)
+ return ansicolors[color or 'white'] .. str .. ansicolors.reset
+end
+
+local function vault_url(opts, path)
+ if path then
+ return string.format('%s/v1/%s/%s', opts.addr, opts.path, path)
+ end
+
+ return string.format('%s/v1/%s', opts.addr, opts.path)
+end
+
+local function is_http_error(err, data)
+ return err or (math.floor(data.code / 100) ~= 2)
+end
+
+local function parse_vault_reply(data)
+ local p = ucl.parser()
+ local res, parser_err = p:parse_string(data)
+
+ if not res then
+ return nil, parser_err
+ else
+ return p:get_object(), nil
+ end
+end
+
+local function maybe_print_vault_data(opts, data, func)
+ if data then
+ local res, parser_err = parse_vault_reply(data)
+
+ if not res then
+ printf('vault reply for cannot be parsed: %s', parser_err)
+ else
+ if func then
+ printf(ucl.to_format(func(res), opts.output))
+ else
+ printf(ucl.to_format(res, opts.output))
+ end
+ end
+ else
+ printf('no data received')
+ end
+end
+
+local function print_dkim_txt_record(b64, selector, alg)
+ local labels = {}
+ local prefix = string.format("v=DKIM1; k=%s; p=", alg)
+ b64 = prefix .. b64
+ if #b64 < 255 then
+ labels = { '"' .. b64 .. '"' }
+ else
+ for sl = 1, #b64, 256 do
+ table.insert(labels, '"' .. b64:sub(sl, sl + 255) .. '"')
+ end
+ end
+
+ printf("%s._domainkey IN TXT ( %s )", selector,
+ table.concat(labels, "\n\t"))
+end
+
+local function show_handler(opts, domain)
+ local uri = vault_url(opts, domain)
+ local err, data = rspamd_http.request {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+ url = uri,
+ headers = {
+ ['X-Vault-Token'] = opts.token
+ }
+ }
+
+ if is_http_error(err, data) then
+ printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code)
+ maybe_print_vault_data(opts, err)
+ os.exit(1)
+ else
+ maybe_print_vault_data(opts, data.content, function(obj)
+ return obj.data.selectors
+ end)
+ end
+end
+
+local function delete_handler(opts, domain)
+ local uri = vault_url(opts, domain)
+ local err, data = rspamd_http.request {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+ url = uri,
+ method = 'delete',
+ headers = {
+ ['X-Vault-Token'] = opts.token
+ }
+ }
+
+ if is_http_error(err, data) then
+ printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code)
+ maybe_print_vault_data(opts, err)
+ os.exit(1)
+ else
+ printf('deleted key(s) for %s', domain)
+ end
+end
+
+local function list_handler(opts)
+ local uri = vault_url(opts)
+ local err, data = rspamd_http.request {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+ url = uri .. '?list=true',
+ headers = {
+ ['X-Vault-Token'] = opts.token
+ }
+ }
+
+ if is_http_error(err, data) then
+ printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code)
+ maybe_print_vault_data(opts, err)
+ os.exit(1)
+ else
+ maybe_print_vault_data(opts, data.content, function(obj)
+ return obj.data.keys
+ end)
+ end
+end
+
+-- Returns pair privkey+pubkey
+local function genkey(opts)
+ return cr.gen_dkim_keypair(opts.algorithm, opts.bits)
+end
+
+local function create_and_push_key(opts, domain, existing)
+ local uri = vault_url(opts, domain)
+ local sk, pk = genkey(opts)
+
+ local res = {
+ selectors = {
+ [1] = {
+ selector = opts.selector,
+ domain = domain,
+ key = tostring(sk),
+ pubkey = tostring(pk),
+ alg = opts.algorithm,
+ bits = opts.bits or 0,
+ valid_start = os.time(),
+ }
+ }
+ }
+
+ for _, sel in ipairs(existing) do
+ res.selectors[#res.selectors + 1] = sel
+ end
+
+ if opts.expire then
+ res.selectors[1].valid_end = os.time() + opts.expire * 3600 * 24
+ end
+
+ local err, data = rspamd_http.request {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+ url = uri,
+ method = 'put',
+ headers = {
+ ['Content-Type'] = 'application/json',
+ ['X-Vault-Token'] = opts.token
+ },
+ body = {
+ ucl.to_format(res, 'json-compact')
+ },
+ }
+
+ if is_http_error(err, data) then
+ printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code)
+ maybe_print_vault_data(opts, data.content)
+ os.exit(1)
+ else
+ maybe_printf(opts, 'stored key for: %s, selector: %s', domain, opts.selector)
+ maybe_printf(opts, 'please place the corresponding public key as following:')
+
+ if opts.silent then
+ printf('%s', pk)
+ else
+ print_dkim_txt_record(tostring(pk), opts.selector, opts.algorithm)
+ end
+ end
+end
+
+local function newkey_handler(opts, domain)
+ local uri = vault_url(opts, domain)
+
+ if not opts.selector then
+ opts.selector = string.format('%s-%s', opts.algorithm,
+ os.date("!%Y%m%d"))
+ end
+
+ local err, data = rspamd_http.request {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+ url = uri,
+ method = 'get',
+ headers = {
+ ['X-Vault-Token'] = opts.token
+ }
+ }
+
+ if is_http_error(err, data) or not data.content then
+ create_and_push_key(opts, domain, {})
+ else
+ -- Key exists
+ local rep = parse_vault_reply(data.content)
+
+ if not rep or not rep.data then
+ printf('cannot parse reply for %s: %s', uri, data.content)
+ os.exit(1)
+ end
+
+ local elts = rep.data.selectors
+
+ if not elts then
+ create_and_push_key(opts, domain, {})
+ os.exit(0)
+ end
+
+ for _, sel in ipairs(elts) do
+ if sel.alg == opts.algorithm then
+ printf('key with the specific algorithm %s is already presented at %s selector for %s domain',
+ opts.algorithm, sel.selector, domain)
+ os.exit(1)
+ else
+ create_and_push_key(opts, domain, elts)
+ end
+ end
+ end
+end
+
+local function roll_handler(opts, domain)
+ local uri = vault_url(opts, domain)
+ local res = {
+ selectors = {}
+ }
+
+ local err, data = rspamd_http.request {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+ url = uri,
+ method = 'get',
+ headers = {
+ ['X-Vault-Token'] = opts.token
+ }
+ }
+
+ if is_http_error(err, data) or not data.content then
+ printf("No keys to roll for domain %s", domain)
+ os.exit(1)
+ else
+ local rep = parse_vault_reply(data.content)
+
+ if not rep or not rep.data then
+ printf('cannot parse reply for %s: %s', uri, data.content)
+ os.exit(1)
+ end
+
+ local elts = rep.data.selectors
+
+ if not elts then
+ printf("No keys to roll for domain %s", domain)
+ os.exit(1)
+ end
+
+ local nkeys = {} -- indexed by algorithm
+
+ local function insert_key(sel, add_expire)
+ if not nkeys[sel.alg] then
+ nkeys[sel.alg] = {}
+ end
+
+ if add_expire then
+ sel.valid_end = os.time() + opts.ttl * 3600 * 24
+ end
+
+ table.insert(nkeys[sel.alg], sel)
+ end
+
+ for _, sel in ipairs(elts) do
+ if sel.valid_end and sel.valid_end < os.time() then
+ if not opts.remove_expired then
+ insert_key(sel, false)
+ else
+ maybe_printf(opts, 'removed expired key for %s (selector %s, expire "%s"',
+ domain, sel.selector, os.date('%c', sel.valid_end))
+ end
+ else
+ insert_key(sel, true)
+ end
+ end
+
+ -- Now we need to ensure that all but one selectors have either expired or just a single key
+ for alg, keys in pairs(nkeys) do
+ table.sort(keys, function(k1, k2)
+ if k1.valid_end and k2.valid_end then
+ return k1.valid_end > k2.valid_end
+ elseif k1.valid_end then
+ return true
+ elseif k2.valid_end then
+ return false
+ end
+ return false
+ end)
+ -- Exclude the key with the highest expiration date and examine the rest
+ if not (#keys == 1 or fun.all(function(k)
+ return k.valid_end and k.valid_end < os.time()
+ end, fun.tail(keys))) then
+ printf('bad keys list for %s and %s algorithm', domain, alg)
+ fun.each(function(k)
+ if not k.valid_end then
+ printf('selector %s, algorithm %s has a key with no expire',
+ k.selector, k.alg)
+ elseif k.valid_end >= os.time() then
+ printf('selector %s, algorithm %s has a key that not yet expired: %s',
+ k.selector, k.alg, os.date('%c', k.valid_end))
+ end
+ end, fun.tail(keys))
+ os.exit(1)
+ end
+ -- Do not create new keys, if we only want to remove expired keys
+ if not opts.remove_expired then
+ -- OK to process
+ -- Insert keys for each algorithm in pairs <old_key(s)>, <new_key>
+ local sk, pk = genkey({ algorithm = alg, bits = keys[1].bits })
+ local selector = string.format('%s-%s', alg,
+ os.date("!%Y%m%d"))
+
+ if selector == keys[1].selector then
+ selector = selector .. '-1'
+ end
+ local nelt = {
+ selector = selector,
+ domain = domain,
+ key = tostring(sk),
+ pubkey = tostring(pk),
+ alg = alg,
+ bits = keys[1].bits,
+ valid_start = os.time(),
+ }
+
+ if opts.expire then
+ nelt.valid_end = os.time() + opts.expire * 3600 * 24
+ end
+
+ table.insert(res.selectors, nelt)
+ end
+ for _, k in ipairs(keys) do
+ table.insert(res.selectors, k)
+ end
+ end
+ end
+
+ -- We can now store res in the vault
+ err, data = rspamd_http.request {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+ url = uri,
+ method = 'put',
+ headers = {
+ ['Content-Type'] = 'application/json',
+ ['X-Vault-Token'] = opts.token
+ },
+ body = {
+ ucl.to_format(res, 'json-compact')
+ },
+ }
+
+ if is_http_error(err, data) then
+ printf('cannot put request to the vault (%s), HTTP error code %s', uri, data.code)
+ maybe_print_vault_data(opts, data.content)
+ os.exit(1)
+ else
+ for _, key in ipairs(res.selectors) do
+ if not key.valid_end or key.valid_end > os.time() + opts.ttl * 3600 * 24 then
+ maybe_printf(opts, 'rolled key for: %s, new selector: %s', domain, key.selector)
+ maybe_printf(opts, 'please place the corresponding public key as following:')
+
+ if opts.silent then
+ printf('%s', key.pubkey)
+ else
+ print_dkim_txt_record(key.pubkey, key.selector, key.alg)
+ end
+
+ end
+ end
+
+ maybe_printf(opts, 'your old keys will be valid until %s',
+ os.date('%c', os.time() + opts.ttl * 3600 * 24))
+ end
+end
+
+local function handler(args)
+ local opts = parser:parse(args)
+
+ if not opts.addr then
+ opts.addr = os.getenv('VAULT_ADDR')
+ end
+
+ if not opts.token then
+ opts.token = os.getenv('VAULT_TOKEN')
+ else
+ maybe_printf(opts, 'defining token via command line is insecure, define it via environment variable %s',
+ highlight('VAULT_TOKEN', 'red'))
+ end
+
+ if not opts.token or not opts.addr then
+ printf('no token or/and vault addr has been specified, exiting')
+ os.exit(1)
+ end
+
+ local command = opts.command
+
+ if command == 'list' then
+ list_handler(opts)
+ elseif command == 'show' then
+ fun.each(function(d)
+ show_handler(opts, d)
+ end, opts.domain)
+ elseif command == 'newkey' then
+ fun.each(function(d)
+ newkey_handler(opts, d)
+ end, opts.domain)
+ elseif command == 'roll' then
+ fun.each(function(d)
+ roll_handler(opts, d)
+ end, opts.domain)
+ elseif command == 'delete' then
+ fun.each(function(d)
+ delete_handler(opts, d)
+ end, opts.domain)
+ else
+ parser:error(string.format('command %s is not implemented', command))
+ end
+end
+
+return {
+ handler = handler,
+ description = parser._description,
+ name = 'vault'
+}