summaryrefslogtreecommitdiffstats
path: root/lualib/lua_magic/init.lua
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--lualib/lua_magic/init.lua388
1 files changed, 388 insertions, 0 deletions
diff --git a/lualib/lua_magic/init.lua b/lualib/lua_magic/init.lua
new file mode 100644
index 0000000..38bfddb
--- /dev/null
+++ b/lualib/lua_magic/init.lua
@@ -0,0 +1,388 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+--[[[
+-- @module lua_magic
+-- This module contains file types detection logic
+--]]
+
+local patterns = require "lua_magic/patterns"
+local types = require "lua_magic/types"
+local heuristics = require "lua_magic/heuristics"
+local fun = require "fun"
+local lua_util = require "lua_util"
+
+local rspamd_text = require "rspamd_text"
+local rspamd_trie = require "rspamd_trie"
+
+local N = "lua_magic"
+local exports = {}
+-- trie objects
+local compiled_patterns
+local compiled_short_patterns
+local compiled_tail_patterns
+-- {<str>, <match_object>, <pattern_object>} indexed by pattern number
+local processed_patterns = {}
+local short_patterns = {}
+local tail_patterns = {}
+
+local short_match_limit = 128
+local max_short_offset = -1
+local min_tail_offset = math.huge
+
+local function process_patterns(log_obj)
+ -- Add pattern to either short patterns or to normal patterns
+ local function add_processed(str, match, pattern)
+ if match.position and type(match.position) == 'number' then
+ if match.tail then
+ -- Tail pattern
+ tail_patterns[#tail_patterns + 1] = {
+ str, match, pattern
+ }
+ if min_tail_offset > match.tail then
+ min_tail_offset = match.tail
+ end
+
+ lua_util.debugm(N, log_obj, 'add tail pattern %s for ext %s',
+ str, pattern.ext)
+ elseif match.position < short_match_limit then
+ short_patterns[#short_patterns + 1] = {
+ str, match, pattern
+ }
+ if str:sub(1, 1) == '^' then
+ lua_util.debugm(N, log_obj, 'add head pattern %s for ext %s',
+ str, pattern.ext)
+ else
+ lua_util.debugm(N, log_obj, 'add short pattern %s for ext %s',
+ str, pattern.ext)
+ end
+
+ if max_short_offset < match.position then
+ max_short_offset = match.position
+ end
+ else
+ processed_patterns[#processed_patterns + 1] = {
+ str, match, pattern
+ }
+
+ lua_util.debugm(N, log_obj, 'add long pattern %s for ext %s',
+ str, pattern.ext)
+ end
+ else
+ processed_patterns[#processed_patterns + 1] = {
+ str, match, pattern
+ }
+
+ lua_util.debugm(N, log_obj, 'add long pattern %s for ext %s',
+ str, pattern.ext)
+ end
+ end
+
+ if not compiled_patterns then
+ for ext, pattern in pairs(patterns) do
+ assert(types[ext], 'not found type: ' .. ext)
+ pattern.ext = ext
+ for _, match in ipairs(pattern.matches) do
+ if match.string then
+ if match.relative_position and not match.position then
+ match.position = match.relative_position + #match.string
+
+ if match.relative_position == 0 then
+ if match.string:sub(1, 1) ~= '^' then
+ match.string = '^' .. match.string
+ end
+ end
+ end
+ add_processed(match.string, match, pattern)
+ elseif match.hex then
+ local hex_table = {}
+
+ for i = 1, #match.hex, 2 do
+ local subc = match.hex:sub(i, i + 1)
+ hex_table[#hex_table + 1] = string.format('\\x{%s}', subc)
+ end
+
+ if match.relative_position and not match.position then
+ match.position = match.relative_position + #match.hex / 2
+ end
+ if match.relative_position == 0 then
+ table.insert(hex_table, 1, '^')
+ end
+ add_processed(table.concat(hex_table), match, pattern)
+ end
+ end
+ end
+ local bit = require "bit"
+ local compile_flags = bit.bor(rspamd_trie.flags.re, rspamd_trie.flags.dot_all)
+ compile_flags = bit.bor(compile_flags, rspamd_trie.flags.single_match)
+ compile_flags = bit.bor(compile_flags, rspamd_trie.flags.no_start)
+ compiled_patterns = rspamd_trie.create(fun.totable(
+ fun.map(function(t)
+ return t[1]
+ end, processed_patterns)),
+ compile_flags
+ )
+ compiled_short_patterns = rspamd_trie.create(fun.totable(
+ fun.map(function(t)
+ return t[1]
+ end, short_patterns)),
+ compile_flags
+ )
+ compiled_tail_patterns = rspamd_trie.create(fun.totable(
+ fun.map(function(t)
+ return t[1]
+ end, tail_patterns)),
+ compile_flags
+ )
+
+ lua_util.debugm(N, log_obj,
+ 'compiled %s (%s short; %s long; %s tail) patterns',
+ #processed_patterns + #short_patterns + #tail_patterns,
+ #short_patterns, #processed_patterns, #tail_patterns)
+ end
+end
+
+process_patterns(rspamd_config)
+
+local function match_chunk(chunk, input, tlen, offset, trie, processed_tbl, log_obj, res, part)
+ local matches = trie:match(chunk)
+
+ local last = tlen
+
+ local function add_result(weight, ext)
+ if not res[ext] then
+ res[ext] = 0
+ end
+ if weight then
+ res[ext] = res[ext] + weight
+ else
+ res[ext] = res[ext] + 1
+ end
+
+ lua_util.debugm(N, log_obj, 'add pattern for %s, weight %s, total weight %s',
+ ext, weight, res[ext])
+ end
+
+ local function match_position(pos, expected)
+ local cmp = function(a, b)
+ return a == b
+ end
+ if type(expected) == 'table' then
+ -- Something like {'>', 0}
+ if expected[1] == '>' then
+ cmp = function(a, b)
+ return a > b
+ end
+ elseif expected[1] == '>=' then
+ cmp = function(a, b)
+ return a >= b
+ end
+ elseif expected[1] == '<' then
+ cmp = function(a, b)
+ return a < b
+ end
+ elseif expected[1] == '<=' then
+ cmp = function(a, b)
+ return a <= b
+ end
+ elseif expected[1] == '!=' then
+ cmp = function(a, b)
+ return a ~= b
+ end
+ end
+ expected = expected[2]
+ end
+
+ -- Tail match
+ if expected < 0 then
+ expected = last + expected + 1
+ end
+ return cmp(pos, expected)
+ end
+
+ for npat, matched_positions in pairs(matches) do
+ local pat_data = processed_tbl[npat]
+ local pattern = pat_data[3]
+ local match = pat_data[2]
+
+ -- Single position
+ if match.position then
+ local position = match.position
+
+ for _, pos in ipairs(matched_positions) do
+ lua_util.debugm(N, log_obj, 'found match %s at offset %s(from %s)',
+ pattern.ext, pos, offset)
+ if match_position(pos + offset, position) then
+ if match.heuristic then
+ local ext, weight = match.heuristic(input, log_obj, pos + offset, part)
+
+ if ext then
+ add_result(weight, ext)
+ break
+ end
+ else
+ add_result(match.weight, pattern.ext)
+ break
+ end
+ end
+ end
+ elseif match.positions then
+ -- Match all positions
+ local all_right = true
+ local matched_pos = 0
+ for _, position in ipairs(match.positions) do
+ local matched = false
+ for _, pos in ipairs(matched_positions) do
+ lua_util.debugm(N, log_obj, 'found match %s at offset %s(from %s)',
+ pattern.ext, pos, offset)
+ if not match_position(pos + offset, position) then
+ matched = true
+ matched_pos = pos
+ break
+ end
+ end
+ if not matched then
+ all_right = false
+ break
+ end
+ end
+
+ if all_right then
+ if match.heuristic then
+ local ext, weight = match.heuristic(input, log_obj, matched_pos + offset, part)
+
+ if ext then
+ add_result(weight, ext)
+ break
+ end
+ else
+ add_result(match.weight, pattern.ext)
+ break
+ end
+ end
+ end
+ end
+
+end
+
+local function process_detected(res)
+ local extensions = lua_util.keys(res)
+
+ if #extensions > 0 then
+ table.sort(extensions, function(ex1, ex2)
+ return res[ex1] > res[ex2]
+ end)
+
+ return extensions, res[extensions[1]]
+ end
+
+ return nil
+end
+
+exports.detect = function(part, log_obj)
+ if not log_obj then
+ log_obj = rspamd_config
+ end
+ local input = part:get_content()
+
+ local res = {}
+
+ if type(input) == 'string' then
+ -- Convert to rspamd_text
+ input = rspamd_text.fromstring(input)
+ end
+
+ if type(input) == 'userdata' then
+ local inplen = #input
+
+ -- Check tail matches
+ if inplen > min_tail_offset then
+ local tail = input:span(inplen - min_tail_offset, min_tail_offset)
+ match_chunk(tail, input, inplen, inplen - min_tail_offset,
+ compiled_tail_patterns, tail_patterns, log_obj, res, part)
+ end
+
+ -- Try short match
+ local head = input:span(1, math.min(max_short_offset, inplen))
+ match_chunk(head, input, inplen, 0,
+ compiled_short_patterns, short_patterns, log_obj, res, part)
+
+ -- Check if we have enough data or go to long patterns
+ local extensions, confidence = process_detected(res)
+
+ if extensions and #extensions > 0 and confidence > 30 then
+ -- We are done on short patterns
+ return extensions[1], types[extensions[1]]
+ end
+
+ -- No way, let's check data in chunks or just the whole input if it is small enough
+ if #input > exports.chunk_size * 3 then
+ -- Chunked version as input is too long
+ local chunk1, chunk2 = input:span(1, exports.chunk_size * 2),
+ input:span(inplen - exports.chunk_size, exports.chunk_size)
+ local offset1, offset2 = 0, inplen - exports.chunk_size
+
+ match_chunk(chunk1, input, inplen,
+ offset1, compiled_patterns, processed_patterns, log_obj, res, part)
+ match_chunk(chunk2, input, inplen,
+ offset2, compiled_patterns, processed_patterns, log_obj, res, part)
+ else
+ -- Input is short enough to match it at all
+ match_chunk(input, input, inplen, 0,
+ compiled_patterns, processed_patterns, log_obj, res, part)
+ end
+ else
+ -- Table input is NYI
+ assert(0, 'table input for match')
+ end
+
+ local extensions = process_detected(res)
+
+ if extensions and #extensions > 0 then
+ return extensions[1], types[extensions[1]]
+ end
+
+ -- Nothing found
+ return nil
+end
+
+exports.detect_mime_part = function(part, log_obj)
+ local ext, weight = heuristics.mime_part_heuristic(part, log_obj)
+
+ if ext and weight and weight > 20 then
+ return ext, types[ext]
+ end
+
+ ext = exports.detect(part, log_obj)
+
+ if ext then
+ return ext, types[ext]
+ end
+
+ -- Text/html and other parts
+ ext, weight = heuristics.text_part_heuristic(part, log_obj)
+ if ext and weight and weight > 20 then
+ return ext, types[ext]
+ end
+end
+
+-- This parameter specifies how many bytes are checked in the input
+-- Rspamd checks 2 chunks at start and 1 chunk at the end
+exports.chunk_size = 32768
+
+exports.types = types
+
+return exports \ No newline at end of file