diff options
Diffstat (limited to 'lualib')
99 files changed, 32642 insertions, 0 deletions
diff --git a/lualib/ansicolors.lua b/lualib/ansicolors.lua new file mode 100644 index 0000000..81783f6 --- /dev/null +++ b/lualib/ansicolors.lua @@ -0,0 +1,68 @@ +local colormt = {} +local ansicolors = {} + +local rspamd_util = require "rspamd_util" +local isatty = rspamd_util.isatty() + +function colormt:__tostring() + return self.value +end + +function colormt:__concat(other) + return tostring(self) .. tostring(other) +end + +function colormt:__call(s) + return self .. s .. ansicolors.reset +end + +colormt.__metatable = {} +local function makecolor(value) + if isatty then + return setmetatable({ + value = string.char(27) .. '[' .. tostring(value) .. 'm' + }, colormt) + else + return setmetatable({ + value = '' + }, colormt) + end +end + +local colors = { + -- attributes + reset = 0, + clear = 0, + bright = 1, + dim = 2, + underscore = 4, + blink = 5, + reverse = 7, + hidden = 8, + + -- foreground + black = 30, + red = 31, + green = 32, + yellow = 33, + blue = 34, + magenta = 35, + cyan = 36, + white = 37, + + -- background + onblack = 40, + onred = 41, + ongreen = 42, + onyellow = 43, + onblue = 44, + onmagenta = 45, + oncyan = 46, + onwhite = 47, +} + +for c, v in pairs(colors) do + ansicolors[c] = makecolor(v) +end + +return ansicolors
\ No newline at end of file diff --git a/lualib/global_functions.lua b/lualib/global_functions.lua new file mode 100644 index 0000000..45d4e84 --- /dev/null +++ b/lualib/global_functions.lua @@ -0,0 +1,56 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local logger = require "rspamd_logger" +local lua_util = require "lua_util" +local lua_redis = require "lua_redis" +local meta_functions = require "lua_meta" +local maps = require "lua_maps" + +local exports = {} + +exports.rspamd_parse_redis_server = lua_redis.rspamd_parse_redis_server +exports.parse_redis_server = lua_redis.rspamd_parse_redis_server +exports.rspamd_redis_make_request = lua_redis.rspamd_redis_make_request +exports.redis_make_request = lua_redis.rspamd_redis_make_request + +exports.rspamd_gen_metatokens = meta_functions.rspamd_gen_metatokens +exports.rspamd_count_metatokens = meta_functions.rspamd_count_metatokens + +exports.rspamd_map_add = maps.rspamd_map_add + +exports.rspamd_str_split = lua_util.rspamd_str_split + +-- a special syntax sugar to export all functions to the global table +setmetatable(exports, { + __call = function(t, override) + for k, v in pairs(t) do + if _G[k] ~= nil then + local msg = 'function ' .. k .. ' already exists in global scope.' + if override then + _G[k] = v + logger.errx('WARNING: ' .. msg .. ' Overwritten.') + else + logger.errx('NOTICE: ' .. msg .. ' Skipped.') + end + else + _G[k] = v + end + end + end, +}) + +return exports diff --git a/lualib/lua_auth_results.lua b/lualib/lua_auth_results.lua new file mode 100644 index 0000000..8c907d9 --- /dev/null +++ b/lualib/lua_auth_results.lua @@ -0,0 +1,301 @@ +--[[ +Copyright (c) 2016, Andrew Lewis <nerf@judo.za.org> +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local rspamd_util = require "rspamd_util" +local lua_util = require "lua_util" + +local default_settings = { + spf_symbols = { + pass = 'R_SPF_ALLOW', + fail = 'R_SPF_FAIL', + softfail = 'R_SPF_SOFTFAIL', + neutral = 'R_SPF_NEUTRAL', + temperror = 'R_SPF_DNSFAIL', + none = 'R_SPF_NA', + permerror = 'R_SPF_PERMFAIL', + }, + dmarc_symbols = { + pass = 'DMARC_POLICY_ALLOW', + permerror = 'DMARC_BAD_POLICY', + temperror = 'DMARC_DNSFAIL', + none = 'DMARC_NA', + reject = 'DMARC_POLICY_REJECT', + softfail = 'DMARC_POLICY_SOFTFAIL', + quarantine = 'DMARC_POLICY_QUARANTINE', + }, + arc_symbols = { + pass = 'ARC_ALLOW', + permerror = 'ARC_INVALID', + temperror = 'ARC_DNSFAIL', + none = 'ARC_NA', + reject = 'ARC_REJECT', + }, + dkim_symbols = { + none = 'R_DKIM_NA', + }, + add_smtp_user = true, +} + +local exports = { + default_settings = default_settings +} + +local local_hostname = rspamd_util.get_hostname() + +local function gen_auth_results(task, settings) + local auth_results, hdr_parts = {}, {} + + if not settings then + settings = default_settings + end + + local auth_types = { + dkim = settings.dkim_symbols, + dmarc = settings.dmarc_symbols, + spf = settings.spf_symbols, + arc = settings.arc_symbols, + } + + local common = { + symbols = {} + } + + local mta_hostname = task:get_request_header('MTA-Name') or + task:get_request_header('MTA-Tag') + if mta_hostname then + mta_hostname = tostring(mta_hostname) + else + mta_hostname = local_hostname + end + + table.insert(hdr_parts, mta_hostname) + + for auth_type, symbols in pairs(auth_types) do + for key, sym in pairs(symbols) do + if not common.symbols.sym then + local s = task:get_symbol(sym) + if not s then + common.symbols[sym] = false + else + common.symbols[sym] = s + if not auth_results[auth_type] then + auth_results[auth_type] = { key } + else + table.insert(auth_results[auth_type], key) + end + + if auth_type ~= 'dkim' then + break + end + end + end + end + end + + local dkim_results = task:get_dkim_results() + -- For each signature we set authentication results + -- dkim=neutral (body hash did not verify) header.d=example.com header.s=sel header.b=fA8VVvJ8; + -- dkim=neutral (body hash did not verify) header.d=example.com header.s=sel header.b=f8pM8o90; + + for _, dres in ipairs(dkim_results) do + local ar_string = 'none' + + if dres.result == 'reject' then + ar_string = 'fail' -- imply failure, not neutral + elseif dres.result == 'allow' then + ar_string = 'pass' + elseif dres.result == 'bad record' or dres.result == 'permerror' then + ar_string = 'permerror' + elseif dres.result == 'tempfail' then + ar_string = 'temperror' + end + local hdr = {} + + hdr[1] = string.format('dkim=%s', ar_string) + + if dres.fail_reason then + hdr[#hdr + 1] = string.format('(%s)', lua_util.maybe_smtp_quote_value(dres.fail_reason)) + end + + if dres.domain then + hdr[#hdr + 1] = string.format('header.d=%s', lua_util.maybe_smtp_quote_value(dres.domain)) + end + + if dres.selector then + hdr[#hdr + 1] = string.format('header.s=%s', lua_util.maybe_smtp_quote_value(dres.selector)) + end + + if dres.bhash then + hdr[#hdr + 1] = string.format('header.b=%s', lua_util.maybe_smtp_quote_value(dres.bhash)) + end + + table.insert(hdr_parts, table.concat(hdr, ' ')) + end + + if #dkim_results == 0 then + -- We have no dkim results, so check for DKIM_NA symbol + if common.symbols[settings.dkim_symbols.none] then + table.insert(hdr_parts, 'dkim=none') + end + end + + for auth_type, keys in pairs(auth_results) do + for _, key in ipairs(keys) do + local hdr = '' + if auth_type == 'dmarc' then + local opts = common.symbols[auth_types['dmarc'][key]][1]['options'] or {} + hdr = hdr .. 'dmarc=' + if key == 'reject' or key == 'quarantine' or key == 'softfail' then + hdr = hdr .. 'fail' + else + hdr = hdr .. lua_util.maybe_smtp_quote_value(key) + end + if key == 'pass' then + hdr = hdr .. ' (policy=' .. lua_util.maybe_smtp_quote_value(opts[2]) .. ')' + hdr = hdr .. ' header.from=' .. lua_util.maybe_smtp_quote_value(opts[1]) + elseif key ~= 'none' then + local t = { opts[1]:match('^([^%s]+) : (.*)$') } + if #t > 0 then + local dom = t[1] + local rsn = t[2] + if rsn then + hdr = string.format('%s reason=%s', hdr, lua_util.maybe_smtp_quote_value(rsn)) + end + hdr = string.format('%s header.from=%s', hdr, lua_util.maybe_smtp_quote_value(dom)) + end + if key == 'softfail' then + hdr = hdr .. ' (policy=none)' + else + hdr = hdr .. ' (policy=' .. lua_util.maybe_smtp_quote_value(key) .. ')' + end + end + table.insert(hdr_parts, hdr) + elseif auth_type == 'arc' then + if common.symbols[auth_types['arc'][key]][1] then + local opts = common.symbols[auth_types['arc'][key]][1]['options'] or {} + for _, v in ipairs(opts) do + hdr = string.format('%s%s=%s (%s)', hdr, auth_type, + lua_util.maybe_smtp_quote_value(key), lua_util.maybe_smtp_quote_value(v)) + table.insert(hdr_parts, hdr) + end + end + elseif auth_type == 'spf' then + -- Main type + local sender + local sender_type + local smtp_from = task:get_from({ 'smtp', 'orig' }) + + if smtp_from and + smtp_from[1] and + smtp_from[1]['addr'] ~= '' and + smtp_from[1]['addr'] ~= nil then + sender = lua_util.maybe_smtp_quote_value(smtp_from[1]['addr']) + sender_type = 'smtp.mailfrom' + else + local helo = task:get_helo() + if helo then + sender = lua_util.maybe_smtp_quote_value(helo) + sender_type = 'smtp.helo' + end + end + + if sender and sender_type then + -- Comment line + local comment = '' + if key == 'pass' then + comment = string.format('%s: domain of %s designates %s as permitted sender', + mta_hostname, sender, tostring(task:get_from_ip() or 'unknown')) + elseif key == 'fail' then + comment = string.format('%s: domain of %s does not designate %s as permitted sender', + mta_hostname, sender, tostring(task:get_from_ip() or 'unknown')) + elseif key == 'neutral' or key == 'softfail' then + comment = string.format('%s: %s is neither permitted nor denied by domain of %s', + mta_hostname, tostring(task:get_from_ip() or 'unknown'), sender) + elseif key == 'permerror' then + comment = string.format('%s: domain of %s uses mechanism not recognized by this client', + mta_hostname, sender) + elseif key == 'temperror' then + comment = string.format('%s: error in processing during lookup of %s: DNS error', + mta_hostname, sender) + elseif key == 'none' then + comment = string.format('%s: domain of %s has no SPF policy when checking %s', + mta_hostname, sender, tostring(task:get_from_ip() or 'unknown')) + end + hdr = string.format('%s=%s (%s) %s=%s', auth_type, key, + comment, sender_type, sender) + else + hdr = string.format('%s=%s', auth_type, key) + end + + table.insert(hdr_parts, hdr) + end + end + end + + local u = task:get_user() + local smtp_from = task:get_from({ 'smtp', 'orig' }) + + if u and smtp_from then + local hdr = { [1] = 'auth=pass' } + + if settings['add_smtp_user'] then + table.insert(hdr, 'smtp.auth=' .. lua_util.maybe_smtp_quote_value(u)) + end + if smtp_from[1]['addr'] then + table.insert(hdr, 'smtp.mailfrom=' .. lua_util.maybe_smtp_quote_value(smtp_from[1]['addr'])) + end + + table.insert(hdr_parts, table.concat(hdr, ' ')) + end + + if #hdr_parts > 0 then + if #hdr_parts == 1 then + hdr_parts[2] = 'none' + end + return table.concat(hdr_parts, '; ') + end + + return nil +end + +exports.gen_auth_results = gen_auth_results + +local aar_elt_grammar +-- This function parses an ar element to a table of kv pairs that represents different +-- elements +local function parse_ar_element(elt) + + if not aar_elt_grammar then + -- Generate grammar + local lpeg = require "lpeg" + local P = lpeg.P + local S = lpeg.S + local V = lpeg.V + local C = lpeg.C + local space = S(" ") ^ 0 + local doublequoted = space * P '"' * ((1 - S '"\r\n\f\\') + (P '\\' * 1)) ^ 0 * '"' * space + local comment = space * P { "(" * ((1 - S "()") + V(1)) ^ 0 * ")" } * space + local name = C((1 - S('=(" ')) ^ 1) * space + local pair = lpeg.Cg(name * "=" * space * name) * space + aar_elt_grammar = lpeg.Cf(lpeg.Ct("") * (pair + comment + doublequoted) ^ 1, rawset) + end + + return aar_elt_grammar:match(elt) +end +exports.parse_ar_element = parse_ar_element + +return exports diff --git a/lualib/lua_aws.lua b/lualib/lua_aws.lua new file mode 100644 index 0000000..e6c4b29 --- /dev/null +++ b/lualib/lua_aws.lua @@ -0,0 +1,300 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + + +--[[[ +-- @module lua_aws +-- This module contains Amazon AWS utility functions +--]] + +local N = "aws" +--local rspamd_logger = require "rspamd_logger" +local ts = (require "tableshape").types +local lua_util = require "lua_util" +local fun = require "fun" +local rspamd_crypto_hash = require "rspamd_cryptobox_hash" + +local exports = {} + +-- Returns a canonical representation of today date +local function today_canonical() + return os.date('!%Y%m%d') +end + +--[[[ +-- @function lua_aws.aws_date([date_str]) +-- Returns an aws date header corresponding to the specific date +--]] +local function aws_date(date_str) + if not date_str then + date_str = today_canonical() + end + + return date_str .. os.date('!T%H%M%SZ') +end + +exports.aws_date = aws_date + + +-- Local cache of the keys to save resources +local cached_keys = {} + +local function maybe_get_cached_key(date_str, secret_key, region, service, req_type) + local bucket = cached_keys[tonumber(date_str)] + + if not bucket then + return nil + end + + local elt = bucket[string.format('%s.%s.%s.%s', secret_key, region, service, req_type)] + if elt then + return elt + end +end + +local function save_cached_key(date_str, secret_key, region, service, req_type, key) + local numdate = tonumber(date_str) + -- expire old buckets + for k, _ in pairs(cached_keys) do + if k < numdate then + cached_keys[k] = nil + end + end + + local bucket = cached_keys[tonumber(date_str)] + local idx = string.format('%s.%s.%s.%s', secret_key, region, service, req_type) + + if not bucket then + cached_keys[tonumber(date_str)] = { + idx = key + } + else + bucket[idx] = key + end +end +--[[[ +-- @function lua_aws.aws_signing_key([date_str], secret_key, region, [service='s3'], [req_type='aws4_request']) +-- Returns a signing key for the specific parameters +--]] +local function aws_signing_key(date_str, secret_key, region, service, req_type) + if not date_str then + date_str = today_canonical() + end + + if not service then + service = 's3' + end + + if not req_type then + req_type = 'aws4_request' + end + + assert(type(secret_key) == 'string') + assert(type(region) == 'string') + + local maybe_cached = maybe_get_cached_key(date_str, secret_key, region, service, req_type) + + if maybe_cached then + return maybe_cached + end + + local hmac1 = rspamd_crypto_hash.create_specific_keyed("AWS4" .. secret_key, "sha256", date_str):bin() + local hmac2 = rspamd_crypto_hash.create_specific_keyed(hmac1, "sha256", region):bin() + local hmac3 = rspamd_crypto_hash.create_specific_keyed(hmac2, "sha256", service):bin() + local final_key = rspamd_crypto_hash.create_specific_keyed(hmac3, "sha256", req_type):bin() + + save_cached_key(date_str, secret_key, region, service, req_type, final_key) + + return final_key +end + +exports.aws_signing_key = aws_signing_key + +--[[[ +-- @function lua_aws.aws_canon_request_hash(method, path, headers_to_sign, hex_hash) +-- Returns a hash + list of headers as required to produce signature afterwards +--]] +local function aws_canon_request_hash(method, uri, headers_to_sign, hex_hash) + assert(type(method) == 'string') + assert(type(uri) == 'string') + assert(type(headers_to_sign) == 'table') + + if not hex_hash then + hex_hash = headers_to_sign['x-amz-content-sha256'] + end + + assert(type(hex_hash) == 'string') + + local sha_ctx = rspamd_crypto_hash.create_specific('sha256') + + lua_util.debugm(N, 'update signature with the method %s', + method) + sha_ctx:update(method .. '\n') + lua_util.debugm(N, 'update signature with the uri %s', + uri) + sha_ctx:update(uri .. '\n') + -- XXX add query string canonicalisation + sha_ctx:update('\n') + -- Sort auth headers and canonicalise them as requested + local hdr_canon = fun.tomap(fun.map(function(k, v) + return k:lower(), lua_util.str_trim(v) + end, headers_to_sign)) + local header_names = lua_util.keys(hdr_canon) + table.sort(header_names) + for _, hn in ipairs(header_names) do + local v = hdr_canon[hn] + lua_util.debugm(N, 'update signature with the header %s, %s', + hn, v) + sha_ctx:update(string.format('%s:%s\n', hn, v)) + end + local hdrs_list = table.concat(header_names, ';') + lua_util.debugm(N, 'headers list to sign: %s', hdrs_list) + sha_ctx:update(string.format('\n%s\n%s', hdrs_list, hex_hash)) + + return sha_ctx:hex(), hdrs_list +end + +exports.aws_canon_request_hash = aws_canon_request_hash + +local aws_authorization_hdr_args_schema = ts.shape { + date = ts.string + ts['nil'] / today_canonical, + secret_key = ts.string, + method = ts.string + ts['nil'] / function() + return 'GET' + end, + uri = ts.string, + region = ts.string, + service = ts.string + ts['nil'] / function() + return 's3' + end, + req_type = ts.string + ts['nil'] / function() + return 'aws4_request' + end, + headers = ts.map_of(ts.string, ts.string), + key_id = ts.string, +} +--[[[ +-- @function lua_aws.aws_authorization_hdr(params) +-- Produces an authorization header as required by AWS +-- Parameters schema is the following: +ts.shape{ + date = ts.string + ts['nil'] / today_canonical, + secret_key = ts.string, + method = ts.string + ts['nil'] / function() return 'GET' end, + uri = ts.string, + region = ts.string, + service = ts.string + ts['nil'] / function() return 's3' end, + req_type = ts.string + ts['nil'] / function() return 'aws4_request' end, + headers = ts.map_of(ts.string, ts.string), + key_id = ts.string, +} +-- +--]] +local function aws_authorization_hdr(tbl, transformed) + local res, err + if not transformed then + res, err = aws_authorization_hdr_args_schema:transform(tbl) + assert(res, err) + else + res = tbl + end + + local signing_key = aws_signing_key(res.date, res.secret_key, res.region, res.service, + res.req_type) + assert(signing_key ~= nil) + local signed_sha, signed_hdrs = aws_canon_request_hash(res.method, res.uri, + res.headers) + + if not signed_sha then + return nil + end + + local string_to_sign = string.format('AWS4-HMAC-SHA256\n%s\n%s/%s/%s/%s\n%s', + res.headers['x-amz-date'] or aws_date(), + res.date, res.region, res.service, res.req_type, + signed_sha) + lua_util.debugm(N, "string to sign: %s", string_to_sign) + local hmac = rspamd_crypto_hash.create_specific_keyed(signing_key, 'sha256', string_to_sign):hex() + lua_util.debugm(N, "hmac: %s", hmac) + local auth_hdr = string.format('AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/%s,' .. + 'SignedHeaders=%s,Signature=%s', + res.key_id, res.date, res.region, res.service, res.req_type, + signed_hdrs, hmac) + + return auth_hdr +end + +exports.aws_authorization_hdr = aws_authorization_hdr + + + +--[[[ +-- @function lua_aws.aws_request_enrich(params, content) +-- Produces an authorization header as required by AWS +-- Parameters schema is the following: +ts.shape{ + date = ts.string + ts['nil'] / today_canonical, + secret_key = ts.string, + method = ts.string + ts['nil'] / function() return 'GET' end, + uri = ts.string, + region = ts.string, + service = ts.string + ts['nil'] / function() return 's3' end, + req_type = ts.string + ts['nil'] / function() return 'aws4_request' end, + headers = ts.map_of(ts.string, ts.string), + key_id = ts.string, +} +This method returns new/modified in place table of the headers +-- +--]] +local function aws_request_enrich(tbl, content) + local res, err = aws_authorization_hdr_args_schema:transform(tbl) + assert(res, err) + local content_sha256 = rspamd_crypto_hash.create_specific('sha256', content):hex() + local hdrs = res.headers + hdrs['x-amz-content-sha256'] = content_sha256 + if not hdrs['x-amz-date'] then + hdrs['x-amz-date'] = aws_date(res.date) + end + hdrs['Authorization'] = aws_authorization_hdr(res, true) + + return hdrs +end + +exports.aws_request_enrich = aws_request_enrich + +-- A simple tests according to AWS docs to check sanity +local test_request_hdrs = { + ['Host'] = 'examplebucket.s3.amazonaws.com', + ['x-amz-date'] = '20130524T000000Z', + ['Range'] = 'bytes=0-9', + ['x-amz-content-sha256'] = 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855', +} + +assert(aws_canon_request_hash('GET', '/test.txt', test_request_hdrs) == + '7344ae5b7ee6c3e7e6b0fe0640412a37625d1fbfff95c48bbb2dc43964946972') + +assert(aws_authorization_hdr { + date = '20130524', + region = 'us-east-1', + headers = test_request_hdrs, + uri = '/test.txt', + key_id = 'AKIAIOSFODNN7EXAMPLE', + secret_key = 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY', +} == 'AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20130524/us-east-1/s3/aws4_request,' .. + 'SignedHeaders=host;range;x-amz-content-sha256;x-amz-date,' .. + 'Signature=f0e8bdb87c964420e857bd35b5d6ed310bd44f0170aba48dd91039c6036bdb41') + +return exports diff --git a/lualib/lua_bayes_learn.lua b/lualib/lua_bayes_learn.lua new file mode 100644 index 0000000..ea97db6 --- /dev/null +++ b/lualib/lua_bayes_learn.lua @@ -0,0 +1,151 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +-- This file contains functions to simplify bayes classifier auto-learning + +local lua_util = require "lua_util" +local lua_verdict = require "lua_verdict" +local N = "lua_bayes" + +local exports = {} + +exports.can_learn = function(task, is_spam, is_unlearn) + local learn_type = task:get_request_header('Learn-Type') + + if not (learn_type and tostring(learn_type) == 'bulk') then + local prob = task:get_mempool():get_variable('bayes_prob', 'double') + + if prob then + local in_class = false + local cl + if is_spam then + cl = 'spam' + in_class = prob >= 0.95 + else + cl = 'ham' + in_class = prob <= 0.05 + end + + if in_class then + return false, string.format( + 'already in class %s; probability %.2f%%', + cl, math.abs((prob - 0.5) * 200.0)) + end + end + end + + return true +end + +exports.autolearn = function(task, conf) + local function log_can_autolearn(verdict, score, threshold) + local from = task:get_from('smtp') + local mime_rcpts = 'undef' + local mr = task:get_recipients('mime') + if mr then + for _, r in ipairs(mr) do + if mime_rcpts == 'undef' then + mime_rcpts = r.addr + else + mime_rcpts = mime_rcpts .. ',' .. r.addr + end + end + end + + lua_util.debugm(N, task, 'id: %s, from: <%s>: can autolearn %s: score %s %s %s, mime_rcpts: <%s>', + task:get_header('Message-Id') or '<undef>', + from and from[1].addr or 'undef', + verdict, + string.format("%.2f", score), + verdict == 'ham' and '<=' or verdict == 'spam' and '>=' or '/', + threshold, + mime_rcpts) + end + + -- We have autolearn config so let's figure out what is requested + local verdict, score = lua_verdict.get_specific_verdict("bayes", task) + local learn_spam, learn_ham = false, false + + if verdict == 'passthrough' then + -- No need to autolearn + lua_util.debugm(N, task, 'no need to autolearn - verdict: %s', + verdict) + return + end + + if conf.spam_threshold and conf.ham_threshold then + if verdict == 'spam' then + if conf.spam_threshold and score >= conf.spam_threshold then + log_can_autolearn(verdict, score, conf.spam_threshold) + learn_spam = true + end + elseif verdict == 'junk' then + if conf.junk_threshold and score >= conf.junk_threshold then + log_can_autolearn(verdict, score, conf.junk_threshold) + learn_spam = true + end + elseif verdict == 'ham' then + if conf.ham_threshold and score <= conf.ham_threshold then + log_can_autolearn(verdict, score, conf.ham_threshold) + learn_ham = true + end + end + elseif conf.learn_verdict then + if verdict == 'spam' or verdict == 'junk' then + learn_spam = true + elseif verdict == 'ham' then + learn_ham = true + end + end + + if conf.check_balance then + -- Check balance of learns + local spam_learns = task:get_mempool():get_variable('spam_learns', 'int64') or 0 + local ham_learns = task:get_mempool():get_variable('ham_learns', 'int64') or 0 + + local min_balance = 0.9 + if conf.min_balance then + min_balance = conf.min_balance + end + + if spam_learns > 0 or ham_learns > 0 then + local max_ratio = 1.0 / min_balance + local spam_learns_ratio = spam_learns / (ham_learns + 1) + if spam_learns_ratio > max_ratio and learn_spam then + lua_util.debugm(N, task, + 'skip learning spam, balance is not satisfied: %s < %s; %s spam learns; %s ham learns', + spam_learns_ratio, min_balance, spam_learns, ham_learns) + learn_spam = false + end + + local ham_learns_ratio = ham_learns / (spam_learns + 1) + if ham_learns_ratio > max_ratio and learn_ham then + lua_util.debugm(N, task, + 'skip learning ham, balance is not satisfied: %s < %s; %s spam learns; %s ham learns', + ham_learns_ratio, min_balance, spam_learns, ham_learns) + learn_ham = false + end + end + end + + if learn_spam then + return 'spam' + elseif learn_ham then + return 'ham' + end +end + +return exports
\ No newline at end of file diff --git a/lualib/lua_bayes_redis.lua b/lualib/lua_bayes_redis.lua new file mode 100644 index 0000000..7533997 --- /dev/null +++ b/lualib/lua_bayes_redis.lua @@ -0,0 +1,244 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]] + +-- This file contains functions to support Bayes statistics in Redis + +local exports = {} +local lua_redis = require "lua_redis" +local logger = require "rspamd_logger" +local lua_util = require "lua_util" +local ucl = require "ucl" + +local N = "bayes" + +local function gen_classify_functor(redis_params, classify_script_id) + return function(task, expanded_key, id, is_spam, stat_tokens, callback) + + local function classify_redis_cb(err, data) + lua_util.debugm(N, task, 'classify redis cb: %s, %s', err, data) + if err then + callback(task, false, err) + else + callback(task, true, data[1], data[2], data[3], data[4]) + end + end + + lua_redis.exec_redis_script(classify_script_id, + { task = task, is_write = false, key = expanded_key }, + classify_redis_cb, { expanded_key, stat_tokens }) + end +end + +local function gen_learn_functor(redis_params, learn_script_id) + return function(task, expanded_key, id, is_spam, symbol, is_unlearn, stat_tokens, callback, maybe_text_tokens) + local function learn_redis_cb(err, data) + lua_util.debugm(N, task, 'learn redis cb: %s, %s', err, data) + if err then + callback(task, false, err) + else + callback(task, true) + end + end + + if maybe_text_tokens then + lua_redis.exec_redis_script(learn_script_id, + { task = task, is_write = true, key = expanded_key }, + learn_redis_cb, + { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens }) + else + lua_redis.exec_redis_script(learn_script_id, + { task = task, is_write = true, key = expanded_key }, + learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens }) + end + + end +end + +local function load_redis_params(classifier_ucl, statfile_ucl) + local redis_params + + -- Try load from statfile options + if statfile_ucl.redis then + redis_params = lua_redis.try_load_redis_servers(statfile_ucl.redis, rspamd_config, true) + end + + if not redis_params then + if statfile_ucl then + redis_params = lua_redis.try_load_redis_servers(statfile_ucl, rspamd_config, true) + end + end + + -- Try load from classifier config + if not redis_params and classifier_ucl.backend then + redis_params = lua_redis.try_load_redis_servers(classifier_ucl.backend, rspamd_config, true) + end + + if not redis_params and classifier_ucl.redis then + redis_params = lua_redis.try_load_redis_servers(classifier_ucl.redis, rspamd_config, true) + end + + if not redis_params then + redis_params = lua_redis.try_load_redis_servers(classifier_ucl, rspamd_config, true) + end + + -- Try load global options + if not redis_params then + redis_params = lua_redis.try_load_redis_servers(rspamd_config:get_all_opt('redis'), rspamd_config, true) + end + + if not redis_params then + logger.err(rspamd_config, "cannot load Redis parameters for the classifier") + return nil + end + + return redis_params +end + +--- +--- Init bayes classifier +--- @param classifier_ucl ucl of the classifier config +--- @param statfile_ucl ucl of the statfile config +--- @return a pair of (classify_functor, learn_functor) or `nil` in case of error +exports.lua_bayes_init_statfile = function(classifier_ucl, statfile_ucl, symbol, is_spam, ev_base, stat_periodic_cb) + + local redis_params = load_redis_params(classifier_ucl, statfile_ucl) + + if not redis_params then + return nil + end + + local classify_script_id = lua_redis.load_redis_script_from_file("bayes_classify.lua", redis_params) + local learn_script_id = lua_redis.load_redis_script_from_file("bayes_learn.lua", redis_params) + local stat_script_id = lua_redis.load_redis_script_from_file("bayes_stat.lua", redis_params) + local max_users = classifier_ucl.max_users or 1000 + + local current_data = { + users = 0, + revision = 0, + } + local final_data = { + users = 0, + revision = 0, -- number of learns + } + local cursor = 0 + rspamd_config:add_periodic(ev_base, 0.0, function(cfg, _) + + local function stat_redis_cb(err, data) + lua_util.debugm(N, cfg, 'stat redis cb: %s, %s', err, data) + + if err then + logger.warn(cfg, 'cannot get bayes statistics for %s: %s', symbol, err) + else + local new_cursor = data[1] + current_data.users = current_data.users + data[2] + current_data.revision = current_data.revision + data[3] + if new_cursor == 0 then + -- Done iteration + final_data = lua_util.shallowcopy(current_data) + current_data = { + users = 0, + revision = 0, + } + lua_util.debugm(N, cfg, 'final data: %s', final_data) + stat_periodic_cb(cfg, final_data) + end + + cursor = new_cursor + end + end + + lua_redis.exec_redis_script(stat_script_id, + { ev_base = ev_base, cfg = cfg, is_write = false }, + stat_redis_cb, { tostring(cursor), + symbol, + is_spam and "learns_spam" or "learns_ham", + tostring(max_users) }) + return statfile_ucl.monitor_timeout or classifier_ucl.monitor_timeout or 30.0 + end) + + return gen_classify_functor(redis_params, classify_script_id), gen_learn_functor(redis_params, learn_script_id) +end + +local function gen_cache_check_functor(redis_params, check_script_id, conf) + local packed_conf = ucl.to_format(conf, 'msgpack') + return function(task, cache_id, callback) + + local function classify_redis_cb(err, data) + lua_util.debugm(N, task, 'check cache redis cb: %s, %s (%s)', err, data, type(data)) + if err then + callback(task, false, err) + else + if type(data) == 'number' then + callback(task, true, data) + else + callback(task, false, 'not found') + end + end + end + + lua_util.debugm(N, task, 'checking cache: %s', cache_id) + lua_redis.exec_redis_script(check_script_id, + { task = task, is_write = false, key = cache_id }, + classify_redis_cb, { cache_id, packed_conf }) + end +end + +local function gen_cache_learn_functor(redis_params, learn_script_id, conf) + local packed_conf = ucl.to_format(conf, 'msgpack') + return function(task, cache_id, is_spam) + local function learn_redis_cb(err, data) + lua_util.debugm(N, task, 'learn_cache redis cb: %s, %s', err, data) + end + + lua_util.debugm(N, task, 'try to learn cache: %s', cache_id) + lua_redis.exec_redis_script(learn_script_id, + { task = task, is_write = true, key = cache_id }, + learn_redis_cb, + { cache_id, is_spam and "1" or "0", packed_conf }) + + end +end + +exports.lua_bayes_init_cache = function(classifier_ucl, statfile_ucl) + local redis_params = load_redis_params(classifier_ucl, statfile_ucl) + + if not redis_params then + return nil + end + + local default_conf = { + cache_prefix = "learned_ids", + cache_max_elt = 10000, -- Maximum number of elements in the cache key + cache_max_keys = 5, -- Maximum number of keys in the cache + cache_elt_len = 32, -- Length of the element in the cache (will trim id to that value) + } + + local conf = lua_util.override_defaults(default_conf, classifier_ucl) + -- Clean all not known configurations + for k, _ in pairs(conf) do + if default_conf[k] == nil then + conf[k] = nil + end + end + + local check_script_id = lua_redis.load_redis_script_from_file("bayes_cache_check.lua", redis_params) + local learn_script_id = lua_redis.load_redis_script_from_file("bayes_cache_learn.lua", redis_params) + + return gen_cache_check_functor(redis_params, check_script_id, conf), gen_cache_learn_functor(redis_params, + learn_script_id, conf) +end + +return exports diff --git a/lualib/lua_cfg_transform.lua b/lualib/lua_cfg_transform.lua new file mode 100644 index 0000000..d6243ad --- /dev/null +++ b/lualib/lua_cfg_transform.lua @@ -0,0 +1,634 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local logger = require "rspamd_logger" +local lua_util = require "lua_util" +local rspamd_util = require "rspamd_util" +local fun = require "fun" + +local function is_implicit(t) + local mt = getmetatable(t) + + return mt and mt.class and mt.class == 'ucl.type.impl_array' +end + +local function metric_pairs(t) + -- collect the keys + local keys = {} + local implicit_array = is_implicit(t) + + local function gen_keys(tbl) + if implicit_array then + for _, v in ipairs(tbl) do + if v.name then + table.insert(keys, { v.name, v }) + v.name = nil + else + -- Very tricky to distinguish: + -- group {name = "foo" ... } + group "blah" { ... } + for gr_name, gr in pairs(v) do + if type(gr_name) ~= 'number' then + -- We can also have implicit arrays here + local gr_implicit = is_implicit(gr) + + if gr_implicit then + for _, gr_elt in ipairs(gr) do + table.insert(keys, { gr_name, gr_elt }) + end + else + table.insert(keys, { gr_name, gr }) + end + end + end + end + end + else + if tbl.name then + table.insert(keys, { tbl.name, tbl }) + tbl.name = nil + else + for k, v in pairs(tbl) do + if type(k) ~= 'number' then + -- We can also have implicit arrays here + local sym_implicit = is_implicit(v) + + if sym_implicit then + for _, elt in ipairs(v) do + table.insert(keys, { k, elt }) + end + else + table.insert(keys, { k, v }) + end + end + end + end + end + end + + gen_keys(t) + + -- return the iterator function + local i = 0 + return function() + i = i + 1 + if keys[i] then + return keys[i][1], keys[i][2] + end + end +end + +local function group_transform(cfg, k, v) + if v.name then + k = v.name + end + + local new_group = { + symbols = {} + } + + if v.enabled then + new_group.enabled = v.enabled + end + if v.disabled then + new_group.disabled = v.disabled + end + if v.max_score then + new_group.max_score = v.max_score + end + + if v.symbol then + for sk, sv in metric_pairs(v.symbol) do + if sv.name then + sk = sv.name + sv.name = nil -- Remove field + end + + new_group.symbols[sk] = sv + end + end + + if not cfg.group then + cfg.group = {} + end + + if cfg.group[k] then + cfg.group[k] = lua_util.override_defaults(cfg.group[k], new_group) + else + cfg.group[k] = new_group + end + + logger.infox("overriding group %s from the legacy metric settings", k) +end + +local function symbol_transform(cfg, k, v) + -- first try to find any group where there is a definition of this symbol + for gr_n, gr in pairs(cfg.group) do + if gr.symbols and gr.symbols[k] then + -- We override group symbol with ungrouped symbol + logger.infox("overriding group symbol %s in the group %s", k, gr_n) + gr.symbols[k] = lua_util.override_defaults(gr.symbols[k], v) + return + end + end + -- Now check what Rspamd knows about this symbol + local sym = rspamd_config:get_symbol(k) + + if not sym or not sym.group then + -- Otherwise we just use group 'ungrouped' + if not cfg.group.ungrouped then + cfg.group.ungrouped = { + symbols = {} + } + end + + cfg.group.ungrouped.symbols[k] = v + logger.debugx("adding symbol %s to the group 'ungrouped'", k) + end +end + +local function test_groups(groups) + for gr_name, gr in pairs(groups) do + if not gr.symbols then + local cnt = 0 + for _, _ in pairs(gr) do + cnt = cnt + 1 + end + + if cnt == 0 then + logger.debugx('group %s is empty', gr_name) + else + logger.infox('group %s has no symbols', gr_name) + end + end + end +end + +local function convert_metric(cfg, metric) + if metric.actions then + cfg.actions = lua_util.override_defaults(cfg.actions, metric.actions) + logger.infox("overriding actions from the legacy metric settings") + end + if metric.unknown_weight then + cfg.actions.unknown_weight = metric.unknown_weight + end + + if metric.subject then + logger.infox("overriding subject from the legacy metric settings") + cfg.actions.subject = metric.subject + end + + if metric.group then + for k, v in metric_pairs(metric.group) do + group_transform(cfg, k, v) + end + else + if not cfg.group then + cfg.group = { + ungrouped = { + symbols = {} + } + } + end + end + + if metric.symbol then + for k, v in metric_pairs(metric.symbol) do + symbol_transform(cfg, k, v) + end + end + + return cfg +end + +-- Converts a table of groups indexed by number (implicit array) to a +-- merged group definition +local function merge_groups(groups) + local ret = {} + for k, gr in pairs(groups) do + if type(k) == 'number' then + for key, sec in pairs(gr) do + ret[key] = sec + end + else + ret[k] = gr + end + end + + return ret +end + +-- Checks configuration files for statistics +local function check_statistics_sanity() + local local_conf = rspamd_paths['LOCAL_CONFDIR'] + local local_stat = string.format('%s/local.d/%s', local_conf, + 'statistic.conf') + local local_bayes = string.format('%s/local.d/%s', local_conf, + 'classifier-bayes.conf') + + if rspamd_util.file_exists(local_stat) and + rspamd_util.file_exists(local_bayes) then + logger.warnx(rspamd_config, 'conflicting files %s and %s are found: ' .. + 'Rspamd classifier configuration might be broken!', local_stat, local_bayes) + end +end + +-- Converts surbl module config to rbl module +local function surbl_section_convert(cfg, section) + local rbl_section = cfg.rbl.rbls + local wl = section.whitelist + for name, value in pairs(section.rules or {}) do + if rbl_section[name] then + logger.warnx(rspamd_config, 'conflicting names in surbl and rbl rules: %s, prefer surbl rule!', + name) + end + local converted = { + urls = true, + ignore_defaults = true, + } + + if wl then + converted.whitelist = wl + end + + for k, v in pairs(value) do + local skip = false + -- Rename + if k == 'suffix' then + k = 'rbl' + end + if k == 'ips' then + k = 'returncodes' + end + if k == 'bits' then + k = 'returnbits' + end + if k == 'noip' then + k = 'no_ip' + end + -- Crappy legacy + if k == 'options' then + if v == 'noip' or v == 'no_ip' then + converted.no_ip = true + skip = true + end + end + if k:match('check_') then + local n = k:match('check_(.*)') + k = n + end + + if k == 'dkim' and v then + converted.dkim_domainonly = false + converted.dkim_match_from = true + end + + if k == 'emails' and v then + -- To match surbl behaviour + converted.emails_domainonly = true + end + + if not skip then + converted[k] = lua_util.deepcopy(v) + end + end + rbl_section[name] = lua_util.override_defaults(rbl_section[name], converted) + end +end + +-- Converts surbl module config to rbl module +local function emails_section_convert(cfg, section) + local rbl_section = cfg.rbl.rbls + local wl = section.whitelist + for name, value in pairs(section.rules or {}) do + if rbl_section[name] then + logger.warnx(rspamd_config, 'conflicting names in emails and rbl rules: %s, prefer emails rule!', + name) + end + local converted = { + emails = true, + ignore_defaults = true, + } + + if wl then + converted.whitelist = wl + end + + for k, v in pairs(value) do + local skip = false + -- Rename + if k == 'dnsbl' then + k = 'rbl' + end + if k == 'check_replyto' then + k = 'replyto' + end + if k == 'hashlen' then + k = 'hash_len' + end + if k == 'encoding' then + k = 'hash_format' + end + if k == 'domain_only' then + k = 'emails_domainonly' + end + if k == 'delimiter' then + k = 'emails_delimiter' + end + if k == 'skip_body' then + skip = true + if v then + -- Hack + converted.emails = false + converted.replyto = true + else + converted.emails = true + end + end + if k == 'expect_ip' then + -- Another stupid hack + if not converted.return_codes then + converted.returncodes = {} + end + local symbol = value.symbol or name + converted.returncodes[symbol] = { v } + skip = true + end + + if not skip then + converted[k] = lua_util.deepcopy(v) + end + end + rbl_section[name] = lua_util.override_defaults(rbl_section[name], converted) + end +end + +return function(cfg) + local ret = false + + if cfg['metric'] then + for _, v in metric_pairs(cfg.metric) do + cfg = convert_metric(cfg, v) + end + ret = true + end + + if cfg.symbols then + for k, v in metric_pairs(cfg.symbols) do + symbol_transform(cfg, k, v) + end + end + + check_statistics_sanity() + + if not cfg.actions then + logger.errx('no actions defined') + else + -- Perform sanity check for actions + local actions_defs = { 'no action', 'no_action', -- In case if that's added + 'greylist', 'add header', 'add_header', + 'rewrite subject', 'rewrite_subject', 'quarantine', + 'reject', 'discard' } + + if not cfg.actions['no action'] and not cfg.actions['no_action'] and + not cfg.actions['accept'] then + for _, d in ipairs(actions_defs) do + if cfg.actions[d] then + + local action_score = nil + if type(cfg.actions[d]) == 'number' then + action_score = cfg.actions[d] + elseif type(cfg.actions[d]) == 'table' and cfg.actions[d]['score'] then + action_score = cfg.actions[d]['score'] + end + + if type(cfg.actions[d]) ~= 'table' and not action_score then + cfg.actions[d] = nil + elseif type(action_score) == 'number' and action_score < 0 then + cfg.actions['no_action'] = cfg.actions[d] - 0.001 + logger.infox(rspamd_config, 'set no_action score to: %s, as action %s has negative score', + cfg.actions['no_action'], d) + break + end + end + end + end + + local actions_set = lua_util.list_to_hash(actions_defs) + + -- Now check actions section for garbage + actions_set['unknown_weight'] = true + actions_set['grow_factor'] = true + actions_set['subject'] = true + + for k, _ in pairs(cfg.actions) do + if not actions_set[k] then + logger.warnx(rspamd_config, 'unknown element in actions section: %s', k) + end + end + + -- Performs thresholds sanity + -- We exclude greylist here as it can be set to whatever threshold in practice + local actions_order = { + 'no_action', + 'add_header', + 'rewrite_subject', + 'quarantine', + 'reject', + 'discard' + } + for i = 1, (#actions_order - 1) do + local act = actions_order[i] + + if cfg.actions[act] and type(cfg.actions[act]) == 'number' then + local score = cfg.actions[act] + + for j = i + 1, #actions_order do + local next_act = actions_order[j] + if cfg.actions[next_act] and type(cfg.actions[next_act]) == 'number' then + local next_score = cfg.actions[next_act] + if next_score <= score then + logger.errx(rspamd_config, 'invalid actions thresholds order: action %s (%s) must have lower ' .. + 'score than action %s (%s)', act, score, next_act, next_score) + ret = false + end + end + end + end + end + end + + if not cfg.group then + logger.errx('no symbol groups defined') + else + if cfg.group[1] then + -- We need to merge groups + cfg.group = merge_groups(cfg.group) + ret = true + end + test_groups(cfg.group) + end + + -- Deal with dkim settings + if not cfg.dkim then + cfg.dkim = {} + else + if cfg.dkim.sign_condition then + -- We have an obsoleted sign condition, so we need to either add dkim_signing and move it + -- there or just move sign condition there... + if not cfg.dkim_signing then + logger.warnx('obsoleted DKIM signing method used, converting it to "dkim_signing" module') + cfg.dkim_signing = { + sign_condition = cfg.dkim.sign_condition + } + else + if not cfg.dkim_signing.sign_condition then + logger.warnx('obsoleted DKIM signing method used, move it to "dkim_signing" module') + cfg.dkim_signing.sign_condition = cfg.dkim.sign_condition + else + logger.warnx('obsoleted DKIM signing method used, ignore it as "dkim_signing" also defines condition!') + end + end + end + end + + -- Again: legacy stuff :( + if not cfg.dkim.sign_headers then + local sec = cfg.dkim_signing + if sec and sec[1] then + sec = cfg.dkim_signing[1] + end + + if sec and sec.sign_headers then + cfg.dkim.sign_headers = sec.sign_headers + end + end + + -- DKIM signing/ARC legacy + for _, mod in ipairs({ 'dkim_signing', 'arc' }) do + if cfg[mod] then + if cfg[mod].auth_only ~= nil then + if cfg[mod].sign_authenticated ~= nil then + logger.warnx(rspamd_config, + 'both auth_only (%s) and sign_authenticated (%s) for %s are specified, prefer auth_only', + cfg[mod].auth_only, cfg[mod].sign_authenticated, mod) + end + cfg[mod].sign_authenticated = cfg[mod].auth_only + end + end + end + + if cfg.dkim and cfg.dkim.sign_headers and type(cfg.dkim.sign_headers) == 'table' then + -- Flatten + cfg.dkim.sign_headers = table.concat(cfg.dkim.sign_headers, ':') + end + + -- Try to find some obvious issues with configuration + for k, v in pairs(cfg) do + if type(v) == 'table' and v[k] and type(v[k]) == 'table' then + logger.errx('nested section: %s { %s { ... } }, it is likely a configuration error', + k, k) + end + end + + -- If neural network is enabled we MUST have `check_all_filters` flag + if cfg.neural then + if not cfg.options then + cfg.options = {} + end + + if not cfg.options.check_all_filters then + logger.infox(rspamd_config, 'enable `options.check_all_filters` for neural network') + cfg.options.check_all_filters = true + end + end + + -- Deal with IP_SCORE + if cfg.ip_score and (cfg.ip_score.servers or cfg.redis.servers) then + logger.warnx(rspamd_config, 'ip_score module is deprecated in honor of reputation module!') + + if not cfg.reputation then + cfg.reputation = { + rules = {} + } + end + + if not cfg.reputation.rules then + cfg.reputation.rules = {} + end + + if not fun.any(function(_, v) + return v.selector and v.selector.ip + end, + cfg.reputation.rules) then + logger.infox(rspamd_config, 'attach ip reputation element to use it') + + cfg.reputation.rules.ip_score = { + selector = { + ip = {}, + }, + backend = { + redis = {}, + } + } + + if cfg.ip_score.servers then + cfg.reputation.rules.ip_score.backend.redis.servers = cfg.ip_score.servers + end + + if cfg.symbols and cfg.symbols['IP_SCORE'] then + local t = cfg.symbols['IP_SCORE'] + + if not cfg.symbols['SENDER_REP_SPAM'] then + cfg.symbols['SENDER_REP_SPAM'] = t + cfg.symbols['SENDER_REP_HAM'] = t + cfg.symbols['SENDER_REP_HAM'].weight = -(t.weight or 0) + end + end + else + logger.infox(rspamd_config, 'ip reputation already exists, do not do any IP_SCORE transforms') + end + end + + if cfg.surbl then + if not cfg.rbl then + cfg.rbl = { + rbls = {} + } + end + if not cfg.rbl.rbls then + cfg.rbl.rbls = {} + end + surbl_section_convert(cfg, cfg.surbl) + logger.infox(rspamd_config, 'converted surbl rules to rbl rules') + cfg.surbl = {} + end + + if cfg.emails then + if not cfg.rbl then + cfg.rbl = { + rbls = {} + } + end + if not cfg.rbl.rbls then + cfg.rbl.rbls = {} + end + emails_section_convert(cfg, cfg.emails) + logger.infox(rspamd_config, 'converted emails rules to rbl rules') + cfg.emails = {} + end + + return ret, cfg +end diff --git a/lualib/lua_cfg_utils.lua b/lualib/lua_cfg_utils.lua new file mode 100644 index 0000000..e07a3ae --- /dev/null +++ b/lualib/lua_cfg_utils.lua @@ -0,0 +1,84 @@ +--[[ +Copyright (c) 2023, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_cfg_utils +-- This module contains utility functions for configuration of Rspamd modules +--]] + +local rspamd_logger = require "rspamd_logger" +local exports = {} + +--[[[ +-- @function lua_util.disable_module(modname, how[, reason]) +-- Disables a plugin +-- @param {string} modname name of plugin to disable +-- @param {string} how 'redis' to disable redis, 'config' to disable startup +-- @param {string} reason optional reason for failure +--]] +exports.disable_module = function(modname, how, reason) + if rspamd_plugins_state.enabled[modname] then + rspamd_plugins_state.enabled[modname] = nil + end + + if how == 'redis' then + rspamd_plugins_state.disabled_redis[modname] = {} + elseif how == 'config' then + rspamd_plugins_state.disabled_unconfigured[modname] = {} + elseif how == 'experimental' then + rspamd_plugins_state.disabled_experimental[modname] = {} + elseif how == 'failed' then + rspamd_plugins_state.disabled_failed[modname] = { reason = reason } + else + rspamd_plugins_state.disabled_unknown[modname] = {} + end +end + +--[[[ +-- @function lua_util.push_config_error(module, err) +-- Pushes a configuration error to the state +-- @param {string} module name of module +-- @param {string} err error string +--]] +exports.push_config_error = function(module, err) + if not rspamd_plugins_state.config_errors then + rspamd_plugins_state.config_errors = {} + end + + if not rspamd_plugins_state.config_errors[module] then + rspamd_plugins_state.config_errors[module] = {} + end + + table.insert(rspamd_plugins_state.config_errors[module], err) +end + +exports.check_configuration_errors = function() + local ret = true + + if type(rspamd_plugins_state.config_errors) == 'table' then + -- We have some errors found during the configuration, so we need to show them + for m, errs in pairs(rspamd_plugins_state.config_errors) do + for _, err in ipairs(errs) do + rspamd_logger.errx(rspamd_config, 'configuration error: module %s: %s', m, err) + ret = false + end + end + end + + return ret +end + +return exports
\ No newline at end of file diff --git a/lualib/lua_clickhouse.lua b/lualib/lua_clickhouse.lua new file mode 100644 index 0000000..28366d2 --- /dev/null +++ b/lualib/lua_clickhouse.lua @@ -0,0 +1,547 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> +Copyright (c) 2018, Mikhail Galanin <mgalanin@mimecast.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_clickhouse +-- This module contains Clickhouse access functions +--]] + +local rspamd_logger = require "rspamd_logger" +local rspamd_http = require "rspamd_http" +local lua_util = require "lua_util" +local rspamd_text = require "rspamd_text" + +local exports = {} +local N = 'clickhouse' + +local default_timeout = 10.0 + +local function escape_spaces(query) + return query:gsub('%s', '%%20') +end + +local function ch_number(a) + if (a + 2 ^ 52) - 2 ^ 52 == a then + -- Integer + return tostring(math.floor(a)) + end + + return tostring(a) +end + +local function clickhouse_quote(str) + if str then + return str:gsub('[\'\\\n\t\r]', { + ['\''] = [[\']], + ['\\'] = [[\\]], + ['\n'] = [[\n]], + ['\t'] = [[\t]], + ['\r'] = [[\r]], + }) + end + + return '' +end + +-- Converts an array to a string suitable for clickhouse +local function array_to_string(ar) + for i, elt in ipairs(ar) do + local t = type(elt) + if t == 'string' then + ar[i] = string.format('\'%s\'', clickhouse_quote(elt)) + elseif t == 'userdata' then + ar[i] = string.format('\'%s\'', clickhouse_quote(tostring(elt))) + elseif t == 'number' then + ar[i] = ch_number(elt) + end + end + + return table.concat(ar, ',') +end + +-- Converts a row into TSV, taking extra care about arrays +local function row_to_tsv(row) + + for i, elt in ipairs(row) do + local t = type(elt) + if t == 'table' then + row[i] = '[' .. array_to_string(elt) .. ']' + elseif t == 'number' then + row[i] = ch_number(elt) + elseif t == 'userdata' then + row[i] = clickhouse_quote(tostring(elt)) + else + row[i] = clickhouse_quote(elt) + end + end + + return rspamd_text.fromtable(row, '\t') +end + +exports.row_to_tsv = row_to_tsv + +-- Parses JSONEachRow reply from CH +local function parse_clickhouse_response_json_eachrow(params, data, row_cb) + local ucl = require "ucl" + + if data == nil then + -- clickhouse returned no data (i.e. empty result set): exiting + return {} + end + + local function parse_string(s) + local parser = ucl.parser() + local res, err + if type(s) == 'string' then + res, err = parser:parse_string(s) + else + res, err = parser:parse_text(s) + end + + if not res then + rspamd_logger.errx(params.log_obj, 'Parser error: %s', err) + return nil + end + return parser:get_object() + end + + -- iterate over rows and parse + local parsed_rows = {} + for plain_row in data:lines() do + if plain_row and #plain_row > 1 then + local parsed_row = parse_string(plain_row) + if parsed_row then + if row_cb then + row_cb(parsed_row) + else + table.insert(parsed_rows, parsed_row) + end + end + end + end + + return parsed_rows +end + +-- Parses JSON reply from CH +local function parse_clickhouse_response_json(params, data) + local ucl = require "ucl" + + if data == nil then + -- clickhouse returned no data (i.e. empty result set) considered valid! + return nil, {} + end + + local function parse_string(s) + local parser = ucl.parser() + local res, err + + if type(s) == 'string' then + res, err = parser:parse_string(s) + else + res, err = parser:parse_text(s) + end + + if not res then + rspamd_logger.errx(params.log_obj, 'Parser error: %s', err) + return nil + end + return parser:get_object() + end + + local json = parse_string(data) + + if not json then + return 'bad json', {} + end + + return nil, json +end + +-- Helper to generate HTTP closure +local function mk_http_select_cb(upstream, params, ok_cb, fail_cb, row_cb) + local function http_cb(err_message, code, data, _) + if code ~= 200 or err_message then + if not err_message then + err_message = data + end + local ip_addr = upstream:get_addr():to_string(true) + + if fail_cb then + fail_cb(params, err_message, data) + else + rspamd_logger.errx(params.log_obj, + "request failed on clickhouse server %s: %s", + ip_addr, err_message) + end + upstream:fail() + else + upstream:ok() + local rows = parse_clickhouse_response_json_eachrow(params, data, row_cb) + + if rows then + if ok_cb then + ok_cb(params, rows) + else + lua_util.debugm(N, params.log_obj, + "http_select_cb ok: %s, %s, %s, %s", err_message, code, + data:gsub('[\n%s]+', ' '), _) + end + else + if fail_cb then + fail_cb(params, 'failed to parse reply', data) + else + local ip_addr = upstream:get_addr():to_string(true) + rspamd_logger.errx(params.log_obj, + "request failed on clickhouse server %s: %s", + ip_addr, 'failed to parse reply') + end + end + end + end + + return http_cb +end + +-- Helper to generate HTTP closure +local function mk_http_insert_cb(upstream, params, ok_cb, fail_cb) + local function http_cb(err_message, code, data, _) + if code ~= 200 or err_message then + if not err_message then + err_message = data + end + local ip_addr = upstream:get_addr():to_string(true) + + if fail_cb then + fail_cb(params, err_message, data) + else + rspamd_logger.errx(params.log_obj, + "request failed on clickhouse server %s: %s", + ip_addr, err_message) + end + upstream:fail() + else + upstream:ok() + + if ok_cb then + local err, parsed = parse_clickhouse_response_json(data) + + if err then + fail_cb(params, err, data) + else + ok_cb(params, parsed) + end + + else + lua_util.debugm(N, params.log_obj, + "http_insert_cb ok: %s, %s, %s, %s", err_message, code, + data:gsub('[\n%s]+', ' '), _) + end + end + end + + return http_cb +end + +--[[[ +-- @function lua_clickhouse.select(upstream, settings, params, query, + ok_cb, fail_cb) +-- Make select request to clickhouse +-- @param {upstream} upstream clickhouse server upstream +-- @param {table} settings global settings table: +-- * use_gsip: use gzip compression +-- * timeout: request timeout +-- * no_ssl_verify: skip SSL verification +-- * user: HTTP user +-- * password: HTTP password +-- @param {params} HTTP request params +-- @param {string} query select query (passed in HTTP body) +-- @param {function} ok_cb callback to be called in case of success +-- @param {function} fail_cb callback to be called in case of some error +-- @param {function} row_cb optional callback to be called on each parsed data row (instead of table insertion) +-- @return {boolean} whether a connection was successful +-- @example +-- +--]] +exports.select = function(upstream, settings, params, query, ok_cb, fail_cb, row_cb) + local http_params = {} + + for k, v in pairs(params) do + http_params[k] = v + end + + http_params.callback = mk_http_select_cb(upstream, http_params, ok_cb, fail_cb, row_cb) + http_params.gzip = settings.use_gzip + http_params.mime_type = 'text/plain' + http_params.timeout = settings.timeout or default_timeout + http_params.no_ssl_verify = settings.no_ssl_verify + http_params.user = settings.user + http_params.password = settings.password + http_params.body = query + http_params.log_obj = params.task or params.config + http_params.opaque_body = true + + lua_util.debugm(N, http_params.log_obj, "clickhouse select request: %s", http_params.body) + + if not http_params.url then + local connect_prefix = "http://" + if settings.use_https then + connect_prefix = 'https://' + end + local ip_addr = upstream:get_addr():to_string(true) + local database = settings.database or 'default' + http_params.url = string.format('%s%s/?database=%s&default_format=JSONEachRow', + connect_prefix, ip_addr, escape_spaces(database)) + end + + return rspamd_http.request(http_params) +end + +--[[[ +-- @function lua_clickhouse.select_sync(upstream, settings, params, query, + ok_cb, fail_cb, row_cb) +-- Make select request to clickhouse +-- @param {upstream} upstream clickhouse server upstream +-- @param {table} settings global settings table: +-- * use_gsip: use gzip compression +-- * timeout: request timeout +-- * no_ssl_verify: skip SSL verification +-- * user: HTTP user +-- * password: HTTP password +-- @param {params} HTTP request params +-- @param {string} query select query (passed in HTTP body) +-- @param {function} ok_cb callback to be called in case of success +-- @param {function} fail_cb callback to be called in case of some error +-- @param {function} row_cb optional callback to be called on each parsed data row (instead of table insertion) +-- @return +-- {string} error message if exists +-- nil | {rows} | {http_response} +-- @example +-- +--]] +exports.select_sync = function(upstream, settings, params, query, row_cb) + local http_params = {} + + for k, v in pairs(params) do + http_params[k] = v + end + + http_params.gzip = settings.use_gzip + http_params.mime_type = 'text/plain' + http_params.timeout = settings.timeout or default_timeout + http_params.no_ssl_verify = settings.no_ssl_verify + http_params.user = settings.user + http_params.password = settings.password + http_params.body = query + http_params.log_obj = params.task or params.config + http_params.opaque_body = true + + lua_util.debugm(N, http_params.log_obj, "clickhouse select request: %s", http_params.body) + + if not http_params.url then + local connect_prefix = "http://" + if settings.use_https then + connect_prefix = 'https://' + end + local ip_addr = upstream:get_addr():to_string(true) + local database = settings.database or 'default' + http_params.url = string.format('%s%s/?database=%s&default_format=JSONEachRow', + connect_prefix, ip_addr, escape_spaces(database)) + end + + local err, response = rspamd_http.request(http_params) + + if err then + return err, nil + elseif response.code ~= 200 then + return response.content, response + else + lua_util.debugm(N, http_params.log_obj, "clickhouse select response: %1", response) + local rows = parse_clickhouse_response_json_eachrow(params, response.content, row_cb) + return nil, rows + end +end + +--[[[ +-- @function lua_clickhouse.insert(upstream, settings, params, query, rows, + ok_cb, fail_cb) +-- Insert data rows to clickhouse +-- @param {upstream} upstream clickhouse server upstream +-- @param {table} settings global settings table: +-- * use_gsip: use gzip compression +-- * timeout: request timeout +-- * no_ssl_verify: skip SSL verification +-- * user: HTTP user +-- * password: HTTP password +-- @param {params} HTTP request params +-- @param {string} query select query (passed in `query` request element with spaces escaped) +-- @param {table|mixed} rows mix of strings, numbers or tables (for arrays) +-- @param {function} ok_cb callback to be called in case of success +-- @param {function} fail_cb callback to be called in case of some error +-- @return {boolean} whether a connection was successful +-- @example +-- +--]] +exports.insert = function(upstream, settings, params, query, rows, + ok_cb, fail_cb) + local http_params = {} + + for k, v in pairs(params) do + http_params[k] = v + end + + http_params.callback = mk_http_insert_cb(upstream, http_params, ok_cb, fail_cb) + http_params.gzip = settings.use_gzip + http_params.mime_type = 'text/plain' + http_params.timeout = settings.timeout or default_timeout + http_params.no_ssl_verify = settings.no_ssl_verify + http_params.user = settings.user + http_params.password = settings.password + http_params.method = 'POST' + http_params.body = { rspamd_text.fromtable(rows, '\n'), '\n' } + http_params.log_obj = params.task or params.config + + if not http_params.url then + local connect_prefix = "http://" + if settings.use_https then + connect_prefix = 'https://' + end + local ip_addr = upstream:get_addr():to_string(true) + local database = settings.database or 'default' + http_params.url = string.format('%s%s/?database=%s&query=%s%%20FORMAT%%20TabSeparated', + connect_prefix, + ip_addr, + escape_spaces(database), + escape_spaces(query)) + end + + return rspamd_http.request(http_params) +end + +--[[[ +-- @function lua_clickhouse.generic(upstream, settings, params, query, + ok_cb, fail_cb) +-- Make a generic request to Clickhouse (e.g. alter) +-- @param {upstream} upstream clickhouse server upstream +-- @param {table} settings global settings table: +-- * use_gsip: use gzip compression +-- * timeout: request timeout +-- * no_ssl_verify: skip SSL verification +-- * user: HTTP user +-- * password: HTTP password +-- @param {params} HTTP request params +-- @param {string} query Clickhouse query (passed in `query` request element with spaces escaped) +-- @param {function} ok_cb callback to be called in case of success +-- @param {function} fail_cb callback to be called in case of some error +-- @return {boolean} whether a connection was successful +-- @example +-- +--]] +exports.generic = function(upstream, settings, params, query, + ok_cb, fail_cb) + local http_params = {} + + for k, v in pairs(params) do + http_params[k] = v + end + + http_params.callback = mk_http_insert_cb(upstream, http_params, ok_cb, fail_cb) + http_params.gzip = settings.use_gzip + http_params.mime_type = 'text/plain' + http_params.timeout = settings.timeout or default_timeout + http_params.no_ssl_verify = settings.no_ssl_verify + http_params.user = settings.user + http_params.password = settings.password + http_params.log_obj = params.task or params.config + http_params.body = query + + if not http_params.url then + local connect_prefix = "http://" + if settings.use_https then + connect_prefix = 'https://' + end + local ip_addr = upstream:get_addr():to_string(true) + local database = settings.database or 'default' + http_params.url = string.format('%s%s/?database=%s&default_format=JSONEachRow', + connect_prefix, ip_addr, escape_spaces(database)) + end + + return rspamd_http.request(http_params) +end + +--[[[ +-- @function lua_clickhouse.generic_sync(upstream, settings, params, query, + ok_cb, fail_cb) +-- Make a generic request to Clickhouse (e.g. alter) +-- @param {upstream} upstream clickhouse server upstream +-- @param {table} settings global settings table: +-- * use_gsip: use gzip compression +-- * timeout: request timeout +-- * no_ssl_verify: skip SSL verification +-- * user: HTTP user +-- * password: HTTP password +-- @param {params} HTTP request params +-- @param {string} query Clickhouse query (passed in `query` request element with spaces escaped) +-- @return {boolean} whether a connection was successful +-- @example +-- +--]] +exports.generic_sync = function(upstream, settings, params, query) + local http_params = {} + + for k, v in pairs(params) do + http_params[k] = v + end + + http_params.gzip = settings.use_gzip + http_params.mime_type = 'text/plain' + http_params.timeout = settings.timeout or default_timeout + http_params.no_ssl_verify = settings.no_ssl_verify + http_params.user = settings.user + http_params.password = settings.password + http_params.log_obj = params.task or params.config + http_params.body = query + + if not http_params.url then + local connect_prefix = "http://" + if settings.use_https then + connect_prefix = 'https://' + end + local ip_addr = upstream:get_addr():to_string(true) + local database = settings.database or 'default' + http_params.url = string.format('%s%s/?database=%s&default_format=JSON', + connect_prefix, ip_addr, escape_spaces(database)) + end + + local err, response = rspamd_http.request(http_params) + + if err then + return err, nil + elseif response.code ~= 200 then + return response.content, response + else + lua_util.debugm(N, http_params.log_obj, "clickhouse generic response: %1", response) + local e, obj = parse_clickhouse_response_json(params, response.content) + + if e then + return e, nil + end + return nil, obj + end +end + +return exports diff --git a/lualib/lua_content/ical.lua b/lualib/lua_content/ical.lua new file mode 100644 index 0000000..d018a85 --- /dev/null +++ b/lualib/lua_content/ical.lua @@ -0,0 +1,105 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local l = require 'lpeg' +local lua_util = require "lua_util" +local N = "lua_content" + +local ical_grammar + +local function gen_grammar() + if not ical_grammar then + local wsp = l.S(" \t\v\f") + local crlf = (l.P "\r" ^ -1 * l.P "\n") + l.P "\r" + local eol = (crlf * #crlf) + (crlf - (crlf ^ -1 * wsp)) + local name = l.C((l.P(1) - (l.P ":")) ^ 1) / function(v) + return (v:gsub("[\n\r]+%s", "")) + end + local value = l.C((l.P(1) - eol) ^ 0) / function(v) + return (v:gsub("[\n\r]+%s", "")) + end + ical_grammar = name * ":" * wsp ^ 0 * value * eol ^ -1 + end + + return ical_grammar +end + +local exports = {} + +local function extract_text_data(specific) + local fun = require "fun" + + local tbl = fun.totable(fun.map(function(e) + return e[2]:lower() + end, specific.elts)) + return table.concat(tbl, '\n') +end + + +-- Keys that can have visible urls +local url_keys = lua_util.list_to_hash { + 'description', + 'location', + 'summary', + 'organizer', + 'organiser', + 'attendee', + 'url' +} + +local function process_ical(input, mpart, task) + local control = { n = '\n', r = '' } + local rspamd_url = require "rspamd_url" + local escaper = l.Ct((gen_grammar() / function(key, value) + value = value:gsub("\\(.)", control) + key = key:lower():match('^([^;]+)') + + if key and url_keys[key] then + local local_urls = rspamd_url.all(task:get_mempool(), value) + + if local_urls and #local_urls > 0 then + for _, u in ipairs(local_urls) do + lua_util.debugm(N, task, 'ical: found URL in ical key "%s": %s', + key, tostring(u)) + task:inject_url(u, mpart) + end + end + end + lua_util.debugm(N, task, 'ical: ical key %s = "%s"', + key, value) + return { key, value } + end) ^ 1) + + local elts = escaper:match(input) + + if not elts then + return nil + end + + return { + tag = 'ical', + extract_text = extract_text_data, + elts = elts + } +end + +--[[[ +-- @function lua_ical.process(input) +-- Returns all values from ical as a plain text. Names are completely ignored. +--]] +exports.process = process_ical + +return exports
\ No newline at end of file diff --git a/lualib/lua_content/init.lua b/lualib/lua_content/init.lua new file mode 100644 index 0000000..701d223 --- /dev/null +++ b/lualib/lua_content/init.lua @@ -0,0 +1,109 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_content +-- This module contains content processing logic +--]] + + +local exports = {} +local N = "lua_content" +local lua_util = require "lua_util" + +local content_modules = { + ical = { + mime_type = { "text/calendar", "application/calendar" }, + module = require "lua_content/ical", + extensions = { 'ics' }, + output = "text" + }, + vcf = { + mime_type = { "text/vcard", "application/vcard" }, + module = require "lua_content/vcard", + extensions = { 'vcf' }, + output = "text" + }, + pdf = { + mime_type = "application/pdf", + module = require "lua_content/pdf", + extensions = { 'pdf' }, + output = "table" + }, +} + +local modules_by_mime_type +local modules_by_extension + +local function init() + modules_by_mime_type = {} + modules_by_extension = {} + for k, v in pairs(content_modules) do + if v.mime_type then + if type(v.mime_type) == 'table' then + for _, mt in ipairs(v.mime_type) do + modules_by_mime_type[mt] = { k, v } + end + else + modules_by_mime_type[v.mime_type] = { k, v } + end + + end + if v.extensions then + for _, ext in ipairs(v.extensions) do + modules_by_extension[ext] = { k, v } + end + end + end +end + +exports.maybe_process_mime_part = function(part, task) + if not modules_by_mime_type then + init() + end + + local ctype, csubtype = part:get_type() + local mt = string.format("%s/%s", ctype or 'application', + csubtype or 'octet-stream') + local pair = modules_by_mime_type[mt] + + if not pair then + local ext = part:get_detected_ext() + + if ext then + pair = modules_by_extension[ext] + end + end + + if pair then + lua_util.debugm(N, task, "found known content of type %s: %s", + mt, pair[1]) + + local data = pair[2].module.process(part:get_content(), part, task) + + if data then + lua_util.debugm(N, task, "extracted content from %s: %s type", + pair[1], type(data)) + part:set_specific(data) + else + lua_util.debugm(N, task, "failed to extract anything from %s", + pair[1]) + end + end + +end + +return exports
\ No newline at end of file diff --git a/lualib/lua_content/pdf.lua b/lualib/lua_content/pdf.lua new file mode 100644 index 0000000..f6d5c0b --- /dev/null +++ b/lualib/lua_content/pdf.lua @@ -0,0 +1,1424 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_content/pdf +-- This module contains some heuristics for PDF files +--]] + +local rspamd_trie = require "rspamd_trie" +local rspamd_util = require "rspamd_util" +local rspamd_text = require "rspamd_text" +local rspamd_url = require "rspamd_url" +local bit = require "bit" +local N = "lua_content" +local lua_util = require "lua_util" +local rspamd_regexp = require "rspamd_regexp" +local lpeg = require "lpeg" +local pdf_patterns = { + trailer = { + patterns = { + [[\ntrailer\r?\n]] + } + }, + suspicious = { + patterns = { + [[netsh\s]], + [[echo\s]], + [=[\/[A-Za-z]*#\d\d[#A-Za-z<>/\s]]=], -- Hex encode obfuscation + } + }, + start_object = { + patterns = { + [=[[\r\n\0]\s*\d+\s+\d+\s+obj[\s<]]=] + } + }, + end_object = { + patterns = { + [=[endobj[\r\n]]=] + } + }, + start_stream = { + patterns = { + [=[>\s*stream[\r\n]]=], + } + }, + end_stream = { + patterns = { + [=[endstream[\r\n]]=] + } + } +} + +local pdf_text_patterns = { + start = { + patterns = { + [[\sBT\s]] + } + }, + stop = { + patterns = { + [[\sET\b]] + } + } +} + +local pdf_cmap_patterns = { + start = { + patterns = { + [[\d\s+beginbfchar\s]], + [[\d\s+beginbfrange\s]] + } + }, + stop = { + patterns = { + [[\sendbfrange\b]], + [[\sendbchar\b]] + } + } +} + +-- index[n] -> +-- t[1] - pattern, +-- t[2] - key in patterns table, +-- t[3] - value in patterns table +-- t[4] - local pattern index +local pdf_indexes = {} +local pdf_text_indexes = {} +local pdf_cmap_indexes = {} + +local pdf_trie +local pdf_text_trie +local pdf_cmap_trie + +local exports = {} + +local config = { + max_extraction_size = 512 * 1024, + max_processing_size = 32 * 1024, + text_extraction = false, -- NYI feature + url_extraction = true, + enabled = true, + js_fuzzy = true, -- Generate fuzzy hashes from PDF javascripts + min_js_fuzzy = 256, -- Minimum size of js to be considered as a fuzzy + openaction_fuzzy_only = false, -- Generate fuzzy from all scripts + max_pdf_objects = 10000, -- Maximum number of objects to be considered + max_pdf_trailer = 10 * 1024 * 1024, -- Maximum trailer size (to avoid abuse) + max_pdf_trailer_lines = 100, -- Maximum number of lines in pdf trailer + pdf_process_timeout = 1.0, -- Timeout in seconds for processing +} + +-- Used to process patterns found in PDF +-- positions for functional processors should be a iter/table from trie matcher in form +---- [{n1, pat_idx1}, ... {nn, pat_idxn}] where +---- pat_idxn is pattern index and n1 ... nn are match positions +local processors = {} +-- PDF objects outer grammar in LPEG style (performing table captures) +local pdf_outer_grammar +local pdf_text_grammar + +-- Used to match objects +local object_re = rspamd_regexp.create_cached([=[/(\d+)\s+(\d+)\s+obj\s*/]=]) + +local function config_module() + local opts = rspamd_config:get_all_opt('lua_content') + + if opts and opts.pdf then + config = lua_util.override_defaults(config, opts.pdf) + end +end + +local function compile_tries() + local default_compile_flags = bit.bor(rspamd_trie.flags.re, + rspamd_trie.flags.dot_all, + rspamd_trie.flags.no_start) + local function compile_pats(patterns, indexes, compile_flags) + local strs = {} + for what, data in pairs(patterns) do + for i, pat in ipairs(data.patterns) do + strs[#strs + 1] = pat + indexes[#indexes + 1] = { what, data, pat, i } + end + end + + return rspamd_trie.create(strs, compile_flags or default_compile_flags) + end + + if not pdf_trie then + pdf_trie = compile_pats(pdf_patterns, pdf_indexes) + end + if not pdf_text_trie then + pdf_text_trie = compile_pats(pdf_text_patterns, pdf_text_indexes) + end + if not pdf_cmap_trie then + pdf_cmap_trie = compile_pats(pdf_cmap_patterns, pdf_cmap_indexes) + end +end + +-- Returns a table with generic grammar elements for PDF +local function generic_grammar_elts() + local P = lpeg.P + local R = lpeg.R + local S = lpeg.S + local V = lpeg.V + local C = lpeg.C + local D = R '09' -- Digits + + local grammar_elts = {} + + -- Helper functions + local function pdf_hexstring_unescape(s) + if #s % 2 == 0 then + -- Sane hex string + return lua_util.unhex(s) + end + + -- WTF hex string + -- Append '0' to it and unescape... + return lua_util.unhex(s:sub(1, #s - 1)) .. lua_util.unhex((s:sub(#s) .. '0')) + end + + local function pdf_string_unescape(s) + local function ue_single(cc) + if cc == '\\r' then + return '\r' + elseif cc == '\\n' then + return '\n' + else + return cc:gsub(2, 2) + end + end + -- simple unescape \char + s = s:gsub('\\[^%d]', ue_single) + -- unescape octal + local function ue_octal(cc) + -- Replace unknown stuff with '?' + return string.char(tonumber(cc:sub(2), 8) or 63) + end + s = s:gsub('\\%d%d?%d?', ue_octal) + + return s + end + + local function pdf_id_unescape(s) + return (s:gsub('#%d%d', function(cc) + return string.char(tonumber(cc:sub(2), 16)) + end)) + end + + local delim = S '()<>[]{}/%' + grammar_elts.ws = S '\0 \r\n\t\f' + local hex = R 'af' + R 'AF' + D + -- Comments. + local eol = P '\r\n' + '\n' + local line = (1 - S '\r\n\f') ^ 0 * eol ^ -1 + grammar_elts.comment = P '%' * line + + -- Numbers. + local sign = S '+-' ^ -1 + local decimal = D ^ 1 + local float = D ^ 1 * P '.' * D ^ 0 + P '.' * D ^ 1 + grammar_elts.number = C(sign * (float + decimal)) / tonumber + + -- String + grammar_elts.str = P { "(" * C(((1 - S "()\\") + (P '\\' * 1) + V(1)) ^ 0) / pdf_string_unescape * ")" } + grammar_elts.hexstr = P { "<" * C(hex ^ 0) / pdf_hexstring_unescape * ">" } + + -- Identifier + grammar_elts.id = P { '/' * C((1 - (delim + grammar_elts.ws)) ^ 1) / pdf_id_unescape } + + -- Booleans (who care about them?) + grammar_elts.boolean = C(P("true") + P("false")) + + -- Stupid references + grammar_elts.ref = lpeg.Ct { lpeg.Cc("%REF%") * C(D ^ 1) * " " * C(D ^ 1) * " " * "R" } + + return grammar_elts +end + + +-- Generates a grammar to parse outer elements (external objects in PDF notation) +local function gen_outer_grammar() + local V = lpeg.V + local gen = generic_grammar_elts() + + return lpeg.P { + "EXPR"; + EXPR = gen.ws ^ 0 * V("ELT") ^ 0 * gen.ws ^ 0, + ELT = V("ARRAY") + V("DICT") + V("ATOM"), + ATOM = gen.ws ^ 0 * (gen.comment + gen.boolean + gen.ref + + gen.number + V("STRING") + gen.id) * gen.ws ^ 0, + DICT = "<<" * gen.ws ^ 0 * lpeg.Cf(lpeg.Ct("") * V("KV_PAIR") ^ 0, rawset) * gen.ws ^ 0 * ">>", + KV_PAIR = lpeg.Cg(gen.id * gen.ws ^ 0 * V("ELT") * gen.ws ^ 0), + ARRAY = "[" * gen.ws ^ 0 * lpeg.Ct(V("ELT") ^ 0) * gen.ws ^ 0 * "]", + STRING = lpeg.P { gen.str + gen.hexstr }, + } +end + +-- Graphic state in PDF +local function gen_graphics_unary() + local P = lpeg.P + local S = lpeg.S + + return P("q") + P("Q") + P("h") + + S("WSsFfBb") * P("*") ^ 0 + P("n") + +end +local function gen_graphics_binary() + local P = lpeg.P + local S = lpeg.S + + return S("gGwJjMi") + + P("M") + P("ri") + P("gs") + + P("CS") + P("cs") + P("sh") +end +local function gen_graphics_ternary() + local P = lpeg.P + local S = lpeg.S + + return P("d") + P("m") + S("lm") +end +local function gen_graphics_nary() + local P = lpeg.P + local S = lpeg.S + + return P("SC") + P("sc") + P("SCN") + P("scn") + P("k") + P("K") + P("re") + S("cvy") + + P("RG") + P("rg") +end + +-- Generates a grammar to parse text blocks (between BT and ET) +local function gen_text_grammar() + local V = lpeg.V + local P = lpeg.P + local C = lpeg.C + local gen = generic_grammar_elts() + + local empty = "" + local unary_ops = C("T*") / "\n" + + C(gen_graphics_unary()) / empty + local binary_ops = P("Tc") + P("Tw") + P("Tz") + P("TL") + P("Tr") + P("Ts") + + gen_graphics_binary() + local ternary_ops = P("TD") + P("Td") + gen_graphics_ternary() + local nary_op = P("Tm") + gen_graphics_nary() + local text_binary_op = P("Tj") + P("TJ") + P("'") + local text_quote_op = P('"') + local font_op = P("Tf") + + return lpeg.P { + "EXPR"; + EXPR = gen.ws ^ 0 * lpeg.Ct(V("COMMAND") ^ 0), + COMMAND = (V("UNARY") + V("BINARY") + V("TERNARY") + V("NARY") + V("TEXT") + + V("FONT") + gen.comment) * gen.ws ^ 0, + UNARY = unary_ops, + BINARY = V("ARG") / empty * gen.ws ^ 1 * binary_ops, + TERNARY = V("ARG") / empty * gen.ws ^ 1 * V("ARG") / empty * gen.ws ^ 1 * ternary_ops, + NARY = (gen.number / 0 * gen.ws ^ 1) ^ 1 * (gen.id / empty * gen.ws ^ 0) ^ -1 * nary_op, + ARG = V("ARRAY") + V("DICT") + V("ATOM"), + ATOM = (gen.comment + gen.boolean + gen.ref + + gen.number + V("STRING") + gen.id), + DICT = "<<" * gen.ws ^ 0 * lpeg.Cf(lpeg.Ct("") * V("KV_PAIR") ^ 0, rawset) * gen.ws ^ 0 * ">>", + KV_PAIR = lpeg.Cg(gen.id * gen.ws ^ 0 * V("ARG") * gen.ws ^ 0), + ARRAY = "[" * gen.ws ^ 0 * lpeg.Ct(V("ARG") ^ 0) * gen.ws ^ 0 * "]", + STRING = lpeg.P { gen.str + gen.hexstr }, + TEXT = (V("TEXT_ARG") * gen.ws ^ 1 * text_binary_op) + + (V("ARG") / 0 * gen.ws ^ 1 * V("ARG") / 0 * gen.ws ^ 1 * V("TEXT_ARG") * gen.ws ^ 1 * text_quote_op), + FONT = (V("FONT_ARG") * gen.ws ^ 1 * (gen.number / 0) * gen.ws ^ 1 * font_op), + FONT_ARG = lpeg.Ct(lpeg.Cc("%font%") * gen.id), + TEXT_ARG = lpeg.Ct(V("STRING")) + V("TEXT_ARRAY"), + TEXT_ARRAY = "[" * + lpeg.Ct(((gen.ws ^ 0 * (gen.ws ^ 0 * (gen.number / 0) ^ 0 * gen.ws ^ 0 * (gen.str + gen.hexstr))) ^ 1)) * gen.ws ^ 0 * "]", + } +end + + +-- Call immediately on require +compile_tries() +config_module() +pdf_outer_grammar = gen_outer_grammar() +pdf_text_grammar = gen_text_grammar() + +local function extract_text_data(specific) + return nil -- NYI +end + +-- Generates index for major/minor pair +local function obj_ref(major, minor) + return major * 10.0 + 1.0 / (minor + 1.0) +end + +-- Return indirect object reference (if needed) +local function maybe_dereference_object(elt, pdf, task) + if type(elt) == 'table' and elt[1] == '%REF%' then + local ref = obj_ref(elt[2], elt[3]) + + if pdf.ref[ref] then + -- No recursion! + return pdf.ref[ref] + else + lua_util.debugm(N, task, 'cannot dereference %s:%s -> %s, no object', + elt[2], elt[3], obj_ref(elt[2], elt[3])) + return nil + end + end + + return elt +end + +-- Apply PDF stream filter +local function apply_pdf_filter(input, filt) + if filt == 'FlateDecode' then + return rspamd_util.inflate(input, config.max_extraction_size) + end + + return nil +end + +-- Conditionally apply a pipeline of stream filters and return uncompressed data +local function maybe_apply_filter(dict, data, pdf, task) + local uncompressed = data + + if dict.Filter then + local filt = dict.Filter + if type(filt) == 'string' then + filt = { filt } + end + + if dict.DecodeParms then + local decode_params = maybe_dereference_object(dict.DecodeParms, pdf, task) + + if type(decode_params) == 'table' then + if decode_params.Predictor then + return nil, 'predictor exists' + end + end + end + + for _, f in ipairs(filt) do + uncompressed = apply_pdf_filter(uncompressed, f) + + if not uncompressed then + break + end + end + end + + return uncompressed, nil +end + +-- Conditionally extract stream data from object and attach it as obj.uncompressed +local function maybe_extract_object_stream(obj, pdf, task) + if pdf.encrypted then + -- TODO add decryption some day + return nil + end + local dict = obj.dict + if dict.Length and type(obj.stream) == 'table' then + local len = math.min(obj.stream.len, + tonumber(maybe_dereference_object(dict.Length, pdf, task)) or 0) + if len > 0 then + local real_stream = obj.stream.data:span(1, len) + + local uncompressed, filter_err = maybe_apply_filter(dict, real_stream, pdf, task) + + if uncompressed then + obj.uncompressed = uncompressed + lua_util.debugm(N, task, 'extracted object %s:%s: (%s -> %s)', + obj.major, obj.minor, len, uncompressed:len()) + return obj.uncompressed + else + lua_util.debugm(N, task, 'cannot extract object %s:%s; len = %s; filter = %s: %s', + obj.major, obj.minor, len, dict.Filter, filter_err) + end + else + lua_util.debugm(N, task, 'cannot extract object %s:%s; len = %s', + obj.major, obj.minor, len) + end + end +end + +local function parse_object_grammar(obj, task, pdf) + -- Parse grammar + local obj_dict_span + if obj.stream then + obj_dict_span = obj.data:span(1, obj.stream.start - obj.start) + else + obj_dict_span = obj.data + end + + if obj_dict_span:len() < config.max_processing_size then + local ret, obj_or_err = pcall(pdf_outer_grammar.match, pdf_outer_grammar, obj_dict_span) + + if ret then + if obj.stream then + if type(obj_or_err) == 'table' then + obj.dict = obj_or_err + else + obj.dict = {} + end + + lua_util.debugm(N, task, 'stream object %s:%s is parsed to: %s', + obj.major, obj.minor, obj_or_err) + else + -- Direct object + if type(obj_or_err) == 'table' then + obj.dict = obj_or_err + obj.uncompressed = obj_or_err + lua_util.debugm(N, task, 'direct object %s:%s is parsed to: %s', + obj.major, obj.minor, obj_or_err) + pdf.ref[obj_ref(obj.major, obj.minor)] = obj + else + lua_util.debugm(N, task, 'direct object %s:%s is parsed to raw data: %s', + obj.major, obj.minor, obj_or_err) + pdf.ref[obj_ref(obj.major, obj.minor)] = obj_or_err + obj.dict = {} + obj.uncompressed = obj_or_err + end + end + else + lua_util.debugm(N, task, 'object %s:%s cannot be parsed: %s', + obj.major, obj.minor, obj_or_err) + end + else + lua_util.debugm(N, task, 'object %s:%s cannot be parsed: too large %s', + obj.major, obj.minor, obj_dict_span:len()) + end +end + +-- Extracts font data and process /ToUnicode mappings +-- NYI in fact as cmap is ridiculously stupid and complicated +--[[ +local function process_font(task, pdf, font, fname) + local dict = font + if font.dict then + dict = font.dict + end + + if type(dict) == 'table' and dict.ToUnicode then + local cmap = maybe_dereference_object(dict.ToUnicode, pdf, task) + + if cmap and cmap.dict then + maybe_extract_object_stream(cmap, pdf, task) + lua_util.debugm(N, task, 'found cmap for font %s: %s', + fname, cmap.uncompressed) + end + end +end +--]] + +-- Forward declaration +local process_dict + +-- This function processes javascript string and returns JS hash and JS rspamd_text +local function process_javascript(task, pdf, js, obj) + local rspamd_cryptobox_hash = require "rspamd_cryptobox_hash" + if type(js) == 'string' then + js = rspamd_text.fromstring(js):oneline() + elseif type(js) == 'userdata' then + js = js:oneline() + else + return nil + end + + local hash = rspamd_cryptobox_hash.create(js) + local bin_hash = hash:bin() + + if not pdf.scripts then + pdf.scripts = {} + end + + if pdf.scripts[bin_hash] then + -- Duplicate + return pdf.scripts[bin_hash] + end + + local njs = { + data = js, + hash = hash:hex(), + bin_hash = bin_hash, + object = obj, + } + pdf.scripts[bin_hash] = njs + return njs +end + +-- Extract interesting stuff from /Action, e.g. javascript +local function process_action(task, pdf, obj) + if not (obj.js or obj.launch) and (obj.dict and obj.dict.JS) then + local js = maybe_dereference_object(obj.dict.JS, pdf, task) + + if js then + if type(js) == 'table' then + local extracted_js = maybe_extract_object_stream(js, pdf, task) + + if not extracted_js then + lua_util.debugm(N, task, 'invalid type for JavaScript from %s:%s: %s', + obj.major, obj.minor, js) + else + js = extracted_js + end + end + + js = process_javascript(task, pdf, js, obj) + if js then + obj.js = js + lua_util.debugm(N, task, 'extracted javascript from %s:%s: %s', + obj.major, obj.minor, obj.js.data) + else + lua_util.debugm(N, task, 'invalid type for JavaScript from %s:%s: %s', + obj.major, obj.minor, js) + end + elseif obj.dict.F then + local launch = maybe_dereference_object(obj.dict.F, pdf, task) + + if launch then + if type(launch) == 'string' then + obj.launch = rspamd_text.fromstring(launch):exclude_chars('%n%c') + lua_util.debugm(N, task, 'extracted launch from %s:%s: %s', + obj.major, obj.minor, obj.launch) + elseif type(launch) == 'userdata' then + obj.launch = launch:exclude_chars('%n%c') + lua_util.debugm(N, task, 'extracted launch from %s:%s: %s', + obj.major, obj.minor, obj.launch) + else + lua_util.debugm(N, task, 'invalid type for launch from %s:%s: %s', + obj.major, obj.minor, launch) + end + end + else + + lua_util.debugm(N, task, 'no JS attribute in action %s:%s', + obj.major, obj.minor) + end + end +end + +-- Extract interesting stuff from /Catalog, e.g. javascript in /OpenAction +local function process_catalog(task, pdf, obj) + if obj.dict then + if obj.dict.OpenAction then + local action = maybe_dereference_object(obj.dict.OpenAction, pdf, task) + + if action and type(action) == 'table' then + -- This also processes action js (if not already processed) + process_dict(task, pdf, action, action.dict) + if action.js then + lua_util.debugm(N, task, 'found openaction JS in %s:%s: %s', + obj.major, obj.minor, action.js) + pdf.openaction = action.js + action.js.object = obj + elseif action.launch then + lua_util.debugm(N, task, 'found openaction launch in %s:%s: %s', + obj.major, obj.minor, action.launch) + pdf.launch = action.launch + else + lua_util.debugm(N, task, 'no JS in openaction %s:%s: %s', + obj.major, obj.minor, action) + end + else + lua_util.debugm(N, task, 'cannot find openaction %s:%s: %s -> %s', + obj.major, obj.minor, obj.dict.OpenAction, action) + end + else + lua_util.debugm(N, task, 'no openaction in catalog %s:%s', + obj.major, obj.minor) + end + end +end + +local function process_xref(task, pdf, obj) + if obj.dict then + if obj.dict.Encrypt then + local encrypt = maybe_dereference_object(obj.dict.Encrypt, pdf, task) + lua_util.debugm(N, task, 'found encrypt: %s in xref object %s:%s', + encrypt, obj.major, obj.minor) + pdf.encrypted = true + end + end +end + +process_dict = function(task, pdf, obj, dict) + if not obj.type and type(dict) == 'table' then + if dict.Type and type(dict.Type) == 'string' then + -- Common stuff + obj.type = dict.Type + end + + if not obj.type then + + if obj.dict.S and obj.dict.JS then + obj.type = 'Javascript' + lua_util.debugm(N, task, 'implicit type for JavaScript object %s:%s', + obj.major, obj.minor) + else + lua_util.debugm(N, task, 'no type for %s:%s', + obj.major, obj.minor) + return + end + end + + lua_util.debugm(N, task, 'processed stream dictionary for object %s:%s -> %s', + obj.major, obj.minor, obj.type) + local contents = dict.Contents + if contents and type(contents) == 'table' then + if contents[1] == '%REF%' then + -- Single reference + contents = { contents } + end + obj.contents = {} + + for _, c in ipairs(contents) do + local cobj = maybe_dereference_object(c, pdf, task) + if cobj and type(cobj) == 'table' then + obj.contents[#obj.contents + 1] = cobj + cobj.parent = obj + cobj.type = 'content' + end + end + + lua_util.debugm(N, task, 'found content objects for %s:%s -> %s', + obj.major, obj.minor, #obj.contents) + end + + local resources = dict.Resources + if resources and type(resources) == 'table' then + local res_ref = maybe_dereference_object(resources, pdf, task) + + if type(res_ref) ~= 'table' then + lua_util.debugm(N, task, 'cannot parse resources from pdf: %s', + resources) + obj.resources = {} + elseif res_ref.dict then + obj.resources = res_ref.dict + else + obj.resources = {} + end + else + -- Fucking pdf: we need to inherit from parent + resources = {} + if dict.Parent then + local parent = maybe_dereference_object(dict.Parent, pdf, task) + + if parent and type(parent) == 'table' and parent.dict then + if parent.resources then + lua_util.debugm(N, task, 'propagated resources from %s:%s to %s:%s', + parent.major, parent.minor, obj.major, obj.minor) + resources = parent.resources + end + end + end + + obj.resources = resources + end + + + + --[[Disabled fonts extraction + local fonts = obj.resources.Font + if fonts and type(fonts) == 'table' then + obj.fonts = {} + for k,v in pairs(fonts) do + obj.fonts[k] = maybe_dereference_object(v, pdf, task) + + if obj.fonts[k] then + local font = obj.fonts[k] + + if config.text_extraction then + process_font(task, pdf, font, k) + lua_util.debugm(N, task, 'found font "%s" for object %s:%s -> %s', + k, obj.major, obj.minor, font) + end + end + end + end + ]] + + lua_util.debugm(N, task, 'found resources for object %s:%s (%s): %s', + obj.major, obj.minor, obj.type, obj.resources) + + if obj.type == 'Action' then + process_action(task, pdf, obj) + elseif obj.type == 'Catalog' then + process_catalog(task, pdf, obj) + elseif obj.type == 'XRef' then + -- XRef stream instead of trailer from PDF 1.5 (thanks Adobe) + process_xref(task, pdf, obj) + elseif obj.type == 'Javascript' then + local js = maybe_dereference_object(obj.dict.JS, pdf, task) + + if js then + if type(js) == 'table' then + local extracted_js = maybe_extract_object_stream(js, pdf, task) + + if not extracted_js then + lua_util.debugm(N, task, 'invalid type for JavaScript from %s:%s: %s', + obj.major, obj.minor, js) + else + js = extracted_js + end + end + + js = process_javascript(task, pdf, js, obj) + if js then + obj.js = js + lua_util.debugm(N, task, 'extracted javascript from %s:%s: %s', + obj.major, obj.minor, obj.js.data) + else + lua_util.debugm(N, task, 'invalid type for JavaScript from %s:%s: %s', + obj.major, obj.minor, js) + end + end + end + end -- Already processed dict (obj.type is not empty) +end + +-- This function is intended to unpack objects from ObjStm crappy structure +local compound_obj_grammar +local function compound_obj_grammar_gen() + if not compound_obj_grammar then + local gen = generic_grammar_elts() + compound_obj_grammar = gen.ws ^ 0 * (gen.comment * gen.ws ^ 1) ^ 0 * + lpeg.Ct(lpeg.Ct(gen.number * gen.ws ^ 1 * gen.number * gen.ws ^ 0) ^ 1) + end + + return compound_obj_grammar +end +local function pdf_compound_object_unpack(_, uncompressed, pdf, task, first) + -- First, we need to parse data line by line likely to find a line + -- that consists of pairs of numbers + compound_obj_grammar_gen() + local elts = compound_obj_grammar:match(uncompressed) + if elts and #elts > 0 then + lua_util.debugm(N, task, 'compound elts (chunk length %s): %s', + #uncompressed, elts) + + for i, pair in ipairs(elts) do + local obj_number, offset = pair[1], pair[2] + + offset = offset + first + if offset < #uncompressed then + local span_len + if i == #elts then + span_len = #uncompressed - offset + else + span_len = (elts[i + 1][2] + first) - offset + end + + if span_len > 0 and offset + span_len <= #uncompressed then + local obj = { + major = obj_number, + minor = 0, -- Implicit + data = uncompressed:span(offset + 1, span_len), + ref = obj_ref(obj_number, 0) + } + parse_object_grammar(obj, task, pdf) + + if obj.dict then + pdf.objects[#pdf.objects + 1] = obj + end + else + lua_util.debugm(N, task, 'invalid span_len for compound object %s:%s; offset = %s, len = %s', + pair[1], pair[2], offset + span_len, #uncompressed) + end + end + end + end +end + +-- PDF 1.5 ObjStmt +local function extract_pdf_compound_objects(task, pdf) + for i, obj in ipairs(pdf.objects or {}) do + if i > 0 and i % 100 == 0 then + local now = rspamd_util.get_ticks() + + if now >= pdf.end_timestamp then + pdf.timeout_processing = now - pdf.start_timestamp + + lua_util.debugm(N, task, 'pdf: timeout processing compound objects after spending %s seconds, ' .. + '%s elements processed', + pdf.timeout_processing, i) + break + end + end + if obj.stream and obj.dict and type(obj.dict) == 'table' then + local t = obj.dict.Type + if t and t == 'ObjStm' then + -- We are in troubles sir... + local nobjs = tonumber(maybe_dereference_object(obj.dict.N, pdf, task)) + local first = tonumber(maybe_dereference_object(obj.dict.First, pdf, task)) + + if nobjs and first then + --local extend = maybe_dereference_object(obj.dict.Extends, pdf, task) + lua_util.debugm(N, task, 'extract ObjStm with %s objects (%s first) %s extend', + nobjs, first, obj.dict.Extends) + + local uncompressed = maybe_extract_object_stream(obj, pdf, task) + + if uncompressed then + pdf_compound_object_unpack(obj, uncompressed, pdf, task, first) + end + else + lua_util.debugm(N, task, 'ObjStm object %s:%s has bad dict: %s', + obj.major, obj.minor, obj.dict) + end + end + end + end +end + +-- This function arranges starts and ends of all objects and process them into initial +-- set of objects +local function extract_outer_objects(task, input, pdf) + local start_pos, end_pos = 1, 1 + local max_start_pos, max_end_pos + local obj_count = 0 + + max_start_pos = math.min(config.max_pdf_objects, #pdf.start_objects) + max_end_pos = math.min(config.max_pdf_objects, #pdf.end_objects) + lua_util.debugm(N, task, "pdf: extract objects from %s start positions and %s end positions", + max_start_pos, max_end_pos) + + while start_pos <= max_start_pos and end_pos <= max_end_pos do + local first = pdf.start_objects[start_pos] + local last = pdf.end_objects[end_pos] + + -- 7 is length of `endobj\n` + if first + 6 < last then + local len = last - first - 6 + + -- Also get the starting span and try to match it versus obj re to get numbers + local obj_line_potential = first - 32 + if obj_line_potential < 1 then + obj_line_potential = 1 + end + local prev_obj_end = pdf.end_objects[end_pos - 1] + if end_pos > 1 and prev_obj_end >= obj_line_potential and prev_obj_end < first then + obj_line_potential = prev_obj_end + 1 + end + + local obj_line_span = input:span(obj_line_potential, first - obj_line_potential + 1) + local matches = object_re:search(obj_line_span, true, true) + + if matches and matches[1] then + local nobj = { + start = first, + len = len, + data = input:span(first, len), + major = tonumber(matches[1][2]), + minor = tonumber(matches[1][3]), + } + pdf.objects[obj_count + 1] = nobj + if nobj.major and nobj.minor then + -- Add reference + local ref = obj_ref(nobj.major, nobj.minor) + nobj.ref = ref -- Our internal reference + pdf.ref[ref] = nobj + end + end + + obj_count = obj_count + 1 + start_pos = start_pos + 1 + end_pos = end_pos + 1 + elseif first > last then + end_pos = end_pos + 1 + else + start_pos = start_pos + 1 + end_pos = end_pos + 1 + end + end +end + +-- This function attaches streams to objects and processes outer pdf grammar +local function attach_pdf_streams(task, input, pdf) + if pdf.start_streams and pdf.end_streams then + local start_pos, end_pos = 1, 1 + local max_start_pos, max_end_pos + + max_start_pos = math.min(config.max_pdf_objects, #pdf.start_streams) + max_end_pos = math.min(config.max_pdf_objects, #pdf.end_streams) + + for _, obj in ipairs(pdf.objects) do + while start_pos <= max_start_pos and end_pos <= max_end_pos do + local first = pdf.start_streams[start_pos] + local last = pdf.end_streams[end_pos] + last = last - 10 -- Exclude endstream\n pattern + lua_util.debugm(N, task, "start: %s, end: %s; obj: %s-%s", + first, last, obj.start, obj.start + obj.len) + if first > obj.start and last < obj.start + obj.len and last > first then + -- In case if we have fake endstream :( + while pdf.end_streams[end_pos + 1] and pdf.end_streams[end_pos + 1] < obj.start + obj.len do + end_pos = end_pos + 1 + last = pdf.end_streams[end_pos] + end + -- Strip the first \n + while first < last do + local chr = input:byte(first) + if chr ~= 13 and chr ~= 10 then + break + end + first = first + 1 + end + local len = last - first + obj.stream = { + start = first, + len = len, + data = input:span(first, len) + } + start_pos = start_pos + 1 + end_pos = end_pos + 1 + break + elseif first < obj.start then + start_pos = start_pos + 1 + elseif last > obj.start + obj.len then + -- Not this object + break + else + start_pos = start_pos + 1 + end_pos = end_pos + 1 + end + end + if obj.stream then + lua_util.debugm(N, task, 'found object %s:%s %s start %s len, %s stream start, %s stream length', + obj.major, obj.minor, obj.start, obj.len, obj.stream.start, obj.stream.len) + else + lua_util.debugm(N, task, 'found object %s:%s %s start %s len, no stream', + obj.major, obj.minor, obj.start, obj.len) + end + end + end +end + +-- Processes PDF objects: extracts streams, object numbers, process outer grammar, +-- augment object types +local function postprocess_pdf_objects(task, input, pdf) + pdf.objects = {} -- objects table + pdf.ref = {} -- references table + extract_outer_objects(task, input, pdf) + + -- Now we have objects and we need to attach streams that are in bounds + attach_pdf_streams(task, input, pdf) + -- Parse grammar for outer objects + for i, obj in ipairs(pdf.objects) do + if i > 0 and i % 100 == 0 then + local now = rspamd_util.get_ticks() + + if now >= pdf.end_timestamp then + pdf.timeout_processing = now - pdf.start_timestamp + + lua_util.debugm(N, task, 'pdf: timeout processing grammars after spending %s seconds, ' .. + '%s elements processed', + pdf.timeout_processing, i) + break + end + end + if obj.ref then + parse_object_grammar(obj, task, pdf) + + -- Special early handling + if obj.dict and obj.dict.Type and obj.dict.Type == 'XRef' then + process_xref(task, pdf, obj) + end + end + end + + if not pdf.timeout_processing then + extract_pdf_compound_objects(task, pdf) + else + -- ENOTIME + return + end + + -- Now we might probably have all objects being processed + for i, obj in ipairs(pdf.objects) do + if obj.dict then + -- Types processing + if i > 0 and i % 100 == 0 then + local now = rspamd_util.get_ticks() + + if now >= pdf.end_timestamp then + pdf.timeout_processing = now - pdf.start_timestamp + + lua_util.debugm(N, task, 'pdf: timeout processing dicts after spending %s seconds, ' .. + '%s elements processed', + pdf.timeout_processing, i) + break + end + end + process_dict(task, pdf, obj, obj.dict) + end + end +end + +local function offsets_to_blocks(starts, ends, out) + local start_pos, end_pos = 1, 1 + + while start_pos <= #starts and end_pos <= #ends do + local first = starts[start_pos] + local last = ends[end_pos] + + if first < last then + local len = last - first + out[#out + 1] = { + start = first, + len = len, + } + start_pos = start_pos + 1 + end_pos = end_pos + 1 + elseif first > last then + end_pos = end_pos + 1 + else + -- Not ordered properly! + break + end + end +end + +local function search_text(task, pdf) + for _, obj in ipairs(pdf.objects) do + if obj.type == 'Page' and obj.contents then + local text = {} + for _, tobj in ipairs(obj.contents) do + maybe_extract_object_stream(tobj, pdf, task) + local matches = pdf_text_trie:match(tobj.uncompressed or '') + if matches then + local text_blocks = {} + local starts = {} + local ends = {} + + for npat, matched_positions in pairs(matches) do + if npat == 1 then + for _, pos in ipairs(matched_positions) do + starts[#starts + 1] = pos + end + else + for _, pos in ipairs(matched_positions) do + ends[#ends + 1] = pos + end + end + end + + offsets_to_blocks(starts, ends, text_blocks) + for _, bl in ipairs(text_blocks) do + if bl.len > 2 then + -- To remove \s+ET\b pattern (it can leave trailing space or not but it doesn't matter) + bl.len = bl.len - 2 + end + + bl.data = tobj.uncompressed:span(bl.start, bl.len) + --lua_util.debugm(N, task, 'extracted text from object %s:%s: %s', + -- tobj.major, tobj.minor, bl.data) + + if bl.len < config.max_processing_size then + local ret, obj_or_err = pcall(pdf_text_grammar.match, pdf_text_grammar, + bl.data) + + if ret then + text[#text + 1] = obj_or_err + lua_util.debugm(N, task, 'attached %s from content object %s:%s to %s:%s', + obj_or_err, tobj.major, tobj.minor, obj.major, obj.minor) + else + lua_util.debugm(N, task, 'object %s:%s cannot be parsed: %s', + obj.major, obj.minor, obj_or_err) + end + + end + end + end + end + + -- Join all text data together + if #text > 0 then + obj.text = rspamd_text.fromtable(text) + lua_util.debugm(N, task, 'object %s:%s is parsed to: %s', + obj.major, obj.minor, obj.text) + end + end + end +end + +-- This function searches objects for `/URI` key and parses it's content +local function search_urls(task, pdf, mpart) + local function recursive_object_traverse(obj, dict, rec) + if rec > 10 then + lua_util.debugm(N, task, 'object %s:%s recurses too much', + obj.major, obj.minor) + return + end + + for k, v in pairs(dict) do + if type(v) == 'table' then + recursive_object_traverse(obj, v, rec + 1) + elseif k == 'URI' then + v = maybe_dereference_object(v, pdf, task) + if type(v) == 'string' then + local url = rspamd_url.create(task:get_mempool(), v, { 'content' }) + + if url then + lua_util.debugm(N, task, 'found url %s in object %s:%s', + v, obj.major, obj.minor) + task:inject_url(url, mpart) + end + end + end + end + end + + for _, obj in ipairs(pdf.objects) do + if obj.dict and type(obj.dict) == 'table' then + recursive_object_traverse(obj, obj.dict, 0) + end + end +end + +local function process_pdf(input, mpart, task) + + if not config.enabled then + -- Skip processing + return {} + end + + local matches = pdf_trie:match(input) + + if matches then + local start_ts = rspamd_util.get_ticks() + -- Temp object used to share data between pdf extraction methods + local pdf_object = { + tag = 'pdf', + extract_text = extract_text_data, + start_timestamp = start_ts, + end_timestamp = start_ts + config.pdf_process_timeout, + } + -- Output object that excludes all internal stuff + local pdf_output = lua_util.shallowcopy(pdf_object) + local grouped_processors = {} + for npat, matched_positions in pairs(matches) do + local index = pdf_indexes[npat] + + local proc_key, loc_npat = index[1], index[4] + + if not grouped_processors[proc_key] then + grouped_processors[proc_key] = { + processor_func = processors[proc_key], + offsets = {}, + } + end + local proc = grouped_processors[proc_key] + -- Fill offsets + for _, pos in ipairs(matched_positions) do + proc.offsets[#proc.offsets + 1] = { pos, loc_npat } + end + end + + for name, processor in pairs(grouped_processors) do + -- Sort by offset + lua_util.debugm(N, task, "pdf: process group %s with %s matches", + name, #processor.offsets) + table.sort(processor.offsets, function(e1, e2) + return e1[1] < e2[1] + end) + processor.processor_func(input, task, processor.offsets, pdf_object, pdf_output) + end + + pdf_output.flags = {} + + if pdf_object.start_objects and pdf_object.end_objects then + if #pdf_object.start_objects > config.max_pdf_objects then + pdf_output.many_objects = #pdf_object.start_objects + -- Trim + end + + -- Postprocess objects + postprocess_pdf_objects(task, input, pdf_object) + if config.text_extraction then + search_text(task, pdf_object, pdf_output) + end + if config.url_extraction then + search_urls(task, pdf_object, mpart, pdf_output) + end + + if config.js_fuzzy and pdf_object.scripts then + pdf_output.fuzzy_hashes = {} + if config.openaction_fuzzy_only then + -- OpenAction only + if pdf_object.openaction and pdf_object.openaction.bin_hash then + if config.min_js_fuzzy and #pdf_object.openaction.data >= config.min_js_fuzzy then + lua_util.debugm(N, task, "pdf: add fuzzy hash from openaction: %s; size = %s; object: %s:%s", + pdf_object.openaction.hash, + #pdf_object.openaction.data, + pdf_object.openaction.object.major, pdf_object.openaction.object.minor) + table.insert(pdf_output.fuzzy_hashes, pdf_object.openaction.bin_hash) + else + lua_util.debugm(N, task, "pdf: skip fuzzy hash from JavaScript: %s, too short: %s", + pdf_object.openaction.hash, #pdf_object.openaction.data) + end + end + else + -- All hashes + for h, sc in pairs(pdf_object.scripts) do + if config.min_js_fuzzy and #sc.data >= config.min_js_fuzzy then + lua_util.debugm(N, task, "pdf: add fuzzy hash from JavaScript: %s; size = %s; object: %s:%s", + sc.hash, + #sc.data, + sc.object.major, sc.object.minor) + table.insert(pdf_output.fuzzy_hashes, h) + else + lua_util.debugm(N, task, "pdf: skip fuzzy hash from JavaScript: %s, too short: %s", + sc.hash, #sc.data) + end + end + + end + end + else + pdf_output.flags.no_objects = true + end + + -- Propagate from object to output + if pdf_object.encrypted then + pdf_output.encrypted = true + end + if pdf_object.scripts then + pdf_output.scripts = true + end + + return pdf_output + end +end + +-- Processes the PDF trailer +processors.trailer = function(input, task, positions, pdf_object, pdf_output) + local last_pos = positions[#positions] + + lua_util.debugm(N, task, 'pdf: process trailer at position %s (%s total length)', + last_pos, #input) + + if last_pos[1] > config.max_pdf_trailer then + pdf_output.long_trailer = #input - last_pos[1] + return + end + + local last_span = input:span(last_pos[1]) + local lines_checked = 0 + for line in last_span:lines(true) do + if line:find('/Encrypt ') then + lua_util.debugm(N, task, "pdf: found encrypted line in trailer: %s", + line) + pdf_output.encrypted = true + pdf_object.encrypted = true + break + end + lines_checked = lines_checked + 1 + + if lines_checked > config.max_pdf_trailer_lines then + lua_util.debugm(N, task, "pdf: trailer has too many lines, stop checking") + pdf_output.long_trailer = #input - last_pos[1] + break + end + end +end + +processors.suspicious = function(input, task, positions, pdf_object, pdf_output) + local suspicious_factor = 0.0 + local nexec = 0 + local nencoded = 0 + local close_encoded = 0 + local last_encoded + for _, match in ipairs(positions) do + if match[2] == 1 then + -- netsh + suspicious_factor = suspicious_factor + 0.5 + elseif match[2] == 2 then + nexec = nexec + 1 + elseif match[2] == 3 then + local enc_data = input:sub(match[1] - 2, match[1] - 1) + local legal_escape = false + + if enc_data then + enc_data = enc_data:strtoul() + + if enc_data then + -- Legit encode cases are non printable characters (e.g. spaces) + if enc_data < 0x21 or enc_data >= 0x7f then + legal_escape = true + end + end + end + + if not legal_escape then + nencoded = nencoded + 1 + + if last_encoded then + if match[1] - last_encoded < 8 then + -- likely consecutive encoded chars, increase factor + close_encoded = close_encoded + 1 + end + end + last_encoded = match[1] + + end + end + end + + if nencoded > 10 then + suspicious_factor = suspicious_factor + nencoded / 10 + end + if nexec > 1 then + suspicious_factor = suspicious_factor + nexec / 2.0 + end + if close_encoded > 4 and nencoded - close_encoded < 5 then + -- Too many close encoded comparing to the total number of encoded characters + suspicious_factor = suspicious_factor + 0.5 + end + + lua_util.debugm(N, task, 'pdf: found a suspicious patterns: %s exec, %s encoded (%s close), ' .. + '%s final factor', + nexec, nencoded, close_encoded, suspicious_factor) + + if suspicious_factor > 1.0 then + suspicious_factor = 1.0 + end + + pdf_output.suspicious = suspicious_factor +end + +local function generic_table_inserter(positions, pdf_object, output_key) + if not pdf_object[output_key] then + pdf_object[output_key] = {} + end + local shift = #pdf_object[output_key] + for i, pos in ipairs(positions) do + pdf_object[output_key][i + shift] = pos[1] + end +end + +processors.start_object = function(_, task, positions, pdf_object) + generic_table_inserter(positions, pdf_object, 'start_objects') +end + +processors.end_object = function(_, task, positions, pdf_object) + generic_table_inserter(positions, pdf_object, 'end_objects') +end + +processors.start_stream = function(_, task, positions, pdf_object) + generic_table_inserter(positions, pdf_object, 'start_streams') +end + +processors.end_stream = function(_, task, positions, pdf_object) + generic_table_inserter(positions, pdf_object, 'end_streams') +end + +exports.process = process_pdf + +return exports
\ No newline at end of file diff --git a/lualib/lua_content/vcard.lua b/lualib/lua_content/vcard.lua new file mode 100644 index 0000000..ed14412 --- /dev/null +++ b/lualib/lua_content/vcard.lua @@ -0,0 +1,84 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local l = require 'lpeg' +local lua_util = require "lua_util" +local N = "lua_content" + +local vcard_grammar + +-- XXX: Currently it is a copy of ical grammar +local function gen_grammar() + if not vcard_grammar then + local wsp = l.S(" \t\v\f") + local crlf = (l.P "\r" ^ -1 * l.P "\n") + l.P "\r" + local eol = (crlf * #crlf) + (crlf - (crlf ^ -1 * wsp)) + local name = l.C((l.P(1) - (l.P ":")) ^ 1) / function(v) + return (v:gsub("[\n\r]+%s", "")) + end + local value = l.C((l.P(1) - eol) ^ 0) / function(v) + return (v:gsub("[\n\r]+%s", "")) + end + vcard_grammar = name * ":" * wsp ^ 0 * value * eol ^ -1 + end + + return vcard_grammar +end + +local exports = {} + +local function process_vcard(input, mpart, task) + local control = { n = '\n', r = '' } + local rspamd_url = require "rspamd_url" + local escaper = l.Ct((gen_grammar() / function(key, value) + value = value:gsub("\\(.)", control) + key = key:lower() + local local_urls = rspamd_url.all(task:get_mempool(), value) + + if local_urls and #local_urls > 0 then + for _, u in ipairs(local_urls) do + lua_util.debugm(N, task, 'vcard: found URL in vcard %s', + tostring(u)) + task:inject_url(u, mpart) + end + end + lua_util.debugm(N, task, 'vcard: vcard key %s = "%s"', + key, value) + return { key, value } + end) ^ 1) + + local elts = escaper:match(input) + + if not elts then + return nil + end + + return { + tag = 'vcard', + extract_text = function() + return nil + end, -- NYI + elts = elts + } +end + +--[[[ +-- @function vcard.process(input) +-- Returns all values from vcard as a plain text. Names are completely ignored. +--]] +exports.process = process_vcard + +return exports
\ No newline at end of file diff --git a/lualib/lua_dkim_tools.lua b/lualib/lua_dkim_tools.lua new file mode 100644 index 0000000..165ea8f --- /dev/null +++ b/lualib/lua_dkim_tools.lua @@ -0,0 +1,742 @@ +--[[ +Copyright (c) 2016, Andrew Lewis <nerf@judo.za.org> +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local exports = {} + +local E = {} +local lua_util = require "lua_util" +local rspamd_util = require "rspamd_util" +local logger = require "rspamd_logger" +local fun = require "fun" + +local function check_violation(N, task, domain) + -- Check for DKIM_REJECT + local sym_check = 'R_DKIM_REJECT' + + if N == 'arc' then + sym_check = 'ARC_REJECT' + end + if task:has_symbol(sym_check) then + local sym = task:get_symbol(sym_check)[1] + logger.infox(task, 'skip signing for %s: violation %s found: %s', + domain, sym_check, sym.options) + return false + end + + return true +end + +local function insert_or_update_prop(N, task, p, prop, origin, data) + if #p == 0 then + local k = {} + k[prop] = data + table.insert(p, k) + lua_util.debugm(N, task, 'add %s "%s" using %s', prop, data, origin) + else + for _, k in ipairs(p) do + if not k[prop] then + k[prop] = data + lua_util.debugm(N, task, 'set %s to "%s" using %s', prop, data, origin) + end + end + end +end + +local function get_mempool_selectors(N, task) + local p = {} + local key_var = "dkim_key" + local selector_var = "dkim_selector" + if N == "arc" then + key_var = "arc_key" + selector_var = "arc_selector" + end + + p.key = task:get_mempool():get_variable(key_var) + p.selector = task:get_mempool():get_variable(selector_var) + + if (not p.key or not p.selector) then + return false, {} + end + + lua_util.debugm(N, task, 'override selector and key to %s:%s', p.key, p.selector) + return true, p +end + +local function parse_dkim_http_headers(N, task, settings) + -- Configure headers + local headers = { + sign_header = settings.http_sign_header or "PerformDkimSign", + sign_on_reject_header = settings.http_sign_on_reject_header_header or 'SignOnAuthFailed', + domain_header = settings.http_domain_header or 'DkimDomain', + selector_header = settings.http_selector_header or 'DkimSelector', + key_header = settings.http_key_header or 'DkimPrivateKey' + } + + if task:get_request_header(headers.sign_header) then + local domain = task:get_request_header(headers.domain_header) + local selector = task:get_request_header(headers.selector_header) + local key = task:get_request_header(headers.key_header) + + if not (domain and selector and key) then + + logger.errx(task, 'missing required headers to sign email') + return false, {} + end + + -- Now check if we need to check the existing auth + local hdr = task:get_request_header(headers.sign_on_reject_header) + if not hdr or tostring(hdr) == '0' or tostring(hdr) == 'false' then + if not check_violation(N, task, domain, selector) then + return false, {} + end + end + + local p = {} + local k = { + domain = tostring(domain), + rawkey = tostring(key), + selector = tostring(selector), + } + table.insert(p, k) + return true, p + end + + lua_util.debugm(N, task, 'no sign header %s', headers.sign_header) + return false, {} +end + +local function prepare_dkim_signing(N, task, settings) + local is_local, is_sign_networks, is_authed + + if settings.use_http_headers then + local res, tbl = parse_dkim_http_headers(N, task, settings) + + if not res then + if not settings.allow_headers_fallback then + return res, {} + else + lua_util.debugm(N, task, 'failed to read http headers, fallback to normal schema') + end + else + return res, tbl + end + end + + if settings.sign_condition and type(settings.sign_condition) == 'function' then + -- Use sign condition only + local ret = settings.sign_condition(task) + + if not ret then + return false, {} + end + + if ret[1] then + return true, ret + else + return true, { ret } + end + end + + local auser = task:get_user() + local ip = task:get_from_ip() + + if ip and ip:is_local() then + is_local = true + end + + local has_pre_result = task:has_pre_result() + if has_pre_result then + local metric_action = task:get_metric_action() + + if metric_action == 'reject' or metric_action == 'drop' then + -- No need to sign what we are already rejecting/dropping + lua_util.debugm(N, task, 'task result is already %s, no need to sign', metric_action) + return false, {} + end + + if metric_action == 'soft reject' then + -- Same here, we are going to delay an email, signing is just a waste of time + lua_util.debugm(N, task, 'task result is %s, skip signing', metric_action) + return false, {} + end + + -- For spam actions, there is no clear distinction + if metric_action ~= 'no action' and type(settings.skip_spam_sign) == 'boolean' and settings.skip_spam_sign then + lua_util.debugm(N, task, 'task result is %s, no need to sign', metric_action) + return false, {} + end + end + + if settings.sign_authenticated and auser then + lua_util.debugm(N, task, 'user is authenticated') + is_authed = true + elseif (settings.sign_networks and settings.sign_networks:get_key(ip)) then + is_sign_networks = true + lua_util.debugm(N, task, 'mail is from address in sign_networks') + elseif settings.sign_local and is_local then + lua_util.debugm(N, task, 'mail is from local address') + elseif settings.sign_inbound and not is_local and not auser then + lua_util.debugm(N, task, 'mail was sent to us') + else + lua_util.debugm(N, task, 'mail is ineligible for signing') + return false, {} + end + + local efrom = task:get_from('smtp') + local empty_envelope = false + if #(((efrom or E)[1] or E).addr or '') == 0 then + if not settings.allow_envfrom_empty then + lua_util.debugm(N, task, 'empty envelope from not allowed') + return false, {} + else + empty_envelope = true + end + end + + local hfrom = task:get_from('mime') + if not settings.allow_hdrfrom_multiple and (hfrom or E)[2] then + lua_util.debugm(N, task, 'multiple header from not allowed') + return false, {} + end + + local eto = task:get_recipients(0) + + local dkim_domain + local hdom = ((hfrom or E)[1] or E).domain + local edom = ((efrom or E)[1] or E).domain + local tdom = ((eto or E)[1] or E).domain + local udom = string.match(auser or '', '.*@(.*)') + + local function get_dkim_domain(dtype) + if settings[dtype] == 'header' then + return hdom + elseif settings[dtype] == 'envelope' then + return edom + elseif settings[dtype] == 'auth' then + return udom + elseif settings[dtype] == 'recipient' then + return tdom + else + return settings[dtype]:lower() + end + end + + local function is_skip_sign() + return not (settings.sign_networks and is_sign_networks) and + not (settings.sign_authenticated and is_authed) and + not (settings.sign_local and is_local) + end + + if hdom then + hdom = hdom:lower() + end + if edom then + edom = edom:lower() + end + if udom then + udom = udom:lower() + end + if tdom then + tdom = tdom:lower() + end + + if settings.signing_table and (settings.key_table or settings.use_vault) then + -- OpenDKIM style + if is_skip_sign() then + lua_util.debugm(N, task, + 'skip signing: is_sign_network: %s, is_authed: %s, is_local: %s', + is_sign_networks, is_authed, is_local) + return false, {} + end + + if not hfrom or not hfrom[1] or not hfrom[1].addr then + lua_util.debugm(N, task, + 'signing_table: cannot get data when no header from is presented') + return false, {} + end + local sign_entry = settings.signing_table:get_key(hfrom[1].addr:lower()) + + if sign_entry then + -- Check opendkim style entries + lua_util.debugm(N, task, + 'signing_table: found entry for %s: %s', hfrom[1].addr, sign_entry) + if sign_entry == '%' then + sign_entry = hdom + end + + if settings.key_table then + -- Now search in key table + local key_entry = settings.key_table:get_key(sign_entry) + + if key_entry then + local parts = lua_util.str_split(key_entry, ':') + + if #parts == 2 then + -- domain + key + local selector = settings.selector + + if not selector then + logger.errx(task, 'no selector defined for sign_entry %s, key_entry %s', + sign_entry, key_entry) + return false, {} + end + + local res = { + selector = selector, + domain = parts[1]:gsub('%%', hdom) + } + + local st = parts[2]:sub(1, 2) + + if st:sub(1, 1) == '/' or st == './' or st == '..' then + res.key = parts[2]:gsub('%%', hdom) + lua_util.debugm(N, task, 'perform dkim signing for %s, selector=%s, domain=%s, key file=%s', + hdom, selector, res.domain, res.key) + else + res.rawkey = parts[2] -- No sanity check here + lua_util.debugm(N, task, 'perform dkim signing for %s, selector=%s, domain=%s, raw key used', + hdom, selector, res.domain) + end + + return true, { res } + elseif #parts == 3 then + -- domain, selector, key + local selector = parts[2] + + local res = { + selector = selector, + domain = parts[1]:gsub('%%', hdom) + } + + local st = parts[3]:sub(1, 2) + + if st:sub(1, 1) == '/' or st == './' or st == '..' then + res.key = parts[3]:gsub('%%', hdom) + lua_util.debugm(N, task, 'perform dkim signing for %s, selector=%s, domain=%s, key file=%s', + hdom, selector, res.domain, res.key) + else + res.rawkey = parts[3] -- No sanity check here + lua_util.debugm(N, task, 'perform dkim signing for %s, selector=%s, domain=%s, raw key used', + hdom, selector, res.domain) + end + + return true, { res } + else + logger.errx(task, 'invalid key entry for sign entry %s: %s; when signing %s domain', + sign_entry, key_entry, hdom) + return false, {} + end + elseif settings.use_vault then + -- Sign table is presented, the rest is covered by vault + lua_util.debugm(N, task, 'check vault for %s, by sign entry %s, key entry is missing', + hdom, sign_entry) + return true, { + domain = sign_entry, + vault = true + } + else + logger.errx(task, 'missing key entry for sign entry %s; when signing %s domain', + sign_entry, hdom) + return false, {} + end + else + logger.errx(task, 'cannot get key entry for signing entry %s, when signing %s domain', + sign_entry, hdom) + return false, {} + end + else + lua_util.debugm(N, task, + 'signing_table: no entry for %s', hfrom[1].addr) + return false, {} + end + else + if settings.use_domain_sign_networks and is_sign_networks then + dkim_domain = get_dkim_domain('use_domain_sign_networks') + lua_util.debugm(N, task, + 'sign_networks: use domain(%s) for signature: %s', + settings.use_domain_sign_networks, dkim_domain) + elseif settings.use_domain_sign_local and is_local then + dkim_domain = get_dkim_domain('use_domain_sign_local') + lua_util.debugm(N, task, 'local: use domain(%s) for signature: %s', + settings.use_domain_sign_local, dkim_domain) + elseif settings.use_domain_sign_inbound and not is_local and not auser then + dkim_domain = get_dkim_domain('use_domain_sign_inbound') + lua_util.debugm(N, task, 'inbound: use domain(%s) for signature: %s', + settings.use_domain_sign_inbound, dkim_domain) + elseif settings.use_domain_custom then + if type(settings.use_domain_custom) == 'string' then + -- Load custom function + local loadstring = loadstring or load + local ret, res_or_err = pcall(loadstring(settings.use_domain_custom)) + if ret then + if type(res_or_err) == 'function' then + settings.use_domain_custom = res_or_err + dkim_domain = settings.use_domain_custom(task) + lua_util.debugm(N, task, 'use custom domain for signing: %s', + dkim_domain) + else + logger.errx(task, 'cannot load dkim domain custom script: invalid type: %s, expected function', + type(res_or_err)) + settings.use_domain_custom = nil + end + else + logger.errx(task, 'cannot load dkim domain custom script: %s', res_or_err) + settings.use_domain_custom = nil + end + else + dkim_domain = settings.use_domain_custom(task) + lua_util.debugm(N, task, 'use custom domain for signing: %s', + dkim_domain) + end + else + dkim_domain = get_dkim_domain('use_domain') + lua_util.debugm(N, task, 'use domain(%s) for signature: %s', + settings.use_domain, dkim_domain) + end + end + + if not dkim_domain then + lua_util.debugm(N, task, 'could not extract dkim domain') + return false, {} + end + + if settings.use_esld then + dkim_domain = rspamd_util.get_tld(dkim_domain) + if hdom then + hdom = rspamd_util.get_tld(hdom) + end + if edom then + edom = rspamd_util.get_tld(edom) + end + end + + lua_util.debugm(N, task, 'final DKIM domain: %s', dkim_domain) + + -- Sanity checks + if edom and hdom and not settings.allow_hdrfrom_mismatch and hdom ~= edom then + if settings.allow_hdrfrom_mismatch_local and is_local then + lua_util.debugm(N, task, 'domain mismatch allowed for local IP: %1 != %2', hdom, edom) + elseif settings.allow_hdrfrom_mismatch_sign_networks and is_sign_networks then + lua_util.debugm(N, task, 'domain mismatch allowed for sign_networks: %1 != %2', hdom, edom) + else + if empty_envelope and hdom then + lua_util.debugm(N, task, 'domain mismatch allowed for empty envelope: %1 != %2', hdom, edom) + else + lua_util.debugm(N, task, 'domain mismatch not allowed: %1 != %2', hdom, edom) + return false, {} + end + end + end + + if auser and not settings.allow_username_mismatch then + if not udom then + lua_util.debugm(N, task, 'couldnt find domain in username') + return false, {} + end + if settings.use_esld then + udom = rspamd_util.get_tld(udom) + end + if udom ~= dkim_domain then + lua_util.debugm(N, task, 'user domain mismatch') + return false, {} + end + end + + local p = {} + + if settings.use_vault then + if settings.vault_domains then + if settings.vault_domains:get_key(dkim_domain) then + return true, { + domain = dkim_domain, + vault = true, + } + else + lua_util.debugm(N, task, 'domain %s is not designated for vault', + dkim_domain) + return false, {} + end + else + -- TODO: try every domain in the vault + return true, { + domain = dkim_domain, + vault = true, + } + end + end + + if settings.domain[dkim_domain] then + -- support old style selector/paths + if settings.domain[dkim_domain].selector or + settings.domain[dkim_domain].path then + local k = {} + k.selector = settings.domain[dkim_domain].selector + k.key = settings.domain[dkim_domain].path + table.insert(p, k) + end + for _, s in ipairs((settings.domain[dkim_domain].selectors or {})) do + lua_util.debugm(N, task, 'adding selector: %1', s) + local k = {} + k.selector = s.selector + k.key = s.path + table.insert(p, k) + end + end + + if #p == 0 then + local ret, k = get_mempool_selectors(N, task) + if ret then + table.insert(p, k) + lua_util.debugm(N, task, 'using mempool selector %s with key %s', + k.selector, k.key) + end + end + + if settings.selector_map then + local data = settings.selector_map:get_key(dkim_domain) + if data then + insert_or_update_prop(N, task, p, 'selector', 'selector_map', data) + else + lua_util.debugm(N, task, 'no selector in map for %s', dkim_domain) + end + end + + if settings.path_map then + local data = settings.path_map:get_key(dkim_domain) + if data then + insert_or_update_prop(N, task, p, 'key', 'path_map', data) + else + lua_util.debugm(N, task, 'no key in map for %s', dkim_domain) + end + end + + if #p == 0 and not settings.try_fallback then + lua_util.debugm(N, task, 'dkim unconfigured and fallback disabled') + return false, {} + end + + if not settings.use_redis then + insert_or_update_prop(N, task, p, 'key', + 'default path', settings.path) + end + + insert_or_update_prop(N, task, p, 'selector', + 'default selector', settings.selector) + + if settings.check_violation then + if not check_violation(N, task, p.domain) then + return false, {} + end + end + + insert_or_update_prop(N, task, p, 'domain', 'dkim_domain', + dkim_domain) + + return true, p +end + +exports.prepare_dkim_signing = prepare_dkim_signing + +exports.sign_using_redis = function(N, task, settings, selectors, sign_func, err_func) + local lua_redis = require "lua_redis" + + local function try_redis_key(selector, p) + p.key = nil + p.selector = selector + local rk = string.format('%s.%s', p.selector, p.domain) + local function redis_key_cb(err, data) + if err then + err_func(string.format("cannot make request to load DKIM key for %s: %s", + rk, err)) + elseif type(data) ~= 'string' then + lua_util.debugm(N, task, "missing DKIM key for %s", rk) + else + p.rawkey = data + lua_util.debugm(N, task, 'found and parsed key for %s:%s in Redis', + p.domain, p.selector) + sign_func(task, p) + end + end + local rret = lua_redis.redis_make_request(task, + settings.redis_params, -- connect params + rk, -- hash key + false, -- is write + redis_key_cb, --callback + 'HGET', -- command + { settings.key_prefix, rk } -- arguments + ) + if not rret then + err_func(task, + string.format("cannot make request to load DKIM key for %s", rk)) + end + end + + for _, p in ipairs(selectors) do + if settings.selector_prefix then + logger.infox(task, "using selector prefix '%s' for domain '%s'", + settings.selector_prefix, p.domain); + local function redis_selector_cb(err, data) + if err or type(data) ~= 'string' then + err_func(task, string.format("cannot make request to load DKIM selector for domain %s: %s", + p.domain, err)) + else + try_redis_key(data, p) + end + end + local rret = lua_redis.redis_make_request(task, + settings.redis_params, -- connect params + p.domain, -- hash key + false, -- is write + redis_selector_cb, --callback + 'HGET', -- command + { settings.selector_prefix, p.domain } -- arguments + ) + if not rret then + err_func(task, string.format("cannot make Redis request to load DKIM selector for domain %s", + p.domain)) + end + else + try_redis_key(p.selector, p) + end + end +end + +exports.sign_using_vault = function(N, task, settings, selectors, sign_func, err_func) + local http = require "rspamd_http" + local ucl = require "ucl" + + local full_url = string.format('%s/v1/%s/%s', + settings.vault_url, settings.vault_path or 'dkim', selectors.domain) + local upstream_list = lua_util.http_upstreams_by_url(rspamd_config:get_mempool(), settings.vault_url) + + local function vault_callback(err, code, body, _) + if code ~= 200 then + err_func(task, string.format('cannot request data from the vault url: %s; %s (%s)', + full_url, err, body)) + else + local parser = ucl.parser() + local res, parser_err = parser:parse_string(body) + if not res then + err_func(task, string.format('vault reply for %s (data=%s) cannot be parsed: %s', + full_url, body, parser_err)) + else + local obj = parser:get_object() + + if not obj or not obj.data then + err_func(task, string.format('vault reply for %s (data=%s) is invalid, no data', + full_url, body)) + else + local elts = obj.data.selectors or {} + + -- Filter selectors by time/sanity + local function is_selector_valid(p) + if not p.key or not p.selector then + return false + end + + if p.valid_start then + -- Check start time + if rspamd_util.get_time() < tonumber(p.valid_start) then + return false + end + end + + if p.valid_end then + if rspamd_util.get_time() >= tonumber(p.valid_end) then + return false + end + end + + return true + end + fun.each(function(p) + local dkim_sign_data = { + rawkey = p.key, + selector = p.selector, + domain = p.domain or selectors.domain, + alg = p.alg, + } + lua_util.debugm(N, task, 'found and parsed key for %s:%s in Vault', + dkim_sign_data.domain, dkim_sign_data.selector) + sign_func(task, dkim_sign_data) + end, fun.filter(is_selector_valid, elts)) + end + end + end + end + + local ret = http.request { + task = task, + url = full_url, + callback = vault_callback, + timeout = settings.http_timeout or 5.0, + no_ssl_verify = settings.no_ssl_verify, + keepalive = true, + upstream = upstream_list and upstream_list:get_upstream_round_robin() or nil, + headers = { + ['X-Vault-Token'] = settings.vault_token, + }, + } + + if not ret then + err_func(task, string.format("cannot make HTTP request to load DKIM data domain %s", + selectors.domain)) + end +end + +exports.validate_signing_settings = function(settings) + return settings.use_redis or + settings.path or + settings.domain or + settings.path_map or + settings.selector_map or + settings.use_http_headers or + (settings.signing_table and settings.key_table) or + (settings.use_vault and settings.vault_url and settings.vault_token) or + settings.sign_condition +end + +exports.process_signing_settings = function(N, settings, opts) + local lua_maps = require "lua_maps" + -- Used to convert plain options to the maps + local maps_opts = { + sign_networks = { 'radix', 'DKIM signing networks' }, + path_map = { 'map', 'Paths to DKIM signing keys' }, + selector_map = { 'map', 'DKIM selectors' }, + signing_table = { 'glob', 'DKIM signing table' }, + key_table = { 'glob', 'DKIM keys table' }, + vault_domains = { 'glob', 'DKIM signing domains in vault' }, + whitelisted_signers_map = { 'set', 'ARC trusted signers domains' } + } + for k, v in pairs(opts) do + local maybe_map = maps_opts[k] + if maybe_map then + settings[k] = lua_maps.map_add_from_ucl(v, maybe_map[1], maybe_map[2]) + elseif k == 'sign_condition' then + local ret, f = lua_util.callback_from_string(v) + if ret then + settings[k] = f + else + logger.errx(rspamd_config, 'cannot load sign condition %s: %s', v, f) + end + else + settings[k] = v + end + end +end + +return exports diff --git a/lualib/lua_ffi/common.lua b/lualib/lua_ffi/common.lua new file mode 100644 index 0000000..4076cfa --- /dev/null +++ b/lualib/lua_ffi/common.lua @@ -0,0 +1,45 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_ffi/common +-- Common ffi definitions +--]] + +local ffi = require 'ffi' + +ffi.cdef [[ +struct GString { + char *str; + size_t len; + size_t allocated_len; +}; +struct GArray { + char *data; + unsigned len; +}; +typedef void (*ref_dtor_cb_t)(void *data); +struct ref_entry_s { + unsigned int refcount; + ref_dtor_cb_t dtor; +}; + +void g_string_free (struct GString *st, int free_data); +void g_free (void *p); +long rspamd_snprintf (char *buf, long max, const char *fmt, ...); +]] + +return {}
\ No newline at end of file diff --git a/lualib/lua_ffi/dkim.lua b/lualib/lua_ffi/dkim.lua new file mode 100644 index 0000000..e4592c2 --- /dev/null +++ b/lualib/lua_ffi/dkim.lua @@ -0,0 +1,144 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_ffi/dkim +-- This module contains ffi interfaces to DKIM +--]] + +local ffi = require 'ffi' + +ffi.cdef [[ +struct rspamd_dkim_sign_context_s; +struct rspamd_dkim_key_s; +struct rspamd_task; +enum rspamd_dkim_key_format { + RSPAMD_DKIM_KEY_FILE = 0, + RSPAMD_DKIM_KEY_PEM, + RSPAMD_DKIM_KEY_BASE64, + RSPAMD_DKIM_KEY_RAW, +}; +enum rspamd_dkim_type { + RSPAMD_DKIM_NORMAL, + RSPAMD_DKIM_ARC_SIG, + RSPAMD_DKIM_ARC_SEAL +}; +struct rspamd_dkim_sign_context_s* +rspamd_create_dkim_sign_context (struct rspamd_task *task, + struct rspamd_dkim_key_s *priv_key, + int headers_canon, + int body_canon, + const char *dkim_headers, + enum rspamd_dkim_type type, + void *unused); +struct rspamd_dkim_key_s* rspamd_dkim_sign_key_load (const char *what, size_t len, + enum rspamd_dkim_key_format, + void *err); +void rspamd_dkim_key_unref (struct rspamd_dkim_key_s *k); +struct GString *rspamd_dkim_sign (struct rspamd_task *task, + const char *selector, + const char *domain, + unsigned long expire, + size_t len, + unsigned int idx, + const char *arc_cv, + struct rspamd_dkim_sign_context_s *ctx); +]] + +local function load_sign_key(what, format) + if not format then + format = ffi.C.RSPAMD_DKIM_KEY_PEM + else + if format == 'file' then + format = ffi.C.RSPAMD_DKIM_KEY_FILE + elseif format == 'base64' then + format = ffi.C.RSPAMD_DKIM_KEY_BASE64 + elseif format == 'raw' then + format = ffi.C.RSPAMD_DKIM_KEY_RAW + else + return nil, 'unknown key format' + end + end + + return ffi.C.rspamd_dkim_sign_key_load(what, #what, format, nil) +end + +local default_dkim_headers = "(o)from:(o)sender:(o)reply-to:(o)subject:(o)date:(o)message-id:" .. + "(o)to:(o)cc:(o)mime-version:(o)content-type:(o)content-transfer-encoding:" .. + "resent-to:resent-cc:resent-from:resent-sender:resent-message-id:" .. + "(o)in-reply-to:(o)references:list-id:list-owner:list-unsubscribe:" .. + "list-subscribe:list-post:(o)openpgp:(o)autocrypt" + +local function create_sign_context(task, privkey, dkim_headers, sign_type) + if not task or not privkey then + return nil, 'invalid arguments' + end + + if not dkim_headers then + dkim_headers = default_dkim_headers + end + + if not sign_type then + sign_type = 'dkim' + end + + if sign_type == 'dkim' then + sign_type = ffi.C.RSPAMD_DKIM_NORMAL + elseif sign_type == 'arc-sig' then + sign_type = ffi.C.RSPAMD_DKIM_ARC_SIG + elseif sign_type == 'arc-seal' then + sign_type = ffi.C.RSPAMD_DKIM_ARC_SEAL + else + return nil, 'invalid sign type' + end + + return ffi.C.rspamd_create_dkim_sign_context(task:topointer(), privkey, + 1, 1, dkim_headers, sign_type, nil) +end + +local function do_sign(task, sign_context, selector, domain, + expire, len, arc_idx) + if not task or not sign_context or not selector or not domain then + return nil, 'invalid arguments' + end + + if not expire then + expire = 0 + end + if not len then + len = 0 + end + if not arc_idx then + arc_idx = 0 + end + + local gstring = ffi.C.rspamd_dkim_sign(task:topointer(), selector, domain, expire, len, arc_idx, nil, sign_context) + + if not gstring then + return nil, 'cannot sign' + end + + local ret = ffi.string(gstring.str, gstring.len) + ffi.C.g_string_free(gstring, true) + + return ret +end + +return { + load_sign_key = load_sign_key, + create_sign_context = create_sign_context, + do_sign = do_sign +} diff --git a/lualib/lua_ffi/init.lua b/lualib/lua_ffi/init.lua new file mode 100644 index 0000000..efbbc7a --- /dev/null +++ b/lualib/lua_ffi/init.lua @@ -0,0 +1,59 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_ffi +-- This module contains ffi interfaces (requires luajit or lua-ffi) +--]] + +local ffi + +local exports = {} + +if type(jit) == 'table' then + ffi = require "ffi" + local NULL = ffi.new 'void*' + + exports.is_null = function(o) + return o ~= NULL + end +else + local ret, result_or_err = pcall(require, 'ffi') + + if not ret then + return {} + end + + ffi = result_or_err + -- Lua ffi + local NULL = ffi.NULL or ffi.C.NULL + exports.is_null = function(o) + return o ~= NULL + end +end + +pcall(ffi.load, "rspamd-server", true) +exports.common = require "lua_ffi/common" +exports.dkim = require "lua_ffi/dkim" +exports.spf = require "lua_ffi/spf" +exports.linalg = require "lua_ffi/linalg" + +for k, v in pairs(ffi) do + -- Preserve all stuff to use lua_ffi as ffi itself + exports[k] = v +end + +return exports
\ No newline at end of file diff --git a/lualib/lua_ffi/linalg.lua b/lualib/lua_ffi/linalg.lua new file mode 100644 index 0000000..2df488a --- /dev/null +++ b/lualib/lua_ffi/linalg.lua @@ -0,0 +1,87 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_ffi/linalg +-- This module contains ffi interfaces to linear algebra routines +--]] + +local ffi = require 'ffi' + +local exports = {} + +ffi.cdef [[ + void kad_sgemm_simple(int trans_A, int trans_B, int M, int N, int K, const float *A, const float *B, float *C); + bool kad_ssyev_simple (int N, float *A, float *output); +]] + +local function table_to_ffi(a, m, n) + local a_conv = ffi.new("float[?]", m * n) + for i = 1, m or #a do + for j = 1, n or #a[1] do + a_conv[(i - 1) * n + (j - 1)] = a[i][j] + end + end + return a_conv +end + +local function ffi_to_table(a, m, n) + local res = {} + + for i = 0, m - 1 do + res[i + 1] = {} + for j = 0, n - 1 do + res[i + 1][j + 1] = a[i * n + j] + end + end + + return res +end + +exports.sgemm = function(a, m, b, n, k, trans_a, trans_b) + if type(a) == 'table' then + -- Need to convert, slow! + a = table_to_ffi(a, m, k) + end + if type(b) == 'table' then + b = table_to_ffi(b, k, n) + end + local res = ffi.new("float[?]", m * n) + ffi.C.kad_sgemm_simple(trans_a or 0, trans_b or 0, m, n, k, ffi.cast('const float*', a), + ffi.cast('const float*', b), ffi.cast('float*', res)) + return res +end + +exports.eigen = function(a, n) + if type(a) == 'table' then + -- Need to convert, slow! + n = n or #a + a = table_to_ffi(a, n, n) + end + + local res = ffi.new("float[?]", n) + + if ffi.C.kad_ssyev_simple(n, ffi.cast('float*', a), res) then + return res, a + end + + return nil +end + +exports.ffi_to_table = ffi_to_table +exports.table_to_ffi = table_to_ffi + +return exports
\ No newline at end of file diff --git a/lualib/lua_ffi/spf.lua b/lualib/lua_ffi/spf.lua new file mode 100644 index 0000000..0f982f2 --- /dev/null +++ b/lualib/lua_ffi/spf.lua @@ -0,0 +1,143 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_ffi/spf +-- This module contains ffi interfaces to SPF +--]] + +local ffi = require 'ffi' + +ffi.cdef [[ +enum spf_mech_e { + SPF_FAIL, + SPF_SOFT_FAIL, + SPF_PASS, + SPF_NEUTRAL +}; +static const unsigned RSPAMD_SPF_FLAG_IPV6 = (1 << 0); +static const unsigned RSPAMD_SPF_FLAG_IPV4 = (1 << 1); +static const unsigned RSPAMD_SPF_FLAG_ANY = (1 << 3); +struct spf_addr { + unsigned char addr6[16]; + unsigned char addr4[4]; + union { + struct { + uint16_t mask_v4; + uint16_t mask_v6; + } dual; + uint32_t idx; + } m; + unsigned flags; + enum spf_mech_e mech; + char *spf_string; + struct spf_addr *prev, *next; +}; + +struct spf_resolved { + char *domain; + unsigned ttl; + int temp_failed; + int na; + int perm_failed; + uint64_t digest; + struct GArray *elts; + struct ref_entry_s ref; +}; + +typedef void (*spf_cb_t)(struct spf_resolved *record, + struct rspamd_task *task, void *data); +struct rspamd_task; +int rspamd_spf_resolve(struct rspamd_task *task, spf_cb_t callback, + void *cbdata); +const char * rspamd_spf_get_domain (struct rspamd_task *task); +struct spf_resolved * spf_record_ref (struct spf_resolved *rec); +void spf_record_unref (struct spf_resolved *rec); +char * spf_addr_mask_to_string (struct spf_addr *addr); +struct spf_addr * spf_addr_match_task (struct rspamd_task *task, struct spf_resolved *rec); +]] + +local function convert_mech(mech) + if mech == ffi.C.SPF_FAIL then + return 'fail' + elseif mech == ffi.C.SPF_SOFT_FAIL then + return 'softfail' + elseif mech == ffi.C.SPF_PASS then + return 'pass' + elseif mech == ffi.C.SPF_NEUTRAL then + return 'neutral' + end +end + +local NULL = ffi.new 'void*' + +local function spf_addr_tolua(ffi_spf_addr) + local ipstr = ffi.C.spf_addr_mask_to_string(ffi_spf_addr) + local ret = { + res = convert_mech(ffi_spf_addr.mech), + ipnet = ffi.string(ipstr), + } + + if ffi_spf_addr.spf_string ~= NULL then + ret.spf_str = ffi.string(ffi_spf_addr.spf_string) + end + + ffi.C.g_free(ipstr) + return ret +end + +local function spf_resolve(task, cb) + local function spf_cb(rec, _, _) + if not rec then + cb(false, 'record is empty') + else + local nelts = rec.elts.len + local elts = ffi.cast("struct spf_addr *", rec.elts.data) + local res = { + addrs = {} + } + local digstr = ffi.new("char[64]") + ffi.C.rspamd_snprintf(digstr, 64, "0x%xuL", rec.digest) + res.digest = ffi.string(digstr) + for i = 1, nelts do + res.addrs[i] = spf_addr_tolua(elts[i - 1]) + end + + local matched = ffi.C.spf_addr_match_task(task:topointer(), rec) + + if matched ~= NULL then + cb(true, res, spf_addr_tolua(matched)) + else + cb(true, res, nil) + end + end + end + + local ret = ffi.C.rspamd_spf_resolve(task:topointer(), spf_cb, nil) + + if not ret then + cb(false, 'cannot perform resolving') + end +end + +local function spf_unref(rec) + ffi.C.spf_record_unref(rec) +end + +return { + spf_resolve = spf_resolve, + spf_unref = spf_unref +}
\ No newline at end of file diff --git a/lualib/lua_fuzzy.lua b/lualib/lua_fuzzy.lua new file mode 100644 index 0000000..986d1a0 --- /dev/null +++ b/lualib/lua_fuzzy.lua @@ -0,0 +1,355 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_fuzzy +-- This module contains helper functions for supporting fuzzy check module +--]] + + +local N = "lua_fuzzy" +local lua_util = require "lua_util" +local rspamd_regexp = require "rspamd_regexp" +local fun = require "fun" +local rspamd_logger = require "rspamd_logger" +local ts = require("tableshape").types + +-- Filled by C code, indexed by number in this table +local rules = {} + +-- Pre-defined rules options +local policies = { + recommended = { + min_bytes = 1024, + min_height = 500, + min_width = 500, + min_length = 64, + text_multiplier = 4.0, -- divide min_bytes by 4 for texts + mime_types = { "application/*" }, + scan_archives = true, + short_text_direct_hash = true, + text_shingles = true, + skip_images = false, + } +} + +local default_policy = policies.recommended + +local schema_fields = { + min_bytes = ts.number + ts.string / tonumber, + min_height = ts.number + ts.string / tonumber, + min_width = ts.number + ts.string / tonumber, + min_length = ts.number + ts.string / tonumber, + text_multiplier = ts.number, + mime_types = ts.array_of(ts.string), + scan_archives = ts.boolean, + short_text_direct_hash = ts.boolean, + text_shingles = ts.boolean, + skip_images = ts.boolean, +} +local policy_schema = ts.shape(schema_fields) + +local policy_schema_open = ts.shape(schema_fields, { + open = true, +}) + +local exports = {} + + +--[[[ +-- @function lua_fuzzy.register_policy(name, policy) +-- Adds a new policy with name `name`. Must be valid, checked using policy_schema +--]] +exports.register_policy = function(name, policy) + if policies[name] then + rspamd_logger.warnx(rspamd_config, "overriding policy %s", name) + end + + local parsed_policy, err = policy_schema:transform(policy) + + if not parsed_policy then + rspamd_logger.errx(rspamd_config, 'invalid fuzzy rule policy %s: %s', + name, err) + + return + else + policies.name = parsed_policy + end +end + +--[[[ +-- @function lua_fuzzy.process_rule(rule) +-- Processes fuzzy rule (applying policies or defaults if needed). Returns policy id +--]] +exports.process_rule = function(rule) + local processed_rule = lua_util.shallowcopy(rule) + local policy = default_policy + + if processed_rule.policy then + policy = policies[processed_rule.policy] + end + + if policy then + processed_rule = lua_util.override_defaults(policy, processed_rule) + + local parsed_policy, err = policy_schema_open:transform(processed_rule) + + if not parsed_policy then + rspamd_logger.errx(rspamd_config, 'invalid fuzzy rule default fields: %s', err) + else + processed_rule = parsed_policy + end + else + rspamd_logger.warnx(rspamd_config, "unknown policy %s", processed_rule.policy) + end + + if processed_rule.mime_types then + processed_rule.mime_types = fun.totable(fun.map(function(gl) + return rspamd_regexp.import_glob(gl, 'i') + end, processed_rule.mime_types)) + end + + table.insert(rules, processed_rule) + return #rules +end + +local function check_length(task, part, rule) + local bytes = part:get_length() + local length_ok = bytes > 0 + + local id = part:get_id() + lua_util.debugm(N, task, 'check size of part %s', id) + + if length_ok and rule.min_bytes > 0 then + + local adjusted_bytes = bytes + + if part:is_text() then + -- Fuzzy plugin uses stripped utf content to get an exact hash, that + -- corresponds to `get_content_oneline()` + -- However, in the case of empty parts this method returns `nil`, so extra + -- sanity check is required. + bytes = #(part:get_text():get_content_oneline() or '') + + -- Short hashing algorithm also use subject unless explicitly denied + if not rule.no_subject then + local subject = task:get_subject() or '' + bytes = bytes + #subject + end + + if rule.text_multiplier then + adjusted_bytes = bytes * rule.text_multiplier + end + end + + if rule.min_bytes > adjusted_bytes then + lua_util.debugm(N, task, 'skip part of length %s (%s adjusted) ' .. + 'as it has less than %s bytes', + bytes, adjusted_bytes, rule.min_bytes) + length_ok = false + else + lua_util.debugm(N, task, 'allow part of length %s (%s adjusted)', + bytes, adjusted_bytes, rule.min_bytes) + end + else + lua_util.debugm(N, task, 'allow part %s, no length limits', id) + end + + return length_ok +end + +local function check_text_part(task, part, rule, text) + local allow_direct, allow_shingles = false, false + + local id = part:get_id() + lua_util.debugm(N, task, 'check text part %s', id) + local wcnt = text:get_words_count() + + if rule.text_shingles then + -- Check number of words + local min_words = rule.min_length or 0 + if min_words < 32 then + min_words = 32 -- Minimum for shingles + end + if wcnt < min_words then + lua_util.debugm(N, task, 'text has less than %s words: %s; disable shingles', + rule.min_length, wcnt) + allow_shingles = false + else + lua_util.debugm(N, task, 'allow shingles in text %s, %s words', + id, wcnt) + allow_shingles = true + end + + if not rule.short_text_direct_hash and not allow_shingles then + allow_direct = false + else + if not allow_shingles then + lua_util.debugm(N, task, + 'allow direct hash for short text %s, %s words', + id, wcnt) + allow_direct = check_length(task, part, rule) + else + allow_direct = wcnt > 0 + end + end + else + lua_util.debugm(N, task, + 'disable shingles in text %s', id) + allow_direct = check_length(task, part, rule) + end + + return allow_direct, allow_shingles +end + +--local function has_sane_text_parts(task) +-- local text_parts = task:get_text_parts() or {} +-- return fun.any(function(tp) return tp:get_words_count() > 32 end, text_parts) +--end + +local function check_image_part(task, part, rule, image) + if rule.skip_images then + lua_util.debugm(N, task, 'skip image part as images are disabled') + return false, false + end + + local id = part:get_id() + lua_util.debugm(N, task, 'check image part %s', id) + + if rule.min_width > 0 or rule.min_height > 0 then + -- Check dimensions + local min_width = rule.min_width or rule.min_height + local min_height = rule.min_height or rule.min_width + local height = image:get_height() + local width = image:get_width() + + if height and width then + if height < min_height or width < min_width then + lua_util.debugm(N, task, 'skip image part %s as it does not meet minimum sizes: %sx%s < %sx%s', + id, width, height, min_width, min_height) + return false, false + else + lua_util.debugm(N, task, 'allow image part %s: %sx%s', + id, width, height) + end + end + end + + return check_length(task, part, rule), false +end + +local function mime_types_check(task, part, rule) + local t, st = part:get_type() + + if not t then + return false, false + end + + local ct = string.format('%s/%s', t, st) + + local detected_ct + t, st = part:get_detected_type() + if t then + detected_ct = string.format('%s/%s', t, st) + else + detected_ct = ct + end + + local id = part:get_id() + lua_util.debugm(N, task, 'check binary part %s: %s', id, ct) + + -- For bad mime parts we implicitly enable fuzzy check + local mime_trace = (task:get_symbol('MIME_TRACE') or {})[1] + local opts = {} + + if mime_trace then + opts = mime_trace.options or opts + end + opts = fun.tomap(fun.map(function(opt) + local elts = lua_util.str_split(opt, ':') + return elts[1], elts[2] + end, opts)) + + if opts[id] and opts[id] == '-' then + lua_util.debugm(N, task, 'explicitly check binary part %s: bad mime type %s', id, ct) + return check_length(task, part, rule), false + end + + if rule.mime_types then + + if fun.any(function(gl_re) + if gl_re:match(ct) or (detected_ct and gl_re:match(detected_ct)) then + return true + else + return false + end + end, rule.mime_types) then + lua_util.debugm(N, task, 'found mime type match for part %s: %s (%s detected)', + id, ct, detected_ct) + return check_length(task, part, rule), false + end + + return false, false + end + + return false, false +end + +exports.check_mime_part = function(task, part, rule_id) + local rule = rules[rule_id] + + if not rule then + rspamd_logger.errx(task, 'cannot find rule with id %s', rule_id) + + return false, false + end + + if part:is_text() then + return check_text_part(task, part, rule, part:get_text()) + end + + if part:is_image() then + return check_image_part(task, part, rule, part:get_image()) + end + + if part:is_archive() and rule.scan_archives then + -- Always send archives + lua_util.debugm(N, task, 'check archive part %s', part:get_id()) + + return true, false + end + + if part:is_specific() then + local sp = part:get_specific() + + if type(sp) == 'table' and sp.fuzzy_hashes then + lua_util.debugm(N, task, 'check specific part %s', part:get_id()) + return true, false + end + end + + if part:is_attachment() then + return mime_types_check(task, part, rule) + end + + return false, false +end + +exports.cleanup_rules = function() + rules = {} +end + +return exports diff --git a/lualib/lua_lexer.lua b/lualib/lua_lexer.lua new file mode 100644 index 0000000..54bbd7c --- /dev/null +++ b/lualib/lua_lexer.lua @@ -0,0 +1,163 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[ Lua LPEG grammar based on https://github.com/xolox/lua-lxsh/ ]] + + +local lpeg = require "lpeg" + +local P = lpeg.P +local R = lpeg.R +local S = lpeg.S +local D = R '09' -- Digits +local I = R('AZ', 'az', '\127\255') + '_' -- Identifiers +local B = -(I + D) -- Word boundary +local EOS = -lpeg.P(1) -- end of string + +-- Pattern for long strings and long comments. +local longstring = #(P '[[' + (P '[' * P '=' ^ 0 * '[')) * P(function(input, index) + local level = input:match('^%[(=*)%[', index) + if level then + local _, last = input:find(']' .. level .. ']', index, true) + if last then + return last + 1 + end + end +end) + +-- String literals. +local singlequoted = P "'" * ((1 - S "'\r\n\f\\") + (P '\\' * 1)) ^ 0 * "'" +local doublequoted = P '"' * ((1 - S '"\r\n\f\\') + (P '\\' * 1)) ^ 0 * '"' + +-- Comments. +local eol = P '\r\n' + '\n' +local line = (1 - S '\r\n\f') ^ 0 * eol ^ -1 +local singleline = P '--' * line +local multiline = P '--' * longstring + +-- Numbers. +local sign = S '+-' ^ -1 +local decimal = D ^ 1 +local hexadecimal = P '0' * S 'xX' * R('09', 'AF', 'af') ^ 1 +local float = D ^ 1 * P '.' * D ^ 0 + P '.' * D ^ 1 +local maybeexp = (float + decimal) * (S 'eE' * sign * D ^ 1) ^ -1 + +local function compile_keywords(keywords) + local list = {} + for word in keywords:gmatch('%S+') do + list[#list + 1] = word + end + -- Sort by length + table.sort(list, function(a, b) + return #a > #b + end) + + local pattern + for _, word in ipairs(list) do + local p = lpeg.P(word) + pattern = pattern and (pattern + p) or p + end + + local AB = B + EOS -- ending boundary + return pattern * AB +end + +-- Identifiers +local ident = I * (I + D) ^ 0 +local expr = ('.' * ident) ^ 0 + +local patterns = { + { 'whitespace', S '\r\n\f\t\v ' ^ 1 }, + { 'constant', (P 'true' + 'false' + 'nil') * B }, + { 'string', singlequoted + doublequoted + longstring }, + { 'comment', multiline + singleline }, + { 'number', hexadecimal + maybeexp }, + { 'operator', P 'not' + '...' + 'and' + '..' + '~=' + '==' + '>=' + '<=' + + 'or' + S ']{=>^[<;)*(%}+-:,/.#' }, + { 'keyword', compile_keywords([[ + break do else elseif end for function if in local repeat return then until while + ]]) }, + { 'identifier', lpeg.Cmt(ident, + function(input, index) + return expr:match(input, index) + end) + }, + { 'error', 1 }, +} + +local compiled + +local function compile_patterns() + if not compiled then + local function process(elt) + local n, grammar = elt[1], elt[2] + return lpeg.Cc(n) * lpeg.P(grammar) * lpeg.Cp() + end + local any = process(patterns[1]) + for i = 2, #patterns do + any = any + process(patterns[i]) + end + compiled = any + end + + return compiled +end + +local function sync(token, lnum, cnum) + local lastidx + lnum, cnum = lnum or 1, cnum or 1 + if token:find '\n' then + for i in token:gmatch '()\n' do + lnum = lnum + 1 + lastidx = i + end + cnum = #token - lastidx + 1 + else + cnum = cnum + #token + end + return lnum, cnum +end + +local exports = {} + +exports.gmatch = function(input) + local parser = compile_patterns() + local index, lnum, cnum = 1, 1, 1 + + return function() + local kind, after = parser:match(input, index) + if kind and after then + local text = input:sub(index, after - 1) + local oldlnum, oldcnum = lnum, cnum + index = after + lnum, cnum = sync(text, lnum, cnum) + return kind, text, oldlnum, oldcnum + end + end +end + +exports.lex_to_table = function(input) + local out = {} + + for kind, text, lnum, cnum in exports.gmatch(input) do + out[#out + 1] = { kind, text, lnum, cnum } + end + + return out +end + +return exports + diff --git a/lualib/lua_magic/heuristics.lua b/lualib/lua_magic/heuristics.lua new file mode 100644 index 0000000..b8a1b41 --- /dev/null +++ b/lualib/lua_magic/heuristics.lua @@ -0,0 +1,605 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_magic/heuristics +-- This module contains heuristics for some specific cases +--]] + +local rspamd_trie = require "rspamd_trie" +local rspamd_util = require "rspamd_util" +local lua_util = require "lua_util" +local bit = require "bit" +local fun = require "fun" + +local N = "lua_magic" +local msoffice_trie +local msoffice_patterns = { + doc = { [[WordDocument]] }, + xls = { [[Workbook]], [[Book]] }, + ppt = { [[PowerPoint Document]], [[Current User]] }, + vsd = { [[VisioDocument]] }, +} +local msoffice_trie_clsid +local msoffice_clsids = { + doc = { [[0609020000000000c000000000000046]] }, + xls = { [[1008020000000000c000000000000046]], [[2008020000000000c000000000000046]] }, + ppt = { [[108d81649b4fcf1186ea00aa00b929e8]] }, + msg = { [[46f0060000000000c000000000000046]], [[0b0d020000000000c000000000000046]] }, + msi = { [[84100c0000000000c000000000000046]] }, +} +local zip_trie +local zip_patterns = { + -- https://lists.oasis-open.org/archives/office/200505/msg00006.html + odt = { + [[mimetypeapplication/vnd\.oasis\.opendocument\.text]], + [[mimetypeapplication/vnd\.oasis\.opendocument\.image]], + [[mimetypeapplication/vnd\.oasis\.opendocument\.graphic]] + }, + ods = { + [[mimetypeapplication/vnd\.oasis\.opendocument\.spreadsheet]], + [[mimetypeapplication/vnd\.oasis\.opendocument\.formula]], + [[mimetypeapplication/vnd\.oasis\.opendocument\.chart]] + }, + odp = { [[mimetypeapplication/vnd\.oasis\.opendocument\.presentation]] }, + epub = { [[epub\+zip]] }, + asice = { [[mimetypeapplication/vnd\.etsi\.asic-e\+zipPK]] }, + asics = { [[mimetypeapplication/vnd\.etsi\.asic-s\+zipPK]] }, +} + +local txt_trie +local txt_patterns = { + html = { + { [=[(?i)<html[\s>]]=], 32 }, + { [[(?i)<script\b]], 20 }, -- Commonly used by spammers + { [[<script\s+type="text\/javascript">]], 31 }, -- Another spammy pattern + { [[(?i)<\!DOCTYPE HTML\b]], 33 }, + { [[(?i)<body\b]], 20 }, + { [[(?i)<table\b]], 20 }, + { [[(?i)<a\s]], 10 }, + { [[(?i)<p\b]], 10 }, + { [[(?i)<div\b]], 10 }, + { [[(?i)<span\b]], 10 }, + }, + csv = { + { [[(?:[-a-zA-Z0-9_]+\s*,){2,}(?:[-a-zA-Z0-9_]+,?[ ]*[\r\n])]], 20 } + }, + ics = { + { [[^BEGIN:VCALENDAR\r?\n]], 40 }, + }, + vcf = { + { [[^BEGIN:VCARD\r?\n]], 40 }, + }, + xml = { + { [[<\?xml\b.+\?>]], 31 }, + } +} + +-- Used to match pattern index and extension +local msoffice_clsid_indexes = {} +local msoffice_patterns_indexes = {} +local zip_patterns_indexes = {} +local txt_patterns_indexes = {} + +local exports = {} + +local function compile_tries() + local default_compile_flags = bit.bor(rspamd_trie.flags.re, + rspamd_trie.flags.dot_all, + rspamd_trie.flags.single_match, + rspamd_trie.flags.no_start) + local function compile_pats(patterns, indexes, transform_func, compile_flags) + local strs = {} + for ext, pats in pairs(patterns) do + for _, pat in ipairs(pats) do + -- These are utf16 strings in fact... + strs[#strs + 1] = transform_func(pat) + indexes[#indexes + 1] = { ext, pat } + end + end + + return rspamd_trie.create(strs, compile_flags or default_compile_flags) + end + + if not msoffice_trie then + -- Directory names + local function msoffice_pattern_transform(pat) + return '^' .. + table.concat( + fun.totable( + fun.map(function(c) + return c .. [[\x{00}]] + end, + fun.iter(pat)))) + end + local function msoffice_clsid_transform(pat) + local hex_table = {} + for i = 1, #pat, 2 do + local subc = pat:sub(i, i + 1) + hex_table[#hex_table + 1] = string.format('\\x{%s}', subc) + end + + return '^' .. table.concat(hex_table) .. '$' + end + -- Directory entries + msoffice_trie = compile_pats(msoffice_patterns, msoffice_patterns_indexes, + msoffice_pattern_transform) + -- Clsids + msoffice_trie_clsid = compile_pats(msoffice_clsids, msoffice_clsid_indexes, + msoffice_clsid_transform) + -- Misc zip patterns at the initial fragment + zip_trie = compile_pats(zip_patterns, zip_patterns_indexes, + function(pat) + return pat + end) + -- Text patterns at the initial fragment + txt_trie = compile_pats(txt_patterns, txt_patterns_indexes, + function(pat_tbl) + return pat_tbl[1] + end, + bit.bor(rspamd_trie.flags.re, + rspamd_trie.flags.dot_all, + rspamd_trie.flags.no_start)) + end +end + +-- Call immediately on require +compile_tries() + +local function detect_ole_format(input, log_obj, _, part) + local inplen = #input + if inplen < 0x31 + 4 then + lua_util.debugm(N, log_obj, "short length: %s", inplen) + return nil + end + + local bom, sec_size = rspamd_util.unpack('<I2<I2', input:span(29, 4)) + if bom == 0xFFFE then + bom = '<' + else + lua_util.debugm(N, log_obj, "bom file!: %s", bom) + bom = '>'; + sec_size = bit.bswap(sec_size) + end + + if sec_size < 7 or sec_size > 31 then + lua_util.debugm(N, log_obj, "bad sec_size: %s", sec_size) + return nil + end + + sec_size = 2 ^ sec_size + + -- SecID of first sector of the directory stream + local directory_offset = (rspamd_util.unpack(bom .. 'I4', input:span(0x31, 4))) + * sec_size + 512 + 1 + lua_util.debugm(N, log_obj, "directory: %s", directory_offset) + + if inplen < directory_offset then + lua_util.debugm(N, log_obj, "short length: %s", inplen) + return nil + end + + local function process_dir_entry(offset) + local dtype = input:byte(offset + 66) + lua_util.debugm(N, log_obj, "dtype: %s, offset: %s", dtype, offset) + + if dtype then + if dtype == 5 then + -- Extract clsid + local matches = msoffice_trie_clsid:match(input:span(offset + 80, 16)) + if matches then + for n, _ in pairs(matches) do + if msoffice_clsid_indexes[n] then + lua_util.debugm(N, log_obj, "found valid clsid for %s", + msoffice_clsid_indexes[n][1]) + return true, msoffice_clsid_indexes[n][1] + end + end + end + return true, nil + elseif dtype == 2 then + local matches = msoffice_trie:match(input:span(offset, 64)) + if matches then + for n, _ in pairs(matches) do + if msoffice_patterns_indexes[n] then + return true, msoffice_patterns_indexes[n][1] + end + end + end + return true, nil + elseif dtype >= 0 and dtype < 5 then + -- Bad type + return true, nil + end + end + + return false, nil + end + + repeat + local res, ext = process_dir_entry(directory_offset) + + if res and ext then + return ext, 60 + end + + if not res then + break + end + + directory_offset = directory_offset + 128 + until directory_offset >= inplen +end + +exports.ole_format_heuristic = detect_ole_format + +local function process_top_detected(res) + local extensions = lua_util.keys(res) + + if #extensions > 0 then + table.sort(extensions, function(ex1, ex2) + return res[ex1] > res[ex2] + end) + + return extensions[1], res[extensions[1]] + end + + return nil +end + +local function detect_archive_flaw(part, arch, log_obj, _) + local arch_type = arch:get_type() + local res = { + docx = 0, + xlsx = 0, + pptx = 0, + jar = 0, + odt = 0, + odp = 0, + ods = 0, + apk = 0, + } -- ext + confidence pairs + + -- General msoffice patterns + local function add_msoffice_confidence(incr) + res.docx = res.docx + incr + res.xlsx = res.xlsx + incr + res.pptx = res.pptx + incr + end + + if arch_type == 'zip' then + -- Find specific files/folders in zip file + local files = arch:get_files(100) or {} + for _, file in ipairs(files) do + if file == '[Content_Types].xml' then + add_msoffice_confidence(10) + elseif file:sub(1, 3) == 'xl/' then + res.xlsx = res.xlsx + 30 + elseif file:sub(1, 5) == 'word/' then + res.docx = res.docx + 30 + elseif file:sub(1, 4) == 'ppt/' then + res.pptx = res.pptx + 30 + elseif file == 'META-INF/MANIFEST.MF' then + res.jar = res.jar + 40 + elseif file == 'AndroidManifest.xml' then + res.apk = res.apk + 60 + end + end + + local ext, weight = process_top_detected(res) + + if weight >= 40 then + return ext, weight + end + + -- Apply misc Zip detection logic + local content = part:get_content() + + if #content > 128 then + local start_span = content:span(1, 128) + + local matches = zip_trie:match(start_span) + if matches then + for n, _ in pairs(matches) do + if zip_patterns_indexes[n] then + lua_util.debugm(N, log_obj, "found zip pattern for %s", + zip_patterns_indexes[n][1]) + return zip_patterns_indexes[n][1], 40 + end + end + end + end + end + + return arch_type:lower(), 40 +end + +local csv_grammar +-- Returns a grammar that will count commas +local function get_csv_grammar() + if not csv_grammar then + local lpeg = require 'lpeg' + + local field = '"' * lpeg.Cs(((lpeg.P(1) - '"') + lpeg.P '""' / '"') ^ 0) * '"' + + lpeg.C((1 - lpeg.S ',\n"') ^ 0) + + csv_grammar = lpeg.Cf(lpeg.Cc(0) * field * lpeg.P((lpeg.P(',') + + lpeg.P('\t')) * field) ^ 1 * (lpeg.S '\r\n' + -1), + function(acc) + return acc + 1 + end) + end + + return csv_grammar +end +local function validate_csv(part, content, log_obj) + local max_chunk = 32768 + local chunk = content:sub(1, max_chunk) + + local expected_commas + local matched_lines = 0 + local max_matched_lines = 10 + + lua_util.debugm(N, log_obj, "check for csv pattern") + + for s in chunk:lines() do + local ncommas = get_csv_grammar():match(s) + + if not ncommas then + lua_util.debugm(N, log_obj, "not a csv line at line number %s", + matched_lines) + return false + end + + if expected_commas and ncommas ~= expected_commas then + -- Mismatched commas + lua_util.debugm(N, log_obj, "missmatched commas on line %s: %s != %s", + matched_lines, ncommas, expected_commas) + return false + elseif not expected_commas then + if ncommas == 0 then + lua_util.debugm(N, log_obj, "no commas in the first line") + return false + end + expected_commas = ncommas + end + + matched_lines = matched_lines + 1 + + if matched_lines > max_matched_lines then + break + end + end + + lua_util.debugm(N, log_obj, "csv content is sane: %s fields; %s lines checked", + expected_commas, matched_lines) + + return true +end + +exports.mime_part_heuristic = function(part, log_obj, _) + if part:is_archive() then + local arch = part:get_archive() + return detect_archive_flaw(part, arch, log_obj) + end + + return nil +end + +exports.text_part_heuristic = function(part, log_obj, _) + -- We get some span of data and check it + local function is_span_text(span) + -- We examine 8 bit content, and we assume it might be localized text + -- if it has more than 3 subsequent 8 bit characters + local function rough_8bit_check(bytes, idx, remain, len) + local b = bytes[idx] + local n8bit = 0 + + while b >= 127 and idx < len do + -- utf8 part + if bit.band(b, 0xe0) == 0xc0 and remain > 1 and + bit.band(bytes[idx + 1], 0xc0) == 0x80 then + return true, 1 + elseif bit.band(b, 0xf0) == 0xe0 and remain > 2 and + bit.band(bytes[idx + 1], 0xc0) == 0x80 and + bit.band(bytes[idx + 2], 0xc0) == 0x80 then + return true, 2 + elseif bit.band(b, 0xf8) == 0xf0 and remain > 3 and + bit.band(bytes[idx + 1], 0xc0) == 0x80 and + bit.band(bytes[idx + 2], 0xc0) == 0x80 and + bit.band(bytes[idx + 3], 0xc0) == 0x80 then + return true, 3 + end + + n8bit = n8bit + 1 + idx = idx + 1 + b = bytes[idx] + remain = remain - 1 + end + + if n8bit >= 3 then + return true, n8bit + end + + return false, 0 + end + + -- Convert to string as LuaJIT can optimise string.sub (and fun.iter) but not C calls + local tlen = #span + local non_printable = 0 + local bytes = span:bytes() + local i = 1 + repeat + local b = bytes[i] + + if (b < 0x20) and not (b == 0x0d or b == 0x0a or b == 0x09) then + non_printable = non_printable + 1 + elseif b >= 127 then + local c, nskip = rough_8bit_check(bytes, i, tlen - i, tlen) + + if not c then + non_printable = non_printable + 1 + else + i = i + nskip + end + end + i = i + 1 + until i > tlen + + lua_util.debugm(N, log_obj, "text part check: %s printable, %s non-printable, %s total", + tlen - non_printable, non_printable, tlen) + if non_printable / tlen > 0.0078125 then + return false + end + + return true + end + + local parent = part:get_parent() + + if parent then + local parent_type, parent_subtype = parent:get_type() + + if parent_type == 'multipart' and parent_subtype == 'encrypted' then + -- Skip text heuristics for encrypted parts + lua_util.debugm(N, log_obj, "text part check: parent is encrypted, not a text part") + + return false + end + end + + local content = part:get_content() + local mtype, msubtype = part:get_type() + local clen = #content + local is_text + + if clen > 0 then + if clen > 80 * 3 then + -- Use chunks + is_text = is_span_text(content:span(1, 160)) and is_span_text(content:span(clen - 80, 80)) + else + is_text = is_span_text(content) + end + + if is_text and mtype ~= 'message' then + -- Try patterns + local span_len = math.min(4096, clen) + local start_span = content:span(1, span_len) + local matches = txt_trie:match(start_span) + local res = {} + local fname = part:get_filename() + + if matches then + -- Require at least 2 occurrences of those patterns + for n, positions in pairs(matches) do + local ext, weight = txt_patterns_indexes[n][1], txt_patterns_indexes[n][2][2] + if ext then + res[ext] = (res[ext] or 0) + weight * #positions + lua_util.debugm(N, log_obj, "found txt pattern for %s: %s, total: %s; %s/%s announced", + ext, weight * #positions, res[ext], mtype, msubtype) + end + end + + if res.html and res.html >= 40 then + -- HTML has priority over something like js... + return 'html', res.html + end + + local ext, weight = process_top_detected(res) + + if weight then + if weight >= 40 then + -- Extra validation for csv extension + if ext ~= 'csv' or validate_csv(part, content, log_obj) then + return ext, weight + end + elseif fname and weight >= 20 then + return ext, weight + end + end + end + + -- Content type stuff + if (mtype == 'text' or mtype == 'application') and + (msubtype == 'html' or msubtype == 'xhtml+xml') then + return 'html', 21 + end + + if msubtype:lower() == 'csv' then + if validate_csv(part, content, log_obj) then + return 'csv', 40 + end + end + + -- Extension stuff + local function has_extension(file, ext) + local ext_len = ext:len() + return file:len() > ext_len + 1 + and file:sub(-ext_len):lower() == ext + and file:sub(-ext_len - 1, -ext_len - 1) == '.' + end + + if fname and (has_extension(fname, 'htm') or has_extension(fname, 'html')) then + return 'html', 21 + end + + if mtype ~= 'text' then + -- Do not treat non text patterns as text + return nil + end + + return 'txt', 40 + end + end +end + +exports.pdf_format_heuristic = function(input, log_obj, pos, part) + local weight = 10 + local ext = string.match(part:get_filename() or '', '%.([^.]+)$') + -- If we found a pattern at the beginning + if pos <= 10 then + weight = weight + 30 + end + -- If the announced extension is `pdf` + if ext and ext:lower() == 'pdf' then + weight = weight + 30 + end + + return 'pdf', weight +end + +exports.pe_part_heuristic = function(input, log_obj, pos, part) + if not input then + return + end + + -- pe header should start at the offset that is placed in msdos header at position 60..64 + local pe_ptr_bin = input:sub(60, 64) + if #pe_ptr_bin ~= 4 then + return + end + + -- it is an LE 32 bit integer + local pe_ptr = rspamd_util.unpack("<I4", pe_ptr_bin) + -- if pe header magic matches the offset, it is definitely a PE file + if pe_ptr ~= pos then + return + end + + return 'exe', 30 +end + +return exports diff --git a/lualib/lua_magic/init.lua b/lualib/lua_magic/init.lua new file mode 100644 index 0000000..38bfddb --- /dev/null +++ b/lualib/lua_magic/init.lua @@ -0,0 +1,388 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_magic +-- This module contains file types detection logic +--]] + +local patterns = require "lua_magic/patterns" +local types = require "lua_magic/types" +local heuristics = require "lua_magic/heuristics" +local fun = require "fun" +local lua_util = require "lua_util" + +local rspamd_text = require "rspamd_text" +local rspamd_trie = require "rspamd_trie" + +local N = "lua_magic" +local exports = {} +-- trie objects +local compiled_patterns +local compiled_short_patterns +local compiled_tail_patterns +-- {<str>, <match_object>, <pattern_object>} indexed by pattern number +local processed_patterns = {} +local short_patterns = {} +local tail_patterns = {} + +local short_match_limit = 128 +local max_short_offset = -1 +local min_tail_offset = math.huge + +local function process_patterns(log_obj) + -- Add pattern to either short patterns or to normal patterns + local function add_processed(str, match, pattern) + if match.position and type(match.position) == 'number' then + if match.tail then + -- Tail pattern + tail_patterns[#tail_patterns + 1] = { + str, match, pattern + } + if min_tail_offset > match.tail then + min_tail_offset = match.tail + end + + lua_util.debugm(N, log_obj, 'add tail pattern %s for ext %s', + str, pattern.ext) + elseif match.position < short_match_limit then + short_patterns[#short_patterns + 1] = { + str, match, pattern + } + if str:sub(1, 1) == '^' then + lua_util.debugm(N, log_obj, 'add head pattern %s for ext %s', + str, pattern.ext) + else + lua_util.debugm(N, log_obj, 'add short pattern %s for ext %s', + str, pattern.ext) + end + + if max_short_offset < match.position then + max_short_offset = match.position + end + else + processed_patterns[#processed_patterns + 1] = { + str, match, pattern + } + + lua_util.debugm(N, log_obj, 'add long pattern %s for ext %s', + str, pattern.ext) + end + else + processed_patterns[#processed_patterns + 1] = { + str, match, pattern + } + + lua_util.debugm(N, log_obj, 'add long pattern %s for ext %s', + str, pattern.ext) + end + end + + if not compiled_patterns then + for ext, pattern in pairs(patterns) do + assert(types[ext], 'not found type: ' .. ext) + pattern.ext = ext + for _, match in ipairs(pattern.matches) do + if match.string then + if match.relative_position and not match.position then + match.position = match.relative_position + #match.string + + if match.relative_position == 0 then + if match.string:sub(1, 1) ~= '^' then + match.string = '^' .. match.string + end + end + end + add_processed(match.string, match, pattern) + elseif match.hex then + local hex_table = {} + + for i = 1, #match.hex, 2 do + local subc = match.hex:sub(i, i + 1) + hex_table[#hex_table + 1] = string.format('\\x{%s}', subc) + end + + if match.relative_position and not match.position then + match.position = match.relative_position + #match.hex / 2 + end + if match.relative_position == 0 then + table.insert(hex_table, 1, '^') + end + add_processed(table.concat(hex_table), match, pattern) + end + end + end + local bit = require "bit" + local compile_flags = bit.bor(rspamd_trie.flags.re, rspamd_trie.flags.dot_all) + compile_flags = bit.bor(compile_flags, rspamd_trie.flags.single_match) + compile_flags = bit.bor(compile_flags, rspamd_trie.flags.no_start) + compiled_patterns = rspamd_trie.create(fun.totable( + fun.map(function(t) + return t[1] + end, processed_patterns)), + compile_flags + ) + compiled_short_patterns = rspamd_trie.create(fun.totable( + fun.map(function(t) + return t[1] + end, short_patterns)), + compile_flags + ) + compiled_tail_patterns = rspamd_trie.create(fun.totable( + fun.map(function(t) + return t[1] + end, tail_patterns)), + compile_flags + ) + + lua_util.debugm(N, log_obj, + 'compiled %s (%s short; %s long; %s tail) patterns', + #processed_patterns + #short_patterns + #tail_patterns, + #short_patterns, #processed_patterns, #tail_patterns) + end +end + +process_patterns(rspamd_config) + +local function match_chunk(chunk, input, tlen, offset, trie, processed_tbl, log_obj, res, part) + local matches = trie:match(chunk) + + local last = tlen + + local function add_result(weight, ext) + if not res[ext] then + res[ext] = 0 + end + if weight then + res[ext] = res[ext] + weight + else + res[ext] = res[ext] + 1 + end + + lua_util.debugm(N, log_obj, 'add pattern for %s, weight %s, total weight %s', + ext, weight, res[ext]) + end + + local function match_position(pos, expected) + local cmp = function(a, b) + return a == b + end + if type(expected) == 'table' then + -- Something like {'>', 0} + if expected[1] == '>' then + cmp = function(a, b) + return a > b + end + elseif expected[1] == '>=' then + cmp = function(a, b) + return a >= b + end + elseif expected[1] == '<' then + cmp = function(a, b) + return a < b + end + elseif expected[1] == '<=' then + cmp = function(a, b) + return a <= b + end + elseif expected[1] == '!=' then + cmp = function(a, b) + return a ~= b + end + end + expected = expected[2] + end + + -- Tail match + if expected < 0 then + expected = last + expected + 1 + end + return cmp(pos, expected) + end + + for npat, matched_positions in pairs(matches) do + local pat_data = processed_tbl[npat] + local pattern = pat_data[3] + local match = pat_data[2] + + -- Single position + if match.position then + local position = match.position + + for _, pos in ipairs(matched_positions) do + lua_util.debugm(N, log_obj, 'found match %s at offset %s(from %s)', + pattern.ext, pos, offset) + if match_position(pos + offset, position) then + if match.heuristic then + local ext, weight = match.heuristic(input, log_obj, pos + offset, part) + + if ext then + add_result(weight, ext) + break + end + else + add_result(match.weight, pattern.ext) + break + end + end + end + elseif match.positions then + -- Match all positions + local all_right = true + local matched_pos = 0 + for _, position in ipairs(match.positions) do + local matched = false + for _, pos in ipairs(matched_positions) do + lua_util.debugm(N, log_obj, 'found match %s at offset %s(from %s)', + pattern.ext, pos, offset) + if not match_position(pos + offset, position) then + matched = true + matched_pos = pos + break + end + end + if not matched then + all_right = false + break + end + end + + if all_right then + if match.heuristic then + local ext, weight = match.heuristic(input, log_obj, matched_pos + offset, part) + + if ext then + add_result(weight, ext) + break + end + else + add_result(match.weight, pattern.ext) + break + end + end + end + end + +end + +local function process_detected(res) + local extensions = lua_util.keys(res) + + if #extensions > 0 then + table.sort(extensions, function(ex1, ex2) + return res[ex1] > res[ex2] + end) + + return extensions, res[extensions[1]] + end + + return nil +end + +exports.detect = function(part, log_obj) + if not log_obj then + log_obj = rspamd_config + end + local input = part:get_content() + + local res = {} + + if type(input) == 'string' then + -- Convert to rspamd_text + input = rspamd_text.fromstring(input) + end + + if type(input) == 'userdata' then + local inplen = #input + + -- Check tail matches + if inplen > min_tail_offset then + local tail = input:span(inplen - min_tail_offset, min_tail_offset) + match_chunk(tail, input, inplen, inplen - min_tail_offset, + compiled_tail_patterns, tail_patterns, log_obj, res, part) + end + + -- Try short match + local head = input:span(1, math.min(max_short_offset, inplen)) + match_chunk(head, input, inplen, 0, + compiled_short_patterns, short_patterns, log_obj, res, part) + + -- Check if we have enough data or go to long patterns + local extensions, confidence = process_detected(res) + + if extensions and #extensions > 0 and confidence > 30 then + -- We are done on short patterns + return extensions[1], types[extensions[1]] + end + + -- No way, let's check data in chunks or just the whole input if it is small enough + if #input > exports.chunk_size * 3 then + -- Chunked version as input is too long + local chunk1, chunk2 = input:span(1, exports.chunk_size * 2), + input:span(inplen - exports.chunk_size, exports.chunk_size) + local offset1, offset2 = 0, inplen - exports.chunk_size + + match_chunk(chunk1, input, inplen, + offset1, compiled_patterns, processed_patterns, log_obj, res, part) + match_chunk(chunk2, input, inplen, + offset2, compiled_patterns, processed_patterns, log_obj, res, part) + else + -- Input is short enough to match it at all + match_chunk(input, input, inplen, 0, + compiled_patterns, processed_patterns, log_obj, res, part) + end + else + -- Table input is NYI + assert(0, 'table input for match') + end + + local extensions = process_detected(res) + + if extensions and #extensions > 0 then + return extensions[1], types[extensions[1]] + end + + -- Nothing found + return nil +end + +exports.detect_mime_part = function(part, log_obj) + local ext, weight = heuristics.mime_part_heuristic(part, log_obj) + + if ext and weight and weight > 20 then + return ext, types[ext] + end + + ext = exports.detect(part, log_obj) + + if ext then + return ext, types[ext] + end + + -- Text/html and other parts + ext, weight = heuristics.text_part_heuristic(part, log_obj) + if ext and weight and weight > 20 then + return ext, types[ext] + end +end + +-- This parameter specifies how many bytes are checked in the input +-- Rspamd checks 2 chunks at start and 1 chunk at the end +exports.chunk_size = 32768 + +exports.types = types + +return exports
\ No newline at end of file diff --git a/lualib/lua_magic/patterns.lua b/lualib/lua_magic/patterns.lua new file mode 100644 index 0000000..971ddd9 --- /dev/null +++ b/lualib/lua_magic/patterns.lua @@ -0,0 +1,471 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_magic/patterns +-- This module contains most common patterns +--]] + +local heuristics = require "lua_magic/heuristics" + +local patterns = { + pdf = { + -- These are alternatives + matches = { + { + string = [[%PDF-[12]\.\d]], + position = { '<=', 1024 }, + weight = 60, + heuristic = heuristics.pdf_format_heuristic + }, + { + string = [[%FDF-[12]\.\d]], + position = { '<=', 1024 }, + weight = 60, + heuristic = heuristics.pdf_format_heuristic + }, + }, + }, + ps = { + matches = { + { + string = [[%!PS-Adobe]], + relative_position = 0, + weight = 60, + }, + }, + }, + -- RTF document + rtf = { + matches = { + { + string = [[^{\\rt]], + position = 4, + weight = 60, + } + } + }, + chm = { + matches = { + { + string = [[ITSF]], + relative_position = 0, + weight = 60, + } + } + }, + djvu = { + matches = { + { + string = [[AT&TFORM]], + relative_position = 0, + weight = 60, + }, + { + string = [[DJVM]], + relative_position = 0x0c, + weight = 60, + } + } + }, + -- MS Office format, needs heuristic + ole = { + matches = { + { + hex = [[d0cf11e0a1b11ae1]], + relative_position = 0, + weight = 60, + heuristic = heuristics.ole_format_heuristic + } + } + }, + -- MS Exe file + exe = { + matches = { + { + string = [[MZ]], + relative_position = 0, + weight = 15, + }, + -- PE part + { + string = [[PE\x{00}\x{00}]], + position = { '>=', 0x3c + 4 }, + weight = 15, + heuristic = heuristics.pe_part_heuristic, + } + } + }, + elf = { + matches = { + { + hex = [[7f454c46]], + relative_position = 0, + weight = 60, + }, + } + }, + lnk = { + matches = { + { + hex = [[4C0000000114020000000000C000000000000046]], + relative_position = 0, + weight = 60, + }, + } + }, + bat = { + matches = { + { + string = [[(?i)@\s*ECHO\s+OFF]], + position = { '>=', 0 }, + weight = 60, + }, + } + }, + class = { + -- Technically, this also matches MachO files, but I don't care about + -- Apple and their mental health problems here: just consider Java files, + -- Mach object files and all other cafe babes as bad and block them! + matches = { + { + hex = [[cafebabe]], + relative_position = 0, + weight = 60, + }, + } + }, + ics = { + matches = { + { + string = [[BEGIN:VCALENDAR]], + weight = 60, + relative_position = 0, + } + } + }, + vcf = { + matches = { + { + string = [[BEGIN:VCARD]], + weight = 60, + relative_position = 0, + } + } + }, + -- Archives + arj = { + matches = { + { + hex = '60EA', + relative_position = 0, + weight = 60, + }, + } + }, + ace = { + matches = { + { + string = [[\*\*ACE\*\*]], + position = 14, + weight = 60, + }, + } + }, + cab = { + matches = { + { + hex = [[4d53434600000000]], -- Can be anywhere for SFX :( + position = { '>=', 8 }, + weight = 60, + }, + } + }, + tar = { + matches = { + { + string = [[ustar]], + relative_position = 257, + weight = 60, + }, + } + }, + bz2 = { + matches = { + { + string = "^BZ[h0]", + position = 3, + weight = 60, + }, + } + }, + lz4 = { + matches = { + { + hex = "04224d18", + relative_position = 0, + weight = 60, + }, + { + hex = "03214c18", + relative_position = 0, + weight = 60, + }, + { + hex = "02214c18", + relative_position = 0, + weight = 60, + }, + { + -- MozLZ4 + hex = '6d6f7a4c7a343000', + relative_position = 0, + weight = 60, + } + } + }, + zst = { + matches = { + { + string = [[^[\x{22}-\x{40}]\x{B5}\x{2F}\x{FD}]], + position = 4, + weight = 60, + }, + } + }, + zoo = { + matches = { + { + hex = [[dca7c4fd]], + relative_position = 20, + weight = 60, + }, + } + }, + xar = { + matches = { + { + string = [[xar!]], + relative_position = 0, + weight = 60, + }, + } + }, + iso = { + matches = { + { + string = [[\x{01}CD001\x{01}]], + position = { '>=', 0x8000 + 7 }, -- first 32k is unused + weight = 60, + }, + } + }, + egg = { + -- ALZip egg + matches = { + { + string = [[EGGA]], + weight = 60, + relative_position = 0, + }, + } + }, + alz = { + -- ALZip alz + matches = { + { + string = [[ALZ\x{01}]], + weight = 60, + relative_position = 0, + }, + } + }, + -- Apple is a 'special' child: this needs to be matched at the data tail... + dmg = { + matches = { + { + string = [[koly\x{00}\x{00}\x{00}\x{04}]], + position = -512 + 8, + weight = 61, + tail = 512, + }, + } + }, + szdd = { + matches = { + { + hex = [[535a4444]], + relative_position = 0, + weight = 60, + }, + } + }, + xz = { + matches = { + { + hex = [[FD377A585A00]], + relative_position = 0, + weight = 60, + }, + } + }, + -- Images + psd = { + matches = { + { + string = [[8BPS]], + relative_position = 0, + weight = 60, + }, + } + }, + ico = { + matches = { + { + hex = [[00000100]], + relative_position = 0, + weight = 60, + }, + } + }, + pcx = { + matches = { + { + hex = [[0A050108]], + relative_position = 0, + weight = 60, + }, + } + }, + pic = { + matches = { + { + hex = [[FF80C9C71A00]], + relative_position = 0, + weight = 60, + }, + } + }, + swf = { + matches = { + { + hex = [[5a5753]], -- LZMA + relative_position = 0, + weight = 60, + }, + { + hex = [[435753]], -- Zlib + relative_position = 0, + weight = 60, + }, + { + hex = [[465753]], -- Uncompressed + relative_position = 0, + weight = 60, + }, + } + }, + tiff = { + matches = { + { + hex = [[49492a00]], -- LE encoded + relative_position = 0, + weight = 60, + }, + { + hex = [[4d4d]], -- BE tiff + relative_position = 0, + weight = 60, + }, + } + }, + -- Other + pgp = { + matches = { + { + hex = [[A803504750]], + relative_position = 0, + weight = 60, + }, + { + hex = [[2D424547494E20504750204D4553534147452D]], + relative_position = 0, + weight = 60, + }, + } + }, + uue = { + matches = { + { + hex = [[626567696e20]], + relative_position = 0, + weight = 60, + }, + } + }, + dwg = { + matches = { + { + string = '^AC10[12][2-9]', + position = 6, + weight = 60, + } + } + }, + jpg = { + matches = { + { -- JPEG2000 + hex = [[0000000c6a5020200d0a870a]], + relative_position = 0, + weight = 60, + }, + { + string = [[^\x{ff}\x{d8}\x{ff}]], + weight = 60, + position = 3, + }, + }, + }, + png = { + matches = { + { + string = [[^\x{89}PNG\x{0d}\x{0a}\x{1a}\x{0a}]], + position = 8, + weight = 60, + }, + } + }, + gif = { + matches = { + { + string = [[^GIF8\d]], + position = 5, + weight = 60, + }, + } + }, + bmp = { + matches = { + { + string = [[^BM...\x{00}\x{00}\x{00}\x{00}]], + position = 9, + weight = 60, + }, + } + }, +} + +return patterns diff --git a/lualib/lua_magic/types.lua b/lualib/lua_magic/types.lua new file mode 100644 index 0000000..3dce2e1 --- /dev/null +++ b/lualib/lua_magic/types.lua @@ -0,0 +1,327 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_magic/patterns +-- This module contains types definitions +--]] + +-- This table is indexed by msdos extension for convenience + +local types = { + -- exe + exe = { + ct = 'application/x-ms-application', + type = 'executable', + }, + elf = { + ct = 'application/x-elf-executable', + type = 'executable', + }, + lnk = { + ct = 'application/x-ms-application', + type = 'executable', + }, + class = { + ct = 'application/x-java-applet', + type = 'executable', + }, + jar = { + ct = 'application/java-archive', + type = 'archive', + }, + apk = { + ct = 'application/vnd.android.package-archive', + type = 'archive', + }, + bat = { + ct = 'application/x-bat', + type = 'executable', + }, + -- text + rtf = { + ct = "application/rtf", + type = 'binary', + }, + pdf = { + ct = 'application/pdf', + type = 'binary', + }, + ps = { + ct = 'application/postscript', + type = 'binary', + }, + chm = { + ct = 'application/x-chm', + type = 'binary', + }, + djvu = { + ct = 'application/x-djvu', + type = 'binary', + }, + -- archives + arj = { + ct = 'application/x-arj', + type = 'archive', + }, + cab = { + ct = 'application/x-cab', + type = 'archive', + }, + ace = { + ct = 'application/x-ace', + type = 'archive', + }, + tar = { + ct = 'application/x-tar', + type = 'archive', + }, + bz2 = { + ct = 'application/x-bzip', + type = 'archive', + }, + xz = { + ct = 'application/x-xz', + type = 'archive', + }, + lz4 = { + ct = 'application/x-lz4', + type = 'archive', + }, + zst = { + ct = 'application/x-zstandard', + type = 'archive', + }, + dmg = { + ct = 'application/x-dmg', + type = 'archive', + }, + iso = { + ct = 'application/x-iso', + type = 'archive', + }, + zoo = { + ct = 'application/x-zoo', + type = 'archive', + }, + egg = { + ct = 'application/x-egg', + type = 'archive', + }, + alz = { + ct = 'application/x-alz', + type = 'archive', + }, + xar = { + ct = 'application/x-xar', + type = 'archive', + }, + epub = { + ct = 'application/x-epub', + type = 'archive' + }, + szdd = { -- in fact, their MSDOS extension is like FOO.TX_ or FOO.TX$ + ct = 'application/x-compressed', + type = 'archive', + }, + -- images + psd = { + ct = 'image/psd', + type = 'image', + av_check = false, + }, + pcx = { + ct = 'image/pcx', + type = 'image', + av_check = false, + }, + pic = { + ct = 'image/pic', + type = 'image', + av_check = false, + }, + tiff = { + ct = 'image/tiff', + type = 'image', + av_check = false, + }, + ico = { + ct = 'image/ico', + type = 'image', + av_check = false, + }, + swf = { + ct = 'application/x-shockwave-flash', + type = 'image', + }, + -- Ole files + ole = { + ct = 'application/octet-stream', + type = 'office' + }, + doc = { + ct = 'application/msword', + type = 'office' + }, + xls = { + ct = 'application/vnd.ms-excel', + type = 'office' + }, + ppt = { + ct = 'application/vnd.ms-powerpoint', + type = 'office' + }, + vsd = { + ct = 'application/vnd.visio', + type = 'office' + }, + msi = { + ct = 'application/x-msi', + type = 'executable' + }, + msg = { + ct = 'application/vnd.ms-outlook', + type = 'office' + }, + -- newer office (2007+) + docx = { + ct = 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', + type = 'office' + }, + xlsx = { + ct = 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', + type = 'office' + }, + pptx = { + ct = 'application/vnd.openxmlformats-officedocument.presentationml.presentation', + type = 'office' + }, + -- OpenOffice formats + odt = { + ct = 'application/vnd.oasis.opendocument.text', + type = 'office' + }, + ods = { + ct = 'application/vnd.oasis.opendocument.spreadsheet', + type = 'office' + }, + odp = { + ct = 'application/vnd.oasis.opendocument.presentation', + type = 'office' + }, + -- https://en.wikipedia.org/wiki/Associated_Signature_Containers + asice = { + ct = 'application/vnd.etsi.asic-e+zip', + type = 'office' + }, + asics = { + ct = 'application/vnd.etsi.asic-s+zip', + type = 'office' + }, + -- other + pgp = { + ct = 'application/encrypted', + type = 'encrypted' + }, + uue = { + ct = 'application/x-uuencoded', + type = 'binary', + }, + -- Types that are detected by Rspamd itself + -- Archives + zip = { + ct = 'application/zip', + type = 'archive', + }, + rar = { + ct = 'application/x-rar', + type = 'archive', + }, + ['7z'] = { + ct = 'application/x-7z-compressed', + type = 'archive', + }, + gz = { + ct = 'application/gzip', + type = 'archive', + }, + -- Images + png = { + ct = 'image/png', + type = 'image', + av_check = false, + }, + gif = { + ct = 'image/gif', + type = 'image', + av_check = false, + }, + jpg = { + ct = 'image/jpeg', + type = 'image', + av_check = false, + }, + bmp = { + type = 'image', + ct = 'image/bmp', + av_check = false, + }, + dwg = { + type = 'image', + ct = 'image/vnd.dwg', + }, + -- Text + xml = { + ct = 'application/xml', + type = 'text', + no_text = true, + }, + txt = { + type = 'text', + ct = 'text/plain', + av_check = false, + }, + html = { + type = 'text', + ct = 'text/html', + av_check = false, + }, + csv = { + type = 'text', + ct = 'text/csv', + av_check = false, + no_text = true, + }, + ics = { + type = 'text', + ct = 'text/calendar', + av_check = false, + no_text = true, + }, + vcf = { + type = 'text', + ct = 'text/vcard', + av_check = false, + no_text = true, + }, + eml = { + type = 'message', + ct = 'message/rfc822', + av_check = false, + }, +} + +return types
\ No newline at end of file diff --git a/lualib/lua_maps.lua b/lualib/lua_maps.lua new file mode 100644 index 0000000..d357310 --- /dev/null +++ b/lualib/lua_maps.lua @@ -0,0 +1,612 @@ +--[[[ +-- @module lua_maps +-- This module contains helper functions for managing rspamd maps +--]] + +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local rspamd_logger = require "rspamd_logger" +local ts = require("tableshape").types +local lua_util = require "lua_util" + +local exports = {} + +local maps_cache = {} + +local function map_hash_key(data, mtype) + local hash = require "rspamd_cryptobox_hash" + local st = hash.create_specific('xxh64') + st:update(data) + st:update(mtype) + + return st:hex() +end + +local function starts(where, st) + return string.sub(where, 1, string.len(st)) == st +end + +local function cut_prefix(where, st) + return string.sub(where, #st + 1) +end + +local function maybe_adjust_type(data, mtype) + local function check_prefix(prefix, t) + if starts(data, prefix) then + data = cut_prefix(data, prefix) + mtype = t + + return true + end + + return false + end + + local known_types = { + { 'regexp;', 'regexp' }, + { 're;', 'regexp' }, + { 'regexp_multi;', 'regexp_multi' }, + { 're_multi;', 'regexp_multi' }, + { 'glob;', 'glob' }, + { 'glob_multi;', 'glob_multi' }, + { 'radix;', 'radix' }, + { 'ipnet;', 'radix' }, + { 'set;', 'set' }, + { 'hash;', 'hash' }, + { 'plain;', 'hash' }, + { 'cdb;', 'cdb' }, + { 'cdb:/', 'cdb' }, + } + + if mtype == 'callback' then + return mtype + end + + for _, t in ipairs(known_types) do + if check_prefix(t[1], t[2]) then + return data, mtype + end + end + + -- No change + return data, mtype +end + +local external_map_schema = ts.shape { + external = ts.equivalent(true), -- must be true + backend = ts.string, -- where to get data, required + method = ts.one_of { "body", "header", "query" }, -- how to pass input + encode = ts.one_of { "json", "messagepack" }:is_optional(), -- how to encode input (if relevant) + timeout = (ts.number + ts.string / lua_util.parse_time_interval):is_optional(), +} + +local rspamd_http = require "rspamd_http" +local ucl = require "ucl" + +local function url_encode_string(str) + str = string.gsub(str, "([^%w _%%%-%.~])", + function(c) + return string.format("%%%02X", string.byte(c)) + end) + str = string.gsub(str, " ", "+") + return str +end + +assert(url_encode_string('上海+ä¸åœ‹') == '%E4%B8%8A%E6%B5%B7%2B%E4%B8%AD%E5%9C%8B') +assert(url_encode_string('? and the Mysterians') == '%3F+and+the+Mysterians') + +local function query_external_map(map_config, upstreams, key, callback, task) + local http_method = (map_config.method == 'body' or map_config.method == 'form') and 'POST' or 'GET' + local upstream = upstreams:get_upstream_round_robin() + local http_headers = { + ['Accept'] = '*/*' + } + local http_body = nil + local url = map_config.backend + + if type(key) == 'string' or type(key) == 'userdata' then + if map_config.method == 'body' then + http_body = key + http_headers['Content-Type'] = 'text/plain' + elseif map_config.method == 'header' then + http_headers = { + key = key + } + elseif map_config.method == 'query' then + url = string.format('%s?key=%s', url, url_encode_string(tostring(key))) + end + elseif type(key) == 'table' then + if map_config.method == 'body' then + if map_config.encode == 'json' then + http_body = ucl.to_format(key, 'json-compact', true) + http_headers['Content-Type'] = 'application/json' + elseif map_config.encode == 'messagepack' then + http_body = ucl.to_format(key, 'messagepack', true) + http_headers['Content-Type'] = 'application/msgpack' + else + local caller = debug.getinfo(2) or {} + rspamd_logger.errx(task, + "requested external map key with a wrong combination body method and missing encode; caller: %s:%s", + caller.short_src, caller.currentline) + callback(false, 'invalid map usage', 500, task) + end + else + -- query/header and no encode + if map_config.method == 'query' then + local params_table = {} + for k, v in pairs(key) do + if type(v) == 'string' then + table.insert(params_table, string.format('%s=%s', url_encode_string(k), url_encode_string(v))) + end + end + url = string.format('%s?%s', url, table.concat(params_table, '&')) + elseif map_config.method == 'header' then + http_headers = key + else + local caller = debug.getinfo(2) or {} + rspamd_logger.errx(task, + "requested external map key with a wrong combination of encode and input; caller: %s:%s", + caller.short_src, caller.currentline) + callback(false, 'invalid map usage', 500, task) + return + end + end + end + + local function map_callback(err, code, body, _) + if err then + callback(false, err, code, task) + elseif code == 200 then + callback(true, body, 200, task) + else + callback(false, err, code, task) + end + end + + local ret = rspamd_http.request { + task = task, + url = url, + callback = map_callback, + timeout = map_config.timeout or 1.0, + keepalive = true, + upstream = upstream, + method = http_method, + headers = http_headers, + body = http_body, + } + + if not ret then + callback(false, 'http request error', 500, task) + end +end + +--[[[ +-- @function lua_maps.map_add_from_ucl(opt, mtype, description) +-- Creates a map from static data +-- Returns true if map was added or nil +-- @param {string or table} opt data for map (or URL) +-- @param {string} mtype type of map (`set`, `map`, `radix`, `regexp`) +-- @param {string} description human-readable description of map +-- @param {function} callback optional callback that will be called on map match (required for external maps) +-- @return {bool} true on success, or `nil` +--]] +local function rspamd_map_add_from_ucl(opt, mtype, description, callback) + local ret = { + get_key = function(t, k, key_callback, task) + if t.__data then + local cb = key_callback or callback + if t.__external then + if not cb or not task then + local caller = debug.getinfo(2) or {} + rspamd_logger.errx(rspamd_config, "requested external map key without callback or task; caller: %s:%s", + caller.short_src, caller.currentline) + return nil + end + query_external_map(t.__data, t.__upstreams, k, cb, task) + else + local result = t.__data:get_key(k) + if cb then + if result then + cb(true, result, 200, task) + else + cb(false, 'not found', 404, task) + end + else + return result + end + end + end + + return nil + end, + foreach = function(t, cb) + return t.__data:foreach(cb) + end, + on_load = function(t, cb) + t.__data:on_load(cb) + end + } + local ret_mt = { + __index = function(t, k, key_callback, task) + if t.__data then + return t.get_key(k, key_callback, task) + end + + return nil + end + } + + if not opt then + return nil + end + + local function maybe_register_selector() + if opt.selector_alias then + local lua_selectors = require "lua_selectors" + lua_selectors.add_map(opt.selector_alias, ret) + end + end + + if type(opt) == 'string' then + opt, mtype = maybe_adjust_type(opt, mtype) + local cache_key = map_hash_key(opt, mtype) + if not callback and maps_cache[cache_key] then + rspamd_logger.infox(rspamd_config, 'reuse url for %s(%s)', + opt, mtype) + + return maps_cache[cache_key] + end + -- We have a single string, so we treat it as a map + local map = rspamd_config:add_map { + type = mtype, + description = description, + url = opt, + } + + if map then + ret.__data = map + ret.hash = cache_key + setmetatable(ret, ret_mt) + maps_cache[cache_key] = ret + return ret + end + elseif type(opt) == 'table' then + local cache_key = lua_util.table_digest(opt) + if not callback and maps_cache[cache_key] then + rspamd_logger.infox(rspamd_config, 'reuse url for complex map definition %s: %s', + cache_key:sub(1, 8), description) + + return maps_cache[cache_key] + end + + if opt[1] then + -- Adjust each element if needed + local adjusted + for i, source in ipairs(opt) do + local nsrc, ntype = maybe_adjust_type(source, mtype) + + if mtype ~= ntype then + if not adjusted then + mtype = ntype + end + adjusted = true + end + opt[i] = nsrc + end + + if mtype == 'radix' then + + if string.find(opt[1], '^%d') then + local map = rspamd_config:radix_from_ucl(opt) + + if map then + ret.__data = map + setmetatable(ret, ret_mt) + maps_cache[cache_key] = ret + maybe_register_selector() + + return ret + end + else + -- Plain table + local map = rspamd_config:add_map { + type = mtype, + description = description, + url = opt, + } + if map then + ret.__data = map + setmetatable(ret, ret_mt) + maps_cache[cache_key] = ret + maybe_register_selector() + + return ret + end + end + elseif mtype == 'regexp' or mtype == 'glob' then + if string.find(opt[1], '^/%a') or string.find(opt[1], '^http') then + -- Plain table + local map = rspamd_config:add_map { + type = mtype, + description = description, + url = opt, + } + if map then + ret.__data = map + setmetatable(ret, ret_mt) + maps_cache[cache_key] = ret + maybe_register_selector() + + return ret + end + else + local map = rspamd_config:add_map { + type = mtype, + description = description, + url = { + url = 'static', + data = opt, + } + } + if map then + ret.__data = map + setmetatable(ret, ret_mt) + maps_cache[cache_key] = ret + maybe_register_selector() + + return ret + end + end + else + if string.find(opt[1], '^/%a') or string.find(opt[1], '^http') then + -- Plain table + local map = rspamd_config:add_map { + type = mtype, + description = description, + url = opt, + } + if map then + ret.__data = map + setmetatable(ret, ret_mt) + maps_cache[cache_key] = ret + maybe_register_selector() + + return ret + end + else + local data = {} + local nelts = 0 + -- Plain array of keys, count merely numeric elts + for _, elt in ipairs(opt) do + if type(elt) == 'string' then + -- Numeric table + if mtype == 'hash' then + -- Treat as KV pair + local pieces = lua_util.str_split(elt, ' ') + if #pieces > 1 then + local key = table.remove(pieces, 1) + data[key] = table.concat(pieces, ' ') + else + data[elt] = true + end + else + data[elt] = true + end + + nelts = nelts + 1 + end + end + + if nelts > 0 then + -- Plain Lua table that is used as a map + ret.__data = data + ret.get_key = function(t, k) + if k ~= '__data' then + return t.__data[k] + end + + return nil + end + ret.foreach = function(_, func) + for k, v in pairs(ret.__data) do + if not func(k, v) then + return false + end + end + + return true + end + ret.on_load = function(_, cb) + rspamd_config:add_on_load(function(_, _, _) + cb() + end) + end + + maps_cache[cache_key] = ret + maybe_register_selector() + + return ret + else + -- Empty map, huh? + rspamd_logger.errx(rspamd_config, 'invalid map element: %s', + opt) + end + end + end + else + if opt.external then + -- External map definition, missing fields are handled by schema + local parse_res, parse_err = external_map_schema(opt) + + if parse_res then + ret.__upstreams = lua_util.http_upstreams_by_url(rspamd_config:get_mempool(), opt.backend) + if ret.__upstreams then + ret.__data = opt + ret.__external = true + setmetatable(ret, ret_mt) + maybe_register_selector() + + return ret + else + rspamd_logger.errx(rspamd_config, 'cannot parse external map upstreams: %s', + opt.backend) + end + else + rspamd_logger.errx(rspamd_config, 'cannot parse external map: %s', + parse_err) + end + else + -- Adjust lua specific augmentations in a trivial case + if type(opt.url) == 'string' then + local nsrc, ntype = maybe_adjust_type(opt.url, mtype) + if nsrc and ntype then + opt.url = nsrc + mtype = ntype + end + end + -- We have some non-trivial object so let C code to deal with it somehow... + local map = rspamd_config:add_map { + type = mtype, + description = description, + url = opt, + } + if map then + ret.__data = map + setmetatable(ret, ret_mt) + maps_cache[cache_key] = ret + maybe_register_selector() + + return ret + end + end + end -- opt[1] + end + + return nil +end + +--[[[ +-- @function lua_maps.map_add(mname, optname, mtype, description) +-- Creates a map from configuration elements (static data or URL) +-- Returns true if map was added or nil +-- @param {string} mname config section to use +-- @param {string} optname option name to use +-- @param {string} mtype type of map ('set', 'hash', 'radix', 'regexp', 'glob') +-- @param {string} description human-readable description of map +-- @param {function} callback optional callback that will be called on map match (required for external maps) +-- @return {bool} true on success, or `nil` +--]] + +local function rspamd_map_add(mname, optname, mtype, description, callback) + local opt = rspamd_config:get_module_opt(mname, optname) + + return rspamd_map_add_from_ucl(opt, mtype, description, callback) +end + +exports.rspamd_map_add = rspamd_map_add +exports.map_add = rspamd_map_add +exports.rspamd_map_add_from_ucl = rspamd_map_add_from_ucl +exports.map_add_from_ucl = rspamd_map_add_from_ucl + +-- Check `what` for being lua_map name, otherwise just compares key with what +local function rspamd_maybe_check_map(key, what) + local fun = require "fun" + + if type(what) == "table" then + return fun.any(function(elt) + return rspamd_maybe_check_map(key, elt) + end, what) + end + if type(rspamd_maps) == "table" then + local mn + if starts(key, "map:") then + mn = string.sub(key, 5) + elseif starts(key, "map://") then + mn = string.sub(key, 7) + end + + if mn and rspamd_maps[mn] then + return rspamd_maps[mn]:get_key(what) + end + end + + return what:lower() == key +end + +exports.rspamd_maybe_check_map = rspamd_maybe_check_map + +--[[[ +-- @function lua_maps.fill_config_maps(mname, options, defs) +-- Fill maps that could be defined in defs, from the config in the options +-- Defs is a table indexed by a map's parameter name and defining it's config, +-- @example +-- defs = { +-- my_map = { +-- type = 'map', +-- description = 'my cool map', +-- optional = true, +-- } +-- } +-- -- Then this function will look for opts.my_map parameter and try to replace it with +-- -- a map with the specific type, description but not failing if it was empty. +-- -- It will also set options.my_map_orig to the original value defined in the map. +--]] +exports.fill_config_maps = function(mname, opts, map_defs) + assert(type(opts) == 'table') + assert(type(map_defs) == 'table') + for k, v in pairs(map_defs) do + if opts[k] then + local map = rspamd_map_add_from_ucl(opts[k], v.type or 'map', v.description) + if not map then + rspamd_logger.errx(rspamd_config, 'map add error %s for module %s', k, mname) + return false + end + opts[k .. '_orig'] = opts[k] + opts[k] = map + elseif not v.optional then + rspamd_logger.errx(rspamd_config, 'cannot find non optional map %s for module %s', k, mname) + return false + end + end + + return true +end + +local direct_map_schema = ts.shape { -- complex object + name = ts.string:is_optional(), + description = ts.string:is_optional(), + selector_alias = ts.string:is_optional(), -- an optional alias for the selectos framework + timeout = ts.number, + data = ts.array_of(ts.string):is_optional(), + -- Tableshape has no options support for something like key1 or key2? + upstreams = ts.one_of { + ts.string, + ts.array_of(ts.string), + } :is_optional(), + url = ts.one_of { + ts.string, + ts.array_of(ts.string), + } :is_optional(), +} + +exports.map_schema = ts.one_of { + ts.string, -- 'http://some_map' + ts.array_of(ts.string), -- ['foo', 'bar'] + ts.one_of { direct_map_schema, external_map_schema } +} + +return exports diff --git a/lualib/lua_maps_expressions.lua b/lualib/lua_maps_expressions.lua new file mode 100644 index 0000000..996de99 --- /dev/null +++ b/lualib/lua_maps_expressions.lua @@ -0,0 +1,219 @@ +--[[[ +-- @module lua_maps_expressions +-- This module contains routines to combine maps, selectors and expressions +-- in a generic framework +@example +whitelist_ip_from = { + rules { + ip { + selector = "ip"; + map = "/path/to/whitelist_ip.map"; + } + from { + selector = "from(smtp)"; + map = "/path/to/whitelist_from.map"; + } + } + expression = "ip & from"; +} +--]] + +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local lua_selectors = require "lua_selectors" +local lua_maps = require "lua_maps" +local rspamd_expression = require "rspamd_expression" +local rspamd_logger = require "rspamd_logger" +local fun = require "fun" +local ts = require("tableshape").types + +local exports = {} + +local function process_func(elt, task) + local matched = {} + local function process_atom(atom) + local rule = elt.rules[atom] + local res = 0 + + local function match_rule(val) + local map_match = rule.map:get_key(val) + if map_match then + res = 1.0 + matched[rule.name] = { + matched = val, + value = map_match + } + end + end + + local values = rule.selector(task) + + if values then + if type(values) == 'table' then + for _, val in ipairs(values) do + if res == 0 then + match_rule(val) + end + end + else + match_rule(values) + end + end + + return res + end + + local res = elt.expr:process(process_atom) + + if res > 0 then + return res, matched + end + + return nil +end + +exports.schema = ts.shape { + expression = ts.string, + rules = ts.array_of( + ts.shape { + selector = ts.string, + map = lua_maps.map_schema, + } + ) +} + +--[[[ +-- @function lua_maps_expression.create(config, object, module_name) +-- Creates a new maps combination from `object` for `module_name`. +-- The input should be table with the following fields: +-- +-- * `rules` - kv map of rules where each rule has `map` and `selector` mandatory attribute, also `type` for map type, e.g. `regexp` +-- * `expression` - Rspamd expression where elements are names from `rules` field, e.g. `ip & from` +-- +-- This function returns an object with public method `process(task)` that checks +-- a task for the conditions defined in `expression` and `rules` and returns 2 values: +-- +-- 1. value returned by an expression (e.g. 1 or 0) +-- 2. an map (rule_name -> table) of matches, where each element has the following fields: +-- * `matched` - selector's value +-- * `value` - map's result +-- +-- In case if `expression` is false a `nil` value is returned. +-- @param {rspamd_config} cfg rspamd config +-- @param {table} obj configuration table +-- +--]] +local function create(cfg, obj, module_name) + if not module_name then + module_name = 'lua_maps_expressions' + end + + if not obj or not obj.rules or not obj.expression then + rspamd_logger.errx(cfg, 'cannot add maps combination for module %s: required elements are missing', + module_name) + return nil + end + + local ret = { + process = process_func, + rules = {}, + module_name = module_name + } + + for name, rule in pairs(obj.rules) do + local sel = lua_selectors.create_selector_closure(cfg, rule.selector) + + if not sel then + rspamd_logger.errx(cfg, 'cannot add selector for element %s in module %s', + name, module_name) + end + + if not rule.type then + -- Guess type + if name:find('ip') or name:find('ipnet') then + rule.type = 'radix' + elseif name:find('regexp') or name:find('re_') then + rule.type = 'regexp' + elseif name:find('glob') then + rule.type = 'regexp' + else + rule.type = 'set' + end + end + local map = lua_maps.map_add_from_ucl(rule.map, rule.type, + obj.description or module_name) + if not map then + rspamd_logger.errx(cfg, 'cannot add map for element %s in module %s', + name, module_name) + end + + if sel and map then + ret.rules[name] = { + selector = sel, + map = map, + name = name, + } + else + return nil + end + end + + -- Now process and parse expression + local function parse_atom(str) + local atom = table.concat(fun.totable(fun.take_while(function(c) + if string.find(', \t()><+!|&\n', c, 1, true) then + return false + end + return true + end, fun.iter(str))), '') + + if ret.rules[atom] then + return atom + end + + rspamd_logger.errx(cfg, 'use of undefined element "%s" when parsing maps expression for %s', + atom, module_name) + + return nil + end + local expr = rspamd_expression.create(obj.expression, parse_atom, + rspamd_config:get_mempool()) + + if not expr then + rspamd_logger.errx(cfg, 'cannot add map expression for module %s', + module_name) + return nil + end + + ret.expr = expr + + if obj.symbol then + rspamd_config:register_symbol { + type = 'virtual,ghost', + name = obj.symbol, + score = 0.0, + } + end + + ret.symbol = obj.symbol + + return ret +end + +exports.create = create + +return exports
\ No newline at end of file diff --git a/lualib/lua_meta.lua b/lualib/lua_meta.lua new file mode 100644 index 0000000..340d89e --- /dev/null +++ b/lualib/lua_meta.lua @@ -0,0 +1,549 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local exports = {} + +local N = "metatokens" +local ts = require("tableshape").types +local logger = require "rspamd_logger" + +-- Metafunctions +local function meta_size_function(task) + local sizes = { + 100, + 200, + 500, + 1000, + 2000, + 4000, + 10000, + 20000, + 30000, + 100000, + 200000, + 400000, + 800000, + 1000000, + 2000000, + 8000000, + } + + local size = task:get_size() + for i = 1, #sizes do + if sizes[i] >= size then + return { (1.0 * i) / #sizes } + end + end + + return { 0 } +end + +local function meta_images_function(task) + local images = task:get_images() + local ntotal = 0 + local njpg = 0 + local npng = 0 + local nlarge = 0 + local nsmall = 0 + + if images then + for _, img in ipairs(images) do + if img:get_type() == 'png' then + npng = npng + 1 + elseif img:get_type() == 'jpeg' then + njpg = njpg + 1 + end + + local w = img:get_width() + local h = img:get_height() + + if w > 0 and h > 0 then + if w + h > 256 then + nlarge = nlarge + 1 + else + nsmall = nsmall + 1 + end + end + + ntotal = ntotal + 1 + end + end + if ntotal > 0 then + njpg = 1.0 * njpg / ntotal + npng = 1.0 * npng / ntotal + nlarge = 1.0 * nlarge / ntotal + nsmall = 1.0 * nsmall / ntotal + end + return { ntotal, njpg, npng, nlarge, nsmall } +end + +local function meta_nparts_function(task) + local nattachments = 0 + local ntextparts = 0 + local totalparts = 1 + + local tp = task:get_text_parts() + if tp then + ntextparts = #tp + end + + local parts = task:get_parts() + + if parts then + for _, p in ipairs(parts) do + if p:is_attachment() then + nattachments = nattachments + 1 + end + totalparts = totalparts + 1 + end + end + + return { (1.0 * ntextparts) / totalparts, (1.0 * nattachments) / totalparts } +end + +local function meta_encoding_function(task) + local nutf = 0 + local nother = 0 + + local tp = task:get_text_parts() + if tp and #tp > 0 then + for _, p in ipairs(tp) do + if p:is_utf() then + nutf = nutf + 1 + else + nother = nother + 1 + end + end + + return { nutf / #tp, nother / #tp } + end + + return { 0, 0 } +end + +local function meta_recipients_function(task) + local nmime = 0 + local nsmtp = 0 + + if task:has_recipients('mime') then + nmime = #(task:get_recipients('mime')) + end + if task:has_recipients('smtp') then + nsmtp = #(task:get_recipients('smtp')) + end + + if nmime > 0 then + nmime = 1.0 / nmime + end + if nsmtp > 0 then + nsmtp = 1.0 / nsmtp + end + + return { nmime, nsmtp } +end + +local function meta_received_function(task) + local count_factor = 0 + local invalid_factor = 0 + local rh = task:get_received_headers() + local time_factor = 0 + local secure_factor = 0 + local fun = require "fun" + + if rh and #rh > 0 then + + local ntotal = 0.0 + local init_time = 0 + + fun.each(function(rc) + ntotal = ntotal + 1.0 + + if not rc.by_hostname then + invalid_factor = invalid_factor + 1.0 + end + if init_time == 0 and rc.timestamp then + init_time = rc.timestamp + elseif rc.timestamp then + time_factor = time_factor + math.abs(init_time - rc.timestamp) + init_time = rc.timestamp + end + if rc.flags and (rc.flags['ssl'] or rc.flags['authenticated']) then + secure_factor = secure_factor + 1.0 + end + end, + fun.filter(function(rc) + return not rc.flags or not rc.flags['artificial'] + end, rh)) + + if ntotal > 0 then + invalid_factor = invalid_factor / ntotal + secure_factor = secure_factor / ntotal + count_factor = 1.0 / ntotal + end + + if time_factor ~= 0 then + time_factor = 1.0 / time_factor + end + end + + return { count_factor, invalid_factor, time_factor, secure_factor } +end + +local function meta_urls_function(task) + local has_urls, nurls = task:has_urls() + if has_urls and nurls > 0 then + return { 1.0 / nurls } + end + + return { 0 } +end + +local function meta_words_function(task) + local avg_len = task:get_mempool():get_variable("avg_words_len", "double") or 0.0 + local short_words = task:get_mempool():get_variable("short_words_cnt", "double") or 0.0 + local ret_len = 0 + + local lens = { + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 15, + 20, + } + + for i = 1, #lens do + if lens[i] >= avg_len then + ret_len = (1.0 * i) / #lens + break + end + end + + local tp = task:get_text_parts() + local wres = { + 0, -- spaces rate + 0, -- double spaces rate + 0, -- non spaces rate + 0, -- ascii characters rate + 0, -- non-ascii characters rate + 0, -- capital characters rate + 0, -- numeric characters + } + for _, p in ipairs(tp) do + local stats = p:get_stats() + local len = p:get_length() + + if len > 0 then + wres[1] = wres[1] + stats['spaces'] / len + wres[2] = wres[2] + stats['double_spaces'] / len + wres[3] = wres[3] + stats['non_spaces'] / len + wres[4] = wres[4] + stats['ascii_characters'] / len + wres[5] = wres[5] + stats['non_ascii_characters'] / len + wres[6] = wres[6] + stats['capital_letters'] / len + wres[7] = wres[7] + stats['numeric_characters'] / len + end + end + + local ret = { + short_words, + ret_len, + } + + local divisor = 1.0 + if #tp > 0 then + divisor = #tp + end + + for _, wr in ipairs(wres) do + table.insert(ret, wr / divisor) + end + + return ret +end + +local metafunctions = { + { + cb = meta_size_function, + ninputs = 1, + names = { + "size" + }, + description = 'Describes size of the message', + }, + { + cb = meta_images_function, + ninputs = 5, + -- 1 - number of images, + -- 2 - number of png images, + -- 3 - number of jpeg images + -- 4 - number of large images (> 128 x 128) + -- 5 - number of small images (< 128 x 128) + names = { + 'nimages', + 'npng_images', + 'njpeg_images', + 'nlarge_images', + 'nsmall_images' + }, + description = [[Functions for images matching: + - number of images, + - number of png images, + - number of jpeg images + - number of large images (> 128 x 128) + - number of small images (< 128 x 128) +]] + }, + { + cb = meta_nparts_function, + ninputs = 2, + -- 1 - number of text parts + -- 2 - number of attachments + names = { + 'ntext_parts', + 'nattachments' + }, + description = [[Functions for images matching: + - number of text parts + - number of attachments +]] + }, + { + cb = meta_encoding_function, + ninputs = 2, + -- 1 - number of utf parts + -- 2 - number of non-utf parts + names = { + 'nutf_parts', + 'nascii_parts' + }, + description = [[Functions for encoding matching: + - number of utf parts + - number of non-utf parts +]] + }, + { + cb = meta_recipients_function, + ninputs = 2, + -- 1 - number of mime rcpt + -- 2 - number of smtp rcpt + names = { + 'nmime_rcpt', + 'nsmtp_rcpt' + }, + description = [[Functions for recipients data matching: + - number of mime rcpt + - number of smtp rcpt +]] + }, + { + cb = meta_received_function, + ninputs = 4, + names = { + 'nreceived', + 'nreceived_invalid', + 'nreceived_bad_time', + 'nreceived_secure' + }, + description = [[Functions for received headers data matching: + - number of received headers + - number of bad received headers + - number of skewed time received headers + - number of received via secured relays +]] + }, + { + cb = meta_urls_function, + ninputs = 1, + names = { + 'nurls' + }, + description = [[Functions for urls data matching: + - number of urls +]] + }, + { + cb = meta_words_function, + ninputs = 9, + names = { + 'avg_words_len', + 'nshort_words', + 'spaces_rate', + 'double_spaces_rate', + 'non_spaces_rate', + 'ascii_characters_rate', + 'non_ascii_characters_rate', + 'capital_characters_rate', + 'numeric_characters' + }, + description = [[Functions for words data matching: + - average length of the words + - number of short words + - rate of spaces in the text + - rate of multiple spaces + - rate of non space characters + - rate of ascii characters + - rate of non-ascii characters + - rate of capital letters + - rate of numbers +]] + }, +} + +local meta_schema = ts.shape { + cb = ts.func, + ninputs = ts.number, + names = ts.array_of(ts.string), + description = ts.string:is_optional() +} + +local metatokens_by_name = {} + +local function fill_metatokens_by_name() + metatokens_by_name = {} + + for _, mt in ipairs(metafunctions) do + for i = 1, mt.ninputs do + local name = mt.names[i] + + metatokens_by_name[name] = function(task) + local results = mt.cb(task) + return results[i] + end + end + end +end + +local function calculate_digest() + local cr = require "rspamd_cryptobox_hash" + + local h = cr.create() + for _, mt in ipairs(metafunctions) do + for i = 1, mt.ninputs do + local name = mt.names[i] + h:update(name) + end + end + + exports.digest = h:hex() +end + +local function rspamd_gen_metatokens(task, names) + local lua_util = require "lua_util" + local ipairs = ipairs + local metatokens = {} + + if not names then + local cached = task:cache_get('metatokens') + + if cached then + return cached + else + for _, mt in ipairs(metafunctions) do + local ct = mt.cb(task) + for i, tok in ipairs(ct) do + lua_util.debugm(N, task, "metatoken: %s = %s", + mt.names[i], tok) + if tok ~= tok or tok == math.huge then + logger.errx(task, 'metatoken %s returned %s; replace it with 0 for sanity', + mt.names[i], tok) + tok = 0.0 + end + table.insert(metatokens, tok) + end + end + + task:cache_set('metatokens', metatokens) + end + + else + for _, n in ipairs(names) do + if metatokens_by_name[n] then + local tok = metatokens_by_name[n](task) + if tok ~= tok or tok == math.huge then + logger.errx(task, 'metatoken %s returned %s; replace it with 0 for sanity', + n, tok) + tok = 0.0 + end + table.insert(metatokens, tok) + else + logger.errx(task, 'unknown metatoken: %s', n) + end + end + end + + return metatokens +end + +exports.rspamd_gen_metatokens = rspamd_gen_metatokens +exports.gen_metatokens = rspamd_gen_metatokens + +local function rspamd_gen_metatokens_table(task) + local metatokens = {} + + for _, mt in ipairs(metafunctions) do + local ct = mt.cb(task) + for i, tok in ipairs(ct) do + if tok ~= tok or tok == math.huge then + logger.errx(task, 'metatoken %s returned %s; replace it with 0 for sanity', + mt.names[i], tok) + tok = 0.0 + end + + metatokens[mt.names[i]] = tok + end + end + + return metatokens +end + +exports.rspamd_gen_metatokens_table = rspamd_gen_metatokens_table +exports.gen_metatokens_table = rspamd_gen_metatokens_table + +local function rspamd_count_metatokens() + local ipairs = ipairs + local total = 0 + for _, mt in ipairs(metafunctions) do + total = total + mt.ninputs + end + + return total +end + +exports.rspamd_count_metatokens = rspamd_count_metatokens +exports.count_metatokens = rspamd_count_metatokens +exports.version = 1 -- MUST be increased on each change of metatokens + +exports.add_metafunction = function(tbl) + local ret, err = meta_schema(tbl) + + if not ret then + logger.errx('cannot add metafunction: %s', err) + else + table.insert(metafunctions, tbl) + fill_metatokens_by_name() + calculate_digest() + end +end + +fill_metatokens_by_name() +calculate_digest() + +return exports diff --git a/lualib/lua_mime.lua b/lualib/lua_mime.lua new file mode 100644 index 0000000..0f5aa75 --- /dev/null +++ b/lualib/lua_mime.lua @@ -0,0 +1,760 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_mime +-- This module contains helper functions to modify mime parts +--]] + +local logger = require "rspamd_logger" +local rspamd_util = require "rspamd_util" +local rspamd_text = require "rspamd_text" +local ucl = require "ucl" + +local exports = {} + +local function newline(task) + local t = task:get_newlines_type() + + if t == 'cr' then + return '\r' + elseif t == 'lf' then + return '\n' + end + + return '\r\n' +end + +local function do_append_footer(task, part, footer, is_multipart, out, state) + local tp = part:get_text() + local ct = 'text/plain' + local cte = 'quoted-printable' + local newline_s = state.newline_s + + if tp:is_html() then + ct = 'text/html' + end + + local encode_func = function(input) + return rspamd_util.encode_qp(input, 80, task:get_newlines_type()) + end + + if part:get_cte() == '7bit' then + cte = '7bit' + encode_func = function(input) + if type(input) == 'userdata' then + return input + else + return rspamd_text.fromstring(input) + end + end + end + + if is_multipart then + out[#out + 1] = string.format('Content-Type: %s; charset=utf-8%s' .. + 'Content-Transfer-Encoding: %s', + ct, newline_s, cte) + out[#out + 1] = '' + else + state.new_cte = cte + end + + local content = tp:get_content('raw_utf') or '' + local double_nline = newline_s .. newline_s + local nlen = #double_nline + -- Hack, if part ends with 2 newline, then we append it after footer + if content:sub(-(nlen), nlen + 1) == double_nline then + -- content without last newline + content = content:sub(-(#newline_s), #newline_s + 1) .. footer + out[#out + 1] = { encode_func(content), true } + out[#out + 1] = '' + else + content = content .. footer + out[#out + 1] = { encode_func(content), true } + out[#out + 1] = '' + end + +end + +--[[[ +-- @function lua_mime.add_text_footer(task, html_footer, text_footer) +-- Adds a footer to all text parts in a message. It returns a table with the following +-- fields: +-- * out: new content (body only) +-- * need_rewrite_ct: boolean field that means if we must rewrite content type +-- * new_ct: new content type (type => string, subtype => string) +-- * new_cte: new content-transfer encoding (string) +--]] +exports.add_text_footer = function(task, html_footer, text_footer) + local newline_s = newline(task) + local state = { + newline_s = newline_s + } + local out = {} + local text_parts = task:get_text_parts() + + if not (html_footer or text_footer) or not (text_parts and #text_parts > 0) then + return false + end + + if html_footer or text_footer then + -- We need to take extra care about content-type and cte + local ct = task:get_header('Content-Type') + if ct then + ct = rspamd_util.parse_content_type(ct, task:get_mempool()) + end + + if ct then + if ct.type and ct.type == 'text' then + if ct.subtype then + if html_footer and (ct.subtype == 'html' or ct.subtype == 'htm') then + state.need_rewrite_ct = true + elseif text_footer and ct.subtype == 'plain' then + state.need_rewrite_ct = true + end + else + if text_footer then + state.need_rewrite_ct = true + end + end + + state.new_ct = ct + end + else + + if text_parts then + + if #text_parts == 1 then + state.need_rewrite_ct = true + state.new_ct = { + type = 'text', + subtype = 'plain' + } + elseif #text_parts > 1 then + -- XXX: in fact, it cannot be + state.new_ct = { + type = 'multipart', + subtype = 'mixed' + } + end + end + end + end + + local boundaries = {} + local cur_boundary + for _, part in ipairs(task:get_parts()) do + local boundary = part:get_boundary() + if part:is_multipart() then + if cur_boundary then + out[#out + 1] = string.format('--%s', + boundaries[#boundaries]) + end + + boundaries[#boundaries + 1] = boundary or '--XXX' + cur_boundary = boundary + + local rh = part:get_raw_headers() + if #rh > 0 then + out[#out + 1] = { rh, true } + end + elseif part:is_message() then + if boundary then + if cur_boundary and boundary ~= cur_boundary then + -- Need to close boundary + out[#out + 1] = string.format('--%s--%s', + boundaries[#boundaries], newline_s) + table.remove(boundaries) + cur_boundary = nil + end + out[#out + 1] = string.format('--%s', + boundary) + end + + out[#out + 1] = { part:get_raw_headers(), true } + else + local append_footer = false + local skip_footer = part:is_attachment() + + local parent = part:get_parent() + if parent then + local t, st = parent:get_type() + + if t == 'multipart' and st == 'signed' then + -- Do not modify signed parts + skip_footer = true + end + end + if text_footer and part:is_text() then + local tp = part:get_text() + + if not tp:is_html() then + append_footer = text_footer + end + end + + if html_footer and part:is_text() then + local tp = part:get_text() + + if tp:is_html() then + append_footer = html_footer + end + end + + if boundary then + if cur_boundary and boundary ~= cur_boundary then + -- Need to close boundary + out[#out + 1] = string.format('--%s--%s', + boundaries[#boundaries], newline_s) + table.remove(boundaries) + cur_boundary = boundary + end + out[#out + 1] = string.format('--%s', + boundary) + end + + if append_footer and not skip_footer then + do_append_footer(task, part, append_footer, + parent and parent:is_multipart(), out, state) + else + out[#out + 1] = { part:get_raw_headers(), true } + out[#out + 1] = { part:get_raw_content(), false } + end + end + end + + -- Close remaining + local b = table.remove(boundaries) + while b do + out[#out + 1] = string.format('--%s--', b) + if #boundaries > 0 then + out[#out + 1] = '' + end + b = table.remove(boundaries) + end + + state.out = out + + return state +end + +local function do_replacement (task, part, mp, replacements, + is_multipart, out, state) + + local tp = part:get_text() + local ct = 'text/plain' + local cte = 'quoted-printable' + local newline_s = state.newline_s + + if tp:is_html() then + ct = 'text/html' + end + + local encode_func = function(input) + return rspamd_util.encode_qp(input, 80, task:get_newlines_type()) + end + + if part:get_cte() == '7bit' then + cte = '7bit' + encode_func = function(input) + if type(input) == 'userdata' then + return input + else + return rspamd_text.fromstring(input) + end + end + end + + local content = tp:get_content('raw_utf') or rspamd_text.fromstring('') + local match_pos = mp:match(content, true) + + if match_pos then + -- sort matches and form the table: + -- start .. end for inclusion position + local matches_flattened = {} + for npat, matches in pairs(match_pos) do + for _, m in ipairs(matches) do + table.insert(matches_flattened, { m, npat }) + end + end + + -- Handle the case of empty match + if #matches_flattened == 0 then + out[#out + 1] = { part:get_raw_headers(), true } + out[#out + 1] = { part:get_raw_content(), false } + + return + end + + if is_multipart then + out[#out + 1] = { string.format('Content-Type: %s; charset="utf-8"%s' .. + 'Content-Transfer-Encoding: %s', + ct, newline_s, cte), true } + out[#out + 1] = { '', true } + else + state.new_cte = cte + end + + state.has_matches = true + -- now sort flattened by start of match and eliminate all overlaps + table.sort(matches_flattened, function(m1, m2) + return m1[1][1] < m2[1][1] + end) + + for i = 1, #matches_flattened - 1 do + local st = matches_flattened[i][1][1] -- current start of match + local e = matches_flattened[i][1][2] -- current end of match + local max_npat = matches_flattened[i][2] + for j = i + 1, #matches_flattened do + if matches_flattened[j][1][1] == st then + -- overlap + if matches_flattened[j][1][2] > e then + -- larger exclusion and switch replacement + e = matches_flattened[j][1][2] + max_npat = matches_flattened[j][2] + end + else + break + end + end + -- Maximum overlap for all matches + for j = i, #matches_flattened do + if matches_flattened[j][1][1] == st then + if e > matches_flattened[j][1][2] then + matches_flattened[j][1][2] = e + matches_flattened[j][2] = max_npat + end + else + break + end + end + end + -- Off-by one: match returns 0 based positions while we use 1 based in Lua + for _, m in ipairs(matches_flattened) do + m[1][1] = m[1][1] + 1 + m[1][2] = m[1][2] + 1 + end + + -- Now flattened match table is sorted by start pos and has the maximum overlapped pattern + -- Matches with the same start and end are covering the same replacement + -- e.g. we had something like [1 .. 2] -> replacement 1 and [1 .. 4] -> replacement 2 + -- after flattening we should have [1 .. 4] -> 2 and [1 .. 4] -> 2 + -- we can safely ignore those duplicates in the following code + + local cur_start = 1 + local fragments = {} + for _, m in ipairs(matches_flattened) do + if m[1][1] >= cur_start then + fragments[#fragments + 1] = content:sub(cur_start, m[1][1] - 1) + fragments[#fragments + 1] = replacements[m[2]] + cur_start = m[1][2] -- end of match + end + end + + -- last part + if cur_start < #content then + fragments[#fragments + 1] = content:span(cur_start) + end + + -- Final stuff + out[#out + 1] = { encode_func(rspamd_text.fromtable(fragments)), false } + else + -- No matches + out[#out + 1] = { part:get_raw_headers(), true } + out[#out + 1] = { part:get_raw_content(), false } + end +end + +--[[[ +-- @function lua_mime.multipattern_text_replace(task, mp, replacements) +-- Replaces text according to multipattern matches. It returns a table with the following +-- fields: +-- * out: new content (body only) +-- * need_rewrite_ct: boolean field that means if we must rewrite content type +-- * new_ct: new content type (type => string, subtype => string) +-- * new_cte: new content-transfer encoding (string) +--]] +exports.multipattern_text_replace = function(task, mp, replacements) + local newline_s = newline(task) + local state = { + newline_s = newline_s + } + local out = {} + local text_parts = task:get_text_parts() + + if not mp or not (text_parts and #text_parts > 0) then + return false + end + + -- We need to take extra care about content-type and cte + local ct = task:get_header('Content-Type') + if ct then + ct = rspamd_util.parse_content_type(ct, task:get_mempool()) + end + + if ct then + if ct.type and ct.type == 'text' then + state.need_rewrite_ct = true + state.new_ct = ct + end + else + -- No explicit CT, need to guess + if text_parts then + if #text_parts == 1 then + state.need_rewrite_ct = true + state.new_ct = { + type = 'text', + subtype = 'plain' + } + elseif #text_parts > 1 then + -- XXX: in fact, it cannot be + state.new_ct = { + type = 'multipart', + subtype = 'mixed' + } + end + end + end + + local boundaries = {} + local cur_boundary + for _, part in ipairs(task:get_parts()) do + local boundary = part:get_boundary() + if part:is_multipart() then + if cur_boundary then + out[#out + 1] = { string.format('--%s', + boundaries[#boundaries]), true } + end + + boundaries[#boundaries + 1] = boundary or '--XXX' + cur_boundary = boundary + + local rh = part:get_raw_headers() + if #rh > 0 then + out[#out + 1] = { rh, true } + end + elseif part:is_message() then + if boundary then + if cur_boundary and boundary ~= cur_boundary then + -- Need to close boundary + out[#out + 1] = { string.format('--%s--', + boundaries[#boundaries]), true } + table.remove(boundaries) + cur_boundary = nil + end + out[#out + 1] = { string.format('--%s', + boundary), true } + end + + out[#out + 1] = { part:get_raw_headers(), true } + else + local skip_replacement = part:is_attachment() + + local parent = part:get_parent() + if parent then + local t, st = parent:get_type() + + if t == 'multipart' and st == 'signed' then + -- Do not modify signed parts + skip_replacement = true + end + end + if not part:is_text() then + skip_replacement = true + end + + if boundary then + if cur_boundary and boundary ~= cur_boundary then + -- Need to close boundary + out[#out + 1] = { string.format('--%s--', + boundaries[#boundaries]), true } + table.remove(boundaries) + cur_boundary = boundary + end + out[#out + 1] = { string.format('--%s', + boundary), true } + end + + if not skip_replacement then + do_replacement(task, part, mp, replacements, + parent and parent:is_multipart(), out, state) + else + -- Append as is + out[#out + 1] = { part:get_raw_headers(), true } + out[#out + 1] = { part:get_raw_content(), false } + end + end + end + + -- Close remaining + local b = table.remove(boundaries) + while b do + out[#out + 1] = { string.format('--%s--', b), true } + if #boundaries > 0 then + out[#out + 1] = { '', true } + end + b = table.remove(boundaries) + end + + state.out = out + + return state +end + +--[[[ +-- @function lua_mime.modify_headers(task, {add = {hname = {value = 'value', order = 1}}, remove = {hname = {1,2}}}) +-- Adds/removes headers both internal and in the milter reply +-- Mode defines to be compatible with Rspamd <=3.2 and is the default (equal to 'compat') +--]] +exports.modify_headers = function(task, hdr_alterations, mode) + -- Assume default mode compatibility + if not mode then + mode = 'compat' + end + local add = hdr_alterations.add or {} + local remove = hdr_alterations.remove or {} + + local add_headers = {} -- For Milter reply + local hdr_flattened = {} -- For C API + + local function flatten_add_header(hname, hdr) + if not add_headers[hname] then + add_headers[hname] = {} + end + if not hdr_flattened[hname] then + hdr_flattened[hname] = { add = {} } + end + local add_tbl = hdr_flattened[hname].add + if hdr.value then + table.insert(add_headers[hname], { + order = (tonumber(hdr.order) or -1), + value = hdr.value, + }) + table.insert(add_tbl, { tonumber(hdr.order) or -1, hdr.value }) + elseif type(hdr) == 'table' then + for _, v in ipairs(hdr) do + flatten_add_header(hname, v) + end + elseif type(hdr) == 'string' then + table.insert(add_headers[hname], { + order = -1, + value = hdr, + }) + table.insert(add_tbl, { -1, hdr }) + else + logger.errx(task, 'invalid modification of header: %s', hdr) + end + + if mode == 'compat' and #add_headers[hname] == 1 then + -- Switch to the compatibility mode + add_headers[hname] = add_headers[hname][1] + end + end + if hdr_alterations.order then + -- Get headers alterations ordered + for _, hname in ipairs(hdr_alterations.order) do + flatten_add_header(hname, add[hname]) + end + else + for hname, hdr in pairs(add) do + flatten_add_header(hname, hdr) + end + end + + for hname, hdr in pairs(remove) do + if not hdr_flattened[hname] then + hdr_flattened[hname] = { remove = {} } + end + if not hdr_flattened[hname].remove then + hdr_flattened[hname].remove = {} + end + local remove_tbl = hdr_flattened[hname].remove + if type(hdr) == 'number' then + table.insert(remove_tbl, hdr) + else + for _, num in ipairs(hdr) do + table.insert(remove_tbl, num) + end + end + end + + if mode == 'compat' then + -- Clear empty alterations in the compat mode + if add_headers and not next(add_headers) then + add_headers = nil + end + if hdr_alterations.remove and not next(hdr_alterations.remove) then + hdr_alterations.remove = nil + end + end + task:set_milter_reply({ + add_headers = add_headers, + remove_headers = hdr_alterations.remove + }) + + for hname, flat_rules in pairs(hdr_flattened) do + task:modify_header(hname, flat_rules) + end +end + +--[[[ +-- @function lua_mime.message_to_ucl(task, [stringify_content]) +-- Exports a message to an ucl object +--]] +exports.message_to_ucl = function(task, stringify_content) + local E = {} + + local maybe_stringify_f = stringify_content and + tostring or function(t) + return t + end + local result = { + size = task:get_size(), + digest = task:get_digest(), + newlines = task:get_newlines_type(), + headers = task:get_headers(true) + } + + -- Utility to convert ip addr to a string or nil if invalid/absent + local function maybe_stringify_ip(addr) + if addr and addr:is_valid() then + return addr:to_string() + end + + return nil + end + + -- Envelope (smtp) information from email (nil if empty) + result.envelope = { + from_smtp = (task:get_from('smtp') or E)[1], + recipients_smtp = task:get_recipients('smtp'), + helo = task:get_helo(), + hostname = task:get_hostname(), + client_ip = maybe_stringify_ip(task:get_client_ip()), + from_ip = maybe_stringify_ip(task:get_from_ip()), + } + if not next(result.envelope) then + result.envelope = ucl.null + end + + local parts = task:get_parts() or E + result.parts = {} + for _, part in ipairs(parts) do + if not part:is_multipart() and not part:is_message() then + local p = { + size = part:get_length(), + type = string.format('%s/%s', part:get_type()), + detected_type = string.format('%s/%s', part:get_detected_type()), + filename = part:get_filename(), + content = maybe_stringify_f(part:get_content()), + headers = part:get_headers(true) or E, + boundary = part:get_enclosing_boundary(), + } + table.insert(result.parts, p) + else + -- Service part: multipart container or message/rfc822 + local p = { + type = string.format('%s/%s', part:get_type()), + headers = part:get_headers(true) or E, + boundary = part:get_enclosing_boundary(), + size = 0, + } + + if part:is_multipart() then + p.multipart_boundary = part:get_boundary() + end + + table.insert(result.parts, p) + end + end + + return result +end + +--[[[ +-- @function lua_mime.message_to_ucl_schema() +-- Returns schema for a message to verify result/document fields +--]] +exports.message_to_ucl_schema = function() + local ts = require("tableshape").types + + local function headers_schema() + return ts.shape { + order = ts.integer:describe('Header order in a message'), + raw = ts.string:describe('Raw header value'):is_optional(), + empty_separator = ts.boolean:describe('Whether header has an empty separator'), + separator = ts.string:describe('Separator between a header and a value'), + decoded = ts.string:describe('Decoded value'):is_optional(), + value = ts.string:describe('Decoded value'):is_optional(), + name = ts.string:describe('Header name'), + tab_separated = ts.boolean:describe('Whether header has tab as a separator') + } + end + + local function part_schema() + return ts.shape { + content = ts.string:describe('Decoded content'):is_optional(), + multipart_boundary = ts.string:describe('Multipart service boundary'):is_optional(), + size = ts.integer:describe('Size of the part'), + type = ts.string:describe('Announced type'):is_optional(), + detected_type = ts.string:describe('Detected type'):is_optional(), + boundary = ts.string:describe('Eclosing boundary'):is_optional(), + filename = ts.string:describe('File name for attachments'):is_optional(), + headers = ts.array_of(headers_schema()):describe('Part headers'), + } + end + + local function email_addr_schema() + return ts.shape { + addr = ts.string:describe('Parsed address'):is_optional(), + raw = ts.string:describe('Raw address'), + flags = ts.shape { + valid = ts.boolean:describe('Valid address'):is_optional(), + ip = ts.boolean:describe('IP like address'):is_optional(), + braced = ts.boolean:describe('Have braces around address'):is_optional(), + quoted = ts.boolean:describe('Have quotes around address'):is_optional(), + empty = ts.boolean:describe('Empty address'):is_optional(), + backslash = ts.boolean:describe('Backslash in address'):is_optional(), + ['8bit'] = ts.boolean:describe('8 bit characters in address'):is_optional(), + }, + user = ts.string:describe('Parsed user part'):is_optional(), + name = ts.string:describe('Displayed name'):is_optional(), + domain = ts.string:describe('Parsed domain part'):is_optional(), + } + end + local function envelope_schema() + return ts.shape { + from_smtp = email_addr_schema():describe('SMTP from'):is_optional(), + recipients_smtp = ts.array_of(email_addr_schema()):describe('SMTP recipients'):is_optional(), + helo = ts.string:describe('SMTP Helo'):is_optional(), + hostname = ts.string:describe('Sender hostname'):is_optional(), + client_ip = ts.string:describe('Client ip'):is_optional(), + from_ip = ts.string:describe('Sender ip'):is_optional(), + } + end + + return ts.shape { + headers = ts.array_of(headers_schema()), + parts = ts.array_of(part_schema()), + digest = ts.pattern(string.format('^%s$', string.rep('%x', 32))) + :describe('Message digest'), + newlines = ts.one_of({ "cr", "lf", "crlf" }):describe('Newlines type'), + size = ts.integer:describe('Size of the message in bytes'), + envelope = envelope_schema() + } +end + +return exports diff --git a/lualib/lua_mime_types.lua b/lualib/lua_mime_types.lua new file mode 100644 index 0000000..ba55f97 --- /dev/null +++ b/lualib/lua_mime_types.lua @@ -0,0 +1,745 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_mime_types +-- This module contains mime types list +--]] + +local exports = {} + +-- All mime extensions with corresponding content types +exports.full_extensions_map = { + { "323", "text/h323" }, + { "3g2", "video/3gpp2" }, + { "3gp", "video/3gpp" }, + { "3gp2", "video/3gpp2" }, + { "3gpp", "video/3gpp" }, + { "7z", { "application/x-7z-compressed", "application/7z" } }, + { "aa", "audio/audible" }, + { "AAC", "audio/aac" }, + { "aaf", "application/octet-stream" }, + { "aax", "audio/vnd.audible.aax" }, + { "ac3", "audio/ac3" }, + { "aca", "application/octet-stream" }, + { "accda", "application/msaccess.addin" }, + { "accdb", "application/msaccess" }, + { "accdc", "application/msaccess.cab" }, + { "accde", "application/msaccess" }, + { "accdr", "application/msaccess.runtime" }, + { "accdt", "application/msaccess" }, + { "accdw", "application/msaccess.webapplication" }, + { "accft", "application/msaccess.ftemplate" }, + { "acx", "application/internet-property-stream" }, + { "AddIn", "text/xml" }, + { "ade", "application/msaccess" }, + { "adobebridge", "application/x-bridge-url" }, + { "adp", "application/msaccess" }, + { "ADT", "audio/vnd.dlna.adts" }, + { "ADTS", "audio/aac" }, + { "afm", "application/octet-stream" }, + { "ai", "application/postscript" }, + { "aif", "audio/aiff" }, + { "aifc", "audio/aiff" }, + { "aiff", "audio/aiff" }, + { "air", "application/vnd.adobe.air-application-installer-package+zip" }, + { "amc", "application/mpeg" }, + { "anx", "application/annodex" }, + { "apk", "application/vnd.android.package-archive" }, + { "application", "application/x-ms-application" }, + { "art", "image/x-jg" }, + { "asa", "application/xml" }, + { "asax", "application/xml" }, + { "ascx", "application/xml" }, + { "asd", "application/octet-stream" }, + { "asf", "video/x-ms-asf" }, + { "ashx", "application/xml" }, + { "asi", "application/octet-stream" }, + { "asm", "text/plain" }, + { "asmx", "application/xml" }, + { "aspx", "application/xml" }, + { "asr", "video/x-ms-asf" }, + { "asx", "video/x-ms-asf" }, + { "atom", "application/atom+xml" }, + { "au", "audio/basic" }, + { "avi", "video/x-msvideo" }, + { "axa", "audio/annodex" }, + { "axs", "application/olescript" }, + { "axv", "video/annodex" }, + { "bas", "text/plain" }, + { "bcpio", "application/x-bcpio" }, + { "bin", "application/octet-stream" }, + { "bmp", { "image/bmp", "image/x-ms-bmp" } }, + { "c", "text/plain" }, + { "cab", "application/octet-stream" }, + { "caf", "audio/x-caf" }, + { "calx", "application/vnd.ms-office.calx" }, + { "cat", "application/vnd.ms-pki.seccat" }, + { "cc", "text/plain" }, + { "cd", "text/plain" }, + { "cdda", "audio/aiff" }, + { "cdf", "application/x-cdf" }, + { "cer", "application/x-x509-ca-cert" }, + { "cfg", "text/plain" }, + { "chm", "application/octet-stream" }, + { "class", "application/x-java-applet" }, + { "clp", "application/x-msclip" }, + { "cmd", "text/plain" }, + { "cmx", "image/x-cmx" }, + { "cnf", "text/plain" }, + { "cod", "image/cis-cod" }, + { "config", "application/xml" }, + { "contact", "text/x-ms-contact" }, + { "coverage", "application/xml" }, + { "cpio", "application/x-cpio" }, + { "cpp", "text/plain" }, + { "crd", "application/x-mscardfile" }, + { "crl", "application/pkix-crl" }, + { "crt", "application/x-x509-ca-cert" }, + { "cs", "text/plain" }, + { "csdproj", "text/plain" }, + { "csh", "application/x-csh" }, + { "csproj", "text/plain" }, + { "css", "text/css" }, + { "csv", { "application/vnd.ms-excel", "text/csv", "text/plain" } }, + { "cur", "application/octet-stream" }, + { "cxx", "text/plain" }, + { "dat", { "application/octet-stream", "application/ms-tnef" } }, + { "datasource", "application/xml" }, + { "dbproj", "text/plain" }, + { "dcr", "application/x-director" }, + { "def", "text/plain" }, + { "deploy", "application/octet-stream" }, + { "der", "application/x-x509-ca-cert" }, + { "dgml", "application/xml" }, + { "dib", "image/bmp" }, + { "dif", "video/x-dv" }, + { "dir", "application/x-director" }, + { "disco", "text/xml" }, + { "divx", "video/divx" }, + { "dll", "application/x-msdownload" }, + { "dll.config", "text/xml" }, + { "dlm", "text/dlm" }, + { "doc", "application/msword" }, + { "docm", "application/vnd.ms-word.document.macroEnabled.12" }, + { "docx", { + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/msword", + "application/vnd.ms-word.document.12", + "application/octet-stream", + } }, + { "dot", "application/msword" }, + { "dotm", "application/vnd.ms-word.template.macroEnabled.12" }, + { "dotx", "application/vnd.openxmlformats-officedocument.wordprocessingml.template" }, + { "dsp", "application/octet-stream" }, + { "dsw", "text/plain" }, + { "dtd", "text/xml" }, + { "dtsConfig", "text/xml" }, + { "dv", "video/x-dv" }, + { "dvi", "application/x-dvi" }, + { "dwf", "drawing/x-dwf" }, + { "dwg", { "application/acad", "image/vnd.dwg" } }, + { "dwp", "application/octet-stream" }, + { "dxf", "application/x-dxf" }, + { "dxr", "application/x-director" }, + { "eml", "message/rfc822" }, + { "emz", "application/octet-stream" }, + { "eot", "application/vnd.ms-fontobject" }, + { "eps", "application/postscript" }, + { "etl", "application/etl" }, + { "etx", "text/x-setext" }, + { "evy", "application/envoy" }, + { "exe", { + "application/x-dosexec", + "application/x-msdownload", + "application/x-executable", + } }, + { "exe.config", "text/xml" }, + { "fdf", "application/vnd.fdf" }, + { "fif", "application/fractals" }, + { "filters", "application/xml" }, + { "fla", "application/octet-stream" }, + { "flac", "audio/flac" }, + { "flr", "x-world/x-vrml" }, + { "flv", "video/x-flv" }, + { "fsscript", "application/fsharp-script" }, + { "fsx", "application/fsharp-script" }, + { "generictest", "application/xml" }, + { "gif", "image/gif" }, + { "gpx", "application/gpx+xml" }, + { "group", "text/x-ms-group" }, + { "gsm", "audio/x-gsm" }, + { "gtar", "application/x-gtar" }, + { "gz", { "application/gzip", "application/x-gzip", "application/tlsrpt+gzip" } }, + { "h", "text/plain" }, + { "hdf", "application/x-hdf" }, + { "hdml", "text/x-hdml" }, + { "hhc", "application/x-oleobject" }, + { "hhk", "application/octet-stream" }, + { "hhp", "application/octet-stream" }, + { "hlp", "application/winhlp" }, + { "hpp", "text/plain" }, + { "hqx", "application/mac-binhex40" }, + { "hta", "application/hta" }, + { "htc", "text/x-component" }, + { "htm", "text/html" }, + { "html", "text/html" }, + { "htt", "text/webviewhtml" }, + { "hxa", "application/xml" }, + { "hxc", "application/xml" }, + { "hxd", "application/octet-stream" }, + { "hxe", "application/xml" }, + { "hxf", "application/xml" }, + { "hxh", "application/octet-stream" }, + { "hxi", "application/octet-stream" }, + { "hxk", "application/xml" }, + { "hxq", "application/octet-stream" }, + { "hxr", "application/octet-stream" }, + { "hxs", "application/octet-stream" }, + { "hxt", "text/html" }, + { "hxv", "application/xml" }, + { "hxw", "application/octet-stream" }, + { "hxx", "text/plain" }, + { "i", "text/plain" }, + { "ico", "image/x-icon" }, + { "ics", { "text/calendar", "application/ics", "application/octet-stream" } }, + { "idl", "text/plain" }, + { "ief", "image/ief" }, + { "iii", "application/x-iphone" }, + { "inc", "text/plain" }, + { "inf", "application/octet-stream" }, + { "ini", "text/plain" }, + { "inl", "text/plain" }, + { "ins", "application/x-internet-signup" }, + { "ipa", "application/x-itunes-ipa" }, + { "ipg", "application/x-itunes-ipg" }, + { "ipproj", "text/plain" }, + { "ipsw", "application/x-itunes-ipsw" }, + { "iqy", "text/x-ms-iqy" }, + { "isp", "application/x-internet-signup" }, + { "ite", "application/x-itunes-ite" }, + { "itlp", "application/x-itunes-itlp" }, + { "itms", "application/x-itunes-itms" }, + { "itpc", "application/x-itunes-itpc" }, + { "IVF", "video/x-ivf" }, + { "jar", "application/java-archive" }, + { "java", "application/octet-stream" }, + { "jck", "application/liquidmotion" }, + { "jcz", "application/liquidmotion" }, + { "jfif", { "image/jpeg", "image/pjpeg" } }, + { "jnlp", "application/x-java-jnlp-file" }, + { "jpb", "application/octet-stream" }, + { "jpe", { "image/jpeg", "image/pjpeg" } }, + { "jpeg", { "image/jpeg", "image/pjpeg" } }, + { "jpg", { "image/jpeg", "image/pjpeg" } }, + { "js", "application/javascript" }, + { "json", "application/json" }, + { "jsx", "text/jscript" }, + { "jsxbin", "text/plain" }, + { "latex", "application/x-latex" }, + { "library-ms", "application/windows-library+xml" }, + { "lit", "application/x-ms-reader" }, + { "loadtest", "application/xml" }, + { "lpk", "application/octet-stream" }, + { "lsf", "video/x-la-asf" }, + { "lst", "text/plain" }, + { "lsx", "video/x-la-asf" }, + { "lzh", "application/octet-stream" }, + { "m13", "application/x-msmediaview" }, + { "m14", "application/x-msmediaview" }, + { "m1v", "video/mpeg" }, + { "m2t", "video/vnd.dlna.mpeg-tts" }, + { "m2ts", "video/vnd.dlna.mpeg-tts" }, + { "m2v", "video/mpeg" }, + { "m3u", "audio/x-mpegurl" }, + { "m3u8", "audio/x-mpegurl" }, + { "m4a", { "audio/m4a", "audio/x-m4a" } }, + { "m4b", "audio/m4b" }, + { "m4p", "audio/m4p" }, + { "m4r", "audio/x-m4r" }, + { "m4v", "video/x-m4v" }, + { "mac", "image/x-macpaint" }, + { "mak", "text/plain" }, + { "man", "application/x-troff-man" }, + { "manifest", "application/x-ms-manifest" }, + { "map", "text/plain" }, + { "master", "application/xml" }, + { "mbox", "application/mbox" }, + { "mda", "application/msaccess" }, + { "mdb", "application/x-msaccess" }, + { "mde", "application/msaccess" }, + { "mdp", "application/octet-stream" }, + { "me", "application/x-troff-me" }, + { "mfp", "application/x-shockwave-flash" }, + { "mht", "message/rfc822" }, + { "mhtml", "message/rfc822" }, + { "mid", "audio/mid" }, + { "midi", "audio/mid" }, + { "mix", "application/octet-stream" }, + { "mk", "text/plain" }, + { "mmf", "application/x-smaf" }, + { "mno", "text/xml" }, + { "mny", "application/x-msmoney" }, + { "mod", "video/mpeg" }, + { "mov", "video/quicktime" }, + { "movie", "video/x-sgi-movie" }, + { "mp2", "video/mpeg" }, + { "mp2v", "video/mpeg" }, + { "mp3", { "audio/mpeg", "audio/mpeg3", "audio/mp3", "audio/x-mpeg-3" } }, + { "mp4", "video/mp4" }, + { "mp4v", "video/mp4" }, + { "mpa", "video/mpeg" }, + { "mpe", "video/mpeg" }, + { "mpeg", "video/mpeg" }, + { "mpf", "application/vnd.ms-mediapackage" }, + { "mpg", "video/mpeg" }, + { "mpp", "application/vnd.ms-project" }, + { "mpv2", "video/mpeg" }, + { "mqv", "video/quicktime" }, + { "ms", "application/x-troff-ms" }, + { "msg", "application/vnd.ms-outlook" }, + { "msi", { "application/x-msi", "application/octet-stream" } }, + { "mso", "application/octet-stream" }, + { "mts", "video/vnd.dlna.mpeg-tts" }, + { "mtx", "application/xml" }, + { "mvb", "application/x-msmediaview" }, + { "mvc", "application/x-miva-compiled" }, + { "mxp", "application/x-mmxp" }, + { "nc", "application/x-netcdf" }, + { "nsc", "video/x-ms-asf" }, + { "nws", "message/rfc822" }, + { "ocx", "application/octet-stream" }, + { "oda", "application/oda" }, + { "odb", "application/vnd.oasis.opendocument.database" }, + { "odc", "application/vnd.oasis.opendocument.chart" }, + { "odf", "application/vnd.oasis.opendocument.formula" }, + { "odg", "application/vnd.oasis.opendocument.graphics" }, + { "odh", "text/plain" }, + { "odi", "application/vnd.oasis.opendocument.image" }, + { "odl", "text/plain" }, + { "odm", "application/vnd.oasis.opendocument.text-master" }, + { "odp", "application/vnd.oasis.opendocument.presentation" }, + { "ods", "application/vnd.oasis.opendocument.spreadsheet" }, + { "odt", "application/vnd.oasis.opendocument.text" }, + { "oga", "audio/ogg" }, + { "ogg", "audio/ogg" }, + { "ogv", "video/ogg" }, + { "ogx", "application/ogg" }, + { "one", "application/onenote" }, + { "onea", "application/onenote" }, + { "onepkg", "application/onenote" }, + { "onetmp", "application/onenote" }, + { "onetoc", "application/onenote" }, + { "onetoc2", "application/onenote" }, + { "opus", "audio/ogg" }, + { "orderedtest", "application/xml" }, + { "osdx", "application/opensearchdescription+xml" }, + { "otf", "application/font-sfnt" }, + { "otg", "application/vnd.oasis.opendocument.graphics-template" }, + { "oth", "application/vnd.oasis.opendocument.text-web" }, + { "otp", "application/vnd.oasis.opendocument.presentation-template" }, + { "ots", "application/vnd.oasis.opendocument.spreadsheet-template" }, + { "ott", "application/vnd.oasis.opendocument.text-template" }, + { "oxt", "application/vnd.openofficeorg.extension" }, + { "p10", "application/pkcs10" }, + { "p12", "application/x-pkcs12" }, + { "p7b", "application/x-pkcs7-certificates" }, + { "p7c", "application/pkcs7-mime" }, + { "p7m", "application/pkcs7-mime", "application/x-pkcs7-mime" }, + { "p7r", "application/x-pkcs7-certreqresp" }, + { "p7s", { "application/pkcs7-signature", "application/x-pkcs7-signature", "text/plain" } }, + { "pbm", "image/x-portable-bitmap" }, + { "pcast", "application/x-podcast" }, + { "pct", "image/pict" }, + { "pcx", "application/octet-stream" }, + { "pcz", "application/octet-stream" }, + { "pdf", "application/pdf" }, + { "pfb", "application/octet-stream" }, + { "pfm", "application/octet-stream" }, + { "pfx", "application/x-pkcs12" }, + { "pgm", "image/x-portable-graymap" }, + { "pic", "image/pict" }, + { "pict", "image/pict" }, + { "pkgdef", "text/plain" }, + { "pkgundef", "text/plain" }, + { "pko", "application/vnd.ms-pki.pko" }, + { "pls", "audio/scpls" }, + { "pma", "application/x-perfmon" }, + { "pmc", "application/x-perfmon" }, + { "pml", "application/x-perfmon" }, + { "pmr", "application/x-perfmon" }, + { "pmw", "application/x-perfmon" }, + { "png", "image/png" }, + { "pnm", "image/x-portable-anymap" }, + { "pnt", "image/x-macpaint" }, + { "pntg", "image/x-macpaint" }, + { "pnz", "image/png" }, + { "pot", "application/vnd.ms-powerpoint" }, + { "potm", "application/vnd.ms-powerpoint.template.macroEnabled.12" }, + { "potx", "application/vnd.openxmlformats-officedocument.presentationml.template" }, + { "ppa", "application/vnd.ms-powerpoint" }, + { "ppam", "application/vnd.ms-powerpoint.addin.macroEnabled.12" }, + { "ppm", "image/x-portable-pixmap" }, + { "pps", "application/vnd.ms-powerpoint" }, + { "ppsm", "application/vnd.ms-powerpoint.slideshow.macroEnabled.12" }, + { "ppsx", "application/vnd.openxmlformats-officedocument.presentationml.slideshow" }, + { "ppt", "application/vnd.ms-powerpoint" }, + { "pptm", "application/vnd.ms-powerpoint.presentation.macroEnabled.12" }, + { "pptx", "application/vnd.openxmlformats-officedocument.presentationml.presentation" }, + { "prf", "application/pics-rules" }, + { "prm", "application/octet-stream" }, + { "prx", "application/octet-stream" }, + { "ps", "application/postscript" }, + { "psc1", "application/PowerShell" }, + { "psd", "application/octet-stream" }, + { "psess", "application/xml" }, + { "psm", "application/octet-stream" }, + { "psp", "application/octet-stream" }, + { "pst", "application/vnd.ms-outlook" }, + { "pub", "application/x-mspublisher" }, + { "pwz", "application/vnd.ms-powerpoint" }, + { "qht", "text/x-html-insertion" }, + { "qhtm", "text/x-html-insertion" }, + { "qt", "video/quicktime" }, + { "qti", "image/x-quicktime" }, + { "qtif", "image/x-quicktime" }, + { "qtl", "application/x-quicktimeplayer" }, + { "qxd", "application/octet-stream" }, + { "ra", "audio/x-pn-realaudio" }, + { "ram", "audio/x-pn-realaudio" }, + { "rar", { "application/x-rar-compressed", "application/x-rar", "application/rar", "application/octet-stream" } }, + { "ras", "image/x-cmu-raster" }, + { "rat", "application/rat-file" }, + { "rc", "text/plain" }, + { "rc2", "text/plain" }, + { "rct", "text/plain" }, + { "rdlc", "application/xml" }, + { "reg", "text/plain" }, + { "resx", "application/xml" }, + { "rf", "image/vnd.rn-realflash" }, + { "rgb", "image/x-rgb" }, + { "rgs", "text/plain" }, + { "rm", "application/vnd.rn-realmedia" }, + { "rmi", "audio/mid" }, + { "rmp", "application/vnd.rn-rn_music_package" }, + { "roff", "application/x-troff" }, + { "rpm", "audio/x-pn-realaudio-plugin" }, + { "rqy", "text/x-ms-rqy" }, + { "rtf", { "application/rtf", "application/msword", "text/richtext", "text/rtf" } }, + { "rtx", "text/richtext" }, + { "rvt", "application/octet-stream" }, + { "ruleset", "application/xml" }, + { "s", "text/plain" }, + { "safariextz", "application/x-safari-safariextz" }, + { "scd", "application/x-msschedule" }, + { "scr", "text/plain" }, + { "sct", "text/scriptlet" }, + { "sd2", "audio/x-sd2" }, + { "sdp", "application/sdp" }, + { "sea", "application/octet-stream" }, + { "searchConnector-ms", "application/windows-search-connector+xml" }, + { "setpay", "application/set-payment-initiation" }, + { "setreg", "application/set-registration-initiation" }, + { "settings", "application/xml" }, + { "sgimb", "application/x-sgimb" }, + { "sgml", "text/sgml" }, + { "sh", "application/x-sh" }, + { "shar", "application/x-shar" }, + { "shtml", "text/html" }, + { "sit", "application/x-stuffit" }, + { "sitemap", "application/xml" }, + { "skin", "application/xml" }, + { "skp", "application/x-koan" }, + { "sldm", "application/vnd.ms-powerpoint.slide.macroEnabled.12" }, + { "sldx", "application/vnd.openxmlformats-officedocument.presentationml.slide" }, + { "slk", "application/vnd.ms-excel" }, + { "sln", "text/plain" }, + { "slupkg-ms", "application/x-ms-license" }, + { "smd", "audio/x-smd" }, + { "smi", "application/octet-stream" }, + { "smx", "audio/x-smd" }, + { "smz", "audio/x-smd" }, + { "snd", "audio/basic" }, + { "snippet", "application/xml" }, + { "snp", "application/octet-stream" }, + { "sol", "text/plain" }, + { "sor", "text/plain" }, + { "spc", "application/x-pkcs7-certificates" }, + { "spl", "application/futuresplash" }, + { "spx", "audio/ogg" }, + { "src", "application/x-wais-source" }, + { "srf", "text/plain" }, + { "SSISDeploymentManifest", "text/xml" }, + { "ssm", "application/streamingmedia" }, + { "sst", "application/vnd.ms-pki.certstore" }, + { "stl", "application/vnd.ms-pki.stl" }, + { "sv4cpio", "application/x-sv4cpio" }, + { "sv4crc", "application/x-sv4crc" }, + { "svc", "application/xml" }, + { "svg", "image/svg+xml" }, + { "swf", "application/x-shockwave-flash" }, + { "step", "application/step" }, + { "stp", "application/step" }, + { "t", "application/x-troff" }, + { "tar", "application/x-tar" }, + { "tcl", "application/x-tcl" }, + { "testrunconfig", "application/xml" }, + { "testsettings", "application/xml" }, + { "tex", "application/x-tex" }, + { "texi", "application/x-texinfo" }, + { "texinfo", "application/x-texinfo" }, + { "tgz", "application/x-compressed" }, + { "thmx", "application/vnd.ms-officetheme" }, + { "thn", "application/octet-stream" }, + { "tif", { "image/tiff", "application/octet-stream" } }, + { "tiff", "image/tiff" }, + { "tlh", "text/plain" }, + { "tli", "text/plain" }, + { "toc", "application/octet-stream" }, + { "tr", "application/x-troff" }, + { "trm", "application/x-msterminal" }, + { "trx", "application/xml" }, + { "ts", "video/vnd.dlna.mpeg-tts" }, + { "tsv", "text/tab-separated-values" }, + { "ttf", "application/font-sfnt" }, + { "tts", "video/vnd.dlna.mpeg-tts" }, + { "txt", "text/plain" }, + { "u32", "application/octet-stream" }, + { "uls", "text/iuls" }, + { "user", "text/plain" }, + { "ustar", "application/x-ustar" }, + { "vb", "text/plain" }, + { "vbdproj", "text/plain" }, + { "vbk", "video/mpeg" }, + { "vbproj", "text/plain" }, + { "vbs", "text/vbscript" }, + { "vcf", { "text/x-vcard", "text/vcard" } }, + { "vcproj", "application/xml" }, + { "vcs", "text/plain" }, + { "vcxproj", "application/xml" }, + { "vddproj", "text/plain" }, + { "vdp", "text/plain" }, + { "vdproj", "text/plain" }, + { "vdx", "application/vnd.ms-visio.viewer" }, + { "vml", "text/xml" }, + { "vscontent", "application/xml" }, + { "vsct", "text/xml" }, + { "vsd", "application/vnd.visio" }, + { "vsi", "application/ms-vsi" }, + { "vsix", "application/vsix" }, + { "vsixlangpack", "text/xml" }, + { "vsixmanifest", "text/xml" }, + { "vsmdi", "application/xml" }, + { "vspscc", "text/plain" }, + { "vss", "application/vnd.visio" }, + { "vsscc", "text/plain" }, + { "vssettings", "text/xml" }, + { "vssscc", "text/plain" }, + { "vst", "application/vnd.visio" }, + { "vstemplate", "text/xml" }, + { "vsto", "application/x-ms-vsto" }, + { "vsw", "application/vnd.visio" }, + { "vsx", "application/vnd.visio" }, + { "vtx", "application/vnd.visio" }, + { "wav", { "audio/wav", "audio/vnd.wave", "audio/x-wav" } }, + { "wave", "audio/wav" }, + { "wax", "audio/x-ms-wax" }, + { "wbk", "application/msword" }, + { "wbmp", "image/vnd.wap.wbmp" }, + { "wcm", "application/vnd.ms-works" }, + { "wdb", "application/vnd.ms-works" }, + { "wdp", "image/vnd.ms-photo" }, + { "webarchive", "application/x-safari-webarchive" }, + { "webm", "video/webm" }, + { "webp", "image/webp" }, + { "webtest", "application/xml" }, + { "wiq", "application/xml" }, + { "wiz", "application/msword" }, + { "wks", "application/vnd.ms-works" }, + { "WLMP", "application/wlmoviemaker" }, + { "wlpginstall", "application/x-wlpg-detect" }, + { "wlpginstall3", "application/x-wlpg3-detect" }, + { "wm", "video/x-ms-wm" }, + { "wma", "audio/x-ms-wma" }, + { "wmd", "application/x-ms-wmd" }, + { "wmf", { "application/x-msmetafile", "image/wmf", "image/x-wmf" } }, + { "wml", "text/vnd.wap.wml" }, + { "wmlc", "application/vnd.wap.wmlc" }, + { "wmls", "text/vnd.wap.wmlscript" }, + { "wmlsc", "application/vnd.wap.wmlscriptc" }, + { "wmp", "video/x-ms-wmp" }, + { "wmv", "video/x-ms-wmv" }, + { "wmx", "video/x-ms-wmx" }, + { "wmz", "application/x-ms-wmz" }, + { "woff", "application/font-woff" }, + { "wpl", "application/vnd.ms-wpl" }, + { "wps", "application/vnd.ms-works" }, + { "wri", "application/x-mswrite" }, + { "wrl", "x-world/x-vrml" }, + { "wrz", "x-world/x-vrml" }, + { "wsc", "text/scriptlet" }, + { "wsdl", "text/xml" }, + { "wvx", "video/x-ms-wvx" }, + { "x", "application/directx" }, + { "xaf", "x-world/x-vrml" }, + { "xaml", "application/xaml+xml" }, + { "xap", "application/x-silverlight-app" }, + { "xbap", "application/x-ms-xbap" }, + { "xbm", "image/x-xbitmap" }, + { "xdr", "text/plain" }, + { "xht", "application/xhtml+xml" }, + { "xhtml", "application/xhtml+xml" }, + { "xla", "application/vnd.ms-excel" }, + { "xlam", "application/vnd.ms-excel.addin.macroEnabled.12" }, + { "xlc", "application/vnd.ms-excel" }, + { "xld", "application/vnd.ms-excel" }, + { "xlk", "application/vnd.ms-excel" }, + { "xll", "application/vnd.ms-excel" }, + { "xlm", "application/vnd.ms-excel" }, + { "xls", { + "application/excel", + "application/vnd.ms-excel", + "application/vnd.ms-office", + "application/x-excel", + "application/octet-stream" + } }, + { "xlsb", "application/vnd.ms-excel.sheet.binary.macroEnabled.12" }, + { "xlsm", "application/vnd.ms-excel.sheet.macroEnabled.12" }, + { "xlsx", { + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "application/vnd.ms-excel.12", + "application/octet-stream" + } }, + { "xlt", "application/vnd.ms-excel" }, + { "xltm", "application/vnd.ms-excel.template.macroEnabled.12" }, + { "xltx", "application/vnd.openxmlformats-officedocument.spreadsheetml.template" }, + { "xlw", "application/vnd.ms-excel" }, + { "xml", { "application/xml", "text/xml", "application/octet-stream" } }, + { "xmp", "application/octet-stream" }, + { "xmta", "application/xml" }, + { "xof", "x-world/x-vrml" }, + { "XOML", "text/plain" }, + { "xpm", "image/x-xpixmap" }, + { "xps", "application/vnd.ms-xpsdocument" }, + { "xrm-ms", "text/xml" }, + { "xsc", "application/xml" }, + { "xsd", "text/xml" }, + { "xsf", "text/xml" }, + { "xsl", "text/xml" }, + { "xslt", "text/xml" }, + { "xsn", "application/octet-stream" }, + { "xss", "application/xml" }, + { "xspf", "application/xspf+xml" }, + { "xtp", "application/octet-stream" }, + { "xwd", "image/x-xwindowdump" }, + { "z", "application/x-compress" }, + { "zip", { + "application/zip", + "application/x-zip-compressed", + "application/octet-stream" + } }, + { "zlib", "application/zlib" }, +} + +-- Used to match extension by content type +exports.reversed_extensions_map = { + ["text/html"] = "html", + ["text/css"] = "css", + ["text/xml"] = "xml", + ["image/gif"] = "gif", + ["image/jpeg"] = "jpeg", + ["application/javascript"] = "js", + ["application/atom+xml"] = "atom", + ["application/rss+xml"] = "rss", + ["application/csv"] = "csv", + ["text/mathml"] = "mml", + ["text/plain"] = "txt", + ["text/vnd.sun.j2me.app-descriptor"] = "jad", + ["text/vnd.wap.wml"] = "wml", + ["text/x-component"] = "htc", + ["image/png"] = "png", + ["image/svg+xml"] = "svg", + ["image/tiff"] = "tiff", + ["image/vnd.wap.wbmp"] = "wbmp", + ["image/webp"] = "webp", + ["image/x-icon"] = "ico", + ["image/x-jng"] = "jng", + ["image/x-ms-bmp"] = "bmp", + ["font/woff"] = "woff", + ["font/woff2"] = "woff2", + ["application/java-archive"] = "jar", + ["application/json"] = "json", + ["application/mac-binhex40"] = "hqx", + ["application/msword"] = "doc", + ["application/pdf"] = "pdf", + ["application/postscript"] = "ps", + ["application/rtf"] = "rtf", + ["application/vnd.apple.mpegurl"] = "m3u8", + ["application/vnd.google-earth.kml+xml"] = "kml", + ["application/vnd.google-earth.kmz"] = "kmz", + ["application/vnd.ms-excel"] = "xls", + ["application/vnd.ms-fontobject"] = "eot", + ["application/vnd.ms-powerpoint"] = "ppt", + ["application/vnd.oasis.opendocument.graphics"] = "odg", + ["application/vnd.oasis.opendocument.presentation"] = "odp", + ["application/vnd.oasis.opendocument.spreadsheet"] = "ods", + ["application/vnd.oasis.opendocument.text"] = "odt", + ["application/vnd.openxmlformats-officedocument.presentationml.presentation"] = "pptx", + ["application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"] = "xlsx", + ["application/vnd.openxmlformats-officedocument.wordprocessingml.document"] = "docx", + ["application/x-7z-compressed"] = "7z", + ["application/x-cocoa"] = "cco", + ["application/x-java-archive-diff"] = "jardiff", + ["application/x-java-jnlp-file"] = "jnlp", + ["application/x-makeself"] = "run", + ["application/x-perl"] = "pl", + ["application/x-pilot"] = "pdb", + ["application/x-rar-compressed"] = "rar", + ["application/x-redhat-package-manager"] = "rpm", + ["application/x-sea"] = "sea", + ["application/x-shockwave-flash"] = "swf", + ["application/x-stuffit"] = "sit", + ["application/x-tcl"] = "tcl", + ["application/x-x509-ca-cert"] = "crt", + ["application/x-xpinstall"] = "xpi", + ["application/xhtml+xml"] = "xhtml", + ["application/xspf+xml"] = "xspf", + ["application/zip"] = "zip", + ["application/x-dosexec"] = "exe", + ["application/x-msdownload"] = "exe", + ["application/x-executable"] = "exe", + ["text/x-msdos-batch"] = "bat", + + ["audio/midi"] = "mid", + ["audio/mpeg"] = "mp3", + ["audio/ogg"] = "ogg", + ["audio/x-m4a"] = "m4a", + ["audio/x-realaudio"] = "ra", + ["video/3gpp"] = "3gpp", + ["video/mp2t"] = "ts", + ["video/mp4"] = "mp4", + ["video/mpeg"] = "mpeg", + ["video/quicktime"] = "mov", + ["video/webm"] = "webm", + ["video/x-flv"] = "flv", + ["video/x-m4v"] = "m4v", + ["video/x-mng"] = "mng", + ["video/x-ms-asf"] = "asx", + ["video/x-ms-wmv"] = "wmv", + ["video/x-msvideo"] = "avi", +} + +return exports diff --git a/lualib/lua_redis.lua b/lualib/lua_redis.lua new file mode 100644 index 0000000..818d955 --- /dev/null +++ b/lualib/lua_redis.lua @@ -0,0 +1,1817 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local logger = require "rspamd_logger" +local lutil = require "lua_util" +local rspamd_util = require "rspamd_util" +local ts = require("tableshape").types + +local exports = {} + +local E = {} +local N = "lua_redis" + +local common_schema = { + timeout = (ts.number + ts.string / lutil.parse_time_interval):is_optional():describe("Connection timeout"), + db = ts.string:is_optional():describe("Database number"), + database = ts.string:is_optional():describe("Database number"), + dbname = ts.string:is_optional():describe("Database number"), + prefix = ts.string:is_optional():describe("Key prefix"), + username = ts.string:is_optional():describe("Username"), + password = ts.string:is_optional():describe("Password"), + expand_keys = ts.boolean:is_optional():describe("Expand keys"), + sentinels = (ts.string + ts.array_of(ts.string)):is_optional():describe("Sentinel servers"), + sentinel_watch_time = (ts.number + ts.string / lutil.parse_time_interval):is_optional():describe("Sentinel watch time"), + sentinel_masters_pattern = ts.string:is_optional():describe("Sentinel masters pattern"), + sentinel_master_maxerrors = (ts.number + ts.string / tonumber):is_optional():describe("Sentinel master max errors"), + sentinel_username = ts.string:is_optional():describe("Sentinel username"), + sentinel_password = ts.string:is_optional():describe("Sentinel password"), +} + +local read_schema = lutil.table_merge({ + read_servers = ts.string + ts.array_of(ts.string), +}, common_schema) + +local write_schema = lutil.table_merge({ + write_servers = ts.string + ts.array_of(ts.string), +}, common_schema) + +local rw_schema = lutil.table_merge({ + read_servers = ts.string + ts.array_of(ts.string), + write_servers = ts.string + ts.array_of(ts.string), +}, common_schema) + +local servers_schema = lutil.table_merge({ + servers = ts.string + ts.array_of(ts.string), +}, common_schema) + +local server_schema = lutil.table_merge({ + server = ts.string + ts.array_of(ts.string), +}, common_schema) + +local enrich_schema = function(external) + return ts.one_of { + ts.shape(external), -- no specific redis parameters + ts.shape(lutil.table_merge(read_schema, external)), -- read_servers specified + ts.shape(lutil.table_merge(write_schema, external)), -- write_servers specified + ts.shape(lutil.table_merge(rw_schema, external)), -- both read and write servers defined + ts.shape(lutil.table_merge(servers_schema, external)), -- just servers for both ops + ts.shape(lutil.table_merge(server_schema, external)), -- legacy `server` attribute + } +end + +exports.enrich_schema = enrich_schema + +local function redis_query_sentinel(ev_base, params, initialised) + local function flatten_redis_table(tbl) + local res = {} + for i = 1, #tbl, 2 do + res[tbl[i]] = tbl[i + 1] + end + + return res + end + -- Coroutines syntax + local rspamd_redis = require "rspamd_redis" + local sentinels = params.sentinels + local addr = sentinels:get_upstream_round_robin() + + local host = addr:get_addr() + local masters = {} + local process_masters -- Function that is called to process masters data + + local function masters_cb(err, result) + if not err and result and type(result) == 'table' then + + local pending_subrequests = 0 + + for _, m in ipairs(result) do + local master = flatten_redis_table(m) + + -- Wrap IPv6-addresses in brackets + if (master.ip:match(":")) then + master.ip = "[" .. master.ip .. "]" + end + + if params.sentinel_masters_pattern then + if master.name:match(params.sentinel_masters_pattern) then + lutil.debugm(N, 'found master %s with ip %s and port %s', + master.name, master.ip, master.port) + masters[master.name] = master + else + lutil.debugm(N, 'skip master %s with ip %s and port %s, pattern %s', + master.name, master.ip, master.port, params.sentinel_masters_pattern) + end + else + lutil.debugm(N, 'found master %s with ip %s and port %s', + master.name, master.ip, master.port) + masters[master.name] = master + end + end + + -- For each master we need to get a list of slaves + for k, v in pairs(masters) do + v.slaves = {} + local function slaves_cb(slave_err, slave_result) + if not slave_err and type(slave_result) == 'table' then + for _, s in ipairs(slave_result) do + local slave = flatten_redis_table(s) + lutil.debugm(N, rspamd_config, + 'found slave for master %s with ip %s and port %s', + v.name, slave.ip, slave.port) + -- Wrap IPv6-addresses in brackets + if (slave.ip:match(":")) then + slave.ip = "[" .. slave.ip .. "]" + end + v.slaves[#v.slaves + 1] = slave + end + else + logger.errx('cannot get slaves data from Redis Sentinel %s: %s', + host:to_string(true), slave_err) + addr:fail() + end + + pending_subrequests = pending_subrequests - 1 + + if pending_subrequests == 0 then + -- Finalize masters and slaves + process_masters() + end + end + + local ret = rspamd_redis.make_request { + host = addr:get_addr(), + timeout = params.timeout, + username = params.sentinel_username, + password = params.sentinel_password, + config = rspamd_config, + ev_base = ev_base, + cmd = 'SENTINEL', + args = { 'slaves', k }, + no_pool = true, + callback = slaves_cb + } + + if not ret then + logger.errx(rspamd_config, 'cannot connect sentinel when query slaves at address: %s', + host:to_string(true)) + addr:fail() + else + pending_subrequests = pending_subrequests + 1 + end + end + + addr:ok() + else + logger.errx('cannot get masters data from Redis Sentinel %s: %s', + host:to_string(true), err) + addr:fail() + end + end + + local ret = rspamd_redis.make_request { + host = addr:get_addr(), + timeout = params.timeout, + config = rspamd_config, + ev_base = ev_base, + username = params.sentinel_username, + password = params.sentinel_password, + cmd = 'SENTINEL', + args = { 'masters' }, + no_pool = true, + callback = masters_cb, + } + + if not ret then + logger.errx(rspamd_config, 'cannot connect sentinel at address: %s', + host:to_string(true)) + addr:fail() + end + + process_masters = function() + -- We now form new strings for masters and slaves + local read_servers_tbl, write_servers_tbl = {}, {} + + for _, master in pairs(masters) do + write_servers_tbl[#write_servers_tbl + 1] = string.format( + '%s:%s', master.ip, master.port + ) + read_servers_tbl[#read_servers_tbl + 1] = string.format( + '%s:%s', master.ip, master.port + ) + + for _, slave in ipairs(master.slaves) do + if slave['master-link-status'] == 'ok' then + read_servers_tbl[#read_servers_tbl + 1] = string.format( + '%s:%s', slave.ip, slave.port + ) + end + end + end + + table.sort(read_servers_tbl) + table.sort(write_servers_tbl) + + local read_servers_str = table.concat(read_servers_tbl, ',') + local write_servers_str = table.concat(write_servers_tbl, ',') + + lutil.debugm(N, rspamd_config, + 'new servers list: %s read; %s write', + read_servers_str, + write_servers_str) + + if read_servers_str ~= params.read_servers_str then + local upstream_list = require "rspamd_upstream_list" + + local read_upstreams = upstream_list.create(rspamd_config, + read_servers_str, 6379) + + if read_upstreams then + logger.infox(rspamd_config, 'sentinel %s: replace read servers with new list: %s', + host:to_string(true), read_servers_str) + params.read_servers = read_upstreams + params.read_servers_str = read_servers_str + end + end + + if write_servers_str ~= params.write_servers_str then + local upstream_list = require "rspamd_upstream_list" + + local write_upstreams = upstream_list.create(rspamd_config, + write_servers_str, 6379) + + if write_upstreams then + logger.infox(rspamd_config, 'sentinel %s: replace write servers with new list: %s', + host:to_string(true), write_servers_str) + params.write_servers = write_upstreams + params.write_servers_str = write_servers_str + + local queried = false + + local function monitor_failures(up, _, count) + if count > params.sentinel_master_maxerrors and not queried then + logger.infox(rspamd_config, 'sentinel: master with address %s, caused %s failures, try to query sentinel', + host:to_string(true), count) + queried = true -- Avoid multiple checks caused by this monitor + redis_query_sentinel(ev_base, params, true) + end + end + + write_upstreams:add_watcher('failure', monitor_failures) + end + end + end + +end + +local function add_redis_sentinels(params) + local upstream_list = require "rspamd_upstream_list" + + local upstreams_sentinels = upstream_list.create(rspamd_config, + params.sentinels, 5000) + + if not upstreams_sentinels then + logger.errx(rspamd_config, 'cannot load redis sentinels string: %s', + params.sentinels) + + return + end + + params.sentinels = upstreams_sentinels + + if not params.sentinel_watch_time then + params.sentinel_watch_time = 60 -- Each minute + end + + if not params.sentinel_master_maxerrors then + params.sentinel_master_maxerrors = 2 -- Maximum number of errors before rechecking + end + + rspamd_config:add_on_load(function(_, ev_base, worker) + local initialised = false + if worker:is_scanner() or worker:get_type() == 'fuzzy' then + rspamd_config:add_periodic(ev_base, 0.0, function() + redis_query_sentinel(ev_base, params, initialised) + initialised = true + + return params.sentinel_watch_time + end, false) + end + end) +end + +local cached_results = {} + +local function calculate_redis_hash(params) + local cr = require "rspamd_cryptobox_hash" + + local h = cr.create() + + local function rec_hash(k, v) + if type(v) == 'string' then + h:update(k) + h:update(v) + elseif type(v) == 'number' then + h:update(k) + h:update(tostring(v)) + elseif type(v) == 'table' then + for kk, vv in pairs(v) do + rec_hash(kk, vv) + end + end + end + + rec_hash('top', params) + + return h:base32() +end + +local function process_redis_opts(options, redis_params) + local default_timeout = 1.0 + local default_expand_keys = false + + if not redis_params['timeout'] or redis_params['timeout'] == default_timeout then + if options['timeout'] then + redis_params['timeout'] = tonumber(options['timeout']) + else + redis_params['timeout'] = default_timeout + end + end + + if options['prefix'] and not redis_params['prefix'] then + redis_params['prefix'] = options['prefix'] + end + + if type(options['expand_keys']) == 'boolean' then + redis_params['expand_keys'] = options['expand_keys'] + else + redis_params['expand_keys'] = default_expand_keys + end + + if not redis_params['db'] then + if options['db'] then + redis_params['db'] = tostring(options['db']) + elseif options['dbname'] then + redis_params['db'] = tostring(options['dbname']) + elseif options['database'] then + redis_params['db'] = tostring(options['database']) + end + end + if options['username'] and not redis_params['username'] then + redis_params['username'] = options['username'] + end + if options['password'] and not redis_params['password'] then + redis_params['password'] = options['password'] + end + + if not redis_params.sentinels and options.sentinels then + redis_params.sentinels = options.sentinels + end + + if options['sentinel_masters_pattern'] and not redis_params['sentinel_masters_pattern'] then + redis_params['sentinel_masters_pattern'] = options['sentinel_masters_pattern'] + end + +end + +local function enrich_defaults(rspamd_config, module, redis_params) + if rspamd_config then + local opts = rspamd_config:get_all_opt('redis') + + if opts then + if module then + if opts[module] then + process_redis_opts(opts[module], redis_params) + end + end + + process_redis_opts(opts, redis_params) + end + end +end + +local function maybe_return_cached(redis_params) + local h = calculate_redis_hash(redis_params) + + if cached_results[h] then + lutil.debugm(N, 'reused redis server: %s', redis_params) + return cached_results[h] + end + + redis_params.hash = h + cached_results[h] = redis_params + + if not redis_params.read_only and redis_params.sentinels then + add_redis_sentinels(redis_params) + end + + lutil.debugm(N, 'loaded new redis server: %s', redis_params) + return redis_params +end + +--[[[ +-- @module lua_redis +-- This module contains helper functions for working with Redis +--]] +local function process_redis_options(options, rspamd_config, result) + local default_port = 6379 + local upstream_list = require "rspamd_upstream_list" + local read_only = true + + -- Try to get read servers: + local upstreams_read, upstreams_write + + if options['read_servers'] then + if rspamd_config then + upstreams_read = upstream_list.create(rspamd_config, + options['read_servers'], default_port) + else + upstreams_read = upstream_list.create(options['read_servers'], + default_port) + end + + result.read_servers_str = options['read_servers'] + elseif options['servers'] then + if rspamd_config then + upstreams_read = upstream_list.create(rspamd_config, + options['servers'], default_port) + else + upstreams_read = upstream_list.create(options['servers'], default_port) + end + + result.read_servers_str = options['servers'] + read_only = false + elseif options['server'] then + if rspamd_config then + upstreams_read = upstream_list.create(rspamd_config, + options['server'], default_port) + else + upstreams_read = upstream_list.create(options['server'], default_port) + end + + result.read_servers_str = options['server'] + read_only = false + end + + if upstreams_read then + if options['write_servers'] then + if rspamd_config then + upstreams_write = upstream_list.create(rspamd_config, + options['write_servers'], default_port) + else + upstreams_write = upstream_list.create(options['write_servers'], + default_port) + end + result.write_servers_str = options['write_servers'] + read_only = false + elseif not read_only then + upstreams_write = upstreams_read + result.write_servers_str = result.read_servers_str + end + end + + -- Store options + process_redis_opts(options, result) + + if read_only and not upstreams_write then + result.read_only = true + elseif upstreams_write then + result.read_only = false + end + + if upstreams_read then + result.read_servers = upstreams_read + + if upstreams_write then + result.write_servers = upstreams_write + end + + return true + end + + lutil.debugm(N, rspamd_config, + 'cannot load redis server from obj: %s, processed to %s', + options, result) + + return false +end + +--[[[ +@function try_load_redis_servers(options, rspamd_config, no_fallback) +Tries to load redis servers from the specified `options` object. +Returns `redis_params` table or nil in case of failure + +--]] +exports.try_load_redis_servers = function(options, rspamd_config, no_fallback, module_name) + local result = {} + + if process_redis_options(options, rspamd_config, result) then + if not no_fallback then + enrich_defaults(rspamd_config, module_name, result) + end + return maybe_return_cached(result) + end +end + +-- This function parses redis server definition using either +-- specific server string for this module or global +-- redis section +local function rspamd_parse_redis_server(module_name, module_opts, no_fallback) + local result = {} + + -- Try local options + local opts + lutil.debugm(N, rspamd_config, 'try load redis config for: %s', module_name) + if not module_opts then + opts = rspamd_config:get_all_opt(module_name) + else + opts = module_opts + end + + if opts then + local ret + + if opts.redis then + ret = process_redis_options(opts.redis, rspamd_config, result) + + if ret then + if not no_fallback then + enrich_defaults(rspamd_config, module_name, result) + end + return maybe_return_cached(result) + end + end + + ret = process_redis_options(opts, rspamd_config, result) + + if ret then + if not no_fallback then + enrich_defaults(rspamd_config, module_name, result) + end + return maybe_return_cached(result) + end + end + + if no_fallback then + logger.infox(rspamd_config, "cannot find Redis definitions for %s and fallback is disabled", + module_name) + + return nil + end + + -- Try global options + opts = rspamd_config:get_all_opt('redis') + + if opts then + local ret + + if opts[module_name] then + ret = process_redis_options(opts[module_name], rspamd_config, result) + + if ret then + return maybe_return_cached(result) + end + else + ret = process_redis_options(opts, rspamd_config, result) + + -- Exclude disabled + if opts['disabled_modules'] then + for _, v in ipairs(opts['disabled_modules']) do + if v == module_name then + logger.infox(rspamd_config, "NOT using default redis server for module %s: it is disabled", + module_name) + + return nil + end + end + end + + if ret then + logger.infox(rspamd_config, "use default Redis settings for %s", + module_name) + return maybe_return_cached(result) + end + end + end + + if result.read_servers then + return maybe_return_cached(result) + end + + return nil +end + +--[[[ +-- @function lua_redis.parse_redis_server(module_name, module_opts, no_fallback) +-- Extracts Redis server settings from configuration +-- @param {string} module_name name of module to get settings for +-- @param {table} module_opts settings for module or `nil` to fetch them from configuration +-- @param {boolean} no_fallback should be `true` if global settings must not be used +-- @return {table} redis server settings +-- @example +-- local rconfig = lua_redis.parse_redis_server('my_module') +-- -- rconfig contains upstream_list objects in ['write_servers'] and ['read_servers'] +-- -- ['timeout'] contains timeout in seconds +-- -- ['expand_keys'] if true tells that redis key expansion is enabled +--]] + +exports.rspamd_parse_redis_server = rspamd_parse_redis_server +exports.parse_redis_server = rspamd_parse_redis_server + +local process_cmd = { + bitop = function(args) + local idx_l = {} + for i = 2, #args do + table.insert(idx_l, i) + end + return idx_l + end, + blpop = function(args) + local idx_l = {} + for i = 1, #args - 1 do + table.insert(idx_l, i) + end + return idx_l + end, + eval = function(args) + local idx_l = {} + local numkeys = args[2] + if numkeys and tonumber(numkeys) >= 1 then + for i = 3, numkeys + 2 do + table.insert(idx_l, i) + end + end + return idx_l + end, + set = function(args) + return { 1 } + end, + mget = function(args) + local idx_l = {} + for i = 1, #args do + table.insert(idx_l, i) + end + return idx_l + end, + mset = function(args) + local idx_l = {} + for i = 1, #args, 2 do + table.insert(idx_l, i) + end + return idx_l + end, + sdiffstore = function(args) + local idx_l = {} + for i = 2, #args do + table.insert(idx_l, i) + end + return idx_l + end, + smove = function(args) + return { 1, 2 } + end, + script = function() + end +} +process_cmd.append = process_cmd.set +process_cmd.auth = process_cmd.script +process_cmd.bgrewriteaof = process_cmd.script +process_cmd.bgsave = process_cmd.script +process_cmd.bitcount = process_cmd.set +process_cmd.bitfield = process_cmd.set +process_cmd.bitpos = process_cmd.set +process_cmd.brpop = process_cmd.blpop +process_cmd.brpoplpush = process_cmd.blpop +process_cmd.client = process_cmd.script +process_cmd.cluster = process_cmd.script +process_cmd.command = process_cmd.script +process_cmd.config = process_cmd.script +process_cmd.dbsize = process_cmd.script +process_cmd.debug = process_cmd.script +process_cmd.decr = process_cmd.set +process_cmd.decrby = process_cmd.set +process_cmd.del = process_cmd.mget +process_cmd.discard = process_cmd.script +process_cmd.dump = process_cmd.set +process_cmd.echo = process_cmd.script +process_cmd.evalsha = process_cmd.eval +process_cmd.exec = process_cmd.script +process_cmd.exists = process_cmd.mget +process_cmd.expire = process_cmd.set +process_cmd.expireat = process_cmd.set +process_cmd.flushall = process_cmd.script +process_cmd.flushdb = process_cmd.script +process_cmd.geoadd = process_cmd.set +process_cmd.geohash = process_cmd.set +process_cmd.geopos = process_cmd.set +process_cmd.geodist = process_cmd.set +process_cmd.georadius = process_cmd.set +process_cmd.georadiusbymember = process_cmd.set +process_cmd.get = process_cmd.set +process_cmd.getbit = process_cmd.set +process_cmd.getrange = process_cmd.set +process_cmd.getset = process_cmd.set +process_cmd.hdel = process_cmd.set +process_cmd.hexists = process_cmd.set +process_cmd.hget = process_cmd.set +process_cmd.hgetall = process_cmd.set +process_cmd.hincrby = process_cmd.set +process_cmd.hincrbyfloat = process_cmd.set +process_cmd.hkeys = process_cmd.set +process_cmd.hlen = process_cmd.set +process_cmd.hmget = process_cmd.set +process_cmd.hmset = process_cmd.set +process_cmd.hscan = process_cmd.set +process_cmd.hset = process_cmd.set +process_cmd.hsetnx = process_cmd.set +process_cmd.hstrlen = process_cmd.set +process_cmd.hvals = process_cmd.set +process_cmd.incr = process_cmd.set +process_cmd.incrby = process_cmd.set +process_cmd.incrbyfloat = process_cmd.set +process_cmd.info = process_cmd.script +process_cmd.keys = process_cmd.script +process_cmd.lastsave = process_cmd.script +process_cmd.lindex = process_cmd.set +process_cmd.linsert = process_cmd.set +process_cmd.llen = process_cmd.set +process_cmd.lpop = process_cmd.set +process_cmd.lpush = process_cmd.set +process_cmd.lpushx = process_cmd.set +process_cmd.lrange = process_cmd.set +process_cmd.lrem = process_cmd.set +process_cmd.lset = process_cmd.set +process_cmd.ltrim = process_cmd.set +process_cmd.migrate = process_cmd.script +process_cmd.monitor = process_cmd.script +process_cmd.move = process_cmd.set +process_cmd.msetnx = process_cmd.mset +process_cmd.multi = process_cmd.script +process_cmd.object = process_cmd.script +process_cmd.persist = process_cmd.set +process_cmd.pexpire = process_cmd.set +process_cmd.pexpireat = process_cmd.set +process_cmd.pfadd = process_cmd.set +process_cmd.pfcount = process_cmd.set +process_cmd.pfmerge = process_cmd.mget +process_cmd.ping = process_cmd.script +process_cmd.psetex = process_cmd.set +process_cmd.psubscribe = process_cmd.script +process_cmd.pubsub = process_cmd.script +process_cmd.pttl = process_cmd.set +process_cmd.publish = process_cmd.script +process_cmd.punsubscribe = process_cmd.script +process_cmd.quit = process_cmd.script +process_cmd.randomkey = process_cmd.script +process_cmd.readonly = process_cmd.script +process_cmd.readwrite = process_cmd.script +process_cmd.rename = process_cmd.mget +process_cmd.renamenx = process_cmd.mget +process_cmd.restore = process_cmd.set +process_cmd.role = process_cmd.script +process_cmd.rpop = process_cmd.set +process_cmd.rpoplpush = process_cmd.mget +process_cmd.rpush = process_cmd.set +process_cmd.rpushx = process_cmd.set +process_cmd.sadd = process_cmd.set +process_cmd.save = process_cmd.script +process_cmd.scard = process_cmd.set +process_cmd.sdiff = process_cmd.mget +process_cmd.select = process_cmd.script +process_cmd.setbit = process_cmd.set +process_cmd.setex = process_cmd.set +process_cmd.setnx = process_cmd.set +process_cmd.sinterstore = process_cmd.sdiff +process_cmd.sismember = process_cmd.set +process_cmd.slaveof = process_cmd.script +process_cmd.slowlog = process_cmd.script +process_cmd.smembers = process_cmd.script +process_cmd.sort = process_cmd.set +process_cmd.spop = process_cmd.set +process_cmd.srandmember = process_cmd.set +process_cmd.srem = process_cmd.set +process_cmd.strlen = process_cmd.set +process_cmd.subscribe = process_cmd.script +process_cmd.sunion = process_cmd.mget +process_cmd.sunionstore = process_cmd.mget +process_cmd.swapdb = process_cmd.script +process_cmd.sync = process_cmd.script +process_cmd.time = process_cmd.script +process_cmd.touch = process_cmd.mget +process_cmd.ttl = process_cmd.set +process_cmd.type = process_cmd.set +process_cmd.unsubscribe = process_cmd.script +process_cmd.unlink = process_cmd.mget +process_cmd.unwatch = process_cmd.script +process_cmd.wait = process_cmd.script +process_cmd.watch = process_cmd.mget +process_cmd.zadd = process_cmd.set +process_cmd.zcard = process_cmd.set +process_cmd.zcount = process_cmd.set +process_cmd.zincrby = process_cmd.set +process_cmd.zinterstore = process_cmd.eval +process_cmd.zlexcount = process_cmd.set +process_cmd.zrange = process_cmd.set +process_cmd.zrangebylex = process_cmd.set +process_cmd.zrank = process_cmd.set +process_cmd.zrem = process_cmd.set +process_cmd.zrembylex = process_cmd.set +process_cmd.zrembyrank = process_cmd.set +process_cmd.zrembyscore = process_cmd.set +process_cmd.zrevrange = process_cmd.set +process_cmd.zrevrangebyscore = process_cmd.set +process_cmd.zrevrank = process_cmd.set +process_cmd.zscore = process_cmd.set +process_cmd.zunionstore = process_cmd.eval +process_cmd.scan = process_cmd.script +process_cmd.sscan = process_cmd.set +process_cmd.hscan = process_cmd.set +process_cmd.zscan = process_cmd.set + +local function get_key_indexes(cmd, args) + local idx_l = {} + cmd = string.lower(cmd) + if process_cmd[cmd] then + idx_l = process_cmd[cmd](args) + else + logger.warnx(rspamd_config, "Don't know how to extract keys for %s Redis command", cmd) + end + return idx_l +end + +local gen_meta = { + principal_recipient = function(task) + return task:get_principal_recipient() + end, + principal_recipient_domain = function(task) + local p = task:get_principal_recipient() + if not p then + return + end + return string.match(p, '.*@(.*)') + end, + ip = function(task) + local i = task:get_ip() + if i and i:is_valid() then + return i:to_string() + end + end, + from = function(task) + return ((task:get_from('smtp') or E)[1] or E)['addr'] + end, + from_domain = function(task) + return ((task:get_from('smtp') or E)[1] or E)['domain'] + end, + from_domain_or_helo_domain = function(task) + local d = ((task:get_from('smtp') or E)[1] or E)['domain'] + if d and #d > 0 then + return d + end + return task:get_helo() + end, + mime_from = function(task) + return ((task:get_from('mime') or E)[1] or E)['addr'] + end, + mime_from_domain = function(task) + return ((task:get_from('mime') or E)[1] or E)['domain'] + end, +} + +local function gen_get_esld(f) + return function(task) + local d = f(task) + if not d then + return + end + return rspamd_util.get_tld(d) + end +end + +gen_meta.smtp_from = gen_meta.from +gen_meta.smtp_from_domain = gen_meta.from_domain +gen_meta.smtp_from_domain_or_helo_domain = gen_meta.from_domain_or_helo_domain +gen_meta.esld_principal_recipient_domain = gen_get_esld(gen_meta.principal_recipient_domain) +gen_meta.esld_from_domain = gen_get_esld(gen_meta.from_domain) +gen_meta.esld_smtp_from_domain = gen_meta.esld_from_domain +gen_meta.esld_mime_from_domain = gen_get_esld(gen_meta.mime_from_domain) +gen_meta.esld_from_domain_or_helo_domain = gen_get_esld(gen_meta.from_domain_or_helo_domain) +gen_meta.esld_smtp_from_domain_or_helo_domain = gen_meta.esld_from_domain_or_helo_domain + +local function get_key_expansion_metadata(task) + + local md_mt = { + __index = function(self, k) + k = string.lower(k) + local v = rawget(self, k) + if v then + return v + end + if gen_meta[k] then + v = gen_meta[k](task) + rawset(self, k, v) + end + return v + end, + } + + local lazy_meta = {} + setmetatable(lazy_meta, md_mt) + return lazy_meta + +end + +-- Performs async call to redis hiding all complexity inside function +-- task - rspamd_task +-- redis_params - valid params returned by rspamd_parse_redis_server +-- key - key to select upstream or nil to select round-robin/master-slave +-- is_write - true if need to write to redis server +-- callback - function to be called upon request is completed +-- command - redis command +-- args - table of arguments +-- extra_opts - table of optional request arguments +local function rspamd_redis_make_request(task, redis_params, key, is_write, + callback, command, args, extra_opts) + local addr + local function rspamd_redis_make_request_cb(err, data) + if err then + addr:fail() + else + addr:ok() + end + if callback then + callback(err, data, addr) + end + end + if not task or not redis_params or not command then + return false, nil, nil + end + + local rspamd_redis = require "rspamd_redis" + + if key then + if is_write then + addr = redis_params['write_servers']:get_upstream_by_hash(key) + else + addr = redis_params['read_servers']:get_upstream_by_hash(key) + end + else + if is_write then + addr = redis_params['write_servers']:get_upstream_master_slave(key) + else + addr = redis_params['read_servers']:get_upstream_round_robin(key) + end + end + + if not addr then + logger.errx(task, 'cannot select server to make redis request') + end + + if redis_params['expand_keys'] then + local m = get_key_expansion_metadata(task) + local indexes = get_key_indexes(command, args) + for _, i in ipairs(indexes) do + args[i] = lutil.template(args[i], m) + end + end + + local ip_addr = addr:get_addr() + local options = { + task = task, + callback = rspamd_redis_make_request_cb, + host = ip_addr, + timeout = redis_params['timeout'], + cmd = command, + args = args + } + + if extra_opts then + for k, v in pairs(extra_opts) do + options[k] = v + end + end + + if redis_params['username'] then + options['username'] = redis_params['username'] + end + + if redis_params['password'] then + options['password'] = redis_params['password'] + end + + if redis_params['db'] then + options['dbname'] = redis_params['db'] + end + + lutil.debugm(N, task, 'perform request to redis server' .. + ' (host=%s, timeout=%s): cmd: %s', ip_addr, + options.timeout, options.cmd) + + local ret, conn = rspamd_redis.make_request(options) + + if not ret then + addr:fail() + logger.warnx(task, "cannot make redis request to: %s", tostring(ip_addr)) + end + + return ret, conn, addr +end + +--[[[ +-- @function lua_redis.redis_make_request(task, redis_params, key, is_write, callback, command, args) +-- Sends a request to Redis +-- @param {rspamd_task} task task object +-- @param {table} redis_params redis configuration in format returned by lua_redis.parse_redis_server() +-- @param {string} key key to use for sharding +-- @param {boolean} is_write should be `true` if we are performing a write operating +-- @param {function} callback callback function (first parameter is error if applicable, second is a 2D array (table)) +-- @param {string} command Redis command to run +-- @param {table} args Numerically indexed table containing arguments for command +--]] + +exports.rspamd_redis_make_request = rspamd_redis_make_request +exports.redis_make_request = rspamd_redis_make_request + +local function redis_make_request_taskless(ev_base, cfg, redis_params, key, + is_write, callback, command, args, extra_opts) + if not ev_base or not redis_params or not command then + return false, nil, nil + end + + local addr + local function rspamd_redis_make_request_cb(err, data) + if err then + addr:fail() + else + addr:ok() + end + if callback then + callback(err, data, addr) + end + end + + local rspamd_redis = require "rspamd_redis" + + if key then + if is_write then + addr = redis_params['write_servers']:get_upstream_by_hash(key) + else + addr = redis_params['read_servers']:get_upstream_by_hash(key) + end + else + if is_write then + addr = redis_params['write_servers']:get_upstream_master_slave(key) + else + addr = redis_params['read_servers']:get_upstream_round_robin(key) + end + end + + if not addr then + logger.errx(cfg, 'cannot select server to make redis request') + end + + local options = { + ev_base = ev_base, + config = cfg, + callback = rspamd_redis_make_request_cb, + host = addr:get_addr(), + timeout = redis_params['timeout'], + cmd = command, + args = args + } + if extra_opts then + for k, v in pairs(extra_opts) do + options[k] = v + end + end + + if redis_params['username'] then + options['username'] = redis_params['username'] + end + + if redis_params['password'] then + options['password'] = redis_params['password'] + end + + if redis_params['db'] then + options['dbname'] = redis_params['db'] + end + + lutil.debugm(N, cfg, 'perform taskless request to redis server' .. + ' (host=%s, timeout=%s): cmd: %s', options.host:tostring(true), + options.timeout, options.cmd) + local ret, conn = rspamd_redis.make_request(options) + if not ret then + logger.errx('cannot execute redis request') + addr:fail() + end + + return ret, conn, addr +end + +--[[[ +-- @function lua_redis.redis_make_request_taskless(ev_base, redis_params, key, is_write, callback, command, args) +-- Sends a request to Redis in context where `task` is not available for some specific use-cases +-- Identical to redis_make_request() except in that first parameter is an `event base` object +--]] + +exports.rspamd_redis_make_request_taskless = redis_make_request_taskless +exports.redis_make_request_taskless = redis_make_request_taskless + +local redis_scripts = { +} + +local function script_set_loaded(script) + if script.sha then + script.loaded = true + end + + local wait_table = {} + for _, s in ipairs(script.waitq) do + table.insert(wait_table, s) + end + + script.waitq = {} + + for _, s in ipairs(wait_table) do + s(script.loaded) + end +end + +local function prepare_redis_call(script) + local servers = {} + local options = {} + + if script.redis_params.read_servers then + servers = lutil.table_merge(servers, script.redis_params.read_servers:all_upstreams()) + end + if script.redis_params.write_servers then + servers = lutil.table_merge(servers, script.redis_params.write_servers:all_upstreams()) + end + + -- Call load script on each server, set loaded flag + script.in_flight = #servers + for _, s in ipairs(servers) do + local cur_opts = { + host = s:get_addr(), + timeout = script.redis_params['timeout'], + cmd = 'SCRIPT', + args = { 'LOAD', script.script }, + upstream = s + } + + if script.redis_params['username'] then + cur_opts['username'] = script.redis_params['username'] + end + + if script.redis_params['password'] then + cur_opts['password'] = script.redis_params['password'] + end + + if script.redis_params['db'] then + cur_opts['dbname'] = script.redis_params['db'] + end + + table.insert(options, cur_opts) + end + + return options +end + +local function load_script_task(script, task, is_write) + local rspamd_redis = require "rspamd_redis" + local opts = prepare_redis_call(script) + + for _, opt in ipairs(opts) do + opt.task = task + opt.is_write = is_write + opt.callback = function(err, data) + if err then + logger.errx(task, 'cannot upload script to %s: %s; registered from: %s:%s', + opt.upstream:get_addr():to_string(true), + err, script.caller.short_src, script.caller.currentline) + opt.upstream:fail() + script.fatal_error = err + else + opt.upstream:ok() + logger.infox(task, + "uploaded redis script to %s %s %s, sha: %s", + opt.upstream:get_addr():to_string(true), + script.filename and "from file" or "with id", script.filename or script.id, data) + script.sha = data -- We assume that sha is the same on all servers + end + script.in_flight = script.in_flight - 1 + + if script.in_flight == 0 then + script_set_loaded(script) + end + end + + local ret = rspamd_redis.make_request(opt) + + if not ret then + logger.errx('cannot execute redis request to load script on %s', + opt.upstream:get_addr()) + script.in_flight = script.in_flight - 1 + opt.upstream:fail() + end + + if script.in_flight == 0 then + script_set_loaded(script) + end + end +end + +local function load_script_taskless(script, cfg, ev_base, is_write) + local rspamd_redis = require "rspamd_redis" + local opts = prepare_redis_call(script) + + for _, opt in ipairs(opts) do + opt.config = cfg + opt.ev_base = ev_base + opt.is_write = is_write + opt.callback = function(err, data) + if err then + logger.errx(cfg, 'cannot upload script to %s: %s; registered from: %s:%s, filename: %s', + opt.upstream:get_addr():to_string(true), + err, script.caller.short_src, script.caller.currentline, script.filename) + opt.upstream:fail() + script.fatal_error = err + else + opt.upstream:ok() + logger.infox(cfg, + "uploaded redis script to %s %s %s, sha: %s", + opt.upstream:get_addr():to_string(true), + script.filename and "from file" or "with id", script.filename or script.id, + data) + script.sha = data -- We assume that sha is the same on all servers + script.fatal_error = nil + end + script.in_flight = script.in_flight - 1 + + if script.in_flight == 0 then + script_set_loaded(script) + end + end + local ret = rspamd_redis.make_request(opt) + + if not ret then + logger.errx('cannot execute redis request to load script on %s', + opt.upstream:get_addr()) + script.in_flight = script.in_flight - 1 + opt.upstream:fail() + end + + if script.in_flight == 0 then + script_set_loaded(script) + end + end +end + +local function load_redis_script(script, cfg, ev_base, _) + if script.redis_params then + load_script_taskless(script, cfg, ev_base) + end +end + +local function add_redis_script(script, redis_params, caller_level, maybe_filename) + if not caller_level then + caller_level = 2 + end + local caller = debug.getinfo(caller_level) or debug.getinfo(caller_level - 1) or E + + local new_script = { + caller = caller, + loaded = false, + redis_params = redis_params, + script = script, + waitq = {}, -- callbacks pending for script being loaded + id = #redis_scripts + 1, + filename = maybe_filename, + } + + -- Register on load function + rspamd_config:add_on_load(function(cfg, ev_base, worker) + local mult = 0.0 + rspamd_config:add_periodic(ev_base, 0.0, function() + if not new_script.sha then + load_redis_script(new_script, cfg, ev_base, worker) + mult = mult + 1 + return 1.0 * mult -- Check one more time in one second + end + + return false + end, false) + end) + + table.insert(redis_scripts, new_script) + + return #redis_scripts +end +exports.add_redis_script = add_redis_script + +-- Loads a Redis script from a file, strips comments, and passes the content to +-- `add_redis_script` function. +-- +-- @param filename The name of the file containing the Redis script. +-- @param redis_params The Redis parameters to use for this script. +-- @return The ID of the newly added Redis script. +-- +local function load_redis_script_from_file(filename, redis_params, dir) + local lua_util = require "lua_util" + local rspamd_logger = require "rspamd_logger" + + if not dir then + dir = rspamd_paths.LUALIBDIR + end + local path = filename + if filename:sub(1, 1) ~= package.config:sub(1, 1) then + -- Relative path + path = lua_util.join_path(dir, "redis_scripts", filename) + end + -- Read file contents + local file = io.open(path, "r") + if not file then + rspamd_logger.errx("failed to open Redis script file: %s", path) + return nil + end + local script = file:read("*all") + if not script then + rspamd_logger.errx("failed to load Redis script file: %s", path) + return nil + end + file:close() + script = lua_util.strip_lua_comments(script) + + return add_redis_script(script, redis_params, 3, filename) +end + +exports.load_redis_script_from_file = load_redis_script_from_file + +local function exec_redis_script(id, params, callback, keys, args) + local redis_args = {} + + if not redis_scripts[id] then + logger.errx("cannot find registered script with id %s", id) + return false + end + + local script = redis_scripts[id] + + if script.fatal_error then + callback(script.fatal_error, nil) + return true + end + + if not script.redis_params then + callback('no redis servers defined', nil) + return true + end + + local function do_call(can_reload) + local function redis_cb(err, data) + if not err then + callback(err, data) + elseif string.match(err, 'NOSCRIPT') then + -- Schedule restart + script.sha = nil + if can_reload then + table.insert(script.waitq, do_call) + if script.in_flight == 0 then + -- Reload scripts if this has not been initiated yet + if params.task then + load_script_task(script, params.task) + else + load_script_taskless(script, rspamd_config, params.ev_base) + end + end + else + callback(err, data) + end + else + callback(err, data) + end + end + + if #redis_args == 0 then + table.insert(redis_args, script.sha) + table.insert(redis_args, tostring(#keys)) + for _, k in ipairs(keys) do + table.insert(redis_args, k) + end + + if type(args) == 'table' then + for _, a in ipairs(args) do + table.insert(redis_args, a) + end + end + end + + if params.task then + if not rspamd_redis_make_request(params.task, script.redis_params, + params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then + callback('Cannot make redis request', nil) + end + else + if not redis_make_request_taskless(params.ev_base, rspamd_config, + script.redis_params, + params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then + callback('Cannot make redis request', nil) + end + end + end + + if script.loaded then + do_call(true) + else + -- Delayed until scripts are loaded + if not params.task then + table.insert(script.waitq, do_call) + else + -- TODO: fix taskfull requests + table.insert(script.waitq, function() + if script.loaded then + do_call(false) + else + callback('NOSCRIPT', nil) + end + end) + load_script_task(script, params.task, params.is_write) + end + end + + return true +end + +exports.exec_redis_script = exec_redis_script + +local function redis_connect_sync(redis_params, is_write, key, cfg, ev_base) + if not redis_params then + return false, nil + end + + local rspamd_redis = require "rspamd_redis" + local addr + + if key then + if is_write then + addr = redis_params['write_servers']:get_upstream_by_hash(key) + else + addr = redis_params['read_servers']:get_upstream_by_hash(key) + end + else + if is_write then + addr = redis_params['write_servers']:get_upstream_master_slave(key) + else + addr = redis_params['read_servers']:get_upstream_round_robin(key) + end + end + + if not addr then + logger.errx(cfg, 'cannot select server to make redis request') + end + + local options = { + host = addr:get_addr(), + timeout = redis_params['timeout'], + config = cfg or rspamd_config, + ev_base = ev_base or rspamadm_ev_base, + session = redis_params.session or rspamadm_session + } + + for k, v in pairs(redis_params) do + options[k] = v + end + + if not options.config then + logger.errx('config is not set') + return false, nil, addr + end + + if not options.ev_base then + logger.errx('ev_base is not set') + return false, nil, addr + end + + if not options.session then + logger.errx('session is not set') + return false, nil, addr + end + + local ret, conn = rspamd_redis.connect_sync(options) + if not ret then + logger.errx('cannot create redis connection: %s', conn) + addr:fail() + + return false, nil, addr + end + + if conn then + local need_exec = false + if redis_params['username'] then + if redis_params['password'] then + conn:add_cmd('AUTH', { redis_params['username'], redis_params['password'] }) + need_exec = true + else + logger.warnx('Redis requires a password when username is supplied') + return false, nil, addr + end + elseif redis_params['password'] then + conn:add_cmd('AUTH', { redis_params['password'] }) + need_exec = true + end + + if redis_params['db'] then + conn:add_cmd('SELECT', { tostring(redis_params['db']) }) + need_exec = true + elseif redis_params['dbname'] then + conn:add_cmd('SELECT', { tostring(redis_params['dbname']) }) + need_exec = true + end + + if need_exec then + local exec_ret, res = conn:exec() + + if not exec_ret then + logger.errx('cannot prepare redis connection (authentication or db selection failure): %s', + res) + addr:fail() + return false, nil, addr + end + end + end + + return ret, conn, addr +end + +exports.redis_connect_sync = redis_connect_sync + +--[[[ +-- @function lua_redis.request(redis_params, attrs, req) +-- Sends a request to Redis synchronously with coroutines or asynchronously using +-- a callback (modern API) +-- @param redis_params a table of redis server parameters +-- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session) +-- @param req a table of request: a command + command options +-- @return {result,data/connection,address} boolean result, connection object in case of async request and results if using coroutines, redis server address +--]] + +exports.request = function(redis_params, attrs, req) + local lua_util = require "lua_util" + + if not attrs or not redis_params or not req then + logger.errx('invalid arguments for redis request') + return false, nil, nil + end + + if not (attrs.task or (attrs.config and attrs.ev_base)) then + logger.errx('invalid attributes for redis request') + return false, nil, nil + end + + local opts = lua_util.shallowcopy(attrs) + + local log_obj = opts.task or opts.config + + local addr + + if opts.callback then + -- Wrap callback + local callback = opts.callback + local function rspamd_redis_make_request_cb(err, data) + if err then + addr:fail() + else + addr:ok() + end + callback(err, data, addr) + end + opts.callback = rspamd_redis_make_request_cb + end + + local rspamd_redis = require "rspamd_redis" + local is_write = opts.is_write + + if opts.key then + if is_write then + addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key) + else + addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key) + end + else + if is_write then + addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key) + else + addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key) + end + end + + if not addr then + logger.errx(log_obj, 'cannot select server to make redis request') + end + + opts.host = addr:get_addr() + opts.timeout = redis_params.timeout + + if type(req) == 'string' then + opts.cmd = req + else + -- XXX: modifies the input table + opts.cmd = table.remove(req, 1); + opts.args = req + end + + if redis_params.username then + opts.username = redis_params.username + end + + if redis_params.password then + opts.password = redis_params.password + end + + if redis_params.db then + opts.dbname = redis_params.db + end + + lutil.debugm(N, 'perform generic request to redis server' .. + ' (host=%s, timeout=%s): cmd: %s, arguments: %s', addr, + opts.timeout, opts.cmd, opts.args) + + if opts.callback then + local ret, conn = rspamd_redis.make_request(opts) + if not ret then + logger.errx(log_obj, 'cannot execute redis request') + addr:fail() + end + + return ret, conn, addr + else + -- Coroutines version + local ret, conn = rspamd_redis.connect_sync(opts) + if not ret then + logger.errx(log_obj, 'cannot execute redis request') + addr:fail() + else + conn:add_cmd(opts.cmd, opts.args) + return conn:exec() + end + return false, nil, addr + end +end + +--[[[ +-- @function lua_redis.connect(redis_params, attrs) +-- Connects to Redis synchronously with coroutines or asynchronously using a callback (modern API) +-- @param redis_params a table of redis server parameters +-- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session) +-- @return {result,connection,address} boolean result, connection object, redis server address +--]] + +exports.connect = function(redis_params, attrs) + local lua_util = require "lua_util" + + if not attrs or not redis_params then + logger.errx('invalid arguments for redis connect') + return false, nil, nil + end + + if not (attrs.task or (attrs.config and attrs.ev_base)) then + logger.errx('invalid attributes for redis connect') + return false, nil, nil + end + + local opts = lua_util.shallowcopy(attrs) + + local log_obj = opts.task or opts.config + + local addr + + if opts.callback then + -- Wrap callback + local callback = opts.callback + local function rspamd_redis_make_request_cb(err, data) + if err then + addr:fail() + else + addr:ok() + end + callback(err, data, addr) + end + opts.callback = rspamd_redis_make_request_cb + end + + local rspamd_redis = require "rspamd_redis" + local is_write = opts.is_write + + if opts.key then + if is_write then + addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key) + else + addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key) + end + else + if is_write then + addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key) + else + addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key) + end + end + + if not addr then + logger.errx(log_obj, 'cannot select server to make redis connect') + end + + opts.host = addr:get_addr() + opts.timeout = redis_params.timeout + + if redis_params.username then + opts.username = redis_params.username + end + + if redis_params.password then + opts.password = redis_params.password + end + + if redis_params.db then + opts.dbname = redis_params.db + end + + if opts.callback then + local ret, conn = rspamd_redis.connect(opts) + if not ret then + logger.errx(log_obj, 'cannot execute redis connect') + addr:fail() + end + + return ret, conn, addr + else + -- Coroutines version + local ret, conn = rspamd_redis.connect_sync(opts) + if not ret then + logger.errx(log_obj, 'cannot execute redis connect') + addr:fail() + else + return true, conn, addr + end + + return false, nil, addr + end +end + +local redis_prefixes = {} + +--[[[ +-- @function lua_redis.register_prefix(prefix, module, description[, optional]) +-- Register new redis prefix for documentation purposes +-- @param {string} prefix string prefix +-- @param {string} module module name +-- @param {string} description prefix description +-- @param {table} optional optional kv pairs (e.g. pattern) +--]] +local function register_prefix(prefix, module, description, optional) + local pr = { + module = module, + description = description + } + + if optional and type(optional) == 'table' then + for k, v in pairs(optional) do + pr[k] = v + end + end + + redis_prefixes[prefix] = pr +end + +exports.register_prefix = register_prefix + +--[[[ +-- @function lua_redis.prefixes([mname]) +-- Returns prefixes for specific module (or all prefixes). Returns a table prefix -> table +--]] +exports.prefixes = function(mname) + if not mname then + return redis_prefixes + else + local fun = require "fun" + + return fun.totable(fun.filter(function(_, data) + return data.module == mname + end, + redis_prefixes)) + end +end + +return exports diff --git a/lualib/lua_scanners/avast.lua b/lualib/lua_scanners/avast.lua new file mode 100644 index 0000000..7e77897 --- /dev/null +++ b/lualib/lua_scanners/avast.lua @@ -0,0 +1,304 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module avast +-- This module contains avast av access functions +--]] + +local lua_util = require "lua_util" +local rspamd_util = require "rspamd_util" +local tcp = require "rspamd_tcp" +local upstream_list = require "rspamd_upstream_list" +local rspamd_regexp = require "rspamd_regexp" +local rspamd_logger = require "rspamd_logger" +local common = require "lua_scanners/common" + +local N = "avast" + +local default_message = '${SCANNER}: virus found: "${VIRUS}"' + +local function avast_config(opts) + local avast_conf = { + name = N, + scan_mime_parts = true, + scan_text_mime = false, + scan_image_mime = false, + timeout = 4.0, -- FIXME: this will break task_timeout! + log_clean = false, + detection_category = "virus", + retransmits = 1, + servers = nil, -- e.g. /var/run/avast/scan.sock + cache_expire = 3600, -- expire redis in one hour + message = default_message, + tmpdir = '/tmp', + } + + avast_conf = lua_util.override_defaults(avast_conf, opts) + + if not avast_conf.prefix then + avast_conf.prefix = 'rs_' .. avast_conf.name .. '_' + end + + if not avast_conf.log_prefix then + if avast_conf.name:lower() == avast_conf.type:lower() then + avast_conf.log_prefix = avast_conf.name + else + avast_conf.log_prefix = avast_conf.name .. ' (' .. avast_conf.type .. ')' + end + end + + if not avast_conf['servers'] then + rspamd_logger.errx(rspamd_config, 'no servers/unix socket defined') + + return nil + end + + avast_conf['upstreams'] = upstream_list.create(rspamd_config, + avast_conf['servers'], + 0) + + if avast_conf['upstreams'] then + lua_util.add_debug_alias('antivirus', avast_conf.name) + return avast_conf + end + + rspamd_logger.errx(rspamd_config, 'cannot parse servers %s', + avast_conf['servers']) + return nil +end + +local function avast_check(task, content, digest, rule, maybe_part) + local function avast_check_uncached () + local upstream = rule.upstreams:get_upstream_round_robin() + local addr = upstream:get_addr() + local retransmits = rule.retransmits + local CRLF = '\r\n' + + -- Common tcp options + local tcp_opts = { + stop_pattern = CRLF, + host = addr:to_string(), + port = addr:get_port(), + upstream = upstream, + timeout = rule.timeout, + task = task + } + + -- Regexps to process reply from avast + local clean_re = rspamd_regexp.create_cached( + [=[(?!\\)\t\[\+\]]=] + ) + local virus_re = rspamd_regexp.create_cached( + [[(?!\\)\t\[L\]\d\.\d\t\d\s(.*)]] + ) + local error_re = rspamd_regexp.create_cached( + [[(?!\\)\t\[E\]\d+\.0\tError\s\d+\s(.*)]] + ) + + -- Used to make a dialog + local tcp_conn + + -- Save content in file as avast can work with files only + local fname = string.format('%s/%s.avtmp', + rule.tmpdir, rspamd_util.random_hex(32)) + local message_fd = rspamd_util.create_file(fname) + + if not message_fd then + rspamd_logger.errx('cannot store file for avast scan: %s', fname) + return + end + + if type(content) == 'string' then + -- Create rspamd_text + local rspamd_text = require "rspamd_text" + content = rspamd_text.fromstring(content) + end + content:save_in_file(message_fd) + + -- Ensure file cleanup on task processed + task:get_mempool():add_destructor(function() + os.remove(fname) + rspamd_util.close_file(message_fd) + end) + + -- Dialog stages closures + local avast_helo_cb + local avast_scan_cb + local avast_scan_done_cb + + -- Utility closures + local function maybe_retransmit() + if retransmits > 0 then + retransmits = retransmits - 1 + else + rspamd_logger.errx(task, + '%s [%s]: failed to scan, maximum retransmits exceed', + rule['symbol'], rule['type']) + common.yield_result(task, rule, 'failed to scan and retransmits exceed', + 0.0, 'fail', maybe_part) + + return + end + + upstream = rule.upstreams:get_upstream_round_robin() + addr = upstream:get_addr() + tcp_opts.upstream = upstream + tcp_opts.callback = avast_helo_cb + + local is_succ, err = tcp.request(tcp_opts) + + if not is_succ then + rspamd_logger.infox(task, 'cannot create connection to avast server: %s (%s)', + addr:to_string(true), err) + else + lua_util.debugm(rule.log_prefix, task, 'established connection to %s; retransmits=%s', + addr:to_string(true), retransmits) + end + end + + local function no_connection_error(err) + if err then + if tcp_conn then + tcp_conn:close() + tcp_conn = nil + + rspamd_logger.infox(task, 'failed to request to avast (%s): %s', + addr:to_string(true), err) + maybe_retransmit() + end + + return false + end + + return true + end + + + -- Define callbacks + avast_helo_cb = function(merr, mdata, conn) + -- Called when we have established a connection but not read anything + tcp_conn = conn + + if no_connection_error(merr) then + -- Check mdata to ensure that it starts with 220 + if #mdata > 3 and tostring(mdata:span(1, 3)) == '220' then + tcp_conn:add_write(avast_scan_cb, string.format( + 'SCAN %s%s', fname, CRLF)) + else + rspamd_logger.errx(task, 'Unhandled response: %s', mdata) + end + end + end + + avast_scan_cb = function(merr) + -- Called when we have send request to avast and are waiting for reply + if no_connection_error(merr) then + tcp_conn:add_read(avast_scan_done_cb, CRLF) + end + end + + avast_scan_done_cb = function(merr, mdata) + if no_connection_error(merr) then + lua_util.debugm(rule.log_prefix, task, 'got reply from avast: %s', + mdata) + if #mdata > 4 then + local beg = tostring(mdata:span(1, 3)) + + if beg == '210' then + -- Ignore 210, fire another read + if tcp_conn then + tcp_conn:add_read(avast_scan_done_cb, CRLF) + end + elseif beg == '200' then + -- Final line + if tcp_conn then + tcp_conn:close() + tcp_conn = nil + end + else + -- Check line using regular expressions + local cached + local ret = clean_re:search(mdata, false, true) + + if ret then + cached = 'OK' + if rule.log_clean then + rspamd_logger.infox(task, + '%s [%s]: message or mime_part is clean', + rule.symbol, rule.type) + end + end + + if not cached then + ret = virus_re:search(mdata, false, true) + + if ret then + local vname = ret[1][2] + + if vname then + vname = vname:gsub('\\ ', ' '):gsub('\\\\', '\\') + common.yield_result(task, rule, vname, 1.0, nil, maybe_part) + cached = vname + end + end + end + + if not cached then + ret = error_re:search(mdata, false, true) + + if ret then + rspamd_logger.errx(task, '%s: error: %s', rule.log_prefix, ret[1][2]) + common.yield_result(task, rule, 'error:' .. ret[1][2], + 0.0, 'fail', maybe_part) + end + end + + if cached then + common.save_cache(task, digest, rule, cached, 1.0, maybe_part) + else + -- Unexpected reply + rspamd_logger.errx(task, '%s: unexpected reply: %s', rule.log_prefix, mdata) + end + -- Read more + if tcp_conn then + tcp_conn:add_read(avast_scan_done_cb, CRLF) + end + end + end + end + end + + -- Send the real request + maybe_retransmit() + end + + if common.condition_check_and_continue(task, content, rule, digest, + avast_check_uncached, maybe_part) then + return + else + avast_check_uncached() + end + +end + +return { + type = 'antivirus', + description = 'Avast antivirus', + configure = avast_config, + check = avast_check, + name = N +} diff --git a/lualib/lua_scanners/clamav.lua b/lualib/lua_scanners/clamav.lua new file mode 100644 index 0000000..fc99ab0 --- /dev/null +++ b/lualib/lua_scanners/clamav.lua @@ -0,0 +1,193 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module clamav +-- This module contains clamav access functions +--]] + +local lua_util = require "lua_util" +local tcp = require "rspamd_tcp" +local upstream_list = require "rspamd_upstream_list" +local rspamd_util = require "rspamd_util" +local rspamd_logger = require "rspamd_logger" +local common = require "lua_scanners/common" + +local N = "clamav" + +local default_message = '${SCANNER}: virus found: "${VIRUS}"' + +local function clamav_config(opts) + local clamav_conf = { + name = N, + scan_mime_parts = true, + scan_text_mime = false, + scan_image_mime = false, + default_port = 3310, + log_clean = false, + timeout = 5.0, -- FIXME: this will break task_timeout! + detection_category = "virus", + retransmits = 2, + cache_expire = 3600, -- expire redis in one hour + message = default_message, + } + + clamav_conf = lua_util.override_defaults(clamav_conf, opts) + + if not clamav_conf.prefix then + clamav_conf.prefix = 'rs_' .. clamav_conf.name .. '_' + end + + if not clamav_conf.log_prefix then + if clamav_conf.name:lower() == clamav_conf.type:lower() then + clamav_conf.log_prefix = clamav_conf.name + else + clamav_conf.log_prefix = clamav_conf.name .. ' (' .. clamav_conf.type .. ')' + end + end + + if not clamav_conf['servers'] then + rspamd_logger.errx(rspamd_config, 'no servers defined') + + return nil + end + + clamav_conf['upstreams'] = upstream_list.create(rspamd_config, + clamav_conf['servers'], + clamav_conf.default_port) + + if clamav_conf['upstreams'] then + lua_util.add_debug_alias('antivirus', clamav_conf.name) + return clamav_conf + end + + rspamd_logger.errx(rspamd_config, 'cannot parse servers %s', + clamav_conf['servers']) + return nil +end + +local function clamav_check(task, content, digest, rule, maybe_part) + local function clamav_check_uncached () + local upstream = rule.upstreams:get_upstream_round_robin() + local addr = upstream:get_addr() + local retransmits = rule.retransmits + local header = rspamd_util.pack("c9 c1 >I4", "zINSTREAM", "\0", + #content) + local footer = rspamd_util.pack(">I4", 0) + + local function clamav_callback(err, data) + if err then + + -- retry with another upstream until retransmits exceeds + if retransmits > 0 then + + retransmits = retransmits - 1 + + -- Select a different upstream! + upstream = rule.upstreams:get_upstream_round_robin() + addr = upstream:get_addr() + + lua_util.debugm(rule.name, task, '%s: error: %s; retry IP: %s; retries left: %s', + rule.log_prefix, err, addr, retransmits) + + tcp.request({ + task = task, + host = addr:to_string(), + port = addr:get_port(), + upstream = upstream, + timeout = rule['timeout'], + callback = clamav_callback, + data = { header, content, footer }, + stop_pattern = '\0' + }) + else + rspamd_logger.errx(task, '%s: failed to scan, maximum retransmits exceed', rule.log_prefix) + common.yield_result(task, rule, + 'failed to scan and retransmits exceed', 0.0, 'fail', + maybe_part) + end + + else + data = tostring(data) + local cached + lua_util.debugm(rule.name, task, '%s: got reply: %s', + rule.log_prefix, data) + if data == 'stream: OK' then + cached = 'OK' + if rule['log_clean'] then + rspamd_logger.infox(task, '%s: message or mime_part is clean', + rule.log_prefix) + else + lua_util.debugm(rule.name, task, '%s: message or mime_part is clean', rule.log_prefix) + end + else + local vname = string.match(data, 'stream: (.+) FOUND') + if string.find(vname, '^Heuristics%.Encrypted') then + rspamd_logger.errx(task, '%s: File is encrypted', rule.log_prefix) + common.yield_result(task, rule, 'File is encrypted: ' .. vname, + 0.0, 'encrypted', maybe_part) + cached = 'ENCRYPTED' + elseif string.find(vname, '^Heuristics%.OLE2%.ContainsMacros') then + rspamd_logger.errx(task, '%s: ClamAV Found an OLE2 Office Macro', rule.log_prefix) + common.yield_result(task, rule, vname, 0.0, 'macro', maybe_part) + cached = 'MACRO' + elseif string.find(vname, '^Heuristics%.Limits%.Exceeded') then + rspamd_logger.errx(task, '%s: ClamAV Limits Exceeded', rule.log_prefix) + common.yield_result(task, rule, 'Limits Exceeded: ' .. vname, 0.0, + 'fail', maybe_part) + elseif vname then + common.yield_result(task, rule, vname, 1.0, nil, maybe_part) + cached = vname + else + rspamd_logger.errx(task, '%s: unhandled response: %s', rule.log_prefix, data) + common.yield_result(task, rule, 'unhandled response:' .. vname, 0.0, + 'fail', maybe_part) + end + end + if cached then + common.save_cache(task, digest, rule, cached, 1.0, maybe_part) + end + end + end + + tcp.request({ + task = task, + host = addr:to_string(), + port = addr:get_port(), + timeout = rule['timeout'], + callback = clamav_callback, + upstream = upstream, + data = { header, content, footer }, + stop_pattern = '\0' + }) + end + + if common.condition_check_and_continue(task, content, rule, digest, + clamav_check_uncached, maybe_part) then + return + else + clamav_check_uncached() + end + +end + +return { + type = 'antivirus', + description = 'clamav antivirus', + configure = clamav_config, + check = clamav_check, + name = N +} diff --git a/lualib/lua_scanners/cloudmark.lua b/lualib/lua_scanners/cloudmark.lua new file mode 100644 index 0000000..b07f238 --- /dev/null +++ b/lualib/lua_scanners/cloudmark.lua @@ -0,0 +1,372 @@ +--[[ +Copyright (c) 2021, Alexander Moisseev <moiseev@mezonplus.ru> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module cloudmark +-- This module contains Cloudmark v2 interface +--]] + +local lua_util = require "lua_util" +local http = require "rspamd_http" +local upstream_list = require "rspamd_upstream_list" +local rspamd_logger = require "rspamd_logger" +local ucl = require "ucl" +local rspamd_util = require "rspamd_util" +local common = require "lua_scanners/common" +local fun = require "fun" +local lua_mime = require "lua_mime" + +local N = 'cloudmark' +-- Boundary for multipart transfers, generated on module init +local static_boundary = rspamd_util.random_hex(32) + +local function cloudmark_url(rule, addr, maybe_url) + local url + local port = addr:get_port() + + maybe_url = maybe_url or rule.url + if port == 0 then + port = rule.default_port + end + if rule.use_https then + url = string.format('https://%s:%d%s', tostring(addr), + port, maybe_url) + else + url = string.format('http://%s:%d%s', tostring(addr), + port, maybe_url) + end + + return url +end + +-- Detect cloudmark max size +local function cloudmark_preload(rule, cfg, ev_base, _) + local upstream = rule.upstreams:get_upstream_round_robin() + local addr = upstream:get_addr() + local function max_message_size_cb(http_err, code, body, _) + if http_err then + rspamd_logger.errx(ev_base, 'HTTP error when getting max message size: %s', + http_err) + return + end + if code ~= 200 then + rspamd_logger.errx(ev_base, 'bad HTTP code when getting max message size: %s', code) + end + local parser = ucl.parser() + local ret, err = parser:parse_string(body) + if not ret then + rspamd_logger.errx(ev_base, 'could not parse response body [%s]: %s', body, err) + return + end + local obj = parser:get_object() + local ms = obj.maxMessageSize + if not ms then + rspamd_logger.errx(ev_base, 'missing maxMessageSize in the response body (JSON): %s', obj) + return + end + + rule.max_size = ms + lua_util.debugm(N, cfg, 'set maximum message size set to %s bytes', ms) + end + http.request({ + ev_base = ev_base, + config = cfg, + url = cloudmark_url(rule, addr, '/score/v2/max-message-size'), + callback = max_message_size_cb, + }) +end + +local function cloudmark_config(opts) + + local cloudmark_conf = { + name = N, + default_port = 2713, + url = '/score/v2/message', + use_https = false, + timeout = 5.0, + log_clean = false, + retransmits = 1, + score_threshold = 90, -- minimum score to considerate reply + message = '${SCANNER}: spam message found: "${VIRUS}"', + max_message = 0, + detection_category = "hash", + default_score = 1, + action = false, + log_spamcause = true, + symbol_fail = 'CLOUDMARK_FAIL', + symbol = 'CLOUDMARK_CHECK', + symbol_spam = 'CLOUDMARK_SPAM', + add_headers = false, -- allow addition of the headers from Cloudmark + } + + cloudmark_conf = lua_util.override_defaults(cloudmark_conf, opts) + + if not cloudmark_conf.prefix then + cloudmark_conf.prefix = 'rs_' .. cloudmark_conf.name .. '_' + end + + if not cloudmark_conf.log_prefix then + if cloudmark_conf.name:lower() == cloudmark_conf.type:lower() then + cloudmark_conf.log_prefix = cloudmark_conf.name + else + cloudmark_conf.log_prefix = cloudmark_conf.name .. ' (' .. cloudmark_conf.type .. ')' + end + end + + if not cloudmark_conf.servers and cloudmark_conf.socket then + cloudmark_conf.servers = cloudmark_conf.socket + end + + if not cloudmark_conf.servers then + rspamd_logger.errx(rspamd_config, 'no servers defined') + + return nil + end + + cloudmark_conf.upstreams = upstream_list.create(rspamd_config, + cloudmark_conf.servers, + cloudmark_conf.default_port) + + if cloudmark_conf.upstreams then + + cloudmark_conf.symbols = { { symbol = cloudmark_conf.symbol_spam, score = 5.0 } } + cloudmark_conf.preloads = { cloudmark_preload } + lua_util.add_debug_alias('external_services', cloudmark_conf.name) + return cloudmark_conf + end + + rspamd_logger.errx(rspamd_config, 'cannot parse servers %s', + cloudmark_conf['servers']) + return nil +end + +-- Converts a key-value map to the table representing multipart body, with the following values: +-- `data`: data of the part +-- `filename`: optional filename +-- `content-type`: content type of the element (optional) +-- `content-transfer-encoding`: optional CTE header +local function table_to_multipart_body(tbl, boundary) + local seen_data = false + local out = {} + + for k, v in pairs(tbl) do + if v.data then + seen_data = true + table.insert(out, string.format('--%s\r\n', boundary)) + if v.filename then + table.insert(out, + string.format('Content-Disposition: form-data; name="%s"; filename="%s"\r\n', + k, v.filename)) + else + table.insert(out, + string.format('Content-Disposition: form-data; name="%s"\r\n', k)) + end + if v['content-type'] then + table.insert(out, + string.format('Content-Type: %s\r\n', v['content-type'])) + else + table.insert(out, 'Content-Type: text/plain\r\n') + end + if v['content-transfer-encoding'] then + table.insert(out, + string.format('Content-Transfer-Encoding: %s\r\n', + v['content-transfer-encoding'])) + else + table.insert(out, 'Content-Transfer-Encoding: binary\r\n') + end + table.insert(out, '\r\n') + table.insert(out, v.data) + table.insert(out, '\r\n') + end + end + + if seen_data then + table.insert(out, string.format('--%s--\r\n', boundary)) + end + + return out +end + +local function parse_cloudmark_reply(task, rule, body) + local parser = ucl.parser() + local ret, err = parser:parse_string(body) + if not ret then + rspamd_logger.errx(task, '%s: bad response body (raw): %s', N, body) + task:insert_result(rule.symbol_fail, 1.0, 'Parser error: ' .. err) + return + end + local obj = parser:get_object() + lua_util.debugm(N, task, 'cloudmark reply is: %s', obj) + + if not obj.score then + rspamd_logger.errx(task, '%s: bad response body (raw): %s', N, body) + task:insert_result(rule.symbol_fail, 1.0, 'Parser error: no score') + return + end + + if obj.analysis then + -- Report analysis string + rspamd_logger.infox(task, 'cloudmark report string: %s', obj.analysis) + end + + local score = tonumber(obj.score) or 0 + if score >= rule.score_threshold then + task:insert_result(rule.symbol_spam, 1.0, tostring(score)) + end + + if rule.add_headers and type(obj.appendHeaders) == 'table' then + local headers_add = fun.tomap(fun.map(function(h) + return h.headerField, { + order = 1, value = h.body + } + end, obj.appendHeaders)) + lua_mime.modify_headers(task, { + add = headers_add + }) + end + +end + +local function cloudmark_check(task, content, digest, rule, maybe_part) + local function cloudmark_check_uncached() + local upstream = rule.upstreams:get_upstream_round_robin() + local addr = upstream:get_addr() + local retransmits = rule.retransmits + + local url = cloudmark_url(rule, addr) + local message_data = task:get_content() + if rule.max_message and rule.max_message > 0 and #message_data > rule.max_message then + task:insert_result(rule['symbol_fail'], 0.0, 'Message too large: ' .. #message_data) + return + end + local request = { + rfc822 = { + ['Content-Type'] = 'message/rfc822', + data = message_data, + } + } + + local helo = task:get_helo() + if helo then + request['heloDomain'] = { + data = helo, + } + end + local mail_from = task:get_from('smtp') or {} + if mail_from[1] and #mail_from[1].addr > 1 then + request['mailFrom'] = { + data = mail_from[1].addr + } + end + + local rcpt_to = task:get_recipients('smtp') + if rcpt_to then + request['rcptTo'] = { + data = table.concat(fun.totable(fun.map(function(r) + return r.addr + end, rcpt_to)), ',') + } + end + + local fip = task:get_from_ip() + if fip and fip:is_valid() then + request['connIp'] = tostring(fip) + end + + local hostname = task:get_hostname() + if hostname then + request['fromHost'] = hostname + end + + local request_data = { + task = task, + url = url, + body = table_to_multipart_body(request, static_boundary), + headers = { + ['Content-Type'] = string.format('multipart/form-data; boundary="%s"', static_boundary) + }, + timeout = rule.timeout, + } + + local function cloudmark_callback(http_err, code, body, headers) + + local function cloudmark_requery() + -- set current upstream to fail because an error occurred + upstream:fail() + + -- retry with another upstream until retransmits exceeds + if retransmits > 0 then + + retransmits = retransmits - 1 + + lua_util.debugm(rule.name, task, + '%s: request Error: %s - retries left: %s', + rule.log_prefix, http_err, retransmits) + + -- Select a different upstream! + upstream = rule.upstreams:get_upstream_round_robin() + addr = upstream:get_addr() + url = cloudmark_url(rule, addr) + + lua_util.debugm(rule.name, task, '%s: retry IP: %s:%s', + rule.log_prefix, addr, addr:get_port()) + request_data.url = url + + http.request(request_data) + else + rspamd_logger.errx(task, '%s: failed to scan, maximum retransmits ' .. + 'exceed', rule.log_prefix) + task:insert_result(rule['symbol_fail'], 0.0, 'failed to scan and ' .. + 'retransmits exceed') + upstream:fail() + end + end + + if http_err then + cloudmark_requery() + else + -- Parse the response + if upstream then + upstream:ok() + end + if code ~= 200 then + rspamd_logger.errx(task, 'invalid HTTP code: %s, body: %s, headers: %s', code, body, headers) + task:insert_result(rule.symbol_fail, 1.0, 'Bad HTTP code: ' .. code) + return + end + parse_cloudmark_reply(task, rule, body) + end + end + + request_data.callback = cloudmark_callback + http.request(request_data) + end + + if common.condition_check_and_continue(task, content, rule, digest, + cloudmark_check_uncached, maybe_part) then + return + else + cloudmark_check_uncached() + end +end + +return { + type = { 'cloudmark', 'scanner' }, + description = 'Cloudmark cartridge interface', + configure = cloudmark_config, + check = cloudmark_check, + name = N, +} diff --git a/lualib/lua_scanners/common.lua b/lualib/lua_scanners/common.lua new file mode 100644 index 0000000..11f5e1f --- /dev/null +++ b/lualib/lua_scanners/common.lua @@ -0,0 +1,539 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> +Copyright (c) 2019, Carsten Rosenberg <c.rosenberg@heinlein-support.de> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_scanners_common +-- This module contains common external scanners functions +--]] + +local rspamd_logger = require "rspamd_logger" +local rspamd_regexp = require "rspamd_regexp" +local lua_util = require "lua_util" +local lua_redis = require "lua_redis" +local lua_magic_types = require "lua_magic/types" +local fun = require "fun" + +local exports = {} + +local function log_clean(task, rule, msg) + + msg = msg or 'message or mime_part is clean' + + if rule.log_clean then + rspamd_logger.infox(task, '%s: %s', rule.log_prefix, msg) + else + lua_util.debugm(rule.name, task, '%s: %s', rule.log_prefix, msg) + end + +end + +local function match_patterns(default_sym, found, patterns, dyn_weight) + if type(patterns) ~= 'table' then + return default_sym, dyn_weight + end + if not patterns[1] then + for sym, pat in pairs(patterns) do + if pat:match(found) then + return sym, '1' + end + end + return default_sym, dyn_weight + else + for _, p in ipairs(patterns) do + for sym, pat in pairs(p) do + if pat:match(found) then + return sym, '1' + end + end + end + return default_sym, dyn_weight + end +end + +local function yield_result(task, rule, vname, dyn_weight, is_fail, maybe_part) + local all_whitelisted = true + local patterns + local symbol + local threat_table + local threat_info + local flags + + if type(vname) == 'string' then + threat_table = { vname } + elseif type(vname) == 'table' then + threat_table = vname + end + + + -- This should be more generic + if not is_fail then + patterns = rule.patterns + symbol = rule.symbol + threat_info = rule.detection_category .. 'found' + if not dyn_weight then + dyn_weight = 1.0 + end + elseif is_fail == 'fail' then + patterns = rule.patterns_fail + symbol = rule.symbol_fail + threat_info = "FAILED with error" + dyn_weight = 0.0 + elseif is_fail == 'encrypted' then + patterns = rule.patterns + symbol = rule.symbol_encrypted + threat_info = "Scan has returned that input was encrypted" + dyn_weight = 1.0 + elseif is_fail == 'macro' then + patterns = rule.patterns + symbol = rule.symbol_macro + threat_info = "Scan has returned that input contains macros" + dyn_weight = 1.0 + end + + for _, tm in ipairs(threat_table) do + local symname, symscore = match_patterns(symbol, tm, patterns, dyn_weight) + if rule.whitelist and rule.whitelist:get_key(tm) then + rspamd_logger.infox(task, '%s: "%s" is in whitelist', rule.log_prefix, tm) + else + all_whitelisted = false + rspamd_logger.infox(task, '%s: result - %s: "%s - score: %s"', + rule.log_prefix, threat_info, tm, symscore) + + if maybe_part and rule.show_attachments and maybe_part:get_filename() then + local fname = maybe_part:get_filename() + task:insert_result(symname, symscore, string.format("%s|%s", + tm, fname)) + else + task:insert_result(symname, symscore, tm) + end + + end + end + + if rule.action and is_fail ~= 'fail' and not all_whitelisted then + threat_table = table.concat(threat_table, '; ') + if rule.action ~= 'reject' then + flags = 'least' + end + task:set_pre_result(rule.action, + lua_util.template(rule.message or 'Rejected', { + SCANNER = rule.name, + VIRUS = threat_table, + }), rule.name, nil, nil, flags) + end +end + +local function message_not_too_large(task, content, rule) + local max_size = tonumber(rule.max_size) + if not max_size then + return true + end + if #content > max_size then + rspamd_logger.infox(task, "skip %s check as it is too large: %s (%s is allowed)", + rule.log_prefix, #content, max_size) + return false + end + return true +end + +local function message_not_too_small(task, content, rule) + local min_size = tonumber(rule.min_size) + if not min_size then + return true + end + if #content < min_size then + rspamd_logger.infox(task, "skip %s check as it is too small: %s (%s is allowed)", + rule.log_prefix, #content, min_size) + return false + end + return true +end + +local function message_min_words(task, rule) + if rule.text_part_min_words and tonumber(rule.text_part_min_words) > 0 then + local text_part_above_limit = false + local text_parts = task:get_text_parts() + + local filter_func = function(p) + return p:get_words_count() >= tonumber(rule.text_part_min_words) + end + + fun.each(function(p) + text_part_above_limit = true + end, fun.filter(filter_func, text_parts)) + + if not text_part_above_limit then + rspamd_logger.infox(task, '%s: #words in all text parts is below text_part_min_words limit: %s', + rule.log_prefix, rule.text_part_min_words) + end + + return text_part_above_limit + else + return true + end +end + +local function dynamic_scan(task, rule) + if rule.dynamic_scan then + if rule.action ~= 'reject' then + local metric_result = task:get_metric_score() + local metric_action = task:get_metric_action() + local has_pre_result = task:has_pre_result() + -- ToDo: needed? + -- Sometimes leads to FPs + --if rule.symbol_type == 'postfilter' and metric_action == 'reject' then + -- rspamd_logger.infox(task, '%s: aborting: %s', rule.log_prefix, "result is already reject") + -- return false + --elseif metric_result[1] > metric_result[2]*2 then + if metric_result[1] > metric_result[2] * 2 then + rspamd_logger.infox(task, '%s: aborting: %s', rule.log_prefix, 'score > 2 * reject_level: ' .. metric_result[1]) + return false + elseif has_pre_result and metric_action == 'reject' then + rspamd_logger.infox(task, '%s: aborting: %s', rule.log_prefix, 'pre_result reject is set') + return false + else + return true, 'undecided' + end + else + return true, 'dynamic_scan is not possible with config `action=reject;`' + end + else + return true + end +end + +local function need_check(task, content, rule, digest, fn, maybe_part) + + local uncached = true + local key = digest + + local function redis_av_cb(err, data) + if data and type(data) == 'string' then + -- Cached + data = lua_util.str_split(data, '\t') + local threat_string = lua_util.str_split(data[1], '\v') + local score = data[2] or rule.default_score + + if threat_string[1] ~= 'OK' then + if threat_string[1] == 'MACRO' then + yield_result(task, rule, 'File contains macros', + 0.0, 'macro', maybe_part) + elseif threat_string[1] == 'ENCRYPTED' then + yield_result(task, rule, 'File is encrypted', + 0.0, 'encrypted', maybe_part) + else + lua_util.debugm(rule.name, task, '%s: got cached threat result for %s: %s - score: %s', + rule.log_prefix, key, threat_string[1], score) + yield_result(task, rule, threat_string, score, false, maybe_part) + end + + else + lua_util.debugm(rule.name, task, '%s: got cached negative result for %s: %s', + rule.log_prefix, key, threat_string[1]) + end + uncached = false + else + if err then + rspamd_logger.errx(task, 'got error checking cache: %s', err) + end + end + + local f_message_not_too_large = message_not_too_large(task, content, rule) + local f_message_not_too_small = message_not_too_small(task, content, rule) + local f_message_min_words = message_min_words(task, rule) + local f_dynamic_scan = dynamic_scan(task, rule) + + if uncached and + f_message_not_too_large and + f_message_not_too_small and + f_message_min_words and + f_dynamic_scan then + + fn() + + end + + end + + if rule.redis_params and not rule.no_cache then + + key = rule.prefix .. key + + if lua_redis.redis_make_request(task, + rule.redis_params, -- connect params + key, -- hash key + false, -- is write + redis_av_cb, --callback + 'GET', -- command + { key } -- arguments) + ) then + return true + end + end + + return false + +end + +local function save_cache(task, digest, rule, to_save, dyn_weight, maybe_part) + local key = digest + if not dyn_weight then + dyn_weight = 1.0 + end + + local function redis_set_cb(err) + -- Do nothing + if err then + rspamd_logger.errx(task, 'failed to save %s cache for %s -> "%s": %s', + rule.detection_category, to_save, key, err) + else + lua_util.debugm(rule.name, task, '%s: saved cached result for %s: %s - score %s - ttl %s', + rule.log_prefix, key, to_save, dyn_weight, rule.cache_expire) + end + end + + if type(to_save) == 'table' then + to_save = table.concat(to_save, '\v') + end + + local value_tbl = { to_save, dyn_weight } + if maybe_part and rule.show_attachments and maybe_part:get_filename() then + local fname = maybe_part:get_filename() + table.insert(value_tbl, fname) + end + local value = table.concat(value_tbl, '\t') + + if rule.redis_params and rule.prefix then + key = rule.prefix .. key + + lua_redis.redis_make_request(task, + rule.redis_params, -- connect params + key, -- hash key + true, -- is write + redis_set_cb, --callback + 'SETEX', -- command + { key, rule.cache_expire or 0, value } + ) + end + + return false +end + +local function create_regex_table(patterns) + local regex_table = {} + if patterns[1] then + for i, p in ipairs(patterns) do + if type(p) == 'table' then + local new_set = {} + for k, v in pairs(p) do + new_set[k] = rspamd_regexp.create_cached(v) + end + regex_table[i] = new_set + else + regex_table[i] = {} + end + end + else + for k, v in pairs(patterns) do + regex_table[k] = rspamd_regexp.create_cached(v) + end + end + return regex_table +end + +local function match_filter(task, rule, found, patterns, pat_type) + if type(patterns) ~= 'table' or not found then + return false + end + if not patterns[1] then + for _, pat in pairs(patterns) do + if pat_type == 'ext' and tostring(pat) == tostring(found) then + return true + elseif pat_type == 'regex' and pat:match(found) then + return true + end + end + return false + else + for _, p in ipairs(patterns) do + for _, pat in ipairs(p) do + if pat_type == 'ext' and tostring(pat) == tostring(found) then + return true + elseif pat_type == 'regex' and pat:match(found) then + return true + end + end + end + return false + end +end + +-- borrowed from mime_types.lua +-- ext is the last extension, LOWERCASED +-- ext2 is the one before last extension LOWERCASED +local function gen_extension(fname) + local filename_parts = lua_util.str_split(fname, '.') + + local ext = {} + for n = 1, 2 do + ext[n] = #filename_parts > n and string.lower(filename_parts[#filename_parts + 1 - n]) or nil + end + return ext[1], ext[2], filename_parts +end + +local function check_parts_match(task, rule) + + local filter_func = function(p) + local mtype, msubtype = p:get_type() + local detected_ext = p:get_detected_ext() + local fname = p:get_filename() + local ext, ext2 + + if rule.scan_all_mime_parts == false then + -- check file extension and filename regex matching + --lua_util.debugm(rule.name, task, '%s: filename: |%s|%s|', rule.log_prefix, fname) + if fname ~= nil then + ext, ext2 = gen_extension(fname) + --lua_util.debugm(rule.name, task, '%s: extension, fname: |%s|%s|%s|', rule.log_prefix, ext, ext2, fname) + if match_filter(task, rule, ext, rule.mime_parts_filter_ext, 'ext') + or match_filter(task, rule, ext2, rule.mime_parts_filter_ext, 'ext') then + lua_util.debugm(rule.name, task, '%s: extension matched: |%s|%s|', rule.log_prefix, ext, ext2) + return true + elseif match_filter(task, rule, fname, rule.mime_parts_filter_regex, 'regex') then + lua_util.debugm(rule.name, task, '%s: filename regex matched', rule.log_prefix) + return true + end + end + -- check content type string regex matching + if mtype ~= nil and msubtype ~= nil then + local ct = string.format('%s/%s', mtype, msubtype):lower() + if match_filter(task, rule, ct, rule.mime_parts_filter_regex, 'regex') then + lua_util.debugm(rule.name, task, '%s: regex content-type: %s', rule.log_prefix, ct) + return true + end + end + -- check detected content type (libmagic) regex matching + if detected_ext then + local magic = lua_magic_types[detected_ext] or {} + if match_filter(task, rule, detected_ext, rule.mime_parts_filter_ext, 'ext') then + lua_util.debugm(rule.name, task, '%s: detected extension matched: |%s|', rule.log_prefix, detected_ext) + return true + elseif magic.ct and match_filter(task, rule, magic.ct, rule.mime_parts_filter_regex, 'regex') then + lua_util.debugm(rule.name, task, '%s: regex detected libmagic content-type: %s', + rule.log_prefix, magic.ct) + return true + end + end + -- check filenames in archives + if p:is_archive() then + local arch = p:get_archive() + local filelist = arch:get_files_full(1000) + for _, f in ipairs(filelist) do + ext, ext2 = gen_extension(f.name) + if match_filter(task, rule, ext, rule.mime_parts_filter_ext, 'ext') + or match_filter(task, rule, ext2, rule.mime_parts_filter_ext, 'ext') then + lua_util.debugm(rule.name, task, '%s: extension matched in archive: |%s|%s|', rule.log_prefix, ext, ext2) + --lua_util.debugm(rule.name, task, '%s: extension matched in archive: %s', rule.log_prefix, ext) + return true + elseif match_filter(task, rule, f.name, rule.mime_parts_filter_regex, 'regex') then + lua_util.debugm(rule.name, task, '%s: filename regex matched in archive', rule.log_prefix) + return true + end + end + end + end + + -- check text_part has more words than text_part_min_words_check + if rule.scan_text_mime and rule.text_part_min_words and p:is_text() and + p:get_words_count() >= tonumber(rule.text_part_min_words) then + return true + end + + if rule.scan_image_mime and p:is_image() then + return true + end + + if rule.scan_all_mime_parts ~= false then + local is_part_checkable = (p:is_attachment() and (not p:is_image() or rule.scan_image_mime)) + if detected_ext then + -- We know what to scan! + local magic = lua_magic_types[detected_ext] or {} + + if magic.av_check ~= false or is_part_checkable then + return true + end + elseif is_part_checkable then + -- Just rely on attachment property + return true + end + end + + return false + end + + return fun.filter(filter_func, task:get_parts()) +end + +local function check_metric_results(task, rule) + + if rule.action ~= 'reject' then + local metric_result = task:get_metric_score() + local metric_action = task:get_metric_action() + local has_pre_result = task:has_pre_result() + + if rule.symbol_type == 'postfilter' and metric_action == 'reject' then + return true, 'result is already reject' + elseif metric_result[1] > metric_result[2] * 2 then + return true, 'score > 2 * reject_level: ' .. metric_result[1] + elseif has_pre_result and metric_action == 'reject' then + return true, 'pre_result reject is set' + else + return false, 'undecided' + end + else + return false, 'dynamic_scan is not possible with config `action=reject;`' + end +end + +exports.log_clean = log_clean +exports.yield_result = yield_result +exports.match_patterns = match_patterns +exports.condition_check_and_continue = need_check +exports.save_cache = save_cache +exports.create_regex_table = create_regex_table +exports.check_parts_match = check_parts_match +exports.check_metric_results = check_metric_results + +setmetatable(exports, { + __call = function(t, override) + for k, v in pairs(t) do + if _G[k] ~= nil then + local msg = 'function ' .. k .. ' already exists in global scope.' + if override then + _G[k] = v + print('WARNING: ' .. msg .. ' Overwritten.') + else + print('NOTICE: ' .. msg .. ' Skipped.') + end + else + _G[k] = v + end + end + end, +}) + +return exports diff --git a/lualib/lua_scanners/dcc.lua b/lualib/lua_scanners/dcc.lua new file mode 100644 index 0000000..8d5e9e1 --- /dev/null +++ b/lualib/lua_scanners/dcc.lua @@ -0,0 +1,313 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> +Copyright (c) 2018, Carsten Rosenberg <c.rosenberg@heinlein-support.de> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module dcc +-- This module contains dcc access functions +--]] + +local lua_util = require "lua_util" +local tcp = require "rspamd_tcp" +local upstream_list = require "rspamd_upstream_list" +local rspamd_logger = require "rspamd_logger" +local common = require "lua_scanners/common" +local fun = require "fun" + +local N = 'dcc' + +local function dcc_config(opts) + + local dcc_conf = { + name = N, + default_port = 10045, + timeout = 5.0, + log_clean = false, + retransmits = 2, + cache_expire = 7200, -- expire redis in 2h + message = '${SCANNER}: bulk message found: "${VIRUS}"', + detection_category = "hash", + default_score = 1, + action = false, + client = '0.0.0.0', + symbol_fail = 'DCC_FAIL', + symbol = 'DCC_REJECT', + symbol_bulk = 'DCC_BULK', + body_max = 999999, + fuz1_max = 999999, + fuz2_max = 999999, + } + + dcc_conf = lua_util.override_defaults(dcc_conf, opts) + + if not dcc_conf.prefix then + dcc_conf.prefix = 'rs_' .. dcc_conf.name .. '_' + end + + if not dcc_conf.log_prefix then + dcc_conf.log_prefix = dcc_conf.name + end + + if not dcc_conf.servers and dcc_conf.socket then + dcc_conf.servers = dcc_conf.socket + end + + if not dcc_conf.servers then + rspamd_logger.errx(rspamd_config, 'no servers defined') + + return nil + end + + dcc_conf.upstreams = upstream_list.create(rspamd_config, + dcc_conf.servers, + dcc_conf.default_port) + + if dcc_conf.upstreams then + lua_util.add_debug_alias('external_services', dcc_conf.name) + return dcc_conf + end + + rspamd_logger.errx(rspamd_config, 'cannot parse servers %s', + dcc_conf['servers']) + return nil +end + +local function dcc_check(task, content, digest, rule) + local function dcc_check_uncached () + local upstream = rule.upstreams:get_upstream_round_robin() + local addr = upstream:get_addr() + local retransmits = rule.retransmits + local client = rule.client + + local client_ip = task:get_from_ip() + if client_ip and client_ip:is_valid() then + client = client_ip:to_string() + end + local client_host = task:get_hostname() + if client_host then + client = client .. "\r" .. client_host + end + + -- HELO + local helo = task:get_helo() or '' + + -- Envelope From + local ef = task:get_from() + local envfrom = 'test@example.com' + if ef and ef[1] then + envfrom = ef[1]['addr'] + end + + -- Envelope To + local envrcpt = 'test@example.com' + local rcpts = task:get_recipients(); + if rcpts then + local dcc_recipients = table.concat(fun.totable(fun.map(function(rcpt) + return rcpt['addr'] + end, + rcpts)), '\n') + if dcc_recipients then + envrcpt = dcc_recipients + end + end + + -- Build the DCC query + -- https://www.dcc-servers.net/dcc/dcc-tree/dccifd.html#Protocol + local request_data = { + "header\n", + client .. "\n", + helo .. "\n", + envfrom .. "\n", + envrcpt .. "\n", + "\n", + content + } + + local function dcc_callback(err, data, conn) + + local function dcc_requery() + -- retry with another upstream until retransmits exceeds + if retransmits > 0 then + + retransmits = retransmits - 1 + + -- Select a different upstream! + upstream = rule.upstreams:get_upstream_round_robin() + addr = upstream:get_addr() + + lua_util.debugm(rule.name, task, '%s: error: %s; retry IP: %s; retries left: %s', + rule.log_prefix, err, addr, retransmits) + + tcp.request({ + task = task, + host = addr:to_string(), + port = addr:get_port(), + timeout = rule.timeout or 2.0, + upstream = upstream, + shutdown = true, + data = request_data, + callback = dcc_callback, + body_max = 999999, + fuz1_max = 999999, + fuz2_max = 999999, + }) + else + rspamd_logger.errx(task, '%s: failed to scan, maximum retransmits ' .. + 'exceed', rule.log_prefix) + common.yield_result(task, rule, 'failed to scan and retransmits exceed', 0.0, 'fail') + end + end + + if err then + + dcc_requery() + + else + -- Parse the response + local _, _, result, disposition, header = tostring(data):find("(.-)\n(.-)\n(.-)$") + lua_util.debugm(rule.name, task, 'DCC result=%1 disposition=%2 header="%3"', + result, disposition, header) + + if header then + -- Unfold header + header = header:gsub('\r?\n%s*', ' ') + local _, _, info = header:find("; (.-)$") + if (result == 'R') then + -- Reject + common.yield_result(task, rule, info, rule.default_score) + common.save_cache(task, digest, rule, info, rule.default_score) + elseif (result == 'T') then + -- Temporary failure + rspamd_logger.warnx(task, 'DCC returned a temporary failure result: %s', result) + dcc_requery() + elseif result == 'A' then + + local opts = {} + local score = 0.0 + info = info:lower() + local rep = info:match('rep=([^=%s]+)') + + -- Adjust reputation if available + if rep then + rep = tonumber(rep) + end + if not rep then + rep = 1.0 + end + + local function check_threshold(what, num, lim) + local rnum + if num == 'many' then + rnum = lim + else + rnum = tonumber(num) + end + + if rnum and rnum >= lim then + opts[#opts + 1] = string.format('%s=%s', what, num) + score = score + rep / 3.0 + end + end + + info = info:lower() + local body = info:match('body=([^=%s]+)') + + if body then + check_threshold('body', body, rule.body_max) + end + + local fuz1 = info:match('fuz1=([^=%s]+)') + + if fuz1 then + check_threshold('fuz1', fuz1, rule.fuz1_max) + end + + local fuz2 = info:match('fuz2=([^=%s]+)') + + if fuz2 then + check_threshold('fuz2', fuz2, rule.fuz2_max) + end + + if #opts > 0 and score > 0 then + task:insert_result(rule.symbol_bulk, + score, + opts) + common.save_cache(task, digest, rule, opts, score) + else + common.save_cache(task, digest, rule, 'OK') + if rule.log_clean then + rspamd_logger.infox(task, '%s: clean, returned result A - info: %s', + rule.log_prefix, info) + else + lua_util.debugm(rule.name, task, '%s: returned result A - info: %s', + rule.log_prefix, info) + end + end + elseif result == 'G' then + -- do nothing + common.save_cache(task, digest, rule, 'OK') + if rule.log_clean then + rspamd_logger.infox(task, '%s: clean, returned result G - info: %s', rule.log_prefix, info) + else + lua_util.debugm(rule.name, task, '%s: returned result G - info: %s', rule.log_prefix, info) + end + elseif result == 'S' then + -- do nothing + common.save_cache(task, digest, rule, 'OK') + if rule.log_clean then + rspamd_logger.infox(task, '%s: clean, returned result S - info: %s', rule.log_prefix, info) + else + lua_util.debugm(rule.name, task, '%s: returned result S - info: %s', rule.log_prefix, info) + end + else + -- Unknown result + rspamd_logger.warnx(task, '%s: result error: %1', rule.log_prefix, result); + common.yield_result(task, rule, 'error: ' .. result, 0.0, 'fail') + end + end + end + end + + tcp.request({ + task = task, + host = addr:to_string(), + port = addr:get_port(), + timeout = rule.timeout or 2.0, + shutdown = true, + upstream = upstream, + data = request_data, + callback = dcc_callback, + body_max = 999999, + fuz1_max = 999999, + fuz2_max = 999999, + }) + end + + if common.condition_check_and_continue(task, content, rule, digest, dcc_check_uncached) then + return + else + dcc_check_uncached() + end + +end + +return { + type = { 'dcc', 'bulk', 'hash', 'scanner' }, + description = 'dcc bulk scanner', + configure = dcc_config, + check = dcc_check, + name = N +} diff --git a/lualib/lua_scanners/fprot.lua b/lualib/lua_scanners/fprot.lua new file mode 100644 index 0000000..5a469c3 --- /dev/null +++ b/lualib/lua_scanners/fprot.lua @@ -0,0 +1,181 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module fprot +-- This module contains fprot access functions +--]] + +local lua_util = require "lua_util" +local tcp = require "rspamd_tcp" +local upstream_list = require "rspamd_upstream_list" +local rspamd_logger = require "rspamd_logger" +local common = require "lua_scanners/common" + +local N = "fprot" + +local default_message = '${SCANNER}: virus found: "${VIRUS}"' + +local function fprot_config(opts) + local fprot_conf = { + name = N, + scan_mime_parts = true, + scan_text_mime = false, + scan_image_mime = false, + default_port = 10200, + timeout = 5.0, -- FIXME: this will break task_timeout! + log_clean = false, + detection_category = "virus", + retransmits = 2, + cache_expire = 3600, -- expire redis in one hour + message = default_message, + } + + fprot_conf = lua_util.override_defaults(fprot_conf, opts) + + if not fprot_conf.prefix then + fprot_conf.prefix = 'rs_' .. fprot_conf.name .. '_' + end + + if not fprot_conf.log_prefix then + if fprot_conf.name:lower() == fprot_conf.type:lower() then + fprot_conf.log_prefix = fprot_conf.name + else + fprot_conf.log_prefix = fprot_conf.name .. ' (' .. fprot_conf.type .. ')' + end + end + + if not fprot_conf['servers'] then + rspamd_logger.errx(rspamd_config, 'no servers defined') + + return nil + end + + fprot_conf['upstreams'] = upstream_list.create(rspamd_config, + fprot_conf['servers'], + fprot_conf.default_port) + + if fprot_conf['upstreams'] then + lua_util.add_debug_alias('antivirus', fprot_conf.name) + return fprot_conf + end + + rspamd_logger.errx(rspamd_config, 'cannot parse servers %s', + fprot_conf['servers']) + return nil +end + +local function fprot_check(task, content, digest, rule, maybe_part) + local function fprot_check_uncached () + local upstream = rule.upstreams:get_upstream_round_robin() + local addr = upstream:get_addr() + local retransmits = rule.retransmits + local scan_id = task:get_queue_id() + if not scan_id then + scan_id = task:get_uid() + end + local header = string.format('SCAN STREAM %s SIZE %d\n', scan_id, + #content) + local footer = '\n' + + local function fprot_callback(err, data) + if err then + -- retry with another upstream until retransmits exceeds + if retransmits > 0 then + + retransmits = retransmits - 1 + + -- Select a different upstream! + upstream = rule.upstreams:get_upstream_round_robin() + addr = upstream:get_addr() + + lua_util.debugm(rule.name, task, '%s: error: %s; retry IP: %s; retries left: %s', + rule.log_prefix, err, addr, retransmits) + + tcp.request({ + task = task, + host = addr:to_string(), + port = addr:get_port(), + upstream = upstream, + timeout = rule['timeout'], + callback = fprot_callback, + data = { header, content, footer }, + stop_pattern = '\n' + }) + else + rspamd_logger.errx(task, + '%s [%s]: failed to scan, maximum retransmits exceed', + rule['symbol'], rule['type']) + common.yield_result(task, rule, 'failed to scan and retransmits exceed', + 0.0, 'fail', maybe_part) + end + else + upstream:ok() + data = tostring(data) + local cached + local clean = string.match(data, '^0 <clean>') + if clean then + cached = 'OK' + if rule['log_clean'] then + rspamd_logger.infox(task, + '%s [%s]: message or mime_part is clean', + rule['symbol'], rule['type']) + end + else + -- returncodes: 1: infected, 2: suspicious, 3: both, 4-255: some error occurred + -- see http://www.f-prot.com/support/helpfiles/unix/appendix_c.html for more detail + local vname = string.match(data, '^[1-3] <[%w%s]-: (.-)>') + if not vname then + rspamd_logger.errx(task, 'Unhandled response: %s', data) + else + common.yield_result(task, rule, vname, 1.0, nil, maybe_part) + cached = vname + end + end + if cached then + common.save_cache(task, digest, rule, cached, 1.0, maybe_part) + end + end + end + + tcp.request({ + task = task, + host = addr:to_string(), + port = addr:get_port(), + upstream = upstream, + timeout = rule['timeout'], + callback = fprot_callback, + data = { header, content, footer }, + stop_pattern = '\n' + }) + end + + if common.condition_check_and_continue(task, content, rule, digest, + fprot_check_uncached, maybe_part) then + return + else + fprot_check_uncached() + end + +end + +return { + type = 'antivirus', + description = 'fprot antivirus', + configure = fprot_config, + check = fprot_check, + name = N +} diff --git a/lualib/lua_scanners/icap.lua b/lualib/lua_scanners/icap.lua new file mode 100644 index 0000000..682562d --- /dev/null +++ b/lualib/lua_scanners/icap.lua @@ -0,0 +1,713 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> +Copyright (c) 2019, Carsten Rosenberg <c.rosenberg@heinlein-support.de> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[ +@module icap +This module contains icap access functions. +Currently tested with + - C-ICAP Squidclamav / echo + - Checkpoint Sandblast + - F-Secure Internet Gatekeeper + - Kaspersky Web Traffic Security + - Kaspersky Scan Engine 2.0 + - McAfee Web Gateway 9/10/11 + - Sophos Savdi + - Symantec (Rspamd <3.2, >=3.2 untested) + - Trend Micro IWSVA 6.0 + - Trend Micro Web Gateway + +@TODO + - Preview / Continue + - Reqmod URL's + - Content-Type / Filename +]] -- + +--[[ +Configuration Notes: + +C-ICAP Squidclamav + scheme = "squidclamav"; + +Checkpoint Sandblast example: + scheme = "sandblast"; + +ESET Gateway Security / Antivirus for Linux example: + scheme = "scan"; + +F-Secure Internet Gatekeeper example: + scheme = "respmod"; + x_client_header = true; + x_rcpt_header = true; + x_from_header = true; + +Kaspersky Web Traffic Security example: + scheme = "av/respmod"; + x_client_header = true; + +Kaspersky Web Traffic Security (as configured in kavicapd.xml): + scheme = "resp"; + x_client_header = true; + +McAfee Web Gateway 10/11 (Headers must be activated with personal extra Rules) + scheme = "respmod"; + x_client_header = true; + +Sophos SAVDI example: + # scheme as configured in savdi.conf (name option in service section) + scheme = "respmod"; + +Symantec example: + scheme = "avscan"; + +Trend Micro IWSVA example (X-Virus-ID/X-Infection-Found headers must be activated): + scheme = "avscan"; + x_client_header = true; + +Trend Micro Web Gateway example (X-Virus-ID/X-Infection-Found headers must be activated): + scheme = "interscan"; + x_client_header = true; +]] -- + + +local lua_util = require "lua_util" +local tcp = require "rspamd_tcp" +local upstream_list = require "rspamd_upstream_list" +local rspamd_logger = require "rspamd_logger" +local common = require "lua_scanners/common" +local rspamd_util = require "rspamd_util" +local rspamd_version = rspamd_version + +local N = 'icap' + +local function icap_config(opts) + + local icap_conf = { + name = N, + scan_mime_parts = true, + scan_all_mime_parts = true, + scan_text_mime = false, + scan_image_mime = false, + scheme = "scan", + default_port = 1344, + ssl = false, + no_ssl_verify = false, + timeout = 10.0, + log_clean = false, + retransmits = 2, + cache_expire = 7200, -- expire redis in one hour + message = '${SCANNER}: threat found with icap scanner: "${VIRUS}"', + detection_category = "virus", + default_score = 1, + action = false, + dynamic_scan = false, + user_agent = "Rspamd", + x_client_header = false, + x_rcpt_header = false, + x_from_header = false, + req_headers_enabled = true, + req_fake_url = "http://127.0.0.1/mail", + http_headers_enabled = true, + use_http_result_header = true, + use_http_3xx_as_threat = false, + use_specific_content_type = false, -- Use content type from a part where possible + } + + icap_conf = lua_util.override_defaults(icap_conf, opts) + + if not icap_conf.prefix then + icap_conf.prefix = 'rs_' .. icap_conf.name .. '_' + end + + if not icap_conf.log_prefix then + icap_conf.log_prefix = icap_conf.name .. ' (' .. icap_conf.type .. ')' + end + + if not icap_conf.log_prefix then + if icap_conf.name:lower() == icap_conf.type:lower() then + icap_conf.log_prefix = icap_conf.name + else + icap_conf.log_prefix = icap_conf.name .. ' (' .. icap_conf.type .. ')' + end + end + + if not icap_conf.servers then + rspamd_logger.errx(rspamd_config, 'no servers defined') + + return nil + end + + icap_conf.upstreams = upstream_list.create(rspamd_config, + icap_conf.servers, + icap_conf.default_port) + + if icap_conf.upstreams then + lua_util.add_debug_alias('external_services', icap_conf.name) + return icap_conf + end + + rspamd_logger.errx(rspamd_config, 'cannot parse servers %s', + icap_conf.servers) + return nil +end + +local function icap_check(task, content, digest, rule, maybe_part) + local function icap_check_uncached () + local upstream = rule.upstreams:get_upstream_round_robin() + local addr = upstream:get_addr() + local retransmits = rule.retransmits + local http_headers = {} + local req_headers = {} + local tcp_options = {} + local threat_table = {} + + -- Build extended User Agent + if rule.user_agent == "extended" then + rule.user_agent = string.format("Rspamd/%s-%s (%s/%s)", + rspamd_version('main'), + rspamd_version('id'), + rspamd_util.get_hostname(), + string.sub(task:get_uid(), 1, 6)) + end + + -- Build the icap queries + local options_request = { + string.format("OPTIONS icap://%s/%s ICAP/1.0\r\n", addr:to_string(), rule.scheme), + string.format('Host: %s\r\n', addr:to_string()), + string.format("User-Agent: %s\r\n", rule.user_agent), + "Connection: keep-alive\r\n", + "Encapsulated: null-body=0\r\n\r\n", + } + if rule.user_agent == "none" then + table.remove(options_request, 3) + end + + local respond_headers = { + -- Add main RESPMOD header before any other + string.format('RESPMOD icap://%s/%s ICAP/1.0\r\n', addr:to_string(), rule.scheme), + string.format('Host: %s\r\n', addr:to_string()), + } + + local size = tonumber(#content) + local chunked_size = string.format("%x", size) + + local function icap_callback(err, conn) + + local function icap_requery(err_m, info) + -- retry with another upstream until retransmits exceeds + if retransmits > 0 then + + retransmits = retransmits - 1 + + lua_util.debugm(rule.name, task, + '%s: %s Request Error: %s - retries left: %s', + rule.log_prefix, info, err_m, retransmits) + + -- Select a different upstream! + upstream = rule.upstreams:get_upstream_round_robin() + addr = upstream:get_addr() + + lua_util.debugm(rule.name, task, '%s: retry IP: %s:%s', + rule.log_prefix, addr, addr:get_port()) + + tcp_options.host = addr:to_string() + tcp_options.port = addr:get_port() + tcp_options.callback = icap_callback + tcp_options.data = options_request + tcp_options.upstream = upstream + + tcp.request(tcp_options) + + else + rspamd_logger.errx(task, '%s: failed to scan, maximum retransmits ' .. + 'exceed - error: %s', rule.log_prefix, err_m or '') + common.yield_result(task, rule, string.format('failed - error: %s', err_m), + 0.0, 'fail', maybe_part) + end + end + + local function get_req_headers() + + local in_client_ip = task:get_from_ip() + local req_hlen = 2 + if maybe_part then + table.insert(req_headers, + string.format('GET http://%s/%s HTTP/1.0\r\n', in_client_ip, maybe_part:get_filename())) + if rule.use_specific_content_type then + table.insert(http_headers, string.format('Content-Type: %s/%s\r\n', maybe_part:get_detected_type())) + --else + -- To test: what content type is better for icap servers? + --table.insert(http_headers, 'Content-Type: text/plain\r\n') + end + else + table.insert(req_headers, string.format('GET %s HTTP/1.0\r\n', rule.req_fake_url)) + table.insert(http_headers, string.format('Content-Type: application/octet-stream\r\n')) + end + table.insert(req_headers, string.format('Date: %s\r\n', rspamd_util.time_to_string(rspamd_util.get_time()))) + if rule.user_agent ~= "none" then + table.insert(req_headers, string.format("User-Agent: %s\r\n", rule.user_agent)) + end + + for _, h in ipairs(req_headers) do + req_hlen = req_hlen + tonumber(#h) + end + + return req_hlen, req_headers + + end + + local function get_http_headers() + local http_hlen = 2 + table.insert(http_headers, 'HTTP/1.0 200 OK\r\n') + table.insert(http_headers, string.format('Date: %s\r\n', rspamd_util.time_to_string(rspamd_util.get_time()))) + table.insert(http_headers, string.format('Server: %s\r\n', 'Apache/2.4')) + if rule.user_agent ~= "none" then + table.insert(http_headers, string.format("User-Agent: %s\r\n", rule.user_agent)) + end + --table.insert(http_headers, string.format('Content-Type: %s\r\n', 'text/html')) + table.insert(http_headers, string.format('Content-Length: %s\r\n', size)) + + for _, h in ipairs(http_headers) do + http_hlen = http_hlen + tonumber(#h) + end + + return http_hlen, http_headers + + end + + local function get_respond_query() + local req_hlen = 0 + local resp_req_headers + local http_hlen = 0 + local resp_http_headers + + -- Append all extra headers + if rule.user_agent ~= "none" then + table.insert(respond_headers, + string.format("User-Agent: %s\r\n", rule.user_agent)) + end + + if rule.req_headers_enabled then + req_hlen, resp_req_headers = get_req_headers() + end + if rule.http_headers_enabled then + http_hlen, resp_http_headers = get_http_headers() + end + + if rule.req_headers_enabled and rule.http_headers_enabled then + local res_body_hlen = req_hlen + http_hlen + table.insert(respond_headers, + string.format('Encapsulated: req-hdr=0, res-hdr=%s, res-body=%s\r\n', + req_hlen, res_body_hlen)) + elseif rule.http_headers_enabled then + table.insert(respond_headers, + string.format('Encapsulated: res-hdr=0, res-body=%s\r\n', + http_hlen)) + else + table.insert(respond_headers, 'Encapsulated: res-body=0\r\n') + end + + table.insert(respond_headers, '\r\n') + for _, h in ipairs(resp_req_headers) do + table.insert(respond_headers, h) + end + table.insert(respond_headers, '\r\n') + for _, h in ipairs(resp_http_headers) do + table.insert(respond_headers, h) + end + table.insert(respond_headers, '\r\n') + table.insert(respond_headers, chunked_size .. '\r\n') + table.insert(respond_headers, content) + table.insert(respond_headers, '\r\n0\r\n\r\n') + return respond_headers + end + + local function add_respond_header(name, value) + if name and value then + table.insert(respond_headers, string.format('%s: %s\r\n', name, value)) + end + end + + local function result_header_table(result) + local icap_headers = {} + for s in result:gmatch("[^\r\n]+") do + if string.find(s, '^ICAP') then + icap_headers['icap'] = tostring(s) + elseif string.find(s, '^HTTP') then + icap_headers['http'] = tostring(s) + elseif string.find(s, '[%a%d-+]-:') then + local _, _, key, value = tostring(s):find("([%a%d-+]-):%s?(.+)") + if key ~= nil then + icap_headers[key:lower()] = tostring(value) + end + end + end + lua_util.debugm(rule.name, task, '%s: icap_headers: %s', + rule.log_prefix, icap_headers) + return icap_headers + end + + local function threat_table_add(icap_threat, maybe_split) + + if maybe_split and string.find(icap_threat, ',') then + local threats = lua_util.str_split(string.gsub(icap_threat, "%s", ""), ',') or {} + + for _, v in ipairs(threats) do + table.insert(threat_table, v) + end + else + table.insert(threat_table, icap_threat) + end + return true + end + + local function icap_parse_result(headers) + + --[[ + @ToDo: handle type in response + + Generic Strings: + icap: X-Infection-Found: Type=0; Resolution=2; Threat=Troj/DocDl-OYC; + icap: X-Infection-Found: Type=0; Resolution=2; Threat=W97M.Downloader; + + Symantec String: + icap: X-Infection-Found: Type=2; Resolution=2; Threat=Container size violation + icap: X-Infection-Found: Type=2; Resolution=2; Threat=Encrypted container violation; + + Sophos Strings: + icap: X-Virus-ID: Troj/DocDl-OYC + http: X-Blocked: Virus found during virus scan + http: X-Blocked-By: Sophos Anti-Virus + + Kaspersky Web Traffic Security Strings: + icap: X-Virus-ID: HEUR:Backdoor.Java.QRat.gen + icap: X-Response-Info: blocked + icap: X-Virus-ID: no threats + icap: X-Response-Info: blocked + icap: X-Response-Info: passed + http: HTTP/1.1 403 Forbidden + + Kaspersky Scan Engine 2.0 (ICAP mode) + icap: X-Virus-ID: EICAR-Test-File + http: HTTP/1.0 403 Forbidden + + Trend Micro Strings: + icap: X-Virus-ID: Trojan.W97M.POWLOAD.SMTHF1 + icap: X-Infection-Found: Type=0; Resolution=2; Threat=Trojan.W97M.POWLOAD.SMTHF1; + http: HTTP/1.1 403 Forbidden (TMWS Blocked) + http: HTTP/1.1 403 Forbidden + + F-Secure Internet Gatekeeper Strings: + icap: X-FSecure-Scan-Result: infected + icap: X-FSecure-Infection-Name: "Malware.W97M/Agent.32584203" + icap: X-FSecure-Infected-Filename: "virus.doc" + + ESET File Security for Linux 7.0 + icap: X-Infection-Found: Type=0; Resolution=0; Threat=VBA/TrojanDownloader.Agent.JOA; + icap: X-Virus-ID: Trojaner + icap: X-Response-Info: Blocked + + McAfee Web Gateway 10/11 (Headers must be activated with personal extra Rules) + icap: X-Virus-ID: EICAR test file + icap: X-Media-Type: text/plain + icap: X-Block-Result: 80 + icap: X-Block-Reason: Malware found + icap: X-Block-Reason: Archive not supported + icap: X-Block-Reason: Media Type (Block List) + http: HTTP/1.0 403 VirusFound + + C-ICAP Squidclamav + icap/http: X-Infection-Found: Type=0; Resolution=2; Threat={HEX}EICAR.TEST.3.UNOFFICIAL; + icap/http: X-Virus-ID: {HEX}EICAR.TEST.3.UNOFFICIAL + http: HTTP/1.0 307 Temporary Redirect + ]] -- + + -- Generic ICAP Headers + if headers['x-infection-found'] then + local _, _, icap_type, _, icap_threat = headers['x-infection-found']:find("Type=(.-); Resolution=(.-); Threat=(.-);$") + + -- Type=2 is typical for scan error returns + if icap_type and icap_type == '2' then + lua_util.debugm(rule.name, task, + '%s: icap error X-Infection-Found: %s', rule.log_prefix, icap_threat) + common.yield_result(task, rule, icap_threat, 0, + 'fail', maybe_part) + return true + elseif icap_threat ~= nil then + lua_util.debugm(rule.name, task, + '%s: icap X-Infection-Found: %s', rule.log_prefix, icap_threat) + threat_table_add(icap_threat, false) + -- stupid workaround for unuseable x-infection-found header + -- but also x-virus-name set (McAfee Web Gateway 9) + elseif not icap_threat and headers['x-virus-name'] then + threat_table_add(headers['x-virus-name'], true) + else + threat_table_add(headers['x-infection-found'], true) + end + elseif headers['x-virus-name'] and headers['x-virus-name'] ~= "no threats" then + lua_util.debugm(rule.name, task, + '%s: icap X-Virus-Name: %s', rule.log_prefix, headers['x-virus-name']) + threat_table_add(headers['x-virus-name'], true) + elseif headers['x-virus-id'] and headers['x-virus-id'] ~= "no threats" then + lua_util.debugm(rule.name, task, + '%s: icap X-Virus-ID: %s', rule.log_prefix, headers['x-virus-id']) + threat_table_add(headers['x-virus-id'], true) + -- FSecure X-Headers + elseif headers['x-fsecure-scan-result'] and headers['x-fsecure-scan-result'] ~= "clean" then + + local infected_filename = "" + local infection_name = "-unknown-" + + if headers['x-fsecure-infected-filename'] then + infected_filename = string.gsub(headers['x-fsecure-infected-filename'], '[%s"]', '') + end + if headers['x-fsecure-infection-name'] then + infection_name = string.gsub(headers['x-fsecure-infection-name'], '[%s"]', '') + end + + lua_util.debugm(rule.name, task, + '%s: icap X-FSecure-Infection-Name (X-FSecure-Infected-Filename): %s (%s)', + rule.log_prefix, infection_name, infected_filename) + + threat_table_add(infection_name, true) + -- McAfee Web Gateway manual extra headers + elseif headers['x-mwg-block-reason'] and headers['x-mwg-block-reason'] ~= "" then + threat_table_add(headers['x-mwg-block-reason'], false) + -- Sophos SAVDI special http headers + elseif headers['x-blocked'] and headers['x-blocked'] ~= "" then + threat_table_add(headers['x-blocked'], false) + elseif headers['x-block-reason'] and headers['x-block-reason'] ~= "" then + threat_table_add(headers['x-block-reason'], false) + -- last try HTTP [4]xx return + elseif headers.http and string.find(headers.http, '^HTTP%/[12]%.. [4]%d%d') then + threat_table_add( + string.format("pseudo-virus (blocked): %s", string.gsub(headers.http, 'HTTP%/[12]%.. ', '')), false) + elseif rule.use_http_3xx_as_threat and + headers.http and + string.find(headers.http, '^HTTP%/[12]%.. [3]%d%d') + then + threat_table_add( + string.format("pseudo-virus (redirect): %s", + string.gsub(headers.http, 'HTTP%/[12]%.. ', '')), false) + end + + if #threat_table > 0 then + common.yield_result(task, rule, threat_table, rule.default_score, nil, maybe_part) + common.save_cache(task, digest, rule, threat_table, rule.default_score, maybe_part) + return true + else + return false + end + end + + local function icap_r_respond_http_cb(err_m, data, connection) + if err_m or connection == nil then + icap_requery(err_m, "icap_r_respond_http_cb") + else + local result = tostring(data) + + local icap_http_headers = result_header_table(result) or {} + -- Find HTTP/[12].x [234]xx response + if icap_http_headers.http and string.find(icap_http_headers.http, 'HTTP%/[12]%.. [234]%d%d') then + local icap_http_header_result = icap_parse_result(icap_http_headers) + if icap_http_header_result then + -- Threat found - close connection + connection:close() + else + common.save_cache(task, digest, rule, 'OK', 0, maybe_part) + common.log_clean(task, rule) + end + else + rspamd_logger.errx(task, '%s: unhandled response |%s|', + rule.log_prefix, string.gsub(result, "\r\n", ", ")) + common.yield_result(task, rule, string.format('unhandled icap response: %s', icap_http_headers.icap), + 0.0, 'fail', maybe_part) + end + end + end + + local function icap_r_respond_cb(err_m, data, connection) + if err_m or connection == nil then + icap_requery(err_m, "icap_r_respond_cb") + else + local result = tostring(data) + + local icap_headers = result_header_table(result) or {} + -- Find ICAP/1.x 2xx response + if icap_headers.icap and string.find(icap_headers.icap, 'ICAP%/1%.. 2%d%d') then + local icap_header_result = icap_parse_result(icap_headers) + if icap_header_result then + -- Threat found - close connection + connection:close() + elseif not icap_header_result + and rule.use_http_result_header + and icap_headers.encapsulated + and not string.find(icap_headers.encapsulated, 'null%-body=0') + then + -- Try to read encapsulated HTTP Headers + lua_util.debugm(rule.name, task, '%s: no ICAP virus header found - try HTTP headers', + rule.log_prefix) + connection:add_read(icap_r_respond_http_cb, '\r\n\r\n') + else + connection:close() + common.save_cache(task, digest, rule, 'OK', 0, maybe_part) + common.log_clean(task, rule) + end + elseif icap_headers.icap and string.find(icap_headers.icap, 'ICAP%/1%.. [45]%d%d') then + -- Find ICAP/1.x 5/4xx response + --[[ + Symantec String: + ICAP/1.0 539 Aborted - No AV scanning license + SquidClamAV/C-ICAP: + ICAP/1.0 500 Server error + Eset: + ICAP/1.0 405 Forbidden + TrendMicro: + ICAP/1.0 400 Bad request + McAfee: + ICAP/1.0 418 Bad composition + ]]-- + rspamd_logger.errx(task, '%s: ICAP ERROR: %s', rule.log_prefix, icap_headers.icap) + common.yield_result(task, rule, icap_headers.icap, 0.0, + 'fail', maybe_part) + return false + else + rspamd_logger.errx(task, '%s: unhandled response |%s|', + rule.log_prefix, string.gsub(result, "\r\n", ", ")) + common.yield_result(task, rule, string.format('unhandled icap response: %s', icap_headers.icap), + 0.0, 'fail', maybe_part) + end + end + end + + local function icap_w_respond_cb(err_m, connection) + if err_m or connection == nil then + icap_requery(err_m, "icap_w_respond_cb") + else + connection:add_read(icap_r_respond_cb, '\r\n\r\n') + end + end + + local function icap_r_options_cb(err_m, data, connection) + if err_m or connection == nil then + icap_requery(err_m, "icap_r_options_cb") + else + local icap_headers = result_header_table(tostring(data)) + + if icap_headers.icap and string.find(icap_headers.icap, 'ICAP%/1%.. 2%d%d') then + if icap_headers['methods'] and string.find(icap_headers['methods'], 'RESPMOD') then + -- Allow "204 No Content" responses + -- https://datatracker.ietf.org/doc/html/rfc3507#section-4.6 + if icap_headers['allow'] and string.find(icap_headers['allow'], '204') then + add_respond_header('Allow', '204') + end + + if rule.x_client_header then + local client = task:get_from_ip() + if client then + add_respond_header('X-Client-IP', client:to_string()) + end + end + + -- F-Secure extra headers + if icap_headers['server'] and string.find(icap_headers['server'], 'f-secure icap server') then + + if rule.x_rcpt_header then + local rcpt_to = task:get_principal_recipient() + if rcpt_to then + add_respond_header('X-Rcpt-To', rcpt_to) + end + end + + if rule.x_from_header then + local mail_from = task:get_principal_recipient() + if mail_from and mail_from[1] then + add_respond_header('X-Rcpt-To', mail_from[1].addr) + end + end + + end + + if icap_headers.connection and icap_headers.connection:lower() == 'close' then + lua_util.debugm(rule.name, task, '%s: OPTIONS request Connection: %s - using new connection', + rule.log_prefix, icap_headers.connection) + connection:close() + tcp_options.callback = icap_w_respond_cb + tcp_options.data = get_respond_query() + tcp.request(tcp_options) + else + connection:add_write(icap_w_respond_cb, get_respond_query()) + end + + else + rspamd_logger.errx(task, '%s: RESPMOD method not advertised: Methods: %s', + rule.log_prefix, icap_headers['methods']) + common.yield_result(task, rule, 'NO RESPMOD', 0.0, + 'fail', maybe_part) + end + else + rspamd_logger.errx(task, '%s: OPTIONS query failed: %s', + rule.log_prefix, icap_headers.icap or "-") + common.yield_result(task, rule, 'OPTIONS query failed', 0.0, + 'fail', maybe_part) + end + end + end + + if err or conn == nil then + icap_requery(err, "options_request") + else + conn:add_read(icap_r_options_cb, '\r\n\r\n') + end + end + + tcp_options.task = task + tcp_options.stop_pattern = '\r\n' + tcp_options.read = false + tcp_options.timeout = rule.timeout + tcp_options.callback = icap_callback + tcp_options.data = options_request + + if rule.ssl then + tcp_options.ssl = true + if rule.no_ssl_verify then + tcp_options.no_ssl_verify = true + end + end + + tcp_options.host = addr:to_string() + tcp_options.port = addr:get_port() + tcp_options.upstream = upstream + + tcp.request(tcp_options) + end + + if common.condition_check_and_continue(task, content, rule, digest, + icap_check_uncached, maybe_part) then + return + else + icap_check_uncached() + end + +end + +return { + type = { N, 'virus', 'virus', 'scanner' }, + description = 'generic icap antivirus', + configure = icap_config, + check = icap_check, + name = N +} diff --git a/lualib/lua_scanners/init.lua b/lualib/lua_scanners/init.lua new file mode 100644 index 0000000..e47cebe --- /dev/null +++ b/lualib/lua_scanners/init.lua @@ -0,0 +1,75 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_scanners +-- This module contains external scanners functions +--]] + +local fun = require "fun" + +local exports = { +} + +local function require_scanner(name) + local sc = require("lua_scanners/" .. name) + + exports[sc.name or name] = sc +end + +-- Antiviruses +require_scanner('clamav') +require_scanner('fprot') +require_scanner('kaspersky_av') +require_scanner('kaspersky_se') +require_scanner('savapi') +require_scanner('sophos') +require_scanner('virustotal') +require_scanner('avast') + +-- Other scanners +require_scanner('dcc') +require_scanner('oletools') +require_scanner('icap') +require_scanner('vadesecure') +require_scanner('spamassassin') +require_scanner('p0f') +require_scanner('razor') +require_scanner('pyzor') +require_scanner('cloudmark') + +exports.add_scanner = function(name, t, conf_func, check_func) + assert(type(conf_func) == 'function' and type(check_func) == 'function', + 'bad arguments') + exports[name] = { + type = t, + configure = conf_func, + check = check_func, + } +end + +exports.filter = function(t) + return fun.tomap(fun.filter(function(_, elt) + return type(elt) == 'table' and elt.type and ( + (type(elt.type) == 'string' and elt.type == t) or + (type(elt.type) == 'table' and fun.any(function(tt) + return tt == t + end, elt.type)) + ) + end, exports)) +end + +return exports diff --git a/lualib/lua_scanners/kaspersky_av.lua b/lualib/lua_scanners/kaspersky_av.lua new file mode 100644 index 0000000..d52cef0 --- /dev/null +++ b/lualib/lua_scanners/kaspersky_av.lua @@ -0,0 +1,197 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module kaspersky +-- This module contains kaspersky antivirus access functions +--]] + +local lua_util = require "lua_util" +local tcp = require "rspamd_tcp" +local upstream_list = require "rspamd_upstream_list" +local rspamd_util = require "rspamd_util" +local rspamd_logger = require "rspamd_logger" +local common = require "lua_scanners/common" + +local N = "kaspersky" + +local default_message = '${SCANNER}: virus found: "${VIRUS}"' + +local function kaspersky_config(opts) + local kaspersky_conf = { + name = N, + scan_mime_parts = true, + scan_text_mime = false, + scan_image_mime = false, + product_id = 0, + log_clean = false, + timeout = 5.0, + retransmits = 1, -- use local files, retransmits are useless + cache_expire = 3600, -- expire redis in one hour + message = default_message, + detection_category = "virus", + tmpdir = '/tmp', + } + + kaspersky_conf = lua_util.override_defaults(kaspersky_conf, opts) + + if not kaspersky_conf.prefix then + kaspersky_conf.prefix = 'rs_' .. kaspersky_conf.name .. '_' + end + + if not kaspersky_conf.log_prefix then + if kaspersky_conf.name:lower() == kaspersky_conf.type:lower() then + kaspersky_conf.log_prefix = kaspersky_conf.name + else + kaspersky_conf.log_prefix = kaspersky_conf.name .. ' (' .. kaspersky_conf.type .. ')' + end + end + + if not kaspersky_conf['servers'] then + rspamd_logger.errx(rspamd_config, 'no servers defined') + + return nil + end + + kaspersky_conf['upstreams'] = upstream_list.create(rspamd_config, + kaspersky_conf['servers'], 0) + + if kaspersky_conf['upstreams'] then + lua_util.add_debug_alias('antivirus', kaspersky_conf.name) + return kaspersky_conf + end + + rspamd_logger.errx(rspamd_config, 'cannot parse servers %s', + kaspersky_conf['servers']) + return nil +end + +local function kaspersky_check(task, content, digest, rule, maybe_part) + local function kaspersky_check_uncached () + local upstream = rule.upstreams:get_upstream_round_robin() + local addr = upstream:get_addr() + local retransmits = rule.retransmits + local fname = string.format('%s/%s.tmp', + rule.tmpdir, rspamd_util.random_hex(32)) + local message_fd = rspamd_util.create_file(fname) + local clamav_compat_cmd = string.format("nSCAN %s\n", fname) + + if not message_fd then + rspamd_logger.errx('cannot store file for kaspersky scan: %s', fname) + return + end + + if type(content) == 'string' then + -- Create rspamd_text + local rspamd_text = require "rspamd_text" + content = rspamd_text.fromstring(content) + end + content:save_in_file(message_fd) + + -- Ensure file cleanup + task:get_mempool():add_destructor(function() + os.remove(fname) + rspamd_util.close_file(message_fd) + end) + + local function kaspersky_callback(err, data) + if err then + + -- retry with another upstream until retransmits exceeds + if retransmits > 0 then + + retransmits = retransmits - 1 + + -- Select a different upstream! + upstream = rule.upstreams:get_upstream_round_robin() + addr = upstream:get_addr() + + lua_util.debugm(rule.name, task, '%s: error: %s; retry IP: %s; retries left: %s', + rule.log_prefix, err, addr, retransmits) + + tcp.request({ + task = task, + host = addr:to_string(), + port = addr:get_port(), + upstream = upstream, + timeout = rule['timeout'], + callback = kaspersky_callback, + data = { clamav_compat_cmd }, + stop_pattern = '\n' + }) + else + rspamd_logger.errx(task, + '%s [%s]: failed to scan, maximum retransmits exceed', + rule['symbol'], rule['type']) + common.yield_result(task, rule, + 'failed to scan and retransmits exceed', 0.0, 'fail', + maybe_part) + end + + else + data = tostring(data) + local cached + lua_util.debugm(rule.name, task, + '%s [%s]: got reply: %s', + rule['symbol'], rule['type'], data) + if data == 'stream: OK' or data == fname .. ': OK' then + cached = 'OK' + common.log_clean(task, rule) + else + local vname = string.match(data, ': (.+) FOUND') + if vname then + common.yield_result(task, rule, vname, 1.0, nil, maybe_part) + cached = vname + else + rspamd_logger.errx(task, 'unhandled response: %s', data) + common.yield_result(task, rule, 'unhandled response', + 0.0, 'fail', maybe_part) + end + end + if cached then + common.save_cache(task, digest, rule, cached, 1.0, maybe_part) + end + end + end + + tcp.request({ + task = task, + host = addr:to_string(), + port = addr:get_port(), + upstream = upstream, + timeout = rule['timeout'], + callback = kaspersky_callback, + data = { clamav_compat_cmd }, + stop_pattern = '\n' + }) + end + + if common.condition_check_and_continue(task, content, rule, digest, + kaspersky_check_uncached, maybe_part) then + return + else + kaspersky_check_uncached() + end + +end + +return { + type = 'antivirus', + description = 'kaspersky antivirus', + configure = kaspersky_config, + check = kaspersky_check, + name = N +} diff --git a/lualib/lua_scanners/kaspersky_se.lua b/lualib/lua_scanners/kaspersky_se.lua new file mode 100644 index 0000000..5e0f2ea --- /dev/null +++ b/lualib/lua_scanners/kaspersky_se.lua @@ -0,0 +1,287 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module kaspersky_se +-- This module contains Kaspersky Scan Engine integration support +-- https://www.kaspersky.com/scan-engine +--]] + +local lua_util = require "lua_util" +local rspamd_util = require "rspamd_util" +local http = require "rspamd_http" +local upstream_list = require "rspamd_upstream_list" +local rspamd_logger = require "rspamd_logger" +local common = require "lua_scanners/common" + +local N = 'kaspersky_se' + +local function kaspersky_se_config(opts) + + local default_conf = { + name = N, + default_port = 9999, + use_https = false, + use_files = false, + timeout = 5.0, + log_clean = false, + tmpdir = '/tmp', + retransmits = 1, + cache_expire = 7200, -- expire redis in 2h + message = '${SCANNER}: spam message found: "${VIRUS}"', + detection_category = "virus", + default_score = 1, + action = false, + scan_mime_parts = true, + scan_text_mime = false, + scan_image_mime = false, + } + + default_conf = lua_util.override_defaults(default_conf, opts) + + if not default_conf.prefix then + default_conf.prefix = 'rs_' .. default_conf.name .. '_' + end + + if not default_conf.log_prefix then + if default_conf.name:lower() == default_conf.type:lower() then + default_conf.log_prefix = default_conf.name + else + default_conf.log_prefix = default_conf.name .. ' (' .. default_conf.type .. ')' + end + end + + if not default_conf.servers and default_conf.socket then + default_conf.servers = default_conf.socket + end + + if not default_conf.servers then + rspamd_logger.errx(rspamd_config, 'no servers defined') + + return nil + end + + default_conf.upstreams = upstream_list.create(rspamd_config, + default_conf.servers, + default_conf.default_port) + + if default_conf.upstreams then + lua_util.add_debug_alias('external_services', default_conf.name) + return default_conf + end + + rspamd_logger.errx(rspamd_config, 'cannot parse servers %s', + default_conf['servers']) + return nil +end + +local function kaspersky_se_check(task, content, digest, rule, maybe_part) + local function kaspersky_se_check_uncached() + local function make_url(addr) + local url + local suffix = '/scanmemory' + + if rule.use_files then + suffix = '/scanfile' + end + if rule.use_https then + url = string.format('https://%s:%d%s', tostring(addr), + addr:get_port(), suffix) + else + url = string.format('http://%s:%d%s', tostring(addr), + addr:get_port(), suffix) + end + + return url + end + + local upstream = rule.upstreams:get_upstream_round_robin() + local addr = upstream:get_addr() + local retransmits = rule.retransmits + + local url = make_url(addr) + local hdrs = { + ['X-KAV-ProtocolVersion'] = '1', + ['X-KAV-Timeout'] = tostring(rule.timeout * 1000), + } + + if task:has_from() then + hdrs['X-KAV-ObjectURL'] = string.format('[from:%s]', task:get_from()[1].addr) + end + + local req_body + + if rule.use_files then + local fname = string.format('%s/%s.tmp', + rule.tmpdir, rspamd_util.random_hex(32)) + local message_fd = rspamd_util.create_file(fname) + + if not message_fd then + rspamd_logger.errx('cannot store file for kaspersky_se scan: %s', fname) + return + end + + if type(content) == 'string' then + -- Create rspamd_text + local rspamd_text = require "rspamd_text" + content = rspamd_text.fromstring(content) + end + content:save_in_file(message_fd) + + -- Ensure cleanup + task:get_mempool():add_destructor(function() + os.remove(fname) + rspamd_util.close_file(message_fd) + end) + + req_body = fname + else + req_body = content + end + + local request_data = { + task = task, + url = url, + body = req_body, + headers = hdrs, + timeout = rule.timeout, + } + + local function kas_callback(http_err, code, body, headers) + + local function requery() + -- set current upstream to fail because an error occurred + upstream:fail() + + -- retry with another upstream until retransmits exceeds + if retransmits > 0 then + + retransmits = retransmits - 1 + + lua_util.debugm(rule.name, task, + '%s: Request Error: %s - retries left: %s', + rule.log_prefix, http_err, retransmits) + + -- Select a different upstream! + upstream = rule.upstreams:get_upstream_round_robin() + addr = upstream:get_addr() + url = make_url(addr) + + lua_util.debugm(rule.name, task, '%s: retry IP: %s:%s', + rule.log_prefix, addr, addr:get_port()) + request_data.url = url + request_data.upstream = upstream + + http.request(request_data) + else + rspamd_logger.errx(task, '%s: failed to scan, maximum retransmits ' .. + 'exceed', rule.log_prefix) + task:insert_result(rule['symbol_fail'], 0.0, 'failed to scan and ' .. + 'retransmits exceed') + end + end + + if http_err then + requery() + else + -- Parse the response + if upstream then + upstream:ok() + end + if code ~= 200 then + rspamd_logger.errx(task, 'invalid HTTP code: %s, body: %s, headers: %s', code, body, headers) + task:insert_result(rule.symbol_fail, 1.0, 'Bad HTTP code: ' .. code) + return + end + local data = string.gsub(tostring(body), '[\r\n%s]$', '') + local cached + lua_util.debugm(rule.name, task, '%s: got reply data: "%s"', + rule.log_prefix, data) + + if data:find('^CLEAN') then + -- Handle CLEAN replies + if data == 'CLEAN' then + cached = 'OK' + if rule['log_clean'] then + rspamd_logger.infox(task, '%s: message or mime_part is clean', + rule.log_prefix) + else + lua_util.debugm(rule.name, task, '%s: message or mime_part is clean', + rule.log_prefix) + end + elseif data == 'CLEAN AND CONTAINS OFFICE MACRO' then + common.yield_result(task, rule, 'File contains macros', + 0.0, 'macro', maybe_part) + cached = 'MACRO' + else + rspamd_logger.errx(task, '%s: unhandled clean response: %s', rule.log_prefix, data) + common.yield_result(task, rule, 'unhandled response:' .. data, + 0.0, 'fail', maybe_part) + end + elseif data == 'SERVER_ERROR' then + rspamd_logger.errx(task, '%s: error: %s', rule.log_prefix, data) + common.yield_result(task, rule, 'error:' .. data, + 0.0, 'fail', maybe_part) + elseif string.match(data, 'DETECT (.+)') then + local vname = string.match(data, 'DETECT (.+)') + common.yield_result(task, rule, vname, 1.0, nil, maybe_part) + cached = vname + elseif string.match(data, 'NON_SCANNED %((.+)%)') then + local why = string.match(data, 'NON_SCANNED %((.+)%)') + + if why == 'PASSWORD PROTECTED' then + rspamd_logger.errx(task, '%s: File is encrypted', rule.log_prefix) + common.yield_result(task, rule, 'File is encrypted: ' .. why, + 0.0, 'encrypted', maybe_part) + cached = 'ENCRYPTED' + else + common.yield_result(task, rule, 'unhandled response:' .. data, + 0.0, 'fail', maybe_part) + end + else + rspamd_logger.errx(task, '%s: unhandled response: %s', rule.log_prefix, data) + common.yield_result(task, rule, 'unhandled response:' .. data, + 0.0, 'fail', maybe_part) + end + + if cached then + common.save_cache(task, digest, rule, cached, 1.0, maybe_part) + end + + end + end + + request_data.callback = kas_callback + http.request(request_data) + end + + if common.condition_check_and_continue(task, content, rule, digest, + kaspersky_se_check_uncached, maybe_part) then + return + else + + kaspersky_se_check_uncached() + end + +end + +return { + type = 'antivirus', + description = 'Kaspersky Scan Engine interface', + configure = kaspersky_se_config, + check = kaspersky_se_check, + name = N +} diff --git a/lualib/lua_scanners/oletools.lua b/lualib/lua_scanners/oletools.lua new file mode 100644 index 0000000..378e094 --- /dev/null +++ b/lualib/lua_scanners/oletools.lua @@ -0,0 +1,369 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> +Copyright (c) 2018, Carsten Rosenberg <c.rosenberg@heinlein-support.de> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module oletools +-- This module contains oletools access functions. +-- Olefy is needed: https://github.com/HeinleinSupport/olefy +--]] + +local lua_util = require "lua_util" +local tcp = require "rspamd_tcp" +local upstream_list = require "rspamd_upstream_list" +local rspamd_logger = require "rspamd_logger" +local ucl = require "ucl" +local common = require "lua_scanners/common" + +local N = 'oletools' + +local function oletools_config(opts) + + local oletools_conf = { + name = N, + scan_mime_parts = true, + scan_text_mime = false, + scan_image_mime = false, + default_port = 10050, + timeout = 15.0, + log_clean = false, + retransmits = 2, + cache_expire = 86400, -- expire redis in 1d + min_size = 500, + symbol = "OLETOOLS", + message = '${SCANNER}: Oletools threat message found: "${VIRUS}"', + detection_category = "office macro", + default_score = 1, + action = false, + extended = false, + symbol_type = 'postfilter', + dynamic_scan = true, + } + + oletools_conf = lua_util.override_defaults(oletools_conf, opts) + + if not oletools_conf.prefix then + oletools_conf.prefix = 'rs_' .. oletools_conf.name .. '_' + end + + if not oletools_conf.log_prefix then + if oletools_conf.name:lower() == oletools_conf.type:lower() then + oletools_conf.log_prefix = oletools_conf.name + else + oletools_conf.log_prefix = oletools_conf.name .. ' (' .. oletools_conf.type .. ')' + end + end + + if not oletools_conf.servers then + rspamd_logger.errx(rspamd_config, 'no servers defined') + + return nil + end + + oletools_conf.upstreams = upstream_list.create(rspamd_config, + oletools_conf.servers, + oletools_conf.default_port) + + if oletools_conf.upstreams then + lua_util.add_debug_alias('external_services', oletools_conf.name) + return oletools_conf + end + + rspamd_logger.errx(rspamd_config, 'cannot parse servers %s', + oletools_conf.servers) + return nil +end + +local function oletools_check(task, content, digest, rule, maybe_part) + local function oletools_check_uncached () + local upstream = rule.upstreams:get_upstream_round_robin() + local addr = upstream:get_addr() + local retransmits = rule.retransmits + local protocol = 'OLEFY/1.0\nMethod: oletools\nRspamd-ID: ' .. task:get_uid() .. '\n\n' + local json_response = "" + + local function oletools_callback(err, data, conn) + + local function oletools_requery(error) + + -- retry with another upstream until retransmits exceeds + if retransmits > 0 then + + retransmits = retransmits - 1 + + -- Select a different upstream! + upstream = rule.upstreams:get_upstream_round_robin() + addr = upstream:get_addr() + + lua_util.debugm(rule.name, task, '%s: error: %s; retry IP: %s; retries left: %s', + rule.log_prefix, err, addr, retransmits) + + tcp.request({ + task = task, + host = addr:to_string(), + port = addr:get_port(), + upstream = upstream, + timeout = rule.timeout, + shutdown = true, + data = { protocol, content }, + callback = oletools_callback, + }) + else + rspamd_logger.errx(task, '%s: failed to scan, maximum retransmits ' .. + 'exceed - err: %s', rule.log_prefix, error) + common.yield_result(task, rule, + 'failed to scan, maximum retransmits exceed - err: ' .. error, + 0.0, 'fail', maybe_part) + end + end + + if err then + + oletools_requery(err) + + else + json_response = json_response .. tostring(data) + + if not string.find(json_response, '\t\n\n\t') and #data == 8192 then + lua_util.debugm(rule.name, task, '%s: no stop word: add_read - #json: %s / current packet: %s', + rule.log_prefix, #json_response, #data) + conn:add_read(oletools_callback) + + else + local ucl_parser = ucl.parser() + local ok, ucl_err = ucl_parser:parse_string(tostring(json_response)) + if not ok then + rspamd_logger.errx(task, "%s: error parsing json response, retry: %s", + rule.log_prefix, ucl_err) + oletools_requery(ucl_err) + return + end + + local result = ucl_parser:get_object() + + local oletools_rc = { + [0] = 'RETURN_OK', + [1] = 'RETURN_WARNINGS', + [2] = 'RETURN_WRONG_ARGS', + [3] = 'RETURN_FILE_NOT_FOUND', + [4] = 'RETURN_XGLOB_ERR', + [5] = 'RETURN_OPEN_ERROR', + [6] = 'RETURN_PARSE_ERROR', + [7] = 'RETURN_SEVERAL_ERRS', + [8] = 'RETURN_UNEXPECTED', + [9] = 'RETURN_ENCRYPTED', + } + + -- M=Macros, A=Auto-executable, S=Suspicious keywords, I=IOCs, + -- H=Hex strings, B=Base64 strings, D=Dridex strings, V=VBA strings + -- Keep sorted to avoid dragons + local analysis_cat_table = { + autoexec = '-', + base64 = '-', + dridex = '-', + hex = '-', + iocs = '-', + macro_exist = '-', + suspicious = '-', + vba = '-' + } + local analysis_keyword_table = {} + + for _, v in ipairs(result) do + + if v.error ~= nil and v.type ~= 'error' then + -- olefy, not oletools error + rspamd_logger.errx(task, '%s: ERROR found: %s', rule.log_prefix, + v.error) + if v.error == 'File too small' then + common.save_cache(task, digest, rule, 'OK', 1.0, maybe_part) + common.log_clean(task, rule, 'File too small to be scanned for macros') + return + else + oletools_requery(v.error) + end + + elseif tostring(v.type) == "MetaInformation" and v.version ~= nil then + -- if MetaInformation section - check and print script and version + + lua_util.debugm(N, task, '%s: version: %s %s', rule.log_prefix, + tostring(v.script_name), tostring(v.version)) + + elseif tostring(v.type) == "MetaInformation" and v.return_code ~= nil then + -- if MetaInformation section - check return_code + + local oletools_rc_code = tonumber(v.return_code) + if oletools_rc_code == 9 then + rspamd_logger.warnx(task, '%s: File is encrypted.', rule.log_prefix) + common.yield_result(task, rule, + 'failed - err: ' .. oletools_rc[oletools_rc_code], + 0.0, 'encrypted', maybe_part) + common.save_cache(task, digest, rule, 'encrypted', 1.0, maybe_part) + return + elseif oletools_rc_code == 5 then + rspamd_logger.warnx(task, '%s: olefy could not open the file - error: %s', rule.log_prefix, + result[2]['message']) + common.yield_result(task, rule, + 'failed - err: ' .. oletools_rc[oletools_rc_code], + 0.0, 'fail', maybe_part) + return + elseif oletools_rc_code > 6 then + rspamd_logger.errx(task, '%s: MetaInfo section error code: %s', + rule.log_prefix, oletools_rc[oletools_rc_code]) + rspamd_logger.errx(task, '%s: MetaInfo section message: %s', + rule.log_prefix, result[2]['message']) + common.yield_result(task, rule, + 'failed - err: ' .. oletools_rc[oletools_rc_code], + 0.0, 'fail', maybe_part) + return + elseif oletools_rc_code > 1 then + rspamd_logger.errx(task, '%s: Error message: %s', + rule.log_prefix, result[2]['message']) + oletools_requery(oletools_rc[oletools_rc_code]) + end + + elseif tostring(v.type) == "error" then + -- error section found - check message + rspamd_logger.errx(task, '%s: Error section error code: %s', + rule.log_prefix, v.error) + rspamd_logger.errx(task, '%s: Error section message: %s', + rule.log_prefix, v.message) + --common.yield_result(task, rule, 'failed - err: ' .. v.error, 0.0, 'fail') + + elseif type(v.analysis) == 'table' and type(v.macros) == 'table' then + -- analysis + macro found - evaluate response + + if type(v.analysis) == 'table' and #v.analysis == 0 and #v.macros == 0 then + rspamd_logger.warnx(task, '%s: maybe unhandled python or oletools error', rule.log_prefix) + oletools_requery('oletools unhandled error') + + elseif #v.macros > 0 then + + analysis_cat_table.macro_exist = 'M' + + lua_util.debugm(rule.name, task, + '%s: filename: %s', rule.log_prefix, result[2]['file']) + lua_util.debugm(rule.name, task, + '%s: type: %s', rule.log_prefix, result[2]['type']) + + for _, m in ipairs(v.macros) do + lua_util.debugm(rule.name, task, '%s: macros found - code: %s, ole_stream: %s, ' .. + 'vba_filename: %s', rule.log_prefix, m.code, m.ole_stream, m.vba_filename) + end + + for _, a in ipairs(v.analysis) do + lua_util.debugm(rule.name, task, '%s: threat found - type: %s, keyword: %s, ' .. + 'description: %s', rule.log_prefix, a.type, a.keyword, a.description) + if a.type == 'AutoExec' then + analysis_cat_table.autoexec = 'A' + table.insert(analysis_keyword_table, a.keyword) + elseif a.type == 'Suspicious' then + if rule.extended == true or + (a.keyword ~= 'Base64 Strings' and a.keyword ~= 'Hex Strings') + then + analysis_cat_table.suspicious = 'S' + table.insert(analysis_keyword_table, a.keyword) + end + elseif a.type == 'IOC' then + analysis_cat_table.iocs = 'I' + elseif a.type == 'Hex strings' then + analysis_cat_table.hex = 'H' + elseif a.type == 'Base64 strings' then + analysis_cat_table.base64 = 'B' + elseif a.type == 'Dridex strings' then + analysis_cat_table.dridex = 'D' + elseif a.type == 'VBA strings' then + analysis_cat_table.vba = 'V' + end + end + end + end + end + + lua_util.debugm(N, task, '%s: analysis_keyword_table: %s', rule.log_prefix, analysis_keyword_table) + lua_util.debugm(N, task, '%s: analysis_cat_table: %s', rule.log_prefix, analysis_cat_table) + + if rule.extended == false and analysis_cat_table.autoexec == 'A' and analysis_cat_table.suspicious == 'S' then + -- use single string as virus name + local threat = 'AutoExec + Suspicious (' .. table.concat(analysis_keyword_table, ',') .. ')' + lua_util.debugm(rule.name, task, '%s: threat result: %s', rule.log_prefix, threat) + common.yield_result(task, rule, threat, rule.default_score, nil, maybe_part) + common.save_cache(task, digest, rule, threat, rule.default_score, maybe_part) + + elseif rule.extended == true and #analysis_keyword_table > 0 then + -- report any flags (types) and any most keywords as individual virus name + local analysis_cat_table_values_sorted = {} + + -- see https://github.com/rspamd/rspamd/commit/6bd3e2b9f49d1de3ab882aeca9c30bc7d526ac9d#commitcomment-40130493 + -- for details + local analysis_cat_table_keys_sorted = lua_util.keys(analysis_cat_table) + table.sort(analysis_cat_table_keys_sorted) + + for _, v in ipairs(analysis_cat_table_keys_sorted) do + table.insert(analysis_cat_table_values_sorted, analysis_cat_table[v]) + end + + table.insert(analysis_keyword_table, 1, table.concat(analysis_cat_table_values_sorted)) + + lua_util.debugm(rule.name, task, '%s: extended threat result: %s', + rule.log_prefix, table.concat(analysis_keyword_table, ',')) + + common.yield_result(task, rule, analysis_keyword_table, + rule.default_score, nil, maybe_part) + common.save_cache(task, digest, rule, analysis_keyword_table, + rule.default_score, maybe_part) + + elseif analysis_cat_table.macro_exist == '-' and #analysis_keyword_table == 0 then + common.save_cache(task, digest, rule, 'OK', 1.0, maybe_part) + common.log_clean(task, rule, 'No macro found') + + else + common.save_cache(task, digest, rule, 'OK', 1.0, maybe_part) + common.log_clean(task, rule, 'Scanned Macro is OK') + end + end + end + end + + tcp.request({ + task = task, + host = addr:to_string(), + port = addr:get_port(), + upstream = upstream, + timeout = rule.timeout, + shutdown = true, + data = { protocol, content }, + callback = oletools_callback, + }) + + end + + if common.condition_check_and_continue(task, content, rule, digest, + oletools_check_uncached, maybe_part) then + return + else + oletools_check_uncached() + end + +end + +return { + type = { N, 'attachment scanner', 'hash', 'scanner' }, + description = 'oletools office macro scanner', + configure = oletools_config, + check = oletools_check, + name = N +} diff --git a/lualib/lua_scanners/p0f.lua b/lualib/lua_scanners/p0f.lua new file mode 100644 index 0000000..7785f83 --- /dev/null +++ b/lualib/lua_scanners/p0f.lua @@ -0,0 +1,227 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> +Copyright (c) 2019, Denis Paavilainen <denpa@denpa.pro> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module p0f +-- This module contains p0f access functions +--]] + +local tcp = require "rspamd_tcp" +local rspamd_util = require "rspamd_util" +local rspamd_logger = require "rspamd_logger" +local lua_redis = require "lua_redis" +local lua_util = require "lua_util" +local common = require "lua_scanners/common" + +-- SEE: https://github.com/p0f/p0f/blob/v3.06b/docs/README#L317 +local S = { + BAD_QUERY = 0x0, + OK = 0x10, + NO_MATCH = 0x20 +} + +local N = 'p0f' + +local function p0f_check(task, ip, rule) + + local function ip2bin(addr) + addr = addr:to_table() + + for k, v in ipairs(addr) do + addr[k] = rspamd_util.pack('B', v) + end + + return table.concat(addr) + end + + local function trim(...) + local vars = { ... } + + for k, v in ipairs(vars) do + -- skip numbers, trim only strings + if tonumber(vars[k]) == nil then + vars[k] = string.gsub(v, '[^%w-_\\.\\(\\) ]', '') + end + end + + return lua_util.unpack(vars) + end + + local function parse_p0f_response(data) + --[[ + p0f_api_response[232]: magic, status, first_seen, last_seen, total_conn, + uptime_min, up_mod_days, last_nat, last_chg, distance, bad_sw, os_match_q, + os_name, os_flavor, http_name, http_flavor, link_type, language + ]]-- + + data = tostring(data) + + -- API response must be 232 bytes long + if #data ~= 232 then + rspamd_logger.errx(task, 'malformed response from p0f on %s, %s bytes', + rule.socket, #data) + + common.yield_result(task, rule, 'Malformed Response: ' .. rule.socket, + 0.0, 'fail') + return + end + + local _, status, _, _, _, uptime_min, _, _, _, distance, _, _, os_name, + os_flavor, _, _, link_type, _ = trim(rspamd_util.unpack( + 'I4I4I4I4I4I4I4I4I4hbbc32c32c32c32c32c32', data)) + + if status ~= S.OK then + if status == S.BAD_QUERY then + rspamd_logger.errx(task, 'malformed p0f query on %s', rule.socket) + common.yield_result(task, rule, 'Malformed Query: ' .. rule.socket, + 0.0, 'fail') + end + + return + end + + local os_string = #os_name == 0 and 'unknown' or os_name .. ' ' .. os_flavor + + task:get_mempool():set_variable('os_fingerprint', os_string, link_type, + uptime_min, distance) + + if link_type and #link_type > 0 then + common.yield_result(task, rule, { + os_string, + 'link=' .. link_type, + 'distance=' .. distance }, + 0.0) + else + common.yield_result(task, rule, { + os_string, + 'link=unknown', + 'distance=' .. distance }, + 0.0) + end + + return data + end + + local function make_p0f_request() + + local function check_p0f_cb(err, data) + + local function redis_set_cb(redis_set_err) + if redis_set_err then + rspamd_logger.errx(task, 'redis received an error: %s', redis_set_err) + end + end + + if err then + rspamd_logger.errx(task, 'p0f received an error: %s', err) + common.yield_result(task, rule, 'Error getting result: ' .. err, + 0.0, 'fail') + return + end + + data = parse_p0f_response(data) + + if rule.redis_params and data then + local key = rule.prefix .. ip:to_string() + local ret = lua_redis.redis_make_request(task, + rule.redis_params, + key, + true, + redis_set_cb, + 'SETEX', + { key, tostring(rule.expire), data } + ) + + if not ret then + rspamd_logger.warnx(task, 'error connecting to redis') + end + end + end + + local query = rspamd_util.pack('I4 I1 c16', 0x50304601, + ip:get_version(), ip2bin(ip)) + + tcp.request({ + host = rule.socket, + callback = check_p0f_cb, + data = { query }, + task = task, + timeout = rule.timeout + }) + end + + local function redis_get_cb(err, data) + if err or type(data) ~= 'string' then + make_p0f_request() + else + parse_p0f_response(data) + end + end + + local ret = nil + if rule.redis_params then + local key = rule.prefix .. ip:to_string() + ret = lua_redis.redis_make_request(task, + rule.redis_params, + key, + false, + redis_get_cb, + 'GET', + { key } + ) + end + + if not ret then + make_p0f_request() -- fallback to directly querying p0f + end +end + +local function p0f_config(opts) + local p0f_conf = { + name = N, + timeout = 5, + symbol = 'P0F', + symbol_fail = 'P0F_FAIL', + patterns = {}, + expire = 7200, + prefix = 'p0f', + detection_category = 'fingerprint', + message = '${SCANNER}: fingerprint matched: "${VIRUS}"' + } + + p0f_conf = lua_util.override_defaults(p0f_conf, opts) + p0f_conf.patterns = common.create_regex_table(p0f_conf.patterns) + + if not p0f_conf.log_prefix then + p0f_conf.log_prefix = p0f_conf.name + end + + if not p0f_conf.socket then + rspamd_logger.errx(rspamd_config, 'no servers defined') + return nil + end + + return p0f_conf +end + +return { + type = { N, 'fingerprint', 'scanner' }, + description = 'passive OS fingerprinter', + configure = p0f_config, + check = p0f_check, + name = N +} diff --git a/lualib/lua_scanners/pyzor.lua b/lualib/lua_scanners/pyzor.lua new file mode 100644 index 0000000..75c1b4a --- /dev/null +++ b/lualib/lua_scanners/pyzor.lua @@ -0,0 +1,206 @@ +--[[ +Copyright (c) 2021, defkev <defkev@gmail.com> +Copyright (c) 2018, Carsten Rosenberg <c.rosenberg@heinlein-support.de> +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module pyzor +-- This module contains pyzor access functions +--]] + +local lua_util = require "lua_util" +local tcp = require "rspamd_tcp" +local upstream_list = require "rspamd_upstream_list" +local rspamd_logger = require "rspamd_logger" +local common = require "lua_scanners/common" + +local N = 'pyzor' +local categories = { 'pyzor', 'bulk', 'hash', 'scanner' } + +local function pyzor_config(opts) + + local pyzor_conf = { + text_part_min_words = 2, + default_port = 5953, + timeout = 15.0, + log_clean = false, + retransmits = 2, + detection_category = "hash", + cache_expire = 7200, -- expire redis in one hour + message = '${SCANNER}: Pyzor bulk message found: "${VIRUS}"', + default_score = 1.5, + action = false, + } + + pyzor_conf = lua_util.override_defaults(pyzor_conf, opts) + + if not pyzor_conf.prefix then + pyzor_conf.prefix = 'rext_' .. N .. '_' + end + + if not pyzor_conf.log_prefix then + pyzor_conf.log_prefix = N .. ' (' .. pyzor_conf.detection_category .. ')' + end + + if not pyzor_conf['servers'] then + rspamd_logger.errx(rspamd_config, 'no servers defined') + + return nil + end + + pyzor_conf['upstreams'] = upstream_list.create(rspamd_config, + pyzor_conf['servers'], + pyzor_conf.default_port) + + if pyzor_conf['upstreams'] then + lua_util.add_debug_alias('external_services', N) + return pyzor_conf + end + + rspamd_logger.errx(rspamd_config, 'cannot parse servers %s', + pyzor_conf['servers']) + return nil +end + +local function pyzor_check(task, content, digest, rule) + local function pyzor_check_uncached () + local upstream = rule.upstreams:get_upstream_round_robin() + local addr = upstream:get_addr() + local retransmits = rule.retransmits + + local function pyzor_callback(err, data, conn) + + if err then + + -- retry with another upstream until retransmits exceeds + if retransmits > 0 then + + retransmits = retransmits - 1 + + -- Select a different upstream! + upstream = rule.upstreams:get_upstream_round_robin() + addr = upstream:get_addr() + + lua_util.debugm(N, task, '%s: retry IP: %s:%s err: %s', + rule.log_prefix, addr, addr:get_port(), err) + + tcp.request({ + task = task, + host = addr:to_string(), + port = addr:get_port(), + upstream = upstream, + timeout = rule['timeout'], + shutdown = true, + data = content, + callback = pyzor_callback, + }) + else + rspamd_logger.errx(task, '%s: failed to scan, maximum retransmits exceed', + rule['symbol'], rule['type']) + task:insert_result(rule['symbol_fail'], 0.0, + 'failed to scan and retransmits exceed') + end + else + -- pyzor output is unicode (\x09 -> tab, \0a -> newline) + -- public.pyzor.org:24441 (200, 'OK') 21285091 206759 + -- server:port Code Diag Count WL-Count + local str_data = tostring(data) + lua_util.debugm(N, task, '%s: returned data: %s', + rule.log_prefix, str_data) + -- If pyzor would return JSON this wouldn't be necessary + local resp = {} + for v in string.gmatch(str_data, '[^\t]+') do + table.insert(resp, v) + end + -- rspamd_logger.infox(task, 'resp: %s', resp) + if resp[2] ~= [[(200, 'OK')]] then + rspamd_logger.errx(task, "error parsing response: %s", str_data) + return + end + + local whitelisted = tonumber(resp[4]) + local reported = tonumber(resp[3]) + + --rspamd_logger.infox(task, "%s - count=%s wl=%s", addr:to_string(), reported, whitelisted) + + --[[ + Weight is Count - WL-Count of rule.default_score in percent, e.g. + SPAM: + Count: 100 (100%) + WL-Count: 1 (1%) + rule.default_score: 1 + Weight: 0.99 + HAM: + Count: 10 (100%) + WL-Count: 10 (100%) + rule.default_score: 1 + Weight: 0 + ]] + local weight = tonumber(string.format("%.2f", + rule.default_score * (reported - whitelisted) / (reported + whitelisted))) + local info = string.format("count=%d wl=%d", reported, whitelisted) + local threat_string = string.format("bl_%d_wl_%d", + reported, whitelisted) + + if weight > 0 then + lua_util.debugm(N, task, '%s: returned result is spam - info: %s', + rule.log_prefix, info) + common.yield_result(task, rule, threat_string, weight) + common.save_cache(task, digest, rule, threat_string, weight) + else + if rule.log_clean then + rspamd_logger.infox(task, '%s: clean, returned result is ham - info: %s', + rule.log_prefix, info) + else + lua_util.debugm(N, task, '%s: returned result is ham - info: %s', + rule.log_prefix, info) + end + common.save_cache(task, digest, rule, 'OK', weight) + end + + end + end + + if digest == 'da39a3ee5e6b4b0d3255bfef95601890afd80709' then + rspamd_logger.infox(task, '%s: not checking default digest', rule.log_prefix) + return + end + + tcp.request({ + task = task, + host = addr:to_string(), + port = addr:get_port(), + upstream = upstream, + timeout = rule.timeout, + shutdown = true, + data = content, + callback = pyzor_callback, + }) + end + if common.condition_check_and_continue(task, content, rule, digest, pyzor_check_uncached) then + return + else + pyzor_check_uncached() + end +end + +return { + type = categories, + description = 'pyzor bulk scanner', + configure = pyzor_config, + check = pyzor_check, + name = N +} diff --git a/lualib/lua_scanners/razor.lua b/lualib/lua_scanners/razor.lua new file mode 100644 index 0000000..fcc0a8e --- /dev/null +++ b/lualib/lua_scanners/razor.lua @@ -0,0 +1,181 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> +Copyright (c) 2018, Carsten Rosenberg <c.rosenberg@heinlein-support.de> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module razor +-- This module contains razor access functions +--]] + +local lua_util = require "lua_util" +local tcp = require "rspamd_tcp" +local upstream_list = require "rspamd_upstream_list" +local rspamd_logger = require "rspamd_logger" +local common = require "lua_scanners/common" + +local N = 'razor' + +local function razor_config(opts) + + local razor_conf = { + name = N, + default_port = 11342, + timeout = 5.0, + log_clean = false, + retransmits = 2, + cache_expire = 7200, -- expire redis in 2h + message = '${SCANNER}: spam message found: "${VIRUS}"', + detection_category = "hash", + default_score = 1, + action = false, + dynamic_scan = false, + symbol_fail = 'RAZOR_FAIL', + symbol = 'RAZOR', + } + + razor_conf = lua_util.override_defaults(razor_conf, opts) + + if not razor_conf.prefix then + razor_conf.prefix = 'rs_' .. razor_conf.name .. '_' + end + + if not razor_conf.log_prefix then + razor_conf.log_prefix = razor_conf.name + end + + if not razor_conf.servers and razor_conf.socket then + razor_conf.servers = razor_conf.socket + end + + if not razor_conf.servers then + rspamd_logger.errx(rspamd_config, 'no servers defined') + + return nil + end + + razor_conf.upstreams = upstream_list.create(rspamd_config, + razor_conf.servers, + razor_conf.default_port) + + if razor_conf.upstreams then + lua_util.add_debug_alias('external_services', razor_conf.name) + return razor_conf + end + + rspamd_logger.errx(rspamd_config, 'cannot parse servers %s', + razor_conf['servers']) + return nil +end + +local function razor_check(task, content, digest, rule) + local function razor_check_uncached () + local upstream = rule.upstreams:get_upstream_round_robin() + local addr = upstream:get_addr() + local retransmits = rule.retransmits + + local function razor_callback(err, data, conn) + + local function razor_requery() + -- retry with another upstream until retransmits exceeds + if retransmits > 0 then + + retransmits = retransmits - 1 + + lua_util.debugm(rule.name, task, '%s: Request Error: %s - retries left: %s', + rule.log_prefix, err, retransmits) + + -- Select a different upstream! + upstream = rule.upstreams:get_upstream_round_robin() + addr = upstream:get_addr() + + lua_util.debugm(rule.name, task, '%s: retry IP: %s:%s', + rule.log_prefix, addr, addr:get_port()) + + tcp.request({ + task = task, + host = addr:to_string(), + port = addr:get_port(), + upstream = upstream, + timeout = rule.timeout or 2.0, + shutdown = true, + data = content, + callback = razor_callback, + }) + else + rspamd_logger.errx(task, '%s: failed to scan, maximum retransmits ' .. + 'exceed', rule.log_prefix) + common.yield_result(task, rule, 'failed to scan and retransmits exceed', 0.0, 'fail') + end + end + + if err then + + razor_requery() + + else + --[[ + @todo: Razorsocket currently only returns ham or spam. When the wrapper is fixed we should add dynamic scores here. + Maybe check spamassassin implementation. + + This implementation is based on https://github.com/cgt/rspamd-plugins + Thanks @cgt! + ]] -- + + local threat_string = tostring(data) + if threat_string == "spam" then + lua_util.debugm(N, task, '%s: returned result is spam', rule['symbol'], rule['type']) + common.yield_result(task, rule, threat_string, rule.default_score) + common.save_cache(task, digest, rule, threat_string, rule.default_score) + elseif threat_string == "ham" then + if rule.log_clean then + rspamd_logger.infox(task, '%s: returned result is ham', rule['symbol'], rule['type']) + else + lua_util.debugm(N, task, '%s: returned result is ham', rule['symbol'], rule['type']) + end + common.save_cache(task, digest, rule, 'OK', rule.default_score) + else + rspamd_logger.errx(task, "%s - unknown response from razorfy: %s", addr:to_string(), threat_string) + end + + end + end + + tcp.request({ + task = task, + host = addr:to_string(), + port = addr:get_port(), + upstream = upstream, + timeout = rule.timeout or 2.0, + shutdown = true, + data = content, + callback = razor_callback, + }) + end + + if common.condition_check_and_continue(task, content, rule, digest, razor_check_uncached) then + return + else + razor_check_uncached() + end +end + +return { + type = { 'razor', 'spam', 'hash', 'scanner' }, + description = 'razor bulk scanner', + configure = razor_config, + check = razor_check, + name = N +} diff --git a/lualib/lua_scanners/savapi.lua b/lualib/lua_scanners/savapi.lua new file mode 100644 index 0000000..08f7b66 --- /dev/null +++ b/lualib/lua_scanners/savapi.lua @@ -0,0 +1,261 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module savapi +-- This module contains avira savapi antivirus access functions +--]] + +local lua_util = require "lua_util" +local tcp = require "rspamd_tcp" +local upstream_list = require "rspamd_upstream_list" +local rspamd_util = require "rspamd_util" +local rspamd_logger = require "rspamd_logger" +local common = require "lua_scanners/common" + +local N = "savapi" + +local default_message = '${SCANNER}: virus found: "${VIRUS}"' + +local function savapi_config(opts) + local savapi_conf = { + name = N, + scan_mime_parts = true, + scan_text_mime = false, + scan_image_mime = false, + default_port = 4444, -- note: You must set ListenAddress in savapi.conf + product_id = 0, + log_clean = false, + timeout = 15.0, -- FIXME: this will break task_timeout! + retransmits = 1, -- FIXME: useless, for local files + cache_expire = 3600, -- expire redis in one hour + message = default_message, + detection_category = "virus", + tmpdir = '/tmp', + } + + savapi_conf = lua_util.override_defaults(savapi_conf, opts) + + if not savapi_conf.prefix then + savapi_conf.prefix = 'rs_' .. savapi_conf.name .. '_' + end + + if not savapi_conf.log_prefix then + if savapi_conf.name:lower() == savapi_conf.type:lower() then + savapi_conf.log_prefix = savapi_conf.name + else + savapi_conf.log_prefix = savapi_conf.name .. ' (' .. savapi_conf.type .. ')' + end + end + + if not savapi_conf['servers'] then + rspamd_logger.errx(rspamd_config, 'no servers defined') + + return nil + end + + savapi_conf['upstreams'] = upstream_list.create(rspamd_config, + savapi_conf['servers'], + savapi_conf.default_port) + + if savapi_conf['upstreams'] then + lua_util.add_debug_alias('antivirus', savapi_conf.name) + return savapi_conf + end + + rspamd_logger.errx(rspamd_config, 'cannot parse servers %s', + savapi_conf['servers']) + return nil +end + +local function savapi_check(task, content, digest, rule) + local function savapi_check_uncached () + local upstream = rule.upstreams:get_upstream_round_robin() + local addr = upstream:get_addr() + local retransmits = rule.retransmits + local fname = string.format('%s/%s.tmp', + rule.tmpdir, rspamd_util.random_hex(32)) + local message_fd = rspamd_util.create_file(fname) + + if not message_fd then + rspamd_logger.errx('cannot store file for savapi scan: %s', fname) + return + end + + if type(content) == 'string' then + -- Create rspamd_text + local rspamd_text = require "rspamd_text" + content = rspamd_text.fromstring(content) + end + content:save_in_file(message_fd) + + -- Ensure cleanup + task:get_mempool():add_destructor(function() + os.remove(fname) + rspamd_util.close_file(message_fd) + end) + + local vnames = {} + + -- Forward declaration for recursive calls + local savapi_scan1_cb + + local function savapi_fin_cb(err, conn) + local vnames_reordered = {} + -- Swap table + for virus, _ in pairs(vnames) do + table.insert(vnames_reordered, virus) + end + lua_util.debugm(rule.name, task, "%s: number of virus names found %s", rule['type'], #vnames_reordered) + if #vnames_reordered > 0 then + local vname = {} + for _, virus in ipairs(vnames_reordered) do + table.insert(vname, virus) + end + + common.yield_result(task, rule, vname) + common.save_cache(task, digest, rule, vname) + end + if conn then + conn:close() + end + end + + local function savapi_scan2_cb(err, data, conn) + local result = tostring(data) + lua_util.debugm(rule.name, task, "%s: got reply: %s", + rule.type, result) + + -- Terminal response - clean + if string.find(result, '200') or string.find(result, '210') then + if rule['log_clean'] then + rspamd_logger.infox(task, '%s: message or mime_part is clean', rule['type']) + end + common.save_cache(task, digest, rule, 'OK') + conn:add_write(savapi_fin_cb, 'QUIT\n') + + -- Terminal response - infected + elseif string.find(result, '319') then + conn:add_write(savapi_fin_cb, 'QUIT\n') + + -- Non-terminal response + elseif string.find(result, '310') then + local virus + virus = result:match "310.*<<<%s(.*)%s+;.*;.*" + if not virus then + virus = result:match "310%s(.*)%s+;.*;.*" + if not virus then + rspamd_logger.errx(task, "%s: virus result unparseable: %s", + rule['type'], result) + common.yield_result(task, rule, 'virus result unparseable: ' .. result, 0.0, 'fail') + return + end + end + -- Store unique virus names + vnames[virus] = 1 + -- More content is expected + conn:add_write(savapi_scan1_cb, '\n') + end + end + + savapi_scan1_cb = function(err, conn) + conn:add_read(savapi_scan2_cb, '\n') + end + + -- 100 PRODUCT:xyz + local function savapi_greet2_cb(err, data, conn) + local result = tostring(data) + if string.find(result, '100 PRODUCT') then + lua_util.debugm(rule.name, task, "%s: scanning file: %s", + rule['type'], fname) + conn:add_write(savapi_scan1_cb, { string.format('SCAN %s\n', + fname) }) + else + rspamd_logger.errx(task, '%s: invalid product id %s', rule['type'], + rule['product_id']) + common.yield_result(task, rule, 'invalid product id: ' .. result, 0.0, 'fail') + conn:add_write(savapi_fin_cb, 'QUIT\n') + end + end + + local function savapi_greet1_cb(err, conn) + conn:add_read(savapi_greet2_cb, '\n') + end + + local function savapi_callback_init(err, data, conn) + if err then + + -- retry with another upstream until retransmits exceeds + if retransmits > 0 then + + retransmits = retransmits - 1 + + -- Select a different upstream! + upstream = rule.upstreams:get_upstream_round_robin() + addr = upstream:get_addr() + + lua_util.debugm(rule.name, task, '%s: error: %s; retry IP: %s; retries left: %s', + rule.log_prefix, err, addr, retransmits) + + tcp.request({ + task = task, + host = addr:to_string(), + port = addr:get_port(), + upstream = upstream, + timeout = rule['timeout'], + callback = savapi_callback_init, + stop_pattern = { '\n' }, + }) + else + rspamd_logger.errx(task, '%s [%s]: failed to scan, maximum retransmits exceed', rule['symbol'], rule['type']) + common.yield_result(task, rule, 'failed to scan and retransmits exceed', 0.0, 'fail') + end + else + local result = tostring(data) + + -- 100 SAVAPI:4.0 greeting + if string.find(result, '100') then + conn:add_write(savapi_greet1_cb, { string.format('SET PRODUCT %s\n', rule['product_id']) }) + end + end + end + + tcp.request({ + task = task, + host = addr:to_string(), + port = addr:get_port(), + upstream = upstream, + timeout = rule['timeout'], + callback = savapi_callback_init, + stop_pattern = { '\n' }, + }) + end + + if common.condition_check_and_continue(task, content, rule, digest, savapi_check_uncached) then + return + else + savapi_check_uncached() + end + +end + +return { + type = 'antivirus', + description = 'savapi avira antivirus', + configure = savapi_config, + check = savapi_check, + name = N +} diff --git a/lualib/lua_scanners/sophos.lua b/lualib/lua_scanners/sophos.lua new file mode 100644 index 0000000..d9b64f1 --- /dev/null +++ b/lualib/lua_scanners/sophos.lua @@ -0,0 +1,192 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module savapi +-- This module contains avira savapi antivirus access functions +--]] + +local lua_util = require "lua_util" +local tcp = require "rspamd_tcp" +local upstream_list = require "rspamd_upstream_list" +local rspamd_logger = require "rspamd_logger" +local common = require "lua_scanners/common" + +local N = "sophos" + +local default_message = '${SCANNER}: virus found: "${VIRUS}"' + +local function sophos_config(opts) + local sophos_conf = { + name = N, + scan_mime_parts = true, + scan_text_mime = false, + scan_image_mime = false, + default_port = 4010, + timeout = 15.0, + log_clean = false, + retransmits = 2, + cache_expire = 3600, -- expire redis in one hour + message = default_message, + detection_category = "virus", + } + + sophos_conf = lua_util.override_defaults(sophos_conf, opts) + + if not sophos_conf.prefix then + sophos_conf.prefix = 'rs_' .. sophos_conf.name .. '_' + end + + if not sophos_conf.log_prefix then + if sophos_conf.name:lower() == sophos_conf.type:lower() then + sophos_conf.log_prefix = sophos_conf.name + else + sophos_conf.log_prefix = sophos_conf.name .. ' (' .. sophos_conf.type .. ')' + end + end + + if not sophos_conf['servers'] then + rspamd_logger.errx(rspamd_config, 'no servers defined') + + return nil + end + + sophos_conf['upstreams'] = upstream_list.create(rspamd_config, + sophos_conf['servers'], + sophos_conf.default_port) + + if sophos_conf['upstreams'] then + lua_util.add_debug_alias('antivirus', sophos_conf.name) + return sophos_conf + end + + rspamd_logger.errx(rspamd_config, 'cannot parse servers %s', + sophos_conf['servers']) + return nil +end + +local function sophos_check(task, content, digest, rule, maybe_part) + local function sophos_check_uncached () + local upstream = rule.upstreams:get_upstream_round_robin() + local addr = upstream:get_addr() + local retransmits = rule.retransmits + local protocol = 'SSSP/1.0\n' + local streamsize = string.format('SCANDATA %d\n', #content) + local bye = 'BYE\n' + + local function sophos_callback(err, data, conn) + + if err then + -- retry with another upstream until retransmits exceeds + if retransmits > 0 then + + retransmits = retransmits - 1 + + -- Select a different upstream! + upstream = rule.upstreams:get_upstream_round_robin() + addr = upstream:get_addr() + + lua_util.debugm(rule.name, task, '%s: error: %s; retry IP: %s; retries left: %s', + rule.log_prefix, err, addr, retransmits) + + tcp.request({ + task = task, + host = addr:to_string(), + port = addr:get_port(), + upstream = upstream, + timeout = rule['timeout'], + callback = sophos_callback, + data = { protocol, streamsize, content, bye } + }) + else + rspamd_logger.errx(task, '%s [%s]: failed to scan, maximum retransmits exceed', rule['symbol'], rule['type']) + common.yield_result(task, rule, 'failed to scan and retransmits exceed', + 0.0, 'fail', maybe_part) + end + else + data = tostring(data) + lua_util.debugm(rule.name, task, + '%s [%s]: got reply: %s', rule['symbol'], rule['type'], data) + local vname = string.match(data, 'VIRUS (%S+) ') + local cached + if vname then + common.yield_result(task, rule, vname, 1.0, nil, maybe_part) + common.save_cache(task, digest, rule, vname, 1.0, maybe_part) + else + if string.find(data, 'DONE OK') then + if rule['log_clean'] then + rspamd_logger.infox(task, '%s: message or mime_part is clean', rule.log_prefix) + else + lua_util.debugm(rule.name, task, + '%s: message or mime_part is clean', rule.log_prefix) + end + cached = 'OK' + -- not finished - continue + elseif string.find(data, 'ACC') or string.find(data, 'OK SSSP') then + conn:add_read(sophos_callback) + elseif string.find(data, 'FAIL 0212') then + rspamd_logger.warnx(task, 'Message is encrypted (FAIL 0212): %s', data) + common.yield_result(task, rule, 'SAVDI: Message is encrypted (FAIL 0212)', + 0.0, 'encrypted', maybe_part) + cached = 'ENCRYPTED' + elseif string.find(data, 'REJ 4') then + rspamd_logger.warnx(task, 'Message is oversized (REJ 4): %s', data) + common.yield_result(task, rule, 'SAVDI: Message oversized (REJ 4)', + 0.0, 'fail', maybe_part) + -- explicitly set REJ1 message when SAVDIreports a protocol error + elseif string.find(data, 'REJ 1') then + rspamd_logger.errx(task, 'SAVDI (Protocol error (REJ 1)): %s', data) + common.yield_result(task, rule, 'SAVDI: Protocol error (REJ 1)', + 0.0, 'fail', maybe_part) + else + rspamd_logger.errx(task, 'unhandled response: %s', data) + common.yield_result(task, rule, 'unhandled response: ' .. data, + 0.0, 'fail', maybe_part) + end + if cached then + common.save_cache(task, digest, rule, cached, 1.0, maybe_part) + end + end + end + end + + tcp.request({ + task = task, + host = addr:to_string(), + port = addr:get_port(), + upstream = upstream, + timeout = rule['timeout'], + callback = sophos_callback, + data = { protocol, streamsize, content, bye } + }) + end + + if common.condition_check_and_continue(task, content, rule, digest, + sophos_check_uncached, maybe_part) then + return + else + sophos_check_uncached() + end + +end + +return { + type = 'antivirus', + description = 'sophos antivirus', + configure = sophos_config, + check = sophos_check, + name = N +} diff --git a/lualib/lua_scanners/spamassassin.lua b/lualib/lua_scanners/spamassassin.lua new file mode 100644 index 0000000..f425924 --- /dev/null +++ b/lualib/lua_scanners/spamassassin.lua @@ -0,0 +1,213 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> +Copyright (c) 2019, Carsten Rosenberg <c.rosenberg@heinlein-support.de> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module spamassassin +-- This module contains spamd access functions. +--]] + +local lua_util = require "lua_util" +local tcp = require "rspamd_tcp" +local upstream_list = require "rspamd_upstream_list" +local rspamd_logger = require "rspamd_logger" +local common = require "lua_scanners/common" + +local N = 'spamassassin' + +local function spamassassin_config(opts) + + local spamassassin_conf = { + N = N, + scan_mime_parts = false, + scan_text_mime = false, + scan_image_mime = false, + default_port = 783, + timeout = 15.0, + log_clean = false, + retransmits = 2, + cache_expire = 3600, -- expire redis in one hour + symbol = "SPAMD", + message = '${SCANNER}: Spamassassin bulk message found: "${VIRUS}"', + detection_category = "spam", + default_score = 1, + action = false, + extended = false, + symbol_type = 'postfilter', + dynamic_scan = true, + } + + spamassassin_conf = lua_util.override_defaults(spamassassin_conf, opts) + + if not spamassassin_conf.prefix then + spamassassin_conf.prefix = 'rs_' .. spamassassin_conf.name .. '_' + end + + if not spamassassin_conf.log_prefix then + if spamassassin_conf.name:lower() == spamassassin_conf.type:lower() then + spamassassin_conf.log_prefix = spamassassin_conf.name + else + spamassassin_conf.log_prefix = spamassassin_conf.name .. ' (' .. spamassassin_conf.type .. ')' + end + end + + if not spamassassin_conf.servers then + rspamd_logger.errx(rspamd_config, 'no servers defined') + + return nil + end + + spamassassin_conf.upstreams = upstream_list.create(rspamd_config, + spamassassin_conf.servers, + spamassassin_conf.default_port) + + if spamassassin_conf.upstreams then + lua_util.add_debug_alias('external_services', spamassassin_conf.N) + return spamassassin_conf + end + + rspamd_logger.errx(rspamd_config, 'cannot parse servers %s', + spamassassin_conf.servers) + return nil +end + +local function spamassassin_check(task, content, digest, rule) + local function spamassassin_check_uncached () + local upstream = rule.upstreams:get_upstream_round_robin() + local addr = upstream:get_addr() + local retransmits = rule.retransmits + + -- Build the spamd query + -- https://svn.apache.org/repos/asf/spamassassin/trunk/spamd/PROTOCOL + local request_data = { + "HEADERS SPAMC/1.5\r\n", + "User: root\r\n", + "Content-length: " .. #content .. "\r\n", + "\r\n", + content, + } + + local function spamassassin_callback(err, data) + + local function spamassassin_requery(error) + + -- retry with another upstream until retransmits exceeds + if retransmits > 0 then + + retransmits = retransmits - 1 + + lua_util.debugm(rule.N, task, '%s: Request Error: %s - retries left: %s', + rule.log_prefix, error, retransmits) + + -- Select a different upstream! + upstream = rule.upstreams:get_upstream_round_robin() + addr = upstream:get_addr() + + lua_util.debugm(rule.N, task, '%s: retry IP: %s:%s', + rule.log_prefix, addr, addr:get_port()) + + tcp.request({ + task = task, + host = addr:to_string(), + port = addr:get_port(), + upstream = upstream, + timeout = rule['timeout'], + data = request_data, + callback = spamassassin_callback, + }) + else + rspamd_logger.errx(task, '%s: failed to scan, maximum retransmits ' .. + 'exceed - err: %s', rule.log_prefix, error) + common.yield_result(task, rule, 'failed to scan and retransmits exceed: ' .. error, 0.0, 'fail') + end + end + + if err then + + spamassassin_requery(err) + + else + --lua_util.debugm(rule.N, task, '%s: returned result: %s', rule.log_prefix, data) + + --[[ + patterns tested against Spamassassin 3.4.6 + + X-Spam-Status: No, score=1.1 required=5.0 tests=HTML_MESSAGE,MIME_HTML_ONLY, + TVD_RCVD_SPACE_BRACKET,UNPARSEABLE_RELAY autolearn=no + autolearn_force=no version=3.4.6 + ]] -- + local header = string.gsub(tostring(data), "[\r\n]+[\t ]", " ") + --lua_util.debugm(rule.N, task, '%s: returned header: %s', rule.log_prefix, header) + + local symbols = "" + local spam_score = 0 + for s in header:gmatch("[^\r\n]+") do + if string.find(s, 'X%-Spam%-Status: %S+, score') then + local pattern_symbols = "X%-Spam%-Status: %S+, score%=([%-%d%.]+)%s.*tests%=(.*,?)(%s*%S+)%sautolearn.*" + spam_score = string.gsub(s, pattern_symbols, "%1") + symbols = string.gsub(s, pattern_symbols, "%2%3") + symbols = string.gsub(symbols, "%s", "") + end + end + + lua_util.debugm(rule.N, task, '%s: spam_score: %s, symbols: %s, int spam_score: |%s|, type spam_score: |%s|', + rule.log_prefix, spam_score, symbols, tonumber(spam_score), type(spam_score)) + + if tonumber(spam_score) > 0 and #symbols > 0 and symbols ~= "none" then + + if rule.extended == false then + common.yield_result(task, rule, symbols, spam_score) + common.save_cache(task, digest, rule, symbols, spam_score) + else + local symbols_table = lua_util.str_split(symbols, ",") + lua_util.debugm(rule.N, task, '%s: returned symbols as table: %s', rule.log_prefix, symbols_table) + + common.yield_result(task, rule, symbols_table, spam_score) + common.save_cache(task, digest, rule, symbols_table, spam_score) + end + else + common.save_cache(task, digest, rule, 'OK') + common.log_clean(task, rule, 'no spam detected - spam score: ' .. spam_score .. ', symbols: ' .. symbols) + end + end + end + + tcp.request({ + task = task, + host = addr:to_string(), + port = addr:get_port(), + upstream = upstream, + timeout = rule['timeout'], + data = request_data, + callback = spamassassin_callback, + }) + end + + if common.condition_check_and_continue(task, content, rule, digest, spamassassin_check_uncached) then + return + else + spamassassin_check_uncached() + end + +end + +return { + type = { N, 'spam', 'scanner' }, + description = 'spamassassin spam scanner', + configure = spamassassin_config, + check = spamassassin_check, + name = N +} diff --git a/lualib/lua_scanners/vadesecure.lua b/lualib/lua_scanners/vadesecure.lua new file mode 100644 index 0000000..826573a --- /dev/null +++ b/lualib/lua_scanners/vadesecure.lua @@ -0,0 +1,351 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module vadesecure +-- This module contains Vadesecure Filterd interface +--]] + +local lua_util = require "lua_util" +local http = require "rspamd_http" +local upstream_list = require "rspamd_upstream_list" +local rspamd_logger = require "rspamd_logger" +local ucl = require "ucl" +local common = require "lua_scanners/common" + +local N = 'vadesecure' + +local function vade_config(opts) + + local vade_conf = { + name = N, + default_port = 23808, + url = '/api/v1/scan', + use_https = false, + timeout = 5.0, + log_clean = false, + retransmits = 1, + cache_expire = 7200, -- expire redis in 2h + message = '${SCANNER}: spam message found: "${VIRUS}"', + detection_category = "hash", + default_score = 1, + action = false, + log_spamcause = true, + symbol_fail = 'VADE_FAIL', + symbol = 'VADE_CHECK', + settings_outbound = nil, -- Set when there is a settings id for outbound messages + symbols = { + clean = { + symbol = 'VADE_CLEAN', + score = -0.5, + description = 'VadeSecure decided message to be clean' + }, + spam = { + high = { + symbol = 'VADE_SPAM_HIGH', + score = 8.0, + description = 'VadeSecure decided message to be clearly spam' + }, + medium = { + symbol = 'VADE_SPAM_MEDIUM', + score = 5.0, + description = 'VadeSecure decided message to be highly likely spam' + }, + low = { + symbol = 'VADE_SPAM_LOW', + score = 2.0, + description = 'VadeSecure decided message to be likely spam' + }, + }, + malware = { + symbol = 'VADE_MALWARE', + score = 8.0, + description = 'VadeSecure decided message to be malware' + }, + scam = { + symbol = 'VADE_SCAM', + score = 7.0, + description = 'VadeSecure decided message to be scam' + }, + phishing = { + symbol = 'VADE_PHISHING', + score = 8.0, + description = 'VadeSecure decided message to be phishing' + }, + commercial = { + symbol = 'VADE_COMMERCIAL', + score = 0.0, + description = 'VadeSecure decided message to be commercial message' + }, + community = { + symbol = 'VADE_COMMUNITY', + score = 0.0, + description = 'VadeSecure decided message to be community message' + }, + transactional = { + symbol = 'VADE_TRANSACTIONAL', + score = 0.0, + description = 'VadeSecure decided message to be transactional message' + }, + suspect = { + symbol = 'VADE_SUSPECT', + score = 3.0, + description = 'VadeSecure decided message to be suspicious message' + }, + bounce = { + symbol = 'VADE_BOUNCE', + score = 0.0, + description = 'VadeSecure decided message to be bounce message' + }, + other = 'VADE_OTHER', + } + } + + vade_conf = lua_util.override_defaults(vade_conf, opts) + + if not vade_conf.prefix then + vade_conf.prefix = 'rs_' .. vade_conf.name .. '_' + end + + if not vade_conf.log_prefix then + if vade_conf.name:lower() == vade_conf.type:lower() then + vade_conf.log_prefix = vade_conf.name + else + vade_conf.log_prefix = vade_conf.name .. ' (' .. vade_conf.type .. ')' + end + end + + if not vade_conf.servers and vade_conf.socket then + vade_conf.servers = vade_conf.socket + end + + if not vade_conf.servers then + rspamd_logger.errx(rspamd_config, 'no servers defined') + + return nil + end + + vade_conf.upstreams = upstream_list.create(rspamd_config, + vade_conf.servers, + vade_conf.default_port) + + if vade_conf.upstreams then + lua_util.add_debug_alias('external_services', vade_conf.name) + return vade_conf + end + + rspamd_logger.errx(rspamd_config, 'cannot parse servers %s', + vade_conf['servers']) + return nil +end + +local function vade_check(task, content, digest, rule, maybe_part) + local function vade_check_uncached() + local function vade_url(addr) + local url + if rule.use_https then + url = string.format('https://%s:%d%s', tostring(addr), + rule.default_port, rule.url) + else + url = string.format('http://%s:%d%s', tostring(addr), + rule.default_port, rule.url) + end + + return url + end + + local upstream = rule.upstreams:get_upstream_round_robin() + local addr = upstream:get_addr() + local retransmits = rule.retransmits + + local url = vade_url(addr) + local hdrs = {} + + local helo = task:get_helo() + if helo then + hdrs['X-Helo'] = helo + end + local mail_from = task:get_from('smtp') or {} + if mail_from[1] and #mail_from[1].addr > 1 then + hdrs['X-Mailfrom'] = mail_from[1].addr + end + + local rcpt_to = task:get_recipients('smtp') + if rcpt_to then + hdrs['X-Rcptto'] = {} + for _, r in ipairs(rcpt_to) do + table.insert(hdrs['X-Rcptto'], r.addr) + end + end + + local fip = task:get_from_ip() + if fip and fip:is_valid() then + hdrs['X-Inet'] = tostring(fip) + end + + if rule.settings_outbound then + local settings_id = task:get_settings_id() + + if settings_id then + local lua_settings = require "lua_settings" + -- Convert to string + settings_id = lua_settings.settings_by_id(settings_id) + + if settings_id then + settings_id = settings_id.name or '' + + if settings_id == rule.settings_outbound then + lua_util.debugm(rule.name, task, '%s settings has matched outbound', + settings_id) + hdrs['X-Params'] = 'mode=smtpout' + end + end + end + end + + local request_data = { + task = task, + url = url, + body = task:get_content(), + headers = hdrs, + timeout = rule.timeout, + } + + local function vade_callback(http_err, code, body, headers) + + local function vade_requery() + -- set current upstream to fail because an error occurred + upstream:fail() + + -- retry with another upstream until retransmits exceeds + if retransmits > 0 then + + retransmits = retransmits - 1 + + lua_util.debugm(rule.name, task, + '%s: Request Error: %s - retries left: %s', + rule.log_prefix, http_err, retransmits) + + -- Select a different upstream! + upstream = rule.upstreams:get_upstream_round_robin() + addr = upstream:get_addr() + url = vade_url(addr) + + lua_util.debugm(rule.name, task, '%s: retry IP: %s:%s', + rule.log_prefix, addr, addr:get_port()) + request_data.url = url + + http.request(request_data) + else + rspamd_logger.errx(task, '%s: failed to scan, maximum retransmits ' .. + 'exceed', rule.log_prefix) + task:insert_result(rule['symbol_fail'], 0.0, 'failed to scan and ' .. + 'retransmits exceed') + end + end + + if http_err then + vade_requery() + else + -- Parse the response + if upstream then + upstream:ok() + end + if code ~= 200 then + rspamd_logger.errx(task, 'invalid HTTP code: %s, body: %s, headers: %s', code, body, headers) + task:insert_result(rule.symbol_fail, 1.0, 'Bad HTTP code: ' .. code) + return + end + local parser = ucl.parser() + local ret, err = parser:parse_string(body) + if not ret then + rspamd_logger.errx(task, 'vade: bad response body (raw): %s', body) + task:insert_result(rule.symbol_fail, 1.0, 'Parser error: ' .. err) + return + end + local obj = parser:get_object() + local verdict = obj.verdict + if not verdict then + rspamd_logger.errx(task, 'vade: bad response JSON (no verdict): %s', obj) + task:insert_result(rule.symbol_fail, 1.0, 'No verdict/unknown verdict') + return + end + local vparts = lua_util.str_split(verdict, ":") + verdict = table.remove(vparts, 1) or verdict + + local sym = rule.symbols[verdict] + if not sym then + sym = rule.symbols.other + end + + if not sym.symbol then + -- Subcategory match + local lvl = 'low' + if vparts and vparts[1] then + lvl = vparts[1] + end + + if sym[lvl] then + sym = sym[lvl] + else + sym = rule.symbols.other + end + end + + local opts = {} + if obj.score then + table.insert(opts, 'score=' .. obj.score) + end + if obj.elapsed then + table.insert(opts, 'elapsed=' .. obj.elapsed) + end + + if rule.log_spamcause and obj.spamcause then + rspamd_logger.infox(task, 'vadesecure verdict="%s", score=%s, spamcause="%s", message-id="%s"', + verdict, obj.score, obj.spamcause, task:get_message_id()) + else + lua_util.debugm(rule.name, task, 'vadesecure returned verdict="%s", score=%s, spamcause="%s"', + verdict, obj.score, obj.spamcause) + end + + if #vparts > 0 then + table.insert(opts, 'verdict=' .. verdict .. ';' .. table.concat(vparts, ':')) + end + + task:insert_result(sym.symbol, 1.0, opts) + end + end + + request_data.callback = vade_callback + http.request(request_data) + end + + if common.condition_check_and_continue(task, content, rule, digest, + vade_check_uncached, maybe_part) then + return + else + vade_check_uncached() + end + +end + +return { + type = { 'vadesecure', 'scanner' }, + description = 'VadeSecure Filterd interface', + configure = vade_config, + check = vade_check, + name = N +} diff --git a/lualib/lua_scanners/virustotal.lua b/lualib/lua_scanners/virustotal.lua new file mode 100644 index 0000000..d937c41 --- /dev/null +++ b/lualib/lua_scanners/virustotal.lua @@ -0,0 +1,214 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module virustotal +-- This module contains Virustotal integration support +-- https://www.virustotal.com/ +--]] + +local lua_util = require "lua_util" +local http = require "rspamd_http" +local rspamd_cryptobox_hash = require "rspamd_cryptobox_hash" +local rspamd_logger = require "rspamd_logger" +local common = require "lua_scanners/common" + +local N = 'virustotal' + +local function virustotal_config(opts) + + local default_conf = { + name = N, + url = 'https://www.virustotal.com/vtapi/v2/file', + timeout = 5.0, + log_clean = false, + retransmits = 1, + cache_expire = 7200, -- expire redis in 2h + message = '${SCANNER}: spam message found: "${VIRUS}"', + detection_category = "virus", + default_score = 1, + action = false, + scan_mime_parts = true, + scan_text_mime = false, + scan_image_mime = false, + apikey = nil, -- Required to set by user + -- Specific for virustotal + minimum_engines = 3, -- Minimum required to get scored + full_score_engines = 7, -- After this number we set max score + } + + default_conf = lua_util.override_defaults(default_conf, opts) + + if not default_conf.prefix then + default_conf.prefix = 'rs_' .. default_conf.name .. '_' + end + + if not default_conf.log_prefix then + if default_conf.name:lower() == default_conf.type:lower() then + default_conf.log_prefix = default_conf.name + else + default_conf.log_prefix = default_conf.name .. ' (' .. default_conf.type .. ')' + end + end + + if not default_conf.apikey then + rspamd_logger.errx(rspamd_config, 'no apikey defined for virustotal, disable checks') + + return nil + end + + lua_util.add_debug_alias('external_services', default_conf.name) + return default_conf +end + +local function virustotal_check(task, content, digest, rule, maybe_part) + local function virustotal_check_uncached() + local function make_url(hash) + return string.format('%s/report?apikey=%s&resource=%s', + rule.url, rule.apikey, hash) + end + + local hash = rspamd_cryptobox_hash.create_specific('md5') + hash:update(content) + hash = hash:hex() + + local url = make_url(hash) + lua_util.debugm(N, task, "send request %s", url) + local request_data = { + task = task, + url = url, + timeout = rule.timeout, + } + + local function vt_http_callback(http_err, code, body, headers) + if http_err then + rspamd_logger.errx(task, 'HTTP error: %s, body: %s, headers: %s', http_err, body, headers) + else + local cached + local dyn_score + -- Parse the response + if code ~= 200 then + if code == 404 then + cached = 'OK' + if rule['log_clean'] then + rspamd_logger.infox(task, '%s: hash %s clean (not found)', + rule.log_prefix, hash) + else + lua_util.debugm(rule.name, task, '%s: hash %s clean (not found)', + rule.log_prefix, hash) + end + elseif code == 204 then + -- Request rate limit exceeded + rspamd_logger.infox(task, 'virustotal request rate limit exceeded') + task:insert_result(rule.symbol_fail, 1.0, 'rate limit exceeded') + return + else + rspamd_logger.errx(task, 'invalid HTTP code: %s, body: %s, headers: %s', code, body, headers) + task:insert_result(rule.symbol_fail, 1.0, 'Bad HTTP code: ' .. code) + return + end + else + local ucl = require "ucl" + local parser = ucl.parser() + local res, json_err = parser:parse_string(body) + + lua_util.debugm(rule.name, task, '%s: got reply data: "%s"', + rule.log_prefix, body) + + if res then + local obj = parser:get_object() + if not obj.positives or type(obj.positives) ~= 'number' then + if obj.response_code then + if obj.response_code == 0 then + cached = 'OK' + if rule['log_clean'] then + rspamd_logger.infox(task, '%s: hash %s clean (not found)', + rule.log_prefix, hash) + else + lua_util.debugm(rule.name, task, '%s: hash %s clean (not found)', + rule.log_prefix, hash) + end + else + rspamd_logger.errx(task, 'invalid JSON reply: %s, body: %s, headers: %s', + 'bad response code: ' .. tostring(obj.response_code), body, headers) + task:insert_result(rule.symbol_fail, 1.0, 'Bad JSON reply: no `positives` element') + return + end + else + rspamd_logger.errx(task, 'invalid JSON reply: %s, body: %s, headers: %s', + 'no response_code', body, headers) + task:insert_result(rule.symbol_fail, 1.0, 'Bad JSON reply: no `positives` element') + return + end + else + if obj.positives < rule.minimum_engines then + lua_util.debugm(rule.name, task, '%s: hash %s has not enough hits: %s where %s is min', + rule.log_prefix, obj.positives, rule.minimum_engines) + -- TODO: add proper hashing! + cached = 'OK' + else + if obj.positives > rule.full_score_engines then + dyn_score = 1.0 + else + local norm_pos = obj.positives - rule.minimum_engines + dyn_score = norm_pos / (rule.full_score_engines - rule.minimum_engines) + end + + if dyn_score < 0 or dyn_score > 1 then + dyn_score = 1.0 + end + local sopt = string.format("%s:%s/%s", + hash, obj.positives, obj.total) + common.yield_result(task, rule, sopt, dyn_score, nil, maybe_part) + cached = sopt + end + end + else + -- not res + rspamd_logger.errx(task, 'invalid JSON reply: %s, body: %s, headers: %s', + json_err, body, headers) + task:insert_result(rule.symbol_fail, 1.0, 'Bad JSON reply: ' .. json_err) + return + end + end + + if cached then + common.save_cache(task, digest, rule, cached, dyn_score, maybe_part) + end + end + end + + request_data.callback = vt_http_callback + http.request(request_data) + end + + if common.condition_check_and_continue(task, content, rule, digest, + virustotal_check_uncached) then + return + else + + virustotal_check_uncached() + end + +end + +return { + type = 'antivirus', + description = 'Virustotal integration', + configure = virustotal_config, + check = virustotal_check, + name = N +} diff --git a/lualib/lua_selectors/common.lua b/lualib/lua_selectors/common.lua new file mode 100644 index 0000000..7b2372d --- /dev/null +++ b/lualib/lua_selectors/common.lua @@ -0,0 +1,95 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local ts = require("tableshape").types +local exports = {} +local cr_hash = require 'rspamd_cryptobox_hash' + +local blake2b_key = cr_hash.create_specific('blake2'):update('rspamd'):bin() + +local function digest_schema() + return { ts.one_of { 'hex', 'base32', 'bleach32', 'rbase32', 'base64' }:is_optional(), + ts.one_of { 'blake2', 'sha256', 'sha1', 'sha512', 'md5' }:is_optional() } +end + +exports.digest_schema = digest_schema + +local function create_raw_digest(data, args) + local ht = args[2] or 'blake2' + + local h + + if ht == 'blake2' then + -- Hack to be compatible with various 'get_digest' methods + h = cr_hash.create_keyed(blake2b_key):update(data) + else + h = cr_hash.create_specific(ht):update(data) + end + + return h +end + +local function encode_digest(h, args) + local encoding = args[1] or 'hex' + + local s + if encoding == 'hex' then + s = h:hex() + elseif encoding == 'base32' then + s = h:base32() + elseif encoding == 'bleach32' then + s = h:base32('bleach') + elseif encoding == 'rbase32' then + s = h:base32('rfc') + elseif encoding == 'base64' then + s = h:base64() + end + + return s +end + +local function create_digest(data, args) + local h = create_raw_digest(data, args) + return encode_digest(h, args) +end + +local function get_cached_or_raw_digest(task, idx, mime_part, args) + if #args == 0 then + -- Optimise as we already have this hash in the API + return mime_part:get_digest() + end + + local ht = args[2] or 'blake2' + local cache_key = 'mp_digest_' .. ht .. tostring(idx) + + local cached = task:cache_get(cache_key) + + if cached then + return encode_digest(cached, args) + end + + local h = create_raw_digest(mime_part:get_content('raw_parsed'), args) + task:cache_set(cache_key, h) + + return encode_digest(h, args) +end + +exports.create_digest = create_digest +exports.create_raw_digest = create_raw_digest +exports.get_cached_or_raw_digest = get_cached_or_raw_digest +exports.encode_digest = encode_digest + +return exports
\ No newline at end of file diff --git a/lualib/lua_selectors/extractors.lua b/lualib/lua_selectors/extractors.lua new file mode 100644 index 0000000..81dfa9d --- /dev/null +++ b/lualib/lua_selectors/extractors.lua @@ -0,0 +1,565 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local fun = require 'fun' +local meta_functions = require "lua_meta" +local lua_util = require "lua_util" +local rspamd_url = require "rspamd_url" +local common = require "lua_selectors/common" +local ts = require("tableshape").types +local maps = require "lua_selectors/maps" +local E = {} +local M = "selectors" + +local url_flags_ts = ts.array_of(ts.one_of(lua_util.keys(rspamd_url.flags))):is_optional() + +local function gen_exclude_flags_filter(exclude_flags) + return function(u) + local got_flags = u:get_flags() + for _, flag in ipairs(exclude_flags) do + if got_flags[flag] then + return false + end + end + return true + end +end + +local extractors = { + -- Plain id function + ['id'] = { + ['get_value'] = function(_, args) + if args[1] then + return args[1], 'string' + end + + return '', 'string' + end, + ['description'] = [[Return value from function's argument or an empty string, +For example, `id('Something')` returns a string 'Something']], + ['args_schema'] = { ts.string:is_optional() } + }, + -- Similar but for making lists + ['list'] = { + ['get_value'] = function(_, args) + if args[1] then + return fun.map(tostring, args), 'string_list' + end + + return {}, 'string_list' + end, + ['description'] = [[Return a list from function's arguments or an empty list, +For example, `list('foo', 'bar')` returns a list {'foo', 'bar'}]], + }, + -- Get source IP address + ['ip'] = { + ['get_value'] = function(task) + local ip = task:get_ip() + if ip and ip:is_valid() then + return ip, 'userdata' + end + return nil + end, + ['description'] = [[Get source IP address]], + }, + -- Get MIME from + ['from'] = { + ['get_value'] = function(task, args) + local from + if type(args) == 'table' then + from = task:get_from(args) + else + from = task:get_from(0) + end + if ((from or E)[1] or E).addr then + return from[1], 'table' + end + return nil + end, + ['description'] = [[Get MIME or SMTP from (e.g. `from('smtp')` or `from('mime')`, +uses any type by default)]], + }, + ['rcpts'] = { + ['get_value'] = function(task, args) + local rcpts + if type(args) == 'table' then + rcpts = task:get_recipients(args) + else + rcpts = task:get_recipients(0) + end + if ((rcpts or E)[1] or E).addr then + return rcpts, 'table_list' + end + return nil + end, + ['description'] = [[Get MIME or SMTP rcpts (e.g. `rcpts('smtp')` or `rcpts('mime')`, +uses any type by default)]], + }, + -- Get country (ASN module must be executed first) + ['country'] = { + ['get_value'] = function(task) + local country = task:get_mempool():get_variable('country') + if not country then + return nil + else + return country, 'string' + end + end, + ['description'] = [[Get country (ASN module must be executed first)]], + }, + -- Get ASN number + ['asn'] = { + ['type'] = 'string', + ['get_value'] = function(task) + local asn = task:get_mempool():get_variable('asn') + if not asn then + return nil + else + return asn, 'string' + end + end, + ['description'] = [[Get AS number (ASN module must be executed first)]], + }, + -- Get authenticated username + ['user'] = { + ['get_value'] = function(task) + local auser = task:get_user() + if not auser then + return nil + else + return auser, 'string' + end + end, + ['description'] = 'Get authenticated user name', + }, + -- Get principal recipient + ['to'] = { + ['get_value'] = function(task) + return task:get_principal_recipient(), 'string' + end, + ['description'] = 'Get principal recipient', + }, + -- Get content digest + ['digest'] = { + ['get_value'] = function(task) + return task:get_digest(), 'string' + end, + ['description'] = 'Get content digest', + }, + -- Get list of all attachments digests + ['attachments'] = { + ['get_value'] = function(task, args) + local parts = task:get_parts() or E + local digests = {} + for i, p in ipairs(parts) do + if p:is_attachment() then + table.insert(digests, common.get_cached_or_raw_digest(task, i, p, args)) + end + end + + if #digests > 0 then + return digests, 'string_list' + end + + return nil + end, + ['description'] = [[Get list of all attachments digests. +The first optional argument is encoding (`hex`, `base32` (and forms `bleach32`, `rbase32`), `base64`), +the second optional argument is optional hash type (`blake2`, `sha256`, `sha1`, `sha512`, `md5`)]], + ['args_schema'] = common.digest_schema() + + }, + -- Get all attachments files + ['files'] = { + ['get_value'] = function(task) + local parts = task:get_parts() or E + local files = {} + + for _, p in ipairs(parts) do + local fname = p:get_filename() + if fname then + table.insert(files, fname) + end + end + + if #files > 0 then + return files, 'string_list' + end + + return nil + end, + ['description'] = 'Get all attachments files', + }, + -- Get languages for text parts + ['languages'] = { + ['get_value'] = function(task) + local text_parts = task:get_text_parts() or E + local languages = {} + + for _, p in ipairs(text_parts) do + local lang = p:get_language() + if lang then + table.insert(languages, lang) + end + end + + if #languages > 0 then + return languages, 'string_list' + end + + return nil + end, + ['description'] = 'Get languages for text parts', + }, + -- Get helo value + ['helo'] = { + ['get_value'] = function(task) + return task:get_helo(), 'string' + end, + ['description'] = 'Get helo value', + }, + -- Get header with the name that is expected as an argument. Returns list of + -- headers with this name + ['header'] = { + ['get_value'] = function(task, args) + local strong = false + if args[2] then + if args[2]:match('strong') then + strong = true + end + + if args[2]:match('full') then + return task:get_header_full(args[1], strong), 'table_list' + end + + return task:get_header(args[1], strong), 'string' + else + return task:get_header(args[1]), 'string' + end + end, + ['description'] = [[Get header with the name that is expected as an argument. +The optional second argument accepts list of flags: + - `full`: returns all headers with this name with all data (like task:get_header_full()) + - `strong`: use case sensitive match when matching header's name]], + ['args_schema'] = { ts.string, + (ts.pattern("strong") + ts.pattern("full")):is_optional() } + }, + -- Get list of received headers (returns list of tables) + ['received'] = { + ['get_value'] = function(task, args) + local rh = task:get_received_headers() + if not rh[1] then + return nil + end + if args[1] then + return fun.map(function(r) + return r[args[1]] + end, rh), 'string_list' + end + + return rh, 'table_list' + end, + ['description'] = [[Get list of received headers. +If no arguments specified, returns list of tables. Otherwise, selects a specific element, +e.g. `by_hostname`]], + }, + -- Get all urls + ['urls'] = { + ['get_value'] = function(task, args) + local urls = task:get_urls() + if not urls[1] then + return nil + end + if args[1] then + return fun.map(function(r) + return r[args[1]](r) + end, urls), 'string_list' + end + return urls, 'userdata_list' + end, + ['description'] = [[Get list of all urls. +If no arguments specified, returns list of url objects. Otherwise, calls a specific method, +e.g. `get_tld`]], + }, + -- Get specific urls + ['specific_urls'] = { + ['get_value'] = function(task, args) + local params = args[1] or {} + params.task = task + params.no_cache = true + if params.exclude_flags then + params.filter = gen_exclude_flags_filter(params.exclude_flags) + end + local urls = lua_util.extract_specific_urls(params) + if not urls[1] then + return nil + end + return urls, 'userdata_list' + end, + ['description'] = [[Get most specific urls. Arguments are equal to the Lua API function]], + ['args_schema'] = { ts.shape { + limit = ts.number + ts.string / tonumber, + esld_limit = (ts.number + ts.string / tonumber):is_optional(), + exclude_flags = url_flags_ts, + flags = url_flags_ts, + flags_mode = ts.one_of { 'explicit' }:is_optional(), + prefix = ts.string:is_optional(), + need_content = (ts.boolean + ts.string / lua_util.toboolean):is_optional(), + need_emails = (ts.boolean + ts.string / lua_util.toboolean):is_optional(), + need_images = (ts.boolean + ts.string / lua_util.toboolean):is_optional(), + ignore_redirected = (ts.boolean + ts.string / lua_util.toboolean):is_optional(), + } } + }, + ['specific_urls_filter_map'] = { + ['get_value'] = function(task, args) + local map = maps[args[1]] + if not map then + lua_util.debugm(M, "invalid/unknown map: %s", args[1]) + end + local params = args[2] or {} + params.task = task + params.no_cache = true + if params.exclude_flags then + params.filter = gen_exclude_flags_filter(params.exclude_flags) + end + local urls = lua_util.extract_specific_urls(params) + if not urls[1] then + return nil + end + return fun.filter(function(u) + return map:get_key(tostring(u)) + end, urls), 'userdata_list' + end, + ['description'] = [[Get most specific urls, filtered by some map. Arguments are equal to the Lua API function]], + ['args_schema'] = { ts.string, ts.shape { + limit = ts.number + ts.string / tonumber, + esld_limit = (ts.number + ts.string / tonumber):is_optional(), + exclude_flags = url_flags_ts, + flags = url_flags_ts, + flags_mode = ts.one_of { 'explicit' }:is_optional(), + prefix = ts.string:is_optional(), + need_content = (ts.boolean + ts.string / lua_util.toboolean):is_optional(), + need_emails = (ts.boolean + ts.string / lua_util.toboolean):is_optional(), + need_images = (ts.boolean + ts.string / lua_util.toboolean):is_optional(), + ignore_redirected = (ts.boolean + ts.string / lua_util.toboolean):is_optional(), + } } + }, + -- URLs filtered by flags + ['urls_filtered'] = { + ['get_value'] = function(task, args) + local urls = task:get_urls_filtered(args[1], args[2]) + if not urls[1] then + return nil + end + return urls, 'userdata_list' + end, + ['description'] = [[Get list of all urls filtered by flags_include/exclude +(see rspamd_task:get_urls_filtered for description)]], + ['args_schema'] = { ts.array_of { + url_flags_ts:is_optional(), url_flags_ts:is_optional() + } } + }, + -- Get all emails + ['emails'] = { + ['get_value'] = function(task, args) + local urls = task:get_emails() + if not urls[1] then + return nil + end + if args[1] then + return fun.map(function(r) + return r[args[1]](r) + end, urls), 'string_list' + end + return urls, 'userdata_list' + end, + ['description'] = [[Get list of all emails. +If no arguments specified, returns list of url objects. Otherwise, calls a specific method, +e.g. `get_user`]], + }, + -- Get specific pool var. The first argument must be variable name, + -- the second argument is optional and defines the type (string by default) + ['pool_var'] = { + ['get_value'] = function(task, args) + local type = args[2] or 'string' + return task:get_mempool():get_variable(args[1], type), (type) + end, + ['description'] = [[Get specific pool var. The first argument must be variable name, +the second argument is optional and defines the type (string by default)]], + ['args_schema'] = { ts.string, ts.string:is_optional() } + }, + -- Get value of specific key from task cache + ['task_cache'] = { + ['get_value'] = function(task, args) + local val = task:cache_get(args[1]) + if not val then + return + end + if type(val) == 'table' then + if not val[1] then + return + end + return val, 'string_list' + end + return val, 'string' + end, + ['description'] = [[Get value of specific key from task cache. The first argument must be +the key name]], + ['args_schema'] = { ts.string } + }, + -- Get specific HTTP request header. The first argument must be header name. + ['request_header'] = { + ['get_value'] = function(task, args) + local hdr = task:get_request_header(args[1]) + if hdr then + return hdr, 'string' + end + + return nil + end, + ['description'] = [[Get specific HTTP request header. +The first argument must be header name.]], + ['args_schema'] = { ts.string } + }, + -- Get task date, optionally formatted + ['time'] = { + ['get_value'] = function(task, args) + local what = args[1] or 'message' + local dt = task:get_date { format = what, gmt = true } + + if dt then + if args[2] then + -- Should be in format !xxx, as dt is in GMT + return os.date(args[2], dt), 'string' + end + + return tostring(dt), 'string' + end + + return nil + end, + ['description'] = [[Get task timestamp. The first argument is type: + - `connect`: connection timestamp (default) + - `message`: timestamp as defined by `Date` header + + The second argument is optional time format, see [os.date](http://pgl.yoyo.org/luai/i/os.date) description]], + ['args_schema'] = { ts.one_of { 'connect', 'message' }:is_optional(), + ts.string:is_optional() } + }, + -- Get text words from a message + ['words'] = { + ['get_value'] = function(task, args) + local how = args[1] or 'stem' + local tp = task:get_text_parts() + + if tp then + local rtype = 'string_list' + if how == 'full' then + rtype = 'table_list' + end + + return lua_util.flatten( + fun.map(function(p) + return p:get_words(how) + end, tp)), rtype + end + + return nil + end, + ['description'] = [[Get words from text parts + - `stem`: stemmed words (default) + - `raw`: raw words + - `norm`: normalised words (lowercased) + - `full`: list of tables + ]], + ['args_schema'] = { ts.one_of { 'stem', 'raw', 'norm', 'full' }:is_optional() }, + }, + -- Get queue ID + ['queueid'] = { + ['get_value'] = function(task) + local queueid = task:get_queue_id() + if queueid then + return queueid, 'string' + end + return nil + end, + ['description'] = [[Get queue ID]], + }, + -- Get ID of the task being processed + ['uid'] = { + ['get_value'] = function(task) + local uid = task:get_uid() + if uid then + return uid, 'string' + end + return nil + end, + ['description'] = [[Get ID of the task being processed]], + }, + -- Get message ID of the task being processed + ['messageid'] = { + ['get_value'] = function(task) + local mid = task:get_message_id() + if mid then + return mid, 'string' + end + return nil + end, + ['description'] = [[Get message ID]], + }, + -- Get specific symbol + ['symbol'] = { + ['get_value'] = function(task, args) + local symbol = task:get_symbol(args[1], args[2]) + if symbol then + return symbol[1], 'table' + end + end, + ['description'] = 'Get specific symbol. The first argument must be the symbol name. ' .. + 'The second argument is an optional shadow result name. ' .. + 'Returns the symbol table. See task:get_symbol()', + ['args_schema'] = { ts.string, ts.string:is_optional() } + }, + -- Get full scan result + ['scan_result'] = { + ['get_value'] = function(task, args) + local res = task:get_metric_result(args[1]) + if res then + return res, 'table' + end + end, + ['description'] = 'Get full scan result (either default or shadow if shadow result name is specified)' .. + 'Returns the result table. See task:get_metric_result()', + ['args_schema'] = { ts.string:is_optional() } + }, + -- Get list of metatokens as strings + ['metatokens'] = { + ['get_value'] = function(task) + local tokens = meta_functions.gen_metatokens(task) + if not tokens[1] then + return nil + end + local res = {} + for _, t in ipairs(tokens) do + table.insert(res, tostring(t)) + end + return res, 'string_list' + end, + ['description'] = 'Get metatokens for a message as strings', + }, +} + +return extractors diff --git a/lualib/lua_selectors/init.lua b/lualib/lua_selectors/init.lua new file mode 100644 index 0000000..5fcdb38 --- /dev/null +++ b/lualib/lua_selectors/init.lua @@ -0,0 +1,668 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +-- This module contains 'selectors' implementation: code to extract data +-- from Rspamd tasks and compose those together +-- +-- Read more at https://rspamd.com/doc/configuration/selectors.html + +--[[[ +-- @module lua_selectors +-- This module contains 'selectors' implementation: code to extract data +-- from Rspamd tasks and compose those together. +-- Typical selector looks like this: header(User).lower.substring(1, 2):ip +--]] + +local exports = { + maps = require "lua_selectors/maps" +} + +local logger = require 'rspamd_logger' +local fun = require 'fun' +local lua_util = require "lua_util" +local M = "selectors" +local rspamd_text = require "rspamd_text" +local unpack_function = table.unpack or unpack +local E = {} + +local extractors = require "lua_selectors/extractors" +local transform_function = require "lua_selectors/transforms" + +local text_cookie = rspamd_text.cookie + +local function pure_type(ltype) + return ltype:match('^(.*)_list$') +end + +local function implicit_tostring(t, ud_or_table) + if t == 'table' then + -- Table (very special) + if ud_or_table.value then + return ud_or_table.value, 'string' + elseif ud_or_table.addr then + return ud_or_table.addr, 'string' + end + + return logger.slog("%s", ud_or_table), 'string' + elseif (t == 'string' or t == 'text') and type(ud_or_table) == 'userdata' then + if ud_or_table.cookie and ud_or_table.cookie == text_cookie then + -- Preserve opaque + return ud_or_table, 'string' + else + return tostring(ud_or_table), 'string' + end + elseif t ~= 'nil' then + return tostring(ud_or_table), 'string' + end + + return nil +end + +local function process_selector(task, sel) + local function allowed_type(t) + if t == 'string' or t == 'string_list' then + return true + end + + return false + end + + local function list_type(t) + return pure_type(t) + end + + local input, etype = sel.selector.get_value(task, sel.selector.args) + + if not input then + lua_util.debugm(M, task, 'no value extracted for %s', sel.selector.name) + return nil + end + + lua_util.debugm(M, task, 'extracted %s, type %s', + sel.selector.name, etype) + + local pipe = sel.processor_pipe or E + local first_elt = pipe[1] + + if first_elt and (first_elt.method or + fun.any(function(t) + return t == 'userdata' or t == 'table' + end, first_elt.types)) then + -- Explicit conversion + local meth = first_elt + + if meth.types[etype] then + lua_util.debugm(M, task, 'apply method `%s` to %s', + meth.name, etype) + input, etype = meth.process(input, etype, meth.args) + else + local pt = pure_type(etype) + + if meth.types[pt] then + lua_util.debugm(M, task, 'map method `%s` to list of %s', + meth.name, pt) + -- Map method to a list of inputs, excluding empty elements + -- We need to fold it down here to get a proper type resolution + input = fun.totable(fun.filter(function(map_elt, _) + return map_elt + end, + fun.map(function(list_elt) + local ret, ty = meth.process(list_elt, pt, meth.args) + if ret then + etype = ty + end + return ret + end, input))) + if input and etype then + etype = etype .. "_list" + else + input = nil + end + end + end + -- Remove method from the pipeline + pipe = fun.drop_n(1, pipe) + elseif etype:match('^userdata') or etype:match('^table') then + -- Implicit conversion + local pt = pure_type(etype) + + if not pt then + lua_util.debugm(M, task, 'apply implicit conversion %s->string', etype) + input = implicit_tostring(etype, input) + etype = 'string' + else + lua_util.debugm(M, task, 'apply implicit map %s->string', pt) + input = fun.filter(function(map_elt) + return map_elt + end, + fun.map(function(list_elt) + local ret = implicit_tostring(pt, list_elt) + return ret + end, input)) + etype = 'string_list' + end + else + lua_util.debugm(M, task, 'avoid implicit conversion as the transformer accepts complex input') + end + + -- Now we fold elements using left fold + local function fold_function(acc, x) + if acc == nil or acc[1] == nil then + lua_util.debugm(M, task, 'do not apply %s, accumulator is nil', x.name) + return nil + end + + local value = acc[1] + local t = acc[2] + + if not x.types[t] then + local pt = pure_type(t) + + if pt and x.types['list'] then + -- Generic list processor + lua_util.debugm(M, task, 'apply list function `%s` to %s', x.name, t) + return { x.process(value, t, x.args) } + elseif pt and x.map_type and x.types[pt] then + local map_type = x.map_type .. '_list' + lua_util.debugm(M, task, 'map `%s` to list of %s resulting %s', + x.name, pt, map_type) + -- Apply map, filtering empty values + return { + fun.filter(function(map_elt) + return map_elt + end, + fun.map(function(list_elt) + if not list_elt then + return nil + end + local ret, _ = x.process(list_elt, pt, x.args) + return ret + end, value)), + map_type -- Returned type + } + end + logger.errx(task, 'cannot apply transform %s for type %s', x.name, t) + return nil + end + + lua_util.debugm(M, task, 'apply %s to %s', x.name, t) + return { x.process(value, t, x.args) } + end + + local res = fun.foldl(fold_function, + { input, etype }, + pipe) + + if not res or not res[1] then + return nil + end -- Pipeline failed + + if not allowed_type(res[2]) then + -- Search for implicit conversion + local pt = pure_type(res[2]) + + if pt then + lua_util.debugm(M, task, 'apply implicit map %s->string_list', pt) + res[1] = fun.map(function(e) + return implicit_tostring(pt, e) + end, res[1]) + res[2] = 'string_list' + else + res[1] = implicit_tostring(res[2], res[1]) + res[2] = 'string' + end + end + + if list_type(res[2]) then + -- Convert to table as it might have a functional form + res[1] = fun.totable(res[1]) + end + + lua_util.debugm(M, task, 'final selector type: %s, value: %s', res[2], res[1]) + + return res[1] +end + +local function make_grammar() + local l = require "lpeg" + local spc = l.S(" \t\n") ^ 0 + local cont = l.R("\128\191") -- continuation byte + local utf8_high = l.R("\194\223") * cont + + l.R("\224\239") * cont * cont + + l.R("\240\244") * cont * cont * cont + local atom_start = (l.R("az") + l.R("AZ") + l.R("09") + utf8_high + l.S "-") ^ 1 + local atom_end = (l.R("az") + l.R("AZ") + l.R("09") + l.S "-_" + utf8_high) ^ 1 + local atom_mid = (1 - l.S("'\r\n\f\\,)(}{= " .. '"')) ^ 1 + local atom_argument = l.C(atom_start * atom_mid ^ 0 * atom_end ^ 0) -- We allow more characters for the arguments + local atom = l.C(atom_start * atom_end ^ 0) -- We are more strict about selector names itself + local singlequoted_string = l.P "'" * l.C(((1 - l.S "'\r\n\f\\") + (l.P '\\' * 1)) ^ 0) * "'" + local doublequoted_string = l.P '"' * l.C(((1 - l.S '"\r\n\f\\') + (l.P '\\' * 1)) ^ 0) * '"' + local argument = atom_argument + singlequoted_string + doublequoted_string + local dot = l.P(".") + local semicolon = l.P(":") + local obrace = "(" * spc + local tbl_obrace = "{" * spc + local eqsign = spc * "=" * spc + local tbl_ebrace = spc * "}" + local ebrace = spc * ")" + local comma = spc * "," * spc + local sel_separator = spc * l.S ";*" * spc + + return l.P { + "LIST"; + LIST = l.Ct(l.V("EXPR")) * (sel_separator * l.Ct(l.V("EXPR"))) ^ 0, + EXPR = l.V("FUNCTION") * (semicolon * l.V("METHOD")) ^ -1 * (dot * l.V("PROCESSOR")) ^ 0, + PROCESSOR = l.Ct(atom * spc * (obrace * l.V("ARG_LIST") * ebrace) ^ 0), + FUNCTION = l.Ct(atom * spc * (obrace * l.V("ARG_LIST") * ebrace) ^ 0), + METHOD = l.Ct(atom / function(e) + return '__' .. e + end * spc * (obrace * l.V("ARG_LIST") * ebrace) ^ 0), + ARG_LIST = l.Ct((l.V("ARG") * comma ^ 0) ^ 0), + ARG = l.Cf(tbl_obrace * l.V("NAMED_ARG") * tbl_ebrace, rawset) + argument + l.V("LIST_ARGS"), + NAMED_ARG = (l.Ct("") * l.Cg(argument * eqsign * (argument + l.V("LIST_ARGS")) * comma ^ 0) ^ 0), + LIST_ARGS = l.Ct(tbl_obrace * l.V("LIST_ARG") * tbl_ebrace), + LIST_ARG = l.Cg(argument * comma ^ 0) ^ 0, + } +end + +local parser = make_grammar() + +--[[[ +-- @function lua_selectors.parse_selector(cfg, str) +--]] +exports.parse_selector = function(cfg, str) + local parsed = { parser:match(str) } + local output = {} + + if not parsed or not parsed[1] then + return nil + end + + local function check_args(name, schema, args) + if schema then + if getmetatable(schema) then + -- Schema covers all arguments + local res, err = schema:transform(args) + if not res then + logger.errx(rspamd_config, 'invalid arguments for %s: %s', name, err) + return false + else + for i, elt in ipairs(res) do + args[i] = elt + end + end + else + for i, selt in ipairs(schema) do + local res, err = selt:transform(args[i]) + + if err then + logger.errx(rspamd_config, 'invalid arguments for %s: argument number: %s, error: %s', name, i, err) + return false + else + args[i] = res + end + end + end + end + + return true + end + + -- Output AST format is the following: + -- table of individual selectors + -- each selector: list of functions + -- each function: function name + optional list of arguments + for _, sel in ipairs(parsed) do + local res = { + selector = {}, + processor_pipe = {}, + } + + local selector_tbl = sel[1] + if not selector_tbl then + logger.errx(cfg, 'no selector represented') + return nil + end + if not extractors[selector_tbl[1]] then + logger.errx(cfg, 'selector %s is unknown', selector_tbl[1]) + return nil + end + + res.selector = lua_util.shallowcopy(extractors[selector_tbl[1]]) + res.selector.name = selector_tbl[1] + res.selector.args = selector_tbl[2] or E + + if not check_args(res.selector.name, + res.selector.args_schema, + res.selector.args) then + return nil + end + + lua_util.debugm(M, cfg, 'processed selector %s, args: %s', + res.selector.name, res.selector.args) + + local pipeline_error = false + -- Now process processors pipe + fun.each(function(proc_tbl) + local proc_name = proc_tbl[1] + + if proc_name:match('^__') then + -- Special case - method + local method_name = proc_name:match('^__(.*)$') + -- Check array indexing... + if tonumber(method_name) then + method_name = tonumber(method_name) + end + local processor = { + name = tostring(method_name), + method = true, + args = proc_tbl[2] or E, + types = { + userdata = true, + table = true, + string = true, + }, + map_type = 'string', + process = function(inp, t, args) + local ret + if t == 'table' then + -- Plain table field + ret = inp[method_name] + else + -- We call method unpacking arguments and dropping all but the first result returned + ret = (inp[method_name](inp, unpack_function(args or E))) + end + + local ret_type = type(ret) + + if ret_type == 'nil' then + return nil + end + -- Now apply types heuristic + if ret_type == 'string' then + return ret, 'string' + elseif ret_type == 'table' then + -- TODO: we need to ensure that 1) table is numeric 2) table has merely strings + return ret, 'string_list' + else + return implicit_tostring(ret_type, ret) + end + end, + } + lua_util.debugm(M, cfg, 'attached method %s to selector %s, args: %s', + proc_name, res.selector.name, processor.args) + table.insert(res.processor_pipe, processor) + else + + if not transform_function[proc_name] then + logger.errx(cfg, 'processor %s is unknown', proc_name) + pipeline_error = proc_name + return nil + end + local processor = lua_util.shallowcopy(transform_function[proc_name]) + processor.name = proc_name + processor.args = proc_tbl[2] or E + + if not check_args(processor.name, processor.args_schema, processor.args) then + pipeline_error = 'args schema for ' .. proc_name + return nil + end + + lua_util.debugm(M, cfg, 'attached processor %s to selector %s, args: %s', + proc_name, res.selector.name, processor.args) + table.insert(res.processor_pipe, processor) + end + end, fun.tail(sel)) + + if pipeline_error then + logger.errx(cfg, 'unknown or invalid processor used: "%s", exiting', pipeline_error) + return nil + end + + table.insert(output, res) + end + + return output +end + +--[[[ +-- @function lua_selectors.register_extractor(cfg, name, selector) +--]] +exports.register_extractor = function(cfg, name, selector) + if selector.get_value then + if extractors[name] then + logger.warnx(cfg, 'redefining selector %s', name) + end + extractors[name] = selector + + return true + end + + logger.errx(cfg, 'bad selector %s', name) + return false +end + +--[[[ +-- @function lua_selectors.register_transform(cfg, name, transform) +--]] +exports.register_transform = function(cfg, name, transform) + if transform.process and transform.types then + if transform_function[name] then + logger.warnx(cfg, 'redefining transform function %s', name) + end + transform_function[name] = transform + + return true + end + + logger.errx(cfg, 'bad transform function %s', name) + return false +end + +--[[[ +-- @function lua_selectors.process_selectors(task, selectors_pipe) +--]] +exports.process_selectors = function(task, selectors_pipe) + local ret = {} + + for _, sel in ipairs(selectors_pipe) do + local r = process_selector(task, sel) + + -- If any element is nil, then the whole selector is nil + if not r then + return nil + end + table.insert(ret, r) + end + + return ret +end + +--[[[ +-- @function lua_selectors.combine_selectors(task, selectors, delimiter) +--]] +exports.combine_selectors = function(_, selectors, delimiter) + if not delimiter then + delimiter = '' + end + + if not selectors then + return nil + end + + local have_tables, have_userdata + + for _, s in ipairs(selectors) do + if type(s) == 'table' then + have_tables = true + elseif type(s) == 'userdata' then + have_userdata = true + end + end + + if not have_tables then + if not have_userdata then + return table.concat(selectors, delimiter) + else + return rspamd_text.fromtable(selectors, delimiter) + end + else + -- We need to do a spill on each table selector and make a cortesian product + -- e.g. s:tbl:s -> s:telt1:s + s:telt2:s ... + local tbl = {} + local res = {} + + for i, s in ipairs(selectors) do + if type(s) == 'string' then + rawset(tbl, i, fun.duplicate(s)) + elseif type(s) == 'userdata' then + rawset(tbl, i, fun.duplicate(tostring(s))) + else + -- Raw table + rawset(tbl, i, fun.map(tostring, s)) + end + end + + fun.each(function(...) + table.insert(res, table.concat({ ... }, delimiter)) + end, fun.zip(lua_util.unpack(tbl))) + + return res + end +end + +--[[[ +-- @function lua_selectors.flatten_selectors(selectors) +-- Convert selectors to a flat table of elements +--]] +exports.flatten_selectors = function(_, selectors, _) + local res = {} + + local function fill(tbl) + for _, s in ipairs(tbl) do + if type(s) == 'string' then + rawset(res, #res + 1, s) + elseif type(s) == 'userdata' then + rawset(res, #res + 1, tostring(s)) + else + fill(s) + end + end + end + + fill(selectors) + + return res +end + +--[[[ +-- @function lua_selectors.kv_table_from_pairs(selectors) +-- Convert selectors to a table where the odd elements are keys and even are elements +-- Similarly to make a map from (k, v) pairs list +-- To specify the concrete constant keys, one can use the `id` extractor +--]] +exports.kv_table_from_pairs = function(log_obj, selectors, _) + local res = {} + local rspamd_logger = require "rspamd_logger" + + local function fill(tbl) + local tbl_len = #tbl + if tbl_len % 2 ~= 0 or tbl_len == 0 then + rspamd_logger.errx(log_obj, "invalid invocation of the `kv_table_from_pairs`: table length is invalid %s", + tbl_len) + return + end + for i = 1, tbl_len, 2 do + local k = tostring(tbl[i]) + local v = tbl[i + 1] + if type(v) == 'string' then + res[k] = v + elseif type(v) == 'userdata' then + res[k] = tostring(v) + else + res[k] = fun.totable(fun.map(function(elt) + return tostring(elt) + end, v)) + end + end + end + + fill(selectors) + + return res +end + + +--[[[ +-- @function lua_selectors.create_closure(log_obj, cfg, selector_str, delimiter, fn) +-- Creates a closure from a string selector, using the specific combinator function +--]] +exports.create_selector_closure_fn = function(log_obj, cfg, selector_str, delimiter, fn) + local selector = exports.parse_selector(cfg, selector_str) + + if not selector then + return nil + end + + return function(task) + local res = exports.process_selectors(task, selector) + + if res then + return fn(log_obj, res, delimiter) + end + + return nil + end +end + +--[[[ +-- @function lua_selectors.create_closure(cfg, selector_str, delimiter='', flatten=false) +-- Creates a closure from a string selector +--]] +exports.create_selector_closure = function(cfg, selector_str, delimiter, flatten) + local combinator_fn = flatten and exports.flatten_selectors or exports.combine_selectors + + return exports.create_selector_closure_fn(nil, cfg, selector_str, delimiter, combinator_fn) +end + +local function display_selectors(tbl) + return fun.tomap(fun.map(function(k, v) + return k, fun.tomap(fun.filter(function(kk, vv) + return type(vv) ~= 'function' + end, v)) + end, tbl)) +end + +exports.list_extractors = function() + return display_selectors(extractors) +end + +exports.list_transforms = function() + return display_selectors(transform_function) +end + +exports.add_map = function(name, map) + if not exports.maps[name] then + exports.maps[name] = map + else + logger.errx(rspamd_config, "duplicate map redefinition for the selectors: %s", name) + end +end + +-- Publish log target +exports.M = M + +return exports diff --git a/lualib/lua_selectors/maps.lua b/lualib/lua_selectors/maps.lua new file mode 100644 index 0000000..85b54a6 --- /dev/null +++ b/lualib/lua_selectors/maps.lua @@ -0,0 +1,19 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local maps = {} -- Shared within selectors, indexed by name + +return maps
\ No newline at end of file diff --git a/lualib/lua_selectors/transforms.lua b/lualib/lua_selectors/transforms.lua new file mode 100644 index 0000000..6c6bc71 --- /dev/null +++ b/lualib/lua_selectors/transforms.lua @@ -0,0 +1,571 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local fun = require 'fun' +local lua_util = require "lua_util" +local rspamd_util = require "rspamd_util" +local ts = require("tableshape").types +local logger = require 'rspamd_logger' +local common = require "lua_selectors/common" +local M = "selectors" + +local maps = require "lua_selectors/maps" + +local function pure_type(ltype) + return ltype:match('^(.*)_list$') +end + +local transform_function = { + -- Returns the lowercased string + ['lower'] = { + ['types'] = { + ['string'] = true, + }, + ['map_type'] = 'string', + ['process'] = function(inp, _) + return inp:lower(), 'string' + end, + ['description'] = 'Returns the lowercased string', + }, + -- Returns the lowercased utf8 string + ['lower_utf8'] = { + ['types'] = { + ['string'] = true, + }, + ['map_type'] = 'string', + ['process'] = function(inp, t) + return rspamd_util.lower_utf8(inp), t + end, + ['description'] = 'Returns the lowercased utf8 string', + }, + -- Returns the first element + ['first'] = { + ['types'] = { + ['list'] = true, + }, + ['process'] = function(inp, t) + return fun.head(inp), pure_type(t) + end, + ['description'] = 'Returns the first element', + }, + -- Returns the last element + ['last'] = { + ['types'] = { + ['list'] = true, + }, + ['process'] = function(inp, t) + return fun.nth(fun.length(inp), inp), pure_type(t) + end, + ['description'] = 'Returns the last element', + }, + -- Returns the nth element + ['nth'] = { + ['types'] = { + ['list'] = true, + }, + ['process'] = function(inp, t, args) + return fun.nth(args[1] or 1, inp), pure_type(t) + end, + ['description'] = 'Returns the nth element', + ['args_schema'] = { ts.number + ts.string / tonumber } + }, + ['take_n'] = { + ['types'] = { + ['list'] = true, + }, + ['process'] = function(inp, t, args) + return fun.take_n(args[1] or 1, inp), t + end, + ['description'] = 'Returns the n first elements', + ['args_schema'] = { ts.number + ts.string / tonumber } + }, + ['drop_n'] = { + ['types'] = { + ['list'] = true, + }, + ['process'] = function(inp, t, args) + return fun.drop_n(args[1] or 1, inp), t + end, + ['description'] = 'Returns list without the first n elements', + ['args_schema'] = { ts.number + ts.string / tonumber } + }, + -- Joins strings into a single string using separator in the argument + ['join'] = { + ['types'] = { + ['string_list'] = true + }, + ['process'] = function(inp, _, args) + return table.concat(fun.totable(inp), args[1] or ''), 'string' + end, + ['description'] = 'Joins strings into a single string using separator in the argument', + ['args_schema'] = { ts.string:is_optional() } + }, + -- Joins strings into a set of strings using N elements and a separator in the argument + ['join_nth'] = { + ['types'] = { + ['string_list'] = true + }, + ['process'] = function(inp, _, args) + local step = args[1] + local sep = args[2] or '' + local inp_t = fun.totable(inp) + local res = {} + + for i = 1, #inp_t, step do + table.insert(res, table.concat(inp_t, sep, i, i + step)) + end + return res, 'string_list' + end, + ['description'] = 'Joins strings into a set of strings using N elements and a separator in the argument', + ['args_schema'] = { ts.number + ts.string / tonumber, ts.string:is_optional() } + }, + -- Joins tables into a table of strings + ['join_tables'] = { + ['types'] = { + ['list'] = true + }, + ['process'] = function(inp, _, args) + local sep = args[1] or '' + return fun.map(function(t) + return table.concat(t, sep) + end, inp), 'string_list' + end, + ['description'] = 'Joins tables into a table of strings', + ['args_schema'] = { ts.string:is_optional() } + }, + -- Sort strings + ['sort'] = { + ['types'] = { + ['list'] = true + }, + ['process'] = function(inp, t, _) + table.sort(inp) + return inp, t + end, + ['description'] = 'Sort strings lexicographically', + }, + -- Return unique elements based on hashing (can work without sorting) + ['uniq'] = { + ['types'] = { + ['list'] = true + }, + ['process'] = function(inp, t, _) + local tmp = {} + fun.each(function(val) + tmp[val] = true + end, inp) + + return fun.map(function(k, _) + return k + end, tmp), t + end, + ['description'] = 'Returns a list of unique elements (using a hash table)', + }, + -- Create a digest from string or a list of strings + ['digest'] = { + ['types'] = { + ['string'] = true + }, + ['map_type'] = 'string', + ['process'] = function(inp, _, args) + return common.create_digest(inp, args), 'string' + end, + ['description'] = [[Create a digest from a string. +The first argument is encoding (`hex`, `base32` (and forms `bleach32`, `rbase32`), `base64`), +the second argument is optional hash type (`blake2`, `sha256`, `sha1`, `sha512`, `md5`)]], + ['args_schema'] = common.digest_schema() + }, + -- Extracts substring + ['substring'] = { + ['types'] = { + ['string'] = true + }, + ['map_type'] = 'string', + ['process'] = function(inp, _, args) + local start_pos = args[1] or 1 + local end_pos = args[2] or -1 + + return inp:sub(start_pos, end_pos), 'string' + end, + ['description'] = 'Extracts substring; the first argument is start, the second is the last (like in Lua)', + ['args_schema'] = { (ts.number + ts.string / tonumber):is_optional(), + (ts.number + ts.string / tonumber):is_optional() } + }, + -- Prepends a string or a strings list + ['prepend'] = { + ['types'] = { + ['string'] = true + }, + ['map_type'] = 'string', + ['process'] = function(inp, _, args) + local prepend = table.concat(args, '') + + return prepend .. inp, 'string' + end, + ['description'] = 'Prepends a string or a strings list', + }, + -- Appends a string or a strings list + ['append'] = { + ['types'] = { + ['string'] = true + }, + ['map_type'] = 'string', + ['process'] = function(inp, _, args) + local append = table.concat(args, '') + + return inp .. append, 'string' + end, + ['description'] = 'Appends a string or a strings list', + }, + -- Regexp matching + ['regexp'] = { + ['types'] = { + ['string'] = true + }, + ['map_type'] = 'string', + ['process'] = function(inp, _, args) + local rspamd_regexp = require "rspamd_regexp" + + local re = rspamd_regexp.create_cached(args[1]) + + if not re then + logger.errx('invalid regexp: %s', args[1]) + return nil + end + + local res = re:search(inp, false, true) + + if res then + -- Map all results in a single list + local flattened_table = {} + local function flatten_table(tbl) + for _, v in ipairs(tbl) do + if type(v) == 'table' then + flatten_table(v) + else + table.insert(flattened_table, v) + end + end + end + flatten_table(res) + return flattened_table, 'string_list' + end + + return nil + end, + ['description'] = 'Regexp matching, returns all matches flattened in a single list', + ['args_schema'] = { ts.string } + }, + -- Returns a value if it exists in some map (or acts like a `filter` function) + ['filter_map'] = { + ['types'] = { + ['string'] = true + }, + ['map_type'] = 'string', + ['process'] = function(inp, t, args) + local map = maps[args[1]] + + if not map then + logger.errx('invalid map name: %s', args[1]) + return nil + end + + local res = map:get_key(inp) + + if res then + return inp, t + end + + return nil + end, + ['description'] = 'Returns a value if it exists in some map (or acts like a `filter` function)', + ['args_schema'] = { ts.string } + }, + -- Returns a value if it exists in some map (or acts like a `filter` function) + ['except_map'] = { + ['types'] = { + ['string'] = true + }, + ['map_type'] = 'string', + ['process'] = function(inp, t, args) + local map = maps[args[1]] + + if not map then + logger.errx('invalid map name: %s', args[1]) + return nil + end + + local res = map:get_key(inp) + + if not res then + return inp, t + end + + return nil + end, + ['description'] = 'Returns a value if it does not exists in some map (or acts like a `except` function)', + ['args_schema'] = { ts.string } + }, + -- Returns a value from some map corresponding to some key (or acts like a `map` function) + ['apply_map'] = { + ['types'] = { + ['string'] = true + }, + ['map_type'] = 'string', + ['process'] = function(inp, t, args) + local map = maps[args[1]] + + if not map then + logger.errx('invalid map name: %s', args[1]) + return nil + end + + local res = map:get_key(inp) + + if res then + return res, t + end + + return nil + end, + ['description'] = 'Returns a value from some map corresponding to some key (or acts like a `map` function)', + ['args_schema'] = { ts.string } + }, + -- Drops input value and return values from function's arguments or an empty string + ['id'] = { + ['types'] = { + ['string'] = true, + ['list'] = true, + }, + ['map_type'] = 'string', + ['process'] = function(_, _, args) + if args[1] and args[2] then + return fun.map(tostring, args), 'string_list' + elseif args[1] then + return args[1], 'string' + end + + return '', 'string' + end, + ['description'] = 'Drops input value and return values from function\'s arguments or an empty string', + ['args_schema'] = (ts.string + ts.array_of(ts.string)):is_optional() + }, + ['equal'] = { + ['types'] = { + ['string'] = true, + }, + ['map_type'] = 'string', + ['process'] = function(inp, _, args) + if inp == args[1] then + return inp, 'string' + end + + return nil + end, + ['description'] = [[Boolean function equal. +Returns either nil or its argument if input is equal to argument]], + ['args_schema'] = { ts.string } + }, + -- Boolean function in, returns either nil or its input if input is in args list + ['in'] = { + ['types'] = { + ['string'] = true, + }, + ['map_type'] = 'string', + ['process'] = function(inp, t, args) + for _, a in ipairs(args) do + if a == inp then + return inp, t + end + end + return nil + end, + ['description'] = [[Boolean function in. +Returns either nil or its input if input is in args list]], + ['args_schema'] = ts.array_of(ts.string) + }, + ['not_in'] = { + ['types'] = { + ['string'] = true, + }, + ['map_type'] = 'string', + ['process'] = function(inp, t, args) + for _, a in ipairs(args) do + if a == inp then + return nil + end + end + return inp, t + end, + ['description'] = [[Boolean function not in. +Returns either nil or its input if input is not in args list]], + ['args_schema'] = ts.array_of(ts.string) + }, + ['inverse'] = { + ['types'] = { + ['string'] = true, + }, + ['map_type'] = 'string', + ['process'] = function(inp, _, args) + if inp then + return nil + else + return (args[1] or 'true'), 'string' + end + end, + ['description'] = [[Inverses input. +Empty string comes the first argument or 'true', non-empty string comes nil]], + ['args_schema'] = { ts.string:is_optional() } + }, + ['ipmask'] = { + ['types'] = { + ['string'] = true, + }, + ['map_type'] = 'string', + ['process'] = function(inp, _, args) + local rspamd_ip = require "rspamd_ip" + -- Non optimal: convert string to an IP address + local ip = rspamd_ip.from_string(inp) + + if not ip or not ip:is_valid() then + lua_util.debugm(M, "cannot convert %s to IP", inp) + return nil + end + + if ip:get_version() == 4 then + local mask = tonumber(args[1]) + + return ip:apply_mask(mask):to_string(), 'string' + else + -- IPv6 takes the second argument or the first one... + local mask_str = args[2] or args[1] + local mask = tonumber(mask_str) + + return ip:apply_mask(mask):to_string(), 'string' + end + end, + ['description'] = 'Applies mask to IP address.' .. + ' The first argument is the mask for IPv4 addresses, the second is the mask for IPv6 addresses.', + ['args_schema'] = { (ts.number + ts.string / tonumber), + (ts.number + ts.string / tonumber):is_optional() } + }, + -- Returns the string(s) with all non ascii chars replaced + ['to_ascii'] = { + ['types'] = { + ['string'] = true, + ['list'] = true, + }, + ['map_type'] = 'string', + ['process'] = function(inp, _, args) + if type(inp) == 'table' then + return fun.map( + function(s) + return string.gsub(tostring(s), '[\128-\255]', args[1] or '?') + end, inp), 'string_list' + else + return string.gsub(tostring(inp), '[\128-\255]', '?'), 'string' + end + end, + ['description'] = 'Returns the string with all non-ascii bytes replaced with the character ' .. + 'given as second argument or `?`', + ['args_schema'] = { ts.string:is_optional() } + }, + -- Extracts tld from a hostname + ['get_tld'] = { + ['types'] = { + ['string'] = true + }, + ['map_type'] = 'string', + ['process'] = function(inp, _, _) + return rspamd_util.get_tld(inp), 'string' + end, + ['description'] = 'Extracts tld from a hostname represented as a string', + ['args_schema'] = {} + }, + -- Converts list of strings to numbers and returns a packed string + ['pack_numbers'] = { + ['types'] = { + ['string_list'] = true + }, + ['map_type'] = 'string', + ['process'] = function(inp, _, args) + local fmt = args[1] or 'f' + local res = {} + for _, s in ipairs(inp) do + table.insert(res, tonumber(s)) + end + return rspamd_util.pack(string.rep(fmt, #res), lua_util.unpack(res)), 'string' + end, + ['description'] = 'Converts a list of strings to numbers & returns a packed string', + ['args_schema'] = { ts.string:is_optional() } + }, + -- Filter nils from a list + ['filter_string_nils'] = { + ['types'] = { + ['string_list'] = true + }, + ['process'] = function(inp, _, _) + return fun.filter(function(val) + return type(val) == 'string' and val ~= 'nil' + end, inp), 'string_list' + end, + ['description'] = 'Removes all nils from a list of strings (when converted implicitly)', + ['args_schema'] = {} + }, + -- Call a set of methods on a userdata object + ['apply_methods'] = { + ['types'] = { + ['userdata'] = true, + }, + ['process'] = function(inp, _, args) + local res = {} + for _, arg in ipairs(args) do + local meth = inp[arg] + local ret = meth(inp) + if ret then + table.insert(res, tostring(ret)) + end + end + return res, 'string_list' + end, + ['description'] = 'Apply a list of method calls to the userdata object', + }, + -- Apply method to list of userdata and use it as a filter, excluding elements for which method returns false/nil + ['filter_method'] = { + ['types'] = { + ['userdata_list'] = true + }, + ['process'] = function(inp, t, args) + local meth = args[1] + + if not meth then + logger.errx('invalid method name: %s', args[1]) + return nil + end + + return fun.filter(function(val) + return val[meth](val) + end, inp), 'userdata_list' + end, + ['description'] = 'Apply method to list of userdata and use it as a filter,' .. + ' excluding elements for which method returns false/nil', + ['args_schema'] = { ts.string } + }, +} + +transform_function.match = transform_function.regexp + +return transform_function diff --git a/lualib/lua_settings.lua b/lualib/lua_settings.lua new file mode 100644 index 0000000..d6d24d6 --- /dev/null +++ b/lualib/lua_settings.lua @@ -0,0 +1,309 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_settings +-- This module contains internal helpers for the settings infrastructure in Rspamd +-- More details at https://rspamd.com/doc/configuration/settings.html +--]] + +local exports = {} +local known_ids = {} +local post_init_added = false +local post_init_performed = false +local all_symbols +local default_symbols + +local fun = require "fun" +local lua_util = require "lua_util" +local rspamd_logger = require "rspamd_logger" + +local function register_settings_cb(from_postload) + if not post_init_performed then + all_symbols = rspamd_config:get_symbols() + + default_symbols = fun.totable(fun.filter(function(_, v) + return not v.allowed_ids or #v.allowed_ids == 0 or v.flags.explicit_disable + end, all_symbols)) + + local explicit_symbols = lua_util.keys(fun.filter(function(k, v) + return v.flags.explicit_disable + end, all_symbols)) + + local symnames = lua_util.list_to_hash(lua_util.keys(all_symbols)) + + for _, set in pairs(known_ids) do + local s = set.settings.apply or {} + set.symbols = lua_util.shallowcopy(symnames) + local enabled_symbols = {} + local seen_enabled = false + local disabled_symbols = {} + local seen_disabled = false + + -- Enabled map + if s.symbols_enabled then + -- Remove all symbols from set.symbols aside of explicit_disable symbols + set.symbols = lua_util.list_to_hash(explicit_symbols) + seen_enabled = true + for _, sym in ipairs(s.symbols_enabled) do + enabled_symbols[sym] = true + set.symbols[sym] = true + end + end + if s.groups_enabled then + seen_enabled = true + for _, gr in ipairs(s.groups_enabled) do + local syms = rspamd_config:get_group_symbols(gr) + + if syms then + for _, sym in ipairs(syms) do + enabled_symbols[sym] = true + set.symbols[sym] = true + end + end + end + end + + -- Disabled map + if s.symbols_disabled then + seen_disabled = true + for _, sym in ipairs(s.symbols_disabled) do + disabled_symbols[sym] = true + set.symbols[sym] = false + end + end + if s.groups_disabled then + seen_disabled = true + for _, gr in ipairs(s.groups_disabled) do + local syms = rspamd_config:get_group_symbols(gr) + + if syms then + for _, sym in ipairs(syms) do + disabled_symbols[sym] = true + set.symbols[sym] = false + end + end + end + end + + -- Deal with complexity to avoid mess in C + if not seen_enabled then + enabled_symbols = nil + end + if not seen_disabled then + disabled_symbols = nil + end + + if enabled_symbols or disabled_symbols then + -- Specify what symbols are really enabled for this settings id + set.has_specific_symbols = true + end + + rspamd_config:register_settings_id(set.name, enabled_symbols, disabled_symbols) + + -- Remove to avoid clash + s.symbols_disabled = nil + s.symbols_enabled = nil + s.groups_enabled = nil + s.groups_disabled = nil + end + + -- We now iterate over all symbols and check for allowed_ids/forbidden_ids + for k, v in pairs(all_symbols) do + if v.allowed_ids and not v.flags.explicit_disable then + for _, id in ipairs(v.allowed_ids) do + if known_ids[id] then + local set = known_ids[id] + if not set.has_specific_symbols then + set.has_specific_symbols = true + end + set.symbols[k] = true + else + rspamd_logger.errx(rspamd_config, 'symbol %s is allowed at unknown settings id %s', + k, id) + end + end + end + if v.forbidden_ids then + for _, id in ipairs(v.forbidden_ids) do + if known_ids[id] then + local set = known_ids[id] + if not set.has_specific_symbols then + set.has_specific_symbols = true + end + set.symbols[k] = false + else + rspamd_logger.errx(rspamd_config, 'symbol %s is denied at unknown settings id %s', + k, id) + end + end + end + end + + -- Now we create lists of symbols for each settings and digest + for _, set in pairs(known_ids) do + set.symbols = lua_util.keys(fun.filter(function(_, v) + return v + end, set.symbols)) + table.sort(set.symbols) + set.digest = lua_util.table_digest(set.symbols) + end + + post_init_performed = true + end +end + +-- Returns numeric representation of the settings id +local function numeric_settings_id(str) + local cr = require "rspamd_cryptobox_hash" + local util = require "rspamd_util" + local ret = util.unpack("I4", + cr.create_specific('xxh64'):update(str):bin()) + + return ret +end + +exports.numeric_settings_id = numeric_settings_id + +-- Used to do the following: +-- If there is a group of symbols_allowed, it checks if that is an array +-- If that is a hash table then we transform it to a normal list, probably adding symbols to adjust scores +local function transform_settings_maybe(settings, name) + if settings.apply then + local apply = settings.apply + + if apply.symbols_enabled then + local senabled = apply.symbols_enabled + + if not senabled[1] then + -- Transform map to a list + local nlist = {} + if not apply.scores then + apply.scores = {} + end + for k, v in pairs(senabled) do + if tonumber(v) then + -- Move to symbols as well + apply.scores[k] = tonumber(v) + lua_util.debugm('settings', rspamd_config, + 'set symbol %s -> %s for settings %s', k, v, name) + end + nlist[#nlist + 1] = k + end + -- Convert + apply.symbols_enabled = nlist + end + + local symhash = lua_util.list_to_hash(apply.symbols_enabled) + + if apply.symbols then + -- Check if added symbols are enabled + for k, v in pairs(apply.symbols) do + local s + -- Check if we have ["sym1", "sym2" ...] or {"sym1": xx, "sym2": yy} + if type(k) == 'string' then + s = k + else + s = v + end + if not symhash[s] then + lua_util.debugm('settings', rspamd_config, + 'added symbol %s to symbols_enabled for %s', s, name) + apply.symbols_enabled[#apply.symbols_enabled + 1] = s + end + end + end + end + end + + return settings +end + +local function register_settings_id(str, settings, from_postload) + local numeric_id = numeric_settings_id(str) + + if known_ids[numeric_id] then + -- Might be either rewrite or a collision + if known_ids[numeric_id].name ~= str then + local logger = require "rspamd_logger" + + logger.errx(rspamd_config, 'settings ID clash! id %s maps to %s and conflicts with %s', + numeric_id, known_ids[numeric_id].name, str) + + return nil + end + else + known_ids[numeric_id] = { + name = str, + id = numeric_id, + settings = transform_settings_maybe(settings, str), + symbols = {} + } + end + + if not from_postload and not post_init_added then + -- Use high priority to ensure that settings are initialised early but not before all + -- plugins are loaded + rspamd_config:add_post_init(function() + register_settings_cb(true) + end, 150) + rspamd_config:add_config_unload(function() + if post_init_added then + known_ids = {} + post_init_added = false + end + post_init_performed = false + end) + + post_init_added = true + end + + return numeric_id +end + +exports.register_settings_id = register_settings_id + +local function settings_by_id(id) + if not post_init_performed then + register_settings_cb(false) + end + return known_ids[id] +end + +exports.settings_by_id = settings_by_id +exports.all_settings = function() + if not post_init_performed then + register_settings_cb(false) + end + return known_ids +end +exports.all_symbols = function() + if not post_init_performed then + register_settings_cb(false) + end + return all_symbols +end +-- What is enabled when no settings are there +exports.default_symbols = function() + if not post_init_performed then + register_settings_cb(false) + end + return default_symbols +end + +exports.load_all_settings = register_settings_cb + +return exports
\ No newline at end of file diff --git a/lualib/lua_smtp.lua b/lualib/lua_smtp.lua new file mode 100644 index 0000000..3c40349 --- /dev/null +++ b/lualib/lua_smtp.lua @@ -0,0 +1,201 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local rspamd_tcp = require "rspamd_tcp" +local lua_util = require "lua_util" + +local exports = {} + +local CRLF = '\r\n' +local default_timeout = 10.0 + +--[[[ +-- @function lua_smtp.sendmail(task, message, opts, callback) +--]] +local function sendmail(opts, message, callback) + local stage = 'connect' + + local function mail_cb(err, data, conn) + local function no_error_write(merr) + if merr then + callback(false, string.format('error on stage %s: %s', + stage, merr)) + if conn then + conn:close() + end + + return false + end + + return true + end + + local function no_error_read(merr, mdata, wantcode) + wantcode = wantcode or '2' + if merr then + callback(false, string.format('error on stage %s: %s', + stage, merr)) + if conn then + conn:close() + end + + return false + end + if mdata then + if type(mdata) ~= 'string' then + mdata = tostring(mdata) + end + if string.sub(mdata, 1, 1) ~= wantcode then + callback(false, string.format('bad smtp response on stage %s: "%s" when "%s" expected', + stage, mdata, wantcode)) + if conn then + conn:close() + end + return false + end + else + callback(false, string.format('no data on stage %s', + stage)) + if conn then + conn:close() + end + return false + end + return true + end + + -- After quit + local function all_done_cb(merr, mdata) + if conn then + conn:close() + end + + callback(true, nil) + + return true + end + + -- QUIT stage + local function quit_done_cb(_, _) + conn:add_read(all_done_cb, CRLF) + end + local function quit_cb(merr, mdata) + if no_error_read(merr, mdata) then + conn:add_write(quit_done_cb, 'QUIT' .. CRLF) + end + end + local function pre_quit_cb(merr, _) + if no_error_write(merr) then + stage = 'quit' + conn:add_read(quit_cb, CRLF) + end + end + + -- DATA stage + local function data_done_cb(merr, mdata) + if no_error_read(merr, mdata, '3') then + if type(message) == 'string' or type(message) == 'userdata' then + conn:add_write(pre_quit_cb, { message, CRLF .. '.' .. CRLF }) + else + table.insert(message, CRLF .. '.' .. CRLF) + conn:add_write(pre_quit_cb, message) + end + end + end + local function data_cb(merr, _) + if no_error_write(merr) then + conn:add_read(data_done_cb, CRLF) + end + end + + -- RCPT phase + local next_recipient + local function rcpt_done_cb_gen(i) + return function(merr, mdata) + if no_error_read(merr, mdata) then + if i == #opts.recipients then + conn:add_write(data_cb, 'DATA' .. CRLF) + else + next_recipient(i + 1) + end + end + end + end + + local function rcpt_cb_gen(i) + return function(merr, _) + if no_error_write(merr, '2') then + conn:add_read(rcpt_done_cb_gen(i), CRLF) + end + end + end + + next_recipient = function(i) + conn:add_write(rcpt_cb_gen(i), + string.format('RCPT TO: <%s>%s', opts.recipients[i], CRLF)) + end + + -- FROM stage + local function from_done_cb(merr, mdata) + -- We need to iterate over recipients sequentially + if no_error_read(merr, mdata, '2') then + stage = 'rcpt' + next_recipient(1) + end + end + local function from_cb(merr, _) + if no_error_write(merr) then + conn:add_read(from_done_cb, CRLF) + end + end + local function hello_done_cb(merr, mdata) + if no_error_read(merr, mdata) then + stage = 'from' + conn:add_write(from_cb, string.format( + 'MAIL FROM: <%s>%s', opts.from, CRLF)) + end + end + + -- HELO stage + local function hello_cb(merr) + if no_error_write(merr) then + conn:add_read(hello_done_cb, CRLF) + end + end + if no_error_read(err, data) then + stage = 'helo' + conn:add_write(hello_cb, string.format('HELO %s%s', + opts.helo, CRLF)) + end + end + + if type(opts.recipients) == 'string' then + opts.recipients = { opts.recipients } + end + + local tcp_opts = lua_util.shallowcopy(opts) + tcp_opts.stop_pattern = CRLF + tcp_opts.timeout = opts.timeout or default_timeout + tcp_opts.callback = mail_cb + + if not rspamd_tcp.request(tcp_opts) then + callback(false, 'cannot make a TCP connection') + end +end + +exports.sendmail = sendmail + +return exports
\ No newline at end of file diff --git a/lualib/lua_stat.lua b/lualib/lua_stat.lua new file mode 100644 index 0000000..a0f3303 --- /dev/null +++ b/lualib/lua_stat.lua @@ -0,0 +1,869 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_stat +-- This module contains helper functions for supporting statistics +--]] + +local logger = require "rspamd_logger" +local sqlite3 = require "rspamd_sqlite3" +local util = require "rspamd_util" +local lua_redis = require "lua_redis" +local lua_util = require "lua_util" +local exports = {} + +local N = "stat_tools" -- luacheck: ignore (maybe unused) + +-- Performs synchronous conversion of redis schema +local function convert_bayes_schema(redis_params, symbol_spam, symbol_ham, expire) + + -- Old schema is the following one: + -- Keys are named <symbol>[<user>] + -- Elements are placed within hash: + -- BAYES_SPAM -> {<id1>: <num_hits>, <id2>: <num_hits> ...} + -- In new schema it is changed to a more extensible schema: + -- Keys are named RS[<user>]_<id> -> {'H': <ham_hits>, 'S': <spam_hits>} + -- So we can expire individual records, measure most popular elements by zranges, + -- add new fields, such as tokens etc + + local res, conn = lua_redis.redis_connect_sync(redis_params, true) + + if not res then + logger.errx("cannot connect to redis server") + return false + end + + -- KEYS[1]: key to check (e.g. 'BAYES_SPAM') + -- KEYS[2]: hash key ('S' or 'H') + -- KEYS[3]: expire + local lua_script = [[ +redis.replicate_commands() +local keys = redis.call('SMEMBERS', KEYS[1]..'_keys') +local nconverted = 0 +for _,k in ipairs(keys) do + local cursor = redis.call('HSCAN', k, 0) + local neutral_prefix = string.gsub(k, KEYS[1], 'RS') + local elts + while cursor[1] ~= "0" do + elts = cursor[2] + cursor = redis.call('HSCAN', k, cursor[1]) + local real_key + for i,v in ipairs(elts) do + if i % 2 ~= 0 then + real_key = v + else + local nkey = string.format('%s_%s', neutral_prefix, real_key) + redis.call('HSET', nkey, KEYS[2], v) + if KEYS[3] and tonumber(KEYS[3]) > 0 then + redis.call('EXPIRE', nkey, KEYS[3]) + end + nconverted = nconverted + 1 + end + end + end +end +return nconverted +]] + + conn:add_cmd('EVAL', { lua_script, '3', symbol_spam, 'S', tostring(expire) }) + local ret + ret, res = conn:exec() + + if not ret then + logger.errx('error converting symbol %s: %s', symbol_spam, res) + return false + else + logger.messagex('converted %s elements from symbol %s', res, symbol_spam) + end + + conn:add_cmd('EVAL', { lua_script, '3', symbol_ham, 'H', tostring(expire) }) + ret, res = conn:exec() + + if not ret then + logger.errx('error converting symbol %s: %s', symbol_ham, res) + return false + else + logger.messagex('converted %s elements from symbol %s', res, symbol_ham) + end + + -- We can now convert metadata: set + learned + version + -- KEYS[1]: key to check (e.g. 'BAYES_SPAM') + -- KEYS[2]: learn key (e.g. 'learns_spam' or 'learns_ham') + lua_script = [[ +local keys = redis.call('SMEMBERS', KEYS[1]..'_keys') + +for _,k in ipairs(keys) do + local learns = redis.call('HGET', k, 'learns') or 0 + local neutral_prefix = string.gsub(k, KEYS[1], 'RS') + + redis.call('HSET', neutral_prefix, KEYS[2], learns) + redis.call('SADD', KEYS[1]..'_keys', neutral_prefix) + redis.call('SREM', KEYS[1]..'_keys', k) + redis.call('DEL', KEYS[1]) + redis.call('SET', k ..'_version', '2') +end +]] + + conn:add_cmd('EVAL', { lua_script, '2', symbol_spam, 'learns_spam' }) + ret, res = conn:exec() + + if not ret then + logger.errx('error converting metadata for symbol %s: %s', symbol_spam, res) + return false + end + + conn:add_cmd('EVAL', { lua_script, '2', symbol_ham, 'learns_ham' }) + ret, res = conn:exec() + + if not ret then + logger.errx('error converting metadata for symbol %s', symbol_ham, res) + return false + end + + return true +end + +exports.convert_bayes_schema = convert_bayes_schema + +-- It now accepts both ham and spam databases +-- parameters: +-- redis_params - how do we connect to a redis server +-- sqlite_db_spam - name for sqlite database with spam tokens +-- sqlite_db_ham - name for sqlite database with ham tokens +-- symbol_ham - name for symbol representing spam, e.g. BAYES_SPAM +-- symbol_spam - name for symbol representing ham, e.g. BAYES_HAM +-- learn_cache_spam - name for sqlite database with spam learn cache +-- learn_cache_ham - name for sqlite database with ham learn cache +-- reset_previous - if true, then the old database is flushed (slow) +local function convert_sqlite_to_redis(redis_params, + sqlite_db_spam, sqlite_db_ham, symbol_spam, symbol_ham, + learn_cache_db, expire, reset_previous) + local nusers = 0 + local lim = 1000 -- Update each 1000 tokens + local users_map = {} + local converted = 0 + + local db_spam = sqlite3.open(sqlite_db_spam) + if not db_spam then + logger.errx('Cannot open source db: %s', sqlite_db_spam) + return false + end + local db_ham = sqlite3.open(sqlite_db_ham) + if not db_ham then + logger.errx('Cannot open source db: %s', sqlite_db_ham) + return false + end + + local res, conn = lua_redis.redis_connect_sync(redis_params, true) + + if not res then + logger.errx("cannot connect to redis server") + return false + end + + if reset_previous then + -- Do a more complicated cleanup + -- execute a lua script that cleans up data + local script = [[ +local members = redis.call('SMEMBERS', KEYS[1]..'_keys') + +for _,prefix in ipairs(members) do + local keys = redis.call('KEYS', prefix..'*') + redis.call('DEL', keys) +end +]] + -- Common keys + for _, sym in ipairs({ symbol_spam, symbol_ham }) do + logger.messagex('Cleaning up old data for %s', sym) + conn:add_cmd('EVAL', { script, '1', sym }) + conn:exec() + conn:add_cmd('DEL', { sym .. "_version" }) + conn:add_cmd('DEL', { sym .. "_keys" }) + conn:exec() + end + + if learn_cache_db then + -- Cleanup learned_cache + logger.messagex('Cleaning up old data learned cache') + conn:add_cmd('DEL', { "learned_ids" }) + conn:exec() + end + end + + local function convert_db(db, is_spam) + -- Map users and languages + local what = 'ham' + if is_spam then + what = 'spam' + end + + local learns = {} + db:sql('BEGIN;') + -- Fill users mapping + for row in db:rows('SELECT * FROM users;') do + if row.id == '0' then + users_map[row.id] = '' + else + users_map[row.id] = row.name + end + learns[row.id] = row.learns + nusers = nusers + 1 + end + + -- Workaround for old databases + for row in db:rows('SELECT * FROM languages') do + if learns['0'] then + learns['0'] = learns['0'] + row.learns + else + learns['0'] = row.learns + end + end + + local function send_batch(tokens, prefix) + -- We use the new schema: RS[user]_token -> H=ham count + -- S=spam count + local hash_key = 'H' + if is_spam then + hash_key = 'S' + end + for _, tok in ipairs(tokens) do + -- tok schema: + -- tok[1] = token_id (uint64 represented as a string) + -- tok[2] = token value (number) + -- tok[3] = user_map[user_id] or '' + local rkey = string.format('%s%s_%s', prefix, tok[3], tok[1]) + conn:add_cmd('HINCRBYFLOAT', { rkey, hash_key, tostring(tok[2]) }) + + if expire and expire ~= 0 then + conn:add_cmd('EXPIRE', { rkey, tostring(expire) }) + end + end + + return conn:exec() + end + -- Fill tokens, sending data to redis each `lim` records + + local ntokens = db:query('SELECT count(*) as c FROM tokens')['c'] + local tokens = {} + local num = 0 + local total = 0 + + for row in db:rows('SELECT token,value,user FROM tokens;') do + local user = '' + if row.user ~= 0 and users_map[row.user] then + user = users_map[row.user] + end + + table.insert(tokens, { row.token, row.value, user }) + num = num + 1 + total = total + 1 + if num > lim then + -- TODO: we use the default 'RS' prefix, it can be false in case of + -- classifiers with labels + local ret, err_str = send_batch(tokens, 'RS') + if not ret then + logger.errx('Cannot send tokens to the redis server: ' .. err_str) + db:sql('COMMIT;') + return false + end + + num = 0 + tokens = {} + end + + io.write(string.format('Processed batch %s: %s/%s\r', what, total, ntokens)) + end + -- Last batch + if #tokens > 0 then + local ret, err_str = send_batch(tokens, 'RS') + if not ret then + logger.errx('Cannot send tokens to the redis server: ' .. err_str) + db:sql('COMMIT;') + return false + end + + io.write(string.format('Processed batch %s: %s/%s\r', what, total, ntokens)) + end + io.write('\n') + + converted = converted + total + + -- Close DB + db:sql('COMMIT;') + local symbol = symbol_ham + local learns_elt = "learns_ham" + + if is_spam then + symbol = symbol_spam + learns_elt = "learns_spam" + end + + for id, learned in pairs(learns) do + local user = users_map[id] + if not conn:add_cmd('HSET', { 'RS' .. user, learns_elt, learned }) then + logger.errx('Cannot update learns for user: ' .. user) + return false + end + if not conn:add_cmd('SADD', { symbol .. '_keys', 'RS' .. user }) then + logger.errx('Cannot update learns for user: ' .. user) + return false + end + end + -- Set version + conn:add_cmd('SET', { symbol .. '_version', '2' }) + return conn:exec() + end + + logger.messagex('Convert spam tokens') + if not convert_db(db_spam, true) then + return false + end + + logger.messagex('Convert ham tokens') + if not convert_db(db_ham, false) then + return false + end + + if learn_cache_db then + logger.messagex('Convert learned ids from %s', learn_cache_db) + local db = sqlite3.open(learn_cache_db) + local ret = true + local total = 0 + + if not db then + logger.errx('Cannot open cache database: ' .. learn_cache_db) + return false + end + + db:sql('BEGIN;') + + for row in db:rows('SELECT * FROM learns;') do + local is_spam + local digest = tostring(util.encode_base32(row.digest)) + + if row.flag == '0' then + is_spam = '-1' + else + is_spam = '1' + end + + if not conn:add_cmd('HSET', { 'learned_ids', digest, is_spam }) then + logger.errx('Cannot add hash: ' .. digest) + ret = false + else + total = total + 1 + end + end + db:sql('COMMIT;') + + if ret then + conn:exec() + end + + if ret then + logger.messagex('Converted %s cached items from sqlite3 learned cache to redis', + total) + else + logger.errx('Error occurred during sending data to redis') + end + end + + logger.messagex('Migrated %s tokens for %s users for symbols (%s, %s)', + converted, nusers, symbol_spam, symbol_ham) + return true +end + +exports.convert_sqlite_to_redis = convert_sqlite_to_redis + +-- Loads sqlite3 based classifiers and output data in form of array of objects: +-- [ +-- { +-- symbol_spam = XXX +-- symbol_ham = YYY +-- db_spam = XXX.sqlite +-- db_ham = YYY.sqlite +-- learn_cache = ZZZ.sqlite +-- per_user = true/false +-- label = str +-- } +-- ] +local function load_sqlite_config(cfg) + local result = {} + + local function parse_classifier(cls) + local tbl = {} + if cls.cache then + local cache = cls.cache + if cache.type == 'sqlite3' and (cache.file or cache.path) then + tbl.learn_cache = (cache.file or cache.path) + end + end + + if cls.per_user then + tbl.per_user = cls.per_user + end + + if cls.label then + tbl.label = cls.label + end + + local statfiles = cls.statfile + for _, stf in ipairs(statfiles) do + local path = (stf.file or stf.path or stf.db or stf.dbname) + local symbol = stf.symbol or 'undefined' + + if not path then + logger.errx('no path defined for statfile %s', symbol) + else + + local spam + if stf.spam then + spam = stf.spam + else + if string.match(symbol:upper(), 'SPAM') then + spam = true + else + spam = false + end + end + + if spam then + tbl.symbol_spam = symbol + tbl.db_spam = path + else + tbl.symbol_ham = symbol + tbl.db_ham = path + end + end + end + + if tbl.symbol_spam and tbl.symbol_ham and tbl.db_ham and tbl.db_spam then + table.insert(result, tbl) + end + end + + local classifier = cfg.classifier + + if classifier then + if classifier[1] then + for _, cls in ipairs(classifier) do + if cls.bayes then + cls = cls.bayes + end + if cls.backend and cls.backend == 'sqlite3' then + parse_classifier(cls) + end + end + else + if classifier.bayes then + classifier = classifier.bayes + if classifier[1] then + for _, cls in ipairs(classifier) do + if cls.backend and cls.backend == 'sqlite3' then + parse_classifier(cls) + end + end + else + if classifier.backend and classifier.backend == 'sqlite3' then + parse_classifier(classifier) + end + end + end + end + end + + return result +end + +exports.load_sqlite_config = load_sqlite_config + +-- A helper method that suggests a user how to configure Redis based +-- classifier based on the existing sqlite classifier +local function redis_classifier_from_sqlite(sqlite_classifier, expire) + local result = { + new_schema = true, + backend = 'redis', + cache = { + backend = 'redis' + }, + statfile = { + [sqlite_classifier.symbol_spam] = { + spam = true + }, + [sqlite_classifier.symbol_ham] = { + spam = false + } + } + } + + if expire then + result.expire = expire + end + + return { classifier = { bayes = result } } +end + +exports.redis_classifier_from_sqlite = redis_classifier_from_sqlite + +-- Reads statistics config and return preprocessed table +local function process_stat_config(cfg) + local opts_section = cfg:get_all_opt('options') or {} + + -- Check if we have a dedicated section for statistics + if opts_section.statistics then + opts_section = opts_section.statistics + end + + -- Default + local res_config = { + classify_headers = { + "User-Agent", + "X-Mailer", + "Content-Type", + "X-MimeOLE", + "Organization", + "Organisation" + }, + classify_images = true, + classify_mime_info = true, + classify_urls = true, + classify_meta = true, + classify_max_tlds = 10, + } + + res_config = lua_util.override_defaults(res_config, opts_section) + + -- Postprocess classify_headers + local classify_headers_parsed = {} + + for _, v in ipairs(res_config.classify_headers) do + local s1, s2 = v:match("^([A-Z])[^%-]+%-([A-Z]).*$") + + local hname + if s1 and s2 then + hname = string.format('%s-%s', s1, s2) + else + s1 = v:match("^X%-([A-Z].*)$") + + if s1 then + hname = string.format('x%s', s1:sub(1, 3):lower()) + else + hname = string.format('%s', v:sub(1, 3):lower()) + end + end + + if classify_headers_parsed[hname] then + table.insert(classify_headers_parsed[hname], v) + else + classify_headers_parsed[hname] = { v } + end + end + + res_config.classify_headers_parsed = classify_headers_parsed + + return res_config +end + +local function get_mime_stat_tokens(task, res, i) + local parts = task:get_parts() or {} + local seen_multipart = false + local seen_plain = false + local seen_html = false + local empty_plain = false + local empty_html = false + local online_text = false + + for _, part in ipairs(parts) do + local fname = part:get_filename() + + local sz = part:get_length() + + if sz > 0 then + rawset(res, i, string.format("#ps:%d", + math.floor(math.log(sz)))) + lua_util.debugm("bayes", task, "part size: %s", + res[i]) + i = i + 1 + end + + if fname then + rawset(res, i, "#f:" .. fname) + i = i + 1 + + lua_util.debugm("bayes", task, "added attachment: #f:%s", + fname) + end + + if part:is_text() then + local tp = part:get_text() + + if tp:is_html() then + seen_html = true + + if tp:get_length() == 0 then + empty_html = true + end + else + seen_plain = true + + if tp:get_length() == 0 then + empty_plain = true + end + end + + if tp:get_lines_count() < 2 then + online_text = true + end + + rawset(res, i, "#lang:" .. (tp:get_language() or 'unk')) + lua_util.debugm("bayes", task, "added language: %s", + res[i]) + i = i + 1 + + rawset(res, i, "#cs:" .. (tp:get_charset() or 'unk')) + lua_util.debugm("bayes", task, "added charset: %s", + res[i]) + i = i + 1 + + elseif part:is_multipart() then + seen_multipart = true; + end + end + + -- Create a special token depending on parts structure + local st_tok = "#unk" + if seen_multipart and seen_html and seen_plain then + st_tok = '#mpth' + end + + if seen_html and not seen_plain then + st_tok = "#ho" + end + + if seen_plain and not seen_html then + st_tok = "#to" + end + + local spec_tok = "" + if online_text then + spec_tok = "#ot" + end + + if empty_plain then + spec_tok = spec_tok .. "#ep" + end + + if empty_html then + spec_tok = spec_tok .. "#eh" + end + + rawset(res, i, string.format("#m:%s%s", st_tok, spec_tok)) + lua_util.debugm("bayes", task, "added mime token: %s", + res[i]) + i = i + 1 + + return i +end + +local function get_headers_stat_tokens(task, cf, res, i) + --[[ + -- As discussed with Alexander Moisseev, this feature can skew statistics + -- especially when learning is separated from scanning, so learning + -- has a different set of tokens where this token can have too high weight + local hdrs_cksum = task:get_mempool():get_variable("headers_hash") + + if hdrs_cksum then + rawset(res, i, string.format("#hh:%s", hdrs_cksum:sub(1, 7))) + lua_util.debugm("bayes", task, "added hdrs hash token: %s", + res[i]) + i = i + 1 + end + ]]-- + + for k, hdrs in pairs(cf.classify_headers_parsed) do + for _, hname in ipairs(hdrs) do + local value = task:get_header(hname) + + if value then + rawset(res, i, string.format("#h:%s:%s", k, value)) + lua_util.debugm("bayes", task, "added hdrs token: %s", + res[i]) + i = i + 1 + end + end + end + + local from = (task:get_from('mime') or {})[1] + + if from and from.name then + rawset(res, i, string.format("#F:%s", from.name)) + lua_util.debugm("bayes", task, "added from name token: %s", + res[i]) + i = i + 1 + end + + return i +end + +local function get_meta_stat_tokens(task, res, i) + local day_and_hour = os.date('%u:%H', + task:get_date { format = 'message', gmt = true }) + rawset(res, i, string.format("#dt:%s", day_and_hour)) + lua_util.debugm("bayes", task, "added day_of_week token: %s", + res[i]) + i = i + 1 + + local pol = {} + + -- Authentication results + if task:has_symbol('DKIM_TRACE') then + -- Autolearn or scan + if task:has_symbol('R_SPF_ALLOW') then + table.insert(pol, 's=pass') + end + + local trace = task:get_symbol('DKIM_TRACE') + local dkim_opts = trace[1]['options'] + if dkim_opts then + for _, o in ipairs(dkim_opts) do + local check_res = string.sub(o, -1) + local domain = string.sub(o, 1, -3) + + if check_res == '+' then + table.insert(pol, string.format('d=%s:%s', "pass", domain)) + end + end + end + else + -- Offline learn + local aur = task:get_header('Authentication-Results') + + if aur then + local spf = aur:match('spf=([a-z]+)') + local dkim, dkim_domain = aur:match('dkim=([a-z]+) header.d=([a-z.%-]+)') + + if spf then + table.insert(pol, 's=' .. spf) + end + if dkim and dkim_domain then + table.insert(pol, string.format('d=%s:%s', dkim, dkim_domain)) + end + end + end + + if #pol > 0 then + rawset(res, i, string.format("#aur:%s", table.concat(pol, ','))) + lua_util.debugm("bayes", task, "added policies token: %s", + res[i]) + i = i + 1 + end + + --[[ + -- Disabled. + -- 1. Depending on the source the message has a different set of Received + -- headers as the receiving MTA adds another Received header. + -- 2. The usefulness of the Received tokens is questionable. + local rh = task:get_received_headers() + + if rh and #rh > 0 then + local lim = math.min(5, #rh) + for j =1,lim do + local rcvd = rh[j] + local ip = rcvd.real_ip + if ip and ip:is_valid() and ip:get_version() == 4 then + local masked = ip:apply_mask(24) + + rawset(res, i, string.format("#rcv:%s:%s", tostring(masked), + rcvd.proto)) + lua_util.debugm("bayes", task, "added received token: %s", + res[i]) + i = i + 1 + end + end + end + ]]-- + + return i +end + +local function get_stat_tokens(task, cf) + local res = {} + local E = {} + local i = 1 + + if cf.classify_images then + local images = task:get_images() or E + + for _, img in ipairs(images) do + rawset(res, i, "image") + i = i + 1 + rawset(res, i, tostring(img:get_height())) + i = i + 1 + rawset(res, i, tostring(img:get_width())) + i = i + 1 + rawset(res, i, tostring(img:get_type())) + i = i + 1 + + local fname = img:get_filename() + + if fname then + rawset(res, i, tostring(img:get_filename())) + i = i + 1 + end + + lua_util.debugm("bayes", task, "added image: %s", + fname) + end + end + + if cf.classify_mime_info then + i = get_mime_stat_tokens(task, res, i) + end + + if cf.classify_headers and #cf.classify_headers > 0 then + i = get_headers_stat_tokens(task, cf, res, i) + end + + if cf.classify_urls then + local urls = lua_util.extract_specific_urls { task = task, limit = 5, esld_limit = 1 } + + if urls then + for _, u in ipairs(urls) do + rawset(res, i, string.format("#u:%s", u:get_tld())) + lua_util.debugm("bayes", task, "added url token: %s", + res[i]) + i = i + 1 + end + end + end + + if cf.classify_meta then + i = get_meta_stat_tokens(task, res, i) + end + + return res +end + +exports.gen_stat_tokens = function(cfg) + local stat_config = process_stat_config(cfg) + + return function(task) + return get_stat_tokens(task, stat_config) + end +end + +return exports diff --git a/lualib/lua_tcp_sync.lua b/lualib/lua_tcp_sync.lua new file mode 100644 index 0000000..f8e6044 --- /dev/null +++ b/lualib/lua_tcp_sync.lua @@ -0,0 +1,213 @@ +local rspamd_tcp = require "rspamd_tcp" +local lua_util = require "lua_util" + +local exports = {} +local N = 'tcp_sync' + +local tcp_sync = { _conn = nil, _data = '', _eof = false, _addr = '' } +local metatable = { + __tostring = function(self) + return "class {tcp_sync connect to: " .. self._addr .. "}" + end +} + +function tcp_sync.new(connection) + local self = {} + + for name, method in pairs(tcp_sync) do + if name ~= 'new' then + self[name] = method + end + end + + self._conn = connection + + setmetatable(self, metatable) + + return self +end + +--[[[ +-- @method tcp_sync.read_once() +-- +-- Acts exactly like low-level tcp_sync.read_once() +-- the only exception is that if there is some pending data, +-- it's returned immediately and no underlying call is performed +-- +-- @return +-- true, {data} if everything is fine +-- false, {error message} otherwise +-- +--]] +function tcp_sync:read_once() + local is_ok, data + if self._data:len() > 0 then + data = self._data + self._data = nil + return true, data + end + + is_ok, data = self._conn:read_once() + + return is_ok, data +end + +--[[[ +-- @method tcp_sync.read_until(pattern) +-- +-- Reads data from the connection until pattern is found +-- returns all bytes before the pattern +-- +-- @param {pattern} Read data until pattern is found +-- @return +-- true, {data} if everything is fine +-- false, {error message} otherwise +-- @example +-- +--]] +function tcp_sync:read_until(pattern) + repeat + local pos_start, pos_end = self._data:find(pattern, 1, true) + if pos_start then + local data = self._data:sub(1, pos_start - 1) + self._data = self._data:sub(pos_end + 1) + return true, data + end + + local is_ok, more_data = self._conn:read_once() + if not is_ok then + return is_ok, more_data + end + + self._data = self._data .. more_data + until false +end + +--[[[ +-- @method tcp_sync.read_bytes(n) +-- +-- Reads {n} bytes from the stream +-- +-- @param {n} Number of bytes to read +-- @return +-- true, {data} if everything is fine +-- false, {error message} otherwise +-- +--]] +function tcp_sync:read_bytes(n) + repeat + if self._data:len() >= n then + local data = self._data:sub(1, n) + self._data = self._data:sub(n + 1) + return true, data + end + + local is_ok, more_data = self._conn:read_once() + if not is_ok then + return is_ok, more_data + end + + self._data = self._data .. more_data + until false +end + +--[[[ +-- @method tcp_sync.read_until_eof(n) +-- +-- Reads stream until EOF is reached +-- +-- @return +-- true, {data} if everything is fine +-- false, {error message} otherwise +-- +--]] +function tcp_sync:read_until_eof() + while not self:eof() do + local is_ok, more_data = self._conn:read_once() + if not is_ok then + if self:eof() then + -- this error is EOF (connection terminated) + -- exactly what we were waiting for + break + end + return is_ok, more_data + end + self._data = self._data .. more_data + end + + local data = self._data + self._data = '' + return true, data +end + +--[[[ +-- @method tcp_sync.write(n) +-- +-- Writes data into the stream. +-- +-- @return +-- true if everything is fine +-- false, {error message} otherwise +-- +--]] +function tcp_sync:write(data) + return self._conn:write(data) +end + +--[[[ +-- @method tcp_sync.close() +-- +-- Closes the connection. If the connection was created with task, +-- this method is called automatically as soon as the task is done +-- Calling this method helps to prevent connections leak. +-- The object is finally destroyed by garbage collector. +-- +-- @return +-- +--]] +function tcp_sync:close() + return self._conn:close() +end + +--[[[ +-- @method tcp_sync.eof() +-- +-- @return +-- true if last "read" operation ended with EOF +-- false otherwise +-- +--]] +function tcp_sync:eof() + if not self._eof and self._conn:eof() then + self._eof = true + end + return self._eof +end + +--[[[ +-- @function tcp_sync.shutdown(n) +-- +-- half-close socket +-- +-- @return +-- +--]] +function tcp_sync:shutdown() + return self._conn:shutdown() +end + +exports.connect = function(args) + local is_ok, connection = rspamd_tcp.connect_sync(args) + if not is_ok then + return is_ok, connection + end + + local instance = tcp_sync.new(connection) + instance._addr = string.format("%s:%s", tostring(args.host), tostring(args.port)) + + lua_util.debugm(N, args.task, 'Connected to %s', instance._addr) + + return true, instance +end + +return exports
\ No newline at end of file diff --git a/lualib/lua_urls_compose.lua b/lualib/lua_urls_compose.lua new file mode 100644 index 0000000..1113421 --- /dev/null +++ b/lualib/lua_urls_compose.lua @@ -0,0 +1,286 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_urls_compose +-- This module contains functions to compose urls queries from hostname +-- to TLD part +--]] + +local N = "lua_urls_compose" +local lua_util = require "lua_util" +local rspamd_util = require "rspamd_util" +local bit = require "bit" +local rspamd_trie = require "rspamd_trie" +local fun = require "fun" +local rspamd_regexp = require "rspamd_regexp" + +local maps_cache = {} + +local exports = {} + +local function process_url(self, log_obj, url_tld, url_host) + local tld_elt = self.tlds[url_tld] + + if tld_elt then + lua_util.debugm(N, log_obj, 'found compose tld for %s (host = %s)', + url_tld, url_host) + + for _, excl in ipairs(tld_elt.except_rules) do + local matched, ret = excl[2](url_tld, url_host) + if matched then + lua_util.debugm(N, log_obj, 'found compose exclusion for %s (%s) -> %s', + url_host, excl[1], ret) + + return ret + end + end + + if tld_elt.multipattern_compose_rules then + local matches = tld_elt.multipattern_compose_rules:match(url_host) + + if matches then + local lua_pat_idx = math.huge + + for m, _ in pairs(matches) do + if m < lua_pat_idx then + lua_pat_idx = m + end + end + + if #tld_elt.compose_rules >= lua_pat_idx then + local lua_pat = tld_elt.compose_rules[lua_pat_idx] + local matched, ret = lua_pat[2](url_tld, url_host) + + if not matched then + lua_util.debugm(N, log_obj, 'NOT found compose inclusion for %s (%s) -> %s', + url_host, lua_pat[1], url_tld) + + return url_tld + else + lua_util.debugm(N, log_obj, 'found compose inclusion for %s (%s) -> %s', + url_host, lua_pat[1], ret) + + return ret + end + else + lua_util.debugm(N, log_obj, 'NOT found compose inclusion for %s (%s) -> %s', + url_host, lua_pat_idx, url_tld) + + return url_tld + end + end + else + -- Match one by one + for _, lua_pat in ipairs(tld_elt.compose_rules) do + local matched, ret = lua_pat[2](url_tld, url_host) + if matched then + lua_util.debugm(N, log_obj, 'found compose inclusion for %s (%s) -> %s', + url_host, lua_pat[1], ret) + + return ret + end + end + end + + lua_util.debugm(N, log_obj, 'not found compose inclusion for %s in %s -> %s', + url_host, url_tld, url_tld) + else + lua_util.debugm(N, log_obj, 'not found compose tld for %s in %s -> %s', + url_host, url_tld, url_tld) + end + + return url_tld +end + +local function tld_pattern_transform(tld_pat) + -- Convert tld like pattern to a lua match pattern + -- blah -> %.blah + -- *.blah -> .*%.blah + local ret + if tld_pat:sub(1, 2) == '*.' then + ret = string.format('^((?:[^.]+\\.)*%s)$', tld_pat:sub(3)) + else + ret = string.format('(?:^|\\.)((?:[^.]+\\.)?%s)$', tld_pat) + end + + lua_util.debugm(N, nil, 'added pattern %s -> %s', + tld_pat, ret) + + return ret +end + +local function include_elt_gen(pat) + pat = rspamd_regexp.create(tld_pattern_transform(pat), 'i') + return function(_, host) + local matches = pat:search(host, false, true) + if matches then + return true, matches[1][2] + end + + return false + end +end + +local function exclude_elt_gen(pat) + pat = rspamd_regexp.create(tld_pattern_transform(pat)) + return function(tld, host) + if pat:search(host) then + return true, tld + end + + return false + end +end + +local function compose_map_cb(self, map_text) + local lpeg = require "lpeg" + + local singleline_comment = lpeg.P '#' * (1 - lpeg.S '\r\n\f') ^ 0 + local comments_strip_grammar = lpeg.C((1 - lpeg.P '#') ^ 1) * lpeg.S(' \t') ^ 0 * singleline_comment ^ 0 + + local function process_tld_rule(tld_elt, l) + if l:sub(1, 1) == '!' then + -- Exclusion elt + table.insert(tld_elt.except_rules, { l, exclude_elt_gen(l:sub(2)) }) + else + table.insert(tld_elt.compose_rules, { l, include_elt_gen(l) }) + end + end + + local function process_map_line(l) + -- Skip empty lines and comments + if #l == 0 then + return + end + l = comments_strip_grammar:match(l) + if not l or #l == 0 then + return + end + + -- Get TLD + local tld = rspamd_util.get_tld(l) + + if tld then + local tld_elt = self.tlds[tld] + + if not tld_elt then + tld_elt = { + compose_rules = {}, + except_rules = {}, + multipattern_compose_rules = nil + } + + lua_util.debugm(N, rspamd_config, 'processed new tld rule for %s', tld) + self.tlds[tld] = tld_elt + end + + process_tld_rule(tld_elt, l) + else + lua_util.debugm(N, rspamd_config, 'cannot read tld from compose map line: %s', l) + end + end + + for line in map_text:lines() do + process_map_line(line) + end + + local multipattern_threshold = 1 + for tld, tld_elt in pairs(self.tlds) do + -- Sort patterns to have longest labels before shortest ones, + -- so we can ensure that they match before + table.sort(tld_elt.compose_rules, function(e1, e2) + local _, ndots1 = string.gsub(e1[1], '(%.)', '') + local _, ndots2 = string.gsub(e2[1], '(%.)', '') + + return ndots1 > ndots2 + end) + if rspamd_trie.has_hyperscan() and #tld_elt.compose_rules >= multipattern_threshold then + lua_util.debugm(N, rspamd_config, 'tld %s has %s rules, apply multipattern', + tld, #tld_elt.compose_rules) + local flags = bit.bor(rspamd_trie.flags.re, + rspamd_trie.flags.dot_all, + rspamd_trie.flags.no_start, + rspamd_trie.flags.icase) + + + -- We now convert our internal patterns to multipattern patterns + local mp_table = fun.totable(fun.map(function(pat_elt) + return tld_pattern_transform(pat_elt[1]) + end, tld_elt.compose_rules)) + tld_elt.multipattern_compose_rules = rspamd_trie.create(mp_table, flags) + end + end +end + +exports.add_composition_map = function(cfg, map_obj) + local hash_key = map_obj + if type(map_obj) == 'table' then + hash_key = lua_util.table_digest(map_obj) + end + + local map = maps_cache[hash_key] + + if not map then + local ret = { + process_url = process_url, + hash = hash_key, + tlds = {}, + } + + map = cfg:add_map { + type = 'callback', + description = 'URL compose map', + url = map_obj, + callback = function(input) + compose_map_cb(ret, input) + end, + opaque_data = true, + } + + ret.map = map + maps_cache[hash_key] = ret + map = ret + end + + return map +end + +exports.inject_composition_rules = function(cfg, rules) + local hash_key = rules + local rspamd_text = require "rspamd_text" + if type(rules) == 'table' then + hash_key = lua_util.table_digest(rules) + end + + local map = maps_cache[hash_key] + + if not map then + local ret = { + process_url = process_url, + hash = hash_key, + tlds = {}, + } + + compose_map_cb(ret, rspamd_text.fromtable(rules, '\n')) + maps_cache[hash_key] = ret + map = ret + end + + return map +end + +return exports
\ No newline at end of file diff --git a/lualib/lua_util.lua b/lualib/lua_util.lua new file mode 100644 index 0000000..6964b0f --- /dev/null +++ b/lualib/lua_util.lua @@ -0,0 +1,1639 @@ +--[[ +Copyright (c) 2023, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_util +-- This module contains utility functions for working with Lua and/or Rspamd +--]] + +local exports = {} +local lpeg = require 'lpeg' +local rspamd_util = require "rspamd_util" +local fun = require "fun" +local lupa = require "lupa" + +local split_grammar = {} +local spaces_split_grammar +local space = lpeg.S ' \t\n\v\f\r' +local nospace = 1 - space +local ptrim = space ^ 0 * lpeg.C((space ^ 0 * nospace ^ 1) ^ 0) +local match = lpeg.match + +lupa.configure('{%', '%}', '{=', '=}', '{#', '#}', { + keep_trailing_newline = true, + autoescape = false, +}) + +lupa.filters.pbkdf = function(s) + local cr = require "rspamd_cryptobox" + return cr.pbkdf(s) +end + +local function rspamd_str_split(s, sep) + local gr + if not sep then + if not spaces_split_grammar then + local _sep = space + local elem = lpeg.C((1 - _sep) ^ 0) + local p = lpeg.Ct(elem * (_sep * elem) ^ 0) + spaces_split_grammar = p + end + + gr = spaces_split_grammar + else + gr = split_grammar[sep] + + if not gr then + local _sep + if type(sep) == 'string' then + _sep = lpeg.S(sep) -- Assume set + else + _sep = sep -- Assume lpeg object + end + local elem = lpeg.C((1 - _sep) ^ 0) + local p = lpeg.Ct(elem * (_sep * elem) ^ 0) + gr = p + split_grammar[sep] = gr + end + end + + return gr:match(s) +end + +--[[[ +-- @function lua_util.str_split(text, delimiter) +-- Splits text into a numeric table by delimiter +-- @param {string} text delimited text +-- @param {string} delimiter the delimiter +-- @return {table} numeric table containing string parts +--]] + +exports.rspamd_str_split = rspamd_str_split +exports.str_split = rspamd_str_split + +local function rspamd_str_trim(s) + return match(ptrim, s) +end +exports.rspamd_str_trim = rspamd_str_trim +--[[[ +-- @function lua_util.str_trim(text) +-- Returns a string with no trailing and leading spaces +-- @param {string} text input text +-- @return {string} string with no trailing and leading spaces +--]] +exports.str_trim = rspamd_str_trim + +--[[[ +-- @function lua_util.str_startswith(text, prefix) +-- @param {string} text +-- @param {string} prefix +-- @return {boolean} true if text starts with the specified prefix, false otherwise +--]] +exports.str_startswith = function(s, prefix) + return s:sub(1, prefix:len()) == prefix +end + +--[[[ +-- @function lua_util.str_endswith(text, suffix) +-- @param {string} text +-- @param {string} suffix +-- @return {boolean} true if text ends with the specified suffix, false otherwise +--]] +exports.str_endswith = function(s, suffix) + return s:find(suffix, -suffix:len(), true) ~= nil +end + +--[[[ +-- @function lua_util.round(number, decimalPlaces) +-- Round number to fixed number of decimal points +-- @param {number} number number to round +-- @param {number} decimalPlaces number of decimal points +-- @return {number} rounded number +--]] + +-- modified version from Robert Jay Gould http://lua-users.org/wiki/SimpleRound +exports.round = function(num, numDecimalPlaces) + local mult = 10 ^ (numDecimalPlaces or 0) + if num >= 0 then + return math.floor(num * mult + 0.5) / mult + else + return math.ceil(num * mult - 0.5) / mult + end +end + +--[[[ +-- @function lua_util.template(text, replacements) +-- Replaces values in a text template +-- Variable names can contain letters, numbers and underscores, are prefixed with `$` and may or not use curly braces. +-- @param {string} text text containing variables +-- @param {table} replacements key/value pairs for replacements +-- @return {string} string containing replaced values +-- @example +-- local goop = lua_util.template("HELLO $FOO ${BAR}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'}) +-- -- goop contains "HELLO LUA WORLD!" +--]] + +exports.template = function(tmpl, keys) + local var_lit = lpeg.P { lpeg.R("az") + lpeg.R("AZ") + lpeg.R("09") + "_" } + local var = lpeg.P { (lpeg.P("$") / "") * ((var_lit ^ 1) / keys) } + local var_braced = lpeg.P { (lpeg.P("${") / "") * ((var_lit ^ 1) / keys) * (lpeg.P("}") / "") } + + local template_grammar = lpeg.Cs((var + var_braced + 1) ^ 0) + + return lpeg.match(template_grammar, tmpl) +end + +local function enrich_template_with_globals(env) + local newenv = exports.shallowcopy(env) + newenv.paths = rspamd_paths + newenv.env = rspamd_env + + return newenv +end +--[[[ +-- @function lua_util.jinja_template(text, env[, skip_global_env]) +-- Replaces values in a text template according to jinja2 syntax +-- @param {string} text text containing variables +-- @param {table} replacements key/value pairs for replacements +-- @param {boolean} skip_global_env don't export Rspamd superglobals +-- @return {string} string containing replaced values +-- @example +-- lua_util.jinja_template("HELLO {{FOO}} {{BAR}}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'}) +-- "HELLO LUA WORLD!" +--]] +exports.jinja_template = function(text, env, skip_global_env) + if not skip_global_env then + env = enrich_template_with_globals(env) + end + + return lupa.expand(text, env) +end + +--[[[ +-- @function lua_util.jinja_file(filename, env[, skip_global_env]) +-- Replaces values in a text template according to jinja2 syntax +-- @param {string} filename name of file to expand +-- @param {table} replacements key/value pairs for replacements +-- @param {boolean} skip_global_env don't export Rspamd superglobals +-- @return {string} string containing replaced values +-- @example +-- lua_util.jinja_template("HELLO {{FOO}} {{BAR}}!", {['FOO'] = 'LUA', ['BAR'] = 'WORLD'}) +-- "HELLO LUA WORLD!" +--]] +exports.jinja_template_file = function(filename, env, skip_global_env) + if not skip_global_env then + env = enrich_template_with_globals(env) + end + + return lupa.expand_file(filename, env) +end + +exports.remove_email_aliases = function(email_addr) + local function check_gmail_user(addr) + -- Remove all points + local no_dots_user = string.gsub(addr.user, '%.', '') + local cap, pluses = string.match(no_dots_user, '^([^%+][^%+]*)(%+.*)$') + if cap then + return cap, rspamd_str_split(pluses, '+'), nil + elseif no_dots_user ~= addr.user then + return no_dots_user, {}, nil + end + + return nil + end + + local function check_address(addr) + if addr.user then + local cap, pluses = string.match(addr.user, '^([^%+][^%+]*)(%+.*)$') + if cap then + return cap, rspamd_str_split(pluses, '+'), nil + end + end + + return nil + end + + local function set_addr(addr, new_user, new_domain) + if new_user then + addr.user = new_user + end + if new_domain then + addr.domain = new_domain + end + + if addr.domain then + addr.addr = string.format('%s@%s', addr.user, addr.domain) + else + addr.addr = string.format('%s@', addr.user) + end + + if addr.name and #addr.name > 0 then + addr.raw = string.format('"%s" <%s>', addr.name, addr.addr) + else + addr.raw = string.format('<%s>', addr.addr) + end + end + + local function check_gmail(addr) + local nu, tags, nd = check_gmail_user(addr) + + if nu then + return nu, tags, nd + end + + return nil + end + + local function check_googlemail(addr) + local nd = 'gmail.com' + local nu, tags = check_gmail_user(addr) + + if nu then + return nu, tags, nd + end + + return nil, nil, nd + end + + local specific_domains = { + ['gmail.com'] = check_gmail, + ['googlemail.com'] = check_googlemail, + } + + if email_addr then + if email_addr.domain and specific_domains[email_addr.domain] then + local nu, tags, nd = specific_domains[email_addr.domain](email_addr) + if nu or nd then + set_addr(email_addr, nu, nd) + + return nu, tags + end + else + local nu, tags, nd = check_address(email_addr) + if nu or nd then + set_addr(email_addr, nu, nd) + + return nu, tags + end + end + + return nil + end +end + +exports.is_rspamc_or_controller = function(task) + local ua = task:get_request_header('User-Agent') or '' + local pwd = task:get_request_header('Password') + local is_rspamc = false + if tostring(ua) == 'rspamc' or pwd then + is_rspamc = true + end + + return is_rspamc +end + +--[[[ +-- @function lua_util.unpack(table) +-- Converts numeric table to varargs +-- This is `unpack` on Lua 5.1/5.2/LuaJIT and `table.unpack` on Lua 5.3 +-- @param {table} table numerically indexed table to unpack +-- @return {varargs} unpacked table elements +--]] + +local unpack_function = table.unpack or unpack +exports.unpack = function(t) + return unpack_function(t) +end + +--[[[ +-- @function lua_util.flatten(table) +-- Flatten underlying tables in a single table +-- @param {table} table table of tables +-- @return {table} flattened table +--]] +exports.flatten = function(t) + local res = {} + for _, e in fun.iter(t) do + for _, v in fun.iter(e) do + res[#res + 1] = v + end + end + + return res +end + +--[[[ +-- @function lua_util.spairs(table) +-- Like `pairs` but keys are sorted lexicographically +-- @param {table} table table containing key/value pairs +-- @return {function} generator function returning key/value pairs +--]] + +-- Sorted iteration: +-- for k,v in spairs(t) do ... end +-- +-- or with custom comparison: +-- for k, v in spairs(t, function(t, a, b) return t[a] < t[b] end) +-- +-- optional limit is also available (e.g. return top X elements) +local function spairs(t, order, lim) + -- collect the keys + local keys = {} + for k in pairs(t) do + keys[#keys + 1] = k + end + + -- if order function given, sort by it by passing the table and keys a, b, + -- otherwise just sort the keys + if order then + table.sort(keys, function(a, b) + return order(t, a, b) + end) + else + table.sort(keys) + end + + -- return the iterator function + local i = 0 + return function() + i = i + 1 + if not lim or i <= lim then + if keys[i] then + return keys[i], t[keys[i]] + end + end + end +end + +exports.spairs = spairs + +local lua_cfg_utils = require "lua_cfg_utils" + +exports.config_utils = lua_cfg_utils +exports.disable_module = lua_cfg_utils.disable_module + +--[[[ +-- @function lua_util.disable_module(modname) +-- Checks experimental plugins state and disable if needed +-- @param {string} modname name of plugin to check +-- @return {boolean} true if plugin should be enabled, false otherwise +--]] +local function check_experimental(modname) + if rspamd_config:experimental_enabled() then + return true + else + lua_cfg_utils.disable_module(modname, 'experimental') + end + + return false +end + +exports.check_experimental = check_experimental + +--[[[ +-- @function lua_util.list_to_hash(list) +-- Converts numerically-indexed table to table indexed by values +-- @param {table} list numerically-indexed table or string, which is treated as a one-element list +-- @return {table} table indexed by values +-- @example +-- local h = lua_util.list_to_hash({"a", "b"}) +-- -- h contains {a = true, b = true} +--]] +local function list_to_hash(list) + if type(list) == 'table' then + if list[1] then + local h = {} + for _, e in ipairs(list) do + h[e] = true + end + return h + else + return list + end + elseif type(list) == 'string' then + local h = {} + h[list] = true + return h + end +end + +exports.list_to_hash = list_to_hash + +--[[[ +-- @function lua_util.nkeys(table|gen, param, state) +-- Returns number of keys in a table (i.e. from both the array and hash parts combined) +-- @param {table} list numerically-indexed table or string, which is treated as a one-element list +-- @return {number} number of keys +-- @example +-- print(lua_util.nkeys({})) -- 0 +-- print(lua_util.nkeys({ "a", nil, "b" })) -- 2 +-- print(lua_util.nkeys({ dog = 3, cat = 4, bird = nil })) -- 2 +-- print(lua_util.nkeys({ "a", dog = 3, cat = 4 })) -- 3 +-- +--]] +local function nkeys(gen, param, state) + local n = 0 + if not param then + for _, _ in pairs(gen) do + n = n + 1 + end + else + for _, _ in fun.iter(gen, param, state) do + n = n + 1 + end + end + return n +end + +exports.nkeys = nkeys + +--[[[ +-- @function lua_util.parse_time_interval(str) +-- Parses human readable time interval +-- Accepts 's' for seconds, 'm' for minutes, 'h' for hours, 'd' for days, +-- 'w' for weeks, 'y' for years +-- @param {string} str input string +-- @return {number|nil} parsed interval as seconds (might be fractional) +--]] +local function parse_time_interval(str) + local function parse_time_suffix(s) + if s == 's' then + return 1 + elseif s == 'm' then + return 60 + elseif s == 'h' then + return 3600 + elseif s == 'd' then + return 86400 + elseif s == 'w' then + return 86400 * 7 + elseif s == 'y' then + return 365 * 86400; + end + end + + local digit = lpeg.R("09") + local parser = {} + parser.integer = (lpeg.S("+-") ^ -1) * + (digit ^ 1) + parser.fractional = (lpeg.P(".")) * + (digit ^ 1) + parser.number = (parser.integer * + (parser.fractional ^ -1)) + + (lpeg.S("+-") * parser.fractional) + parser.time = lpeg.Cf(lpeg.Cc(1) * + (parser.number / tonumber) * + ((lpeg.S("smhdwy") / parse_time_suffix) ^ -1), + function(acc, val) + return acc * val + end) + + local t = lpeg.match(parser.time, str) + + return t +end + +exports.parse_time_interval = parse_time_interval + +--[[[ +-- @function lua_util.dehumanize_number(str) +-- Parses human readable number +-- Accepts 'k' for thousands, 'm' for millions, 'g' for billions, 'b' suffix for 1024 multiplier, +-- e.g. `10mb` equal to `10 * 1024 * 1024` +-- @param {string} str input string +-- @return {number|nil} parsed number +--]] +local function dehumanize_number(str) + local function parse_suffix(s) + if s == 'k' then + return 1000 + elseif s == 'm' then + return 1000000 + elseif s == 'g' then + return 1e9 + elseif s == 'kb' then + return 1024 + elseif s == 'mb' then + return 1024 * 1024 + elseif s == 'gb' then + return 1024 * 1024; + end + end + + local digit = lpeg.R("09") + local parser = {} + parser.integer = (lpeg.S("+-") ^ -1) * + (digit ^ 1) + parser.fractional = (lpeg.P(".")) * + (digit ^ 1) + parser.number = (parser.integer * + (parser.fractional ^ -1)) + + (lpeg.S("+-") * parser.fractional) + parser.humanized_number = lpeg.Cf(lpeg.Cc(1) * + (parser.number / tonumber) * + (((lpeg.S("kmg") * (lpeg.P("b") ^ -1)) / parse_suffix) ^ -1), + function(acc, val) + return acc * val + end) + + local t = lpeg.match(parser.humanized_number, str) + + return t +end + +exports.dehumanize_number = dehumanize_number + +--[[[ +-- @function lua_util.table_cmp(t1, t2) +-- Compare two tables deeply +--]] +local function table_cmp(table1, table2) + local avoid_loops = {} + local function recurse(t1, t2) + if type(t1) ~= type(t2) then + return false + end + if type(t1) ~= "table" then + return t1 == t2 + end + + if avoid_loops[t1] then + return avoid_loops[t1] == t2 + end + avoid_loops[t1] = t2 + -- Copy keys from t2 + local t2keys = {} + local t2tablekeys = {} + for k, _ in pairs(t2) do + if type(k) == "table" then + table.insert(t2tablekeys, k) + end + t2keys[k] = true + end + -- Let's iterate keys from t1 + for k1, v1 in pairs(t1) do + local v2 = t2[k1] + if type(k1) == "table" then + -- if key is a table, we need to find an equivalent one. + local ok = false + for i, tk in ipairs(t2tablekeys) do + if table_cmp(k1, tk) and recurse(v1, t2[tk]) then + table.remove(t2tablekeys, i) + t2keys[tk] = nil + ok = true + break + end + end + if not ok then + return false + end + else + -- t1 has a key which t2 doesn't have, fail. + if v2 == nil then + return false + end + t2keys[k1] = nil + if not recurse(v1, v2) then + return false + end + end + end + -- if t2 has a key which t1 doesn't have, fail. + if next(t2keys) then + return false + end + return true + end + return recurse(table1, table2) +end + +exports.table_cmp = table_cmp + +--[[[ +-- @function lua_util.table_merge(t1, t2) +-- Merge two tables +--]] +local function table_merge(t1, t2) + local res = {} + local nidx = 1 -- for numeric indicies + local it_func = function(k, v) + if type(k) == 'number' then + res[nidx] = v + nidx = nidx + 1 + else + res[k] = v + end + end + for k, v in pairs(t1) do + it_func(k, v) + end + for k, v in pairs(t2) do + it_func(k, v) + end + return res +end + +exports.table_merge = table_merge + +--[[[ +-- @function lua_util.table_cmp(task, name, value, stop_chars) +-- Performs header folding +--]] +exports.fold_header = function(task, name, value, stop_chars) + + local how + + if task:has_flag("milter") then + how = "lf" + else + how = task:get_newlines_type() + end + + return rspamd_util.fold_header(name, value, how, stop_chars) +end + +--[[[ +-- @function lua_util.override_defaults(defaults, override) +-- Overrides values from defaults with override +--]] +local function override_defaults(def, override) + -- Corner cases + if not override or type(override) ~= 'table' then + return def + end + if not def or type(def) ~= 'table' then + return override + end + + local res = {} + + for k, v in pairs(override) do + if type(v) == 'table' then + if def[k] and type(def[k]) == 'table' then + -- Recursively override elements + res[k] = override_defaults(def[k], v) + else + res[k] = v + end + else + res[k] = v + end + end + + for k, v in pairs(def) do + if type(res[k]) == 'nil' then + res[k] = v + end + end + + return res +end + +exports.override_defaults = override_defaults + +--[[[ +-- @function lua_util.filter_specific_urls(urls, params) +-- params: { +- - task - if needed to save in the cache +- - limit <int> (default = 9999) +- - esld_limit <int> (default = 9999) n domains per eSLD (effective second level domain) + works only if number of unique eSLD less than `limit` +- - need_emails <bool> (default = false) +- - filter <callback> (default = nil) +- - prefix <string> cache prefix (default = nil) +-- } +-- Apply heuristic in extracting of urls from `urls` table, this function +-- tries its best to extract specific number of urls from a task based on +-- their characteristics +--]] +exports.filter_specific_urls = function(urls, params) + local cache_key + + if params.task and not params.no_cache then + if params.prefix then + cache_key = params.prefix + else + cache_key = string.format('sp_urls_%d%s%s%s', params.limit, + tostring(params.need_emails or false), + tostring(params.need_images or false), + tostring(params.need_content or false)) + end + local cached = params.task:cache_get(cache_key) + + if cached then + return cached + end + end + + if not urls then + return {} + end + + if params.filter then + urls = fun.totable(fun.filter(params.filter, urls)) + end + + -- Filter by tld: + local tlds = {} + local eslds = {} + local ntlds, neslds = 0, 0 + + local res = {} + local nres = 0 + + local function insert_url(str, u) + if not res[str] then + res[str] = u + nres = nres + 1 + + return true + end + + return false + end + + local function process_single_url(u, default_priority) + local priority = default_priority or 1 -- Normal priority + local flags = u:get_flags() + if params.ignore_ip and flags.numeric then + return + end + + if flags.redirected then + local redir = u:get_redirected() -- get the real url + + if params.ignore_redirected then + -- Replace `u` with redir + u = redir + priority = 2 + else + -- Process both redirected url and the original one + process_single_url(redir, 2) + end + end + + if flags.image then + if not params.need_images then + -- Ignore url + return + else + -- Penalise images in urls + priority = 0 + end + end + + local esld = u:get_tld() + local str_hash = tostring(u) + + if esld then + -- Special cases + if (u:get_protocol() ~= 'mailto') and (not flags.html_displayed) then + if flags.obscured then + priority = 3 + else + if (flags.has_user or flags.has_port) then + priority = 2 + elseif (flags.subject or flags.phished) then + priority = 2 + end + end + elseif flags.html_displayed then + priority = 0 + end + + if not eslds[esld] then + eslds[esld] = { { str_hash, u, priority } } + neslds = neslds + 1 + else + if #eslds[esld] < params.esld_limit then + table.insert(eslds[esld], { str_hash, u, priority }) + end + end + + + -- eSLD - 1 part => tld + local parts = rspamd_str_split(esld, '.') + local tld = table.concat(fun.totable(fun.tail(parts)), '.') + + if not tlds[tld] then + tlds[tld] = { { str_hash, u, priority } } + ntlds = ntlds + 1 + else + table.insert(tlds[tld], { str_hash, u, priority }) + end + end + end + + for _, u in ipairs(urls) do + process_single_url(u) + end + + local limit = params.limit + limit = limit - nres + if limit < 0 then + limit = 0 + end + + if limit == 0 then + res = exports.values(res) + if params.task and not params.no_cache then + params.task:cache_set(cache_key, res) + end + return res + end + + -- Sort eSLDs and tlds + local function sort_stuff(tbl) + -- Sort according to max priority + table.sort(tbl, function(e1, e2) + -- Sort by priority so max priority is at the end + table.sort(e1, function(tr1, tr2) + return tr1[3] < tr2[3] + end) + table.sort(e2, function(tr1, tr2) + return tr1[3] < tr2[3] + end) + + if e1[#e1][3] ~= e2[#e2][3] then + -- Sort by priority so max priority is at the beginning + return e1[#e1][3] > e2[#e2][3] + else + -- Prefer less urls to more urls per esld + return #e1 < #e2 + end + + end) + + return tbl + end + + eslds = sort_stuff(exports.values(eslds)) + neslds = #eslds + + if neslds <= limit then + -- Number of eslds < limit + repeat + local item_found = false + + for _, lurls in ipairs(eslds) do + if #lurls > 0 then + local last = table.remove(lurls) + insert_url(last[1], last[2]) + limit = limit - 1 + item_found = true + end + end + + until limit <= 0 or not item_found + + res = exports.values(res) + if params.task and not params.no_cache then + params.task:cache_set(cache_key, res) + end + return res + end + + tlds = sort_stuff(exports.values(tlds)) + ntlds = #tlds + + -- Number of tlds < limit + while limit > 0 do + for _, lurls in ipairs(tlds) do + if #lurls > 0 then + local last = table.remove(lurls) + insert_url(last[1], last[2]) + limit = limit - 1 + end + if limit == 0 then + break + end + end + end + + res = exports.values(res) + if params.task and not params.no_cache then + params.task:cache_set(cache_key, res) + end + return res +end + +--[[[ +-- @function lua_util.extract_specific_urls(params) +-- params: { +- - task +- - limit <int> (default = 9999) +- - esld_limit <int> (default = 9999) n domains per eSLD (effective second level domain) + works only if number of unique eSLD less than `limit` +- - need_emails <bool> (default = false) +- - filter <callback> (default = nil) +- - prefix <string> cache prefix (default = nil) +- - ignore_redirected <bool> (default = false) +- - need_images <bool> (default = false) +- - need_content <bool> (default = false) +-- } +-- Apply heuristic in extracting of urls from task, this function +-- tries its best to extract specific number of urls from a task based on +-- their characteristics +--]] +-- exports.extract_specific_urls = function(params_or_task, limit, need_emails, filter, prefix) +exports.extract_specific_urls = function(params_or_task, lim, need_emails, filter, prefix) + local default_params = { + limit = 9999, + esld_limit = 9999, + need_emails = false, + need_images = false, + need_content = false, + filter = nil, + prefix = nil, + ignore_ip = false, + ignore_redirected = false, + no_cache = false, + } + + local params + if type(params_or_task) == 'table' and type(lim) == 'nil' then + params = params_or_task + else + -- Deprecated call + params = { + task = params_or_task, + limit = lim, + need_emails = need_emails, + filter = filter, + prefix = prefix + } + end + for k, v in pairs(default_params) do + if type(params[k]) == 'nil' and v ~= nil then + params[k] = v + end + end + local url_params = { + emails = params.need_emails, + images = params.need_images, + content = params.need_content, + flags = params.flags, -- maybe nil + flags_mode = params.flags_mode, -- maybe nil + } + + -- Shortcut for cached stuff + if params.task and not params.no_cache then + local cache_key + if params.prefix then + cache_key = params.prefix + else + local cache_key_suffix + if params.flags then + cache_key_suffix = table.concat(params.flags) .. (params.flags_mode or '') + else + cache_key_suffix = string.format('%s%s%s', + tostring(params.need_emails or false), + tostring(params.need_images or false), + tostring(params.need_content or false)) + end + cache_key = string.format('sp_urls_%d%s', params.limit, cache_key_suffix) + end + local cached = params.task:cache_get(cache_key) + + if cached then + return cached + end + end + + -- No cache version + local urls = params.task:get_urls(url_params) + + return exports.filter_specific_urls(urls, params) +end + +--[[[ +-- @function lua_util.deepcopy(table) +-- params: { +- - table +-- } +-- Performs deep copy of the table. Including metatables +--]] +local function deepcopy(orig) + local orig_type = type(orig) + local copy + if orig_type == 'table' then + copy = {} + for orig_key, orig_value in next, orig, nil do + copy[deepcopy(orig_key)] = deepcopy(orig_value) + end + if getmetatable(orig) then + setmetatable(copy, deepcopy(getmetatable(orig))) + end + else + -- number, string, boolean, etc + copy = orig + end + return copy +end + +exports.deepcopy = deepcopy + +--[[[ +-- @function lua_util.deepsort(table) +-- params: { +- - table +-- } +-- Performs recursive in-place sort of a table +--]] +local function default_sort_cmp(e1, e2) + if type(e1) == type(e2) then + return e1 < e2 + else + return type(e1) < type(e2) + end +end + +local function deepsort(tbl, sort_func) + local orig_type = type(tbl) + if orig_type == 'table' then + table.sort(tbl, sort_func or default_sort_cmp) + for _, orig_value in next, tbl, nil do + deepsort(orig_value) + end + end +end + +exports.deepsort = deepsort + +--[[[ +-- @function lua_util.shallowcopy(tbl) +-- Performs shallow (and fast) copy of a table or another Lua type +--]] +exports.shallowcopy = function(orig) + local orig_type = type(orig) + local copy + if orig_type == 'table' then + copy = {} + for orig_key, orig_value in pairs(orig) do + copy[orig_key] = orig_value + end + else + copy = orig + end + return copy +end + +-- Debugging support +local logger = require "rspamd_logger" +local unconditional_debug = logger.log_level() == 'debug' +local debug_modules = {} +local debug_aliases = {} +local log_level = 384 -- debug + forced (1 << 7 | 1 << 8) + + +exports.init_debug_logging = function(config) + -- Fill debug modules from the config + if not unconditional_debug then + local log_config = config:get_all_opt('logging') + if log_config then + local log_level_str = log_config.level + if log_level_str then + if log_level_str == 'debug' then + unconditional_debug = true + end + end + if log_config.debug_modules then + for _, m in ipairs(log_config.debug_modules) do + debug_modules[m] = true + logger.infox(config, 'enable debug for Lua module %s', m) + end + end + + if #debug_aliases > 0 then + for alias, mod in pairs(debug_aliases) do + if debug_modules[mod] then + debug_modules[alias] = true + logger.infox(config, 'enable debug for Lua module %s (%s aliased)', + alias, mod) + end + end + end + end + end +end + +exports.enable_debug_logging = function() + unconditional_debug = true +end + +exports.enable_debug_modules = function(...) + for _, m in ipairs({ ... }) do + debug_modules[m] = true + end +end + +exports.disable_debug_logging = function() + unconditional_debug = false +end + +--[[[ +-- @function lua_util.debugm(module, [log_object], format, ...) +-- Performs fast debug log for a specific module +--]] +exports.debugm = function(mod, obj_or_fmt, fmt_or_something, ...) + if unconditional_debug or debug_modules[mod] then + if type(obj_or_fmt) == 'string' then + logger.logx(log_level, mod, '', 2, obj_or_fmt, fmt_or_something, ...) + else + logger.logx(log_level, mod, obj_or_fmt, 2, fmt_or_something, ...) + end + end +end + +--[[[ +-- @function lua_util.add_debug_alias(mod, alias) +-- Add debugging alias so logging to `alias` will be treated as logging to `mod` +--]] +exports.add_debug_alias = function(mod, alias) + debug_aliases[alias] = mod + + if debug_modules[mod] then + debug_modules[alias] = true + logger.infox(rspamd_config, 'enable debug for Lua module %s (%s aliased)', + alias, mod) + end +end +---[[[ +-- @function lua_util.get_task_verdict(task) +-- Returns verdict for a task + score if certain, must be called from idempotent filters only +-- Returns string: +-- * `spam`: if message have over reject threshold and has more than one positive rule +-- * `junk`: if a message has between score between [add_header/rewrite subject] to reject thresholds and has more than two positive rules +-- * `passthrough`: if a message has been passed through some short-circuit rule +-- * `ham`: if a message has overall score below junk level **and** more than three negative rule, or negative total score +-- * `uncertain`: all other cases +--]] +exports.get_task_verdict = function(task) + local lua_verdict = require "lua_verdict" + + return lua_verdict.get_default_verdict(task) +end + +---[[[ +-- @function lua_util.maybe_obfuscate_string(subject, settings, prefix) +-- Obfuscate string if enabled in settings. Also checks utf8 validity - if +-- string is not valid utf8 then '???' is returned. Empty string returned as is. +-- Supported settings: +-- * <prefix>_privacy = false - subject privacy is off +-- * <prefix>_privacy_alg = 'blake2' - default hash-algorithm to obfuscate subject +-- * <prefix>_privacy_prefix = 'obf' - prefix to show it's obfuscated +-- * <prefix>_privacy_length = 16 - cut the length of the hash; if 0 or fasle full hash is returned +-- @return obfuscated or validated subject +--]] + +exports.maybe_obfuscate_string = function(subject, settings, prefix) + local hash = require 'rspamd_cryptobox_hash' + if not subject or subject == '' then + return subject + elseif not rspamd_util.is_valid_utf8(subject) then + subject = '???' + elseif settings[prefix .. '_privacy'] then + local hash_alg = settings[prefix .. '_privacy_alg'] or 'blake2' + local subject_hash = hash.create_specific(hash_alg, subject) + + local strip_len = settings[prefix .. '_privacy_length'] + if strip_len and strip_len > 0 then + subject = subject_hash:hex():sub(1, strip_len) + else + subject = subject_hash:hex() + end + + local privacy_prefix = settings[prefix .. '_privacy_prefix'] + if privacy_prefix and #privacy_prefix > 0 then + subject = privacy_prefix .. ':' .. subject + end + end + + return subject +end + +---[[[ +-- @function lua_util.callback_from_string(str) +-- Converts a string like `return function(...) end` to lua function and return true and this function +-- or returns false + error message +-- @return status code and function object or an error message +--]]] +exports.callback_from_string = function(s) + local loadstring = loadstring or load + + if not s or #s == 0 then + return false, 'invalid or empty string' + end + + s = exports.rspamd_str_trim(s) + local inp + + if s:match('^return%s*function') then + -- 'return function', can be evaluated directly + inp = s + elseif s:match('^function%s*%(') then + inp = 'return ' .. s + else + -- Just a plain sequence + inp = 'return function(...)\n' .. s .. '; end' + end + + local ret, res_or_err = pcall(loadstring(inp)) + + if not ret or type(res_or_err) ~= 'function' then + return false, res_or_err + end + + return ret, res_or_err +end + +---[[[ +-- @function lua_util.keys(t) +-- Returns all keys from a specific table +-- @param {table} t input table (or iterator triplet) +-- @return array of keys +--]]] +exports.keys = function(gen, param, state) + local keys = {} + local i = 1 + + if param then + for k, _ in fun.iter(gen, param, state) do + rawset(keys, i, k) + i = i + 1 + end + else + for k, _ in pairs(gen) do + rawset(keys, i, k) + i = i + 1 + end + end + + return keys +end + +---[[[ +-- @function lua_util.values(t) +-- Returns all values from a specific table +-- @param {table} t input table +-- @return array of values +--]]] +exports.values = function(gen, param, state) + local values = {} + local i = 1 + + if param then + for _, v in fun.iter(gen, param, state) do + rawset(values, i, v) + i = i + 1 + end + else + for _, v in pairs(gen) do + rawset(values, i, v) + i = i + 1 + end + end + + return values +end + +---[[[ +-- @function lua_util.distance_sorted(t1, t2) +-- Returns distance between two sorted tables t1 and t2 +-- @param {table} t1 input table +-- @param {table} t2 input table +-- @return distance between `t1` and `t2` +--]]] +exports.distance_sorted = function(t1, t2) + local ncomp = #t1 + local ndiff = 0 + local i, j = 1, 1 + + if ncomp < #t2 then + ncomp = #t2 + end + + for _ = 1, ncomp do + if j > #t2 then + ndiff = ndiff + ncomp - #t2 + if i > j then + ndiff = ndiff - (i - j) + end + break + elseif i > #t1 then + ndiff = ndiff + ncomp - #t1 + if j > i then + ndiff = ndiff - (j - i) + end + break + end + + if t1[i] == t2[j] then + i = i + 1 + j = j + 1 + elseif t1[i] < t2[j] then + i = i + 1 + ndiff = ndiff + 1 + else + j = j + 1 + ndiff = ndiff + 1 + end + end + + return ndiff +end + +---[[[ +-- @function lua_util.table_digest(t) +-- Returns hash of all values if t[1] is string or all keys/values otherwise +-- @param {table} t input array or map +-- @return {string} base32 representation of blake2b hash of all strings +--]]] +local function table_digest(t) + local cr = require "rspamd_cryptobox_hash" + local h = cr.create() + + if t[1] then + for _, e in ipairs(t) do + if type(e) == 'table' then + h:update(table_digest(e)) + else + h:update(tostring(e)) + end + end + else + for k, v in pairs(t) do + h:update(tostring(k)) + + if type(v) == 'string' then + h:update(v) + elseif type(v) == 'table' then + h:update(table_digest(v)) + end + end + end + return h:base32() +end + +exports.table_digest = table_digest + +---[[[ +-- @function lua_util.toboolean(v) +-- Converts a string or a number to boolean +-- @param {string|number} v +-- @return {boolean} v converted to boolean +--]]] +exports.toboolean = function(v) + local true_t = { + ['1'] = true, + ['true'] = true, + ['TRUE'] = true, + ['True'] = true, + }; + local false_t = { + ['0'] = false, + ['false'] = false, + ['FALSE'] = false, + ['False'] = false, + }; + + if type(v) == 'string' then + if true_t[v] == true then + return true; + elseif false_t[v] == false then + return false; + else + return false, string.format('cannot convert %q to boolean', v); + end + elseif type(v) == 'number' then + return v ~= 0 + else + return false, string.format('cannot convert %q to boolean', v); + end +end + +---[[[ +-- @function lua_util.config_check_local_or_authed(config, modname) +-- Reads check_local and check_authed from the config as this is used in many modules +-- @param {rspamd_config} config `rspamd_config` global +-- @param {name} module name +-- @return {boolean} v converted to boolean +--]]] +exports.config_check_local_or_authed = function(rspamd_config, modname, def_local, def_authed) + local check_local = def_local or false + local check_authed = def_authed or false + + local function try_section(where) + local ret = false + local opts = rspamd_config:get_all_opt(where) + if type(opts) == 'table' then + if type(opts['check_local']) == 'boolean' then + check_local = opts['check_local'] + ret = true + end + if type(opts['check_authed']) == 'boolean' then + check_authed = opts['check_authed'] + ret = true + end + end + + return ret + end + + if not try_section(modname) then + try_section('options') + end + + return { check_local, check_authed } +end + +---[[[ +-- @function lua_util.is_skip_local_or_authed(task, conf[, ip]) +-- Returns `true` if local or authenticated task should be skipped for this module +-- @param {rspamd_task} task +-- @param {table} conf table returned from `config_check_local_or_authed` +-- @param {rspamd_ip} ip optional ip address (can be obtained from a task) +-- @return {boolean} true if check should be skipped +--]]] +exports.is_skip_local_or_authed = function(task, conf, ip) + if not ip then + ip = task:get_from_ip() + end + if not conf then + conf = { false, false } + end + if ((not conf[2] and task:get_user()) or + (not conf[1] and type(ip) == 'userdata' and ip:is_local())) then + return true + end + + return false +end + +---[[[ +-- @function lua_util.maybe_smtp_quote_value(str) +-- Checks string for the forbidden elements (tspecials in RFC and quote string if needed) +-- @param {string} str input string +-- @return {string} original or quoted string +--]]] +local tspecial = lpeg.S "()<>,;:\\\"/[]?= \t\v" +local special_match = lpeg.P((1 - tspecial) ^ 0 * tspecial ^ 1) +exports.maybe_smtp_quote_value = function(str) + if special_match:match(str) then + return string.format('"%s"', str:gsub('"', '\\"')) + end + + return str +end + +---[[[ +-- @function lua_util.shuffle(table) +-- Performs in-place shuffling of a table +-- @param {table} tbl table to shuffle +-- @return {table} same table +--]]] +exports.shuffle = function(tbl) + local size = #tbl + for i = size, 1, -1 do + local rand = math.random(size) + tbl[i], tbl[rand] = tbl[rand], tbl[i] + end + return tbl +end + +-- +local hex_table = {} +for idx = 0, 255 do + hex_table[("%02X"):format(idx)] = string.char(idx) + hex_table[("%02x"):format(idx)] = string.char(idx) +end + +---[[[ +-- @function lua_util.unhex(str) +-- Decode hex encoded string +-- @param {string} str string to decode +-- @return {string} hex decoded string (valid hex pairs are decoded, everything else is printed as is) +--]]] +exports.unhex = function(str) + return str:gsub('(..)', hex_table) +end + +local http_upstream_lists = {} +local function http_upstreams_by_url(pool, url) + local rspamd_url = require "rspamd_url" + + local cached = http_upstream_lists[url] + if cached then + return cached + end + + local real_url = rspamd_url.create(pool, url) + + if not real_url then + return nil + end + + local host = real_url:get_host() + local proto = real_url:get_protocol() or 'http' + local port = real_url:get_port() or (proto == 'https' and 443 or 80) + local upstream_list = require "rspamd_upstream_list" + local upstreams = upstream_list.create(host, port) + + if upstreams then + http_upstream_lists[url] = upstreams + return upstreams + end + + return nil +end +---[[[ +-- @function lua_util.http_upstreams_by_url(pool, url) +-- Returns a cached or new upstreams list that corresponds to the specific url +-- @param {mempool} pool memory pool to use (typically static pool from rspamd_config) +-- @param {string} url full url +-- @return {upstreams_list} object to get upstream from an url +--]]] +exports.http_upstreams_by_url = http_upstreams_by_url + +---[[[ +-- @function lua_util.dns_timeout_augmentation(cfg) +-- Returns an augmentation suitable to define DNS timeout for a module +-- @return {string} a string in format 'timeout=x' where `x` is a number of seconds for DNS timeout +--]]] +local function dns_timeout_augmentation(cfg) + return string.format('timeout=%f', cfg:get_dns_timeout() or 0.0) +end + +exports.dns_timeout_augmentation = dns_timeout_augmentation + +---[[[ +--- @function lua_util.strip_lua_comments(lua_code) +-- Strips single-line and multi-line comments from a given Lua code string and removes +-- any extra spaces or newlines. +-- +-- @param lua_code The Lua code string to strip comments from. +-- @return The resulting Lua code string with comments and extra spaces removed. +-- +---]]] +local function strip_lua_comments(lua_code) + -- Remove single-line comments + lua_code = lua_code:gsub("%-%-[^\r\n]*", "") + + -- Remove multi-line comments + lua_code = lua_code:gsub("%-%-%[%[.-%]%]", "") + + -- Remove extra spaces and newlines + lua_code = lua_code:gsub("%s+", " ") + + return lua_code +end + +exports.strip_lua_comments = strip_lua_comments + +---[[[ +-- @function lua_util.join_path(...) +-- Joins path components into a single path string using the appropriate separator +-- for the current operating system. +-- +-- @param ... Any number of path components to join together. +-- @return A single path string, with components separated by the appropriate separator. +-- +---]]] +local path_sep = package.config:sub(1, 1) or '/' +local function join_path(...) + local components = { ... } + + -- Join components using separator + return table.concat(components, path_sep) +end +exports.join_path = join_path + +-- Short unit test for sanity +if path_sep == '/' then + assert(join_path('/path', 'to', 'file') == '/path/to/file') +else + assert(join_path('C:', 'path', 'to', 'file') == 'C:\\path\\to\\file') +end + +-- Defines symbols priorities for common usage in prefilters/postfilters +exports.symbols_priorities = { + top = 10, -- Symbols must be executed first (or last), such as settings + high = 9, -- Example: asn + medium = 5, -- Everything should use this as default + low = 0, +} + +return exports diff --git a/lualib/lua_verdict.lua b/lualib/lua_verdict.lua new file mode 100644 index 0000000..6ce99e6 --- /dev/null +++ b/lualib/lua_verdict.lua @@ -0,0 +1,208 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local exports = {} + +---[[[ +-- @function lua_verdict.get_default_verdict(task) +-- Returns verdict for a task + score if certain, must be called from idempotent filters only +-- Returns string: +-- * `spam`: if message have over reject threshold and has more than one positive rule +-- * `junk`: if a message has between score between [add_header/rewrite subject] to reject thresholds and has more than two positive rules +-- * `passthrough`: if a message has been passed through some short-circuit rule +-- * `ham`: if a message has overall score below junk level **and** more than three negative rule, or negative total score +-- * `uncertain`: all other cases +--]] +local function default_verdict_function(task) + local result = task:get_metric_result() + + if result then + + if result.passthrough then + return 'passthrough', nil + end + + local score = result.score + + local action = result.action + + if action == 'reject' and result.npositive > 1 then + return 'spam', score + elseif action == 'no action' then + if score < 0 or result.nnegative > 3 then + return 'ham', score + end + else + -- All colors of junk + if action == 'add header' or action == 'rewrite subject' then + if result.npositive > 2 then + return 'junk', score + end + end + end + + return 'uncertain', score + end +end + +local default_possible_verdicts = { + passthrough = { + can_learn = false, + description = 'message has passthrough result', + }, + spam = { + can_learn = 'spam', + description = 'message is likely spam', + }, + junk = { + can_learn = 'spam', + description = 'message is likely possible spam', + }, + ham = { + can_learn = 'ham', + description = 'message is likely ham', + }, + uncertain = { + can_learn = false, + description = 'not certainty in verdict' + } +} + +-- Verdict functions specific for modules +local specific_verdicts = { + default = { + callback = default_verdict_function, + possible_verdicts = default_possible_verdicts + } +} + +local default_verdict = specific_verdicts.default + +exports.get_default_verdict = default_verdict.callback +exports.set_verdict_function = function(func, what) + assert(type(func) == 'function') + if not what then + -- Default verdict + local existing = specific_verdicts.default.callback + specific_verdicts.default.callback = func + exports.get_default_verdict = func + + return existing + else + local existing = specific_verdicts[what] + + if not existing then + specific_verdicts[what] = { + callback = func, + possible_verdicts = default_possible_verdicts + } + else + existing = existing.callback + end + + specific_verdicts[what].callback = func + return existing + end +end + +exports.set_verdict_table = function(verdict_tbl, what) + assert(type(verdict_tbl) == 'table' and + type(verdict_tbl.callback) == 'function' and + type(verdict_tbl.possible_verdicts) == 'table') + + if not what then + -- Default verdict + local existing = specific_verdicts.default + specific_verdicts.default = verdict_tbl + exports.get_default_verdict = specific_verdicts.default.callback + + return existing + else + local existing = specific_verdicts[what] + specific_verdicts[what] = verdict_tbl + return existing + end +end + +exports.get_specific_verdict = function(what, task) + if specific_verdicts[what] then + return specific_verdicts[what].callback(task) + end + + return exports.get_default_verdict(task) +end + +exports.get_possible_verdicts = function(what) + local lua_util = require "lua_util" + if what then + if specific_verdicts[what] then + return lua_util.keys(specific_verdicts[what].possible_verdicts) + end + else + return lua_util.keys(specific_verdicts.default.possible_verdicts) + end + + return nil +end + +exports.can_learn = function(verdict, what) + if what then + if specific_verdicts[what] and specific_verdicts[what].possible_verdicts[verdict] then + return specific_verdicts[what].possible_verdicts[verdict].can_learn + end + else + if specific_verdicts.default.possible_verdicts[verdict] then + return specific_verdicts.default.possible_verdicts[verdict].can_learn + end + end + + return nil -- To distinguish from `false` that could happen in can_learn +end + +exports.describe = function(verdict, what) + if what then + if specific_verdicts[what] and specific_verdicts[what].possible_verdicts[verdict] then + return specific_verdicts[what].possible_verdicts[verdict].description + end + else + if specific_verdicts.default.possible_verdicts[verdict] then + return specific_verdicts.default.possible_verdicts[verdict].description + end + end + + return nil +end + +---[[[ +-- @function lua_verdict.adjust_passthrough_action(task) +-- If an action is `soft reject` then this function extracts a module that has set this action +-- and returns an adjusted action (e.g. 'greylist' or 'ratelimit'). +-- Otherwise an action is returned as is. +--]] +exports.adjust_passthrough_action = function(task) + local action = task:get_metric_action() + if action == 'soft reject' then + local has_pr, _, _, module = task:has_pre_result() + + if has_pr and module then + action = module + end + end + + return action +end + +return exports
\ No newline at end of file diff --git a/lualib/plugins/dmarc.lua b/lualib/plugins/dmarc.lua new file mode 100644 index 0000000..7791f4e --- /dev/null +++ b/lualib/plugins/dmarc.lua @@ -0,0 +1,359 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> +Copyright (c) 2015-2016, Andrew Lewis <nerf@judo.za.org> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +-- Common dmarc stuff +local rspamd_logger = require "rspamd_logger" +local lua_util = require "lua_util" +local N = "dmarc" + +local exports = {} + +exports.default_settings = { + auth_and_local_conf = false, + symbols = { + spf_allow_symbol = 'R_SPF_ALLOW', + spf_deny_symbol = 'R_SPF_FAIL', + spf_softfail_symbol = 'R_SPF_SOFTFAIL', + spf_neutral_symbol = 'R_SPF_NEUTRAL', + spf_tempfail_symbol = 'R_SPF_DNSFAIL', + spf_permfail_symbol = 'R_SPF_PERMFAIL', + spf_na_symbol = 'R_SPF_NA', + + dkim_allow_symbol = 'R_DKIM_ALLOW', + dkim_deny_symbol = 'R_DKIM_REJECT', + dkim_tempfail_symbol = 'R_DKIM_TEMPFAIL', + dkim_na_symbol = 'R_DKIM_NA', + dkim_permfail_symbol = 'R_DKIM_PERMFAIL', + + -- DMARC symbols + allow = 'DMARC_POLICY_ALLOW', + badpolicy = 'DMARC_BAD_POLICY', + dnsfail = 'DMARC_DNSFAIL', + na = 'DMARC_NA', + reject = 'DMARC_POLICY_REJECT', + softfail = 'DMARC_POLICY_SOFTFAIL', + quarantine = 'DMARC_POLICY_QUARANTINE', + }, + no_sampling_domains = nil, + no_reporting_domains = nil, + reporting = { + report_local_controller = false, -- Store reports for local/controller scans (for testing only) + redis_keys = { + index_prefix = 'dmarc_idx', + report_prefix = 'dmarc_rpt', + join_char = ';', + }, + helo = 'rspamd.localhost', + smtp = '127.0.0.1', + smtp_port = 25, + retries = 2, + from_name = 'Rspamd', + msgid_from = 'rspamd', + enabled = false, + max_entries = 1000, + keys_expire = 172800, + only_domains = nil, + }, + actions = {}, +} + + +-- Returns a key used to be inserted into dmarc report sample +exports.dmarc_report = function(task, settings, data) + local rspamd_lua_utils = require "lua_util" + local E = {} + + local ip = task:get_from_ip() + if not ip or not ip:is_valid() then + rspamd_logger.infox(task, 'cannot store dmarc report for %s: no valid source IP', + data.domain) + return nil + end + + ip = ip:to_string() + + if rspamd_lua_utils.is_rspamc_or_controller(task) and not settings.reporting.report_local_controller then + rspamd_logger.infox(task, 'cannot store dmarc report for %s from IP %s: has come from controller/rspamc', + data.domain, ip) + return + end + + local dkim_pass = table.concat(data.dkim_results.pass or E, '|') + local dkim_fail = table.concat(data.dkim_results.fail or E, '|') + local dkim_temperror = table.concat(data.dkim_results.temperror or E, '|') + local dkim_permerror = table.concat(data.dkim_results.permerror or E, '|') + local disposition_to_return = data.disposition + local res = table.concat({ + ip, data.spf_ok, data.dkim_ok, + disposition_to_return, (data.sampled_out and 'sampled_out' or ''), data.domain, + dkim_pass, dkim_fail, dkim_temperror, dkim_permerror, data.spf_domain, data.spf_result }, ',') + + return res +end + +exports.gen_munging_callback = function(munging_opts, settings) + local rspamd_util = require "rspamd_util" + local lua_mime = require "lua_mime" + return function(task) + if munging_opts.mitigate_allow_only then + if not task:has_symbol(settings.symbols.allow) then + lua_util.debugm(N, task, 'skip munging, no %s symbol', + settings.symbols.allow) + -- Excepted + return + end + else + local has_dmarc = task:has_symbol(settings.symbols.allow) or + task:has_symbol(settings.symbols.quarantine) or + task:has_symbol(settings.symbols.reject) or + task:has_symbol(settings.symbols.softfail) + + if not has_dmarc then + lua_util.debugm(N, task, 'skip munging, no %s symbol', + settings.symbols.allow) + -- Excepted + return + end + end + if munging_opts.mitigate_strict_only then + local s = task:get_symbol(settings.symbols.allow) or { [1] = {} } + local sopts = s[1].options or {} + + local seen_strict + for _, o in ipairs(sopts) do + if o == 'reject' or o == 'quarantine' then + seen_strict = true + break + end + end + + if not seen_strict then + lua_util.debugm(N, task, 'skip munging, no strict policy found in %s', + settings.symbols.allow) + -- Excepted + return + end + end + if munging_opts.munge_map_condition then + local accepted, trace = munging_opts.munge_map_condition:process(task) + if not accepted then + lua_util.debugm(N, task, 'skip munging, maps condition not satisfied: (%s)', + trace) + -- Excepted + return + end + end + -- Now, look for domain for munging + local mr = task:get_recipients({ 'mime', 'orig' }) + local rcpt_found + if mr then + for _, r in ipairs(mr) do + if r.domain and munging_opts.list_map:get_key(r.addr) then + rcpt_found = r + break + end + end + end + + if not rcpt_found then + lua_util.debugm(N, task, 'skip munging, recipients are not in list_map') + -- Excepted + return + end + + local from = task:get_from({ 'mime', 'orig' }) + + if not from or not from[1] then + lua_util.debugm(N, task, 'skip munging, from is bad') + -- Excepted + return + end + + from = from[1] + local via_user = rcpt_found.user + local via_addr = rcpt_found.addr + local via_name + + if from.name == "" then + via_name = string.format('%s via %s', from.user or 'unknown', via_user) + else + via_name = string.format('%s via %s', from.name, via_user) + end + + local hdr_encoded = rspamd_util.fold_header('From', + rspamd_util.mime_header_encode(string.format('%s <%s>', + via_name, via_addr)), task:get_newlines_type()) + local orig_from_encoded = rspamd_util.fold_header('X-Original-From', + rspamd_util.mime_header_encode(string.format('%s <%s>', + from.name or '', from.addr)), task:get_newlines_type()) + local add_hdrs = { + ['From'] = { order = 1, value = hdr_encoded }, + } + local remove_hdrs = { ['From'] = 0 } + + local nreply = from.addr + if munging_opts.reply_goes_to_list then + -- Reply-to goes to the list + nreply = via_addr + end + + if task:has_header('Reply-To') then + -- If we have reply-to header, then we need to insert an additional + -- address there + local orig_reply = task:get_header_full('Reply-To')[1] + if orig_reply.value then + nreply = string.format('%s, %s', orig_reply.value, nreply) + end + remove_hdrs['Reply-To'] = 1 + end + + add_hdrs['Reply-To'] = { order = 0, value = nreply } + + add_hdrs['X-Original-From'] = { order = 0, value = orig_from_encoded } + lua_mime.modify_headers(task, { + remove = remove_hdrs, + add = add_hdrs + }) + lua_util.debugm(N, task, 'munged DMARC header for %s: %s -> %s', + from.domain, hdr_encoded, from.addr) + rspamd_logger.infox(task, 'munged DMARC header for %s', from.addr) + task:insert_result('DMARC_MUNGED', 1.0, from.addr) + end +end + +local function gen_dmarc_grammar() + local lpeg = require "lpeg" + lpeg.locale(lpeg) + local space = lpeg.space ^ 0 + local name = lpeg.C(lpeg.alpha ^ 1) * space + local sep = space * (lpeg.S("\\;") * space) + (lpeg.space ^ 1) + local value = lpeg.C(lpeg.P(lpeg.graph - sep) ^ 1) + local pair = lpeg.Cg(name * "=" * space * value) * sep ^ -1 + local list = lpeg.Cf(lpeg.Ct("") * pair ^ 0, rawset) + local version = lpeg.P("v") * space * lpeg.P("=") * space * lpeg.P("DMARC1") + local record = version * sep * list + + return record +end + +local dmarc_grammar = gen_dmarc_grammar() + +local function dmarc_key_value_case(elts) + if type(elts) ~= "table" then + return elts + end + local result = {} + for k, v in pairs(elts) do + k = k:lower() + if k ~= "v" then + v = v:lower() + end + + result[k] = v + end + + return result +end + +--[[ +-- Used to check dmarc record, check elements and produce dmarc policy processed +-- result. +-- Returns: +-- false,false - record is garbage +-- false,error_message - record is invalid +-- true,policy_table - record is valid and parsed +]] +local function dmarc_check_record(log_obj, record, is_tld) + local failed_policy + local result = { + dmarc_policy = 'none' + } + + local elts = dmarc_grammar:match(record) + lua_util.debugm(N, log_obj, "got DMARC record: %s, tld_flag=%s, processed=%s", + record, is_tld, elts) + + if elts then + elts = dmarc_key_value_case(elts) + + local dkim_pol = elts['adkim'] + if dkim_pol then + if dkim_pol == 's' then + result.strict_dkim = true + elseif dkim_pol ~= 'r' then + failed_policy = 'adkim tag has invalid value: ' .. dkim_pol + return false, failed_policy + end + end + + local spf_pol = elts['aspf'] + if spf_pol then + if spf_pol == 's' then + result.strict_spf = true + elseif spf_pol ~= 'r' then + failed_policy = 'aspf tag has invalid value: ' .. spf_pol + return false, failed_policy + end + end + + local policy = elts['p'] + if policy then + if (policy == 'reject') then + result.dmarc_policy = 'reject' + elseif (policy == 'quarantine') then + result.dmarc_policy = 'quarantine' + elseif (policy ~= 'none') then + failed_policy = 'p tag has invalid value: ' .. policy + return false, failed_policy + end + end + + -- Adjust policy if we are in tld mode + local subdomain_policy = elts['sp'] + if elts['sp'] and is_tld then + result.subdomain_policy = elts['sp'] + + if (subdomain_policy == 'reject') then + result.dmarc_policy = 'reject' + elseif (subdomain_policy == 'quarantine') then + result.dmarc_policy = 'quarantine' + elseif (subdomain_policy == 'none') then + result.dmarc_policy = 'none' + elseif (subdomain_policy ~= 'none') then + failed_policy = 'sp tag has invalid value: ' .. subdomain_policy + return false, failed_policy + end + end + result.pct = elts['pct'] + if result.pct then + result.pct = tonumber(result.pct) + end + + if elts.rua then + result.rua = elts['rua'] + end + result.raw_elts = elts + else + return false, false -- Ignore garbage + end + + return true, result +end + +exports.dmarc_check_record = dmarc_check_record + +return exports diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua new file mode 100644 index 0000000..6e88ef2 --- /dev/null +++ b/lualib/plugins/neural.lua @@ -0,0 +1,892 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local fun = require "fun" +local lua_redis = require "lua_redis" +local lua_settings = require "lua_settings" +local lua_util = require "lua_util" +local meta_functions = require "lua_meta" +local rspamd_kann = require "rspamd_kann" +local rspamd_logger = require "rspamd_logger" +local rspamd_tensor = require "rspamd_tensor" +local rspamd_util = require "rspamd_util" +local ucl = require "ucl" + +local N = 'neural' + +-- Used in prefix to avoid wrong ANN to be loaded +local plugin_ver = '2' + +-- Module vars +local default_options = { + train = { + max_trains = 1000, + max_epoch = 1000, + max_usages = 10, + max_iterations = 25, -- Torch style + mse = 0.001, + autotrain = true, + train_prob = 1.0, + learn_threads = 1, + learn_mode = 'balanced', -- Possible values: balanced, proportional + learning_rate = 0.01, + classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias) + spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1) + ham_skip_prob = 0.0, -- proportional mode: ham skip probability + store_pool_only = false, -- store tokens in cache only (disables autotrain); + -- neural_vec_mpack stores vector of training data in messagepack neural_profile_digest stores profile digest + }, + watch_interval = 60.0, + lock_expire = 600, + learning_spawned = false, + ann_expire = 60 * 60 * 24 * 2, -- 2 days + hidden_layer_mult = 1.5, -- number of neurons in the hidden layer + roc_enabled = false, -- Use ROC to find the best possible thresholds for ham and spam. If spam_score_threshold or ham_score_threshold is defined, it takes precedence over ROC thresholds. + roc_misclassification_cost = 0.5, -- Cost of misclassifying a spam message (must be 0..1). + spam_score_threshold = nil, -- neural score threshold for spam (must be 0..1 or nil to disable) + ham_score_threshold = nil, -- neural score threshold for ham (must be 0..1 or nil to disable) + flat_threshold_curve = false, -- use binary classification 0/1 when threshold is reached + symbol_spam = 'NEURAL_SPAM', + symbol_ham = 'NEURAL_HAM', + max_inputs = nil, -- when PCA is used + blacklisted_symbols = {}, -- list of symbols skipped in neural processing +} + +-- Rule structure: +-- * static config fields (see `default_options`) +-- * prefix - name or defined prefix +-- * settings - table of settings indexed by settings id, -1 is used when no settings defined + +-- Rule settings element defines elements for specific settings id: +-- * symbols - static symbols profile (defined by config or extracted from symcache) +-- * name - name of settings id +-- * digest - digest of all symbols +-- * ann - dynamic ANN configuration loaded from Redis +-- * train - train data for ANN (e.g. the currently trained ANN) + +-- Settings ANN table is loaded from Redis and represents dynamic profile for ANN +-- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically +-- * version - version of ANN loaded from redis +-- * redis_key - name of ANN key in Redis +-- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols) +-- * distance - distance between set.symbols and set.ann.symbols +-- * ann - kann object + +local settings = { + rules = {}, + prefix = 'rn', -- Neural network default prefix + max_profiles = 3, -- Maximum number of NN profiles stored +} + +-- Get module & Redis configuration +local module_config = rspamd_config:get_all_opt(N) +settings = lua_util.override_defaults(settings, module_config) +local redis_params = lua_redis.parse_redis_server('neural') + +local redis_lua_script_vectors_len = "neural_train_size.lua" +local redis_lua_script_maybe_invalidate = "neural_maybe_invalidate.lua" +local redis_lua_script_maybe_lock = "neural_maybe_lock.lua" +local redis_lua_script_save_unlock = "neural_save_unlock.lua" + +local redis_script_id = {} + +local function load_scripts() + redis_script_id.vectors_len = lua_redis.load_redis_script_from_file(redis_lua_script_vectors_len, + redis_params) + redis_script_id.maybe_invalidate = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_invalidate, + redis_params) + redis_script_id.maybe_lock = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_lock, + redis_params) + redis_script_id.save_unlock = lua_redis.load_redis_script_from_file(redis_lua_script_save_unlock, + redis_params) +end + +local function create_ann(n, nlayers, rule) + -- We ignore number of layers so far when using kann + local nhidden = math.floor(n * (rule.hidden_layer_mult or 1.0) + 1.0) + local t = rspamd_kann.layer.input(n) + t = rspamd_kann.transform.relu(t) + t = rspamd_kann.layer.dense(t, nhidden); + t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.ceb_neg) + return rspamd_kann.new.kann(t) +end + +-- Fills ANN data for a specific settings element +local function fill_set_ann(set, ann_key) + if not set.ann then + set.ann = { + symbols = set.symbols, + distance = 0, + digest = set.digest, + redis_key = ann_key, + version = 0, + } + end +end + +-- This function takes all inputs, applies PCA transformation and returns the final +-- PCA matrix as rspamd_tensor +local function learn_pca(inputs, max_inputs) + local scatter_matrix = rspamd_tensor.scatter_matrix(rspamd_tensor.fromtable(inputs)) + local eigenvals = scatter_matrix:eigen() + -- scatter matrix is not filled with eigenvectors + lua_util.debugm(N, 'eigenvalues: %s', eigenvals) + local w = rspamd_tensor.new(2, max_inputs, #scatter_matrix[1]) + for i = 1, max_inputs do + w[i] = scatter_matrix[#scatter_matrix - i + 1] + end + + lua_util.debugm(N, 'pca matrix: %s', w) + + return w +end + +-- This function computes optimal threshold using ROC for the given set of inputs. +-- Returns a threshold that minimizes: +-- alpha * (false_positive_rate) + beta * (false_negative_rate) +-- Where alpha is cost of false positive result +-- beta is cost of false negative result +local function get_roc_thresholds(ann, inputs, outputs, alpha, beta) + + -- Sorts list x and list y based on the values in list x. + local sort_relative = function(x, y) + + local r = {} + + assert(#x == #y) + local n = #x + + local a = {} + local b = {} + for i = 1, n do + r[i] = i + end + + local cmp = function(p, q) + return p < q + end + + table.sort(r, function(p, q) + return cmp(x[p], x[q]) + end) + + for i = 1, n do + a[i] = x[r[i]] + b[i] = y[r[i]] + end + + return a, b + end + + local function get_scores(nn, input_vectors) + local scores = {} + for i = 1, #inputs do + local score = nn:apply1(input_vectors[i], nn.pca)[1] + scores[#scores + 1] = score + end + + return scores + end + + local fpr = {} + local fnr = {} + local scores = get_scores(ann, inputs) + + scores, outputs = sort_relative(scores, outputs) + + local n_samples = #outputs + local n_spam = 0 + local n_ham = 0 + local ham_count_ahead = {} + local spam_count_ahead = {} + local ham_count_behind = {} + local spam_count_behind = {} + + ham_count_ahead[n_samples + 1] = 0 + spam_count_ahead[n_samples + 1] = 0 + + for i = n_samples, 1, -1 do + + if outputs[i][1] == 0 then + n_ham = n_ham + 1 + ham_count_ahead[i] = 1 + spam_count_ahead[i] = 0 + else + n_spam = n_spam + 1 + ham_count_ahead[i] = 0 + spam_count_ahead[i] = 1 + end + + ham_count_ahead[i] = ham_count_ahead[i] + ham_count_ahead[i + 1] + spam_count_ahead[i] = spam_count_ahead[i] + spam_count_ahead[i + 1] + end + + for i = 1, n_samples do + if outputs[i][1] == 0 then + ham_count_behind[i] = 1 + spam_count_behind[i] = 0 + else + ham_count_behind[i] = 0 + spam_count_behind[i] = 1 + end + + if i ~= 1 then + ham_count_behind[i] = ham_count_behind[i] + ham_count_behind[i - 1] + spam_count_behind[i] = spam_count_behind[i] + spam_count_behind[i - 1] + end + end + + for i = 1, n_samples do + fpr[i] = 0 + fnr[i] = 0 + + if (ham_count_ahead[i + 1] + ham_count_behind[i]) ~= 0 then + fpr[i] = ham_count_ahead[i + 1] / (ham_count_ahead[i + 1] + ham_count_behind[i]) + end + + if (spam_count_behind[i] + spam_count_ahead[i + 1]) ~= 0 then + fnr[i] = spam_count_behind[i] / (spam_count_behind[i] + spam_count_ahead[i + 1]) + end + end + + local p = n_spam / (n_spam + n_ham) + + local cost = {} + local min_cost_idx = 0 + local min_cost = math.huge + for i = 1, n_samples do + cost[i] = ((1 - p) * alpha * fpr[i]) + (p * beta * fnr[i]) + if min_cost >= cost[i] then + min_cost = cost[i] + min_cost_idx = i + end + end + + return scores[min_cost_idx] +end + +-- This function is intended to extend lock for ANN during training +-- It registers periodic that increases locked key each 30 seconds unless +-- `set.learning_spawned` is set to `true` +local function register_lock_extender(rule, set, ev_base, ann_key) + rspamd_config:add_periodic(ev_base, 30.0, + function() + local function redis_lock_extend_cb(err, _) + if err then + rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s', + ann_key, err) + else + rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds', + ann_key) + end + end + + if set.learning_spawned then + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + true, -- is write + redis_lock_extend_cb, --callback + 'HINCRBY', -- command + { ann_key, 'lock', '30' } + ) + else + lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false") + return false -- do not plan any more updates + end + + return true + end + ) +end + +local function can_push_train_vector(rule, task, learn_type, nspam, nham) + local train_opts = rule.train + local coin = math.random() + + if train_opts.train_prob and coin < 1.0 - train_opts.train_prob then + rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin) + return false + end + + if train_opts.learn_mode == 'balanced' then + -- Keep balanced training set based on number of spam and ham samples + if learn_type == 'spam' then + if nspam <= train_opts.max_trains then + if nspam > nham then + -- Apply sampling + local skip_rate = 1.0 - nham / (nspam + 1) + if coin < skip_rate - train_opts.classes_bias then + rspamd_logger.infox(task, + 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored', + learn_type, + skip_rate - train_opts.classes_bias, + nspam, nham) + return false + end + end + return true + else + -- Enough learns + rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s', + learn_type, + nspam) + end + else + if nham <= train_opts.max_trains then + if nham > nspam then + -- Apply sampling + local skip_rate = 1.0 - nspam / (nham + 1) + if coin < skip_rate - train_opts.classes_bias then + rspamd_logger.infox(task, + 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored', + learn_type, + skip_rate - train_opts.classes_bias, + nspam, nham) + return false + end + end + return true + else + rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many ham samples: %s', learn_type, + nham) + end + end + else + -- Probabilistic learn mode, we just skip learn if we already have enough samples or + -- if our coin drop is less than desired probability + if learn_type == 'spam' then + if nspam <= train_opts.max_trains then + if train_opts.spam_skip_prob then + if coin <= train_opts.spam_skip_prob then + rspamd_logger.infox(task, 'skip %s sample probabilistically; probability %s (%s skip chance)', learn_type, + coin, train_opts.spam_skip_prob) + return false + end + + return true + end + else + rspamd_logger.infox(task, 'skip %s sample; too many spam samples: %s (%s limit)', learn_type, + nspam, train_opts.max_trains) + end + else + if nham <= train_opts.max_trains then + if train_opts.ham_skip_prob then + if coin <= train_opts.ham_skip_prob then + rspamd_logger.infox(task, 'skip %s sample probabilistically; probability %s (%s skip chance)', learn_type, + coin, train_opts.ham_skip_prob) + return false + end + + return true + end + else + rspamd_logger.infox(task, 'skip %s sample; too many ham samples: %s (%s limit)', learn_type, + nham, train_opts.max_trains) + end + end + end + + return false +end + +-- Closure generator for unlock function +local function gen_unlock_cb(rule, set, ann_key) + return function(err) + if err then + rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s', + rule.prefix, set.name, ann_key, err) + else + lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s', + rule.prefix, set.name, ann_key) + end + end +end + +-- Used to generate new ANN key for specific profile +local function new_ann_key(rule, set, version) + local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix, + rule.prefix, set.name, set.digest:sub(1, 8), tostring(version)) + + return ann_key +end + +local function redis_ann_prefix(rule, settings_name) + -- We also need to count metatokens: + local n = meta_functions.version + return string.format('%s%d_%s_%d_%s', + settings.prefix, plugin_ver, rule.prefix, n, settings_name) +end + +-- This function receives training vectors, checks them, spawn learning and saves ANN in Redis +local function spawn_train(params) + -- Check training data sanity + -- Now we need to join inputs and create the appropriate test vectors + local n = #params.set.symbols + + meta_functions.rspamd_count_metatokens() + + -- Now we can train ann + local train_ann = create_ann(params.rule.max_inputs or n, 3, params.rule) + + if #params.ham_vec + #params.spam_vec < params.rule.train.max_trains / 2 then + -- Invalidate ANN as it is definitely invalid + -- TODO: add invalidation + assert(false) + else + local inputs, outputs = {}, {} + + -- Used to show parsed vectors in a convenient format (for debugging only) + local function debug_vec(t) + local ret = {} + for i, v in ipairs(t) do + if v ~= 0 then + ret[#ret + 1] = string.format('%d=%.2f', i, v) + end + end + + return ret + end + + -- Make training set by joining vectors + -- KANN automatically shuffles those samples + -- 1.0 is used for spam and -1.0 is used for ham + -- It implies that output layer can express that (e.g. tanh output) + for _, e in ipairs(params.spam_vec) do + inputs[#inputs + 1] = e + outputs[#outputs + 1] = { 1.0 } + --rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e)) + end + for _, e in ipairs(params.ham_vec) do + inputs[#inputs + 1] = e + outputs[#outputs + 1] = { -1.0 } + --rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e)) + end + + -- Called in child process + local function train() + local log_thresh = params.rule.train.max_iterations / 10 + local seen_nan = false + + local function train_cb(iter, train_cost, value_cost) + if (iter * (params.rule.train.max_iterations / log_thresh)) % (params.rule.train.max_iterations) == 0 then + if train_cost ~= train_cost and not seen_nan then + -- We have nan :( try to log lot's of stuff to dig into a problem + seen_nan = true + rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s', + params.rule.prefix, params.set.name, + value_cost) + for i, e in ipairs(inputs) do + lua_util.debugm(N, rspamd_config, 'train vector %s -> %s', + debug_vec(e), outputs[i][1]) + end + end + + rspamd_logger.infox(rspamd_config, + "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s", + params.rule.prefix, params.set.name, + params.ann_key, + iter, + train_cost, + value_cost) + end + end + + lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started", + params.rule.prefix, params.set.name) + + local pca + if params.rule.max_inputs then + -- Train PCA in the main process, presumably it is not that long + lua_util.debugm(N, rspamd_config, "start PCA train for ANN %s:%s", + params.rule.prefix, params.set.name) + pca = learn_pca(inputs, params.rule.max_inputs) + end + + lua_util.debugm(N, rspamd_config, "start neural train for ANN %s:%s", + params.rule.prefix, params.set.name) + local ret, err = pcall(train_ann.train1, train_ann, + inputs, outputs, { + lr = params.rule.train.learning_rate, + max_epoch = params.rule.train.max_iterations, + cb = train_cb, + pca = pca + }) + + if not ret then + rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s", + params.rule.prefix, params.set.name, err) + + return nil + else + lua_util.debugm(N, rspamd_config, "finished neural train for ANN %s:%s", + params.rule.prefix, params.set.name) + end + + local roc_thresholds = {} + if params.rule.roc_enabled then + local spam_threshold = get_roc_thresholds(train_ann, + inputs, + outputs, + 1 - params.rule.roc_misclassification_cost, + params.rule.roc_misclassification_cost) + local ham_threshold = get_roc_thresholds(train_ann, + inputs, + outputs, + params.rule.roc_misclassification_cost, + 1 - params.rule.roc_misclassification_cost) + roc_thresholds = { spam_threshold, ham_threshold } + + rspamd_logger.messagex("ROC thresholds: (spam_threshold: %s, ham_threshold: %s)", + roc_thresholds[1], roc_thresholds[2]) + end + + if not seen_nan then + -- Convert to strings as ucl cannot rspamd_text properly + local pca_data + if pca then + pca_data = tostring(pca:save()) + end + local out = { + ann_data = tostring(train_ann:save()), + pca_data = pca_data, + roc_thresholds = roc_thresholds, + } + + local final_data = ucl.to_format(out, 'msgpack') + lua_util.debugm(N, rspamd_config, "subprocess for ANN %s:%s returned %s bytes", + params.rule.prefix, params.set.name, #final_data) + return final_data + else + return nil + end + end + + params.set.learning_spawned = true + + local function redis_save_cb(err) + if err then + rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s', + params.rule.prefix, params.set.name, params.ann_key, err) + lua_redis.redis_make_request_taskless(params.ev_base, + rspamd_config, + params.rule.redis, + nil, + false, -- is write + gen_unlock_cb(params.rule, params.set, params.ann_key), --callback + 'HDEL', -- command + { params.ann_key, 'lock' } + ) + else + rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s', + params.rule.prefix, params.set.name, params.set.ann.redis_key) + end + end + + local function ann_trained(err, data) + params.set.learning_spawned = false + if err then + rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s', + params.rule.prefix, params.set.name, err) + lua_redis.redis_make_request_taskless(params.ev_base, + rspamd_config, + params.rule.redis, + nil, + true, -- is write + gen_unlock_cb(params.rule, params.set, params.ann_key), --callback + 'HDEL', -- command + { params.ann_key, 'lock' } + ) + else + local parser = ucl.parser() + local ok, parse_err = parser:parse_text(data, 'msgpack') + assert(ok, parse_err) + local parsed = parser:get_object() + local ann_data = rspamd_util.zstd_compress(parsed.ann_data) + local pca_data = parsed.pca_data + local roc_thresholds = parsed.roc_thresholds + + fill_set_ann(params.set, params.ann_key) + if pca_data then + params.set.ann.pca = rspamd_tensor.load(pca_data) + pca_data = rspamd_util.zstd_compress(pca_data) + end + + if roc_thresholds then + params.set.ann.roc_thresholds = roc_thresholds + end + + + -- Deserialise ANN from the child process + ann_trained = rspamd_kann.load(parsed.ann_data) + local version = (params.set.ann.version or 0) + 1 + params.set.ann.version = version + params.set.ann.ann = ann_trained + params.set.ann.symbols = params.set.symbols + params.set.ann.redis_key = new_ann_key(params.rule, params.set, version) + + local profile = { + symbols = params.set.symbols, + digest = params.set.digest, + redis_key = params.set.ann.redis_key, + version = version + } + + local profile_serialized = ucl.to_format(profile, 'json-compact', true) + local roc_thresholds_serialized = ucl.to_format(roc_thresholds, 'json-compact', true) + + rspamd_logger.infox(rspamd_config, + 'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)', + params.rule.prefix, params.set.name, + #data, #ann_data, + #(params.set.ann.pca or {}), #(pca_data or {}), + params.set.ann.redis_key, params.ann_key) + + lua_redis.exec_redis_script(redis_script_id.save_unlock, + { ev_base = params.ev_base, is_write = true }, + redis_save_cb, + { profile.redis_key, + redis_ann_prefix(params.rule, params.set.name), + ann_data, + profile_serialized, + tostring(params.rule.ann_expire), + tostring(os.time()), + params.ann_key, -- old key to unlock... + roc_thresholds_serialized, + pca_data, + }) + end + end + + if params.rule.max_inputs then + fill_set_ann(params.set, params.ann_key) + end + + params.worker:spawn_process { + func = train, + on_complete = ann_trained, + proctitle = string.format("ANN train for %s/%s", params.rule.prefix, params.set.name), + } + -- Spawn learn and register lock extension + params.set.learning_spawned = true + register_lock_extender(params.rule, params.set, params.ev_base, params.ann_key) + return + + end +end + +-- This function is used to adjust profiles and allowed setting ids for each rule +-- It must be called when all settings are already registered (e.g. at post-init for config) +local function process_rules_settings() + local function process_settings_elt(rule, selt) + local profile = rule.profile[selt.name] + if profile then + -- Use static user defined profile + -- Ensure that we have an array... + lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s", + rule.prefix, selt.name, profile) + if not profile[1] then + profile = lua_util.keys(profile) + end + selt.symbols = profile + else + lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)", + rule.prefix, selt.name) + end + + local function filter_symbols_predicate(sname) + if settings.blacklisted_symbols and settings.blacklisted_symbols[sname] then + return false + end + local fl = rspamd_config:get_symbol_flags(sname) + if fl then + fl = lua_util.list_to_hash(fl) + + return not (fl.nostat or fl.idempotent or fl.skip or fl.composite) + end + + return false + end + + -- Generic stuff + if not profile then + -- Do filtering merely if we are using a dynamic profile + selt.symbols = fun.totable(fun.filter(filter_symbols_predicate, selt.symbols)) + end + + table.sort(selt.symbols) + + selt.digest = lua_util.table_digest(selt.symbols) + selt.prefix = redis_ann_prefix(rule, selt.name) + + rspamd_logger.messagex(rspamd_config, + 'use NN prefix for rule %s; settings id "%s"; symbols digest: "%s"', + selt.prefix, selt.name, selt.digest) + + lua_redis.register_prefix(selt.prefix, N, + string.format('NN prefix for rule "%s"; settings id "%s"', + selt.prefix, selt.name), { + persistent = true, + type = 'zlist', + }) + -- Versions + lua_redis.register_prefix(selt.prefix .. '_\\d+', N, + string.format('NN storage for rule "%s"; settings id "%s"', + selt.prefix, selt.name), { + persistent = true, + type = 'hash', + }) + lua_redis.register_prefix(selt.prefix .. '_\\d+_spam_set', N, + string.format('NN learning set (spam) for rule "%s"; settings id "%s"', + selt.prefix, selt.name), { + persistent = true, + type = 'set', + }) + lua_redis.register_prefix(selt.prefix .. '_\\d+_ham_set', N, + string.format('NN learning set (spam) for rule "%s"; settings id "%s"', + rule.prefix, selt.name), { + persistent = true, + type = 'set', + }) + end + + for k, rule in pairs(settings.rules) do + if not rule.allowed_settings then + rule.allowed_settings = {} + elseif rule.allowed_settings == 'all' then + -- Extract all settings ids + rule.allowed_settings = lua_util.keys(lua_settings.all_settings()) + end + + -- Convert to a map <setting_id> -> true + rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings) + + -- Check if we can work without settings + if k == 'default' or type(rule.default) ~= 'boolean' then + rule.default = true + end + + rule.settings = {} + + if rule.default then + local default_settings = { + symbols = lua_settings.default_symbols(), + name = 'default' + } + + process_settings_elt(rule, default_settings) + rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32 + end + + -- Now, for each allowed settings, we store sorted symbols + digest + -- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest } + for s, _ in pairs(rule.allowed_settings) do + -- Here, we have a name, set of symbols and + local settings_id = s + if type(settings_id) ~= 'number' then + settings_id = lua_settings.numeric_settings_id(s) + end + local selt = lua_settings.settings_by_id(settings_id) + + local nelt = { + symbols = selt.symbols, -- Already sorted + name = selt.name + } + + process_settings_elt(rule, nelt) + for id, ex in pairs(rule.settings) do + if type(ex) == 'table' then + if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then + -- Equal symbols, add reference + lua_util.debugm(N, rspamd_config, + 'added reference from settings id %s to %s; same symbols', + nelt.name, ex.name) + rule.settings[settings_id] = id + nelt = nil + end + end + end + + if nelt then + rule.settings[settings_id] = nelt + lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s', + nelt.name, settings_id, rule.prefix) + end + end + end +end + +-- Extract settings element for a specific settings id +local function get_rule_settings(task, rule) + local sid = task:get_settings_id() or -1 + local set = rule.settings[sid] + + if not set then + return nil + end + + while type(set) == 'number' do + -- Reference to another settings! + set = rule.settings[set] + end + + return set +end + +local function result_to_vector(task, profile) + if not profile.zeros then + -- Fill zeros vector + local zeros = {} + for i = 1, meta_functions.count_metatokens() do + zeros[i] = 0.0 + end + for _, _ in ipairs(profile.symbols) do + zeros[#zeros + 1] = 0.0 + end + profile.zeros = zeros + end + + local vec = lua_util.shallowcopy(profile.zeros) + local mt = meta_functions.rspamd_gen_metatokens(task) + + for i, v in ipairs(mt) do + vec[i] = v + end + + task:process_ann_tokens(profile.symbols, vec, #mt, 0.1) + + return vec +end + +return { + can_push_train_vector = can_push_train_vector, + create_ann = create_ann, + default_options = default_options, + gen_unlock_cb = gen_unlock_cb, + get_rule_settings = get_rule_settings, + load_scripts = load_scripts, + module_config = module_config, + new_ann_key = new_ann_key, + plugin_ver = plugin_ver, + process_rules_settings = process_rules_settings, + redis_ann_prefix = redis_ann_prefix, + redis_params = redis_params, + redis_script_id = redis_script_id, + result_to_vector = result_to_vector, + settings = settings, + spawn_train = spawn_train, +} diff --git a/lualib/plugins/rbl.lua b/lualib/plugins/rbl.lua new file mode 100644 index 0000000..af5d6bd --- /dev/null +++ b/lualib/plugins/rbl.lua @@ -0,0 +1,232 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local ts = require("tableshape").types +local lua_maps = require "lua_maps" +local lua_util = require "lua_util" + +-- Common RBL plugin definitions + +local check_types = { + from = { + connfilter = true, + }, + received = {}, + helo = { + connfilter = true, + }, + urls = {}, + content_urls = {}, + numeric_urls = {}, + emails = {}, + replyto = {}, + dkim = {}, + rdns = { + connfilter = true, + }, + selector = { + require_argument = true, + }, +} + +local default_options = { + ['default_enabled'] = true, + ['default_ipv4'] = true, + ['default_ipv6'] = true, + ['default_unknown'] = false, + ['default_dkim_domainonly'] = true, + ['default_emails_domainonly'] = false, + ['default_exclude_users'] = false, + ['default_exclude_local'] = true, + ['default_no_ip'] = false, + ['default_dkim_match_from'] = false, + ['default_selector_flatten'] = true, +} + +local return_codes_schema = ts.map_of( + ts.string / string.upper, -- Symbol name + ( + ts.array_of(ts.string) + + (ts.string / function(s) + return { s } + end) -- List of IP patterns + ) +) +local return_bits_schema = ts.map_of( + ts.string / string.upper, -- Symbol name + ( + ts.array_of(ts.number + ts.string / tonumber) + + (ts.string / function(s) + return { tonumber(s) } + end) + + (ts.number / function(s) + return { s } + end) + ) +) + +local rule_schema_tbl = { + content_urls = ts.boolean:is_optional(), + disable_monitoring = ts.boolean:is_optional(), + disabled = ts.boolean:is_optional(), + dkim = ts.boolean:is_optional(), + dkim_domainonly = ts.boolean:is_optional(), + dkim_match_from = ts.boolean:is_optional(), + emails = ts.boolean:is_optional(), + emails_delimiter = ts.string:is_optional(), + emails_domainonly = ts.boolean:is_optional(), + enabled = ts.boolean:is_optional(), + exclude_local = ts.boolean:is_optional(), + exclude_users = ts.boolean:is_optional(), + from = ts.boolean:is_optional(), + hash = ts.one_of { "sha1", "sha256", "sha384", "sha512", "md5", "blake2" }:is_optional(), + hash_format = ts.one_of { "hex", "base32", "base64" }:is_optional(), + hash_len = (ts.integer + ts.string / tonumber):is_optional(), + helo = ts.boolean:is_optional(), + ignore_default = ts.boolean:is_optional(), -- alias + ignore_defaults = ts.boolean:is_optional(), + ignore_url_whitelist = ts.boolean:is_optional(), + ignore_whitelist = ts.boolean:is_optional(), + ignore_whitelists = ts.boolean:is_optional(), -- alias + images = ts.boolean:is_optional(), + ipv4 = ts.boolean:is_optional(), + ipv6 = ts.boolean:is_optional(), + is_whitelist = ts.boolean:is_optional(), + local_exclude_ip_map = ts.string:is_optional(), + monitored_address = ts.string:is_optional(), + no_ip = ts.boolean:is_optional(), + process_script = ts.string:is_optional(), + random_monitored = ts.boolean:is_optional(), + rbl = ts.string, + rdns = ts.boolean:is_optional(), + received = ts.boolean:is_optional(), + received_flags = ts.array_of(ts.string):is_optional(), + received_max_pos = ts.number:is_optional(), + received_min_pos = ts.number:is_optional(), + received_nflags = ts.array_of(ts.string):is_optional(), + replyto = ts.boolean:is_optional(), + requests_limit = (ts.integer + ts.string / tonumber):is_optional(), + require_symbols = ( + ts.array_of(ts.string) + (ts.string / function(s) + return { s } + end) + ):is_optional(), + resolve_ip = ts.boolean:is_optional(), + return_bits = return_bits_schema:is_optional(), + return_codes = return_codes_schema:is_optional(), + returnbits = return_bits_schema:is_optional(), + returncodes = return_codes_schema:is_optional(), + returncodes_matcher = ts.one_of { "equality", "glob", "luapattern", "radix", "regexp" }:is_optional(), + selector = ts.one_of { ts.string, ts.table }:is_optional(), + selector_flatten = ts.boolean:is_optional(), + symbol = ts.string:is_optional(), + symbols_prefixes = ts.map_of(ts.string, ts.string):is_optional(), + unknown = ts.boolean:is_optional(), + url_compose_map = lua_maps.map_schema:is_optional(), + url_full_hostname = ts.boolean:is_optional(), + url_whitelist = lua_maps.map_schema:is_optional(), + urls = ts.boolean:is_optional(), + whitelist = lua_maps.map_schema:is_optional(), + whitelist_exception = ( + ts.array_of(ts.string) + (ts.string / function(s) + return { s } + end) + ):is_optional(), + checks = ts.array_of(ts.one_of(lua_util.keys(check_types))):is_optional(), + exclude_checks = ts.array_of(ts.one_of(lua_util.keys(check_types))):is_optional(), +} + +local function convert_checks(rule, name) + local rspamd_logger = require "rspamd_logger" + if rule.checks then + local all_connfilter = true + local exclude_checks = lua_util.list_to_hash(rule.exclude_checks or {}) + for _, check in ipairs(rule.checks) do + if not exclude_checks[check] then + local check_type = check_types[check] + if check_type.require_argument then + if not rule[check] then + rspamd_logger.errx(rspamd_config, 'rbl rule %s has check %s which requires an argument', + name, check) + return nil + end + end + + rule[check] = check_type + + if not check_type.connfilter then + all_connfilter = false + end + + if not check_type then + rspamd_logger.errx(rspamd_config, 'rbl rule %s has invalid check type: %s', + name, check) + return nil + end + else + rspamd_logger.infox(rspamd_config, 'disable check %s in %s: excluded explicitly', + check, name) + end + end + rule.connfilter = all_connfilter + end + + -- Now check if we have any check enabled at all + local check_found = false + for k, _ in pairs(check_types) do + if type(rule[k]) ~= 'nil' then + check_found = true + break + end + end + + if not check_found then + -- Enable implicit `from` check to allow upgrade + rspamd_logger.warnx(rspamd_config, 'rbl rule %s has no check enabled, enable default `from` check', + name) + rule.from = true + end + + if rule.returncodes and not rule.returncodes_matcher then + for _, v in pairs(rule.returncodes) do + for _, e in ipairs(v) do + if e:find('[%%%[]') then + rspamd_logger.warn(rspamd_config, 'implicitly enabling luapattern returncodes_matcher for rule %s', name) + rule.returncodes_matcher = 'luapattern' + break + end + end + if rule.returncodes_matcher then + break + end + end + end + + return rule +end + + +-- Add default boolean flags to the schema +for def_k, _ in pairs(default_options) do + rule_schema_tbl[def_k:sub(#('default_') + 1)] = ts.boolean:is_optional() +end + +return { + check_types = check_types, + rule_schema = ts.shape(rule_schema_tbl), + default_options = default_options, + convert_checks = convert_checks, +} diff --git a/lualib/plugins_stats.lua b/lualib/plugins_stats.lua new file mode 100644 index 0000000..2497fb9 --- /dev/null +++ b/lualib/plugins_stats.lua @@ -0,0 +1,48 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local ansicolors = require "ansicolors" + +local function printf(fmt, ...) + print(string.format(fmt, ...)) +end + +local function highlight(str) + return ansicolors.white .. str .. ansicolors.reset +end + +local function print_plugins_table(tbl, what) + local mods = {} + for k, _ in pairs(tbl) do + table.insert(mods, k) + end + + printf("Modules %s: %s", highlight(what), table.concat(mods, ", ")) +end + +return function(args, _) + print_plugins_table(rspamd_plugins_state.enabled, "enabled") + print_plugins_table(rspamd_plugins_state.disabled_explicitly, + "disabled (explicitly)") + print_plugins_table(rspamd_plugins_state.disabled_unconfigured, + "disabled (unconfigured)") + print_plugins_table(rspamd_plugins_state.disabled_redis, + "disabled (no Redis)") + print_plugins_table(rspamd_plugins_state.disabled_experimental, + "disabled (experimental)") + print_plugins_table(rspamd_plugins_state.disabled_failed, + "disabled (failed)") +end
\ No newline at end of file diff --git a/lualib/redis_scripts/bayes_cache_check.lua b/lualib/redis_scripts/bayes_cache_check.lua new file mode 100644 index 0000000..f1ffc2b --- /dev/null +++ b/lualib/redis_scripts/bayes_cache_check.lua @@ -0,0 +1,20 @@ +-- Lua script to perform cache checking for bayes classification +-- This script accepts the following parameters: +-- key1 - cache id +-- key2 - configuration table in message pack + +local cache_id = KEYS[1] +local conf = cmsgpack.unpack(KEYS[2]) +cache_id = string.sub(cache_id, 1, conf.cache_elt_len) + +-- Try each prefix that is in Redis +for i = 0, conf.cache_max_keys do + local prefix = conf.cache_prefix .. string.rep("X", i) + local have = redis.call('HGET', prefix, cache_id) + + if have then + return tonumber(have) + end +end + +return nil diff --git a/lualib/redis_scripts/bayes_cache_learn.lua b/lualib/redis_scripts/bayes_cache_learn.lua new file mode 100644 index 0000000..8811f3c --- /dev/null +++ b/lualib/redis_scripts/bayes_cache_learn.lua @@ -0,0 +1,61 @@ +-- Lua script to perform cache checking for bayes classification +-- This script accepts the following parameters: +-- key1 - cache id +-- key3 - is spam (1 or 0) +-- key3 - configuration table in message pack + +local cache_id = KEYS[1] +local is_spam = KEYS[2] +local conf = cmsgpack.unpack(KEYS[3]) +cache_id = string.sub(cache_id, 1, conf.cache_elt_len) + +-- Try each prefix that is in Redis (as some other instance might have set it) +for i = 0, conf.cache_max_keys do + local prefix = conf.cache_prefix .. string.rep("X", i) + local have = redis.call('HGET', prefix, cache_id) + + if have then + -- Already in cache + return false + end +end + +local added = false +local lim = conf.cache_max_elt +for i = 0, conf.cache_max_keys do + if not added then + local prefix = conf.cache_prefix .. string.rep("X", i) + local count = redis.call('HLEN', prefix) + + if count < lim then + -- We can add it to this prefix + redis.call('HSET', prefix, cache_id, is_spam) + added = true + end + end +end + +if not added then + -- Need to expire some keys + local expired = false + for i = 0, conf.cache_max_keys do + local prefix = conf.cache_prefix .. string.rep("X", i) + local exists = redis.call('EXISTS', prefix) + + if exists then + if expired then + redis.call('DEL', prefix) + redis.call('HSET', prefix, cache_id, is_spam) + + -- Do not expire anything else + expired = true + elseif i > 0 then + -- Move key to a shorter prefix, so we will rotate them eventually from lower to upper + local new_prefix = conf.cache_prefix .. string.rep("X", i - 1) + redis.call('RENAME', prefix, new_prefix) + end + end + end +end + +return true
\ No newline at end of file diff --git a/lualib/redis_scripts/bayes_classify.lua b/lualib/redis_scripts/bayes_classify.lua new file mode 100644 index 0000000..e94f645 --- /dev/null +++ b/lualib/redis_scripts/bayes_classify.lua @@ -0,0 +1,37 @@ +-- Lua script to perform bayes classification +-- This script accepts the following parameters: +-- key1 - prefix for bayes tokens (e.g. for per-user classification) +-- key2 - set of tokens encoded in messagepack array of strings + +local prefix = KEYS[1] +local output_spam = {} +local output_ham = {} + +local learned_ham = tonumber(redis.call('HGET', prefix, 'learns_ham')) or 0 +local learned_spam = tonumber(redis.call('HGET', prefix, 'learns_spam')) or 0 + +-- Output is a set of pairs (token_index, token_count), tokens that are not +-- found are not filled. +-- This optimisation will save a lot of space for sparse tokens, and in Bayes that assumption is normally held + +if learned_ham > 0 and learned_spam > 0 then + local input_tokens = cmsgpack.unpack(KEYS[2]) + for i, token in ipairs(input_tokens) do + local token_data = redis.call('HMGET', token, 'H', 'S') + + if token_data then + local ham_count = token_data[1] + local spam_count = token_data[2] + + if ham_count then + table.insert(output_ham, { i, tonumber(ham_count) }) + end + + if spam_count then + table.insert(output_spam, { i, tonumber(spam_count) }) + end + end + end +end + +return { learned_ham, learned_spam, output_ham, output_spam }
\ No newline at end of file diff --git a/lualib/redis_scripts/bayes_learn.lua b/lualib/redis_scripts/bayes_learn.lua new file mode 100644 index 0000000..80d86d8 --- /dev/null +++ b/lualib/redis_scripts/bayes_learn.lua @@ -0,0 +1,44 @@ +-- Lua script to perform bayes learning +-- This script accepts the following parameters: +-- key1 - prefix for bayes tokens (e.g. for per-user classification) +-- key2 - boolean is_spam +-- key3 - string symbol +-- key4 - boolean is_unlearn +-- key5 - set of tokens encoded in messagepack array of strings +-- key6 - set of text tokens (if any) encoded in messagepack array of strings (size must be twice of `KEYS[5]`) + +local prefix = KEYS[1] +local is_spam = KEYS[2] == 'true' and true or false +local symbol = KEYS[3] +local is_unlearn = KEYS[4] == 'true' and true or false +local input_tokens = cmsgpack.unpack(KEYS[5]) +local text_tokens + +if KEYS[6] then + text_tokens = cmsgpack.unpack(KEYS[6]) +end + +local hash_key = is_spam and 'S' or 'H' +local learned_key = is_spam and 'learns_spam' or 'learns_ham' + +redis.call('SADD', symbol .. '_keys', prefix) +redis.call('HSET', prefix, 'version', '2') -- new schema +redis.call('HINCRBY', prefix, learned_key, is_unlearn and -1 or 1) -- increase or decrease learned count + +for i, token in ipairs(input_tokens) do + redis.call('HINCRBY', token, hash_key, 1) + if text_tokens then + local tok1 = text_tokens[i * 2 - 1] + local tok2 = text_tokens[i * 2] + + if tok1 then + if tok2 then + redis.call('HSET', token, 'tokens', string.format('%s:%s', tok1, tok2)) + else + redis.call('HSET', token, 'tokens', tok1) + end + + redis.call('ZINCRBY', prefix .. '_z', is_unlearn and -1 or 1, token) + end + end +end
\ No newline at end of file diff --git a/lualib/redis_scripts/bayes_stat.lua b/lualib/redis_scripts/bayes_stat.lua new file mode 100644 index 0000000..31e5128 --- /dev/null +++ b/lualib/redis_scripts/bayes_stat.lua @@ -0,0 +1,19 @@ +-- Lua script to perform bayes stats +-- This script accepts the following parameters: +-- key1 - current cursor +-- key2 - symbol to examine +-- key3 - learn key (e.g. learns_ham or learns_spam) +-- key4 - max users + +local cursor = tonumber(KEYS[1]) + +local ret = redis.call('SSCAN', KEYS[2] .. '_keys', cursor, 'COUNT', tonumber(KEYS[4])) + +local new_cursor = tonumber(ret[1]) +local nkeys = #ret[2] +local learns = 0 +for _, key in ipairs(ret[2]) do + learns = learns + (tonumber(redis.call('HGET', key, KEYS[3])) or 0) +end + +return { new_cursor, nkeys, learns }
\ No newline at end of file diff --git a/lualib/redis_scripts/neural_maybe_invalidate.lua b/lualib/redis_scripts/neural_maybe_invalidate.lua new file mode 100644 index 0000000..517fa01 --- /dev/null +++ b/lualib/redis_scripts/neural_maybe_invalidate.lua @@ -0,0 +1,25 @@ +-- Lua script to invalidate ANNs by rank +-- Uses the following keys +-- key1 - prefix for keys +-- key2 - number of elements to leave + +local card = redis.call('ZCARD', KEYS[1]) +local lim = tonumber(KEYS[2]) +if card > lim then + local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1) + if to_delete then + for _, k in ipairs(to_delete) do + local tb = cjson.decode(k) + if type(tb) == 'table' and type(tb.redis_key) == 'string' then + redis.call('DEL', tb.redis_key) + -- Also train vectors + redis.call('DEL', tb.redis_key .. '_spam_set') + redis.call('DEL', tb.redis_key .. '_ham_set') + end + end + end + redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1) + return to_delete +else + return {} +end
\ No newline at end of file diff --git a/lualib/redis_scripts/neural_maybe_lock.lua b/lualib/redis_scripts/neural_maybe_lock.lua new file mode 100644 index 0000000..f705115 --- /dev/null +++ b/lualib/redis_scripts/neural_maybe_lock.lua @@ -0,0 +1,19 @@ +-- Lua script lock ANN for learning +-- Uses the following keys +-- key1 - prefix for keys +-- key2 - current time +-- key3 - key expire +-- key4 - hostname + +local locked = redis.call('HGET', KEYS[1], 'lock') +local now = tonumber(KEYS[2]) +if locked then + locked = tonumber(locked) + local expire = tonumber(KEYS[3]) + if now > locked and (now - locked) < expire then + return { tostring(locked), redis.call('HGET', KEYS[1], 'hostname') or 'unknown' } + end +end +redis.call('HSET', KEYS[1], 'lock', tostring(now)) +redis.call('HSET', KEYS[1], 'hostname', KEYS[4]) +return 1
\ No newline at end of file diff --git a/lualib/redis_scripts/neural_save_unlock.lua b/lualib/redis_scripts/neural_save_unlock.lua new file mode 100644 index 0000000..5af1ddc --- /dev/null +++ b/lualib/redis_scripts/neural_save_unlock.lua @@ -0,0 +1,24 @@ +-- Lua script to save and unlock ANN in redis +-- Uses the following keys +-- key1 - prefix for ANN +-- key2 - prefix for profile +-- key3 - compressed ANN +-- key4 - profile as JSON +-- key5 - expire in seconds +-- key6 - current time +-- key7 - old key +-- key8 - ROC Thresholds +-- key9 - optional PCA +local now = tonumber(KEYS[6]) +redis.call('ZADD', KEYS[2], now, KEYS[4]) +redis.call('HSET', KEYS[1], 'ann', KEYS[3]) +redis.call('DEL', KEYS[1] .. '_spam_set') +redis.call('DEL', KEYS[1] .. '_ham_set') +redis.call('HDEL', KEYS[1], 'lock') +redis.call('HDEL', KEYS[7], 'lock') +redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5])) +redis.call('HSET', KEYS[1], 'roc_thresholds', KEYS[8]) +if KEYS[9] then + redis.call('HSET', KEYS[1], 'pca', KEYS[9]) +end +return 1
\ No newline at end of file diff --git a/lualib/redis_scripts/neural_train_size.lua b/lualib/redis_scripts/neural_train_size.lua new file mode 100644 index 0000000..45ad6a9 --- /dev/null +++ b/lualib/redis_scripts/neural_train_size.lua @@ -0,0 +1,24 @@ +-- Lua script that checks if we can store a new training vector +-- Uses the following keys: +-- key1 - ann key +-- returns nspam,nham (or nil if locked) + +local prefix = KEYS[1] +local locked = redis.call('HGET', prefix, 'lock') +if locked then + local host = redis.call('HGET', prefix, 'hostname') or 'unknown' + return string.format('%s:%s', host, locked) +end +local nspam = 0 +local nham = 0 + +local ret = redis.call('SCARD', prefix .. '_spam_set') +if ret then + nspam = tonumber(ret) +end +ret = redis.call('SCARD', prefix .. '_ham_set') +if ret then + nham = tonumber(ret) +end + +return { nspam, nham }
\ No newline at end of file diff --git a/lualib/redis_scripts/ratelimit_check.lua b/lualib/redis_scripts/ratelimit_check.lua new file mode 100644 index 0000000..d39cdf1 --- /dev/null +++ b/lualib/redis_scripts/ratelimit_check.lua @@ -0,0 +1,85 @@ +-- This Lua script is a rate limiter for Redis using the token bucket algorithm. +-- The script checks if a message should be rate-limited and updates the bucket status accordingly. +-- Input keys: +-- KEYS[1]: A prefix for the Redis keys, e.g., RL_<triplet>_<seconds> +-- KEYS[2]: The current time in milliseconds +-- KEYS[3]: The bucket leak rate (messages per millisecond) +-- KEYS[4]: The maximum allowed burst +-- KEYS[5]: The expiration time for a bucket +-- KEYS[6]: The number of recipients for the message + +-- Redis keys used: +-- l: Last hit (time in milliseconds) +-- b: Current burst (number of tokens in the bucket) +-- p: Pending messages (number of messages in processing) +-- dr: Current dynamic rate multiplier (*10000) +-- db: Current dynamic burst multiplier (*10000) + +-- Returns: +-- An array containing: +-- 1. if the message should be rate-limited or 0 if not +-- 2. The current burst value after processing the message +-- 3. The dynamic rate multiplier +-- 4. The dynamic burst multiplier +-- 5. The number of tokens leaked during processing + +local last = redis.call('HGET', KEYS[1], 'l') +local now = tonumber(KEYS[2]) +local nrcpt = tonumber(KEYS[6]) +local leak_rate = tonumber(KEYS[3]) +local max_burst = tonumber(KEYS[4]) +local prefix = KEYS[1] +local dynr, dynb, leaked = 0, 0, 0 +if not last then + -- New bucket + redis.call('HMSET', prefix, 'l', tostring(now), 'b', '0', 'dr', '10000', 'db', '10000', 'p', tostring(nrcpt)) + redis.call('EXPIRE', prefix, KEYS[5]) + return { 0, '0', '1', '1', '0' } +end +last = tonumber(last) + +local burst, pending = unpack(redis.call('HMGET', prefix, 'b', 'p')) +burst, pending = tonumber(burst or '0'), tonumber(pending or '0') +-- Sanity to avoid races +if burst < 0 then + burst = 0 +end +if pending < 0 then + pending = 0 +end +pending = pending + nrcpt -- this message +-- Perform leak +if burst + pending > 0 then + -- If we have any time passed + if burst > 0 and last < now then + dynr = tonumber(redis.call('HGET', prefix, 'dr')) / 10000.0 + if dynr == 0 then + dynr = 0.0001 + end + leak_rate = leak_rate * dynr + leaked = ((now - last) * leak_rate) + if leaked > burst then + leaked = burst + end + burst = burst - leaked + redis.call('HINCRBYFLOAT', prefix, 'b', -(leaked)) + redis.call('HSET', prefix, 'l', tostring(now)) + end + + dynb = tonumber(redis.call('HGET', prefix, 'db')) / 10000.0 + if dynb == 0 then + dynb = 0.0001 + end + + burst = burst + pending + if burst > 0 and burst > max_burst * dynb then + return { 1, tostring(burst - pending), tostring(dynr), tostring(dynb), tostring(leaked) } + end + -- Increase pending if we allow ratelimit + redis.call('HINCRBY', prefix, 'p', nrcpt) +else + burst = 0 + redis.call('HMSET', prefix, 'b', '0', 'p', tostring(nrcpt)) +end + +return { 0, tostring(burst), tostring(dynr), tostring(dynb), tostring(leaked) }
\ No newline at end of file diff --git a/lualib/redis_scripts/ratelimit_cleanup_pending.lua b/lualib/redis_scripts/ratelimit_cleanup_pending.lua new file mode 100644 index 0000000..698a3ec --- /dev/null +++ b/lualib/redis_scripts/ratelimit_cleanup_pending.lua @@ -0,0 +1,33 @@ +-- This script cleans up the pending requests in Redis. + +-- KEYS: Input parameters +-- KEYS[1] - prefix: The Redis key prefix used to store the bucket information. +-- KEYS[2] - now: The current time in milliseconds. +-- KEYS[3] - expire: The expiration time for the Redis key storing the bucket information, in seconds. +-- KEYS[4] - number_of_recipients: The number of requests to be allowed (or the increase rate). + +-- 1. Retrieve the last hit time and initialize variables +local prefix = KEYS[1] +local last = redis.call('HGET', prefix, 'l') +local nrcpt = tonumber(KEYS[4]) +if not last then + -- No bucket, no cleanup + return 0 +end + + +-- 2. Update the pending values based on the number of recipients (requests) +local pending = redis.call('HGET', prefix, 'p') +pending = tonumber(pending or '0') +if pending < nrcpt then + pending = 0 +else + pending = pending - nrcpt +end + +-- 3. Set the updated values back to Redis and update the expiration time for the bucket +redis.call('HMSET', prefix, 'p', tostring(pending), 'l', KEYS[2]) +redis.call('EXPIRE', prefix, KEYS[3]) + +-- 4. Return the updated pending value +return pending
\ No newline at end of file diff --git a/lualib/redis_scripts/ratelimit_update.lua b/lualib/redis_scripts/ratelimit_update.lua new file mode 100644 index 0000000..caee8fb --- /dev/null +++ b/lualib/redis_scripts/ratelimit_update.lua @@ -0,0 +1,93 @@ +-- This script updates a token bucket rate limiter with dynamic rate and burst multipliers in Redis. + +-- KEYS: Input parameters +-- KEYS[1] - prefix: The Redis key prefix used to store the bucket information. +-- KEYS[2] - now: The current time in milliseconds. +-- KEYS[3] - dynamic_rate_multiplier: A multiplier to adjust the rate limit dynamically. +-- KEYS[4] - dynamic_burst_multiplier: A multiplier to adjust the burst limit dynamically. +-- KEYS[5] - max_dyn_rate: The maximum allowed value for the dynamic rate multiplier. +-- KEYS[6] - max_burst_rate: The maximum allowed value for the dynamic burst multiplier. +-- KEYS[7] - expire: The expiration time for the Redis key storing the bucket information, in seconds. +-- KEYS[8] - number_of_recipients: The number of requests to be allowed (or the increase rate). + +-- 1. Retrieve the last hit time and initialize variables +local prefix = KEYS[1] +local last = redis.call('HGET', prefix, 'l') +local now = tonumber(KEYS[2]) +local nrcpt = tonumber(KEYS[8]) +if not last then + -- 2. Initialize a new bucket if the last hit time is not found (must not happen) + redis.call('HMSET', prefix, 'l', tostring(now), 'b', tostring(nrcpt), 'dr', '10000', 'db', '10000', 'p', '0') + redis.call('EXPIRE', prefix, KEYS[7]) + return { 1, 1, 1 } +end + +-- 3. Update the dynamic rate multiplier based on input parameters +local dr, db = 1.0, 1.0 + +local max_dr = tonumber(KEYS[5]) + +if max_dr > 1 then + local rate_mult = tonumber(KEYS[3]) + dr = tonumber(redis.call('HGET', prefix, 'dr')) / 10000 + + if rate_mult > 1.0 and dr < max_dr then + dr = dr * rate_mult + if dr > 0.0001 then + redis.call('HSET', prefix, 'dr', tostring(math.floor(dr * 10000))) + else + redis.call('HSET', prefix, 'dr', '1') + end + elseif rate_mult < 1.0 and dr > (1.0 / max_dr) then + dr = dr * rate_mult + if dr > 0.0001 then + redis.call('HSET', prefix, 'dr', tostring(math.floor(dr * 10000))) + else + redis.call('HSET', prefix, 'dr', '1') + end + end +end + +-- 4. Update the dynamic burst multiplier based on input parameters +local max_db = tonumber(KEYS[6]) +if max_db > 1 then + local rate_mult = tonumber(KEYS[4]) + db = tonumber(redis.call('HGET', prefix, 'db')) / 10000 + + if rate_mult > 1.0 and db < max_db then + db = db * rate_mult + if db > 0.0001 then + redis.call('HSET', prefix, 'db', tostring(math.floor(db * 10000))) + else + redis.call('HSET', prefix, 'db', '1') + end + elseif rate_mult < 1.0 and db > (1.0 / max_db) then + db = db * rate_mult + if db > 0.0001 then + redis.call('HSET', prefix, 'db', tostring(math.floor(db * 10000))) + else + redis.call('HSET', prefix, 'db', '1') + end + end +end + +-- 5. Update the burst and pending values based on the number of recipients (requests) +local burst, pending = unpack(redis.call('HMGET', prefix, 'b', 'p')) +burst, pending = tonumber(burst or '0'), tonumber(pending or '0') +if burst < 0 then + burst = nrcpt +else + burst = burst + nrcpt +end +if pending < nrcpt then + pending = 0 +else + pending = pending - nrcpt +end + +-- 6. Set the updated values back to Redis and update the expiration time for the bucket +redis.call('HMSET', prefix, 'b', tostring(burst), 'p', tostring(pending), 'l', KEYS[2]) +redis.call('EXPIRE', prefix, KEYS[7]) + +-- 7. Return the updated burst value, dynamic rate multiplier, and dynamic burst multiplier +return { tostring(burst), tostring(dr), tostring(db) }
\ No newline at end of file diff --git a/lualib/rspamadm/clickhouse.lua b/lualib/rspamadm/clickhouse.lua new file mode 100644 index 0000000..b22d800 --- /dev/null +++ b/lualib/rspamadm/clickhouse.lua @@ -0,0 +1,528 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local argparse = require "argparse" +local lua_clickhouse = require "lua_clickhouse" +local lua_util = require "lua_util" +local rspamd_http = require "rspamd_http" +local rspamd_upstream_list = require "rspamd_upstream_list" +local rspamd_logger = require "rspamd_logger" +local ucl = require "ucl" + +local E = {} + +-- Define command line options +local parser = argparse() + :name 'rspamadm clickhouse' + :description 'Retrieve information from Clickhouse' + :help_description_margin(30) + :command_target('command') + :require_command(true) + +parser:option '-c --config' + :description 'Path to config file' + :argname('config_file') + :default(rspamd_paths['CONFDIR'] .. '/rspamd.conf') +parser:option '-d --database' + :description 'Name of Clickhouse database to use' + :argname('database') + :default('default') +parser:flag '--no-ssl-verify' + :description 'Disable SSL verification' + :argname('no_ssl_verify') +parser:mutex( + parser:option '-p --password' + :description 'Password to use for Clickhouse' + :argname('password'), + parser:flag '-a --ask-password' + :description 'Ask password from the terminal' + :argname('ask_password') +) +parser:option '-s --server' + :description 'Address[:port] to connect to Clickhouse with' + :argname('server') +parser:option '-u --user' + :description 'Username to use for Clickhouse' + :argname('user') +parser:option '--use-gzip' + :description 'Use Gzip with Clickhouse' + :argname('use_gzip') + :default(true) +parser:flag '--use-https' + :description 'Use HTTPS with Clickhouse' + :argname('use_https') + +local neural_profile = parser:command 'neural_profile' + :description 'Generate symbols profile using data from Clickhouse' +neural_profile:option '-w --where' + :description 'WHERE clause for Clickhouse query' + :argname('where') +neural_profile:flag '-j --json' + :description 'Write output as JSON' + :argname('json') +neural_profile:option '--days' + :description 'Number of days to collect stats for' + :argname('days') + :default('7') +neural_profile:option '--limit -l' + :description 'Maximum rows to fetch per day' + :argname('limit') +neural_profile:option '--settings-id' + :description 'Settings ID to query' + :argname('settings_id') + :default('') + +local neural_train = parser:command 'neural_train' + :description 'Train neural using data from Clickhouse' +neural_train:option '--days' + :description 'Number of days to query data for' + :argname('days') + :default('7') +neural_train:option '--column-name-digest' + :description 'Name of neural profile digest column in Clickhouse' + :argname('column_name_digest') + :default('NeuralDigest') +neural_train:option '--column-name-vector' + :description 'Name of neural training vector column in Clickhouse' + :argname('column_name_vector') + :default('NeuralMpack') +neural_train:option '--limit -l' + :description 'Maximum rows to fetch per day' + :argname('limit') +neural_train:option '--profile -p' + :description 'Profile to use for training' + :argname('profile') + :default('default') +neural_train:option '--rule -r' + :description 'Rule to train' + :argname('rule') + :default('default') +neural_train:option '--spam -s' + :description 'WHERE clause to use for spam' + :argname('spam') + :default("Action == 'reject'") +neural_train:option '--ham -h' + :description 'WHERE clause to use for ham' + :argname('ham') + :default('Score < 0') +neural_train:option '--url -u' + :description 'URL to use for training' + :argname('url') + :default('http://127.0.0.1:11334/plugins/neural/learn') + +local http_params = { + config = rspamd_config, + ev_base = rspamadm_ev_base, + session = rspamadm_session, + resolver = rspamadm_dns_resolver, +} + +local function load_config(config_file) + local _r, err = rspamd_config:load_ucl(config_file) + + if not _r then + rspamd_logger.errx('cannot load %s: %s', config_file, err) + os.exit(1) + end + + _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' }) + if not _r then + rspamd_logger.errx('cannot process %s: %s', config_file, err) + os.exit(1) + end + + if not rspamd_config:init_modules() then + rspamd_logger.errx('cannot init modules when parsing %s', config_file) + os.exit(1) + end + + rspamd_config:init_subsystem('symcache') +end + +local function days_list(days) + -- Create list of days to query starting with yesterday + local query_days = {} + local previous_date = os.time() - 86400 + local num_days = tonumber(days) + for _ = 1, num_days do + table.insert(query_days, os.date('%Y-%m-%d', previous_date)) + previous_date = previous_date - 86400 + end + return query_days +end + +local function get_excluded_symbols(known_symbols, correlations, seen_total) + -- Walk results once to collect all symbols & count occurrences + + local remove = {} + local known_symbols_list = {} + local composites = rspamd_config:get_all_opt('composites') + local all_symbols = rspamd_config:get_symbols() + local skip_flags = { + nostat = true, + skip = true, + idempotent = true, + composite = true, + } + for k, v in pairs(known_symbols) do + local lower_count, higher_count + if v.seen_spam > v.seen_ham then + lower_count = v.seen_ham + higher_count = v.seen_spam + else + lower_count = v.seen_spam + higher_count = v.seen_ham + end + + if composites[k] then + remove[k] = 'composite symbol' + elseif lower_count / higher_count >= 0.95 then + remove[k] = 'weak ham/spam correlation' + elseif v.seen / seen_total >= 0.9 then + remove[k] = 'omnipresent symbol' + elseif not all_symbols[k] then + remove[k] = 'nonexistent symbol' + else + for fl, _ in pairs(all_symbols[k].flags or {}) do + if skip_flags[fl] then + remove[k] = fl .. ' symbol' + break + end + end + end + known_symbols_list[v.id] = { + seen = v.seen, + name = k, + } + end + + -- Walk correlation matrix and check total counts + for sym_id, row in pairs(correlations) do + for inner_sym_id, count in pairs(row) do + local known = known_symbols_list[sym_id] + local inner = known_symbols_list[inner_sym_id] + if known and count == known.seen and not remove[inner.name] and not remove[known.name] then + remove[known.name] = string.format("overlapped by %s", + known_symbols_list[inner_sym_id].name) + end + end + end + + return remove +end + +local function handle_neural_profile(args) + + local known_symbols, correlations = {}, {} + local symbols_count, seen_total = 0, 0 + + local function process_row(r) + local is_spam = true + if r['Action'] == 'no action' or r['Action'] == 'greylist' then + is_spam = false + end + seen_total = seen_total + 1 + + local nsym = #r['Symbols.Names'] + + for i = 1, nsym do + local sym = r['Symbols.Names'][i] + local t = known_symbols[sym] + if not t then + local spam_count, ham_count = 0, 0 + if is_spam then + spam_count = spam_count + 1 + else + ham_count = ham_count + 1 + end + known_symbols[sym] = { + id = symbols_count, + seen = 1, + seen_ham = ham_count, + seen_spam = spam_count, + } + symbols_count = symbols_count + 1 + else + known_symbols[sym].seen = known_symbols[sym].seen + 1 + if is_spam then + known_symbols[sym].seen_spam = known_symbols[sym].seen_spam + 1 + else + known_symbols[sym].seen_ham = known_symbols[sym].seen_ham + 1 + end + end + end + + -- Fill correlations + for i = 1, nsym do + for j = 1, nsym do + if i ~= j then + local sym = r['Symbols.Names'][i] + local inner_sym_name = r['Symbols.Names'][j] + local known_sym = known_symbols[sym] + local inner_sym = known_symbols[inner_sym_name] + if known_sym and inner_sym then + if not correlations[known_sym.id] then + correlations[known_sym.id] = {} + end + local n = correlations[known_sym.id][inner_sym.id] or 0 + n = n + 1 + correlations[known_sym.id][inner_sym.id] = n + end + end + end + end + end + + local query_days = days_list(args.days) + local conditions = {} + table.insert(conditions, string.format("SettingsId = '%s'", args.settings_id)) + local limit = '' + local num_limit = tonumber(args.limit) + if num_limit then + limit = string.format(' LIMIT %d', num_limit) -- Contains leading space + end + if args.where then + table.insert(conditions, args.where) + end + + local query_fmt = 'SELECT Action, Symbols.Names FROM rspamd WHERE %s%s' + for _, query_day in ipairs(query_days) do + -- Date should be the last condition + table.insert(conditions, string.format("Date = '%s'", query_day)) + local query = string.format(query_fmt, table.concat(conditions, ' AND '), limit) + local upstream = args.upstream:get_upstream_round_robin() + local err = lua_clickhouse.select_sync(upstream, args, http_params, query, process_row) + if err ~= nil then + io.stderr:write(string.format('Error querying Clickhouse: %s\n', err)) + os.exit(1) + end + conditions[#conditions] = nil -- remove Date condition + end + + local remove = get_excluded_symbols(known_symbols, correlations, seen_total) + if not args.json then + for k in pairs(known_symbols) do + if not remove[k] then + io.stdout:write(string.format('%s\n', k)) + end + end + os.exit(0) + end + + local json_output = { + all_symbols = {}, + removed_symbols = {}, + used_symbols = {}, + } + for k in pairs(known_symbols) do + table.insert(json_output.all_symbols, k) + local why_removed = remove[k] + if why_removed then + json_output.removed_symbols[k] = why_removed + else + table.insert(json_output.used_symbols, k) + end + end + io.stdout:write(ucl.to_format(json_output, 'json')) +end + +local function post_neural_training(url, rule, spam_rows, ham_rows) + -- Prepare JSON payload + local payload = ucl.to_format( + { + ham_vec = ham_rows, + rule = rule, + spam_vec = spam_rows, + }, 'json') + + -- POST the payload + local err, response = rspamd_http.request({ + body = payload, + config = rspamd_config, + ev_base = rspamadm_ev_base, + log_obj = rspamd_config, + resolver = rspamadm_dns_resolver, + session = rspamadm_session, + url = url, + }) + + if err then + io.stderr:write(string.format('HTTP error: %s\n', err)) + os.exit(1) + end + if response.code ~= 200 then + io.stderr:write(string.format('bad HTTP code: %d\n', response.code)) + os.exit(1) + end + io.stdout:write(string.format('%s\n', response.content)) +end + +local function handle_neural_train(args) + + local this_where -- which class of messages are we collecting data for + local ham_rows, spam_rows = {}, {} + local want_spam, want_ham = true, true -- keep collecting while true + + -- Try find profile in config + local neural_opts = rspamd_config:get_all_opt('neural') + local symbols_profile = ((((neural_opts or E).rules or E)[args.rule] or E).profile or E)[args.profile] + if not symbols_profile then + io.stderr:write(string.format("Couldn't find profile %s in rule %s\n", args.profile, args.rule)) + os.exit(1) + end + -- Try find max_trains + local max_trains = (neural_opts.rules[args.rule].train or E).max_trains or 1000 + + -- Callback used to process rows from Clickhouse + local function process_row(r) + local destination -- which table to collect this information in + if this_where == args.ham then + destination = ham_rows + if #destination >= max_trains then + want_ham = false + return + end + else + destination = spam_rows + if #destination >= max_trains then + want_spam = false + return + end + end + local ucl_parser = ucl.parser() + local ok, err = ucl_parser:parse_string(r[args.column_name_vector], 'msgpack') + if not ok then + io.stderr:write(string.format("Couldn't parse [%s]: %s", r[args.column_name_vector], err)) + os.exit(1) + end + table.insert(destination, ucl_parser:get_object()) + end + + -- Generate symbols digest + table.sort(symbols_profile) + local symbols_digest = lua_util.table_digest(symbols_profile) + -- Create list of days to query data for + local query_days = days_list(args.days) + -- Set value for limit + local limit = '' + local num_limit = tonumber(args.limit) + if num_limit then + limit = string.format(' LIMIT %d', num_limit) -- Contains leading space + end + -- Prepare query elements + local conditions = { string.format("%s = '%s'", args.column_name_digest, symbols_digest) } + local query_fmt = 'SELECT %s FROM rspamd WHERE %s%s' + + -- Run queries + for _, the_where in ipairs({ args.ham, args.spam }) do + -- Inform callback which group of vectors we're collecting + this_where = the_where + table.insert(conditions, the_where) -- should be 2nd from last condition + -- Loop over days and try collect data + for _, query_day in ipairs(query_days) do + -- Break the loop if we have enough data already + if this_where == args.ham then + if not want_ham then + break + end + else + if not want_spam then + break + end + end + -- Date should be the last condition + table.insert(conditions, string.format("Date = '%s'", query_day)) + local query = string.format(query_fmt, args.column_name_vector, table.concat(conditions, ' AND '), limit) + local upstream = args.upstream:get_upstream_round_robin() + local err = lua_clickhouse.select_sync(upstream, args, http_params, query, process_row) + if err ~= nil then + io.stderr:write(string.format('Error querying Clickhouse: %s\n', err)) + os.exit(1) + end + conditions[#conditions] = nil -- remove Date condition + end + conditions[#conditions] = nil -- remove spam/ham condition + end + + -- Make sure we collected enough data for training + if #ham_rows < max_trains then + io.stderr:write(string.format('Insufficient ham rows: %d/%d\n', #ham_rows, max_trains)) + os.exit(1) + end + if #spam_rows < max_trains then + io.stderr:write(string.format('Insufficient spam rows: %d/%d\n', #spam_rows, max_trains)) + os.exit(1) + end + + return post_neural_training(args.url, args.rule, spam_rows, ham_rows) +end + +local command_handlers = { + neural_profile = handle_neural_profile, + neural_train = handle_neural_train, +} + +local function handler(args) + local cmd_opts = parser:parse(args) + + load_config(cmd_opts.config_file) + local cfg_opts = rspamd_config:get_all_opt('clickhouse') + + if cmd_opts.ask_password then + local rspamd_util = require "rspamd_util" + + io.write('Password: ') + cmd_opts.password = rspamd_util.readpassphrase() + end + + local function override_settings(params) + for _, which in ipairs(params) do + if cmd_opts[which] == nil then + cmd_opts[which] = cfg_opts[which] + end + end + end + + override_settings({ + 'database', 'no_ssl_verify', 'password', 'server', + 'use_gzip', 'use_https', 'user', + }) + + local servers = cmd_opts['server'] or cmd_opts['servers'] + if not servers then + parser:error("server(s) unspecified & couldn't be fetched from config") + end + + cmd_opts.upstream = rspamd_upstream_list.create(rspamd_config, servers, 8123) + + if not cmd_opts.upstream then + io.stderr:write(string.format("can't parse clickhouse address: %s\n", servers)) + os.exit(1) + end + + local f = command_handlers[cmd_opts.command] + if not f then + parser:error(string.format("command isn't implemented: %s", + cmd_opts.command)) + end + f(cmd_opts) +end + +return { + handler = handler, + description = parser._description, + name = 'clickhouse' +} diff --git a/lualib/rspamadm/configgraph.lua b/lualib/rspamadm/configgraph.lua new file mode 100644 index 0000000..07f14a9 --- /dev/null +++ b/lualib/rspamadm/configgraph.lua @@ -0,0 +1,172 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local rspamd_logger = require "rspamd_logger" +local rspamd_util = require "rspamd_util" +local rspamd_regexp = require "rspamd_regexp" +local argparse = require "argparse" + +-- Define command line options +local parser = argparse() + :name "rspamadm configgraph" + :description "Produces graph of Rspamd includes" + :help_description_margin(30) +parser:option "-c --config" + :description "Path to config file" + :argname("<file>") + :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf") +parser:flag "-a --all" + :description('Show all nodes, not just existing ones') + +local function process_filename(fname) + local cdir = rspamd_paths['CONFDIR'] .. '/' + fname = fname:gsub(cdir, '') + return fname +end + +local function output_dot(opts, nodes, adjacency) + rspamd_logger.messagex("digraph rspamd {") + for k, node in pairs(nodes) do + local attrs = { "shape=box" } + local skip = false + if node.exists then + if node.priority >= 10 then + attrs[#attrs + 1] = "color=red" + elseif node.priority > 0 then + attrs[#attrs + 1] = "color=blue" + end + else + if opts.all then + attrs[#attrs + 1] = "style=dotted" + else + skip = true + end + end + + if not skip then + rspamd_logger.messagex("\"%s\" [%s];", process_filename(k), + table.concat(attrs, ',')) + end + end + for _, adj in ipairs(adjacency) do + local attrs = {} + local skip = false + + if adj.to.exists then + if adj.to.merge then + attrs[#attrs + 1] = "arrowhead=diamond" + attrs[#attrs + 1] = "label=\"+\"" + elseif adj.to.priority > 1 then + attrs[#attrs + 1] = "color=red" + end + else + if opts.all then + attrs[#attrs + 1] = "style=dotted" + else + skip = true + end + end + + if not skip then + rspamd_logger.messagex("\"%s\" -> \"%s\" [%s];", process_filename(adj.from), + adj.to.short_path, table.concat(attrs, ',')) + end + end + rspamd_logger.messagex("}") +end + +local function load_config_traced(opts) + local glob_traces = {} + local adjacency = {} + local nodes = {} + + local function maybe_match_glob(file) + for _, gl in ipairs(glob_traces) do + if gl.re:match(file) then + return gl + end + end + + return nil + end + + local function add_dep(from, node, args) + adjacency[#adjacency + 1] = { + from = from, + to = node, + args = args + } + end + + local function process_node(fname, args) + local node = nodes[fname] + if not node then + node = { + path = fname, + short_path = process_filename(fname), + exists = rspamd_util.file_exists(fname), + merge = args.duplicate and args.duplicate == 'merge', + priority = args.priority or 0, + glob = args.glob, + try = args.try, + } + nodes[fname] = node + end + + return node + end + + local function trace_func(cur_file, included_file, args, parent) + if args.glob then + glob_traces[#glob_traces + 1] = { + re = rspamd_regexp.import_glob(included_file, ''), + parent = cur_file, + args = args, + seen = {}, + } + else + local node = process_node(included_file, args) + if opts.all or node.exists then + local gl_parent = maybe_match_glob(included_file) + if gl_parent and not gl_parent.seen[cur_file] then + add_dep(gl_parent.parent, nodes[cur_file], gl_parent.args) + gl_parent.seen[cur_file] = true + end + add_dep(cur_file, node, args) + end + end + end + + local _r, err = rspamd_config:load_ucl(opts['config'], trace_func) + if not _r then + rspamd_logger.errx('cannot parse %s: %s', opts['config'], err) + os.exit(1) + end + + output_dot(opts, nodes, adjacency) +end + +local function handler(args) + local res = parser:parse(args) + + load_config_traced(res) +end + +return { + handler = handler, + description = parser._description, + name = 'configgraph' +}
\ No newline at end of file diff --git a/lualib/rspamadm/confighelp.lua b/lualib/rspamadm/confighelp.lua new file mode 100644 index 0000000..38b26b6 --- /dev/null +++ b/lualib/rspamadm/confighelp.lua @@ -0,0 +1,123 @@ +local opts +local known_attrs = { + data = 1, + example = 1, + type = 1, + required = 1, + default = 1, +} +local argparse = require "argparse" +local ansicolors = require "ansicolors" + +local parser = argparse() + :name "rspamadm confighelp" + :description "Shows help for the specified configuration options" + :help_description_margin(32) +parser:argument "path":args "*" + :description('Optional config paths') +parser:flag "--no-color" + :description "Disable coloured output" +parser:flag "--short" + :description "Show only option names" +parser:flag "--no-examples" + :description "Do not show examples (implied by --short)" + +local function maybe_print_color(key) + if not opts['no-color'] then + return ansicolors.white .. key .. ansicolors.reset + else + return key + end +end + +local function sort_values(tbl) + local res = {} + for k, v in pairs(tbl) do + table.insert(res, { key = k, value = v }) + end + + -- Sort order + local order = { + options = 1, + dns = 2, + upstream = 3, + logging = 4, + metric = 5, + composite = 6, + classifier = 7, + modules = 8, + lua = 9, + worker = 10, + workers = 11, + } + + table.sort(res, function(a, b) + local oa = order[a['key']] + local ob = order[b['key']] + + if oa and ob then + return oa < ob + elseif oa then + return -1 < 0 + elseif ob then + return 1 < 0 + else + return a['key'] < b['key'] + end + + end) + + return res +end + +local function print_help(key, value, tabs) + print(string.format('%sConfiguration element: %s', tabs, maybe_print_color(key))) + + if not opts['short'] then + if value['data'] then + local nv = string.match(value['data'], '^#%s*(.*)%s*$') or value.data + print(string.format('%s\tDescription: %s', tabs, nv)) + end + if type(value['type']) == 'string' then + print(string.format('%s\tType: %s', tabs, value['type'])) + end + if type(value['required']) == 'boolean' then + if value['required'] then + print(string.format('%s\tRequired: %s', tabs, + maybe_print_color(tostring(value['required'])))) + else + print(string.format('%s\tRequired: %s', tabs, + tostring(value['required']))) + end + end + if value['default'] then + print(string.format('%s\tDefault: %s', tabs, value['default'])) + end + if not opts['no-examples'] and value['example'] then + local nv = string.match(value['example'], '^%s*(.*[^%s])%s*$') or value.example + print(string.format('%s\tExample:\n%s', tabs, nv)) + end + if value.type and value.type == 'object' then + print('') + end + end + + local sorted = sort_values(value) + for _, v in ipairs(sorted) do + if not known_attrs[v['key']] then + -- We need to go deeper + print_help(v['key'], v['value'], tabs .. '\t') + end + end +end + +return function(args, res) + opts = parser:parse(args) + + local sorted = sort_values(res) + + for _, v in ipairs(sorted) do + print_help(v['key'], v['value'], '') + print('') + end +end diff --git a/lualib/rspamadm/configwizard.lua b/lualib/rspamadm/configwizard.lua new file mode 100644 index 0000000..2637036 --- /dev/null +++ b/lualib/rspamadm/configwizard.lua @@ -0,0 +1,849 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local ansicolors = require "ansicolors" +local local_conf = rspamd_paths['CONFDIR'] +local rspamd_util = require "rspamd_util" +local rspamd_logger = require "rspamd_logger" +local lua_util = require "lua_util" +local lua_stat_tools = require "lua_stat" +local lua_redis = require "lua_redis" +local ucl = require "ucl" +local argparse = require "argparse" +local fun = require "fun" + +local plugins_stat = require "plugins_stats" + +local rspamd_logo = [[ + ____ _ + | _ \ ___ _ __ __ _ _ __ ___ __| | + | |_) |/ __|| '_ \ / _` || '_ ` _ \ / _` | + | _ < \__ \| |_) || (_| || | | | | || (_| | + |_| \_\|___/| .__/ \__,_||_| |_| |_| \__,_| + |_| +]] + +local parser = argparse() + :name "rspamadm configwizard" + :description "Perform guided configuration for Rspamd daemon" + :help_description_margin(32) +parser:option "-c --config" + :description "Path to config file" + :argname("<file>") + :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf") +parser:argument "checks" + :description "Checks to do (or 'list')" + :argname("<checks>") + :args "*" + +local redis_params + +local function printf(fmt, ...) + if fmt then + io.write(string.format(fmt, ...)) + end + io.write('\n') +end + +local function highlight(str) + return ansicolors.white .. str .. ansicolors.reset +end + +local function ask_yes_no(greet, default) + local def_str + if default then + greet = greet .. "[Y/n]: " + def_str = "yes" + else + greet = greet .. "[y/N]: " + def_str = "no" + end + + local reply = rspamd_util.readline(greet) + + if not reply then + os.exit(0) + end + if #reply == 0 then + reply = def_str + end + reply = reply:lower() + if reply == 'y' or reply == 'yes' then + return true + end + + return false +end + +local function readline_default(greet, def_value) + local reply = rspamd_util.readline(greet) + if not reply then + os.exit(0) + end + + if #reply == 0 then + return def_value + end + + return reply +end + +local function readline_expire() + local expire = '100d' + repeat + expire = readline_default("Expire time for new tokens [" .. expire .. "]: ", + expire) + expire = lua_util.parse_time_interval(expire) + + if not expire then + expire = '100d' + elseif expire > 2147483647 then + printf("The maximum possible value is 2147483647 (about 68y)") + expire = '68y' + elseif expire < -1 then + printf("The value must be a non-negative integer or -1") + expire = -1 + elseif expire ~= math.floor(expire) then + printf("The value must be an integer") + expire = math.floor(expire) + else + return expire + end + until false +end + +local function print_changes(changes) + local function print_change(k, c, where) + printf('File: %s, changes list:', highlight(local_conf .. '/' + .. where .. '/' .. k)) + + for ek, ev in pairs(c) do + printf("%s => %s", highlight(ek), rspamd_logger.slog("%s", ev)) + end + end + for k, v in pairs(changes.l) do + print_change(k, v, 'local.d') + if changes.o[k] then + v = changes.o[k] + print_change(k, v, 'override.d') + end + print() + end +end + +local function apply_changes(changes) + local function dirname(fname) + if fname:match(".-/.-") then + return string.gsub(fname, "(.*/)(.*)", "%1") + else + return nil + end + end + + local function apply_change(k, c, where) + local fname = local_conf .. '/' .. where .. '/' .. k + + if not rspamd_util.file_exists(fname) then + printf("Create file %s", highlight(fname)) + + local dname = dirname(fname) + + if dname then + local ret, err = rspamd_util.mkdir(dname, true) + + if not ret then + printf("Cannot make directory %s: %s", dname, highlight(err)) + os.exit(1) + end + end + end + + local f = io.open(fname, "a+") + + if not f then + printf("Cannot open file %s, aborting", highlight(fname)) + os.exit(1) + end + + f:write(ucl.to_config(c)) + + f:close() + end + for k, v in pairs(changes.l) do + apply_change(k, v, 'local.d') + if changes.o[k] then + v = changes.o[k] + apply_change(k, v, 'override.d') + end + end +end + +local function setup_controller(controller, changes) + printf("Setup %s and controller worker:", highlight("WebUI")) + + if not controller.password or controller.password == 'q1' then + if ask_yes_no("Controller password is not set, do you want to set one?", true) then + local pw_encrypted = rspamadm.pw_encrypt() + if pw_encrypted then + printf("Set encrypted password to: %s", highlight(pw_encrypted)) + changes.l['worker-controller.inc'] = { + password = pw_encrypted + } + end + end + end +end + +local function setup_redis(cfg, changes) + local function parse_servers(servers) + local ls = lua_util.rspamd_str_split(servers, ",") + + return ls + end + + printf("%s servers are not set:", highlight("Redis")) + printf("The following modules will be enabled if you add Redis servers:") + + for k, _ in pairs(rspamd_plugins_state.disabled_redis) do + printf("\t* %s", highlight(k)) + end + + if ask_yes_no("Do you wish to set Redis servers?", true) then + local read_servers = readline_default("Input read only servers separated by `,` [default: localhost]: ", + "localhost") + + local rs = parse_servers(read_servers) + if rs and #rs > 0 then + changes.l['redis.conf'] = { + read_servers = table.concat(rs, ",") + } + end + local write_servers = readline_default("Input write only servers separated by `,` [default: " + .. read_servers .. "]: ", read_servers) + + if not write_servers or #write_servers == 0 then + printf("Use read servers %s as write servers", highlight(table.concat(rs, ","))) + write_servers = read_servers + end + + redis_params = { + read_servers = rs, + } + + local ws = parse_servers(write_servers) + if ws and #ws > 0 then + changes.l['redis.conf']['write_servers'] = table.concat(ws, ",") + redis_params['write_servers'] = ws + end + + if ask_yes_no('Do you have any username set for your Redis (ACL SETUSER and Redis 6.0+)') then + local username = readline_default("Enter Redis username:", nil) + + if username then + changes.l['redis.conf'].username = username + redis_params.username = username + end + + local passwd = readline_default("Enter Redis password:", nil) + + if passwd then + changes.l['redis.conf']['password'] = passwd + redis_params['password'] = passwd + end + elseif ask_yes_no('Do you have any password set for your Redis?') then + local passwd = readline_default("Enter Redis password:", nil) + + if passwd then + changes.l['redis.conf']['password'] = passwd + redis_params['password'] = passwd + end + end + + if ask_yes_no('Do you have any specific database for your Redis?') then + local db = readline_default("Enter Redis database:", nil) + + if db then + changes.l['redis.conf']['db'] = db + redis_params['db'] = db + end + end + end +end + +local function setup_dkim_signing(cfg, changes) + -- Remove the trailing slash of a pathname, if present. + local function remove_trailing_slash(path) + if string.sub(path, -1) ~= "/" then + return path + end + return string.sub(path, 1, string.len(path) - 1) + end + + printf('How would you like to set up DKIM signing?') + printf('1. Use domain from %s for sign', highlight('mime from header')) + printf('2. Use domain from %s for sign', highlight('SMTP envelope from')) + printf('3. Use domain from %s for sign', highlight('authenticated user')) + printf('4. Sign all mail from %s', highlight('specific networks')) + printf() + + local sign_type = readline_default('Enter your choice (1, 2, 3, 4) [default: 1]: ', '1') + local sign_networks + local allow_mismatch + local sign_authenticated + local use_esld + local sign_domain = 'pet luacheck' + + local defined_auth_types = { 'header', 'envelope', 'auth', 'recipient' } + + if sign_type == '4' then + repeat + sign_networks = readline_default('Enter list of networks to perform dkim signing: ', + '') + until #sign_networks ~= 0 + + sign_networks = fun.totable(fun.map(lua_util.rspamd_str_trim, + lua_util.str_split(sign_networks, ',; '))) + printf('What domain would you like to use for signing?') + printf('* %s to use mime from domain', highlight('header')) + printf('* %s to use SMTP from domain', highlight('envelope')) + printf('* %s to use domain from SMTP auth', highlight('auth')) + printf('* %s to use domain from SMTP recipient', highlight('recipient')) + printf('* anything else to use as a %s domain (e.g. `example.com`)', highlight('static')) + printf() + + sign_domain = readline_default('Enter your choice [default: header]: ', 'header') + else + if sign_type == '1' then + sign_domain = 'header' + elseif sign_type == '2' then + sign_domain = 'envelope' + else + sign_domain = 'auth' + end + end + + if sign_type ~= '3' then + sign_authenticated = ask_yes_no( + string.format('Do you want to sign mail from %s? ', + highlight('authenticated users')), true) + else + sign_authenticated = true + end + + if fun.any(function(s) + return s == sign_domain + end, defined_auth_types) then + -- Allow mismatch + allow_mismatch = ask_yes_no( + string.format('Allow data %s, e.g. if mime from domain is not equal to authenticated user domain? ', + highlight('mismatch')), true) + -- ESLD check + use_esld = ask_yes_no( + string.format('Do you want to use %s domain (e.g. example.com instead of foo.example.com)? ', + highlight('effective')), true) + else + allow_mismatch = true + end + + local domains = {} + local has_domains = false + + local dkim_keys_dir = rspamd_paths["DBDIR"] .. "/dkim/" + + local prompt = string.format("Enter output directory for the keys [default: %s]: ", + highlight(dkim_keys_dir)) + dkim_keys_dir = remove_trailing_slash(readline_default(prompt, dkim_keys_dir)) + + local ret, err = rspamd_util.mkdir(dkim_keys_dir, true) + + if not ret then + printf("Cannot make directory %s: %s", dkim_keys_dir, highlight(err)) + os.exit(1) + end + + local function print_domains() + printf("Domains configured:") + for k, v in pairs(domains) do + printf("Domain: %s, selector: %s, privkey: %s", highlight(k), + v.selector, v.privkey) + end + printf("--") + end + local function print_public_key(pk) + local base64_pk = tostring(rspamd_util.encode_base64(pk)) + printf('v=DKIM1; k=rsa; p=%s\n', base64_pk) + end + repeat + if has_domains then + print_domains() + end + + local domain + repeat + domain = rspamd_util.readline("Enter domain to sign: ") + if not domain then + os.exit(1) + end + until #domain ~= 0 + + local selector = readline_default("Enter selector [default: dkim]: ", 'dkim') + if not selector then + selector = 'dkim' + end + + local privkey_file = string.format("%s/%s.%s.key", dkim_keys_dir, domain, + selector) + if not rspamd_util.file_exists(privkey_file) then + if ask_yes_no("Do you want to create privkey " .. highlight(privkey_file), + true) then + local rsa = require "rspamd_rsa" + local sk, pk = rsa.keypair(2048) + sk:save(privkey_file, 'pem') + print("You need to chown private key file to rspamd user!!") + print("To make dkim signing working, to place the following record in your DNS zone:") + print_public_key(tostring(pk)) + end + end + + domains[domain] = { + selector = selector, + path = privkey_file, + } + until not ask_yes_no("Do you wish to add another DKIM domain?") + + changes.l['dkim_signing.conf'] = { domain = domains } + local res_tbl = changes.l['dkim_signing.conf'] + + if sign_networks then + res_tbl.sign_networks = sign_networks + res_tbl.use_domain_sign_networks = sign_domain + else + res_tbl.use_domain = sign_domain + end + + if allow_mismatch then + res_tbl.allow_hdrfrom_mismatch = true + res_tbl.allow_hdrfrom_mismatch_sign_networks = true + res_tbl.allow_username_mismatch = true + end + + res_tbl.use_esld = use_esld + res_tbl.sign_authenticated = sign_authenticated +end + +local function check_redis_classifier(cls, changes) + local symbol_spam, symbol_ham + -- Load symbols from statfiles + local statfiles = cls.statfile + for _, stf in ipairs(statfiles) do + local symbol = stf.symbol or 'undefined' + + local spam + if stf.spam then + spam = stf.spam + else + if string.match(symbol:upper(), 'SPAM') then + spam = true + else + spam = false + end + end + + if spam then + symbol_spam = symbol + else + symbol_ham = symbol + end + end + + if not symbol_spam or not symbol_ham then + printf("Classifier has no symbols defined") + return + end + + local parsed_redis = lua_redis.try_load_redis_servers(cls, nil) + + if not parsed_redis and redis_params then + parsed_redis = lua_redis.try_load_redis_servers(redis_params, nil) + if not parsed_redis then + printf("Cannot parse Redis params") + return + end + end + + local function try_convert(update_config) + if ask_yes_no("Do you wish to convert data to the new schema?", true) then + local expire = readline_expire() + if not lua_stat_tools.convert_bayes_schema(parsed_redis, symbol_spam, + symbol_ham, expire) then + printf("Conversion failed") + else + printf("Conversion succeed") + if update_config then + changes.l['classifier-bayes.conf'] = { + new_schema = true, + } + + if expire then + changes.l['classifier-bayes.conf'].expire = expire + end + end + end + end + end + + local function get_version(conn) + conn:add_cmd("SMEMBERS", { "RS_keys" }) + + local ret, members = conn:exec() + + -- Empty db + if not ret or #members == 0 then + return false, 0 + end + + -- We still need to check versions + local lua_script = [[ +local ver = 0 + +local tst = redis.call('GET', KEYS[1]..'_version') +if tst then + ver = tonumber(tst) or 0 +end + +return ver +]] + conn:add_cmd('EVAL', { lua_script, '1', 'RS' }) + local _, ver = conn:exec() + + return true, tonumber(ver) + end + + local function check_expire(conn) + -- We still need to check versions + local lua_script = [[ +local ttl = 0 + +local sc = redis.call('SCAN', 0, 'MATCH', 'RS*_*', 'COUNT', 1) +local _,key = sc[1], sc[2] + +if key and key[1] then + ttl = redis.call('TTL', key[1]) +end + +return ttl +]] + conn:add_cmd('EVAL', { lua_script, '0' }) + local _, ttl = conn:exec() + + return tonumber(ttl) + end + + local res, conn = lua_redis.redis_connect_sync(parsed_redis, true) + if not res then + printf("Cannot connect to Redis server") + return false + end + + if not cls.new_schema then + local r, ver = get_version(conn) + if not r then + return false + end + if ver ~= 2 then + if not ver then + printf('Key "RS_version" has not been found in Redis for %s/%s', + symbol_ham, symbol_spam) + else + printf("You are using an old schema version: %s for %s/%s", + ver, symbol_ham, symbol_spam) + end + try_convert(true) + else + printf("You have configured an old schema for %s/%s but your data has new layout", + symbol_ham, symbol_spam) + + if ask_yes_no("Switch config to the new schema?", true) then + changes.l['classifier-bayes.conf'] = { + new_schema = true, + } + + local expire = check_expire(conn) + if expire then + changes.l['classifier-bayes.conf'].expire = expire + end + end + end + else + local r, ver = get_version(conn) + if not r then + return false + end + if ver ~= 2 then + printf("You have configured new schema for %s/%s but your DB has old version: %s", + symbol_spam, symbol_ham, ver) + try_convert(false) + else + printf( + 'You have configured new schema for %s/%s and your DB already has new layout (v. %s).' .. + ' DB conversion is not needed.', + symbol_spam, symbol_ham, ver) + end + end +end + +local function setup_statistic(cfg, changes) + local sqlite_configs = lua_stat_tools.load_sqlite_config(cfg) + + if #sqlite_configs > 0 then + + if not redis_params then + printf('You have %d sqlite classifiers, but you have no Redis servers being set', + #sqlite_configs) + return false + end + + local parsed_redis = lua_redis.try_load_redis_servers(redis_params, nil) + if parsed_redis then + printf('You have %d sqlite classifiers', #sqlite_configs) + local expire = readline_expire() + + local reset_previous = ask_yes_no("Reset previous data?") + if ask_yes_no('Do you wish to convert them to Redis?', true) then + + for _, cls in ipairs(sqlite_configs) do + if rspamd_util.file_exists(cls.db_spam) and rspamd_util.file_exists(cls.db_ham) then + if not lua_stat_tools.convert_sqlite_to_redis(parsed_redis, cls.db_spam, + cls.db_ham, cls.symbol_spam, cls.symbol_ham, cls.learn_cache, expire, + reset_previous) then + rspamd_logger.errx('conversion failed') + + return false + end + else + rspamd_logger.messagex('cannot find %s and %s, skip conversion', + cls.db_spam, cls.db_ham) + end + + rspamd_logger.messagex('Converted classifier to the from sqlite to redis') + changes.l['classifier-bayes.conf'] = { + backend = 'redis', + new_schema = true, + } + + if expire then + changes.l['classifier-bayes.conf'].expire = expire + end + + if cls.learn_cache then + changes.l['classifier-bayes.conf'].cache = { + backend = 'redis' + } + end + end + end + end + else + -- Check sanity for the existing Redis classifiers + local classifier = cfg.classifier + + if classifier then + if classifier[1] then + for _, cls in ipairs(classifier) do + if cls.bayes then + cls = cls.bayes + end + if cls.backend and cls.backend == 'redis' then + check_redis_classifier(cls, changes) + end + end + else + if classifier.bayes then + + classifier = classifier.bayes + if classifier[1] then + for _, cls in ipairs(classifier) do + if cls.backend and cls.backend == 'redis' then + check_redis_classifier(cls, changes) + end + end + else + if classifier.backend and classifier.backend == 'redis' then + check_redis_classifier(classifier, changes) + end + end + end + end + end + end +end + +local function find_worker(cfg, wtype) + if cfg.worker then + for k, s in pairs(cfg.worker) do + if type(k) == 'number' and type(s) == 'table' then + if s[wtype] then + return s[wtype] + end + end + if type(s) == 'table' and s.type and s.type == wtype then + return s + end + if type(k) == 'string' and k == wtype then + return s + end + end + end + + return nil +end + +return { + handler = function(cmd_args) + local changes = { + l = {}, -- local changes + o = {}, -- override changes + } + + local interactive_start = true + local checks = {} + local all_checks = { + 'controller', + 'redis', + 'dkim', + 'statistic', + } + + local opts = parser:parse(cmd_args) + local args = opts['checks'] or {} + + local _r, err = rspamd_config:load_ucl(opts['config']) + + if not _r then + rspamd_logger.errx('cannot parse %s: %s', opts['config'], err) + os.exit(1) + end + + _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' }) + if not _r then + rspamd_logger.errx('cannot process %s: %s', opts['config'], err) + os.exit(1) + end + + local cfg = rspamd_config:get_ucl() + + if not rspamd_config:init_modules() then + rspamd_logger.errx('cannot init modules when parsing %s', opts['config']) + os.exit(1) + end + + if #args > 0 then + interactive_start = false + + for _, arg in ipairs(args) do + if arg == 'all' then + checks = all_checks + elseif arg == 'list' then + printf(highlight(rspamd_logo)) + printf('Available modules') + for _, c in ipairs(all_checks) do + printf('- %s', c) + end + return + else + table.insert(checks, arg) + end + end + else + checks = all_checks + end + + local function has_check(check) + for _, c in ipairs(checks) do + if c == check then + return true + end + end + + return false + end + + rspamd_util.umask('022') + if interactive_start then + printf(highlight(rspamd_logo)) + printf("Welcome to the configuration tool") + printf("We use %s configuration file, writing results to %s", + highlight(opts['config']), highlight(local_conf)) + plugins_stat(nil, nil) + end + + if not interactive_start or + ask_yes_no("Do you wish to continue?", true) then + + if has_check('controller') then + local controller = find_worker(cfg, 'controller') + if controller then + setup_controller(controller, changes) + end + end + + if has_check('redis') then + if not cfg.redis or (not cfg.redis.servers and not cfg.redis.read_servers) then + setup_redis(cfg, changes) + else + redis_params = cfg.redis + end + else + redis_params = cfg.redis + end + + if has_check('dkim') then + if cfg.dkim_signing and not cfg.dkim_signing.domain then + if ask_yes_no('Do you want to setup dkim signing feature?') then + setup_dkim_signing(cfg, changes) + end + end + end + + if has_check('statistic') or has_check('statistics') then + setup_statistic(cfg, changes) + end + + local nchanges = 0 + for _, _ in pairs(changes.l) do + nchanges = nchanges + 1 + end + for _, _ in pairs(changes.o) do + nchanges = nchanges + 1 + end + + if nchanges > 0 then + print_changes(changes) + if ask_yes_no("Apply changes?", true) then + apply_changes(changes) + printf("%d changes applied, the wizard is finished now", nchanges) + printf("*** Please reload the Rspamd configuration ***") + else + printf("No changes applied, the wizard is finished now") + end + else + printf("No changes found, the wizard is finished now") + end + end + end, + name = 'configwizard', + description = parser._description, +} diff --git a/lualib/rspamadm/cookie.lua b/lualib/rspamadm/cookie.lua new file mode 100644 index 0000000..7e0526a --- /dev/null +++ b/lualib/rspamadm/cookie.lua @@ -0,0 +1,125 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local argparse = require "argparse" + + +-- Define command line options +local parser = argparse() + :name "rspamadm cookie" + :description "Produces cookies or message ids" + :help_description_margin(30) + +parser:mutex( + parser:option "-k --key" + :description('Key to load') + :argname "<32hex>", + parser:flag "-K --new-key" + :description('Generates a new key') +) + +parser:option "-d --domain" + :description('Use specified domain and generate full message id') + :argname "<domain>" +parser:flag "-D --decrypt" + :description('Decrypt cookie instead of encrypting one') +parser:flag "-t --timestamp" + :description('Show cookie timestamp (valid for decrypting only)') +parser:argument "cookie":args "?" + :description('Use specified cookie') + +local function gen_cookie(args, key) + local cr = require "rspamd_cryptobox" + + if not args.cookie then + return + end + + local function encrypt() + if #args.cookie > 31 then + print('cookie too long (>31 characters), cannot encrypt') + os.exit(1) + end + + local enc_cookie = cr.encrypt_cookie(key, args.cookie) + if args.domain then + print(string.format('<%s@%s>', enc_cookie, args.domain)) + else + print(enc_cookie) + end + end + + local function decrypt() + local extracted_cookie = args.cookie:match('^%<?([^@]+)@.*$') + if not extracted_cookie then + -- Assume full message id as a cookie + extracted_cookie = args.cookie + end + + local dec_cookie, ts = cr.decrypt_cookie(key, extracted_cookie) + + if dec_cookie then + if args.timestamp then + print(string.format('%s %s', dec_cookie, ts)) + else + print(dec_cookie) + end + else + print('cannot decrypt cookie') + os.exit(1) + end + end + + if args.decrypt then + decrypt() + else + encrypt() + end +end + +local function handler(args) + local res = parser:parse(args) + + if not (res.key or res['new_key']) then + parser:error('--key or --new-key must be specified') + end + + if res.key then + local pattern = { '^' } + for i = 1, 32 do + pattern[i + 1] = '[a-zA-Z0-9]' + end + pattern[34] = '$' + + if not res.key:match(table.concat(pattern, '')) then + parser:error('invalid key: ' .. res.key) + end + + gen_cookie(res, res.key) + else + local util = require "rspamd_util" + local key = util.random_hex(32) + + print(key) + gen_cookie(res, res.key) + end +end + +return { + handler = handler, + description = parser._description, + name = 'cookie' +}
\ No newline at end of file diff --git a/lualib/rspamadm/corpus_test.lua b/lualib/rspamadm/corpus_test.lua new file mode 100644 index 0000000..0e63f9f --- /dev/null +++ b/lualib/rspamadm/corpus_test.lua @@ -0,0 +1,185 @@ +local rspamd_logger = require "rspamd_logger" +local ucl = require "ucl" +local lua_util = require "lua_util" +local argparse = require "argparse" + +local parser = argparse() + :name "rspamadm corpus_test" + :description "Create logs files from email corpus" + :help_description_margin(32) + +parser:option "-H --ham" + :description("Ham directory") + :argname("<dir>") +parser:option "-S --spam" + :description("Spam directory") + :argname("<dir>") +parser:option "-n --conns" + :description("Number of parallel connections") + :argname("<N>") + :convert(tonumber) + :default(10) +parser:option "-o --output" + :description("Output file") + :argname("<file>") + :default('results.log') +parser:option "-t --timeout" + :description("Timeout for client connections") + :argname("<sec>") + :convert(tonumber) + :default(60) +parser:option "-c --connect" + :description("Connect to specific host") + :argname("<host>") + :default('localhost:11334') +parser:option "-r --rspamc" + :description("Use specific rspamc path") + :argname("<path>") + :default('rspamc') + +local HAM = "HAM" +local SPAM = "SPAM" +local opts + +local function scan_email(n_parallel, path, timeout) + + local rspamc_command = string.format("%s --connect %s -j --compact -n %s -t %.3f %s", + opts.rspamc, opts.connect, n_parallel, timeout, path) + local result = assert(io.popen(rspamc_command)) + result = result:read("*all") + return result +end + +local function write_results(results, file) + + local f = io.open(file, 'w') + + for _, result in pairs(results) do + local log_line = string.format("%s %.2f %s", + result.type, result.score, result.action) + + for _, sym in pairs(result.symbols) do + log_line = log_line .. " " .. sym + end + + log_line = log_line .. " " .. result.scan_time .. " " .. file .. ':' .. result.filename + + log_line = log_line .. "\r\n" + + f:write(log_line) + end + + f:close() +end + +local function encoded_json_to_log(result) + -- Returns table containing score, action, list of symbols + + local filtered_result = {} + local ucl_parser = ucl.parser() + + local is_good, err = ucl_parser:parse_string(result) + + if not is_good then + rspamd_logger.errx("Parser error: %1", err) + return nil + end + + result = ucl_parser:get_object() + + filtered_result.score = result.score + if not result.action then + rspamd_logger.errx("Bad JSON: %1", result) + return nil + end + local action = result.action:gsub("%s+", "_") + filtered_result.action = action + + filtered_result.symbols = {} + + for sym, _ in pairs(result.symbols) do + table.insert(filtered_result.symbols, sym) + end + + filtered_result.filename = result.filename + filtered_result.scan_time = result.scan_time + + return filtered_result +end + +local function scan_results_to_logs(results, actual_email_type) + + local logs = {} + + results = lua_util.rspamd_str_split(results, "\n") + + if results[#results] == "" then + results[#results] = nil + end + + for _, result in pairs(results) do + result = encoded_json_to_log(result) + if result then + result['type'] = actual_email_type + table.insert(logs, result) + end + end + + return logs +end + +local function handler(args) + opts = parser:parse(args) + local ham_directory = opts['ham'] + local spam_directory = opts['spam'] + local connections = opts["conns"] + local output = opts["output"] + + local results = {} + + local start_time = os.time() + local no_of_ham = 0 + local no_of_spam = 0 + + if ham_directory then + rspamd_logger.messagex("Scanning ham corpus...") + local ham_results = scan_email(connections, ham_directory, opts["timeout"]) + ham_results = scan_results_to_logs(ham_results, HAM) + + no_of_ham = #ham_results + + for _, result in pairs(ham_results) do + table.insert(results, result) + end + end + + if spam_directory then + rspamd_logger.messagex("Scanning spam corpus...") + local spam_results = scan_email(connections, spam_directory, opts.timeout) + spam_results = scan_results_to_logs(spam_results, SPAM) + + no_of_spam = #spam_results + + for _, result in pairs(spam_results) do + table.insert(results, result) + end + end + + rspamd_logger.messagex("Writing results to %s", output) + write_results(results, output) + + rspamd_logger.messagex("Stats: ") + local elapsed_time = os.time() - start_time + local total_msgs = no_of_ham + no_of_spam + rspamd_logger.messagex("Elapsed time: %ss", elapsed_time) + rspamd_logger.messagex("No of ham: %s", no_of_ham) + rspamd_logger.messagex("No of spam: %s", no_of_spam) + rspamd_logger.messagex("Messages/sec: %s", (total_msgs / elapsed_time)) +end + +return { + name = 'corpustest', + aliases = { 'corpus_test', 'corpus' }, + handler = handler, + description = parser._description +} diff --git a/lualib/rspamadm/dkim_keygen.lua b/lualib/rspamadm/dkim_keygen.lua new file mode 100644 index 0000000..211094d --- /dev/null +++ b/lualib/rspamadm/dkim_keygen.lua @@ -0,0 +1,178 @@ +--[[ +Copyright (c) 2023, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local argparse = require "argparse" +local rspamd_util = require "rspamd_util" +local rspamd_cryptobox = require "rspamd_cryptobox" + +local parser = argparse() + :name 'rspamadm dkim_keygen' + :description 'Create key pairs for dkim signing' + :help_description_margin(30) +parser:option '-d --domain' + :description 'Create a key for a specific domain' + :default "example.com" +parser:option '-s --selector' + :description 'Create a key for a specific DKIM selector' + :default "mail" +parser:option '-k --privkey' + :description 'Save private key to file instead of printing it to stdout' +parser:option '-b --bits' + :convert(function(input) + local n = tonumber(input) + if not n or n < 512 or n > 4096 then + return nil + end + return n +end) + :description 'Generate an RSA key with the specified number of bits (512 to 4096)' + :default(1024) +parser:option '-t --type' + :description 'Key type: RSA, ED25519 or ED25119-seed' + :convert { + ['rsa'] = 'rsa', + ['RSA'] = 'rsa', + ['ed25519'] = 'ed25519', + ['ED25519'] = 'ed25519', + ['ed25519-seed'] = 'ed25519-seed', + ['ED25519-seed'] = 'ed25519-seed', +} + :default 'rsa' +parser:option '-o --output' + :description 'Output public key in the following format: dns, dnskey or plain' + :convert { + ['dns'] = 'dns', + ['plain'] = 'plain', + ['dnskey'] = 'dnskey', +} + :default 'dns' +parser:option '--priv-output' + :description 'Output private key in the following format: PEM or DER (for RSA)' + :convert { + ['pem'] = 'pem', + ['der'] = 'der', +} + :default 'pem' +parser:flag '-f --force' + :description 'Force overwrite of existing files' + +local function split_string(input, max_length) + max_length = max_length or 253 + local pieces = {} + local index = 1 + + while index <= #input do + local piece = input:sub(index, index + max_length - 1) + table.insert(pieces, piece) + index = index + max_length + end + + return pieces +end + +local function print_public_key_dns(opts, base64_pk) + local key_type = opts.type == 'rsa' and 'rsa' or 'ed25519' + if #base64_pk < 255 - 2 then + io.write(string.format('%s._domainkey IN TXT ( "v=DKIM1; k=%s;" \n\t"p=%s" ) ;\n', + opts.selector, key_type, base64_pk)) + else + -- Split it by parts + local parts = split_string(base64_pk) + io.write(string.format('%s._domainkey IN TXT ( "v=DKIM1; k=%s; "\n', opts.selector, key_type)) + for i, part in ipairs(parts) do + if i == 1 then + io.write(string.format('\t"p=%s"\n', part)) + else + io.write(string.format('\t"%s"\n', part)) + end + end + io.write(") ; \n") + end + +end + +local function print_public_key(opts, pk, need_base64) + local key_type = opts.type == 'rsa' and 'rsa' or 'ed25519' + local base64_pk = need_base64 and tostring(rspamd_util.encode_base64(pk)) or tostring(pk) + if opts.output == 'plain' then + io.write(base64_pk) + io.write("\n") + elseif opts.output == 'dns' then + print_public_key_dns(opts, base64_pk, false) + elseif opts.output == 'dnskey' then + io.write(string.format('v=DKIM1; k=%s; p=%s\n', key_type, base64_pk)) + end +end + +local function gen_rsa_key(opts) + local rsa = require "rspamd_rsa" + + local sk, pk = rsa.keypair(opts.bits or 1024) + if opts.privkey then + if opts.force then + os.remove(opts.privkey) + end + sk:save(opts.privkey, opts.priv_output) + else + sk:save("-", opts.priv_output) + end + + -- We generate key directly via lua_rsa and it returns unencoded raw data + print_public_key(opts, tostring(pk), true) +end + +local function gen_eddsa_key(opts) + local sk, pk = rspamd_cryptobox.gen_dkim_keypair(opts.type) + + if opts.privkey and opts.force then + os.remove(opts.privkey) + end + if not sk:save_in_file(opts.privkey, tonumber('0600', 8)) then + io.stderr:write('cannot save private key to ' .. (opts.privkey or 'stdout') .. '\n') + os.exit(1) + end + + if not opts.privkey then + io.write("\n") + io.flush() + end + + -- gen_dkim_keypair function returns everything encoded in base64, so no need to do anything + print_public_key(opts, tostring(pk), false) +end + +local function handler(args) + local opts = parser:parse(args) + + if not opts then + os.exit(1) + end + + if opts.type == 'rsa' then + gen_rsa_key(opts) + else + gen_eddsa_key(opts) + end +end + +return { + name = 'dkim_keygen', + aliases = { 'dkimkeygen' }, + handler = handler, + description = parser._description +} + + diff --git a/lualib/rspamadm/dmarc_report.lua b/lualib/rspamadm/dmarc_report.lua new file mode 100644 index 0000000..42c801e --- /dev/null +++ b/lualib/rspamadm/dmarc_report.lua @@ -0,0 +1,737 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local argparse = require "argparse" +local lua_util = require "lua_util" +local logger = require "rspamd_logger" +local lua_redis = require "lua_redis" +local dmarc_common = require "plugins/dmarc" +local lupa = require "lupa" +local rspamd_mempool = require "rspamd_mempool" +local rspamd_url = require "rspamd_url" +local rspamd_text = require "rspamd_text" +local rspamd_util = require "rspamd_util" +local rspamd_dns = require "rspamd_dns" + +local N = 'dmarc_report' + +-- Define command line options +local parser = argparse() + :name "rspamadm dmarc_report" + :description "Dmarc reports sending tool" + :help_description_margin(30) + +parser:option "-c --config" + :description "Path to config file" + :argname("<cfg>") + :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf") + +parser:flag "-v --verbose" + :description "Enable dmarc specific logging" + +parser:flag "-n --no-opt" + :description "Do not reset reporting data/send reports" + +parser:argument "date" + :description "Date to process (today by default)" + :argname "<YYYYMMDD>" + :args "*" +parser:option "-b --batch-size" + :description "Send reports in batches up to <batch-size> messages" + :argname "<number>" + :convert(tonumber) + :default "10" + +local report_template = [[From: "{= from_name =}" <{= from_addr =}> +To: {= rcpt =} +{%+ if is_string(bcc) %}Bcc: {= bcc =}{%- endif %} +Subject: Report Domain: {= reporting_domain =} + Submitter: {= submitter =} + Report-ID: {= report_id =} +Date: {= report_date =} +MIME-Version: 1.0 +Message-ID: <{= message_id =}> +Content-Type: multipart/mixed; + boundary="----=_NextPart_{= uuid =}" + +This is a multipart message in MIME format. + +------=_NextPart_{= uuid =} +Content-Type: text/plain; charset="us-ascii" +Content-Transfer-Encoding: 7bit + +This is an aggregate report from {= submitter =}. + +Report domain: {= reporting_domain =} +Submitter: {= submitter =} +Report ID: {= report_id =} + +------=_NextPart_{= uuid =} +Content-Type: application/gzip +Content-Transfer-Encoding: base64 +Content-Disposition: attachment; + filename="{= submitter =}!{= reporting_domain =}!{= report_start =}!{= report_end =}.xml.gz" + +]] +local report_footer = [[ + +------=_NextPart_{= uuid =}--]] + +local dmarc_settings = {} +local redis_params +local redis_attrs = { + config = rspamd_config, + ev_base = rspamadm_ev_base, + session = rspamadm_session, + log_obj = rspamd_config, + resolver = rspamadm_dns_resolver, +} +local pool + +local function load_config(opts) + local _r, err = rspamd_config:load_ucl(opts['config']) + + if not _r then + logger.errx('cannot parse %s: %s', opts['config'], err) + os.exit(1) + end + + _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' }) + if not _r then + logger.errx('cannot process %s: %s', opts['config'], err) + os.exit(1) + end +end + +-- Concat elements using redis_keys.join_char +local function redis_prefix(...) + return table.concat({ ... }, dmarc_settings.reporting.redis_keys.join_char) +end + +local function get_rua(rep_key) + local parts = lua_util.str_split(rep_key, dmarc_settings.reporting.redis_keys.join_char) + + if #parts >= 3 then + return parts[3] + end + + return nil +end + +local function get_domain(rep_key) + local parts = lua_util.str_split(rep_key, dmarc_settings.reporting.redis_keys.join_char) + + if #parts >= 3 then + return parts[2] + end + + return nil +end + +local function gen_uuid() + local template = 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx' + return string.gsub(template, '[xy]', function(c) + local v = (c == 'x') and math.random(0, 0xf) or math.random(8, 0xb) + return string.format('%x', v) + end) +end + +local function gen_xml_grammar() + local lpeg = require 'lpeg' + local lt = lpeg.P('<') / '<' + local gt = lpeg.P('>') / '>' + local amp = lpeg.P('&') / '&' + local quot = lpeg.P('"') / '"' + local apos = lpeg.P("'") / ''' + local special = lt + gt + amp + quot + apos + local grammar = lpeg.Cs((special + 1) ^ 0) + return grammar +end + +local xml_grammar = gen_xml_grammar() + +local function escape_xml(input) + if type(input) == 'string' or type(input) == 'userdata' then + return xml_grammar:match(input) + else + input = tostring(input) + + if input then + return xml_grammar:match(input) + end + end + + return '' +end +-- Enable xml escaping in lupa templates +lupa.filters.escape_xml = escape_xml + +-- Creates report XML header +local function report_header(reporting_domain, report_start, report_end, domain_policy) + local report_id = string.format('%s.%d.%d', + reporting_domain, report_start, report_end) + local xml_template = [[ +<?xml version="1.0" encoding="UTF-8" ?> +<feedback> + <report_metadata> + <org_name>{= report_settings.org_name | escape_xml =}</org_name> + <email>{= report_settings.email | escape_xml =}</email> + <report_id>{= report_id =}</report_id> + <date_range> + <begin>{= report_start =}</begin> + <end>{= report_end =}</end> + </date_range> + </report_metadata> + <policy_published> + <domain>{= reporting_domain | escape_xml =}</domain> + <adkim>{= domain_policy.adkim | escape_xml =}</adkim> + <aspf>{= domain_policy.aspf | escape_xml =}</aspf> + <p>{= domain_policy.p | escape_xml =}</p> + <sp>{= domain_policy.sp | escape_xml =}</sp> + <pct>{= domain_policy.pct | escape_xml =}</pct> + </policy_published> +]] + return lua_util.jinja_template(xml_template, { + report_settings = dmarc_settings.reporting, + report_id = report_id, + report_start = report_start, + report_end = report_end, + domain_policy = domain_policy, + reporting_domain = reporting_domain, + }, true) +end + +-- Generate xml entry for a preprocessed redis row +local function entry_to_xml(data) + local xml_template = [[<record> + <row> + <source_ip>{= data.ip =}</source_ip> + <count>{= data.count =}</count> + <policy_evaluated> + <disposition>{= data.disposition =}</disposition> + <dkim>{= data.dkim_disposition =}</dkim> + <spf>{= data.spf_disposition =}</spf> + {% if data.override and data.override ~= '' -%} + <reason><type>{= data.override =}</type></reason> + {%- endif %} + </policy_evaluated> + </row> + <identifiers> + <header_from>{= data.header_from =}</header_from> + </identifiers> + <auth_results> + {% if data.dkim_results[1] -%} + {% for d in data.dkim_results -%} + <dkim> + <domain>{= d.domain =}</domain> + <result>{= d.result =}</result> + </dkim> + {%- endfor %} + {%- endif %} + <spf> + <domain>{= data.spf_domain =}</domain> + <result>{= data.spf_result =}</result> + </spf> + </auth_results> +</record> +]] + return lua_util.jinja_template(xml_template, { data = data }, true) +end + +-- Process a report entry stored in Redis splitting it to a lua table +local function process_report_entry(data, score) + local split = lua_util.str_split(data, ',') + local row = { + ip = split[1], + spf_disposition = split[2], + dkim_disposition = split[3], + disposition = split[4], + override = split[5], + header_from = split[6], + dkim_results = {}, + spf_domain = split[11], + spf_result = split[12], + count = tonumber(score), + } + -- Process dkim entries + local function dkim_entries_process(dkim_data, result) + if dkim_data and dkim_data ~= '' then + local dkim_elts = lua_util.str_split(dkim_data, '|') + for _, d in ipairs(dkim_elts) do + table.insert(row.dkim_results, { domain = d, result = result }) + end + end + end + dkim_entries_process(split[7], 'pass') + dkim_entries_process(split[8], 'fail') + dkim_entries_process(split[9], 'temperror') + dkim_entries_process(split[9], 'permerror') + + return row +end + +-- Process a single rua entry, validating in DNS if needed +local function process_rua(dmarc_domain, rua) + local parts = lua_util.str_split(rua, ',') + + -- Remove size limitation, as we don't care about them + local addrs = {} + for _, rua_part in ipairs(parts) do + local u = rspamd_url.create(pool, rua_part:gsub('!%d+[kmg]?$', '')) + local u2 = rspamd_url.create(pool, dmarc_domain) + if u and (u:get_protocol() or '') == 'mailto' and u:get_user() then + -- Check each address for sanity + if u:get_tld() == u2:get_tld() then + -- Same eSLD - always include + table.insert(addrs, u) + else + -- We need to check authority + local resolve_str = string.format('%s._report._dmarc.%s', + dmarc_domain, u:get_host()) + local is_ok, results = rspamd_dns.request({ + config = rspamd_config, + session = rspamadm_session, + type = 'txt', + name = resolve_str, + }) + + if not is_ok then + logger.errx('cannot resolve %s: %s; exclude %s', resolve_str, results, rua_part) + else + local found = false + for _, t in ipairs(results) do + if string.match(t, 'v=DMARC1') then + found = true + break + end + end + + if not found then + logger.errx('%s is not authorized to process reports on %s', dmarc_domain, u:get_host()) + else + -- All good + table.insert(addrs, u) + end + end + end + else + logger.errx('invalid rua url: "%s""', tostring(u or 'null')) + end + end + + if #addrs > 0 then + return addrs + end + + return nil +end + +-- Validate reporting domain, extracting rua and checking 3rd party report domains +-- This function returns a full dmarc record processed + rua as a list of url objects +local function validate_reporting_domain(reporting_domain) + local is_ok, results = rspamd_dns.request({ + config = rspamd_config, + session = rspamadm_session, + type = 'txt', + name = '_dmarc.' .. reporting_domain, + }) + + if not is_ok or not results then + logger.errx('cannot resolve _dmarc.%s: %s', reporting_domain, results) + return nil + end + + for _, r in ipairs(results) do + local processed, rec = dmarc_common.dmarc_check_record(rspamd_config, r, false) + if processed and rec.rua then + -- We need to check or alter rua if needed + local processed_rua = process_rua(reporting_domain, rec.rua) + if processed_rua then + rec = rec.raw_elts + rec.rua = processed_rua + + -- Fill defaults in a record to avoid nils in a report + rec['pct'] = rec['pct'] or 100 + rec['adkim'] = rec['adkim'] or 'r' + rec['aspf'] = rec['aspf'] or 'r' + rec['p'] = rec['p'] or 'none' + rec['sp'] = rec['sp'] or 'none' + return rec + end + return nil + end + end + + return nil +end + +-- Returns a list of recipients from a table as a string processing elements if needed +local function rcpt_list(tbl, func) + local res = {} + for _, r in ipairs(tbl) do + if func then + table.insert(res, func(r)) + else + table.insert(res, r) + end + end + + return table.concat(res, ',') +end + +-- Synchronous smtp send function +local function send_reports_by_smtp(opts, reports, finish_cb) + local lua_smtp = require "lua_smtp" + local reports_failed = 0 + local reports_sent = 0 + local report_settings = dmarc_settings.reporting + + local function gen_sendmail_cb(report, args) + return function(ret, err) + -- We modify this from all callbacks + args.nreports = args.nreports - 1 + if not ret then + logger.errx("Couldn't send mail for %s: %s", report.reporting_domain, err) + reports_failed = reports_failed + 1 + else + reports_sent = reports_sent + 1 + lua_util.debugm(N, 'successfully sent a report for %s: %s bytes sent', + report.reporting_domain, #report.message) + end + + -- Tail call to the next batch or to the final function + if args.nreports == 0 then + if args.next_start > #reports then + finish_cb(reports_sent, reports_failed) + else + args.cont_func(args.next_start) + end + end + end + end + + local function send_data_in_batches(cur_batch) + local nreports = math.min(#reports - cur_batch + 1, opts.batch_size) + local next_start = cur_batch + nreports + lua_util.debugm(N, 'send data for %s domains (from %s to %s)', + nreports, cur_batch, next_start - 1) + -- Shared across all closures + local gen_args = { + cont_func = send_data_in_batches, + nreports = nreports, + next_start = next_start + } + for i = cur_batch, next_start - 1 do + local report = reports[i] + local send_opts = { + ev_base = rspamadm_ev_base, + session = rspamadm_session, + config = rspamd_config, + host = report_settings.smtp, + port = report_settings.smtp_port or 25, + resolver = rspamadm_dns_resolver, + from = report_settings.email, + recipients = report.rcpts, + helo = report_settings.helo or 'rspamd.localhost', + } + + lua_smtp.sendmail(send_opts, + report.message, + gen_sendmail_cb(report, gen_args)) + end + end + + send_data_in_batches(1) +end + +local function prepare_report(opts, start_time, end_time, rep_key) + local rua = get_rua(rep_key) + local reporting_domain = get_domain(rep_key) + + if not rua then + logger.errx('report %s has no valid rua, skip it', rep_key) + return nil + end + if not reporting_domain then + logger.errx('report %s has no valid reporting_domain, skip it', rep_key) + return nil + end + + local ret, results = lua_redis.request(redis_params, redis_attrs, + { 'EXISTS', rep_key }) + + if not ret or not results or results == 0 then + return nil + end + + -- Rename report key to avoid races + if not opts.no_opt then + lua_redis.request(redis_params, redis_attrs, + { 'RENAME', rep_key, rep_key .. '_processing' }) + rep_key = rep_key .. '_processing' + end + + local dmarc_record = validate_reporting_domain(reporting_domain) + lua_util.debugm(N, 'process reporting domain %s: %s', reporting_domain, dmarc_record) + + if not dmarc_record then + if not opts.no_opt then + lua_redis.request(redis_params, redis_attrs, + { 'DEL', rep_key }) + end + logger.messagex('Cannot process reports for domain %s; invalid dmarc record', reporting_domain) + return nil + end + + -- Get all reports for a domain + ret, results = lua_redis.request(redis_params, redis_attrs, + { 'ZRANGE', rep_key, '0', '-1', 'WITHSCORES' }) + local report_entries = {} + table.insert(report_entries, + report_header(reporting_domain, start_time, end_time, dmarc_record)) + for i = 1, #results, 2 do + local xml_record = entry_to_xml(process_report_entry(results[i], results[i + 1])) + table.insert(report_entries, xml_record) + end + table.insert(report_entries, '</feedback>') + local xml_to_compress = rspamd_text.fromtable(report_entries) + lua_util.debugm(N, 'got xml: %s', xml_to_compress) + + -- Prepare SMTP message + local report_settings = dmarc_settings.reporting + local rcpt_string = rcpt_list(dmarc_record.rua, function(rua_elt) + return string.format('%s@%s', rua_elt:get_user(), rua_elt:get_host()) + end) + local bcc_string + if report_settings.bcc_addrs then + bcc_string = rcpt_list(report_settings.bcc_addrs) + end + local uuid = gen_uuid() + local rhead = lua_util.jinja_template(report_template, { + from_name = report_settings.from_name, + from_addr = report_settings.email, + rcpt = rcpt_string, + bcc = bcc_string, + uuid = uuid, + reporting_domain = reporting_domain, + submitter = report_settings.domain, + report_id = string.format('%s.%d.%d', reporting_domain, start_time, + end_time), + report_date = rspamd_util.time_to_string(rspamd_util.get_time()), + message_id = rspamd_util.random_hex(16) .. '@' .. report_settings.msgid_from, + report_start = start_time, + report_end = end_time + }, true) + local rfooter = lua_util.jinja_template(report_footer, { + uuid = uuid, + }, true) + local message = rspamd_text.fromtable { + (rhead:gsub("\n", "\r\n")), + rspamd_util.encode_base64(rspamd_util.gzip_compress(xml_to_compress), 73), + rfooter:gsub("\n", "\r\n"), + } + + lua_util.debugm(N, 'got final message: %s', message) + + if not opts.no_opt then + lua_redis.request(redis_params, redis_attrs, + { 'DEL', rep_key }) + end + + local report_rcpts = lua_util.str_split(rcpt_string, ',') + + if report_settings.bcc_addrs then + for _, b in ipairs(report_settings.bcc_addrs) do + table.insert(report_rcpts, b) + end + end + + return { + message = message, + rcpts = report_rcpts, + reporting_domain = reporting_domain + } +end + +local function process_report_date(opts, start_time, end_time, date) + local idx_key = redis_prefix(dmarc_settings.reporting.redis_keys.index_prefix, date) + local ret, results = lua_redis.request(redis_params, redis_attrs, + { 'EXISTS', idx_key }) + + if not ret or not results or results == 0 then + logger.messagex('No reports for %s', date) + return {} + end + + -- Rename index key to avoid races + if not opts.no_opt then + lua_redis.request(redis_params, redis_attrs, + { 'RENAME', idx_key, idx_key .. '_processing' }) + idx_key = idx_key .. '_processing' + end + ret, results = lua_redis.request(redis_params, redis_attrs, + { 'SMEMBERS', idx_key }) + + if not ret or not results then + -- Remove bad key + if not opts.no_opt then + lua_redis.request(redis_params, redis_attrs, + { 'DEL', idx_key }) + end + logger.messagex('Cannot get reports for %s', date) + return {} + end + + local reports = {} + for _, rep in ipairs(results) do + local report = prepare_report(opts, start_time, end_time, rep) + + if report then + table.insert(reports, report) + end + end + + -- Shuffle reports to make sending more fair + lua_util.shuffle(reports) + -- Remove processed key + if not opts.no_opt then + lua_redis.request(redis_params, redis_attrs, + { 'DEL', idx_key }) + end + + return reports +end + + +-- Returns a day before today at 00:00 as unix seconds +local function yesterday_midnight() + local piecewise_time = os.date("*t") + piecewise_time.day = piecewise_time.day - 1 -- Lua allows negative values here + piecewise_time.hour = 0 + piecewise_time.sec = 0 + piecewise_time.min = 0 + return os.time(piecewise_time) +end + +-- Returns today time at 00:00 as unix seconds +local function today_midnight() + local piecewise_time = os.date("*t") + piecewise_time.hour = 0 + piecewise_time.sec = 0 + piecewise_time.min = 0 + return os.time(piecewise_time) +end + +local function handler(args) + local start_time + -- Preserve start time as report sending might take some time + local start_collection = today_midnight() + + local opts = parser:parse(args) + + pool = rspamd_mempool.create() + load_config(opts) + rspamd_url.init(rspamd_config:get_tld_path()) + + if opts.verbose then + lua_util.enable_debug_modules('dmarc', N) + end + + dmarc_settings = rspamd_config:get_all_opt('dmarc') + if not dmarc_settings or not dmarc_settings.reporting or not dmarc_settings.reporting.enabled then + logger.errx('dmarc reporting is not enabled, exiting') + os.exit(1) + end + + dmarc_settings = lua_util.override_defaults(dmarc_common.default_settings, dmarc_settings) + redis_params = lua_redis.parse_redis_server('dmarc', dmarc_settings) + + if not redis_params then + logger.errx('Redis is not configured, exiting') + os.exit(1) + end + + for _, e in ipairs({ 'email', 'domain', 'org_name' }) do + if not dmarc_settings.reporting[e] then + logger.errx('Missing required setting: dmarc.reporting.%s', e) + return + end + end + + local ret, results = lua_redis.request(redis_params, redis_attrs, { + 'GET', 'rspamd_dmarc_last_collection' + }) + + if not ret or not tonumber(results) then + start_time = yesterday_midnight() + else + start_time = tonumber(results) + end + + lua_util.debugm(N, 'previous last report date is %s', start_time) + + if not opts.date or #opts.date == 0 then + opts.date = {} + table.insert(opts.date, os.date('%Y%m%d', yesterday_midnight())) + end + + local ndates = 0 + local nreports = 0 + local all_reports = {} + for _, date in ipairs(opts.date) do + lua_util.debugm(N, 'Process date %s', date) + local reports_for_date = process_report_date(opts, start_time, start_collection, date) + if #reports_for_date > 0 then + ndates = ndates + 1 + nreports = nreports + #reports_for_date + + for _, r in ipairs(reports_for_date) do + table.insert(all_reports, r) + end + end + end + + local function finish_cb(nsuccess, nfail) + if not opts.no_opt then + lua_util.debugm(N, 'set last report date to %s', start_collection) + -- Hack to avoid coroutines + async functions mess: we use async redis call here + redis_attrs.callback = function() + logger.messagex('Reporting collection has finished %s dates processed, %s reports: %s completed, %s failed', + ndates, nreports, nsuccess, nfail) + end + lua_redis.request(redis_params, redis_attrs, + { 'SETEX', 'rspamd_dmarc_last_collection', dmarc_settings.reporting.keys_expire * 2, + tostring(start_collection) }) + else + logger.messagex('Reporting collection has finished %s dates processed, %s reports: %s completed, %s failed', + ndates, nreports, nsuccess, nfail) + end + + pool:destroy() + end + if not opts.no_opt then + send_reports_by_smtp(opts, all_reports, finish_cb) + else + logger.messagex('Skip sending mails due to -n / --no-opt option') + end +end + +return { + name = 'dmarc_report', + aliases = { 'dmarc_reporting' }, + handler = handler, + description = parser._description +} diff --git a/lualib/rspamadm/dns_tool.lua b/lualib/rspamadm/dns_tool.lua new file mode 100644 index 0000000..3eb09a8 --- /dev/null +++ b/lualib/rspamadm/dns_tool.lua @@ -0,0 +1,232 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + + +local argparse = require "argparse" +local rspamd_logger = require "rspamd_logger" +local ansicolors = require "ansicolors" +local bit = require "bit" + +local parser = argparse() + :name "rspamadm dnstool" + :description "DNS tools provided by Rspamd" + :help_description_margin(30) + :command_target("command") + :require_command(true) + +parser:option "-c --config" + :description "Path to config file" + :argname("<cfg>") + :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf") + +local spf = parser:command "spf" + :description "Extracts spf records" +spf:mutex( + spf:option "-d --domain" + :description "Domain to use" + :argname("<domain>"), + spf:option "-f --from" + :description "SMTP from to use" + :argname("<from>") +) + +spf:option "-i --ip" + :description "Source IP address to use" + :argname("<ip>") +spf:flag "-a --all" + :description "Print all records" + +local function printf(fmt, ...) + if fmt then + io.write(string.format(fmt, ...)) + end + io.write('\n') +end + +local function highlight(str) + return ansicolors.white .. str .. ansicolors.reset +end + +local function green(str) + return ansicolors.green .. str .. ansicolors.reset +end + +local function red(str) + return ansicolors.red .. str .. ansicolors.reset +end + +local function load_config(opts) + local _r, err = rspamd_config:load_ucl(opts['config']) + + if not _r then + rspamd_logger.errx('cannot parse %s: %s', opts['config'], err) + os.exit(1) + end + + _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' }) + if not _r then + rspamd_logger.errx('cannot process %s: %s', opts['config'], err) + os.exit(1) + end +end + +local function spf_handler(opts) + local rspamd_spf = require "rspamd_spf" + local rspamd_task = require "rspamd_task" + local rspamd_ip = require "rspamd_ip" + + local task = rspamd_task:create(rspamd_config, rspamadm_ev_base) + task:set_session(rspamadm_session) + task:set_resolver(rspamadm_dns_resolver) + + if opts.ip then + opts.ip = rspamd_ip.fromstring(opts.ip) + task:set_from_ip(opts.ip) + else + opts.all = true + end + + if opts.from then + local rspamd_parsers = require "rspamd_parsers" + local addr_parsed = rspamd_parsers.parse_mail_address(opts.from) + if addr_parsed then + task:set_from('smtp', addr_parsed[1]) + else + io.stderr:write('Invalid from addr\n') + os.exit(1) + end + elseif opts.domain then + task:set_from('smtp', { user = 'user', domain = opts.domain }) + else + io.stderr:write('Neither domain nor from specified\n') + os.exit(1) + end + + local function flag_to_str(fl) + if bit.band(fl, rspamd_spf.flags.temp_fail) ~= 0 then + return "temporary failure" + elseif bit.band(fl, rspamd_spf.flags.perm_fail) ~= 0 then + return "permanent failure" + elseif bit.band(fl, rspamd_spf.flags.na) ~= 0 then + return "no spf record" + end + + return "unknown flag: " .. tostring(fl) + end + + local function display_spf_results(elt, colored) + local dec = function(e) + return e + end + local policy_decode = function(e) + if e == rspamd_spf.policy.fail then + return 'reject' + elseif e == rspamd_spf.policy.pass then + return 'pass' + elseif e == rspamd_spf.policy.soft_fail then + return 'soft fail' + elseif e == rspamd_spf.policy.neutral then + return 'neutral' + end + + return 'unknown' + end + + if colored then + dec = function(e) + return highlight(e) + end + + if elt.result == rspamd_spf.policy.pass then + dec = function(e) + return green(e) + end + elseif elt.result == rspamd_spf.policy.fail then + dec = function(e) + return red(e) + end + end + + end + printf('%s: %s', highlight('Policy'), dec(policy_decode(elt.result))) + printf('%s: %s', highlight('Network'), dec(elt.addr)) + + if elt.str then + printf('%s: %s', highlight('Original'), elt.str) + end + end + + local function cb(record, flags, err) + if record then + local result, flag_or_policy, error_or_addr + if opts.ip then + result, flag_or_policy, error_or_addr = record:check_ip(opts.ip) + elseif opts.all then + result = true + end + if opts.ip and not opts.all then + if result then + display_spf_results(error_or_addr, true) + else + printf('Not matched: %s', error_or_addr) + end + + os.exit(0) + end + + if result then + printf('SPF record for %s; digest: %s', + highlight(opts.domain or opts.from), highlight(record:get_digest())) + for _, elt in ipairs(record:get_elts()) do + if result and error_or_addr and elt.str and elt.str == error_or_addr.str then + printf("%s", highlight('*** Matched ***')) + display_spf_results(elt, true) + printf('------') + else + display_spf_results(elt, false) + printf('------') + end + end + else + printf('Error getting SPF record: %s (%s flag)', err, + flag_to_str(flag_or_policy or flags)) + end + else + printf('Cannot get SPF record: %s', err) + end + end + rspamd_spf.resolve(task, cb) +end + +local function handler(args) + local opts = parser:parse(args) + load_config(opts) + + local command = opts.command + + if command == 'spf' then + spf_handler(opts) + else + parser:error('command %s is not implemented', command) + end +end + +return { + name = 'dnstool', + aliases = { 'dns', 'dns_tool' }, + handler = handler, + description = parser._description +}
\ No newline at end of file diff --git a/lualib/rspamadm/fuzzy_convert.lua b/lualib/rspamadm/fuzzy_convert.lua new file mode 100644 index 0000000..fab3995 --- /dev/null +++ b/lualib/rspamadm/fuzzy_convert.lua @@ -0,0 +1,208 @@ +local sqlite3 = require "rspamd_sqlite3" +local redis = require "rspamd_redis" +local util = require "rspamd_util" + +local function connect_redis(server, username, password, db) + local ret + local conn, err = redis.connect_sync({ + host = server, + }) + + if not conn then + return nil, 'Cannot connect: ' .. err + end + + if username then + if password then + ret = conn:add_cmd('AUTH', { username, password }) + if not ret then + return nil, 'Cannot queue command' + end + else + return nil, 'Redis requires a password when username is supplied' + end + elseif password then + ret = conn:add_cmd('AUTH', { password }) + if not ret then + return nil, 'Cannot queue command' + end + end + if db then + ret = conn:add_cmd('SELECT', { db }) + if not ret then + return nil, 'Cannot queue command' + end + end + + return conn, nil +end + +local function send_digests(digests, redis_host, redis_username, redis_password, redis_db) + local conn, err = connect_redis(redis_host, redis_username, redis_password, redis_db) + if err then + print(err) + return false + end + local ret + for _, v in ipairs(digests) do + ret = conn:add_cmd('HMSET', { + 'fuzzy' .. v[1], + 'F', v[2], + 'V', v[3], + }) + if not ret then + print('Cannot batch command') + return false + end + ret = conn:add_cmd('EXPIRE', { + 'fuzzy' .. v[1], + tostring(v[4]), + }) + if not ret then + print('Cannot batch command') + return false + end + end + ret, err = conn:exec() + if not ret then + print('Cannot execute batched commands: ' .. err) + return false + end + return true +end + +local function send_shingles(shingles, redis_host, redis_username, redis_password, redis_db) + local conn, err = connect_redis(redis_host, redis_username, redis_password, redis_db) + if err then + print("Redis error: " .. err) + return false + end + local ret + for _, v in ipairs(shingles) do + ret = conn:add_cmd('SET', { + 'fuzzy_' .. v[2] .. '_' .. v[1], + v[4], + }) + if not ret then + print('Cannot batch SET command: ' .. err) + return false + end + ret = conn:add_cmd('EXPIRE', { + 'fuzzy_' .. v[2] .. '_' .. v[1], + tostring(v[3]), + }) + if not ret then + print('Cannot batch command') + return false + end + end + ret, err = conn:exec() + if not ret then + print('Cannot execute batched commands: ' .. err) + return false + end + return true +end + +local function update_counters(total, redis_host, redis_username, redis_password, redis_db) + local conn, err = connect_redis(redis_host, redis_username, redis_password, redis_db) + if err then + print(err) + return false + end + local ret + ret = conn:add_cmd('SET', { + 'fuzzylocal', + total, + }) + if not ret then + print('Cannot batch command') + return false + end + ret = conn:add_cmd('SET', { + 'fuzzy_count', + total, + }) + if not ret then + print('Cannot batch command') + return false + end + ret, err = conn:exec() + if not ret then + print('Cannot execute batched commands: ' .. err) + return false + end + return true +end + +return function(_, res) + local db = sqlite3.open(res['source_db']) + local shingles = {} + local digests = {} + local num_batch_digests = 0 + local num_batch_shingles = 0 + local total_digests = 0 + local total_shingles = 0 + local lim_batch = 1000 -- Update each 1000 entries + local redis_username = res['redis_username'] + local redis_password = res['redis_password'] + local redis_db = nil + + if res['redis_db'] then + redis_db = tostring(res['redis_db']) + end + + if not db then + print('Cannot open source db: ' .. res['source_db']) + return + end + + local now = util.get_time() + for row in db:rows('SELECT id, flag, digest, value, time FROM digests') do + + local expire_in = math.floor(now - row.time + res['expiry']) + if expire_in >= 1 then + table.insert(digests, { row.digest, row.flag, row.value, expire_in }) + num_batch_digests = num_batch_digests + 1 + total_digests = total_digests + 1 + for srow in db:rows('SELECT value, number FROM shingles WHERE digest_id = ' .. row.id) do + table.insert(shingles, { srow.value, srow.number, expire_in, row.digest }) + total_shingles = total_shingles + 1 + num_batch_shingles = num_batch_shingles + 1 + end + end + if num_batch_digests >= lim_batch then + if not send_digests(digests, res['redis_host'], redis_username, redis_password, redis_db) then + return + end + num_batch_digests = 0 + digests = {} + end + if num_batch_shingles >= lim_batch then + if not send_shingles(shingles, res['redis_host'], redis_username, redis_password, redis_db) then + return + end + num_batch_shingles = 0 + shingles = {} + end + end + if digests[1] then + if not send_digests(digests, res['redis_host'], redis_username, redis_password, redis_db) then + return + end + end + if shingles[1] then + if not send_shingles(shingles, res['redis_host'], redis_username, redis_password, redis_db) then + return + end + end + + local message = string.format( + 'Migrated %d digests and %d shingles', + total_digests, total_shingles + ) + if not update_counters(total_digests, res['redis_host'], redis_username, redis_password, redis_db) then + message = message .. ' but failed to update counters' + end + print(message) +end diff --git a/lualib/rspamadm/fuzzy_ping.lua b/lualib/rspamadm/fuzzy_ping.lua new file mode 100644 index 0000000..e0345da --- /dev/null +++ b/lualib/rspamadm/fuzzy_ping.lua @@ -0,0 +1,259 @@ +--[[ +Copyright (c) 2023, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local argparse = require "argparse" +local ansicolors = require "ansicolors" +local rspamd_logger = require "rspamd_logger" +local lua_util = require "lua_util" + +local E = {} + +local parser = argparse() + :name 'rspamadm fuzzy_ping' + :description 'Pings fuzzy storage' + :help_description_margin(30) +parser:option "-c --config" + :description "Path to config file" + :argname("<cfg>") + :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf") +parser:option "-r --rule" + :description "Storage to ping (must be configured in Rspamd configuration)" + :argname("<name>") + :default("rspamd.com") +parser:flag "-f --flood" + :description "Flood mode (no waiting for replies)" +parser:flag "-S --silent" + :description "Silent mode (statistics only)" +parser:option "-t --timeout" + :description "Timeout for requests" + :argname("<timeout>") + :convert(tonumber) + :default(5) +parser:option "-s --server" + :description "Override server to ping" + :argname("<name>") +parser:option "-n --number" + :description "Timeout for requests" + :argname("<number>") + :convert(tonumber) + :default(5) +parser:flag "-l --list" + :description "List configured storages" + +local function load_config(opts) + local _r, err = rspamd_config:load_ucl(opts['config']) + + if not _r then + rspamd_logger.errx('cannot parse %s: %s', opts['config'], err) + os.exit(1) + end + + -- Init the real structure excluding logging and workers + _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' }) + if not _r then + rspamd_logger.errx('cannot process %s: %s', opts['config'], err) + os.exit(1) + end + + _r, err = rspamd_config:init_modules() + if not _r then + rspamd_logger.errx('cannot init modules from %s: %s', opts['config'], err) + os.exit(1) + end +end + +local function highlight(fmt, ...) + return ansicolors.white .. string.format(fmt, ...) .. ansicolors.reset +end + +local function highlight_err(fmt, ...) + return ansicolors.red .. string.format(fmt, ...) .. ansicolors.reset +end + +local function print_storages(rules) + for n, rule in pairs(rules) do + print(highlight('Rule: %s', n)) + print(string.format("\tRead only: %s", rule.read_only)) + print(string.format("\tServers: %s", table.concat(lua_util.values(rule.servers), ','))) + print("\tFlags:") + + for fl, id in pairs(rule.flags or E) do + print(string.format("\t\t%s: %s", fl, id)) + end + end +end + +local function std_mean(tbl) + local function mean() + local sum = 0 + local count = 0 + + for _, v in ipairs(tbl) do + sum = sum + v + count = count + 1 + end + + return (sum / count) + end + + local m + local vm + local sum = 0 + local count = 0 + local result + + m = mean(tbl) + + for _, v in ipairs(tbl) do + vm = v - m + sum = sum + (vm * vm) + count = count + 1 + end + + result = math.sqrt(sum / (count - 1)) + + return result, m +end + +local function maxmin(tbl) + local max = -math.huge + local min = math.huge + + for _, v in ipairs(tbl) do + max = math.max(max, v) + min = math.min(min, v) + end + + return max, min +end + +local function print_results(results) + local servers = {} + local err_servers = {} + for _, res in ipairs(results) do + if res.success then + if servers[res.server] then + table.insert(servers[res.server], res.latency) + else + servers[res.server] = { res.latency } + end + else + if err_servers[res.server] then + err_servers[res.server] = err_servers[res.server] + 1 + else + err_servers[res.server] = 1 + end + -- For the case if no successful replies are detected + if not servers[res.server] then + servers[res.server] = {} + end + end + end + + for s, l in pairs(servers) do + local total = #l + (err_servers[s] or 0) + print(highlight('Summary for %s: %d packets transmitted, %d packets received, %.1f%% packet loss', + s, total, #l, (total - #l) * 100.0 / total)) + local mean, std = std_mean(l) + local max, min = maxmin(l) + print(string.format('round-trip min/avg/max/std-dev = %.2f/%.2f/%.2f/%.2f ms', + min, mean, + max, std)) + end +end + +local function handler(args) + local opts = parser:parse(args) + + load_config(opts) + + if opts.list then + print_storages(rspamd_plugins.fuzzy_check.list_storages(rspamd_config)) + os.exit(0) + end + + -- Perform ping using a fake task from async stuff provided by rspamadm + local rspamd_task = require "rspamd_task" + + -- TODO: this task is not cleared at the end, do something about it some day + local task = rspamd_task.create(rspamd_config, rspamadm_ev_base) + task:set_session(rspamadm_session) + task:set_resolver(rspamadm_dns_resolver) + + local replied = 0 + local results = {} + local ping_fuzzy + + local function gen_ping_fuzzy_cb(num) + return function(success, server, latency_or_err) + if not success then + if not opts.silent then + print(highlight_err('error from %s: %s', server, latency_or_err)) + end + results[num] = { + success = false, + error = latency_or_err, + server = tostring(server), + } + else + if not opts.silent then + local adjusted_latency = math.floor(latency_or_err * 1000) * 1.0 / 1000; + print(highlight('reply from %s: %s ms', server, adjusted_latency)) + + end + results[num] = { + success = true, + latency = latency_or_err, + server = tostring(server), + } + end + + if replied == opts.number - 1 then + print_results(results) + else + replied = replied + 1 + if not opts.flood then + ping_fuzzy(replied + 1) + end + end + end + end + + ping_fuzzy = function(num) + local ret, err = rspamd_plugins.fuzzy_check.ping_storage(task, gen_ping_fuzzy_cb(num), + opts.rule, opts.timeout, opts.server) + + if not ret then + print(highlight_err('error from %s: %s', opts.server, err)) + opts.number = opts.number - 1 -- To avoid issues with waiting for other replies + end + end + + if opts.flood then + for i = 1, opts.number do + ping_fuzzy(i) + end + else + ping_fuzzy(1) + end +end + +return { + name = 'fuzzy_ping', + aliases = { 'fuzzyping' }, + handler = handler, + description = parser._description +}
\ No newline at end of file diff --git a/lualib/rspamadm/fuzzy_stat.lua b/lualib/rspamadm/fuzzy_stat.lua new file mode 100644 index 0000000..ef8a5de --- /dev/null +++ b/lualib/rspamadm/fuzzy_stat.lua @@ -0,0 +1,366 @@ +local rspamd_util = require "rspamd_util" +local lua_util = require "lua_util" +local opts = {} + +local argparse = require "argparse" +local parser = argparse() + :name "rspamadm control fuzzystat" + :description "Shows help for the specified configuration options" + :help_description_margin(32) +parser:flag "--no-ips" + :description "No IPs stats" +parser:flag "--no-keys" + :description "No keys stats" +parser:flag "--short" + :description "Short output mode" +parser:flag "-n --number" + :description "Disable numbers humanization" +parser:option "--sort" + :description "Sort order" + :convert { + checked = "checked", + matched = "matched", + errors = "errors", + name = "name" +} + +local function add_data(target, src) + for k, v in pairs(src) do + if type(v) == 'number' then + if target[k] then + target[k] = target[k] + v + else + target[k] = v + end + elseif k == 'ips' then + if not target['ips'] then + target['ips'] = {} + end + -- Iterate over IPs + for ip, st in pairs(v) do + if not target['ips'][ip] then + target['ips'][ip] = {} + end + add_data(target['ips'][ip], st) + end + elseif k == 'flags' then + if not target['flags'] then + target['flags'] = {} + end + -- Iterate over Flags + for flag, st in pairs(v) do + if not target['flags'][flag] then + target['flags'][flag] = {} + end + add_data(target['flags'][flag], st) + end + elseif k == 'keypair' then + if type(v.extensions) == 'table' then + if type(v.extensions.name) == 'string' then + target.name = v.extensions.name + end + end + end + end +end + +local function print_num(num) + if num then + if opts['n'] or opts['number'] then + return tostring(num) + else + return rspamd_util.humanize_number(num) + end + else + return 'na' + end +end + +local function print_stat(st, tabs) + if st['checked'] then + if st.checked_per_hour then + print(string.format('%sChecked: %s (%s per hour in average)', tabs, + print_num(st['checked']), print_num(st['checked_per_hour']))) + else + print(string.format('%sChecked: %s', tabs, print_num(st['checked']))) + end + end + if st['matched'] then + if st.checked and st.checked > 0 and st.checked <= st.matched then + local percentage = st.matched / st.checked * 100.0 + if st.matched_per_hour then + print(string.format('%sMatched: %s - %s percent (%s per hour in average)', tabs, + print_num(st['matched']), percentage, print_num(st['matched_per_hour']))) + else + print(string.format('%sMatched: %s - %s percent', tabs, print_num(st['matched']), percentage)) + end + else + if st.matched_per_hour then + print(string.format('%sMatched: %s (%s per hour in average)', tabs, + print_num(st['matched']), print_num(st['matched_per_hour']))) + else + print(string.format('%sMatched: %s', tabs, print_num(st['matched']))) + end + end + end + if st['errors'] then + print(string.format('%sErrors: %s', tabs, print_num(st['errors']))) + end + if st['added'] then + print(string.format('%sAdded: %s', tabs, print_num(st['added']))) + end + if st['deleted'] then + print(string.format('%sDeleted: %s', tabs, print_num(st['deleted']))) + end +end + +-- Sort by checked +local function sort_hash_table(tbl, sort_opts, key_key) + local res = {} + for k, v in pairs(tbl) do + table.insert(res, { [key_key] = k, data = v }) + end + + local function sort_order(elt) + local key = 'checked' + local sort_res = 0 + + if sort_opts['sort'] then + if sort_opts['sort'] == 'matched' then + key = 'matched' + elseif sort_opts['sort'] == 'errors' then + key = 'errors' + elseif sort_opts['sort'] == 'name' then + return elt[key_key] + end + end + + if elt.data[key] then + sort_res = elt.data[key] + end + + return sort_res + end + + table.sort(res, function(a, b) + return sort_order(a) > sort_order(b) + end) + + return res +end + +local function add_result(dst, src, k) + if type(src) == 'table' then + if type(dst) == 'number' then + -- Convert dst to table + dst = { dst } + elseif type(dst) == 'nil' then + dst = {} + end + + for i, v in ipairs(src) do + if dst[i] and k ~= 'fuzzy_stored' then + dst[i] = dst[i] + v + else + dst[i] = v + end + end + else + if type(dst) == 'table' then + if k ~= 'fuzzy_stored' then + dst[1] = dst[1] + src + else + dst[1] = src + end + else + if dst and k ~= 'fuzzy_stored' then + dst = dst + src + else + dst = src + end + end + end + + return dst +end + +local function print_result(r) + local function num_to_epoch(num) + if num == 1 then + return 'v0.6' + elseif num == 2 then + return 'v0.8' + elseif num == 3 then + return 'v0.9' + elseif num == 4 then + return 'v1.0+' + elseif num == 5 then + return 'v1.7+' + end + return '???' + end + if type(r) == 'table' then + local res = {} + for i, num in ipairs(r) do + res[i] = string.format('(%s: %s)', num_to_epoch(i), print_num(num)) + end + + return table.concat(res, ', ') + end + + return print_num(r) +end + +return function(args, res) + local res_ips = {} + local res_databases = {} + local wrk = res['workers'] + opts = parser:parse(args) + + if wrk then + for _, pr in pairs(wrk) do + -- processes cycle + if pr['data'] then + local id = pr['id'] + + if id then + local res_db = res_databases[id] + if not res_db then + res_db = { + keys = {} + } + res_databases[id] = res_db + end + + -- General stats + for k, v in pairs(pr['data']) do + if k ~= 'keys' and k ~= 'errors_ips' then + res_db[k] = add_result(res_db[k], v, k) + elseif k == 'errors_ips' then + -- Errors ips + if not res_db['errors_ips'] then + res_db['errors_ips'] = {} + end + for ip, nerrors in pairs(v) do + if not res_db['errors_ips'][ip] then + res_db['errors_ips'][ip] = nerrors + else + res_db['errors_ips'][ip] = nerrors + res_db['errors_ips'][ip] + end + end + end + end + + if pr['data']['keys'] then + local res_keys = res_db['keys'] + if not res_keys then + res_keys = {} + res_db['keys'] = res_keys + end + -- Go through keys in input + for k, elts in pairs(pr['data']['keys']) do + -- keys cycle + if not res_keys[k] then + res_keys[k] = {} + end + + add_data(res_keys[k], elts) + + if elts['ips'] then + for ip, v in pairs(elts['ips']) do + if not res_ips[ip] then + res_ips[ip] = {} + end + add_data(res_ips[ip], v) + end + end + end + end + end + end + end + end + + -- General stats + for db, st in pairs(res_databases) do + print(string.format('Statistics for storage %s', db)) + + for k, v in pairs(st) do + if k ~= 'keys' and k ~= 'errors_ips' then + print(string.format('%s: %s', k, print_result(v))) + end + end + print('') + + local res_keys = st['keys'] + if res_keys and not opts['no_keys'] and not opts['short'] then + print('Keys statistics:') + -- Convert into an array to allow sorting + local sorted_keys = sort_hash_table(res_keys, opts, 'key') + + for _, key in ipairs(sorted_keys) do + local key_stat = key.data + if key_stat.name then + print(string.format('Key id: %s, name: %s', key.key, key_stat.name)) + else + print(string.format('Key id: %s', key.key)) + end + + print_stat(key_stat, '\t') + + if key_stat['ips'] and not opts['no_ips'] then + print('') + print('\tIPs stat:') + local sorted_ips = sort_hash_table(key_stat['ips'], opts, 'ip') + + for _, v in ipairs(sorted_ips) do + print(string.format('\t%s', v['ip'])) + print_stat(v['data'], '\t\t') + print('') + end + end + + if key_stat.flags then + print('') + print('\tFlags stat:') + for flag, v in pairs(key_stat.flags) do + print(string.format('\t[%s]:', flag)) + -- Remove irrelevant fields + v.checked = nil + print_stat(v, '\t\t') + print('') + end + end + + print('') + end + end + if st['errors_ips'] and not opts['no_ips'] and not opts['short'] then + print('') + print('Errors IPs statistics:') + local ip_stat = st['errors_ips'] + local ips = lua_util.keys(ip_stat) + -- Reverse sort by number of errors + table.sort(ips, function(a, b) + return ip_stat[a] > ip_stat[b] + end) + for _, ip in ipairs(ips) do + print(string.format('%s: %s', ip, print_result(ip_stat[ip]))) + end + print('') + end + end + + if not opts['no_ips'] and not opts['short'] then + print('') + print('IPs statistics:') + + local sorted_ips = sort_hash_table(res_ips, opts, 'ip') + for _, v in ipairs(sorted_ips) do + print(string.format('%s', v['ip'])) + print_stat(v['data'], '\t') + print('') + end + end +end + diff --git a/lualib/rspamadm/grep.lua b/lualib/rspamadm/grep.lua new file mode 100644 index 0000000..6ed0569 --- /dev/null +++ b/lualib/rspamadm/grep.lua @@ -0,0 +1,174 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local argparse = require "argparse" + + +-- Define command line options +local parser = argparse() + :name "rspamadm grep" + :description "Search for patterns in rspamd logs" + :help_description_margin(30) +parser:mutex( + parser:option "-s --string" + :description('Plain string to search (case-insensitive)') + :argname "<str>", + parser:option "-p --pattern" + :description('Pattern to search for (regex)') + :argname "<re>" +) +parser:flag "-l --lua" + :description('Use Lua patterns in string search') + +parser:argument "input":args "*" + :description('Process specified inputs') + :default("stdin") +parser:flag "-S --sensitive" + :description('Enable case-sensitivity in string search') +parser:flag "-o --orphans" + :description('Print orphaned logs') +parser:flag "-P --partial" + :description('Print partial logs') + +local function handler(args) + + local rspamd_regexp = require 'rspamd_regexp' + local res = parser:parse(args) + + if not res['string'] and not res['pattern'] then + parser:error('string or pattern options must be specified') + end + + if res['string'] and res['pattern'] then + parser:error('string and pattern options are mutually exclusive') + end + + local buffer = {} + local matches = {} + + local pattern = res['pattern'] + local re + if pattern then + re = rspamd_regexp.create(pattern) + if not re then + io.stderr:write("Couldn't compile regex: " .. pattern .. '\n') + os.exit(1) + end + end + + local plainm = true + if res['lua'] then + plainm = false + end + local orphans = res['orphans'] + local search_str = res['string'] + local sensitive = res['sensitive'] + local partial = res['partial'] + if search_str and not sensitive then + search_str = string.lower(search_str) + end + local inputs = res['input'] or { 'stdin' } + + for _, n in ipairs(inputs) do + local h, err + if string.match(n, '%.xz$') then + h, err = io.popen('xzcat ' .. n, 'r') + elseif string.match(n, '%.bz2$') then + h, err = io.popen('bzcat ' .. n, 'r') + elseif string.match(n, '%.gz$') then + h, err = io.popen('zcat ' .. n, 'r') + elseif string.match(n, '%.zst$') then + h, err = io.popen('zstdcat ' .. n, 'r') + elseif n == 'stdin' then + h = io.input() + else + h, err = io.open(n, 'r') + end + if not h then + if err then + io.stderr:write("Couldn't open file (" .. n .. '): ' .. err .. '\n') + else + io.stderr:write("Couldn't open file (" .. n .. '): no error\n') + end + else + for line in h:lines() do + local hash = string.match(line, '<(%x+)>') + local already_matching = false + if hash then + if matches[hash] then + table.insert(matches[hash], line) + already_matching = true + else + if buffer[hash] then + table.insert(buffer[hash], line) + else + buffer[hash] = { line } + end + end + end + local ismatch = false + if re then + ismatch = re:match(line) + elseif sensitive and search_str then + ismatch = string.find(line, search_str, 1, plainm) + elseif search_str then + local lwr = string.lower(line) + ismatch = string.find(lwr, search_str, 1, plainm) + end + if ismatch then + if not hash then + if orphans then + print('*** orphaned ***') + print(line) + print() + end + elseif not already_matching then + matches[hash] = buffer[hash] + end + end + local is_end = string.match(line, '<%x+>; task; rspamd_protocol_http_reply:') + if is_end then + buffer[hash] = nil + if matches[hash] then + for _, v in ipairs(matches[hash]) do + print(v) + end + print() + matches[hash] = nil + end + end + end + if partial then + for k, v in pairs(matches) do + print('*** partial ***') + for _, vv in ipairs(v) do + print(vv) + end + print() + matches[k] = nil + end + else + matches = {} + end + end + end +end + +return { + handler = handler, + description = parser._description, + name = 'grep' +}
\ No newline at end of file diff --git a/lualib/rspamadm/keypair.lua b/lualib/rspamadm/keypair.lua new file mode 100644 index 0000000..f0716a2 --- /dev/null +++ b/lualib/rspamadm/keypair.lua @@ -0,0 +1,508 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local argparse = require "argparse" +local rspamd_keypair = require "rspamd_cryptobox_keypair" +local rspamd_pubkey = require "rspamd_cryptobox_pubkey" +local rspamd_signature = require "rspamd_cryptobox_signature" +local rspamd_crypto = require "rspamd_cryptobox" +local rspamd_util = require "rspamd_util" +local ucl = require "ucl" +local logger = require "rspamd_logger" + +-- Define command line options +local parser = argparse() + :name "rspamadm keypair" + :description "Manages keypairs for Rspamd" + :help_description_margin(30) + :command_target("command") + :require_command(false) + +-- Generate subcommand +local generate = parser:command "generate gen g" + :description "Creates a new keypair" +generate:flag "-s --sign" + :description "Generates a sign keypair instead of the encryption one" +generate:flag "-n --nist" + :description "Uses nist encryption algorithm" +generate:option "-o --output" + :description "Write keypair to file" + :argname "<file>" +generate:mutex( + generate:flag "-j --json" + :description "Output JSON instead of UCL", + generate:flag "-u --ucl" + :description "Output UCL" + :default(true) +) +generate:option "--name" + :description "Adds name extension" + :argname "<name>" + +-- Sign subcommand + +local sign = parser:command "sign sig s" + :description "Signs a file using keypair" +sign:option "-k --keypair" + :description "Keypair to use" + :argname "<file>" +sign:option "-s --suffix" + :description "Suffix for signature" + :argname "<suffix>" + :default("sig") +sign:argument "file" + :description "File to sign" + :argname "<file>" + :args "*" + +-- Verify subcommand + +local verify = parser:command "verify ver v" + :description "Verifies a file using keypair or a public key" +verify:mutex( + verify:option "-p --pubkey" + :description "Load pubkey from the specified file" + :argname "<file>", + verify:option "-P --pubstring" + :description "Load pubkey from the base32 encoded string" + :argname "<base32>", + verify:option "-k --keypair" + :description "Get pubkey from the keypair file" + :argname "<file>" +) +verify:argument "file" + :description "File to verify" + :argname "<file>" + :args "*" +verify:flag "-n --nist" + :description "Uses nistp curves (P256)" +verify:option "-s --suffix" + :description "Suffix for signature" + :argname "<suffix>" + :default("sig") + +-- Encrypt subcommand + +local encrypt = parser:command "encrypt crypt enc e" + :description "Encrypts a file using keypair (or a pubkey)" +encrypt:mutex( + encrypt:option "-p --pubkey" + :description "Load pubkey from the specified file" + :argname "<file>", + encrypt:option "-P --pubstring" + :description "Load pubkey from the base32 encoded string" + :argname "<base32>", + encrypt:option "-k --keypair" + :description "Get pubkey from the keypair file" + :argname "<file>" +) +encrypt:option "-s --suffix" + :description "Suffix for encrypted file" + :argname "<suffix>" + :default("enc") +encrypt:argument "file" + :description "File to encrypt" + :argname "<file>" + :args "*" +encrypt:flag "-r --rm" + :description "Remove unencrypted file" +encrypt:flag "-f --force" + :description "Remove destination file if it exists" + +-- Decrypt subcommand + +local decrypt = parser:command "decrypt dec d" + :description "Decrypts a file using keypair" +decrypt:option "-k --keypair" + :description "Get pubkey from the keypair file" + :argname "<file>" +decrypt:flag "-S --keep-suffix" + :description "Preserve suffix for decrypted file (overwrite encrypted)" +decrypt:argument "file" + :description "File to encrypt" + :argname "<file>" + :args "*" +decrypt:flag "-f --force" + :description "Remove destination file if it exists (implied with -S)" +decrypt:flag "-r --rm" + :description "Remove encrypted file" + +-- Default command is generate, so duplicate options to be compatible + +parser:flag "-s --sign" + :description "Generates a sign keypair instead of the encryption one" +parser:flag "-n --nist" + :description "Uses nistp curves (P256)" +parser:mutex( + parser:flag "-j --json" + :description "Output JSON instead of UCL", + parser:flag "-u --ucl" + :description "Output UCL" + :default(true) +) +parser:option "-o --output" + :description "Write keypair to file" + :argname "<file>" + +local function fatal(...) + logger.errx(...) + os.exit(1) +end + +local function ask_yes_no(greet, default) + local def_str + if default then + greet = greet .. "[Y/n]: " + def_str = "yes" + else + greet = greet .. "[y/N]: " + def_str = "no" + end + + local reply = rspamd_util.readline(greet) + + if not reply then + os.exit(0) + end + if #reply == 0 then + reply = def_str + end + reply = reply:lower() + if reply == 'y' or reply == 'yes' then + return true + end + + return false +end + +local function generate_handler(opts) + local mode = 'encryption' + if opts.sign then + mode = 'sign' + end + local alg = 'curve25519' + if opts.nist then + alg = 'nist' + end + -- TODO: probably, do it in a more safe way + local kp = rspamd_keypair.create(mode, alg):totable() + + if opts.name then + kp.keypair.extensions = { + name = opts.name + } + end + + local format = 'ucl' + + if opts.json then + format = 'json' + end + + if opts.output then + local out = io.open(opts.output, 'w') + if not out then + fatal('cannot open output to write: ' .. opts.output) + end + out:write(ucl.to_format(kp, format)) + out:close() + else + io.write(ucl.to_format(kp, format)) + end +end + +local function sign_handler(opts) + if opts.file then + if type(opts.file) == 'string' then + opts.file = { opts.file } + end + else + parser:error('no files to sign') + end + if not opts.keypair then + parser:error("no keypair specified") + end + + local ucl_parser = ucl.parser() + local res, err = ucl_parser:parse_file(opts.keypair) + + if not res then + fatal(string.format('cannot load %s: %s', opts.keypair, err)) + end + + local kp = rspamd_keypair.load(ucl_parser:get_object()) + + if not kp then + fatal("cannot load keypair: " .. opts.keypair) + end + + for _, fname in ipairs(opts.file) do + local sig = rspamd_crypto.sign_file(kp, fname) + + if not sig then + fatal(string.format("cannot sign %s\n", fname)) + end + + local out = string.format('%s.%s', fname, opts.suffix or 'sig') + local of = io.open(out, 'w') + if not of then + fatal('cannot open output to write: ' .. out) + end + of:write(sig:bin()) + of:close() + io.write(string.format('signed %s -> %s (%s)\n', fname, out, sig:hex())) + end +end + +local function verify_handler(opts) + if opts.file then + if type(opts.file) == 'string' then + opts.file = { opts.file } + end + else + parser:error('no files to verify') + end + + local pk + local alg = 'curve25519' + + if opts.keypair then + local ucl_parser = ucl.parser() + local res, err = ucl_parser:parse_file(opts.keypair) + + if not res then + fatal(string.format('cannot load %s: %s', opts.keypair, err)) + end + + local kp = rspamd_keypair.load(ucl_parser:get_object()) + + if not kp then + fatal("cannot load keypair: " .. opts.keypair) + end + + pk = kp:pk() + alg = kp:alg() + elseif opts.pubkey then + if opts.nist then + alg = 'nist' + end + pk = rspamd_pubkey.load(opts.pubkey, 'sign', alg) + elseif opts.pubstr then + if opts.nist then + alg = 'nist' + end + pk = rspamd_pubkey.create(opts.pubstr, 'sign', alg) + end + + if not pk then + fatal("cannot create pubkey") + end + + local valid = true + + for _, fname in ipairs(opts.file) do + + local sig_fname = string.format('%s.%s', fname, opts.suffix or 'sig') + local sig = rspamd_signature.load(sig_fname, alg) + + if not sig then + fatal(string.format("cannot load signature for %s -> %s", + fname, sig_fname)) + end + + if rspamd_crypto.verify_file(pk, sig, fname, alg) then + io.write(string.format('verified %s -> %s (%s)\n', fname, sig_fname, sig:hex())) + else + valid = false + io.write(string.format('FAILED to verify %s -> %s (%s)\n', fname, + sig_fname, sig:hex())) + end + end + + if not valid then + os.exit(1) + end +end + +local function encrypt_handler(opts) + if opts.file then + if type(opts.file) == 'string' then + opts.file = { opts.file } + end + else + parser:error('no files to sign') + end + + local pk + local alg = 'curve25519' + + if opts.keypair then + local ucl_parser = ucl.parser() + local res, err = ucl_parser:parse_file(opts.keypair) + + if not res then + fatal(string.format('cannot load %s: %s', opts.keypair, err)) + end + + local kp = rspamd_keypair.load(ucl_parser:get_object()) + + if not kp then + fatal("cannot load keypair: " .. opts.keypair) + end + + pk = kp:pk() + alg = kp:alg() + elseif opts.pubkey then + if opts.nist then + alg = 'nist' + end + pk = rspamd_pubkey.load(opts.pubkey, 'sign', alg) + elseif opts.pubstr then + if opts.nist then + alg = 'nist' + end + pk = rspamd_pubkey.create(opts.pubstr, 'sign', alg) + end + + if not pk then + fatal("cannot load keypair: " .. opts.keypair) + end + + for _, fname in ipairs(opts.file) do + local enc = rspamd_crypto.encrypt_file(pk, fname, alg) + + if not enc then + fatal(string.format("cannot encrypt %s\n", fname)) + end + + local out + if opts.suffix and #opts.suffix > 0 then + out = string.format('%s.%s', fname, opts.suffix) + else + out = string.format('%s', fname) + end + + if rspamd_util.file_exists(out) then + if opts.force or ask_yes_no(string.format('File %s already exists, overwrite?', + out), true) then + os.remove(out) + else + os.exit(1) + end + end + + enc:save_in_file(out) + + if opts.rm then + os.remove(fname) + io.write(string.format('encrypted %s (deleted) -> %s\n', fname, out)) + else + io.write(string.format('encrypted %s -> %s\n', fname, out)) + end + end +end + +local function decrypt_handler(opts) + if opts.file then + if type(opts.file) == 'string' then + opts.file = { opts.file } + end + else + parser:error('no files to decrypt') + end + if not opts.keypair then + parser:error("no keypair specified") + end + + local ucl_parser = ucl.parser() + local res, err = ucl_parser:parse_file(opts.keypair) + + if not res then + fatal(string.format('cannot load %s: %s', opts.keypair, err)) + end + + local kp = rspamd_keypair.load(ucl_parser:get_object()) + + if not kp then + fatal("cannot load keypair: " .. opts.keypair) + end + + for _, fname in ipairs(opts.file) do + local decrypted = rspamd_crypto.decrypt_file(kp, fname) + + if not decrypted then + fatal(string.format("cannot decrypt %s\n", fname)) + end + + local out + if not opts['keep-suffix'] then + -- Strip the last suffix + out = fname:match("^(.+)%..+$") + else + out = fname + end + + local removed = false + + if rspamd_util.file_exists(out) then + if (opts.force or opts['keep-suffix']) + or ask_yes_no(string.format('File %s already exists, overwrite?', out), true) then + os.remove(out) + removed = true + else + os.exit(1) + end + end + + if opts.rm then + os.remove(fname) + removed = true + end + + if removed then + io.write(string.format('decrypted %s (removed) -> %s\n', fname, out)) + else + io.write(string.format('decrypted %s -> %s\n', fname, out)) + end + end +end + +local function handler(args) + local opts = parser:parse(args) + + local command = opts.command or "generate" + + if command == 'generate' then + generate_handler(opts) + elseif command == 'sign' then + sign_handler(opts) + elseif command == 'verify' then + verify_handler(opts) + elseif command == 'encrypt' then + encrypt_handler(opts) + elseif command == 'decrypt' then + decrypt_handler(opts) + else + parser:error('command %s is not implemented', command) + end +end + +return { + name = 'keypair', + aliases = { 'kp', 'key' }, + handler = handler, + description = parser._description +}
\ No newline at end of file diff --git a/lualib/rspamadm/mime.lua b/lualib/rspamadm/mime.lua new file mode 100644 index 0000000..6a589d6 --- /dev/null +++ b/lualib/rspamadm/mime.lua @@ -0,0 +1,1012 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local argparse = require "argparse" +local ansicolors = require "ansicolors" +local rspamd_util = require "rspamd_util" +local rspamd_task = require "rspamd_task" +local rspamd_text = require "rspamd_text" +local rspamd_logger = require "rspamd_logger" +local lua_meta = require "lua_meta" +local rspamd_url = require "rspamd_url" +local lua_util = require "lua_util" +local lua_mime = require "lua_mime" +local ucl = require "ucl" + +-- Define command line options +local parser = argparse() + :name "rspamadm mime" + :description "Mime manipulations provided by Rspamd" + :help_description_margin(30) + :command_target("command") + :require_command(true) + +parser:option "-c --config" + :description "Path to config file" + :argname("<cfg>") + :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf") +parser:mutex( + parser:flag "-j --json" + :description "JSON output", + parser:flag "-U --ucl" + :description "UCL output", + parser:flag "-M --messagepack" + :description "MessagePack output" +) +parser:flag "-C --compact" + :description "Use compact format" +parser:flag "--no-file" + :description "Do not print filename" + +-- Extract subcommand +local extract = parser:command "extract ex e" + :description "Extracts data from MIME messages" +extract:argument "file" + :description "File to process" + :argname "<file>" + :args "+" + +extract:flag "-t --text" + :description "Extracts plain text data from a message" +extract:flag "-H --html" + :description "Extracts htm data from a message" +extract:option "-o --output" + :description "Output format ('raw', 'content', 'oneline', 'decoded', 'decoded_utf')" + :argname("<type>") + :convert { + raw = "raw", + content = "content", + oneline = "content_oneline", + decoded = "raw_parsed", + decoded_utf = "raw_utf" +} + :default "content" +extract:flag "-w --words" + :description "Extracts words" +extract:flag "-p --part" + :description "Show part info" +extract:flag "-s --structure" + :description "Show structure info (e.g. HTML tags)" +extract:flag "-i --invisible" + :description "Show invisible content for HTML parts" +extract:option "-F --words-format" + :description "Words format ('stem', 'norm', 'raw', 'full')" + :argname("<type>") + :convert { + stem = "stem", + norm = "norm", + raw = "raw", + full = "full", +} + :default "stem" + +local stat = parser:command "stat st s" + :description "Extracts statistical data from MIME messages" +stat:argument "file" + :description "File to process" + :argname "<file>" + :args "+" +stat:mutex( + stat:flag "-m --meta" + :description "Lua metatokens", + stat:flag "-b --bayes" + :description "Bayes tokens", + stat:flag "-F --fuzzy" + :description "Fuzzy hashes" +) +stat:flag "-s --shingles" + :description "Show shingles for fuzzy hashes" + +local urls = parser:command "urls url u" + :description "Extracts URLs from MIME messages" +urls:argument "file" + :description "File to process" + :argname "<file>" + :args "+" +urls:mutex( + urls:flag "-t --tld" + :description "Get TLDs only", + urls:flag "-H --host" + :description "Get hosts only", + urls:flag "-f --full" + :description "Show piecewise urls as processed by Rspamd" +) + +urls:flag "-u --unique" + :description "Print only unique urls" +urls:flag "-s --sort" + :description "Sort output" +urls:flag "--count" + :description "Print count of each printed element" +urls:flag "-r --reverse" + :description "Reverse sort order" + +local modify = parser:command "modify mod m" + :description "Modifies MIME message" +modify:argument "file" + :description "File to process" + :argname "<file>" + :args "+" + +modify:option "-a --add-header" + :description "Adds specific header" + :argname "<header=value>" + :count "*" +modify:option "-r --remove-header" + :description "Removes specific header (all occurrences)" + :argname "<header>" + :count "*" +modify:option "-R --rewrite-header" + :description "Rewrites specific header, uses Lua string.format pattern" + :argname "<header=pattern>" + :count "*" +modify:option "-t --text-footer" + :description "Adds footer to text/plain parts from a specific file" + :argname "<file>" +modify:option "-H --html-footer" + :description "Adds footer to text/html parts from a specific file" + :argname "<file>" + +local sign = parser:command "sign" + :description "Performs DKIM signing" +sign:argument "file" + :description "File to process" + :argname "<file>" + :args "+" + +sign:option "-d --domain" + :description "Use specific domain" + :argname "<domain>" + :count "1" +sign:option "-s --selector" + :description "Use specific selector" + :argname "<selector>" + :count "1" +sign:option "-k --key" + :description "Use specific key of file" + :argname "<key>" + :count "1" +sign:option "-t --type" + :description "ARC or DKIM signing" + :argname("<arc|dkim>") + :convert { + ['arc'] = 'arc', + ['dkim'] = 'dkim', +} + :default 'dkim' +sign:option "-o --output" + :description "Output format" + :argname("<message|signature>") + :convert { + ['message'] = 'message', + ['signature'] = 'signature', +} + :default 'message' + +local dump = parser:command "dump" + :description "Dumps a raw message in different formats" +dump:argument "file" + :description "File to process" + :argname "<file>" + :args "+" +-- Duplicate format for convenience +dump:mutex( + parser:flag "-j --json" + :description "JSON output", + parser:flag "-U --ucl" + :description "UCL output", + parser:flag "-M --messagepack" + :description "MessagePack output" +) +dump:flag "-s --split" + :description "Split the output file contents such that no content is embedded" + +dump:option "-o --outdir" + :description "Output directory" + :argname("<directory>") + +local function load_config(opts) + local _r, err = rspamd_config:load_ucl(opts['config']) + + if not _r then + rspamd_logger.errx('cannot parse %s: %s', opts['config'], err) + os.exit(1) + end + + _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' }) + if not _r then + rspamd_logger.errx('cannot process %s: %s', opts['config'], err) + os.exit(1) + end +end + +local function load_task(opts, fname) + if not fname then + fname = '-' + end + + local res, task = rspamd_task.load_from_file(fname, rspamd_config) + + if not res then + parser:error(string.format('cannot read message from %s: %s', fname, + task)) + end + + if not task:process_message() then + parser:error(string.format('cannot read message from %s: %s', fname, + 'failed to parse')) + end + + return task +end + +local function highlight(fmt, ...) + return ansicolors.white .. string.format(fmt, ...) .. ansicolors.reset +end + +local function maybe_print_fname(opts, fname) + if not opts.json and not opts['no-file'] then + rspamd_logger.messagex(highlight('File: %s', fname)) + end +end + +local function output_fmt(opts) + local fmt = 'json' + if opts.compact then + fmt = 'json-compact' + end + if opts.ucl then + fmt = 'ucl' + end + if opts.messagepack then + fmt = 'msgpack' + end + + return fmt +end + +-- Print elements in form +-- filename -> table of elements +local function print_elts(elts, opts, func) + local fun = require "fun" + + if opts.json or opts.ucl then + io.write(ucl.to_format(elts, output_fmt(opts))) + else + fun.each(function(fname, elt) + + if not opts.json and not opts.ucl then + if func then + elt = fun.map(func, elt) + end + maybe_print_fname(opts, fname) + fun.each(function(e) + io.write(e) + io.write("\n") + end, elt) + end + end, elts) + end +end + +local function extract_handler(opts) + local out_elts = {} + local tasks = {} + local process_func + + if opts.words then + -- Enable stemming and urls detection + load_config(opts) + rspamd_url.init(rspamd_config:get_tld_path()) + rspamd_config:init_subsystem('langdet') + end + + local function maybe_print_text_part_info(part, out) + local fun = require "fun" + if opts.part then + local t = 'plain text' + if part:is_html() then + t = 'html' + end + + if not opts.json and not opts.ucl then + table.insert(out, + rspamd_logger.slog('Part: %s: %s, language: %s, size: %s (%s raw), words: %s', + part:get_mimepart():get_digest():sub(1, 8), + t, + part:get_language(), + part:get_length(), part:get_raw_length(), + part:get_words_count())) + table.insert(out, + rspamd_logger.slog('Stats: %s', + fun.foldl(function(acc, k, v) + if acc ~= '' then + return string.format('%s, %s:%s', acc, k, v) + else + return string.format('%s:%s', k, v) + end + end, '', part:get_stats()))) + end + end + end + + local function maybe_print_mime_part_info(part, out) + if opts.part then + + if not opts.json and not opts.ucl then + local mtype, msubtype = part:get_type() + local det_mtype, det_msubtype = part:get_detected_type() + table.insert(out, + rspamd_logger.slog('Mime Part: %s: %s/%s (%s/%s detected), filename: %s (%s detected ext), size: %s', + part:get_digest():sub(1, 8), + mtype, msubtype, + det_mtype, det_msubtype, + part:get_filename(), + part:get_detected_ext(), + part:get_length())) + end + end + end + + local function print_words(words, full) + local fun = require "fun" + + if not full then + return table.concat(words, ' ') + else + return table.concat( + fun.totable( + fun.map(function(w) + -- [1] - stemmed word + -- [2] - normalised word + -- [3] - raw word + -- [4] - flags (table of strings) + return string.format('%s|%s|%s(%s)', + w[3], w[2], w[1], table.concat(w[4], ',')) + end, words) + ), + ' ' + ) + end + end + + for _, fname in ipairs(opts.file) do + local task = load_task(opts, fname) + out_elts[fname] = {} + + if not opts.text and not opts.html then + opts.text = true + opts.html = true + end + + if opts.words then + local how_words = opts['words_format'] or 'stem' + table.insert(out_elts[fname], 'meta_words: ' .. + print_words(task:get_meta_words(how_words), how_words == 'full')) + end + + if opts.text or opts.html then + local mp = task:get_parts() or {} + + for _, mime_part in ipairs(mp) do + local how = opts.output + local part + if mime_part:is_text() then + part = mime_part:get_text() + end + + if part and opts.text and not part:is_html() then + maybe_print_text_part_info(part, out_elts[fname]) + maybe_print_mime_part_info(mime_part, out_elts[fname]) + if not opts.json and not opts.ucl then + table.insert(out_elts[fname], '\n') + end + + if opts.words then + local how_words = opts['words_format'] or 'stem' + table.insert(out_elts[fname], print_words(part:get_words(how_words), + how_words == 'full')) + else + table.insert(out_elts[fname], tostring(part:get_content(how))) + end + elseif part and opts.html and part:is_html() then + maybe_print_text_part_info(part, out_elts[fname]) + maybe_print_mime_part_info(mime_part, out_elts[fname]) + if not opts.json and not opts.ucl then + table.insert(out_elts[fname], '\n') + end + + if opts.words then + local how_words = opts['words_format'] or 'stem' + table.insert(out_elts[fname], print_words(part:get_words(how_words), + how_words == 'full')) + else + if opts.structure then + local hc = part:get_html() + local res = {} + process_func = function(elt) + local fun = require "fun" + if type(elt) == 'table' then + return table.concat(fun.totable( + fun.map( + function(t) + return rspamd_logger.slog("%s", t) + end, + elt)), '\n') + else + return rspamd_logger.slog("%s", elt) + end + end + + hc:foreach_tag('any', function(tag) + local elt = {} + local ex = tag:get_extra() + elt.tag = tag:get_type() + if ex then + elt.extra = ex + end + local content = tag:get_content() + if content then + elt.content = tostring(content) + end + local style = tag:get_style() + if style then + elt.style = style + end + table.insert(res, elt) + end) + table.insert(out_elts[fname], res) + else + -- opts.structure + table.insert(out_elts[fname], tostring(part:get_content(how))) + end + if opts.invisible then + local hc = part:get_html() + table.insert(out_elts[fname], string.format('invisible content: %s', + tostring(hc:get_invisible()))) + end + end + end + + if not part then + maybe_print_mime_part_info(mime_part, out_elts[fname]) + end + end + end + + table.insert(out_elts[fname], "") + table.insert(tasks, task) + end + + print_elts(out_elts, opts, process_func) + -- To avoid use after free we postpone tasks destruction + for _, task in ipairs(tasks) do + task:destroy() + end +end + +local function stat_handler(opts) + local fun = require "fun" + local out_elts = {} + + load_config(opts) + rspamd_url.init(rspamd_config:get_tld_path()) + rspamd_config:init_subsystem('langdet,stat') -- Needed to gen stat tokens + + local process_func + + for _, fname in ipairs(opts.file) do + local task = load_task(opts, fname) + out_elts[fname] = {} + + if opts.meta then + local mt = lua_meta.gen_metatokens_table(task) + out_elts[fname] = mt + process_func = function(k, v) + return string.format("%s = %s", k, v) + end + elseif opts.bayes then + local bt = task:get_stat_tokens() + out_elts[fname] = bt + process_func = function(e) + return string.format('%s (%d): "%s"+"%s", [%s]', e.data, e.win, e.t1 or "", + e.t2 or "", table.concat(fun.totable( + fun.map(function(k) + return k + end, e.flags)), ",")) + end + elseif opts.fuzzy then + local parts = task:get_parts() or {} + out_elts[fname] = {} + process_func = function(e) + local ret = string.format('part: %s(%s): %s', e.type, e.file or "", e.digest) + if opts.shingles and e.shingles then + local sgl = {} + for _, s in ipairs(e.shingles) do + table.insert(sgl, string.format('%s: %s+%s+%s', s[1], s[2], s[3], s[4])) + end + + ret = ret .. '\n' .. table.concat(sgl, '\n') + end + return ret + end + for _, part in ipairs(parts) do + if not part:is_multipart() then + local text = part:get_text() + + if text then + local digest, shingles = text:get_fuzzy_hashes(task:get_mempool()) + table.insert(out_elts[fname], { + digest = digest, + shingles = shingles, + type = string.format('%s/%s', + ({ part:get_type() })[1], + ({ part:get_type() })[2]) + }) + else + table.insert(out_elts[fname], { + digest = part:get_digest(), + file = part:get_filename(), + type = string.format('%s/%s', + ({ part:get_type() })[1], + ({ part:get_type() })[2]) + }) + end + end + end + end + + task:destroy() -- No automatic dtor + end + + print_elts(out_elts, opts, process_func) +end + +local function urls_handler(opts) + load_config(opts) + rspamd_url.init(rspamd_config:get_tld_path()) + local out_elts = {} + + if opts.json then + rspamd_logger.messagex('[') + end + + for _, fname in ipairs(opts.file) do + out_elts[fname] = {} + local task = load_task(opts, fname) + local elts = {} + + local function process_url(u) + local s + if opts.tld then + s = u:get_tld() + elseif opts.host then + s = u:get_host() + elseif opts.full then + s = rspamd_logger.slog('%s: %s', u:get_text(), u:to_table()) + else + s = u:get_text() + end + + if opts.unique then + if elts[s] then + elts[s].count = elts[s].count + 1 + else + elts[s] = { + count = 1, + url = u:to_table() + } + end + else + if opts.json then + table.insert(elts, u) + else + table.insert(elts, s) + end + end + end + + for _, u in ipairs(task:get_urls(true)) do + process_url(u) + end + + local json_elts = {} + + local function process_elt(s, u) + if opts.unique then + -- s is string, u is {url = url, count = count } + if not opts.json then + if opts.count then + table.insert(json_elts, string.format('%s : %s', s, u.count)) + else + table.insert(json_elts, s) + end + else + local tb = u.url + tb.count = u.count + table.insert(json_elts, tb) + end + else + -- s is index, u is url or string + if opts.json then + table.insert(json_elts, u) + else + table.insert(json_elts, u) + end + end + end + + if opts.sort then + local sfunc + if opts.unique then + sfunc = function(t, a, b) + if t[a].count ~= t[b].count then + if opts.reverse then + return t[a].count > t[b].count + else + return t[a].count < t[b].count + end + else + -- Sort lexicography + if opts.reverse then + return a > b + else + return a < b + end + end + end + else + sfunc = function(t, a, b) + local va, vb + if opts.json then + va = t[a]:get_text() + vb = t[b]:get_text() + else + va = t[a] + vb = t[b] + end + if opts.reverse then + return va > vb + else + return va < vb + end + end + end + + for s, u in lua_util.spairs(elts, sfunc) do + process_elt(s, u) + end + else + for s, u in pairs(elts) do + process_elt(s, u) + end + end + + out_elts[fname] = json_elts + + task:destroy() -- No automatic dtor + end + + print_elts(out_elts, opts) +end + +local function newline(task) + local t = task:get_newlines_type() + + if t == 'cr' then + return '\r' + elseif t == 'lf' then + return '\n' + end + + return '\r\n' +end + +local function modify_handler(opts) + load_config(opts) + rspamd_url.init(rspamd_config:get_tld_path()) + + local function read_file(file) + local f = assert(io.open(file, "rb")) + local content = f:read("*all") + f:close() + return content + end + + local text_footer, html_footer + + if opts['text_footer'] then + text_footer = read_file(opts['text_footer']) + end + + if opts['html_footer'] then + html_footer = read_file(opts['html_footer']) + end + + for _, fname in ipairs(opts.file) do + local task = load_task(opts, fname) + local newline_s = newline(task) + local seen_cte + + local rewrite = lua_mime.add_text_footer(task, html_footer, text_footer) or {} + local out = {} -- Start with headers + + local function process_headers_cb(name, hdr) + for _, h in ipairs(opts['remove_header']) do + if name:match(h) then + return + end + end + + for _, h in ipairs(opts['rewrite_header']) do + local hname, hpattern = h:match('^([^=]+)=(.+)$') + if hname == name then + local new_value = string.format(hpattern, hdr.decoded) + new_value = string.format('%s:%s%s', + name, hdr.separator, + rspamd_util.fold_header(name, + rspamd_util.mime_header_encode(new_value), + task:get_newlines_type())) + out[#out + 1] = new_value + return + end + end + + if rewrite.need_rewrite_ct then + if name:lower() == 'content-type' then + local nct = string.format('%s: %s/%s; charset=utf-8', + 'Content-Type', rewrite.new_ct.type, rewrite.new_ct.subtype) + out[#out + 1] = nct + return + elseif name:lower() == 'content-transfer-encoding' then + out[#out + 1] = string.format('%s: %s', + 'Content-Transfer-Encoding', rewrite.new_cte or 'quoted-printable') + seen_cte = true + return + end + end + + out[#out + 1] = hdr.raw:gsub('\r?\n?$', '') + end + + task:headers_foreach(process_headers_cb, { full = true }) + + for _, h in ipairs(opts['add_header']) do + local hname, hvalue = h:match('^([^=]+)=(.+)$') + + if hname and hvalue then + out[#out + 1] = string.format('%s: %s', hname, + rspamd_util.fold_header(hname, hvalue, task:get_newlines_type())) + end + end + + if not seen_cte and rewrite.need_rewrite_ct then + out[#out + 1] = string.format('%s: %s', + 'Content-Transfer-Encoding', rewrite.new_cte or 'quoted-printable') + end + + -- End of headers + out[#out + 1] = '' + + if rewrite.out then + for _, o in ipairs(rewrite.out) do + out[#out + 1] = o + end + else + out[#out + 1] = { task:get_rawbody(), false } + end + + for _, o in ipairs(out) do + if type(o) == 'string' then + io.write(o) + io.write(newline_s) + elseif type(o) == 'table' then + io.flush() + if type(o[1]) == 'string' then + io.write(o[1]) + else + o[1]:save_in_file(1) + end + + if o[2] then + io.write(newline_s) + end + else + o:save_in_file(1) + io.write(newline_s) + end + end + + task:destroy() -- No automatic dtor + end +end + +local function sign_handler(opts) + load_config(opts) + rspamd_url.init(rspamd_config:get_tld_path()) + + local lua_dkim = require("lua_ffi").dkim + + if not lua_dkim then + io.stderr:write('FFI support is required: please use LuaJIT or install lua-ffi') + os.exit(1) + end + + local sign_key + if rspamd_util.file_exists(opts.key) then + sign_key = lua_dkim.load_sign_key(opts.key, 'file') + else + sign_key = lua_dkim.load_sign_key(opts.key, 'base64') + end + + if not sign_key then + io.stderr:write('Cannot load key: ' .. opts.key .. '\n') + os.exit(1) + end + + for _, fname in ipairs(opts.file) do + local task = load_task(opts, fname) + local ctx = lua_dkim.create_sign_context(task, sign_key, nil, opts.algorithm) + + if not ctx then + io.stderr:write('Cannot init signing\n') + os.exit(1) + end + + local sig = lua_dkim.do_sign(task, ctx, opts.selector, opts.domain) + + if not sig then + io.stderr:write('Cannot create signature\n') + os.exit(1) + end + + if opts.output == 'signature' then + io.write(sig) + io.write('\n') + io.flush() + else + local dkim_hdr = string.format('%s: %s%s', + 'DKIM-Signature', + rspamd_util.fold_header('DKIM-Signature', + rspamd_util.mime_header_encode(sig), + task:get_newlines_type()), + newline(task)) + io.write(dkim_hdr) + io.flush() + task:get_content():save_in_file(1) + end + + task:destroy() -- No automatic dtor + end +end + +-- Strips directories and .extensions (if present) from a filepath +local function filename_only(filepath) + local filename = filepath:match(".*%/([^%.]+)") + if not filename then + filename = filepath:match("([^%.]+)") + end + return filename +end + +assert(filename_only("very_simple") == "very_simple") +assert(filename_only("/home/very_simple.eml") == "very_simple") +assert(filename_only("very_simple.eml") == "very_simple") +assert(filename_only("very_simple.example.eml") == "very_simple") +assert(filename_only("/home/very_simple") == "very_simple") +assert(filename_only("home/very_simple") == "very_simple") +assert(filename_only("./home/very_simple") == "very_simple") +assert(filename_only("../home/very_simple.eml") == "very_simple") +assert(filename_only("/home/dir.with.dots/very_simple.eml") == "very_simple") + +--Write the dump content to file or standard out +local function write_dump_content(dump_content, fname, extension, outdir) + if type(dump_content) == "string" then + dump_content = rspamd_text.fromstring(dump_content) + end + + local wrote_filepath = nil + if outdir then + if outdir:sub(-1) ~= "/" then + outdir = outdir .. "/" + end + + local outpath = string.format("%s%s.%s", outdir, filename_only(fname), extension) + if rspamd_util.file_exists(outpath) then + os.remove(outpath) + end + if dump_content:save_in_file(outpath) then + wrote_filepath = outpath + io.write(wrote_filepath .. "\n") + else + io.stderr:write(string.format("Unable to save dump content to file: %s\n", outpath)) + end + else + dump_content:save_in_file(1) + end + return wrote_filepath +end + +-- Get the formatted ucl (split or unsplit) or the raw task content +local function get_dump_content(task, opts, fname) + if opts.ucl or opts.json or opts.messagepack then + local ucl_object = lua_mime.message_to_ucl(task) + -- Split out the content field into separate raws and update the ucl + if opts.split then + for i, part in ipairs(ucl_object.parts) do + if part.content then + local part_filename = string.format("%s-part%d", filename_only(fname), i) + local part_path = write_dump_content(part.content, part_filename, "raw", opts.outdir) + if part_path then + part.content = ucl.null + part.content_path = part_path + end + end + end + end + local extension = output_fmt(opts) + return ucl.to_format(ucl_object, extension), extension + end + return task:get_content(), "mime" +end + +local function dump_handler(opts) + load_config(opts) + rspamd_url.init(rspamd_config:get_tld_path()) + + for _, fname in ipairs(opts.file) do + local task = load_task(opts, fname) + local data, extension = get_dump_content(task, opts, fname) + write_dump_content(data, fname, extension, opts.outdir) + + task:destroy() -- No automatic dtor + end +end + +local function handler(args) + local opts = parser:parse(args) + + local command = opts.command + + if type(opts.file) == 'string' then + opts.file = { opts.file } + elseif type(opts.file) == 'none' then + opts.file = {} + end + + if command == 'extract' then + extract_handler(opts) + elseif command == 'stat' then + stat_handler(opts) + elseif command == 'urls' then + urls_handler(opts) + elseif command == 'modify' then + modify_handler(opts) + elseif command == 'sign' then + sign_handler(opts) + elseif command == 'dump' then + dump_handler(opts) + else + parser:error('command %s is not implemented', command) + end +end + +return { + name = 'mime', + aliases = { 'mime_tool' }, + handler = handler, + description = parser._description +} diff --git a/lualib/rspamadm/neural_test.lua b/lualib/rspamadm/neural_test.lua new file mode 100644 index 0000000..31d21a9 --- /dev/null +++ b/lualib/rspamadm/neural_test.lua @@ -0,0 +1,228 @@ +local rspamd_logger = require "rspamd_logger" +local argparse = require "argparse" +local lua_util = require "lua_util" +local ucl = require "ucl" + +local parser = argparse() + :name "rspamadm neural_test" + :description "Test the neural network with labelled dataset" + :help_description_margin(32) + +parser:option "-c --config" + :description "Path to config file" + :argname("<cfg>") + :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf") +parser:option "-H --hamdir" + :description("Ham directory") + :argname("<dir>") +parser:option "-S --spamdir" + :description("Spam directory") + :argname("<dir>") +parser:option "-t --timeout" + :description("Timeout for client connections") + :argname("<sec>") + :convert(tonumber) + :default(60) +parser:option "-n --conns" + :description("Number of parallel connections") + :argname("<N>") + :convert(tonumber) + :default(10) +parser:option "-c --connect" + :description("Connect to specific host") + :argname("<host>") + :default('localhost:11334') +parser:option "-r --rspamc" + :description("Use specific rspamc path") + :argname("<path>") + :default('rspamc') +parser:option '--rule' + :description 'Rule to test' + :argname('<rule>') + +local HAM = "HAM" +local SPAM = "SPAM" + +local function load_config(opts) + local _r, err = rspamd_config:load_ucl(opts['config']) + + if not _r then + rspamd_logger.errx('cannot parse %s: %s', opts['config'], err) + os.exit(1) + end + + _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' }) + if not _r then + rspamd_logger.errx('cannot process %s: %s', opts['config'], err) + os.exit(1) + end +end + +local function scan_email(rspamc_path, host, n_parallel, path, timeout) + + local rspamc_command = string.format("%s --connect %s -j --compact -n %s -t %.3f %s", + rspamc_path, host, n_parallel, timeout, path) + local result = assert(io.popen(rspamc_command)) + result = result:read("*all") + return result +end + +local function encoded_json_to_log(result) + -- Returns table containing score, action, list of symbols + + local filtered_result = {} + local ucl_parser = ucl.parser() + + local is_good, err = ucl_parser:parse_string(result) + + if not is_good then + rspamd_logger.errx("Parser error: %1", err) + return nil + end + + result = ucl_parser:get_object() + + filtered_result.score = result.score + if not result.action then + rspamd_logger.errx("Bad JSON: %1", result) + return nil + end + local action = result.action:gsub("%s+", "_") + filtered_result.action = action + + filtered_result.symbols = {} + + for sym, _ in pairs(result.symbols) do + table.insert(filtered_result.symbols, sym) + end + + filtered_result.filename = result.filename + filtered_result.scan_time = result.scan_time + + return filtered_result +end + +local function filter_scan_results(results, actual_email_type) + + local logs = {} + + results = lua_util.rspamd_str_split(results, "\n") + + if results[#results] == "" then + results[#results] = nil + end + + for _, result in pairs(results) do + result = encoded_json_to_log(result) + if result then + result['type'] = actual_email_type + table.insert(logs, result) + end + end + + return logs +end + +local function get_stats_from_scan_results(results, rules) + + local rule_stats = {} + for rule, _ in pairs(rules) do + rule_stats[rule] = { tp = 0, tn = 0, fp = 0, fn = 0 } + end + + for _, result in ipairs(results) do + for _, symbol in ipairs(result["symbols"]) do + for name, rule in pairs(rules) do + if rule.symbol_spam and rule.symbol_spam == symbol then + if result.type == HAM then + rule_stats[name].fp = rule_stats[name].fp + 1 + elseif result.type == SPAM then + rule_stats[name].tp = rule_stats[name].tp + 1 + end + elseif rule.symbol_ham and rule.symbol_ham == symbol then + if result.type == HAM then + rule_stats[name].tn = rule_stats[name].tn + 1 + elseif result.type == SPAM then + rule_stats[name].fn = rule_stats[name].fn + 1 + end + end + end + end + end + + for rule, _ in pairs(rules) do + rule_stats[rule].fpr = rule_stats[rule].fp / (rule_stats[rule].fp + rule_stats[rule].tn) + rule_stats[rule].fnr = rule_stats[rule].fn / (rule_stats[rule].fn + rule_stats[rule].tp) + end + + return rule_stats +end + +local function print_neural_stats(neural_stats) + for rule, stats in pairs(neural_stats) do + rspamd_logger.messagex("\nStats for rule: %s", rule) + rspamd_logger.messagex("False positive rate: %s%%", stats.fpr * 100) + rspamd_logger.messagex("False negative rate: %s%%", stats.fnr * 100) + end +end + +local function handler(args) + local opts = parser:parse(args) + + local ham_directory = opts['hamdir'] + local spam_directory = opts['spamdir'] + local connections = opts["conns"] + + load_config(opts) + + local neural_opts = rspamd_config:get_all_opt('neural') + + if opts["rule"] then + local found = false + for rule_name, _ in pairs(neural_opts.rules) do + if string.lower(rule_name) == string.lower(opts["rule"]) then + found = true + else + neural_opts.rules[rule_name] = nil + end + end + + if not found then + rspamd_logger.errx("Couldn't find the rule %s", opts["rule"]) + return + end + end + + local results = {} + + if ham_directory then + rspamd_logger.messagex("Scanning ham corpus...") + local ham_results = scan_email(opts.rspamc, opts.connect, connections, ham_directory, opts.timeout) + ham_results = filter_scan_results(ham_results, HAM) + + for _, result in pairs(ham_results) do + table.insert(results, result) + end + end + + if spam_directory then + rspamd_logger.messagex("Scanning spam corpus...") + local spam_results = scan_email(opts.rspamc, opts.connect, connections, spam_directory, opts.timeout) + spam_results = filter_scan_results(spam_results, SPAM) + + for _, result in pairs(spam_results) do + table.insert(results, result) + end + end + + local neural_stats = get_stats_from_scan_results(results, neural_opts.rules) + print_neural_stats(neural_stats) + +end + +return { + name = "neuraltest", + aliases = { "neural_test" }, + handler = handler, + description = parser._description +}
\ No newline at end of file diff --git a/lualib/rspamadm/publicsuffix.lua b/lualib/rspamadm/publicsuffix.lua new file mode 100644 index 0000000..96bf069 --- /dev/null +++ b/lualib/rspamadm/publicsuffix.lua @@ -0,0 +1,82 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local argparse = require "argparse" +local rspamd_logger = require "rspamd_logger" + +-- Define command line options +local parser = argparse() + :name 'rspamadm publicsuffix' + :description 'Do manipulations with the publicsuffix list' + :help_description_margin(30) + :command_target('command') + :require_command(true) + +parser:option '-c --config' + :description 'Path to config file' + :argname('config_file') + :default(rspamd_paths['CONFDIR'] .. '/rspamd.conf') + +parser:command 'compile' + :description 'Compile publicsuffix list if needed' + +local function load_config(config_file) + local _r, err = rspamd_config:load_ucl(config_file) + + if not _r then + rspamd_logger.errx('cannot load %s: %s', config_file, err) + os.exit(1) + end + + _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' }) + if not _r then + rspamd_logger.errx('cannot process %s: %s', config_file, err) + os.exit(1) + end +end + +local function compile_handler(_) + local rspamd_url = require "rspamd_url" + local tld_file = rspamd_config:get_tld_path() + + if not tld_file then + rspamd_logger.errx('missing `url_tld` option, cannot continue') + os.exit(1) + end + + rspamd_logger.messagex('loading public suffix file from %s', tld_file) + rspamd_url.init(tld_file) + rspamd_logger.messagex('public suffix file has been loaded') +end + +local function handler(args) + local cmd_opts = parser:parse(args) + + load_config(cmd_opts.config_file) + + if cmd_opts.command == 'compile' then + compile_handler(cmd_opts) + else + rspamd_logger.errx('unknown command: %s', cmd_opts.command) + os.exit(1) + end +end + +return { + handler = handler, + description = parser._description, + name = 'publicsuffix' +}
\ No newline at end of file diff --git a/lualib/rspamadm/stat_convert.lua b/lualib/rspamadm/stat_convert.lua new file mode 100644 index 0000000..62a19a2 --- /dev/null +++ b/lualib/rspamadm/stat_convert.lua @@ -0,0 +1,38 @@ +local lua_redis = require "lua_redis" +local stat_tools = require "lua_stat" +local ucl = require "ucl" +local logger = require "rspamd_logger" +local lua_util = require "lua_util" + +return function(_, res) + local redis_params = lua_redis.try_load_redis_servers(res.redis, nil) + if res.expire then + res.expire = lua_util.parse_time_interval(res.expire) + end + if not redis_params then + logger.errx('cannot load redis server definition') + + return false + end + + local sqlite_params = stat_tools.load_sqlite_config(res) + + if #sqlite_params == 0 then + logger.errx('cannot load sqlite classifiers') + return false + end + + for _, cls in ipairs(sqlite_params) do + if not stat_tools.convert_sqlite_to_redis(redis_params, cls.db_spam, + cls.db_ham, cls.symbol_spam, cls.symbol_ham, cls.learn_cache, res.expire, + res.reset_previous) then + logger.errx('conversion failed') + + return false + end + logger.messagex('Converted classifier to the from sqlite to redis') + logger.messagex('Suggested configuration:') + logger.messagex(ucl.to_format(stat_tools.redis_classifier_from_sqlite(cls, res.expire), + 'config')) + end +end diff --git a/lualib/rspamadm/statistics_dump.lua b/lualib/rspamadm/statistics_dump.lua new file mode 100644 index 0000000..6bc0458 --- /dev/null +++ b/lualib/rspamadm/statistics_dump.lua @@ -0,0 +1,544 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local lua_redis = require "lua_redis" +local rspamd_logger = require "rspamd_logger" +local argparse = require "argparse" +local rspamd_zstd = require "rspamd_zstd" +local rspamd_text = require "rspamd_text" +local rspamd_util = require "rspamd_util" +local rspamd_cdb = require "rspamd_cdb" +local lua_util = require "lua_util" +local rspamd_i64 = require "rspamd_int64" +local ucl = require "ucl" + +local N = "statistics_dump" +local E = {} +local classifiers = {} + +-- Define command line options +local parser = argparse() + :name "rspamadm statistics_dump" + :description "Dump/restore Rspamd statistics" + :help_description_margin(30) + :command_target("command") + :require_command(false) + +parser:option "-c --config" + :description "Path to config file" + :argname("<cfg>") + :default(rspamd_paths["CONFDIR"] .. "/" .. "rspamd.conf") + +-- Extract subcommand +local dump = parser:command "dump d" + :description "Dump bayes statistics" +dump:mutex( + dump:flag "-j --json" + :description "Json output", + dump:flag "-C --cdb" + :description "CDB output" +) +dump:flag "-c --compress" + :description "Compress output" +dump:option "-b --batch-size" + :description "Number of entires to process at once" + :argname("<elts>") + :convert(tonumber) + :default(1000) + + +-- Restore +local restore = parser:command "restore r" + :description "Restore bayes statistics" +restore:argument "file" + :description "Input file to process" + :argname "<file>" + :args "*" +restore:option "-b --batch-size" + :description "Number of entires to process at once" + :argname("<elts>") + :convert(tonumber) + :default(1000) +restore:option "-m --mode" + :description "Number of entires to process at once" + :argname("<append|subtract|replace>") + :convert { + ['append'] = 'append', + ['subtract'] = 'subtract', + ['replace'] = 'replace', +} + :default 'append' +restore:flag "-n --no-operation" + :description "Only show redis commands to be issued" + +local function load_config(opts) + local _r, err = rspamd_config:load_ucl(opts['config']) + + if not _r then + rspamd_logger.errx('cannot parse %s: %s', opts['config'], err) + os.exit(1) + end + + _r, err = rspamd_config:parse_rcl({ 'logging', 'worker' }) + if not _r then + rspamd_logger.errx('cannot process %s: %s', opts['config'], err) + os.exit(1) + end +end + +local function check_redis_classifier(cls, cfg) + -- Skip old classifiers + if cls.new_schema then + local symbol_spam, symbol_ham + -- Load symbols from statfiles + + local function check_statfile_table(tbl, def_sym) + local symbol = tbl.symbol or def_sym + + local spam + if tbl.spam then + spam = tbl.spam + else + if string.match(symbol:upper(), 'SPAM') then + spam = true + else + spam = false + end + end + + if spam then + symbol_spam = symbol + else + symbol_ham = symbol + end + end + + local statfiles = cls.statfile + if statfiles[1] then + for _, stf in ipairs(statfiles) do + if not stf.symbol then + for k, v in pairs(stf) do + check_statfile_table(v, k) + end + else + check_statfile_table(stf, 'undefined') + end + end + else + for stn, stf in pairs(statfiles) do + check_statfile_table(stf, stn) + end + end + + local redis_params + redis_params = lua_redis.try_load_redis_servers(cls, + rspamd_config, false, 'bayes') + if not redis_params then + redis_params = lua_redis.try_load_redis_servers(cfg[N] or E, + rspamd_config, false, 'bayes') + if not redis_params then + redis_params = lua_redis.try_load_redis_servers(cfg[N] or E, + rspamd_config, true) + if not redis_params then + return false + end + end + end + + table.insert(classifiers, { + symbol_spam = symbol_spam, + symbol_ham = symbol_ham, + redis_params = redis_params, + }) + end +end + +local function redis_map_zip(ar) + local data = {} + for j = 1, #ar, 2 do + data[ar[j]] = ar[j + 1] + end + + return data +end + +-- Used to clear tables +local clear_fcn = table.clear or function(tbl) + local keys = lua_util.keys(tbl) + for _, k in ipairs(keys) do + tbl[k] = nil + end +end + +local compress_ctx + +local function dump_out(out, opts, last) + if opts.compress and not compress_ctx then + compress_ctx = rspamd_zstd.compress_ctx() + end + + if compress_ctx then + if last then + compress_ctx:stream(rspamd_text.fromtable(out), 'end'):write() + else + compress_ctx:stream(rspamd_text.fromtable(out), 'flush'):write() + end + else + for _, o in ipairs(out) do + io.write(o) + end + end +end + +local function dump_cdb(out, opts, last, pattern) + local results = out[pattern] + + if not out.cdb_builder then + -- First invocation + out.cdb_builder = rspamd_cdb.build(string.format('%s.cdb', pattern)) + out.cdb_builder:add('_lrnspam', rspamd_i64.fromstring(results.learns_spam or '0')) + out.cdb_builder:add('_lrnham_', rspamd_i64.fromstring(results.learns_ham or '0')) + end + + for _, o in ipairs(results.elts) do + out.cdb_builder:add(o.key, o.value) + end + + if last then + out.cdb_builder:finalize() + out.cdb_builder = nil + end +end + +local function dump_pattern(conn, pattern, opts, out, key) + local cursor = 0 + + repeat + conn:add_cmd('SCAN', { tostring(cursor), + 'MATCH', pattern, + 'COUNT', tostring(opts.batch_size) }) + local ret, results = conn:exec() + + if not ret then + rspamd_logger.errx("cannot connect execute scan command: %s", results) + os.exit(1) + end + + cursor = tonumber(results[1]) + + local elts = results[2] + local tokens = {} + + for _, e in ipairs(elts) do + conn:add_cmd('HGETALL', { e }) + end + -- This function returns many results, each for each command + -- So if we have batch 1000, then we would have 1000 tables in form + -- [result, {hash_content}] + local all_results = { conn:exec() } + + for i = 1, #all_results, 2 do + local r, hash_content = all_results[i], all_results[i + 1] + + if r then + -- List to a hash map + local data = redis_map_zip(hash_content) + tokens[#tokens + 1] = { key = elts[(i + 1) / 2], data = data } + end + end + + -- Output keeping track of the commas + for i, d in ipairs(tokens) do + if cursor == 0 and i == #tokens or not opts.json then + if opts.cdb then + table.insert(out[key].elts, { + key = rspamd_i64.fromstring(string.match(d.key, '%d+')), + value = rspamd_util.pack('ff', tonumber(d.data["S"] or '0') or 0, + tonumber(d.data["H"] or '0')) + }) + else + out[#out + 1] = rspamd_logger.slog('"%s": %s\n', d.key, + ucl.to_format(d.data, "json-compact")) + end + else + out[#out + 1] = rspamd_logger.slog('"%s": %s,\n', d.key, + ucl.to_format(d.data, "json-compact")) + end + + end + + if opts.json and cursor == 0 then + out[#out + 1] = '}}\n' + end + + -- Do not write the last chunk of out as it will be processed afterwards + if cursor ~= 0 then + if opts.cdb then + dump_out(out, opts, false) + clear_fcn(out) + else + dump_cdb(out, opts, false, key) + out[key].elts = {} + end + elseif opts.cdb then + dump_cdb(out, opts, true, key) + end + + until cursor == 0 +end + +local function dump_handler(opts) + local patterns_seen = {} + for _, cls in ipairs(classifiers) do + local res, conn = lua_redis.redis_connect_sync(cls.redis_params, false) + + if not res then + rspamd_logger.errx("cannot connect to redis server: %s", cls.redis_params) + os.exit(1) + end + + local out = {} + local function check_keys(sym) + local sym_keys_pattern = string.format("%s_keys", sym) + conn:add_cmd('SMEMBERS', { sym_keys_pattern }) + local ret, keys = conn:exec() + + if not ret then + rspamd_logger.errx("cannot execute command to get keys: %s", keys) + os.exit(1) + end + + if not opts.json then + out[#out + 1] = string.format('"%s": %s\n', sym_keys_pattern, + ucl.to_format(keys, 'json-compact')) + end + for _, k in ipairs(keys) do + local pat = string.format('%s*_*', k) + if not patterns_seen[pat] then + conn:add_cmd('HGETALL', { k }) + local _ret, additional_keys = conn:exec() + + if _ret then + if opts.json then + out[#out + 1] = string.format('{"pattern": "%s", "meta": %s, "elts": {\n', + k, ucl.to_format(redis_map_zip(additional_keys), 'json-compact')) + elseif opts.cdb then + out[k] = redis_map_zip(additional_keys) + out[k].elts = {} + else + out[#out + 1] = string.format('"%s": %s\n', k, + ucl.to_format(redis_map_zip(additional_keys), 'json-compact')) + end + dump_pattern(conn, pat, opts, out, k) + patterns_seen[pat] = true + end + end + end + end + + check_keys(cls.symbol_spam) + check_keys(cls.symbol_ham) + + if #out > 0 then + dump_out(out, opts, true) + end + end +end + +local function obj_to_redis_arguments(obj, opts, cmd_pipe) + local key, value = next(obj) + + if type(key) == 'string' then + if type(value) == 'table' then + if not value[1] then + if opts.mode == 'replace' then + local cmd = 'HMSET' + local params = { key } + for k, v in pairs(value) do + table.insert(params, k) + table.insert(params, v) + end + table.insert(cmd_pipe, { cmd, params }) + else + local cmd = 'HINCRBYFLOAT' + local mult = 1.0 + if opts.mode == 'subtract' then + mult = (-mult) + end + + for k, v in pairs(value) do + if tonumber(v) then + v = tonumber(v) + table.insert(cmd_pipe, { cmd, { key, k, tostring(v * mult) } }) + else + table.insert(cmd_pipe, { 'HSET', { key, k, v } }) + end + end + end + else + -- Numeric table of elements (e.g. _keys) - it is actually a set in Redis + for _, elt in ipairs(value) do + table.insert(cmd_pipe, { 'SADD', { key, elt } }) + end + end + end + end + + return cmd_pipe +end + +local function execute_batch(batch, conns, opts) + local cmd_pipe = {} + + for _, cmd in ipairs(batch) do + obj_to_redis_arguments(cmd, opts, cmd_pipe) + end + + if opts.no_operation then + for _, cmd in ipairs(cmd_pipe) do + rspamd_logger.messagex('%s %s', cmd[1], table.concat(cmd[2], ' ')) + end + else + for _, conn in ipairs(conns) do + for _, cmd in ipairs(cmd_pipe) do + local is_ok, err = conn:add_cmd(cmd[1], cmd[2]) + + if not is_ok then + rspamd_logger.errx("cannot add command: %s with args: %s: %s", cmd[1], cmd[2], err) + end + end + + conn:exec() + end + end +end + +local function restore_handler(opts) + local files = opts.file or { '-' } + local conns = {} + + for _, cls in ipairs(classifiers) do + local res, conn = lua_redis.redis_connect_sync(cls.redis_params, true) + + if not res then + rspamd_logger.errx("cannot connect to redis server: %s", cls.redis_params) + os.exit(1) + end + + table.insert(conns, conn) + end + + local batch = {} + + for _, f in ipairs(files) do + local fd + if f ~= '-' then + fd = io.open(f, 'r') + io.input(fd) + end + + local cur_line = 1 + for line in io.lines() do + local ucl_parser = ucl.parser() + local res, err + res, err = ucl_parser:parse_string(line) + + if not res then + rspamd_logger.errx("%s: cannot read line %s: %s", f, cur_line, err) + os.exit(1) + end + + table.insert(batch, ucl_parser:get_object()) + cur_line = cur_line + 1 + + if cur_line % opts.batch_size == 0 then + execute_batch(batch, conns, opts) + batch = {} + end + end + + if fd then + fd:close() + end + end + + if #batch > 0 then + execute_batch(batch, conns, opts) + end +end + +local function handler(args) + local opts = parser:parse(args) + + local command = opts.command or 'dump' + + load_config(opts) + rspamd_config:init_subsystem('stat') + + local obj = rspamd_config:get_ucl() + + local classifier = obj.classifier + + if classifier then + if classifier[1] then + for _, cls in ipairs(classifier) do + if cls.bayes then + cls = cls.bayes + end + if cls.backend and cls.backend == 'redis' then + check_redis_classifier(cls, obj) + end + end + else + if classifier.bayes then + + classifier = classifier.bayes + if classifier[1] then + for _, cls in ipairs(classifier) do + if cls.backend and cls.backend == 'redis' then + check_redis_classifier(cls, obj) + end + end + else + if classifier.backend and classifier.backend == 'redis' then + check_redis_classifier(classifier, obj) + end + end + end + end + end + + if type(opts.file) == 'string' then + opts.file = { opts.file } + elseif type(opts.file) == 'none' then + opts.file = {} + end + + if command == 'dump' then + dump_handler(opts) + elseif command == 'restore' then + restore_handler(opts) + else + parser:error('command %s is not implemented', command) + end +end + +return { + name = 'statistics_dump', + aliases = { 'stat_dump', 'bayes_dump' }, + handler = handler, + description = parser._description +}
\ No newline at end of file diff --git a/lualib/rspamadm/template.lua b/lualib/rspamadm/template.lua new file mode 100644 index 0000000..ca1779a --- /dev/null +++ b/lualib/rspamadm/template.lua @@ -0,0 +1,131 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local argparse = require "argparse" + + +-- Define command line options +local parser = argparse() + :name "rspamadm template" + :description "Apply jinja templates for strings/files" + :help_description_margin(30) +parser:argument "file" + :description "File to process" + :argname "<file>" + :args "*" + +parser:flag "-n --no-vars" + :description "Don't add Rspamd internal variables" +parser:option "-e --env" + :description "Load additional environment vars from specific file (name=value)" + :argname "<filename>" + :count "*" +parser:option "-l --lua-env" + :description "Load additional environment vars from specific file (lua source)" + :argname "<filename>" + :count "*" +parser:mutex( + parser:option "-s --suffix" + :description "Store files with the new suffix" + :argname "<suffix>", + parser:flag "-i --inplace" + :description "Replace input file(s)" +) + +local lua_util = require "lua_util" + +local function set_env(opts, env) + if opts.env then + for _, fname in ipairs(opts.env) do + for kv in assert(io.open(fname)):lines() do + if not kv:match('%s*#.*') then + local k, v = kv:match('([^=%s]+)%s*=%s*(.+)') + + if k and v then + env[k] = v + else + io.write(string.format('invalid env line in %s: %s\n', fname, kv)) + end + end + end + end + end + + if opts.lua_env then + for _, fname in ipairs(opts.env) do + local ret, res_or_err = pcall(loadfile(fname)) + + if not ret then + io.write(string.format('cannot load %s: %s\n', fname, res_or_err)) + else + if type(res_or_err) == 'table' then + for k, v in pairs(res_or_err) do + env[k] = lua_util.deepcopy(v) + end + else + io.write(string.format('cannot load %s: not a table\n', fname)) + end + end + end + end +end + +local function read_file(file) + local content + if file == '-' then + content = io.read("*all") + else + local f = assert(io.open(file, "rb")) + content = f:read("*all") + f:close() + end + return content +end + +local function handler(args) + local opts = parser:parse(args) + local env = {} + set_env(opts, env) + + if not opts.file or #opts.file == 0 then + opts.file = { '-' } + end + for _, fname in ipairs(opts.file) do + local content = read_file(fname) + local res = lua_util.jinja_template(content, env, opts.no_vars) + + if opts.inplace then + local nfile = string.format('%s.new', fname) + local out = assert(io.open(nfile, 'w')) + out:write(content) + out:close() + os.rename(nfile, fname) + elseif opts.suffix then + local nfile = string.format('%s.%s', opts.suffix) + local out = assert(io.open(nfile, 'w')) + out:write(content) + out:close() + else + io.write(res) + end + end +end + +return { + handler = handler, + description = parser._description, + name = 'template' +}
\ No newline at end of file diff --git a/lualib/rspamadm/vault.lua b/lualib/rspamadm/vault.lua new file mode 100644 index 0000000..840e504 --- /dev/null +++ b/lualib/rspamadm/vault.lua @@ -0,0 +1,579 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + + +local rspamd_logger = require "rspamd_logger" +local ansicolors = require "ansicolors" +local ucl = require "ucl" +local argparse = require "argparse" +local fun = require "fun" +local rspamd_http = require "rspamd_http" +local cr = require "rspamd_cryptobox" + +local parser = argparse() + :name "rspamadm vault" + :description "Perform Hashicorp Vault management" + :help_description_margin(32) + :command_target("command") + :require_command(true) + +parser:flag "-s --silent" + :description "Do not output extra information" +parser:option "-a --addr" + :description "Vault address (if not defined in VAULT_ADDR env)" +parser:option "-t --token" + :description "Vault token (not recommended, better define VAULT_TOKEN env)" +parser:option "-p --path" + :description "Path to work with in the vault" + :default "dkim" +parser:option "-o --output" + :description "Output format ('ucl', 'json', 'json-compact', 'yaml')" + :argname("<type>") + :convert { + ucl = "ucl", + json = "json", + ['json-compact'] = "json-compact", + yaml = "yaml", +} + :default "ucl" + +parser:command "list ls l" + :description "List elements in the vault" + +local show = parser:command "show get" + :description "Extract element from the vault" +show:argument "domain" + :description "Domain to create key for" + :args "+" + +local delete = parser:command "delete del rm remove" + :description "Delete element from the vault" +delete:argument "domain" + :description "Domain to create delete key(s) for" + :args "+" + +local newkey = parser:command "newkey new create" + :description "Add new key to the vault" +newkey:argument "domain" + :description "Domain to create key for" + :args "+" +newkey:option "-s --selector" + :description "Selector to use" + :count "?" +newkey:option "-A --algorithm" + :argname("<type>") + :convert { + rsa = "rsa", + ed25519 = "ed25519", + eddsa = "ed25519", +} + :default "rsa" +newkey:option "-b --bits" + :argname("<nbits>") + :convert(tonumber) + :default "1024" +newkey:option "-x --expire" + :argname("<days>") + :convert(tonumber) +newkey:flag "-r --rewrite" + +local roll = parser:command "roll rollover" + :description "Perform keys rollover" +roll:argument "domain" + :description "Domain to roll key(s) for" + :args "+" +roll:option "-T --ttl" + :description "Validity period for old keys (days)" + :convert(tonumber) + :default "1" +roll:flag "-r --remove-expired" + :description "Remove expired keys" +roll:option "-x --expire" + :argname("<days>") + :convert(tonumber) + +local function printf(fmt, ...) + if fmt then + io.write(rspamd_logger.slog(fmt, ...)) + end + io.write('\n') +end + +local function maybe_printf(opts, fmt, ...) + if not opts.silent then + printf(fmt, ...) + end +end + +local function highlight(str, color) + return ansicolors[color or 'white'] .. str .. ansicolors.reset +end + +local function vault_url(opts, path) + if path then + return string.format('%s/v1/%s/%s', opts.addr, opts.path, path) + end + + return string.format('%s/v1/%s', opts.addr, opts.path) +end + +local function is_http_error(err, data) + return err or (math.floor(data.code / 100) ~= 2) +end + +local function parse_vault_reply(data) + local p = ucl.parser() + local res, parser_err = p:parse_string(data) + + if not res then + return nil, parser_err + else + return p:get_object(), nil + end +end + +local function maybe_print_vault_data(opts, data, func) + if data then + local res, parser_err = parse_vault_reply(data) + + if not res then + printf('vault reply for cannot be parsed: %s', parser_err) + else + if func then + printf(ucl.to_format(func(res), opts.output)) + else + printf(ucl.to_format(res, opts.output)) + end + end + else + printf('no data received') + end +end + +local function print_dkim_txt_record(b64, selector, alg) + local labels = {} + local prefix = string.format("v=DKIM1; k=%s; p=", alg) + b64 = prefix .. b64 + if #b64 < 255 then + labels = { '"' .. b64 .. '"' } + else + for sl = 1, #b64, 256 do + table.insert(labels, '"' .. b64:sub(sl, sl + 255) .. '"') + end + end + + printf("%s._domainkey IN TXT ( %s )", selector, + table.concat(labels, "\n\t")) +end + +local function show_handler(opts, domain) + local uri = vault_url(opts, domain) + local err, data = rspamd_http.request { + config = rspamd_config, + ev_base = rspamadm_ev_base, + session = rspamadm_session, + resolver = rspamadm_dns_resolver, + url = uri, + headers = { + ['X-Vault-Token'] = opts.token + } + } + + if is_http_error(err, data) then + printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code) + maybe_print_vault_data(opts, err) + os.exit(1) + else + maybe_print_vault_data(opts, data.content, function(obj) + return obj.data.selectors + end) + end +end + +local function delete_handler(opts, domain) + local uri = vault_url(opts, domain) + local err, data = rspamd_http.request { + config = rspamd_config, + ev_base = rspamadm_ev_base, + session = rspamadm_session, + resolver = rspamadm_dns_resolver, + url = uri, + method = 'delete', + headers = { + ['X-Vault-Token'] = opts.token + } + } + + if is_http_error(err, data) then + printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code) + maybe_print_vault_data(opts, err) + os.exit(1) + else + printf('deleted key(s) for %s', domain) + end +end + +local function list_handler(opts) + local uri = vault_url(opts) + local err, data = rspamd_http.request { + config = rspamd_config, + ev_base = rspamadm_ev_base, + session = rspamadm_session, + resolver = rspamadm_dns_resolver, + url = uri .. '?list=true', + headers = { + ['X-Vault-Token'] = opts.token + } + } + + if is_http_error(err, data) then + printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code) + maybe_print_vault_data(opts, err) + os.exit(1) + else + maybe_print_vault_data(opts, data.content, function(obj) + return obj.data.keys + end) + end +end + +-- Returns pair privkey+pubkey +local function genkey(opts) + return cr.gen_dkim_keypair(opts.algorithm, opts.bits) +end + +local function create_and_push_key(opts, domain, existing) + local uri = vault_url(opts, domain) + local sk, pk = genkey(opts) + + local res = { + selectors = { + [1] = { + selector = opts.selector, + domain = domain, + key = tostring(sk), + pubkey = tostring(pk), + alg = opts.algorithm, + bits = opts.bits or 0, + valid_start = os.time(), + } + } + } + + for _, sel in ipairs(existing) do + res.selectors[#res.selectors + 1] = sel + end + + if opts.expire then + res.selectors[1].valid_end = os.time() + opts.expire * 3600 * 24 + end + + local err, data = rspamd_http.request { + config = rspamd_config, + ev_base = rspamadm_ev_base, + session = rspamadm_session, + resolver = rspamadm_dns_resolver, + url = uri, + method = 'put', + headers = { + ['Content-Type'] = 'application/json', + ['X-Vault-Token'] = opts.token + }, + body = { + ucl.to_format(res, 'json-compact') + }, + } + + if is_http_error(err, data) then + printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code) + maybe_print_vault_data(opts, data.content) + os.exit(1) + else + maybe_printf(opts, 'stored key for: %s, selector: %s', domain, opts.selector) + maybe_printf(opts, 'please place the corresponding public key as following:') + + if opts.silent then + printf('%s', pk) + else + print_dkim_txt_record(tostring(pk), opts.selector, opts.algorithm) + end + end +end + +local function newkey_handler(opts, domain) + local uri = vault_url(opts, domain) + + if not opts.selector then + opts.selector = string.format('%s-%s', opts.algorithm, + os.date("!%Y%m%d")) + end + + local err, data = rspamd_http.request { + config = rspamd_config, + ev_base = rspamadm_ev_base, + session = rspamadm_session, + resolver = rspamadm_dns_resolver, + url = uri, + method = 'get', + headers = { + ['X-Vault-Token'] = opts.token + } + } + + if is_http_error(err, data) or not data.content then + create_and_push_key(opts, domain, {}) + else + -- Key exists + local rep = parse_vault_reply(data.content) + + if not rep or not rep.data then + printf('cannot parse reply for %s: %s', uri, data.content) + os.exit(1) + end + + local elts = rep.data.selectors + + if not elts then + create_and_push_key(opts, domain, {}) + os.exit(0) + end + + for _, sel in ipairs(elts) do + if sel.alg == opts.algorithm then + printf('key with the specific algorithm %s is already presented at %s selector for %s domain', + opts.algorithm, sel.selector, domain) + os.exit(1) + else + create_and_push_key(opts, domain, elts) + end + end + end +end + +local function roll_handler(opts, domain) + local uri = vault_url(opts, domain) + local res = { + selectors = {} + } + + local err, data = rspamd_http.request { + config = rspamd_config, + ev_base = rspamadm_ev_base, + session = rspamadm_session, + resolver = rspamadm_dns_resolver, + url = uri, + method = 'get', + headers = { + ['X-Vault-Token'] = opts.token + } + } + + if is_http_error(err, data) or not data.content then + printf("No keys to roll for domain %s", domain) + os.exit(1) + else + local rep = parse_vault_reply(data.content) + + if not rep or not rep.data then + printf('cannot parse reply for %s: %s', uri, data.content) + os.exit(1) + end + + local elts = rep.data.selectors + + if not elts then + printf("No keys to roll for domain %s", domain) + os.exit(1) + end + + local nkeys = {} -- indexed by algorithm + + local function insert_key(sel, add_expire) + if not nkeys[sel.alg] then + nkeys[sel.alg] = {} + end + + if add_expire then + sel.valid_end = os.time() + opts.ttl * 3600 * 24 + end + + table.insert(nkeys[sel.alg], sel) + end + + for _, sel in ipairs(elts) do + if sel.valid_end and sel.valid_end < os.time() then + if not opts.remove_expired then + insert_key(sel, false) + else + maybe_printf(opts, 'removed expired key for %s (selector %s, expire "%s"', + domain, sel.selector, os.date('%c', sel.valid_end)) + end + else + insert_key(sel, true) + end + end + + -- Now we need to ensure that all but one selectors have either expired or just a single key + for alg, keys in pairs(nkeys) do + table.sort(keys, function(k1, k2) + if k1.valid_end and k2.valid_end then + return k1.valid_end > k2.valid_end + elseif k1.valid_end then + return true + elseif k2.valid_end then + return false + end + return false + end) + -- Exclude the key with the highest expiration date and examine the rest + if not (#keys == 1 or fun.all(function(k) + return k.valid_end and k.valid_end < os.time() + end, fun.tail(keys))) then + printf('bad keys list for %s and %s algorithm', domain, alg) + fun.each(function(k) + if not k.valid_end then + printf('selector %s, algorithm %s has a key with no expire', + k.selector, k.alg) + elseif k.valid_end >= os.time() then + printf('selector %s, algorithm %s has a key that not yet expired: %s', + k.selector, k.alg, os.date('%c', k.valid_end)) + end + end, fun.tail(keys)) + os.exit(1) + end + -- Do not create new keys, if we only want to remove expired keys + if not opts.remove_expired then + -- OK to process + -- Insert keys for each algorithm in pairs <old_key(s)>, <new_key> + local sk, pk = genkey({ algorithm = alg, bits = keys[1].bits }) + local selector = string.format('%s-%s', alg, + os.date("!%Y%m%d")) + + if selector == keys[1].selector then + selector = selector .. '-1' + end + local nelt = { + selector = selector, + domain = domain, + key = tostring(sk), + pubkey = tostring(pk), + alg = alg, + bits = keys[1].bits, + valid_start = os.time(), + } + + if opts.expire then + nelt.valid_end = os.time() + opts.expire * 3600 * 24 + end + + table.insert(res.selectors, nelt) + end + for _, k in ipairs(keys) do + table.insert(res.selectors, k) + end + end + end + + -- We can now store res in the vault + err, data = rspamd_http.request { + config = rspamd_config, + ev_base = rspamadm_ev_base, + session = rspamadm_session, + resolver = rspamadm_dns_resolver, + url = uri, + method = 'put', + headers = { + ['Content-Type'] = 'application/json', + ['X-Vault-Token'] = opts.token + }, + body = { + ucl.to_format(res, 'json-compact') + }, + } + + if is_http_error(err, data) then + printf('cannot put request to the vault (%s), HTTP error code %s', uri, data.code) + maybe_print_vault_data(opts, data.content) + os.exit(1) + else + for _, key in ipairs(res.selectors) do + if not key.valid_end or key.valid_end > os.time() + opts.ttl * 3600 * 24 then + maybe_printf(opts, 'rolled key for: %s, new selector: %s', domain, key.selector) + maybe_printf(opts, 'please place the corresponding public key as following:') + + if opts.silent then + printf('%s', key.pubkey) + else + print_dkim_txt_record(key.pubkey, key.selector, key.alg) + end + + end + end + + maybe_printf(opts, 'your old keys will be valid until %s', + os.date('%c', os.time() + opts.ttl * 3600 * 24)) + end +end + +local function handler(args) + local opts = parser:parse(args) + + if not opts.addr then + opts.addr = os.getenv('VAULT_ADDR') + end + + if not opts.token then + opts.token = os.getenv('VAULT_TOKEN') + else + maybe_printf(opts, 'defining token via command line is insecure, define it via environment variable %s', + highlight('VAULT_TOKEN', 'red')) + end + + if not opts.token or not opts.addr then + printf('no token or/and vault addr has been specified, exiting') + os.exit(1) + end + + local command = opts.command + + if command == 'list' then + list_handler(opts) + elseif command == 'show' then + fun.each(function(d) + show_handler(opts, d) + end, opts.domain) + elseif command == 'newkey' then + fun.each(function(d) + newkey_handler(opts, d) + end, opts.domain) + elseif command == 'roll' then + fun.each(function(d) + roll_handler(opts, d) + end, opts.domain) + elseif command == 'delete' then + fun.each(function(d) + delete_handler(opts, d) + end, opts.domain) + else + parser:error(string.format('command %s is not implemented', command)) + end +end + +return { + handler = handler, + description = parser._description, + name = 'vault' +} |