summaryrefslogtreecommitdiffstats
path: root/lualib/lua_fuzzy.lua
diff options
context:
space:
mode:
Diffstat (limited to 'lualib/lua_fuzzy.lua')
-rw-r--r--lualib/lua_fuzzy.lua355
1 files changed, 355 insertions, 0 deletions
diff --git a/lualib/lua_fuzzy.lua b/lualib/lua_fuzzy.lua
new file mode 100644
index 0000000..986d1a0
--- /dev/null
+++ b/lualib/lua_fuzzy.lua
@@ -0,0 +1,355 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_fuzzy
+-- This module contains helper functions for supporting fuzzy check module
+--]]
+
+
+local N = "lua_fuzzy"
+local lua_util = require "lua_util"
+local rspamd_regexp = require "rspamd_regexp"
+local fun = require "fun"
+local rspamd_logger = require "rspamd_logger"
+local ts = require("tableshape").types
+
+-- Filled by C code, indexed by number in this table
+local rules = {}
+
+-- Pre-defined rules options
+local policies = {
+ recommended = {
+ min_bytes = 1024,
+ min_height = 500,
+ min_width = 500,
+ min_length = 64,
+ text_multiplier = 4.0, -- divide min_bytes by 4 for texts
+ mime_types = { "application/*" },
+ scan_archives = true,
+ short_text_direct_hash = true,
+ text_shingles = true,
+ skip_images = false,
+ }
+}
+
+local default_policy = policies.recommended
+
+local schema_fields = {
+ min_bytes = ts.number + ts.string / tonumber,
+ min_height = ts.number + ts.string / tonumber,
+ min_width = ts.number + ts.string / tonumber,
+ min_length = ts.number + ts.string / tonumber,
+ text_multiplier = ts.number,
+ mime_types = ts.array_of(ts.string),
+ scan_archives = ts.boolean,
+ short_text_direct_hash = ts.boolean,
+ text_shingles = ts.boolean,
+ skip_images = ts.boolean,
+}
+local policy_schema = ts.shape(schema_fields)
+
+local policy_schema_open = ts.shape(schema_fields, {
+ open = true,
+})
+
+local exports = {}
+
+
+--[[[
+-- @function lua_fuzzy.register_policy(name, policy)
+-- Adds a new policy with name `name`. Must be valid, checked using policy_schema
+--]]
+exports.register_policy = function(name, policy)
+ if policies[name] then
+ rspamd_logger.warnx(rspamd_config, "overriding policy %s", name)
+ end
+
+ local parsed_policy, err = policy_schema:transform(policy)
+
+ if not parsed_policy then
+ rspamd_logger.errx(rspamd_config, 'invalid fuzzy rule policy %s: %s',
+ name, err)
+
+ return
+ else
+ policies.name = parsed_policy
+ end
+end
+
+--[[[
+-- @function lua_fuzzy.process_rule(rule)
+-- Processes fuzzy rule (applying policies or defaults if needed). Returns policy id
+--]]
+exports.process_rule = function(rule)
+ local processed_rule = lua_util.shallowcopy(rule)
+ local policy = default_policy
+
+ if processed_rule.policy then
+ policy = policies[processed_rule.policy]
+ end
+
+ if policy then
+ processed_rule = lua_util.override_defaults(policy, processed_rule)
+
+ local parsed_policy, err = policy_schema_open:transform(processed_rule)
+
+ if not parsed_policy then
+ rspamd_logger.errx(rspamd_config, 'invalid fuzzy rule default fields: %s', err)
+ else
+ processed_rule = parsed_policy
+ end
+ else
+ rspamd_logger.warnx(rspamd_config, "unknown policy %s", processed_rule.policy)
+ end
+
+ if processed_rule.mime_types then
+ processed_rule.mime_types = fun.totable(fun.map(function(gl)
+ return rspamd_regexp.import_glob(gl, 'i')
+ end, processed_rule.mime_types))
+ end
+
+ table.insert(rules, processed_rule)
+ return #rules
+end
+
+local function check_length(task, part, rule)
+ local bytes = part:get_length()
+ local length_ok = bytes > 0
+
+ local id = part:get_id()
+ lua_util.debugm(N, task, 'check size of part %s', id)
+
+ if length_ok and rule.min_bytes > 0 then
+
+ local adjusted_bytes = bytes
+
+ if part:is_text() then
+ -- Fuzzy plugin uses stripped utf content to get an exact hash, that
+ -- corresponds to `get_content_oneline()`
+ -- However, in the case of empty parts this method returns `nil`, so extra
+ -- sanity check is required.
+ bytes = #(part:get_text():get_content_oneline() or '')
+
+ -- Short hashing algorithm also use subject unless explicitly denied
+ if not rule.no_subject then
+ local subject = task:get_subject() or ''
+ bytes = bytes + #subject
+ end
+
+ if rule.text_multiplier then
+ adjusted_bytes = bytes * rule.text_multiplier
+ end
+ end
+
+ if rule.min_bytes > adjusted_bytes then
+ lua_util.debugm(N, task, 'skip part of length %s (%s adjusted) ' ..
+ 'as it has less than %s bytes',
+ bytes, adjusted_bytes, rule.min_bytes)
+ length_ok = false
+ else
+ lua_util.debugm(N, task, 'allow part of length %s (%s adjusted)',
+ bytes, adjusted_bytes, rule.min_bytes)
+ end
+ else
+ lua_util.debugm(N, task, 'allow part %s, no length limits', id)
+ end
+
+ return length_ok
+end
+
+local function check_text_part(task, part, rule, text)
+ local allow_direct, allow_shingles = false, false
+
+ local id = part:get_id()
+ lua_util.debugm(N, task, 'check text part %s', id)
+ local wcnt = text:get_words_count()
+
+ if rule.text_shingles then
+ -- Check number of words
+ local min_words = rule.min_length or 0
+ if min_words < 32 then
+ min_words = 32 -- Minimum for shingles
+ end
+ if wcnt < min_words then
+ lua_util.debugm(N, task, 'text has less than %s words: %s; disable shingles',
+ rule.min_length, wcnt)
+ allow_shingles = false
+ else
+ lua_util.debugm(N, task, 'allow shingles in text %s, %s words',
+ id, wcnt)
+ allow_shingles = true
+ end
+
+ if not rule.short_text_direct_hash and not allow_shingles then
+ allow_direct = false
+ else
+ if not allow_shingles then
+ lua_util.debugm(N, task,
+ 'allow direct hash for short text %s, %s words',
+ id, wcnt)
+ allow_direct = check_length(task, part, rule)
+ else
+ allow_direct = wcnt > 0
+ end
+ end
+ else
+ lua_util.debugm(N, task,
+ 'disable shingles in text %s', id)
+ allow_direct = check_length(task, part, rule)
+ end
+
+ return allow_direct, allow_shingles
+end
+
+--local function has_sane_text_parts(task)
+-- local text_parts = task:get_text_parts() or {}
+-- return fun.any(function(tp) return tp:get_words_count() > 32 end, text_parts)
+--end
+
+local function check_image_part(task, part, rule, image)
+ if rule.skip_images then
+ lua_util.debugm(N, task, 'skip image part as images are disabled')
+ return false, false
+ end
+
+ local id = part:get_id()
+ lua_util.debugm(N, task, 'check image part %s', id)
+
+ if rule.min_width > 0 or rule.min_height > 0 then
+ -- Check dimensions
+ local min_width = rule.min_width or rule.min_height
+ local min_height = rule.min_height or rule.min_width
+ local height = image:get_height()
+ local width = image:get_width()
+
+ if height and width then
+ if height < min_height or width < min_width then
+ lua_util.debugm(N, task, 'skip image part %s as it does not meet minimum sizes: %sx%s < %sx%s',
+ id, width, height, min_width, min_height)
+ return false, false
+ else
+ lua_util.debugm(N, task, 'allow image part %s: %sx%s',
+ id, width, height)
+ end
+ end
+ end
+
+ return check_length(task, part, rule), false
+end
+
+local function mime_types_check(task, part, rule)
+ local t, st = part:get_type()
+
+ if not t then
+ return false, false
+ end
+
+ local ct = string.format('%s/%s', t, st)
+
+ local detected_ct
+ t, st = part:get_detected_type()
+ if t then
+ detected_ct = string.format('%s/%s', t, st)
+ else
+ detected_ct = ct
+ end
+
+ local id = part:get_id()
+ lua_util.debugm(N, task, 'check binary part %s: %s', id, ct)
+
+ -- For bad mime parts we implicitly enable fuzzy check
+ local mime_trace = (task:get_symbol('MIME_TRACE') or {})[1]
+ local opts = {}
+
+ if mime_trace then
+ opts = mime_trace.options or opts
+ end
+ opts = fun.tomap(fun.map(function(opt)
+ local elts = lua_util.str_split(opt, ':')
+ return elts[1], elts[2]
+ end, opts))
+
+ if opts[id] and opts[id] == '-' then
+ lua_util.debugm(N, task, 'explicitly check binary part %s: bad mime type %s', id, ct)
+ return check_length(task, part, rule), false
+ end
+
+ if rule.mime_types then
+
+ if fun.any(function(gl_re)
+ if gl_re:match(ct) or (detected_ct and gl_re:match(detected_ct)) then
+ return true
+ else
+ return false
+ end
+ end, rule.mime_types) then
+ lua_util.debugm(N, task, 'found mime type match for part %s: %s (%s detected)',
+ id, ct, detected_ct)
+ return check_length(task, part, rule), false
+ end
+
+ return false, false
+ end
+
+ return false, false
+end
+
+exports.check_mime_part = function(task, part, rule_id)
+ local rule = rules[rule_id]
+
+ if not rule then
+ rspamd_logger.errx(task, 'cannot find rule with id %s', rule_id)
+
+ return false, false
+ end
+
+ if part:is_text() then
+ return check_text_part(task, part, rule, part:get_text())
+ end
+
+ if part:is_image() then
+ return check_image_part(task, part, rule, part:get_image())
+ end
+
+ if part:is_archive() and rule.scan_archives then
+ -- Always send archives
+ lua_util.debugm(N, task, 'check archive part %s', part:get_id())
+
+ return true, false
+ end
+
+ if part:is_specific() then
+ local sp = part:get_specific()
+
+ if type(sp) == 'table' and sp.fuzzy_hashes then
+ lua_util.debugm(N, task, 'check specific part %s', part:get_id())
+ return true, false
+ end
+ end
+
+ if part:is_attachment() then
+ return mime_types_check(task, part, rule)
+ end
+
+ return false, false
+end
+
+exports.cleanup_rules = function()
+ rules = {}
+end
+
+return exports