diff options
Diffstat (limited to '')
-rw-r--r-- | lualib/lua_util.lua | 1639 |
1 files changed, 1639 insertions, 0 deletions
diff --git a/lualib/lua_util.lua b/lualib/lua_util.lua new file mode 100644 index 0000000..6964b0f --- /dev/null +++ b/lualib/lua_util.lua @@ -0,0 +1,1639 @@ +--[[ +Copyright (c) 2023, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_util +-- This module contains utility functions for working with Lua and/or Rspamd +--]] + +local exports = {} +local lpeg = require 'lpeg' +local rspamd_util = require "rspamd_util" +local fun = require "fun" +local lupa = require "lupa" + +local split_grammar = {} +local spaces_split_grammar +local space = lpeg.S ' \t\n\v\f\r' +local nospace = 1 - space +local ptrim = space ^ 0 * lpeg.C((space ^ 0 * nospace ^ 1) ^ 0) +local match = lpeg.match + +lupa.configure('{%', '%}', '{=', '=}', '{#', '#}', { + keep_trailing_newline = true, + autoescape = false, +}) + +lupa.filters.pbkdf = function(s) + local cr = require "rspamd_cryptobox" + return cr.pbkdf(s) +end + +local function rspamd_str_split(s, sep) + local gr + if not sep then + if not spaces_split_grammar then + local _sep = space + local elem = lpeg.C((1 - _sep) ^ 0) + local p = lpeg.Ct(elem * (_sep * elem) ^ 0) + spaces_split_grammar = p + end + + gr = spaces_split_grammar + else + gr = split_grammar[sep] + + if not gr then + local _sep + if type(sep) == 'string' then + _sep = lpeg.S(sep) -- Assume set + else + _sep = sep -- Assume lpeg object + end + local elem = lpeg.C((1 - _sep) ^ 0) + local p = lpeg.Ct(elem * (_sep * elem) ^ 0) + gr = p + split_grammar[sep] = gr + end + end + + return gr:match(s) +end + +--[[[ +-- @function lua_util.str_split(text, delimiter) +-- Splits text into a numeric table by delimiter +-- @param {string} text delimited text +-- @param {string} delimiter the delimiter +-- @return {table} numeric table containing string parts +--]] + +exports.rspamd_str_split = rspamd_str_split +exports.str_split = rspamd_str_split + +local function rspamd_str_trim(s) + return match(ptrim, s) +end +exports.rspamd_str_trim = rspamd_str_trim +--[[[ +-- @function lua_util.str_trim(text) +-- Returns a string with no trailing and leading spaces +-- @param {string} text input text +-- @return {string} string with no trailing and leading spaces +--]] +exports.str_trim = rspamd_str_trim + +--[[[ +-- @function lua_util.str_startswith(text, prefix) +-- @param {string} text +-- @param {string} prefix +-- @return {boolean} true if text starts with the specified prefix, false otherwise +--]] +exports.str_startswith = function(s, prefix) + return s:sub(1, prefix:len()) == prefix +end + +--[[[ +-- @function lua_util.str_endswith(text, suffix) +-- @param {string} text +-- @param {string} suffix +-- @return {boolean} true if text ends with the specified suffix, false otherwise +--]] +exports.str_endswith = function(s, suffix) + return s:find(suffix, -suffix:len(), true) ~= nil +end + +--[[[ +-- @function lua_util.round(number, decimalPlaces) +-- Round number to fixed number of decimal points +-- @param {number} number number to round +-- @param {number} decimalPlaces number of decimal points +-- @return {number} rounded number +--]] + +-- modified version from Robert Jay Gould http://lua-users.org/wiki/SimpleRound +exports.round = function(num, numDecimalPlaces) + local mult = 10 ^ (numDecimalPlaces or 0) + if num >= 0 then + return math.floor(num * mult + 0.5) / mult + else + return math.ceil(num * mult - 0.5) / mult + end +end + +--[[[ +-- @function lua_util.template(text, replacements) +-- Replaces values in a text template +-- Variable names can contain letters, numbers and underscores, are prefixed with `$` and may or not use curly braces. +-- @param {string} text text containing variables +-- @param {table} replacements key/value pairs for replacements +-- @return {string} string containing replaced values +-- @example +-- local goop = lua_util.template("HELLO $FOO ${BAR}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'}) +-- -- goop contains "HELLO LUA WORLD!" +--]] + +exports.template = function(tmpl, keys) + local var_lit = lpeg.P { lpeg.R("az") + lpeg.R("AZ") + lpeg.R("09") + "_" } + local var = lpeg.P { (lpeg.P("$") / "") * ((var_lit ^ 1) / keys) } + local var_braced = lpeg.P { (lpeg.P("${") / "") * ((var_lit ^ 1) / keys) * (lpeg.P("}") / "") } + + local template_grammar = lpeg.Cs((var + var_braced + 1) ^ 0) + + return lpeg.match(template_grammar, tmpl) +end + +local function enrich_template_with_globals(env) + local newenv = exports.shallowcopy(env) + newenv.paths = rspamd_paths + newenv.env = rspamd_env + + return newenv +end +--[[[ +-- @function lua_util.jinja_template(text, env[, skip_global_env]) +-- Replaces values in a text template according to jinja2 syntax +-- @param {string} text text containing variables +-- @param {table} replacements key/value pairs for replacements +-- @param {boolean} skip_global_env don't export Rspamd superglobals +-- @return {string} string containing replaced values +-- @example +-- lua_util.jinja_template("HELLO {{FOO}} {{BAR}}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'}) +-- "HELLO LUA WORLD!" +--]] +exports.jinja_template = function(text, env, skip_global_env) + if not skip_global_env then + env = enrich_template_with_globals(env) + end + + return lupa.expand(text, env) +end + +--[[[ +-- @function lua_util.jinja_file(filename, env[, skip_global_env]) +-- Replaces values in a text template according to jinja2 syntax +-- @param {string} filename name of file to expand +-- @param {table} replacements key/value pairs for replacements +-- @param {boolean} skip_global_env don't export Rspamd superglobals +-- @return {string} string containing replaced values +-- @example +-- lua_util.jinja_template("HELLO {{FOO}} {{BAR}}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'}) +-- "HELLO LUA WORLD!" +--]] +exports.jinja_template_file = function(filename, env, skip_global_env) + if not skip_global_env then + env = enrich_template_with_globals(env) + end + + return lupa.expand_file(filename, env) +end + +exports.remove_email_aliases = function(email_addr) + local function check_gmail_user(addr) + -- Remove all points + local no_dots_user = string.gsub(addr.user, '%.', '') + local cap, pluses = string.match(no_dots_user, '^([^%+][^%+]*)(%+.*)$') + if cap then + return cap, rspamd_str_split(pluses, '+'), nil + elseif no_dots_user ~= addr.user then + return no_dots_user, {}, nil + end + + return nil + end + + local function check_address(addr) + if addr.user then + local cap, pluses = string.match(addr.user, '^([^%+][^%+]*)(%+.*)$') + if cap then + return cap, rspamd_str_split(pluses, '+'), nil + end + end + + return nil + end + + local function set_addr(addr, new_user, new_domain) + if new_user then + addr.user = new_user + end + if new_domain then + addr.domain = new_domain + end + + if addr.domain then + addr.addr = string.format('%s@%s', addr.user, addr.domain) + else + addr.addr = string.format('%s@', addr.user) + end + + if addr.name and #addr.name > 0 then + addr.raw = string.format('"%s" <%s>', addr.name, addr.addr) + else + addr.raw = string.format('<%s>', addr.addr) + end + end + + local function check_gmail(addr) + local nu, tags, nd = check_gmail_user(addr) + + if nu then + return nu, tags, nd + end + + return nil + end + + local function check_googlemail(addr) + local nd = 'gmail.com' + local nu, tags = check_gmail_user(addr) + + if nu then + return nu, tags, nd + end + + return nil, nil, nd + end + + local specific_domains = { + ['gmail.com'] = check_gmail, + ['googlemail.com'] = check_googlemail, + } + + if email_addr then + if email_addr.domain and specific_domains[email_addr.domain] then + local nu, tags, nd = specific_domains[email_addr.domain](email_addr) + if nu or nd then + set_addr(email_addr, nu, nd) + + return nu, tags + end + else + local nu, tags, nd = check_address(email_addr) + if nu or nd then + set_addr(email_addr, nu, nd) + + return nu, tags + end + end + + return nil + end +end + +exports.is_rspamc_or_controller = function(task) + local ua = task:get_request_header('User-Agent') or '' + local pwd = task:get_request_header('Password') + local is_rspamc = false + if tostring(ua) == 'rspamc' or pwd then + is_rspamc = true + end + + return is_rspamc +end + +--[[[ +-- @function lua_util.unpack(table) +-- Converts numeric table to varargs +-- This is `unpack` on Lua 5.1/5.2/LuaJIT and `table.unpack` on Lua 5.3 +-- @param {table} table numerically indexed table to unpack +-- @return {varargs} unpacked table elements +--]] + +local unpack_function = table.unpack or unpack +exports.unpack = function(t) + return unpack_function(t) +end + +--[[[ +-- @function lua_util.flatten(table) +-- Flatten underlying tables in a single table +-- @param {table} table table of tables +-- @return {table} flattened table +--]] +exports.flatten = function(t) + local res = {} + for _, e in fun.iter(t) do + for _, v in fun.iter(e) do + res[#res + 1] = v + end + end + + return res +end + +--[[[ +-- @function lua_util.spairs(table) +-- Like `pairs` but keys are sorted lexicographically +-- @param {table} table table containing key/value pairs +-- @return {function} generator function returning key/value pairs +--]] + +-- Sorted iteration: +-- for k,v in spairs(t) do ... end +-- +-- or with custom comparison: +-- for k, v in spairs(t, function(t, a, b) return t[a] < t[b] end) +-- +-- optional limit is also available (e.g. return top X elements) +local function spairs(t, order, lim) + -- collect the keys + local keys = {} + for k in pairs(t) do + keys[#keys + 1] = k + end + + -- if order function given, sort by it by passing the table and keys a, b, + -- otherwise just sort the keys + if order then + table.sort(keys, function(a, b) + return order(t, a, b) + end) + else + table.sort(keys) + end + + -- return the iterator function + local i = 0 + return function() + i = i + 1 + if not lim or i <= lim then + if keys[i] then + return keys[i], t[keys[i]] + end + end + end +end + +exports.spairs = spairs + +local lua_cfg_utils = require "lua_cfg_utils" + +exports.config_utils = lua_cfg_utils +exports.disable_module = lua_cfg_utils.disable_module + +--[[[ +-- @function lua_util.disable_module(modname) +-- Checks experimental plugins state and disable if needed +-- @param {string} modname name of plugin to check +-- @return {boolean} true if plugin should be enabled, false otherwise +--]] +local function check_experimental(modname) + if rspamd_config:experimental_enabled() then + return true + else + lua_cfg_utils.disable_module(modname, 'experimental') + end + + return false +end + +exports.check_experimental = check_experimental + +--[[[ +-- @function lua_util.list_to_hash(list) +-- Converts numerically-indexed table to table indexed by values +-- @param {table} list numerically-indexed table or string, which is treated as a one-element list +-- @return {table} table indexed by values +-- @example +-- local h = lua_util.list_to_hash({"a", "b"}) +-- -- h contains {a = true, b = true} +--]] +local function list_to_hash(list) + if type(list) == 'table' then + if list[1] then + local h = {} + for _, e in ipairs(list) do + h[e] = true + end + return h + else + return list + end + elseif type(list) == 'string' then + local h = {} + h[list] = true + return h + end +end + +exports.list_to_hash = list_to_hash + +--[[[ +-- @function lua_util.nkeys(table|gen, param, state) +-- Returns number of keys in a table (i.e. from both the array and hash parts combined) +-- @param {table} list numerically-indexed table or string, which is treated as a one-element list +-- @return {number} number of keys +-- @example +-- print(lua_util.nkeys({})) -- 0 +-- print(lua_util.nkeys({ "a", nil, "b" })) -- 2 +-- print(lua_util.nkeys({ dog = 3, cat = 4, bird = nil })) -- 2 +-- print(lua_util.nkeys({ "a", dog = 3, cat = 4 })) -- 3 +-- +--]] +local function nkeys(gen, param, state) + local n = 0 + if not param then + for _, _ in pairs(gen) do + n = n + 1 + end + else + for _, _ in fun.iter(gen, param, state) do + n = n + 1 + end + end + return n +end + +exports.nkeys = nkeys + +--[[[ +-- @function lua_util.parse_time_interval(str) +-- Parses human readable time interval +-- Accepts 's' for seconds, 'm' for minutes, 'h' for hours, 'd' for days, +-- 'w' for weeks, 'y' for years +-- @param {string} str input string +-- @return {number|nil} parsed interval as seconds (might be fractional) +--]] +local function parse_time_interval(str) + local function parse_time_suffix(s) + if s == 's' then + return 1 + elseif s == 'm' then + return 60 + elseif s == 'h' then + return 3600 + elseif s == 'd' then + return 86400 + elseif s == 'w' then + return 86400 * 7 + elseif s == 'y' then + return 365 * 86400; + end + end + + local digit = lpeg.R("09") + local parser = {} + parser.integer = (lpeg.S("+-") ^ -1) * + (digit ^ 1) + parser.fractional = (lpeg.P(".")) * + (digit ^ 1) + parser.number = (parser.integer * + (parser.fractional ^ -1)) + + (lpeg.S("+-") * parser.fractional) + parser.time = lpeg.Cf(lpeg.Cc(1) * + (parser.number / tonumber) * + ((lpeg.S("smhdwy") / parse_time_suffix) ^ -1), + function(acc, val) + return acc * val + end) + + local t = lpeg.match(parser.time, str) + + return t +end + +exports.parse_time_interval = parse_time_interval + +--[[[ +-- @function lua_util.dehumanize_number(str) +-- Parses human readable number +-- Accepts 'k' for thousands, 'm' for millions, 'g' for billions, 'b' suffix for 1024 multiplier, +-- e.g. `10mb` equal to `10 * 1024 * 1024` +-- @param {string} str input string +-- @return {number|nil} parsed number +--]] +local function dehumanize_number(str) + local function parse_suffix(s) + if s == 'k' then + return 1000 + elseif s == 'm' then + return 1000000 + elseif s == 'g' then + return 1e9 + elseif s == 'kb' then + return 1024 + elseif s == 'mb' then + return 1024 * 1024 + elseif s == 'gb' then + return 1024 * 1024; + end + end + + local digit = lpeg.R("09") + local parser = {} + parser.integer = (lpeg.S("+-") ^ -1) * + (digit ^ 1) + parser.fractional = (lpeg.P(".")) * + (digit ^ 1) + parser.number = (parser.integer * + (parser.fractional ^ -1)) + + (lpeg.S("+-") * parser.fractional) + parser.humanized_number = lpeg.Cf(lpeg.Cc(1) * + (parser.number / tonumber) * + (((lpeg.S("kmg") * (lpeg.P("b") ^ -1)) / parse_suffix) ^ -1), + function(acc, val) + return acc * val + end) + + local t = lpeg.match(parser.humanized_number, str) + + return t +end + +exports.dehumanize_number = dehumanize_number + +--[[[ +-- @function lua_util.table_cmp(t1, t2) +-- Compare two tables deeply +--]] +local function table_cmp(table1, table2) + local avoid_loops = {} + local function recurse(t1, t2) + if type(t1) ~= type(t2) then + return false + end + if type(t1) ~= "table" then + return t1 == t2 + end + + if avoid_loops[t1] then + return avoid_loops[t1] == t2 + end + avoid_loops[t1] = t2 + -- Copy keys from t2 + local t2keys = {} + local t2tablekeys = {} + for k, _ in pairs(t2) do + if type(k) == "table" then + table.insert(t2tablekeys, k) + end + t2keys[k] = true + end + -- Let's iterate keys from t1 + for k1, v1 in pairs(t1) do + local v2 = t2[k1] + if type(k1) == "table" then + -- if key is a table, we need to find an equivalent one. + local ok = false + for i, tk in ipairs(t2tablekeys) do + if table_cmp(k1, tk) and recurse(v1, t2[tk]) then + table.remove(t2tablekeys, i) + t2keys[tk] = nil + ok = true + break + end + end + if not ok then + return false + end + else + -- t1 has a key which t2 doesn't have, fail. + if v2 == nil then + return false + end + t2keys[k1] = nil + if not recurse(v1, v2) then + return false + end + end + end + -- if t2 has a key which t1 doesn't have, fail. + if next(t2keys) then + return false + end + return true + end + return recurse(table1, table2) +end + +exports.table_cmp = table_cmp + +--[[[ +-- @function lua_util.table_merge(t1, t2) +-- Merge two tables +--]] +local function table_merge(t1, t2) + local res = {} + local nidx = 1 -- for numeric indicies + local it_func = function(k, v) + if type(k) == 'number' then + res[nidx] = v + nidx = nidx + 1 + else + res[k] = v + end + end + for k, v in pairs(t1) do + it_func(k, v) + end + for k, v in pairs(t2) do + it_func(k, v) + end + return res +end + +exports.table_merge = table_merge + +--[[[ +-- @function lua_util.table_cmp(task, name, value, stop_chars) +-- Performs header folding +--]] +exports.fold_header = function(task, name, value, stop_chars) + + local how + + if task:has_flag("milter") then + how = "lf" + else + how = task:get_newlines_type() + end + + return rspamd_util.fold_header(name, value, how, stop_chars) +end + +--[[[ +-- @function lua_util.override_defaults(defaults, override) +-- Overrides values from defaults with override +--]] +local function override_defaults(def, override) + -- Corner cases + if not override or type(override) ~= 'table' then + return def + end + if not def or type(def) ~= 'table' then + return override + end + + local res = {} + + for k, v in pairs(override) do + if type(v) == 'table' then + if def[k] and type(def[k]) == 'table' then + -- Recursively override elements + res[k] = override_defaults(def[k], v) + else + res[k] = v + end + else + res[k] = v + end + end + + for k, v in pairs(def) do + if type(res[k]) == 'nil' then + res[k] = v + end + end + + return res +end + +exports.override_defaults = override_defaults + +--[[[ +-- @function lua_util.filter_specific_urls(urls, params) +-- params: { +- - task - if needed to save in the cache +- - limit <int> (default = 9999) +- - esld_limit <int> (default = 9999) n domains per eSLD (effective second level domain) + works only if number of unique eSLD less than `limit` +- - need_emails <bool> (default = false) +- - filter <callback> (default = nil) +- - prefix <string> cache prefix (default = nil) +-- } +-- Apply heuristic in extracting of urls from `urls` table, this function +-- tries its best to extract specific number of urls from a task based on +-- their characteristics +--]] +exports.filter_specific_urls = function(urls, params) + local cache_key + + if params.task and not params.no_cache then + if params.prefix then + cache_key = params.prefix + else + cache_key = string.format('sp_urls_%d%s%s%s', params.limit, + tostring(params.need_emails or false), + tostring(params.need_images or false), + tostring(params.need_content or false)) + end + local cached = params.task:cache_get(cache_key) + + if cached then + return cached + end + end + + if not urls then + return {} + end + + if params.filter then + urls = fun.totable(fun.filter(params.filter, urls)) + end + + -- Filter by tld: + local tlds = {} + local eslds = {} + local ntlds, neslds = 0, 0 + + local res = {} + local nres = 0 + + local function insert_url(str, u) + if not res[str] then + res[str] = u + nres = nres + 1 + + return true + end + + return false + end + + local function process_single_url(u, default_priority) + local priority = default_priority or 1 -- Normal priority + local flags = u:get_flags() + if params.ignore_ip and flags.numeric then + return + end + + if flags.redirected then + local redir = u:get_redirected() -- get the real url + + if params.ignore_redirected then + -- Replace `u` with redir + u = redir + priority = 2 + else + -- Process both redirected url and the original one + process_single_url(redir, 2) + end + end + + if flags.image then + if not params.need_images then + -- Ignore url + return + else + -- Penalise images in urls + priority = 0 + end + end + + local esld = u:get_tld() + local str_hash = tostring(u) + + if esld then + -- Special cases + if (u:get_protocol() ~= 'mailto') and (not flags.html_displayed) then + if flags.obscured then + priority = 3 + else + if (flags.has_user or flags.has_port) then + priority = 2 + elseif (flags.subject or flags.phished) then + priority = 2 + end + end + elseif flags.html_displayed then + priority = 0 + end + + if not eslds[esld] then + eslds[esld] = { { str_hash, u, priority } } + neslds = neslds + 1 + else + if #eslds[esld] < params.esld_limit then + table.insert(eslds[esld], { str_hash, u, priority }) + end + end + + + -- eSLD - 1 part => tld + local parts = rspamd_str_split(esld, '.') + local tld = table.concat(fun.totable(fun.tail(parts)), '.') + + if not tlds[tld] then + tlds[tld] = { { str_hash, u, priority } } + ntlds = ntlds + 1 + else + table.insert(tlds[tld], { str_hash, u, priority }) + end + end + end + + for _, u in ipairs(urls) do + process_single_url(u) + end + + local limit = params.limit + limit = limit - nres + if limit < 0 then + limit = 0 + end + + if limit == 0 then + res = exports.values(res) + if params.task and not params.no_cache then + params.task:cache_set(cache_key, res) + end + return res + end + + -- Sort eSLDs and tlds + local function sort_stuff(tbl) + -- Sort according to max priority + table.sort(tbl, function(e1, e2) + -- Sort by priority so max priority is at the end + table.sort(e1, function(tr1, tr2) + return tr1[3] < tr2[3] + end) + table.sort(e2, function(tr1, tr2) + return tr1[3] < tr2[3] + end) + + if e1[#e1][3] ~= e2[#e2][3] then + -- Sort by priority so max priority is at the beginning + return e1[#e1][3] > e2[#e2][3] + else + -- Prefer less urls to more urls per esld + return #e1 < #e2 + end + + end) + + return tbl + end + + eslds = sort_stuff(exports.values(eslds)) + neslds = #eslds + + if neslds <= limit then + -- Number of eslds < limit + repeat + local item_found = false + + for _, lurls in ipairs(eslds) do + if #lurls > 0 then + local last = table.remove(lurls) + insert_url(last[1], last[2]) + limit = limit - 1 + item_found = true + end + end + + until limit <= 0 or not item_found + + res = exports.values(res) + if params.task and not params.no_cache then + params.task:cache_set(cache_key, res) + end + return res + end + + tlds = sort_stuff(exports.values(tlds)) + ntlds = #tlds + + -- Number of tlds < limit + while limit > 0 do + for _, lurls in ipairs(tlds) do + if #lurls > 0 then + local last = table.remove(lurls) + insert_url(last[1], last[2]) + limit = limit - 1 + end + if limit == 0 then + break + end + end + end + + res = exports.values(res) + if params.task and not params.no_cache then + params.task:cache_set(cache_key, res) + end + return res +end + +--[[[ +-- @function lua_util.extract_specific_urls(params) +-- params: { +- - task +- - limit <int> (default = 9999) +- - esld_limit <int> (default = 9999) n domains per eSLD (effective second level domain) + works only if number of unique eSLD less than `limit` +- - need_emails <bool> (default = false) +- - filter <callback> (default = nil) +- - prefix <string> cache prefix (default = nil) +- - ignore_redirected <bool> (default = false) +- - need_images <bool> (default = false) +- - need_content <bool> (default = false) +-- } +-- Apply heuristic in extracting of urls from task, this function +-- tries its best to extract specific number of urls from a task based on +-- their characteristics +--]] +-- exports.extract_specific_urls = function(params_or_task, limit, need_emails, filter, prefix) +exports.extract_specific_urls = function(params_or_task, lim, need_emails, filter, prefix) + local default_params = { + limit = 9999, + esld_limit = 9999, + need_emails = false, + need_images = false, + need_content = false, + filter = nil, + prefix = nil, + ignore_ip = false, + ignore_redirected = false, + no_cache = false, + } + + local params + if type(params_or_task) == 'table' and type(lim) == 'nil' then + params = params_or_task + else + -- Deprecated call + params = { + task = params_or_task, + limit = lim, + need_emails = need_emails, + filter = filter, + prefix = prefix + } + end + for k, v in pairs(default_params) do + if type(params[k]) == 'nil' and v ~= nil then + params[k] = v + end + end + local url_params = { + emails = params.need_emails, + images = params.need_images, + content = params.need_content, + flags = params.flags, -- maybe nil + flags_mode = params.flags_mode, -- maybe nil + } + + -- Shortcut for cached stuff + if params.task and not params.no_cache then + local cache_key + if params.prefix then + cache_key = params.prefix + else + local cache_key_suffix + if params.flags then + cache_key_suffix = table.concat(params.flags) .. (params.flags_mode or '') + else + cache_key_suffix = string.format('%s%s%s', + tostring(params.need_emails or false), + tostring(params.need_images or false), + tostring(params.need_content or false)) + end + cache_key = string.format('sp_urls_%d%s', params.limit, cache_key_suffix) + end + local cached = params.task:cache_get(cache_key) + + if cached then + return cached + end + end + + -- No cache version + local urls = params.task:get_urls(url_params) + + return exports.filter_specific_urls(urls, params) +end + +--[[[ +-- @function lua_util.deepcopy(table) +-- params: { +- - table +-- } +-- Performs deep copy of the table. Including metatables +--]] +local function deepcopy(orig) + local orig_type = type(orig) + local copy + if orig_type == 'table' then + copy = {} + for orig_key, orig_value in next, orig, nil do + copy[deepcopy(orig_key)] = deepcopy(orig_value) + end + if getmetatable(orig) then + setmetatable(copy, deepcopy(getmetatable(orig))) + end + else + -- number, string, boolean, etc + copy = orig + end + return copy +end + +exports.deepcopy = deepcopy + +--[[[ +-- @function lua_util.deepsort(table) +-- params: { +- - table +-- } +-- Performs recursive in-place sort of a table +--]] +local function default_sort_cmp(e1, e2) + if type(e1) == type(e2) then + return e1 < e2 + else + return type(e1) < type(e2) + end +end + +local function deepsort(tbl, sort_func) + local orig_type = type(tbl) + if orig_type == 'table' then + table.sort(tbl, sort_func or default_sort_cmp) + for _, orig_value in next, tbl, nil do + deepsort(orig_value) + end + end +end + +exports.deepsort = deepsort + +--[[[ +-- @function lua_util.shallowcopy(tbl) +-- Performs shallow (and fast) copy of a table or another Lua type +--]] +exports.shallowcopy = function(orig) + local orig_type = type(orig) + local copy + if orig_type == 'table' then + copy = {} + for orig_key, orig_value in pairs(orig) do + copy[orig_key] = orig_value + end + else + copy = orig + end + return copy +end + +-- Debugging support +local logger = require "rspamd_logger" +local unconditional_debug = logger.log_level() == 'debug' +local debug_modules = {} +local debug_aliases = {} +local log_level = 384 -- debug + forced (1 << 7 | 1 << 8) + + +exports.init_debug_logging = function(config) + -- Fill debug modules from the config + if not unconditional_debug then + local log_config = config:get_all_opt('logging') + if log_config then + local log_level_str = log_config.level + if log_level_str then + if log_level_str == 'debug' then + unconditional_debug = true + end + end + if log_config.debug_modules then + for _, m in ipairs(log_config.debug_modules) do + debug_modules[m] = true + logger.infox(config, 'enable debug for Lua module %s', m) + end + end + + if #debug_aliases > 0 then + for alias, mod in pairs(debug_aliases) do + if debug_modules[mod] then + debug_modules[alias] = true + logger.infox(config, 'enable debug for Lua module %s (%s aliased)', + alias, mod) + end + end + end + end + end +end + +exports.enable_debug_logging = function() + unconditional_debug = true +end + +exports.enable_debug_modules = function(...) + for _, m in ipairs({ ... }) do + debug_modules[m] = true + end +end + +exports.disable_debug_logging = function() + unconditional_debug = false +end + +--[[[ +-- @function lua_util.debugm(module, [log_object], format, ...) +-- Performs fast debug log for a specific module +--]] +exports.debugm = function(mod, obj_or_fmt, fmt_or_something, ...) + if unconditional_debug or debug_modules[mod] then + if type(obj_or_fmt) == 'string' then + logger.logx(log_level, mod, '', 2, obj_or_fmt, fmt_or_something, ...) + else + logger.logx(log_level, mod, obj_or_fmt, 2, fmt_or_something, ...) + end + end +end + +--[[[ +-- @function lua_util.add_debug_alias(mod, alias) +-- Add debugging alias so logging to `alias` will be treated as logging to `mod` +--]] +exports.add_debug_alias = function(mod, alias) + debug_aliases[alias] = mod + + if debug_modules[mod] then + debug_modules[alias] = true + logger.infox(rspamd_config, 'enable debug for Lua module %s (%s aliased)', + alias, mod) + end +end +---[[[ +-- @function lua_util.get_task_verdict(task) +-- Returns verdict for a task + score if certain, must be called from idempotent filters only +-- Returns string: +-- * `spam`: if message have over reject threshold and has more than one positive rule +-- * `junk`: if a message has between score between [add_header/rewrite subject] to reject thresholds and has more than two positive rules +-- * `passthrough`: if a message has been passed through some short-circuit rule +-- * `ham`: if a message has overall score below junk level **and** more than three negative rule, or negative total score +-- * `uncertain`: all other cases +--]] +exports.get_task_verdict = function(task) + local lua_verdict = require "lua_verdict" + + return lua_verdict.get_default_verdict(task) +end + +---[[[ +-- @function lua_util.maybe_obfuscate_string(subject, settings, prefix) +-- Obfuscate string if enabled in settings. Also checks utf8 validity - if +-- string is not valid utf8 then '???' is returned. Empty string returned as is. +-- Supported settings: +-- * <prefix>_privacy = false - subject privacy is off +-- * <prefix>_privacy_alg = 'blake2' - default hash-algorithm to obfuscate subject +-- * <prefix>_privacy_prefix = 'obf' - prefix to show it's obfuscated +-- * <prefix>_privacy_length = 16 - cut the length of the hash; if 0 or fasle full hash is returned +-- @return obfuscated or validated subject +--]] + +exports.maybe_obfuscate_string = function(subject, settings, prefix) + local hash = require 'rspamd_cryptobox_hash' + if not subject or subject == '' then + return subject + elseif not rspamd_util.is_valid_utf8(subject) then + subject = '???' + elseif settings[prefix .. '_privacy'] then + local hash_alg = settings[prefix .. '_privacy_alg'] or 'blake2' + local subject_hash = hash.create_specific(hash_alg, subject) + + local strip_len = settings[prefix .. '_privacy_length'] + if strip_len and strip_len > 0 then + subject = subject_hash:hex():sub(1, strip_len) + else + subject = subject_hash:hex() + end + + local privacy_prefix = settings[prefix .. '_privacy_prefix'] + if privacy_prefix and #privacy_prefix > 0 then + subject = privacy_prefix .. ':' .. subject + end + end + + return subject +end + +---[[[ +-- @function lua_util.callback_from_string(str) +-- Converts a string like `return function(...) end` to lua function and return true and this function +-- or returns false + error message +-- @return status code and function object or an error message +--]]] +exports.callback_from_string = function(s) + local loadstring = loadstring or load + + if not s or #s == 0 then + return false, 'invalid or empty string' + end + + s = exports.rspamd_str_trim(s) + local inp + + if s:match('^return%s*function') then + -- 'return function', can be evaluated directly + inp = s + elseif s:match('^function%s*%(') then + inp = 'return ' .. s + else + -- Just a plain sequence + inp = 'return function(...)\n' .. s .. '; end' + end + + local ret, res_or_err = pcall(loadstring(inp)) + + if not ret or type(res_or_err) ~= 'function' then + return false, res_or_err + end + + return ret, res_or_err +end + +---[[[ +-- @function lua_util.keys(t) +-- Returns all keys from a specific table +-- @param {table} t input table (or iterator triplet) +-- @return array of keys +--]]] +exports.keys = function(gen, param, state) + local keys = {} + local i = 1 + + if param then + for k, _ in fun.iter(gen, param, state) do + rawset(keys, i, k) + i = i + 1 + end + else + for k, _ in pairs(gen) do + rawset(keys, i, k) + i = i + 1 + end + end + + return keys +end + +---[[[ +-- @function lua_util.values(t) +-- Returns all values from a specific table +-- @param {table} t input table +-- @return array of values +--]]] +exports.values = function(gen, param, state) + local values = {} + local i = 1 + + if param then + for _, v in fun.iter(gen, param, state) do + rawset(values, i, v) + i = i + 1 + end + else + for _, v in pairs(gen) do + rawset(values, i, v) + i = i + 1 + end + end + + return values +end + +---[[[ +-- @function lua_util.distance_sorted(t1, t2) +-- Returns distance between two sorted tables t1 and t2 +-- @param {table} t1 input table +-- @param {table} t2 input table +-- @return distance between `t1` and `t2` +--]]] +exports.distance_sorted = function(t1, t2) + local ncomp = #t1 + local ndiff = 0 + local i, j = 1, 1 + + if ncomp < #t2 then + ncomp = #t2 + end + + for _ = 1, ncomp do + if j > #t2 then + ndiff = ndiff + ncomp - #t2 + if i > j then + ndiff = ndiff - (i - j) + end + break + elseif i > #t1 then + ndiff = ndiff + ncomp - #t1 + if j > i then + ndiff = ndiff - (j - i) + end + break + end + + if t1[i] == t2[j] then + i = i + 1 + j = j + 1 + elseif t1[i] < t2[j] then + i = i + 1 + ndiff = ndiff + 1 + else + j = j + 1 + ndiff = ndiff + 1 + end + end + + return ndiff +end + +---[[[ +-- @function lua_util.table_digest(t) +-- Returns hash of all values if t[1] is string or all keys/values otherwise +-- @param {table} t input array or map +-- @return {string} base32 representation of blake2b hash of all strings +--]]] +local function table_digest(t) + local cr = require "rspamd_cryptobox_hash" + local h = cr.create() + + if t[1] then + for _, e in ipairs(t) do + if type(e) == 'table' then + h:update(table_digest(e)) + else + h:update(tostring(e)) + end + end + else + for k, v in pairs(t) do + h:update(tostring(k)) + + if type(v) == 'string' then + h:update(v) + elseif type(v) == 'table' then + h:update(table_digest(v)) + end + end + end + return h:base32() +end + +exports.table_digest = table_digest + +---[[[ +-- @function lua_util.toboolean(v) +-- Converts a string or a number to boolean +-- @param {string|number} v +-- @return {boolean} v converted to boolean +--]]] +exports.toboolean = function(v) + local true_t = { + ['1'] = true, + ['true'] = true, + ['TRUE'] = true, + ['True'] = true, + }; + local false_t = { + ['0'] = false, + ['false'] = false, + ['FALSE'] = false, + ['False'] = false, + }; + + if type(v) == 'string' then + if true_t[v] == true then + return true; + elseif false_t[v] == false then + return false; + else + return false, string.format('cannot convert %q to boolean', v); + end + elseif type(v) == 'number' then + return v ~= 0 + else + return false, string.format('cannot convert %q to boolean', v); + end +end + +---[[[ +-- @function lua_util.config_check_local_or_authed(config, modname) +-- Reads check_local and check_authed from the config as this is used in many modules +-- @param {rspamd_config} config `rspamd_config` global +-- @param {name} module name +-- @return {boolean} v converted to boolean +--]]] +exports.config_check_local_or_authed = function(rspamd_config, modname, def_local, def_authed) + local check_local = def_local or false + local check_authed = def_authed or false + + local function try_section(where) + local ret = false + local opts = rspamd_config:get_all_opt(where) + if type(opts) == 'table' then + if type(opts['check_local']) == 'boolean' then + check_local = opts['check_local'] + ret = true + end + if type(opts['check_authed']) == 'boolean' then + check_authed = opts['check_authed'] + ret = true + end + end + + return ret + end + + if not try_section(modname) then + try_section('options') + end + + return { check_local, check_authed } +end + +---[[[ +-- @function lua_util.is_skip_local_or_authed(task, conf[, ip]) +-- Returns `true` if local or authenticated task should be skipped for this module +-- @param {rspamd_task} task +-- @param {table} conf table returned from `config_check_local_or_authed` +-- @param {rspamd_ip} ip optional ip address (can be obtained from a task) +-- @return {boolean} true if check should be skipped +--]]] +exports.is_skip_local_or_authed = function(task, conf, ip) + if not ip then + ip = task:get_from_ip() + end + if not conf then + conf = { false, false } + end + if ((not conf[2] and task:get_user()) or + (not conf[1] and type(ip) == 'userdata' and ip:is_local())) then + return true + end + + return false +end + +---[[[ +-- @function lua_util.maybe_smtp_quote_value(str) +-- Checks string for the forbidden elements (tspecials in RFC and quote string if needed) +-- @param {string} str input string +-- @return {string} original or quoted string +--]]] +local tspecial = lpeg.S "()<>,;:\\\"/[]?= \t\v" +local special_match = lpeg.P((1 - tspecial) ^ 0 * tspecial ^ 1) +exports.maybe_smtp_quote_value = function(str) + if special_match:match(str) then + return string.format('"%s"', str:gsub('"', '\\"')) + end + + return str +end + +---[[[ +-- @function lua_util.shuffle(table) +-- Performs in-place shuffling of a table +-- @param {table} tbl table to shuffle +-- @return {table} same table +--]]] +exports.shuffle = function(tbl) + local size = #tbl + for i = size, 1, -1 do + local rand = math.random(size) + tbl[i], tbl[rand] = tbl[rand], tbl[i] + end + return tbl +end + +-- +local hex_table = {} +for idx = 0, 255 do + hex_table[("%02X"):format(idx)] = string.char(idx) + hex_table[("%02x"):format(idx)] = string.char(idx) +end + +---[[[ +-- @function lua_util.unhex(str) +-- Decode hex encoded string +-- @param {string} str string to decode +-- @return {string} hex decoded string (valid hex pairs are decoded, everything else is printed as is) +--]]] +exports.unhex = function(str) + return str:gsub('(..)', hex_table) +end + +local http_upstream_lists = {} +local function http_upstreams_by_url(pool, url) + local rspamd_url = require "rspamd_url" + + local cached = http_upstream_lists[url] + if cached then + return cached + end + + local real_url = rspamd_url.create(pool, url) + + if not real_url then + return nil + end + + local host = real_url:get_host() + local proto = real_url:get_protocol() or 'http' + local port = real_url:get_port() or (proto == 'https' and 443 or 80) + local upstream_list = require "rspamd_upstream_list" + local upstreams = upstream_list.create(host, port) + + if upstreams then + http_upstream_lists[url] = upstreams + return upstreams + end + + return nil +end +---[[[ +-- @function lua_util.http_upstreams_by_url(pool, url) +-- Returns a cached or new upstreams list that corresponds to the specific url +-- @param {mempool} pool memory pool to use (typically static pool from rspamd_config) +-- @param {string} url full url +-- @return {upstreams_list} object to get upstream from an url +--]]] +exports.http_upstreams_by_url = http_upstreams_by_url + +---[[[ +-- @function lua_util.dns_timeout_augmentation(cfg) +-- Returns an augmentation suitable to define DNS timeout for a module +-- @return {string} a string in format 'timeout=x' where `x` is a number of seconds for DNS timeout +--]]] +local function dns_timeout_augmentation(cfg) + return string.format('timeout=%f', cfg:get_dns_timeout() or 0.0) +end + +exports.dns_timeout_augmentation = dns_timeout_augmentation + +---[[[ +--- @function lua_util.strip_lua_comments(lua_code) +-- Strips single-line and multi-line comments from a given Lua code string and removes +-- any extra spaces or newlines. +-- +-- @param lua_code The Lua code string to strip comments from. +-- @return The resulting Lua code string with comments and extra spaces removed. +-- +---]]] +local function strip_lua_comments(lua_code) + -- Remove single-line comments + lua_code = lua_code:gsub("%-%-[^\r\n]*", "") + + -- Remove multi-line comments + lua_code = lua_code:gsub("%-%-%[%[.-%]%]", "") + + -- Remove extra spaces and newlines + lua_code = lua_code:gsub("%s+", " ") + + return lua_code +end + +exports.strip_lua_comments = strip_lua_comments + +---[[[ +-- @function lua_util.join_path(...) +-- Joins path components into a single path string using the appropriate separator +-- for the current operating system. +-- +-- @param ... Any number of path components to join together. +-- @return A single path string, with components separated by the appropriate separator. +-- +---]]] +local path_sep = package.config:sub(1, 1) or '/' +local function join_path(...) + local components = { ... } + + -- Join components using separator + return table.concat(components, path_sep) +end +exports.join_path = join_path + +-- Short unit test for sanity +if path_sep == '/' then + assert(join_path('/path', 'to', 'file') == '/path/to/file') +else + assert(join_path('C:', 'path', 'to', 'file') == 'C:\\path\\to\\file') +end + +-- Defines symbols priorities for common usage in prefilters/postfilters +exports.symbols_priorities = { + top = 10, -- Symbols must be executed first (or last), such as settings + high = 9, -- Example: asn + medium = 5, -- Everything should use this as default + low = 0, +} + +return exports |