summaryrefslogtreecommitdiffstats
path: root/src/plugins/lua/multimap.lua
diff options
context:
space:
mode:
Diffstat (limited to 'src/plugins/lua/multimap.lua')
-rw-r--r--src/plugins/lua/multimap.lua1403
1 files changed, 1403 insertions, 0 deletions
diff --git a/src/plugins/lua/multimap.lua b/src/plugins/lua/multimap.lua
new file mode 100644
index 0000000..53b2732
--- /dev/null
+++ b/src/plugins/lua/multimap.lua
@@ -0,0 +1,1403 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+if confighelp then
+ return
+end
+
+-- Multimap is rspamd module designed to define and operate with different maps
+
+local rules = {}
+local rspamd_logger = require "rspamd_logger"
+local rspamd_util = require "rspamd_util"
+local rspamd_regexp = require "rspamd_regexp"
+local rspamd_expression = require "rspamd_expression"
+local rspamd_ip = require "rspamd_ip"
+local lua_util = require "lua_util"
+local lua_selectors = require "lua_selectors"
+local lua_maps = require "lua_maps"
+local redis_params
+local fun = require "fun"
+local N = 'multimap'
+
+local multimap_grammar
+-- Parse result in form: <symbol>:<score>|<symbol>|<score>
+local function parse_multimap_value(parse_rule, p_ret)
+ if p_ret and type(p_ret) == 'string' then
+ local lpeg = require "lpeg"
+
+ if not multimap_grammar then
+ local number = {}
+
+ local digit = lpeg.R("09")
+ number.integer = (lpeg.S("+-") ^ -1) *
+ (digit ^ 1)
+
+ -- Matches: .6, .899, .9999873
+ number.fractional = (lpeg.P(".")) *
+ (digit ^ 1)
+
+ -- Matches: 55.97, -90.8, .9
+ number.decimal = (number.integer * -- Integer
+ (number.fractional ^ -1)) + -- Fractional
+ (lpeg.S("+-") * number.fractional) -- Completely fractional number
+
+ local sym_start = lpeg.R("az", "AZ") + lpeg.S("_")
+ local sym_elt = sym_start + lpeg.R("09")
+ local symbol = sym_start * sym_elt ^ 0
+ local symbol_cap = lpeg.Cg(symbol, 'symbol')
+ local score_cap = lpeg.Cg(number.decimal, 'score')
+ local opts_cap = lpeg.Cg(lpeg.Ct(lpeg.C(symbol) * (lpeg.P(",") * lpeg.C(symbol)) ^ 0), 'opts')
+ local symscore_cap = (symbol_cap * lpeg.P(":") * score_cap)
+ local symscoreopt_cap = symscore_cap * lpeg.P(":") * opts_cap
+ local grammar = symscoreopt_cap + symscore_cap + symbol_cap + score_cap
+ multimap_grammar = lpeg.Ct(grammar)
+ end
+ local tbl = multimap_grammar:match(p_ret)
+
+ if tbl then
+ local sym
+ local score = 1.0
+ local opts = {}
+
+ if tbl.symbol then
+ sym = tbl.symbol
+ end
+ if tbl.score then
+ score = tonumber(tbl.score)
+ end
+ if tbl.opts then
+ opts = tbl.opts
+ end
+
+ return true, sym, score, opts
+ else
+ if p_ret ~= '' then
+ rspamd_logger.infox(rspamd_config, '%s: cannot parse string "%s"',
+ parse_rule.symbol, p_ret)
+ end
+
+ return true, nil, 1.0, {}
+ end
+ elseif type(p_ret) == 'boolean' then
+ return p_ret, nil, 1.0, {}
+ end
+
+ return false, nil, 0.0, {}
+end
+
+local value_types = {
+ ip = {
+ get_value = function(ip)
+ return ip:to_string()
+ end,
+ },
+ from = {
+ get_value = function(val)
+ return val
+ end,
+ },
+ helo = {
+ get_value = function(val)
+ return val
+ end,
+ },
+ header = {
+ get_value = function(val)
+ return val
+ end,
+ },
+ rcpt = {
+ get_value = function(val)
+ return val
+ end,
+ },
+ user = {
+ get_value = function(val)
+ return val
+ end,
+ },
+ url = {
+ get_value = function(val)
+ return val
+ end,
+ },
+ dnsbl = {
+ get_value = function(ip)
+ return ip:to_string()
+ end,
+ },
+ filename = {
+ get_value = function(val)
+ return val
+ end,
+ },
+ content = {
+ get_value = function()
+ return nil
+ end,
+ },
+ hostname = {
+ get_value = function(val)
+ return val
+ end,
+ },
+ asn = {
+ get_value = function(val)
+ return val
+ end,
+ },
+ country = {
+ get_value = function(val)
+ return val
+ end,
+ },
+ received = {
+ get_value = function(val)
+ return val
+ end,
+ },
+ mempool = {
+ get_value = function(val)
+ return val
+ end,
+ },
+ selector = {
+ get_value = function(val)
+ return val
+ end,
+ },
+ symbol_options = {
+ get_value = function(val)
+ return val
+ end,
+ },
+}
+
+local function ip_to_rbl(ip, rbl)
+ return table.concat(ip:inversed_str_octets(), ".") .. '.' .. rbl
+end
+
+local function apply_hostname_filter(task, filter, hostname, r)
+ if filter == 'tld' then
+ local tld = rspamd_util.get_tld(hostname)
+ return tld
+ elseif filter == 'top' then
+ local tld = rspamd_util.get_tld(hostname)
+ return tld:match('[^.]*$') or tld
+ else
+ if not r['re_filter'] then
+ local pat = string.match(filter, 'tld:regexp:(.+)')
+ if not pat then
+ rspamd_logger.errx(task, 'bad search filter: %s', filter)
+ return
+ end
+ r['re_filter'] = rspamd_regexp.create_cached(pat)
+ if not r['re_filter'] then
+ rspamd_logger.errx(task, 'couldnt create regex: %s', pat)
+ return
+ end
+ end
+ local tld = rspamd_util.get_tld(hostname)
+ local res = r['re_filter']:search(tld)
+ if res then
+ return res[1]
+ else
+ return nil
+ end
+ end
+end
+
+local function apply_url_filter(task, filter, url, r)
+ if not filter then
+ return url:get_host()
+ end
+
+ if filter == 'tld' then
+ return url:get_tld()
+ elseif filter == 'top' then
+ local tld = url:get_tld()
+ return tld:match('[^.]*$') or tld
+ elseif filter == 'full' then
+ return url:get_text()
+ elseif filter == 'is_phished' then
+ if url:is_phished() then
+ return url:get_host()
+ else
+ return nil
+ end
+ elseif filter == 'is_redirected' then
+ if url:is_redirected() then
+ return url:get_host()
+ else
+ return nil
+ end
+ elseif filter == 'is_obscured' then
+ if url:is_obscured() then
+ return url:get_host()
+ else
+ return nil
+ end
+ elseif filter == 'path' then
+ return url:get_path()
+ elseif filter == 'query' then
+ return url:get_query()
+ elseif string.find(filter, 'tag:') then
+ local tags = url:get_tags()
+ local want_tag = string.match(filter, 'tag:(.*)')
+ for _, t in ipairs(tags) do
+ if t == want_tag then
+ return url:get_host()
+ end
+ end
+ return nil
+ elseif string.find(filter, 'tld:regexp:') then
+ if not r['re_filter'] then
+ local type, pat = string.match(filter, '(regexp:)(.+)')
+ if type and pat then
+ r['re_filter'] = rspamd_regexp.create_cached(pat)
+ end
+ end
+
+ if not r['re_filter'] then
+ rspamd_logger.errx(task, 'bad search filter: %s', filter)
+ else
+ local results = r['re_filter']:search(url:get_tld())
+ if results then
+ return results[1]
+ else
+ return nil
+ end
+ end
+ elseif string.find(filter, 'full:regexp:') then
+ if not r['re_filter'] then
+ local type, pat = string.match(filter, '(regexp:)(.+)')
+ if type and pat then
+ r['re_filter'] = rspamd_regexp.create_cached(pat)
+ end
+ end
+
+ if not r['re_filter'] then
+ rspamd_logger.errx(task, 'bad search filter: %s', filter)
+ else
+ local results = r['re_filter']:search(url:get_text())
+ if results then
+ return results[1]
+ else
+ return nil
+ end
+ end
+ elseif string.find(filter, 'regexp:') then
+ if not r['re_filter'] then
+ local type, pat = string.match(filter, '(regexp:)(.+)')
+ if type and pat then
+ r['re_filter'] = rspamd_regexp.create_cached(pat)
+ end
+ end
+
+ if not r['re_filter'] then
+ rspamd_logger.errx(task, 'bad search filter: %s', filter)
+ else
+ local results = r['re_filter']:search(url:get_host())
+ if results then
+ return results[1]
+ else
+ return nil
+ end
+ end
+ elseif string.find(filter, '^template:') then
+ if not r['template'] then
+ r['template'] = string.match(filter, '^template:(.+)')
+ end
+
+ if r['template'] then
+ return lua_util.template(r['template'], url:to_table())
+ end
+ end
+
+ return url:get_host()
+end
+
+local function apply_addr_filter(task, filter, input, rule)
+ if filter == 'email:addr' or filter == 'email' then
+ local addr = rspamd_util.parse_mail_address(input, task:get_mempool(), 1024)
+ if addr and addr[1] then
+ return fun.totable(fun.map(function(a)
+ return a.addr
+ end, addr))
+ end
+ elseif filter == 'email:user' then
+ local addr = rspamd_util.parse_mail_address(input, task:get_mempool(), 1024)
+ if addr and addr[1] then
+ return fun.totable(fun.map(function(a)
+ return a.user
+ end, addr))
+ end
+ elseif filter == 'email:domain' then
+ local addr = rspamd_util.parse_mail_address(input, task:get_mempool(), 1024)
+ if addr and addr[1] then
+ return fun.totable(fun.map(function(a)
+ return a.domain
+ end, addr))
+ end
+ elseif filter == 'email:domain:tld' then
+ local addr = rspamd_util.parse_mail_address(input, task:get_mempool(), 1024)
+ if addr and addr[1] then
+ return fun.totable(fun.map(function(a)
+ return rspamd_util.get_tld(a.domain)
+ end, addr))
+ end
+ elseif filter == 'email:name' then
+ local addr = rspamd_util.parse_mail_address(input, task:get_mempool(), 1024)
+ if addr and addr[1] then
+ return fun.totable(fun.map(function(a)
+ return a.name
+ end, addr))
+ end
+ elseif filter == 'ip_addr' then
+ local ip_addr = rspamd_ip.from_string(input)
+
+ if ip_addr and ip_addr:is_valid() then
+ return ip_addr
+ end
+ else
+ -- regexp case
+ if not rule['re_filter'] then
+ local type, pat = string.match(filter, '(regexp:)(.+)')
+ if type and pat then
+ rule['re_filter'] = rspamd_regexp.create_cached(pat)
+ end
+ end
+
+ if not rule['re_filter'] then
+ rspamd_logger.errx(task, 'bad search filter: %s', filter)
+ else
+ local results = rule['re_filter']:search(input)
+ if results then
+ return results[1]
+ end
+ end
+ end
+
+ return input
+end
+local function apply_filename_filter(task, filter, fn, r)
+ if filter == 'extension' or filter == 'ext' then
+ return string.match(fn, '%.([^.]+)$')
+ elseif string.find(filter, 'regexp:') then
+ if not r['re_filter'] then
+ local type, pat = string.match(filter, '(regexp:)(.+)')
+ if type and pat then
+ r['re_filter'] = rspamd_regexp.create_cached(pat)
+ end
+ end
+
+ if not r['re_filter'] then
+ rspamd_logger.errx(task, 'bad search filter: %s', filter)
+ else
+ local results = r['re_filter']:search(fn)
+ if results then
+ return results[1]
+ else
+ return nil
+ end
+ end
+ end
+
+ return fn
+end
+
+local function apply_regexp_filter(task, filter, fn, r)
+ if string.find(filter, 'regexp:') then
+ if not r['re_filter'] then
+ local type, pat = string.match(filter, '(regexp:)(.+)')
+ if type and pat then
+ r['re_filter'] = rspamd_regexp.create_cached(pat)
+ end
+ end
+
+ if not r['re_filter'] then
+ rspamd_logger.errx(task, 'bad search filter: %s', filter)
+ else
+ local results = r['re_filter']:search(fn, false, true)
+ if results then
+ return results[1][2]
+ else
+ return nil
+ end
+ end
+ end
+
+ return fn
+end
+
+local function apply_content_filter(task, filter)
+ if filter == 'body' then
+ return { task:get_rawbody() }
+ elseif filter == 'full' then
+ return { task:get_content() }
+ elseif filter == 'headers' then
+ return { task:get_raw_headers() }
+ elseif filter == 'text' then
+ local ret = {}
+ for _, p in ipairs(task:get_text_parts()) do
+ table.insert(ret, p:get_content())
+ end
+ return ret
+ elseif filter == 'rawtext' then
+ local ret = {}
+ for _, p in ipairs(task:get_text_parts()) do
+ table.insert(ret, p:get_content('raw_parsed'))
+ end
+ return ret
+ elseif filter == 'oneline' then
+ local ret = {}
+ for _, p in ipairs(task:get_text_parts()) do
+ table.insert(ret, p:get_content_oneline())
+ end
+ return ret
+ else
+ rspamd_logger.errx(task, 'bad search filter: %s', filter)
+ end
+
+ return {}
+end
+
+local multimap_filters = {
+ from = apply_addr_filter,
+ rcpt = apply_addr_filter,
+ helo = apply_hostname_filter,
+ symbol_options = apply_regexp_filter,
+ header = apply_addr_filter,
+ url = apply_url_filter,
+ filename = apply_filename_filter,
+ mempool = apply_regexp_filter,
+ selector = apply_regexp_filter,
+ hostname = apply_hostname_filter,
+ --content = apply_content_filter, -- Content filters are special :(
+}
+
+local function multimap_query_redis(key, task, value, callback)
+ local cmd = 'HGET'
+ if type(value) == 'userdata' and value.class == 'rspamd{ip}' then
+ cmd = 'HMGET'
+ end
+
+ local srch = { key }
+
+ -- Insert all ips for some mask :(
+ if type(value) == 'userdata' and value.class == 'rspamd{ip}' then
+ srch[#srch + 1] = tostring(value)
+ -- IPv6 case
+ local maxbits = 128
+ local minbits = 64
+ if value:get_version() == 4 then
+ maxbits = 32
+ minbits = 8
+ end
+ for i = maxbits, minbits, -1 do
+ local nip = value:apply_mask(i):tostring() .. "/" .. i
+ srch[#srch + 1] = nip
+ end
+ else
+ srch[#srch + 1] = value
+ end
+
+ local function redis_map_cb(err, data)
+ lua_util.debugm(N, task, 'got reply from Redis when trying to get key %s: err=%s, data=%s',
+ key, err, data)
+ if not err and type(data) ~= 'userdata' then
+ callback(data)
+ end
+ end
+
+ return rspamd_redis_make_request(task,
+ redis_params, -- connect params
+ key, -- hash key
+ false, -- is write
+ redis_map_cb, --callback
+ cmd, -- command
+ srch -- arguments
+ )
+end
+
+local function multimap_callback(task, rule)
+ local function match_element(r, value, callback)
+ if not value then
+ return false
+ end
+
+ local function get_key_callback(ret, err_or_data, err_code)
+ lua_util.debugm(N, task, 'got return "%s" (err code = %s) for multimap %s',
+ err_or_data,
+ err_code,
+ rule.symbol)
+
+ if ret then
+ if type(err_or_data) == 'table' then
+ for _, elt in ipairs(err_or_data) do
+ callback(elt)
+ end
+ else
+ callback(err_or_data)
+ end
+ elseif err_code ~= 404 then
+ rspamd_logger.infox(task, "map %s: get key returned error %s: %s",
+ rule.symbol, err_code, err_or_data)
+ end
+ end
+
+ lua_util.debugm(N, task, 'check value %s for multimap %s', value,
+ rule.symbol)
+
+ local ret = false
+
+ if r.redis_key then
+ -- Deal with hash name here: it can be either plain string or a selector
+ if type(r.redis_key) == 'string' then
+ ret = multimap_query_redis(r.redis_key, task, value, callback)
+ else
+ -- Here we have a selector
+ local results = r.redis_key(task)
+
+ -- Here we need to spill this function into multiple queries
+ if type(results) == 'table' then
+ for _, res in ipairs(results) do
+ ret = multimap_query_redis(res, task, value, callback)
+
+ if not ret then
+ break
+ end
+ end
+ else
+ ret = multimap_query_redis(results, task, value, callback)
+ end
+ end
+
+ return ret
+ elseif r.map_obj then
+ r.map_obj:get_key(value, get_key_callback, task)
+ end
+ end
+
+ local function insert_results(result, opt)
+ local _, symbol, score, opts = parse_multimap_value(rule, result)
+ local forced = false
+ if symbol then
+ if rule.symbols_set then
+ if not rule.symbols_set[symbol] then
+ rspamd_logger.infox(task, 'symbol %s is not registered for map %s, ' ..
+ 'replace it with just %s',
+ symbol, rule.symbol, rule.symbol)
+ symbol = rule.symbol
+ end
+ elseif rule.disable_multisymbol then
+ symbol = rule.symbol
+ if type(opt) == 'table' then
+ table.insert(opt, result)
+ elseif type(opt) ~= nil then
+ opt = { opt, result }
+ else
+ opt = { result }
+ end
+ else
+ forced = not rule.dynamic_symbols
+ end
+ else
+ symbol = rule.symbol
+ end
+
+ if opts and #opts > 0 then
+ -- Options come from the map itself
+ task:insert_result(forced, symbol, score, opts)
+ else
+ if opt then
+ if type(opt) == 'table' then
+ task:insert_result(forced, symbol, score, fun.totable(fun.map(tostring, opt)))
+ else
+ task:insert_result(forced, symbol, score, tostring(opt))
+ end
+
+ else
+ task:insert_result(forced, symbol, score)
+ end
+ end
+
+ if rule.action then
+ local message = rule.message
+ if rule.message_func then
+ message = rule.message_func(task, rule.symbol, opt)
+ end
+ if message then
+ task:set_pre_result(rule.action, message, N)
+ else
+ task:set_pre_result(rule.action, 'Matched map: ' .. rule.symbol, N)
+ end
+ end
+ end
+
+ -- Match a single value for against a single rule
+ local function match_rule(r, value)
+ local function rule_callback(result)
+ if result then
+ if type(result) == 'table' then
+ for _, rs in ipairs(result) do
+ if type(rs) ~= 'userdata' then
+ rule_callback(rs)
+ end
+ end
+ return
+ end
+ local opt = value_types[r['type']].get_value(value)
+ insert_results(result, opt)
+ end
+ end
+
+ if r.filter or r.type == 'url' then
+ local fn = multimap_filters[r.type]
+
+ if fn then
+
+ local filtered_value = fn(task, r.filter, value, r)
+ lua_util.debugm(N, task, 'apply filter %s for rule %s: %s -> %s',
+ r.filter, r.symbol, value, filtered_value)
+ value = filtered_value
+ end
+ end
+
+ if type(value) == 'table' then
+ fun.each(function(elt)
+ match_element(r, elt, rule_callback)
+ end, value)
+ else
+ match_element(r, value, rule_callback)
+ end
+ end
+
+ -- Match list of values according to the field
+ local function match_list(r, ls, fields)
+ if ls then
+ if fields then
+ fun.each(function(e)
+ local match = e[fields[1]]
+ if match then
+ if fields[2] then
+ match = fields[2](match)
+ end
+ match_rule(r, match)
+ end
+ end, ls)
+ else
+ fun.each(function(e)
+ match_rule(r, e)
+ end, ls)
+ end
+ end
+ end
+
+ local function match_addr(r, addr)
+ match_list(r, addr, { 'addr' })
+
+ if not r.filter then
+ match_list(r, addr, { 'domain' })
+ match_list(r, addr, { 'user' })
+ end
+ end
+
+ local function match_url(r, url)
+ match_rule(r, url)
+ end
+
+ local function match_hostname(r, hostname)
+ match_rule(r, hostname)
+ end
+
+ local function match_filename(r, fn)
+ match_rule(r, fn)
+ end
+
+ local function match_received_header(r, pos, total, h)
+ local use_tld = false
+ local filter = r['filter'] or 'real_ip'
+ if filter:match('^tld:') then
+ filter = filter:sub(5)
+ use_tld = true
+ end
+ local v = h[filter]
+ if v then
+ local min_pos = tonumber(r['min_pos'])
+ local max_pos = tonumber(r['max_pos'])
+ if min_pos then
+ if min_pos < 0 then
+ if min_pos == -1 then
+ if (pos ~= total) then
+ return
+ end
+ else
+ if pos <= (total - (min_pos * -1)) then
+ return
+ end
+ end
+ elseif pos < min_pos then
+ return
+ end
+ end
+ if max_pos then
+ if max_pos < -1 then
+ if (total - (max_pos * -1)) >= pos then
+ return
+ end
+ elseif max_pos > 0 then
+ if pos > max_pos then
+ return
+ end
+ end
+ end
+ local match_flags = r['flags']
+ local nmatch_flags = r['nflags']
+ if match_flags or nmatch_flags then
+ local got_flags = h['flags']
+ if match_flags then
+ for _, flag in ipairs(match_flags) do
+ if not got_flags[flag] then
+ return
+ end
+ end
+ end
+ if nmatch_flags then
+ for _, flag in ipairs(nmatch_flags) do
+ if got_flags[flag] then
+ return
+ end
+ end
+ end
+ end
+ if filter == 'real_ip' or filter == 'from_ip' then
+ if type(v) == 'string' then
+ v = rspamd_ip.from_string(v)
+ end
+ if v and v:is_valid() then
+ match_rule(r, v)
+ end
+ else
+ if use_tld and type(v) == 'string' then
+ v = rspamd_util.get_tld(v)
+ end
+ match_rule(r, v)
+ end
+ end
+ end
+
+ local function match_content(r)
+ local data
+
+ if r['filter'] then
+ data = apply_content_filter(task, r['filter'], r)
+ else
+ data = { task:get_content() }
+ end
+
+ for _, v in ipairs(data) do
+ match_rule(r, v)
+ end
+ end
+
+ if rule.expression and not rule.combined then
+ local res, trace = rule['expression']:process_traced(task)
+
+ if not res or res == 0 then
+ lua_util.debugm(N, task, 'condition is false for %s',
+ rule.symbol)
+ return
+ else
+ lua_util.debugm(N, task, 'condition is true for %s: %s',
+ rule.symbol,
+ trace)
+ end
+ end
+
+ local process_rule_funcs = {
+ ip = function()
+ local ip = task:get_from_ip()
+ if ip and ip:is_valid() then
+ match_rule(rule, ip)
+ end
+ end,
+ dnsbl = function()
+ local ip = task:get_from_ip()
+ if ip and ip:is_valid() then
+ local to_resolve = ip_to_rbl(ip, rule['map'])
+ local function dns_cb(_, _, results, err)
+ lua_util.debugm(N, rspamd_config,
+ 'resolve() finished: results=%1, err=%2, to_resolve=%3',
+ results, err, to_resolve)
+
+ if err and
+ (err ~= 'requested record is not found' and
+ err ~= 'no records with this name') then
+ rspamd_logger.errx(task, 'error looking up %s: %s', to_resolve, results)
+ elseif results then
+ task:insert_result(rule['symbol'], 1, rule['map'])
+ if rule.action then
+ task:set_pre_result(rule['action'],
+ 'Matched map: ' .. rule['symbol'], N)
+ end
+ end
+ end
+
+ task:get_resolver():resolve_a({
+ task = task,
+ name = to_resolve,
+ callback = dns_cb,
+ forced = true
+ })
+ end
+ end,
+ header = function()
+ if type(rule['header']) == 'table' then
+ for _, rh in ipairs(rule['header']) do
+ local hv = task:get_header_full(rh)
+ match_list(rule, hv, { 'decoded' })
+ end
+ else
+ local hv = task:get_header_full(rule['header'])
+ match_list(rule, hv, { 'decoded' })
+ end
+ end,
+ rcpt = function()
+ if task:has_recipients('smtp') then
+ local rcpts = task:get_recipients('smtp')
+ match_addr(rule, rcpts)
+ elseif task:has_recipients('mime') then
+ local rcpts = task:get_recipients('mime')
+ match_addr(rule, rcpts)
+ end
+ end,
+ from = function()
+ if task:has_from('smtp') then
+ local from = task:get_from('smtp')
+ match_addr(rule, from)
+ elseif task:has_from('mime') then
+ local from = task:get_from('mime')
+ match_addr(rule, from)
+ end
+ end,
+ helo = function()
+ local helo = task:get_helo()
+ if helo then
+ match_hostname(rule, helo)
+ end
+ end,
+ url = function()
+ if task:has_urls() then
+ local msg_urls = task:get_urls()
+
+ for _, url in ipairs(msg_urls) do
+ match_url(rule, url)
+ end
+ end
+ end,
+ user = function()
+ local user = task:get_user()
+ if user then
+ match_rule(rule, user)
+ end
+ end,
+ filename = function()
+ local parts = task:get_parts()
+
+ local function filter_parts(p)
+ return p:is_attachment() or (not p:is_text()) and (not p:is_multipart())
+ end
+
+ local function filter_archive(p)
+ local ext = p:get_detected_ext()
+ local det_type = 'unknown'
+
+ if ext then
+ local lua_magic_types = require "lua_magic/types"
+ local det_t = lua_magic_types[ext]
+
+ if det_t then
+ det_type = det_t.type
+ end
+ end
+
+ return p:is_archive() and det_type == 'archive' and not rule.skip_archives
+ end
+
+ for _, p in fun.iter(fun.filter(filter_parts, parts)) do
+ if filter_archive(p) then
+ local fnames = p:get_archive():get_files(1000)
+
+ for _, fn in ipairs(fnames) do
+ match_filename(rule, fn)
+ end
+ end
+
+ local fn = p:get_filename()
+ if fn then
+ match_filename(rule, fn)
+ end
+ -- Also deal with detected content type
+ if not rule.skip_detected then
+ local ext = p:get_detected_ext()
+
+ if ext then
+ local fake_fname = string.format('detected.%s', ext)
+ lua_util.debugm(N, task, 'detected filename %s',
+ fake_fname)
+ match_filename(rule, fake_fname)
+ end
+ end
+ end
+ end,
+
+ content = function()
+ match_content(rule)
+ end,
+ hostname = function()
+ local hostname = task:get_hostname()
+ if hostname then
+ match_hostname(rule, hostname)
+ end
+ end,
+ asn = function()
+ local asn = task:get_mempool():get_variable('asn')
+ if asn then
+ match_rule(rule, asn)
+ end
+ end,
+ country = function()
+ local country = task:get_mempool():get_variable('country')
+ if country then
+ match_rule(rule, country)
+ end
+ end,
+ mempool = function()
+ local var = task:get_mempool():get_variable(rule['variable'])
+ if var then
+ match_rule(rule, var)
+ end
+ end,
+ symbol_options = function()
+ local sym = task:get_symbol(rule['target_symbol'])
+ if sym and sym[1].options then
+ for _, o in ipairs(sym[1].options) do
+ match_rule(rule, o)
+ end
+ end
+ end,
+ received = function()
+ local hdrs = task:get_received_headers()
+ if hdrs and hdrs[1] then
+ if not rule['artificial'] then
+ hdrs = fun.filter(function(h)
+ return not h['flags']['artificial']
+ end, hdrs):totable()
+ end
+ for pos, h in ipairs(hdrs) do
+ match_received_header(rule, pos, #hdrs, h)
+ end
+ end
+ end,
+ selector = function()
+ local elts = rule.selector(task)
+
+ if elts then
+ if type(elts) == 'table' then
+ for _, elt in ipairs(elts) do
+ match_rule(rule, elt)
+ end
+ else
+ match_rule(rule, elts)
+ end
+ end
+ end,
+ combined = function()
+ local ret, trace = rule.combined:process(task)
+ if ret and ret ~= 0 then
+ for n, t in pairs(trace) do
+ insert_results(t.value, string.format("%s=%s",
+ n, t.matched))
+ end
+ end
+ end,
+ }
+
+ local rt = rule.type
+ local process_func = process_rule_funcs[rt]
+ if process_func then
+ process_func()
+ else
+ rspamd_logger.errx(task, 'Unrecognised rule type: %s', rt)
+ end
+end
+
+local function gen_multimap_callback(rule)
+ return function(task)
+ multimap_callback(task, rule)
+ end
+end
+
+local function multimap_on_load_gen(rule)
+ return function()
+ lua_util.debugm(N, rspamd_config, "loaded map object for rule %s", rule['symbol'])
+ local known_symbols = {}
+ rule.map_obj:foreach(function(key, value)
+ local r, symbol, score, _ = parse_multimap_value(rule, value)
+
+ if r and symbol and not known_symbols[symbol] then
+ lua_util.debugm(N, rspamd_config, "%s: adding new symbol %s (score = %s), triggered by %s",
+ rule.symbol, symbol, score, key)
+ rspamd_config:register_symbol {
+ name = value,
+ parent = rule.callback_id,
+ type = 'virtual',
+ score = score,
+ }
+ rspamd_config:set_metric_symbol({
+ group = N,
+ score = 1.0, -- In future, we will parse score from `get_value` and use it as multiplier
+ description = 'Automatic symbol generated by rule: ' .. rule.symbol,
+ name = value,
+ })
+ known_symbols[value] = true
+ end
+ end)
+ end
+end
+
+local function add_multimap_rule(key, newrule)
+ local ret = false
+
+ local function multimap_load_kv_map(rule)
+ if rule['regexp'] then
+ if rule['multi'] then
+ rule.map_obj = lua_maps.map_add_from_ucl(rule.map, 'regexp_multi',
+ rule.description)
+ else
+ rule.map_obj = lua_maps.map_add_from_ucl(rule.map, 'regexp',
+ rule.description)
+ end
+ elseif rule['glob'] then
+ if rule['multi'] then
+ rule.map_obj = lua_maps.map_add_from_ucl(rule.map, 'glob_multi',
+ rule.description)
+ else
+ rule.map_obj = lua_maps.map_add_from_ucl(rule.map, 'glob',
+ rule.description)
+ end
+ else
+ rule.map_obj = lua_maps.map_add_from_ucl(rule.map, 'hash',
+ rule.description)
+ end
+ end
+
+ local known_generic_types = {
+ header = true,
+ rcpt = true,
+ from = true,
+ helo = true,
+ symbol_options = true,
+ filename = true,
+ url = true,
+ user = true,
+ content = true,
+ hostname = true,
+ asn = true,
+ country = true,
+ mempool = true,
+ selector = true,
+ combined = true
+ }
+
+ if newrule['message_func'] then
+ newrule['message_func'] = assert(load(newrule['message_func']))()
+ end
+ if newrule['url'] and not newrule['map'] then
+ newrule['map'] = newrule['url']
+ end
+ if not (newrule.map or newrule.rules) then
+ rspamd_logger.errx(rspamd_config, 'incomplete rule, missing map')
+ return nil
+ end
+ if not newrule['symbol'] and key then
+ newrule['symbol'] = key
+ elseif not newrule['symbol'] then
+ rspamd_logger.errx(rspamd_config, 'incomplete rule, missing symbol')
+ return nil
+ end
+ if not newrule['description'] then
+ newrule['description'] = string.format('multimap, type %s: %s', newrule['type'],
+ newrule['symbol'])
+ end
+ if newrule['type'] == 'mempool' and not newrule['variable'] then
+ rspamd_logger.errx(rspamd_config, 'mempool map requires variable')
+ return nil
+ end
+ if newrule['type'] == 'selector' then
+ if not newrule['selector'] then
+ rspamd_logger.errx(rspamd_config, 'selector map requires selector definition')
+ return nil
+ else
+ local selector = lua_selectors.create_selector_closure(
+ rspamd_config, newrule['selector'], newrule['delimiter'] or "")
+
+ if not selector then
+ rspamd_logger.errx(rspamd_config, 'selector map has invalid selector: "%s", symbol: %s',
+ newrule['selector'], newrule['symbol'])
+ return nil
+ end
+
+ newrule.selector = selector
+ end
+ end
+ if type(newrule['map']) == 'string' and
+ string.find(newrule['map'], '^redis://.*$') then
+ if not redis_params then
+ rspamd_logger.infox(rspamd_config, 'no redis servers are specified, ' ..
+ 'cannot add redis map %s: %s', newrule['symbol'], newrule['map'])
+ return nil
+ end
+
+ newrule['redis_key'] = string.match(newrule['map'], '^redis://(.*)$')
+
+ if newrule['redis_key'] then
+ ret = true
+ end
+ elseif type(newrule['map']) == 'string' and
+ string.find(newrule['map'], '^redis%+selector://.*$') then
+ if not redis_params then
+ rspamd_logger.infox(rspamd_config, 'no redis servers are specified, ' ..
+ 'cannot add redis map %s: %s', newrule['symbol'], newrule['map'])
+ return nil
+ end
+
+ local selector_str = string.match(newrule['map'], '^redis%+selector://(.*)$')
+ local selector = lua_selectors.create_selector_closure(
+ rspamd_config, selector_str, newrule['delimiter'] or "")
+
+ if not selector then
+ rspamd_logger.errx(rspamd_config, 'redis selector map has invalid selector: "%s", symbol: %s',
+ selector_str, newrule['symbol'])
+ return nil
+ end
+
+ newrule['redis_key'] = selector
+ ret = true
+ elseif newrule.type == 'combined' then
+ local lua_maps_expressions = require "lua_maps_expressions"
+ newrule.combined = lua_maps_expressions.create(rspamd_config,
+ {
+ rules = newrule.rules,
+ expression = newrule.expression,
+ on_load = newrule.dynamic_symbols and multimap_on_load_gen(newrule) or nil,
+ }, N, 'Combined map for ' .. newrule.symbol)
+ if not newrule.combined then
+ rspamd_logger.errx(rspamd_config, 'cannot add combined map for %s', newrule.symbol)
+ else
+ ret = true
+ end
+ else
+ if newrule['type'] == 'ip' then
+ newrule.map_obj = lua_maps.map_add_from_ucl(newrule.map, 'radix',
+ newrule.description)
+ if newrule.map_obj then
+ ret = true
+ else
+ rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %1',
+ newrule['map'])
+ end
+ elseif newrule['type'] == 'received' then
+ if type(newrule['flags']) == 'table' and newrule['flags'][1] then
+ newrule['flags'] = newrule['flags']
+ elseif type(newrule['flags']) == 'string' then
+ newrule['flags'] = { newrule['flags'] }
+ end
+ if type(newrule['nflags']) == 'table' and newrule['nflags'][1] then
+ newrule['nflags'] = newrule['nflags']
+ elseif type(newrule['nflags']) == 'string' then
+ newrule['nflags'] = { newrule['nflags'] }
+ end
+ local filter = newrule['filter'] or 'real_ip'
+ if filter == 'real_ip' or filter == 'from_ip' then
+ newrule.map_obj = lua_maps.map_add_from_ucl(newrule.map, 'radix',
+ newrule.description)
+ if newrule.map_obj then
+ ret = true
+ else
+ rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %1',
+ newrule['map'])
+ end
+ else
+ multimap_load_kv_map(newrule)
+
+ if newrule.map_obj then
+ ret = true
+ else
+ rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %1',
+ newrule['map'])
+ end
+ end
+ elseif known_generic_types[newrule.type] then
+
+ if newrule.filter == 'ip_addr' then
+ newrule.map_obj = lua_maps.map_add_from_ucl(newrule.map, 'radix',
+ newrule.description)
+ elseif not newrule.combined then
+ multimap_load_kv_map(newrule)
+ end
+
+ if newrule.map_obj then
+ ret = true
+ else
+ rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %1',
+ newrule['map'])
+ end
+ elseif newrule['type'] == 'dnsbl' then
+ ret = true
+ end
+ end
+
+ if ret then
+ if newrule.map_obj and newrule.dynamic_symbols then
+ newrule.map_obj:on_load(multimap_on_load_gen(newrule))
+ end
+ if newrule['type'] == 'symbol_options' then
+ rspamd_config:register_dependency(newrule['symbol'], newrule['target_symbol'])
+ end
+ if newrule['require_symbols'] then
+ local atoms = {}
+
+ local function parse_atom(str)
+ local atom = table.concat(fun.totable(fun.take_while(function(c)
+ if string.find(', \t()><+!|&\n', c, 1, true) then
+ return false
+ end
+ return true
+ end, fun.iter(str))), '')
+ table.insert(atoms, atom)
+ return atom
+ end
+
+ local function process_atom(atom, task)
+ local f_ret = task:has_symbol(atom)
+ lua_util.debugm(N, rspamd_config, 'check for symbol %s: %s', atom, f_ret)
+
+ if f_ret then
+ return 1
+ end
+
+ return 0
+ end
+
+ local expression = rspamd_expression.create(newrule['require_symbols'],
+ { parse_atom, process_atom }, rspamd_config:get_mempool())
+ if expression then
+ newrule['expression'] = expression
+
+ fun.each(function(v)
+ lua_util.debugm(N, rspamd_config, 'add dependency %s -> %s',
+ newrule['symbol'], v)
+ rspamd_config:register_dependency(newrule['symbol'], v)
+ end, atoms)
+ end
+ end
+ return newrule
+ end
+
+ return nil
+end
+
+-- Registration
+local opts = rspamd_config:get_all_opt(N)
+if opts and type(opts) == 'table' then
+ redis_params = rspamd_parse_redis_server(N)
+ for k, m in pairs(opts) do
+ if type(m) == 'table' and m['type'] then
+ local rule = add_multimap_rule(k, m)
+ if not rule then
+ rspamd_logger.errx(rspamd_config, 'cannot add rule: "' .. k .. '"')
+ else
+ rspamd_logger.infox(rspamd_config, 'added multimap rule: %s (%s)',
+ k, rule.type)
+ table.insert(rules, rule)
+ end
+ end
+ end
+ -- add fake symbol to check all maps inside a single callback
+ fun.each(function(rule)
+ local augmentations = {}
+
+ if rule.action then
+ table.insert(augmentations, 'passthrough')
+ end
+
+ local id = rspamd_config:register_symbol({
+ type = 'normal',
+ name = rule['symbol'],
+ augmentations = augmentations,
+ callback = gen_multimap_callback(rule),
+ })
+
+ rule.callback_id = id
+
+ if rule['symbols'] then
+ -- Find allowed symbols by this map
+ rule['symbols_set'] = {}
+ fun.each(function(s)
+ rspamd_config:register_symbol({
+ type = 'virtual',
+ name = s,
+ parent = id,
+ score = tonumber(rule.score or "0") or 0, -- Default score
+ })
+ rule['symbols_set'][s] = 1
+ end, rule['symbols'])
+ end
+ if not rule.score then
+ rspamd_logger.infox(rspamd_config, 'set default score 0 for multimap rule %s', rule.symbol)
+ rule.score = 0
+ end
+ if rule.score then
+ -- Register metric symbol
+ rule.name = rule.symbol
+ rule.description = rule.description or 'multimap symbol'
+ rule.group = rule.group or N
+
+ local tmp_flags
+ tmp_flags = rule.flags
+
+ if rule.type == 'received' and rule.flags then
+ -- XXX: hack to allow received flags/nflags
+ -- See issue #3526 on GH
+ rule.flags = nil
+ end
+
+ -- XXX: for combined maps we use trace, so flags must include one_shot to avoid scores multiplication
+ if rule.combined and not rule.flags then
+ rule.flags = 'one_shot'
+ end
+ rspamd_config:set_metric_symbol(rule)
+ rule.flags = tmp_flags
+ end
+ end, rules)
+
+ if #rules == 0 then
+ lua_util.disable_module(N, "config")
+ end
+end