diff options
Diffstat (limited to '')
54 files changed, 34142 insertions, 0 deletions
diff --git a/src/plugins/chartable.cxx b/src/plugins/chartable.cxx new file mode 100644 index 0000000..704f12a --- /dev/null +++ b/src/plugins/chartable.cxx @@ -0,0 +1,2122 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/***MODULE:chartable + * rspamd module that make marks based on symbol chains + * + * Allowed options: + * - symbol (string): symbol to insert (default: 'R_BAD_CHARSET') + * - threshold (double): value that would be used as threshold in expression characters_changed / total_characters + * (e.g. if threshold is 0.1 than charset change should occur more often than in 10 symbols), default: 0.1 + */ + + +#include "config.h" +#include "libmime/message.h" +#include "rspamd.h" +#include "libstat/stat_api.h" +#include "libmime/lang_detection.h" + +#include "unicode/utf8.h" +#include "unicode/uchar.h" +#include "contrib/ankerl/unordered_dense.h" + +#define DEFAULT_SYMBOL "R_MIXED_CHARSET" +#define DEFAULT_URL_SYMBOL "R_MIXED_CHARSET_URL" +#define DEFAULT_THRESHOLD 0.1 + +#define msg_debug_chartable(...) rspamd_conditional_debug_fast(nullptr, task->from_addr, \ + rspamd_chartable_log_id, "chartable", task->task_pool->tag.uid, \ + G_STRFUNC, \ + __VA_ARGS__) + +INIT_LOG_MODULE(chartable) + +/* Initialization */ +gint chartable_module_init(struct rspamd_config *cfg, struct module_ctx **ctx); + +gint chartable_module_config(struct rspamd_config *cfg, bool validate); + +gint chartable_module_reconfig(struct rspamd_config *cfg); + +module_t chartable_module = { + "chartable", + chartable_module_init, + chartable_module_config, + chartable_module_reconfig, + nullptr, + RSPAMD_MODULE_VER, + (guint) -1, +}; + +struct chartable_ctx { + struct module_ctx ctx; + const gchar *symbol; + const gchar *url_symbol; + double threshold; + guint max_word_len; +}; + +static inline struct chartable_ctx * +chartable_get_context(struct rspamd_config *cfg) +{ + return (struct chartable_ctx *) g_ptr_array_index(cfg->c_modules, + chartable_module.ctx_offset); +} + +static void chartable_symbol_callback(struct rspamd_task *task, + struct rspamd_symcache_dynamic_item *item, + void *unused); + +static void chartable_url_symbol_callback(struct rspamd_task *task, + struct rspamd_symcache_dynamic_item *item, + void *unused); + +gint chartable_module_init(struct rspamd_config *cfg, struct module_ctx **ctx) +{ + struct chartable_ctx *chartable_module_ctx; + + chartable_module_ctx = rspamd_mempool_alloc0_type(cfg->cfg_pool, + struct chartable_ctx); + chartable_module_ctx->max_word_len = 10; + + *ctx = (struct module_ctx *) chartable_module_ctx; + + return 0; +} + + +gint chartable_module_config(struct rspamd_config *cfg, bool _) +{ + const ucl_object_t *value; + gint res = TRUE; + struct chartable_ctx *chartable_module_ctx = chartable_get_context(cfg); + + if (!rspamd_config_is_module_enabled(cfg, "chartable")) { + return TRUE; + } + + if ((value = + rspamd_config_get_module_opt(cfg, "chartable", "symbol")) != nullptr) { + chartable_module_ctx->symbol = ucl_obj_tostring(value); + } + else { + chartable_module_ctx->symbol = DEFAULT_SYMBOL; + } + if ((value = + rspamd_config_get_module_opt(cfg, "chartable", "url_symbol")) != nullptr) { + chartable_module_ctx->url_symbol = ucl_obj_tostring(value); + } + else { + chartable_module_ctx->url_symbol = DEFAULT_URL_SYMBOL; + } + if ((value = + rspamd_config_get_module_opt(cfg, "chartable", "threshold")) != nullptr) { + if (!ucl_obj_todouble_safe(value, &chartable_module_ctx->threshold)) { + msg_warn_config("invalid numeric value"); + chartable_module_ctx->threshold = DEFAULT_THRESHOLD; + } + } + else { + chartable_module_ctx->threshold = DEFAULT_THRESHOLD; + } + if ((value = + rspamd_config_get_module_opt(cfg, "chartable", "max_word_len")) != nullptr) { + chartable_module_ctx->max_word_len = ucl_object_toint(value); + } + else { + chartable_module_ctx->threshold = DEFAULT_THRESHOLD; + } + + rspamd_symcache_add_symbol(cfg->cache, + chartable_module_ctx->symbol, + 0, + chartable_symbol_callback, + nullptr, + SYMBOL_TYPE_NORMAL, + -1); + rspamd_symcache_add_symbol(cfg->cache, + chartable_module_ctx->url_symbol, + 0, + chartable_url_symbol_callback, + nullptr, + SYMBOL_TYPE_NORMAL, + -1); + + msg_info_config("init internal chartable module"); + + return res; +} + +gint chartable_module_reconfig(struct rspamd_config *cfg) +{ + return chartable_module_config(cfg, false); +} + +static const auto latin_confusable = ankerl::unordered_dense::set<int>{ + 0x02028, + 0x02029, + 0x01680, + 0x02000, + 0x02001, + 0x02002, + 0x02003, + 0x02004, + 0x02005, + 0x02006, + 0x02008, + 0x02009, + 0x0200a, + 0x0205f, + 0x000a0, + 0x02007, + 0x0202f, + 0x007fa, + 0x0fe4d, + 0x0fe4e, + 0x0fe4f, + 0x02010, + 0x02011, + 0x02012, + 0x02013, + 0x0fe58, + 0x006d4, + 0x02043, + 0x002d7, + 0x02212, + 0x02796, + 0x02cba, + 0x0060d, + 0x0066b, + 0x0201a, + 0x000b8, + 0x0a4f9, + 0x0037e, + 0x00903, + 0x00a83, + 0x0ff1a, + 0x00589, + 0x00703, + 0x00704, + 0x016ec, + 0x0fe30, + 0x01803, + 0x01809, + 0x0205a, + 0x005c3, + 0x002f8, + 0x0a789, + 0x02236, + 0x002d0, + 0x0a4fd, + 0x0ff01, + 0x001c3, + 0x02d51, + 0x00294, + 0x00241, + 0x0097d, + 0x013ae, + 0x0a6eb, + 0x1d16d, + 0x02024, + 0x00701, + 0x00702, + 0x0a60e, + 0x10a50, + 0x00660, + 0x006f0, + 0x0a4f8, + 0x0055d, + 0x0ff07, + 0x02018, + 0x02019, + 0x0201b, + 0x02032, + 0x02035, + 0x0055a, + 0x005f3, + 0x00060, + 0x01fef, + 0x0ff40, + 0x000b4, + 0x00384, + 0x01ffd, + 0x01fbd, + 0x01fbf, + 0x01ffe, + 0x002b9, + 0x00374, + 0x002c8, + 0x002ca, + 0x002cb, + 0x002f4, + 0x002bb, + 0x002bd, + 0x002bc, + 0x002be, + 0x0a78c, + 0x005d9, + 0x007f4, + 0x007f5, + 0x0144a, + 0x016cc, + 0x16f51, + 0x16f52, + 0x0ff3b, + 0x02768, + 0x02772, + 0x03014, + 0x0fd3e, + 0x0ff3d, + 0x02769, + 0x02773, + 0x03015, + 0x0fd3f, + 0x02774, + 0x1d114, + 0x02775, + 0x0204e, + 0x0066d, + 0x02217, + 0x1031f, + 0x01735, + 0x02041, + 0x02215, + 0x02044, + 0x02571, + 0x027cb, + 0x029f8, + 0x1d23a, + 0x031d3, + 0x03033, + 0x02cc6, + 0x030ce, + 0x04e3f, + 0x02f03, + 0x0ff3c, + 0x0fe68, + 0x02216, + 0x027cd, + 0x029f5, + 0x029f9, + 0x1d20f, + 0x1d23b, + 0x031d4, + 0x04e36, + 0x02f02, + 0x0a778, + 0x002c4, + 0x002c6, + 0x016ed, + 0x02795, + 0x1029b, + 0x02039, + 0x0276e, + 0x002c2, + 0x1d236, + 0x01438, + 0x016b2, + 0x01400, + 0x02e40, + 0x030a0, + 0x0a4ff, + 0x0203a, + 0x0276f, + 0x002c3, + 0x1d237, + 0x01433, + 0x16f3f, + 0x02053, + 0x002dc, + 0x01fc0, + 0x0223c, + 0x1d7d0, + 0x1d7da, + 0x1d7e4, + 0x1d7ee, + 0x1d7f8, + 0x0a75a, + 0x001a7, + 0x003e8, + 0x0a644, + 0x014bf, + 0x0a6ef, + 0x1d206, + 0x1d7d1, + 0x1d7db, + 0x1d7e5, + 0x1d7ef, + 0x1d7f9, + 0x0a7ab, + 0x0021c, + 0x001b7, + 0x0a76a, + 0x02ccc, + 0x00417, + 0x004e0, + 0x16f3b, + 0x118ca, + 0x1d7d2, + 0x1d7dc, + 0x1d7e6, + 0x1d7f0, + 0x1d7fa, + 0x013ce, + 0x118af, + 0x1d7d3, + 0x1d7dd, + 0x1d7e7, + 0x1d7f1, + 0x1d7fb, + 0x001bc, + 0x118bb, + 0x1d7d4, + 0x1d7de, + 0x1d7e8, + 0x1d7f2, + 0x1d7fc, + 0x02cd2, + 0x00431, + 0x013ee, + 0x118d5, + 0x1d212, + 0x1d7d5, + 0x1d7df, + 0x1d7e9, + 0x1d7f3, + 0x1d7fd, + 0x104d2, + 0x118c6, + 0x00b03, + 0x009ea, + 0x00a6a, + 0x1e8cb, + 0x1d7d6, + 0x1d7e0, + 0x1d7ea, + 0x1d7f4, + 0x1d7fe, + 0x00223, + 0x00222, + 0x1031a, + 0x00a67, + 0x00b68, + 0x009ed, + 0x00d6d, + 0x1d7d7, + 0x1d7e1, + 0x1d7eb, + 0x1d7f5, + 0x1d7ff, + 0x0a76e, + 0x02cca, + 0x118cc, + 0x118ac, + 0x118d6, + 0x0237a, + 0x0ff41, + 0x1d41a, + 0x1d44e, + 0x1d482, + 0x1d4b6, + 0x1d4ea, + 0x1d51e, + 0x1d552, + 0x1d586, + 0x1d5ba, + 0x1d5ee, + 0x1d622, + 0x1d656, + 0x1d68a, + 0x00251, + 0x003b1, + 0x1d6c2, + 0x1d6fc, + 0x1d736, + 0x1d770, + 0x1d7aa, + 0x00430, + 0x0ff21, + 0x1d400, + 0x1d434, + 0x1d468, + 0x1d49c, + 0x1d4d0, + 0x1d504, + 0x1d538, + 0x1d56c, + 0x1d5a0, + 0x1d5d4, + 0x1d608, + 0x1d63c, + 0x1d670, + 0x00391, + 0x1d6a8, + 0x1d6e2, + 0x1d71c, + 0x1d756, + 0x1d790, + 0x00410, + 0x013aa, + 0x015c5, + 0x0a4ee, + 0x16f40, + 0x102a0, + 0x1d41b, + 0x1d44f, + 0x1d483, + 0x1d4b7, + 0x1d4eb, + 0x1d51f, + 0x1d553, + 0x1d587, + 0x1d5bb, + 0x1d5ef, + 0x1d623, + 0x1d657, + 0x1d68b, + 0x00184, + 0x0042c, + 0x013cf, + 0x015af, + 0x0ff22, + 0x0212c, + 0x1d401, + 0x1d435, + 0x1d469, + 0x1d4d1, + 0x1d505, + 0x1d539, + 0x1d56d, + 0x1d5a1, + 0x1d5d5, + 0x1d609, + 0x1d63d, + 0x1d671, + 0x0a7b4, + 0x00392, + 0x1d6a9, + 0x1d6e3, + 0x1d71d, + 0x1d757, + 0x1d791, + 0x00412, + 0x013f4, + 0x015f7, + 0x0a4d0, + 0x10282, + 0x102a1, + 0x10301, + 0x0ff43, + 0x0217d, + 0x1d41c, + 0x1d450, + 0x1d484, + 0x1d4b8, + 0x1d4ec, + 0x1d520, + 0x1d554, + 0x1d588, + 0x1d5bc, + 0x1d5f0, + 0x1d624, + 0x1d658, + 0x1d68c, + 0x01d04, + 0x003f2, + 0x02ca5, + 0x00441, + 0x0abaf, + 0x1043d, + 0x1f74c, + 0x118f2, + 0x118e9, + 0x0ff23, + 0x0216d, + 0x02102, + 0x0212d, + 0x1d402, + 0x1d436, + 0x1d46a, + 0x1d49e, + 0x1d4d2, + 0x1d56e, + 0x1d5a2, + 0x1d5d6, + 0x1d60a, + 0x1d63e, + 0x1d672, + 0x003f9, + 0x02ca4, + 0x00421, + 0x013df, + 0x0a4da, + 0x102a2, + 0x10302, + 0x10415, + 0x1051c, + 0x0217e, + 0x02146, + 0x1d41d, + 0x1d451, + 0x1d485, + 0x1d4b9, + 0x1d4ed, + 0x1d521, + 0x1d555, + 0x1d589, + 0x1d5bd, + 0x1d5f1, + 0x1d625, + 0x1d659, + 0x1d68d, + 0x00501, + 0x013e7, + 0x0146f, + 0x0a4d2, + 0x0216e, + 0x02145, + 0x1d403, + 0x1d437, + 0x1d46b, + 0x1d49f, + 0x1d4d3, + 0x1d507, + 0x1d53b, + 0x1d56f, + 0x1d5a3, + 0x1d5d7, + 0x1d60b, + 0x1d63f, + 0x1d673, + 0x013a0, + 0x015de, + 0x015ea, + 0x0a4d3, + 0x0212e, + 0x0ff45, + 0x0212f, + 0x02147, + 0x1d41e, + 0x1d452, + 0x1d486, + 0x1d4ee, + 0x1d522, + 0x1d556, + 0x1d58a, + 0x1d5be, + 0x1d5f2, + 0x1d626, + 0x1d65a, + 0x1d68e, + 0x0ab32, + 0x00435, + 0x004bd, + 0x022ff, + 0x0ff25, + 0x02130, + 0x1d404, + 0x1d438, + 0x1d46c, + 0x1d4d4, + 0x1d508, + 0x1d53c, + 0x1d570, + 0x1d5a4, + 0x1d5d8, + 0x1d60c, + 0x1d640, + 0x1d674, + 0x00395, + 0x1d6ac, + 0x1d6e6, + 0x1d720, + 0x1d75a, + 0x1d794, + 0x00415, + 0x02d39, + 0x013ac, + 0x0a4f0, + 0x118a6, + 0x118ae, + 0x10286, + 0x1d41f, + 0x1d453, + 0x1d487, + 0x1d4bb, + 0x1d4ef, + 0x1d523, + 0x1d557, + 0x1d58b, + 0x1d5bf, + 0x1d5f3, + 0x1d627, + 0x1d65b, + 0x1d68f, + 0x0ab35, + 0x0a799, + 0x0017f, + 0x01e9d, + 0x00584, + 0x1d213, + 0x02131, + 0x1d405, + 0x1d439, + 0x1d46d, + 0x1d4d5, + 0x1d509, + 0x1d53d, + 0x1d571, + 0x1d5a5, + 0x1d5d9, + 0x1d60d, + 0x1d641, + 0x1d675, + 0x0a798, + 0x003dc, + 0x1d7ca, + 0x015b4, + 0x0a4dd, + 0x118c2, + 0x118a2, + 0x10287, + 0x102a5, + 0x10525, + 0x0ff47, + 0x0210a, + 0x1d420, + 0x1d454, + 0x1d488, + 0x1d4f0, + 0x1d524, + 0x1d558, + 0x1d58c, + 0x1d5c0, + 0x1d5f4, + 0x1d628, + 0x1d65c, + 0x1d690, + 0x00261, + 0x01d83, + 0x0018d, + 0x00581, + 0x1d406, + 0x1d43a, + 0x1d46e, + 0x1d4a2, + 0x1d4d6, + 0x1d50a, + 0x1d53e, + 0x1d572, + 0x1d5a6, + 0x1d5da, + 0x1d60e, + 0x1d642, + 0x1d676, + 0x0050c, + 0x013c0, + 0x013f3, + 0x0a4d6, + 0x0ff48, + 0x0210e, + 0x1d421, + 0x1d489, + 0x1d4bd, + 0x1d4f1, + 0x1d525, + 0x1d559, + 0x1d58d, + 0x1d5c1, + 0x1d5f5, + 0x1d629, + 0x1d65d, + 0x1d691, + 0x004bb, + 0x00570, + 0x013c2, + 0x0ff28, + 0x0210b, + 0x0210c, + 0x0210d, + 0x1d407, + 0x1d43b, + 0x1d46f, + 0x1d4d7, + 0x1d573, + 0x1d5a7, + 0x1d5db, + 0x1d60f, + 0x1d643, + 0x1d677, + 0x00397, + 0x1d6ae, + 0x1d6e8, + 0x1d722, + 0x1d75c, + 0x1d796, + 0x02c8e, + 0x0041d, + 0x013bb, + 0x0157c, + 0x0a4e7, + 0x102cf, + 0x002db, + 0x02373, + 0x0ff49, + 0x02170, + 0x02139, + 0x02148, + 0x1d422, + 0x1d456, + 0x1d48a, + 0x1d4be, + 0x1d4f2, + 0x1d526, + 0x1d55a, + 0x1d58e, + 0x1d5c2, + 0x1d5f6, + 0x1d62a, + 0x1d65e, + 0x1d692, + 0x00131, + 0x1d6a4, + 0x0026a, + 0x00269, + 0x003b9, + 0x01fbe, + 0x0037a, + 0x1d6ca, + 0x1d704, + 0x1d73e, + 0x1d778, + 0x1d7b2, + 0x00456, + 0x0a647, + 0x004cf, + 0x0ab75, + 0x013a5, + 0x118c3, + 0x0ff4a, + 0x02149, + 0x1d423, + 0x1d457, + 0x1d48b, + 0x1d4bf, + 0x1d4f3, + 0x1d527, + 0x1d55b, + 0x1d58f, + 0x1d5c3, + 0x1d5f7, + 0x1d62b, + 0x1d65f, + 0x1d693, + 0x003f3, + 0x00458, + 0x0ff2a, + 0x1d409, + 0x1d43d, + 0x1d471, + 0x1d4a5, + 0x1d4d9, + 0x1d50d, + 0x1d541, + 0x1d575, + 0x1d5a9, + 0x1d5dd, + 0x1d611, + 0x1d645, + 0x1d679, + 0x0a7b2, + 0x0037f, + 0x00408, + 0x013ab, + 0x0148d, + 0x0a4d9, + 0x1d424, + 0x1d458, + 0x1d48c, + 0x1d4c0, + 0x1d4f4, + 0x1d528, + 0x1d55c, + 0x1d590, + 0x1d5c4, + 0x1d5f8, + 0x1d62c, + 0x1d660, + 0x1d694, + 0x0212a, + 0x0ff2b, + 0x1d40a, + 0x1d43e, + 0x1d472, + 0x1d4a6, + 0x1d4da, + 0x1d50e, + 0x1d542, + 0x1d576, + 0x1d5aa, + 0x1d5de, + 0x1d612, + 0x1d646, + 0x1d67a, + 0x0039a, + 0x1d6b1, + 0x1d6eb, + 0x1d725, + 0x1d75f, + 0x1d799, + 0x02c94, + 0x0041a, + 0x013e6, + 0x016d5, + 0x0a4d7, + 0x10518, + 0x005c0, + 0x0007c, + 0x02223, + 0x023fd, + 0x0ffe8, + 0x00031, + 0x00661, + 0x006f1, + 0x10320, + 0x1e8c7, + 0x1d7cf, + 0x1d7d9, + 0x1d7e3, + 0x1d7ed, + 0x1d7f7, + 0x00049, + 0x0ff29, + 0x02160, + 0x02110, + 0x02111, + 0x1d408, + 0x1d43c, + 0x1d470, + 0x1d4d8, + 0x1d540, + 0x1d574, + 0x1d5a8, + 0x1d5dc, + 0x1d610, + 0x1d644, + 0x1d678, + 0x00196, + 0x0ff4c, + 0x0217c, + 0x02113, + 0x1d425, + 0x1d459, + 0x1d48d, + 0x1d4c1, + 0x1d4f5, + 0x1d529, + 0x1d55d, + 0x1d591, + 0x1d5c5, + 0x1d5f9, + 0x1d62d, + 0x1d661, + 0x1d695, + 0x001c0, + 0x00399, + 0x1d6b0, + 0x1d6ea, + 0x1d724, + 0x1d75e, + 0x1d798, + 0x02c92, + 0x00406, + 0x004c0, + 0x005d5, + 0x005df, + 0x00627, + 0x1ee00, + 0x1ee80, + 0x0fe8e, + 0x0fe8d, + 0x007ca, + 0x02d4f, + 0x016c1, + 0x0a4f2, + 0x16f28, + 0x1028a, + 0x10309, + 0x1d22a, + 0x0216c, + 0x02112, + 0x1d40b, + 0x1d43f, + 0x1d473, + 0x1d4db, + 0x1d50f, + 0x1d543, + 0x1d577, + 0x1d5ab, + 0x1d5df, + 0x1d613, + 0x1d647, + 0x1d67b, + 0x02cd0, + 0x013de, + 0x014aa, + 0x0a4e1, + 0x16f16, + 0x118a3, + 0x118b2, + 0x1041b, + 0x10526, + 0x0ff2d, + 0x0216f, + 0x02133, + 0x1d40c, + 0x1d440, + 0x1d474, + 0x1d4dc, + 0x1d510, + 0x1d544, + 0x1d578, + 0x1d5ac, + 0x1d5e0, + 0x1d614, + 0x1d648, + 0x1d67c, + 0x0039c, + 0x1d6b3, + 0x1d6ed, + 0x1d727, + 0x1d761, + 0x1d79b, + 0x003fa, + 0x02c98, + 0x0041c, + 0x013b7, + 0x015f0, + 0x016d6, + 0x0a4df, + 0x102b0, + 0x10311, + 0x1d427, + 0x1d45b, + 0x1d48f, + 0x1d4c3, + 0x1d4f7, + 0x1d52b, + 0x1d55f, + 0x1d593, + 0x1d5c7, + 0x1d5fb, + 0x1d62f, + 0x1d663, + 0x1d697, + 0x00578, + 0x0057c, + 0x0ff2e, + 0x02115, + 0x1d40d, + 0x1d441, + 0x1d475, + 0x1d4a9, + 0x1d4dd, + 0x1d511, + 0x1d579, + 0x1d5ad, + 0x1d5e1, + 0x1d615, + 0x1d649, + 0x1d67d, + 0x0039d, + 0x1d6b4, + 0x1d6ee, + 0x1d728, + 0x1d762, + 0x1d79c, + 0x02c9a, + 0x0a4e0, + 0x10513, + 0x00c02, + 0x00c82, + 0x00d02, + 0x00d82, + 0x00966, + 0x00a66, + 0x00ae6, + 0x00be6, + 0x00c66, + 0x00ce6, + 0x00d66, + 0x00e50, + 0x00ed0, + 0x01040, + 0x00665, + 0x006f5, + 0x0ff4f, + 0x02134, + 0x1d428, + 0x1d45c, + 0x1d490, + 0x1d4f8, + 0x1d52c, + 0x1d560, + 0x1d594, + 0x1d5c8, + 0x1d5fc, + 0x1d630, + 0x1d664, + 0x1d698, + 0x01d0f, + 0x01d11, + 0x0ab3d, + 0x003bf, + 0x1d6d0, + 0x1d70a, + 0x1d744, + 0x1d77e, + 0x1d7b8, + 0x003c3, + 0x1d6d4, + 0x1d70e, + 0x1d748, + 0x1d782, + 0x1d7bc, + 0x02c9f, + 0x0043e, + 0x010ff, + 0x00585, + 0x005e1, + 0x00647, + 0x1ee24, + 0x1ee64, + 0x1ee84, + 0x0feeb, + 0x0feec, + 0x0feea, + 0x0fee9, + 0x006be, + 0x0fbac, + 0x0fbad, + 0x0fbab, + 0x0fbaa, + 0x006c1, + 0x0fba8, + 0x0fba9, + 0x0fba7, + 0x0fba6, + 0x006d5, + 0x00d20, + 0x0101d, + 0x104ea, + 0x118c8, + 0x118d7, + 0x1042c, + 0x00030, + 0x007c0, + 0x009e6, + 0x00b66, + 0x03007, + 0x114d0, + 0x118e0, + 0x1d7ce, + 0x1d7d8, + 0x1d7e2, + 0x1d7ec, + 0x1d7f6, + 0x0ff2f, + 0x1d40e, + 0x1d442, + 0x1d476, + 0x1d4aa, + 0x1d4de, + 0x1d512, + 0x1d546, + 0x1d57a, + 0x1d5ae, + 0x1d5e2, + 0x1d616, + 0x1d64a, + 0x1d67e, + 0x0039f, + 0x1d6b6, + 0x1d6f0, + 0x1d72a, + 0x1d764, + 0x1d79e, + 0x02c9e, + 0x0041e, + 0x00555, + 0x02d54, + 0x012d0, + 0x00b20, + 0x104c2, + 0x0a4f3, + 0x118b5, + 0x10292, + 0x102ab, + 0x10404, + 0x10516, + 0x02374, + 0x0ff50, + 0x1d429, + 0x1d45d, + 0x1d491, + 0x1d4c5, + 0x1d4f9, + 0x1d52d, + 0x1d561, + 0x1d595, + 0x1d5c9, + 0x1d5fd, + 0x1d631, + 0x1d665, + 0x1d699, + 0x003c1, + 0x003f1, + 0x1d6d2, + 0x1d6e0, + 0x1d70c, + 0x1d71a, + 0x1d746, + 0x1d754, + 0x1d780, + 0x1d78e, + 0x1d7ba, + 0x1d7c8, + 0x02ca3, + 0x00440, + 0x0ff30, + 0x02119, + 0x1d40f, + 0x1d443, + 0x1d477, + 0x1d4ab, + 0x1d4df, + 0x1d513, + 0x1d57b, + 0x1d5af, + 0x1d5e3, + 0x1d617, + 0x1d64b, + 0x1d67f, + 0x003a1, + 0x1d6b8, + 0x1d6f2, + 0x1d72c, + 0x1d766, + 0x1d7a0, + 0x02ca2, + 0x00420, + 0x013e2, + 0x0146d, + 0x0a4d1, + 0x10295, + 0x1d42a, + 0x1d45e, + 0x1d492, + 0x1d4c6, + 0x1d4fa, + 0x1d52e, + 0x1d562, + 0x1d596, + 0x1d5ca, + 0x1d5fe, + 0x1d632, + 0x1d666, + 0x1d69a, + 0x0051b, + 0x00563, + 0x00566, + 0x0211a, + 0x1d410, + 0x1d444, + 0x1d478, + 0x1d4ac, + 0x1d4e0, + 0x1d514, + 0x1d57c, + 0x1d5b0, + 0x1d5e4, + 0x1d618, + 0x1d64c, + 0x1d680, + 0x02d55, + 0x1d42b, + 0x1d45f, + 0x1d493, + 0x1d4c7, + 0x1d4fb, + 0x1d52f, + 0x1d563, + 0x1d597, + 0x1d5cb, + 0x1d5ff, + 0x1d633, + 0x1d667, + 0x1d69b, + 0x0ab47, + 0x0ab48, + 0x01d26, + 0x02c85, + 0x00433, + 0x0ab81, + 0x1d216, + 0x0211b, + 0x0211c, + 0x0211d, + 0x1d411, + 0x1d445, + 0x1d479, + 0x1d4e1, + 0x1d57d, + 0x1d5b1, + 0x1d5e5, + 0x1d619, + 0x1d64d, + 0x1d681, + 0x001a6, + 0x013a1, + 0x013d2, + 0x104b4, + 0x01587, + 0x0a4e3, + 0x16f35, + 0x0ff53, + 0x1d42c, + 0x1d460, + 0x1d494, + 0x1d4c8, + 0x1d4fc, + 0x1d530, + 0x1d564, + 0x1d598, + 0x1d5cc, + 0x1d600, + 0x1d634, + 0x1d668, + 0x1d69c, + 0x0a731, + 0x001bd, + 0x00455, + 0x0abaa, + 0x118c1, + 0x10448, + 0x0ff33, + 0x1d412, + 0x1d446, + 0x1d47a, + 0x1d4ae, + 0x1d4e2, + 0x1d516, + 0x1d54a, + 0x1d57e, + 0x1d5b2, + 0x1d5e6, + 0x1d61a, + 0x1d64e, + 0x1d682, + 0x00405, + 0x0054f, + 0x013d5, + 0x013da, + 0x0a4e2, + 0x16f3a, + 0x10296, + 0x10420, + 0x1d42d, + 0x1d461, + 0x1d495, + 0x1d4c9, + 0x1d4fd, + 0x1d531, + 0x1d565, + 0x1d599, + 0x1d5cd, + 0x1d601, + 0x1d635, + 0x1d669, + 0x1d69d, + 0x022a4, + 0x027d9, + 0x1f768, + 0x0ff34, + 0x1d413, + 0x1d447, + 0x1d47b, + 0x1d4af, + 0x1d4e3, + 0x1d517, + 0x1d54b, + 0x1d57f, + 0x1d5b3, + 0x1d5e7, + 0x1d61b, + 0x1d64f, + 0x1d683, + 0x003a4, + 0x1d6bb, + 0x1d6f5, + 0x1d72f, + 0x1d769, + 0x1d7a3, + 0x02ca6, + 0x00422, + 0x013a2, + 0x0a4d4, + 0x16f0a, + 0x118bc, + 0x10297, + 0x102b1, + 0x10315, + 0x1d42e, + 0x1d462, + 0x1d496, + 0x1d4ca, + 0x1d4fe, + 0x1d532, + 0x1d566, + 0x1d59a, + 0x1d5ce, + 0x1d602, + 0x1d636, + 0x1d66a, + 0x1d69e, + 0x0a79f, + 0x01d1c, + 0x0ab4e, + 0x0ab52, + 0x0028b, + 0x003c5, + 0x1d6d6, + 0x1d710, + 0x1d74a, + 0x1d784, + 0x1d7be, + 0x0057d, + 0x104f6, + 0x118d8, + 0x0222a, + 0x022c3, + 0x1d414, + 0x1d448, + 0x1d47c, + 0x1d4b0, + 0x1d4e4, + 0x1d518, + 0x1d54c, + 0x1d580, + 0x1d5b4, + 0x1d5e8, + 0x1d61c, + 0x1d650, + 0x1d684, + 0x0054d, + 0x01200, + 0x104ce, + 0x0144c, + 0x0a4f4, + 0x16f42, + 0x118b8, + 0x02228, + 0x022c1, + 0x0ff56, + 0x02174, + 0x1d42f, + 0x1d463, + 0x1d497, + 0x1d4cb, + 0x1d4ff, + 0x1d533, + 0x1d567, + 0x1d59b, + 0x1d5cf, + 0x1d603, + 0x1d637, + 0x1d66b, + 0x1d69f, + 0x01d20, + 0x003bd, + 0x1d6ce, + 0x1d708, + 0x1d742, + 0x1d77c, + 0x1d7b6, + 0x00475, + 0x005d8, + 0x11706, + 0x0aba9, + 0x118c0, + 0x1d20d, + 0x00667, + 0x006f7, + 0x02164, + 0x1d415, + 0x1d449, + 0x1d47d, + 0x1d4b1, + 0x1d4e5, + 0x1d519, + 0x1d54d, + 0x1d581, + 0x1d5b5, + 0x1d5e9, + 0x1d61d, + 0x1d651, + 0x1d685, + 0x00474, + 0x02d38, + 0x013d9, + 0x0142f, + 0x0a6df, + 0x0a4e6, + 0x16f08, + 0x118a0, + 0x1051d, + 0x0026f, + 0x1d430, + 0x1d464, + 0x1d498, + 0x1d4cc, + 0x1d500, + 0x1d534, + 0x1d568, + 0x1d59c, + 0x1d5d0, + 0x1d604, + 0x1d638, + 0x1d66c, + 0x1d6a0, + 0x01d21, + 0x00461, + 0x0051d, + 0x00561, + 0x1170a, + 0x1170e, + 0x1170f, + 0x0ab83, + 0x118ef, + 0x118e6, + 0x1d416, + 0x1d44a, + 0x1d47e, + 0x1d4b2, + 0x1d4e6, + 0x1d51a, + 0x1d54e, + 0x1d582, + 0x1d5b6, + 0x1d5ea, + 0x1d61e, + 0x1d652, + 0x1d686, + 0x0051c, + 0x013b3, + 0x013d4, + 0x0a4ea, + 0x0166e, + 0x000d7, + 0x0292b, + 0x0292c, + 0x02a2f, + 0x0ff58, + 0x02179, + 0x1d431, + 0x1d465, + 0x1d499, + 0x1d4cd, + 0x1d501, + 0x1d535, + 0x1d569, + 0x1d59d, + 0x1d5d1, + 0x1d605, + 0x1d639, + 0x1d66d, + 0x1d6a1, + 0x00445, + 0x01541, + 0x0157d, + 0x0166d, + 0x02573, + 0x10322, + 0x118ec, + 0x0ff38, + 0x02169, + 0x1d417, + 0x1d44b, + 0x1d47f, + 0x1d4b3, + 0x1d4e7, + 0x1d51b, + 0x1d54f, + 0x1d583, + 0x1d5b7, + 0x1d5eb, + 0x1d61f, + 0x1d653, + 0x1d687, + 0x0a7b3, + 0x003a7, + 0x1d6be, + 0x1d6f8, + 0x1d732, + 0x1d76c, + 0x1d7a6, + 0x02cac, + 0x00425, + 0x02d5d, + 0x016b7, + 0x0a4eb, + 0x10290, + 0x102b4, + 0x10317, + 0x10527, + 0x00263, + 0x01d8c, + 0x0ff59, + 0x1d432, + 0x1d466, + 0x1d49a, + 0x1d4ce, + 0x1d502, + 0x1d536, + 0x1d56a, + 0x1d59e, + 0x1d5d2, + 0x1d606, + 0x1d63a, + 0x1d66e, + 0x1d6a2, + 0x0028f, + 0x01eff, + 0x0ab5a, + 0x003b3, + 0x0213d, + 0x1d6c4, + 0x1d6fe, + 0x1d738, + 0x1d772, + 0x1d7ac, + 0x00443, + 0x004af, + 0x010e7, + 0x118dc, + 0x0ff39, + 0x1d418, + 0x1d44c, + 0x1d480, + 0x1d4b4, + 0x1d4e8, + 0x1d51c, + 0x1d550, + 0x1d584, + 0x1d5b8, + 0x1d5ec, + 0x1d620, + 0x1d654, + 0x1d688, + 0x003a5, + 0x003d2, + 0x1d6bc, + 0x1d6f6, + 0x1d730, + 0x1d76a, + 0x1d7a4, + 0x02ca8, + 0x00423, + 0x004ae, + 0x013a9, + 0x013bd, + 0x0a4ec, + 0x16f43, + 0x118a4, + 0x102b2, + 0x1d433, + 0x1d467, + 0x1d49b, + 0x1d4cf, + 0x1d503, + 0x1d537, + 0x1d56b, + 0x1d59f, + 0x1d5d3, + 0x1d607, + 0x1d63b, + 0x1d66f, + 0x1d6a3, + 0x01d22, + 0x0ab93, + 0x118c4, + 0x102f5, + 0x118e5, + 0x0ff3a, + 0x02124, + 0x02128, + 0x1d419, + 0x1d44d, + 0x1d481, + 0x1d4b5, + 0x1d4e9, + 0x1d585, + 0x1d5b9, + 0x1d5ed, + 0x1d621, + 0x1d655, + 0x1d689, + 0x00396, + 0x1d6ad, + 0x1d6e7, + 0x1d721, + 0x1d75b, + 0x1d795, + 0x013c3, + 0x0a4dc, + 0x118a9, +}; + +static gboolean +rspamd_can_alias_latin(gint ch) +{ + return latin_confusable.contains(ch); +} + +static gdouble +rspamd_chartable_process_word_utf(struct rspamd_task *task, + rspamd_stat_token_t *w, + gboolean is_url, + guint *ncap, + struct chartable_ctx *chartable_module_ctx, + gboolean ignore_diacritics) +{ + const UChar32 *p, *end; + gdouble badness = 0.0; + UChar32 uc; + UBlockCode sc; + guint cat; + gint last_is_latin = -1; + guint same_script_count = 0, nsym = 0, nspecial = 0; + enum { + start_process = 0, + got_alpha, + got_digit, + got_unknown, + } state = start_process, + prev_state = start_process; + + p = w->unicode.begin; + end = p + w->unicode.len; + + /* We assume that w is normalized */ + + while (p < end) { + uc = *p++; + + if (((gint32) uc) < 0) { + break; + } + + sc = ublock_getCode(uc); + cat = u_charType(uc); + + if (!ignore_diacritics) { + if (cat == U_NON_SPACING_MARK || + (sc == UBLOCK_LATIN_1_SUPPLEMENT) || + (sc == UBLOCK_LATIN_EXTENDED_A) || + (sc == UBLOCK_LATIN_EXTENDED_ADDITIONAL) || + (sc == UBLOCK_LATIN_EXTENDED_B) || + (sc == UBLOCK_COMBINING_DIACRITICAL_MARKS)) { + nspecial++; + } + } + + if (u_isalpha(uc)) { + + if (sc <= UBLOCK_COMBINING_DIACRITICAL_MARKS || + sc == UBLOCK_LATIN_EXTENDED_ADDITIONAL) { + /* + * Assume all latin, IPA, diacritic and space modifiers + * characters as basic latin + */ + sc = UBLOCK_BASIC_LATIN; + } + + if (sc != UBLOCK_BASIC_LATIN && u_isupper(uc)) { + if (ncap) { + (*ncap)++; + } + } + + if (state == got_digit) { + /* Penalize digit -> alpha translations */ + if (!is_url && sc != UBLOCK_BASIC_LATIN && + prev_state != start_process) { + badness += 0.25; + } + } + else if (state == got_alpha) { + /* Check script */ + if (same_script_count > 0) { + if (sc != UBLOCK_BASIC_LATIN && last_is_latin) { + + if (rspamd_can_alias_latin(uc)) { + badness += 1.0 / (gdouble) same_script_count; + } + + last_is_latin = 0; + same_script_count = 1; + } + else { + same_script_count++; + } + } + else { + last_is_latin = sc == UBLOCK_BASIC_LATIN; + same_script_count = 1; + } + } + + prev_state = state; + state = got_alpha; + } + else if (u_isdigit(uc)) { + if (state != got_digit) { + prev_state = state; + } + + state = got_digit; + same_script_count = 0; + } + else { + /* We don't care about unknown characters here */ + if (state != got_unknown) { + prev_state = state; + } + + state = got_unknown; + same_script_count = 0; + } + + nsym++; + } + + if (nspecial > 0) { + if (!ignore_diacritics) { + /* Count diacritics */ + badness += nspecial; + } + else if (nspecial > 1) { + badness += (nspecial - 1.0) / 2.0; + } + } + + /* Try to avoid FP for long words */ + if (nsym > chartable_module_ctx->max_word_len) { + badness = 0; + } + else { + if (badness > 4.0) { + badness = 4.0; + } + } + + msg_debug_chartable("word %*s, badness: %.2f", + (gint) w->normalized.len, w->normalized.begin, + badness); + + return badness; +} + +static gdouble +rspamd_chartable_process_word_ascii(struct rspamd_task *task, + rspamd_stat_token_t *w, + gboolean is_url, + struct chartable_ctx *chartable_module_ctx) +{ + gdouble badness = 0.0; + enum { + ascii = 1, + non_ascii + } sc, + last_sc; + gint same_script_count = 0, seen_alpha = FALSE; + enum { + start_process = 0, + got_alpha, + got_digit, + got_unknown, + } state = start_process; + + const auto *p = (const unsigned char *) w->normalized.begin; + const auto *end = p + w->normalized.len; + last_sc = non_ascii; + + if (w->normalized.len > chartable_module_ctx->max_word_len) { + return 0.0; + } + + /* We assume that w is normalized */ + while (p < end) { + if (g_ascii_isalpha(*p) || *p > 0x7f) { + + if (state == got_digit) { + /* Penalize digit -> alpha translations */ + if (seen_alpha && !is_url && !g_ascii_isxdigit(*p)) { + badness += 0.25; + } + } + else if (state == got_alpha) { + /* Check script */ + sc = (*p > 0x7f) ? ascii : non_ascii; + + if (same_script_count > 0) { + if (sc != last_sc) { + badness += 1.0 / (gdouble) same_script_count; + last_sc = sc; + same_script_count = 1; + } + else { + same_script_count++; + } + } + else { + last_sc = sc; + same_script_count = 1; + } + } + + seen_alpha = TRUE; + state = got_alpha; + } + else if (g_ascii_isdigit(*p)) { + state = got_digit; + same_script_count = 0; + } + else { + /* We don't care about unknown characters here */ + state = got_unknown; + same_script_count = 0; + } + + p++; + } + + if (badness > 4.0) { + badness = 4.0; + } + + msg_debug_chartable("word %*s, badness: %.2f", + (gint) w->normalized.len, w->normalized.begin, + badness); + + return badness; +} + +static gboolean +rspamd_chartable_process_part(struct rspamd_task *task, + struct rspamd_mime_text_part *part, + struct chartable_ctx *chartable_module_ctx, + gboolean ignore_diacritics) +{ + rspamd_stat_token_t *w; + guint i, ncap = 0; + gdouble cur_score = 0.0; + + if (part == nullptr || part->utf_words == nullptr || + part->utf_words->len == 0 || part->nwords == 0) { + return FALSE; + } + + for (i = 0; i < part->utf_words->len; i++) { + w = &g_array_index(part->utf_words, rspamd_stat_token_t, i); + + if ((w->flags & RSPAMD_STAT_TOKEN_FLAG_TEXT)) { + + if (w->flags & RSPAMD_STAT_TOKEN_FLAG_UTF) { + cur_score += rspamd_chartable_process_word_utf(task, w, FALSE, + &ncap, chartable_module_ctx, ignore_diacritics); + } + else { + cur_score += rspamd_chartable_process_word_ascii(task, w, + FALSE, chartable_module_ctx); + } + } + } + + /* + * TODO: perhaps, we should do this analysis somewhere else and get + * something like: <SYM_SC><SYM_SC><SYM_SC> representing classes for all + * symbols in the text + */ + part->capital_letters += ncap; + + cur_score /= (gdouble) part->nwords; + + if (cur_score > 1.0) { + cur_score = 1.0; + } + + if (cur_score > chartable_module_ctx->threshold) { + rspamd_task_insert_result(task, chartable_module_ctx->symbol, + cur_score, nullptr); + return TRUE; + } + + return FALSE; +} + +static void +chartable_symbol_callback(struct rspamd_task *task, + struct rspamd_symcache_dynamic_item *item, + void *_) +{ + guint i; + struct rspamd_mime_text_part *part; + struct chartable_ctx *chartable_module_ctx = chartable_get_context(task->cfg); + gboolean ignore_diacritics = TRUE, seen_violated_part = FALSE; + + /* Check if we have parts with diacritic symbols language */ + PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, part) + { + if (part->languages && part->languages->len > 0) { + auto *lang = (struct rspamd_lang_detector_res *) g_ptr_array_index(part->languages, 0); + gint flags; + + flags = rspamd_language_detector_elt_flags(lang->elt); + + if ((flags & RS_LANGUAGE_DIACRITICS)) { + ignore_diacritics = TRUE; + } + else if (lang->prob > 0.75) { + ignore_diacritics = FALSE; + } + } + + if (rspamd_chartable_process_part(task, part, chartable_module_ctx, ignore_diacritics)) { + seen_violated_part = TRUE; + } + } + + if (MESSAGE_FIELD(task, text_parts)->len == 0) { + /* No text parts, assume that we should ignore diacritics checks for metatokens */ + ignore_diacritics = TRUE; + } + + if (task->meta_words != nullptr && task->meta_words->len > 0) { + rspamd_stat_token_t *w; + gdouble cur_score = 0; + gsize arlen = task->meta_words->len; + + for (i = 0; i < arlen; i++) { + w = &g_array_index(task->meta_words, rspamd_stat_token_t, i); + cur_score += rspamd_chartable_process_word_utf(task, w, FALSE, + nullptr, chartable_module_ctx, ignore_diacritics); + } + + cur_score /= (gdouble) (arlen + 1); + + if (cur_score > 1.0) { + cur_score = 1.0; + } + + if (cur_score > chartable_module_ctx->threshold) { + if (!seen_violated_part) { + /* Further penalise */ + if (cur_score > 0.25) { + cur_score = 0.25; + } + } + + rspamd_task_insert_result(task, chartable_module_ctx->symbol, + cur_score, "subject"); + } + } + + rspamd_symcache_finalize_item(task, item); +} + +static void +chartable_url_symbol_callback(struct rspamd_task *task, + struct rspamd_symcache_dynamic_item *item, + void *unused) +{ + /* XXX: TODO: unbreak module once URLs unicode project is over */ +#if 0 + struct rspamd_url *u; + GHashTableIter it; + gpointer k, v; + rspamd_stat_token_t w; + gdouble cur_score = 0.0; + struct chartable_ctx *chartable_module_ctx = chartable_get_context (task->cfg); + + g_hash_table_iter_init (&it, task->urls); + + while (g_hash_table_iter_next (&it, &k, &v)) { + u = v; + + if (cur_score > 2.0) { + cur_score = 2.0; + break; + } + + if (u->hostlen > 0) { + w.stemmed.begin = u->host; + w.stemmed.len = u->hostlen; + + if (g_utf8_validate (w.stemmed.begin, w.stemmed.len, nullptr)) { + cur_score += rspamd_chartable_process_word_utf (task, &w, + TRUE, nullptr, chartable_module_ctx); + } + else { + cur_score += rspamd_chartable_process_word_ascii (task, &w, + TRUE, chartable_module_ctx); + } + } + } + + g_hash_table_iter_init (&it, task->emails); + + while (g_hash_table_iter_next (&it, &k, &v)) { + u = v; + + if (cur_score > 2.0) { + cur_score = 2.0; + break; + } + + if (u->hostlen > 0) { + w.stemmed.begin = u->host; + w.stemmed.len = u->hostlen; + + if (g_utf8_validate (w.stemmed.begin, w.stemmed.len, nullptr)) { + cur_score += rspamd_chartable_process_word_utf (task, &w, + TRUE, nullptr, chartable_module_ctx); + } + else { + cur_score += rspamd_chartable_process_word_ascii (task, &w, + TRUE, chartable_module_ctx); + } + } + } + + if (cur_score > chartable_module_ctx->threshold) { + rspamd_task_insert_result (task, chartable_module_ctx->symbol, + cur_score, nullptr); + + } +#endif + rspamd_symcache_finalize_item(task, item); +} diff --git a/src/plugins/dkim_check.c b/src/plugins/dkim_check.c new file mode 100644 index 0000000..29ab34d --- /dev/null +++ b/src/plugins/dkim_check.c @@ -0,0 +1,1620 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/***MODULE:dkim + * rspamd module that checks dkim records of incoming email + * + * Allowed options: + * - symbol_allow (string): symbol to insert in case of allow (default: 'R_DKIM_ALLOW') + * - symbol_reject (string): symbol to insert (default: 'R_DKIM_REJECT') + * - symbol_tempfail (string): symbol to insert in case of temporary fail (default: 'R_DKIM_TEMPFAIL') + * - symbol_permfail (string): symbol to insert in case of permanent failure (default: 'R_DKIM_PERMFAIL') + * - symbol_na (string): symbol to insert in case of no signing (default: 'R_DKIM_NA') + * - whitelist (map): map of whitelisted networks + * - domains (map): map of domains to check + * - strict_multiplier (number): multiplier for strict domains + * - time_jitter (number): jitter in seconds to allow time diff while checking + * - trusted_only (flag): check signatures only for domains in 'domains' map + */ + + +#include "config.h" +#include "libmime/message.h" +#include "libserver/dkim.h" +#include "libutil/hash.h" +#include "libserver/maps/map.h" +#include "libserver/maps/map_helpers.h" +#include "rspamd.h" +#include "utlist.h" +#include "unix-std.h" +#include "lua/lua_common.h" +#include "libserver/mempool_vars_internal.h" + +#define DEFAULT_SYMBOL_REJECT "R_DKIM_REJECT" +#define DEFAULT_SYMBOL_TEMPFAIL "R_DKIM_TEMPFAIL" +#define DEFAULT_SYMBOL_ALLOW "R_DKIM_ALLOW" +#define DEFAULT_SYMBOL_NA "R_DKIM_NA" +#define DEFAULT_SYMBOL_PERMFAIL "R_DKIM_PERMFAIL" +#define DEFAULT_CACHE_SIZE 2048 +#define DEFAULT_TIME_JITTER 60 +#define DEFAULT_MAX_SIGS 5 + +static const gchar *M = "rspamd dkim plugin"; + +static const gchar default_sign_headers[] = "" + "(o)from:(x)sender:(o)reply-to:(o)subject:(x)date:(x)message-id:" + "(o)to:(o)cc:(x)mime-version:(x)content-type:(x)content-transfer-encoding:" + "resent-to:resent-cc:resent-from:resent-sender:resent-message-id:" + "(x)in-reply-to:(x)references:list-id:list-help:list-owner:list-unsubscribe:" + "list-unsubscribe-post:list-subscribe:list-post:(x)openpgp:(x)autocrypt"; +static const gchar default_arc_sign_headers[] = "" + "(o)from:(x)sender:(o)reply-to:(o)subject:(x)date:(x)message-id:" + "(o)to:(o)cc:(x)mime-version:(x)content-type:(x)content-transfer-encoding:" + "resent-to:resent-cc:resent-from:resent-sender:resent-message-id:" + "(x)in-reply-to:(x)references:list-id:list-help:list-owner:list-unsubscribe:" + "list-unsubscribe-post:list-subscribe:list-post:dkim-signature:(x)openpgp:" + "(x)autocrypt"; + +struct dkim_ctx { + struct module_ctx ctx; + const gchar *symbol_reject; + const gchar *symbol_tempfail; + const gchar *symbol_allow; + const gchar *symbol_na; + const gchar *symbol_permfail; + + struct rspamd_radix_map_helper *whitelist_ip; + struct rspamd_hash_map_helper *dkim_domains; + guint strict_multiplier; + guint time_jitter; + rspamd_lru_hash_t *dkim_hash; + rspamd_lru_hash_t *dkim_sign_hash; + const gchar *sign_headers; + const gchar *arc_sign_headers; + guint max_sigs; + gboolean trusted_only; + gboolean check_local; + gboolean check_authed; +}; + +struct dkim_check_result { + rspamd_dkim_context_t *ctx; + rspamd_dkim_key_t *key; + struct rspamd_task *task; + struct rspamd_dkim_check_result *res; + gdouble mult_allow; + gdouble mult_deny; + struct rspamd_symcache_dynamic_item *item; + struct dkim_check_result *next, *prev, *first; +}; + +static void dkim_symbol_callback(struct rspamd_task *task, + struct rspamd_symcache_dynamic_item *item, + void *unused); + +static gint lua_dkim_sign_handler(lua_State *L); +static gint lua_dkim_verify_handler(lua_State *L); +static gint lua_dkim_canonicalize_handler(lua_State *L); + +/* Initialization */ +gint dkim_module_init(struct rspamd_config *cfg, struct module_ctx **ctx); +gint dkim_module_config(struct rspamd_config *cfg, bool validate); +gint dkim_module_reconfig(struct rspamd_config *cfg); + +module_t dkim_module = { + "dkim", + dkim_module_init, + dkim_module_config, + dkim_module_reconfig, + NULL, + RSPAMD_MODULE_VER, + (guint) -1, +}; + +static inline struct dkim_ctx * +dkim_get_context(struct rspamd_config *cfg) +{ + return (struct dkim_ctx *) g_ptr_array_index(cfg->c_modules, + dkim_module.ctx_offset); +} + +static void +dkim_module_key_dtor(gpointer k) +{ + rspamd_dkim_key_t *key = k; + + rspamd_dkim_key_unref(key); +} + +static void +dkim_module_free_list(gpointer k) +{ + g_list_free_full((GList *) k, rspamd_gstring_free_hard); +} + +gint dkim_module_init(struct rspamd_config *cfg, struct module_ctx **ctx) +{ + struct dkim_ctx *dkim_module_ctx; + + dkim_module_ctx = rspamd_mempool_alloc0(cfg->cfg_pool, + sizeof(*dkim_module_ctx)); + dkim_module_ctx->sign_headers = default_sign_headers; + dkim_module_ctx->arc_sign_headers = default_arc_sign_headers; + dkim_module_ctx->max_sigs = DEFAULT_MAX_SIGS; + + *ctx = (struct module_ctx *) dkim_module_ctx; + + rspamd_rcl_add_doc_by_path(cfg, + NULL, + "DKIM check plugin", + "dkim", + UCL_OBJECT, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "dkim", + "Map of IP addresses that should be excluded from DKIM checks", + "whitelist", + UCL_STRING, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "dkim", + "Symbol that is added if DKIM check is successful", + "symbol_allow", + UCL_STRING, + NULL, + 0, + DEFAULT_SYMBOL_ALLOW, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "dkim", + "Symbol that is added if DKIM check is unsuccessful", + "symbol_reject", + UCL_STRING, + NULL, + 0, + DEFAULT_SYMBOL_REJECT, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "dkim", + "Symbol that is added if DKIM check can't be completed (e.g. DNS failure)", + "symbol_tempfail", + UCL_STRING, + NULL, + 0, + DEFAULT_SYMBOL_TEMPFAIL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "dkim", + "Symbol that is added if mail is not signed", + "symbol_na", + UCL_STRING, + NULL, + 0, + DEFAULT_SYMBOL_NA, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "dkim", + "Symbol that is added if permanent failure encountered", + "symbol_permfail", + UCL_STRING, + NULL, + 0, + DEFAULT_SYMBOL_PERMFAIL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "dkim", + "Size of DKIM keys cache", + "dkim_cache_size", + UCL_INT, + NULL, + 0, + G_STRINGIFY(DEFAULT_CACHE_SIZE), + 0); + rspamd_rcl_add_doc_by_path(cfg, + "dkim", + "Allow this time difference when checking DKIM signature time validity", + "time_jitter", + UCL_TIME, + NULL, + 0, + G_STRINGIFY(DEFAULT_TIME_JITTER), + 0); + rspamd_rcl_add_doc_by_path(cfg, + "dkim", + "Domains to check DKIM for (check all domains if this option is empty)", + "domains", + UCL_STRING, + NULL, + 0, + "empty", + 0); + rspamd_rcl_add_doc_by_path(cfg, + "dkim", + "Map of domains that are treated as 'trusted' meaning that DKIM policy failure has more significant score", + "trusted_domains", + UCL_STRING, + NULL, + 0, + "empty", + 0); + rspamd_rcl_add_doc_by_path(cfg, + "dkim", + "Multiply dkim score by this factor for trusted domains", + "strict_multiplier", + UCL_FLOAT, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "dkim", + "Check DKIM policies merely for `trusted_domains`", + "trusted_only", + UCL_BOOLEAN, + NULL, + 0, + "false", + 0); + rspamd_rcl_add_doc_by_path(cfg, + "dkim", + "Lua script that tells if a message should be signed and with what params (obsoleted)", + "sign_condition", + UCL_STRING, + NULL, + 0, + "empty", + 0); + rspamd_rcl_add_doc_by_path(cfg, + "dkim", + "Obsoleted: maximum number of DKIM signatures to check", + "max_sigs", + UCL_INT, + NULL, + 0, + "n/a", + 0); + rspamd_rcl_add_doc_by_path(cfg, + "dkim", + "Headers used in signing", + "sign_headers", + UCL_STRING, + NULL, + 0, + default_sign_headers, + 0); + + return 0; +} + +gint dkim_module_config(struct rspamd_config *cfg, bool validate) +{ + const ucl_object_t *value; + gint res = TRUE, cb_id = -1; + guint cache_size, sign_cache_size; + gboolean got_trusted = FALSE; + struct dkim_ctx *dkim_module_ctx = dkim_get_context(cfg); + + /* Register global methods */ + lua_getglobal(cfg->lua_state, "rspamd_plugins"); + + if (lua_type(cfg->lua_state, -1) == LUA_TTABLE) { + lua_pushstring(cfg->lua_state, "dkim"); + lua_createtable(cfg->lua_state, 0, 1); + /* Set methods */ + lua_pushstring(cfg->lua_state, "sign"); + lua_pushcfunction(cfg->lua_state, lua_dkim_sign_handler); + lua_settable(cfg->lua_state, -3); + lua_pushstring(cfg->lua_state, "verify"); + lua_pushcfunction(cfg->lua_state, lua_dkim_verify_handler); + lua_settable(cfg->lua_state, -3); + lua_pushstring(cfg->lua_state, "canon_header_relaxed"); + lua_pushcfunction(cfg->lua_state, lua_dkim_canonicalize_handler); + lua_settable(cfg->lua_state, -3); + /* Finish dkim key */ + lua_settable(cfg->lua_state, -3); + } + + lua_pop(cfg->lua_state, 1); /* Remove global function */ + dkim_module_ctx->whitelist_ip = NULL; + + value = rspamd_config_get_module_opt(cfg, "dkim", "check_local"); + + if (value == NULL) { + value = rspamd_config_get_module_opt(cfg, "options", "check_local"); + } + + if (value != NULL) { + dkim_module_ctx->check_local = ucl_object_toboolean(value); + } + else { + dkim_module_ctx->check_local = FALSE; + } + + value = rspamd_config_get_module_opt(cfg, "dkim", + "check_authed"); + + if (value == NULL) { + value = rspamd_config_get_module_opt(cfg, "options", + "check_authed"); + } + + if (value != NULL) { + dkim_module_ctx->check_authed = ucl_object_toboolean(value); + } + else { + dkim_module_ctx->check_authed = FALSE; + } + if ((value = + rspamd_config_get_module_opt(cfg, "dkim", "symbol_reject")) != NULL) { + dkim_module_ctx->symbol_reject = ucl_object_tostring(value); + } + else { + dkim_module_ctx->symbol_reject = DEFAULT_SYMBOL_REJECT; + } + if ((value = + rspamd_config_get_module_opt(cfg, "dkim", + "symbol_tempfail")) != NULL) { + dkim_module_ctx->symbol_tempfail = ucl_object_tostring(value); + } + else { + dkim_module_ctx->symbol_tempfail = DEFAULT_SYMBOL_TEMPFAIL; + } + if ((value = + rspamd_config_get_module_opt(cfg, "dkim", "symbol_allow")) != NULL) { + dkim_module_ctx->symbol_allow = ucl_object_tostring(value); + } + else { + dkim_module_ctx->symbol_allow = DEFAULT_SYMBOL_ALLOW; + } + if ((value = + rspamd_config_get_module_opt(cfg, "dkim", "symbol_na")) != NULL) { + dkim_module_ctx->symbol_na = ucl_object_tostring(value); + } + else { + dkim_module_ctx->symbol_na = DEFAULT_SYMBOL_NA; + } + if ((value = + rspamd_config_get_module_opt(cfg, "dkim", "symbol_permfail")) != NULL) { + dkim_module_ctx->symbol_permfail = ucl_object_tostring(value); + } + else { + dkim_module_ctx->symbol_permfail = DEFAULT_SYMBOL_PERMFAIL; + } + if ((value = + rspamd_config_get_module_opt(cfg, "dkim", + "dkim_cache_size")) != NULL) { + cache_size = ucl_object_toint(value); + } + else { + cache_size = DEFAULT_CACHE_SIZE; + } + + if ((value = + rspamd_config_get_module_opt(cfg, "dkim", + "sign_cache_size")) != NULL) { + sign_cache_size = ucl_object_toint(value); + } + else { + sign_cache_size = 128; + } + + if ((value = + rspamd_config_get_module_opt(cfg, "dkim", "time_jitter")) != NULL) { + dkim_module_ctx->time_jitter = ucl_object_todouble(value); + } + else { + dkim_module_ctx->time_jitter = DEFAULT_TIME_JITTER; + } + + if ((value = + rspamd_config_get_module_opt(cfg, "dkim", "max_sigs")) != NULL) { + dkim_module_ctx->max_sigs = ucl_object_toint(value); + } + + if ((value = + rspamd_config_get_module_opt(cfg, "dkim", "whitelist")) != NULL) { + + rspamd_config_radix_from_ucl(cfg, value, "DKIM whitelist", + &dkim_module_ctx->whitelist_ip, NULL, NULL, "dkim whitelist"); + } + + if ((value = + rspamd_config_get_module_opt(cfg, "dkim", "domains")) != NULL) { + if (!rspamd_map_add_from_ucl(cfg, value, + "DKIM domains", + rspamd_kv_list_read, + rspamd_kv_list_fin, + rspamd_kv_list_dtor, + (void **) &dkim_module_ctx->dkim_domains, + NULL, RSPAMD_MAP_DEFAULT)) { + msg_warn_config("cannot load dkim domains list from %s", + ucl_object_tostring(value)); + } + else { + got_trusted = TRUE; + } + } + + if (!got_trusted && (value = + rspamd_config_get_module_opt(cfg, "dkim", "trusted_domains")) != NULL) { + if (!rspamd_map_add_from_ucl(cfg, value, + "DKIM domains", + rspamd_kv_list_read, + rspamd_kv_list_fin, + rspamd_kv_list_dtor, + (void **) &dkim_module_ctx->dkim_domains, + NULL, RSPAMD_MAP_DEFAULT)) { + msg_warn_config("cannot load dkim domains list from %s", + ucl_object_tostring(value)); + + if (validate) { + return FALSE; + } + } + else { + got_trusted = TRUE; + } + } + + if ((value = + rspamd_config_get_module_opt(cfg, "dkim", + "strict_multiplier")) != NULL) { + dkim_module_ctx->strict_multiplier = ucl_object_toint(value); + } + else { + dkim_module_ctx->strict_multiplier = 1; + } + + if ((value = + rspamd_config_get_module_opt(cfg, "dkim", "trusted_only")) != NULL) { + dkim_module_ctx->trusted_only = ucl_object_toboolean(value); + } + else { + dkim_module_ctx->trusted_only = FALSE; + } + + if ((value = + rspamd_config_get_module_opt(cfg, "dkim", "sign_headers")) != NULL) { + dkim_module_ctx->sign_headers = ucl_object_tostring(value); + } + + if ((value = + rspamd_config_get_module_opt(cfg, "arc", "sign_headers")) != NULL) { + dkim_module_ctx->arc_sign_headers = ucl_object_tostring(value); + } + + if (cache_size > 0) { + dkim_module_ctx->dkim_hash = rspamd_lru_hash_new( + cache_size, + g_free, + dkim_module_key_dtor); + rspamd_mempool_add_destructor(cfg->cfg_pool, + (rspamd_mempool_destruct_t) rspamd_lru_hash_destroy, + dkim_module_ctx->dkim_hash); + } + + if (sign_cache_size > 0) { + dkim_module_ctx->dkim_sign_hash = rspamd_lru_hash_new( + sign_cache_size, + g_free, + (GDestroyNotify) rspamd_dkim_sign_key_unref); + rspamd_mempool_add_destructor(cfg->cfg_pool, + (rspamd_mempool_destruct_t) rspamd_lru_hash_destroy, + dkim_module_ctx->dkim_sign_hash); + } + + if (dkim_module_ctx->trusted_only && !got_trusted) { + msg_err_config("trusted_only option is set and no trusted domains are defined"); + if (validate) { + return FALSE; + } + } + else { + if (!rspamd_config_is_module_enabled(cfg, "dkim")) { + return TRUE; + } + + cb_id = rspamd_symcache_add_symbol(cfg->cache, + "DKIM_CHECK", + 0, + dkim_symbol_callback, + NULL, + SYMBOL_TYPE_CALLBACK, + -1); + rspamd_config_add_symbol(cfg, + "DKIM_CHECK", + 0.0, + "DKIM check callback", + "policies", + RSPAMD_SYMBOL_FLAG_IGNORE_METRIC, + 1, + 1); + rspamd_config_add_symbol_group(cfg, "DKIM_CHECK", "dkim"); + rspamd_symcache_add_symbol(cfg->cache, + dkim_module_ctx->symbol_reject, + 0, + NULL, + NULL, + SYMBOL_TYPE_VIRTUAL | SYMBOL_TYPE_FINE, + cb_id); + rspamd_symcache_add_symbol(cfg->cache, + dkim_module_ctx->symbol_na, + 0, + NULL, NULL, + SYMBOL_TYPE_VIRTUAL | SYMBOL_TYPE_FINE, + cb_id); + rspamd_symcache_add_symbol(cfg->cache, + dkim_module_ctx->symbol_permfail, + 0, + NULL, NULL, + SYMBOL_TYPE_VIRTUAL | SYMBOL_TYPE_FINE, + cb_id); + rspamd_symcache_add_symbol(cfg->cache, + dkim_module_ctx->symbol_tempfail, + 0, + NULL, NULL, + SYMBOL_TYPE_VIRTUAL | SYMBOL_TYPE_FINE, + cb_id); + rspamd_symcache_add_symbol(cfg->cache, + dkim_module_ctx->symbol_allow, + 0, + NULL, NULL, + SYMBOL_TYPE_VIRTUAL | SYMBOL_TYPE_FINE, + cb_id); + + rspamd_symcache_add_symbol(cfg->cache, + "DKIM_TRACE", + 0, + NULL, NULL, + SYMBOL_TYPE_VIRTUAL | SYMBOL_TYPE_NOSTAT, + cb_id); + rspamd_config_add_symbol(cfg, + "DKIM_TRACE", + 0.0, + "DKIM trace symbol", + "policies", + RSPAMD_SYMBOL_FLAG_IGNORE_METRIC, + 1, + 1); + rspamd_config_add_symbol_group(cfg, "DKIM_TRACE", "dkim"); + + msg_info_config("init internal dkim module"); +#ifndef HAVE_OPENSSL + msg_warn_config( + "openssl is not found so dkim rsa check is disabled, only check body hash, it is NOT safe to trust these results"); +#endif + } + + return res; +} + + +/** + * Grab a private key from the cache + * or from the key content provided + */ +rspamd_dkim_sign_key_t * +dkim_module_load_key_format(struct rspamd_task *task, + struct dkim_ctx *dkim_module_ctx, + const gchar *key, gsize keylen, + enum rspamd_dkim_key_format key_format) + +{ + guchar h[rspamd_cryptobox_HASHBYTES], + hex_hash[rspamd_cryptobox_HASHBYTES * 2 + 1]; + rspamd_dkim_sign_key_t *ret = NULL; + GError *err = NULL; + struct stat st; + + memset(hex_hash, 0, sizeof(hex_hash)); + rspamd_cryptobox_hash(h, key, keylen, NULL, 0); + rspamd_encode_hex_buf(h, sizeof(h), hex_hash, sizeof(hex_hash)); + + if (dkim_module_ctx->dkim_sign_hash) { + ret = rspamd_lru_hash_lookup(dkim_module_ctx->dkim_sign_hash, + hex_hash, time(NULL)); + } + + /* + * This fails for paths that are also valid base64. + * Maybe the caller should have specified a format. + */ + if (key_format == RSPAMD_DKIM_KEY_UNKNOWN) { + if (key[0] == '.' || key[0] == '/') { + if (!rspamd_cryptobox_base64_is_valid(key, keylen)) { + key_format = RSPAMD_DKIM_KEY_FILE; + } + } + else if (rspamd_cryptobox_base64_is_valid(key, keylen)) { + key_format = RSPAMD_DKIM_KEY_BASE64; + } + } + + + if (ret != NULL && key_format == RSPAMD_DKIM_KEY_FILE) { + msg_debug_task("checking for stale file key"); + + if (stat(key, &st) != 0) { + msg_err_task("cannot stat key file: %s", strerror(errno)); + return NULL; + } + + if (rspamd_dkim_sign_key_maybe_invalidate(ret, st.st_mtime)) { + msg_debug_task("removing stale file key"); + /* + * Invalidate DKIM key + * removal from lru cache also cleanup the key and value + */ + if (dkim_module_ctx->dkim_sign_hash) { + rspamd_lru_hash_remove(dkim_module_ctx->dkim_sign_hash, + hex_hash); + } + ret = NULL; + } + } + + /* found key; done */ + if (ret != NULL) { + return ret; + } + + ret = rspamd_dkim_sign_key_load(key, keylen, key_format, &err); + + if (ret == NULL) { + msg_err_task("cannot load dkim key %s: %e", + key, err); + g_error_free(err); + } + else if (dkim_module_ctx->dkim_sign_hash) { + rspamd_lru_hash_insert(dkim_module_ctx->dkim_sign_hash, + g_strdup(hex_hash), ret, time(NULL), 0); + } + + return ret; +} + +static gint +lua_dkim_sign_handler(lua_State *L) +{ + struct rspamd_task *task = lua_check_task(L, 1); + gint64 arc_idx = 0, expire = 0; + enum rspamd_dkim_type sign_type = RSPAMD_DKIM_NORMAL; + GError *err = NULL; + GString *hdr; + GList *sigs = NULL; + const gchar *selector = NULL, *domain = NULL, *key = NULL, *rawkey = NULL, + *headers = NULL, *sign_type_str = NULL, *arc_cv = NULL, + *pubkey = NULL; + rspamd_dkim_sign_context_t *ctx; + rspamd_dkim_sign_key_t *dkim_key; + gsize rawlen = 0, keylen = 0; + gboolean no_cache = FALSE, strict_pubkey_check = FALSE; + struct dkim_ctx *dkim_module_ctx; + + luaL_argcheck(L, lua_type(L, 2) == LUA_TTABLE, 2, "'table' expected"); + /* + * Get the following elements: + * - selector + * - domain + * - key + */ + if (!rspamd_lua_parse_table_arguments(L, 2, &err, + RSPAMD_LUA_PARSE_ARGUMENTS_DEFAULT, + "key=V;rawkey=V;*domain=S;*selector=S;no_cache=B;headers=S;" + "sign_type=S;arc_idx=I;arc_cv=S;expire=I;pubkey=S;" + "strict_pubkey_check=B", + &keylen, &key, &rawlen, &rawkey, &domain, + &selector, &no_cache, &headers, + &sign_type_str, &arc_idx, &arc_cv, &expire, &pubkey, + &strict_pubkey_check)) { + msg_err_task("cannot parse table arguments: %e", + err); + g_error_free(err); + + lua_pushboolean(L, FALSE); + return 1; + } + + dkim_module_ctx = dkim_get_context(task->cfg); + + if (key) { + dkim_key = dkim_module_load_key_format(task, dkim_module_ctx, key, + keylen, RSPAMD_DKIM_KEY_UNKNOWN); + } + else if (rawkey) { + dkim_key = dkim_module_load_key_format(task, dkim_module_ctx, rawkey, + rawlen, RSPAMD_DKIM_KEY_UNKNOWN); + } + else { + msg_err_task("neither key nor rawkey are specified"); + lua_pushboolean(L, FALSE); + + return 1; + } + + if (dkim_key == NULL) { + lua_pushboolean(L, FALSE); + return 1; + } + + if (sign_type_str) { + if (strcmp(sign_type_str, "dkim") == 0) { + sign_type = RSPAMD_DKIM_NORMAL; + + if (headers == NULL) { + headers = dkim_module_ctx->sign_headers; + } + } + else if (strcmp(sign_type_str, "arc-sign") == 0) { + sign_type = RSPAMD_DKIM_ARC_SIG; + + if (headers == NULL) { + headers = dkim_module_ctx->arc_sign_headers; + } + + if (arc_idx == 0) { + lua_settop(L, 0); + return luaL_error(L, "no arc idx specified"); + } + } + else if (strcmp(sign_type_str, "arc-seal") == 0) { + sign_type = RSPAMD_DKIM_ARC_SEAL; + if (arc_cv == NULL) { + lua_settop(L, 0); + return luaL_error(L, "no arc cv specified"); + } + if (arc_idx == 0) { + lua_settop(L, 0); + return luaL_error(L, "no arc idx specified"); + } + } + else { + lua_settop(L, 0); + return luaL_error(L, "unknown sign type: %s", + sign_type_str); + } + } + else { + /* Unspecified sign type, assume plain dkim */ + if (headers == NULL) { + headers = dkim_module_ctx->sign_headers; + } + } + + if (pubkey != NULL) { + /* Also check if private and public keys match */ + rspamd_dkim_key_t *pk; + keylen = strlen(pubkey); + + pk = rspamd_dkim_parse_key(pubkey, &keylen, NULL); + + if (pk == NULL) { + if (strict_pubkey_check) { + msg_err_task("cannot parse pubkey from string: %s, skip signing", + pubkey); + lua_pushboolean(L, FALSE); + + return 1; + } + else { + msg_warn_task("cannot parse pubkey from string: %s", + pubkey); + } + } + else { + GError *te = NULL; + + /* We have parsed the key, so try to check keys */ + if (!rspamd_dkim_match_keys(pk, dkim_key, &te)) { + if (strict_pubkey_check) { + msg_err_task("public key for %s/%s does not match private " + "key: %e, skip signing", + domain, selector, te); + g_error_free(te); + lua_pushboolean(L, FALSE); + rspamd_dkim_key_unref(pk); + + return 1; + } + else { + msg_warn_task("public key for %s/%s does not match private " + "key: %e", + domain, selector, te); + g_error_free(te); + } + } + + rspamd_dkim_key_unref(pk); + } + } + + ctx = rspamd_create_dkim_sign_context(task, dkim_key, + DKIM_CANON_RELAXED, DKIM_CANON_RELAXED, + headers, sign_type, &err); + + if (ctx == NULL) { + msg_err_task("cannot create sign context: %e", + err); + g_error_free(err); + + lua_pushboolean(L, FALSE); + return 1; + } + + hdr = rspamd_dkim_sign(task, selector, domain, 0, + expire, arc_idx, arc_cv, ctx); + + if (hdr) { + + if (!no_cache) { + sigs = rspamd_mempool_get_variable(task->task_pool, "dkim-signature"); + + if (sigs == NULL) { + sigs = g_list_append(sigs, hdr); + rspamd_mempool_set_variable(task->task_pool, "dkim-signature", + sigs, dkim_module_free_list); + } + else { + sigs = g_list_append(sigs, hdr); + (void) sigs; + } + } + + lua_pushboolean(L, TRUE); + lua_pushlstring(L, hdr->str, hdr->len); + + if (no_cache) { + g_string_free(hdr, TRUE); + } + + return 2; + } + + + lua_pushboolean(L, FALSE); + lua_pushnil(L); + + return 2; +} + +gint dkim_module_reconfig(struct rspamd_config *cfg) +{ + return dkim_module_config(cfg, false); +} + +/* + * Parse strict value for domain in format: 'reject_multiplier:deny_multiplier' + */ +static gboolean +dkim_module_parse_strict(const gchar *value, gdouble *allow, gdouble *deny) +{ + const gchar *colon; + gchar *err = NULL; + gdouble val; + gchar numbuf[64]; + + colon = strchr(value, ':'); + if (colon) { + rspamd_strlcpy(numbuf, value, + MIN(sizeof(numbuf), (colon - value) + 1)); + val = strtod(numbuf, &err); + + if (err == NULL || *err == '\0') { + *deny = val; + colon++; + rspamd_strlcpy(numbuf, colon, sizeof(numbuf)); + err = NULL; + val = strtod(numbuf, &err); + + if (err == NULL || *err == '\0') { + *allow = val; + return TRUE; + } + } + } + return FALSE; +} + +static void +dkim_module_check(struct dkim_check_result *res) +{ + gboolean all_done = TRUE; + const gchar *strict_value; + struct dkim_check_result *first, *cur = NULL; + struct dkim_ctx *dkim_module_ctx = dkim_get_context(res->task->cfg); + struct rspamd_task *task = res->task; + + first = res->first; + + DL_FOREACH(first, cur) + { + if (cur->ctx == NULL) { + continue; + } + + if (cur->key != NULL && cur->res == NULL) { + cur->res = rspamd_dkim_check(cur->ctx, cur->key, task); + + if (dkim_module_ctx->dkim_domains != NULL) { + /* Perform strict check */ + const gchar *domain = rspamd_dkim_get_domain(cur->ctx); + + if ((strict_value = + rspamd_match_hash_map(dkim_module_ctx->dkim_domains, + domain, + strlen(domain))) != NULL) { + if (!dkim_module_parse_strict(strict_value, &cur->mult_allow, + &cur->mult_deny)) { + cur->mult_allow = dkim_module_ctx->strict_multiplier; + cur->mult_deny = dkim_module_ctx->strict_multiplier; + } + } + } + } + } + + DL_FOREACH(first, cur) + { + if (cur->ctx == NULL) { + continue; + } + if (cur->res == NULL) { + /* Still need a key */ + all_done = FALSE; + } + } + + if (all_done) { + /* Create zero terminated array of results */ + struct rspamd_dkim_check_result **pres; + guint nres = 0, i = 0; + + DL_FOREACH(first, cur) + { + if (cur->ctx == NULL || cur->res == NULL) { + continue; + } + + nres++; + } + + pres = rspamd_mempool_alloc(task->task_pool, sizeof(*pres) * (nres + 1)); + pres[nres] = NULL; + + DL_FOREACH(first, cur) + { + const gchar *symbol = NULL, *trace = NULL; + gdouble symbol_weight = 1.0; + + if (cur->ctx == NULL || cur->res == NULL) { + continue; + } + + pres[i++] = cur->res; + + if (cur->res->rcode == DKIM_REJECT) { + symbol = dkim_module_ctx->symbol_reject; + trace = "-"; + symbol_weight = cur->mult_deny * 1.0; + } + else if (cur->res->rcode == DKIM_CONTINUE) { + symbol = dkim_module_ctx->symbol_allow; + trace = "+"; + symbol_weight = cur->mult_allow * 1.0; + } + else if (cur->res->rcode == DKIM_PERM_ERROR) { + trace = "~"; + symbol = dkim_module_ctx->symbol_permfail; + } + else if (cur->res->rcode == DKIM_TRYAGAIN) { + trace = "?"; + symbol = dkim_module_ctx->symbol_tempfail; + } + + if (symbol != NULL) { + const gchar *domain = rspamd_dkim_get_domain(cur->ctx); + const gchar *selector = rspamd_dkim_get_selector(cur->ctx); + gsize tracelen; + gchar *tracebuf; + + tracelen = strlen(domain) + strlen(selector) + 4; + tracebuf = rspamd_mempool_alloc(task->task_pool, + tracelen); + rspamd_snprintf(tracebuf, tracelen, "%s:%s", domain, trace); + + rspamd_task_insert_result(cur->task, + "DKIM_TRACE", + 0.0, + tracebuf); + + rspamd_snprintf(tracebuf, tracelen, "%s:s=%s", domain, selector); + rspamd_task_insert_result(task, + symbol, + symbol_weight, + tracebuf); + } + } + + rspamd_mempool_set_variable(task->task_pool, + RSPAMD_MEMPOOL_DKIM_CHECK_RESULTS, + pres, NULL); + } +} + +static void +dkim_module_key_handler(rspamd_dkim_key_t *key, + gsize keylen, + rspamd_dkim_context_t *ctx, + gpointer ud, + GError *err) +{ + struct dkim_check_result *res = ud; + struct rspamd_task *task; + struct dkim_ctx *dkim_module_ctx; + + task = res->task; + dkim_module_ctx = dkim_get_context(task->cfg); + + if (key != NULL) { + /* Another ref belongs to the check context */ + res->key = rspamd_dkim_key_ref(key); + /* + * We actually receive key with refcount = 1, so we just assume that + * lru hash owns this object now + */ + /* Release key when task is processed */ + rspamd_mempool_add_destructor(res->task->task_pool, + dkim_module_key_dtor, res->key); + + if (dkim_module_ctx->dkim_hash) { + rspamd_lru_hash_insert(dkim_module_ctx->dkim_hash, + g_strdup(rspamd_dkim_get_dns_key(ctx)), + key, res->task->task_timestamp, rspamd_dkim_key_get_ttl(key)); + + msg_info_task("stored DKIM key for %s in LRU cache for %d seconds, " + "%d/%d elements in the cache", + rspamd_dkim_get_dns_key(ctx), + rspamd_dkim_key_get_ttl(key), + rspamd_lru_hash_size(dkim_module_ctx->dkim_hash), + rspamd_lru_hash_capacity(dkim_module_ctx->dkim_hash)); + } + } + else { + /* Insert tempfail symbol */ + msg_info_task("cannot get key for domain %s: %e", + rspamd_dkim_get_dns_key(ctx), err); + + if (err != NULL) { + if (err->code == DKIM_SIGERROR_NOKEY) { + res->res = rspamd_dkim_create_result(ctx, DKIM_TRYAGAIN, task); + res->res->fail_reason = "DNS error when getting key"; + } + else { + res->res = rspamd_dkim_create_result(ctx, DKIM_PERM_ERROR, task); + res->res->fail_reason = "invalid DKIM record"; + } + } + } + + if (err) { + g_error_free(err); + } + + dkim_module_check(res); +} + +static void +dkim_symbol_callback(struct rspamd_task *task, + struct rspamd_symcache_dynamic_item *item, + void *unused) +{ + rspamd_dkim_context_t *ctx; + rspamd_dkim_key_t *key; + GError *err = NULL; + struct rspamd_mime_header *rh, *rh_cur; + struct dkim_check_result *res = NULL, *cur; + guint checked = 0; + gdouble *dmarc_checks; + struct dkim_ctx *dkim_module_ctx = dkim_get_context(task->cfg); + + /* Allow dmarc */ + dmarc_checks = rspamd_mempool_get_variable(task->task_pool, + RSPAMD_MEMPOOL_DMARC_CHECKS); + + if (dmarc_checks) { + (*dmarc_checks)++; + } + else { + dmarc_checks = rspamd_mempool_alloc(task->task_pool, + sizeof(*dmarc_checks)); + *dmarc_checks = 1; + rspamd_mempool_set_variable(task->task_pool, + RSPAMD_MEMPOOL_DMARC_CHECKS, + dmarc_checks, NULL); + } + + /* First check if plugin should be enabled */ + if ((!dkim_module_ctx->check_authed && task->auth_user != NULL) || (!dkim_module_ctx->check_local && + rspamd_ip_is_local_cfg(task->cfg, task->from_addr))) { + msg_info_task("skip DKIM checks for local networks and authorized users"); + rspamd_symcache_finalize_item(task, item); + + return; + } + /* Check whitelist */ + if (rspamd_match_radix_map_addr(dkim_module_ctx->whitelist_ip, + task->from_addr) != NULL) { + msg_info_task("skip DKIM checks for whitelisted address"); + rspamd_symcache_finalize_item(task, item); + + return; + } + + rspamd_symcache_item_async_inc(task, item, M); + + /* Now check if a message has its signature */ + rh = rspamd_message_get_header_array(task, RSPAMD_DKIM_SIGNHEADER, FALSE); + if (rh) { + msg_debug_task("dkim signature found"); + + DL_FOREACH(rh, rh_cur) + { + if (rh_cur->decoded == NULL || rh_cur->decoded[0] == '\0') { + msg_info_task("cannot load empty DKIM signature"); + continue; + } + + cur = rspamd_mempool_alloc0(task->task_pool, sizeof(*cur)); + cur->first = res; + cur->res = NULL; + cur->task = task; + cur->mult_allow = 1.0; + cur->mult_deny = 1.0; + cur->item = item; + + ctx = rspamd_create_dkim_context(rh_cur->decoded, + task->task_pool, + task->resolver, + dkim_module_ctx->time_jitter, + RSPAMD_DKIM_NORMAL, + &err); + + if (res == NULL) { + res = cur; + res->first = res; + res->prev = res; + } + else { + DL_APPEND(res, cur); + } + + if (ctx == NULL) { + if (err != NULL) { + msg_info_task("cannot parse DKIM signature: %e", + err); + g_error_free(err); + err = NULL; + } + else { + msg_info_task("cannot parse DKIM signature: " + "unknown error"); + } + + continue; + } + else { + /* Get key */ + cur->ctx = ctx; + const gchar *domain = rspamd_dkim_get_domain(cur->ctx); + + if (dkim_module_ctx->trusted_only && + (dkim_module_ctx->dkim_domains == NULL || + rspamd_match_hash_map(dkim_module_ctx->dkim_domains, + domain, strlen(domain)) == NULL)) { + msg_debug_task("skip dkim check for %s domain", + rspamd_dkim_get_domain(ctx)); + + continue; + } + + if (dkim_module_ctx->dkim_hash) { + key = rspamd_lru_hash_lookup(dkim_module_ctx->dkim_hash, + rspamd_dkim_get_dns_key(ctx), + task->task_timestamp); + } + else { + key = NULL; + } + + if (key != NULL) { + cur->key = rspamd_dkim_key_ref(key); + /* Release key when task is processed */ + rspamd_mempool_add_destructor(task->task_pool, + dkim_module_key_dtor, cur->key); + } + else { + if (!rspamd_get_dkim_key(ctx, + task, + dkim_module_key_handler, + cur)) { + continue; + } + } + } + + checked++; + + if (checked > dkim_module_ctx->max_sigs) { + msg_info_task("message has multiple signatures but we" + " stopped after %d checked signatures as limit" + " is reached", + checked); + break; + } + } + } + else { + rspamd_task_insert_result(task, + dkim_module_ctx->symbol_na, + 1.0, + NULL); + } + + if (res != NULL) { + dkim_module_check(res); + } + + rspamd_symcache_item_async_dec_check(task, item, M); +} + +struct rspamd_dkim_lua_verify_cbdata { + rspamd_dkim_context_t *ctx; + struct rspamd_task *task; + lua_State *L; + rspamd_dkim_key_t *key; + gint cbref; +}; + +static void +dkim_module_lua_push_verify_result(struct rspamd_dkim_lua_verify_cbdata *cbd, + struct rspamd_dkim_check_result *res, GError *err) +{ + struct rspamd_task **ptask, *task; + const gchar *error_str = "unknown error"; + gboolean success = FALSE; + + task = cbd->task; + + switch (res->rcode) { + case DKIM_CONTINUE: + error_str = NULL; + success = TRUE; + break; + case DKIM_REJECT: + if (err) { + error_str = err->message; + } + else { + error_str = "reject"; + } + break; + case DKIM_TRYAGAIN: + if (err) { + error_str = err->message; + } + else { + error_str = "tempfail"; + } + break; + case DKIM_NOTFOUND: + if (err) { + error_str = err->message; + } + else { + error_str = "not found"; + } + break; + case DKIM_RECORD_ERROR: + if (err) { + error_str = err->message; + } + else { + error_str = "bad record"; + } + break; + case DKIM_PERM_ERROR: + if (err) { + error_str = err->message; + } + else { + error_str = "permanent error"; + } + break; + default: + break; + } + + lua_rawgeti(cbd->L, LUA_REGISTRYINDEX, cbd->cbref); + ptask = lua_newuserdata(cbd->L, sizeof(*ptask)); + *ptask = task; + lua_pushboolean(cbd->L, success); + + if (error_str) { + lua_pushstring(cbd->L, error_str); + } + else { + lua_pushnil(cbd->L); + } + + if (cbd->ctx) { + if (res->domain) { + lua_pushstring(cbd->L, res->domain); + } + else { + lua_pushnil(cbd->L); + } + + if (res->selector) { + lua_pushstring(cbd->L, res->selector); + } + else { + lua_pushnil(cbd->L); + } + + if (res->short_b) { + lua_pushstring(cbd->L, res->short_b); + } + else { + lua_pushnil(cbd->L); + } + + if (res->fail_reason) { + lua_pushstring(cbd->L, res->fail_reason); + } + else { + lua_pushnil(cbd->L); + } + } + else { + lua_pushnil(cbd->L); + lua_pushnil(cbd->L); + lua_pushnil(cbd->L); + lua_pushnil(cbd->L); + } + + if (lua_pcall(cbd->L, 7, 0, 0) != 0) { + msg_err_task("call to verify callback failed: %s", + lua_tostring(cbd->L, -1)); + lua_pop(cbd->L, 1); + } + + luaL_unref(cbd->L, LUA_REGISTRYINDEX, cbd->cbref); +} + +static void +dkim_module_lua_on_key(rspamd_dkim_key_t *key, + gsize keylen, + rspamd_dkim_context_t *ctx, + gpointer ud, + GError *err) +{ + struct rspamd_dkim_lua_verify_cbdata *cbd = ud; + struct rspamd_task *task; + struct rspamd_dkim_check_result *res; + struct dkim_ctx *dkim_module_ctx; + + task = cbd->task; + dkim_module_ctx = dkim_get_context(task->cfg); + + if (key != NULL) { + /* Another ref belongs to the check context */ + cbd->key = rspamd_dkim_key_ref(key); + /* + * We actually receive key with refcount = 1, so we just assume that + * lru hash owns this object now + */ + + if (dkim_module_ctx->dkim_hash) { + rspamd_lru_hash_insert(dkim_module_ctx->dkim_hash, + g_strdup(rspamd_dkim_get_dns_key(ctx)), + key, cbd->task->task_timestamp, rspamd_dkim_key_get_ttl(key)); + } + /* Release key when task is processed */ + rspamd_mempool_add_destructor(cbd->task->task_pool, + dkim_module_key_dtor, cbd->key); + } + else { + /* Insert tempfail symbol */ + msg_info_task("cannot get key for domain %s: %e", + rspamd_dkim_get_dns_key(ctx), err); + + if (err != NULL) { + if (err->code == DKIM_SIGERROR_NOKEY) { + res = rspamd_dkim_create_result(ctx, DKIM_TRYAGAIN, task); + res->fail_reason = "DNS error when getting key"; + } + else { + res = rspamd_dkim_create_result(ctx, DKIM_PERM_ERROR, task); + res->fail_reason = "invalid DKIM record"; + } + } + else { + res = rspamd_dkim_create_result(ctx, DKIM_TRYAGAIN, task); + res->fail_reason = "DNS error when getting key"; + } + + dkim_module_lua_push_verify_result(cbd, res, err); + + if (err) { + g_error_free(err); + } + + return; + } + + res = rspamd_dkim_check(cbd->ctx, cbd->key, cbd->task); + dkim_module_lua_push_verify_result(cbd, res, NULL); +} + +static gint +lua_dkim_verify_handler(lua_State *L) +{ + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *sig = luaL_checkstring(L, 2); + rspamd_dkim_context_t *ctx; + struct rspamd_dkim_lua_verify_cbdata *cbd; + rspamd_dkim_key_t *key; + struct rspamd_dkim_check_result *ret; + GError *err = NULL; + const gchar *type_str = NULL; + enum rspamd_dkim_type type = RSPAMD_DKIM_NORMAL; + struct dkim_ctx *dkim_module_ctx; + + if (task && sig && lua_isfunction(L, 3)) { + if (lua_isstring(L, 4)) { + type_str = lua_tostring(L, 4); + + if (type_str) { + if (strcmp(type_str, "dkim") == 0) { + type = RSPAMD_DKIM_NORMAL; + } + else if (strcmp(type_str, "arc-sign") == 0) { + type = RSPAMD_DKIM_ARC_SIG; + } + else if (strcmp(type_str, "arc-seal") == 0) { + type = RSPAMD_DKIM_ARC_SEAL; + } + else { + lua_settop(L, 0); + return luaL_error(L, "unknown sign type: %s", + type_str); + } + } + } + + dkim_module_ctx = dkim_get_context(task->cfg); + + ctx = rspamd_create_dkim_context(sig, + task->task_pool, + task->resolver, + dkim_module_ctx->time_jitter, + type, + &err); + + if (ctx == NULL) { + lua_pushboolean(L, false); + + if (err) { + lua_pushstring(L, err->message); + g_error_free(err); + } + else { + lua_pushstring(L, "unknown error"); + } + + return 2; + } + + cbd = rspamd_mempool_alloc(task->task_pool, sizeof(*cbd)); + cbd->L = L; + cbd->task = task; + lua_pushvalue(L, 3); + cbd->cbref = luaL_ref(L, LUA_REGISTRYINDEX); + cbd->ctx = ctx; + cbd->key = NULL; + + if (dkim_module_ctx->dkim_hash) { + key = rspamd_lru_hash_lookup(dkim_module_ctx->dkim_hash, + rspamd_dkim_get_dns_key(ctx), + task->task_timestamp); + } + else { + key = NULL; + } + + if (key != NULL) { + cbd->key = rspamd_dkim_key_ref(key); + /* Release key when task is processed */ + rspamd_mempool_add_destructor(task->task_pool, + dkim_module_key_dtor, cbd->key); + ret = rspamd_dkim_check(cbd->ctx, cbd->key, cbd->task); + dkim_module_lua_push_verify_result(cbd, ret, NULL); + } + else { + rspamd_get_dkim_key(ctx, + task, + dkim_module_lua_on_key, + cbd); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + lua_pushboolean(L, TRUE); + lua_pushnil(L); + + return 2; +} + +static gint +lua_dkim_canonicalize_handler(lua_State *L) +{ + gsize nlen, vlen; + const gchar *hname = luaL_checklstring(L, 1, &nlen), + *hvalue = luaL_checklstring(L, 2, &vlen); + static gchar st_buf[8192]; + gchar *buf; + guint inlen; + gboolean allocated = FALSE; + goffset r; + + if (hname && hvalue && nlen > 0) { + inlen = nlen + vlen + sizeof(":" CRLF); + + if (inlen > sizeof(st_buf)) { + buf = g_malloc(inlen); + allocated = TRUE; + } + else { + /* Faster */ + buf = st_buf; + } + + r = rspamd_dkim_canonize_header_relaxed_str(hname, hvalue, buf, inlen); + + if (r == -1) { + lua_pushnil(L); + } + else { + lua_pushlstring(L, buf, r); + } + + if (allocated) { + g_free(buf); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} diff --git a/src/plugins/fuzzy_check.c b/src/plugins/fuzzy_check.c new file mode 100644 index 0000000..85db83d --- /dev/null +++ b/src/plugins/fuzzy_check.c @@ -0,0 +1,4695 @@ +/* + * Copyright 2023 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/***MODULE:fuzzy + * rspamd module that checks fuzzy checksums for messages + * + * Allowed options: + * - symbol (string): symbol to insert (default: 'R_FUZZY') + * - max_score (double): maximum score to that weights of hashes would be normalized (default: 0 - no normalization) + * + * - fuzzy_map (string): a string that contains map in format { fuzzy_key => [ symbol, weight ] } where fuzzy_key is number of + * fuzzy list. This string itself should be in format 1:R_FUZZY_SAMPLE1:10,2:R_FUZZY_SAMPLE2:1 etc, where first number is fuzzy + * key, second is symbol to insert and third - weight for normalization + * + * - min_length (integer): minimum length (in characters) for text part to be checked for fuzzy hash (default: 0 - no limit) + * - whitelist (map string): map of ip addresses that should not be checked with this module + * - servers (string): list of fuzzy servers in format "server1:port,server2:port" - these servers would be used for checking and storing + * fuzzy hashes + */ + +#include "config.h" +#include "libmime/message.h" +#include "libserver/maps/map.h" +#include "libserver/maps/map_helpers.h" +#include "libmime/images.h" +#include "libserver/worker_util.h" +#include "libserver/mempool_vars_internal.h" +#include "fuzzy_wire.h" +#include "utlist.h" +#include "ottery.h" +#include "lua/lua_common.h" +#include "unix-std.h" +#include "libserver/http/http_private.h" +#include "libserver/http/http_router.h" +#include "libstat/stat_api.h" +#include <math.h> +#include "libutil/libev_helper.h" + +#define DEFAULT_SYMBOL "R_FUZZY_HASH" + +#define DEFAULT_IO_TIMEOUT 1.0 +#define DEFAULT_RETRANSMITS 3 +#define DEFAULT_MAX_ERRORS 4 +#define DEFAULT_REVIVE_TIME 60 +#define DEFAULT_PORT 11335 + +#define RSPAMD_FUZZY_PLUGIN_VERSION RSPAMD_FUZZY_VERSION + +static const gint rspamd_fuzzy_hash_len = 5; +static const gchar *M = "fuzzy check"; +struct fuzzy_ctx; + +struct fuzzy_mapping { + guint64 fuzzy_flag; + const gchar *symbol; + double weight; +}; + +struct fuzzy_rule { + struct upstream_list *servers; + const gchar *symbol; + const gchar *algorithm_str; + const gchar *name; + const ucl_object_t *ucl_obj; + enum rspamd_shingle_alg alg; + GHashTable *mappings; + GPtrArray *fuzzy_headers; + GString *hash_key; + GString *shingles_key; + gdouble io_timeout; + struct rspamd_cryptobox_keypair *local_key; + struct rspamd_cryptobox_pubkey *peer_key; + double max_score; + double weight_threshold; + gboolean read_only; + gboolean skip_unknown; + gboolean no_share; + gboolean no_subject; + gint learn_condition_cb; + guint32 retransmits; + struct rspamd_hash_map_helper *skip_map; + struct fuzzy_ctx *ctx; + gint lua_id; +}; + +struct fuzzy_ctx { + struct module_ctx ctx; + rspamd_mempool_t *fuzzy_pool; + GPtrArray *fuzzy_rules; + struct rspamd_config *cfg; + const gchar *default_symbol; + struct rspamd_radix_map_helper *whitelist; + struct rspamd_keypair_cache *keypairs_cache; + guint max_errors; + gdouble revive_time; + gdouble io_timeout; + gint check_mime_part_ref; /* Lua callback */ + gint process_rule_ref; /* Lua callback */ + gint cleanup_rules_ref; + guint32 retransmits; + gboolean enabled; +}; + +enum fuzzy_result_type { + FUZZY_RESULT_TXT, + FUZZY_RESULT_IMG, + FUZZY_RESULT_CONTENT, + FUZZY_RESULT_BIN +}; + +struct fuzzy_client_result { + const gchar *symbol; + gchar *option; + gdouble score; + gdouble prob; + enum fuzzy_result_type type; +}; + +struct fuzzy_client_session { + GPtrArray *commands; + GPtrArray *results; + struct rspamd_task *task; + struct rspamd_symcache_dynamic_item *item; + struct upstream *server; + struct fuzzy_rule *rule; + struct ev_loop *event_loop; + struct rspamd_io_ev ev; + gint state; + gint fd; + guint retransmits; +}; + +struct fuzzy_learn_session { + GPtrArray *commands; + gint *saved; + struct { + const gchar *error_message; + gint error_code; + } err; + struct rspamd_http_connection_entry *http_entry; + struct rspamd_async_session *session; + struct upstream *server; + struct fuzzy_rule *rule; + struct rspamd_task *task; + struct ev_loop *event_loop; + struct rspamd_io_ev ev; + gint fd; + guint retransmits; +}; + +#define FUZZY_CMD_FLAG_REPLIED (1 << 0) +#define FUZZY_CMD_FLAG_SENT (1 << 1) +#define FUZZY_CMD_FLAG_IMAGE (1 << 2) +#define FUZZY_CMD_FLAG_CONTENT (1 << 3) + +#define FUZZY_CHECK_FLAG_NOIMAGES (1 << 0) +#define FUZZY_CHECK_FLAG_NOATTACHMENTS (1 << 1) +#define FUZZY_CHECK_FLAG_NOTEXT (1 << 2) + +struct fuzzy_cmd_io { + guint32 tag; + guint32 flags; + struct iovec io; + struct rspamd_mime_part *part; + struct rspamd_fuzzy_cmd cmd; +}; + + +static const char *default_headers = "Subject,Content-Type,Reply-To,X-Mailer"; + +static void fuzzy_symbol_callback(struct rspamd_task *task, + struct rspamd_symcache_dynamic_item *item, + void *unused); + +/* Initialization */ +gint fuzzy_check_module_init(struct rspamd_config *cfg, + struct module_ctx **ctx); +gint fuzzy_check_module_config(struct rspamd_config *cfg, bool valdate); +gint fuzzy_check_module_reconfig(struct rspamd_config *cfg); +static gint fuzzy_attach_controller(struct module_ctx *ctx, + GHashTable *commands); +static gint fuzzy_lua_learn_handler(lua_State *L); +static gint fuzzy_lua_unlearn_handler(lua_State *L); +static gint fuzzy_lua_gen_hashes_handler(lua_State *L); +static gint fuzzy_lua_hex_hashes_handler(lua_State *L); +static gint fuzzy_lua_list_storages(lua_State *L); +static gint fuzzy_lua_ping_storage(lua_State *L); + +module_t fuzzy_check_module = { + "fuzzy_check", + fuzzy_check_module_init, + fuzzy_check_module_config, + fuzzy_check_module_reconfig, + fuzzy_attach_controller, + RSPAMD_MODULE_VER, + (guint) -1, +}; + +static inline struct fuzzy_ctx * +fuzzy_get_context(struct rspamd_config *cfg) +{ + return (struct fuzzy_ctx *) g_ptr_array_index(cfg->c_modules, + fuzzy_check_module.ctx_offset); +} + +static void +parse_flags(struct fuzzy_rule *rule, + struct rspamd_config *cfg, + const ucl_object_t *val, + gint cb_id) +{ + const ucl_object_t *elt; + struct fuzzy_mapping *map; + const gchar *sym = NULL; + + if (val->type == UCL_STRING) { + msg_err_config( + "string mappings are deprecated and no longer supported, use new style configuration"); + } + else if (val->type == UCL_OBJECT) { + elt = ucl_object_lookup(val, "symbol"); + if (elt == NULL || !ucl_object_tostring_safe(elt, &sym)) { + sym = ucl_object_key(val); + } + if (sym != NULL) { + map = + rspamd_mempool_alloc(cfg->cfg_pool, + sizeof(struct fuzzy_mapping)); + map->symbol = sym; + elt = ucl_object_lookup(val, "flag"); + + if (elt != NULL) { + map->fuzzy_flag = ucl_obj_toint(elt); + + elt = ucl_object_lookup(val, "max_score"); + + if (elt != NULL) { + map->weight = ucl_obj_todouble(elt); + } + else { + map->weight = rule->max_score; + } + /* Add flag to hash table */ + g_hash_table_insert(rule->mappings, + GINT_TO_POINTER(map->fuzzy_flag), map); + rspamd_symcache_add_symbol(cfg->cache, + map->symbol, 0, + NULL, NULL, + SYMBOL_TYPE_VIRTUAL | SYMBOL_TYPE_FINE, + cb_id); + } + else { + msg_err_config("fuzzy_map parameter has no flag definition"); + } + } + else { + msg_err_config("fuzzy_map parameter has no symbol definition"); + } + } + else { + msg_err_config("fuzzy_map parameter is of an unsupported type"); + } +} + +static GPtrArray * +parse_fuzzy_headers(struct rspamd_config *cfg, const gchar *str) +{ + gchar **strvec; + gint num, i; + GPtrArray *res; + + strvec = g_strsplit_set(str, ",", 0); + num = g_strv_length(strvec); + res = g_ptr_array_sized_new(num); + + for (i = 0; i < num; i++) { + g_strstrip(strvec[i]); + g_ptr_array_add(res, rspamd_mempool_strdup( + cfg->cfg_pool, strvec[i])); + } + + g_strfreev(strvec); + + return res; +} + +static double +fuzzy_normalize(gint32 in, double weight) +{ + if (weight == 0) { + return 0; + } +#ifdef HAVE_TANH + return tanh(G_E * (double) in / weight); +#else + return (in < weight ? in / weight : weight); +#endif +} + +static struct fuzzy_rule * +fuzzy_rule_new(const char *default_symbol, rspamd_mempool_t *pool) +{ + struct fuzzy_rule *rule; + + rule = rspamd_mempool_alloc0(pool, sizeof(struct fuzzy_rule)); + + rule->mappings = g_hash_table_new(g_direct_hash, g_direct_equal); + rule->symbol = default_symbol; + rspamd_mempool_add_destructor(pool, + (rspamd_mempool_destruct_t) g_hash_table_unref, + rule->mappings); + rule->read_only = FALSE; + rule->weight_threshold = NAN; + + return rule; +} + +static void +fuzzy_free_rule(gpointer r) +{ + struct fuzzy_rule *rule = (struct fuzzy_rule *) r; + + g_string_free(rule->hash_key, TRUE); + g_string_free(rule->shingles_key, TRUE); + + if (rule->local_key) { + rspamd_keypair_unref(rule->local_key); + } + + if (rule->peer_key) { + rspamd_pubkey_unref(rule->peer_key); + } +} + +static gint +fuzzy_parse_rule(struct rspamd_config *cfg, const ucl_object_t *obj, + const gchar *name, gint cb_id) +{ + const ucl_object_t *value, *cur; + struct fuzzy_rule *rule; + ucl_object_iter_t it = NULL; + const char *k = NULL, *key_str = NULL, *shingles_key_str = NULL, *lua_script; + struct fuzzy_ctx *fuzzy_module_ctx = fuzzy_get_context(cfg); + + if (obj->type != UCL_OBJECT) { + msg_err_config("invalid rule definition"); + return -1; + } + + if ((value = ucl_object_lookup_any(obj, "enabled", "enable", NULL)) != NULL) { + if (!ucl_object_toboolean(value)) { + msg_info_config("fuzzy rule %s is disabled by configuration", name); + + return 0; + } + } + + rule = fuzzy_rule_new(fuzzy_module_ctx->default_symbol, + cfg->cfg_pool); + rule->ucl_obj = obj; + rule->ctx = fuzzy_module_ctx; + rule->learn_condition_cb = -1; + rule->alg = RSPAMD_SHINGLES_OLD; + rule->skip_map = NULL; + + if ((value = ucl_object_lookup(obj, "skip_hashes")) != NULL) { + rspamd_map_add_from_ucl(cfg, value, + "Fuzzy hashes whitelist", + rspamd_kv_list_read, + rspamd_kv_list_fin, + rspamd_kv_list_dtor, + (void **) &rule->skip_map, + NULL, RSPAMD_MAP_DEFAULT); + } + + if ((value = ucl_object_lookup(obj, "headers")) != NULL) { + it = NULL; + while ((cur = ucl_object_iterate(value, &it, value->type == UCL_ARRAY)) != NULL) { + GPtrArray *tmp; + guint i; + gpointer ptr; + + tmp = parse_fuzzy_headers(cfg, ucl_obj_tostring(cur)); + + if (tmp) { + if (rule->fuzzy_headers) { + PTR_ARRAY_FOREACH(tmp, i, ptr) + { + g_ptr_array_add(rule->fuzzy_headers, ptr); + } + + g_ptr_array_free(tmp, TRUE); + } + else { + rule->fuzzy_headers = tmp; + } + } + } + } + else { + rule->fuzzy_headers = parse_fuzzy_headers(cfg, default_headers); + } + + if (rule->fuzzy_headers != NULL) { + rspamd_mempool_add_destructor(cfg->cfg_pool, + (rspamd_mempool_destruct_t) rspamd_ptr_array_free_hard, + rule->fuzzy_headers); + } + + + if ((value = ucl_object_lookup(obj, "max_score")) != NULL) { + rule->max_score = ucl_obj_todouble(value); + } + + if ((value = ucl_object_lookup(obj, "retransmits")) != NULL) { + rule->retransmits = ucl_obj_toint(value); + } + else { + rule->retransmits = fuzzy_module_ctx->retransmits; + } + + if ((value = ucl_object_lookup(obj, "timeout")) != NULL) { + rule->io_timeout = ucl_obj_todouble(value); + } + else { + rule->io_timeout = fuzzy_module_ctx->io_timeout; + } + + if ((value = ucl_object_lookup(obj, "symbol")) != NULL) { + rule->symbol = ucl_obj_tostring(value); + } + + if (name) { + rule->name = name; + } + else { + rule->name = rule->symbol; + } + + + if ((value = ucl_object_lookup(obj, "read_only")) != NULL) { + rule->read_only = ucl_obj_toboolean(value); + } + + if ((value = ucl_object_lookup(obj, "skip_unknown")) != NULL) { + rule->skip_unknown = ucl_obj_toboolean(value); + } + + if ((value = ucl_object_lookup(obj, "no_share")) != NULL) { + rule->no_share = ucl_obj_toboolean(value); + } + + if ((value = ucl_object_lookup(obj, "no_subject")) != NULL) { + rule->no_subject = ucl_obj_toboolean(value); + } + + if ((value = ucl_object_lookup(obj, "algorithm")) != NULL) { + rule->algorithm_str = ucl_object_tostring(value); + + if (rule->algorithm_str) { + if (g_ascii_strcasecmp(rule->algorithm_str, "old") == 0 || + g_ascii_strcasecmp(rule->algorithm_str, "siphash") == 0) { + rule->alg = RSPAMD_SHINGLES_OLD; + } + else if (g_ascii_strcasecmp(rule->algorithm_str, "xxhash") == 0) { + rule->alg = RSPAMD_SHINGLES_XXHASH; + } + else if (g_ascii_strcasecmp(rule->algorithm_str, "mumhash") == 0) { + rule->alg = RSPAMD_SHINGLES_MUMHASH; + } + else if (g_ascii_strcasecmp(rule->algorithm_str, "fasthash") == 0 || + g_ascii_strcasecmp(rule->algorithm_str, "fast") == 0) { + rule->alg = RSPAMD_SHINGLES_FAST; + } + else { + msg_warn_config("unknown algorithm: %s, use siphash by default", + rule->algorithm_str); + } + } + } + + /* Set a consistent and short string name */ + switch (rule->alg) { + case RSPAMD_SHINGLES_OLD: + rule->algorithm_str = "sip"; + break; + case RSPAMD_SHINGLES_XXHASH: + rule->algorithm_str = "xx"; + break; + case RSPAMD_SHINGLES_MUMHASH: + rule->algorithm_str = "mum"; + break; + case RSPAMD_SHINGLES_FAST: + rule->algorithm_str = "fast"; + break; + } + + if ((value = ucl_object_lookup(obj, "servers")) != NULL) { + rule->servers = rspamd_upstreams_create(cfg->ups_ctx); + /* pass max_error and revive_time configuration in upstream for fuzzy storage + * it allows to configure error_rate threshold and upstream dead timer + */ + rspamd_upstreams_set_limits(rule->servers, + (gdouble) fuzzy_module_ctx->revive_time, NAN, NAN, NAN, + (guint) fuzzy_module_ctx->max_errors, 0); + + rspamd_mempool_add_destructor(cfg->cfg_pool, + (rspamd_mempool_destruct_t) rspamd_upstreams_destroy, + rule->servers); + if (!rspamd_upstreams_from_ucl(rule->servers, value, DEFAULT_PORT, NULL)) { + msg_err_config("cannot read servers definition"); + return -1; + } + } + if ((value = ucl_object_lookup(obj, "fuzzy_map")) != NULL) { + it = NULL; + while ((cur = ucl_object_iterate(value, &it, true)) != NULL) { + parse_flags(rule, cfg, cur, cb_id); + } + } + + if ((value = ucl_object_lookup(obj, "encryption_key")) != NULL) { + /* Create key from user's input */ + k = ucl_object_tostring(value); + + if (k == NULL || (rule->peer_key = + rspamd_pubkey_from_base32(k, 0, RSPAMD_KEYPAIR_KEX, + RSPAMD_CRYPTOBOX_MODE_25519)) == NULL) { + msg_err_config("bad encryption key value: %s", + k); + return -1; + } + + rule->local_key = rspamd_keypair_new(RSPAMD_KEYPAIR_KEX, + RSPAMD_CRYPTOBOX_MODE_25519); + } + + if ((value = ucl_object_lookup(obj, "learn_condition")) != NULL) { + lua_script = ucl_object_tostring(value); + + if (lua_script) { + if (luaL_dostring(cfg->lua_state, lua_script) != 0) { + msg_err_config("cannot execute lua script for fuzzy " + "learn condition: %s", + lua_tostring(cfg->lua_state, -1)); + } + else { + if (lua_type(cfg->lua_state, -1) == LUA_TFUNCTION) { + rule->learn_condition_cb = luaL_ref(cfg->lua_state, + LUA_REGISTRYINDEX); + msg_info_config("loaded learn condition script for fuzzy rule:" + " %s", + rule->name); + } + else { + msg_err_config("lua script must return " + "function(task) and not %s", + lua_typename(cfg->lua_state, + lua_type(cfg->lua_state, -1))); + } + } + } + } + + key_str = NULL; + if ((value = ucl_object_lookup(obj, "fuzzy_key")) != NULL) { + /* Create key from user's input */ + key_str = ucl_object_tostring(value); + } + + /* Setup keys */ + if (key_str == NULL) { + /* Use some default key for all ops */ + key_str = "rspamd"; + } + + rule->hash_key = g_string_sized_new(rspamd_cryptobox_HASHBYTES); + rspamd_cryptobox_hash(rule->hash_key->str, key_str, strlen(key_str), NULL, 0); + rule->hash_key->len = rspamd_cryptobox_HASHKEYBYTES; + + shingles_key_str = NULL; + if ((value = ucl_object_lookup(obj, "fuzzy_shingles_key")) != NULL) { + shingles_key_str = ucl_object_tostring(value); + } + if (shingles_key_str == NULL) { + shingles_key_str = "rspamd"; + } + + rule->shingles_key = g_string_sized_new(rspamd_cryptobox_HASHBYTES); + rspamd_cryptobox_hash(rule->shingles_key->str, shingles_key_str, + strlen(shingles_key_str), NULL, 0); + rule->shingles_key->len = 16; + + if (rspamd_upstreams_count(rule->servers) == 0) { + msg_err_config("no servers defined for fuzzy rule with name: %s", + rule->name); + return -1; + } + else { + g_ptr_array_add(fuzzy_module_ctx->fuzzy_rules, rule); + + if (rule->symbol != fuzzy_module_ctx->default_symbol) { + int vid = rspamd_symcache_add_symbol(cfg->cache, rule->symbol, + 0, + NULL, NULL, + SYMBOL_TYPE_VIRTUAL | SYMBOL_TYPE_FINE, + cb_id); + + if (rule->io_timeout > 0) { + char timeout_buf[32]; + rspamd_snprintf(timeout_buf, sizeof(timeout_buf), "%f", + rule->io_timeout); + rspamd_symcache_add_symbol_augmentation(cfg->cache, + vid, "timeout", + timeout_buf); + } + } + + msg_info_config("added fuzzy rule %s, key: %*xs, " + "shingles_key: %*xs, algorithm: %s", + rule->symbol, + 6, rule->hash_key->str, + 6, rule->shingles_key->str, + rule->algorithm_str); + } + + if ((value = ucl_object_lookup(obj, "weight_threshold")) != NULL) { + rule->weight_threshold = ucl_object_todouble(value); + } + + /* + * Process rule in Lua + */ + gint err_idx, ret; + lua_State *L = (lua_State *) cfg->lua_state; + + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + lua_rawgeti(L, LUA_REGISTRYINDEX, fuzzy_module_ctx->process_rule_ref); + ucl_object_push_lua(L, obj, true); + + if ((ret = lua_pcall(L, 1, 1, err_idx)) != 0) { + msg_err_config("call to process_rule lua " + "script failed (%d): %s", + ret, lua_tostring(L, -1)); + + rule->lua_id = -1; + } + else { + rule->lua_id = lua_tonumber(L, -1); + } + + lua_settop(L, err_idx - 1); + + rspamd_mempool_add_destructor(cfg->cfg_pool, fuzzy_free_rule, + rule); + + return 0; +} + +gint fuzzy_check_module_init(struct rspamd_config *cfg, struct module_ctx **ctx) +{ + struct fuzzy_ctx *fuzzy_module_ctx; + + fuzzy_module_ctx = rspamd_mempool_alloc0(cfg->cfg_pool, + sizeof(struct fuzzy_ctx)); + + fuzzy_module_ctx->fuzzy_pool = rspamd_mempool_new(rspamd_mempool_suggest_size(), + NULL, 0); + /* TODO: this should match rules count actually */ + fuzzy_module_ctx->keypairs_cache = rspamd_keypair_cache_new(32); + fuzzy_module_ctx->fuzzy_rules = g_ptr_array_new(); + fuzzy_module_ctx->cfg = cfg; + fuzzy_module_ctx->process_rule_ref = -1; + fuzzy_module_ctx->check_mime_part_ref = -1; + fuzzy_module_ctx->cleanup_rules_ref = -1; + + rspamd_mempool_add_destructor(cfg->cfg_pool, + (rspamd_mempool_destruct_t) rspamd_mempool_delete, + fuzzy_module_ctx->fuzzy_pool); + rspamd_mempool_add_destructor(cfg->cfg_pool, + (rspamd_mempool_destruct_t) rspamd_keypair_cache_destroy, + fuzzy_module_ctx->keypairs_cache); + rspamd_mempool_add_destructor(cfg->cfg_pool, + (rspamd_mempool_destruct_t) rspamd_ptr_array_free_hard, + fuzzy_module_ctx->fuzzy_rules); + + *ctx = (struct module_ctx *) fuzzy_module_ctx; + + rspamd_rcl_add_doc_by_path(cfg, + NULL, + "Fuzzy check plugin", + "fuzzy_check", + UCL_OBJECT, + NULL, + 0, + NULL, + 0); + + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check", + "Default symbol", + "symbol", + UCL_STRING, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check", + "Minimum number of *words* to check a text part", + "min_length", + UCL_INT, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check", + "Minimum number of *bytes* to check a non-text part", + "min_bytes", + UCL_INT, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check", + "Multiplier for bytes limit when checking for text parts", + "text_multiplier", + UCL_FLOAT, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check", + "Minimum height in pixels for embedded images to check using fuzzy storage", + "min_height", + UCL_INT, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check", + "Minimum width in pixels for embedded images to check using fuzzy storage", + "min_width", + UCL_INT, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check", + "Timeout for waiting reply from a fuzzy server", + "timeout", + UCL_TIME, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check", + "Maximum number of retransmits for a single request", + "retransmits", + UCL_INT, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check", + "Maximum number of upstream errors, affects error rate threshold", + "max_errors", + UCL_INT, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check", + "Time to lapse before re-resolve faulty upstream", + "revive_time", + UCL_FLOAT, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check", + "Whitelisted IPs map", + "whitelist", + UCL_STRING, + NULL, + 0, + NULL, + 0); + /* Rules doc strings */ + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check", + "Fuzzy check rule", + "rule", + UCL_OBJECT, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check.rule", + "Headers that are used to make a separate hash", + "headers", + UCL_ARRAY, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check.rule", + "Whitelisted hashes map", + "skip_hashes", + UCL_STRING, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check.rule", + "Set of mime types (in form type/subtype, or type/*, or *) to check with fuzzy", + "mime_types", + UCL_ARRAY, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check.rule", + "Maximum value for fuzzy hash when weight of symbol is exactly 1.0 (if value is higher then score is still 1.0)", + "max_score", + UCL_INT, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check.rule", + "List of servers to check (or learn)", + "servers", + UCL_STRING, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check.rule", + "If true then never try to learn this fuzzy storage", + "read_only", + UCL_BOOLEAN, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check.rule", + "If true then ignore unknown flags and not add the default fuzzy symbol", + "skip_unknown", + UCL_BOOLEAN, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check.rule", + "Default symbol for rule (if no flags defined or matched)", + "symbol", + UCL_STRING, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check.rule", + "Base32 value for the protocol encryption public key", + "encryption_key", + UCL_STRING, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check.rule", + "Base32 value for the hashing key (for private storages)", + "fuzzy_key", + UCL_STRING, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check.rule", + "Base32 value for the shingles hashing key (for private storages)", + "fuzzy_shingles_key", + UCL_STRING, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check.rule", + "Lua script that returns boolean function to check if this task " + "should be considered when learning fuzzy storage", + "learn_condition", + UCL_STRING, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check.rule", + "Map of SYMBOL -> data for flags configuration", + "fuzzy_map", + UCL_OBJECT, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check.rule", + "Use direct hash for short texts", + "short_text_direct_hash", + UCL_BOOLEAN, + NULL, + 0, + "true", + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check.rule", + "Override module default min bytes for this rule", + "min_bytes", + UCL_INT, + NULL, + 0, + NULL, + 0); + /* Fuzzy map doc strings */ + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check.rule.fuzzy_map", + "Maximum score for this flag", + "max_score", + UCL_INT, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check.rule.fuzzy_map", + "Flag number", + "flag", + UCL_INT, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check.rule", + "Do no use subject to distinguish short text hashes", + "no_subject", + UCL_BOOLEAN, + NULL, + 0, + "false", + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check.rule", + "Disable sharing message stats with the fuzzy server", + "no_share", + UCL_BOOLEAN, + NULL, + 0, + "false", + 0); + + return 0; +} + +gint fuzzy_check_module_config(struct rspamd_config *cfg, bool validate) +{ + const ucl_object_t *value, *cur, *elt; + ucl_object_iter_t it; + gint res = TRUE, cb_id, nrules = 0; + lua_State *L = cfg->lua_state; + struct fuzzy_ctx *fuzzy_module_ctx = fuzzy_get_context(cfg); + + if (!rspamd_config_is_module_enabled(cfg, "fuzzy_check")) { + return TRUE; + } + + fuzzy_module_ctx->enabled = TRUE; + fuzzy_module_ctx->check_mime_part_ref = -1; + fuzzy_module_ctx->process_rule_ref = -1; + fuzzy_module_ctx->cleanup_rules_ref = -1; + + /* Interact with lua_fuzzy */ + if (luaL_dostring(L, "return require \"lua_fuzzy\"") != 0) { + msg_err_config("cannot require lua_fuzzy: %s", + lua_tostring(L, -1)); + fuzzy_module_ctx->enabled = FALSE; + } + else { +#if LUA_VERSION_NUM >= 504 + lua_settop(L, -2); +#endif + if (lua_type(L, -1) != LUA_TTABLE) { + msg_err_config("lua fuzzy must return " + "table and not %s", + lua_typename(L, lua_type(L, -1))); + fuzzy_module_ctx->enabled = FALSE; + } + else { + lua_pushstring(L, "process_rule"); + lua_gettable(L, -2); + + if (lua_type(L, -1) != LUA_TFUNCTION) { + msg_err_config("process_rule must return " + "function and not %s", + lua_typename(L, lua_type(L, -1))); + fuzzy_module_ctx->enabled = FALSE; + } + else { + fuzzy_module_ctx->process_rule_ref = luaL_ref(L, LUA_REGISTRYINDEX); + } + + lua_pushstring(L, "check_mime_part"); + lua_gettable(L, -2); + + if (lua_type(L, -1) != LUA_TFUNCTION) { + msg_err_config("check_mime_part must return " + "function and not %s", + lua_typename(L, lua_type(L, -1))); + fuzzy_module_ctx->enabled = FALSE; + } + else { + fuzzy_module_ctx->check_mime_part_ref = luaL_ref(L, LUA_REGISTRYINDEX); + } + + lua_pushstring(L, "cleanup_rules"); + lua_gettable(L, -2); + + if (lua_type(L, -1) != LUA_TFUNCTION) { + msg_err_config("cleanup_rules must return " + "function and not %s", + lua_typename(L, lua_type(L, -1))); + fuzzy_module_ctx->enabled = FALSE; + } + else { + fuzzy_module_ctx->cleanup_rules_ref = luaL_ref(L, LUA_REGISTRYINDEX); + } + } + } + + lua_settop(L, 0); + + if (!fuzzy_module_ctx->enabled) { + return TRUE; + } + + if ((value = + rspamd_config_get_module_opt(cfg, "fuzzy_check", "symbol")) != NULL) { + fuzzy_module_ctx->default_symbol = ucl_obj_tostring(value); + } + else { + fuzzy_module_ctx->default_symbol = DEFAULT_SYMBOL; + } + + if ((value = + rspamd_config_get_module_opt(cfg, "fuzzy_check", "timeout")) != NULL) { + fuzzy_module_ctx->io_timeout = ucl_obj_todouble(value); + } + else { + fuzzy_module_ctx->io_timeout = DEFAULT_IO_TIMEOUT; + } + + if ((value = + rspamd_config_get_module_opt(cfg, + "fuzzy_check", + "retransmits")) != NULL) { + fuzzy_module_ctx->retransmits = ucl_obj_toint(value); + } + else { + fuzzy_module_ctx->retransmits = DEFAULT_RETRANSMITS; + } + + if ((value = + rspamd_config_get_module_opt(cfg, "fuzzy_check", + "max_errors")) != NULL) { + fuzzy_module_ctx->max_errors = ucl_obj_toint(value); + } + else { + fuzzy_module_ctx->max_errors = DEFAULT_MAX_ERRORS; + } + + if ((value = + rspamd_config_get_module_opt(cfg, "fuzzy_check", + "revive_time")) != NULL) { + fuzzy_module_ctx->revive_time = ucl_obj_todouble(value); + } + else { + fuzzy_module_ctx->revive_time = DEFAULT_REVIVE_TIME; + } + + if ((value = + rspamd_config_get_module_opt(cfg, "fuzzy_check", + "whitelist")) != NULL) { + rspamd_config_radix_from_ucl(cfg, value, "Fuzzy whitelist", + &fuzzy_module_ctx->whitelist, + NULL, + NULL, "fuzzy ip whitelist"); + } + else { + fuzzy_module_ctx->whitelist = NULL; + } + + if ((value = + rspamd_config_get_module_opt(cfg, "fuzzy_check", "rule")) != NULL) { + + cb_id = rspamd_symcache_add_symbol(cfg->cache, + "FUZZY_CALLBACK", 0, fuzzy_symbol_callback, NULL, + SYMBOL_TYPE_CALLBACK | SYMBOL_TYPE_FINE, + -1); + rspamd_config_add_symbol(cfg, + "FUZZY_CALLBACK", + 0.0, + "Fuzzy check callback", + "fuzzy", + RSPAMD_SYMBOL_FLAG_IGNORE_METRIC, + 1, + 1); + + /* + * Here we can have 2 possibilities: + * + * unnamed rules: + * + * rule { + * ... + * } + * rule { + * ... + * } + * + * - or - named rules: + * + * rule { + * "rule1": { + * ... + * } + * "rule2": { + * ... + * } + * } + * + * So, for each element, we check, if there 'servers' key. If 'servers' is + * presented, then we treat it as unnamed rule, otherwise we treat it as + * named rule. + */ + LL_FOREACH(value, cur) + { + + if (ucl_object_lookup(cur, "servers")) { + /* Unnamed rule */ + fuzzy_parse_rule(cfg, cur, NULL, cb_id); + nrules++; + } + else { + /* Named rule */ + it = NULL; + + while ((elt = ucl_object_iterate(cur, &it, true)) != NULL) { + fuzzy_parse_rule(cfg, elt, ucl_object_key(elt), cb_id); + nrules++; + } + } + } + + /* We want that to check bad mime attachments */ + rspamd_symcache_add_delayed_dependency(cfg->cache, + "FUZZY_CALLBACK", "MIME_TYPES_CALLBACK"); + } + + if (fuzzy_module_ctx->fuzzy_rules == NULL) { + msg_warn_config("fuzzy module is enabled but no rules are defined"); + } + + msg_info_config("init internal fuzzy_check module, %d rules loaded", + nrules); + + /* Register global methods */ + lua_getglobal(L, "rspamd_plugins"); + + if (lua_type(L, -1) == LUA_TTABLE) { + lua_pushstring(L, "fuzzy_check"); + lua_createtable(L, 0, 3); + /* Set methods */ + lua_pushstring(L, "unlearn"); + lua_pushcfunction(L, fuzzy_lua_unlearn_handler); + lua_settable(L, -3); + lua_pushstring(L, "learn"); + lua_pushcfunction(L, fuzzy_lua_learn_handler); + lua_settable(L, -3); + lua_pushstring(L, "gen_hashes"); + lua_pushcfunction(L, fuzzy_lua_gen_hashes_handler); + lua_settable(L, -3); + lua_pushstring(L, "hex_hashes"); + lua_pushcfunction(L, fuzzy_lua_hex_hashes_handler); + lua_settable(L, -3); + lua_pushstring(L, "list_storages"); + lua_pushcfunction(L, fuzzy_lua_list_storages); + lua_settable(L, -3); + lua_pushstring(L, "ping_storage"); + lua_pushcfunction(L, fuzzy_lua_ping_storage); + lua_settable(L, -3); + /* Finish fuzzy_check key */ + lua_settable(L, -3); + } + + lua_settop(L, 0); + + return res; +} + +gint fuzzy_check_module_reconfig(struct rspamd_config *cfg) +{ + struct fuzzy_ctx *fuzzy_module_ctx = fuzzy_get_context(cfg); + + if (fuzzy_module_ctx->cleanup_rules_ref != -1) { + /* Sync lua_fuzzy rules */ + gint err_idx, ret; + lua_State *L = (lua_State *) cfg->lua_state; + + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + lua_rawgeti(L, LUA_REGISTRYINDEX, fuzzy_module_ctx->cleanup_rules_ref); + + if ((ret = lua_pcall(L, 0, 0, err_idx)) != 0) { + msg_err_config("call to cleanup_rules lua " + "script failed (%d): %s", + ret, lua_tostring(L, -1)); + } + + luaL_unref(cfg->lua_state, LUA_REGISTRYINDEX, + fuzzy_module_ctx->cleanup_rules_ref); + lua_settop(L, 0); + } + + if (fuzzy_module_ctx->check_mime_part_ref != -1) { + luaL_unref(cfg->lua_state, LUA_REGISTRYINDEX, + fuzzy_module_ctx->check_mime_part_ref); + } + + if (fuzzy_module_ctx->process_rule_ref != -1) { + luaL_unref(cfg->lua_state, LUA_REGISTRYINDEX, + fuzzy_module_ctx->process_rule_ref); + } + + return fuzzy_check_module_config(cfg, false); +} + +/* Finalize IO */ +static void +fuzzy_io_fin(void *ud) +{ + struct fuzzy_client_session *session = ud; + + if (session->commands) { + g_ptr_array_free(session->commands, TRUE); + } + + if (session->results) { + g_ptr_array_free(session->results, TRUE); + } + + rspamd_ev_watcher_stop(session->event_loop, &session->ev); + close(session->fd); +} + +static GArray * +fuzzy_preprocess_words(struct rspamd_mime_text_part *part, rspamd_mempool_t *pool) +{ + return part->utf_words; +} + +static void +fuzzy_encrypt_cmd(struct fuzzy_rule *rule, + struct rspamd_fuzzy_encrypted_req_hdr *hdr, + guchar *data, gsize datalen) +{ + const guchar *pk; + guint pklen; + + g_assert(hdr != NULL); + g_assert(data != NULL); + g_assert(rule != NULL); + + /* Encrypt data */ + memcpy(hdr->magic, + fuzzy_encrypted_magic, + sizeof(hdr->magic)); + ottery_rand_bytes(hdr->nonce, sizeof(hdr->nonce)); + pk = rspamd_keypair_component(rule->local_key, + RSPAMD_KEYPAIR_COMPONENT_PK, &pklen); + memcpy(hdr->pubkey, pk, MIN(pklen, sizeof(hdr->pubkey))); + pk = rspamd_pubkey_get_pk(rule->peer_key, &pklen); + memcpy(hdr->key_id, pk, MIN(sizeof(hdr->key_id), pklen)); + rspamd_keypair_cache_process(rule->ctx->keypairs_cache, + rule->local_key, rule->peer_key); + rspamd_cryptobox_encrypt_nm_inplace(data, datalen, + hdr->nonce, rspamd_pubkey_get_nm(rule->peer_key, rule->local_key), + hdr->mac, + rspamd_pubkey_alg(rule->peer_key)); +} + +static struct fuzzy_cmd_io * +fuzzy_cmd_stat(struct fuzzy_rule *rule, + int c, + gint flag, + guint32 weight, + rspamd_mempool_t *pool) +{ + struct rspamd_fuzzy_cmd *cmd; + struct rspamd_fuzzy_encrypted_cmd *enccmd = NULL; + struct fuzzy_cmd_io *io; + + if (rule->peer_key) { + enccmd = rspamd_mempool_alloc0(pool, sizeof(*enccmd)); + cmd = &enccmd->cmd; + } + else { + cmd = rspamd_mempool_alloc0(pool, sizeof(*cmd)); + } + + cmd->cmd = c; + cmd->version = RSPAMD_FUZZY_PLUGIN_VERSION; + cmd->shingles_count = 0; + cmd->tag = ottery_rand_uint32(); + + io = rspamd_mempool_alloc(pool, sizeof(*io)); + io->flags = 0; + io->tag = cmd->tag; + memcpy(&io->cmd, cmd, sizeof(io->cmd)); + + if (rule->peer_key && enccmd) { + fuzzy_encrypt_cmd(rule, &enccmd->hdr, (guchar *) cmd, sizeof(*cmd)); + io->io.iov_base = enccmd; + io->io.iov_len = sizeof(*enccmd); + } + else { + io->io.iov_base = cmd; + io->io.iov_len = sizeof(*cmd); + } + + return io; +} + +static inline double +fuzzy_milliseconds_since_midnight(void) +{ + double now = rspamd_get_calendar_ticks(); + double ms = now - (int64_t) now; + now = (((int64_t) now % 86400) + ms) * 1000; + + return now; +} + +static struct fuzzy_cmd_io * +fuzzy_cmd_ping(struct fuzzy_rule *rule, + rspamd_mempool_t *pool) +{ + struct rspamd_fuzzy_cmd *cmd; + struct rspamd_fuzzy_encrypted_cmd *enccmd = NULL; + struct fuzzy_cmd_io *io; + + if (rule->peer_key) { + enccmd = rspamd_mempool_alloc0(pool, sizeof(*enccmd)); + cmd = &enccmd->cmd; + } + else { + cmd = rspamd_mempool_alloc0(pool, sizeof(*cmd)); + } + + /* Get milliseconds since midnight */ + + + cmd->cmd = FUZZY_PING; + cmd->version = RSPAMD_FUZZY_PLUGIN_VERSION; + cmd->shingles_count = 0; + cmd->value = fuzzy_milliseconds_since_midnight(); /* Record timestamp */ + cmd->tag = ottery_rand_uint32(); + + io = rspamd_mempool_alloc(pool, sizeof(*io)); + io->flags = 0; + io->tag = cmd->tag; + memcpy(&io->cmd, cmd, sizeof(io->cmd)); + + if (rule->peer_key && enccmd) { + fuzzy_encrypt_cmd(rule, &enccmd->hdr, (guchar *) cmd, sizeof(*cmd)); + io->io.iov_base = enccmd; + io->io.iov_len = sizeof(*enccmd); + } + else { + io->io.iov_base = cmd; + io->io.iov_len = sizeof(*cmd); + } + + return io; +} + +static struct fuzzy_cmd_io * +fuzzy_cmd_hash(struct fuzzy_rule *rule, + int c, + const rspamd_ftok_t *hash, + gint flag, + guint32 weight, + rspamd_mempool_t *pool) +{ + struct rspamd_fuzzy_cmd *cmd; + struct rspamd_fuzzy_encrypted_cmd *enccmd = NULL; + struct fuzzy_cmd_io *io; + + if (rule->peer_key) { + enccmd = rspamd_mempool_alloc0(pool, sizeof(*enccmd)); + cmd = &enccmd->cmd; + } + else { + cmd = rspamd_mempool_alloc0(pool, sizeof(*cmd)); + } + + if (hash->len == sizeof(cmd->digest) * 2) { + /* It is hex encoding */ + if (rspamd_decode_hex_buf(hash->begin, hash->len, cmd->digest, + sizeof(cmd->digest)) == -1) { + msg_err_pool("cannot decode hash, wrong encoding"); + return NULL; + } + } + else { + msg_err_pool("cannot decode hash, wrong length: %z", hash->len); + return NULL; + } + + cmd->cmd = c; + cmd->version = RSPAMD_FUZZY_PLUGIN_VERSION; + cmd->shingles_count = 0; + cmd->tag = ottery_rand_uint32(); + + io = rspamd_mempool_alloc(pool, sizeof(*io)); + io->flags = 0; + io->tag = cmd->tag; + + memcpy(&io->cmd, cmd, sizeof(io->cmd)); + + if (rule->peer_key && enccmd) { + fuzzy_encrypt_cmd(rule, &enccmd->hdr, (guchar *) cmd, sizeof(*cmd)); + io->io.iov_base = enccmd; + io->io.iov_len = sizeof(*enccmd); + } + else { + io->io.iov_base = cmd; + io->io.iov_len = sizeof(*cmd); + } + + return io; +} + +struct rspamd_cached_shingles { + struct rspamd_shingle *sh; + guchar digest[rspamd_cryptobox_HASHBYTES]; + guint additional_length; + guchar *additional_data; +}; + + +static struct rspamd_cached_shingles * +fuzzy_cmd_get_cached(struct fuzzy_rule *rule, + struct rspamd_task *task, + struct rspamd_mime_part *mp) +{ + gchar key[32]; + gint key_part; + struct rspamd_cached_shingles **cached; + + memcpy(&key_part, rule->shingles_key->str, sizeof(key_part)); + rspamd_snprintf(key, sizeof(key), "%s%d", rule->algorithm_str, + key_part); + + cached = (struct rspamd_cached_shingles **) rspamd_mempool_get_variable( + task->task_pool, key); + + if (cached && cached[mp->part_number]) { + return cached[mp->part_number]; + } + + return NULL; +} + +static void +fuzzy_cmd_set_cached(struct fuzzy_rule *rule, + struct rspamd_task *task, + struct rspamd_mime_part *mp, + struct rspamd_cached_shingles *data) +{ + gchar key[32]; + gint key_part; + struct rspamd_cached_shingles **cached; + + memcpy(&key_part, rule->shingles_key->str, sizeof(key_part)); + rspamd_snprintf(key, sizeof(key), "%s%d", rule->algorithm_str, + key_part); + + cached = (struct rspamd_cached_shingles **) rspamd_mempool_get_variable( + task->task_pool, key); + + if (cached) { + cached[mp->part_number] = data; + } + else { + cached = rspamd_mempool_alloc0(task->task_pool, sizeof(*cached) * + (MESSAGE_FIELD(task, parts)->len + 1)); + cached[mp->part_number] = data; + + rspamd_mempool_set_variable(task->task_pool, key, cached, NULL); + } +} + +static gboolean +fuzzy_rule_check_mimepart(struct rspamd_task *task, + struct fuzzy_rule *rule, + struct rspamd_mime_part *part, + gboolean *need_check, + gboolean *fuzzy_check) +{ + lua_State *L = (lua_State *) task->cfg->lua_state; + + gint old_top = lua_gettop(L); + + if (rule->lua_id != -1 && rule->ctx->check_mime_part_ref != -1) { + gint err_idx, ret; + + struct rspamd_task **ptask; + struct rspamd_mime_part **ppart; + + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + lua_rawgeti(L, LUA_REGISTRYINDEX, rule->ctx->check_mime_part_ref); + + ptask = lua_newuserdata(L, sizeof(*ptask)); + *ptask = task; + rspamd_lua_setclass(L, "rspamd{task}", -1); + + ppart = lua_newuserdata(L, sizeof(*ppart)); + *ppart = part; + rspamd_lua_setclass(L, "rspamd{mimepart}", -1); + + lua_pushnumber(L, rule->lua_id); + + if ((ret = lua_pcall(L, 3, 2, err_idx)) != 0) { + msg_err_task("call to check_mime_part lua " + "script failed (%d): %s", + ret, lua_tostring(L, -1)); + + ret = FALSE; + } + else { + ret = TRUE; + *need_check = lua_toboolean(L, -2); + *fuzzy_check = lua_toboolean(L, -1); + } + + lua_settop(L, old_top); + + return ret; + } + + return FALSE; +} + +#define MAX_FUZZY_DOMAIN 64 + +static guint +fuzzy_cmd_extension_length(struct rspamd_task *task, + struct fuzzy_rule *rule) +{ + guint total = 0; + + if (rule->no_share) { + return 0; + } + + /* From domain */ + if (MESSAGE_FIELD(task, from_mime) && MESSAGE_FIELD(task, from_mime)->len > 0) { + struct rspamd_email_address *addr = g_ptr_array_index(MESSAGE_FIELD(task, + from_mime), + 0); + + if (addr->domain_len > 0) { + total += 2; /* 2 bytes: type + length */ + total += MIN(MAX_FUZZY_DOMAIN, addr->domain_len); + } + } + + if (task->from_addr && rspamd_inet_address_get_af(task->from_addr) == AF_INET) { + total += sizeof(struct in_addr) + 1; + } + else if (task->from_addr && rspamd_inet_address_get_af(task->from_addr) == AF_INET6) { + total += sizeof(struct in6_addr) + 1; + } + + return total; +} + +static guint +fuzzy_cmd_write_extensions(struct rspamd_task *task, + struct fuzzy_rule *rule, + guchar *dest, + gsize available) +{ + guint written = 0; + + if (rule->no_share) { + return 0; + } + + if (MESSAGE_FIELD(task, from_mime) && MESSAGE_FIELD(task, from_mime)->len > 0) { + struct rspamd_email_address *addr = g_ptr_array_index(MESSAGE_FIELD(task, + from_mime), + 0); + guint to_write = MIN(MAX_FUZZY_DOMAIN, addr->domain_len) + 2; + + if (to_write > 0 && to_write <= available) { + *dest++ = RSPAMD_FUZZY_EXT_SOURCE_DOMAIN; + *dest++ = to_write - 2; + + if (addr->domain_len < MAX_FUZZY_DOMAIN) { + memcpy(dest, addr->domain, addr->domain_len); + dest += addr->domain_len; + } + else { + /* Trim from left */ + memcpy(dest, + addr->domain + (addr->domain_len - MAX_FUZZY_DOMAIN), + MAX_FUZZY_DOMAIN); + dest += MAX_FUZZY_DOMAIN; + } + + available -= to_write; + written += to_write; + } + } + + if (task->from_addr && rspamd_inet_address_get_af(task->from_addr) == AF_INET) { + if (available >= sizeof(struct in_addr) + 1) { + guint klen; + guchar *inet_data = rspamd_inet_address_get_hash_key(task->from_addr, &klen); + + *dest++ = RSPAMD_FUZZY_EXT_SOURCE_IP4; + + memcpy(dest, inet_data, klen); + dest += klen; + + available -= klen + 1; + written += klen + 1; + } + } + else if (task->from_addr && rspamd_inet_address_get_af(task->from_addr) == AF_INET6) { + if (available >= sizeof(struct in6_addr) + 1) { + guint klen; + guchar *inet_data = rspamd_inet_address_get_hash_key(task->from_addr, &klen); + + *dest++ = RSPAMD_FUZZY_EXT_SOURCE_IP6; + + memcpy(dest, inet_data, klen); + dest += klen; + + available -= klen + 1; + written += klen + 1; + } + } + + return written; +} + +/* + * Create fuzzy command from a text part + */ +static struct fuzzy_cmd_io * +fuzzy_cmd_from_text_part(struct rspamd_task *task, + struct fuzzy_rule *rule, + int c, + gint flag, + guint32 weight, + gboolean short_text, + struct rspamd_mime_text_part *part, + struct rspamd_mime_part *mp) +{ + struct rspamd_fuzzy_shingle_cmd *shcmd = NULL; + struct rspamd_fuzzy_cmd *cmd = NULL; + struct rspamd_fuzzy_encrypted_shingle_cmd *encshcmd = NULL; + struct rspamd_fuzzy_encrypted_cmd *enccmd = NULL; + struct rspamd_cached_shingles *cached = NULL; + struct rspamd_shingle *sh = NULL; + guint i; + rspamd_cryptobox_hash_state_t st; + rspamd_stat_token_t *word; + GArray *words; + struct fuzzy_cmd_io *io; + guint additional_length; + guchar *additional_data; + + cached = fuzzy_cmd_get_cached(rule, task, mp); + + /* + * Important note: + * + * We assume that fuzzy io is a consistent memory layout to fit into + * iov structure of size 1 + * + * However, there are 4 possibilities: + * 1) non encrypted, non shingle command - just one cmd + * 2) encrypted, non shingle command - encryption hdr + cmd + * 3) non encrypted, shingle command - cmd + shingle + * 4) encrypted, shingle command - encryption hdr + cmd + shingle + * + * Extensions are always at the end, but since we also have caching (sigh, meh...) + * then we have one piece that looks like cmd (+ shingle) + extensions + * To encrypt it optionally we take this memory and prepend encryption header + * + * In case of cached version we do the same: allocate, copy from cached (including extra) + * and optionally encrypt. + * + * However, there should be no extensions in case of unencrypted connection + * (for sanity + privacy). + */ + if (cached) { + additional_length = cached->additional_length; + additional_data = cached->additional_data; + + /* Copy cached */ + if (short_text) { + enccmd = rspamd_mempool_alloc0(task->task_pool, + sizeof(*enccmd) + additional_length); + cmd = &enccmd->cmd; + memcpy(cmd->digest, cached->digest, + sizeof(cached->digest)); + cmd->shingles_count = 0; + memcpy(((guchar *) enccmd) + sizeof(*enccmd), additional_data, + additional_length); + } + else if (cached->sh) { + encshcmd = rspamd_mempool_alloc0(task->task_pool, + additional_length + sizeof(*encshcmd)); + shcmd = &encshcmd->cmd; + memcpy(&shcmd->sgl, cached->sh, sizeof(struct rspamd_shingle)); + memcpy(shcmd->basic.digest, cached->digest, + sizeof(cached->digest)); + memcpy(((guchar *) encshcmd) + sizeof(*encshcmd), additional_data, + additional_length); + shcmd->basic.shingles_count = RSPAMD_SHINGLE_SIZE; + } + else { + return NULL; + } + } + else { + additional_length = fuzzy_cmd_extension_length(task, rule); + cached = rspamd_mempool_alloc0(task->task_pool, sizeof(*cached) + + additional_length); + /* + * Allocate extensions and never touch it except copying to avoid + * occasional encryption + */ + cached->additional_length = additional_length; + cached->additional_data = ((guchar *) cached) + sizeof(*cached); + + if (additional_length > 0) { + fuzzy_cmd_write_extensions(task, rule, cached->additional_data, + additional_length); + } + + if (short_text) { + enccmd = rspamd_mempool_alloc0(task->task_pool, + sizeof(*enccmd) + additional_length); + cmd = &enccmd->cmd; + rspamd_cryptobox_hash_init(&st, rule->hash_key->str, + rule->hash_key->len); + + rspamd_cryptobox_hash_update(&st, part->utf_stripped_content->data, + part->utf_stripped_content->len); + + if (!rule->no_subject && (MESSAGE_FIELD(task, subject))) { + /* We also include subject */ + rspamd_cryptobox_hash_update(&st, MESSAGE_FIELD(task, subject), + strlen(MESSAGE_FIELD(task, subject))); + } + + rspamd_cryptobox_hash_final(&st, cmd->digest); + memcpy(cached->digest, cmd->digest, sizeof(cached->digest)); + cached->sh = NULL; + + additional_data = ((guchar *) enccmd) + sizeof(*enccmd); + memcpy(additional_data, cached->additional_data, additional_length); + } + else { + encshcmd = rspamd_mempool_alloc0(task->task_pool, + sizeof(*encshcmd) + additional_length); + shcmd = &encshcmd->cmd; + + /* + * Generate hash from all words in the part + */ + rspamd_cryptobox_hash_init(&st, rule->hash_key->str, rule->hash_key->len); + words = fuzzy_preprocess_words(part, task->task_pool); + + for (i = 0; i < words->len; i++) { + word = &g_array_index(words, rspamd_stat_token_t, i); + + if (!((word->flags & RSPAMD_STAT_TOKEN_FLAG_SKIPPED) || word->stemmed.len == 0)) { + rspamd_cryptobox_hash_update(&st, word->stemmed.begin, + word->stemmed.len); + } + } + + rspamd_cryptobox_hash_final(&st, shcmd->basic.digest); + + msg_debug_task("loading shingles of type %s with key %*xs", + rule->algorithm_str, + 16, rule->shingles_key->str); + sh = rspamd_shingles_from_text(words, + rule->shingles_key->str, task->task_pool, + rspamd_shingles_default_filter, NULL, + rule->alg); + if (sh != NULL) { + memcpy(&shcmd->sgl, sh, sizeof(shcmd->sgl)); + shcmd->basic.shingles_count = RSPAMD_SHINGLE_SIZE; + } + else { + /* No shingles, no check */ + return NULL; + } + + cached->sh = sh; + memcpy(cached->digest, shcmd->basic.digest, sizeof(cached->digest)); + additional_data = ((guchar *) encshcmd) + sizeof(*encshcmd); + memcpy(additional_data, cached->additional_data, additional_length); + } + + /* + * We always save encrypted command as it can handle both + * encrypted and unencrypted requests. + * + * Since it is copied when obtained from the cache, it is safe to use + * it this way. + */ + fuzzy_cmd_set_cached(rule, task, mp, cached); + } + + io = rspamd_mempool_alloc(task->task_pool, sizeof(*io)); + io->part = mp; + + if (!short_text) { + shcmd->basic.tag = ottery_rand_uint32(); + shcmd->basic.cmd = c; + shcmd->basic.version = RSPAMD_FUZZY_PLUGIN_VERSION; + + if (c != FUZZY_CHECK) { + shcmd->basic.flag = flag; + shcmd->basic.value = weight; + } + io->tag = shcmd->basic.tag; + memcpy(&io->cmd, &shcmd->basic, sizeof(io->cmd)); + } + else { + cmd->tag = ottery_rand_uint32(); + cmd->cmd = c; + cmd->version = RSPAMD_FUZZY_PLUGIN_VERSION; + + if (c != FUZZY_CHECK) { + cmd->flag = flag; + cmd->value = weight; + } + io->tag = cmd->tag; + memcpy(&io->cmd, cmd, sizeof(io->cmd)); + } + + io->flags = 0; + + + if (rule->peer_key) { + /* Encrypt data */ + if (!short_text) { + fuzzy_encrypt_cmd(rule, &encshcmd->hdr, (guchar *) shcmd, + sizeof(*shcmd) + additional_length); + io->io.iov_base = encshcmd; + io->io.iov_len = sizeof(*encshcmd) + additional_length; + } + else { + fuzzy_encrypt_cmd(rule, &enccmd->hdr, (guchar *) cmd, + sizeof(*cmd) + additional_length); + io->io.iov_base = enccmd; + io->io.iov_len = sizeof(*enccmd) + additional_length; + } + } + else { + + if (!short_text) { + io->io.iov_base = shcmd; + io->io.iov_len = sizeof(*shcmd) + additional_length; + } + else { + io->io.iov_base = cmd; + io->io.iov_len = sizeof(*cmd) + additional_length; + } + } + + return io; +} + +#if 0 +static struct fuzzy_cmd_io * +fuzzy_cmd_from_image_part (struct fuzzy_rule *rule, + int c, + gint flag, + guint32 weight, + struct rspamd_task *task, + struct rspamd_image *img, + struct rspamd_mime_part *mp) +{ + struct rspamd_fuzzy_shingle_cmd *shcmd; + struct rspamd_fuzzy_encrypted_shingle_cmd *encshcmd; + struct fuzzy_cmd_io *io; + struct rspamd_shingle *sh; + struct rspamd_cached_shingles *cached; + + cached = fuzzy_cmd_get_cached (rule, task, mp); + + if (cached) { + /* Copy cached */ + encshcmd = rspamd_mempool_alloc0 (task->task_pool, sizeof (*encshcmd)); + shcmd = &encshcmd->cmd; + memcpy (&shcmd->sgl, cached->sh, sizeof (struct rspamd_shingle)); + memcpy (shcmd->basic.digest, cached->digest, + sizeof (cached->digest)); + shcmd->basic.shingles_count = RSPAMD_SHINGLE_SIZE; + } + else { + encshcmd = rspamd_mempool_alloc0 (task->task_pool, sizeof (*encshcmd)); + shcmd = &encshcmd->cmd; + + /* + * Generate shingles + */ + sh = rspamd_shingles_from_image (img->dct, + rule->shingles_key->str, task->task_pool, + rspamd_shingles_default_filter, NULL, + rule->alg); + if (sh != NULL) { + memcpy (&shcmd->sgl, sh->hashes, sizeof (shcmd->sgl)); + shcmd->basic.shingles_count = RSPAMD_SHINGLE_SIZE; +#if 0 + for (unsigned int i = 0; i < RSPAMD_SHINGLE_SIZE; i ++) { + msg_err ("shingle %d: %L", i, sh->hashes[i]); + } +#endif + } + + rspamd_cryptobox_hash (shcmd->basic.digest, + (const guchar *)img->dct, RSPAMD_DCT_LEN / NBBY, + rule->hash_key->str, rule->hash_key->len); + + msg_debug_task ("loading shingles of type %s with key %*xs", + rule->algorithm_str, + 16, rule->shingles_key->str); + + /* + * We always save encrypted command as it can handle both + * encrypted and unencrypted requests. + * + * Since it is copied when obtained from the cache, it is safe to use + * it this way. + */ + cached = rspamd_mempool_alloc (task->task_pool, sizeof (*cached)); + cached->sh = sh; + memcpy (cached->digest, shcmd->basic.digest, sizeof (cached->digest)); + fuzzy_cmd_set_cached (rule, task, mp, cached); + } + + shcmd->basic.tag = ottery_rand_uint32 (); + shcmd->basic.cmd = c; + shcmd->basic.version = RSPAMD_FUZZY_PLUGIN_VERSION; + + if (c != FUZZY_CHECK) { + shcmd->basic.flag = flag; + shcmd->basic.value = weight; + } + + io = rspamd_mempool_alloc (task->task_pool, sizeof (*io)); + io->part = mp; + io->tag = shcmd->basic.tag; + io->flags = FUZZY_CMD_FLAG_IMAGE; + memcpy (&io->cmd, &shcmd->basic, sizeof (io->cmd)); + + if (rule->peer_key) { + /* Encrypt data */ + fuzzy_encrypt_cmd (rule, &encshcmd->hdr, (guchar *) shcmd, sizeof (*shcmd)); + io->io.iov_base = encshcmd; + io->io.iov_len = sizeof (*encshcmd); + } + else { + io->io.iov_base = shcmd; + io->io.iov_len = sizeof (*shcmd); + } + + return io; +} +#endif + +static struct fuzzy_cmd_io * +fuzzy_cmd_from_data_part(struct fuzzy_rule *rule, + int c, + gint flag, + guint32 weight, + struct rspamd_task *task, + guchar digest[rspamd_cryptobox_HASHBYTES], + struct rspamd_mime_part *mp) +{ + struct rspamd_fuzzy_cmd *cmd; + struct rspamd_fuzzy_encrypted_cmd *enccmd = NULL; + struct fuzzy_cmd_io *io; + guint additional_length; + guchar *additional_data; + + additional_length = fuzzy_cmd_extension_length(task, rule); + + if (rule->peer_key) { + enccmd = rspamd_mempool_alloc0(task->task_pool, + sizeof(*enccmd) + additional_length); + cmd = &enccmd->cmd; + additional_data = ((guchar *) enccmd) + sizeof(*enccmd); + } + else { + cmd = rspamd_mempool_alloc0(task->task_pool, + sizeof(*cmd) + additional_length); + additional_data = ((guchar *) cmd) + sizeof(*cmd); + } + + cmd->cmd = c; + cmd->version = RSPAMD_FUZZY_PLUGIN_VERSION; + if (c != FUZZY_CHECK) { + cmd->flag = flag; + cmd->value = weight; + } + cmd->shingles_count = 0; + cmd->tag = ottery_rand_uint32(); + memcpy(cmd->digest, digest, sizeof(cmd->digest)); + + io = rspamd_mempool_alloc(task->task_pool, sizeof(*io)); + io->flags = 0; + io->tag = cmd->tag; + io->part = mp; + memcpy(&io->cmd, cmd, sizeof(io->cmd)); + + if (additional_length > 0) { + fuzzy_cmd_write_extensions(task, rule, additional_data, + additional_length); + } + + if (rule->peer_key) { + g_assert(enccmd != NULL); + fuzzy_encrypt_cmd(rule, &enccmd->hdr, (guchar *) cmd, + sizeof(*cmd) + additional_length); + io->io.iov_base = enccmd; + io->io.iov_len = sizeof(*enccmd) + additional_length; + } + else { + io->io.iov_base = cmd; + io->io.iov_len = sizeof(*cmd) + additional_length; + } + + return io; +} + +static gboolean +fuzzy_cmd_to_wire(gint fd, struct iovec *io) +{ + struct msghdr msg; + + memset(&msg, 0, sizeof(msg)); + msg.msg_iov = io; + msg.msg_iovlen = 1; + + while (sendmsg(fd, &msg, 0) == -1) { + if (errno == EINTR) { + continue; + } + return FALSE; + } + + return TRUE; +} + +static gboolean +fuzzy_cmd_vector_to_wire(gint fd, GPtrArray *v) +{ + guint i; + gboolean all_sent = TRUE, all_replied = TRUE; + struct fuzzy_cmd_io *io; + gboolean processed = FALSE; + + /* First try to resend unsent commands */ + for (i = 0; i < v->len; i++) { + io = g_ptr_array_index(v, i); + + if (io->flags & FUZZY_CMD_FLAG_REPLIED) { + continue; + } + + all_replied = FALSE; + + if (!(io->flags & FUZZY_CMD_FLAG_SENT)) { + if (!fuzzy_cmd_to_wire(fd, &io->io)) { + return FALSE; + } + processed = TRUE; + io->flags |= FUZZY_CMD_FLAG_SENT; + all_sent = FALSE; + } + } + + if (all_sent && !all_replied) { + /* Now try to resend each command in the vector */ + for (i = 0; i < v->len; i++) { + io = g_ptr_array_index(v, i); + + if (!(io->flags & FUZZY_CMD_FLAG_REPLIED)) { + io->flags &= ~FUZZY_CMD_FLAG_SENT; + } + } + + return fuzzy_cmd_vector_to_wire(fd, v); + } + + return processed; +} + +/* + * Read replies one-by-one and remove them from req array + */ +static const struct rspamd_fuzzy_reply * +fuzzy_process_reply(guchar **pos, gint *r, GPtrArray *req, + struct fuzzy_rule *rule, struct rspamd_fuzzy_cmd **pcmd, + struct fuzzy_cmd_io **pio) +{ + guchar *p = *pos; + gint remain = *r; + guint i, required_size; + struct fuzzy_cmd_io *io; + const struct rspamd_fuzzy_reply *rep; + struct rspamd_fuzzy_encrypted_reply encrep; + gboolean found = FALSE; + + if (rule->peer_key) { + required_size = sizeof(encrep); + } + else { + required_size = sizeof(*rep); + } + + if (remain <= 0 || (guint) remain < required_size) { + return NULL; + } + + if (rule->peer_key) { + memcpy(&encrep, p, sizeof(encrep)); + *pos += required_size; + *r -= required_size; + + /* Try to decrypt reply */ + rspamd_keypair_cache_process(rule->ctx->keypairs_cache, + rule->local_key, rule->peer_key); + + if (!rspamd_cryptobox_decrypt_nm_inplace((guchar *) &encrep.rep, + sizeof(encrep.rep), + encrep.hdr.nonce, + rspamd_pubkey_get_nm(rule->peer_key, rule->local_key), + encrep.hdr.mac, + rspamd_pubkey_alg(rule->peer_key))) { + msg_info("cannot decrypt reply"); + return NULL; + } + + /* Copy decrypted over the input wire */ + memcpy(p, &encrep.rep, sizeof(encrep.rep)); + } + else { + + *pos += required_size; + *r -= required_size; + } + + rep = (const struct rspamd_fuzzy_reply *) p; + /* + * Search for tag + */ + for (i = 0; i < req->len; i++) { + io = g_ptr_array_index(req, i); + + if (io->tag == rep->v1.tag) { + if (!(io->flags & FUZZY_CMD_FLAG_REPLIED)) { + io->flags |= FUZZY_CMD_FLAG_REPLIED; + + if (pcmd) { + *pcmd = &io->cmd; + } + + if (pio) { + *pio = io; + } + + return rep; + } + found = TRUE; + } + } + + if (!found) { + msg_info("unexpected tag: %ud", rep->v1.tag); + } + + return NULL; +} + +static void +fuzzy_insert_result(struct fuzzy_client_session *session, + const struct rspamd_fuzzy_reply *rep, + struct rspamd_fuzzy_cmd *cmd, + struct fuzzy_cmd_io *io, + guint flag) +{ + const gchar *symbol; + struct fuzzy_mapping *map; + struct rspamd_task *task = session->task; + double weight; + double nval; + guchar buf[2048]; + const gchar *type = "bin"; + struct fuzzy_client_result *res; + gboolean is_fuzzy = FALSE; + gchar hexbuf[rspamd_cryptobox_HASHBYTES * 2 + 1]; + /* Discriminate scores for small images */ + static const guint short_image_limit = 32 * 1024; + + /* Get mapping by flag */ + if ((map = + g_hash_table_lookup(session->rule->mappings, + GINT_TO_POINTER(rep->v1.flag))) == NULL) { + /* Default symbol and default weight */ + symbol = session->rule->symbol; + weight = session->rule->max_score; + } + else { + /* Get symbol and weight from map */ + symbol = map->symbol; + weight = map->weight; + } + + res = rspamd_mempool_alloc0(task->task_pool, sizeof(*res)); + res->prob = rep->v1.prob; + res->symbol = symbol; + /* + * Hash is assumed to be found if probability is more than 0.5 + * In that case `value` means number of matches + * Otherwise `value` means error code + */ + + nval = fuzzy_normalize(rep->v1.value, weight); + + if (io) { + if ((io->flags & FUZZY_CMD_FLAG_IMAGE)) { + if (!io->part || io->part->parsed_data.len <= short_image_limit) { + nval *= rspamd_normalize_probability(rep->v1.prob, 0.5); + } + + type = "img"; + res->type = FUZZY_RESULT_IMG; + } + else { + /* Calc real probability */ + nval *= sqrtf(rep->v1.prob); + + if (cmd->shingles_count > 0) { + type = "txt"; + res->type = FUZZY_RESULT_TXT; + } + else { + if (io->flags & FUZZY_CMD_FLAG_CONTENT) { + type = "content"; + res->type = FUZZY_RESULT_CONTENT; + } + else { + res->type = FUZZY_RESULT_BIN; + } + } + } + } + + res->score = nval; + + if (memcmp(rep->digest, cmd->digest, sizeof(rep->digest)) != 0) { + is_fuzzy = TRUE; + } + + if (map != NULL || !session->rule->skip_unknown) { + GList *fuzzy_var; + rspamd_fstring_t *hex_result; + gchar timebuf[64]; + struct tm tm_split; + + if (session->rule->skip_map) { + rspamd_encode_hex_buf(cmd->digest, sizeof(cmd->digest), + hexbuf, sizeof(hexbuf) - 1); + hexbuf[sizeof(hexbuf) - 1] = '\0'; + if (rspamd_match_hash_map(session->rule->skip_map, hexbuf, + sizeof(hexbuf) - 1)) { + return; + } + } + + rspamd_encode_hex_buf(rep->digest, sizeof(rep->digest), + hexbuf, sizeof(hexbuf) - 1); + hexbuf[sizeof(hexbuf) - 1] = '\0'; + + rspamd_gmtime(rep->ts, &tm_split); + rspamd_snprintf(timebuf, sizeof(timebuf), "%02d.%02d.%4d %02d:%02d:%02d GMT", + tm_split.tm_mday, + tm_split.tm_mon + 1, + tm_split.tm_year + 1900, + tm_split.tm_hour, tm_split.tm_min, tm_split.tm_sec); + + if (is_fuzzy) { + msg_notice_task( + "found fuzzy hash(%s) %s (%*xs requested) with weight: " + "%.2f, probability %.2f, in list: %s:%d%s; added on %s", + type, + hexbuf, + (gint) sizeof(cmd->digest), cmd->digest, + nval, + (gdouble) rep->v1.prob, + symbol, + rep->v1.flag, + map == NULL ? "(unknown)" : "", + timebuf); + } + else { + msg_notice_task( + "found exact fuzzy hash(%s) %s with weight: " + "%.2f, probability %.2f, in list: %s:%d%s; added on %s", + type, + hexbuf, + nval, + (gdouble) rep->v1.prob, + symbol, + rep->v1.flag, + map == NULL ? "(unknown)" : "", + timebuf); + } + + rspamd_snprintf(buf, + sizeof(buf), + "%d:%*s:%.2f:%s", + rep->v1.flag, + (gint) MIN(rspamd_fuzzy_hash_len * 2, sizeof(rep->digest) * 2), hexbuf, + rep->v1.prob, + type); + res->option = rspamd_mempool_strdup(task->task_pool, buf); + g_ptr_array_add(session->results, res); + + /* Store hex string in pool variable */ + hex_result = rspamd_mempool_alloc(task->task_pool, + sizeof(rspamd_fstring_t) + sizeof(hexbuf)); + memcpy(hex_result->str, hexbuf, sizeof(hexbuf)); + hex_result->len = sizeof(hexbuf) - 1; + hex_result->allocated = (gsize) -1; + fuzzy_var = rspamd_mempool_get_variable(task->task_pool, + RSPAMD_MEMPOOL_FUZZY_RESULT); + + if (fuzzy_var == NULL) { + fuzzy_var = g_list_prepend(NULL, hex_result); + rspamd_mempool_set_variable(task->task_pool, + RSPAMD_MEMPOOL_FUZZY_RESULT, fuzzy_var, + (rspamd_mempool_destruct_t) g_list_free); + } + else { + /* Not very efficient, but we don't really use it intensively */ + fuzzy_var = g_list_append(fuzzy_var, hex_result); + } + } +} + +static gint +fuzzy_check_try_read(struct fuzzy_client_session *session) +{ + struct rspamd_task *task; + const struct rspamd_fuzzy_reply *rep; + struct rspamd_fuzzy_cmd *cmd = NULL; + struct fuzzy_cmd_io *io = NULL; + gint r, ret; + guchar buf[2048], *p; + + task = session->task; + + if ((r = read(session->fd, buf, sizeof(buf) - 1)) == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { + return 0; + } + else { + return -1; + } + } + else { + p = buf; + + ret = 0; + + while ((rep = fuzzy_process_reply(&p, &r, + session->commands, session->rule, &cmd, &io)) != NULL) { + if (rep->v1.prob > 0.5) { + if (cmd->cmd == FUZZY_CHECK) { + fuzzy_insert_result(session, rep, cmd, io, rep->v1.flag); + } + else if (cmd->cmd == FUZZY_STAT) { + /* + * We store fuzzy stat in the following way: + * 1) We store fuzzy hashes as a hash of rspamd_fuzzy_stat_entry + * 2) We store the resulting hash table inside pool variable `fuzzy_stat` + */ + struct rspamd_fuzzy_stat_entry *pval; + GHashTable *stats_hash; + + stats_hash = (GHashTable *) rspamd_mempool_get_variable(task->task_pool, + RSPAMD_MEMPOOL_FUZZY_STAT); + + if (stats_hash == NULL) { + stats_hash = g_hash_table_new(rspamd_str_hash, rspamd_str_equal); + rspamd_mempool_set_variable(task->task_pool, RSPAMD_MEMPOOL_FUZZY_STAT, + stats_hash, + (rspamd_mempool_destruct_t) g_hash_table_destroy); + } + + pval = g_hash_table_lookup(stats_hash, session->rule->name); + + if (pval == NULL) { + pval = rspamd_mempool_alloc(task->task_pool, + sizeof(*pval)); + pval->name = rspamd_mempool_strdup(task->task_pool, + session->rule->name); + /* Safe, as pval->name is owned by the pool */ + g_hash_table_insert(stats_hash, (char *) pval->name, pval); + } + + pval->fuzzy_cnt = (((guint64) rep->v1.value) << 32) + rep->v1.flag; + } + } + else if (rep->v1.value == 403) { + rspamd_task_insert_result(task, "FUZZY_BLOCKED", 0.0, + session->rule->name); + } + else if (rep->v1.value == 401) { + if (cmd->cmd != FUZZY_CHECK) { + msg_info_task( + "fuzzy check error for %d: skipped by server", + rep->v1.flag); + } + } + else if (rep->v1.value != 0) { + msg_info_task( + "fuzzy check error for %d: unknown error (%d)", + rep->v1.flag, + rep->v1.value); + } + + ret = 1; + } + } + + return ret; +} + +static void +fuzzy_insert_metric_results(struct rspamd_task *task, struct fuzzy_rule *rule, + GPtrArray *results) +{ + struct fuzzy_client_result *res; + guint i; + gboolean seen_text_hash = FALSE, + seen_img_hash = FALSE, + seen_text_part = FALSE, + seen_long_text = FALSE; + gdouble prob_txt = 0.0, mult; + struct rspamd_mime_text_part *tp; + + /* About 5 words */ + static const unsigned int text_length_cutoff = 25; + + PTR_ARRAY_FOREACH(results, i, res) + { + if (res->type == FUZZY_RESULT_TXT) { + seen_text_hash = TRUE; + prob_txt = MAX(prob_txt, res->prob); + } + else if (res->type == FUZZY_RESULT_IMG) { + seen_img_hash = TRUE; + } + } + + if (task->message) { + PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, tp) + { + if (!IS_TEXT_PART_EMPTY(tp) && tp->utf_words != NULL && tp->utf_words->len > 0) { + seen_text_part = TRUE; + + if (tp->utf_stripped_text.magic == UTEXT_MAGIC) { + if (utext_isLengthExpensive(&tp->utf_stripped_text)) { + seen_long_text = + utext_nativeLength(&tp->utf_stripped_text) > + text_length_cutoff; + } + else { + /* Cannot directly calculate length */ + seen_long_text = + (tp->utf_stripped_content->len / 2) > + text_length_cutoff; + } + } + } + } + } + + PTR_ARRAY_FOREACH(results, i, res) + { + mult = 1.0; + + if (res->type == FUZZY_RESULT_IMG) { + if (!seen_text_hash) { + if (seen_long_text) { + mult *= 0.25; + } + else if (seen_text_part) { + /* We have some short text + image */ + mult *= 0.9; + } + /* Otherwise apply full score */ + } + else if (prob_txt < 0.75) { + /* Penalize sole image without matching text */ + if (prob_txt > 0.5) { + mult *= prob_txt; + } + else { + mult *= 0.5; /* cutoff */ + } + } + } + else if (res->type == FUZZY_RESULT_TXT) { + if (seen_img_hash) { + /* Slightly increase score */ + mult = 1.1; + } + } + + gdouble weight = res->score * mult; + + if (!isnan(rule->weight_threshold)) { + if (weight >= rule->weight_threshold) { + rspamd_task_insert_result_single(task, res->symbol, + weight, res->option); + } + else { + msg_info_task("%s is not added: weight=%.4f below threshold", + res->symbol, weight); + } + } + else { + rspamd_task_insert_result_single(task, res->symbol, + weight, res->option); + } + } +} + +static gboolean +fuzzy_check_session_is_completed(struct fuzzy_client_session *session) +{ + struct fuzzy_cmd_io *io; + guint nreplied = 0, i; + + rspamd_upstream_ok(session->server); + + for (i = 0; i < session->commands->len; i++) { + io = g_ptr_array_index(session->commands, i); + + if (io->flags & FUZZY_CMD_FLAG_REPLIED) { + nreplied++; + } + } + + if (nreplied == session->commands->len) { + fuzzy_insert_metric_results(session->task, session->rule, session->results); + + if (session->item) { + rspamd_symcache_item_async_dec_check(session->task, session->item, M); + } + + rspamd_session_remove_event(session->task->s, fuzzy_io_fin, session); + + return TRUE; + } + + return FALSE; +} + +/* Fuzzy check timeout callback */ +static void +fuzzy_check_timer_callback(gint fd, short what, void *arg) +{ + struct fuzzy_client_session *session = arg; + struct rspamd_task *task; + + task = session->task; + + /* We might be here because of other checks being slow */ + if (fuzzy_check_try_read(session) > 0) { + if (fuzzy_check_session_is_completed(session)) { + return; + } + } + + if (session->retransmits >= session->rule->retransmits) { + msg_err_task("got IO timeout with server %s(%s), after %d/%d retransmits", + rspamd_upstream_name(session->server), + rspamd_inet_address_to_string_pretty( + rspamd_upstream_addr_cur(session->server)), + session->retransmits, + session->rule->retransmits); + rspamd_upstream_fail(session->server, TRUE, "timeout"); + + if (session->item) { + rspamd_symcache_item_async_dec_check(session->task, session->item, M); + } + rspamd_session_remove_event(session->task->s, fuzzy_io_fin, session); + } + else { + /* Plan write event */ + rspamd_ev_watcher_reschedule(session->event_loop, + &session->ev, EV_READ | EV_WRITE); + session->retransmits++; + } +} + +/* Fuzzy check callback */ +static void +fuzzy_check_io_callback(gint fd, short what, void *arg) +{ + struct fuzzy_client_session *session = arg; + struct rspamd_task *task; + gint r; + + enum { + return_error = 0, + return_want_more, + return_finished + } ret = return_error; + + task = session->task; + + if ((what & EV_READ) || session->state == 1) { + /* Try to read reply */ + r = fuzzy_check_try_read(session); + + switch (r) { + case 0: + if (what & EV_READ) { + ret = return_want_more; + } + else { + if (what & EV_WRITE) { + /* Retransmit attempt */ + if (!fuzzy_cmd_vector_to_wire(fd, session->commands)) { + ret = return_error; + } + else { + session->state = 1; + ret = return_want_more; + } + } + else { + /* It is actually time out */ + fuzzy_check_timer_callback(fd, what, arg); + return; + } + } + break; + case 1: + ret = return_finished; + break; + default: + ret = return_error; + break; + } + } + else if (what & EV_WRITE) { + if (!fuzzy_cmd_vector_to_wire(fd, session->commands)) { + ret = return_error; + } + else { + session->state = 1; + ret = return_want_more; + } + } + else { + fuzzy_check_timer_callback(fd, what, arg); + return; + } + + if (ret == return_want_more) { + /* Processed write, switch to reading */ + rspamd_ev_watcher_reschedule(session->event_loop, + &session->ev, EV_READ); + } + else if (ret == return_error) { + /* Error state */ + msg_err_task("got error on IO with server %s(%s), on %s, %d, %s", + rspamd_upstream_name(session->server), + rspamd_inet_address_to_string_pretty( + rspamd_upstream_addr_cur(session->server)), + session->state == 1 ? "read" : "write", + errno, + strerror(errno)); + rspamd_upstream_fail(session->server, TRUE, strerror(errno)); + + if (session->item) { + rspamd_symcache_item_async_dec_check(session->task, session->item, M); + } + + rspamd_session_remove_event(session->task->s, fuzzy_io_fin, session); + } + else { + /* Read something from network */ + if (!fuzzy_check_session_is_completed(session)) { + /* Need to read more */ + rspamd_ev_watcher_reschedule(session->event_loop, + &session->ev, EV_READ); + } + } +} + + +static void +fuzzy_controller_lua_fin(void *ud) +{ + struct fuzzy_learn_session *session = ud; + + (*session->saved)--; + + rspamd_ev_watcher_stop(session->event_loop, &session->ev); + close(session->fd); +} + +/* Controller IO */ + +static void +fuzzy_controller_timer_callback(gint fd, short what, void *arg) +{ + struct fuzzy_learn_session *session = arg; + struct rspamd_task *task; + + task = session->task; + + if (session->retransmits >= session->rule->retransmits) { + rspamd_upstream_fail(session->server, TRUE, "timeout"); + msg_err_task_check("got IO timeout with server %s(%s), " + "after %d/%d retransmits", + rspamd_upstream_name(session->server), + rspamd_inet_address_to_string_pretty( + rspamd_upstream_addr_cur(session->server)), + session->retransmits, + session->rule->retransmits); + + if (session->session) { + rspamd_session_remove_event(session->session, fuzzy_controller_lua_fin, + session); + } + else { + if (session->http_entry) { + rspamd_controller_send_error(session->http_entry, + 500, "IO timeout with fuzzy storage"); + } + + if (*session->saved > 0) { + (*session->saved)--; + if (*session->saved == 0) { + if (session->http_entry) { + rspamd_task_free(session->task); + } + + session->task = NULL; + } + } + + if (session->http_entry) { + rspamd_http_connection_unref(session->http_entry->conn); + } + + rspamd_ev_watcher_stop(session->event_loop, + &session->ev); + close(session->fd); + } + } + else { + /* Plan write event */ + rspamd_ev_watcher_reschedule(session->event_loop, + &session->ev, EV_READ | EV_WRITE); + session->retransmits++; + } +} + +static void +fuzzy_controller_io_callback(gint fd, short what, void *arg) +{ + struct fuzzy_learn_session *session = arg; + const struct rspamd_fuzzy_reply *rep; + struct fuzzy_mapping *map; + struct rspamd_task *task; + guchar buf[2048], *p; + struct fuzzy_cmd_io *io; + struct rspamd_fuzzy_cmd *cmd = NULL; + const gchar *symbol, *ftype; + gint r; + enum { + return_error = 0, + return_want_more, + return_finished + } ret = return_want_more; + guint i, nreplied; + const gchar *op = "process"; + + task = session->task; + + if (what & EV_READ) { + if ((r = read(fd, buf, sizeof(buf) - 1)) == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { + rspamd_ev_watcher_reschedule(session->event_loop, + &session->ev, EV_READ); + return; + } + + msg_info_task("cannot process fuzzy hash for message: %s", + strerror(errno)); + session->err.error_message = "read socket error"; + session->err.error_code = errno; + + ret = return_error; + } + else { + p = buf; + ret = return_want_more; + + while ((rep = fuzzy_process_reply(&p, &r, + session->commands, session->rule, &cmd, &io)) != NULL) { + if ((map = + g_hash_table_lookup(session->rule->mappings, + GINT_TO_POINTER(rep->v1.flag))) == NULL) { + /* Default symbol and default weight */ + symbol = session->rule->symbol; + } + else { + /* Get symbol and weight from map */ + symbol = map->symbol; + } + + ftype = "bin"; + + if (io) { + if ((io->flags & FUZZY_CMD_FLAG_IMAGE)) { + ftype = "img"; + } + else if (io->flags & FUZZY_CMD_FLAG_CONTENT) { + ftype = "content"; + } + else if (cmd->shingles_count > 0) { + ftype = "txt"; + } + + if (io->cmd.cmd == FUZZY_WRITE) { + op = "added"; + } + else if (io->cmd.cmd == FUZZY_DEL) { + op = "deleted"; + } + } + + if (rep->v1.prob > 0.5) { + msg_info_task("%s fuzzy hash (%s) %*xs, list: %s:%d for " + "message <%s>", + op, + ftype, + (gint) sizeof(rep->digest), rep->digest, + symbol, + rep->v1.flag, + MESSAGE_FIELD_CHECK(session->task, message_id)); + } + else { + if (rep->v1.value == 401) { + msg_info_task( + "fuzzy hash (%s) for message cannot be %s" + "<%s>, %*xs, " + "list %s:%d, skipped by server", + ftype, + op, + MESSAGE_FIELD_CHECK(session->task, message_id), + (gint) sizeof(rep->digest), rep->digest, + symbol, + rep->v1.flag); + + session->err.error_message = "fuzzy hash is skipped"; + session->err.error_code = rep->v1.value; + } + else { + msg_info_task( + "fuzzy hash (%s) for message cannot be %s" + "<%s>, %*xs, " + "list %s:%d, error: %d", + ftype, + op, + MESSAGE_FIELD_CHECK(session->task, message_id), + (gint) sizeof(rep->digest), rep->digest, + symbol, + rep->v1.flag, + rep->v1.value); + + session->err.error_message = "process fuzzy error"; + session->err.error_code = rep->v1.value; + } + + ret = return_finished; + } + } + + nreplied = 0; + + for (i = 0; i < session->commands->len; i++) { + io = g_ptr_array_index(session->commands, i); + + if (io->flags & FUZZY_CMD_FLAG_REPLIED) { + nreplied++; + } + } + + if (nreplied == session->commands->len) { + ret = return_finished; + } + } + } + else if (what & EV_WRITE) { + /* Send commands to storage */ + if (!fuzzy_cmd_vector_to_wire(fd, session->commands)) { + session->err.error_message = "write socket error"; + session->err.error_code = errno; + ret = return_error; + } + } + else { + fuzzy_controller_timer_callback(fd, what, arg); + + return; + } + + if (ret == return_want_more) { + rspamd_ev_watcher_reschedule(session->event_loop, + &session->ev, EV_READ); + + return; + } + else if (ret == return_error) { + msg_err_task("got error in IO with server %s(%s), %d, %s", + rspamd_upstream_name(session->server), + rspamd_inet_address_to_string_pretty( + rspamd_upstream_addr_cur(session->server)), + errno, strerror(errno)); + rspamd_upstream_fail(session->server, FALSE, strerror(errno)); + } + + /* + * XXX: actually, we check merely a single reply, which is not correct... + * XXX: when we send a command, we do not check if *all* commands have been + * written + * XXX: please, please, change this code some day + */ + + if (session->session == NULL) { + (*session->saved)--; + + if (session->http_entry) { + rspamd_http_connection_unref(session->http_entry->conn); + } + + rspamd_ev_watcher_stop(session->event_loop, &session->ev); + close(session->fd); + + if (*session->saved == 0) { + goto cleanup; + } + } + else { + /* Lua handler */ + rspamd_session_remove_event(session->session, fuzzy_controller_lua_fin, session); + } + + return; + +cleanup: + /* + * When we send learn commands to fuzzy storages, this code is executed + * *once* when we have queried all storages. We also don't know which + * storage has been failed. + * + * Therefore, we cleanup sessions earlier and actually this code is wrong. + */ + + if (session->err.error_code != 0) { + if (session->http_entry) { + rspamd_controller_send_error(session->http_entry, + session->err.error_code, session->err.error_message); + } + } + else { + rspamd_upstream_ok(session->server); + + if (session->http_entry) { + ucl_object_t *reply, *hashes; + gchar hexbuf[rspamd_cryptobox_HASHBYTES * 2 + 1]; + + reply = ucl_object_typed_new(UCL_OBJECT); + + ucl_object_insert_key(reply, ucl_object_frombool(true), + "success", 0, false); + hashes = ucl_object_typed_new(UCL_ARRAY); + + for (i = 0; i < session->commands->len; i++) { + io = g_ptr_array_index(session->commands, i); + + rspamd_snprintf(hexbuf, sizeof(hexbuf), "%*xs", + (gint) sizeof(io->cmd.digest), io->cmd.digest); + ucl_array_append(hashes, ucl_object_fromstring(hexbuf)); + } + + ucl_object_insert_key(reply, hashes, "hashes", 0, false); + rspamd_controller_send_ucl(session->http_entry, reply); + ucl_object_unref(reply); + } + } + + if (session->task != NULL) { + if (session->http_entry) { + rspamd_task_free(session->task); + } + + session->task = NULL; + } +} + +static GPtrArray * +fuzzy_generate_commands(struct rspamd_task *task, struct fuzzy_rule *rule, + gint c, gint flag, guint32 value, guint flags) +{ + struct rspamd_mime_text_part *part; + struct rspamd_mime_part *mime_part; + struct rspamd_image *image; + struct fuzzy_cmd_io *io, *cur; + guint i, j; + GPtrArray *res = NULL; + gboolean check_part, fuzzy_check; + + if (c == FUZZY_STAT) { + res = g_ptr_array_sized_new(1); + + io = fuzzy_cmd_stat(rule, c, flag, value, task->task_pool); + if (io) { + g_ptr_array_add(res, io); + } + + goto end; + } + else if (c == FUZZY_PING) { + res = g_ptr_array_sized_new(1); + + io = fuzzy_cmd_ping(rule, task->task_pool); + if (io) { + g_ptr_array_add(res, io); + } + + goto end; + } + + if (task->message == NULL) { + goto end; + } + + res = g_ptr_array_sized_new(MESSAGE_FIELD(task, parts)->len + 1); + + PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, parts), i, mime_part) + { + check_part = FALSE; + fuzzy_check = FALSE; + + if (fuzzy_rule_check_mimepart(task, rule, mime_part, &check_part, + &fuzzy_check)) { + io = NULL; + + if (check_part) { + if (mime_part->part_type == RSPAMD_MIME_PART_TEXT && + !(flags & FUZZY_CHECK_FLAG_NOTEXT)) { + part = mime_part->specific.txt; + + io = fuzzy_cmd_from_text_part(task, rule, + c, + flag, + value, + !fuzzy_check, + part, + mime_part); + } + else if (mime_part->part_type == RSPAMD_MIME_PART_IMAGE && + !(flags & FUZZY_CHECK_FLAG_NOIMAGES)) { + image = mime_part->specific.img; + + io = fuzzy_cmd_from_data_part(rule, c, flag, value, + task, + image->parent->digest, + mime_part); + io->flags |= FUZZY_CMD_FLAG_IMAGE; + } + else if (mime_part->part_type == RSPAMD_MIME_PART_CUSTOM_LUA) { + const struct rspamd_lua_specific_part *lua_spec; + + lua_spec = &mime_part->specific.lua_specific; + + if (lua_spec->type == RSPAMD_LUA_PART_TABLE) { + lua_State *L = (lua_State *) task->cfg->lua_state; + gint old_top; + + old_top = lua_gettop(L); + /* Push table */ + lua_rawgeti(L, LUA_REGISTRYINDEX, lua_spec->cbref); + lua_pushstring(L, "fuzzy_hashes"); + lua_gettable(L, -2); + + if (lua_type(L, -1) == LUA_TTABLE) { + gint tbl_pos = lua_gettop(L); + + for (lua_pushnil(L); lua_next(L, tbl_pos); + lua_pop(L, 1)) { + const gchar *h = NULL; + gsize hlen = 0; + + if (lua_isstring(L, -1)) { + h = lua_tolstring(L, -1, &hlen); + } + else if (lua_type(L, -1) == LUA_TUSERDATA) { + struct rspamd_lua_text *t; + + t = lua_check_text(L, -1); + + if (t) { + h = t->start; + hlen = t->len; + } + } + + if (hlen == rspamd_cryptobox_HASHBYTES) { + io = fuzzy_cmd_from_data_part(rule, c, + flag, value, + task, + (guchar *) h, + mime_part); + + if (io) { + io->flags |= FUZZY_CMD_FLAG_CONTENT; + g_ptr_array_add(res, io); + } + } + } + } + + lua_settop(L, old_top); + + /* + * Add part itself as well + */ + io = fuzzy_cmd_from_data_part(rule, c, + flag, value, + task, + mime_part->digest, + mime_part); + } + } + else { + io = fuzzy_cmd_from_data_part(rule, c, flag, value, + task, + mime_part->digest, mime_part); + } + + if (io) { + gboolean skip_existing = FALSE; + + PTR_ARRAY_FOREACH(res, j, cur) + { + if (memcmp(cur->cmd.digest, io->cmd.digest, + sizeof(io->cmd.digest)) == 0) { + skip_existing = TRUE; + break; + } + } + + if (!skip_existing) { + g_ptr_array_add(res, io); + } + } + } + } + } + +end: + if (res && res->len == 0) { + g_ptr_array_free(res, TRUE); + + return NULL; + } + + return res; +} + + +static inline void +register_fuzzy_client_call(struct rspamd_task *task, + struct fuzzy_rule *rule, + GPtrArray *commands) +{ + struct fuzzy_client_session *session; + struct upstream *selected; + rspamd_inet_addr_t *addr; + gint sock; + + if (!rspamd_session_blocked(task->s)) { + /* Get upstream */ + selected = rspamd_upstream_get(rule->servers, RSPAMD_UPSTREAM_ROUND_ROBIN, + NULL, 0); + if (selected) { + addr = rspamd_upstream_addr_next(selected); + if ((sock = rspamd_inet_address_connect(addr, SOCK_DGRAM, TRUE)) == -1) { + msg_warn_task("cannot connect to %s(%s), %d, %s", + rspamd_upstream_name(selected), + rspamd_inet_address_to_string_pretty(addr), + errno, + strerror(errno)); + rspamd_upstream_fail(selected, TRUE, strerror(errno)); + g_ptr_array_free(commands, TRUE); + } + else { + /* Create session for a socket */ + session = + rspamd_mempool_alloc0(task->task_pool, + sizeof(struct fuzzy_client_session)); + session->state = 0; + session->commands = commands; + session->task = task; + session->fd = sock; + session->server = selected; + session->rule = rule; + session->results = g_ptr_array_sized_new(32); + session->event_loop = task->event_loop; + + rspamd_ev_watcher_init(&session->ev, + sock, + EV_WRITE, + fuzzy_check_io_callback, + session); + rspamd_ev_watcher_start(session->event_loop, &session->ev, + rule->io_timeout); + + rspamd_session_add_event(task->s, fuzzy_io_fin, session, M); + session->item = rspamd_symcache_get_cur_item(task); + + if (session->item) { + rspamd_symcache_item_async_inc(task, session->item, M); + } + } + } + } +} + +/* This callback is called when we check message in fuzzy hashes storage */ +static void +fuzzy_symbol_callback(struct rspamd_task *task, + struct rspamd_symcache_dynamic_item *item, + void *unused) +{ + struct fuzzy_rule *rule; + guint i; + GPtrArray *commands; + struct fuzzy_ctx *fuzzy_module_ctx = fuzzy_get_context(task->cfg); + + if (!fuzzy_module_ctx->enabled) { + rspamd_symcache_finalize_item(task, item); + + return; + } + + /* Check whitelist */ + if (fuzzy_module_ctx->whitelist) { + if (rspamd_match_radix_map_addr(fuzzy_module_ctx->whitelist, + task->from_addr) != NULL) { + msg_info_task("<%s>, address %s is whitelisted, skip fuzzy check", + MESSAGE_FIELD(task, message_id), + rspamd_inet_address_to_string(task->from_addr)); + rspamd_symcache_finalize_item(task, item); + + return; + } + } + + rspamd_symcache_item_async_inc(task, item, M); + + PTR_ARRAY_FOREACH(fuzzy_module_ctx->fuzzy_rules, i, rule) + { + commands = fuzzy_generate_commands(task, rule, FUZZY_CHECK, 0, 0, 0); + + if (commands != NULL) { + register_fuzzy_client_call(task, rule, commands); + } + } + + rspamd_symcache_item_async_dec_check(task, item, M); +} + +void fuzzy_stat_command(struct rspamd_task *task) +{ + struct fuzzy_rule *rule; + guint i; + GPtrArray *commands; + struct fuzzy_ctx *fuzzy_module_ctx = fuzzy_get_context(task->cfg); + + if (!fuzzy_module_ctx->enabled) { + return; + } + + PTR_ARRAY_FOREACH(fuzzy_module_ctx->fuzzy_rules, i, rule) + { + commands = fuzzy_generate_commands(task, rule, FUZZY_STAT, 0, 0, 0); + if (commands != NULL) { + register_fuzzy_client_call(task, rule, commands); + } + } +} + +static inline gint +register_fuzzy_controller_call(struct rspamd_http_connection_entry *entry, + struct fuzzy_rule *rule, + struct rspamd_task *task, + GPtrArray *commands, + gint *saved) +{ + struct fuzzy_learn_session *s; + struct upstream *selected; + rspamd_inet_addr_t *addr; + struct rspamd_controller_session *session = entry->ud; + gint sock; + gint ret = -1; + + /* Get upstream */ + + while ((selected = rspamd_upstream_get_forced(rule->servers, + RSPAMD_UPSTREAM_SEQUENTIAL, NULL, 0))) { + /* Create UDP socket */ + addr = rspamd_upstream_addr_next(selected); + + if ((sock = rspamd_inet_address_connect(addr, + SOCK_DGRAM, TRUE)) == -1) { + msg_warn_task("cannot connect to fuzzy storage %s (%s rule): %s", + rspamd_inet_address_to_string_pretty(addr), + rule->name, + strerror(errno)); + rspamd_upstream_fail(selected, TRUE, strerror(errno)); + } + else { + s = + rspamd_mempool_alloc0(session->pool, + sizeof(struct fuzzy_learn_session)); + + s->task = task; + s->commands = commands; + s->http_entry = entry; + s->server = selected; + s->saved = saved; + s->fd = sock; + s->rule = rule; + s->event_loop = task->event_loop; + /* We ref connection to avoid freeing before we process fuzzy rule */ + rspamd_http_connection_ref(entry->conn); + + rspamd_ev_watcher_init(&s->ev, + sock, + EV_WRITE, + fuzzy_controller_io_callback, + s); + rspamd_ev_watcher_start(s->event_loop, &s->ev, rule->io_timeout); + + (*saved)++; + ret = 1; + } + } + + return ret; +} + +static void +fuzzy_process_handler(struct rspamd_http_connection_entry *conn_ent, + struct rspamd_http_message *msg, gint cmd, gint value, gint flag, + struct fuzzy_ctx *ctx, gboolean is_hash, guint flags) +{ + struct fuzzy_rule *rule; + struct rspamd_controller_session *session = conn_ent->ud; + struct rspamd_task *task, **ptask; + gboolean processed = FALSE, skip = FALSE; + gint res = 0; + guint i; + GPtrArray *commands; + lua_State *L; + gint r, *saved, rules = 0, err_idx; + struct fuzzy_ctx *fuzzy_module_ctx; + + /* Prepare task */ + task = rspamd_task_new(session->wrk, session->cfg, NULL, + session->lang_det, conn_ent->rt->event_loop, FALSE); + task->cfg = ctx->cfg; + saved = rspamd_mempool_alloc0(session->pool, sizeof(gint)); + fuzzy_module_ctx = fuzzy_get_context(ctx->cfg); + + if (!is_hash) { + /* Allocate message from string */ + /* XXX: what about encrypted messages ? */ + task->msg.begin = msg->body_buf.begin; + task->msg.len = msg->body_buf.len; + + r = rspamd_message_parse(task); + + if (r == -1) { + msg_warn_task("<%s>: cannot process message for fuzzy", + MESSAGE_FIELD(task, message_id)); + rspamd_task_free(task); + rspamd_controller_send_error(conn_ent, 400, + "Message processing error"); + + return; + } + + rspamd_message_process(task); + } + + PTR_ARRAY_FOREACH(fuzzy_module_ctx->fuzzy_rules, i, rule) + { + if (rule->read_only) { + continue; + } + + /* Check for flag */ + if (g_hash_table_lookup(rule->mappings, + GINT_TO_POINTER(flag)) == NULL) { + msg_info_task("skip rule %s as it has no flag %d defined" + " false", + rule->name, flag); + continue; + } + + /* Check learn condition */ + if (rule->learn_condition_cb != -1) { + skip = FALSE; + L = session->cfg->lua_state; + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + + lua_rawgeti(L, LUA_REGISTRYINDEX, rule->learn_condition_cb); + ptask = lua_newuserdata(L, sizeof(struct rspamd_task *)); + *ptask = task; + rspamd_lua_setclass(L, "rspamd{task}", -1); + + if (lua_pcall(L, 1, LUA_MULTRET, err_idx) != 0) { + msg_err_task("call to fuzzy learn condition failed: %s", + lua_tostring(L, -1)); + } + else { + if (lua_gettop(L) > err_idx + 1) { + /* 2 return values */ + skip = !(lua_toboolean(L, err_idx + 1)); + + if (lua_isnumber(L, err_idx + 2)) { + msg_info_task("learn condition changed flag from %d to " + "%d", + flag, + (gint) lua_tonumber(L, err_idx + 2)); + flag = lua_tonumber(L, err_idx + 2); + } + } + else { + if (lua_isboolean(L, err_idx + 1)) { + skip = !(lua_toboolean(L, err_idx + 1)); + } + else { + msg_warn_task("set skip for rule %s as its condition " + "callback returned" + " a valid boolean", + rule->name); + skip = TRUE; + } + } + } + + /* Result + error function */ + lua_settop(L, err_idx - 1); + + if (skip) { + msg_info_task("skip rule %s by condition callback", + rule->name); + continue; + } + } + + rules++; + + res = 0; + + if (is_hash) { + GPtrArray *args; + const rspamd_ftok_t *arg; + guint j; + + args = rspamd_http_message_find_header_multiple(msg, "Hash"); + + if (args) { + struct fuzzy_cmd_io *io; + commands = g_ptr_array_sized_new(args->len); + + for (j = 0; j < args->len; j++) { + arg = g_ptr_array_index(args, j); + io = fuzzy_cmd_hash(rule, cmd, arg, flag, value, + task->task_pool); + + if (io) { + g_ptr_array_add(commands, io); + } + } + + res = register_fuzzy_controller_call(conn_ent, + rule, + task, + commands, + saved); + rspamd_mempool_add_destructor(task->task_pool, + rspamd_ptr_array_free_hard, commands); + g_ptr_array_free(args, TRUE); + } + else { + rspamd_controller_send_error(conn_ent, 400, + "No hash defined"); + rspamd_task_free(task); + return; + } + } + else { + commands = fuzzy_generate_commands(task, rule, cmd, flag, value, + flags); + if (commands != NULL) { + res = register_fuzzy_controller_call(conn_ent, + rule, + task, + commands, + saved); + rspamd_mempool_add_destructor(task->task_pool, + rspamd_ptr_array_free_hard, commands); + } + } + + if (res > 0) { + processed = TRUE; + } + } + + if (res == -1) { + if (!processed) { + msg_warn_task("cannot send fuzzy request: %s", + strerror(errno)); + rspamd_controller_send_error(conn_ent, 400, "Message sending error"); + rspamd_task_free(task); + + return; + } + else { + /* Some rules failed and some rules are OK */ + msg_warn_task("some rules are not processed, but we still sent this request"); + } + } + else if (!processed) { + if (rules) { + msg_warn_task("no content to generate fuzzy"); + rspamd_controller_send_error(conn_ent, 404, + "No content to generate fuzzy for flag %d", flag); + } + else { + if (skip) { + rspamd_controller_send_error(conn_ent, 403, + "Message is conditionally skipped for flag %d", flag); + } + else { + msg_warn_task("no fuzzy rules found for flag %d", flag); + rspamd_controller_send_error(conn_ent, 404, + "No fuzzy rules matched for flag %d", flag); + } + } + rspamd_task_free(task); + } +} + +static int +fuzzy_controller_handler(struct rspamd_http_connection_entry *conn_ent, + struct rspamd_http_message *msg, struct module_ctx *ctx, gint cmd, + gboolean is_hash) +{ + const rspamd_ftok_t *arg; + glong value = 1, flag = 0, send_flags = 0; + struct fuzzy_ctx *fuzzy_module_ctx = (struct fuzzy_ctx *) ctx; + + if (!fuzzy_module_ctx->enabled) { + msg_err("fuzzy_check module is not enabled"); + rspamd_controller_send_error(conn_ent, 500, "Module disabled"); + return 0; + } + + if (fuzzy_module_ctx->fuzzy_rules == NULL) { + msg_err("fuzzy_check module has no rules defined"); + rspamd_controller_send_error(conn_ent, 500, "Module has no rules"); + return 0; + } + + /* Get size */ + arg = rspamd_http_message_find_header(msg, "Weight"); + if (arg) { + errno = 0; + + if (!rspamd_strtol(arg->begin, arg->len, &value)) { + msg_info("error converting numeric argument %T", arg); + } + } + + arg = rspamd_http_message_find_header(msg, "Flag"); + if (arg) { + errno = 0; + + if (!rspamd_strtol(arg->begin, arg->len, &flag)) { + msg_info("error converting numeric argument %T", arg); + flag = 0; + } + } + else { + flag = 0; + arg = rspamd_http_message_find_header(msg, "Symbol"); + + /* Search flag by symbol */ + if (arg) { + struct fuzzy_rule *rule; + guint i; + GHashTableIter it; + gpointer k, v; + struct fuzzy_mapping *map; + + PTR_ARRAY_FOREACH(fuzzy_module_ctx->fuzzy_rules, i, rule) + { + if (flag != 0) { + break; + } + + g_hash_table_iter_init(&it, rule->mappings); + + while (g_hash_table_iter_next(&it, &k, &v)) { + map = v; + + if (strlen(map->symbol) == arg->len && + rspamd_lc_cmp(map->symbol, arg->begin, arg->len) == 0) { + flag = map->fuzzy_flag; + break; + } + } + } + } + } + + if (flag == 0) { + msg_err("no flag defined to learn fuzzy"); + rspamd_controller_send_error(conn_ent, 404, "Unknown or missing flag"); + return 0; + } + + arg = rspamd_http_message_find_header(msg, "Skip-Images"); + if (arg) { + send_flags |= FUZZY_CHECK_FLAG_NOIMAGES; + } + + arg = rspamd_http_message_find_header(msg, "Skip-Attachments"); + if (arg) { + send_flags |= FUZZY_CHECK_FLAG_NOATTACHMENTS; + } + + arg = rspamd_http_message_find_header(msg, "Skip-Text"); + if (arg) { + send_flags |= FUZZY_CHECK_FLAG_NOTEXT; + } + + fuzzy_process_handler(conn_ent, msg, cmd, value, flag, + (struct fuzzy_ctx *) ctx, is_hash, send_flags); + + return 0; +} + +static inline gint +fuzzy_check_send_lua_learn(struct fuzzy_rule *rule, + struct rspamd_task *task, + GPtrArray *commands, + gint *saved) +{ + struct fuzzy_learn_session *s; + struct upstream *selected; + rspamd_inet_addr_t *addr; + gint sock; + gint ret = -1; + + /* Get upstream */ + if (!rspamd_session_blocked(task->s)) { + while ((selected = rspamd_upstream_get(rule->servers, + RSPAMD_UPSTREAM_SEQUENTIAL, NULL, 0))) { + /* Create UDP socket */ + addr = rspamd_upstream_addr_next(selected); + + if ((sock = rspamd_inet_address_connect(addr, + SOCK_DGRAM, TRUE)) == -1) { + rspamd_upstream_fail(selected, TRUE, strerror(errno)); + } + else { + s = + rspamd_mempool_alloc0(task->task_pool, + sizeof(struct fuzzy_learn_session)); + s->task = task; + s->commands = commands; + s->http_entry = NULL; + s->server = selected; + s->saved = saved; + s->fd = sock; + s->rule = rule; + s->session = task->s; + s->event_loop = task->event_loop; + + rspamd_ev_watcher_init(&s->ev, + sock, + EV_WRITE, + fuzzy_controller_io_callback, + s); + rspamd_ev_watcher_start(s->event_loop, &s->ev, + rule->io_timeout); + + rspamd_session_add_event(task->s, + fuzzy_controller_lua_fin, + s, + M); + + (*saved)++; + ret = 1; + } + } + } + + return ret; +} + +static gboolean +fuzzy_check_lua_process_learn(struct rspamd_task *task, + gint cmd, gint value, gint flag, guint send_flags) +{ + struct fuzzy_rule *rule; + gboolean processed = FALSE, res = TRUE; + guint i; + GPtrArray *commands; + gint *saved, rules = 0; + struct fuzzy_ctx *fuzzy_module_ctx = fuzzy_get_context(task->cfg); + + saved = rspamd_mempool_alloc0(task->task_pool, sizeof(gint)); + + PTR_ARRAY_FOREACH(fuzzy_module_ctx->fuzzy_rules, i, rule) + { + if (!res) { + break; + } + if (rule->read_only) { + continue; + } + + /* Check for flag */ + if (g_hash_table_lookup(rule->mappings, + GINT_TO_POINTER(flag)) == NULL) { + msg_info_task("skip rule %s as it has no flag %d defined" + " false", + rule->name, flag); + continue; + } + + rules++; + + res = 0; + commands = fuzzy_generate_commands(task, rule, cmd, flag, + value, send_flags); + + if (commands != NULL) { + res = fuzzy_check_send_lua_learn(rule, task, commands, + saved); + rspamd_mempool_add_destructor(task->task_pool, + rspamd_ptr_array_free_hard, commands); + } + + if (res) { + processed = TRUE; + } + } + + if (res == -1) { + msg_warn_task("cannot send fuzzy request: %s", + strerror(errno)); + } + else if (!processed) { + if (rules) { + msg_warn_task("no content to generate fuzzy"); + + return FALSE; + } + else { + msg_warn_task("no fuzzy rules found for flag %d", flag); + return FALSE; + } + } + + return TRUE; +} + +static gint +fuzzy_lua_learn_handler(lua_State *L) +{ + struct rspamd_task *task = lua_check_task(L, 1); + + if (task == NULL) { + return luaL_error(L, "invalid arguments"); + } + + guint flag = 0, weight = 1, send_flags = 0; + const gchar *symbol; + struct fuzzy_ctx *fuzzy_module_ctx = fuzzy_get_context(task->cfg); + + if (lua_type(L, 2) == LUA_TNUMBER) { + flag = lua_tointeger(L, 2); + } + else if (lua_type(L, 2) == LUA_TSTRING) { + struct fuzzy_rule *rule; + guint i; + GHashTableIter it; + gpointer k, v; + struct fuzzy_mapping *map; + + symbol = lua_tostring(L, 2); + + PTR_ARRAY_FOREACH(fuzzy_module_ctx->fuzzy_rules, i, rule) + { + if (flag != 0) { + break; + } + + g_hash_table_iter_init(&it, rule->mappings); + + while (g_hash_table_iter_next(&it, &k, &v)) { + map = v; + + if (g_ascii_strcasecmp(symbol, map->symbol) == 0) { + flag = map->fuzzy_flag; + break; + } + } + } + } + + if (flag == 0) { + return luaL_error(L, "bad flag"); + } + + if (lua_type(L, 3) == LUA_TNUMBER) { + weight = lua_tonumber(L, 3); + } + + if (lua_type(L, 4) == LUA_TTABLE) { + const gchar *sf; + + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + sf = lua_tostring(L, -1); + + if (sf) { + if (g_ascii_strcasecmp(sf, "noimages") == 0) { + send_flags |= FUZZY_CHECK_FLAG_NOIMAGES; + } + else if (g_ascii_strcasecmp(sf, "noattachments") == 0) { + send_flags |= FUZZY_CHECK_FLAG_NOATTACHMENTS; + } + else if (g_ascii_strcasecmp(sf, "notext") == 0) { + send_flags |= FUZZY_CHECK_FLAG_NOTEXT; + } + } + } + } + + lua_pushboolean(L, + fuzzy_check_lua_process_learn(task, FUZZY_WRITE, weight, flag, + send_flags)); + return 1; +} + +static gint +fuzzy_lua_unlearn_handler(lua_State *L) +{ + struct rspamd_task *task = lua_check_task(L, 1); + if (task == NULL) { + return luaL_error(L, "invalid arguments"); + } + + guint flag = 0, weight = 1.0, send_flags = 0; + const gchar *symbol; + struct fuzzy_ctx *fuzzy_module_ctx = fuzzy_get_context(task->cfg); + + if (lua_type(L, 2) == LUA_TNUMBER) { + flag = lua_tointeger(L, 2); + } + else if (lua_type(L, 2) == LUA_TSTRING) { + struct fuzzy_rule *rule; + guint i; + GHashTableIter it; + gpointer k, v; + struct fuzzy_mapping *map; + + symbol = lua_tostring(L, 2); + + PTR_ARRAY_FOREACH(fuzzy_module_ctx->fuzzy_rules, i, rule) + { + + if (flag != 0) { + break; + } + + g_hash_table_iter_init(&it, rule->mappings); + + while (g_hash_table_iter_next(&it, &k, &v)) { + map = v; + + if (g_ascii_strcasecmp(symbol, map->symbol) == 0) { + flag = map->fuzzy_flag; + break; + } + } + } + } + + if (flag == 0) { + return luaL_error(L, "bad flag"); + } + + if (lua_type(L, 3) == LUA_TNUMBER) { + weight = lua_tonumber(L, 3); + } + + if (lua_type(L, 4) == LUA_TTABLE) { + const gchar *sf; + + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + sf = lua_tostring(L, -1); + + if (sf) { + if (g_ascii_strcasecmp(sf, "noimages") == 0) { + send_flags |= FUZZY_CHECK_FLAG_NOIMAGES; + } + else if (g_ascii_strcasecmp(sf, "noattachments") == 0) { + send_flags |= FUZZY_CHECK_FLAG_NOATTACHMENTS; + } + else if (g_ascii_strcasecmp(sf, "notext") == 0) { + send_flags |= FUZZY_CHECK_FLAG_NOTEXT; + } + } + } + } + + lua_pushboolean(L, + fuzzy_check_lua_process_learn(task, FUZZY_DEL, weight, flag, + send_flags)); + + return 1; +} + +static gint +fuzzy_lua_gen_hashes_handler(lua_State *L) +{ + struct rspamd_task *task = lua_check_task(L, 1); + + if (task == NULL) { + return luaL_error(L, "invalid arguments"); + } + + guint flag = 0, weight = 1, send_flags = 0; + const gchar *symbol; + struct fuzzy_ctx *fuzzy_module_ctx = fuzzy_get_context(task->cfg); + struct fuzzy_rule *rule; + GPtrArray *commands; + gint cmd = FUZZY_WRITE; + gint i; + + if (lua_type(L, 2) == LUA_TNUMBER) { + flag = lua_tonumber(L, 2); + } + else if (lua_type(L, 2) == LUA_TSTRING) { + struct fuzzy_rule *rule; + GHashTableIter it; + gpointer k, v; + struct fuzzy_mapping *map; + + symbol = lua_tostring(L, 2); + + PTR_ARRAY_FOREACH(fuzzy_module_ctx->fuzzy_rules, i, rule) + { + if (flag != 0) { + break; + } + + g_hash_table_iter_init(&it, rule->mappings); + + while (g_hash_table_iter_next(&it, &k, &v)) { + map = v; + + if (g_ascii_strcasecmp(symbol, map->symbol) == 0) { + flag = map->fuzzy_flag; + break; + } + } + } + } + + if (flag == 0) { + return luaL_error(L, "bad flag"); + } + + if (lua_type(L, 3) == LUA_TNUMBER) { + weight = lua_tonumber(L, 3); + } + + /* Flags */ + if (lua_type(L, 4) == LUA_TTABLE) { + const gchar *sf; + + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + sf = lua_tostring(L, -1); + + if (sf) { + if (g_ascii_strcasecmp(sf, "noimages") == 0) { + send_flags |= FUZZY_CHECK_FLAG_NOIMAGES; + } + else if (g_ascii_strcasecmp(sf, "noattachments") == 0) { + send_flags |= FUZZY_CHECK_FLAG_NOATTACHMENTS; + } + else if (g_ascii_strcasecmp(sf, "notext") == 0) { + send_flags |= FUZZY_CHECK_FLAG_NOTEXT; + } + } + } + } + + /* Type */ + if (lua_type(L, 5) == LUA_TSTRING) { + const gchar *cmd_name = lua_tostring(L, 5); + + if (strcmp(cmd_name, "add") == 0 || strcmp(cmd_name, "write") == 0) { + cmd = FUZZY_WRITE; + } + else if (strcmp(cmd_name, "delete") == 0 || strcmp(cmd_name, "remove") == 0) { + cmd = FUZZY_DEL; + } + else { + return luaL_error(L, "invalid command: %s", cmd_name); + } + } + + lua_createtable(L, 0, fuzzy_module_ctx->fuzzy_rules->len); + + PTR_ARRAY_FOREACH(fuzzy_module_ctx->fuzzy_rules, i, rule) + { + if (rule->read_only) { + continue; + } + + /* Check for flag */ + if (g_hash_table_lookup(rule->mappings, + GINT_TO_POINTER(flag)) == NULL) { + msg_info_task("skip rule %s as it has no flag %d defined" + " false", + rule->name, flag); + continue; + } + + commands = fuzzy_generate_commands(task, rule, cmd, flag, + weight, send_flags); + + if (commands != NULL) { + struct fuzzy_cmd_io *io; + gint j; + + lua_pushstring(L, rule->name); + lua_createtable(L, commands->len, 0); + + PTR_ARRAY_FOREACH(commands, j, io) + { + lua_pushlstring(L, io->io.iov_base, io->io.iov_len); + lua_rawseti(L, -2, j + 1); + } + + lua_settable(L, -3); /* ret[rule->name] = {raw_fuzzy1, ..., raw_fuzzyn} */ + + g_ptr_array_free(commands, TRUE); + } + } + + + return 1; +} + +static gint +fuzzy_lua_hex_hashes_handler(lua_State *L) +{ + struct rspamd_task *task = lua_check_task(L, 1); + + if (task == NULL) { + return luaL_error(L, "invalid arguments"); + } + + guint flag = 0, weight = 1, send_flags = 0; + const gchar *symbol; + struct fuzzy_ctx *fuzzy_module_ctx = fuzzy_get_context(task->cfg); + struct fuzzy_rule *rule; + GPtrArray *commands; + gint i; + + if (lua_type(L, 2) == LUA_TNUMBER) { + flag = lua_tonumber(L, 2); + } + else if (lua_type(L, 2) == LUA_TSTRING) { + struct fuzzy_rule *rule; + GHashTableIter it; + gpointer k, v; + struct fuzzy_mapping *map; + + symbol = lua_tostring(L, 2); + + PTR_ARRAY_FOREACH(fuzzy_module_ctx->fuzzy_rules, i, rule) + { + if (flag != 0) { + break; + } + + g_hash_table_iter_init(&it, rule->mappings); + + while (g_hash_table_iter_next(&it, &k, &v)) { + map = v; + + if (g_ascii_strcasecmp(symbol, map->symbol) == 0) { + flag = map->fuzzy_flag; + break; + } + } + } + } + + if (flag == 0) { + return luaL_error(L, "bad flag"); + } + + lua_createtable(L, 0, fuzzy_module_ctx->fuzzy_rules->len); + + PTR_ARRAY_FOREACH(fuzzy_module_ctx->fuzzy_rules, i, rule) + { + /* Check for flag */ + if (g_hash_table_lookup(rule->mappings, + GINT_TO_POINTER(flag)) == NULL) { + msg_debug_task("skip rule %s as it has no flag %d defined" + " false", + rule->name, flag); + continue; + } + + commands = fuzzy_generate_commands(task, rule, FUZZY_CHECK, flag, + weight, send_flags); + + lua_pushstring(L, rule->name); + + if (commands != NULL) { + lua_createtable(L, commands->len, 0); + /* + * We have all commands cached, so we can just read their cached value to + * get hex hashes + */ + struct rspamd_mime_part *mp; + gint j, part_idx = 1; + + PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, parts), j, mp) + { + struct rspamd_cached_shingles *cached; + + cached = fuzzy_cmd_get_cached(rule, task, mp); + + if (cached) { + gchar hexbuf[rspamd_cryptobox_HASHBYTES * 2 + 1]; + gint r = rspamd_encode_hex_buf(cached->digest, sizeof(cached->digest), hexbuf, + sizeof(hexbuf)); + lua_pushlstring(L, hexbuf, r); + lua_rawseti(L, -2, part_idx++); + } + } + + g_ptr_array_free(commands, TRUE); + } + else { + lua_pushnil(L); + } + + /* res[rule->name] = {hex_hash1, ..., hex_hashn} */ + lua_settable(L, -3); + } + + return 1; +} + +static gboolean +fuzzy_add_handler(struct rspamd_http_connection_entry *conn_ent, + struct rspamd_http_message *msg, struct module_ctx *ctx) +{ + return fuzzy_controller_handler(conn_ent, msg, + ctx, FUZZY_WRITE, FALSE); +} + +static gboolean +fuzzy_delete_handler(struct rspamd_http_connection_entry *conn_ent, + struct rspamd_http_message *msg, struct module_ctx *ctx) +{ + return fuzzy_controller_handler(conn_ent, msg, + ctx, FUZZY_DEL, FALSE); +} + +static gboolean +fuzzy_deletehash_handler(struct rspamd_http_connection_entry *conn_ent, + struct rspamd_http_message *msg, struct module_ctx *ctx) +{ + return fuzzy_controller_handler(conn_ent, msg, + ctx, FUZZY_DEL, TRUE); +} + +static int +fuzzy_attach_controller(struct module_ctx *ctx, GHashTable *commands) +{ + struct fuzzy_ctx *fctx = (struct fuzzy_ctx *) ctx; + struct rspamd_custom_controller_command *cmd; + + cmd = rspamd_mempool_alloc(fctx->fuzzy_pool, sizeof(*cmd)); + cmd->privileged = TRUE; + cmd->require_message = TRUE; + cmd->handler = fuzzy_add_handler; + cmd->ctx = ctx; + g_hash_table_insert(commands, "/fuzzyadd", cmd); + + cmd = rspamd_mempool_alloc(fctx->fuzzy_pool, sizeof(*cmd)); + cmd->privileged = TRUE; + cmd->require_message = TRUE; + cmd->handler = fuzzy_delete_handler; + cmd->ctx = ctx; + g_hash_table_insert(commands, "/fuzzydel", cmd); + + cmd = rspamd_mempool_alloc(fctx->fuzzy_pool, sizeof(*cmd)); + cmd->privileged = TRUE; + cmd->require_message = FALSE; + cmd->handler = fuzzy_deletehash_handler; + cmd->ctx = ctx; + g_hash_table_insert(commands, "/fuzzydelhash", cmd); + + return 0; +} + +/* Lua handlers */ +/* TODO: move to a separate unit, as this file is now a bit too hard to read */ + +static void +lua_upstream_str_inserter(struct upstream *up, guint idx, void *ud) +{ + lua_State *L = (lua_State *) ud; + + lua_pushstring(L, rspamd_upstream_name(up)); + lua_rawseti(L, -2, idx + 1); +} + +static gint +fuzzy_lua_list_storages(lua_State *L) +{ + struct rspamd_config *cfg = lua_check_config(L, 1); + + if (cfg == NULL) { + return luaL_error(L, "invalid arguments"); + } + + struct fuzzy_ctx *fuzzy_module_ctx = fuzzy_get_context(cfg); + struct fuzzy_rule *rule; + guint i; + + lua_createtable(L, 0, fuzzy_module_ctx->fuzzy_rules->len); + PTR_ARRAY_FOREACH(fuzzy_module_ctx->fuzzy_rules, i, rule) + { + lua_newtable(L); + + lua_pushboolean(L, rule->read_only); + lua_setfield(L, -2, "read_only"); + + /* Push servers */ + lua_createtable(L, rspamd_upstreams_count(rule->servers), 0); + rspamd_upstreams_foreach(rule->servers, lua_upstream_str_inserter, L); + lua_setfield(L, -2, "servers"); + + /* Push flags */ + GHashTableIter it; + + lua_createtable(L, 0, g_hash_table_size(rule->mappings)); + gpointer k, v; + struct fuzzy_mapping *map; + + g_hash_table_iter_init(&it, rule->mappings); + while (g_hash_table_iter_next(&it, &k, &v)) { + map = v; + + lua_pushinteger(L, map->fuzzy_flag); + lua_setfield(L, -2, map->symbol); + } + lua_setfield(L, -2, "flags"); + + /* Final table */ + lua_setfield(L, -2, rule->name); + } + + return 1; +} + +struct fuzzy_lua_session { + struct rspamd_task *task; + lua_State *L; + rspamd_inet_addr_t *addr; + GPtrArray *commands; + struct fuzzy_rule *rule; + struct rspamd_io_ev ev; + gint cbref; + gint fd; +}; + +static void +fuzzy_lua_session_fin(void *ud) +{ + struct fuzzy_lua_session *session = ud; + + if (session->commands) { + g_ptr_array_free(session->commands, TRUE); + } + + rspamd_ev_watcher_stop(session->task->event_loop, &session->ev); + luaL_unref(session->L, LUA_REGISTRYINDEX, session->cbref); +} + +static gboolean +fuzzy_lua_session_is_completed(struct fuzzy_lua_session *session) +{ + struct fuzzy_cmd_io *io; + guint nreplied = 0, i; + + + for (i = 0; i < session->commands->len; i++) { + io = g_ptr_array_index(session->commands, i); + + if (io->flags & FUZZY_CMD_FLAG_REPLIED) { + nreplied++; + } + } + + if (nreplied == session->commands->len) { + + rspamd_session_remove_event(session->task->s, fuzzy_lua_session_fin, session); + + return TRUE; + } + + return FALSE; +} + +static void +fuzzy_lua_push_result(struct fuzzy_lua_session *session, gdouble latency) +{ + lua_rawgeti(session->L, LUA_REGISTRYINDEX, session->cbref); + lua_pushboolean(session->L, TRUE); + rspamd_lua_ip_push(session->L, session->addr); + lua_pushnumber(session->L, latency); + + /* TODO: check results maybe? */ + lua_pcall(session->L, 3, 0, 0); +} + +#ifdef __GNUC__ +static void +fuzzy_lua_push_error(struct fuzzy_lua_session *session, const gchar *err_fmt, ...) __attribute__((format(printf, 2, 3))); +#endif + +static void +fuzzy_lua_push_error(struct fuzzy_lua_session *session, const gchar *err_fmt, ...) +{ + va_list v; + + va_start(v, err_fmt); + lua_rawgeti(session->L, LUA_REGISTRYINDEX, session->cbref); + lua_pushboolean(session->L, FALSE); + rspamd_lua_ip_push(session->L, session->addr); + lua_pushvfstring(session->L, err_fmt, v); + va_end(v); + + /* TODO: check results maybe? */ + lua_pcall(session->L, 3, 0, 0); +} + +static gint +fuzzy_lua_try_read(struct fuzzy_lua_session *session) +{ + const struct rspamd_fuzzy_reply *rep; + struct rspamd_fuzzy_cmd *cmd = NULL; + struct fuzzy_cmd_io *io = NULL; + gint r, ret; + guchar buf[2048], *p; + + if ((r = read(session->fd, buf, sizeof(buf) - 1)) == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { + return 0; + } + else { + fuzzy_lua_push_error(session, "cannot read from socket: %s", strerror(errno)); + return -1; + } + } + else { + p = buf; + + ret = 0; + + while ((rep = fuzzy_process_reply(&p, &r, + session->commands, session->rule, &cmd, &io)) != NULL) { + + if (rep->v1.prob > 0.5) { + if (cmd->cmd == FUZZY_PING) { + fuzzy_lua_push_result(session, fuzzy_milliseconds_since_midnight() - rep->v1.value); + } + else { + fuzzy_lua_push_error(session, "unsupported"); + } + } + else { + fuzzy_lua_push_error(session, "invalid reply from server: %d", rep->v1.value); + } + + ret = 1; + } + } + + return ret; +} + +/* Fuzzy check callback */ +static void +fuzzy_lua_io_callback(gint fd, short what, void *arg) +{ + struct fuzzy_lua_session *session = arg; + gint r; + + enum { + return_error = 0, + return_want_more, + return_finished + } ret = return_error; + + if (what & EV_READ) { + /* Try to read reply */ + r = fuzzy_lua_try_read(session); + + switch (r) { + case 0: + if (what & EV_READ) { + ret = return_want_more; + } + else { + if (what & EV_WRITE) { + /* Retransmit attempt */ + if (!fuzzy_cmd_vector_to_wire(fd, session->commands)) { + fuzzy_lua_push_error(session, "cannot write to socket"); + ret = return_error; + } + else { + ret = return_want_more; + } + } + } + break; + case 1: + ret = return_finished; + break; + default: + ret = return_error; + break; + } + } + else if (what & EV_WRITE) { + if (!fuzzy_cmd_vector_to_wire(fd, session->commands)) { + fuzzy_lua_push_error(session, "cannot write to socket"); + ret = return_error; + } + else { + ret = return_want_more; + } + } + else { + /* Timeout */ + fuzzy_lua_push_error(session, "timeout waiting for the reply"); + ret = return_error; + } + + if (ret == return_want_more) { + /* Processed write, switch to reading */ + rspamd_ev_watcher_reschedule(session->task->event_loop, + &session->ev, EV_READ); + } + else if (ret == return_error) { + rspamd_session_remove_event(session->task->s, fuzzy_lua_session_fin, session); + } + else { + /* Read something from network */ + if (!fuzzy_lua_session_is_completed(session)) { + /* Need to read more */ + rspamd_ev_watcher_reschedule(session->task->event_loop, + &session->ev, EV_READ); + } + } +} + +/*** + * @function fuzzy_check.ping_storage(task, callback, rule, timeout[, server_override]) + * @return + */ +static gint +fuzzy_lua_ping_storage(lua_State *L) +{ + struct rspamd_task *task = lua_check_task(L, 1); + + if (task == NULL) { + return luaL_error(L, "invalid arguments: task"); + } + + /* Other arguments sanity */ + if (lua_type(L, 2) != LUA_TFUNCTION || lua_type(L, 3) != LUA_TSTRING || lua_type(L, 4) != LUA_TNUMBER) { + return luaL_error(L, "invalid arguments: callback/rule/timeout argument"); + } + + struct fuzzy_ctx *fuzzy_module_ctx = fuzzy_get_context(task->cfg); + struct fuzzy_rule *rule, *rule_found = NULL; + int i; + const char *rule_name = lua_tostring(L, 3); + + PTR_ARRAY_FOREACH(fuzzy_module_ctx->fuzzy_rules, i, rule) + { + if (strcmp(rule->name, rule_name) == 0) { + rule_found = rule; + break; + } + } + + if (rule_found == NULL) { + return luaL_error(L, "invalid arguments: no such rule defined"); + } + + rspamd_inet_addr_t *addr = NULL; + + if (lua_type(L, 5) == LUA_TSTRING) { + const gchar *server_name = lua_tostring(L, 5); + enum rspamd_parse_host_port_result res; + GPtrArray *addrs = g_ptr_array_new(); + + /* We resolve address synchronously here! Why? Because it is an override... */ + res = rspamd_parse_host_port_priority(server_name, &addrs, 0, NULL, + 11335, FALSE, task->task_pool); + + if (res == RSPAMD_PARSE_ADDR_FAIL) { + lua_pushboolean(L, FALSE); + lua_pushfstring(L, "invalid arguments: cannot resolve %s", server_name); + return 2; + } + + /* Get random address */ + addr = rspamd_inet_address_copy(g_ptr_array_index(addrs, rspamd_random_uint64_fast() % addrs->len), + task->task_pool); + rspamd_mempool_add_destructor(task->task_pool, + rspamd_ptr_array_free_hard, addrs); + } + else { + struct upstream *selected = rspamd_upstream_get(rule_found->servers, + RSPAMD_UPSTREAM_ROUND_ROBIN, NULL, 0); + addr = rspamd_upstream_addr_next(selected); + } + + if (addr != NULL) { + int sock; + GPtrArray *commands = fuzzy_generate_commands(task, rule, FUZZY_PING, 0, 0, 0); + + if ((sock = rspamd_inet_address_connect(addr, SOCK_DGRAM, TRUE)) == -1) { + lua_pushboolean(L, FALSE); + lua_pushfstring(L, "cannot connect to %s, %s", + rspamd_inet_address_to_string_pretty(addr), + strerror(errno)); + return 2; + } + else { + /* Create a dedicated ping session for a socket */ + struct fuzzy_lua_session *session = + rspamd_mempool_alloc0(task->task_pool, + sizeof(struct fuzzy_lua_session)); + session->task = task; + session->fd = sock; + session->addr = addr; + session->commands = commands; + session->L = L; + session->rule = rule_found; + /* Store callback */ + lua_pushvalue(L, 2); + session->cbref = luaL_ref(L, LUA_REGISTRYINDEX); + + rspamd_session_add_event(task->s, fuzzy_lua_session_fin, session, M); + rspamd_ev_watcher_init(&session->ev, + sock, + EV_WRITE, + fuzzy_lua_io_callback, + session); + rspamd_ev_watcher_start(session->task->event_loop, &session->ev, + lua_tonumber(L, 4)); + } + } + + lua_pushboolean(L, TRUE); + return 1; +}
\ No newline at end of file diff --git a/src/plugins/lua/antivirus.lua b/src/plugins/lua/antivirus.lua new file mode 100644 index 0000000..e39ddc5 --- /dev/null +++ b/src/plugins/lua/antivirus.lua @@ -0,0 +1,348 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]] -- + +local rspamd_logger = require "rspamd_logger" +local lua_util = require "lua_util" +local rspamd_util = require "rspamd_util" +local lua_redis = require "lua_redis" +local fun = require "fun" +local lua_antivirus = require("lua_scanners").filter('antivirus') +local common = require "lua_scanners/common" +local redis_params + +local N = "antivirus" + +if confighelp then + rspamd_config:add_example(nil, 'antivirus', + "Check messages for viruses", + [[ + antivirus { + # multiple scanners could be checked, for each we create a configuration block with an arbitrary name + clamav { + # If set force this action if any virus is found (default unset: no action is forced) + # action = "reject"; + # If set, then rejection message is set to this value (mention single quotes) + # message = '${SCANNER}: virus found: "${VIRUS}"'; + # Scan mime_parts separately - otherwise the complete mail will be transferred to AV Scanner + #scan_mime_parts = true; + # Scanning Text is suitable for some av scanner databases (e.g. Sanesecurity) + #scan_text_mime = false; + #scan_image_mime = false; + # If `max_size` is set, messages > n bytes in size are not scanned + max_size = 20000000; + # symbol to add (add it to metric if you want non-zero weight) + symbol = "CLAM_VIRUS"; + # type of scanner: "clamav", "fprot", "sophos" or "savapi" + type = "clamav"; + # For "savapi" you must also specify the following variable + product_id = 12345; + # You can enable logging for clean messages + log_clean = true; + # servers to query (if port is unspecified, scanner-specific default is used) + # can be specified multiple times to pool servers + # can be set to a path to a unix socket + # Enable this in local.d/antivirus.conf + servers = "127.0.0.1:3310"; + # if `patterns` is specified virus name will be matched against provided regexes and the related + # symbol will be yielded if a match is found. If no match is found, default symbol is yielded. + patterns { + # symbol_name = "pattern"; + JUST_EICAR = "^Eicar-Test-Signature$"; + } + # `whitelist` points to a map of IP addresses. Mail from these addresses is not scanned. + whitelist = "/etc/rspamd/antivirus.wl"; + # Replace content that exactly matches the following string to the EICAR pattern + # Useful for E2E testing when another party removes/blocks EICAR attachments + #eicar_fake_pattern = 'testpatterneicar'; + } + } + ]]) + return +end + +-- Encode as base32 in the source to avoid crappy stuff +local eicar_pattern = rspamd_util.decode_base32( + [[akp6woykfbonrepmwbzyfpbmibpone3mj3pgwbffzj9e1nfjdkorisckwkohrnfe1nt41y3jwk1cirjki4w4nkieuni4ndfjcktnn1yjmb1wn]] +) + +local function add_antivirus_rule(sym, opts) + if not opts.type then + rspamd_logger.errx(rspamd_config, 'unknown type for AV rule %s', sym) + return nil + end + + if not opts.symbol then + opts.symbol = sym:upper() + end + local cfg = lua_antivirus[opts.type] + + if not cfg then + rspamd_logger.errx(rspamd_config, 'unknown antivirus type: %s', + opts.type) + return nil + end + + if not opts.symbol_fail then + opts.symbol_fail = opts.symbol .. '_FAIL' + end + if not opts.symbol_encrypted then + opts.symbol_encrypted = opts.symbol .. '_ENCRYPTED' + end + if not opts.symbol_macro then + opts.symbol_macro = opts.symbol .. '_MACRO' + end + + -- WORKAROUND for deprecated attachments_only + if opts.attachments_only ~= nil then + opts.scan_mime_parts = opts.attachments_only + rspamd_logger.warnx(rspamd_config, '%s [%s]: Using attachments_only is deprecated. ' .. + 'Please use scan_mime_parts = %s instead', opts.symbol, opts.type, opts.attachments_only) + end + -- WORKAROUND for deprecated attachments_only + + local rule = cfg.configure(opts) + if not rule then + return nil + end + + rule.type = opts.type + rule.symbol_fail = opts.symbol_fail + rule.symbol_encrypted = opts.symbol_encrypted + rule.redis_params = redis_params + + if not rule then + rspamd_logger.errx(rspamd_config, 'cannot configure %s for %s', + opts.type, opts.symbol) + return nil + end + + rule.patterns = common.create_regex_table(opts.patterns or {}) + rule.patterns_fail = common.create_regex_table(opts.patterns_fail or {}) + + lua_redis.register_prefix(rule.prefix .. '_*', N, + string.format('Antivirus cache for rule "%s"', + rule.type), { + type = 'string', + }) + + -- if any mime_part filter defined, do not scan all attachments + if opts.mime_parts_filter_regex ~= nil + or opts.mime_parts_filter_ext ~= nil then + rule.scan_all_mime_parts = false + else + rule.scan_all_mime_parts = true + end + + rule.patterns = common.create_regex_table(opts.patterns or {}) + rule.patterns_fail = common.create_regex_table(opts.patterns_fail or {}) + + rule.mime_parts_filter_regex = common.create_regex_table(opts.mime_parts_filter_regex or {}) + + rule.mime_parts_filter_ext = common.create_regex_table(opts.mime_parts_filter_ext or {}) + + if opts.whitelist then + rule.whitelist = rspamd_config:add_hash_map(opts.whitelist) + end + + return function(task) + if rule.scan_mime_parts then + + fun.each(function(p) + local content = p:get_content() + local clen = #content + if content and clen > 0 then + if opts.eicar_fake_pattern then + if type(opts.eicar_fake_pattern) == 'string' then + -- Convert it to Rspamd text + local rspamd_text = require "rspamd_text" + opts.eicar_fake_pattern = rspamd_text.fromstring(opts.eicar_fake_pattern) + end + + if clen == #opts.eicar_fake_pattern and content == opts.eicar_fake_pattern then + rspamd_logger.infox(task, 'found eicar fake replacement part in the part (filename="%s")', + p:get_filename()) + content = eicar_pattern + end + end + cfg.check(task, content, p:get_digest(), rule, p) + end + end, common.check_parts_match(task, rule)) + + else + cfg.check(task, task:get_content(), task:get_digest(), rule) + end + end +end + +-- Registration +local opts = rspamd_config:get_all_opt(N) +if opts and type(opts) == 'table' then + redis_params = lua_redis.parse_redis_server(N) + local has_valid = false + for k, m in pairs(opts) do + if type(m) == 'table' then + if not m.type then + m.type = k + end + if not m.name then + m.name = k + end + local cb = add_antivirus_rule(k, m) + + if not cb then + rspamd_logger.errx(rspamd_config, 'cannot add rule: "' .. k .. '"') + lua_util.config_utils.push_config_error(N, 'cannot add AV rule: "' .. k .. '"') + else + rspamd_logger.infox(rspamd_config, 'added antivirus engine %s -> %s', k, m.symbol) + local t = { + name = m.symbol, + callback = cb, + score = 0.0, + group = N + } + + if m.symbol_type == 'postfilter' then + t.type = 'postfilter' + t.priority = lua_util.symbols_priorities.medium + else + t.type = 'normal' + end + + t.augmentations = {} + + if type(m.timeout) == 'number' then + -- Here, we ignore possible DNS timeout and timeout from multiple retries + -- as these situations are not usual nor likely for the antivirus module + table.insert(t.augmentations, string.format("timeout=%f", m.timeout)) + end + + local id = rspamd_config:register_symbol(t) + + rspamd_config:register_symbol({ + type = 'virtual', + name = m['symbol_fail'], + parent = id, + score = 0.0, + group = N + }) + rspamd_config:register_symbol({ + type = 'virtual', + name = m['symbol_encrypted'], + parent = id, + score = 0.0, + group = N + }) + rspamd_config:register_symbol({ + type = 'virtual', + name = m['symbol_macro'], + parent = id, + score = 0.0, + group = N + }) + has_valid = true + if type(m['patterns']) == 'table' then + if m['patterns'][1] then + for _, p in ipairs(m['patterns']) do + if type(p) == 'table' then + for sym in pairs(p) do + rspamd_logger.debugm(N, rspamd_config, 'registering: %1', { + type = 'virtual', + name = sym, + parent = m['symbol'], + parent_id = id, + group = N + }) + rspamd_config:register_symbol({ + type = 'virtual', + name = sym, + parent = id, + score = 0.0, + group = N + }) + end + end + end + else + for sym in pairs(m['patterns']) do + rspamd_config:register_symbol({ + type = 'virtual', + name = sym, + parent = id, + score = 0.0, + group = N + }) + end + end + end + if type(m['patterns_fail']) == 'table' then + if m['patterns_fail'][1] then + for _, p in ipairs(m['patterns_fail']) do + if type(p) == 'table' then + for sym in pairs(p) do + rspamd_logger.debugm(N, rspamd_config, 'registering: %1', { + type = 'virtual', + name = sym, + parent = m['symbol'], + parent_id = id, + group = N + }) + rspamd_config:register_symbol({ + type = 'virtual', + name = sym, + parent = id, + score = 0.0, + group = N + }) + end + end + end + else + for sym in pairs(m['patterns_fail']) do + rspamd_config:register_symbol({ + type = 'virtual', + name = sym, + parent = id, + score = 0.0, + group = N + }) + end + end + end + if m['score'] then + -- Register metric symbol + local description = 'antivirus symbol' + local group = N + if m['description'] then + description = m['description'] + end + if m['group'] then + group = m['group'] + end + rspamd_config:set_metric_symbol({ + name = m['symbol'], + score = m['score'], + description = description, + group = group or 'antivirus' + }) + end + end + end + end + + if not has_valid then + lua_util.disable_module(N, 'config') + end +end diff --git a/src/plugins/lua/arc.lua b/src/plugins/lua/arc.lua new file mode 100644 index 0000000..ff19aef --- /dev/null +++ b/src/plugins/lua/arc.lua @@ -0,0 +1,853 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]] -- + +local rspamd_logger = require "rspamd_logger" +local lua_util = require "lua_util" +local dkim_sign_tools = require "lua_dkim_tools" +local rspamd_util = require "rspamd_util" +local rspamd_rsa_privkey = require "rspamd_rsa_privkey" +local rspamd_rsa = require "rspamd_rsa" +local fun = require "fun" +local lua_auth_results = require "lua_auth_results" +local hash = require "rspamd_cryptobox_hash" +local lua_mime = require "lua_mime" + +if confighelp then + return +end + +local N = 'arc' +local AR_TRUSTED_CACHE_KEY = 'arc_trusted_aar' + +if not rspamd_plugins.dkim then + rspamd_logger.errx(rspamd_config, "cannot enable arc plugin: dkim is disabled") + return +end + +local dkim_verify = rspamd_plugins.dkim.verify +local dkim_sign = rspamd_plugins.dkim.sign +local dkim_canonicalize = rspamd_plugins.dkim.canon_header_relaxed +local redis_params + +if not dkim_verify or not dkim_sign or not dkim_canonicalize then + rspamd_logger.errx(rspamd_config, "cannot enable arc plugin: dkim is disabled") + return +end + +local arc_symbols = { + allow = 'ARC_ALLOW', + invalid = 'ARC_INVALID', + dnsfail = 'ARC_DNSFAIL', + na = 'ARC_NA', + reject = 'ARC_REJECT', +} + +local settings = { + allow_envfrom_empty = true, + allow_hdrfrom_mismatch = false, + allow_hdrfrom_mismatch_local = false, + allow_hdrfrom_mismatch_sign_networks = false, + allow_hdrfrom_multiple = false, + allow_username_mismatch = false, + sign_authenticated = true, + domain = {}, + path = string.format('%s/%s/%s', rspamd_paths['DBDIR'], 'arc', '$domain.$selector.key'), + sign_local = true, + selector = 'arc', + sign_symbol = 'ARC_SIGNED', + try_fallback = true, + use_domain = 'header', + use_esld = true, + use_redis = false, + key_prefix = 'arc_keys', -- default hash name + reuse_auth_results = false, -- Reuse the existing authentication results + whitelisted_signers_map = nil, -- Trusted signers domains + adjust_dmarc = true, -- Adjust DMARC rejected policy for trusted forwarders + allowed_ids = nil, -- Allowed settings id + forbidden_ids = nil, -- Banned settings id +} + +-- To match normal AR +local ar_settings = lua_auth_results.default_settings + +local function parse_arc_header(hdr, target, is_aar) + -- Split elements by ';' and trim spaces + local arr = fun.totable(fun.map( + function(val) + return fun.totable(fun.map(lua_util.rspamd_str_trim, + fun.filter(function(v) + return v and #v > 0 + end, + lua_util.rspamd_str_split(val.decoded, ';') + ) + )) + end, hdr + )) + + -- v[1] is the key and v[2] is the value + local function fill_arc_header_table(v, t) + if v[1] and v[2] then + local key = lua_util.rspamd_str_trim(v[1]) + local value = lua_util.rspamd_str_trim(v[2]) + t[key] = value + end + end + + -- Now we have two tables in format: + -- [arc_header] -> [{arc_header1_elts}, {arc_header2_elts}...] + for i, elts in ipairs(arr) do + if not target[i] then + target[i] = {} + end + if not is_aar then + -- For normal ARC headers we split by kv pair, like k=v + fun.each(function(v) + fill_arc_header_table(v, target[i]) + end, + fun.map(function(elt) + return lua_util.rspamd_str_split(elt, '=') + end, elts) + ) + else + -- For AAR we check special case of i=%d and pass everything else to + -- AAR specific parser + for _, elt in ipairs(elts) do + if string.match(elt, "%s*i%s*=%s*%d+%s*") then + local pair = lua_util.rspamd_str_split(elt, '=') + fill_arc_header_table(pair, target[i]) + else + -- Normal element + local ar_elt = lua_auth_results.parse_ar_element(elt) + + if ar_elt then + if not target[i].ar then + target[i].ar = {} + end + table.insert(target[i].ar, ar_elt) + end + end + end + end + target[i].header = hdr[i].decoded + target[i].raw_header = hdr[i].value + end + + -- sort by i= attribute + table.sort(target, function(a, b) + return (a.i or 0) < (b.i or 0) + end) +end + +local function arc_validate_seals(task, seals, sigs, seal_headers, sig_headers) + local fail_reason + for i = 1, #seals do + if (sigs[i].i or 0) ~= i then + fail_reason = string.format('bad i for signature: %d, expected %d; d=%s', + sigs[i].i, i, sigs[i].d) + rspamd_logger.infox(task, fail_reason) + task:insert_result(arc_symbols['invalid'], 1.0, fail_reason) + return false, fail_reason + end + if (seals[i].i or 0) ~= i then + fail_reason = string.format('bad i for seal: %d, expected %d; d=%s', + seals[i].i, i, seals[i].d) + rspamd_logger.infox(task, fail_reason) + task:insert_result(arc_symbols['invalid'], 1.0, fail_reason) + return false, fail_reason + end + + if not seals[i].cv then + fail_reason = string.format('no cv on i=%d', i) + task:insert_result(arc_symbols['invalid'], 1.0, fail_reason) + return false, fail_reason + end + + if i == 1 then + -- We need to ensure that cv of seal is equal to 'none' + if seals[i].cv ~= 'none' then + fail_reason = 'cv is not "none" for i=1' + task:insert_result(arc_symbols['invalid'], 1.0, fail_reason) + return false, fail_reason + end + else + if seals[i].cv ~= 'pass' then + fail_reason = string.format('cv is %s on i=%d', seals[i].cv, i) + task:insert_result(arc_symbols['reject'], 1.0, fail_reason) + return true, fail_reason + end + end + end + + return true, nil +end + +local function arc_callback(task) + local arc_sig_headers = task:get_header_full('ARC-Message-Signature') + local arc_seal_headers = task:get_header_full('ARC-Seal') + local arc_ar_headers = task:get_header_full('ARC-Authentication-Results') + + if not arc_sig_headers or not arc_seal_headers then + task:insert_result(arc_symbols['na'], 1.0) + return + end + + if #arc_sig_headers ~= #arc_seal_headers then + -- We mandate that count of seals is equal to count of signatures + rspamd_logger.infox(task, 'number of seals (%s) is not equal to number of signatures (%s)', + #arc_seal_headers, #arc_sig_headers) + task:insert_result(arc_symbols['invalid'], 1.0, 'invalid count of seals and signatures') + return + end + + local cbdata = { + seals = {}, + sigs = {}, + ars = {}, + res = 'success', + errors = {}, + allowed_by_trusted = false + } + + parse_arc_header(arc_seal_headers, cbdata.seals, false) + parse_arc_header(arc_sig_headers, cbdata.sigs, false) + + if arc_ar_headers then + parse_arc_header(arc_ar_headers, cbdata.ars, true) + end + + -- Fix i type + fun.each(function(hdr) + hdr.i = tonumber(hdr.i) or 0 + end, cbdata.seals) + + fun.each(function(hdr) + hdr.i = tonumber(hdr.i) or 0 + end, cbdata.sigs) + + -- Now we need to sort elements according to their [i] value + table.sort(cbdata.seals, function(e1, e2) + return (e1.i or 0) < (e2.i or 0) + end) + table.sort(cbdata.sigs, function(e1, e2) + return (e1.i or 0) < (e2.i or 0) + end) + + lua_util.debugm(N, task, 'got %s arc sections', #cbdata.seals) + + -- Now check sanity of what we have + local valid, validation_error = arc_validate_seals(task, cbdata.seals, cbdata.sigs, + arc_seal_headers, arc_sig_headers) + if not valid then + task:cache_set('arc-failure', validation_error) + return + end + + task:cache_set('arc-sigs', cbdata.sigs) + task:cache_set('arc-seals', cbdata.seals) + task:cache_set('arc-authres', cbdata.ars) + + if validation_error then + -- ARC rejection but no strong failure for signing + return + end + + local function gen_arc_seal_cb(index, sig) + return function(_, res, err, domain) + lua_util.debugm(N, task, 'checked arc seal: %s(%s), %s processed', + res, err, index) + + if not res then + cbdata.res = 'fail' + if err and domain then + table.insert(cbdata.errors, string.format('sig:%s:%s', domain, err)) + end + end + + if settings.whitelisted_signers_map and cbdata.res == 'success' then + if settings.whitelisted_signers_map:get_key(sig.d) then + -- Whitelisted signer has been found in a valid chain + local mult = 1.0 + local cur_aar = cbdata.ars[index] + if not cur_aar then + rspamd_logger.warnx(task, "cannot find Arc-Authentication-Results for trusted " .. + "forwarder %s on i=%s", domain, cbdata.index) + else + task:cache_set(AR_TRUSTED_CACHE_KEY, cur_aar) + local seen_dmarc + for _, ar in ipairs(cur_aar.ar) do + if ar.dmarc then + local dmarc_fwd = ar.dmarc + seen_dmarc = true + if dmarc_fwd == 'reject' or dmarc_fwd == 'fail' or dmarc_fwd == 'quarantine' then + lua_util.debugm(N, "found rejected dmarc on forwarding") + mult = 0.0 + elseif dmarc_fwd == 'pass' then + mult = 1.0 + end + elseif ar.spf then + local spf_fwd = ar.spf + if spf_fwd == 'reject' or spf_fwd == 'fail' or spf_fwd == 'quarantine' then + lua_util.debugm(N, "found rejected spf on forwarding") + if not seen_dmarc then + mult = mult * 0.5 + end + end + end + end + end + task:insert_result(arc_symbols.trusted_allow, mult, + string.format('%s:s=%s:i=%d', domain, sig.s, index)) + end + end + + if index == #arc_sig_headers then + if cbdata.res == 'success' then + local arc_allow_result = string.format('%s:s=%s:i=%d', + domain, sig.s, index) + task:insert_result(arc_symbols.allow, 1.0, arc_allow_result) + task:cache_set('arc-allow', arc_allow_result) + else + task:insert_result(arc_symbols.reject, 1.0, + rspamd_logger.slog('seal check failed: %s, %s', cbdata.res, + cbdata.errors)) + end + end + end + end + + local function arc_signature_cb(_, res, err, domain) + lua_util.debugm(N, task, 'checked arc signature %s: %s(%s)', + domain, res, err) + + if not res then + cbdata.res = 'fail' + if err and domain then + table.insert(cbdata.errors, string.format('sig:%s:%s', domain, err)) + end + end + if cbdata.res == 'success' then + -- Verify seals + for i, sig in ipairs(cbdata.seals) do + local ret, lerr = dkim_verify(task, sig.header, gen_arc_seal_cb(i, sig), 'arc-seal') + if not ret then + cbdata.res = 'fail' + table.insert(cbdata.errors, string.format('seal:%s:s=%s:i=%s:%s', + sig.d or '', sig.s or '', sig.i or '', lerr)) + lua_util.debugm(N, task, 'checked arc seal %s: %s(%s), %s processed', + sig.d, ret, lerr, i) + end + end + else + task:insert_result(arc_symbols['reject'], 1.0, + rspamd_logger.slog('signature check failed: %s, %s', cbdata.res, + cbdata.errors)) + end + end + + --[[ + 1. Collect all ARC Sets currently attached to the message. If there + are none, the Chain Validation Status is "none" and the algorithm + stops here. The maximum number of ARC Sets that can be attached + to a message is 50. If more than the maximum number exist the + Chain Validation Status is "fail" and the algorithm stops here. + In the following algorithm, the maximum ARC instance value is + referred to as "N". + + 2. If the Chain Validation Status of the highest instance value ARC + Set is "fail", then the Chain Validation status is "fail" and the + algorithm stops here. + + 3. Validate the structure of the Authenticated Received Chain. A + valid ARC has the following conditions: + + 1. Each ARC Set MUST contain exactly one each of the three ARC + header fields (AAR, AMS, and AS). + + 2. The instance values of the ARC Sets MUST form a continuous + sequence from 1..N with no gaps or repetition. + + 3. The "cv" value for all ARC-Seal header fields must be non- + failing. For instance values > 1, the value must be "pass". + For instance value = 1, the value must be "none". + + * If any of these conditions are not met, the Chain Validation + Status is "fail" and the algorithm stops here. + + 4. Validate the AMS with the greatest instance value (most recent). + If validation fails, then the Chain Validation Status is "fail" + and the algorithm stops here. + + 5 - 7. Optional, not implemented + 8. Validate each AS beginning with the greatest instance value and + proceeding in decreasing order to the AS with the instance value + of 1. If any AS fails to validate, the Chain Validation Status + is "fail" and the algorithm stops here. + 9. If the algorithm reaches this step, then the Chain Validation + Status is "pass", and the algorithm is complete. + ]]-- + + local processed = 0 + local sig = cbdata.sigs[#cbdata.sigs] -- last AMS + local ret, err = dkim_verify(task, sig.header, arc_signature_cb, 'arc-sign') + + if not ret then + cbdata.res = 'fail' + table.insert(cbdata.errors, string.format('sig:%s:%s', sig.d or '', err)) + else + processed = processed + 1 + lua_util.debugm(N, task, 'processed arc signature %s[%s]: %s(%s), %s total', + sig.d, sig.i, ret, err, #cbdata.seals) + end + + if processed == 0 then + task:insert_result(arc_symbols['reject'], 1.0, + rspamd_logger.slog('cannot verify %s of %s signatures: %s', + #arc_sig_headers - processed, #arc_sig_headers, cbdata.errors)) + end +end + +local opts = rspamd_config:get_all_opt('arc') +if not opts or type(opts) ~= 'table' then + return +end + +if opts['symbols'] then + for k, _ in pairs(arc_symbols) do + if opts['symbols'][k] then + arc_symbols[k] = opts['symbols'][k] + end + end +end + +local id = rspamd_config:register_symbol({ + name = 'ARC_CHECK', + type = 'callback', + group = 'policies', + groups = { 'arc' }, + callback = arc_callback, + augmentations = { lua_util.dns_timeout_augmentation(rspamd_config) }, +}) +rspamd_config:register_symbol({ + name = 'ARC_CALLBACK', -- compatibility symbol + type = 'virtual,skip', + parent = id, +}) + +rspamd_config:register_symbol({ + name = arc_symbols['allow'], + parent = id, + type = 'virtual', + score = -1.0, + group = 'policies', + groups = { 'arc' }, +}) +rspamd_config:register_symbol({ + name = arc_symbols['reject'], + parent = id, + type = 'virtual', + score = 2.0, + group = 'policies', + groups = { 'arc' }, +}) +rspamd_config:register_symbol({ + name = arc_symbols['invalid'], + parent = id, + type = 'virtual', + score = 1.0, + group = 'policies', + groups = { 'arc' }, +}) +rspamd_config:register_symbol({ + name = arc_symbols['dnsfail'], + parent = id, + type = 'virtual', + score = 0.0, + group = 'policies', + groups = { 'arc' }, +}) +rspamd_config:register_symbol({ + name = arc_symbols['na'], + parent = id, + type = 'virtual', + score = 0.0, + group = 'policies', + groups = { 'arc' }, +}) + +rspamd_config:register_dependency('ARC_CHECK', 'SPF_CHECK') +rspamd_config:register_dependency('ARC_CHECK', 'DKIM_CHECK') + +local function arc_sign_seal(task, params, header) + local arc_sigs = task:cache_get('arc-sigs') + local arc_seals = task:cache_get('arc-seals') + local arc_auth_results = task:cache_get('arc-authres') + local cur_auth_results + local privkey + + if params.rawkey then + -- Distinguish between pem and base64 + if string.match(params.rawkey, '^-----BEGIN') then + privkey = rspamd_rsa_privkey.load_pem(params.rawkey) + else + privkey = rspamd_rsa_privkey.load_base64(params.rawkey) + end + elseif params.key then + privkey = rspamd_rsa_privkey.load_file(params.key) + end + + if not privkey then + rspamd_logger.errx(task, 'cannot load private key for signing') + return + end + + if settings.reuse_auth_results then + local ar_header = task:get_header('Authentication-Results') + + if ar_header then + rspamd_logger.debugm(N, task, 'reuse authentication results header for ARC') + cur_auth_results = ar_header + else + rspamd_logger.debugm(N, task, 'cannot reuse authentication results, header is missing') + cur_auth_results = lua_auth_results.gen_auth_results(task, ar_settings) or '' + end + else + cur_auth_results = lua_auth_results.gen_auth_results(task, ar_settings) or '' + end + + local sha_ctx = hash.create_specific('sha256') + + -- Update using previous seals + sigs + AAR + local cur_idx = 1 + if arc_seals then + cur_idx = #arc_seals + 1 + -- We use the cached version per each ARC-* header field individually, already sorted by instance + -- value in ascending order + for i = 1, #arc_seals, 1 do + if arc_auth_results[i] then + local s = dkim_canonicalize('ARC-Authentication-Results', + arc_auth_results[i].raw_header) + sha_ctx:update(s) + lua_util.debugm(N, task, 'update signature with header: %s', s) + end + if arc_sigs[i] then + local s = dkim_canonicalize('ARC-Message-Signature', + arc_sigs[i].raw_header) + sha_ctx:update(s) + lua_util.debugm(N, task, 'update signature with header: %s', s) + end + if arc_seals[i] then + local s = dkim_canonicalize('ARC-Seal', arc_seals[i].raw_header) + sha_ctx:update(s) + lua_util.debugm(N, task, 'update signature with header: %s', s) + end + end + end + + header = lua_util.fold_header(task, + 'ARC-Message-Signature', + header) + + cur_auth_results = string.format('i=%d; %s', cur_idx, cur_auth_results) + cur_auth_results = lua_util.fold_header(task, + 'ARC-Authentication-Results', + cur_auth_results, ';') + + local s = dkim_canonicalize('ARC-Authentication-Results', + cur_auth_results) + sha_ctx:update(s) + lua_util.debugm(N, task, 'update signature with header: %s', s) + s = dkim_canonicalize('ARC-Message-Signature', header) + sha_ctx:update(s) + lua_util.debugm(N, task, 'update signature with header: %s', s) + + local cur_arc_seal = string.format('i=%d; s=%s; d=%s; t=%d; a=rsa-sha256; cv=%s; b=', + cur_idx, + params.selector, + params.domain, + math.floor(rspamd_util.get_time()), params.arc_cv) + s = string.format('%s:%s', 'arc-seal', cur_arc_seal) + sha_ctx:update(s) + lua_util.debugm(N, task, 'initial update signature with header: %s', s) + + local nl_type + if task:has_flag("milter") then + nl_type = "lf" + else + nl_type = task:get_newlines_type() + end + + local sig = rspamd_rsa.sign_memory(privkey, sha_ctx:bin()) + cur_arc_seal = string.format('%s%s', cur_arc_seal, + sig:base64(70, nl_type)) + + lua_mime.modify_headers(task, { + add = { + ['ARC-Authentication-Results'] = { order = 1, value = cur_auth_results }, + ['ARC-Message-Signature'] = { order = 1, value = header }, + ['ARC-Seal'] = { order = 1, value = lua_util.fold_header(task, + 'ARC-Seal', cur_arc_seal) } + }, + -- RFC requires a strict order for these headers to be inserted + order = { 'ARC-Authentication-Results', 'ARC-Message-Signature', 'ARC-Seal' }, + }) + task:insert_result(settings.sign_symbol, 1.0, + string.format('%s:s=%s:i=%d', params.domain, params.selector, cur_idx)) +end + +local function prepare_arc_selector(task, sel) + local arc_seals = task:cache_get('arc-seals') + + if not arc_seals then + -- Check if our arc is broken + local failure_reason = task:cache_get('arc-failure') + if failure_reason then + rspamd_logger.infox(task, 'skip ARC as the existing chain is broken: %s', failure_reason) + return false + end + end + + sel.arc_cv = 'none' + sel.arc_idx = 1 + sel.no_cache = true + sel.sign_type = 'arc-sign' + + if arc_seals then + sel.arc_idx = #arc_seals + 1 + + local function default_arc_cv() + if task:cache_get('arc-allow') then + sel.arc_cv = 'pass' + else + sel.arc_cv = 'fail' + end + end + + if settings.reuse_auth_results then + local ar_header = task:get_header('Authentication-Results') + + if ar_header then + local arc_match = string.match(ar_header, 'arc=(%w+)') + + if arc_match then + if arc_match == 'none' or arc_match == 'pass' then + -- none should be converted to `pass` + sel.arc_cv = 'pass' + else + sel.arc_cv = 'fail' + end + else + default_arc_cv() + end + else + -- Cannot reuse, use normal path + default_arc_cv() + end + else + default_arc_cv() + end + + end + + return true +end + +local function do_sign(task, sign_params) + if sign_params.alg and sign_params.alg ~= 'rsa' then + -- No support for ed25519 keys + return + end + + if not prepare_arc_selector(task, sign_params) then + -- Broken arc + return + end + + if settings.check_pubkey then + local resolve_name = sign_params.selector .. "._domainkey." .. sign_params.domain + task:get_resolver():resolve_txt({ + task = task, + name = resolve_name, + callback = function(_, _, results, err) + if not err and results and results[1] then + sign_params.pubkey = results[1] + sign_params.strict_pubkey_check = not settings.allow_pubkey_mismatch + elseif not settings.allow_pubkey_mismatch then + rspamd_logger.errx('public key for domain %s/%s is not found: %s, skip signing', + sign_params.domain, sign_params.selector, err) + return + else + rspamd_logger.infox('public key for domain %s/%s is not found: %s', + sign_params.domain, sign_params.selector, err) + end + + local dret, hdr = dkim_sign(task, sign_params) + if dret then + arc_sign_seal(task, sign_params, hdr) + end + + end, + forced = true + }) + else + local dret, hdr = dkim_sign(task, sign_params) + if dret then + arc_sign_seal(task, sign_params, hdr) + end + end +end + +local function sign_error(task, msg) + rspamd_logger.errx(task, 'signing failure: %s', msg) +end + +local function arc_signing_cb(task) + local ret, selectors = dkim_sign_tools.prepare_dkim_signing(N, task, settings) + + if not ret then + return + end + + if settings.use_redis then + dkim_sign_tools.sign_using_redis(N, task, settings, selectors, do_sign, sign_error) + else + if selectors.vault then + dkim_sign_tools.sign_using_vault(N, task, settings, selectors, do_sign, sign_error) + else + -- TODO: no support for multiple sigs + local cur_selector = selectors[1] + prepare_arc_selector(task, cur_selector) + if ((cur_selector.key or cur_selector.rawkey) and cur_selector.selector) then + if cur_selector.key then + cur_selector.key = lua_util.template(cur_selector.key, { + domain = cur_selector.domain, + selector = cur_selector.selector + }) + + local exists, err = rspamd_util.file_exists(cur_selector.key) + if not exists then + if err and err == 'No such file or directory' then + lua_util.debugm(N, task, 'cannot read key from %s: %s', cur_selector.key, err) + else + rspamd_logger.warnx(task, 'cannot read key from %s: %s', cur_selector.key, err) + end + return false + end + end + + do_sign(task, cur_selector) + else + rspamd_logger.infox(task, 'key path or dkim selector unconfigured; no signing') + return false + end + end + end +end + +dkim_sign_tools.process_signing_settings(N, settings, opts) + +if not dkim_sign_tools.validate_signing_settings(settings) then + rspamd_logger.infox(rspamd_config, 'mandatory parameters missing, disable arc signing') + return +end + +local ar_opts = rspamd_config:get_all_opt('milter_headers') + +if ar_opts and ar_opts.routines then + local routines = ar_opts.routines + + if routines['authentication-results'] then + ar_settings = lua_util.override_defaults(ar_settings, + routines['authentication-results']) + end +end + +if settings.use_redis then + redis_params = rspamd_parse_redis_server('arc') + + if not redis_params then + rspamd_logger.errx(rspamd_config, 'no servers are specified, ' .. + 'but module is configured to load keys from redis, disable arc signing') + return + end + + settings.redis_params = redis_params +end + +local sym_reg_tbl = { + name = settings['sign_symbol'], + callback = arc_signing_cb, + groups = { "policies", "arc" }, + flags = 'ignore_passthrough', + score = 0.0, +} +if type(settings.allowed_ids) == 'table' then + sym_reg_tbl.allowed_ids = settings.allowed_ids +end +if type(settings.forbidden_ids) == 'table' then + sym_reg_tbl.forbidden_ids = settings.forbidden_ids +end + +if settings.whitelisted_signers_map then + arc_symbols.trusted_allow = arc_symbols.trusted_allow or 'ARC_ALLOW_TRUSTED' + rspamd_config:register_symbol({ + name = arc_symbols.trusted_allow, + parent = id, + type = 'virtual', + score = -2.0, + group = 'policies', + groups = { 'arc' }, + }) +end + +rspamd_config:register_symbol(sym_reg_tbl) + +-- Do not sign unless checked +rspamd_config:register_dependency(settings['sign_symbol'], 'ARC_CHECK') +-- We need to check dmarc before signing as we have to produce valid AAR header +-- see #3613 +rspamd_config:register_dependency(settings['sign_symbol'], 'DMARC_CHECK') + +if settings.adjust_dmarc and settings.whitelisted_signers_map then + local function arc_dmarc_adjust_cb(task) + local trusted_arc_ar = task:cache_get(AR_TRUSTED_CACHE_KEY) + local sym_to_adjust + if task:has_symbol(ar_settings.dmarc_symbols.reject) then + sym_to_adjust = ar_settings.dmarc_symbols.reject + elseif task:has_symbol(ar_settings.dmarc_symbols.quarantine) then + sym_to_adjust = ar_settings.dmarc_symbols.quarantine + end + if sym_to_adjust and trusted_arc_ar and trusted_arc_ar.ar then + for _, ar in ipairs(trusted_arc_ar.ar) do + if ar.dmarc then + local dmarc_fwd = ar.dmarc + if dmarc_fwd == 'pass' then + rspamd_logger.infox(task, "adjust dmarc reject score as trusted forwarder " + .. "proved DMARC validity for %s", ar['header.from']) + task:adjust_result(sym_to_adjust, 0.1, + 'ARC trusted') + end + end + end + end + end + rspamd_config:register_symbol({ + name = 'ARC_DMARC_ADJUSTMENT', + callback = arc_dmarc_adjust_cb, + type = 'callback', + }) + rspamd_config:register_dependency('ARC_DMARC_ADJUSTMENT', 'DMARC_CHECK') + rspamd_config:register_dependency('ARC_DMARC_ADJUSTMENT', 'ARC_CHECK') +end diff --git a/src/plugins/lua/asn.lua b/src/plugins/lua/asn.lua new file mode 100644 index 0000000..24da19e --- /dev/null +++ b/src/plugins/lua/asn.lua @@ -0,0 +1,168 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> +Copyright (c) 2016, Andrew Lewis <nerf@judo.za.org> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local rspamd_logger = require "rspamd_logger" +local rspamd_regexp = require "rspamd_regexp" +local lua_util = require "lua_util" +local N = "asn" + +if confighelp then + return +end + +local options = { + provider_type = 'rspamd', + provider_info = { + ip4 = 'asn.rspamd.com', + ip6 = 'asn6.rspamd.com', + }, + symbol = 'ASN', + check_local = false, +} + +local rspamd_re = rspamd_regexp.create_cached("[\\|\\s]") + +local function asn_check(task) + + local function asn_set(asn, ipnet, country) + local descr_t = {} + local mempool = task:get_mempool() + if asn then + if tonumber(asn) ~= nil then + mempool:set_variable("asn", asn) + table.insert(descr_t, "asn:" .. asn) + else + rspamd_logger.errx(task, 'malformed ASN "%s" for ip %s', asn, task:get_from_ip()) + end + end + if ipnet then + mempool:set_variable("ipnet", ipnet) + table.insert(descr_t, "ipnet:" .. ipnet) + end + if country then + mempool:set_variable("country", country) + table.insert(descr_t, "country:" .. country) + end + if options['symbol'] then + task:insert_result(options['symbol'], 0.0, table.concat(descr_t, ', ')) + end + end + + local asn_check_func = {} + asn_check_func.rspamd = function(ip) + local dnsbl = options['provider_info']['ip' .. ip:get_version()] + local req_name = string.format("%s.%s", + table.concat(ip:inversed_str_octets(), '.'), dnsbl) + local function rspamd_dns_cb(_, _, results, dns_err, _, _, serv) + if dns_err and (dns_err ~= 'requested record is not found' and dns_err ~= 'no records with this name') then + rspamd_logger.errx(task, 'error querying dns "%s" on %s: %s', + req_name, serv, dns_err) + task:insert_result(options['symbol_fail'], 0, string.format('%s:%s', req_name, dns_err)) + return + end + if not results or not results[1] then + rspamd_logger.infox(task, 'no ASN information is available for the IP address "%s" on %s', + req_name, serv) + return + end + + lua_util.debugm(N, task, 'got reply from %s when requesting %s: %s', + serv, req_name, results[1]) + + local parts = rspamd_re:split(results[1]) + -- "15169 | 8.8.8.0/24 | US | arin |" for 8.8.8.8 + asn_set(parts[1], parts[2], parts[3]) + end + + task:get_resolver():resolve_txt({ + task = task, + name = req_name, + callback = rspamd_dns_cb + }) + end + + local ip = task:get_from_ip() + if not (ip and ip:is_valid()) or + (not options.check_local and ip:is_local()) then + return + end + + asn_check_func[options['provider_type']](ip) +end + +-- Configuration options +local configure_asn_module = function() + local opts = rspamd_config:get_all_opt('asn') + if opts then + for k, v in pairs(opts) do + options[k] = v + end + end + + local auth_and_local_conf = lua_util.config_check_local_or_authed(rspamd_config, N, + false, true) + options.check_local = auth_and_local_conf[1] + options.check_authed = auth_and_local_conf[2] + + if options['provider_type'] == 'rspamd' then + if not options['provider_info'] and options['provider_info']['ip4'] and + options['provider_info']['ip6'] then + rspamd_logger.errx("Missing required provider_info for rspamd") + return false + end + else + rspamd_logger.errx("Unknown provider_type: %s", options['provider_type']) + return false + end + + if options['symbol'] then + options['symbol_fail'] = options['symbol'] .. '_FAIL' + else + options['symbol_fail'] = 'ASN_FAIL' + end + + return true +end + +if configure_asn_module() then + local id = rspamd_config:register_symbol({ + name = 'ASN_CHECK', + type = 'prefilter', + callback = asn_check, + priority = lua_util.symbols_priorities.high, + flags = 'empty,nostat', + augmentations = { lua_util.dns_timeout_augmentation(rspamd_config) }, + }) + if options['symbol'] then + rspamd_config:register_symbol({ + name = options['symbol'], + parent = id, + type = 'virtual', + flags = 'empty,nostat', + score = 0, + }) + end + rspamd_config:register_symbol { + name = options['symbol_fail'], + parent = id, + type = 'virtual', + flags = 'empty,nostat', + score = 0, + } +else + lua_util.disable_module(N, 'config') +end diff --git a/src/plugins/lua/aws_s3.lua b/src/plugins/lua/aws_s3.lua new file mode 100644 index 0000000..30e88d2 --- /dev/null +++ b/src/plugins/lua/aws_s3.lua @@ -0,0 +1,269 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local N = "aws_s3" +local lua_util = require "lua_util" +local lua_aws = require "lua_aws" +local rspamd_logger = require "rspamd_logger" +local ts = (require "tableshape").types +local rspamd_text = require "rspamd_text" +local rspamd_http = require "rspamd_http" +local rspamd_util = require "rspamd_util" + +local settings = { + s3_bucket = nil, + s3_region = 'us-east-1', + s3_host = 's3.amazonaws.com', + s3_secret_key = nil, + s3_key_id = nil, + s3_timeout = 10, + save_raw = true, + save_structure = false, + inline_content_limit = nil, +} + +local settings_schema = ts.shape { + s3_bucket = ts.string, + s3_region = ts.string, + s3_host = ts.string, + s3_secret_key = ts.string, + s3_key_id = ts.string, + s3_timeout = ts.number + ts.string / lua_util.parse_time_interval, + enabled = ts.boolean:is_optional(), + fail_action = ts.string:is_optional(), + zstd_compress = ts.boolean:is_optional(), + save_raw = ts.boolean:is_optional(), + save_structure = ts.boolean:is_optional(), + inline_content_limit = ts.number:is_optional(), +} + +local function raw_data(task, nonce, queue_id) + local ext, content, content_type + + if settings.zstd_compress then + ext = 'eml.zst' + content = rspamd_util.zstd_compress(task:get_content()) + content_type = 'application/zstd' + else + ext = 'eml' + content = task:get_content() + content_type = 'message/rfc-822' + end + + local path = string.format('/%s-%s.%s', queue_id, nonce, ext) + + return path, content, content_type +end + +local function gen_ext(base) + local ext = base + if settings.zstd_compress then + ext = base .. '.zst' + end + + return ext +end + +local function convert_to_ref(task, nonce, queue_id, part, external_refs) + local path = string.format('/%s-%s-%s.%s', queue_id, nonce, + rspamd_text.randombytes(8):base32(), gen_ext('raw')) + local content = part.content + + if settings.zstd_compress then + external_refs[path] = rspamd_util.zstd_compress(content) + else + external_refs[path] = content + end + + part.content = nil + part.content_path = path + + return path +end + +local function structured_data(task, nonce, queue_id) + local content, content_type + local external_refs = {} + local lua_mime = require "lua_mime" + local ucl = require "ucl" + + local message_split = lua_mime.message_to_ucl(task) + if settings.inline_content_limit and settings.inline_content_limit > 0 then + + for i, part in ipairs(message_split.parts or {}) do + if part.content and #part.content >= settings.inline_content_limit then + local ref = convert_to_ref(task, nonce, queue_id, part, external_refs) + lua_util.debugm(N, task, "convert part number %s to a reference %s", + i, ref) + end + end + end + + if settings.zstd_compress then + content = rspamd_util.zstd_compress(ucl.to_format(message_split, 'msgpack')) + content_type = 'application/zstd' + else + content = ucl.to_format(message_split, 'msgpack') + content_type = 'application/msgpack' + end + + local path = string.format('/%s-%s.%s', queue_id, nonce, gen_ext('msgpack')) + + return path, content, content_type, external_refs +end + +local function s3_aws_callback(task) + local uri = string.format('https://%s.%s', settings.s3_bucket, settings.s3_host) + -- Create a nonce + local nonce = rspamd_text.randombytes(16):base32() + local queue_id = task:get_queue_id() + if not queue_id then + queue_id = rspamd_text.randombytes(8):base32() + end + -- Hack to pass host + local aws_host = string.format('%s.%s', settings.s3_bucket, settings.s3_host) + + local function gen_s3_http_callback(path, what) + return function(http_err, code, body, headers) + + if http_err then + if settings.fail_action then + task:set_pre_result(settings.fail_action, + string.format('S3 save failed: %s', http_err), N, + nil, nil, 'least') + end + rspamd_logger.errx(task, 'cannot save %s to AWS S3: %s', path, http_err) + else + rspamd_logger.messagex(task, 'saved %s successfully in S3 object %s', what, path) + end + lua_util.debugm(N, task, 'obj=%s, err=%s, code=%s, body=%s, headers=%s', + path, http_err, code, body, headers) + end + end + + if settings.save_raw then + local path, content, content_type = raw_data(task, nonce, queue_id) + local hdrs = lua_aws.aws_request_enrich({ + region = settings.s3_region, + headers = { + ['Content-Type'] = content_type, + ['Host'] = aws_host + }, + uri = path, + key_id = settings.s3_key_id, + secret_key = settings.s3_secret_key, + method = 'PUT', + }, content) + rspamd_http.request({ + url = uri .. path, + task = task, + method = 'PUT', + body = content, + callback = gen_s3_http_callback(path, 'raw message'), + headers = hdrs, + timeout = settings.s3_timeout, + }) + end + if settings.save_structure then + local path, content, content_type, external_refs = structured_data(task, nonce, queue_id) + local hdrs = lua_aws.aws_request_enrich({ + region = settings.s3_region, + headers = { + ['Content-Type'] = content_type, + ['Host'] = aws_host + }, + uri = path, + key_id = settings.s3_key_id, + secret_key = settings.s3_secret_key, + method = 'PUT', + }, content) + rspamd_http.request({ + url = uri .. path, + task = task, + method = 'PUT', + body = content, + callback = gen_s3_http_callback(path, 'structured message'), + headers = hdrs, + upstream = settings.upstreams:get_upstream_round_robin(), + timeout = settings.s3_timeout, + }) + + for ref, part_content in pairs(external_refs) do + local part_hdrs = lua_aws.aws_request_enrich({ + region = settings.s3_region, + headers = { + ['Content-Type'] = content_type, + ['Host'] = aws_host + }, + uri = ref, + key_id = settings.s3_key_id, + secret_key = settings.s3_secret_key, + method = 'PUT', + }, part_content) + rspamd_http.request({ + url = uri .. ref, + task = task, + upstream = settings.upstreams:get_upstream_round_robin(), + method = 'PUT', + body = part_content, + callback = gen_s3_http_callback(ref, 'part content'), + headers = part_hdrs, + timeout = settings.s3_timeout, + }) + end + end + + +end + +local opts = rspamd_config:get_all_opt('aws_s3') +if not opts then + return +end + +settings = lua_util.override_defaults(settings, opts) +local res, err = settings_schema:transform(settings) + +if not res then + rspamd_logger.warnx(rspamd_config, 'plugin is misconfigured: %s', err) + lua_util.disable_module(N, "config") + return +end + +rspamd_logger.infox(rspamd_config, 'enabled AWS s3 dump to %s', res.s3_bucket) + +settings = res + +settings.upstreams = lua_util.http_upstreams_by_url(rspamd_config:get_mempool(), + string.format('https://%s.%s', settings.s3_bucket, settings.s3_host)) + +if not settings.upstreams then + rspamd_logger.warnx(rspamd_config, 'cannot parse hostname: %s', + string.format('https://%s.%s', settings.s3_bucket, settings.s3_host)) + lua_util.disable_module(N, "config") + return +end + +local is_postfilter = settings.fail_action ~= nil + +rspamd_config:register_symbol({ + name = 'EXPORT_AWS_S3', + type = is_postfilter and 'postfilter' or 'idempotent', + callback = s3_aws_callback, + augmentations = { string.format("timeout=%f", settings.s3_timeout) }, + priority = is_postfilter and lua_util.symbols_priorities.high or nil, + flags = 'empty,explicit_disable,ignore_passthrough,nostat', +})
\ No newline at end of file diff --git a/src/plugins/lua/bayes_expiry.lua b/src/plugins/lua/bayes_expiry.lua new file mode 100644 index 0000000..44ff9da --- /dev/null +++ b/src/plugins/lua/bayes_expiry.lua @@ -0,0 +1,503 @@ +--[[ +Copyright (c) 2017, Andrew Lewis <nerf@judo.za.org> +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]] -- + +if confighelp then + return +end + +local N = 'bayes_expiry' +local E = {} +local logger = require "rspamd_logger" +local rspamd_util = require "rspamd_util" +local lutil = require "lua_util" +local lredis = require "lua_redis" + +local settings = { + interval = 60, -- one iteration step per minute + count = 1000, -- check up to 1000 keys on each iteration + epsilon_common = 0.01, -- eliminate common if spam to ham rate is equal to this epsilon + common_ttl = 10 * 86400, -- TTL of discriminated common elements + significant_factor = 3.0 / 4.0, -- which tokens should we update + classifiers = {}, + cluster_nodes = 0, +} + +local template = {} + +local function check_redis_classifier(cls, cfg) + -- Skip old classifiers + if cls.new_schema then + local symbol_spam, symbol_ham + local expiry = (cls.expiry or cls.expire) + if type(expiry) == 'table' then + expiry = expiry[1] + end + + -- Load symbols from statfiles + + local function check_statfile_table(tbl, def_sym) + local symbol = tbl.symbol or def_sym + + local spam + if tbl.spam then + spam = tbl.spam + else + if string.match(symbol:upper(), 'SPAM') then + spam = true + else + spam = false + end + end + + if spam then + symbol_spam = symbol + else + symbol_ham = symbol + end + end + + local statfiles = cls.statfile + if statfiles[1] then + for _, stf in ipairs(statfiles) do + if not stf.symbol then + for k, v in pairs(stf) do + check_statfile_table(v, k) + end + else + check_statfile_table(stf, 'undefined') + end + end + else + for stn, stf in pairs(statfiles) do + check_statfile_table(stf, stn) + end + end + + if not symbol_spam or not symbol_ham or type(expiry) ~= 'number' then + logger.debugm(N, rspamd_config, + 'disable expiry for classifier %s: no expiry %s', + symbol_spam, cls) + return + end + -- Now try to load redis_params if needed + + local redis_params + redis_params = lredis.try_load_redis_servers(cls, rspamd_config, false, 'bayes') + if not redis_params then + redis_params = lredis.try_load_redis_servers(cfg[N] or E, rspamd_config, false, 'bayes') + if not redis_params then + redis_params = lredis.try_load_redis_servers(cfg[N] or E, rspamd_config, true) + if not redis_params then + return false + end + end + end + + if redis_params['read_only'] then + logger.infox(rspamd_config, 'disable expiry for classifier %s: read only redis configuration', + symbol_spam) + return + end + + logger.debugm(N, rspamd_config, "enabled expiry for %s/%s -> %s expiry", + symbol_spam, symbol_ham, expiry) + + table.insert(settings.classifiers, { + symbol_spam = symbol_spam, + symbol_ham = symbol_ham, + redis_params = redis_params, + expiry = expiry + }) + end +end + +-- Check classifiers and try find the appropriate ones +local obj = rspamd_config:get_ucl() + +local classifier = obj.classifier + +if classifier then + if classifier[1] then + for _, cls in ipairs(classifier) do + if cls.bayes then + cls = cls.bayes + end + if cls.backend and cls.backend == 'redis' then + check_redis_classifier(cls, obj) + end + end + else + if classifier.bayes then + + classifier = classifier.bayes + if classifier[1] then + for _, cls in ipairs(classifier) do + if cls.backend and cls.backend == 'redis' then + check_redis_classifier(cls, obj) + end + end + else + if classifier.backend and classifier.backend == 'redis' then + check_redis_classifier(classifier, obj) + end + end + end + end +end + +local opts = rspamd_config:get_all_opt(N) + +if opts then + for k, v in pairs(opts) do + settings[k] = v + end +end + +-- In clustered setup, we need to increase interval of expiration +-- according to number of nodes in a cluster +if settings.cluster_nodes == 0 then + local neighbours = obj.neighbours or {} + local n_neighbours = 0 + for _, _ in pairs(neighbours) do + n_neighbours = n_neighbours + 1 + end + settings.cluster_nodes = n_neighbours +end + +-- Fill template +template.count = settings.count +template.threshold = settings.threshold +template.common_ttl = settings.common_ttl +template.epsilon_common = settings.epsilon_common +template.significant_factor = settings.significant_factor +template.expire_step = settings.interval +template.hostname = rspamd_util.get_hostname() + +for k, v in pairs(template) do + template[k] = tostring(v) +end + +-- Arguments: +-- [1] = symbol pattern +-- [2] = expire value +-- [3] = cursor +-- returns {cursor for the next step, step number, step statistic counters, cycle statistic counters, tokens occurrences distribution} +local expiry_script = [[ + local unpack_function = table.unpack or unpack + + local hash2list = function (hash) + local res = {} + for k, v in pairs(hash) do + table.insert(res, k) + table.insert(res, v) + end + return res + end + + local function merge_list(table, list) + local k + for i, v in ipairs(list) do + if i % 2 == 1 then + k = v + else + table[k] = v + end + end + end + + local expire = math.floor(KEYS[2]) + local pattern_sha1 = redis.sha1hex(KEYS[1]) + + local lock_key = pattern_sha1 .. '_lock' -- Check locking + local lock = redis.call('GET', lock_key) + + if lock then + if lock ~= '${hostname}' then + return 'locked by ' .. lock + end + end + + redis.replicate_commands() + redis.call('SETEX', lock_key, ${expire_step}, '${hostname}') + + local cursor_key = pattern_sha1 .. '_cursor' + local cursor = tonumber(redis.call('GET', cursor_key) or 0) + + local step = 1 + local step_key = pattern_sha1 .. '_step' + if cursor > 0 then + step = redis.call('GET', step_key) + step = step and (tonumber(step) + 1) or 1 + end + + local ret = redis.call('SCAN', cursor, 'MATCH', KEYS[1], 'COUNT', '${count}') + local next_cursor = ret[1] + local keys = ret[2] + local tokens = {} + + -- Tokens occurrences distribution counters + local occur = { + ham = {}, + spam = {}, + total = {} + } + + -- Expiry step statistics counters + local nelts, extended, discriminated, sum, sum_squares, common, significant, + infrequent, infrequent_ttls_set, insignificant, insignificant_ttls_set = + 0,0,0,0,0,0,0,0,0,0,0 + + for _,key in ipairs(keys) do + local t = redis.call('TYPE', key)["ok"] + if t == 'hash' then + local values = redis.call('HMGET', key, 'H', 'S') + local ham = tonumber(values[1]) or 0 + local spam = tonumber(values[2]) or 0 + local ttl = redis.call('TTL', key) + tokens[key] = { + ham, + spam, + ttl + } + local total = spam + ham + sum = sum + total + sum_squares = sum_squares + total * total + nelts = nelts + 1 + + for k,v in pairs({['ham']=ham, ['spam']=spam, ['total']=total}) do + if tonumber(v) > 19 then v = 20 end + occur[k][v] = occur[k][v] and occur[k][v] + 1 or 1 + end + end + end + + local mean, stddev = 0, 0 + + if nelts > 0 then + mean = sum / nelts + stddev = math.sqrt(sum_squares / nelts - mean * mean) + end + + for key,token in pairs(tokens) do + local ham, spam, ttl = token[1], token[2], tonumber(token[3]) + local threshold = mean + local total = spam + ham + + local function set_ttl() + if expire < 0 then + if ttl ~= -1 then + redis.call('PERSIST', key) + return 1 + end + elseif ttl == -1 or ttl > expire then + redis.call('EXPIRE', key, expire) + return 1 + end + return 0 + end + + if total == 0 or math.abs(ham - spam) <= total * ${epsilon_common} then + common = common + 1 + if ttl > ${common_ttl} then + discriminated = discriminated + 1 + redis.call('EXPIRE', key, ${common_ttl}) + end + elseif total >= threshold and total > 0 then + if ham / total > ${significant_factor} or spam / total > ${significant_factor} then + significant = significant + 1 + if ttl ~= -1 then + redis.call('PERSIST', key) + extended = extended + 1 + end + else + insignificant = insignificant + 1 + insignificant_ttls_set = insignificant_ttls_set + set_ttl() + end + else + infrequent = infrequent + 1 + infrequent_ttls_set = infrequent_ttls_set + set_ttl() + end + end + + -- Expiry cycle statistics counters + local c = {nelts = 0, extended = 0, discriminated = 0, sum = 0, sum_squares = 0, + common = 0, significant = 0, infrequent = 0, infrequent_ttls_set = 0, insignificant = 0, insignificant_ttls_set = 0} + + local counters_key = pattern_sha1 .. '_counters' + + if cursor ~= 0 then + merge_list(c, redis.call('HGETALL', counters_key)) + end + + c.nelts = c.nelts + nelts + c.extended = c.extended + extended + c.discriminated = c.discriminated + discriminated + c.sum = c.sum + sum + c.sum_squares = c.sum_squares + sum_squares + c.common = c.common + common + c.significant = c.significant + significant + c.infrequent = c.infrequent + infrequent + c.infrequent_ttls_set = c.infrequent_ttls_set + infrequent_ttls_set + c.insignificant = c.insignificant + insignificant + c.insignificant_ttls_set = c.insignificant_ttls_set + insignificant_ttls_set + + redis.call('HMSET', counters_key, unpack_function(hash2list(c))) + redis.call('SET', cursor_key, tostring(next_cursor)) + redis.call('SET', step_key, tostring(step)) + redis.call('DEL', lock_key) + + local occ_distr = {} + for _,cl in pairs({'ham', 'spam', 'total'}) do + local occur_key = pattern_sha1 .. '_occurrence_' .. cl + + if cursor ~= 0 then + local n + for i,v in ipairs(redis.call('HGETALL', occur_key)) do + if i % 2 == 1 then + n = tonumber(v) + else + occur[cl][n] = occur[cl][n] and occur[cl][n] + v or v + end + end + + local str = '' + if occur[cl][0] ~= nil then + str = '0:' .. occur[cl][0] .. ',' + end + for k,v in ipairs(occur[cl]) do + if k == 20 then k = '>19' end + str = str .. k .. ':' .. v .. ',' + end + table.insert(occ_distr, str) + else + redis.call('DEL', occur_key) + end + + if next(occur[cl]) ~= nil then + redis.call('HMSET', occur_key, unpack_function(hash2list(occur[cl]))) + end + end + + return { + next_cursor, step, + {nelts, extended, discriminated, mean, stddev, common, significant, infrequent, + infrequent_ttls_set, insignificant, insignificant_ttls_set}, + {c.nelts, c.extended, c.discriminated, c.sum, c.sum_squares, c.common, + c.significant, c.infrequent, c.infrequent_ttls_set, c.insignificant, c.insignificant_ttls_set}, + occ_distr + } +]] + +local function expire_step(cls, ev_base, worker) + local function redis_step_cb(err, args) + if err then + logger.errx(rspamd_config, 'cannot perform expiry step: %s', err) + elseif type(args) == 'table' then + local cur = tonumber(args[1]) + local step = args[2] + local data = args[3] + local c_data = args[4] + local occ_distr = args[5] + + local function log_stat(cycle) + local infrequent_action = (cls.expiry < 0) and 'made persistent' or 'ttls set' + + local c_mean, c_stddev = 0, 0 + if cycle and c_data[1] ~= 0 then + c_mean = c_data[4] / c_data[1] + c_stddev = math.floor(.5 + math.sqrt(c_data[5] / c_data[1] - c_mean * c_mean)) + c_mean = math.floor(.5 + c_mean) + end + + local d = cycle and { + 'cycle in ' .. step .. ' steps', c_data[1], + c_data[7], c_data[2], 'made persistent', + c_data[10], c_data[11], infrequent_action, + c_data[6], c_data[3], + c_data[8], c_data[9], infrequent_action, + c_mean, + c_stddev + } or { + 'step ' .. step, data[1], + data[7], data[2], 'made persistent', + data[10], data[11], infrequent_action, + data[6], data[3], + data[8], data[9], infrequent_action, + data[4], + data[5] + } + logger.infox(rspamd_config, + 'finished expiry %s: %s items checked, %s significant (%s %s), ' .. + '%s insignificant (%s %s), %s common (%s discriminated), ' .. + '%s infrequent (%s %s), %s mean, %s std', + lutil.unpack(d)) + if cycle then + for i, cl in ipairs({ 'in ham', 'in spam', 'total' }) do + logger.infox(rspamd_config, 'tokens occurrences, %s: {%s}', cl, occ_distr[i]) + end + end + end + log_stat(false) + if cur == 0 then + log_stat(true) + end + elseif type(args) == 'string' then + logger.infox(rspamd_config, 'skip expiry step: %s', args) + end + end + lredis.exec_redis_script(cls.script, + { ev_base = ev_base, is_write = true }, + redis_step_cb, + { 'RS*_*', cls.expiry } + ) +end + +rspamd_config:add_on_load(function(_, ev_base, worker) + -- Exit unless we're the first 'controller' worker + if not worker:is_primary_controller() then + return + end + + local unique_redis_params = {} + -- Push redis script to all unique redis servers + for _, cls in ipairs(settings.classifiers) do + if not unique_redis_params[cls.redis_params.hash] then + unique_redis_params[cls.redis_params.hash] = cls.redis_params + end + end + + for h, rp in pairs(unique_redis_params) do + local script_id = lredis.add_redis_script(lutil.template(expiry_script, + template), rp) + + for _, cls in ipairs(settings.classifiers) do + if cls.redis_params.hash == h then + cls.script = script_id + end + end + end + + -- Expire tokens at regular intervals + for _, cls in ipairs(settings.classifiers) do + rspamd_config:add_periodic(ev_base, + settings['interval'], + function() + expire_step(cls, ev_base, worker) + return true + end, true) + end +end) diff --git a/src/plugins/lua/bimi.lua b/src/plugins/lua/bimi.lua new file mode 100644 index 0000000..2783590 --- /dev/null +++ b/src/plugins/lua/bimi.lua @@ -0,0 +1,391 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local N = "bimi" +local lua_util = require "lua_util" +local rspamd_logger = require "rspamd_logger" +local ts = (require "tableshape").types +local lua_redis = require "lua_redis" +local ucl = require "ucl" +local lua_mime = require "lua_mime" +local rspamd_http = require "rspamd_http" +local rspamd_util = require "rspamd_util" + +local settings = { + helper_url = "http://127.0.0.1:3030", + helper_timeout = 5, + helper_sync = true, + vmc_only = true, + redis_prefix = 'rs_bimi', + redis_min_expiry = 24 * 3600, +} +local redis_params + +local settings_schema = lua_redis.enrich_schema({ + helper_url = ts.string, + helper_timeout = ts.number + ts.string / lua_util.parse_time_interval, + helper_sync = ts.boolean, + vmc_only = ts.boolean, + redis_min_expiry = ts.number + ts.string / lua_util.parse_time_interval, + redis_prefix = ts.string, + enabled = ts.boolean:is_optional(), +}) + +local function check_dmarc_policy(task) + local dmarc_sym = task:get_symbol('DMARC_POLICY_ALLOW') + + if not dmarc_sym then + lua_util.debugm(N, task, "no DMARC allow symbol") + return nil + end + + local opts = dmarc_sym[1].options or {} + if not opts[1] or #opts ~= 2 then + lua_util.debugm(N, task, "DMARC options are bogus: %s", opts) + return nil + end + + -- opts[1] - domain; opts[2] - policy + local dom, policy = opts[1], opts[2] + + if policy ~= 'reject' and policy ~= 'quarantine' then + lua_util.debugm(N, task, "DMARC policy for domain %s is not strict: %s", + dom, policy) + return nil + end + + return dom +end + +local function gen_bimi_grammar() + local lpeg = require "lpeg" + lpeg.locale(lpeg) + local space = lpeg.space ^ 0 + local name = lpeg.C(lpeg.alpha ^ 1) * space + local sep = (lpeg.S("\\;") * space) + (lpeg.space ^ 1) + local value = lpeg.C(lpeg.P(lpeg.graph - sep) ^ 1) + local pair = lpeg.Cg(name * "=" * space * value) * sep ^ -1 + local list = lpeg.Cf(lpeg.Ct("") * pair ^ 0, rawset) + local version = lpeg.P("v") * space * lpeg.P("=") * space * lpeg.P("BIMI1") + local record = version * sep * list + + return record +end + +local bimi_grammar = gen_bimi_grammar() + +local function check_bimi_record(task, rec) + local elts = bimi_grammar:match(rec) + + if elts then + lua_util.debugm(N, task, "got BIMI record: %s, processed=%s", + rec, elts) + local res = {} + + if type(elts.l) == 'string' then + res.l = elts.l + end + if type(elts.a) == 'string' then + res.a = elts.a + end + + if res.l or res.a then + return res + end + end +end + +local function insert_bimi_headers(task, domain, bimi_content) + local hdr_name = 'BIMI-Indicator' + -- Re-encode base64... + local content = rspamd_util.encode_base64(rspamd_util.decode_base64(bimi_content), + 73, task:get_newlines_type()) + lua_mime.modify_headers(task, { + remove = { [hdr_name] = 0 }, + add = { + [hdr_name] = { + order = 0, + value = rspamd_util.fold_header(hdr_name, content, + task:get_newlines_type()) + } + } + }) + task:insert_result('BIMI_VALID', 1.0, { domain }) +end + +local function process_bimi_json(task, domain, redis_data) + local parser = ucl.parser() + local _, err = parser:parse_string(redis_data) + + if err then + rspamd_logger.errx(task, "cannot parse BIMI result from Redis for %s: %s", + domain, err) + else + local d = parser:get_object() + if d.content then + insert_bimi_headers(task, domain, d.content) + elseif d.error then + lua_util.debugm(N, task, "invalid BIMI for %s: %s", + domain, d.error) + end + end +end + +local function make_helper_request(task, domain, record, redis_server) + local is_sync = settings.helper_sync + local helper_url = string.format('%s/v1/check', settings.helper_url) + local redis_key = string.format('%s%s', settings.redis_prefix, + domain) + + local function http_helper_callback(http_err, code, body, _) + if http_err then + rspamd_logger.warnx(task, 'got error reply from helper %s: code=%s; reply=%s', + helper_url, code, http_err) + return + end + if code ~= 200 then + rspamd_logger.warnx(task, 'got non 200 reply from helper %s: code=%s; reply=%s', + helper_url, code, http_err) + return + end + if is_sync then + local parser = ucl.parser() + local _, err = parser:parse_string(body) + + if err then + rspamd_logger.errx(task, "cannot parse BIMI result from helper for %s: %s", + domain, err) + else + local d = parser:get_object() + if d.content then + insert_bimi_headers(task, domain, d.content) + elseif d.error then + lua_util.debugm(N, task, "invalid BIMI for %s: %s", + domain, d.error) + end + + local ret, upstream + local function redis_set_cb(redis_err, _) + if redis_err then + rspamd_logger.warnx(task, 'cannot get reply from Redis when storing image %s: %s', + upstream:get_addr():to_string(), redis_err) + upstream:fail() + else + lua_util.debugm(N, task, 'stored bimi image in Redis for domain %s; key=%s', + domain, redis_key) + end + end + + ret, _, upstream = lua_redis.redis_make_request(task, + redis_params, -- connect params + redis_key, -- hash key + true, -- is write + redis_set_cb, --callback + 'PSETEX', -- command + { redis_key, tostring(settings.redis_min_expiry * 1000.0), + ucl.to_format(d, "json-compact") }) + + if not ret then + rspamd_logger.warnx(task, 'cannot make request to Redis when storing image; domain %s', + domain) + end + end + else + -- In async mode we skip request and use merely Redis to insert indicators + lua_util.debugm(N, task, "sent request to resolve %s to %s", + domain, helper_url) + end + end + + local request_data = { + url = record.a, + sync = is_sync, + domain = domain + } + + if not is_sync then + -- Allow bimi helper to save data in Redis + request_data.redis_server = redis_server + request_data.redis_prefix = settings.redis_prefix + request_data.redis_expiry = settings.redis_min_expiry * 1000.0 + else + request_data.skip_redis = true + end + + local serialised = ucl.to_format(request_data, 'json-compact') + lua_util.debugm(N, task, "send request to BIMI helper: %s", + serialised) + rspamd_http.request({ + task = task, + mime_type = 'application/json', + timeout = settings.helper_timeout, + body = serialised, + url = helper_url, + callback = http_helper_callback, + keepalive = true, + }) +end + +local function check_bimi_vmc(task, domain, record) + local redis_key = string.format('%s%s', settings.redis_prefix, + domain) + local ret, _, upstream + + local function redis_cached_cb(err, data) + if err then + rspamd_logger.warnx(task, 'cannot get reply from Redis %s: %s', + upstream:get_addr():to_string(), err) + upstream:fail() + else + if type(data) == 'string' then + -- We got a cached record, good stuff + lua_util.debugm(N, task, "got valid cached BIMI result for domain: %s", + domain) + process_bimi_json(task, domain, data) + else + -- Get server addr + port + -- We need to fix IPv6 address as redis-rs has no support of + -- the braced IPv6 addresses + local db, password = '', '' + if redis_params.db then + db = string.format('/%s', redis_params.db) + end + if redis_params.username then + if redis_params.password then + password = string.format( '%s:%s@', redis_params.username, redis_params.password) + else + rspamd_logger.warnx(task, "Redis requires a password when username is supplied") + end + elseif redis_params.password then + password = string.format(':%s@', redis_params.password) + end + local redis_server = string.format('redis://%s%s:%s%s', + password, + upstream:get_name(), upstream:get_port(), + db) + make_helper_request(task, domain, record, redis_server) + end + end + end + + -- We first check Redis and then try to use helper + ret, _, upstream = lua_redis.redis_make_request(task, + redis_params, -- connect params + redis_key, -- hash key + false, -- is write + redis_cached_cb, --callback + 'GET', -- command + { redis_key }) + + if not ret then + rspamd_logger.warnx(task, 'cannot make request to Redis; domain %s', domain) + end +end + +local function check_bimi_dns(task, domain) + local resolve_name = string.format('default._bimi.%s', domain) + local dns_cb = function(_, _, results, err) + if err then + lua_util.debugm(N, task, "cannot resolve bimi for %s: %s", + domain, err) + else + for _, rec in ipairs(results) do + local res = check_bimi_record(task, rec) + + if res then + if settings.vmc_only and not res.a then + lua_util.debugm(N, task, "BIMI for domain %s has no VMC, skip it", + domain) + + return + end + + if res.a then + check_bimi_vmc(task, domain, res) + elseif res.l then + -- TODO: add l check + lua_util.debugm(N, task, "l only BIMI for domain %s is not implemented yet", + domain) + end + end + end + end + end + task:get_resolver():resolve_txt({ + task = task, + name = resolve_name, + callback = dns_cb, + forced = true + }) +end + +local function bimi_callback(task) + local dmarc_domain_maybe = check_dmarc_policy(task) + + if not dmarc_domain_maybe then + return + end + + + -- We can either check BIMI via DNS or check Redis cache + -- BIMI check is an external check, so we might prefer Redis to be checked + -- first. On the other hand, DNS request is cheaper and counting low BIMI + -- adaptation we would need to have both Redis and DNS request to hit no + -- result. So, it might be better to check DNS first at this stage... + check_bimi_dns(task, dmarc_domain_maybe) +end + +local opts = rspamd_config:get_all_opt('bimi') +if not opts then + lua_util.disable_module(N, "config") + return +end + +settings = lua_util.override_defaults(settings, opts) +local res, err = settings_schema:transform(settings) + +if not res then + rspamd_logger.warnx(rspamd_config, 'plugin is misconfigured: %s', err) + local err_msg = string.format("schema error: %s", res) + lua_util.config_utils.push_config_error(N, err_msg) + lua_util.disable_module(N, "failed", err_msg) + return +end + +rspamd_logger.infox(rspamd_config, 'enabled BIMI plugin') + +settings = res +redis_params = lua_redis.parse_redis_server(N, opts) + +if redis_params then + local id = rspamd_config:register_symbol({ + name = 'BIMI_CHECK', + type = 'normal', + callback = bimi_callback, + augmentations = { string.format("timeout=%f", settings.helper_timeout or + redis_params.timeout or 0.0) } + }) + rspamd_config:register_symbol { + name = 'BIMI_VALID', + type = 'virtual', + parent = id, + score = 0.0 + } + + rspamd_config:register_dependency('BIMI_CHECK', 'DMARC_CHECK') +else + lua_util.disable_module(N, "redis") +end diff --git a/src/plugins/lua/clickhouse.lua b/src/plugins/lua/clickhouse.lua new file mode 100644 index 0000000..25eabc7 --- /dev/null +++ b/src/plugins/lua/clickhouse.lua @@ -0,0 +1,1556 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local rspamd_logger = require 'rspamd_logger' +local upstream_list = require "rspamd_upstream_list" +local lua_util = require "lua_util" +local lua_clickhouse = require "lua_clickhouse" +local lua_settings = require "lua_settings" +local fun = require "fun" + +local N = "clickhouse" + +if confighelp then + return +end + +local data_rows = {} +local custom_rows = {} +local nrows = 0 +local used_memory = 0 +local last_collection = 0 +local final_call = false -- If the final collection has been started +local schema_version = 9 -- Current schema version + +local settings = { + limits = { -- Collection limits + max_rows = 1000, -- How many rows are allowed (0 for disable this) + max_memory = 50 * 1024 * 1024, -- How many memory should be occupied before sending collection + max_interval = 60, -- Maximum collection interval + }, + collect_garbage = false, -- Perform GC collection after sending the data + check_timeout = 10.0, -- Periodic timeout + timeout = 5.0, + bayes_spam_symbols = { 'BAYES_SPAM' }, + bayes_ham_symbols = { 'BAYES_HAM' }, + ann_symbols_spam = { 'NEURAL_SPAM' }, + ann_symbols_ham = { 'NEURAL_HAM' }, + fuzzy_symbols = { 'FUZZY_DENIED' }, + whitelist_symbols = { 'WHITELIST_DKIM', 'WHITELIST_SPF_DKIM', 'WHITELIST_DMARC' }, + dkim_allow_symbols = { 'R_DKIM_ALLOW' }, + dkim_reject_symbols = { 'R_DKIM_REJECT' }, + dkim_dnsfail_symbols = { 'R_DKIM_TEMPFAIL', 'R_DKIM_PERMFAIL' }, + dkim_na_symbols = { 'R_DKIM_NA' }, + dmarc_allow_symbols = { 'DMARC_POLICY_ALLOW' }, + dmarc_reject_symbols = { 'DMARC_POLICY_REJECT' }, + dmarc_quarantine_symbols = { 'DMARC_POLICY_QUARANTINE' }, + dmarc_softfail_symbols = { 'DMARC_POLICY_SOFTFAIL' }, + dmarc_na_symbols = { 'DMARC_NA' }, + spf_allow_symbols = { 'R_SPF_ALLOW' }, + spf_reject_symbols = { 'R_SPF_FAIL' }, + spf_dnsfail_symbols = { 'R_SPF_DNSFAIL', 'R_SPF_PERMFAIL' }, + spf_neutral_symbols = { 'R_DKIM_TEMPFAIL', 'R_DKIM_PERMFAIL' }, + spf_na_symbols = { 'R_SPF_NA' }, + stop_symbols = {}, + ipmask = 19, + ipmask6 = 48, + full_urls = false, + from_tables = nil, + enable_symbols = false, + database = 'default', + use_https = false, + use_gzip = true, + allow_local = false, + insert_subject = false, + subject_privacy = false, -- subject privacy is off + subject_privacy_alg = 'blake2', -- default hash-algorithm to obfuscate subject + subject_privacy_prefix = 'obf', -- prefix to show it's obfuscated + subject_privacy_length = 16, -- cut the length of the hash + schema_additions = {}, -- additional SQL statements to be executed when schema is uploaded + user = nil, + password = nil, + no_ssl_verify = false, + custom_rules = {}, + enable_digest = false, + exceptions = nil, + retention = { + enable = false, + method = 'detach', + period_months = 3, + run_every = '7d', + }, + extra_columns = {}, +} + +--- @language SQL +local clickhouse_schema = { [[ +CREATE TABLE IF NOT EXISTS rspamd +( + Date Date COMMENT 'Date (used for partitioning)', + TS DateTime COMMENT 'Date and time of the request start (UTC)', + From String COMMENT 'Domain part of the return address (RFC5321.MailFrom)', + MimeFrom String COMMENT 'Domain part of the address in From: header (RFC5322.From)', + IP String COMMENT 'SMTP client IP as provided by MTA or from Received: header', + Helo String COMMENT 'Full hostname as sent by the SMTP client (RFC5321.HELO/.EHLO)', + Score Float32 COMMENT 'Message score', + NRcpt UInt8 COMMENT 'Number of envelope recipients (RFC5321.RcptTo)', + Size UInt32 COMMENT 'Message size in bytes', + IsWhitelist Enum8('blacklist' = 0, 'whitelist' = 1, 'unknown' = 2) DEFAULT 'unknown' COMMENT 'Based on symbols configured in `whitelist_symbols` module option', + IsBayes Enum8('ham' = 0, 'spam' = 1, 'unknown' = 2) DEFAULT 'unknown' COMMENT 'Based on symbols configured in `bayes_spam_symbols` and `bayes_ham_symbols` module options', + IsFuzzy Enum8('whitelist' = 0, 'deny' = 1, 'unknown' = 2) DEFAULT 'unknown' COMMENT 'Based on symbols configured in `fuzzy_symbols` module option', + IsFann Enum8('ham' = 0, 'spam' = 1, 'unknown' = 2) DEFAULT 'unknown' COMMENT 'Based on symbols configured in `ann_symbols_spam` and `ann_symbols_ham` module options', + IsDkim Enum8('reject' = 0, 'allow' = 1, 'unknown' = 2, 'dnsfail' = 3, 'na' = 4) DEFAULT 'unknown' COMMENT 'Based on symbols configured in dkim_* module options', + IsDmarc Enum8('reject' = 0, 'allow' = 1, 'unknown' = 2, 'softfail' = 3, 'na' = 4, 'quarantine' = 5) DEFAULT 'unknown' COMMENT 'Based on symbols configured in dmarc_* module options', + IsSpf Enum8('reject' = 0, 'allow' = 1, 'neutral' = 2, 'dnsfail' = 3, 'na' = 4, 'unknown' = 5) DEFAULT 'unknown' COMMENT 'Based on symbols configured in spf_* module options', + NUrls Int32 COMMENT 'Number of URLs and email extracted from the message', + Action Enum8('reject' = 0, 'rewrite subject' = 1, 'add header' = 2, 'greylist' = 3, 'no action' = 4, 'soft reject' = 5, 'custom' = 6) DEFAULT 'no action' COMMENT 'Action returned for the message; if action is not predefined actual action will be in `CustomAction` field', + CustomAction LowCardinality(String) COMMENT 'Action string for custom action', + FromUser String COMMENT 'Local part of the return address (RFC5321.MailFrom)', + MimeUser String COMMENT 'Local part of the address in From: header (RFC5322.From)', + RcptUser String COMMENT '[Deprecated] Local part of the first envelope recipient (RFC5321.RcptTo)', + RcptDomain String COMMENT '[Deprecated] Domain part of the first envelope recipient (RFC5321.RcptTo)', + SMTPRecipients Array(String) COMMENT 'List of envelope recipients (RFC5321.RcptTo)', + MimeRecipients Array(String) COMMENT 'List of recipients from headers (RFC5322.To/.CC/.BCC)', + MessageId String COMMENT 'Message-ID header', + ListId String COMMENT 'List-Id header', + Subject String COMMENT 'Subject header (or hash if `subject_privacy` module option enabled)', + `Attachments.FileName` Array(String) COMMENT 'Attachment name', + `Attachments.ContentType` Array(String) COMMENT 'Attachment Content-Type', + `Attachments.Length` Array(UInt32) COMMENT 'Attachment size in bytes', + `Attachments.Digest` Array(FixedString(16)) COMMENT 'First 16 characters of hash returned by mime_part:get_digest()', + `Urls.Tld` Array(String) COMMENT 'Effective second level domain part of the URL host', + `Urls.Url` Array(String) COMMENT 'Full URL if `full_urls` module option enabled, host part of URL otherwise', + `Urls.Flags` Array(UInt32) COMMENT 'Corresponding url flags, see `enum rspamd_url_flags` in libserver/url.h for details', + Emails Array(String) COMMENT 'List of emails extracted from the message', + ASN UInt32 COMMENT 'BGP AS number for SMTP client IP (returned by asn.rspamd.com or asn6.rspamd.com)', + Country FixedString(2) COMMENT 'Country for SMTP client IP (returned by asn.rspamd.com or asn6.rspamd.com)', + IPNet String, + `Symbols.Names` Array(LowCardinality(String)) COMMENT 'Symbol name', + `Symbols.Scores` Array(Float32) COMMENT 'Symbol score', + `Symbols.Options` Array(String) COMMENT 'Symbol options (comma separated list)', + `Groups.Names` Array(LowCardinality(String)) COMMENT 'Group name', + `Groups.Scores` Array(Float32) COMMENT 'Group score', + ScanTimeReal UInt32 COMMENT 'Request time in milliseconds', + ScanTimeVirtual UInt32 COMMENT 'Deprecated do not use', + AuthUser String COMMENT 'Username for authenticated SMTP client', + SettingsId LowCardinality(String) COMMENT 'ID for the settings profile', + Digest FixedString(32) COMMENT '[Deprecated]', + SMTPFrom ALIAS if(From = '', '', concat(FromUser, '@', From)) COMMENT 'Return address (RFC5321.MailFrom)', + SMTPRcpt ALIAS SMTPRecipients[1] COMMENT 'The first envelope recipient (RFC5321.RcptTo)', + MIMEFrom ALIAS if(MimeFrom = '', '', concat(MimeUser, '@', MimeFrom)) COMMENT 'Address in From: header (RFC5322.From)', + MIMERcpt ALIAS MimeRecipients[1] COMMENT 'The first recipient from headers (RFC5322.To/.CC/.BCC)' +) ENGINE = MergeTree() +PARTITION BY toMonday(Date) +ORDER BY TS +]], + [[CREATE TABLE IF NOT EXISTS rspamd_version ( Version UInt32) ENGINE = TinyLog]], + { [[INSERT INTO rspamd_version (Version) Values (${SCHEMA_VERSION})]], true }, +} + +-- This describes SQL queries to migrate between versions +local migrations = { + [1] = { + -- Move to a wide fat table + [[ALTER TABLE rspamd + ADD COLUMN IF NOT EXISTS `Attachments.FileName` Array(String) AFTER ListId, + ADD COLUMN IF NOT EXISTS `Attachments.ContentType` Array(String) AFTER `Attachments.FileName`, + ADD COLUMN IF NOT EXISTS `Attachments.Length` Array(UInt32) AFTER `Attachments.ContentType`, + ADD COLUMN IF NOT EXISTS `Attachments.Digest` Array(FixedString(16)) AFTER `Attachments.Length`, + ADD COLUMN IF NOT EXISTS `Urls.Tld` Array(String) AFTER `Attachments.Digest`, + ADD COLUMN IF NOT EXISTS `Urls.Url` Array(String) AFTER `Urls.Tld`, + ADD COLUMN IF NOT EXISTS Emails Array(String) AFTER `Urls.Url`, + ADD COLUMN IF NOT EXISTS ASN UInt32 AFTER Emails, + ADD COLUMN IF NOT EXISTS Country FixedString(2) AFTER ASN, + ADD COLUMN IF NOT EXISTS IPNet String AFTER Country, + ADD COLUMN IF NOT EXISTS `Symbols.Names` Array(String) AFTER IPNet, + ADD COLUMN IF NOT EXISTS `Symbols.Scores` Array(Float64) AFTER `Symbols.Names`, + ADD COLUMN IF NOT EXISTS `Symbols.Options` Array(String) AFTER `Symbols.Scores`]], + -- Add explicit version + [[CREATE TABLE rspamd_version ( Version UInt32) ENGINE = TinyLog]], + [[INSERT INTO rspamd_version (Version) Values (2)]], + }, + [2] = { + -- Add `Subject` column + [[ALTER TABLE rspamd + ADD COLUMN IF NOT EXISTS Subject String AFTER ListId]], + -- New version + [[INSERT INTO rspamd_version (Version) Values (3)]], + }, + [3] = { + [[ALTER TABLE rspamd + ADD COLUMN IF NOT EXISTS IsSpf Enum8('reject' = 0, 'allow' = 1, 'neutral' = 2, 'dnsfail' = 3, 'na' = 4, 'unknown' = 5) DEFAULT 'unknown' AFTER IsDmarc, + MODIFY COLUMN IsDkim Enum8('reject' = 0, 'allow' = 1, 'unknown' = 2, 'dnsfail' = 3, 'na' = 4) DEFAULT 'unknown', + MODIFY COLUMN IsDmarc Enum8('reject' = 0, 'allow' = 1, 'unknown' = 2, 'softfail' = 3, 'na' = 4, 'quarantine' = 5) DEFAULT 'unknown', + ADD COLUMN IF NOT EXISTS MimeRecipients Array(String) AFTER RcptDomain, + ADD COLUMN IF NOT EXISTS MessageId String AFTER MimeRecipients, + ADD COLUMN IF NOT EXISTS ScanTimeReal UInt32 AFTER `Symbols.Options`, + ADD COLUMN IF NOT EXISTS ScanTimeVirtual UInt32 AFTER ScanTimeReal]], + -- Add aliases + [[ALTER TABLE rspamd + ADD COLUMN IF NOT EXISTS SMTPFrom ALIAS if(From = '', '', concat(FromUser, '@', From)), + ADD COLUMN IF NOT EXISTS SMTPRcpt ALIAS if(RcptDomain = '', '', concat(RcptUser, '@', RcptDomain)), + ADD COLUMN IF NOT EXISTS MIMEFrom ALIAS if(MimeFrom = '', '', concat(MimeUser, '@', MimeFrom)), + ADD COLUMN IF NOT EXISTS MIMERcpt ALIAS MimeRecipients[1] + ]], + -- New version + [[INSERT INTO rspamd_version (Version) Values (4)]], + }, + [4] = { + [[ALTER TABLE rspamd + MODIFY COLUMN Action Enum8('reject' = 0, 'rewrite subject' = 1, 'add header' = 2, 'greylist' = 3, 'no action' = 4, 'soft reject' = 5, 'custom' = 6) DEFAULT 'no action', + ADD COLUMN IF NOT EXISTS CustomAction String AFTER Action + ]], + -- New version + [[INSERT INTO rspamd_version (Version) Values (5)]], + }, + [5] = { + [[ALTER TABLE rspamd + ADD COLUMN IF NOT EXISTS AuthUser String AFTER ScanTimeVirtual, + ADD COLUMN IF NOT EXISTS SettingsId LowCardinality(String) AFTER AuthUser + ]], + -- New version + [[INSERT INTO rspamd_version (Version) Values (6)]], + }, + [6] = { + -- Add new columns + [[ALTER TABLE rspamd + ADD COLUMN IF NOT EXISTS Helo String AFTER IP, + ADD COLUMN IF NOT EXISTS SMTPRecipients Array(String) AFTER RcptDomain + ]], + -- Modify SMTPRcpt alias + [[ + ALTER TABLE rspamd + MODIFY COLUMN SMTPRcpt ALIAS SMTPRecipients[1] + ]], + -- New version + [[INSERT INTO rspamd_version (Version) Values (7)]], + }, + [7] = { + -- Add new columns + [[ALTER TABLE rspamd + ADD COLUMN IF NOT EXISTS `Groups.Names` Array(LowCardinality(String)) AFTER `Symbols.Options`, + ADD COLUMN IF NOT EXISTS `Groups.Scores` Array(Float32) AFTER `Groups.Names` + ]], + -- New version + [[INSERT INTO rspamd_version (Version) Values (8)]], + }, + [8] = { + -- Add new columns + [[ALTER TABLE rspamd + ADD COLUMN IF NOT EXISTS `Urls.Flags` Array(UInt32) AFTER `Urls.Url` + ]], + -- New version + [[INSERT INTO rspamd_version (Version) Values (9)]], + }, +} + +local predefined_actions = { + ['reject'] = true, + ['rewrite subject'] = true, + ['add header'] = true, + ['greylist'] = true, + ['no action'] = true, + ['soft reject'] = true +} + +local function clickhouse_main_row(res) + local fields = { + 'Date', + 'TS', + 'From', + 'MimeFrom', + 'IP', + 'Helo', + 'Score', + 'NRcpt', + 'Size', + 'IsWhitelist', + 'IsBayes', + 'IsFuzzy', + 'IsFann', + 'IsDkim', + 'IsDmarc', + 'NUrls', + 'Action', + 'FromUser', + 'MimeUser', + 'RcptUser', + 'RcptDomain', + 'SMTPRecipients', + 'ListId', + 'Subject', + 'Digest', + -- 1.9.2 + + 'IsSpf', + 'MimeRecipients', + 'MessageId', + 'ScanTimeReal', + -- 1.9.3 + + 'CustomAction', + -- 2.0 + + 'AuthUser', + 'SettingsId', + } + + for _, v in ipairs(fields) do + table.insert(res, v) + end +end + +local function clickhouse_attachments_row(res) + local fields = { + 'Attachments.FileName', + 'Attachments.ContentType', + 'Attachments.Length', + 'Attachments.Digest', + } + + for _, v in ipairs(fields) do + table.insert(res, v) + end +end + +local function clickhouse_urls_row(res) + local fields = { + 'Urls.Tld', + 'Urls.Url', + 'Urls.Flags', + } + for _, v in ipairs(fields) do + table.insert(res, v) + end +end + +local function clickhouse_emails_row(res) + local fields = { + 'Emails', + } + for _, v in ipairs(fields) do + table.insert(res, v) + end +end + +local function clickhouse_symbols_row(res) + local fields = { + 'Symbols.Names', + 'Symbols.Scores', + 'Symbols.Options', + } + for _, v in ipairs(fields) do + table.insert(res, v) + end +end + +local function clickhouse_groups_row(res) + local fields = { + 'Groups.Names', + 'Groups.Scores', + } + for _, v in ipairs(fields) do + table.insert(res, v) + end +end + +local function clickhouse_asn_row(res) + local fields = { + 'ASN', + 'Country', + 'IPNet', + } + for _, v in ipairs(fields) do + table.insert(res, v) + end +end + +local function clickhouse_extra_columns(res) + for _, v in ipairs(settings.extra_columns) do + table.insert(res, v.name) + end +end + +local function today(ts) + return os.date('!%Y-%m-%d', ts) +end + +local function clickhouse_check_symbol(task, settings_field_name, fields_table, + field_name, value, value_negative) + for _, s in ipairs(settings[settings_field_name] or {}) do + if task:has_symbol(s) then + if value_negative then + local sym = task:get_symbol(s)[1] + if sym['score'] > 0 then + fields_table[field_name] = value + else + fields_table[field_name] = value_negative + end + else + fields_table[field_name] = value + end + + return true + end + end + + return false +end + +local function clickhouse_send_data(task, ev_base, why, gen_rows, cust_rows) + local log_object = task or rspamd_config + local upstream = settings.upstream:get_upstream_round_robin() + local ip_addr = upstream:get_addr():to_string(true) + rspamd_logger.infox(log_object, "trying to send %s rows to clickhouse server %s; started as %s", + #gen_rows + #cust_rows, ip_addr, why) + + local function gen_success_cb(what, how_many) + return function(_, _) + rspamd_logger.messagex(log_object, "sent %s rows of %s to clickhouse server %s; started as %s", + how_many, what, ip_addr, why) + upstream:ok() + end + end + + local function gen_fail_cb(what, how_many) + return function(_, err) + rspamd_logger.errx(log_object, "cannot send %s rows of %s data to clickhouse server %s: %s; started as %s", + how_many, what, ip_addr, err, why) + upstream:fail() + end + end + + local function send_data(what, tbl, query) + local ch_params = {} + if task then + ch_params.task = task + else + ch_params.config = rspamd_config + ch_params.ev_base = ev_base + end + + local ret = lua_clickhouse.insert(upstream, settings, ch_params, + query, tbl, + gen_success_cb(what, #tbl), + gen_fail_cb(what, #tbl)) + if not ret then + rspamd_logger.errx(log_object, "cannot send %s rows of %s data to clickhouse server %s: %s", + #tbl, what, ip_addr, 'cannot make HTTP request') + end + end + + local fields = {} + clickhouse_main_row(fields) + clickhouse_attachments_row(fields) + clickhouse_urls_row(fields) + clickhouse_emails_row(fields) + clickhouse_asn_row(fields) + + if settings.enable_symbols then + clickhouse_symbols_row(fields) + clickhouse_groups_row(fields) + end + + if #settings.extra_columns > 0 then + clickhouse_extra_columns(fields) + end + + send_data('generic data', gen_rows, + string.format('INSERT INTO rspamd (%s)', + table.concat(fields, ','))) + + for k, crows in pairs(cust_rows) do + if #crows > 1 then + send_data('custom data (' .. k .. ')', crows, + settings.custom_rules[k].first_row()) + end + end +end + +local function clickhouse_collect(task) + if task:has_flag('skip') then + return + end + + if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then + return + end + + for _, sym in ipairs(settings.stop_symbols) do + if task:has_symbol(sym) then + rspamd_logger.infox(task, 'skip Clickhouse storage for message: symbol %s has fired', sym) + return + end + end + + if settings.exceptions then + local excepted, trace = settings.exceptions:process(task) + if excepted then + rspamd_logger.infox(task, 'skipped Clickhouse storage for message: excepted (%s)', + trace) + -- Excepted + return + end + end + + local from_domain = '' + local from_user = '' + if task:has_from('smtp') then + local from = task:get_from({ 'smtp', 'orig' })[1] + + if from then + from_domain = from['domain']:lower() + from_user = from['user'] + end + end + + local mime_domain = '' + local mime_user = '' + if task:has_from('mime') then + local from = task:get_from({ 'mime', 'orig' })[1] + if from then + mime_domain = from['domain']:lower() + mime_user = from['user'] + end + end + + local mime_recipients = {} + if task:has_recipients('mime') then + local recipients = task:get_recipients({ 'mime', 'orig' }) + for _, rcpt in ipairs(recipients) do + table.insert(mime_recipients, rcpt['user'] .. '@' .. rcpt['domain']:lower()) + end + end + + local ip_str = 'undefined' + local ip = task:get_from_ip() + if ip and ip:is_valid() then + local ipnet + if ip:get_version() == 4 then + ipnet = ip:apply_mask(settings['ipmask']) + else + ipnet = ip:apply_mask(settings['ipmask6']) + end + ip_str = ipnet:to_string() + end + + local helo = task:get_helo() or '' + + local rcpt_user = '' + local rcpt_domain = '' + local smtp_recipients = {} + if task:has_recipients('smtp') then + local recipients = task:get_recipients('smtp') + -- for compatibility with an old table structure + rcpt_user = recipients[1]['user'] + rcpt_domain = recipients[1]['domain']:lower() + + for _, rcpt in ipairs(recipients) do + table.insert(smtp_recipients, rcpt['user'] .. '@' .. rcpt['domain']:lower()) + end + end + + local list_id = task:get_header('List-Id') or '' + local message_id = lua_util.maybe_obfuscate_string(task:get_message_id() or '', + settings, 'mid') + + local score = task:get_metric_score()[1]; + local fields = { + bayes = 'unknown', + fuzzy = 'unknown', + ann = 'unknown', + whitelist = 'unknown', + dkim = 'unknown', + dmarc = 'unknown', + spf = 'unknown', + } + + local ret + + ret = clickhouse_check_symbol(task, 'bayes_spam_symbols', fields, + 'bayes', 'spam') + if not ret then + clickhouse_check_symbol(task, 'bayes_ham_symbols', fields, + 'bayes', 'ham') + end + + clickhouse_check_symbol(task, 'ann_symbols_spam', fields, + 'ann', 'spam') + if not ret then + clickhouse_check_symbol(task, 'ann_symbols_ham', fields, + 'ann', 'ham') + end + + clickhouse_check_symbol(task, 'whitelist_symbols', fields, + 'whitelist', 'blacklist', 'whitelist') + + clickhouse_check_symbol(task, 'fuzzy_symbols', fields, + 'fuzzy', 'deny') + + ret = clickhouse_check_symbol(task, 'dkim_allow_symbols', fields, + 'dkim', 'allow') + if not ret then + ret = clickhouse_check_symbol(task, 'dkim_reject_symbols', fields, + 'dkim', 'reject') + end + if not ret then + ret = clickhouse_check_symbol(task, 'dkim_dnsfail_symbols', fields, + 'dkim', 'dnsfail') + end + if not ret then + clickhouse_check_symbol(task, 'dkim_na_symbols', fields, + 'dkim', 'na') + end + + ret = clickhouse_check_symbol(task, 'dmarc_allow_symbols', fields, + 'dmarc', 'allow') + if not ret then + ret = clickhouse_check_symbol(task, 'dmarc_reject_symbols', fields, + 'dmarc', 'reject') + end + if not ret then + ret = clickhouse_check_symbol(task, 'dmarc_quarantine_symbols', fields, + 'dmarc', 'quarantine') + end + if not ret then + ret = clickhouse_check_symbol(task, 'dmarc_softfail_symbols', fields, + 'dmarc', 'softfail') + end + if not ret then + clickhouse_check_symbol(task, 'dmarc_na_symbols', fields, + 'dmarc', 'na') + end + + ret = clickhouse_check_symbol(task, 'spf_allow_symbols', fields, + 'spf', 'allow') + if not ret then + ret = clickhouse_check_symbol(task, 'spf_reject_symbols', fields, + 'spf', 'reject') + end + if not ret then + ret = clickhouse_check_symbol(task, 'spf_neutral_symbols', fields, + 'spf', 'neutral') + end + if not ret then + ret = clickhouse_check_symbol(task, 'spf_dnsfail_symbols', fields, + 'spf', 'dnsfail') + end + if not ret then + clickhouse_check_symbol(task, 'spf_na_symbols', fields, + 'spf', 'na') + end + + local nrcpts = 0 + if task:has_recipients('smtp') then + nrcpts = #task:get_recipients('smtp') + end + + local nurls = 0 + local task_urls = task:get_urls({ + content = true, + images = true, + emails = false, + sort = true, + }) or {} + + nurls = #task_urls + + local timestamp = math.floor(task:get_date({ + format = 'connect', + gmt = true, -- The only sane way to sync stuff with different timezones + })) + + local action = task:get_metric_action() + local custom_action = '' + + if not predefined_actions[action] then + custom_action = action + action = 'custom' + end + + local digest = '' + + if settings.enable_digest then + digest = task:get_digest() + end + + local subject = '' + if settings.insert_subject then + subject = lua_util.maybe_obfuscate_string(task:get_subject() or '', settings, 'subject') + end + + local scan_real = task:get_scan_time() + scan_real = math.floor(scan_real * 1000) + if scan_real < 0 then + rspamd_logger.messagex(task, + 'clock skew detected for message: %s ms real scan time (reset to 0)', + scan_real) + scan_real = 0 + end + + local auth_user = task:get_user() or '' + local settings_id = task:get_settings_id() + + if settings_id then + -- Convert to string + settings_id = lua_settings.settings_by_id(settings_id) + + if settings_id then + settings_id = settings_id.name + end + end + + if not settings_id then + settings_id = '' + end + + local row = { + today(timestamp), + timestamp, + from_domain, + mime_domain, + ip_str, + helo, + score, + nrcpts, + task:get_size(), + fields.whitelist, + fields.bayes, + fields.fuzzy, + fields.ann, + fields.dkim, + fields.dmarc, + nurls, + action, + from_user, + mime_user, + rcpt_user, + rcpt_domain, + smtp_recipients, + list_id, + subject, + digest, + fields.spf, + mime_recipients, + message_id, + scan_real, + custom_action, + auth_user, + settings_id + } + + -- Attachments step + local attachments_fnames = {} + local attachments_ctypes = {} + local attachments_lengths = {} + local attachments_digests = {} + for _, part in ipairs(task:get_parts()) do + if part:is_attachment() then + table.insert(attachments_fnames, part:get_filename() or '') + local mime_type, mime_subtype = part:get_type() + table.insert(attachments_ctypes, string.format("%s/%s", mime_type, mime_subtype)) + table.insert(attachments_lengths, part:get_length()) + table.insert(attachments_digests, string.sub(part:get_digest(), 1, 16)) + end + end + + if #attachments_fnames > 0 then + table.insert(row, attachments_fnames) + table.insert(row, attachments_ctypes) + table.insert(row, attachments_lengths) + table.insert(row, attachments_digests) + else + table.insert(row, {}) + table.insert(row, {}) + table.insert(row, {}) + table.insert(row, {}) + end + + -- Urls step + local urls_urls = {} + local urls_tlds = {} + local urls_flags = {} + + if settings.full_urls then + for i, u in ipairs(task_urls) do + urls_urls[i] = u:get_text() + urls_tlds[i] = u:get_tld() or u:get_host() + urls_flags[i] = u:get_flags_num() + end + else + -- We need to store unique + local mt = { + ord_tbl = {}, -- ordered list of urls + idx_tbl = {}, -- indexed by host + flags, reference to an index in ord_tbl + __newindex = function(t, k, v) + local idx = getmetatable(t).idx_tbl + local ord = getmetatable(t).ord_tbl + local key = k:get_host() .. tostring(k:get_flags_num()) + if idx[key] then + ord[idx[key]] = v -- replace + else + ord[#ord + 1] = v + idx[key] = #ord + end + end, + __index = function(t, k) + local ord = getmetatable(t).ord_tbl + if type(k) == 'number' then + return ord[k] + else + local idx = getmetatable(t).idx_tbl + local key = k:get_host() .. tostring(k:get_flags_num()) + if idx[key] then + return ord[idx[key]] + end + end + end, + } + -- Extra index needed for making this unique + local urls_idx = {} + setmetatable(urls_idx, mt) + for _, u in ipairs(task_urls) do + if not urls_idx[u] then + urls_idx[u] = u + urls_urls[#urls_urls + 1] = u:get_host() + urls_tlds[#urls_tlds + 1] = u:get_tld() or u:get_host() + urls_flags[#urls_flags + 1] = u:get_flags_num() + end + end + end + + + -- Get tlds + table.insert(row, urls_tlds) + -- Get hosts/full urls + table.insert(row, urls_urls) + -- Numeric flags + table.insert(row, urls_flags) + + -- Emails step + if task:has_urls(true) then + local emails = task:get_emails() or {} + local emails_formatted = {} + for i, u in ipairs(emails) do + emails_formatted[i] = string.format('%s@%s', u:get_user(), u:get_host()) + end + table.insert(row, emails_formatted) + else + table.insert(row, {}) + end + + -- ASN information + local asn, country, ipnet = 0, '--', '--' + local pool = task:get_mempool() + ret = pool:get_variable("asn") + if ret then + asn = ret + end + ret = pool:get_variable("country") + if ret then + country = ret:sub(1, 2) + end + ret = pool:get_variable("ipnet") + if ret then + ipnet = ret + end + table.insert(row, asn) + table.insert(row, country) + table.insert(row, ipnet) + + -- Symbols info + if settings.enable_symbols then + local symbols = task:get_symbols_all() + local syms_tab = {} + local scores_tab = {} + local options_tab = {} + + for _, s in ipairs(symbols) do + table.insert(syms_tab, s.name or '') + table.insert(scores_tab, s.score) + + if s.options then + table.insert(options_tab, table.concat(s.options, ',')) + else + table.insert(options_tab, ''); + end + end + table.insert(row, syms_tab) + table.insert(row, scores_tab) + table.insert(row, options_tab) + + -- Groups data + local groups = task:get_groups() + local groups_tab = {} + local gr_scores_tab = {} + for gr, sc in pairs(groups) do + table.insert(groups_tab, gr) + table.insert(gr_scores_tab, sc) + end + table.insert(row, groups_tab) + table.insert(row, gr_scores_tab) + end + + -- Extra columns + if #settings.extra_columns > 0 then + for _, col in ipairs(settings.extra_columns) do + local elts = col.real_selector(task) + + if elts then + table.insert(row, elts) + else + table.insert(row, col.default_value) + end + end + end + + -- Custom data + for k, rule in pairs(settings.custom_rules) do + if not custom_rows[k] then + custom_rows[k] = {} + end + table.insert(custom_rows[k], lua_clickhouse.row_to_tsv(rule.get_row(task))) + end + + local tsv_row = lua_clickhouse.row_to_tsv(row) + used_memory = used_memory + #tsv_row + data_rows[#data_rows + 1] = tsv_row + nrows = nrows + 1 + lua_util.debugm(N, task, + "add clickhouse row %s / %s; used memory: %s / %s", + nrows, settings.limits.max_rows, + used_memory, settings.limits.max_memory) +end + +local function do_remove_partition(ev_base, cfg, table_name, partition) + lua_util.debugm(N, rspamd_config, "removing partition %s.%s", table_name, partition) + local upstream = settings.upstream:get_upstream_round_robin() + local remove_partition_sql = "ALTER TABLE ${table_name} ${remove_method} PARTITION '${partition}'" + local remove_method = (settings.retention.method == 'drop') and 'DROP' or 'DETACH' + local sql_params = { + ['table_name'] = table_name, + ['remove_method'] = remove_method, + ['partition'] = partition + } + + local sql = lua_util.template(remove_partition_sql, sql_params) + + local ch_params = { + body = sql, + ev_base = ev_base, + config = cfg, + } + + local err, _ = lua_clickhouse.generic_sync(upstream, settings, ch_params, sql) + if err then + rspamd_logger.errx(rspamd_config, + "cannot detach partition %s:%s from server %s: %s", + table_name, partition, + settings['server'], err) + return + end + + rspamd_logger.infox(rspamd_config, + 'detached partition %s:%s on server %s', table_name, partition, + settings['server']) + +end + +--[[ + nil - file is not writable, do not perform removal + 0 - it's time to perform removal + <int> - how many seconds wait until next run +]] +local function get_last_removal_ago() + local ts_file = string.format('%s/%s', rspamd_paths['DBDIR'], 'clickhouse_retention_run') + local last_ts + local current_ts = os.time() + + local function write_ts_to_file() + local write_file, err = io.open(ts_file, 'w') + if err then + rspamd_logger.errx(rspamd_config, 'Failed to open %s, will not perform retention: %s', ts_file, err) + return nil + end + + local res + res, err = write_file:write(tostring(current_ts)) + if err or res == nil then + write_file:close() + rspamd_logger.errx(rspamd_config, 'Failed to write %s, will not perform retention: %s', ts_file, err) + return nil + end + write_file:close() + + return true + end + + local f, err = io.open(ts_file, 'r') + if err then + lua_util.debugm(N, rspamd_config, 'Failed to open %s: %s', ts_file, err) + else + last_ts = tonumber(f:read('*number')) + f:close() + end + + if last_ts == nil or (last_ts + settings.retention.period) <= current_ts then + return write_ts_to_file() and 0 + end + + if last_ts > current_ts then + -- Clock skew detected, overwrite last_ts with current_ts and wait for the next + -- retention period + rspamd_logger.errx(rspamd_config, 'Last collection time is in future: %s; overwrite it with %s in %s', + last_ts, current_ts, ts_file) + return write_ts_to_file() and -1 + end + + return (last_ts + settings.retention.period) - current_ts +end + +local function clickhouse_maybe_send_data_periodic(cfg, ev_base, now) + local need_collect = false + local reason + + if nrows == 0 then + lua_util.debugm(N, cfg, "no need to send data, as there are no rows to collect") + return settings.check_timeout + end + + if final_call then + lua_util.debugm(N, cfg, "no need to send data, final call has been issued") + return 0 + end + + if settings.limits.max_rows > 0 then + if nrows > settings.limits.max_rows then + need_collect = true + reason = string.format('limit of rows has been reached: %d', nrows) + end + end + + if last_collection > 0 and settings.limits.max_interval > 0 then + if now - last_collection > settings.limits.max_interval then + need_collect = true + reason = string.format('limit of time since last collection has been reached: %d seconds passed ' .. + '(%d seconds trigger)', + (now - last_collection), settings.limits.max_interval) + end + end + + if settings.limits.max_memory > 0 then + if used_memory >= settings.limits.max_memory then + need_collect = true + reason = string.format('limit of memory has been reached: %d bytes used', + used_memory) + end + end + + if last_collection == 0 then + last_collection = now + end + + if need_collect then + -- Do it atomic + local saved_rows = data_rows + local saved_custom = custom_rows + nrows = 0 + last_collection = now + used_memory = 0 + data_rows = {} + custom_rows = {} + + clickhouse_send_data(nil, ev_base, reason, saved_rows, saved_custom) + + if settings.collect_garbage then + collectgarbage() + end + end + + return settings.check_timeout +end + +local function clickhouse_remove_old_partitions(cfg, ev_base) + local last_time_ago = get_last_removal_ago() + if last_time_ago == nil then + rspamd_logger.errx(rspamd_config, "Failed to get last run time. Disabling retention") + return false + elseif last_time_ago ~= 0 then + return last_time_ago + end + + local upstream = settings.upstream:get_upstream_round_robin() + local partition_to_remove_sql = "SELECT partition, table " .. + "FROM system.parts WHERE table IN ('${tables}') " .. + "GROUP BY partition, table " .. + "HAVING max(max_date) < toDate(now() - interval ${month} month)" + + local table_names = { 'rspamd' } + local tables = table.concat(table_names, "', '") + local sql_params = { + tables = tables, + month = settings.retention.period_months, + } + local sql = lua_util.template(partition_to_remove_sql, sql_params) + + local ch_params = { + ev_base = ev_base, + config = cfg, + } + local err, rows = lua_clickhouse.select_sync(upstream, settings, ch_params, sql) + if err then + rspamd_logger.errx(rspamd_config, + "cannot send data to clickhouse server %s: %s", + settings['server'], err) + else + fun.each(function(row) + do_remove_partition(ev_base, cfg, row.table, row.partition) + end, rows) + end + + -- settings.retention.period is added on initialisation, see below + return settings.retention.period +end + +local function upload_clickhouse_schema(upstream, ev_base, cfg, initial) + local ch_params = { + ev_base = ev_base, + config = cfg, + } + + local errored = false + + -- Upload a single element of the schema + local function upload_schema_elt(v) + if errored then + rspamd_logger.errx(rspamd_config, "cannot upload schema '%s' on clickhouse server %s: due to previous errors", + v, upstream:get_addr():to_string(true)) + return + end + local sql = v + local err, reply = lua_clickhouse.generic_sync(upstream, settings, ch_params, sql) + + if err then + rspamd_logger.errx(rspamd_config, "cannot upload schema '%s' on clickhouse server %s: %s", + sql, upstream:get_addr():to_string(true), err) + errored = true + return + end + rspamd_logger.debugm(N, rspamd_config, 'uploaded clickhouse schema element %s to %s: %s', + v, upstream:get_addr():to_string(true), reply) + end + + -- Process element and return nil if statement should be skipped + local function preprocess_schema_elt(v) + if type(v) == 'string' then + return lua_util.template(v, { SCHEMA_VERSION = tostring(schema_version) }) + elseif type(v) == 'table' then + -- Pair of statement + boolean + if initial == v[2] then + return lua_util.template(v[1], { SCHEMA_VERSION = tostring(schema_version) }) + else + rspamd_logger.debugm(N, rspamd_config, 'skip clickhouse schema element %s: schema already exists', + v) + end + end + + return nil + end + + -- Apply schema elements sequentially, users additions are concatenated to the tail + fun.each(upload_schema_elt, + -- Also template schema version + fun.filter(function(v) + return v ~= nil + end, + fun.map(preprocess_schema_elt, + fun.chain(clickhouse_schema, settings.schema_additions) + ) + ) + ) +end + +local function maybe_apply_migrations(upstream, ev_base, cfg, version) + local ch_params = { + ev_base = ev_base, + config = cfg, + } + -- Apply migrations sequentially + local function migration_recursor(i) + if i < schema_version then + if migrations[i] then + -- We also need to apply statements sequentially + local function sql_recursor(j) + if migrations[i][j] then + local sql = migrations[i][j] + local ret = lua_clickhouse.generic(upstream, settings, ch_params, sql, + function(_, _) + rspamd_logger.infox(rspamd_config, + 'applied migration to version %s from version %s: %s', + i + 1, version, sql:gsub('[\n%s]+', ' ')) + if j == #migrations[i] then + -- Go to the next migration + migration_recursor(i + 1) + else + -- Apply the next statement + sql_recursor(j + 1) + end + end, + function(_, err) + rspamd_logger.errx(rspamd_config, + "cannot apply migration %s: '%s' on clickhouse server %s: %s", + i, sql, upstream:get_addr():to_string(true), err) + end) + if not ret then + rspamd_logger.errx(rspamd_config, + "cannot apply migration %s: '%s' on clickhouse server %s: cannot make request", + i, sql, upstream:get_addr():to_string(true)) + end + end + end + + sql_recursor(1) + else + -- Try another migration + migration_recursor(i + 1) + end + end + end + + migration_recursor(version) +end + +local function add_extra_columns(upstream, ev_base, cfg) + local ch_params = { + ev_base = ev_base, + config = cfg, + } + -- Apply migrations sequentially + local function columns_recursor(i) + if i <= #settings.extra_columns then + local col = settings.extra_columns[i] + local prev_column + if i == 1 then + prev_column = 'MIMERcpt' + else + prev_column = settings.extra_columns[i - 1].name + end + local sql = string.format('ALTER TABLE rspamd ADD COLUMN IF NOT EXISTS `%s` %s AFTER `%s`', + col.name, col.type, prev_column) + if col.comment then + sql = sql .. string.format(", COMMENT COLUMN IF EXISTS `%s` '%s'", col.name, col.comment) + end + + local ret = lua_clickhouse.generic(upstream, settings, ch_params, sql, + function(_, _) + rspamd_logger.infox(rspamd_config, + 'added extra column %s (%s) after %s', + col.name, col.type, prev_column) + -- Apply the next statement + columns_recursor(i + 1) + end, + function(_, err) + rspamd_logger.errx(rspamd_config, + "cannot apply add column alter %s: '%s' on clickhouse server %s: %s", + i, sql, upstream:get_addr():to_string(true), err) + end) + if not ret then + rspamd_logger.errx(rspamd_config, + "cannot apply add column alter %s: '%s' on clickhouse server %s: cannot make request", + i, sql, upstream:get_addr():to_string(true)) + end + end + end + + columns_recursor(1) +end + +local function check_rspamd_table(upstream, ev_base, cfg) + local ch_params = { + ev_base = ev_base, + config = cfg, + } + local sql = [[EXISTS TABLE rspamd]] + local err, rows = lua_clickhouse.select_sync(upstream, settings, ch_params, sql) + if err then + rspamd_logger.errx(rspamd_config, "cannot check rspamd table in clickhouse server %s: %s", + upstream:get_addr():to_string(true), err) + return + end + + if rows[1] and rows[1].result then + if tonumber(rows[1].result) == 1 then + -- Apply migration + upload_clickhouse_schema(upstream, ev_base, cfg, false) + rspamd_logger.infox(rspamd_config, 'table rspamd exists, check if we need to apply migrations') + maybe_apply_migrations(upstream, ev_base, cfg, 1) + else + -- Upload schema + rspamd_logger.infox(rspamd_config, 'table rspamd does not exists, upload full schema') + upload_clickhouse_schema(upstream, ev_base, cfg, true) + end + else + rspamd_logger.errx(rspamd_config, + "unexpected reply on EXISTS command from server %s: %s", + upstream:get_addr():to_string(true), rows) + end +end + +local function check_clickhouse_upstream(upstream, ev_base, cfg) + local ch_params = { + ev_base = ev_base, + config = cfg, + } + -- If we have some custom rules, we just send its schema to the upstream + for k, rule in pairs(settings.custom_rules) do + if rule.schema then + local sql = lua_util.template(rule.schema, settings) + local err, _ = lua_clickhouse.generic_sync(upstream, settings, ch_params, sql) + if err then + rspamd_logger.errx(rspamd_config, 'cannot send custom schema %s to clickhouse server %s: ' .. + 'cannot make request (%s)', + k, upstream:get_addr():to_string(true), err) + end + end + end + + -- Now check the main schema and apply migrations if needed + local sql = [[SELECT MAX(Version) as v FROM rspamd_version]] + local err, rows = lua_clickhouse.select_sync(upstream, settings, ch_params, sql) + if err then + if rows and rows.code == 404 then + rspamd_logger.infox(rspamd_config, + 'table rspamd_version does not exist, check rspamd table') + check_rspamd_table(upstream, ev_base, cfg) + else + rspamd_logger.errx(rspamd_config, + "cannot get version on clickhouse server %s: %s", + upstream:get_addr():to_string(true), err) + end + else + upload_clickhouse_schema(upstream, ev_base, cfg, false) + local version = tonumber(rows[1].v) + maybe_apply_migrations(upstream, ev_base, cfg, version) + end + + if #settings.extra_columns > 0 then + add_extra_columns(upstream, ev_base, cfg) + end +end + +local opts = rspamd_config:get_all_opt('clickhouse') +if opts then + -- Legacy `limit` options + if opts.limit and not opts.limits then + settings.limits.max_rows = opts.limit + end + for k, v in pairs(opts) do + if k == 'custom_rules' then + if not v[1] then + v = { v } + end + + for i, rule in ipairs(v) do + if rule.schema and rule.first_row and rule.get_row then + local first_row, get_row + local loadstring = loadstring or load + local ret, res_or_err = pcall(loadstring(rule.first_row)) + + if not ret or type(res_or_err) ~= 'function' then + rspamd_logger.errx(rspamd_config, 'invalid first_row (%s) - must be a function', + res_or_err) + else + first_row = res_or_err + end + + ret, res_or_err = pcall(loadstring(rule.get_row)) + + if not ret or type(res_or_err) ~= 'function' then + rspamd_logger.errx(rspamd_config, + 'invalid get_row (%s) - must be a function', + res_or_err) + else + get_row = res_or_err + end + + if first_row and get_row then + local name = rule.name or tostring(i) + settings.custom_rules[name] = { + schema = rule.schema, + first_row = first_row, + get_row = get_row, + } + end + else + rspamd_logger.errx(rspamd_config, 'custom rule has no required attributes: schema, first_row and get_row') + end + end + else + settings[k] = lua_util.deepcopy(v) + end + end + + if not settings['server'] and not settings['servers'] then + rspamd_logger.infox(rspamd_config, 'no servers are specified, disabling module') + lua_util.disable_module(N, "config") + else + local lua_maps = require "lua_maps" + settings['from_map'] = lua_maps.map_add('clickhouse', 'from_tables', + 'regexp', 'clickhouse specific domains') + + settings.upstream = upstream_list.create(rspamd_config, + settings['server'] or settings['servers'], 8123) + + if not settings.upstream then + rspamd_logger.errx(rspamd_config, 'cannot parse clickhouse address: %s', + settings['server'] or settings['servers']) + lua_util.disable_module(N, "config") + return + end + + if settings.exceptions then + local maps_expressions = require "lua_maps_expressions" + + settings.exceptions = maps_expressions.create(rspamd_config, + settings.exceptions, N) + end + + if settings.extra_columns then + -- Check sanity and create selector closures + local lua_selectors = require "lua_selectors" + local columns_transformed = {} + local need_sort = false + -- Select traverse function depending on what we have + local iter_func = settings.extra_columns[1] and ipairs or pairs + + for col_name, col_data in iter_func(settings.extra_columns) do + -- Array based extra columns + if col_data.name then + col_name = col_data.name + end + if not col_data.selector or not col_data.type then + rspamd_logger.errx(rspamd_config, 'cannot add clickhouse extra row %s: no type or no selector', + col_name) + else + local is_array = false + + if col_data.type:lower():match('^array') then + is_array = true + end + + local selector = lua_selectors.create_selector_closure(rspamd_config, + col_data.selector, col_data.delimiter or '', is_array) + + if not selector then + rspamd_logger.errx(rspamd_config, 'cannot add clickhouse extra row %s: bad selector: %s', + col_name, col_data.selector) + else + if not col_data.default_value then + if is_array then + col_data.default_value = {} + else + col_data.default_value = '' + end + end + col_data.real_selector = selector + if not col_data.name then + col_data.name = col_name + need_sort = true + end + table.insert(columns_transformed, col_data) + end + end + end + + -- Convert extra columns from a map to an array sorted by column name to + -- preserve strict order when doing altering + if need_sort then + rspamd_logger.infox(rspamd_config, 'sort extra columns as they are not configured as an array') + table.sort(columns_transformed, function(c1, c2) + return c1.name < c2.name + end) + end + settings.extra_columns = columns_transformed + end + + rspamd_config:register_symbol({ + name = 'CLICKHOUSE_COLLECT', + type = 'idempotent', + callback = clickhouse_collect, + flags = 'empty,explicit_disable,ignore_passthrough', + augmentations = { string.format("timeout=%f", settings.timeout) }, + }) + rspamd_config:register_finish_script(function(task) + if nrows > 0 then + final_call = true + local saved_rows = data_rows + local saved_custom = custom_rows + + nrows = 0 + data_rows = {} + used_memory = 0 + custom_rows = {} + + clickhouse_send_data(task, nil, 'final collection', + saved_rows, saved_custom) + + if settings.collect_garbage then + collectgarbage() + end + end + end) + -- Create tables on load + rspamd_config:add_on_load(function(cfg, ev_base, worker) + if worker:is_scanner() then + rspamd_config:add_periodic(ev_base, 0, + clickhouse_maybe_send_data_periodic, true) + end + if worker:is_primary_controller() then + local upstreams = settings.upstream:all_upstreams() + + for _, up in ipairs(upstreams) do + check_clickhouse_upstream(up, ev_base, cfg) + end + + if settings.retention.enable and settings.retention.method ~= 'drop' and + settings.retention.method ~= 'detach' then + rspamd_logger.errx(rspamd_config, + "retention.method should be either 'drop' or 'detach' (now: %s). Disabling retention", + settings.retention.method) + settings.retention.enable = false + end + if settings.retention.enable and settings.retention.period_months < 1 or + settings.retention.period_months > 1000 then + rspamd_logger.errx(rspamd_config, + "please, set retention.period_months between 1 and 1000 (now: %s). Disabling retention", + settings.retention.period_months) + settings.retention.enable = false + end + local period = lua_util.parse_time_interval(settings.retention.run_every) + if settings.retention.enable and period == nil then + rspamd_logger.errx(rspamd_config, "invalid value for retention.run_every (%s). Disabling retention", + settings.retention.run_every) + settings.retention.enable = false + end + + if settings.retention.enable then + settings.retention.period = period + rspamd_logger.infox(rspamd_config, + "retention will be performed each %s seconds for %s month with method %s", + period, settings.retention.period_months, settings.retention.method) + rspamd_config:add_periodic(ev_base, 0, clickhouse_remove_old_partitions, false) + end + end + end) + end +end diff --git a/src/plugins/lua/clustering.lua b/src/plugins/lua/clustering.lua new file mode 100644 index 0000000..d97bdb9 --- /dev/null +++ b/src/plugins/lua/clustering.lua @@ -0,0 +1,322 @@ +--[[ +Copyright (c) 2018, Vsevolod Stakhov + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if confighelp then + return +end + +-- Plugin for finding patterns in email flows + +local N = 'clustering' + +local rspamd_logger = require "rspamd_logger" +local lua_util = require "lua_util" +local lua_verdict = require "lua_verdict" +local lua_redis = require "lua_redis" +local lua_selectors = require "lua_selectors" +local ts = require("tableshape").types + +local redis_params + +local rules = {} -- Rules placement + +local default_rule = { + max_elts = 100, -- Maximum elements in a cluster + expire = 3600, -- Expire for a bucket when limit is not reached + expire_overflow = 36000, -- Expire for a bucket when limit is reached + spam_mult = 1.0, -- Increase on spam hit + junk_mult = 0.5, -- Increase on junk + ham_mult = -0.1, -- Increase on ham + size_mult = 0.01, -- Reaches 1.0 on `max_elts` + score_mult = 0.1, +} + +local rule_schema = ts.shape { + max_elts = ts.number + ts.string / tonumber, + expire = ts.number + ts.string / lua_util.parse_time_interval, + expire_overflow = ts.number + ts.string / lua_util.parse_time_interval, + spam_mult = ts.number, + junk_mult = ts.number, + ham_mult = ts.number, + size_mult = ts.number, + score_mult = ts.number, + source_selector = ts.string, + cluster_selector = ts.string, + symbol = ts.string:is_optional(), + prefix = ts.string:is_optional(), +} + +-- Redis scripts + +-- Queries for a cluster's data +-- Arguments: +-- 1. Source selector (string) +-- 2. Cluster selector (string) +-- Returns: {cur_elts, total_score, element_score} +local query_cluster_script = [[ +local sz = redis.call('HLEN', KEYS[1]) + +if not sz or not tonumber(sz) then + -- New bucket, will update on idempotent phase + return {0, '0', '0'} +end + +local total_score = redis.call('HGET', KEYS[1], '__s') +total_score = tonumber(total_score) or 0 +local score = redis.call('HGET', KEYS[1], KEYS[2]) +if not score or not tonumber(score) then + return {sz, tostring(total_score), '0'} +end +return {sz, tostring(total_score), tostring(score)} +]] +local query_cluster_id + +-- Updates cluster's data +-- Arguments: +-- 1. Source selector (string) +-- 2. Cluster selector (string) +-- 3. Score (number) +-- 4. Max buckets (number) +-- 5. Expire (number) +-- 6. Expire overflow (number) +-- Returns: nothing +local update_cluster_script = [[ +local sz = redis.call('HLEN', KEYS[1]) + +if not sz or not tonumber(sz) then + -- Create bucket + redis.call('HSET', KEYS[1], KEYS[2], math.abs(KEYS[3])) + redis.call('HSET', KEYS[1], '__s', KEYS[3]) + redis.call('EXPIRE', KEYS[1], KEYS[5]) + + return +end + +sz = tonumber(sz) +local lim = tonumber(KEYS[4]) + +if sz > lim then + + if k then + -- Existing key + redis.call('HINCRBYFLOAT', KEYS[1], KEYS[2], math.abs(KEYS[3])) + end +else + redis.call('HINCRBYFLOAT', KEYS[1], KEYS[2], math.abs(KEYS[3])) + redis.call('EXPIRE', KEYS[1], KEYS[6]) +end + +redis.call('HINCRBYFLOAT', KEYS[1], '__s', KEYS[3]) +redis.call('EXPIRE', KEYS[1], KEYS[5]) +]] +local update_cluster_id + +-- Callbacks and logic + +local function clusterting_filter_cb(task, rule) + local source_selector = rule.source_selector(task) + local cluster_selector + + if source_selector then + cluster_selector = rule.cluster_selector(task) + end + + if not cluster_selector or not source_selector then + rspamd_logger.debugm(N, task, 'skip rule %s, selectors: source="%s", cluster="%s"', + rule.name, source_selector, cluster_selector) + return + end + + local function combine_scores(cur_elts, total_score, element_score) + local final_score + + local size_score = cur_elts * rule.size_mult + local cluster_score = total_score * rule.score_mult + + if element_score > 0 then + -- We have seen this element mostly in junk/spam + final_score = math.min(1.0, size_score + cluster_score) + else + -- We have seen this element in ham mostly, so subtract average it from the size score + final_score = math.min(1.0, size_score - cluster_score / cur_elts) + end + rspamd_logger.debugm(N, task, + 'processed rule %s, selectors: source="%s", cluster="%s"; data: %s elts, %s score, %s elt score', + rule.name, source_selector, cluster_selector, cur_elts, total_score, element_score) + if final_score > 0.1 then + task:insert_result(rule.symbol, final_score, { source_selector, + tostring(size_score), + tostring(cluster_score) }) + end + end + + local function redis_get_cb(err, data) + if data then + if type(data) == 'table' then + combine_scores(tonumber(data[1]), tonumber(data[2]), tonumber(data[3])) + else + rspamd_logger.errx(task, 'invalid type while getting clustering keys %s: %s', + source_selector, type(data)) + end + + elseif err then + rspamd_logger.errx(task, 'got error while getting clustering keys %s: %s', + source_selector, err) + else + rspamd_logger.errx(task, 'got error while getting clustering keys %s: %s', + source_selector, "unknown error") + end + end + + lua_redis.exec_redis_script(query_cluster_id, + { task = task, is_write = false, key = source_selector }, + redis_get_cb, + { source_selector, cluster_selector }) +end + +local function clusterting_idempotent_cb(task, rule) + if task:has_flag('skip') then + return + end + if not rule.allow_local and lua_util.is_rspamc_or_controller(task) then + return + end + + local verdict = lua_verdict.get_specific_verdict(N, task) + local score + + if verdict == 'ham' then + score = rule.ham_mult + elseif verdict == 'spam' then + score = rule.spam_mult + elseif verdict == 'junk' then + score = rule.junk_mult + else + rspamd_logger.debugm(N, task, 'skip rule %s, verdict=%s', + rule.name, verdict) + return + end + + local source_selector = rule.source_selector(task) + local cluster_selector + + if source_selector then + cluster_selector = rule.cluster_selector(task) + end + + if not cluster_selector or not source_selector then + rspamd_logger.debugm(N, task, 'skip rule %s, selectors: source="%s", cluster="%s"', + rule.name, source_selector, cluster_selector) + return + end + + local function redis_set_cb(err, data) + if err then + rspamd_logger.errx(task, 'got error while getting clustering keys %s: %s', + source_selector, err) + else + rspamd_logger.debugm(N, task, 'set clustering key for %s: %s{%s} = %s', + source_selector, "unknown error") + end + end + + lua_redis.exec_redis_script(update_cluster_id, + { task = task, is_write = true, key = source_selector }, + redis_set_cb, + { + source_selector, + cluster_selector, + tostring(score), + tostring(rule.max_elts), + tostring(rule.expire), + tostring(rule.expire_overflow) + } + ) +end +-- Init part +redis_params = lua_redis.parse_redis_server('clustering') +local opts = rspamd_config:get_all_opt("clustering") + +-- Initialization part +if not (opts and type(opts) == 'table') then + lua_util.disable_module(N, "config") + return +end + +if not redis_params then + lua_util.disable_module(N, "redis") + return +end + +if opts['rules'] then + for k, v in pairs(opts['rules']) do + local raw_rule = lua_util.override_defaults(default_rule, v) + + local rule, err = rule_schema:transform(raw_rule) + + if not rule then + rspamd_logger.errx(rspamd_config, 'invalid clustering rule %s: %s', + k, err) + else + + if not rule.symbol then + rule.symbol = k + end + if not rule.prefix then + rule.prefix = k .. "_" + end + + rule.source_selector = lua_selectors.create_selector_closure(rspamd_config, + rule.source_selector, '') + rule.cluster_selector = lua_selectors.create_selector_closure(rspamd_config, + rule.cluster_selector, '') + if rule.source_selector and rule.cluster_selector then + rule.name = k + table.insert(rules, rule) + end + end + end + + if #rules > 0 then + + query_cluster_id = lua_redis.add_redis_script(query_cluster_script, redis_params) + update_cluster_id = lua_redis.add_redis_script(update_cluster_script, redis_params) + local function callback_gen(f, rule) + return function(task) + return f(task, rule) + end + end + + for _, rule in ipairs(rules) do + rspamd_config:register_symbol { + name = rule.symbol, + type = 'normal', + callback = callback_gen(clusterting_filter_cb, rule), + } + rspamd_config:register_symbol { + name = rule.symbol .. '_STORE', + type = 'idempotent', + flags = 'empty,explicit_disable,ignore_passthrough', + callback = callback_gen(clusterting_idempotent_cb, rule), + augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) } + } + end + else + lua_util.disable_module(N, "config") + end +else + lua_util.disable_module(N, "config") +end diff --git a/src/plugins/lua/dcc.lua b/src/plugins/lua/dcc.lua new file mode 100644 index 0000000..8508320 --- /dev/null +++ b/src/plugins/lua/dcc.lua @@ -0,0 +1,119 @@ +--[[ +Copyright (c) 2016, Steve Freegard <steve.freegard@fsl.com> +Copyright (c) 2016, Vsevolod Stakhov + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +-- Check messages for 'bulkiness' using DCC + +local N = 'dcc' +local symbol_bulk = "DCC_BULK" +local symbol = "DCC_REJECT" +local opts = rspamd_config:get_all_opt(N) +local lua_util = require "lua_util" +local rspamd_logger = require "rspamd_logger" +local dcc = require("lua_scanners").filter('dcc').dcc + +if confighelp then + rspamd_config:add_example(nil, 'dcc', + "Check messages for 'bulkiness' using DCC", + [[ + dcc { + socket = "/var/dcc/dccifd"; # Unix socket + servers = "127.0.0.1:10045" # OR TCP upstreams + timeout = 2s; # Timeout to wait for checks + body_max = 999999; # Bulkness threshold for body + fuz1_max = 999999; # Bulkness threshold for fuz1 + fuz2_max = 999999; # Bulkness threshold for fuz2 + } + ]]) + return +end + +local rule + +local function check_dcc (task) + dcc.check(task, task:get_content(), nil, rule) +end + +-- Configuration + +-- WORKAROUND for deprecated host and port settings +if opts['host'] ~= nil and opts['port'] ~= nil then + opts['servers'] = opts['host'] .. ':' .. opts['port'] + rspamd_logger.warnx(rspamd_config, 'Using host and port parameters is deprecated. ' .. + 'Please use servers = "%s:%s"; instead', opts['host'], opts['port']) +end +if opts['host'] ~= nil and not opts['port'] then + opts['socket'] = opts['host'] + rspamd_logger.warnx(rspamd_config, 'Using host parameters is deprecated. ' .. + 'Please use socket = "%s"; instead', opts['host']) +end +-- WORKAROUND for deprecated host and port settings + +if not opts.symbol_bulk then + opts.symbol_bulk = symbol_bulk +end +if not opts.symbol then + opts.symbol = symbol +end + +rule = dcc.configure(opts) + +if rule then + local id = rspamd_config:register_symbol({ + name = 'DCC_CHECK', + callback = check_dcc, + type = 'callback', + }) + rspamd_config:register_symbol { + type = 'virtual', + parent = id, + name = opts.symbol + } + rspamd_config:register_symbol { + type = 'virtual', + parent = id, + name = opts.symbol_bulk + } + rspamd_config:register_symbol { + type = 'virtual', + parent = id, + name = 'DCC_FAIL' + } + rspamd_config:set_metric_symbol({ + group = N, + score = 1.0, + description = 'Detected as bulk mail by DCC', + one_shot = true, + name = opts.symbol_bulk, + }) + rspamd_config:set_metric_symbol({ + group = N, + score = 2.0, + description = 'Rejected by DCC', + one_shot = true, + name = opts.symbol, + }) + rspamd_config:set_metric_symbol({ + group = N, + score = 0.0, + description = 'DCC failure', + one_shot = true, + name = 'DCC_FAIL', + }) +else + lua_util.disable_module(N, "config") + rspamd_logger.infox('DCC module not configured'); +end diff --git a/src/plugins/lua/dkim_signing.lua b/src/plugins/lua/dkim_signing.lua new file mode 100644 index 0000000..6c05520 --- /dev/null +++ b/src/plugins/lua/dkim_signing.lua @@ -0,0 +1,186 @@ +--[[ +Copyright (c) 2016, Andrew Lewis <nerf@judo.za.org> +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local lua_util = require "lua_util" +local rspamd_logger = require "rspamd_logger" +local dkim_sign_tools = require "lua_dkim_tools" +local lua_redis = require "lua_redis" +local lua_mime = require "lua_mime" + +if confighelp then + return +end + +local settings = { + allow_envfrom_empty = true, + allow_hdrfrom_mismatch = false, + allow_hdrfrom_mismatch_local = false, + allow_hdrfrom_mismatch_sign_networks = false, + allow_hdrfrom_multiple = false, + allow_username_mismatch = false, + allow_pubkey_mismatch = true, + sign_authenticated = true, + allowed_ids = nil, + forbidden_ids = nil, + check_pubkey = false, + domain = {}, + path = string.format('%s/%s/%s', rspamd_paths['DBDIR'], 'dkim', '$domain.$selector.key'), + sign_local = true, + selector = 'dkim', + symbol = 'DKIM_SIGNED', + try_fallback = true, + use_domain = 'header', + use_esld = true, + use_redis = false, + key_prefix = 'dkim_keys', -- default hash name + use_milter_headers = false, -- use milter headers instead of `dkim_signature` +} + +local N = 'dkim_signing' +local redis_params +local sign_func = rspamd_plugins.dkim.sign + +local function insert_sign_results(task, ret, hdr, dkim_params) + if settings.use_milter_headers then + lua_mime.modify_headers(task, { + add = { + ['DKIM-Signature'] = { order = 1, value = hdr }, + } + }) + end + if ret then + task:insert_result(settings.symbol, 1.0, string.format('%s:s=%s', + dkim_params.domain, dkim_params.selector)) + end +end + +local function do_sign(task, p) + if settings.use_milter_headers then + p.no_cache = true -- Disable caching in rspamd_mempool + end + if settings.check_pubkey then + local resolve_name = p.selector .. "._domainkey." .. p.domain + task:get_resolver():resolve_txt({ + task = task, + name = resolve_name, + callback = function(_, _, results, err) + if not err and results and results[1] then + p.pubkey = results[1] + p.strict_pubkey_check = not settings.allow_pubkey_mismatch + elseif not settings.allow_pubkey_mismatch then + rspamd_logger.infox(task, 'public key for domain %s/%s is not found: %s, skip signing', + p.domain, p.selector, err) + return + else + rspamd_logger.infox(task, 'public key for domain %s/%s is not found: %s', + p.domain, p.selector, err) + end + + local sret, hdr = sign_func(task, p) + insert_sign_results(task, sret, hdr, p) + end, + forced = true + }) + else + local sret, hdr = sign_func(task, p) + insert_sign_results(task, sret, hdr, p) + end +end + +local function sign_error(task, msg) + rspamd_logger.errx(task, 'signing failure: %s', msg) +end + +local function dkim_signing_cb(task) + local ret, selectors = dkim_sign_tools.prepare_dkim_signing(N, task, settings) + + if not ret then + return + end + + if settings.use_redis then + dkim_sign_tools.sign_using_redis(N, task, settings, selectors, do_sign, sign_error) + else + if selectors.vault then + dkim_sign_tools.sign_using_vault(N, task, settings, selectors, do_sign, sign_error) + else + if #selectors > 0 then + for _, k in ipairs(selectors) do + -- templates + if k.key then + k.key = lua_util.template(k.key, { + domain = k.domain, + selector = k.selector + }) + lua_util.debugm(N, task, 'using key "%s", use selector "%s" for domain "%s"', + k.key, k.selector, k.domain) + end + + do_sign(task, k) + end + else + rspamd_logger.infox(task, 'key path or dkim selector unconfigured; no signing') + return false + end + end + end +end + +local opts = rspamd_config:get_all_opt('dkim_signing') +if not opts then + return +end + +dkim_sign_tools.process_signing_settings(N, settings, opts) + +if not dkim_sign_tools.validate_signing_settings(settings) then + rspamd_logger.infox(rspamd_config, 'mandatory parameters missing, disable dkim signing') + lua_util.disable_module(N, "config") + return +end + +if settings.use_redis then + redis_params = lua_redis.parse_redis_server('dkim_signing') + + if not redis_params then + rspamd_logger.errx(rspamd_config, + 'no servers are specified, but module is configured to load keys from redis, disable dkim signing') + lua_util.disable_module(N, "redis") + return + end + + settings.redis_params = redis_params +end + +local sym_reg_tbl = { + name = settings['symbol'], + callback = dkim_signing_cb, + groups = { "policies", "dkim" }, + flags = 'ignore_passthrough', + score = 0.0, +} + +if type(settings.allowed_ids) == 'table' then + sym_reg_tbl.allowed_ids = settings.allowed_ids +end +if type(settings.forbidden_ids) == 'table' then + sym_reg_tbl.forbidden_ids = settings.forbidden_ids +end + +rspamd_config:register_symbol(sym_reg_tbl) +-- Add dependency on DKIM checks +rspamd_config:register_dependency(settings['symbol'], 'DKIM_CHECK') diff --git a/src/plugins/lua/dmarc.lua b/src/plugins/lua/dmarc.lua new file mode 100644 index 0000000..792672b --- /dev/null +++ b/src/plugins/lua/dmarc.lua @@ -0,0 +1,685 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> +Copyright (c) 2015-2016, Andrew Lewis <nerf@judo.za.org> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +-- Dmarc policy filter + +local rspamd_logger = require "rspamd_logger" +local rspamd_util = require "rspamd_util" +local lua_redis = require "lua_redis" +local lua_util = require "lua_util" +local dmarc_common = require "plugins/dmarc" + +if confighelp then + return +end + +local N = 'dmarc' + +local settings = dmarc_common.default_settings + +local redis_params = nil + +local E = {} + +-- Keys: +-- 1 = index key (string) +-- 2 = report key (string) +-- 3 = max report elements (number) +-- 4 = expiry time for elements (number) +-- Arguments +-- 1 = dmarc domain +-- 2 = dmarc report +local take_report_id +local take_report_script = [[ +local index_key = KEYS[1] +local report_key = KEYS[2] +local max_entries = -(tonumber(KEYS[3]) + 1) +local keys_expiry = tonumber(KEYS[4]) +local dmarc_domain = ARGV[1] +local report = ARGV[2] +redis.call('SADD', index_key, report_key) +redis.call('EXPIRE', index_key, 172800) +redis.call('ZINCRBY', report_key, 1, report) +redis.call('ZREMRANGEBYRANK', report_key, 0, max_entries) +redis.call('EXPIRE', report_key, 172800) +]] + +local function maybe_force_action(task, disposition) + if disposition then + local force_action = settings.actions[disposition] + if force_action then + -- Set least action + task:set_pre_result(force_action, 'Action set by DMARC', N, nil, nil, 'least') + end + end +end + +local function dmarc_validate_policy(task, policy, hdrfromdom, dmarc_esld) + local reason = {} + + -- Check dkim and spf symbols + local spf_ok = false + local dkim_ok = false + local spf_tmpfail = false + local dkim_tmpfail = false + + local spf_domain = ((task:get_from(1) or E)[1] or E).domain + + if not spf_domain or spf_domain == '' then + spf_domain = task:get_helo() or '' + end + + if task:has_symbol(settings.symbols['spf_allow_symbol']) then + if policy.strict_spf then + if rspamd_util.strequal_caseless(spf_domain, hdrfromdom) then + spf_ok = true + else + table.insert(reason, "SPF not aligned (strict)") + end + else + local spf_tld = rspamd_util.get_tld(spf_domain) + if rspamd_util.strequal_caseless(spf_tld, dmarc_esld) then + spf_ok = true + else + table.insert(reason, "SPF not aligned (relaxed)") + end + end + else + if task:has_symbol(settings.symbols['spf_tempfail_symbol']) then + if policy.strict_spf then + if rspamd_util.strequal_caseless(spf_domain, hdrfromdom) then + spf_tmpfail = true + end + else + local spf_tld = rspamd_util.get_tld(spf_domain) + if rspamd_util.strequal_caseless(spf_tld, dmarc_esld) then + spf_tmpfail = true + end + end + end + + table.insert(reason, "No valid SPF") + end + + local opts = ((task:get_symbol('DKIM_TRACE') or E)[1] or E).options + local dkim_results = { + pass = {}, + temperror = {}, + permerror = {}, + fail = {}, + } + + if opts then + dkim_results.pass = {} + local dkim_violated + + for _, opt in ipairs(opts) do + local check_res = string.sub(opt, -1) + local domain = string.sub(opt, 1, -3):lower() + + if check_res == '+' then + table.insert(dkim_results.pass, domain) + + if policy.strict_dkim then + if rspamd_util.strequal_caseless(hdrfromdom, domain) then + dkim_ok = true + else + dkim_violated = "DKIM not aligned (strict)" + end + else + local dkim_tld = rspamd_util.get_tld(domain) + + if rspamd_util.strequal_caseless(dkim_tld, dmarc_esld) then + dkim_ok = true + else + dkim_violated = "DKIM not aligned (relaxed)" + end + end + elseif check_res == '?' then + -- Check for dkim tempfail + if not dkim_ok then + if policy.strict_dkim then + if rspamd_util.strequal_caseless(hdrfromdom, domain) then + dkim_tmpfail = true + end + else + local dkim_tld = rspamd_util.get_tld(domain) + + if rspamd_util.strequal_caseless(dkim_tld, dmarc_esld) then + dkim_tmpfail = true + end + end + end + table.insert(dkim_results.temperror, domain) + elseif check_res == '-' then + table.insert(dkim_results.fail, domain) + else + table.insert(dkim_results.permerror, domain) + end + end + + if not dkim_ok and dkim_violated then + table.insert(reason, dkim_violated) + end + else + table.insert(reason, "No valid DKIM") + end + + lua_util.debugm(N, task, + "validated dmarc policy for %s: %s; dkim_ok=%s, dkim_tempfail=%s, spf_ok=%s, spf_tempfail=%s", + policy.domain, policy.dmarc_policy, + dkim_ok, dkim_tmpfail, + spf_ok, spf_tmpfail) + + local disposition = 'none' + local sampled_out = false + + local function handle_dmarc_failure(what, reason_str) + if not policy.pct or policy.pct == 100 then + task:insert_result(settings.symbols[what], 1.0, + policy.domain .. ' : ' .. reason_str, policy.dmarc_policy) + disposition = what + else + local coin = math.random(100) + if (coin > policy.pct) then + if (not settings.no_sampling_domains or + not settings.no_sampling_domains:get_key(policy.domain)) then + + if what == 'reject' then + disposition = 'quarantine' + else + disposition = 'softfail' + end + + task:insert_result(settings.symbols[disposition], 1.0, + policy.domain .. ' : ' .. reason_str, policy.dmarc_policy, "sampled_out") + sampled_out = true + lua_util.debugm(N, task, + 'changed dmarc policy from %s to %s, sampled out: %s < %s', + what, disposition, coin, policy.pct) + else + task:insert_result(settings.symbols[what], 1.0, + policy.domain .. ' : ' .. reason_str, policy.dmarc_policy, "local_policy") + disposition = what + end + else + task:insert_result(settings.symbols[what], 1.0, + policy.domain .. ' : ' .. reason_str, policy.dmarc_policy) + disposition = what + end + end + + maybe_force_action(task, disposition) + end + + if spf_ok or dkim_ok then + --[[ + https://tools.ietf.org/html/rfc7489#section-6.6.2 + DMARC evaluation can only yield a "pass" result after one of the + underlying authentication mechanisms passes for an aligned + identifier. + ]]-- + task:insert_result(settings.symbols['allow'], 1.0, policy.domain, + policy.dmarc_policy) + else + --[[ + https://tools.ietf.org/html/rfc7489#section-6.6.2 + + If neither passes and one or both of them fail due to a + temporary error, the Receiver evaluating the message is unable to + conclude that the DMARC mechanism had a permanent failure; they + therefore cannot apply the advertised DMARC policy. + ]]-- + if spf_tmpfail or dkim_tmpfail then + task:insert_result(settings.symbols['dnsfail'], 1.0, policy.domain .. + ' : ' .. 'SPF/DKIM temp error', policy.dmarc_policy) + else + -- We can now check the failed policy and maybe send report data elt + local reason_str = table.concat(reason, ', ') + + if policy.dmarc_policy == 'quarantine' then + handle_dmarc_failure('quarantine', reason_str) + elseif policy.dmarc_policy == 'reject' then + handle_dmarc_failure('reject', reason_str) + else + task:insert_result(settings.symbols['softfail'], 1.0, + policy.domain .. ' : ' .. reason_str, + policy.dmarc_policy) + end + end + end + + if policy.rua and redis_params and settings.reporting.enabled then + if settings.reporting.exclude_domains then + if settings.reporting.exclude_domains:get_key(policy.domain) or + settings.reporting.exclude_domains:get_key(rspamd_util.get_tld(policy.domain)) then + rspamd_logger.info(task, 'DMARC reporting suppressed for sender domain %s', policy.domain) + return + end + end + if settings.reporting.exclude_recipients then + local rcpt = task:get_principal_recipient() + if rcpt and settings.reporting.exclude_recipients:get_key(rcpt) then + rspamd_logger.info(task, 'DMARC reporting suppressed for recipient %s', rcpt) + return + end + end + + local function dmarc_report_cb(err) + if not err then + rspamd_logger.infox(task, 'dmarc report saved for %s (rua = %s)', + hdrfromdom, policy.rua) + else + rspamd_logger.errx(task, 'dmarc report is not saved for %s: %s', + hdrfromdom, err) + end + end + + local spf_result + if spf_ok then + spf_result = 'pass' + elseif spf_tmpfail then + spf_result = 'temperror' + else + if task:has_symbol(settings.symbols.spf_deny_symbol) then + spf_result = 'fail' + elseif task:has_symbol(settings.symbols.spf_softfail_symbol) then + spf_result = 'softfail' + elseif task:has_symbol(settings.symbols.spf_neutral_symbol) then + spf_result = 'neutral' + elseif task:has_symbol(settings.symbols.spf_permfail_symbol) then + spf_result = 'permerror' + else + spf_result = 'none' + end + end + + -- Prepare and send redis report element + local period = os.date('%Y%m%d', + task:get_date({ format = 'connect', gmt = false })) + + -- Dmarc domain key must include dmarc domain, rua and period + local dmarc_domain_key = table.concat( + { settings.reporting.redis_keys.report_prefix, policy.domain, policy.rua, period }, + settings.reporting.redis_keys.join_char) + local report_data = dmarc_common.dmarc_report(task, settings, { + spf_ok = spf_ok and 'pass' or 'fail', + dkim_ok = dkim_ok and 'pass' or 'fail', + disposition = (disposition == "softfail") and "none" or disposition, + sampled_out = sampled_out, + domain = hdrfromdom, + spf_domain = spf_domain, + dkim_results = dkim_results, + spf_result = spf_result + }) + + local idx_key = table.concat({ settings.reporting.redis_keys.index_prefix, period }, + settings.reporting.redis_keys.join_char) + + if report_data then + lua_redis.exec_redis_script(take_report_id, + { task = task, is_write = true }, + dmarc_report_cb, + { idx_key, dmarc_domain_key, + tostring(settings.reporting.max_entries), tostring(settings.reporting.keys_expire) }, + { hdrfromdom, report_data }) + end + end +end + +local function dmarc_callback(task) + local from = task:get_from(2) + local hfromdom = ((from or E)[1] or E).domain + local dmarc_domain + local ip_addr = task:get_ip() + local dmarc_checks = task:get_mempool():get_variable('dmarc_checks', 'double') or 0 + local seen_invalid = false + + if dmarc_checks ~= 2 then + rspamd_logger.infox(task, "skip DMARC checks as either SPF or DKIM were not checked") + return + end + + if lua_util.is_skip_local_or_authed(task, settings.auth_and_local_conf, ip_addr) then + rspamd_logger.infox(task, "skip DMARC checks for local networks and authorized users") + return + end + + -- Do some initial sanity checks, detect tld domain if different + if hfromdom and hfromdom ~= '' and not (from or E)[2] then + -- Lowercase domain as per #3940 + hfromdom = hfromdom:lower() + dmarc_domain = rspamd_util.get_tld(hfromdom) + elseif (from or E)[2] then + task:insert_result(settings.symbols['na'], 1.0, 'Duplicate From header') + return maybe_force_action(task, 'na') + elseif (from or E)[1] then + task:insert_result(settings.symbols['na'], 1.0, 'No domain in From header') + return maybe_force_action(task, 'na') + else + task:insert_result(settings.symbols['na'], 1.0, 'No From header') + return maybe_force_action(task, 'na') + end + + local dns_checks_inflight = 0 + local dmarc_domain_policy = {} + local dmarc_tld_policy = {} + + local function process_dmarc_policy(policy, final) + lua_util.debugm(N, task, "validate DMARC policy (final=%s): %s", + true, policy) + if policy.err and policy.symbol then + -- In case of fatal errors or final check for tld, we give up and + -- insert result + if final or policy.fatal then + task:insert_result(policy.symbol, 1.0, policy.err) + maybe_force_action(task, policy.disposition) + + return true + end + elseif policy.dmarc_policy then + dmarc_validate_policy(task, policy, hfromdom, dmarc_domain) + + return true -- We have a more specific version, use it + end + + return false -- Missing record + end + + local function gen_dmarc_cb(lookup_domain, is_tld) + local policy_target = dmarc_domain_policy + if is_tld then + policy_target = dmarc_tld_policy + end + + return function(_, _, results, err) + dns_checks_inflight = dns_checks_inflight - 1 + + if not seen_invalid then + policy_target.domain = lookup_domain + + if err then + if (err ~= 'requested record is not found' and + err ~= 'no records with this name') then + policy_target.err = lookup_domain .. ' : ' .. err + policy_target.symbol = settings.symbols['dnsfail'] + else + policy_target.err = lookup_domain + policy_target.symbol = settings.symbols['na'] + end + else + local has_valid_policy = false + + for _, rec in ipairs(results) do + local ret, results_or_err = dmarc_common.dmarc_check_record(task, rec, is_tld) + + if not ret then + if results_or_err then + -- We have a fatal parsing error, give up + policy_target.err = lookup_domain .. ' : ' .. results_or_err + policy_target.symbol = settings.symbols['badpolicy'] + policy_target.fatal = true + seen_invalid = true + end + else + if has_valid_policy then + policy_target.err = lookup_domain .. ' : ' .. + 'Multiple policies defined in DNS' + policy_target.symbol = settings.symbols['badpolicy'] + policy_target.fatal = true + seen_invalid = true + end + has_valid_policy = true + + for k, v in pairs(results_or_err) do + policy_target[k] = v + end + end + end + + if not has_valid_policy and not seen_invalid then + policy_target.err = lookup_domain .. ':' .. ' no valid DMARC record' + policy_target.symbol = settings.symbols['na'] + end + end + end + + if dns_checks_inflight == 0 then + lua_util.debugm(N, task, "finished DNS queries, validate policies") + -- We have checked both tld and real domain (if different) + if not process_dmarc_policy(dmarc_domain_policy, false) then + -- Try tld policy as well + if not process_dmarc_policy(dmarc_tld_policy, true) then + process_dmarc_policy(dmarc_domain_policy, true) + end + end + end + end + end + + local resolve_name = '_dmarc.' .. hfromdom + + task:get_resolver():resolve_txt({ + task = task, + name = resolve_name, + callback = gen_dmarc_cb(hfromdom, false), + forced = true + }) + dns_checks_inflight = dns_checks_inflight + 1 + + if dmarc_domain ~= hfromdom then + resolve_name = '_dmarc.' .. dmarc_domain + + task:get_resolver():resolve_txt({ + task = task, + name = resolve_name, + callback = gen_dmarc_cb(dmarc_domain, true), + forced = true + }) + + dns_checks_inflight = dns_checks_inflight + 1 + end +end + +local opts = rspamd_config:get_all_opt('dmarc') +settings = lua_util.override_defaults(settings, opts) + +settings.auth_and_local_conf = lua_util.config_check_local_or_authed(rspamd_config, N, + false, false) + +-- Legacy... +if settings.reporting and not settings.reporting.exclude_domains and settings.no_reporting_domains then + settings.reporting.exclude_domains = settings.no_reporting_domains +end + +local lua_maps = require "lua_maps" +lua_maps.fill_config_maps(N, settings, { + no_sampling_domains = { + optional = true, + type = 'map', + description = 'Domains not to apply DMARC sampling to' + }, +}) + +if type(settings.reporting) == 'table' then + lua_maps.fill_config_maps(N, settings.reporting, { + exclude_domains = { + optional = true, + type = 'map', + description = 'Domains not to store DMARC reports about' + }, + exclude_recipients = { + optional = true, + type = 'map', + description = 'Recipients not to store DMARC reports for' + }, + }) +end + +if settings.reporting == true then + rspamd_logger.errx(rspamd_config, 'old style dmarc reporting is NO LONGER supported, please read the documentation') +elseif settings.reporting.enabled then + redis_params = lua_redis.parse_redis_server('dmarc', opts) + if not redis_params then + rspamd_logger.errx(rspamd_config, 'cannot parse servers parameter') + else + rspamd_logger.infox(rspamd_config, 'dmarc reporting is enabled') + take_report_id = lua_redis.add_redis_script(take_report_script, redis_params) + end +end + +-- Check spf and dkim sections for changed symbols +local function check_mopt(var, m_opts, name) + if m_opts[name] then + settings.symbols[var] = tostring(m_opts[name]) + end +end + +local spf_opts = rspamd_config:get_all_opt('spf') +if spf_opts then + check_mopt('spf_deny_symbol', spf_opts, 'symbol_fail') + check_mopt('spf_allow_symbol', spf_opts, 'symbol_allow') + check_mopt('spf_softfail_symbol', spf_opts, 'symbol_softfail') + check_mopt('spf_neutral_symbol', spf_opts, 'symbol_neutral') + check_mopt('spf_tempfail_symbol', spf_opts, 'symbol_dnsfail') + check_mopt('spf_na_symbol', spf_opts, 'symbol_na') +end + +local dkim_opts = rspamd_config:get_all_opt('dkim') +if dkim_opts then + check_mopt('dkim_deny_symbol', dkim_opts, 'symbol_reject') + check_mopt('dkim_allow_symbol', dkim_opts, 'symbol_allow') + check_mopt('dkim_tempfail_symbol', dkim_opts, 'symbol_tempfail') + check_mopt('dkim_na_symbol', dkim_opts, 'symbol_na') +end + +local id = rspamd_config:register_symbol({ + name = 'DMARC_CHECK', + type = 'callback', + callback = dmarc_callback +}) +rspamd_config:register_symbol({ + name = 'DMARC_CALLBACK', -- compatibility symbol + type = 'virtual,skip', + parent = id, +}) +rspamd_config:register_symbol({ + name = settings.symbols['allow'], + parent = id, + group = 'policies', + groups = { 'dmarc' }, + type = 'virtual' +}) +rspamd_config:register_symbol({ + name = settings.symbols['reject'], + parent = id, + group = 'policies', + groups = { 'dmarc' }, + type = 'virtual' +}) +rspamd_config:register_symbol({ + name = settings.symbols['quarantine'], + parent = id, + group = 'policies', + groups = { 'dmarc' }, + type = 'virtual' +}) +rspamd_config:register_symbol({ + name = settings.symbols['softfail'], + parent = id, + group = 'policies', + groups = { 'dmarc' }, + type = 'virtual' +}) +rspamd_config:register_symbol({ + name = settings.symbols['dnsfail'], + parent = id, + group = 'policies', + groups = { 'dmarc' }, + type = 'virtual' +}) +rspamd_config:register_symbol({ + name = settings.symbols['badpolicy'], + parent = id, + group = 'policies', + groups = { 'dmarc' }, + type = 'virtual' +}) +rspamd_config:register_symbol({ + name = settings.symbols['na'], + parent = id, + group = 'policies', + groups = { 'dmarc' }, + type = 'virtual' +}) + +rspamd_config:register_dependency('DMARC_CHECK', settings.symbols['spf_allow_symbol']) +rspamd_config:register_dependency('DMARC_CHECK', settings.symbols['dkim_allow_symbol']) + +-- DMARC munging support +if settings.munging then + local lua_maps_expressions = require "lua_maps_expressions" + + local munging_defaults = { + reply_goes_to_list = false, + mitigate_allow_only = true, -- perform munging based on DMARC_POLICY_ALLOW only + mitigate_strict_only = false, -- perform mugning merely for reject/quarantine policies + munge_from = true, -- replace from with something like <orig name> via <rcpt user> + list_map = nil, -- map of maillist domains + munge_map_condition = nil, -- maps expression to enable munging + } + + local munging_opts = lua_util.override_defaults(munging_defaults, settings.munging) + + if not munging_opts.list_map then + rspamd_logger.errx(rspamd_config, 'cannot enable DMARC munging with no list_map parameter') + + return + end + + munging_opts.list_map = lua_maps.map_add_from_ucl(munging_opts.list_map, + 'set', 'DMARC munging map of the recipients addresses to munge') + + if not munging_opts.list_map then + rspamd_logger.errx(rspamd_config, 'cannot enable DMARC munging with invalid list_map (invalid map)') + + return + end + + if munging_opts.munge_map_condition then + munging_opts.munge_map_condition = lua_maps_expressions.create(rspamd_config, + munging_opts.munge_map_condition, N) + end + + rspamd_config:register_symbol({ + name = 'DMARC_MUNGED', + type = 'normal', + flags = 'nostat', + score = 0, + group = 'policies', + groups = { 'dmarc' }, + callback = dmarc_common.gen_munging_callback(munging_opts, settings), + augmentations = { lua_util.dns_timeout_augmentation(rspamd_config) }, + }) + + rspamd_config:register_dependency('DMARC_MUNGED', 'DMARC_CHECK') + -- To avoid dkim signing issues + rspamd_config:register_dependency('DKIM_SIGNED', 'DMARC_MUNGED') + rspamd_config:register_dependency('ARC_SIGNED', 'DMARC_MUNGED') + + rspamd_logger.infox(rspamd_config, 'enabled DMARC munging') +end diff --git a/src/plugins/lua/dynamic_conf.lua b/src/plugins/lua/dynamic_conf.lua new file mode 100644 index 0000000..5af26a9 --- /dev/null +++ b/src/plugins/lua/dynamic_conf.lua @@ -0,0 +1,333 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local rspamd_logger = require "rspamd_logger" +local redis_params +local ucl = require "ucl" +local fun = require "fun" +local lua_util = require "lua_util" +local rspamd_redis = require "lua_redis" +local N = "dynamic_conf" + +if confighelp then + return +end + +local settings = { + redis_key = "dynamic_conf", + redis_watch_interval = 10.0, + priority = 10 +} + +local cur_settings = { + version = 0, + updates = { + symbols = {}, + actions = {}, + has_updates = false + } +} + +local function alpha_cmp(v1, v2) + local math = math + if math.abs(v1 - v2) < 0.001 then + return true + end + + return false +end + +local function apply_dynamic_actions(_, acts) + fun.each(function(k, v) + if type(v) == 'table' then + v['name'] = k + if not v['priority'] then + v['priority'] = settings.priority + end + rspamd_config:set_metric_action(v) + else + rspamd_config:set_metric_symbol({ + name = k, + score = v, + priority = settings.priority + }) + end + end, fun.filter(function(k, v) + local act = rspamd_config:get_metric_action(k) + if (act and alpha_cmp(act, v)) or cur_settings.updates.actions[k] then + return false + end + + return true + end, acts)) +end + +local function apply_dynamic_scores(_, sc) + fun.each(function(k, v) + if type(v) == 'table' then + v['name'] = k + if not v['priority'] then + v['priority'] = settings.priority + end + rspamd_config:set_metric_symbol(v) + else + rspamd_config:set_metric_symbol({ + name = k, + score = v, + priority = settings.priority + }) + end + end, fun.filter(function(k, v) + -- Select elts with scores that are different from local ones + local sym = rspamd_config:get_symbol(k) + if (sym and alpha_cmp(sym.score, v)) or cur_settings.updates.symbols[k] then + return false + end + + return true + end, sc)) +end + +local function apply_dynamic_conf(cfg, data) + if data['scores'] then + -- Apply scores changes + apply_dynamic_scores(cfg, data['scores']) + end + + if data['actions'] then + apply_dynamic_actions(cfg, data['actions']) + end + + if data['symbols_enabled'] then + fun.each(function(_, v) + cfg:enable_symbol(v) + end, data['symbols_enabled']) + end + + if data['symbols_disabled'] then + fun.each(function(_, v) + cfg:disable_symbol(v) + end, data['symbols_disabled']) + end +end + +local function update_dynamic_conf(cfg, ev_base, recv) + local function redis_version_set_cb(err, data) + if err then + rspamd_logger.errx(cfg, "cannot save dynamic conf version to redis: %s", err) + else + rspamd_logger.infox(cfg, "saved dynamic conf version: %s", data) + cur_settings.updates.has_updates = false + cur_settings.updates.symbols = {} + cur_settings.updates.actions = {} + end + end + local function redis_data_set_cb(err) + if err then + rspamd_logger.errx(cfg, "cannot save dynamic conf to redis: %s", err) + else + rspamd_redis.redis_make_request_taskless(ev_base, + cfg, + redis_params, + settings.redis_key, + true, + redis_version_set_cb, + 'HINCRBY', { settings.redis_key, 'v', '1' }) + end + end + + if recv then + -- We need to merge two configs + if recv['scores'] then + if not cur_settings.data.scores then + cur_settings.data.scores = {} + end + fun.each(function(k, v) + cur_settings.data.scores[k] = v + end, + fun.filter(function(k) + if cur_settings.updates.symbols[k] then + return false + end + return true + end, recv['scores'])) + end + if recv['actions'] then + if not cur_settings.data.actions then + cur_settings.data.actions = {} + end + fun.each(function(k, v) + cur_settings.data.actions[k] = v + end, + fun.filter(function(k) + if cur_settings.updates.actions[k] then + return false + end + return true + end, recv['actions'])) + end + end + local newdata = ucl.to_format(cur_settings.data, 'json-compact') + rspamd_redis.redis_make_request_taskless(ev_base, cfg, redis_params, + settings.redis_key, true, + redis_data_set_cb, 'HSET', { settings.redis_key, 'd', newdata }) +end + +local function check_dynamic_conf(cfg, ev_base) + local function redis_load_cb(redis_err, data) + if redis_err then + rspamd_logger.errx(cfg, "cannot read dynamic conf from redis: %s", redis_err) + elseif data and type(data) == 'string' then + local parser = ucl.parser() + local _, err = parser:parse_string(data) + + if err then + rspamd_logger.errx(cfg, "cannot load dynamic conf from redis: %s", err) + else + local d = parser:get_object() + apply_dynamic_conf(cfg, d) + if cur_settings.updates.has_updates then + -- Need to send our updates to Redis + update_dynamic_conf(cfg, ev_base, d) + else + cur_settings.data = d + end + end + end + end + local function redis_check_cb(err, data) + if not err and type(data) == 'string' then + local rver = tonumber(data) + + if not cur_settings.version or (rver and rver > cur_settings.version) then + rspamd_logger.infox(cfg, "need to load fresh dynamic settings with version %s, local version is %s", + rver, cur_settings.version) + cur_settings.version = rver + rspamd_redis.redis_make_request_taskless(ev_base, cfg, redis_params, + settings.redis_key, false, + redis_load_cb, 'HGET', { settings.redis_key, 'd' }) + elseif cur_settings.updates.has_updates then + -- Need to send our updates to Redis + update_dynamic_conf(cfg, ev_base) + end + elseif cur_settings.updates.has_updates then + -- Need to send our updates to Redis + update_dynamic_conf(cfg, ev_base) + end + end + + rspamd_redis.redis_make_request_taskless(ev_base, cfg, redis_params, + settings.redis_key, false, + redis_check_cb, 'HGET', { settings.redis_key, 'v' }) +end + +local section = rspamd_config:get_all_opt("dynamic_conf") +if section then + redis_params = rspamd_redis.parse_redis_server('dynamic_conf') + if not redis_params then + rspamd_logger.infox(rspamd_config, 'no servers are specified, disabling module') + return + end + + for k, v in pairs(section) do + settings[k] = v + end + + rspamd_config:add_on_load(function(_, ev_base, worker) + if worker:is_scanner() then + rspamd_config:add_periodic(ev_base, 0.0, + function(cfg, _) + check_dynamic_conf(cfg, ev_base) + return settings.redis_watch_interval + end, true) + end + end) +end + +-- Updates part +local function add_dynamic_symbol(_, sym, score) + local add = false + if not cur_settings.data then + cur_settings.data = {} + end + + if not cur_settings.data.scores then + cur_settings.data.scores = {} + cur_settings.data.scores[sym] = score + add = true + else + if cur_settings.data.scores[sym] then + if cur_settings.data.scores[sym] ~= score then + add = true + end + else + cur_settings.data.scores[sym] = score + add = true + end + end + + if add then + cur_settings.data.scores[sym] = score + table.insert(cur_settings.updates.symbols, sym) + cur_settings.updates.has_updates = true + end + + return add +end + +local function add_dynamic_action(_, act, score) + local add = false + if not cur_settings.data then + cur_settings.data = {} + cur_settings.version = 0 + end + + if not cur_settings.data.actions then + cur_settings.data.actions = {} + cur_settings.data.actions[act] = score + add = true + else + if cur_settings.data.actions[act] then + if cur_settings.data.actions[act] ~= score then + add = true + end + else + cur_settings.data.actions[act] = score + add = true + end + end + + if add then + cur_settings.data.actions[act] = score + table.insert(cur_settings.updates.actions, act) + cur_settings.updates.has_updates = true + end + + return add +end + +if section then + if redis_params then + rspamd_plugins["dynamic_conf"] = { + add_symbol = add_dynamic_symbol, + add_action = add_dynamic_action, + } + else + lua_util.disable_module(N, "redis") + end +else + lua_util.disable_module(N, "config") +end
\ No newline at end of file diff --git a/src/plugins/lua/elastic.lua b/src/plugins/lua/elastic.lua new file mode 100644 index 0000000..ccbb7c1 --- /dev/null +++ b/src/plugins/lua/elastic.lua @@ -0,0 +1,544 @@ +--[[ +Copyright (c) 2017, Veselin Iordanov +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local rspamd_logger = require 'rspamd_logger' +local rspamd_http = require "rspamd_http" +local lua_util = require "lua_util" +local util = require "rspamd_util" +local ucl = require "ucl" +local rspamd_redis = require "lua_redis" +local upstream_list = require "rspamd_upstream_list" +local lua_settings = require "lua_settings" + +if confighelp then + return +end + +local rows = {} +local nrows = 0 +local failed_sends = 0 +local elastic_template +local redis_params +local N = "elastic" +local E = {} +local HOSTNAME = util.get_hostname() +local connect_prefix = 'http://' +local enabled = true +local ingest_geoip_type = 'plugins' +local settings = { + limit = 500, + index_pattern = 'rspamd-%Y.%m.%d', + template_file = rspamd_paths['SHAREDIR'] .. '/elastic/rspamd_template.json', + kibana_file = rspamd_paths['SHAREDIR'] .. '/elastic/kibana.json', + key_prefix = 'elastic-', + expire = 3600, + timeout = 5.0, + failover = false, + import_kibana = false, + use_https = false, + use_gzip = true, + allow_local = false, + user = nil, + password = nil, + no_ssl_verify = false, + max_fail = 3, + ingest_module = false, + elasticsearch_version = 6, +} + +local function read_file(path) + local file = io.open(path, "rb") + if not file then + return nil + end + local content = file:read "*a" + file:close() + return content +end + +local function elastic_send_data(task) + local es_index = os.date(settings['index_pattern']) + local tbl = {} + for _, value in pairs(rows) do + if settings.elasticsearch_version >= 7 then + table.insert(tbl, '{ "index" : { "_index" : "' .. es_index .. + '","pipeline": "rspamd-geoip"} }') + else + table.insert(tbl, '{ "index" : { "_index" : "' .. es_index .. + '", "_type" : "_doc" ,"pipeline": "rspamd-geoip"} }') + end + table.insert(tbl, ucl.to_format(value, 'json-compact')) + end + + table.insert(tbl, '') -- For last \n + + local upstream = settings.upstream:get_upstream_round_robin() + local ip_addr = upstream:get_addr():to_string(true) + + local push_url = connect_prefix .. ip_addr .. '/' .. es_index .. '/_bulk' + local bulk_json = table.concat(tbl, "\n") + + local function http_callback(err, code, _, _) + if err then + rspamd_logger.infox(task, "cannot push data to elastic backend (%s): %s; failed attempts: %s/%s", + push_url, err, failed_sends, settings.max_fail) + else + if code ~= 200 then + rspamd_logger.infox(task, + "cannot push data to elastic backend (%s): wrong http code %s (%s); failed attempts: %s/%s", + push_url, err, code, failed_sends, settings.max_fail) + else + lua_util.debugm(N, task, "successfully sent %s (%s bytes) rows to ES", + nrows, #bulk_json) + end + end + end + + return rspamd_http.request({ + url = push_url, + headers = { + ['Content-Type'] = 'application/x-ndjson', + }, + body = bulk_json, + callback = http_callback, + task = task, + method = 'post', + gzip = settings.use_gzip, + no_ssl_verify = settings.no_ssl_verify, + user = settings.user, + password = settings.password, + timeout = settings.timeout, + }) +end + +local function get_general_metadata(task) + local r = {} + local ip_addr = task:get_ip() + + if ip_addr and ip_addr:is_valid() then + r.is_local = ip_addr:is_local() + r.ip = tostring(ip_addr) + else + r.ip = '127.0.0.1' + end + + r.webmail = false + r.sender_ip = 'unknown' + local origin = task:get_header('X-Originating-IP') + if origin then + origin = origin:gsub('%[', ''):gsub('%]', '') + local rspamd_ip = require "rspamd_ip" + local origin_ip = rspamd_ip.from_string(origin) + if origin_ip and origin_ip:is_valid() then + r.webmail = true + r.sender_ip = origin -- use string here + end + end + + r.direction = "Inbound" + r.user = task:get_user() or 'unknown' + r.qid = task:get_queue_id() or 'unknown' + r.action = task:get_metric_action() + r.rspamd_server = HOSTNAME + if r.user ~= 'unknown' then + r.direction = "Outbound" + end + local s = task:get_metric_score()[1] + r.score = s + + local rcpt = task:get_recipients('smtp') + if rcpt then + local l = {} + for _, a in ipairs(rcpt) do + table.insert(l, a['addr']) + end + r.rcpt = l + else + r.rcpt = 'unknown' + end + + local from = task:get_from { 'smtp', 'orig' } + if ((from or E)[1] or E).addr then + r.from = from[1].addr + else + r.from = 'unknown' + end + + local mime_from = task:get_from { 'mime', 'orig' } + if ((mime_from or E)[1] or E).addr then + r.mime_from = mime_from[1].addr + else + r.mime_from = 'unknown' + end + + local syminf = task:get_symbols_all() + r.symbols = syminf + r.asn = {} + local pool = task:get_mempool() + r.asn.country = pool:get_variable("country") or 'unknown' + r.asn.asn = pool:get_variable("asn") or 0 + r.asn.ipnet = pool:get_variable("ipnet") or 'unknown' + + local function process_header(name) + local hdr = task:get_header_full(name) + if hdr then + local l = {} + for _, h in ipairs(hdr) do + table.insert(l, h.decoded) + end + return l + else + return 'unknown' + end + end + + r.header_from = process_header('from') + r.header_to = process_header('to') + r.header_subject = process_header('subject') + r.header_date = process_header('date') + r.message_id = task:get_message_id() + local hname = task:get_hostname() or 'unknown' + r.hostname = hname + + local settings_id = task:get_settings_id() + + if settings_id then + -- Convert to string + settings_id = lua_settings.settings_by_id(settings_id) + + if settings_id then + settings_id = settings_id.name + end + end + + if not settings_id then + settings_id = '' + end + + r.settings_id = settings_id + + local scan_real = task:get_scan_time() + scan_real = math.floor(scan_real * 1000) + if scan_real < 0 then + rspamd_logger.messagex(task, + 'clock skew detected for message: %s ms real scan time (reset to 0)', + scan_real) + scan_real = 0 + end + + r.scan_time = scan_real + + return r +end + +local function elastic_collect(task) + if not enabled then + return + end + if task:has_flag('skip') then + return + end + if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then + return + end + + local row = { ['rspamd_meta'] = get_general_metadata(task), + ['@timestamp'] = tostring(util.get_time() * 1000) } + table.insert(rows, row) + nrows = nrows + 1 + if nrows > settings['limit'] then + lua_util.debugm(N, task, 'send elastic search rows: %s', nrows) + if elastic_send_data(task) then + nrows = 0 + rows = {} + failed_sends = 0; + else + failed_sends = failed_sends + 1 + + if failed_sends > settings.max_fail then + rspamd_logger.errx(task, 'cannot send %s rows to ES %s times, stop trying', + nrows, failed_sends) + nrows = 0 + rows = {} + failed_sends = 0; + end + end + end +end + +local opts = rspamd_config:get_all_opt('elastic') + +local function check_elastic_server(cfg, ev_base, _) + local upstream = settings.upstream:get_upstream_round_robin() + local ip_addr = upstream:get_addr():to_string(true) + local plugins_url = connect_prefix .. ip_addr .. '/_nodes/' .. ingest_geoip_type + local function http_callback(err, code, body, _) + if code == 200 then + local parser = ucl.parser() + local res, ucl_err = parser:parse_string(body) + if not res then + rspamd_logger.infox(rspamd_config, 'failed to parse reply from %s: %s', + plugins_url, ucl_err) + enabled = false; + return + end + local obj = parser:get_object() + for node, value in pairs(obj['nodes']) do + local plugin_found = false + for _, plugin in pairs(value['plugins']) do + if plugin['name'] == 'ingest-geoip' then + plugin_found = true + lua_util.debugm(N, "ingest-geoip plugin has been found") + end + end + if not plugin_found then + rspamd_logger.infox(rspamd_config, + 'Unable to find ingest-geoip on %1 node, disabling module', node) + enabled = false + return + end + end + else + rspamd_logger.errx('cannot get plugins from %s: %s(%s) (%s)', plugins_url, + err, code, body) + enabled = false + end + end + rspamd_http.request({ + url = plugins_url, + ev_base = ev_base, + config = cfg, + method = 'get', + callback = http_callback, + no_ssl_verify = settings.no_ssl_verify, + user = settings.user, + password = settings.password, + timeout = settings.timeout, + }) +end + +-- import ingest pipeline and kibana dashboard/visualization +local function initial_setup(cfg, ev_base, worker) + if not worker:is_primary_controller() then + return + end + + local upstream = settings.upstream:get_upstream_round_robin() + local ip_addr = upstream:get_addr():to_string(true) + + local function push_kibana_template() + -- add kibana dashboard and visualizations + if settings['import_kibana'] then + local kibana_mappings = read_file(settings['kibana_file']) + if kibana_mappings then + local parser = ucl.parser() + local res, parser_err = parser:parse_string(kibana_mappings) + if not res then + rspamd_logger.infox(rspamd_config, 'kibana template cannot be parsed: %s', + parser_err) + enabled = false + + return + end + local obj = parser:get_object() + local tbl = {} + for _, item in ipairs(obj) do + table.insert(tbl, '{ "index" : { "_index" : ".kibana", "_type" : "doc" ,"_id": "' .. + item['_type'] .. ':' .. item["_id"] .. '"} }') + table.insert(tbl, ucl.to_format(item['_source'], 'json-compact')) + end + table.insert(tbl, '') -- For last \n + + local kibana_url = connect_prefix .. ip_addr .. '/.kibana/_bulk' + local function kibana_template_callback(err, code, body, _) + if code ~= 200 then + rspamd_logger.errx('cannot put template to %s: %s(%s) (%s)', kibana_url, + err, code, body) + enabled = false + else + lua_util.debugm(N, 'pushed kibana template: %s', body) + end + end + + rspamd_http.request({ + url = kibana_url, + ev_base = ev_base, + config = cfg, + headers = { + ['Content-Type'] = 'application/x-ndjson', + }, + body = table.concat(tbl, "\n"), + method = 'post', + gzip = settings.use_gzip, + callback = kibana_template_callback, + no_ssl_verify = settings.no_ssl_verify, + user = settings.user, + password = settings.password, + timeout = settings.timeout, + }) + else + rspamd_logger.infox(rspamd_config, 'kibana template file %s not found', settings['kibana_file']) + end + end + end + + if enabled then + -- create ingest pipeline + local geoip_url = connect_prefix .. ip_addr .. '/_ingest/pipeline/rspamd-geoip' + local function geoip_cb(err, code, body, _) + if code ~= 200 then + rspamd_logger.errx('cannot get data from %s: %s(%s) (%s)', + geoip_url, err, code, body) + enabled = false + end + end + local template = { + description = "Add geoip info for rspamd", + processors = { + { + geoip = { + field = "rspamd_meta.ip", + target_field = "rspamd_meta.geoip" + } + } + } + } + rspamd_http.request({ + url = geoip_url, + ev_base = ev_base, + config = cfg, + callback = geoip_cb, + headers = { + ['Content-Type'] = 'application/json', + }, + gzip = settings.use_gzip, + body = ucl.to_format(template, 'json-compact'), + method = 'put', + no_ssl_verify = settings.no_ssl_verify, + user = settings.user, + password = settings.password, + timeout = settings.timeout, + }) + -- create template mappings if not exist + local template_url = connect_prefix .. ip_addr .. '/_template/rspamd' + local function http_template_put_callback(err, code, body, _) + if code ~= 200 then + rspamd_logger.errx('cannot put template to %s: %s(%s) (%s)', + template_url, err, code, body) + enabled = false + else + lua_util.debugm(N, 'pushed rspamd template: %s', body) + push_kibana_template() + end + end + local function http_template_exist_callback(_, code, _, _) + if code ~= 200 then + rspamd_http.request({ + url = template_url, + ev_base = ev_base, + config = cfg, + body = elastic_template, + method = 'put', + headers = { + ['Content-Type'] = 'application/json', + }, + gzip = settings.use_gzip, + callback = http_template_put_callback, + no_ssl_verify = settings.no_ssl_verify, + user = settings.user, + password = settings.password, + timeout = settings.timeout, + }) + else + push_kibana_template() + end + end + + rspamd_http.request({ + url = template_url, + ev_base = ev_base, + config = cfg, + method = 'head', + callback = http_template_exist_callback, + no_ssl_verify = settings.no_ssl_verify, + user = settings.user, + password = settings.password, + timeout = settings.timeout, + }) + + end +end + +redis_params = rspamd_redis.parse_redis_server('elastic') + +if redis_params and opts then + for k, v in pairs(opts) do + settings[k] = v + end + + if not settings['server'] and not settings['servers'] then + rspamd_logger.infox(rspamd_config, 'no servers are specified, disabling module') + lua_util.disable_module(N, "config") + else + if settings.use_https then + connect_prefix = 'https://' + end + + if settings.ingest_module then + ingest_geoip_type = 'modules' + end + + settings.upstream = upstream_list.create(rspamd_config, + settings['server'] or settings['servers'], 9200) + + if not settings.upstream then + rspamd_logger.errx('cannot parse elastic address: %s', + settings['server'] or settings['servers']) + lua_util.disable_module(N, "config") + return + end + if not settings['template_file'] then + rspamd_logger.infox(rspamd_config, 'elastic template_file is required, disabling module') + lua_util.disable_module(N, "config") + return + end + + elastic_template = read_file(settings['template_file']); + if not elastic_template then + rspamd_logger.infox(rspamd_config, 'elastic unable to read %s, disabling module', + settings['template_file']) + lua_util.disable_module(N, "config") + return + end + + rspamd_config:register_symbol({ + name = 'ELASTIC_COLLECT', + type = 'idempotent', + callback = elastic_collect, + flags = 'empty,explicit_disable,ignore_passthrough', + augmentations = { string.format("timeout=%f", settings.timeout) }, + }) + + rspamd_config:add_on_load(function(cfg, ev_base, worker) + if worker:is_scanner() then + check_elastic_server(cfg, ev_base, worker) -- check for elasticsearch requirements + initial_setup(cfg, ev_base, worker) -- import mappings pipeline and visualizations + end + end) + end + +end diff --git a/src/plugins/lua/emails.lua b/src/plugins/lua/emails.lua new file mode 100644 index 0000000..5f25e69 --- /dev/null +++ b/src/plugins/lua/emails.lua @@ -0,0 +1,4 @@ +-- This module is deprecated and must not be used. +-- This file serves as a tombstone to prevent old emails to be loaded + +return
\ No newline at end of file diff --git a/src/plugins/lua/external_relay.lua b/src/plugins/lua/external_relay.lua new file mode 100644 index 0000000..3660f92 --- /dev/null +++ b/src/plugins/lua/external_relay.lua @@ -0,0 +1,285 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[ +external_relay plugin - sets IP/hostname from Received headers +]]-- + +if confighelp then + return +end + +local lua_maps = require "lua_maps" +local lua_util = require "lua_util" +local rspamd_logger = require "rspamd_logger" +local ts = require("tableshape").types + +local E = {} +local N = "external_relay" + +local settings = { + rules = {}, +} + +local config_schema = ts.shape { + enabled = ts.boolean:is_optional(), + rules = ts.map_of( + ts.string, ts.one_of { + ts.shape { + priority = ts.number:is_optional(), + strategy = 'authenticated', + symbol = ts.string:is_optional(), + user_map = lua_maps.map_schema:is_optional(), + }, + ts.shape { + count = ts.number, + priority = ts.number:is_optional(), + strategy = 'count', + symbol = ts.string:is_optional(), + }, + ts.shape { + priority = ts.number:is_optional(), + strategy = 'local', + symbol = ts.string:is_optional(), + }, + ts.shape { + hostname_map = lua_maps.map_schema, + priority = ts.number:is_optional(), + strategy = 'hostname_map', + symbol = ts.string:is_optional(), + }, + ts.shape { + ip_map = lua_maps.map_schema, + priority = ts.number:is_optional(), + strategy = 'ip_map', + symbol = ts.string:is_optional(), + }, + } + ), +} + +local function set_from_rcvd(task, rcvd) + local rcvd_ip = rcvd.real_ip + if not (rcvd_ip and rcvd_ip:is_valid()) then + rspamd_logger.errx(task, 'no IP in header: %s', rcvd) + return + end + task:set_from_ip(rcvd_ip) + if rcvd.from_hostname then + task:set_hostname(rcvd.from_hostname) + task:set_helo(rcvd.from_hostname) -- use fake value for HELO + else + rspamd_logger.warnx(task, "couldn't get hostname from headers") + local ipstr = string.format('[%s]', rcvd_ip) + task:set_hostname(ipstr) -- returns nil from task:get_hostname() + task:set_helo(ipstr) + end + return true +end + +local strategies = {} + +strategies.authenticated = function(rule) + local user_map + if rule.user_map then + user_map = lua_maps.map_add_from_ucl(rule.user_map, 'set', 'external relay usernames') + if not user_map then + rspamd_logger.errx(rspamd_config, "couldn't add map %s; won't register symbol %s", + rule.user_map, rule.symbol) + return + end + end + + return function(task) + local user = task:get_user() + if not user then + lua_util.debugm(N, task, 'sender is unauthenticated') + return + end + if user_map then + if not user_map:get_key(user) then + lua_util.debugm(N, task, 'sender (%s) is not in user_map', user) + return + end + end + + local rcvd_hdrs = task:get_received_headers() + -- Try find end of authentication chain + for _, rcvd in ipairs(rcvd_hdrs) do + if not rcvd.flags.authenticated then + -- Found unauthenticated hop, use this header + return set_from_rcvd(task, rcvd) + end + end + + rspamd_logger.errx(task, 'found nothing useful in Received headers') + end +end + +strategies.count = function(rule) + return function(task) + local rcvd_hdrs = task:get_received_headers() + -- Reduce count by 1 if artificial header is present + local hdr_count + if ((rcvd_hdrs[1] or E).flags or E).artificial then + hdr_count = rule.count - 1 + else + hdr_count = rule.count + end + + local rcvd = rcvd_hdrs[hdr_count] + if not rcvd then + rspamd_logger.errx(task, 'found no received header #%s', hdr_count) + return + end + + return set_from_rcvd(task, rcvd) + end +end + +strategies.hostname_map = function(rule) + local hostname_map = lua_maps.map_add_from_ucl(rule.hostname_map, 'map', 'external relay hostnames') + if not hostname_map then + rspamd_logger.errx(rspamd_config, "couldn't add map %s; won't register symbol %s", + rule.hostname_map, rule.symbol) + return + end + + return function(task) + local from_hn = task:get_hostname() + if not from_hn then + lua_util.debugm(N, task, 'sending hostname is missing') + return + end + + if not hostname_map:get_key(from_hn) then + lua_util.debugm(N, task, 'sender\'s hostname (%s) is not a relay', from_hn) + return + end + + local rcvd_hdrs = task:get_received_headers() + -- Try find sending hostname in Received headers + for _, rcvd in ipairs(rcvd_hdrs) do + if rcvd.by_hostname == from_hn and rcvd.real_ip then + if not hostname_map:get_key(rcvd.from_hostname) then + -- Remote hostname is not another relay, use this header + return set_from_rcvd(task, rcvd) + else + -- Keep checking with new hostname + from_hn = rcvd.from_hostname + end + end + end + + rspamd_logger.errx(task, 'found nothing useful in Received headers') + end +end + +strategies.ip_map = function(rule) + local ip_map = lua_maps.map_add_from_ucl(rule.ip_map, 'radix', 'external relay IPs') + if not ip_map then + rspamd_logger.errx(rspamd_config, "couldn't add map %s; won't register symbol %s", + rule.ip_map, rule.symbol) + return + end + + return function(task) + local from_ip = task:get_from_ip() + if not (from_ip and from_ip:is_valid()) then + lua_util.debugm(N, task, 'sender\'s IP is missing') + return + end + + if not ip_map:get_key(from_ip) then + lua_util.debugm(N, task, 'sender\'s ip (%s) is not a relay', from_ip) + return + end + + local rcvd_hdrs = task:get_received_headers() + local num_rcvd = #rcvd_hdrs + -- Try find sending IP in Received headers + for i, rcvd in ipairs(rcvd_hdrs) do + if rcvd.real_ip then + local rcvd_ip = rcvd.real_ip + if rcvd_ip:is_valid() and (not ip_map:get_key(rcvd_ip) or i == num_rcvd) then + return set_from_rcvd(task, rcvd) + end + end + end + + rspamd_logger.errx(task, 'found nothing useful in Received headers') + end +end + +strategies['local'] = function(rule) + return function(task) + local from_ip = task:get_from_ip() + if not from_ip then + lua_util.debugm(N, task, 'sending IP is missing') + return + end + + if not from_ip:is_local() then + lua_util.debugm(N, task, 'sending IP (%s) is non-local', from_ip) + return + end + + local rcvd_hdrs = task:get_received_headers() + local num_rcvd = #rcvd_hdrs + -- Try find first non-local IP in Received headers + for i, rcvd in ipairs(rcvd_hdrs) do + if rcvd.real_ip then + local rcvd_ip = rcvd.real_ip + if rcvd_ip and rcvd_ip:is_valid() and (not rcvd_ip:is_local() or i == num_rcvd) then + return set_from_rcvd(task, rcvd) + end + end + end + + rspamd_logger.errx(task, 'found nothing useful in Received headers') + end +end + +local opts = rspamd_config:get_all_opt(N) +if opts then + settings = lua_util.override_defaults(settings, opts) + + local ok, schema_err = config_schema:transform(settings) + if not ok then + rspamd_logger.errx(rspamd_config, 'config schema error: %s', schema_err) + lua_util.disable_module(N, "config") + return + end + + for k, rule in pairs(settings.rules) do + + if not rule.symbol then + rule.symbol = k + end + + local cb = strategies[rule.strategy](rule) + + if cb then + rspamd_config:register_symbol({ + name = rule.symbol, + type = 'prefilter', + priority = rule.priority or lua_util.symbols_priorities.top + 1, + group = N, + callback = cb, + }) + end + end +end diff --git a/src/plugins/lua/external_services.lua b/src/plugins/lua/external_services.lua new file mode 100644 index 0000000..e299d9f --- /dev/null +++ b/src/plugins/lua/external_services.lua @@ -0,0 +1,408 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> +Copyright (c) 2019, Carsten Rosenberg <c.rosenberg@heinlein-support.de> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]] -- + +local rspamd_logger = require "rspamd_logger" +local lua_util = require "lua_util" +local lua_redis = require "lua_redis" +local fun = require "fun" +local lua_scanners = require("lua_scanners").filter('scanner') +local common = require "lua_scanners/common" +local redis_params + +local N = "external_services" + +if confighelp then + rspamd_config:add_example(nil, 'external_services', + "Check messages using external services (e.g. OEM AS engines, DCC, Pyzor etc)", + [[ + external_services { + # multiple scanners could be checked, for each we create a configuration block with an arbitrary name + + oletools { + # If set force this action if any virus is found (default unset: no action is forced) + # action = "reject"; + # If set, then rejection message is set to this value (mention single quotes) + # If `max_size` is set, messages > n bytes in size are not scanned + # max_size = 20000000; + # log_clean = true; + # servers = "127.0.0.1:10050"; + # cache_expire = 86400; + # scan_mime_parts = true; + # extended = false; + # if `patterns` is specified virus name will be matched against provided regexes and the related + # symbol will be yielded if a match is found. If no match is found, default symbol is yielded. + patterns { + # symbol_name = "pattern"; + JUST_EICAR = "^Eicar-Test-Signature$"; + } + # mime-part regex matching in content-type or filename + mime_parts_filter_regex { + #GEN1 = "application\/octet-stream"; + DOC2 = "application\/msword"; + DOC3 = "application\/vnd\.ms-word.*"; + XLS = "application\/vnd\.ms-excel.*"; + PPT = "application\/vnd\.ms-powerpoint.*"; + GEN2 = "application\/vnd\.openxmlformats-officedocument.*"; + } + # Mime-Part filename extension matching (no regex) + mime_parts_filter_ext { + doc = "doc"; + dot = "dot"; + docx = "docx"; + dotx = "dotx"; + docm = "docm"; + dotm = "dotm"; + xls = "xls"; + xlt = "xlt"; + xla = "xla"; + xlsx = "xlsx"; + xltx = "xltx"; + xlsm = "xlsm"; + xltm = "xltm"; + xlam = "xlam"; + xlsb = "xlsb"; + ppt = "ppt"; + pot = "pot"; + pps = "pps"; + ppa = "ppa"; + pptx = "pptx"; + potx = "potx"; + ppsx = "ppsx"; + ppam = "ppam"; + pptm = "pptm"; + potm = "potm"; + ppsm = "ppsm"; + } + # `whitelist` points to a map of IP addresses. Mail from these addresses is not scanned. + whitelist = "/etc/rspamd/antivirus.wl"; + } + dcc { + # If set force this action if any virus is found (default unset: no action is forced) + # action = "reject"; + # If set, then rejection message is set to this value (mention single quotes) + # If `max_size` is set, messages > n bytes in size are not scanned + max_size = 20000000; + #servers = "127.0.0.1:10045; + # if `patterns` is specified virus name will be matched against provided regexes and the related + # symbol will be yielded if a match is found. If no match is found, default symbol is yielded. + patterns { + # symbol_name = "pattern"; + JUST_EICAR = "^Eicar-Test-Signature$"; + } + # `whitelist` points to a map of IP addresses. Mail from these addresses is not scanned. + whitelist = "/etc/rspamd/antivirus.wl"; + } + } + ]]) + return +end + +local function add_scanner_rule(sym, opts) + if not opts.type then + rspamd_logger.errx(rspamd_config, 'unknown type for external scanner rule %s', sym) + return nil + end + + local cfg = lua_scanners[opts.type] + + if not cfg then + rspamd_logger.errx(rspamd_config, 'unknown external scanner type: %s', + opts.type) + return nil + end + + local rule = cfg.configure(opts) + + if not rule then + rspamd_logger.errx(rspamd_config, 'cannot configure %s for %s', + opts.type, rule.symbol or sym:upper()) + return nil + end + + rule.type = opts.type + -- Fill missing symbols + if not rule.symbol then + rule.symbol = sym:upper() + end + if not rule.symbol_fail then + rule.symbol_fail = rule.symbol .. '_FAIL' + end + if not rule.symbol_encrypted then + rule.symbol_encrypted = rule.symbol .. '_ENCRYPTED' + end + if not rule.symbol_macro then + rule.symbol_macro = rule.symbol .. '_MACRO' + end + + rule.redis_params = redis_params + + lua_redis.register_prefix(rule.prefix .. '_*', N, + string.format('External services cache for rule "%s"', + rule.type), { + type = 'string', + }) + + -- if any mime_part filter defined, do not scan all attachments + if opts.mime_parts_filter_regex ~= nil + or opts.mime_parts_filter_ext ~= nil then + rule.scan_all_mime_parts = false + else + rule.scan_all_mime_parts = true + end + + rule.patterns = common.create_regex_table(opts.patterns or {}) + rule.patterns_fail = common.create_regex_table(opts.patterns_fail or {}) + + rule.mime_parts_filter_regex = common.create_regex_table(opts.mime_parts_filter_regex or {}) + + rule.mime_parts_filter_ext = common.create_regex_table(opts.mime_parts_filter_ext or {}) + + if opts.whitelist then + rule.whitelist = rspamd_config:add_hash_map(opts.whitelist) + end + + local function scan_cb(task) + if rule.scan_mime_parts then + + fun.each(function(p) + local content = p:get_content() + if content and #content > 0 then + cfg.check(task, content, p:get_digest(), rule, p) + end + end, common.check_parts_match(task, rule)) + + else + cfg.check(task, task:get_content(), task:get_digest(), rule, nil) + end + end + + rspamd_logger.infox(rspamd_config, 'registered external services rule: symbol %s; type %s', + rule.symbol, rule.type) + + return scan_cb, rule +end + +-- Registration +local opts = rspamd_config:get_all_opt(N) +if opts and type(opts) == 'table' then + redis_params = lua_redis.parse_redis_server(N) + local has_valid = false + for k, m in pairs(opts) do + if type(m) == 'table' and m.servers then + if not m.type then + m.type = k + end + if not m.name then + m.name = k + end + local cb, nrule = add_scanner_rule(k, m) + + if not cb then + rspamd_logger.errx(rspamd_config, 'cannot add rule: "' .. k .. '"') + else + m = nrule + + local t = { + name = m.symbol, + callback = cb, + score = 0.0, + group = N + } + + if m.symbol_type == 'postfilter' then + t.type = 'postfilter' + t.priority = lua_util.symbols_priorities.medium + else + t.type = 'normal' + end + + t.augmentations = {} + + if type(m.timeout) == 'number' then + -- Here, we ignore possible DNS timeout and timeout from multiple retries + -- as these situations are not usual nor likely for the external_services module + table.insert(t.augmentations, string.format("timeout=%f", m.timeout)) + end + + local id = rspamd_config:register_symbol(t) + + if m.symbol_fail then + rspamd_config:register_symbol({ + type = 'virtual', + name = m['symbol_fail'], + parent = id, + score = 0.0, + group = N + }) + end + + if m.symbol_encrypted then + rspamd_config:register_symbol({ + type = 'virtual', + name = m['symbol_encrypted'], + parent = id, + score = 0.0, + group = N + }) + end + if m.symbol_macro then + rspamd_config:register_symbol({ + type = 'virtual', + name = m['symbol_macro'], + parent = id, + score = 0.0, + group = N + }) + end + has_valid = true + if type(m['patterns']) == 'table' then + if m['patterns'][1] then + for _, p in ipairs(m['patterns']) do + if type(p) == 'table' then + for sym in pairs(p) do + rspamd_logger.debugm(N, rspamd_config, 'registering: %1', { + type = 'virtual', + name = sym, + parent = m['symbol'], + parent_id = id, + }) + rspamd_config:register_symbol({ + type = 'virtual', + name = sym, + parent = id, + score = 0.0, + group = N + }) + end + end + end + else + for sym in pairs(m['patterns']) do + rspamd_config:register_symbol({ + type = 'virtual', + name = sym, + parent = id, + score = 0.0, + group = N + }) + end + end + end + if type(m['patterns_fail']) == 'table' then + if m['patterns_fail'][1] then + for _, p in ipairs(m['patterns_fail']) do + if type(p) == 'table' then + for sym in pairs(p) do + rspamd_logger.debugm(N, rspamd_config, 'registering: %1', { + type = 'virtual', + name = sym, + parent = m['symbol'], + parent_id = id, + }) + rspamd_config:register_symbol({ + type = 'virtual', + name = sym, + parent = id, + score = 0.0, + group = N + }) + end + end + end + else + for sym in pairs(m['patterns_fail']) do + rspamd_config:register_symbol({ + type = 'virtual', + name = sym, + parent = id, + score = 0.0, + group = N + }) + end + end + end + if m.symbols then + local function reg_symbols(tbl) + for _, sym in pairs(tbl) do + if type(sym) == 'string' then + rspamd_config:register_symbol({ + type = 'virtual', + name = sym, + parent = id, + group = N + }) + elseif type(sym) == 'table' then + if sym.symbol then + rspamd_config:register_symbol({ + type = 'virtual', + name = sym.symbol, + parent = id, + group = N + }) + + if sym.score then + rspamd_config:set_metric_symbol({ + name = sym.symbol, + score = sym.score, + description = sym.description, + group = sym.group or N, + }) + end + else + reg_symbols(sym) + end + end + end + end + + reg_symbols(m.symbols) + end + + if m['score'] then + -- Register metric symbol + local description = 'external services symbol' + local group = N + if m['description'] then + description = m['description'] + end + if m['group'] then + group = m['group'] + end + rspamd_config:set_metric_symbol({ + name = m['symbol'], + score = m['score'], + description = description, + group = group + }) + end + + -- Add preloads if a module requires that + if type(m.preloads) == 'table' then + for _, preload in ipairs(m.preloads) do + rspamd_config:add_on_load(function(cfg, ev_base, worker) + preload(m, cfg, ev_base, worker) + end) + end + end + end + end + end + + if not has_valid then + lua_util.disable_module(N, 'config') + end +end diff --git a/src/plugins/lua/force_actions.lua b/src/plugins/lua/force_actions.lua new file mode 100644 index 0000000..4a87cf5 --- /dev/null +++ b/src/plugins/lua/force_actions.lua @@ -0,0 +1,227 @@ +--[[ +Copyright (c) 2017, Andrew Lewis <nerf@judo.za.org> +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +-- A plugin that forces actions + +if confighelp then + return +end + +local E = {} +local N = 'force_actions' +local selector_cache = {} + +local fun = require "fun" +local lua_util = require "lua_util" +local rspamd_cryptobox_hash = require "rspamd_cryptobox_hash" +local rspamd_expression = require "rspamd_expression" +local rspamd_logger = require "rspamd_logger" +local lua_selectors = require "lua_selectors" + +-- Params table fields: +-- expr, act, pool, message, subject, raction, honor, limit, flags +local function gen_cb(params) + + local function parse_atom(str) + local atom = table.concat(fun.totable(fun.take_while(function(c) + if string.find(', \t()><+!|&\n', c, 1, true) then + return false + end + return true + end, fun.iter(str))), '') + return atom + end + + local function process_atom(atom, task) + local f_ret = task:has_symbol(atom) + if f_ret then + f_ret = math.abs(task:get_symbol(atom)[1].score) + if f_ret < 0.001 then + -- Adjust some low score to distinguish from pure zero + f_ret = 0.001 + end + return f_ret + end + return 0 + end + + local e, err = rspamd_expression.create(params.expr, { parse_atom, process_atom }, params.pool) + if err then + rspamd_logger.errx(rspamd_config, 'Couldnt create expression [%1]: %2', params.expr, err) + return + end + + return function(task) + + local function process_message_selectors(repl, selector_expr) + -- create/reuse selector to extract value for this placeholder + local selector = selector_cache[selector_expr] + if not selector then + selector_cache[selector_expr] = lua_selectors.create_selector_closure(rspamd_config, selector_expr, '', true) + selector = selector_cache[selector_expr] + if not selector then + rspamd_logger.errx(task, 'could not create selector [%1]', selector_expr) + return "((could not create selector))" + end + end + local extracted = selector(task) + if extracted then + if type(extracted) == 'table' then + extracted = table.concat(extracted, ',') + end + else + rspamd_logger.errx(task, 'could not extract value with selector [%1]', selector_expr) + extracted = '((error extracting value))' + end + return extracted + end + + local cact = task:get_metric_action() + if not params.message and not params.subject and params.act and cact == params.act then + return false + end + if params.honor and params.honor[cact] then + return false + elseif params.raction and not params.raction[cact] then + return false + end + + local ret = e:process(task) + lua_util.debugm(N, task, "expression %s returned %s", params.expr, ret) + if (not params.limit and ret > 0) or (ret > (params.limit or 0)) then + if params.subject then + task:set_metric_subject(params.subject) + end + + local flags = params.flags or "" + + if type(params.message) == 'string' then + -- process selector expressions in the message + local message = string.gsub(params.message, '(${(.-)})', process_message_selectors) + task:set_pre_result { action = params.act, message = message, module = N, flags = flags } + else + task:set_pre_result { action = params.act, module = N, flags = flags } + end + return true, params.act + end + + end, e:atoms() + +end + +local function configure_module() + local opts = rspamd_config:get_all_opt(N) + if not opts then + return false + end + if type(opts.actions) == 'table' then + rspamd_logger.warnx(rspamd_config, 'Processing legacy config') + for action, expressions in pairs(opts.actions) do + if type(expressions) == 'table' then + for _, expr in ipairs(expressions) do + local message, subject + if type(expr) == 'table' then + subject = expr[3] + message = expr[2] + expr = expr[1] + else + message = (opts.messages or E)[expr] + end + if type(expr) == 'string' then + -- expr, act, pool, message, subject, raction, honor, limit, flags + local cb, atoms = gen_cb { expr = expr, + act = action, + pool = rspamd_config:get_mempool(), + message = message, + subject = subject } + if cb and atoms then + local h = rspamd_cryptobox_hash.create() + h:update(expr) + local name = 'FORCE_ACTION_' .. string.upper(string.sub(h:hex(), 1, 12)) + rspamd_config:register_symbol({ + type = 'normal', + name = name, + callback = cb, + flags = 'empty', + group = N, + }) + for _, a in ipairs(atoms) do + rspamd_config:register_dependency(name, a) + end + rspamd_logger.infox(rspamd_config, 'Registered symbol %1 <%2> with dependencies [%3]', + name, expr, table.concat(atoms, ',')) + end + end + end + end + end + elseif type(opts.rules) == 'table' then + for name, sett in pairs(opts.rules) do + local action = sett.action + local expr = sett.expression + + if action and expr then + local flags = {} + if sett.least then + table.insert(flags, "least") + end + if sett.process_all then + table.insert(flags, "process_all") + end + local raction = lua_util.list_to_hash(sett.require_action) + local honor = lua_util.list_to_hash(sett.honor_action) + local cb, atoms = gen_cb { expr = expr, + act = action, + pool = rspamd_config:get_mempool(), + message = sett.message, + subject = sett.subject, + raction = raction, + honor = honor, + limit = sett.limit, + flags = table.concat(flags, ',') } + if cb and atoms then + local t = {} + if (raction or honor) then + t.type = 'postfilter' + t.priority = lua_util.symbols_priorities.high + else + t.type = 'normal' + if not sett.least then + t.augmentations = { 'passthrough', 'important' } + end + end + t.name = 'FORCE_ACTION_' .. name + t.callback = cb + t.flags = 'empty, ignore_passthrough' + t.group = N + rspamd_config:register_symbol(t) + if t.type == 'normal' then + for _, a in ipairs(atoms) do + rspamd_config:register_dependency(t.name, a) + end + rspamd_logger.infox(rspamd_config, 'Registered symbol %1 <%2> with dependencies [%3]', + t.name, expr, table.concat(atoms, ',')) + else + rspamd_logger.infox(rspamd_config, 'Registered symbol %1 <%2> as postfilter', t.name, expr) + end + end + end + end + end +end + +configure_module() diff --git a/src/plugins/lua/forged_recipients.lua b/src/plugins/lua/forged_recipients.lua new file mode 100644 index 0000000..0d51db3 --- /dev/null +++ b/src/plugins/lua/forged_recipients.lua @@ -0,0 +1,183 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +-- Plugin for comparing smtp dialog recipients and sender with recipients and sender +-- in mime headers + +if confighelp then + rspamd_config:add_example(nil, 'forged_recipients', + "Check forged recipients and senders (e.g. mime and smtp recipients mismatch)", + [[ + forged_recipients { + symbol_sender = "FORGED_SENDER"; # Symbol for a forged sender + symbol_rcpt = "FORGED_RECIPIENTS"; # Symbol for a forged recipients + } + ]]) +end + +local symbol_rcpt = 'FORGED_RECIPIENTS' +local symbol_sender = 'FORGED_SENDER' +local rspamd_util = require "rspamd_util" + +local E = {} + +local function check_forged_headers(task) + local auser = task:get_user() + local delivered_to = task:get_header('Delivered-To') + local smtp_rcpts = task:get_recipients(1) + local smtp_from = task:get_from(1) + + if not smtp_rcpts then + return + end + if #smtp_rcpts == 0 then + return + end + + local mime_rcpts = task:get_recipients({ 'mime', 'orig' }) + + if not mime_rcpts then + return + elseif #mime_rcpts == 0 then + return + end + + -- Find pair for each smtp recipient in To or Cc headers + if #smtp_rcpts > 100 or #mime_rcpts > 100 then + -- Trim array, suggested by Anton Yuzhaninov + smtp_rcpts[100] = nil + mime_rcpts[100] = nil + end + + -- map smtp recipient domains to a list of addresses for this domain + local smtp_rcpt_domain_map = {} + local smtp_rcpt_map = {} + for _, smtp_rcpt in ipairs(smtp_rcpts) do + local addr = smtp_rcpt.addr + + if addr and addr ~= '' then + local dom = string.lower(smtp_rcpt.domain) + addr = addr:lower() + + local dom_map = smtp_rcpt_domain_map[dom] + if not dom_map then + dom_map = {} + smtp_rcpt_domain_map[dom] = dom_map + end + + dom_map[addr] = smtp_rcpt + smtp_rcpt_map[addr] = smtp_rcpt + + if auser and auser == addr then + smtp_rcpt.matched = true + end + if ((smtp_from or E)[1] or E).addr and + smtp_from[1]['addr'] == addr then + -- allow sender to BCC themselves + smtp_rcpt.matched = true + end + end + end + + for _, mime_rcpt in ipairs(mime_rcpts) do + if mime_rcpt.addr and mime_rcpt.addr ~= '' then + local addr = string.lower(mime_rcpt.addr) + local dom = string.lower(mime_rcpt.domain) + local matched_smtp_addr = smtp_rcpt_map[addr] + if matched_smtp_addr then + -- Direct match, go forward + matched_smtp_addr.matched = true + mime_rcpt.matched = true + elseif delivered_to and delivered_to == addr then + mime_rcpt.matched = true + elseif auser and auser == addr then + -- allow user to BCC themselves + mime_rcpt.matched = true + else + local matched_smtp_domain = smtp_rcpt_domain_map[dom] + + if matched_smtp_domain then + -- Same domain but another user, it is likely okay due to aliases substitution + mime_rcpt.matched = true + -- Special field + matched_smtp_domain._seen_mime_domain = true + end + end + end + end + + -- Now go through all lists one more time and find unmatched stuff + local opts = {} + local seen_mime_unmatched = false + local seen_smtp_unmatched = false + for _, mime_rcpt in ipairs(mime_rcpts) do + if not mime_rcpt.matched then + seen_mime_unmatched = true + table.insert(opts, 'm:' .. mime_rcpt.addr) + end + end + for _, smtp_rcpt in ipairs(smtp_rcpts) do + if not smtp_rcpt.matched then + if not smtp_rcpt_domain_map[smtp_rcpt.domain:lower()]._seen_mime_domain then + seen_smtp_unmatched = true + table.insert(opts, 's:' .. smtp_rcpt.addr) + end + end + end + + if seen_smtp_unmatched and seen_mime_unmatched then + task:insert_result(symbol_rcpt, 1.0, opts) + end + + -- Check sender + if smtp_from and smtp_from[1] and smtp_from[1]['addr'] ~= '' then + local mime_from = task:get_from(2) + if not mime_from or not mime_from[1] or + not rspamd_util.strequal_caseless_utf8(mime_from[1]['addr'], smtp_from[1]['addr']) then + task:insert_result(symbol_sender, 1, ((mime_from or E)[1] or E).addr or '', smtp_from[1].addr) + end + end +end + +-- Configuration +local opts = rspamd_config:get_all_opt('forged_recipients') +if opts then + if opts['symbol_rcpt'] or opts['symbol_sender'] then + local id = rspamd_config:register_symbol({ + name = 'FORGED_CALLBACK', + callback = check_forged_headers, + type = 'callback', + group = 'headers', + score = 0.0, + }) + if opts['symbol_rcpt'] then + symbol_rcpt = opts['symbol_rcpt'] + rspamd_config:register_symbol({ + name = symbol_rcpt, + type = 'virtual', + parent = id, + }) + end + if opts['symbol_sender'] then + symbol_sender = opts['symbol_sender'] + rspamd_config:register_symbol({ + name = symbol_sender, + type = 'virtual', + parent = id, + }) + end + end +end diff --git a/src/plugins/lua/fuzzy_collect.lua b/src/plugins/lua/fuzzy_collect.lua new file mode 100644 index 0000000..132ace9 --- /dev/null +++ b/src/plugins/lua/fuzzy_collect.lua @@ -0,0 +1,193 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]] -- + +if confighelp then + return +end + +local rspamd_logger = require "rspamd_logger" +local rspamd_util = require "rspamd_util" +local rspamd_http = require "rspamd_http" +local rspamd_keypairlib = require "rspamd_cryptobox_keypair" +local rspamd_cryptolib = require "rspamd_cryptobox" +local fun = require "fun" + +local settings = { + sync_time = 60.0, + saved_cookie = '', + timeout = 10.0, +} + +local function send_data_mirror(m, cfg, ev_base, body) + local function store_callback(err, _, _, _) + if err then + rspamd_logger.errx(cfg, 'cannot save data on %(%s): %s', m.server, m.name, err) + else + rspamd_logger.infox(cfg, 'saved data on %s(%s)', m.server, m.name) + end + end + rspamd_http.request { + url = string.format('http://%s//update_v1/%s', m.server, m.name), + resolver = cfg:get_resolver(), + config = cfg, + ev_base = ev_base, + timeout = settings.timeout, + callback = store_callback, + body = body, + peer_key = m.pubkey, + keypair = m.keypair, + } +end + +local function collect_fuzzy_hashes(cfg, ev_base) + local function data_callback(err, _, body, _) + if not body or err then + rspamd_logger.errx(cfg, 'cannot load data: %s', err) + else + -- Here, we actually copy body once for each mirror + fun.each(function(_, v) + send_data_mirror(v, cfg, ev_base, body) + end, + settings.mirrors) + end + end + + local function cookie_callback(err, _, body, _) + if not body or err then + rspamd_logger.errx(cfg, 'cannot load cookie: %s', err) + else + if settings.saved_cookie ~= tostring(body) then + settings.saved_cookie = tostring(body) + rspamd_logger.infox(cfg, 'received collection cookie %s', + tostring(rspamd_util.encode_base32(settings.saved_cookie:sub(1, 6)))) + local sig = rspamd_cryptolib.sign_memory(settings.sign_keypair, + settings.saved_cookie) + if not sig then + rspamd_logger.info(cfg, 'cannot sign cookie') + else + rspamd_http.request { + url = string.format('http://%s/data', settings.collect_server), + resolver = cfg:get_resolver(), + config = cfg, + ev_base = ev_base, + timeout = settings.timeout, + callback = data_callback, + peer_key = settings.collect_pubkey, + headers = { + Signature = sig:hex() + }, + opaque_body = true, + } + end + else + rspamd_logger.info(cfg, 'cookie has not changed, do not update') + end + end + end + rspamd_logger.infox(cfg, 'start fuzzy collection, next sync in %s seconds', + settings.sync_time) + rspamd_http.request { + url = string.format('http://%s/cookie', settings.collect_server), + resolver = cfg:get_resolver(), + config = cfg, + ev_base = ev_base, + timeout = settings.timeout, + callback = cookie_callback, + peer_key = settings.collect_pubkey, + } + + return settings.sync_time +end + +local function test_mirror_config(k, m) + if not m.server then + rspamd_logger.errx(rspamd_config, 'server is missing for the mirror') + return false + end + + if not m.pubkey then + rspamd_logger.errx(rspamd_config, 'pubkey is missing for the mirror') + return false + end + + if type(k) ~= 'string' and not m.name then + rspamd_logger.errx(rspamd_config, 'name is missing for the mirror') + return false + end + + if not m.keypair then + rspamd_logger.errx(rspamd_config, 'keypair is missing for the mirror') + return false + end + + if not m.name then + m.name = k + end + + return true +end + +local opts = rspamd_config:get_all_opt('fuzzy_collect') + +if opts and type(opts) == 'table' then + for k, v in pairs(opts) do + settings[k] = v + end + local sane_config = true + + if not settings['sign_keypair'] then + rspamd_logger.errx(rspamd_config, 'sign_keypair is missing') + sane_config = false + end + + settings['sign_keypair'] = rspamd_keypairlib.create(settings['sign_keypair']) + if not settings['sign_keypair'] then + rspamd_logger.errx(rspamd_config, 'sign_keypair is invalid') + sane_config = false + end + + if not settings['collect_server'] then + rspamd_logger.errx(rspamd_config, 'collect_server is missing') + sane_config = false + end + + if not settings['collect_pubkey'] then + rspamd_logger.errx(rspamd_config, 'collect_pubkey is missing') + sane_config = false + end + + if not settings['mirrors'] then + rspamd_logger.errx(rspamd_config, 'collect_pubkey is missing') + sane_config = false + end + + if not fun.all(test_mirror_config, settings['mirrors']) then + sane_config = false + end + + if sane_config then + rspamd_config:add_on_load(function(_, ev_base, worker) + if worker:is_primary_controller() then + rspamd_config:add_periodic(ev_base, 0.0, + function(cfg, _) + return collect_fuzzy_hashes(cfg, ev_base) + end) + end + end) + else + rspamd_logger.errx(rspamd_config, 'module is not configured properly') + end +end diff --git a/src/plugins/lua/greylist.lua b/src/plugins/lua/greylist.lua new file mode 100644 index 0000000..6e221b3 --- /dev/null +++ b/src/plugins/lua/greylist.lua @@ -0,0 +1,542 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> +Copyright (c) 2016, Alexey Savelyev <info@homeweb.ru> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[ +Example domains whitelist config: +greylist { + # Search "example.com" and "mail.example.com" for "mx.out.mail.example.com": + whitelist_domains_url = [ + "$LOCAL_CONFDIR/local.d/maps.d/greylist-whitelist-domains.inc", + "${CONFDIR}/maps.d/maillist.inc", + "${CONFDIR}/maps.d/redirectors.inc", + "${CONFDIR}/maps.d/dmarc_whitelist.inc", + "${CONFDIR}/maps.d/spf_dkim_whitelist.inc", + "${CONFDIR}/maps.d/surbl-whitelist.inc", + "https://maps.rspamd.com/freemail/free.txt.zst" + ]; +} +Example config for exim users: +greylist { + action = "greylist"; +} +--]] + +if confighelp then + rspamd_config:add_example(nil, 'greylist', + "Performs adaptive greylisting using Redis", + [[ +greylist { + # Buckets expire (1 day by default) + expire = 1d; + # Greylisting timeout + timeout = 5m; + # Redis prefix + key_prefix = 'rg'; + # Use body hash up to this value of bytes for greylisting + max_data_len = 10k; + # Default greylisting message + message = 'Try again later'; + # Append symbol on greylisting + symbol = 'GREYLIST'; + # Default action change (for Exim use `greylist`) + action = 'soft reject'; + # Skip greylisting if one of the following symbols has been found + whitelist_symbols = []; + # Mask bits for ipv4 + ipv4_mask = 19; + # Mask bits for ipv6 + ipv6_mask = 64; + # Tell when greylisting is expired (appended to `message`) + report_time = false; + # Greylist local messages + check_local = false; + # Greylist messages from authenticated users + check_authed = false; +} + ]]) + return +end + +-- A plugin that implements greylisting using redis + +local redis_params +local whitelisted_ip +local whitelist_domains_map +local toint = math.ifloor or math.floor +local settings = { + expire = 86400, -- 1 day by default + timeout = 300, -- 5 minutes by default + key_prefix = 'rg', -- default hash name + max_data_len = 10240, -- default data limit to hash + message = 'Try again later', -- default greylisted message + symbol = 'GREYLIST', + action = 'soft reject', -- default greylisted action + whitelist_symbols = {}, -- whitelist when specific symbols have been found + ipv4_mask = 19, -- Mask bits for ipv4 + ipv6_mask = 64, -- Mask bits for ipv6 + report_time = false, -- Tell when greylisting is expired (appended to `message`) + check_local = false, + check_authed = false, +} + +local rspamd_logger = require "rspamd_logger" +local rspamd_util = require "rspamd_util" +local lua_redis = require "lua_redis" +local lua_util = require "lua_util" +local fun = require "fun" +local hash = require "rspamd_cryptobox_hash" +local rspamd_lua_utils = require "lua_util" +local lua_map = require "lua_maps" +local N = "greylist" + +local function data_key(task) + local cached = task:get_mempool():get_variable("grey_bodyhash") + if cached then + return cached + end + + local body = task:get_rawbody() + + if not body then + return nil + end + + local len = body:len() + if len > settings['max_data_len'] then + len = settings['max_data_len'] + end + + local h = hash.create() + h:update(body, len) + + local b32 = settings['key_prefix'] .. 'b' .. h:base32():sub(1, 20) + task:get_mempool():set_variable("grey_bodyhash", b32) + return b32 +end + +local function envelope_key(task) + local cached = task:get_mempool():get_variable("grey_metahash") + if cached then + return cached + end + + local from = task:get_from('smtp') + local h = hash.create() + + local addr = '<>' + if from and from[1] then + addr = from[1]['addr'] + end + + h:update(addr) + local rcpt = task:get_recipients('smtp') + if rcpt then + table.sort(rcpt, function(r1, r2) + return r1['addr'] < r2['addr'] + end) + + fun.each(function(r) + h:update(r['addr']) + end, rcpt) + end + + local ip = task:get_ip() + + if ip and ip:is_valid() then + local s + if ip:get_version() == 4 then + s = tostring(ip:apply_mask(settings['ipv4_mask'])) + else + s = tostring(ip:apply_mask(settings['ipv6_mask'])) + end + h:update(s) + end + + local b32 = settings['key_prefix'] .. 'm' .. h:base32():sub(1, 20) + task:get_mempool():set_variable("grey_metahash", b32) + return b32 +end + +-- Returns pair of booleans: found,greylisted +local function check_time(task, tm, type, now) + local t = tonumber(tm) + + if not t then + rspamd_logger.errx(task, 'not a valid number: %s', tm) + return false, false + end + + if now - t < settings['timeout'] then + return true, true + else + -- We just set variable to pass when in post-filter stage + task:get_mempool():set_variable("grey_whitelisted", type) + + return true, false + end +end + +local function greylist_message(task, end_time, why) + task:insert_result(settings['symbol'], 0.0, 'greylisted', end_time, why) + + if not settings.check_local and rspamd_lua_utils.is_rspamc_or_controller(task) then + return + end + + if settings.message_func then + task:set_pre_result(settings['action'], + settings.message_func(task, end_time), N) + else + local message = settings['message'] + if settings.report_time then + message = string.format("%s: %s", message, end_time) + end + task:set_pre_result(settings['action'], message, N) + end + + task:set_flag('greylisted') +end + +local function greylist_check(task) + local ip = task:get_ip() + + if ((not settings.check_authed and task:get_user()) or + (not settings.check_local and ip and ip:is_local())) then + rspamd_logger.infox(task, "skip greylisting for local networks and/or authorized users"); + return + end + + if ip and ip:is_valid() and whitelisted_ip then + if whitelisted_ip:get_key(ip) then + -- Do not check whitelisted ip + rspamd_logger.infox(task, 'skip greylisting for whitelisted IP') + return + end + end + + local body_key = data_key(task) + local meta_key = envelope_key(task) + local hash_key = body_key .. meta_key + + local function redis_get_cb(err, data) + local ret_body = false + local greylisted_body = false + local ret_meta = false + local greylisted_meta = false + + if data then + local end_time_body, end_time_meta + local now = rspamd_util.get_time() + + if data[1] and type(data[1]) ~= 'userdata' then + local tm = tonumber(data[1]) or now + ret_body, greylisted_body = check_time(task, data[1], 'body', now) + if greylisted_body then + end_time_body = tm + settings['timeout'] + task:get_mempool():set_variable("grey_greylisted_body", + rspamd_util.time_to_string(end_time_body)) + end + end + + if data[2] and type(data[2]) ~= 'userdata' then + if not ret_body or greylisted_body then + local tm = tonumber(data[2]) or now + ret_meta, greylisted_meta = check_time(task, data[2], 'meta', now) + + if greylisted_meta then + end_time_meta = tm + settings['timeout'] + task:get_mempool():set_variable("grey_greylisted_meta", + rspamd_util.time_to_string(end_time_meta)) + end + end + end + + local how + local end_time_str + + if not ret_body and not ret_meta then + -- no record found + task:get_mempool():set_variable("grey_greylisted", 'true') + elseif greylisted_body and greylisted_meta then + end_time_str = rspamd_util.time_to_string( + math.min(end_time_body, end_time_meta)) + how = 'meta and body' + elseif greylisted_body then + end_time_str = rspamd_util.time_to_string(end_time_body) + how = 'body only' + elseif greylisted_meta then + end_time_str = rspamd_util.time_to_string(end_time_meta) + how = 'meta only' + end + + if how and end_time_str then + rspamd_logger.infox(task, 'greylisted until "%s" (%s)', + end_time_str, how) + greylist_message(task, end_time_str, 'too early') + end + elseif err then + rspamd_logger.errx(task, 'got error while getting greylisting keys: %1', err) + return + end + end + + local ret = lua_redis.redis_make_request(task, + redis_params, -- connect params + hash_key, -- hash key + false, -- is write + redis_get_cb, --callback + 'MGET', -- command + { body_key, meta_key } -- arguments + ) + if not ret then + rspamd_logger.errx(task, 'cannot make redis request to check results') + end +end + +local function greylist_set(task) + local action = task:get_metric_action() + local ip = task:get_ip() + + -- Don't do anything if pre-result has been already set + if task:has_pre_result() then + return + end + + -- Check whitelist_symbols + for _, sym in ipairs(settings.whitelist_symbols) do + if task:has_symbol(sym) then + rspamd_logger.infox(task, 'skip greylisting as we have found symbol %s', sym) + if action == 'greylist' then + -- We are going to accept message + rspamd_logger.infox(task, 'downgrading metric action from "greylist" to "no action"') + task:disable_action('greylist') + end + return + end + end + + if settings.greylist_min_score then + local score = task:get_metric_score('default')[1] + if score < settings.greylist_min_score then + rspamd_logger.infox(task, 'Score too low - skip greylisting') + if action == 'greylist' then + -- We are going to accept message + rspamd_logger.infox(task, 'Downgrading metric action from "greylist" to "no action"') + task:disable_action('greylist') + end + return + end + end + + if ((not settings.check_authed and task:get_user()) or + (not settings.check_local and ip and ip:is_local())) then + if action == 'greylist' then + -- We are going to accept message + rspamd_logger.infox(task, 'Downgrading metric action from "greylist" to "no action"') + task:disable_action('greylist') + end + return + end + + if ip and ip:is_valid() and whitelisted_ip then + if whitelisted_ip:get_key(ip) then + if action == 'greylist' then + -- We are going to accept message + rspamd_logger.infox(task, 'Downgrading metric action from "greylist" to "no action"') + task:disable_action('greylist') + end + return + end + end + + local is_whitelisted = task:get_mempool():get_variable("grey_whitelisted") + local do_greylisting = task:get_mempool():get_variable("grey_greylisted") + local do_greylisting_required = task:get_mempool():get_variable("grey_greylisted_required") + + -- Third and second level domains whitelist + if not is_whitelisted and whitelist_domains_map then + local hostname = task:get_hostname() + if hostname then + local domain = rspamd_util.get_tld(hostname) + if whitelist_domains_map:get_key(hostname) or (domain and whitelist_domains_map:get_key(domain)) then + is_whitelisted = 'meta' + rspamd_logger.infox(task, 'skip greylisting for whitelisted domain') + end + end + end + + if action == 'reject' or + not do_greylisting_required and action == 'no action' then + return + end + local body_key = data_key(task) + local meta_key = envelope_key(task) + local upstream, ret, conn + local hash_key = body_key .. meta_key + + local function redis_set_cb(err) + if err then + rspamd_logger.errx(task, 'got error %s when setting greylisting record on server %s', + err, upstream:get_addr()) + end + end + + local is_rspamc = rspamd_lua_utils.is_rspamc_or_controller(task) + + if is_whitelisted then + if action == 'greylist' then + -- We are going to accept message + rspamd_logger.infox(task, 'Downgrading metric action from "greylist" to "no action"') + task:disable_action('greylist') + end + + task:insert_result(settings['symbol'], 0.0, 'pass', is_whitelisted) + rspamd_logger.infox(task, 'greylisting pass (%s) until %s', + is_whitelisted, + rspamd_util.time_to_string(rspamd_util.get_time() + settings['expire'])) + + if not settings.check_local and is_rspamc then + return + end + + ret, conn, upstream = lua_redis.redis_make_request(task, + redis_params, -- connect params + hash_key, -- hash key + true, -- is write + redis_set_cb, --callback + 'EXPIRE', -- command + { body_key, tostring(toint(settings['expire'])) } -- arguments + ) + -- Update greylisting record expire + if ret then + conn:add_cmd('EXPIRE', { + meta_key, tostring(toint(settings['expire'])) + }) + else + rspamd_logger.errx(task, 'got error while connecting to redis') + end + elseif do_greylisting or do_greylisting_required then + if not settings.check_local and is_rspamc then + return + end + local t = tostring(toint(rspamd_util.get_time())) + local end_time = rspamd_util.time_to_string(t + settings['timeout']) + rspamd_logger.infox(task, 'greylisted until "%s", new record', end_time) + greylist_message(task, end_time, 'new record') + -- Create new record + ret, conn, upstream = lua_redis.redis_make_request(task, + redis_params, -- connect params + hash_key, -- hash key + true, -- is write + redis_set_cb, --callback + 'SETEX', -- command + { body_key, tostring(toint(settings['expire'])), t } -- arguments + ) + + if ret then + conn:add_cmd('SETEX', { + meta_key, tostring(toint(settings['expire'])), t + }) + else + rspamd_logger.errx(task, 'got error while connecting to redis') + end + else + if action ~= 'no action' and action ~= 'reject' then + local grey_res = task:get_mempool():get_variable("grey_greylisted_body") + + if grey_res then + -- We need to delay message, hence set a temporary result + rspamd_logger.infox(task, 'greylisting delayed until "%s": body', grey_res) + greylist_message(task, grey_res, 'body') + else + grey_res = task:get_mempool():get_variable("grey_greylisted_meta") + if grey_res then + greylist_message(task, grey_res, 'meta') + end + end + else + task:insert_result(settings['symbol'], 0.0, 'greylisted', 'passed') + end + end +end + +local opts = rspamd_config:get_all_opt('greylist') +if opts then + if opts['message_func'] then + settings.message_func = assert(load(opts['message_func']))() + end + + for k, v in pairs(opts) do + if k ~= 'message_func' then + settings[k] = v + end + end + + local auth_and_local_conf = lua_util.config_check_local_or_authed(rspamd_config, N, + false, false) + settings.check_local = auth_and_local_conf[1] + settings.check_authed = auth_and_local_conf[2] + + if settings['greylist_min_score'] then + settings['greylist_min_score'] = tonumber(settings['greylist_min_score']) + else + local greylist_threshold = rspamd_config:get_metric_action('greylist') + if greylist_threshold then + settings['greylist_min_score'] = greylist_threshold + end + end + + whitelisted_ip = lua_map.rspamd_map_add(N, 'whitelisted_ip', 'radix', + 'Greylist whitelist ip map') + whitelist_domains_map = lua_map.rspamd_map_add(N, 'whitelist_domains_url', + 'map', 'Greylist whitelist domains map') + + redis_params = lua_redis.parse_redis_server(N) + if not redis_params then + rspamd_logger.infox(rspamd_config, 'no servers are specified, disabling module') + rspamd_lua_utils.disable_module(N, "redis") + else + lua_redis.register_prefix(settings.key_prefix .. 'b[a-z0-9]{20}', N, + 'Greylisting elements (body hashes)"', { + type = 'string', + }) + lua_redis.register_prefix(settings.key_prefix .. 'm[a-z0-9]{20}', N, + 'Greylisting elements (meta hashes)"', { + type = 'string', + }) + rspamd_config:register_symbol({ + name = 'GREYLIST_SAVE', + type = 'postfilter', + callback = greylist_set, + priority = lua_util.symbols_priorities.medium, + augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) }, + }) + local id = rspamd_config:register_symbol({ + name = 'GREYLIST_CHECK', + type = 'prefilter', + callback = greylist_check, + priority = lua_util.symbols_priorities.medium, + augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) } + }) + rspamd_config:register_symbol({ + name = settings.symbol, + type = 'virtual', + parent = id, + score = 0, + }) + end +end diff --git a/src/plugins/lua/hfilter.lua b/src/plugins/lua/hfilter.lua new file mode 100644 index 0000000..8c132f5 --- /dev/null +++ b/src/plugins/lua/hfilter.lua @@ -0,0 +1,622 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> +Copyright (c) 2013-2015, Alexey Savelyev <info@homeweb.ru> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + + +-- Weight for checks_hellohost and checks_hello: 5 - very hard, 4 - hard, 3 - medium, 2 - low, 1 - very low. +-- From HFILTER_HELO_* and HFILTER_HOSTNAME_* symbols the maximum weight is selected in case of their actuating. + +if confighelp then + return +end + +local rspamd_regexp = require "rspamd_regexp" +local lua_util = require "lua_util" +local rspamc_local_helo = "rspamc.local" +local checks_hellohost = [[ +/[-.0-9][0-9][.-]?nat/i 5 +/homeuser[.-][0-9]/i 5 +/[-.0-9][0-9][.-]?unused-addr/i 3 +/[-.0-9][0-9][.-]?pppoe/i 5 +/[-.0-9][0-9][.-]?dynamic/i 5 +/[.-]catv[.-]/i 5 +/unused-addr[.-][0-9]/i 3 +/comcast[.-][0-9]/i 5 +/[.-]broadband[.-]/i 5 +/[0-9][.-]?fbx/i 4 +/[.-]peer[.-]/i 1 +/[.-]homeuser[.-]/i 5 +/[-.0-9][0-9][.-]?catv/i 5 +/customers?[.-][0-9]/i 1 +/[.-]wifi[.-]/i 5 +/[0-9][.-]?kabel/i 3 +/dynip[.-][0-9]/i 5 +/[.-]broad[.-]/i 5 +/[a|x]?dsl-line[.-]?[0-9]/i 4 +/[-.0-9][0-9][.-]?ppp/i 5 +/pool[.-][0-9]/i 4 +/[.-]nat[.-]/i 5 +/gprs[.-][0-9]/i 5 +/brodband[.-][0-9]/i 5 +/[.-]gprs[.-]/i 5 +/[.-]user[.-]/i 1 +/[-.0-9][0-9][.-]?in-?addr/i 4 +/[.-]host[.-]/i 2 +/[.-]fbx[.-]/i 4 +/dynamic[.-][0-9]/i 5 +/[-.0-9][0-9][.-]?peer/i 1 +/[-.0-9][0-9][.-]?pool/i 4 +/[-.0-9][0-9][.-]?user/i 1 +/[.-]cdma[.-]/i 5 +/user[.-][0-9]/i 1 +/[-.0-9][0-9][.-]?customers?/i 1 +/ppp[.-][0-9]/i 5 +/kabel[.-][0-9]/i 3 +/dhcp[.-][0-9]/i 5 +/peer[.-][0-9]/i 1 +/[-.0-9][0-9][.-]?host/i 2 +/clients?[.-][0-9]{2,}/i 5 +/host[.-][0-9]/i 2 +/[.-]ppp[.-]/i 5 +/[.-]dhcp[.-]/i 5 +/[.-]comcast[.-]/i 5 +/cable[.-][0-9]/i 3 +/[-.0-9][0-9][.-]?dial-?up/i 5 +/[-.0-9][0-9][.-]?bredband/i 5 +/[-.0-9][0-9][.-]?[a|x]?dsl-line/i 4 +/[.-]dial-?up[.-]/i 5 +/[.-]cablemodem[.-]/i 5 +/pppoe[.-][0-9]/i 5 +/[.-]unused-addr[.-]/i 3 +/pptp[.-][0-9]/i 5 +/broadband[.-][0-9]/i 5 +/[.-][a|x]?dsl-line[.-]/i 4 +/[.-]customers?[.-]/i 1 +/[-.0-9][0-9][.-]?fibertel/i 4 +/[-.0-9][0-9][.-]?comcast/i 5 +/[.-]dynamic[.-]/i 5 +/cdma[.-][0-9]/i 5 +/[0-9][.-]?broad/i 5 +/fbx[.-][0-9]/i 4 +/catv[.-][0-9]/i 5 +/[-.0-9][0-9][.-]?homeuser/i 5 +/[-.0-9][.-]pppoe[.-]/i 5 +/[-.0-9][.-]dynip[.-]/i 5 +/[-.0-9][0-9][.-]?[a|x]?dsl/i 4 +/[-.0-9][0-9]{3,}[.-]?clients?/i 5 +/[-.0-9][0-9][.-]?pptp/i 5 +/[.-]clients?[.-]/i 1 +/[.-]in-?addr[.-]/i 4 +/[.-]pool[.-]/i 4 +/[a|x]?dsl[.-]?[0-9]/i 4 +/[.-][a|x]?dsl[.-]/i 4 +/[-.0-9][0-9][.-]?[a|x]?dsl-dynamic/i 5 +/dial-?up[.-][0-9]/i 5 +/[-.0-9][0-9][.-]?cablemodem/i 5 +/[a|x]?dsl-dynamic[.-]?[0-9]/i 5 +/[.-]pptp[.-]/i 5 +/[.-][a|x]?dsl-dynamic[.-]/i 5 +/[0-9][.-]?wifi/i 5 +/fibertel[.-][0-9]/i 4 +/dyn[.-][0-9][-.0-9]/i 5 +/[-.0-9][0-9][.-]broadband/i 5 +/[-.0-9][0-9][.-]cable/i 3 +/broad[.-][0-9]/i 5 +/[-.0-9][0-9][.-]gprs/i 5 +/cablemodem[.-][0-9]/i 5 +/[-.0-9][0-9][.-]modem/i 5 +/[-.0-9][0-9][.-]dyn/i 5 +/[-.0-9][0-9][.-]dynip/i 5 +/[-.0-9][0-9][.-]cdma/i 5 +/[.-]modem[.-]/i 5 +/[.-]kabel[.-]/i 3 +/[.-]cable[.-]/i 3 +/in-?addr[.-][0-9]/i 4 +/nat[.-][0-9]/i 5 +/[.-]fibertel[.-]/i 4 +/[.-]bredband[.-]/i 5 +/modem[.-][0-9]/i 5 +/[0-9][.-]?dhcp/i 5 +/wifi[.-][0-9]/i 5 +]] +local checks_hellohost_map + +local checks_hello = [[ +/^[^\.]+$/i 5 # for helo=COMPUTER, ANNA, etc... Without dot in helo +/^(dsl)?(device|speedtouch)\.lan$/i 5 +/\.(lan|local|home|localdomain|intra|in-addr.arpa|priv|user|veloxzon)$/i 5 +]] +local checks_hello_map + +local checks_hello_badip = [[ +/^\d\.\d\.\d\.255$/i 1 +/^192\.0\.0\./i 1 +/^2001:db8::/i 1 +/^10\./i 1 +/^192\.0\.2\./i 1 +/^172\.1[6-9]\./i 1 +/^192\.168\./i 1 +/^::1$/i 1 # loopback ipv4, ipv6 +/^ffxx::/i 1 +/^fc00::/i 1 +/^203\.0\.113\./i 1 +/^fe[cdf][0-9a-f]:/i 1 +/^100.12[0-7]\d\./i 1 +/^fe[89ab][0-9a-f]::/i 1 +/^169\.254\./i 1 +/^0\./i 1 +/^198\.51\.100\./i 1 +/^172\.3[01]\./i 1 +/^100.[7-9]\d\./i 1 +/^100.1[01]\d\./i 1 +/^127\./i 1 +/^100.6[4-9]\./i 1 +/^192\.88\.99\./i 1 +/^172\.2[0-9]\./i 1 +]] +local checks_hello_badip_map + +local checks_hello_bareip = [[ +/^\d+[x.-]\d+[x.-]\d+[x.-]\d+$/ +/^[0-9a-f]+:/ +]] +local checks_hello_bareip_map + +local config = { + ['helo_enabled'] = false, + ['hostname_enabled'] = false, + ['from_enabled'] = false, + ['rcpt_enabled'] = false, + ['mid_enabled'] = false, + ['url_enabled'] = false +} + +local compiled_regexp = {} -- cache of regexps +local check_local = false +local check_authed = false +local N = "hfilter" + +local function check_regexp(str, regexp_text) + local re = compiled_regexp[regexp_text] + if not re then + re = rspamd_regexp.create(regexp_text, 'i') + compiled_regexp[regexp_text] = re + end + + return re:match(str) +end + +local function add_static_map(data) + return rspamd_config:add_map { + type = 'regexp_multi', + url = { + upstreams = 'static', + data = data, + } + } +end + +local function check_fqdn(domain) + if check_regexp(domain, + '(?=^.{4,253}$)(^((?!-)[a-zA-Z0-9-]{1,63}(?<!-)\\.)+[a-zA-Z0-9-]{2,63}\\.?$)') then + return true + end + return false +end + +-- host: host for check +-- symbol_suffix: suffix for symbol +-- eq_ip: ip for comparing or empty string +-- eq_host: host for comparing or empty string +local function check_host(task, host, symbol_suffix, eq_ip, eq_host) + local failed_address = 0 + local resolved_address = {} + + local function check_host_cb_mx(_, to_resolve, results, err) + if err and (err ~= 'requested record is not found' and err ~= 'no records with this name') then + lua_util.debugm(N, task, 'error looking up %s: %s', to_resolve, err) + end + if not results then + task:insert_result('HFILTER_' .. symbol_suffix .. '_NORES_A_OR_MX', 1.0, + to_resolve) + else + for _, mx in pairs(results) do + if mx['name'] then + local failed_mx_address = 0 + -- Capture failed_mx_address + local function check_host_cb_mx_a(_, _, mx_results) + if not mx_results then + failed_mx_address = failed_mx_address + 1 + end + + if failed_mx_address >= 2 then + task:insert_result('HFILTER_' .. symbol_suffix .. '_NORESOLVE_MX', + 1.0, mx['name']) + end + end + + task:get_resolver():resolve('a', { + task = task, + name = mx['name'], + callback = check_host_cb_mx_a + }) + task:get_resolver():resolve('aaaa', { + task = task, + name = mx['name'], + callback = check_host_cb_mx_a + }) + end + end + end + end + local function check_host_cb_a(_, _, results) + if not results then + failed_address = failed_address + 1 + else + for _, result in pairs(results) do + table.insert(resolved_address, result:to_string()) + end + end + + if failed_address >= 2 then + -- No A or AAAA records + if eq_ip and eq_ip ~= '' then + for _, result in pairs(resolved_address) do + if result == eq_ip then + return true + end + end + task:insert_result('HFILTER_' .. symbol_suffix .. '_IP_A', 1.0, host) + end + task:get_resolver():resolve_mx({ + task = task, + name = host, + callback = check_host_cb_mx + }) + end + end + + if host then + host = string.lower(host) + else + return false + end + if eq_host then + eq_host = string.lower(eq_host) + else + eq_host = '' + end + + if check_fqdn(host) then + if eq_host == '' or eq_host ~= host then + task:get_resolver():resolve('a', { + task = task, + name = host, + callback = check_host_cb_a + }) + -- Check ipv6 as well + task:get_resolver():resolve('aaaa', { + task = task, + name = host, + callback = check_host_cb_a + }) + end + else + task:insert_result('HFILTER_' .. symbol_suffix .. '_NOT_FQDN', 1.0, host) + end + + return true +end + +-- +local function hfilter_callback(task) + -- Links checks + if config['url_enabled'] then + local parts = task:get_text_parts() + if parts then + local plain_text_part, html_text_part + + for _, p in ipairs(parts) do + if p:is_html() then + html_text_part = p + else + plain_text_part = p + end + end + + local function check_text_part(part, ty) + local url_len = part:get_urls_length() + local plen = part:get_length() + + if plen > 0 and url_len > 0 then + local rel = url_len / plen + if rel > 0.8 then + local sc = (rel - 0.8) * 5.0 + if sc > 1.0 then + sc = 1.0 + end + task:insert_result('HFILTER_URL_ONLY', sc, tostring(sc)) + local lines = part:get_lines_count() + if lines > 0 and lines < 2 then + task:insert_result('HFILTER_URL_ONELINE', 1.00, + string.format('%s:%d:%d', ty, math.floor(rel), lines)) + end + end + end + end + if html_text_part then + check_text_part(html_text_part, 'html') + elseif plain_text_part then + check_text_part(plain_text_part, 'plain') + end + end + end + + --No more checks for auth user or local network + local rip = task:get_from_ip() + if ((not check_authed and task:get_user()) or + (not check_local and rip and rip:is_local())) then + return false + end + + --local message = task:get_message() + local ip = false + if rip and rip:is_valid() then + ip = rip:to_string() + end + + -- Check's HELO + local weight_helo = 0 + local helo + if config['helo_enabled'] then + helo = task:get_helo() + if helo then + if helo ~= rspamc_local_helo then + helo = string.gsub(helo, '[%[%]]', '') + -- Regexp check HELO (checks_hello_badip) + local find_badip = false + local values = checks_hello_badip_map:get_key(helo) + if values then + task:insert_result('HFILTER_HELO_BADIP', 1.0, helo, values) + find_badip = true + end + + -- Regexp check HELO (checks_hello_bareip) + local find_bareip = false + if not find_badip then + values = checks_hello_bareip_map:get_key(helo) + if values then + task:insert_result('HFILTER_HELO_BAREIP', 1.0, helo, values) + find_bareip = true + end + end + + if not find_badip and not find_bareip then + -- Regexp check HELO (checks_hello) + local weights = checks_hello_map:get_key(helo) + for _, weight in ipairs(weights or {}) do + weight = tonumber(weight) or 0 + if weight > weight_helo then + weight_helo = weight + end + end + -- Regexp check HELO (checks_hellohost) + weights = checks_hellohost_map:get_key(helo) + for _, weight in ipairs(weights or {}) do + weight = tonumber(weight) or 0 + if weight > weight_helo then + weight_helo = weight + end + end + --FQDN check HELO + if ip and helo and weight_helo == 0 then + check_host(task, helo, 'HELO', ip) + end + end + end + end + end + + -- Check's HOSTNAME + local weight_hostname = 0 + local hostname = task:get_hostname() + + if config['hostname_enabled'] then + if hostname then + -- Check regexp HOSTNAME + local weights = checks_hellohost_map:get_key(hostname) + for _, weight in ipairs(weights or {}) do + weight = tonumber(weight) or 0 + if weight > weight_hostname then + weight_hostname = weight + end + end + else + task:insert_result('HFILTER_HOSTNAME_UNKNOWN', 1.00) + end + end + + --Insert weight's for HELO or HOSTNAME + if weight_helo > 0 and weight_helo >= weight_hostname then + task:insert_result('HFILTER_HELO_' .. weight_helo, 1.0, helo) + elseif weight_hostname > 0 and weight_hostname > weight_helo then + task:insert_result('HFILTER_HOSTNAME_' .. weight_hostname, 1.0, hostname) + end + + -- MAILFROM checks -- + local frombounce = false + if config['from_enabled'] then + local from = task:get_from(1) + if from then + --FROM host check + for _, fr in ipairs(from) do + local fr_split = rspamd_str_split(fr['addr'], '@') + if #fr_split == 2 then + check_host(task, fr_split[2], 'FROMHOST', '', '') + if fr_split[1] == 'postmaster' then + frombounce = true + end + end + end + else + if helo and helo ~= rspamc_local_helo then + task:insert_result('HFILTER_FROM_BOUNCE', 1.00, helo) + frombounce = true + end + end + end + + -- Recipients checks -- + if config['rcpt_enabled'] then + local rcpt = task:get_recipients() + if rcpt then + local count_rcpt = #rcpt + if frombounce then + if count_rcpt > 1 then + task:insert_result('HFILTER_RCPT_BOUNCEMOREONE', 1.00, + tostring(count_rcpt)) + end + end + end + end + + --Message ID host check + if config['mid_enabled'] then + local message_id = task:get_message_id() + if message_id then + local mid_split = rspamd_str_split(message_id, '@') + if #mid_split == 2 and not string.find(mid_split[2], 'local') then + check_host(task, mid_split[2], 'MID') + end + end + end + + return false +end + +local symbols_enabled = {} + +local symbols_helo = { + "HFILTER_HELO_BAREIP", + "HFILTER_HELO_BADIP", + "HFILTER_HELO_1", + "HFILTER_HELO_2", + "HFILTER_HELO_3", + "HFILTER_HELO_4", + "HFILTER_HELO_5", + "HFILTER_HELO_NORESOLVE_MX", + "HFILTER_HELO_NORES_A_OR_MX", + "HFILTER_HELO_IP_A", + "HFILTER_HELO_NOT_FQDN" +} +local symbols_hostname = { + "HFILTER_HOSTNAME_1", + "HFILTER_HOSTNAME_2", + "HFILTER_HOSTNAME_3", + "HFILTER_HOSTNAME_4", + "HFILTER_HOSTNAME_5", + "HFILTER_HOSTNAME_UNKNOWN" +} +local symbols_rcpt = { + "HFILTER_RCPT_BOUNCEMOREONE" +} +local symbols_mid = { + "HFILTER_MID_NORESOLVE_MX", + "HFILTER_MID_NORES_A_OR_MX", + "HFILTER_MID_NOT_FQDN" +} +local symbols_url = { + "HFILTER_URL_ONLY", + "HFILTER_URL_ONELINE" +} +local symbols_from = { + "HFILTER_FROMHOST_NORESOLVE_MX", + "HFILTER_FROMHOST_NORES_A_OR_MX", + "HFILTER_FROMHOST_NOT_FQDN", + "HFILTER_FROM_BOUNCE" +} + +local auth_and_local_conf = lua_util.config_check_local_or_authed(rspamd_config, N, + false, false) +check_local = auth_and_local_conf[1] +check_authed = auth_and_local_conf[2] +local timeout = 0.0 + +local opts = rspamd_config:get_all_opt('hfilter') +if opts then + for k, v in pairs(opts) do + config[k] = v + end +end + +local function append_t(t, a) + for _, v in ipairs(a) do + table.insert(t, v) + end +end +if config['helo_enabled'] then + checks_hello_bareip_map = add_static_map(checks_hello_bareip) + checks_hello_badip_map = add_static_map(checks_hello_badip) + checks_hellohost_map = add_static_map(checks_hellohost) + checks_hello_map = add_static_map(checks_hello) + append_t(symbols_enabled, symbols_helo) + timeout = math.max(timeout, rspamd_config:get_dns_timeout() * 3) +end +if config['hostname_enabled'] then + if not checks_hellohost_map then + checks_hellohost_map = add_static_map(checks_hellohost) + end + append_t(symbols_enabled, symbols_hostname) + timeout = math.max(timeout, rspamd_config:get_dns_timeout()) +end +if config['from_enabled'] then + append_t(symbols_enabled, symbols_from) + timeout = math.max(timeout, rspamd_config:get_dns_timeout()) +end +if config['rcpt_enabled'] then + append_t(symbols_enabled, symbols_rcpt) +end +if config['mid_enabled'] then + append_t(symbols_enabled, symbols_mid) +end +if config['url_enabled'] then + append_t(symbols_enabled, symbols_url) +end + +--dumper(symbols_enabled) +if #symbols_enabled > 0 then + local id = rspamd_config:register_symbol { + name = 'HFILTER_CHECK', + callback = hfilter_callback, + type = 'callback', + augmentations = { string.format("timeout=%f", timeout) }, + } + for _, sym in ipairs(symbols_enabled) do + rspamd_config:register_symbol { + type = 'virtual', + score = 1.0, + parent = id, + name = sym, + } + rspamd_config:set_metric_symbol({ + name = sym, + score = 0.0, + group = 'hfilter' + }) + end +else + lua_util.disable_module(N, "config") +end diff --git a/src/plugins/lua/history_redis.lua b/src/plugins/lua/history_redis.lua new file mode 100644 index 0000000..d0aa5ae --- /dev/null +++ b/src/plugins/lua/history_redis.lua @@ -0,0 +1,314 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if confighelp then + rspamd_config:add_example(nil, 'history_redis', + "Store history of checks for WebUI using Redis", + [[ +redis_history { + # History key name + key_prefix = 'rs_history'; + # History expire in seconds + expire = 0; + # History rows limit + nrows = 200; + # Use zstd compression when storing data in redis + compress = true; + # Obfuscate subjects for privacy + subject_privacy = false; + # Default hash-algorithm to obfuscate subject + subject_privacy_alg = 'blake2'; + # Prefix to show it's obfuscated + subject_privacy_prefix = 'obf'; + # Cut the length of the hash if desired + subject_privacy_length = 16; +} + ]]) + return +end + +local rspamd_logger = require "rspamd_logger" +local rspamd_util = require "rspamd_util" +local lua_util = require "lua_util" +local lua_redis = require "lua_redis" +local fun = require "fun" +local ucl = require "ucl" +local ts = (require "tableshape").types +local lua_verdict = require "lua_verdict" +local E = {} +local N = "history_redis" +local hostname = rspamd_util.get_hostname() + +local redis_params + +local settings = { + key_prefix = 'rs_history', -- default key name + expire = nil, -- default no expire + nrows = 200, -- default rows limit + compress = true, -- use zstd compression when storing data in redis + subject_privacy = false, -- subject privacy is off + subject_privacy_alg = 'blake2', -- default hash-algorithm to obfuscate subject + subject_privacy_prefix = 'obf', -- prefix to show it's obfuscated + subject_privacy_length = 16, -- cut the length of the hash +} + +local settings_schema = lua_redis.enrich_schema({ + key_prefix = ts.string, + expire = (ts.number + ts.string / lua_util.parse_time_interval):is_optional(), + nrows = ts.number, + compress = ts.boolean, + subject_privacy = ts.boolean:is_optional(), + subject_privacy_alg = ts.string:is_optional(), + subject_privacy_prefix = ts.string:is_optional(), + subject_privacy_length = ts.number:is_optional(), +}) + +local function process_addr(addr) + if addr then + return addr.addr + end + + return 'unknown' +end + +local function normalise_results(tbl, task) + local metric = tbl.default + + -- Convert stupid metric object + if metric then + tbl.symbols = {} + local symbols, others = fun.partition(function(_, v) + return type(v) == 'table' and v.score + end, metric) + + fun.each(function(k, v) + v.name = nil; + tbl.symbols[k] = v; + end, symbols) + fun.each(function(k, v) + tbl[k] = v + end, others) + + -- Reset the original metric + tbl.default = nil + end + + -- Now, add recipients and senders + tbl.sender_smtp = process_addr((task:get_from('smtp') or E)[1]) + tbl.sender_mime = process_addr((task:get_from('mime') or E)[1]) + tbl.rcpt_smtp = fun.totable(fun.map(process_addr, task:get_recipients('smtp') or {})) + tbl.rcpt_mime = fun.totable(fun.map(process_addr, task:get_recipients('mime') or {})) + tbl.user = task:get_user() or 'unknown' + tbl.rmilter = nil + tbl.messages = nil + tbl.urls = nil + tbl.action = lua_verdict.adjust_passthrough_action(task) + + local seconds = task:get_timeval()['tv_sec'] + tbl.unix_time = seconds + + local subject = task:get_header('subject') or 'unknown' + tbl.subject = lua_util.maybe_obfuscate_string(subject, settings, 'subject') + tbl.size = task:get_size() + local ip = task:get_from_ip() + if ip and ip:is_valid() then + tbl.ip = tostring(ip) + else + tbl.ip = 'unknown' + end + + tbl.user = task:get_user() or 'unknown' +end + +local function history_save(task) + local function redis_llen_cb(err, _) + if err then + rspamd_logger.errx(task, 'got error %s when writing history row: %s', + err) + end + end + + -- We skip saving it to the history + if task:has_flag('no_log') then + return + end + + local data = task:get_protocol_reply { 'metrics', 'basic' } + local prefix = settings.key_prefix .. hostname + + if data then + normalise_results(data, task) + else + rspamd_logger.errx('cannot get protocol reply, skip saving in history') + return + end + + local json = ucl.to_format(data, 'json-compact') + + if settings.compress then + json = rspamd_util.zstd_compress(json) + -- Distinguish between compressed and non-compressed options + prefix = prefix .. '_zst' + end + + local ret, conn, _ = lua_redis.rspamd_redis_make_request(task, + redis_params, -- connect params + nil, -- hash key + true, -- is write + redis_llen_cb, --callback + 'LPUSH', -- command + { prefix, json } -- arguments + ) + + if ret then + conn:add_cmd('LTRIM', { prefix, '0', string.format('%d', settings.nrows - 1) }) + + if settings.expire and settings.expire > 0 then + conn:add_cmd('EXPIRE', { prefix, string.format('%d', settings.expire) }) + end + end +end + +local function handle_history_request(task, conn, from, to, reset) + local prefix = settings.key_prefix .. hostname + if settings.compress then + -- Distinguish between compressed and non-compressed options + prefix = prefix .. '_zst' + end + + if reset then + local function redis_ltrim_cb(err, _) + if err then + rspamd_logger.errx(task, 'got error %s when resetting history: %s', + err) + conn:send_error(504, '{"error": "' .. err .. '"}') + else + conn:send_string('{"success":true}') + end + end + lua_redis.rspamd_redis_make_request(task, + redis_params, -- connect params + nil, -- hash key + true, -- is write + redis_ltrim_cb, --callback + 'LTRIM', -- command + { prefix, '0', '0' } -- arguments + ) + else + local function redis_lrange_cb(err, data) + if data then + local reply = { + version = 2, + } + if settings.compress then + local t1 = rspamd_util:get_ticks() + + data = fun.totable(fun.filter(function(e) + return e ~= nil + end, + fun.map(function(e) + local _, dec = rspamd_util.zstd_decompress(e) + if dec then + return dec + end + return nil + end, data))) + lua_util.debugm(N, task, 'decompress took %s ms', + (rspamd_util:get_ticks() - t1) * 1000.0) + collectgarbage() + end + -- Parse elements using ucl + local t1 = rspamd_util:get_ticks() + data = fun.totable( + fun.map(function(_, obj) + return obj + end, + fun.filter(function(res, obj) + if res then + return true + end + return false + end, + fun.map(function(elt) + local parser = ucl.parser() + local res, _ = parser:parse_text(elt) + + if res then + return true, parser:get_object() + else + return false, nil + end + end, data)))) + lua_util.debugm(N, task, 'parse took %s ms', + (rspamd_util:get_ticks() - t1) * 1000.0) + collectgarbage() + t1 = rspamd_util:get_ticks() + reply.rows = data + conn:send_ucl(reply) + lua_util.debugm(N, task, 'process + sending took %s ms', + (rspamd_util:get_ticks() - t1) * 1000.0) + collectgarbage() + else + rspamd_logger.errx(task, 'got error %s when getting history: %s', + err) + conn:send_error(504, '{"error": "' .. err .. '"}') + end + end + lua_redis.rspamd_redis_make_request(task, + redis_params, -- connect params + nil, -- hash key + false, -- is write + redis_lrange_cb, --callback + 'LRANGE', -- command + { prefix, string.format('%d', from), string.format('%d', to) }, -- arguments + { opaque_data = true } + ) + end +end + +local opts = rspamd_config:get_all_opt('history_redis') +if opts then + settings = lua_util.override_defaults(settings, opts) + local res, err = settings_schema:transform(settings) + + if not res then + rspamd_logger.warnx(rspamd_config, '%s: plugin is misconfigured: %s', N, err) + lua_util.disable_module(N, "config") + return + end + settings = res + + redis_params = lua_redis.parse_redis_server('history_redis') + if not redis_params then + rspamd_logger.infox(rspamd_config, 'no servers are specified, disabling module') + lua_util.disable_module(N, "redis") + else + rspamd_config:register_symbol({ + name = 'HISTORY_SAVE', + type = 'idempotent', + callback = history_save, + flags = 'empty,explicit_disable,ignore_passthrough', + augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) } + }) + lua_redis.register_prefix(settings.key_prefix .. hostname, N, + "Redis history", { + type = 'list', + }) + rspamd_plugins['history'] = { + handler = handle_history_request + } + end +end diff --git a/src/plugins/lua/http_headers.lua b/src/plugins/lua/http_headers.lua new file mode 100644 index 0000000..1c6494a --- /dev/null +++ b/src/plugins/lua/http_headers.lua @@ -0,0 +1,198 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local logger = require "rspamd_logger" +local ucl = require "ucl" + +local spf_symbols = { + symbol_allow = 'R_SPF_ALLOW', + symbol_deny = 'R_SPF_FAIL', + symbol_softfail = 'R_SPF_SOFTFAIL', + symbol_neutral = 'R_SPF_NEUTRAL', + symbol_tempfail = 'R_SPF_DNSFAIL', + symbol_na = 'R_SPF_NA', + symbol_permfail = 'R_SPF_PERMFAIL', +} + +local dkim_symbols = { + symbol_allow = 'R_DKIM_ALLOW', + symbol_deny = 'R_DKIM_REJECT', + symbol_tempfail = 'R_DKIM_TEMPFAIL', + symbol_na = 'R_DKIM_NA', + symbol_permfail = 'R_DKIM_PERMFAIL', + symbol_trace = 'DKIM_TRACE', +} + +local dkim_trace = { + pass = '+', + fail = '-', + temperror = '?', + permerror = '~', +} + +local dmarc_symbols = { + allow = 'DMARC_POLICY_ALLOW', + badpolicy = 'DMARC_BAD_POLICY', + dnsfail = 'DMARC_DNSFAIL', + na = 'DMARC_NA', + reject = 'DMARC_POLICY_REJECT', + softfail = 'DMARC_POLICY_SOFTFAIL', + quarantine = 'DMARC_POLICY_QUARANTINE', +} + +local opts = rspamd_config:get_all_opt('dmarc') +if opts and opts['symbols'] then + for k, _ in pairs(dmarc_symbols) do + if opts['symbols'][k] then + dmarc_symbols[k] = opts['symbols'][k] + end + end +end + +opts = rspamd_config:get_all_opt('dkim') +if opts then + for k, _ in pairs(dkim_symbols) do + if opts[k] then + dkim_symbols[k] = opts[k] + end + end +end + +opts = rspamd_config:get_all_opt('spf') +if opts then + for k, _ in pairs(spf_symbols) do + if opts[k] then + spf_symbols[k] = opts[k] + end + end +end + +-- Disable DKIM checks if passed via HTTP headers +rspamd_config:add_condition("DKIM_CHECK", function(task) + local hdr = task:get_request_header('DKIM') + + if hdr then + local parser = ucl.parser() + local res, err = parser:parse_string(tostring(hdr)) + if not res then + logger.infox(task, "cannot parse DKIM header: %1", err) + return true + end + + local p_obj = parser:get_object() + local results = p_obj['results'] + if not results and p_obj['result'] then + results = { { result = p_obj['result'], domain = 'unknown' } } + end + + if results then + for _, obj in ipairs(results) do + local dkim_domain = obj['domain'] or 'unknown' + if obj['result'] == 'pass' or obj['result'] == 'allow' then + task:insert_result(dkim_symbols['symbol_allow'], 1.0, 'http header') + task:insert_result(dkim_symbols['symbol_trace'], 1.0, + string.format('%s:%s', dkim_domain, dkim_trace.pass)) + elseif obj['result'] == 'fail' or obj['result'] == 'reject' then + task:insert_result(dkim_symbols['symbol_deny'], 1.0, 'http header') + task:insert_result(dkim_symbols['symbol_trace'], 1.0, + string.format('%s:%s', dkim_domain, dkim_trace.fail)) + elseif obj['result'] == 'tempfail' or obj['result'] == 'softfail' then + task:insert_result(dkim_symbols['symbol_tempfail'], 1.0, 'http header') + task:insert_result(dkim_symbols['symbol_trace'], 1.0, + string.format('%s:%s', dkim_domain, dkim_trace.temperror)) + elseif obj['result'] == 'permfail' then + task:insert_result(dkim_symbols['symbol_permfail'], 1.0, 'http header') + task:insert_result(dkim_symbols['symbol_trace'], 1.0, + string.format('%s:%s', dkim_domain, dkim_trace.permerror)) + elseif obj['result'] == 'na' then + task:insert_result(dkim_symbols['symbol_na'], 1.0, 'http header') + end + end + end + end + + return false +end) + +-- Disable SPF checks if passed via HTTP headers +rspamd_config:add_condition("SPF_CHECK", function(task) + local hdr = task:get_request_header('SPF') + + if hdr then + local parser = ucl.parser() + local res, err = parser:parse_string(tostring(hdr)) + if not res then + logger.infox(task, "cannot parse SPF header: %1", err) + return true + end + + local obj = parser:get_object() + + if obj['result'] then + if obj['result'] == 'pass' or obj['result'] == 'allow' then + task:insert_result(spf_symbols['symbol_allow'], 1.0, 'http header') + elseif obj['result'] == 'fail' or obj['result'] == 'reject' then + task:insert_result(spf_symbols['symbol_deny'], 1.0, 'http header') + elseif obj['result'] == 'neutral' then + task:insert_result(spf_symbols['symbol_neutral'], 1.0, 'http header') + elseif obj['result'] == 'softfail' then + task:insert_result(spf_symbols['symbol_softfail'], 1.0, 'http header') + elseif obj['result'] == 'permfail' then + task:insert_result(spf_symbols['symbol_permfail'], 1.0, 'http header') + elseif obj['result'] == 'na' then + task:insert_result(spf_symbols['symbol_na'], 1.0, 'http header') + end + end + end + + return false +end) + +rspamd_config:add_condition("DMARC_CALLBACK", function(task) + local hdr = task:get_request_header('DMARC') + + if hdr then + local parser = ucl.parser() + local res, err = parser:parse_string(tostring(hdr)) + if not res then + logger.infox(task, "cannot parse DMARC header: %1", err) + return true + end + + local obj = parser:get_object() + + if obj['result'] then + if obj['result'] == 'pass' or obj['result'] == 'allow' then + task:insert_result(dmarc_symbols['allow'], 1.0, 'http header') + elseif obj['result'] == 'fail' or obj['result'] == 'reject' then + task:insert_result(dmarc_symbols['reject'], 1.0, 'http header') + elseif obj['result'] == 'quarantine' then + task:insert_result(dmarc_symbols['quarantine'], 1.0, 'http header') + elseif obj['result'] == 'tempfail' then + task:insert_result(dmarc_symbols['dnsfail'], 1.0, 'http header') + elseif obj['result'] == 'softfail' or obj['result'] == 'none' then + task:insert_result(dmarc_symbols['softfail'], 1.0, 'http header') + elseif obj['result'] == 'permfail' or obj['result'] == 'badpolicy' then + task:insert_result(dmarc_symbols['badpolicy'], 1.0, 'http header') + elseif obj['result'] == 'na' then + task:insert_result(dmarc_symbols['na'], 1.0, 'http header') + end + end + end + + return false +end) + diff --git a/src/plugins/lua/ip_score.lua b/src/plugins/lua/ip_score.lua new file mode 100644 index 0000000..e43fa3b --- /dev/null +++ b/src/plugins/lua/ip_score.lua @@ -0,0 +1,4 @@ +-- This module is deprecated and must not be used. +-- This file serves as a tombstone to prevent old ip_score to be loaded + +return
\ No newline at end of file diff --git a/src/plugins/lua/known_senders.lua b/src/plugins/lua/known_senders.lua new file mode 100644 index 0000000..d26a1df --- /dev/null +++ b/src/plugins/lua/known_senders.lua @@ -0,0 +1,245 @@ +--[[ +Copyright (c) 2023, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +-- This plugin implements known senders logic for Rspamd + +local rspamd_logger = require "rspamd_logger" +local ts = (require "tableshape").types +local N = 'known_senders' +local lua_util = require "lua_util" +local lua_redis = require "lua_redis" +local lua_maps = require "lua_maps" +local rspamd_cryptobox_hash = require "rspamd_cryptobox_hash" + +if confighelp then + rspamd_config:add_example(nil, 'known_senders', + "Maintain a list of known senders using Redis", + [[ +known_senders { + # Domains to track senders + domains = "https://maps.rspamd.com/freemail/free.txt.zst"; + # Maximum number of elements + max_senders = 100000; + # Maximum time to live (when not using bloom filters) + max_ttl = 30d; + # Use bloom filters (must be enabled in Redis as a plugin) + use_bloom = false; + # Insert symbol for new senders from the specific domains + symbol_unknown = 'UNKNOWN_SENDER'; +} + ]]) + return +end + +local redis_params +local settings = { + domains = {}, + max_senders = 100000, + max_ttl = 30 * 86400, + use_bloom = false, + symbol = 'KNOWN_SENDER', + symbol_unknown = 'UNKNOWN_SENDER', + redis_key = 'rs_known_senders', +} + +local settings_schema = lua_redis.enrich_schema({ + domains = lua_maps.map_schema, + enabled = ts.boolean:is_optional(), + max_senders = (ts.integer + ts.string / tonumber):is_optional(), + max_ttl = (ts.integer + ts.string / tonumber):is_optional(), + use_bloom = ts.boolean:is_optional(), + redis_key = ts.string:is_optional(), + symbol = ts.string:is_optional(), + symbol_unknown = ts.string:is_optional(), +}) + +local function make_key(input) + local hash = rspamd_cryptobox_hash.create_specific('md5') + hash:update(input.addr) + return hash:hex() +end + +local function check_redis_key(task, key, key_ty) + lua_util.debugm(N, task, 'check key %s, type: %s', key, key_ty) + local function redis_zset_callback(err, data) + lua_util.debugm(N, task, 'got data: %s', data) + if err then + rspamd_logger.errx(task, 'redis error: %s', err) + elseif data then + if type(data) ~= 'userdata' then + -- non-null reply + task:insert_result(settings.symbol, 1.0, string.format("%s:%s", key_ty, key)) + else + if settings.symbol_unknown then + task:insert_result(settings.symbol_unknown, 1.0, string.format("%s:%s", key_ty, key)) + end + lua_util.debugm(N, task, 'insert key %s, type: %s', key, key_ty) + -- Insert key to zset and trim it's cardinality + lua_redis.redis_make_request(task, + redis_params, -- connect params + key, -- hash key + true, -- is write + nil, --callback + 'ZADD', -- command + { settings.redis_key, tostring(task:get_timeval(true)), key } -- arguments + ) + lua_redis.redis_make_request(task, + redis_params, -- connect params + key, -- hash key + true, -- is write + nil, --callback + 'ZREMRANGEBYRANK', -- command + { settings.redis_key, '0', + tostring(-(settings.max_senders + 1)) } -- arguments + ) + end + end + end + + local function redis_bloom_callback(err, data) + lua_util.debugm(N, task, 'got data: %s', data) + if err then + rspamd_logger.errx(task, 'redis error: %s', err) + elseif data then + if type(data) ~= 'userdata' and data == 1 then + -- non-null reply equal to `1` + task:insert_result(settings.symbol, 1.0, string.format("%s:%s", key_ty, key)) + else + if settings.symbol_unknown then + task:insert_result(settings.symbol_unknown, 1.0, string.format("%s:%s", key_ty, key)) + end + lua_util.debugm(N, task, 'insert key %s, type: %s', key, key_ty) + -- Reserve bloom filter space + lua_redis.redis_make_request(task, + redis_params, -- connect params + key, -- hash key + true, -- is write + nil, --callback + 'BF.RESERVE', -- command + { settings.redis_key, tostring(settings.max_senders), '0.01', '1000', 'NONSCALING' } -- arguments + ) + -- Insert key and adjust bloom filter + lua_redis.redis_make_request(task, + redis_params, -- connect params + key, -- hash key + true, -- is write + nil, --callback + 'BF.ADD', -- command + { settings.redis_key, key } -- arguments + ) + end + end + end + + if settings.use_bloom then + lua_redis.redis_make_request(task, + redis_params, -- connect params + key, -- hash key + false, -- is write + redis_bloom_callback, --callback + 'BF.EXISTS', -- command + { settings.redis_key, key } -- arguments + ) + else + lua_redis.redis_make_request(task, + redis_params, -- connect params + key, -- hash key + false, -- is write + redis_zset_callback, --callback + 'ZSCORE', -- command + { settings.redis_key, key } -- arguments + ) + end +end + +local function known_senders_callback(task) + local mime_from = (task:get_from('mime') or {})[1] + local smtp_from = (task:get_from('smtp') or {})[1] + local mime_key, smtp_key + if mime_from and mime_from.addr then + if settings.domains:get_key(mime_from.domain) then + mime_key = make_key(mime_from) + else + lua_util.debugm(N, task, 'skip mime from domain %s', mime_from.domain) + end + end + if smtp_from and smtp_from.addr then + if settings.domains:get_key(smtp_from.domain) then + smtp_key = make_key(smtp_from) + else + lua_util.debugm(N, task, 'skip smtp from domain %s', smtp_from.domain) + end + end + + if mime_key and smtp_key and mime_key ~= smtp_key then + -- Check both keys + check_redis_key(task, mime_key, 'mime') + check_redis_key(task, smtp_key, 'smtp') + elseif mime_key then + -- Check mime key + check_redis_key(task, mime_key, 'mime') + elseif smtp_key then + -- Check smtp key + check_redis_key(task, smtp_key, 'smtp') + end +end + +local opts = rspamd_config:get_all_opt('known_senders') +if opts then + settings = lua_util.override_defaults(settings, opts) + local res, err = settings_schema:transform(settings) + if not res then + rspamd_logger.errx(rspamd_config, 'cannot parse known_senders options: %1', err) + else + settings = res + end + redis_params = lua_redis.parse_redis_server(N, opts) + + if redis_params then + local map_conf = settings.domains + settings.domains = lua_maps.map_add_from_ucl(settings.domains, 'set', 'domains to track senders from') + if not settings.domains then + rspamd_logger.errx(rspamd_config, "couldn't add map %s, disable module", + map_conf) + lua_util.disable_module(N, "config") + return + end + lua_redis.register_prefix(settings.redis_key, N, + 'Known elements redis key', { + type = 'zset/bloom filter', + }) + local id = rspamd_config:register_symbol({ + name = settings.symbol, + type = 'normal', + callback = known_senders_callback, + one_shot = true, + score = -1.0, + augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) } + }) + + if settings.symbol_unknown and #settings.symbol_unknown > 0 then + rspamd_config:register_symbol({ + name = settings.symbol_unknown, + type = 'virtual', + parent = id, + one_shot = true, + score = 0.5, + }) + end + else + lua_util.disable_module(N, "redis") + end +end diff --git a/src/plugins/lua/maillist.lua b/src/plugins/lua/maillist.lua new file mode 100644 index 0000000..be1401c --- /dev/null +++ b/src/plugins/lua/maillist.lua @@ -0,0 +1,235 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if confighelp then + return +end + +-- Module for checking mail list headers +local N = 'maillist' +local symbol = 'MAILLIST' +local lua_util = require "lua_util" +-- EZMLM +-- Mailing-List: .*run by ezmlm +-- Precedence: bulk +-- List-Post: <mailto: +-- List-Help: <mailto: +-- List-Unsubscribe: <mailto:[a-zA-Z\.-]+-unsubscribe@ +-- List-Subscribe: <mailto:[a-zA-Z\.-]+-subscribe@ +-- RFC 2919 headers exist +local function check_ml_ezmlm(task) + -- Mailing-List + local header = task:get_header('mailing-list') + if not header or not string.find(header, 'ezmlm$') then + return false + end + -- Precedence + header = task:get_header('precedence') + if not header or not string.match(header, '^bulk$') then + return false + end + -- Other headers + header = task:get_header('list-post') + if not header or not string.find(header, '^<mailto:') then + return false + end + header = task:get_header('list-help') + if not header or not string.find(header, '^<mailto:') then + return false + end + -- Subscribe and unsubscribe + header = task:get_header('list-subscribe') + if not header or not string.find(header, '<mailto:[a-zA-Z.-]+-subscribe@') then + return false + end + header = task:get_header('list-unsubscribe') + if not header or not string.find(header, '<mailto:[a-zA-Z.-]+-unsubscribe@') then + return false + end + + return true +end + +-- GNU Mailman +-- Two major versions currently in use and they use slightly different headers +-- Mailman2: https://code.launchpad.net/~mailman-coders/mailman/2.1 +-- Mailman3: https://gitlab.com/mailman/mailman +local function check_ml_mailman(task) + local header = task:get_header('X-Mailman-Version') + if not header then + return false + end + local mm_version = header:match('^([23])%.') + if not mm_version then + lua_util.debugm(N, task, 'unknown Mailman version: %s', header) + return false + end + lua_util.debugm(N, task, 'checking Mailman %s headers', mm_version) + + -- XXX Some messages may not contain Precedence, but they are rare: + -- http://bazaar.launchpad.net/~mailman-coders/mailman/2.1/revision/1339 + header = task:get_header('Precedence') + if not header or (header ~= 'bulk' and header ~= 'list') then + return false + end + + -- Mailman 3 allows to disable all List-* headers in settings, but by default it adds them. + -- In all other cases all Mailman message should have List-Id header + if not task:has_header('List-Id') then + return false + end + + if mm_version == '2' then + -- X-BeenThere present in all Mailman2 messages + if not task:has_header('X-BeenThere') then + return false + end + -- X-List-Administrivia: is only added to messages Mailman creates and + -- sends out of its own accord + header = task:get_header('X-List-Administrivia') + if header and header == 'yes' then + -- not much elase we can check, Subjects can be changed in settings + return true + end + else + -- Mailman 3 + -- XXX not Mailman3 admin messages have this headers, but one + -- which don't usually have List-* headers examined below + if task:has_header('List-Administrivia') then + return true + end + end + + -- List-Archive and List-Post are optional, check other headers + for _, h in ipairs({ 'List-Help', 'List-Subscribe', 'List-Unsubscribe' }) do + header = task:get_header(h) + if not (header and header:find('<mailto:', 1, true)) then + return false + end + end + + return true +end + +-- Google groups detector +-- header exists X-Google-Loop +-- RFC 2919 headers exist +-- +local function check_ml_googlegroup(task) + return task:has_header('X-Google-Loop') or task:has_header('X-Google-Group-Id') +end + +-- CGP detector +-- X-Listserver = CommuniGate Pro LIST +-- RFC 2919 headers exist +-- +local function check_ml_cgp(task) + local header = task:get_header('X-Listserver') + + if not header or string.sub(header, 0, 20) ~= 'CommuniGate Pro LIST' then + return false + end + + return true +end + +-- RFC 2919 headers +local function check_generic_list_headers(task) + local score = 0 + local has_subscribe, has_unsubscribe + + local common_list_headers = { + ['List-Id'] = 0.75, + ['List-Archive'] = 0.125, + ['List-Owner'] = 0.125, + ['List-Help'] = 0.125, + ['List-Post'] = 0.125, + ['X-Loop'] = 0.125, + ['List-Subscribe'] = function() + has_subscribe = true + return 0.125 + end, + ['List-Unsubscribe'] = function() + has_unsubscribe = true + return 0.125 + end, + ['Precedence'] = function() + local header = task:get_header('Precedence') + if header and (header == 'list' or header == 'bulk') then + return 0.25 + end + end, + } + + for hname, hscore in pairs(common_list_headers) do + if task:has_header(hname) then + if type(hscore) == 'number' then + score = score + hscore + lua_util.debugm(N, task, 'has %s header, score = %s', hname, score) + else + local score_change = hscore() + if score and score_change then + score = score + score_change + lua_util.debugm(N, task, 'has %s header, score = %s', hname, score) + end + end + end + end + + if has_subscribe and has_unsubscribe then + score = score + 0.25 + end + + lua_util.debugm(N, task, 'final maillist score %s', score) + return score +end + + +-- RFC 2919 headers exist +local function check_maillist(task) + local score = check_generic_list_headers(task) + if score >= 1 then + if check_ml_ezmlm(task) then + task:insert_result(symbol, 1, 'ezmlm') + elseif check_ml_mailman(task) then + task:insert_result(symbol, 1, 'mailman') + elseif check_ml_googlegroup(task) then + task:insert_result(symbol, 1, 'googlegroups') + elseif check_ml_cgp(task) then + task:insert_result(symbol, 1, 'cgp') + else + if score > 2 then + score = 2 + end + task:insert_result(symbol, 0.5 * score, 'generic') + end + end +end + + + +-- Configuration +local opts = rspamd_config:get_all_opt('maillist') +if opts then + if opts['symbol'] then + symbol = opts['symbol'] + rspamd_config:register_symbol({ + name = symbol, + callback = check_maillist, + flags = 'nice' + }) + end +end diff --git a/src/plugins/lua/maps_stats.lua b/src/plugins/lua/maps_stats.lua new file mode 100644 index 0000000..d418810 --- /dev/null +++ b/src/plugins/lua/maps_stats.lua @@ -0,0 +1,133 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if confighelp then + rspamd_config:add_example(nil, 'maps_stats', + "Stores maps statistics in Redis", [[ +maps_stats { + # one iteration step per 2 minutes + interval = 2m; + # how many elements to store in Redis + count = 1k; + # common prefix for elements + prefix = 'rm_'; +} +]]) +end + +local redis_params +local lua_util = require "lua_util" +local rspamd_logger = require "rspamd_logger" +local lua_redis = require "lua_redis" +local N = "maps_stats" + +local settings = { + interval = 120, -- one iteration step per 2 minutes + count = 1000, -- how many elements to store in Redis + prefix = 'rm_', -- common prefix for elements +} + +local function process_map(map, ev_base, _) + if map:get_nelts() > 0 and map:get_uri() ~= 'static' then + local key = settings.prefix .. map:get_uri() + + local function redis_zrange_cb(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot delete extra elements in %s: %s', + key, err) + elseif data then + rspamd_logger.infox(rspamd_config, 'cleared %s elements from %s', + data, key) + end + end + local function redis_card_cb(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot get number of elements in %s: %s', + key, err) + elseif data then + if settings.count > 0 and tonumber(data) > settings.count then + lua_redis.rspamd_redis_make_request_taskless(ev_base, + rspamd_config, + redis_params, -- connect params + key, -- hash key + true, -- is write + redis_zrange_cb, --callback + 'ZREMRANGEBYRANK', -- command + { key, '0', tostring(-(settings.count) - 1) } -- arguments + ) + end + end + end + local ret, conn, _ = lua_redis.rspamd_redis_make_request_taskless(ev_base, + rspamd_config, + redis_params, -- connect params + key, -- hash key + true, -- is write + redis_card_cb, --callback + 'ZCARD', -- command + { key } -- arguments + ) + + if ret and conn then + local stats = map:get_stats(true) + for k, s in pairs(stats) do + if s > 0 then + conn:add_cmd('ZINCRBY', { key, tostring(s), k }) + end + end + end + end +end + +if not lua_util.check_experimental(N) then + return +end + +local opts = rspamd_config:get_all_opt(N) + +if opts then + for k, v in pairs(opts) do + settings[k] = v + end +end + +redis_params = lua_redis.parse_redis_server(N, opts) +-- XXX, this is a poor approach as not all maps are defined here... +local tmaps = rspamd_config:get_maps() +for _, m in ipairs(tmaps) do + if m:get_uri() ~= 'static' then + lua_redis.register_prefix(settings.prefix .. m:get_uri(), N, + 'Maps stats data', { + type = 'zlist', + persistent = true, + }) + end +end + +if redis_params then + rspamd_config:add_on_load(function(_, ev_base, worker) + local maps = rspamd_config:get_maps() + + for _, m in ipairs(maps) do + rspamd_config:add_periodic(ev_base, + settings['interval'], + function() + process_map(m, ev_base, worker) + return true + end, true) + end + end) +end
\ No newline at end of file diff --git a/src/plugins/lua/metadata_exporter.lua b/src/plugins/lua/metadata_exporter.lua new file mode 100644 index 0000000..7b353b8 --- /dev/null +++ b/src/plugins/lua/metadata_exporter.lua @@ -0,0 +1,707 @@ +--[[ +Copyright (c) 2016, Andrew Lewis <nerf@judo.za.org> +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if confighelp then + return +end + +-- A plugin that pushes metadata (or whole messages) to external services + +local redis_params +local lua_util = require "lua_util" +local rspamd_http = require "rspamd_http" +local rspamd_util = require "rspamd_util" +local rspamd_logger = require "rspamd_logger" +local rspamd_tcp = require "rspamd_tcp" +local ucl = require "ucl" +local E = {} +local N = 'metadata_exporter' +local HOSTNAME = rspamd_util.get_hostname() + +local settings = { + pusher_enabled = {}, + pusher_format = {}, + pusher_select = {}, + mime_type = 'text/plain', + defer = false, + mail_from = '', + mail_to = 'postmaster@localhost', + helo = 'rspamd', + email_template = [[From: "Rspamd" <$mail_from> +To: $mail_to +Subject: Spam alert +Date: $date +MIME-Version: 1.0 +Message-ID: <$our_message_id> +Content-type: text/plain; charset=utf-8 +Content-Transfer-Encoding: 8bit + +Authenticated username: $user +IP: $ip +Queue ID: $qid +SMTP FROM: $from +SMTP RCPT: $rcpt +MIME From: $header_from +MIME To: $header_to +MIME Date: $header_date +Subject: $header_subject +Message-ID: $message_id +Action: $action +Score: $score +Symbols: $symbols]], +} + +local function get_general_metadata(task, flatten, no_content) + local r = {} + local ip = task:get_from_ip() + if ip and ip:is_valid() then + r.ip = tostring(ip) + else + r.ip = 'unknown' + end + r.user = task:get_user() or 'unknown' + r.qid = task:get_queue_id() or 'unknown' + r.subject = task:get_subject() or 'unknown' + r.action = task:get_metric_action() + r.rspamd_server = HOSTNAME + + local s = task:get_metric_score()[1] + r.score = flatten and string.format('%.2f', s) or s + + local fuzzy = task:get_mempool():get_variable("fuzzy_hashes", "fstrings") + if fuzzy and #fuzzy > 0 then + local fz = {} + for _, h in ipairs(fuzzy) do + table.insert(fz, h) + end + if not flatten then + r.fuzzy = fz + else + r.fuzzy = table.concat(fz, ', ') + end + else + if not flatten then + r.fuzzy = {} + else + r.fuzzy = '' + end + end + + local rcpt = task:get_recipients('smtp') + if rcpt then + local l = {} + for _, a in ipairs(rcpt) do + table.insert(l, a['addr']) + end + if not flatten then + r.rcpt = l + else + r.rcpt = table.concat(l, ', ') + end + else + r.rcpt = 'unknown' + end + local from = task:get_from('smtp') + if ((from or E)[1] or E).addr then + r.from = from[1].addr + else + r.from = 'unknown' + end + local syminf = task:get_symbols_all() + if flatten then + local l = {} + for _, sym in ipairs(syminf) do + local txt + if sym.options then + local topt = table.concat(sym.options, ', ') + txt = sym.name .. '(' .. string.format('%.2f', sym.score) .. ')' .. ' [' .. topt .. ']' + else + txt = sym.name .. '(' .. string.format('%.2f', sym.score) .. ')' + end + table.insert(l, txt) + end + r.symbols = table.concat(l, '\n\t') + else + r.symbols = syminf + end + local function process_header(name) + local hdr = task:get_header_full(name) + if hdr then + local l = {} + for _, h in ipairs(hdr) do + table.insert(l, h.decoded) + end + if not flatten then + return l + else + return table.concat(l, '\n') + end + else + return 'unknown' + end + end + + local scan_real = task:get_scan_time() + scan_real = math.floor(scan_real * 1000) + if scan_real < 0 then + rspamd_logger.messagex(task, + 'clock skew detected for message: %s ms real sca time (reset to 0)', + scan_real) + scan_real = 0 + end + + r.scan_time = scan_real + local content = task:get_content() + r.size = content and content:len() or 0 + + if not no_content then + r.header_from = process_header('from') + r.header_to = process_header('to') + r.header_subject = process_header('subject') + r.header_date = process_header('date') + r.message_id = task:get_message_id() + end + return r +end + +local formatters = { + default = function(task) + return task:get_content(), {} + end, + email_alert = function(task, rule, extra) + local meta = get_general_metadata(task, true) + local display_emails = {} + local mail_targets = {} + meta.mail_from = rule.mail_from or settings.mail_from + local mail_rcpt = rule.mail_to or settings.mail_to + if type(mail_rcpt) ~= 'table' then + table.insert(display_emails, string.format('<%s>', mail_rcpt)) + table.insert(mail_targets, mail_rcpt) + else + for _, e in ipairs(mail_rcpt) do + table.insert(display_emails, string.format('<%s>', e)) + table.insert(mail_targets, e) + end + end + if rule.email_alert_sender then + local x = task:get_from('smtp') + if x and string.len(x[1].addr) > 0 then + table.insert(mail_targets, x) + table.insert(display_emails, string.format('<%s>', x[1].addr)) + end + end + if rule.email_alert_user then + local x = task:get_user() + if x then + table.insert(mail_targets, x) + table.insert(display_emails, string.format('<%s>', x)) + end + end + if rule.email_alert_recipients then + local x = task:get_recipients('smtp') + if x then + for _, e in ipairs(x) do + if string.len(e.addr) > 0 then + table.insert(mail_targets, e.addr) + table.insert(display_emails, string.format('<%s>', e.addr)) + end + end + end + end + meta.mail_to = table.concat(display_emails, ', ') + meta.our_message_id = rspamd_util.random_hex(12) .. '@rspamd' + meta.date = rspamd_util.time_to_string(rspamd_util.get_time()) + return lua_util.template(rule.email_template or settings.email_template, meta), { mail_targets = mail_targets } + end, + json = function(task) + return ucl.to_format(get_general_metadata(task), 'json-compact') + end +} + +local function is_spam(action) + return (action == 'reject' or action == 'add header' or action == 'rewrite subject') +end + +local selectors = { + default = function(task) + return true + end, + is_spam = function(task) + local action = task:get_metric_action() + return is_spam(action) + end, + is_spam_authed = function(task) + if not task:get_user() then + return false + end + local action = task:get_metric_action() + return is_spam(action) + end, + is_reject = function(task) + local action = task:get_metric_action() + return (action == 'reject') + end, + is_reject_authed = function(task) + if not task:get_user() then + return false + end + local action = task:get_metric_action() + return (action == 'reject') + end, + is_not_soft_reject = function(task) + local action = task:get_metric_action() + return (action ~= 'soft reject') + end, +} + +local function maybe_defer(task, rule) + if rule.defer then + rspamd_logger.warnx(task, 'deferring message') + task:set_pre_result('soft reject', 'deferred', N) + end +end + +local pushers = { + redis_pubsub = function(task, formatted, rule) + local _, ret, upstream + local function redis_pub_cb(err) + if err then + rspamd_logger.errx(task, 'got error %s when publishing on server %s', + err, upstream:get_addr()) + return maybe_defer(task, rule) + end + return true + end + ret, _, upstream = rspamd_redis_make_request(task, + redis_params, -- connect params + nil, -- hash key + true, -- is write + redis_pub_cb, --callback + 'PUBLISH', -- command + { rule.channel, formatted } -- arguments + ) + if not ret then + rspamd_logger.errx(task, 'error connecting to redis') + maybe_defer(task, rule) + end + end, + http = function(task, formatted, rule) + local function http_callback(err, code) + local valid_status = { 200, 201, 202, 204 } + + if err then + rspamd_logger.errx(task, 'got error %s in http callback', err) + return maybe_defer(task, rule) + end + for _, v in ipairs(valid_status) do + if v == code then + return true + end + end + rspamd_logger.errx(task, 'got unexpected http status: %s', code) + return maybe_defer(task, rule) + end + local hdrs = {} + if rule.meta_headers then + local gm = get_general_metadata(task, false, true) + local pfx = rule.meta_header_prefix or 'X-Rspamd-' + for k, v in pairs(gm) do + if type(v) == 'table' then + hdrs[pfx .. k] = ucl.to_format(v, 'json-compact') + else + hdrs[pfx .. k] = v + end + end + end + rspamd_http.request({ + task = task, + url = rule.url, + user = rule.user, + password = rule.password, + body = formatted, + callback = http_callback, + mime_type = rule.mime_type or settings.mime_type, + headers = hdrs, + }) + end, + send_mail = function(task, formatted, rule, extra) + local lua_smtp = require "lua_smtp" + local function sendmail_cb(ret, err) + if not ret then + rspamd_logger.errx(task, 'SMTP export error: %s', err) + maybe_defer(task, rule) + end + end + + lua_smtp.sendmail({ + task = task, + host = rule.smtp, + port = rule.smtp_port or settings.smtp_port or 25, + from = rule.mail_from or settings.mail_from, + recipients = extra.mail_targets or rule.mail_to or settings.mail_to, + helo = rule.helo or settings.helo, + timeout = rule.timeout or settings.timeout, + }, formatted, sendmail_cb) + end, + json_raw_tcp = function(task, formatted, rule) + local function json_raw_tcp_callback(err, code) + if err then + rspamd_logger.errx(task, 'got error %s in json_raw_tcp callback', err) + return maybe_defer(task, rule) + end + return true + end + rspamd_tcp.request({ + task=task, + host=rule.host, + port=rule.port, + data=formatted, + callback=json_raw_tcp_callback, + read=false, + }) + end, +} + +local opts = rspamd_config:get_all_opt(N) +if not opts then + return +end +local process_settings = { + select = function(val) + selectors.custom = assert(load(val))() + end, + format = function(val) + formatters.custom = assert(load(val))() + end, + push = function(val) + pushers.custom = assert(load(val))() + end, + custom_push = function(val) + if type(val) == 'table' then + for k, v in pairs(val) do + pushers[k] = assert(load(v))() + end + end + end, + custom_select = function(val) + if type(val) == 'table' then + for k, v in pairs(val) do + selectors[k] = assert(load(v))() + end + end + end, + custom_format = function(val) + if type(val) == 'table' then + for k, v in pairs(val) do + formatters[k] = assert(load(v))() + end + end + end, + pusher_enabled = function(val) + if type(val) == 'string' then + if pushers[val] then + settings.pusher_enabled[val] = true + else + rspamd_logger.errx(rspamd_config, 'Pusher type: %s is invalid', val) + end + elseif type(val) == 'table' then + for _, v in ipairs(val) do + if pushers[v] then + settings.pusher_enabled[v] = true + else + rspamd_logger.errx(rspamd_config, 'Pusher type: %s is invalid', val) + end + end + end + end, +} +for k, v in pairs(opts) do + local f = process_settings[k] + if f then + f(opts[k]) + else + settings[k] = v + end +end +if type(settings.rules) ~= 'table' then + -- Legacy config + settings.rules = {} + if not next(settings.pusher_enabled) then + if pushers.custom then + rspamd_logger.infox(rspamd_config, 'Custom pusher implicitly enabled') + settings.pusher_enabled.custom = true + else + -- Check legacy options + if settings.url then + rspamd_logger.warnx(rspamd_config, 'HTTP pusher implicitly enabled') + settings.pusher_enabled.http = true + end + if settings.channel then + rspamd_logger.warnx(rspamd_config, 'Redis Pubsub pusher implicitly enabled') + settings.pusher_enabled.redis_pubsub = true + end + if settings.smtp and settings.mail_to then + rspamd_logger.warnx(rspamd_config, 'SMTP pusher implicitly enabled') + settings.pusher_enabled.send_mail = true + end + end + end + if not next(settings.pusher_enabled) then + rspamd_logger.errx(rspamd_config, 'No push backend enabled') + return + end + if settings.formatter then + settings.format = formatters[settings.formatter] + if not settings.format then + rspamd_logger.errx(rspamd_config, 'No such formatter: %s', settings.formatter) + return + end + end + if settings.selector then + settings.select = selectors[settings.selector] + if not settings.select then + rspamd_logger.errx(rspamd_config, 'No such selector: %s', settings.selector) + return + end + end + for k in pairs(settings.pusher_enabled) do + local formatter = settings.pusher_format[k] + local selector = settings.pusher_select[k] + if not formatter then + settings.pusher_format[k] = settings.formatter or 'default' + rspamd_logger.infox(rspamd_config, 'Using default formatter for %s pusher', k) + else + if not formatters[formatter] then + rspamd_logger.errx(rspamd_config, 'No such formatter: %s - disabling %s', formatter, k) + settings.pusher_enabled.k = nil + end + end + if not selector then + settings.pusher_select[k] = settings.selector or 'default' + rspamd_logger.infox(rspamd_config, 'Using default selector for %s pusher', k) + else + if not selectors[selector] then + rspamd_logger.errx(rspamd_config, 'No such selector: %s - disabling %s', selector, k) + settings.pusher_enabled.k = nil + end + end + end + if settings.pusher_enabled.redis_pubsub then + redis_params = rspamd_parse_redis_server(N) + if not redis_params then + rspamd_logger.errx(rspamd_config, 'No redis servers are specified') + settings.pusher_enabled.redis_pubsub = nil + else + local r = {} + r.backend = 'redis_pubsub' + r.channel = settings.channel + r.defer = settings.defer + r.selector = settings.pusher_select.redis_pubsub + r.formatter = settings.pusher_format.redis_pubsub + r.timeout = redis_params.timeout + settings.rules[r.backend:upper()] = r + end + end + if settings.pusher_enabled.http then + if not settings.url then + rspamd_logger.errx(rspamd_config, 'No URL is specified') + settings.pusher_enabled.http = nil + else + local r = {} + r.backend = 'http' + r.url = settings.url + r.mime_type = settings.mime_type + r.defer = settings.defer + r.selector = settings.pusher_select.http + r.formatter = settings.pusher_format.http + r.timeout = settings.timeout or 0.0 + settings.rules[r.backend:upper()] = r + end + end + if settings.pusher_enabled.send_mail then + if not (settings.mail_to and settings.smtp) then + rspamd_logger.errx(rspamd_config, 'No mail_to and/or smtp setting is specified') + settings.pusher_enabled.send_mail = nil + else + local r = {} + r.backend = 'send_mail' + r.mail_to = settings.mail_to + r.mail_from = settings.mail_from + r.helo = settings.hello + r.smtp = settings.smtp + r.smtp_port = settings.smtp_port + r.email_template = settings.email_template + r.defer = settings.defer + r.timeout = settings.timeout or 0.0 + r.selector = settings.pusher_select.send_mail + r.formatter = settings.pusher_format.send_mail + settings.rules[r.backend:upper()] = r + end + end + if settings.pusher_enabled.json_raw_tcp then + if not (settings.host and settings.port) then + rspamd_logger.errx(rspamd_config, 'No host and/or port is specified') + settings.pusher_enabled.json_raw_tcp = nil + else + local r = {} + r.backend = 'json_raw_tcp' + r.host = settings.host + r.port = settings.port + r.defer = settings.defer + r.selector = settings.pusher_select.json_raw_tcp + r.formatter = settings.pusher_format.json_raw_tcp + settings.rules[r.backend:upper()] = r + end + end + if not next(settings.pusher_enabled) then + rspamd_logger.errx(rspamd_config, 'No push backend enabled') + return + end +elseif not next(settings.rules) then + lua_util.debugm(N, rspamd_config, 'No rules enabled') + return +end +if not settings.rules or not next(settings.rules) then + rspamd_logger.errx(rspamd_config, 'No rules enabled') + return +end +local backend_required_elements = { + http = { + 'url', + }, + smtp = { + 'mail_to', + 'smtp', + }, + redis_pubsub = { + 'channel', + }, + json_raw_tcp = { + 'host', + 'port', + }, +} +local check_element = { + selector = function(k, v) + if not selectors[v] then + rspamd_logger.errx(rspamd_config, 'Rule %s has invalid selector %s', k, v) + return false + else + return true + end + end, + formatter = function(k, v) + if not formatters[v] then + rspamd_logger.errx(rspamd_config, 'Rule %s has invalid formatter %s', k, v) + return false + else + return true + end + end, +} +local backend_check = { + default = function(k, rule) + local reqset = backend_required_elements[rule.backend] + if reqset then + for _, e in ipairs(reqset) do + if not rule[e] then + rspamd_logger.errx(rspamd_config, 'Rule %s misses required setting %s', k, e) + settings.rules[k] = nil + end + end + end + for sett, v in pairs(rule) do + local f = check_element[sett] + if f then + if not f(sett, v) then + settings.rules[k] = nil + end + end + end + end, +} +backend_check.redis_pubsub = function(k, rule) + if not redis_params then + redis_params = rspamd_parse_redis_server(N) + end + if not redis_params then + rspamd_logger.errx(rspamd_config, 'No redis servers are specified') + settings.rules[k] = nil + else + backend_check.default(k, rule) + rule.timeout = redis_params.timeout + end +end +setmetatable(backend_check, { + __index = function() + return backend_check.default + end, +}) +for k, v in pairs(settings.rules) do + if type(v) == 'table' then + local backend = v.backend + if not backend then + rspamd_logger.errx(rspamd_config, 'Rule %s has no backend', k) + settings.rules[k] = nil + elseif not pushers[backend] then + rspamd_logger.errx(rspamd_config, 'Rule %s has invalid backend %s', k, backend) + settings.rules[k] = nil + else + local f = backend_check[backend] + f(k, v) + end + else + rspamd_logger.errx(rspamd_config, 'Rule %s has bad type: %s', k, type(v)) + settings.rules[k] = nil + end +end + +local function gen_exporter(rule) + return function(task) + if task:has_flag('skip') then + return + end + local selector = rule.selector or 'default' + local selected = selectors[selector](task) + if selected then + lua_util.debugm(N, task, 'Message selected for processing') + local formatter = rule.formatter or 'default' + local formatted, extra = formatters[formatter](task, rule) + if formatted then + pushers[rule.backend](task, formatted, rule, extra) + else + lua_util.debugm(N, task, 'Formatter [%s] returned non-truthy value [%s]', formatter, formatted) + end + else + lua_util.debugm(N, task, 'Selector [%s] returned non-truthy value [%s]', selector, selected) + end + end +end + +if not next(settings.rules) then + rspamd_logger.errx(rspamd_config, 'No rules enabled') + lua_util.disable_module(N, "config") +end +for k, r in pairs(settings.rules) do + rspamd_config:register_symbol({ + name = 'EXPORT_METADATA_' .. k, + type = 'idempotent', + callback = gen_exporter(r), + flags = 'empty,explicit_disable,ignore_passthrough', + augmentations = { string.format("timeout=%f", r.timeout or 0.0) } + }) +end diff --git a/src/plugins/lua/metric_exporter.lua b/src/plugins/lua/metric_exporter.lua new file mode 100644 index 0000000..7588551 --- /dev/null +++ b/src/plugins/lua/metric_exporter.lua @@ -0,0 +1,252 @@ +--[[ +Copyright (c) 2016, Andrew Lewis <nerf@judo.za.org> +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]] -- + +if confighelp then + return +end + +local N = 'metric_exporter' +local logger = require "rspamd_logger" +local mempool = require "rspamd_mempool" +local util = require "rspamd_util" +local tcp = require "rspamd_tcp" +local lua_util = require "lua_util" + +local pool +local settings = { + interval = 120, + timeout = 15, + statefile = string.format('%s/%s', rspamd_paths['DBDIR'], 'metric_exporter_last_push') +} + +local VAR_NAME = 'metric_exporter_last_push' + +local valid_metrics = { + 'actions.add header', + 'actions.greylist', + 'actions.no action', + 'actions.reject', + 'actions.rewrite subject', + 'actions.soft reject', + 'bytes_allocated', + 'chunks_allocated', + 'chunks_freed', + 'chunks_oversized', + 'connections', + 'control_connections', + 'ham_count', + 'learned', + 'pools_allocated', + 'pools_freed', + 'scanned', + 'shared_chunks_allocated', + 'spam_count', +} + +local function validate_metrics(settings_metrics) + if type(settings_metrics) ~= 'table' or #settings_metrics == 0 then + logger.errx(rspamd_config, 'No metrics specified for collection') + return false + end + for _, v in ipairs(settings_metrics) do + local isvalid = false + for _, vm in ipairs(valid_metrics) do + if vm == v then + isvalid = true + break + end + end + if not isvalid then + logger.errx('Invalid metric: %s', v) + return false + end + local split = rspamd_str_split(v, '.') + if #split > 2 then + logger.errx('Too many dots in metric name: %s', v) + return false + end + end + return true +end + +local function load_defaults(defaults) + for k, v in pairs(defaults) do + if settings[k] == nil then + settings[k] = v + end + end +end + +local function graphite_config() + load_defaults({ + host = 'localhost', + port = 2003, + metric_prefix = 'rspamd' + }) + return validate_metrics(settings['metrics']) +end + +local function graphite_push(kwargs) + local stamp + if kwargs['time'] then + stamp = math.floor(kwargs['time']) + else + stamp = math.floor(util.get_time()) + end + local metrics_str = {} + for _, v in ipairs(settings['metrics']) do + local mvalue + local mname = string.format('%s.%s', settings['metric_prefix'], v:gsub(' ', '_')) + local split = rspamd_str_split(v, '.') + if #split == 1 then + mvalue = kwargs['stats'][v] + elseif #split == 2 then + mvalue = kwargs['stats'][split[1]][split[2]] + end + table.insert(metrics_str, string.format('%s %s %s', mname, mvalue, stamp)) + end + + metrics_str = table.concat(metrics_str, '\n') + + tcp.request({ + ev_base = kwargs['ev_base'], + config = rspamd_config, + host = settings['host'], + port = settings['port'], + timeout = settings['timeout'], + read = false, + data = { + metrics_str, '\n', + }, + callback = (function(err) + if err then + logger.errx('Push failed: %1', err) + return + end + pool:set_variable(VAR_NAME, stamp) + end) + }) +end + +local backends = { + graphite = { + configure = graphite_config, + push = graphite_push, + }, +} + +local function configure_metric_exporter() + local opts = rspamd_config:get_all_opt(N) + local be = opts['backend'] + if not be then + logger.debugm(N, rspamd_config, 'Backend is unspecified') + return + end + if not backends[be] then + logger.errx(rspamd_config, 'Backend is invalid: ' .. be) + return false + end + for k, v in pairs(opts) do + settings[k] = v + end + return backends[be]['configure']() +end + +if not configure_metric_exporter() then + lua_util.disable_module(N, "config") + return +end + +rspamd_config:add_on_load(function(_, ev_base, worker) + -- Exit unless we're the first 'controller' worker + if not worker:is_primary_controller() then + return + end + -- Persist mempool variable to statefile on shutdown + pool = mempool.create() + rspamd_config:register_finish_script(function() + local stamp = pool:get_variable(VAR_NAME, 'double') + if not stamp then + logger.warn('No last metric exporter push to persist to disk') + return + end + local f, err = io.open(settings['statefile'], 'w') + if err then + logger.errx('Unable to write statefile to disk: %s', err) + return + end + if f then + f:write(pool:get_variable(VAR_NAME, 'double')) + f:close() + end + pool:destroy() + end) + -- Push metrics to backend + local function push_metrics(time) + logger.infox('Pushing metrics to %s backend', settings['backend']) + local args = { + ev_base = ev_base, + stats = worker:get_stat(), + } + if time then + table.insert(args, time) + end + backends[settings['backend']]['push'](args) + end + -- Push metrics at regular intervals + local function schedule_regular_push() + rspamd_config:add_periodic(ev_base, settings['interval'], function() + push_metrics() + return true + end) + end + -- Push metrics to backend and reschedule check + local function schedule_intermediate_push(when) + rspamd_config:add_periodic(ev_base, when, function() + push_metrics() + schedule_regular_push() + return false + end) + end + -- Try read statefile on startup + local stamp + local f, err = io.open(settings['statefile'], 'r') + if err then + logger.errx('Failed to open statefile: %s', err) + end + if f then + io.input(f) + stamp = tonumber(io.read()) + pool:set_variable(VAR_NAME, stamp) + end + if not stamp then + logger.debugm(N, rspamd_config, 'No state found - pushing stats immediately') + push_metrics() + schedule_regular_push() + return + end + local time = util.get_time() + local delta = stamp - time + settings['interval'] + if delta <= 0 then + logger.debugm(N, rspamd_config, 'Last push is too old - pushing stats immediately') + push_metrics(time) + schedule_regular_push() + return + end + logger.debugm(N, rspamd_config, 'Scheduling next push in %s seconds', delta) + schedule_intermediate_push(delta) +end) diff --git a/src/plugins/lua/mid.lua b/src/plugins/lua/mid.lua new file mode 100644 index 0000000..b8650c8 --- /dev/null +++ b/src/plugins/lua/mid.lua @@ -0,0 +1,123 @@ +--[[ +Copyright (c) 2016, Alexander Moisseev <moiseev@mezonplus.ru> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[ +MID plugin - suppress INVALID_MSGID and MISSING_MID for messages originating +from listed valid DKIM domains with missed or known proprietary Message-IDs +]]-- + +if confighelp then + return +end + +local rspamd_logger = require "rspamd_logger" +local rspamd_regexp = require "rspamd_regexp" +local lua_util = require "lua_util" +local N = "mid" + +local settings = { + url = '', + symbol_known_mid = 'KNOWN_MID', + symbol_known_no_mid = 'KNOWN_NO_MID', + symbol_invalid_msgid = 'INVALID_MSGID', + symbol_missing_mid = 'MISSING_MID', + symbol_dkim_allow = 'R_DKIM_ALLOW', + csymbol_invalid_msgid_allowed = 'INVALID_MSGID_ALLOWED', + csymbol_missing_mid_allowed = 'MISSING_MID_ALLOWED', +} + +local map + +local E = {} + +local function known_mid_cb(task) + local re = {} + local header = task:get_header('Message-Id') + local das = task:get_symbol(settings['symbol_dkim_allow']) + if ((das or E)[1] or E).options then + for _, dkim_domain in ipairs(das[1]['options']) do + if dkim_domain then + local v = map:get_key(dkim_domain:match "[^:]+") + if v then + if v == '' then + if not header then + task:insert_result(settings['symbol_known_no_mid'], 1, dkim_domain) + return + end + else + re[dkim_domain] = rspamd_regexp.create_cached(v) + if header and re[dkim_domain] and re[dkim_domain]:match(header) then + task:insert_result(settings['symbol_known_mid'], 1, dkim_domain) + return + end + end + end + end + end + end +end + +local opts = rspamd_config:get_all_opt('mid') +if opts then + for k, v in pairs(opts) do + settings[k] = v + end + + if not opts.source then + rspamd_logger.infox(rspamd_config, 'mid module requires "source" parameter') + lua_util.disable_module(N, "config") + return + end + + map = rspamd_config:add_map { + url = opts.source, + description = "Message-IDs map", + type = 'map' + } + if map then + local id = rspamd_config:register_symbol({ + name = 'KNOWN_MID_CALLBACK', + type = 'callback', + group = 'mid', + callback = known_mid_cb + }) + rspamd_config:register_symbol({ + name = settings['symbol_known_mid'], + parent = id, + group = 'mid', + type = 'virtual' + }) + rspamd_config:register_symbol({ + name = settings['symbol_known_no_mid'], + parent = id, + group = 'mid', + type = 'virtual' + }) + rspamd_config:add_composite(settings['csymbol_invalid_msgid_allowed'], + string.format('~%s & ^%s', + settings['symbol_known_mid'], + settings['symbol_invalid_msgid'])) + rspamd_config:add_composite(settings['csymbol_missing_mid_allowed'], + string.format('~%s & ^%s', + settings['symbol_known_no_mid'], + settings['symbol_missing_mid'])) + + rspamd_config:register_dependency('KNOWN_MID_CALLBACK', 'DKIM_CHECK') + else + rspamd_logger.infox(rspamd_config, 'source is not a valid map definition, disabling module') + lua_util.disable_module(N, "config") + end +end diff --git a/src/plugins/lua/milter_headers.lua b/src/plugins/lua/milter_headers.lua new file mode 100644 index 0000000..b53a454 --- /dev/null +++ b/src/plugins/lua/milter_headers.lua @@ -0,0 +1,762 @@ +--[[ +Copyright (c) 2016, Andrew Lewis <nerf@judo.za.org> +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if confighelp then + return +end + +-- A plugin that provides common header manipulations + +local logger = require "rspamd_logger" +local util = require "rspamd_util" +local N = 'milter_headers' +local lua_util = require "lua_util" +local lua_maps = require "lua_maps" +local lua_mime = require "lua_mime" +local ts = require("tableshape").types +local E = {} + +local HOSTNAME = util.get_hostname() + +local settings = { + remove_upstream_spam_flag = true; + skip_local = true, + skip_authenticated = true, + skip_all = false, + local_headers = {}, + authenticated_headers = {}, + headers_modify_mode = 'compat', -- To avoid compatibility issues on upgrade + default_headers_order = nil, -- Insert at the end (set 1 to insert just after the first received) + routines = { + ['remove-headers'] = { + headers = {}, + }, + ['add-headers'] = { + headers = {}, + remove = 0, + }, + ['remove-header'] = { + remove = 0, + }, + ['x-spamd-result'] = { + header = 'X-Spamd-Result', + remove = 0, + stop_chars = ' ', + sort_by = 'score', + }, + ['x-rspamd-server'] = { + header = 'X-Rspamd-Server', + remove = 0, + hostname = nil, -- Get the local computer host name + }, + ['x-rspamd-queue-id'] = { + header = 'X-Rspamd-Queue-Id', + remove = 0, + }, + ['x-rspamd-pre-result'] = { + header = 'X-Rspamd-Pre-Result', + remove = 0, + }, + ['x-rspamd-action'] = { + header = 'X-Rspamd-Action', + remove = 0, + }, + ['remove-spam-flag'] = { + header = 'X-Spam', + }, + ['spam-header'] = { + header = 'Deliver-To', + value = 'Junk', + remove = 0, + }, + ['x-virus'] = { + header = 'X-Virus', + remove = 0, + status_clean = nil, + status_infected = nil, + status_fail = nil, + symbols_fail = {}, + symbols = {}, -- needs config + }, + ['x-os-fingerprint'] = { + header = 'X-OS-Fingerprint', + remove = 0, + }, + ['x-spamd-bar'] = { + header = 'X-Spamd-Bar', + positive = '+', + negative = '-', + neutral = '/', + remove = 0, + }, + ['x-spam-level'] = { + header = 'X-Spam-Level', + char = '*', + remove = 0, + }, + ['x-spam-status'] = { + header = 'X-Spam-Status', + remove = 0, + }, + ['authentication-results'] = { + header = 'Authentication-Results', + remove = 0, + add_smtp_user = true, + stop_chars = ';', + }, + ['stat-signature'] = { + header = 'X-Stat-Signature', + remove = 0, + }, + ['fuzzy-hashes'] = { + header = 'X-Rspamd-Fuzzy', + }, + }, +} + +local active_routines = {} +local custom_routines = {} + +local function milter_headers(task) + + -- Used to override wanted stuff by means of settings + local settings_override = false + + local function skip_wanted(hdr) + if settings_override then + return true + end + -- Normal checks + local function match_extended_headers_rcpt() + local rcpts = task:get_recipients('smtp') + if not rcpts then + return false + end + local found + for _, r in ipairs(rcpts) do + found = false + -- Try full addr match + if r.addr and r.domain and r.user then + if settings.extended_headers_rcpt:get_key(r.addr) then + lua_util.debugm(N, task, 'found full addr in recipients for extended headers: %s', + r.addr) + found = true + end + -- Try user as plain match + if not found and settings.extended_headers_rcpt:get_key(r.user) then + lua_util.debugm(N, task, 'found user in recipients for extended headers: %s (%s)', + r.user, r.addr) + found = true + end + -- Try @domain to match domain + if not found and settings.extended_headers_rcpt:get_key('@' .. r.domain) then + lua_util.debugm(N, task, 'found domain in recipients for extended headers: @%s (%s)', + r.domain, r.addr) + found = true + end + end + if found then + break + end + end + return found + end + + if settings.extended_headers_rcpt and match_extended_headers_rcpt() then + return false + end + + if settings.skip_local and not settings.local_headers[hdr] then + local ip = task:get_ip() + if (ip and ip:is_local()) then + return true + end + end + + if settings.skip_authenticated and not settings.authenticated_headers[hdr] then + if task:get_user() ~= nil then + return true + end + end + + if settings.skip_all then + return true + end + + return false + + end + + -- XXX: fix this crap one day + -- routines - are closures that encloses all environment including task + -- common - a common environment shared between routines + -- add - add headers table (filled by routines) + -- remove - remove headers table (filled by routines) + local routines, common, add, remove = {}, {}, {}, {} + + local function add_header(name, value, stop_chars, order) + local hname = settings.routines[name].header + if not add[hname] then + add[hname] = {} + end + table.insert(add[hname], { + order = (order or settings.default_headers_order or -1), + value = lua_util.fold_header(task, hname, value, stop_chars) + }) + end + + routines['x-spamd-result'] = function() + local local_mod = settings.routines['x-spamd-result'] + if skip_wanted('x-spamd-result') then + return + end + if not common.symbols then + common.symbols = task:get_symbols_all() + end + if not common['metric_score'] then + common['metric_score'] = task:get_metric_score() + end + if not common['metric_action'] then + common['metric_action'] = task:get_metric_action() + end + if local_mod.remove then + remove[local_mod.header] = local_mod.remove + end + + local buf = {} + local verdict = string.format('default: %s [%.2f / %.2f]', + --TODO: (common.metric_action == 'no action') and 'False' or 'True', + (common.metric_action == 'reject') and 'True' or 'False', + common.metric_score[1], common.metric_score[2]) + table.insert(buf, verdict) + + -- Deal with symbols + table.sort(common.symbols, function(s1, s2) + local res + if local_mod.sort_by == 'name' then + res = s1.name < s2.name + else + -- inverse order to show important symbols first + res = math.abs(s1.score) > math.abs(s2.score) + end + + return res + end) + + for _, s in ipairs(common.symbols) do + local sym_str = string.format('%s(%.2f)[%s]', + s.name, s.score, table.concat(s.options or {}, ',')) + table.insert(buf, sym_str) + end + add_header('x-spamd-result', table.concat(buf, '; '), ';') + + local has_pr, action, message, module = task:has_pre_result() + + if has_pr then + local pr_header = {} + if action then + table.insert(pr_header, string.format('action=%s', action)) + end + if module then + table.insert(pr_header, string.format('module=%s', module)) + end + if message then + table.insert(pr_header, message) + end + add_header('x-rspamd-pre-result', table.concat(pr_header, '; '), ';') + end + end + + routines['x-rspamd-queue-id'] = function() + if skip_wanted('x-rspamd-queue-id') then + return + end + if common.queue_id ~= false then + common.queue_id = task:get_queue_id() + if not common.queue_id then + common.queue_id = false + end + end + if settings.routines['x-rspamd-queue-id'].remove then + remove[settings.routines['x-rspamd-queue-id'].header] = settings.routines['x-rspamd-queue-id'].remove + end + if common.queue_id then + add[settings.routines['x-rspamd-queue-id'].header] = common.queue_id + end + end + + routines['remove-header'] = function() + if skip_wanted('remove-header') then + return + end + if settings.routines['remove-header'].header and settings.routines['remove-header'].remove then + remove[settings.routines['remove-header'].header] = settings.routines['remove-header'].remove + end + end + + routines['remove-headers'] = function() + if skip_wanted('remove-headers') then + return + end + for h, r in pairs(settings.routines['remove-headers'].headers) do + remove[h] = r + end + end + + routines['add-headers'] = function() + if skip_wanted('add-headers') then + return + end + for h, r in pairs(settings.routines['add-headers'].headers) do + add[h] = r + remove[h] = settings.routines['add-headers'].remove + end + end + + routines['x-rspamd-server'] = function() + local local_mod = settings.routines['x-rspamd-server'] + if skip_wanted('x-rspamd-server') then + return + end + if local_mod.remove then + remove[local_mod.header] = local_mod.remove + end + local hostname = local_mod.hostname + add[local_mod.header] = hostname and hostname or HOSTNAME + end + + routines['x-spamd-bar'] = function() + local local_mod = settings.routines['x-spamd-bar'] + if skip_wanted('x-rspamd-bar') then + return + end + if not common['metric_score'] then + common['metric_score'] = task:get_metric_score() + end + local score = common['metric_score'][1] + local spambar + if score <= -1 then + spambar = string.rep(local_mod.negative, math.floor(score * -1)) + elseif score >= 1 then + spambar = string.rep(local_mod.positive, math.floor(score)) + else + spambar = local_mod.neutral + end + if local_mod.remove then + remove[local_mod.header] = local_mod.remove + end + if spambar ~= '' then + add[local_mod.header] = spambar + end + end + + routines['x-spam-level'] = function() + local local_mod = settings.routines['x-spam-level'] + if skip_wanted('x-spam-level') then + return + end + if not common['metric_score'] then + common['metric_score'] = task:get_metric_score() + end + local score = common['metric_score'][1] + if score < 1 then + return nil, {}, {} + end + if local_mod.remove then + remove[local_mod.header] = local_mod.remove + end + add[local_mod.header] = string.rep(local_mod.char, math.floor(score)) + end + + routines['x-rspamd-action'] = function() + local local_mod = settings.routines['x-rspamd-action'] + if skip_wanted('x-rspamd-action') then + return + end + if not common['metric_action'] then + common['metric_action'] = task:get_metric_action() + end + local action = common['metric_action'] + if local_mod.remove then + remove[local_mod.header] = local_mod.remove + end + add[local_mod.header] = action + end + + local function spam_header (class, name, value, remove_v) + if skip_wanted(class) then + return + end + if not common['metric_action'] then + common['metric_action'] = task:get_metric_action() + end + if remove_v then + remove[name] = remove_v + end + local action = common['metric_action'] + if action ~= 'no action' and action ~= 'greylist' then + add[name] = value + end + end + + routines['spam-header'] = function() + spam_header('spam-header', + settings.routines['spam-header'].header, + settings.routines['spam-header'].value, + settings.routines['spam-header'].remove) + end + + routines['remove-spam-flag'] = function() + remove[settings.routines['remove-spam-flag'].header] = 0 + end + + routines['x-virus'] = function() + local local_mod = settings.routines['x-virus'] + if skip_wanted('x-virus') then + return + end + if not common.symbols_hash then + if not common.symbols then + common.symbols = task:get_symbols_all() + end + local h = {} + for _, s in ipairs(common.symbols) do + h[s.name] = s + end + common.symbols_hash = h + end + if local_mod.remove then + remove[local_mod.header] = local_mod.remove + end + local virii = {} + for _, sym in ipairs(local_mod.symbols) do + local s = common.symbols_hash[sym] + if s then + if (s.options or E)[1] then + table.insert(virii, table.concat(s.options, ',')) + elseif s then + table.insert(virii, 'unknown') + end + end + end + if #virii > 0 then + local virusstatus = table.concat(virii, ',') + if local_mod.status_infected then + virusstatus = local_mod.status_infected .. ', ' .. virusstatus + end + add_header('x-virus', virusstatus) + else + local failed = false + local fail_reason = 'unknown' + for _, sym in ipairs(local_mod.symbols_fail) do + local s = common.symbols_hash[sym] + if s then + failed = true + if (s.options or E)[1] then + fail_reason = table.concat(s.options, ',') + end + end + end + if not failed then + if local_mod.status_clean then + add_header('x-virus', local_mod.status_clean) + end + else + if local_mod.status_clean then + add_header('x-virus', string.format('%s(%s)', + local_mod.status_fail, fail_reason)) + end + end + end + end + + routines['x-os-fingerprint'] = function() + if skip_wanted('x-os-fingerprint') then + return + end + local local_mod = settings.routines['x-os-fingerprint'] + + local os_string, link_type, uptime_min, distance = task:get_mempool():get_variable('os_fingerprint', + 'string, string, double, double'); + + if not os_string then + return + end + + local value = string.format('%s, (up: %i min), (distance %i, link: %s)', + os_string, uptime_min, distance, link_type) + + if local_mod.remove then + remove[local_mod.header] = local_mod.remove + end + + add_header('x-os-fingerprint', value) + end + + routines['x-spam-status'] = function() + if skip_wanted('x-spam-status') then + return + end + if not common['metric_score'] then + common['metric_score'] = task:get_metric_score() + end + if not common['metric_action'] then + common['metric_action'] = task:get_metric_action() + end + local score = common['metric_score'][1] + local action = common['metric_action'] + local is_spam + local spamstatus + if action ~= 'no action' and action ~= 'greylist' then + is_spam = 'Yes' + else + is_spam = 'No' + end + spamstatus = is_spam .. ', score=' .. string.format('%.2f', score) + + if settings.routines['x-spam-status'].remove then + remove[settings.routines['x-spam-status'].header] = settings.routines['x-spam-status'].remove + end + add_header('x-spam-status', spamstatus) + end + + routines['authentication-results'] = function() + if skip_wanted('authentication-results') then + return + end + local ar = require "lua_auth_results" + + if settings.routines['authentication-results'].remove then + remove[settings.routines['authentication-results'].header] = settings.routines['authentication-results'].remove + end + + local res = ar.gen_auth_results(task, + lua_util.override_defaults(ar.default_settings, + settings.routines['authentication-results'])) + + if res then + add_header('authentication-results', res, ';', 1) + end + end + + routines['stat-signature'] = function() + if skip_wanted('stat-signature') then + return + end + if settings.routines['stat-signature'].remove then + remove[settings.routines['stat-signature'].header] = settings.routines['stat-signature'].remove + end + local res = task:get_mempool():get_variable("stat_signature") + if res then + add[settings.routines['stat-signature'].header] = res + end + end + + routines['fuzzy-hashes'] = function() + local res = task:get_mempool():get_variable("fuzzy_hashes", "fstrings") + + if res and #res > 0 then + for _, h in ipairs(res) do + add_header('fuzzy-hashes', h) + end + end + end + + local routines_enabled = active_routines + local user_settings = task:cache_get('settings') + if user_settings and user_settings.plugins then + user_settings = user_settings.plugins.milter_headers or E + end + + if user_settings and type(user_settings.routines) == 'table' then + lua_util.debugm(N, task, 'override routines to %s from user settings', + user_settings.routines) + routines_enabled = user_settings.routines + settings_override = true + end + + for _, n in ipairs(routines_enabled) do + local ok, err + if custom_routines[n] then + local to_add, to_remove, common_in + ok, err, to_add, to_remove, common_in = pcall(custom_routines[n], task, common) + if ok then + for k, v in pairs(to_add) do + add[k] = v + end + for k, v in pairs(to_remove) do + remove[k] = v + end + for k, v in pairs(common_in) do + if type(v) == 'table' then + if not common[k] then + common[k] = {} + end + for kk, vv in pairs(v) do + common[k][kk] = vv + end + else + common[k] = v + end + end + end + else + ok, err = pcall(routines[n]) + end + if not ok then + logger.errx(task, 'call to %s failed: %s', n, err) + end + end + + if not next(add) then + add = nil + end + if not next(remove) then + remove = nil + end + if add or remove then + + lua_mime.modify_headers(task, { + add = add, + remove = remove + }, settings.headers_modify_mode) + end +end + +local config_schema = ts.shape({ + use = ts.array_of(ts.string) + ts.string / function(s) + return { s } + end, + remove_upstream_spam_flag = ts.boolean:is_optional(), + extended_spam_headers = ts.boolean:is_optional(), + skip_local = ts.boolean:is_optional(), + skip_authenticated = ts.boolean:is_optional(), + local_headers = ts.array_of(ts.string):is_optional(), + authenticated_headers = ts.array_of(ts.string):is_optional(), + extended_headers_rcpt = lua_maps.map_schema:is_optional(), + custom = ts.map_of(ts.string, ts.string):is_optional(), +}, { + extra_fields = ts.map_of(ts.string, ts.any) +}) + +local opts = rspamd_config:get_all_opt(N) or + rspamd_config:get_all_opt('rmilter_headers') + +if not opts then + return +end + +-- Process config +do + local res, err = config_schema:transform(opts) + if not res then + logger.errx(rspamd_config, 'invalid config for %s: %s', N, err) + return + else + opts = res + end +end + +local have_routine = {} +local function activate_routine(s) + if settings.routines[s] or custom_routines[s] then + if not have_routine[s] then + have_routine[s] = true + table.insert(active_routines, s) + if (opts.routines and opts.routines[s]) then + settings.routines[s] = lua_util.override_defaults(settings.routines[s], + opts.routines[s]) + end + end + else + logger.errx(rspamd_config, 'routine "%s" does not exist', s) + end +end + +if opts.remove_upstream_spam_flag ~= nil then + settings.remove_upstream_spam_flag = opts.remove_upstream_spam_flag +end + +if opts.extended_spam_headers then + activate_routine('x-spamd-result') + activate_routine('x-rspamd-server') + activate_routine('x-rspamd-queue-id') + activate_routine('x-rspamd-action') +end + +if opts.local_headers then + for _, h in ipairs(opts.local_headers) do + settings.local_headers[h] = true + end +end +if opts.authenticated_headers then + for _, h in ipairs(opts.authenticated_headers) do + settings.authenticated_headers[h] = true + end +end +if opts.custom then + for k, v in pairs(opts['custom']) do + local f, err = load(v) + if not f then + logger.errx(rspamd_config, 'could not load "%s": %s', k, err) + else + custom_routines[k] = f() + end + end +end + +if type(opts['skip_local']) == 'boolean' then + settings.skip_local = opts['skip_local'] +end + +if type(opts['skip_authenticated']) == 'boolean' then + settings.skip_authenticated = opts['skip_authenticated'] +end + +if type(opts['skip_all']) == 'boolean' then + settings.skip_all = opts['skip_all'] +end + +for _, s in ipairs(opts['use']) do + if not have_routine[s] then + activate_routine(s) + end +end + +if settings.remove_upstream_spam_flag then + activate_routine('remove-spam-flag') +end + +if (#active_routines < 1) then + logger.errx(rspamd_config, 'no active routines') + return +end + +logger.infox(rspamd_config, 'active routines [%s]', + table.concat(active_routines, ',')) + +if opts.extended_headers_rcpt then + settings.extended_headers_rcpt = lua_maps.rspamd_map_add_from_ucl(opts.extended_headers_rcpt, + 'set', 'Extended headers recipients') +end + +rspamd_config:register_symbol({ + name = 'MILTER_HEADERS', + type = 'idempotent', + callback = milter_headers, + flags = 'empty,explicit_disable,ignore_passthrough', +}) diff --git a/src/plugins/lua/mime_types.lua b/src/plugins/lua/mime_types.lua new file mode 100644 index 0000000..167ed38 --- /dev/null +++ b/src/plugins/lua/mime_types.lua @@ -0,0 +1,737 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if confighelp then + return +end + +-- This plugin implements mime types checks for mail messages +local logger = require "rspamd_logger" +local lua_util = require "lua_util" +local rspamd_util = require "rspamd_util" +local lua_maps = require "lua_maps" +local lua_mime_types = require "lua_mime_types" +local lua_magic_types = require "lua_magic/types" +local fun = require "fun" + +local N = "mime_types" +local settings = { + file = '', + symbol_unknown = 'MIME_UNKNOWN', + symbol_bad = 'MIME_BAD', + symbol_good = 'MIME_GOOD', + symbol_attachment = 'MIME_BAD_ATTACHMENT', + symbol_encrypted_archive = 'MIME_ENCRYPTED_ARCHIVE', + symbol_obfuscated_archive = 'MIME_OBFUSCATED_ARCHIVE', + symbol_exe_in_gen_split_rar = 'MIME_EXE_IN_GEN_SPLIT_RAR', + symbol_archive_in_archive = 'MIME_ARCHIVE_IN_ARCHIVE', + symbol_double_extension = 'MIME_DOUBLE_BAD_EXTENSION', + symbol_bad_extension = 'MIME_BAD_EXTENSION', + symbol_bad_unicode = 'MIME_BAD_UNICODE', + regexp = false, + extension_map = { -- extension -> mime_type + html = 'text/html', + htm = 'text/html', + pdf = 'application/pdf', + shtm = 'text/html', + shtml = 'text/html', + txt = 'text/plain' + }, + + bad_extensions = { + cue = 2, + exe = 1, + iso = 4, + jar = 2, + zpaq = 2, + -- In contrast to HTML MIME parts, dedicated HTML attachments are considered harmful + htm = 1, + html = 1, + shtm = 1, + shtml = 1, + -- Have you ever seen that in legit email? + ace = 4, + arj = 2, + aspx = 1, + asx = 2, + cab = 3, + dll = 4, + dqy = 2, + iqy = 2, + mht = 2, + mhtml = 2, + oqy = 2, + rqy = 2, + sfx = 2, + slk = 2, + vst = 2, + vss = 2, + wim = 2, + -- Additional bad extensions from Gmail + ade = 4, + adp = 4, + cmd = 4, + cpl = 4, + ins = 4, + isp = 4, + js = 4, + jse = 4, + lib = 4, + mde = 4, + msc = 4, + msi = 4, + msp = 4, + mst = 4, + nsh = 4, + pif = 4, + sct = 4, + shb = 4, + sys = 4, + vb = 4, + vbe = 4, + vbs = 4, + vxd = 4, + wsc = 4, + wsh = 4, + -- Additional bad extensions from Outlook + app = 4, + asp = 4, + bas = 4, + bat = 4, + chm = 4, + cnt = 4, + com = 4, + csh = 4, + diagcab = 4, + fxp = 4, + gadget = 4, + grp = 4, + hlp = 4, + hpj = 4, + hta = 4, + htc = 4, + inf = 4, + its = 4, + jnlp = 4, + lnk = 4, + ksh = 4, + mad = 4, + maf = 4, + mag = 4, + mam = 4, + maq = 4, + mar = 4, + mas = 4, + mat = 4, + mau = 4, + mav = 4, + maw = 4, + mcf = 4, + mda = 4, + mdb = 4, + mdt = 4, + mdw = 4, + mdz = 4, + msh = 4, + msh1 = 4, + msh2 = 4, + mshxml = 4, + msh1xml = 4, + msh2xml = 4, + msu = 4, + ops = 4, + osd = 4, + pcd = 4, + pl = 4, + plg = 4, + prf = 4, + prg = 4, + printerexport = 4, + ps1 = 4, + ps1xml = 4, + ps2 = 4, + ps2xml = 4, + psc1 = 4, + psc2 = 4, + psd1 = 4, + psdm1 = 4, + pst = 4, + pyc = 4, + pyo = 4, + pyw = 4, + pyz = 4, + pyzw = 4, + reg = 4, + scf = 4, + scr = 4, + shs = 4, + theme = 4, + url = 4, + vbp = 4, + vhd = 4, + vhdx = 4, + vsmacros = 4, + vsw = 4, + webpnp = 4, + website = 4, + ws = 4, + wsf = 4, + xbap = 4, + xll = 4, + xnk = 4, + }, + + -- Something that should not be in archive + bad_archive_extensions = { + docx = 0.1, + hta = 4, + jar = 3, + js = 0.5, + pdf = 0.1, + pptx = 0.1, + vbs = 4, + wsf = 4, + xlsx = 0.1, + }, + + archive_extensions = { + ['7z'] = 1, + ace = 1, + alz = 1, + arj = 1, + bz2 = 1, + cab = 1, + egg = 1, + lz = 1, + rar = 1, + xz = 1, + zip = 1, + zpaq = 1, + }, + + -- Not really archives + archive_exceptions = { + docx = true, + odp = true, + ods = true, + odt = true, + pptx = true, + vsdx = true, + xlsx = true, + -- jar = true, + }, + + -- Multiplier for full extension_map mismatch + other_extensions_mult = 0.4, +} + +local map = nil + +local function check_mime_type(task) + local function gen_extension(fname) + local parts = lua_util.str_split(fname or '', '.') + + local ext = {} + for n = 1, 2 do + ext[n] = #parts > n and string.lower(parts[#parts + 1 - n]) or nil + end + + return ext[1], ext[2], parts + end + + local function check_filename(fname, ct, is_archive, part, detected_ext, nfiles) + + lua_util.debugm(N, task, "check filename: %s, ct=%s, is_archive=%s, detected_ext=%s, nfiles=%s", + fname, ct, is_archive, detected_ext, nfiles) + local has_bad_unicode, char, ch_pos = rspamd_util.has_obscured_unicode(fname) + if has_bad_unicode then + task:insert_result(settings.symbol_bad_unicode, 1.0, + string.format("0x%xd after %s", char, + fname:sub(1, ch_pos))) + end + + -- Decode hex encoded characters + fname = string.gsub(fname, '%%(%x%x)', + function(hex) + return string.char(tonumber(hex, 16)) + end) + + -- Replace potentially bad characters with '?' + fname = fname:gsub('[^%s%g]', '?') + + -- Check file is in filename whitelist + if settings.filename_whitelist and + settings.filename_whitelist:get_key(fname) then + logger.debugm("mime_types", task, "skip checking of %s - file is in filename whitelist", + fname) + return + end + + local ext, ext2, parts = gen_extension(fname) + -- ext is the last extension, LOWERCASED + -- ext2 is the one before last extension LOWERCASED + + local detected + + if not is_archive and detected_ext then + detected = lua_magic_types[detected_ext] + end + + if detected_ext and ((not ext) or ext ~= detected_ext) then + -- Try to find extension by real content type + check_filename('detected.' .. detected_ext, detected.ct, + false, part, nil, 1) + end + + if not ext then + return + end + + local function check_extension(badness_mult, badness_mult2) + if not badness_mult and not badness_mult2 then + return + end + if #parts > 2 then + -- We need to ensure that next-to-last extension is an extension, + -- so we check for its length and if it is not a number or date + if #ext2 > 0 and #ext2 <= 4 and not string.match(ext2, '^%d+[%]%)]?$') then + + -- Use the greatest badness multiplier + if not badness_mult or + (badness_mult2 and badness_mult < badness_mult2) then + badness_mult = badness_mult2 + end + + -- Double extension + bad extension == VERY bad + task:insert_result(settings['symbol_double_extension'], badness_mult, + string.format(".%s.%s", ext2, ext)) + task:insert_result('MIME_TRACE', 0.0, + string.format("%s:%s", part:get_id(), '-')) + return + end + end + if badness_mult then + -- Just bad extension + task:insert_result(settings['symbol_bad_extension'], badness_mult, ext) + task:insert_result('MIME_TRACE', 0.0, + string.format("%s:%s", part:get_id(), '-')) + end + end + + -- Process settings + local extra_table = {} + local extra_archive_table = {} + local user_settings = task:cache_get('settings') + if user_settings and user_settings.plugins then + user_settings = user_settings.plugins.mime_types + end + + if user_settings then + logger.infox(task, 'using special tables from user settings') + if user_settings.bad_extensions then + if user_settings.bad_extensions[1] then + -- Convert to a key-value map + extra_table = fun.tomap( + fun.map(function(e) + return e, 1.0 + end, + user_settings.bad_extensions)) + else + extra_table = user_settings.bad_extensions + end + end + if user_settings.bad_archive_extensions then + if user_settings.bad_archive_extensions[1] then + -- Convert to a key-value map + extra_archive_table = fun.tomap(fun.map( + function(e) + return e, 1.0 + end, + user_settings.bad_archive_extensions)) + else + extra_archive_table = user_settings.bad_archive_extensions + end + end + end + + local function check_tables(e) + if is_archive then + return extra_archive_table[e] or (nfiles < 2 and settings.bad_archive_extensions[e]) or + extra_table[e] or settings.bad_extensions[e] + end + + return extra_table[e] or settings.bad_extensions[e] + end + + -- Also check for archive bad extension + if is_archive then + if ext2 then + local score1 = check_tables(ext) + local score2 = check_tables(ext2) + check_extension(score1, score2) + else + local score1 = check_tables(ext) + check_extension(score1, nil) + end + + if settings['archive_extensions'][ext] then + -- Archive in archive + task:insert_result(settings['symbol_archive_in_archive'], 1.0, ext) + task:insert_result('MIME_TRACE', 0.0, + string.format("%s:%s", part:get_id(), '-')) + end + else + if ext2 then + local score1 = check_tables(ext) + local score2 = check_tables(ext2) + check_extension(score1, score2) + -- Check for archive cloaking like .zip.gz + if settings['archive_extensions'][ext2] + -- Exclude multipart archive extensions, e.g. .zip.001 + and not string.match(ext, '^%d+$') + then + task:insert_result(settings['symbol_archive_in_archive'], + 1.0, string.format(".%s.%s", ext2, ext)) + task:insert_result('MIME_TRACE', 0.0, + string.format("%s:%s", part:get_id(), '-')) + end + else + local score1 = check_tables(ext) + check_extension(score1, nil) + end + end + + local mt = settings['extension_map'][ext] + if mt and ct and ct ~= 'application/octet-stream' then + local found + local mult + for _, v in ipairs(mt) do + mult = v.mult + if ct == v.ct then + found = true + break + end + end + + if not found then + task:insert_result(settings['symbol_attachment'], mult, string.format('%s:%s', + ext, ct)) + end + end + end + + local parts = task:get_parts() + + if parts then + for _, p in ipairs(parts) do + local mtype, subtype = p:get_type() + + if not mtype then + lua_util.debugm(N, task, "no content type for part: %s", p:get_id()) + task:insert_result(settings['symbol_unknown'], 1.0, 'missing content type') + task:insert_result('MIME_TRACE', 0.0, + string.format("%s:%s", p:get_id(), '~')) + else + -- Check for attachment + local filename = p:get_filename() + local ct = string.format('%s/%s', mtype, subtype):lower() + local detected_ext = p:get_detected_ext() + + if filename then + check_filename(filename, ct, false, p, detected_ext, 1) + end + + if p:is_archive() then + local check = true + if detected_ext then + local detected_type = lua_magic_types[detected_ext] + + if detected_type.type ~= 'archive' then + logger.debugm("mime_types", task, "skip checking of %s as archive, %s is not archive but %s", + filename, detected_type.type) + check = false + end + end + if check and filename then + local ext = gen_extension(filename) + + if ext and settings.archive_exceptions[ext] then + check = false + logger.debugm("mime_types", task, "skip checking of %s as archive, %s is whitelisted", + filename, ext) + end + end + local arch = p:get_archive() + + -- TODO: migrate to flags once C part is ready + if arch:is_encrypted() then + task:insert_result(settings.symbol_encrypted_archive, 1.0, filename) + task:insert_result('MIME_TRACE', 0.0, + string.format("%s:%s", p:get_id(), '-')) + elseif arch:is_unreadable() then + task:insert_result(settings.symbol_encrypted_archive, 0.5, { + 'compressed header', + filename, + }) + task:insert_result('MIME_TRACE', 0.0, + string.format("%s:%s", p:get_id(), '-')) + elseif arch:is_obfuscated() then + task:insert_result(settings.symbol_obfuscated_archive, 1.0, { + 'obfuscated archive', + filename, + }) + task:insert_result('MIME_TRACE', 0.0, + string.format("%s:%s", p:get_id(), '-')) + end + + if check then + local is_gen_split_rar = false + if filename then + local ext = gen_extension(filename) + is_gen_split_rar = ext and (string.match(ext, '^%d%d%d$')) and (arch:get_type() == 'rar') + end + + local fl = arch:get_files_full(1000) + + local nfiles = #fl + + for _, f in ipairs(fl) do + if f['encrypted'] then + task:insert_result(settings['symbol_encrypted_archive'], + 1.0, f['name']) + task:insert_result('MIME_TRACE', 0.0, + string.format("%s:%s", p:get_id(), '-')) + end + + if f['name'] then + if is_gen_split_rar and (gen_extension(f['name']) or '') == 'exe' then + task:insert_result(settings['symbol_exe_in_gen_split_rar'], 1.0, f['name']) + else + check_filename(f['name'], nil, + true, p, nil, nfiles) + end + end + end + + if nfiles == 1 and fl[1].name then + -- We check that extension of the file inside archive is + -- the same as double extension of the file + local _, ext2 = gen_extension(filename) + + if ext2 and #ext2 > 0 then + local enc_ext = gen_extension(fl[1].name) + + if enc_ext + and settings['bad_extensions'][enc_ext] + and not tonumber(ext2) + and enc_ext ~= ext2 then + task:insert_result(settings['symbol_double_extension'], 2.0, + string.format("%s!=%s", ext2, enc_ext)) + end + end + end + end + end + + if map then + local v = map:get_key(ct) + local detected_different = false + + local detected_type + if detected_ext then + detected_type = lua_magic_types[detected_ext] + end + + if detected_type and detected_type.ct ~= ct then + local v_detected = map:get_key(detected_type.ct) + if not v or v_detected and v_detected > v then + v = v_detected + end + detected_different = true + end + if v then + local n = tonumber(v) + + if n then + if n > 0 then + if detected_different then + -- Penalize case + n = n * 1.5 + task:insert_result(settings['symbol_bad'], n, + string.format('%s:%s', ct, detected_type.ct)) + else + task:insert_result(settings['symbol_bad'], n, ct) + end + task:insert_result('MIME_TRACE', 0.0, + string.format("%s:%s", p:get_id(), '-')) + elseif n < 0 then + task:insert_result(settings['symbol_good'], -n, ct) + task:insert_result('MIME_TRACE', 0.0, + string.format("%s:%s", p:get_id(), '+')) + else + -- Neutral content type + task:insert_result('MIME_TRACE', 0.0, + string.format("%s:%s", p:get_id(), '~')) + end + else + logger.warnx(task, 'unknown value: "%s" for content type %s in the map', + v, ct) + end + else + task:insert_result(settings['symbol_unknown'], 1.0, ct) + task:insert_result('MIME_TRACE', 0.0, + string.format("%s:%s", p:get_id(), '~')) + end + end + end + end + end +end + +local opts = rspamd_config:get_all_opt('mime_types') +if opts then + for k, v in pairs(opts) do + settings[k] = v + end + + settings.filename_whitelist = lua_maps.rspamd_map_add('mime_types', 'filename_whitelist', 'regexp', + 'filename whitelist') + + local function change_extension_map_entry(ext, ct, mult) + if type(ct) == 'table' then + local tbl = {} + for _, elt in ipairs(ct) do + table.insert(tbl, { + ct = elt, + mult = mult, + }) + end + settings.extension_map[ext] = tbl + else + settings.extension_map[ext] = { [1] = { + ct = ct, + mult = mult + } } + end + end + + -- Transform extension_map + for ext, ct in pairs(settings.extension_map) do + change_extension_map_entry(ext, ct, 1.0) + end + + -- Add all extensions + for _, pair in ipairs(lua_mime_types.full_extensions_map) do + local ext, ct = pair[1], pair[2] + if not settings.extension_map[ext] then + change_extension_map_entry(ext, ct, settings.other_extensions_mult) + end + end + + local map_type = 'map' + if settings['regexp'] then + map_type = 'regexp' + end + map = lua_maps.rspamd_map_add('mime_types', 'file', map_type, + 'mime types map') + if map then + local id = rspamd_config:register_symbol({ + name = 'MIME_TYPES_CALLBACK', + callback = check_mime_type, + type = 'callback', + flags = 'nostat', + group = 'mime_types', + }) + + rspamd_config:register_symbol({ + type = 'virtual', + name = settings['symbol_unknown'], + parent = id, + group = 'mime_types', + }) + rspamd_config:register_symbol({ + type = 'virtual', + name = settings['symbol_bad'], + parent = id, + group = 'mime_types', + }) + rspamd_config:register_symbol({ + type = 'virtual', + name = settings['symbol_good'], + flags = 'nice', + parent = id, + group = 'mime_types', + }) + rspamd_config:register_symbol({ + type = 'virtual', + name = settings['symbol_attachment'], + parent = id, + group = 'mime_types', + }) + rspamd_config:register_symbol({ + type = 'virtual', + name = settings['symbol_encrypted_archive'], + parent = id, + group = 'mime_types', + }) + rspamd_config:register_symbol({ + type = 'virtual', + name = settings['symbol_obfuscated_archive'], + parent = id, + group = 'mime_types', + }) + rspamd_config:register_symbol({ + type = 'virtual', + name = settings['symbol_exe_in_gen_split_rar'], + parent = id, + group = 'mime_types', + }) + rspamd_config:register_symbol({ + type = 'virtual', + name = settings['symbol_archive_in_archive'], + parent = id, + group = 'mime_types', + }) + rspamd_config:register_symbol({ + type = 'virtual', + name = settings['symbol_double_extension'], + parent = id, + group = 'mime_types', + }) + rspamd_config:register_symbol({ + type = 'virtual', + name = settings['symbol_bad_extension'], + parent = id, + group = 'mime_types', + }) + rspamd_config:register_symbol({ + type = 'virtual', + name = settings['symbol_bad_unicode'], + parent = id, + group = 'mime_types', + }) + rspamd_config:register_symbol({ + type = 'virtual', + name = 'MIME_TRACE', + parent = id, + group = 'mime_types', + flags = 'nostat', + score = 0, + }) + else + lua_util.disable_module(N, "config") + end +end diff --git a/src/plugins/lua/multimap.lua b/src/plugins/lua/multimap.lua new file mode 100644 index 0000000..53b2732 --- /dev/null +++ b/src/plugins/lua/multimap.lua @@ -0,0 +1,1403 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if confighelp then + return +end + +-- Multimap is rspamd module designed to define and operate with different maps + +local rules = {} +local rspamd_logger = require "rspamd_logger" +local rspamd_util = require "rspamd_util" +local rspamd_regexp = require "rspamd_regexp" +local rspamd_expression = require "rspamd_expression" +local rspamd_ip = require "rspamd_ip" +local lua_util = require "lua_util" +local lua_selectors = require "lua_selectors" +local lua_maps = require "lua_maps" +local redis_params +local fun = require "fun" +local N = 'multimap' + +local multimap_grammar +-- Parse result in form: <symbol>:<score>|<symbol>|<score> +local function parse_multimap_value(parse_rule, p_ret) + if p_ret and type(p_ret) == 'string' then + local lpeg = require "lpeg" + + if not multimap_grammar then + local number = {} + + local digit = lpeg.R("09") + number.integer = (lpeg.S("+-") ^ -1) * + (digit ^ 1) + + -- Matches: .6, .899, .9999873 + number.fractional = (lpeg.P(".")) * + (digit ^ 1) + + -- Matches: 55.97, -90.8, .9 + number.decimal = (number.integer * -- Integer + (number.fractional ^ -1)) + -- Fractional + (lpeg.S("+-") * number.fractional) -- Completely fractional number + + local sym_start = lpeg.R("az", "AZ") + lpeg.S("_") + local sym_elt = sym_start + lpeg.R("09") + local symbol = sym_start * sym_elt ^ 0 + local symbol_cap = lpeg.Cg(symbol, 'symbol') + local score_cap = lpeg.Cg(number.decimal, 'score') + local opts_cap = lpeg.Cg(lpeg.Ct(lpeg.C(symbol) * (lpeg.P(",") * lpeg.C(symbol)) ^ 0), 'opts') + local symscore_cap = (symbol_cap * lpeg.P(":") * score_cap) + local symscoreopt_cap = symscore_cap * lpeg.P(":") * opts_cap + local grammar = symscoreopt_cap + symscore_cap + symbol_cap + score_cap + multimap_grammar = lpeg.Ct(grammar) + end + local tbl = multimap_grammar:match(p_ret) + + if tbl then + local sym + local score = 1.0 + local opts = {} + + if tbl.symbol then + sym = tbl.symbol + end + if tbl.score then + score = tonumber(tbl.score) + end + if tbl.opts then + opts = tbl.opts + end + + return true, sym, score, opts + else + if p_ret ~= '' then + rspamd_logger.infox(rspamd_config, '%s: cannot parse string "%s"', + parse_rule.symbol, p_ret) + end + + return true, nil, 1.0, {} + end + elseif type(p_ret) == 'boolean' then + return p_ret, nil, 1.0, {} + end + + return false, nil, 0.0, {} +end + +local value_types = { + ip = { + get_value = function(ip) + return ip:to_string() + end, + }, + from = { + get_value = function(val) + return val + end, + }, + helo = { + get_value = function(val) + return val + end, + }, + header = { + get_value = function(val) + return val + end, + }, + rcpt = { + get_value = function(val) + return val + end, + }, + user = { + get_value = function(val) + return val + end, + }, + url = { + get_value = function(val) + return val + end, + }, + dnsbl = { + get_value = function(ip) + return ip:to_string() + end, + }, + filename = { + get_value = function(val) + return val + end, + }, + content = { + get_value = function() + return nil + end, + }, + hostname = { + get_value = function(val) + return val + end, + }, + asn = { + get_value = function(val) + return val + end, + }, + country = { + get_value = function(val) + return val + end, + }, + received = { + get_value = function(val) + return val + end, + }, + mempool = { + get_value = function(val) + return val + end, + }, + selector = { + get_value = function(val) + return val + end, + }, + symbol_options = { + get_value = function(val) + return val + end, + }, +} + +local function ip_to_rbl(ip, rbl) + return table.concat(ip:inversed_str_octets(), ".") .. '.' .. rbl +end + +local function apply_hostname_filter(task, filter, hostname, r) + if filter == 'tld' then + local tld = rspamd_util.get_tld(hostname) + return tld + elseif filter == 'top' then + local tld = rspamd_util.get_tld(hostname) + return tld:match('[^.]*$') or tld + else + if not r['re_filter'] then + local pat = string.match(filter, 'tld:regexp:(.+)') + if not pat then + rspamd_logger.errx(task, 'bad search filter: %s', filter) + return + end + r['re_filter'] = rspamd_regexp.create_cached(pat) + if not r['re_filter'] then + rspamd_logger.errx(task, 'couldnt create regex: %s', pat) + return + end + end + local tld = rspamd_util.get_tld(hostname) + local res = r['re_filter']:search(tld) + if res then + return res[1] + else + return nil + end + end +end + +local function apply_url_filter(task, filter, url, r) + if not filter then + return url:get_host() + end + + if filter == 'tld' then + return url:get_tld() + elseif filter == 'top' then + local tld = url:get_tld() + return tld:match('[^.]*$') or tld + elseif filter == 'full' then + return url:get_text() + elseif filter == 'is_phished' then + if url:is_phished() then + return url:get_host() + else + return nil + end + elseif filter == 'is_redirected' then + if url:is_redirected() then + return url:get_host() + else + return nil + end + elseif filter == 'is_obscured' then + if url:is_obscured() then + return url:get_host() + else + return nil + end + elseif filter == 'path' then + return url:get_path() + elseif filter == 'query' then + return url:get_query() + elseif string.find(filter, 'tag:') then + local tags = url:get_tags() + local want_tag = string.match(filter, 'tag:(.*)') + for _, t in ipairs(tags) do + if t == want_tag then + return url:get_host() + end + end + return nil + elseif string.find(filter, 'tld:regexp:') then + if not r['re_filter'] then + local type, pat = string.match(filter, '(regexp:)(.+)') + if type and pat then + r['re_filter'] = rspamd_regexp.create_cached(pat) + end + end + + if not r['re_filter'] then + rspamd_logger.errx(task, 'bad search filter: %s', filter) + else + local results = r['re_filter']:search(url:get_tld()) + if results then + return results[1] + else + return nil + end + end + elseif string.find(filter, 'full:regexp:') then + if not r['re_filter'] then + local type, pat = string.match(filter, '(regexp:)(.+)') + if type and pat then + r['re_filter'] = rspamd_regexp.create_cached(pat) + end + end + + if not r['re_filter'] then + rspamd_logger.errx(task, 'bad search filter: %s', filter) + else + local results = r['re_filter']:search(url:get_text()) + if results then + return results[1] + else + return nil + end + end + elseif string.find(filter, 'regexp:') then + if not r['re_filter'] then + local type, pat = string.match(filter, '(regexp:)(.+)') + if type and pat then + r['re_filter'] = rspamd_regexp.create_cached(pat) + end + end + + if not r['re_filter'] then + rspamd_logger.errx(task, 'bad search filter: %s', filter) + else + local results = r['re_filter']:search(url:get_host()) + if results then + return results[1] + else + return nil + end + end + elseif string.find(filter, '^template:') then + if not r['template'] then + r['template'] = string.match(filter, '^template:(.+)') + end + + if r['template'] then + return lua_util.template(r['template'], url:to_table()) + end + end + + return url:get_host() +end + +local function apply_addr_filter(task, filter, input, rule) + if filter == 'email:addr' or filter == 'email' then + local addr = rspamd_util.parse_mail_address(input, task:get_mempool(), 1024) + if addr and addr[1] then + return fun.totable(fun.map(function(a) + return a.addr + end, addr)) + end + elseif filter == 'email:user' then + local addr = rspamd_util.parse_mail_address(input, task:get_mempool(), 1024) + if addr and addr[1] then + return fun.totable(fun.map(function(a) + return a.user + end, addr)) + end + elseif filter == 'email:domain' then + local addr = rspamd_util.parse_mail_address(input, task:get_mempool(), 1024) + if addr and addr[1] then + return fun.totable(fun.map(function(a) + return a.domain + end, addr)) + end + elseif filter == 'email:domain:tld' then + local addr = rspamd_util.parse_mail_address(input, task:get_mempool(), 1024) + if addr and addr[1] then + return fun.totable(fun.map(function(a) + return rspamd_util.get_tld(a.domain) + end, addr)) + end + elseif filter == 'email:name' then + local addr = rspamd_util.parse_mail_address(input, task:get_mempool(), 1024) + if addr and addr[1] then + return fun.totable(fun.map(function(a) + return a.name + end, addr)) + end + elseif filter == 'ip_addr' then + local ip_addr = rspamd_ip.from_string(input) + + if ip_addr and ip_addr:is_valid() then + return ip_addr + end + else + -- regexp case + if not rule['re_filter'] then + local type, pat = string.match(filter, '(regexp:)(.+)') + if type and pat then + rule['re_filter'] = rspamd_regexp.create_cached(pat) + end + end + + if not rule['re_filter'] then + rspamd_logger.errx(task, 'bad search filter: %s', filter) + else + local results = rule['re_filter']:search(input) + if results then + return results[1] + end + end + end + + return input +end +local function apply_filename_filter(task, filter, fn, r) + if filter == 'extension' or filter == 'ext' then + return string.match(fn, '%.([^.]+)$') + elseif string.find(filter, 'regexp:') then + if not r['re_filter'] then + local type, pat = string.match(filter, '(regexp:)(.+)') + if type and pat then + r['re_filter'] = rspamd_regexp.create_cached(pat) + end + end + + if not r['re_filter'] then + rspamd_logger.errx(task, 'bad search filter: %s', filter) + else + local results = r['re_filter']:search(fn) + if results then + return results[1] + else + return nil + end + end + end + + return fn +end + +local function apply_regexp_filter(task, filter, fn, r) + if string.find(filter, 'regexp:') then + if not r['re_filter'] then + local type, pat = string.match(filter, '(regexp:)(.+)') + if type and pat then + r['re_filter'] = rspamd_regexp.create_cached(pat) + end + end + + if not r['re_filter'] then + rspamd_logger.errx(task, 'bad search filter: %s', filter) + else + local results = r['re_filter']:search(fn, false, true) + if results then + return results[1][2] + else + return nil + end + end + end + + return fn +end + +local function apply_content_filter(task, filter) + if filter == 'body' then + return { task:get_rawbody() } + elseif filter == 'full' then + return { task:get_content() } + elseif filter == 'headers' then + return { task:get_raw_headers() } + elseif filter == 'text' then + local ret = {} + for _, p in ipairs(task:get_text_parts()) do + table.insert(ret, p:get_content()) + end + return ret + elseif filter == 'rawtext' then + local ret = {} + for _, p in ipairs(task:get_text_parts()) do + table.insert(ret, p:get_content('raw_parsed')) + end + return ret + elseif filter == 'oneline' then + local ret = {} + for _, p in ipairs(task:get_text_parts()) do + table.insert(ret, p:get_content_oneline()) + end + return ret + else + rspamd_logger.errx(task, 'bad search filter: %s', filter) + end + + return {} +end + +local multimap_filters = { + from = apply_addr_filter, + rcpt = apply_addr_filter, + helo = apply_hostname_filter, + symbol_options = apply_regexp_filter, + header = apply_addr_filter, + url = apply_url_filter, + filename = apply_filename_filter, + mempool = apply_regexp_filter, + selector = apply_regexp_filter, + hostname = apply_hostname_filter, + --content = apply_content_filter, -- Content filters are special :( +} + +local function multimap_query_redis(key, task, value, callback) + local cmd = 'HGET' + if type(value) == 'userdata' and value.class == 'rspamd{ip}' then + cmd = 'HMGET' + end + + local srch = { key } + + -- Insert all ips for some mask :( + if type(value) == 'userdata' and value.class == 'rspamd{ip}' then + srch[#srch + 1] = tostring(value) + -- IPv6 case + local maxbits = 128 + local minbits = 64 + if value:get_version() == 4 then + maxbits = 32 + minbits = 8 + end + for i = maxbits, minbits, -1 do + local nip = value:apply_mask(i):tostring() .. "/" .. i + srch[#srch + 1] = nip + end + else + srch[#srch + 1] = value + end + + local function redis_map_cb(err, data) + lua_util.debugm(N, task, 'got reply from Redis when trying to get key %s: err=%s, data=%s', + key, err, data) + if not err and type(data) ~= 'userdata' then + callback(data) + end + end + + return rspamd_redis_make_request(task, + redis_params, -- connect params + key, -- hash key + false, -- is write + redis_map_cb, --callback + cmd, -- command + srch -- arguments + ) +end + +local function multimap_callback(task, rule) + local function match_element(r, value, callback) + if not value then + return false + end + + local function get_key_callback(ret, err_or_data, err_code) + lua_util.debugm(N, task, 'got return "%s" (err code = %s) for multimap %s', + err_or_data, + err_code, + rule.symbol) + + if ret then + if type(err_or_data) == 'table' then + for _, elt in ipairs(err_or_data) do + callback(elt) + end + else + callback(err_or_data) + end + elseif err_code ~= 404 then + rspamd_logger.infox(task, "map %s: get key returned error %s: %s", + rule.symbol, err_code, err_or_data) + end + end + + lua_util.debugm(N, task, 'check value %s for multimap %s', value, + rule.symbol) + + local ret = false + + if r.redis_key then + -- Deal with hash name here: it can be either plain string or a selector + if type(r.redis_key) == 'string' then + ret = multimap_query_redis(r.redis_key, task, value, callback) + else + -- Here we have a selector + local results = r.redis_key(task) + + -- Here we need to spill this function into multiple queries + if type(results) == 'table' then + for _, res in ipairs(results) do + ret = multimap_query_redis(res, task, value, callback) + + if not ret then + break + end + end + else + ret = multimap_query_redis(results, task, value, callback) + end + end + + return ret + elseif r.map_obj then + r.map_obj:get_key(value, get_key_callback, task) + end + end + + local function insert_results(result, opt) + local _, symbol, score, opts = parse_multimap_value(rule, result) + local forced = false + if symbol then + if rule.symbols_set then + if not rule.symbols_set[symbol] then + rspamd_logger.infox(task, 'symbol %s is not registered for map %s, ' .. + 'replace it with just %s', + symbol, rule.symbol, rule.symbol) + symbol = rule.symbol + end + elseif rule.disable_multisymbol then + symbol = rule.symbol + if type(opt) == 'table' then + table.insert(opt, result) + elseif type(opt) ~= nil then + opt = { opt, result } + else + opt = { result } + end + else + forced = not rule.dynamic_symbols + end + else + symbol = rule.symbol + end + + if opts and #opts > 0 then + -- Options come from the map itself + task:insert_result(forced, symbol, score, opts) + else + if opt then + if type(opt) == 'table' then + task:insert_result(forced, symbol, score, fun.totable(fun.map(tostring, opt))) + else + task:insert_result(forced, symbol, score, tostring(opt)) + end + + else + task:insert_result(forced, symbol, score) + end + end + + if rule.action then + local message = rule.message + if rule.message_func then + message = rule.message_func(task, rule.symbol, opt) + end + if message then + task:set_pre_result(rule.action, message, N) + else + task:set_pre_result(rule.action, 'Matched map: ' .. rule.symbol, N) + end + end + end + + -- Match a single value for against a single rule + local function match_rule(r, value) + local function rule_callback(result) + if result then + if type(result) == 'table' then + for _, rs in ipairs(result) do + if type(rs) ~= 'userdata' then + rule_callback(rs) + end + end + return + end + local opt = value_types[r['type']].get_value(value) + insert_results(result, opt) + end + end + + if r.filter or r.type == 'url' then + local fn = multimap_filters[r.type] + + if fn then + + local filtered_value = fn(task, r.filter, value, r) + lua_util.debugm(N, task, 'apply filter %s for rule %s: %s -> %s', + r.filter, r.symbol, value, filtered_value) + value = filtered_value + end + end + + if type(value) == 'table' then + fun.each(function(elt) + match_element(r, elt, rule_callback) + end, value) + else + match_element(r, value, rule_callback) + end + end + + -- Match list of values according to the field + local function match_list(r, ls, fields) + if ls then + if fields then + fun.each(function(e) + local match = e[fields[1]] + if match then + if fields[2] then + match = fields[2](match) + end + match_rule(r, match) + end + end, ls) + else + fun.each(function(e) + match_rule(r, e) + end, ls) + end + end + end + + local function match_addr(r, addr) + match_list(r, addr, { 'addr' }) + + if not r.filter then + match_list(r, addr, { 'domain' }) + match_list(r, addr, { 'user' }) + end + end + + local function match_url(r, url) + match_rule(r, url) + end + + local function match_hostname(r, hostname) + match_rule(r, hostname) + end + + local function match_filename(r, fn) + match_rule(r, fn) + end + + local function match_received_header(r, pos, total, h) + local use_tld = false + local filter = r['filter'] or 'real_ip' + if filter:match('^tld:') then + filter = filter:sub(5) + use_tld = true + end + local v = h[filter] + if v then + local min_pos = tonumber(r['min_pos']) + local max_pos = tonumber(r['max_pos']) + if min_pos then + if min_pos < 0 then + if min_pos == -1 then + if (pos ~= total) then + return + end + else + if pos <= (total - (min_pos * -1)) then + return + end + end + elseif pos < min_pos then + return + end + end + if max_pos then + if max_pos < -1 then + if (total - (max_pos * -1)) >= pos then + return + end + elseif max_pos > 0 then + if pos > max_pos then + return + end + end + end + local match_flags = r['flags'] + local nmatch_flags = r['nflags'] + if match_flags or nmatch_flags then + local got_flags = h['flags'] + if match_flags then + for _, flag in ipairs(match_flags) do + if not got_flags[flag] then + return + end + end + end + if nmatch_flags then + for _, flag in ipairs(nmatch_flags) do + if got_flags[flag] then + return + end + end + end + end + if filter == 'real_ip' or filter == 'from_ip' then + if type(v) == 'string' then + v = rspamd_ip.from_string(v) + end + if v and v:is_valid() then + match_rule(r, v) + end + else + if use_tld and type(v) == 'string' then + v = rspamd_util.get_tld(v) + end + match_rule(r, v) + end + end + end + + local function match_content(r) + local data + + if r['filter'] then + data = apply_content_filter(task, r['filter'], r) + else + data = { task:get_content() } + end + + for _, v in ipairs(data) do + match_rule(r, v) + end + end + + if rule.expression and not rule.combined then + local res, trace = rule['expression']:process_traced(task) + + if not res or res == 0 then + lua_util.debugm(N, task, 'condition is false for %s', + rule.symbol) + return + else + lua_util.debugm(N, task, 'condition is true for %s: %s', + rule.symbol, + trace) + end + end + + local process_rule_funcs = { + ip = function() + local ip = task:get_from_ip() + if ip and ip:is_valid() then + match_rule(rule, ip) + end + end, + dnsbl = function() + local ip = task:get_from_ip() + if ip and ip:is_valid() then + local to_resolve = ip_to_rbl(ip, rule['map']) + local function dns_cb(_, _, results, err) + lua_util.debugm(N, rspamd_config, + 'resolve() finished: results=%1, err=%2, to_resolve=%3', + results, err, to_resolve) + + if err and + (err ~= 'requested record is not found' and + err ~= 'no records with this name') then + rspamd_logger.errx(task, 'error looking up %s: %s', to_resolve, results) + elseif results then + task:insert_result(rule['symbol'], 1, rule['map']) + if rule.action then + task:set_pre_result(rule['action'], + 'Matched map: ' .. rule['symbol'], N) + end + end + end + + task:get_resolver():resolve_a({ + task = task, + name = to_resolve, + callback = dns_cb, + forced = true + }) + end + end, + header = function() + if type(rule['header']) == 'table' then + for _, rh in ipairs(rule['header']) do + local hv = task:get_header_full(rh) + match_list(rule, hv, { 'decoded' }) + end + else + local hv = task:get_header_full(rule['header']) + match_list(rule, hv, { 'decoded' }) + end + end, + rcpt = function() + if task:has_recipients('smtp') then + local rcpts = task:get_recipients('smtp') + match_addr(rule, rcpts) + elseif task:has_recipients('mime') then + local rcpts = task:get_recipients('mime') + match_addr(rule, rcpts) + end + end, + from = function() + if task:has_from('smtp') then + local from = task:get_from('smtp') + match_addr(rule, from) + elseif task:has_from('mime') then + local from = task:get_from('mime') + match_addr(rule, from) + end + end, + helo = function() + local helo = task:get_helo() + if helo then + match_hostname(rule, helo) + end + end, + url = function() + if task:has_urls() then + local msg_urls = task:get_urls() + + for _, url in ipairs(msg_urls) do + match_url(rule, url) + end + end + end, + user = function() + local user = task:get_user() + if user then + match_rule(rule, user) + end + end, + filename = function() + local parts = task:get_parts() + + local function filter_parts(p) + return p:is_attachment() or (not p:is_text()) and (not p:is_multipart()) + end + + local function filter_archive(p) + local ext = p:get_detected_ext() + local det_type = 'unknown' + + if ext then + local lua_magic_types = require "lua_magic/types" + local det_t = lua_magic_types[ext] + + if det_t then + det_type = det_t.type + end + end + + return p:is_archive() and det_type == 'archive' and not rule.skip_archives + end + + for _, p in fun.iter(fun.filter(filter_parts, parts)) do + if filter_archive(p) then + local fnames = p:get_archive():get_files(1000) + + for _, fn in ipairs(fnames) do + match_filename(rule, fn) + end + end + + local fn = p:get_filename() + if fn then + match_filename(rule, fn) + end + -- Also deal with detected content type + if not rule.skip_detected then + local ext = p:get_detected_ext() + + if ext then + local fake_fname = string.format('detected.%s', ext) + lua_util.debugm(N, task, 'detected filename %s', + fake_fname) + match_filename(rule, fake_fname) + end + end + end + end, + + content = function() + match_content(rule) + end, + hostname = function() + local hostname = task:get_hostname() + if hostname then + match_hostname(rule, hostname) + end + end, + asn = function() + local asn = task:get_mempool():get_variable('asn') + if asn then + match_rule(rule, asn) + end + end, + country = function() + local country = task:get_mempool():get_variable('country') + if country then + match_rule(rule, country) + end + end, + mempool = function() + local var = task:get_mempool():get_variable(rule['variable']) + if var then + match_rule(rule, var) + end + end, + symbol_options = function() + local sym = task:get_symbol(rule['target_symbol']) + if sym and sym[1].options then + for _, o in ipairs(sym[1].options) do + match_rule(rule, o) + end + end + end, + received = function() + local hdrs = task:get_received_headers() + if hdrs and hdrs[1] then + if not rule['artificial'] then + hdrs = fun.filter(function(h) + return not h['flags']['artificial'] + end, hdrs):totable() + end + for pos, h in ipairs(hdrs) do + match_received_header(rule, pos, #hdrs, h) + end + end + end, + selector = function() + local elts = rule.selector(task) + + if elts then + if type(elts) == 'table' then + for _, elt in ipairs(elts) do + match_rule(rule, elt) + end + else + match_rule(rule, elts) + end + end + end, + combined = function() + local ret, trace = rule.combined:process(task) + if ret and ret ~= 0 then + for n, t in pairs(trace) do + insert_results(t.value, string.format("%s=%s", + n, t.matched)) + end + end + end, + } + + local rt = rule.type + local process_func = process_rule_funcs[rt] + if process_func then + process_func() + else + rspamd_logger.errx(task, 'Unrecognised rule type: %s', rt) + end +end + +local function gen_multimap_callback(rule) + return function(task) + multimap_callback(task, rule) + end +end + +local function multimap_on_load_gen(rule) + return function() + lua_util.debugm(N, rspamd_config, "loaded map object for rule %s", rule['symbol']) + local known_symbols = {} + rule.map_obj:foreach(function(key, value) + local r, symbol, score, _ = parse_multimap_value(rule, value) + + if r and symbol and not known_symbols[symbol] then + lua_util.debugm(N, rspamd_config, "%s: adding new symbol %s (score = %s), triggered by %s", + rule.symbol, symbol, score, key) + rspamd_config:register_symbol { + name = value, + parent = rule.callback_id, + type = 'virtual', + score = score, + } + rspamd_config:set_metric_symbol({ + group = N, + score = 1.0, -- In future, we will parse score from `get_value` and use it as multiplier + description = 'Automatic symbol generated by rule: ' .. rule.symbol, + name = value, + }) + known_symbols[value] = true + end + end) + end +end + +local function add_multimap_rule(key, newrule) + local ret = false + + local function multimap_load_kv_map(rule) + if rule['regexp'] then + if rule['multi'] then + rule.map_obj = lua_maps.map_add_from_ucl(rule.map, 'regexp_multi', + rule.description) + else + rule.map_obj = lua_maps.map_add_from_ucl(rule.map, 'regexp', + rule.description) + end + elseif rule['glob'] then + if rule['multi'] then + rule.map_obj = lua_maps.map_add_from_ucl(rule.map, 'glob_multi', + rule.description) + else + rule.map_obj = lua_maps.map_add_from_ucl(rule.map, 'glob', + rule.description) + end + else + rule.map_obj = lua_maps.map_add_from_ucl(rule.map, 'hash', + rule.description) + end + end + + local known_generic_types = { + header = true, + rcpt = true, + from = true, + helo = true, + symbol_options = true, + filename = true, + url = true, + user = true, + content = true, + hostname = true, + asn = true, + country = true, + mempool = true, + selector = true, + combined = true + } + + if newrule['message_func'] then + newrule['message_func'] = assert(load(newrule['message_func']))() + end + if newrule['url'] and not newrule['map'] then + newrule['map'] = newrule['url'] + end + if not (newrule.map or newrule.rules) then + rspamd_logger.errx(rspamd_config, 'incomplete rule, missing map') + return nil + end + if not newrule['symbol'] and key then + newrule['symbol'] = key + elseif not newrule['symbol'] then + rspamd_logger.errx(rspamd_config, 'incomplete rule, missing symbol') + return nil + end + if not newrule['description'] then + newrule['description'] = string.format('multimap, type %s: %s', newrule['type'], + newrule['symbol']) + end + if newrule['type'] == 'mempool' and not newrule['variable'] then + rspamd_logger.errx(rspamd_config, 'mempool map requires variable') + return nil + end + if newrule['type'] == 'selector' then + if not newrule['selector'] then + rspamd_logger.errx(rspamd_config, 'selector map requires selector definition') + return nil + else + local selector = lua_selectors.create_selector_closure( + rspamd_config, newrule['selector'], newrule['delimiter'] or "") + + if not selector then + rspamd_logger.errx(rspamd_config, 'selector map has invalid selector: "%s", symbol: %s', + newrule['selector'], newrule['symbol']) + return nil + end + + newrule.selector = selector + end + end + if type(newrule['map']) == 'string' and + string.find(newrule['map'], '^redis://.*$') then + if not redis_params then + rspamd_logger.infox(rspamd_config, 'no redis servers are specified, ' .. + 'cannot add redis map %s: %s', newrule['symbol'], newrule['map']) + return nil + end + + newrule['redis_key'] = string.match(newrule['map'], '^redis://(.*)$') + + if newrule['redis_key'] then + ret = true + end + elseif type(newrule['map']) == 'string' and + string.find(newrule['map'], '^redis%+selector://.*$') then + if not redis_params then + rspamd_logger.infox(rspamd_config, 'no redis servers are specified, ' .. + 'cannot add redis map %s: %s', newrule['symbol'], newrule['map']) + return nil + end + + local selector_str = string.match(newrule['map'], '^redis%+selector://(.*)$') + local selector = lua_selectors.create_selector_closure( + rspamd_config, selector_str, newrule['delimiter'] or "") + + if not selector then + rspamd_logger.errx(rspamd_config, 'redis selector map has invalid selector: "%s", symbol: %s', + selector_str, newrule['symbol']) + return nil + end + + newrule['redis_key'] = selector + ret = true + elseif newrule.type == 'combined' then + local lua_maps_expressions = require "lua_maps_expressions" + newrule.combined = lua_maps_expressions.create(rspamd_config, + { + rules = newrule.rules, + expression = newrule.expression, + on_load = newrule.dynamic_symbols and multimap_on_load_gen(newrule) or nil, + }, N, 'Combined map for ' .. newrule.symbol) + if not newrule.combined then + rspamd_logger.errx(rspamd_config, 'cannot add combined map for %s', newrule.symbol) + else + ret = true + end + else + if newrule['type'] == 'ip' then + newrule.map_obj = lua_maps.map_add_from_ucl(newrule.map, 'radix', + newrule.description) + if newrule.map_obj then + ret = true + else + rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %1', + newrule['map']) + end + elseif newrule['type'] == 'received' then + if type(newrule['flags']) == 'table' and newrule['flags'][1] then + newrule['flags'] = newrule['flags'] + elseif type(newrule['flags']) == 'string' then + newrule['flags'] = { newrule['flags'] } + end + if type(newrule['nflags']) == 'table' and newrule['nflags'][1] then + newrule['nflags'] = newrule['nflags'] + elseif type(newrule['nflags']) == 'string' then + newrule['nflags'] = { newrule['nflags'] } + end + local filter = newrule['filter'] or 'real_ip' + if filter == 'real_ip' or filter == 'from_ip' then + newrule.map_obj = lua_maps.map_add_from_ucl(newrule.map, 'radix', + newrule.description) + if newrule.map_obj then + ret = true + else + rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %1', + newrule['map']) + end + else + multimap_load_kv_map(newrule) + + if newrule.map_obj then + ret = true + else + rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %1', + newrule['map']) + end + end + elseif known_generic_types[newrule.type] then + + if newrule.filter == 'ip_addr' then + newrule.map_obj = lua_maps.map_add_from_ucl(newrule.map, 'radix', + newrule.description) + elseif not newrule.combined then + multimap_load_kv_map(newrule) + end + + if newrule.map_obj then + ret = true + else + rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %1', + newrule['map']) + end + elseif newrule['type'] == 'dnsbl' then + ret = true + end + end + + if ret then + if newrule.map_obj and newrule.dynamic_symbols then + newrule.map_obj:on_load(multimap_on_load_gen(newrule)) + end + if newrule['type'] == 'symbol_options' then + rspamd_config:register_dependency(newrule['symbol'], newrule['target_symbol']) + end + if newrule['require_symbols'] then + local atoms = {} + + local function parse_atom(str) + local atom = table.concat(fun.totable(fun.take_while(function(c) + if string.find(', \t()><+!|&\n', c, 1, true) then + return false + end + return true + end, fun.iter(str))), '') + table.insert(atoms, atom) + return atom + end + + local function process_atom(atom, task) + local f_ret = task:has_symbol(atom) + lua_util.debugm(N, rspamd_config, 'check for symbol %s: %s', atom, f_ret) + + if f_ret then + return 1 + end + + return 0 + end + + local expression = rspamd_expression.create(newrule['require_symbols'], + { parse_atom, process_atom }, rspamd_config:get_mempool()) + if expression then + newrule['expression'] = expression + + fun.each(function(v) + lua_util.debugm(N, rspamd_config, 'add dependency %s -> %s', + newrule['symbol'], v) + rspamd_config:register_dependency(newrule['symbol'], v) + end, atoms) + end + end + return newrule + end + + return nil +end + +-- Registration +local opts = rspamd_config:get_all_opt(N) +if opts and type(opts) == 'table' then + redis_params = rspamd_parse_redis_server(N) + for k, m in pairs(opts) do + if type(m) == 'table' and m['type'] then + local rule = add_multimap_rule(k, m) + if not rule then + rspamd_logger.errx(rspamd_config, 'cannot add rule: "' .. k .. '"') + else + rspamd_logger.infox(rspamd_config, 'added multimap rule: %s (%s)', + k, rule.type) + table.insert(rules, rule) + end + end + end + -- add fake symbol to check all maps inside a single callback + fun.each(function(rule) + local augmentations = {} + + if rule.action then + table.insert(augmentations, 'passthrough') + end + + local id = rspamd_config:register_symbol({ + type = 'normal', + name = rule['symbol'], + augmentations = augmentations, + callback = gen_multimap_callback(rule), + }) + + rule.callback_id = id + + if rule['symbols'] then + -- Find allowed symbols by this map + rule['symbols_set'] = {} + fun.each(function(s) + rspamd_config:register_symbol({ + type = 'virtual', + name = s, + parent = id, + score = tonumber(rule.score or "0") or 0, -- Default score + }) + rule['symbols_set'][s] = 1 + end, rule['symbols']) + end + if not rule.score then + rspamd_logger.infox(rspamd_config, 'set default score 0 for multimap rule %s', rule.symbol) + rule.score = 0 + end + if rule.score then + -- Register metric symbol + rule.name = rule.symbol + rule.description = rule.description or 'multimap symbol' + rule.group = rule.group or N + + local tmp_flags + tmp_flags = rule.flags + + if rule.type == 'received' and rule.flags then + -- XXX: hack to allow received flags/nflags + -- See issue #3526 on GH + rule.flags = nil + end + + -- XXX: for combined maps we use trace, so flags must include one_shot to avoid scores multiplication + if rule.combined and not rule.flags then + rule.flags = 'one_shot' + end + rspamd_config:set_metric_symbol(rule) + rule.flags = tmp_flags + end + end, rules) + + if #rules == 0 then + lua_util.disable_module(N, "config") + end +end diff --git a/src/plugins/lua/mx_check.lua b/src/plugins/lua/mx_check.lua new file mode 100644 index 0000000..71892b9 --- /dev/null +++ b/src/plugins/lua/mx_check.lua @@ -0,0 +1,392 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if confighelp then + return +end + +-- MX check plugin +local rspamd_logger = require "rspamd_logger" +local rspamd_tcp = require "rspamd_tcp" +local rspamd_util = require "rspamd_util" +local lua_util = require "lua_util" +local lua_redis = require "lua_redis" +local N = "mx_check" +local fun = require "fun" + +local settings = { + timeout = 1.0, -- connect timeout + symbol_bad_mx = 'MX_INVALID', + symbol_no_mx = 'MX_MISSING', + symbol_good_mx = 'MX_GOOD', + symbol_white_mx = 'MX_WHITE', + expire = 86400, -- 1 day by default + expire_novalid = 7200, -- 2 hours by default for no valid mxes + greylist_invalid = true, -- Greylist first message with invalid MX (require greylist plugin) + key_prefix = 'rmx', + max_mx_a_records = 5, -- Maximum number of A records to check per MX request + wait_for_greeting = false, -- Wait for SMTP greeting and emit `quit` command +} +local redis_params +local exclude_domains + +local E = {} +local CRLF = '\r\n' +local mx_miss_cache_prefix = 'mx_miss:' + +local function mx_check(task) + local ip_addr = task:get_ip() + if task:get_user() or (ip_addr and ip_addr:is_local()) then + return + end + + local from = task:get_from('smtp') + local mx_domain + if ((from or E)[1] or E).domain and not from[2] then + mx_domain = from[1]['domain'] + else + mx_domain = task:get_helo() + + if mx_domain then + mx_domain = rspamd_util.get_tld(mx_domain) + end + end + + if not mx_domain then + return + end + + if exclude_domains then + if exclude_domains:get_key(mx_domain) then + rspamd_logger.infox(task, 'skip mx check for %s, excluded', mx_domain) + task:insert_result(settings.symbol_white_mx, 1.0, mx_domain) + return + end + end + + local valid = false + + local function check_results(mxes) + if fun.all(function(_, elt) + return elt.checked + end, mxes) then + -- Save cache + local key = settings.key_prefix .. mx_domain + local function redis_cache_cb(err) + if err ~= nil then + rspamd_logger.errx(task, 'redis_cache_cb received error: %1', err) + return + end + end + if not valid then + -- Greylist message + if settings.greylist_invalid then + task:get_mempool():set_variable("grey_greylisted_required", "1") + lua_util.debugm(N, task, "advice to greylist a message") + task:insert_result(settings.symbol_bad_mx, 1.0, "greylisted") + else + task:insert_result(settings.symbol_bad_mx, 1.0) + end + local ret = rspamd_redis_make_request(task, + redis_params, -- connect params + key, -- hash key + true, -- is write + redis_cache_cb, --callback + 'SETEX', -- command + { key, tostring(settings.expire_novalid), '0' } -- arguments + ) + lua_util.debugm(N, task, "set redis cache key: %s; invalid MX", key) + if not ret then + rspamd_logger.errx(task, 'got error connecting to redis') + end + else + local valid_mx = {} + fun.each(function(k) + table.insert(valid_mx, k) + end, fun.filter(function(_, elt) + return elt.working + end, mxes)) + task:insert_result(settings.symbol_good_mx, 1.0, valid_mx) + local value = table.concat(valid_mx, ';') + if mxes[mx_domain] and type(mxes[mx_domain]) == 'table' and mxes[mx_domain].mx_missing then + value = mx_miss_cache_prefix .. value + end + local ret = rspamd_redis_make_request(task, + redis_params, -- connect params + key, -- hash key + true, -- is write + redis_cache_cb, --callback + 'SETEX', -- command + { key, tostring(settings.expire), value } -- arguments + ) + lua_util.debugm(N, task, "set redis cache key: %s; %s", key, value) + if not ret then + rspamd_logger.errx(task, 'error connecting to redis') + end + end + end + end + + local function gen_mx_a_callback(name, mxes) + return function(_, _, results, err) + lua_util.debugm(N, task, "got DNS results for %s: %s", name, results) + mxes[name].ips = results + + local function io_cb(io_err, _, conn) + lua_util.debugm(N, task, "TCP IO callback for %s, error: %s", name, io_err) + if io_err then + mxes[name].checked = true + conn:close() + else + mxes[name].checked = true + mxes[name].working = true + valid = true + if settings.wait_for_greeting then + conn:add_write(function(_) + conn:close() + end, string.format('QUIT%s', CRLF)) + end + end + check_results(mxes) + end + local function on_connect_cb(conn) + lua_util.debugm(N, task, "TCP connect callback for %s, error: %s", name, err) + if err then + mxes[name].checked = true + conn:close() + check_results(mxes) + else + mxes[name].checked = true + valid = true + mxes[name].working = true + end + + -- Disconnect without SMTP dialog + if not settings.wait_for_greeting then + check_results(mxes) + conn:close() + end + end + + if err or not results or #results == 0 then + mxes[name].checked = true + else + -- Try to open TCP connection to port 25 for a random IP address + -- see #3839 on GitHub + lua_util.shuffle(results) + local str_ip = results[1]:to_string() + lua_util.debugm(N, task, "trying to connect to IP %s", str_ip) + local t_ret = rspamd_tcp.new({ + task = task, + host = str_ip, + callback = io_cb, + stop_pattern = CRLF, + on_connect = on_connect_cb, + timeout = settings.timeout, + port = 25 + }) + + if not t_ret then + mxes[name].checked = true + end + end + check_results(mxes) + end + end + + local function mx_callback(_, _, results, err) + local mxes = {} + if err or not results then + local r = task:get_resolver() + -- XXX: maybe add ipv6? + -- fallback to implicit mx + if not err and not results then + err = 'no MX records found' + end + + lua_util.debugm(N, task, "cannot find MX record for %s: %s, use implicit fallback", + mx_domain, err) + mxes[mx_domain] = { checked = false, working = false, ips = {}, mx_missing = true } + r:resolve('a', { + name = mx_domain, + callback = gen_mx_a_callback(mx_domain, mxes), + task = task, + forced = true + }) + task:insert_result(settings.symbol_no_mx, 1.0, err) + else + -- Inverse sort by priority + table.sort(results, function(r1, r2) + return r1['priority'] > r2['priority'] + end) + + local max_mx_to_resolve = math.min(#results, settings.max_mx_a_records) + lua_util.debugm(N, task, 'check %s MX records (%d actually returned)', + max_mx_to_resolve, #results) + for i = 1, max_mx_to_resolve do + local mx = results[i] + mxes[mx.name] = { checked = false, working = false, ips = {} } + local r = task:get_resolver() + -- XXX: maybe add ipv6? + r:resolve('a', { + name = mx.name, + callback = gen_mx_a_callback(mx.name, mxes), + task = task, + forced = true + }) + end + check_results(mxes) + end + end + + if not redis_params then + local r = task:get_resolver() + r:resolve('mx', { + name = mx_domain, + callback = mx_callback, + task = task, + forced = true + }) + else + local function redis_cache_get_cb(err, data) + if err or type(data) ~= 'string' then + local r = task:get_resolver() + r:resolve('mx', { + name = mx_domain, + callback = mx_callback, + task = task, + forced = true + }) + else + if data == '0' then + task:insert_result(settings.symbol_bad_mx, 1.0, 'cached') + else + if lua_util.str_startswith(data, mx_miss_cache_prefix) then + task:insert_result(settings.symbol_no_mx, 1.0, 'cached') + data = string.sub(data, #mx_miss_cache_prefix + 1) + end + local mxes = lua_util.str_split(data, ';') + task:insert_result(settings.symbol_good_mx, 1.0, 'cached: ' .. mxes[1]) + end + end + end + + local key = settings.key_prefix .. mx_domain + local ret = rspamd_redis_make_request(task, + redis_params, -- connect params + key, -- hash key + false, -- is write + redis_cache_get_cb, --callback + 'GET', -- command + { key } -- arguments + ) + + if not ret then + local r = task:get_resolver() + r:resolve('mx', { + name = mx_domain, + callback = mx_callback, + task = task, + forced = true + }) + end + end +end + +-- Module setup +local opts = rspamd_config:get_all_opt('mx_check') +if not (opts and type(opts) == 'table') then + rspamd_logger.infox(rspamd_config, 'module is unconfigured') + return +end +if opts then + redis_params = lua_redis.parse_redis_server('mx_check') + if not redis_params then + rspamd_logger.errx(rspamd_config, 'no redis servers are specified, disabling module') + lua_util.disable_module(N, "redis") + return + end + + settings = lua_util.override_defaults(settings, opts) + lua_redis.register_prefix(settings.key_prefix .. '*', N, + 'MX check cache', { + type = 'string', + }) + + local id = rspamd_config:register_symbol({ + name = settings.symbol_bad_mx, + type = 'normal', + callback = mx_check, + flags = 'empty', + augmentations = { string.format("timeout=%f", settings.timeout + rspamd_config:get_dns_timeout() or 0.0) }, + }) + rspamd_config:register_symbol({ + name = settings.symbol_no_mx, + type = 'virtual', + parent = id + }) + rspamd_config:register_symbol({ + name = settings.symbol_good_mx, + type = 'virtual', + parent = id + }) + rspamd_config:register_symbol({ + name = settings.symbol_white_mx, + type = 'virtual', + parent = id + }) + + rspamd_config:set_metric_symbol({ + name = settings.symbol_bad_mx, + score = 0.5, + description = 'Domain has no working MX', + group = 'MX', + one_shot = true, + one_param = true, + }) + rspamd_config:set_metric_symbol({ + name = settings.symbol_good_mx, + score = -0.01, + description = 'Domain has working MX', + group = 'MX', + one_shot = true, + one_param = true, + }) + rspamd_config:set_metric_symbol({ + name = settings.symbol_white_mx, + score = 0.0, + description = 'Domain is whitelisted from MX check', + group = 'MX', + one_shot = true, + one_param = true, + }) + rspamd_config:set_metric_symbol({ + name = settings.symbol_no_mx, + score = 3.5, + description = 'Domain has no resolvable MX', + group = 'MX', + one_shot = true, + one_param = true, + }) + + if settings.exclude_domains then + exclude_domains = rspamd_config:add_map { + type = 'set', + description = 'Exclude specific domains from MX checks', + url = settings.exclude_domains, + } + end +end diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua new file mode 100644 index 0000000..f3b26f1 --- /dev/null +++ b/src/plugins/lua/neural.lua @@ -0,0 +1,1000 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + + +if confighelp then + return +end + +local fun = require "fun" +local lua_redis = require "lua_redis" +local lua_util = require "lua_util" +local lua_verdict = require "lua_verdict" +local neural_common = require "plugins/neural" +local rspamd_kann = require "rspamd_kann" +local rspamd_logger = require "rspamd_logger" +local rspamd_tensor = require "rspamd_tensor" +local rspamd_text = require "rspamd_text" +local rspamd_util = require "rspamd_util" +local ts = require("tableshape").types + +local N = "neural" + +local settings = neural_common.settings + +local redis_profile_schema = ts.shape { + digest = ts.string, + symbols = ts.array_of(ts.string), + version = ts.number, + redis_key = ts.string, + distance = ts.number:is_optional(), +} + +local has_blas = rspamd_tensor.has_blas() +local text_cookie = rspamd_text.cookie + +-- Creates and stores ANN profile in Redis +local function new_ann_profile(task, rule, set, version) + local ann_key = neural_common.new_ann_key(rule, set, version, settings) + + local profile = { + symbols = set.symbols, + redis_key = ann_key, + version = version, + digest = set.digest, + distance = 0 -- Since we are using our own profile + } + + local ucl = require "ucl" + local profile_serialized = ucl.to_format(profile, 'json-compact', true) + + local function add_cb(err, _) + if err then + rspamd_logger.errx(task, 'cannot store ANN profile for %s:%s at %s : %s', + rule.prefix, set.name, profile.redis_key, err) + else + rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s', + rule.prefix, set.name, profile.redis_key) + end + end + + lua_redis.redis_make_request(task, + rule.redis, + nil, + true, -- is write + add_cb, --callback + 'ZADD', -- command + { set.prefix, tostring(rspamd_util.get_time()), profile_serialized } + ) + + return profile +end + + +-- ANN filter function, used to insert scores based on the existing symbols +local function ann_scores_filter(task) + + for _, rule in pairs(settings.rules) do + local sid = task:get_settings_id() or -1 + local ann + local profile + + local set = neural_common.get_rule_settings(task, rule) + if set then + if set.ann then + ann = set.ann.ann + profile = set.ann + else + lua_util.debugm(N, task, 'no ann loaded for %s:%s', + rule.prefix, set.name) + end + else + lua_util.debugm(N, task, 'no ann defined in %s for settings id %s', + rule.prefix, sid) + end + + if ann then + local vec = neural_common.result_to_vector(task, profile) + + local score + local out = ann:apply1(vec, set.ann.pca) + score = out[1] + + local symscore = string.format('%.3f', score) + task:cache_set(rule.prefix .. '_neural_score', score) + lua_util.debugm(N, task, '%s:%s:%s ann score: %s', + rule.prefix, set.name, set.ann.version, symscore) + + if score > 0 then + local result = score + + -- If spam_score_threshold is defined, override all other thresholds. + local spam_threshold = 0 + if rule.spam_score_threshold then + spam_threshold = rule.spam_score_threshold + elseif rule.roc_enabled and not set.ann.roc_thresholds then + spam_threshold = set.ann.roc_thresholds[1] + end + + if result >= spam_threshold then + if rule.flat_threshold_curve then + task:insert_result(rule.symbol_spam, 1.0, symscore) + else + task:insert_result(rule.symbol_spam, result, symscore) + end + else + lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam threshold)', + rule.prefix, set.name, set.ann.version, symscore, + spam_threshold) + end + else + local result = -(score) + + -- If ham_score_threshold is defined, override all other thresholds. + local ham_threshold = 0 + if rule.ham_score_threshold then + ham_threshold = rule.ham_score_threshold + elseif rule.roc_enabled and not set.ann.roc_thresholds then + ham_threshold = set.ann.roc_thresholds[2] + end + + if result >= ham_threshold then + if rule.flat_threshold_curve then + task:insert_result(rule.symbol_ham, 1.0, symscore) + else + task:insert_result(rule.symbol_ham, result, symscore) + end + else + lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham threshold)', + rule.prefix, set.name, set.ann.version, result, + ham_threshold) + end + end + end + end +end + +local function ann_push_task_result(rule, task, verdict, score, set) + local train_opts = rule.train + local learn_spam, learn_ham + local skip_reason = 'unknown' + + if not train_opts.store_pool_only and train_opts.autotrain then + if train_opts.spam_score then + learn_spam = score >= train_opts.spam_score + + if not learn_spam then + skip_reason = string.format('score < spam_score: %f < %f', + score, train_opts.spam_score) + end + else + learn_spam = verdict == 'spam' or verdict == 'junk' + + if not learn_spam then + skip_reason = string.format('verdict: %s', + verdict) + end + end + + if train_opts.ham_score then + learn_ham = score <= train_opts.ham_score + if not learn_ham then + skip_reason = string.format('score > ham_score: %f > %f', + score, train_opts.ham_score) + end + else + learn_ham = verdict == 'ham' + + if not learn_ham then + skip_reason = string.format('verdict: %s', + verdict) + end + end + else + -- Train by request header + local hdr = task:get_request_header('ANN-Train') + + if hdr then + if hdr:lower() == 'spam' then + learn_spam = true + elseif hdr:lower() == 'ham' then + learn_ham = true + else + skip_reason = 'no explicit header' + end + elseif train_opts.store_pool_only then + local ucl = require "ucl" + learn_ham = false + learn_spam = false + + -- Explicitly store tokens in cache + local vec = neural_common.result_to_vector(task, set) + task:cache_set(rule.prefix .. '_neural_vec_mpack', ucl.to_format(vec, 'msgpack')) + task:cache_set(rule.prefix .. '_neural_profile_digest', set.digest) + skip_reason = 'store_pool_only has been set' + end + end + + if learn_spam or learn_ham then + local learn_type + if learn_spam then + learn_type = 'spam' + else + learn_type = 'ham' + end + + local function vectors_len_cb(err, data) + if not err and type(data) == 'table' then + local nspam, nham = data[1], data[2] + + if neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then + local vec = neural_common.result_to_vector(task, set) + + local str = rspamd_util.zstd_compress(table.concat(vec, ';')) + local target_key = set.ann.redis_key .. '_' .. learn_type .. '_set' + + local function learn_vec_cb(redis_err) + if redis_err then + rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s', + rule.prefix, set.name, redis_err) + else + lua_util.debugm(N, task, + "add train data for ANN rule " .. + "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed", + rule.prefix, set.name, learn_type, #vec, target_key, #str) + end + end + + lua_redis.redis_make_request(task, + rule.redis, + nil, + true, -- is write + learn_vec_cb, --callback + 'SADD', -- command + { target_key, str } -- arguments + ) + else + lua_util.debugm(N, task, + "do not add %s train data for ANN rule " .. + "%s:%s", + learn_type, rule.prefix, set.name) + end + else + if err then + rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s', + rule.prefix, set.name, err) + elseif type(data) == 'string' then + -- nil return value + rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning: %s", + learn_type, rule.prefix, set.name, set.ann.redis_key, data) + else + rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' .. + 'please remove this key from Redis manually if you perform upgrade from the previous version', + rule.prefix, set.name, set.ann.redis_key, type(data)) + end + end + end + + -- Check if we can learn + if set.can_store_vectors then + if not set.ann then + -- Need to create or load a profile corresponding to the current configuration + set.ann = new_ann_profile(task, rule, set, 0) + lua_util.debugm(N, task, + 'requested new profile for %s, set.ann is missing', + set.name) + end + + lua_redis.exec_redis_script(neural_common.redis_script_id.vectors_len, + { task = task, is_write = false }, + vectors_len_cb, + { + set.ann.redis_key, + }) + else + lua_util.debugm(N, task, + 'do not push data: train condition not satisfied; reason: not checked existing ANNs') + end + else + lua_util.debugm(N, task, + 'do not push data to key %s: train condition not satisfied; reason: %s', + (set.ann or {}).redis_key, + skip_reason) + end +end + +--- Offline training logic + +-- Utility to extract and split saved training vectors to a table of tables +local function process_training_vectors(data) + return fun.totable(fun.map(function(tok) + local _, str = rspamd_util.zstd_decompress(tok) + return fun.totable(fun.map(tonumber, lua_util.str_split(tostring(str), ';'))) + end, data)) +end + +-- This function does the following: +-- * Tries to lock ANN +-- * Loads spam and ham vectors +-- * Spawn learning process +local function do_train_ann(worker, ev_base, rule, set, ann_key) + local spam_elts = {} + local ham_elts = {} + + local function redis_ham_cb(err, data) + if err or type(data) ~= 'table' then + rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s', + ann_key, err) + -- Unlock on error + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + true, -- is write + neural_common.gen_unlock_cb(rule, set, ann_key), --callback + 'HDEL', -- command + { ann_key, 'lock' } + ) + else + -- Decompress and convert to numbers each training vector + ham_elts = process_training_vectors(data) + neural_common.spawn_train({ worker = worker, ev_base = ev_base, + rule = rule, set = set, ann_key = ann_key, ham_vec = ham_elts, + spam_vec = spam_elts }) + end + end + + -- Spam vectors received + local function redis_spam_cb(err, data) + if err or type(data) ~= 'table' then + rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s', + ann_key, err) + -- Unlock ANN on error + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + true, -- is write + neural_common.gen_unlock_cb(rule, set, ann_key), --callback + 'HDEL', -- command + { ann_key, 'lock' } + ) + else + -- Decompress and convert to numbers each training vector + spam_elts = process_training_vectors(data) + -- Now get ham vectors... + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_ham_cb, --callback + 'SMEMBERS', -- command + { ann_key .. '_ham_set' } + ) + end + end + + local function redis_lock_cb(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s', + ann_key, err) + elseif type(data) == 'number' and data == 1 then + -- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_spam_cb, --callback + 'SMEMBERS', -- command + { ann_key .. '_spam_set' } + ) + + rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning', + rule.prefix, set.name, ann_key) + else + local lock_tm = tonumber(data[1]) + rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' .. + 'locked by another host %s at %s', rule.prefix, set.name, ann_key, + data[2], os.date('%c', lock_tm)) + end + end + + -- Check if we are already learning this network + if set.learning_spawned then + rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN', + ann_key) + return + end + + -- Call Redis script that tries to acquire a lock + -- This script returns either a boolean or a pair {'lock_time', 'hostname'} when + -- ANN is locked by another host (or a process, meh) + lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_lock, + { ev_base = ev_base, is_write = true }, + redis_lock_cb, + { + ann_key, + tostring(os.time()), + tostring(math.max(10.0, rule.watch_interval * 2)), + rspamd_util.get_hostname() + }) +end + +-- This function loads new ann from Redis +-- This is based on `profile` attribute. +-- ANN is loaded from `profile.redis_key` +-- Rank of `profile` key is also increased, unfortunately, it means that we need to +-- serialize profile one more time and set its rank to the current time +-- set.ann fields are set according to Redis data received +local function load_new_ann(rule, ev_base, set, profile, min_diff) + local ann_key = profile.redis_key + + local function data_cb(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s', + ann_key, err) + else + if type(data) == 'table' then + if type(data[1]) == 'userdata' and data[1].cookie == text_cookie then + local _err, ann_data = rspamd_util.zstd_decompress(data[1]) + local ann + + if _err or not ann_data then + rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s', + rule.prefix .. ':' .. set.name, ann_key, _err) + return + else + ann = rspamd_kann.load(ann_data) + + if ann then + set.ann = { + digest = profile.digest, + version = profile.version, + symbols = profile.symbols, + distance = min_diff, + redis_key = profile.redis_key + } + + local ucl = require "ucl" + local profile_serialized = ucl.to_format(profile, 'json-compact', true) + set.ann.ann = ann -- To avoid serialization + + local function rank_cb(_, _) + -- TODO: maybe add some logging + end + -- Also update rank for the loaded ANN to avoid removal + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + true, -- is write + rank_cb, --callback + 'ZADD', -- command + { set.prefix, tostring(rspamd_util.get_time()), profile_serialized } + ) + rspamd_logger.infox(rspamd_config, + 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s', + rule.prefix, set.name, ann_key, #data[1], profile.version) + else + rspamd_logger.errx(rspamd_config, + 'cannot unpack/deserialise ANN for %s:%s from Redis key %s', + rule.prefix, set.name, ann_key) + end + end + else + lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s', + rule.prefix, set.name, ann_key) + end + + if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then + if rule.roc_enabled then + local ucl = require "ucl" + local parser = ucl.parser() + local ok, parse_err = parser:parse_text(data[2]) + assert(ok, parse_err) + local roc_thresholds = parser:get_object() + set.ann.roc_thresholds = roc_thresholds + rspamd_logger.infox(rspamd_config, + 'loaded ROC thresholds for %s:%s; version=%s', + rule.prefix, set.name, profile.version) + rspamd_logger.debugx("ROC thresholds: %s", roc_thresholds) + end + end + + if set.ann and set.ann.ann and type(data[3]) == 'userdata' and data[3].cookie == text_cookie then + -- PCA table + local _err, pca_data = rspamd_util.zstd_decompress(data[3]) + if pca_data then + if rule.max_inputs then + -- We can use PCA + set.ann.pca = rspamd_tensor.load(pca_data) + rspamd_logger.infox(rspamd_config, + 'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s', + rule.prefix, set.name, ann_key, #data[3], profile.version) + else + -- no need in pca, why is it there? + rspamd_logger.warnx(rspamd_config, + 'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined', + rule.prefix, set.name, ann_key) + end + else + -- pca can be missing merely if we have no max_inputs + if rule.max_inputs then + rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s: no PCA: %s', + rule.prefix, set.name, ann_key, _err) + set.ann.ann = nil + else + -- It is okay + set.ann.pca = nil + end + end + end + + else + lua_util.debugm(N, rspamd_config, 'no ANN key for %s:%s in Redis key %s', + rule.prefix, set.name, ann_key) + end + end + end + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, -- is write + data_cb, --callback + 'HMGET', -- command + { ann_key, 'ann', 'roc_thresholds', 'pca' }, -- arguments + { opaque_data = true } + ) +end + +-- Used to check an element in Redis serialized as JSON +-- for some specific rule + some specific setting +-- This function tries to load more fresh or more specific ANNs in lieu of +-- the existing ones. +-- Use this function to load ANNs as `callback` parameter for `check_anns` function +local function process_existing_ann(_, ev_base, rule, set, profiles) + local my_symbols = set.symbols + local min_diff = math.huge + local sel_elt + + for _, elt in fun.iter(profiles) do + if elt and elt.symbols then + local dist = lua_util.distance_sorted(elt.symbols, my_symbols) + -- Check distance + if dist < #my_symbols * .3 then + if dist < min_diff then + min_diff = dist + sel_elt = elt + end + end + end + end + + if sel_elt then + -- We can load element from ANN + if set.ann then + -- We have an existing ANN, probably the same... + if set.ann.digest == sel_elt.digest then + -- Same ANN, check version + if set.ann.version < sel_elt.version then + -- Load new ann + rspamd_logger.infox(rspamd_config, 'ann %s is changed, ' .. + 'our version = %s, remote version = %s', + rule.prefix .. ':' .. set.name, + set.ann.version, + sel_elt.version) + load_new_ann(rule, ev_base, set, sel_elt, min_diff) + else + lua_util.debugm(N, rspamd_config, 'ann %s is not changed, ' .. + 'our version = %s, remote version = %s', + rule.prefix .. ':' .. set.name, + set.ann.version, + sel_elt.version) + end + else + -- We have some different ANN, so we need to compare distance + if set.ann.distance > min_diff then + -- Load more specific ANN + rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s, ' .. + 'our distance = %s, remote distance = %s', + rule.prefix .. ':' .. set.name, + set.ann.distance, + min_diff) + load_new_ann(rule, ev_base, set, sel_elt, min_diff) + else + lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific, ' .. + 'our distance = %s, remote distance = %s', + rule.prefix .. ':' .. set.name, + set.ann.distance, + min_diff) + end + end + else + -- We have no ANN, load new one + load_new_ann(rule, ev_base, set, sel_elt, min_diff) + end + end +end + + +-- This function checks all profiles and selects if we can train our +-- ANN. By our we mean that it has exactly the same symbols in profile. +-- Use this function to train ANN as `callback` parameter for `check_anns` function +local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) + local my_symbols = set.symbols + local sel_elt + local lens = { + spam = 0, + ham = 0, + } + + for _, elt in fun.iter(profiles) do + if elt and elt.symbols then + local dist = lua_util.distance_sorted(elt.symbols, my_symbols) + -- Check distance + if dist == 0 then + sel_elt = elt + break + end + end + end + + if sel_elt then + -- We have our ANN and that's train vectors, check if we can learn + local ann_key = sel_elt.redis_key + + lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained", + ann_key) + + -- Create continuation closure + local redis_len_cb_gen = function(cont_cb, what, is_final) + return function(err, data) + if err then + rspamd_logger.errx(rspamd_config, + 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err) + elseif data and type(data) == 'number' or type(data) == 'string' then + local ntrains = tonumber(data) or 0 + lens[what] = ntrains + if is_final then + -- Ensure that we have the following: + -- one class has reached max_trains + -- other class(es) are at least as full as classes_bias + -- e.g. if classes_bias = 0.25 and we have 10 max_trains then + -- one class must have 10 or more trains whilst another should have + -- at least (10 * (1 - 0.25)) = 8 trains + + local max_len = math.max(lua_util.unpack(lua_util.values(lens))) + local min_len = math.min(lua_util.unpack(lua_util.values(lens))) + + if rule.train.learn_type == 'balanced' then + local len_bias_check_pred = function(_, l) + return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias) + end + if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then + rspamd_logger.debugm(N, rspamd_config, + 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', + ann_key, lens, rule.train.max_trains, what) + cont_cb() + else + rspamd_logger.debugm(N, rspamd_config, + 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', + ann_key, what, lens, rule.train.max_trains) + end + else + -- Probabilistic mode, just ensure that at least one vector is okay + if min_len > 0 and max_len >= rule.train.max_trains then + rspamd_logger.debugm(N, rspamd_config, + 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', + ann_key, lens, rule.train.max_trains, what) + cont_cb() + else + rspamd_logger.debugm(N, rspamd_config, + 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', + ann_key, what, lens, rule.train.max_trains) + end + end + + else + rspamd_logger.debugm(N, rspamd_config, + 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors', + what, ann_key, ntrains, rule.train.max_trains) + cont_cb() + end + end + end + + end + + local function initiate_train() + rspamd_logger.infox(rspamd_config, + 'need to learn ANN %s after %s required learn vectors', + ann_key, lens) + do_train_ann(worker, ev_base, rule, set, ann_key) + end + + -- Spam vector is OK, check ham vector length + local function check_ham_len() + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_len_cb_gen(initiate_train, 'ham', true), --callback + 'SCARD', -- command + { ann_key .. '_ham_set' } + ) + end + + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_len_cb_gen(check_ham_len, 'spam', false), --callback + 'SCARD', -- command + { ann_key .. '_spam_set' } + ) + end +end + +-- Used to deserialise ANN element from a list +local function load_ann_profile(element) + local ucl = require "ucl" + + local parser = ucl.parser() + local res, ucl_err = parser:parse_string(element) + if not res then + rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s', + ucl_err) + return nil + else + local profile = parser:get_object() + local checked, schema_err = redis_profile_schema:transform(profile) + if not checked then + rspamd_logger.errx(rspamd_config, "cannot parse profile schema: %s", schema_err) + + return nil + end + return checked + end +end + +-- Function to check or load ANNs from Redis +local function check_anns(worker, cfg, ev_base, rule, process_callback, what) + for _, set in pairs(rule.settings) do + local function members_cb(err, data) + if err then + rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s', + err) + set.can_store_vectors = true + elseif type(data) == 'table' then + lua_util.debugm(N, cfg, '%s: process element %s:%s', + what, rule.prefix, set.name) + process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data)) + set.can_store_vectors = true + end + end + + if type(set) == 'table' then + -- Extract all profiles for some specific settings id + -- Get the last `max_profiles` recently used + -- Select the most appropriate to our profile but it should not differ by more + -- than 30% of symbols + lua_redis.redis_make_request_taskless(ev_base, + cfg, + rule.redis, + nil, + false, -- is write + members_cb, --callback + 'ZREVRANGE', -- command + { set.prefix, '0', tostring(settings.max_profiles) } -- arguments + ) + end + end -- Cycle over all settings + + return rule.watch_interval +end + +-- Function to clean up old ANNs +local function cleanup_anns(rule, cfg, ev_base) + for _, set in pairs(rule.settings) do + local function invalidate_cb(err, data) + if err then + rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s', + err) + elseif type(data) == 'table' then + for _, expired in ipairs(data) do + local profile = load_ann_profile(expired) + rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s', + rule.prefix .. ':' .. set.name, + profile.redis_key, + profile.version) + end + end + end + + if type(set) == 'table' then + lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_invalidate, + { ev_base = ev_base, is_write = true }, + invalidate_cb, + { set.prefix, tostring(settings.max_profiles) }) + end + end +end + +local function ann_push_vector(task) + if task:has_flag('skip') then + lua_util.debugm(N, task, 'do not push data for skipped task') + return + end + if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then + lua_util.debugm(N, task, 'do not push data for manual scan') + return + end + + local verdict, score = lua_verdict.get_specific_verdict(N, task) + + if verdict == 'passthrough' then + lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)', + verdict, score) + + return + end + + if score ~= score then + lua_util.debugm(N, task, 'ignore task as its score is nan (%s verdict)', + verdict) + + return + end + + for _, rule in pairs(settings.rules) do + local set = neural_common.get_rule_settings(task, rule) + + if set then + ann_push_task_result(rule, task, verdict, score, set) + else + lua_util.debugm(N, task, 'settings not found in rule %s', rule.prefix) + end + + end +end + + +-- Initialization part +if not (neural_common.module_config and type(neural_common.module_config) == 'table') + or not neural_common.redis_params then + rspamd_logger.infox(rspamd_config, 'Module is unconfigured') + lua_util.disable_module(N, "redis") + return +end + +local rules = neural_common.module_config['rules'] + +if not rules then + -- Use legacy configuration + rules = {} + rules['default'] = neural_common.module_config +end + +local id = rspamd_config:register_symbol({ + name = 'NEURAL_CHECK', + type = 'postfilter,callback', + flags = 'nostat', + priority = lua_util.symbols_priorities.medium, + callback = ann_scores_filter +}) + +neural_common.settings.rules = {} -- Reset unless validated further in the cycle + +if settings.blacklisted_symbols and settings.blacklisted_symbols[1] then + -- Transform to hash for simplicity + settings.blacklisted_symbols = lua_util.list_to_hash(settings.blacklisted_symbols) +end + +-- Check all rules +for k, r in pairs(rules) do + local rule_elt = lua_util.override_defaults(neural_common.default_options, r) + rule_elt['redis'] = neural_common.redis_params + rule_elt['anns'] = {} -- Store ANNs here + + if not rule_elt.prefix then + rule_elt.prefix = k + end + if not rule_elt.name then + rule_elt.name = k + end + if rule_elt.train.max_train and not rule_elt.train.max_trains then + rule_elt.train.max_trains = rule_elt.train.max_train + end + + if not rule_elt.profile then + rule_elt.profile = {} + end + + if rule_elt.max_inputs and not has_blas then + rspamd_logger.errx('cannot set max inputs to %s as BLAS is not compiled in', + rule_elt.name, rule_elt.max_inputs) + rule_elt.max_inputs = nil + end + + rspamd_logger.infox(rspamd_config, "register ann rule %s", k) + settings.rules[k] = rule_elt + rspamd_config:set_metric_symbol({ + name = rule_elt.symbol_spam, + score = 0.0, + description = 'Neural network SPAM', + group = 'neural' + }) + rspamd_config:register_symbol({ + name = rule_elt.symbol_spam, + type = 'virtual', + flags = 'nostat', + parent = id + }) + + rspamd_config:set_metric_symbol({ + name = rule_elt.symbol_ham, + score = -0.0, + description = 'Neural network HAM', + group = 'neural' + }) + rspamd_config:register_symbol({ + name = rule_elt.symbol_ham, + type = 'virtual', + flags = 'nostat', + parent = id + }) +end + +rspamd_config:register_symbol({ + name = 'NEURAL_LEARN', + type = 'idempotent,callback', + flags = 'nostat,explicit_disable,ignore_passthrough', + callback = ann_push_vector +}) + +-- We also need to deal with settings +rspamd_config:add_post_init(neural_common.process_rules_settings) + +-- Add training scripts +for _, rule in pairs(settings.rules) do + neural_common.load_scripts(rule.redis) + -- This function will check ANNs in Redis when a worker is loaded + rspamd_config:add_on_load(function(cfg, ev_base, worker) + if worker:is_scanner() then + rspamd_config:add_periodic(ev_base, 0.0, + function(_, _) + return check_anns(worker, cfg, ev_base, rule, process_existing_ann, + 'try_load_ann') + end) + end + + if worker:is_primary_controller() then + -- We also want to train neural nets when they have enough data + rspamd_config:add_periodic(ev_base, 0.0, + function(_, _) + -- Clean old ANNs + cleanup_anns(rule, cfg, ev_base) + return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann, + 'try_train_ann') + end) + end + end) +end diff --git a/src/plugins/lua/once_received.lua b/src/plugins/lua/once_received.lua new file mode 100644 index 0000000..2a5552a --- /dev/null +++ b/src/plugins/lua/once_received.lua @@ -0,0 +1,230 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if confighelp then + return +end + +-- 0 or 1 received: = spam + +local symbol = 'ONCE_RECEIVED' +local symbol_rdns = 'RDNS_NONE' +local symbol_rdns_dnsfail = 'RDNS_DNSFAIL' +local symbol_mx = 'DIRECT_TO_MX' +-- Symbol for strict checks +local symbol_strict = nil +local bad_hosts = {} +local good_hosts = {} +local whitelist = nil + +local rspamd_logger = require "rspamd_logger" +local lua_util = require "lua_util" +local fun = require "fun" +local N = 'once_received' + +local check_local = false +local check_authed = false + +local function check_quantity_received (task) + local recvh = task:get_received_headers() + + local nreceived = fun.reduce(function(acc, _) + return acc + 1 + end, 0, fun.filter(function(h) + return not h['flags']['artificial'] + end, recvh)) + + local function recv_dns_cb(_, to_resolve, results, err) + if err and (err ~= 'requested record is not found' and err ~= 'no records with this name') then + rspamd_logger.errx(task, 'error looking up %s: %s', to_resolve, err) + task:insert_result(symbol_rdns_dnsfail, 1.0) + end + + if not results then + if nreceived <= 1 then + task:insert_result(symbol, 1) + -- Avoid strict symbol inserting as the remaining symbols have already + -- quote a significant weight, so a message could be rejected by just + -- this property. + --task:insert_result(symbol_strict, 1) + -- Check for MUAs + local ua = task:get_header('User-Agent') + local xm = task:get_header('X-Mailer') + if (ua or xm) then + task:insert_result(symbol_mx, 1, (ua or xm)) + end + end + task:insert_result(symbol_rdns, 1) + else + rspamd_logger.infox(task, 'source hostname has not been passed to Rspamd from MTA, ' .. + 'but we could resolve source IP address PTR %s as "%s"', + to_resolve, results[1]) + task:set_hostname(results[1]) + + if good_hosts then + for _, gh in ipairs(good_hosts) do + if string.find(results[1], gh) then + return + end + end + end + + if nreceived <= 1 then + task:insert_result(symbol, 1) + for _, h in ipairs(bad_hosts) do + if string.find(results[1], h) then + + task:insert_result(symbol_strict, 1, h) + return + end + end + end + end + end + + local task_ip = task:get_ip() + + if ((not check_authed and task:get_user()) or + (not check_local and task_ip and task_ip:is_local())) then + rspamd_logger.infox(task, 'Skipping once_received for authenticated user or local network') + return + end + if whitelist and task_ip and whitelist:get_key(task_ip) then + rspamd_logger.infox(task, 'whitelisted mail from %s', + task_ip:to_string()) + return + end + + local hn = task:get_hostname() + -- Here we don't care about received + if (not hn) and task_ip and task_ip:is_valid() then + task:get_resolver():resolve_ptr({ task = task, + name = task_ip:to_string(), + callback = recv_dns_cb, + forced = true + }) + return + end + + if nreceived <= 1 then + local ret = true + local r = recvh[1] + + if not r then + return + end + + if r['real_hostname'] then + local rhn = string.lower(r['real_hostname']) + -- Check for good hostname + if rhn and good_hosts then + for _, gh in ipairs(good_hosts) do + if string.find(rhn, gh) then + ret = false + break + end + end + end + end + + if ret then + -- Strict checks + if symbol_strict then + -- Unresolved host + task:insert_result(symbol, 1) + + if not hn then + return + end + for _, h in ipairs(bad_hosts) do + if string.find(hn, h) then + task:insert_result(symbol_strict, 1, h) + return + end + end + else + task:insert_result(symbol, 1) + end + end + end +end + +local auth_and_local_conf = lua_util.config_check_local_or_authed(rspamd_config, N, + false, false) +check_local = auth_and_local_conf[1] +check_authed = auth_and_local_conf[2] + +-- Configuration +local opts = rspamd_config:get_all_opt(N) +if opts then + if opts['symbol'] then + symbol = opts['symbol'] + + local id = rspamd_config:register_symbol({ + name = symbol, + callback = check_quantity_received, + }) + + for n, v in pairs(opts) do + if n == 'symbol_strict' then + symbol_strict = v + elseif n == 'symbol_rdns' then + symbol_rdns = v + elseif n == 'symbol_rdns_dnsfail' then + symbol_rdns_dnsfail = v + elseif n == 'bad_host' then + if type(v) == 'string' then + bad_hosts[1] = v + else + bad_hosts = v + end + elseif n == 'good_host' then + if type(v) == 'string' then + good_hosts[1] = v + else + good_hosts = v + end + elseif n == 'whitelist' then + local lua_maps = require "lua_maps" + whitelist = lua_maps.map_add('once_received', 'whitelist', 'radix', + 'once received whitelist') + elseif n == 'symbol_mx' then + symbol_mx = v + end + end + + rspamd_config:register_symbol({ + name = symbol_rdns, + type = 'virtual', + parent = id + }) + rspamd_config:register_symbol({ + name = symbol_rdns_dnsfail, + type = 'virtual', + parent = id + }) + rspamd_config:register_symbol({ + name = symbol_strict, + type = 'virtual', + parent = id + }) + rspamd_config:register_symbol({ + name = symbol_mx, + type = 'virtual', + parent = id + }) + end +end diff --git a/src/plugins/lua/p0f.lua b/src/plugins/lua/p0f.lua new file mode 100644 index 0000000..97757c2 --- /dev/null +++ b/src/plugins/lua/p0f.lua @@ -0,0 +1,124 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> +Copyright (c) 2019, Denis Paavilainen <denpa@denpa.pro> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +-- Detect remote OS via passive fingerprinting + +local lua_util = require "lua_util" +local lua_redis = require "lua_redis" +local rspamd_logger = require "rspamd_logger" +local p0f = require("lua_scanners").filter('p0f').p0f + +local N = 'p0f' + +if confighelp then + rspamd_config:add_example(nil, N, + 'Detect remote OS via passive fingerprinting', + [[ + p0f { + # Enable module + enabled = true + + # Path to the unix socket that p0f listens on + socket = '/var/run/p0f.sock'; + + # Connection timeout + timeout = 5s; + + # If defined, insert symbol with lookup results + symbol = 'P0F'; + + # Patterns to match against results returned by p0f + # Symbol will be yielded on OS string, link type or distance matches + patterns = { + WINDOWS = '^Windows.*'; + #DSL = '^DSL$'; + #DISTANCE10 = '^distance:10$'; + } + + # Cache lifetime in seconds (default - 2 hours) + expire = 7200; + + # Cache key prefix + prefix = 'p0f'; + } + ]]) + return +end + +local rule + +local function check_p0f(task) + local ip = task:get_from_ip() + + if not (ip and ip:is_valid()) or ip:is_local() then + return + end + + p0f.check(task, ip, rule) +end + +local opts = rspamd_config:get_all_opt(N) + +rule = p0f.configure(opts) + +if rule then + rule.redis_params = lua_redis.parse_redis_server(N) + + lua_redis.register_prefix(rule.prefix .. '*', N, + 'P0f check cache', { + type = 'string', + }) + + local id = rspamd_config:register_symbol({ + name = 'P0F_CHECK', + type = 'prefilter', + callback = check_p0f, + priority = lua_util.symbols_priorities.medium, + flags = 'empty,nostat', + group = N, + augmentations = { string.format("timeout=%f", rule.timeout or 0.0) }, + + }) + + if rule.symbol then + rspamd_config:register_symbol({ + name = rule.symbol, + parent = id, + type = 'virtual', + flags = 'empty', + group = N + }) + end + + for sym in pairs(rule.patterns) do + rspamd_logger.debugm(N, rspamd_config, 'registering: %1', { + type = 'virtual', + name = sym, + parent = id, + group = N + }) + rspamd_config:register_symbol({ + type = 'virtual', + name = sym, + parent = id, + group = N + }) + end +else + lua_util.disable_module(N, 'config') + rspamd_logger.infox('p0f module not configured'); +end diff --git a/src/plugins/lua/phishing.lua b/src/plugins/lua/phishing.lua new file mode 100644 index 0000000..05e08c0 --- /dev/null +++ b/src/plugins/lua/phishing.lua @@ -0,0 +1,667 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if confighelp then + return +end + +local rspamd_logger = require "rspamd_logger" +local util = require "rspamd_util" +local lua_util = require "lua_util" +local lua_maps = require "lua_maps" + +-- Phishing detection interface for selecting phished urls and inserting corresponding symbol +-- +-- +local N = 'phishing' +local symbol = 'PHISHED_URL' +local phishing_feed_exclusion_symbol = 'PHISHED_EXCLUDED' +local generic_service_symbol = 'PHISHED_GENERIC_SERVICE' +local openphish_symbol = 'PHISHED_OPENPHISH' +local phishtank_symbol = 'PHISHED_PHISHTANK' +local generic_service_name = 'generic service' +local domains = nil +local phishing_exceptions_maps = {} +local anchor_exceptions_maps = {} +local strict_domains_maps = {} +local phishing_feed_exclusion_map = nil +local generic_service_map = nil +local openphish_map = 'https://www.openphish.com/feed.txt' +local phishtank_suffix = 'phishtank.rspamd.com' +-- Not enabled by default as their feed is quite large +local openphish_premium = false +-- Published via DNS +local phishtank_enabled = false +local phishing_feed_exclusion_hash +local generic_service_hash +local openphish_hash +local phishing_feed_exclusion_data = {} +local generic_service_data = {} +local openphish_data = {} + +local opts = rspamd_config:get_all_opt(N) +if not (opts and type(opts) == 'table') then + rspamd_logger.infox(rspamd_config, 'Module is unconfigured') + return +end + +local function is_host_excluded(exclusion_map, host) + if exclusion_map and host then + local excluded = exclusion_map[host] + if excluded then + return true + end + return false + end +end + +local function phishing_cb(task) + local function check_phishing_map(table) + local phishing_data = {} + for k,v in pairs(table) do + phishing_data[k] = v + end + local url = phishing_data.url + local host = url:get_host() + + if is_host_excluded(phishing_data.exclusion_map, host) then + task:insert_result(phishing_data.excl_symbol, 1.0, host) + return + end + + if host then + local elt = phishing_data.map[host] + local found_path = false + local found_query = false + local data = nil + + if elt then + local path = url:get_path() + local query = url:get_query() + + if path then + for _, d in ipairs(elt) do + if d['path'] == path then + found_path = true + data = d['data'] + + if query and d['query'] and query == d['query'] then + found_query = true + elseif not d['query'] then + found_query = true + end + end + end + else + for _, d in ipairs(elt) do + if not d['path'] then + found_path = true + end + + if query and d['query'] and query == d['query'] then + found_query = true + elseif not d['query'] then + found_query = true + end + end + end + + if found_path then + local args + + if type(data) == 'table' then + args = { + data['tld'], + data['sector'], + data['brand'], + } + elseif type(data) == 'string' then + args = data + else + args = host + end + + if found_query then + -- Query + path match + task:insert_result(phishing_data.phish_symbol, 1.0, args) + else + -- Host + path match + if path then + task:insert_result(phishing_data.phish_symbol, 0.3, args) + end + -- No path, no symbol + end + else + if url:is_phished() then + -- Only host matches + task:insert_result(phishing_data.phish_symbol, 0.1, host) + end + end + end + end + end + + local function check_phishing_dns(table) + local phishing_data = {} + for k,v in pairs(table) do + phishing_data[k] = v + end + local url = phishing_data.url + local host = url:get_host() + + if is_host_excluded(phishing_data.exclusion_map, host) then + task:insert_result(phishing_data.excl_symbol, 1.0, host) + return + end + + local function compose_dns_query(elts) + local cr = require "rspamd_cryptobox_hash" + local h = cr.create() + for _, elt in ipairs(elts) do + h:update(elt) + end + return string.format("%s.%s", h:base32():sub(1, 32), phishing_data.dns_suffix) + end + + local r = task:get_resolver() + local path = url:get_path() + local query = url:get_query() + + if host and path then + local function host_host_path_cb(_, _, results, err) + if not err and results then + if not query then + task:insert_result(phishing_data.phish_symbol, 1.0, results) + else + task:insert_result(phishing_data.phish_symbol, 0.3, results) + end + end + end + + local to_resolve_hp = compose_dns_query({ host, path }) + rspamd_logger.debugm(N, task, 'try to resolve {%s, %s} -> %s', + host, path, to_resolve_hp) + r:resolve_txt({ + task = task, + name = to_resolve_hp, + callback = host_host_path_cb }) + + if query then + local function host_host_path_query_cb(_, _, results, err) + if not err and results then + task:insert_result(phishing_data.phish_symbol, 1.0, results) + end + end + + local to_resolve_hpq = compose_dns_query({ host, path, query }) + rspamd_logger.debugm(N, task, 'try to resolve {%s, %s, %s} -> %s', + host, path, query, to_resolve_hpq) + r:resolve_txt({ + task = task, + name = to_resolve_hpq, + callback = host_host_path_query_cb }) + end + + end + end + + -- Process all urls + local dmarc_dom + local dsym = task:get_symbol('DMARC_POLICY_ALLOW') + if dsym then + dsym = dsym[1] -- legacy stuff, need to take the first element + if dsym.options then + dmarc_dom = dsym.options[1] + end + end + + local urls = task:get_urls() or {} + for _, url_iter in ipairs(urls) do + local function do_loop_iter() + -- to emulate continue + local url = url_iter + local phishing_data = {} + phishing_data.url = url + phishing_data.exclusion_map = phishing_feed_exclusion_data + phishing_data.excl_symbol = phishing_feed_exclusion_symbol + if generic_service_hash then + phishing_data.map = generic_service_data + phishing_data.phish_symbol = generic_service_symbol + check_phishing_map(phishing_data) + end + + if openphish_hash then + phishing_data.map = openphish_data + phishing_data.phish_symbol = openphish_symbol + check_phishing_map(phishing_data) + end + + if phishtank_enabled then + phishing_data.dns_suffix = phishtank_suffix + phishing_data.phish_symbol = phishtank_symbol + check_phishing_dns(phishing_data) + end + + if url:is_phished() then + local purl + + if url:is_redirected() then + local rspamd_url = require "rspamd_url" + -- Examine the real redirect target instead of the url + local redirected_url = url:get_redirected() + if not redirected_url then + return + end + + purl = rspamd_url.create(task:get_mempool(), url:get_visible()) + url = redirected_url + else + purl = url:get_phished() + end + + if not purl then + return + end + + local tld = url:get_tld() + local ptld = purl:get_tld() + + if not ptld or not tld then + return + end + + if dmarc_dom and tld == dmarc_dom then + lua_util.debugm(N, 'exclude phishing from %s -> %s by dmarc domain', tld, + ptld) + return + end + + -- Now we can safely remove the last dot component if it is the same + local b, _ = string.find(tld, '%.[^%.]+$') + local b1, _ = string.find(ptld, '%.[^%.]+$') + + local stripped_tld, stripped_ptld = tld, ptld + if b1 and b then + if string.sub(tld, b) == string.sub(ptld, b1) then + stripped_ptld = string.gsub(ptld, '%.[^%.]+$', '') + stripped_tld = string.gsub(tld, '%.[^%.]+$', '') + end + + if #ptld == 0 or #tld == 0 then + return false + end + end + + local weight = 1.0 + local spoofed, why = util.is_utf_spoofed(tld, ptld) + if spoofed then + lua_util.debugm(N, task, "confusable: %1 -> %2: %3", tld, ptld, why) + weight = 1.0 + else + local dist = util.levenshtein_distance(stripped_tld, stripped_ptld, 2) + dist = 2 * dist / (#stripped_tld + #stripped_ptld) + + if dist > 0.3 and dist <= 1.0 then + -- Use distance to penalize the total weight + weight = util.tanh(3 * (1 - dist + 0.1)) + elseif dist > 1 then + -- We also check if two labels are in the same ascii/non-ascii representation + local a1, a2 = false, false + + if string.match(tld, '^[\001-\127]*$') then + a1 = true + end + if string.match(ptld, '^[\001-\127]*$') then + a2 = true + end + + if a1 ~= a2 then + weight = 1 + lua_util.debugm(N, task, "confusable: %1 -> %2: different characters", + tld, ptld, why) + else + -- We have totally different strings in tld, so penalize it somehow + weight = 0.5 + end + end + + lua_util.debugm(N, task, "distance: %1 -> %2: %3", tld, ptld, dist) + end + + local function is_url_in_map(map, furl) + for _, dn in ipairs({ furl:get_tld(), furl:get_host() }) do + if map:get_key(dn) then + return true, dn + end + end + + return false + end + local function found_in_map(map, furl, sweight) + if not furl then + furl = url + end + if not sweight then + sweight = weight + end + if #map > 0 then + for _, rule in ipairs(map) do + local found, dn = is_url_in_map(rule.map, furl) + if found then + task:insert_result(rule.symbol, sweight, string.format("%s->%s:%s", ptld, tld, dn)) + return true + end + end + end + end + + found_in_map(strict_domains_maps, purl, 1.0) + if not found_in_map(anchor_exceptions_maps) then + if not found_in_map(phishing_exceptions_maps, purl, 1.0) then + if domains then + if is_url_in_map(domains, purl) then + task:insert_result(symbol, weight, ptld .. '->' .. tld) + end + else + task:insert_result(symbol, weight, ptld .. '->' .. tld) + end + end + end + end + end + + do_loop_iter() + end +end + +local function phishing_map(mapname, phishmap, id) + if opts[mapname] then + local xd + if type(opts[mapname]) == 'table' then + xd = opts[mapname] + else + rspamd_logger.errx(rspamd_config, 'invalid exception table') + end + + for sym, map_data in pairs(xd) do + local rmap = lua_maps.map_add_from_ucl(map_data, 'set', + 'Phishing ' .. mapname .. ' map') + if rmap then + rspamd_config:register_virtual_symbol(sym, 1, id) + local rule = { symbol = sym, map = rmap } + table.insert(phishmap, rule) + else + rspamd_logger.infox(rspamd_config, 'cannot add map for symbol: %s', sym) + end + end + end +end + +local function rspamd_str_split_fun(s, sep, func) + local lpeg = require "lpeg" + sep = lpeg.P(sep) + local elem = lpeg.P((1 - sep) ^ 0 / func) + local p = lpeg.P(elem * (sep * elem) ^ 0) + return p:match(s) +end + +local function insert_url_from_string(pool, tbl, str, data) + local rspamd_url = require "rspamd_url" + + local u = rspamd_url.create(pool, str) + + if u then + local host = u:get_host() + if host then + local elt = { + data = data, + path = u:get_path(), + query = u:get_query() + } + + if tbl[host] then + table.insert(tbl[host], elt) + else + tbl[host] = { elt } + end + + return true + end + end + + return false +end + +local function phishing_feed_exclusion_plain_cb(string) + local nelts = 0 + local new_data = {} + local rspamd_mempool = require "rspamd_mempool" + local pool = rspamd_mempool.create() + + local function phishing_feed_exclusion_elt_parser(cap) + if insert_url_from_string(pool, new_data, cap, nil) then + nelts = nelts + 1 + end + end + + rspamd_str_split_fun(string, '\n', phishing_feed_exclusion_elt_parser) + + phishing_feed_exclusion_data = new_data + rspamd_logger.infox(phishing_feed_exclusion_hash, "parsed %s elements from phishing feed exclusions", + nelts) + pool:destroy() +end + +local function generic_service_plain_cb(string) + local nelts = 0 + local new_data = {} + local rspamd_mempool = require "rspamd_mempool" + local pool = rspamd_mempool.create() + + local function generic_service_elt_parser(cap) + if insert_url_from_string(pool, new_data, cap, nil) then + nelts = nelts + 1 + end + end + + rspamd_str_split_fun(string, '\n', generic_service_elt_parser) + + generic_service_data = new_data + rspamd_logger.infox(generic_service_hash, "parsed %s elements from %s feed", + nelts, generic_service_name) + pool:destroy() +end + +local function openphish_json_cb(string) + local ucl = require "ucl" + local rspamd_mempool = require "rspamd_mempool" + local nelts = 0 + local new_json_map = {} + local valid = true + + local pool = rspamd_mempool.create() + + local function openphish_elt_parser(cap) + if valid then + local parser = ucl.parser() + local res, err = parser:parse_string(cap) + if not res then + valid = false + rspamd_logger.warnx(openphish_hash, 'cannot parse openphish map: ' .. err) + else + local obj = parser:get_object() + + if obj['url'] then + if insert_url_from_string(pool, new_json_map, obj['url'], obj) then + nelts = nelts + 1 + end + end + end + end + end + + rspamd_str_split_fun(string, '\n', openphish_elt_parser) + + if valid then + openphish_data = new_json_map + rspamd_logger.infox(openphish_hash, "parsed %s elements from openphish feed", + nelts) + end + + pool:destroy() +end + +local function openphish_plain_cb(s) + local nelts = 0 + local new_data = {} + local rspamd_mempool = require "rspamd_mempool" + local pool = rspamd_mempool.create() + + local function openphish_elt_parser(cap) + if insert_url_from_string(pool, new_data, cap, nil) then + nelts = nelts + 1 + end + end + + rspamd_str_split_fun(s, '\n', openphish_elt_parser) + + openphish_data = new_data + rspamd_logger.infox(openphish_hash, "parsed %s elements from openphish feed", + nelts) + pool:destroy() +end + +if opts then + local id + if opts['symbol'] then + symbol = opts['symbol'] + -- Register symbol's callback + id = rspamd_config:register_symbol({ + name = symbol, + callback = phishing_cb + }) + + -- To exclude from domains for dmarc verified messages + rspamd_config:register_dependency(symbol, 'DMARC_CHECK') + + if opts['phishing_feed_exclusion_symbol'] then + phishing_feed_exclusion_symbol = opts['phishing_feed_exclusion_symbol'] + end + if opts['phishing_feed_exclusion_map'] then + phishing_feed_exclusion_map = opts['phishing_feed_exclusion_map'] + end + + if opts['phishing_feed_exclusion_enabled'] then + phishing_feed_exclusion_hash = rspamd_config:add_map({ + type = 'callback', + url = phishing_feed_exclusion_map, + callback = phishing_feed_exclusion_plain_cb, + description = 'Phishing feed exclusions' + }) + end + + if opts['generic_service_symbol'] then + generic_service_symbol = opts['generic_service_symbol'] + end + if opts['generic_service_map'] then + generic_service_map = opts['generic_service_map'] + end + if opts['generic_service_url'] then + generic_service_map = opts['generic_service_url'] + end + if opts['generic_service_name'] then + generic_service_name = opts['generic_service_name'] + end + + if opts['generic_service_enabled'] then + generic_service_hash = rspamd_config:add_map({ + type = 'callback', + url = generic_service_map, + callback = generic_service_plain_cb, + description = 'Generic feed' + }) + end + + if opts['openphish_map'] then + openphish_map = opts['openphish_map'] + end + if opts['openphish_url'] then + openphish_map = opts['openphish_url'] + end + + if opts['openphish_premium'] then + openphish_premium = true + end + + if opts['openphish_enabled'] then + if not openphish_premium then + openphish_hash = rspamd_config:add_map({ + type = 'callback', + url = openphish_map, + callback = openphish_plain_cb, + description = 'Open phishing feed map (see https://www.openphish.com for details)', + opaque_data = true, + }) + else + openphish_hash = rspamd_config:add_map({ + type = 'callback', + url = openphish_map, + callback = openphish_json_cb, + opaque_data = true, + description = 'Open phishing premium feed map (see https://www.openphish.com for details)' + }) + end + end + + if opts['phishtank_enabled'] then + phishtank_enabled = true + if opts['phishtank_suffix'] then + phishtank_suffix = opts['phishtank_suffix'] + end + end + + rspamd_config:register_symbol({ + type = 'virtual', + parent = id, + name = generic_service_symbol, + }) + + rspamd_config:register_symbol({ + type = 'virtual', + parent = id, + name = phishing_feed_exclusion_symbol, + }) + + rspamd_config:register_symbol({ + type = 'virtual', + parent = id, + name = openphish_symbol, + }) + + rspamd_config:register_symbol({ + type = 'virtual', + parent = id, + name = phishtank_symbol, + }) + end + if opts['domains'] and type(opts['domains']) == 'string' then + domains = lua_maps.map_add_from_ucl(opts['domains'], 'set', + 'Phishing domains') + end + phishing_map('phishing_exceptions', phishing_exceptions_maps, id) + phishing_map('exceptions', anchor_exceptions_maps, id) + phishing_map('strict_domains', strict_domains_maps, id) +end diff --git a/src/plugins/lua/ratelimit.lua b/src/plugins/lua/ratelimit.lua new file mode 100644 index 0000000..add5741 --- /dev/null +++ b/src/plugins/lua/ratelimit.lua @@ -0,0 +1,868 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> +Copyright (c) 2016-2017, Andrew Lewis <nerf@judo.za.org> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if confighelp then + return +end + +local rspamd_logger = require "rspamd_logger" +local rspamd_util = require "rspamd_util" +local rspamd_lua_utils = require "lua_util" +local lua_redis = require "lua_redis" +local fun = require "fun" +local lua_maps = require "lua_maps" +local lua_util = require "lua_util" +local lua_verdict = require "lua_verdict" +local rspamd_hash = require "rspamd_cryptobox_hash" +local lua_selectors = require "lua_selectors" +local ts = require("tableshape").types + +-- A plugin that implements ratelimits using redis + +local E = {} +local N = 'ratelimit' +local redis_params +-- Senders that are considered as bounce +local settings = { + bounce_senders = { 'postmaster', 'mailer-daemon', '', 'null', 'fetchmail-daemon', 'mdaemon' }, + -- Do not check ratelimits for these recipients + whitelisted_rcpts = { 'postmaster', 'mailer-daemon' }, + prefix = 'RL', + ham_factor_rate = 1.01, + spam_factor_rate = 0.99, + ham_factor_burst = 1.02, + spam_factor_burst = 0.98, + max_rate_mult = 5, + max_bucket_mult = 10, + expire = 60 * 60 * 24 * 2, -- 2 days by default + limits = {}, + allow_local = false, + prefilter = true, +} + +local bucket_check_script = "ratelimit_check.lua" +local bucket_check_id + +local bucket_update_script = "ratelimit_update.lua" +local bucket_update_id + +local bucket_cleanup_script = "ratelimit_cleanup_pending.lua" +local bucket_cleanup_id + +-- message_func(task, limit_type, prefix, bucket, limit_key) +local message_func = function(_, limit_type, _, _, _) + return string.format('Ratelimit "%s" exceeded', limit_type) +end + +local function load_scripts(_, _) + bucket_check_id = lua_redis.load_redis_script_from_file(bucket_check_script, redis_params) + bucket_update_id = lua_redis.load_redis_script_from_file(bucket_update_script, redis_params) + bucket_cleanup_id = lua_redis.load_redis_script_from_file(bucket_cleanup_script, redis_params) +end + +local limit_parser +local function parse_string_limit(lim, no_error) + local function parse_time_suffix(s) + if s == 's' then + return 1 + elseif s == 'm' then + return 60 + elseif s == 'h' then + return 3600 + elseif s == 'd' then + return 86400 + end + end + local function parse_num_suffix(s) + if s == '' then + return 1 + elseif s == 'k' then + return 1000 + elseif s == 'm' then + return 1000000 + elseif s == 'g' then + return 1000000000 + end + end + local lpeg = require "lpeg" + + if not limit_parser then + local digit = lpeg.R("09") + limit_parser = {} + limit_parser.integer = (lpeg.S("+-") ^ -1) * + (digit ^ 1) + limit_parser.fractional = (lpeg.P(".")) * + (digit ^ 1) + limit_parser.number = (limit_parser.integer * + (limit_parser.fractional ^ -1)) + + (lpeg.S("+-") * limit_parser.fractional) + limit_parser.time = lpeg.Cf(lpeg.Cc(1) * + (limit_parser.number / tonumber) * + ((lpeg.S("smhd") / parse_time_suffix) ^ -1), + function(acc, val) + return acc * val + end) + limit_parser.suffixed_number = lpeg.Cf(lpeg.Cc(1) * + (limit_parser.number / tonumber) * + ((lpeg.S("kmg") / parse_num_suffix) ^ -1), + function(acc, val) + return acc * val + end) + limit_parser.limit = lpeg.Ct(limit_parser.suffixed_number * + (lpeg.S(" ") ^ 0) * lpeg.S("/") * (lpeg.S(" ") ^ 0) * + limit_parser.time) + end + local t = lpeg.match(limit_parser.limit, lim) + + if t and t[1] and t[2] and t[2] ~= 0 then + return t[2], t[1] + end + + if not no_error then + rspamd_logger.errx(rspamd_config, 'bad limit: %s', lim) + end + + return nil +end + +local function str_to_rate(str) + local divider, divisor = parse_string_limit(str, false) + + if not divisor then + rspamd_logger.errx(rspamd_config, 'bad rate string: %s', str) + + return nil + end + + return divisor / divider +end + +local bucket_schema = ts.shape { + burst = ts.number + ts.string / lua_util.dehumanize_number, + rate = ts.number + ts.string / str_to_rate, + skip_recipients = ts.boolean:is_optional(), + symbol = ts.string:is_optional(), + message = ts.string:is_optional(), + skip_soft_reject = ts.boolean:is_optional(), +} + +local function parse_limit(name, data) + if type(data) == 'table' then + -- 2 cases here: + -- * old limit in format [burst, rate] + -- * vector of strings in Andrew's string format (removed from 1.8.2) + -- * proper bucket table + if #data == 2 and tonumber(data[1]) and tonumber(data[2]) then + -- Old style ratelimit + rspamd_logger.warnx(rspamd_config, 'old style ratelimit for %s', name) + if tonumber(data[1]) > 0 and tonumber(data[2]) > 0 then + return { + burst = data[1], + rate = data[2] + } + elseif data[1] ~= 0 then + rspamd_logger.warnx(rspamd_config, 'invalid numbers for %s', name) + else + rspamd_logger.infox(rspamd_config, 'disable limit %s, burst is zero', name) + end + + return nil + else + local parsed_bucket, err = bucket_schema:transform(data) + + if not parsed_bucket or err then + rspamd_logger.errx(rspamd_config, 'cannot parse bucket for %s: %s; original value: %s', + name, err, data) + else + return parsed_bucket + end + end + elseif type(data) == 'string' then + local rep_rate, burst = parse_string_limit(data) + rspamd_logger.warnx(rspamd_config, 'old style rate bucket config detected for %s: %s', + name, data) + if rep_rate and burst then + return { + burst = burst, + rate = burst / rep_rate -- reciprocal + } + end + end + + return nil +end + +--- Check whether this addr is bounce +local function check_bounce(from) + return fun.any(function(b) + return b == from + end, settings.bounce_senders) +end + +local keywords = { + ['ip'] = { + ['get_value'] = function(task) + local ip = task:get_ip() + if ip and ip:is_valid() then + return tostring(ip) + end + return nil + end, + }, + ['rip'] = { + ['get_value'] = function(task) + local ip = task:get_ip() + if ip and ip:is_valid() and not ip:is_local() then + return tostring(ip) + end + return nil + end, + }, + ['from'] = { + ['get_value'] = function(task) + local from = task:get_from(0) + if ((from or E)[1] or E).addr then + return string.lower(from[1]['addr']) + end + return nil + end, + }, + ['bounce'] = { + ['get_value'] = function(task) + local from = task:get_from(0) + if not ((from or E)[1] or E).user then + return '_' + end + if check_bounce(from[1]['user']) then + return '_' + else + return nil + end + end, + }, + ['asn'] = { + ['get_value'] = function(task) + local asn = task:get_mempool():get_variable('asn') + if not asn then + return nil + else + return asn + end + end, + }, + ['user'] = { + ['get_value'] = function(task) + local auser = task:get_user() + if not auser then + return nil + else + return auser + end + end, + }, + ['to'] = { + ['get_value'] = function(task) + return task:get_principal_recipient() + end, + }, + ['digest'] = { + ['get_value'] = function(task) + return task:get_digest() + end, + }, + ['attachments'] = { + ['get_value'] = function(task) + local parts = task:get_parts() or E + local digests = {} + + for _, p in ipairs(parts) do + if p:get_filename() then + table.insert(digests, p:get_digest()) + end + end + + if #digests > 0 then + return table.concat(digests, '') + end + + return nil + end, + }, + ['files'] = { + ['get_value'] = function(task) + local parts = task:get_parts() or E + local files = {} + + for _, p in ipairs(parts) do + local fname = p:get_filename() + if fname then + table.insert(files, fname) + end + end + + if #files > 0 then + return table.concat(files, ':') + end + + return nil + end, + }, +} + +local function gen_rate_key(task, rtype, bucket) + local key_t = { tostring(lua_util.round(100000.0 / bucket.burst)) } + local key_keywords = lua_util.str_split(rtype, '_') + local have_user = false + + for _, v in ipairs(key_keywords) do + local ret + + if keywords[v] and type(keywords[v]['get_value']) == 'function' then + ret = keywords[v]['get_value'](task) + end + if not ret then + return nil + end + if v == 'user' then + have_user = true + end + if type(ret) ~= 'string' then + ret = tostring(ret) + end + table.insert(key_t, ret) + end + + if have_user and not task:get_user() then + return nil + end + + return table.concat(key_t, ":") +end + +local function make_prefix(redis_key, name, bucket) + local hash_len = 24 + if hash_len > #redis_key then + hash_len = #redis_key + end + local hash = settings.prefix .. + string.sub(rspamd_hash.create(redis_key):base32(), 1, hash_len) + -- Fill defaults + if not bucket.spam_factor_rate then + bucket.spam_factor_rate = settings.spam_factor_rate + end + if not bucket.ham_factor_rate then + bucket.ham_factor_rate = settings.ham_factor_rate + end + if not bucket.spam_factor_burst then + bucket.spam_factor_burst = settings.spam_factor_burst + end + if not bucket.ham_factor_burst then + bucket.ham_factor_burst = settings.ham_factor_burst + end + + return { + bucket = bucket, + name = name, + hash = hash + } +end + +local function limit_to_prefixes(task, k, v, prefixes) + local n = 0 + for _, bucket in ipairs(v.buckets) do + if v.selector then + local selectors = lua_selectors.process_selectors(task, v.selector) + if selectors then + local combined = lua_selectors.combine_selectors(task, selectors, ':') + if type(combined) == 'string' then + prefixes[combined] = make_prefix(combined, k, bucket) + n = n + 1 + else + fun.each(function(p) + prefixes[p] = make_prefix(p, k, bucket) + n = n + 1 + end, combined) + end + end + else + local prefix = gen_rate_key(task, k, bucket) + if prefix then + if type(prefix) == 'string' then + prefixes[prefix] = make_prefix(prefix, k, bucket) + n = n + 1 + else + fun.each(function(p) + prefixes[p] = make_prefix(p, k, bucket) + n = n + 1 + end, prefix) + end + end + end + end + + return n +end + +local function ratelimit_cb(task) + if not settings.allow_local and + rspamd_lua_utils.is_rspamc_or_controller(task) then + lua_util.debugm(N, task, 'skip ratelimit for local request') + return + end + + -- Get initial task data + local ip = task:get_from_ip() + if ip and ip:is_valid() and settings.whitelisted_ip then + if settings.whitelisted_ip:get_key(ip) then + -- Do not check whitelisted ip + rspamd_logger.infox(task, 'skip ratelimit for whitelisted IP') + return + end + end + -- Parse all rcpts + local rcpts = task:get_recipients() + local rcpts_user = {} + if rcpts then + fun.each(function(r) + fun.each(function(type) + table.insert(rcpts_user, r[type]) + end, { 'user', 'addr' }) + end, rcpts) + + if fun.any(function(r) + return settings.whitelisted_rcpts:get_key(r) + end, rcpts_user) then + rspamd_logger.infox(task, 'skip ratelimit for whitelisted recipient') + return + end + end + -- Get user (authuser) + if settings.whitelisted_user then + local auser = task:get_user() + if settings.whitelisted_user:get_key(auser) then + rspamd_logger.infox(task, 'skip ratelimit for whitelisted user') + return + end + end + -- Now create all ratelimit prefixes + local prefixes = {} + local nprefixes = 0 + + for k, v in pairs(settings.limits) do + nprefixes = nprefixes + limit_to_prefixes(task, k, v, prefixes) + end + + for k, hdl in pairs(settings.custom_keywords or E) do + local ret, redis_key, bd = pcall(hdl, task) + + if ret then + local bucket = parse_limit(k, bd) + if bucket then + prefixes[redis_key] = make_prefix(redis_key, k, bucket) + end + nprefixes = nprefixes + 1 + else + rspamd_logger.errx(task, 'cannot call handler for %s: %s', + k, redis_key) + end + end + + local function gen_check_cb(prefix, bucket, lim_name, lim_key) + return function(err, data) + if err then + rspamd_logger.errx('cannot check limit %s: %s %s', prefix, err, data) + elseif type(data) == 'table' and data[1] then + lua_util.debugm(N, task, + "got reply for limit %s (%s / %s); %s burst, %s:%s dyn, %s leaked", + prefix, bucket.burst, bucket.rate, + data[2], data[3], data[4], data[5]) + + task:cache_set('ratelimit_bucket_touched', true) + if data[1] == 1 then + -- set symbol only and do NOT soft reject + if bucket.symbol then + -- Per bucket symbol + task:insert_result(bucket.symbol, 1.0, + string.format('%s(%s)', lim_name, lim_key)) + else + if settings.symbol then + task:insert_result(settings.symbol, 1.0, + string.format('%s(%s)', lim_name, lim_key)) + elseif settings.info_symbol then + task:insert_result(settings.info_symbol, 1.0, + string.format('%s(%s)', lim_name, lim_key)) + end + end + rspamd_logger.infox(task, + 'ratelimit "%s(%s)" exceeded, (%s / %s): %s (%s:%s dyn); redis key: %s', + lim_name, prefix, + bucket.burst, bucket.rate, + data[2], data[3], data[4], lim_key) + + if not (bucket.symbol or settings.symbol) and not bucket.skip_soft_reject then + if not bucket.message then + task:set_pre_result('soft reject', + message_func(task, lim_name, prefix, bucket, lim_key), N) + else + task:set_pre_result('soft reject', bucket.message) + end + end + end + end + end + end + + -- Don't do anything if pre-result has been already set + if task:has_pre_result() then + return + end + + local _, nrcpt = task:has_recipients('smtp') + if not nrcpt or nrcpt <= 0 then + nrcpt = 1 + end + + if nprefixes > 0 then + -- Save prefixes to the cache to allow update + task:cache_set('ratelimit_prefixes', prefixes) + local now = rspamd_util.get_time() + now = lua_util.round(now * 1000.0) -- Get milliseconds + -- Now call check script for all defined prefixes + + for pr, value in pairs(prefixes) do + local bucket = value.bucket + local rate = (bucket.rate) / 1000.0 -- Leak rate in messages/ms + local bincr = nrcpt + if bucket.skip_recipients then + bincr = 1 + end + + lua_util.debugm(N, task, "check limit %s:%s -> %s (%s/%s)", + value.name, pr, value.hash, bucket.burst, bucket.rate) + lua_redis.exec_redis_script(bucket_check_id, + { key = value.hash, task = task, is_write = true }, + gen_check_cb(pr, bucket, value.name, value.hash), + { value.hash, tostring(now), tostring(rate), tostring(bucket.burst), + tostring(settings.expire), tostring(bincr) }) + end + end +end + + +-- This function is used to clean up pending bucket when +-- the task is somehow being skipped (e.g. greylisting/ratelimit/whatever) +-- but the ratelimit buckets for this task are touched (e.g. pending has been increased) +-- See https://github.com/rspamd/rspamd/issues/4467 for more context +local function maybe_cleanup_pending(task) + if task:cache_get('ratelimit_bucket_touched') then + local prefixes = task:cache_get('ratelimit_prefixes') + if prefixes then + for k, v in pairs(prefixes) do + local bucket = v.bucket + local function cleanup_cb(err, data) + if err then + rspamd_logger.errx('cannot cleanup limit %s: %s %s', k, err, data) + else + lua_util.debugm(N, task, 'cleaned pending bucked for %s: %s', k, data) + end + end + local _, nrcpt = task:has_recipients('smtp') + if not nrcpt or nrcpt <= 0 then + nrcpt = 1 + end + local bincr = nrcpt + if bucket.skip_recipients then + bincr = 1 + end + local now = task:get_timeval(true) + now = lua_util.round(now * 1000.0) -- Get milliseconds + lua_redis.exec_redis_script(bucket_cleanup_id, + { key = v.hash, task = task, is_write = true }, + cleanup_cb, + { v.hash, tostring(now), tostring(settings.expire), tostring(bincr) }) + end + end + end +end + +local function ratelimit_update_cb(task) + if task:has_flag('skip') then + maybe_cleanup_pending(task) + return + end + if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then + maybe_cleanup_pending(task) + end + + local prefixes = task:cache_get('ratelimit_prefixes') + + if prefixes then + if task:has_pre_result() then + -- Already rate limited/greylisted, do nothing + lua_util.debugm(N, task, 'pre-action has been set, do not update') + maybe_cleanup_pending(task) + return + end + + local verdict = lua_verdict.get_specific_verdict(N, task) + local _, nrcpt = task:has_recipients('smtp') + if not nrcpt or nrcpt <= 0 then + nrcpt = 1 + end + + -- Update each bucket + for k, v in pairs(prefixes) do + local bucket = v.bucket + local function update_bucket_cb(err, data) + if err then + rspamd_logger.errx(task, 'cannot update rate bucket %s: %s', + k, err) + else + lua_util.debugm(N, task, + "updated limit %s:%s -> %s (%s/%s), burst: %s, dyn_rate: %s, dyn_burst: %s", + v.name, k, v.hash, + bucket.burst, bucket.rate, + data[1], data[2], data[3]) + end + end + local now = task:get_timeval(true) + now = lua_util.round(now * 1000.0) -- Get milliseconds + local mult_burst = 1.0 + local mult_rate = 1.0 + + if verdict == 'spam' or verdict == 'junk' then + mult_burst = bucket.spam_factor_burst or 1.0 + mult_rate = bucket.spam_factor_rate or 1.0 + elseif verdict == 'ham' then + mult_burst = bucket.ham_factor_burst or 1.0 + mult_rate = bucket.ham_factor_rate or 1.0 + end + + local bincr = nrcpt + if bucket.skip_recipients then + bincr = 1 + end + + lua_redis.exec_redis_script(bucket_update_id, + { key = v.hash, task = task, is_write = true }, + update_bucket_cb, + { v.hash, tostring(now), tostring(mult_rate), tostring(mult_burst), + tostring(settings.max_rate_mult), tostring(settings.max_bucket_mult), + tostring(settings.expire), tostring(bincr) }) + end + end +end + +local opts = rspamd_config:get_all_opt(N) +if opts then + + settings = lua_util.override_defaults(settings, opts) + + if opts['limit'] then + rspamd_logger.errx(rspamd_config, 'Legacy ratelimit config format no longer supported') + end + + if opts['rates'] and type(opts['rates']) == 'table' then + -- new way of setting limits + fun.each(function(t, lim) + local buckets = {} + + if type(lim) == 'table' and lim.bucket then + + if lim.bucket[1] then + for _, bucket in ipairs(lim.bucket) do + local b = parse_limit(t, bucket) + + if not b then + rspamd_logger.errx(rspamd_config, 'bad ratelimit bucket for %s: "%s"', + t, b) + return + end + + table.insert(buckets, b) + end + else + local bucket = parse_limit(t, lim.bucket) + + if not bucket then + rspamd_logger.errx(rspamd_config, 'bad ratelimit bucket for %s: "%s"', + t, lim.bucket) + return + end + + buckets = { bucket } + end + + settings.limits[t] = { + buckets = buckets + } + + if lim.selector then + local selector = lua_selectors.parse_selector(rspamd_config, lim.selector) + if not selector then + rspamd_logger.errx(rspamd_config, 'bad ratelimit selector for %s: "%s"', + t, lim.selector) + settings.limits[t] = nil + return + end + + settings.limits[t].selector = selector + end + else + rspamd_logger.warnx(rspamd_config, 'old syntax for ratelimits: %s', lim) + buckets = parse_limit(t, lim) + if buckets then + settings.limits[t] = { + buckets = { buckets } + } + end + end + end, opts['rates']) + end + + -- Display what's enabled + fun.each(function(s) + rspamd_logger.infox(rspamd_config, 'enabled ratelimit: %s', s) + end, fun.map(function(n, d) + return string.format('%s [%s]', n, + table.concat(fun.totable(fun.map(function(v) + return string.format('symbol: %s, %s msgs burst, %s msgs/sec rate', + v.symbol, v.burst, v.rate) + end, d.buckets)), '; ') + ) + end, settings.limits)) + + -- Ret, ret, ret: stupid legacy stuff: + -- If we have a string with commas then load it as as static map + -- otherwise, apply normal logic of Rspamd maps + + local wrcpts = opts['whitelisted_rcpts'] + if type(wrcpts) == 'string' then + if string.find(wrcpts, ',') then + settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl( + lua_util.rspamd_str_split(wrcpts, ','), 'set', 'Ratelimit whitelisted rcpts') + else + settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set', + 'Ratelimit whitelisted rcpts') + end + elseif type(opts['whitelisted_rcpts']) == 'table' then + settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set', + 'Ratelimit whitelisted rcpts') + else + -- Stupid default... + settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl( + settings.whitelisted_rcpts, 'set', 'Ratelimit whitelisted rcpts') + end + + if opts['whitelisted_ip'] then + settings.whitelisted_ip = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_ip', 'radix', + 'Ratelimit whitelist ip map') + end + + if opts['whitelisted_user'] then + settings.whitelisted_user = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_user', 'set', + 'Ratelimit whitelist user map') + end + + settings.custom_keywords = {} + if opts['custom_keywords'] then + local ret, res_or_err = pcall(loadfile(opts['custom_keywords'])) + + if ret then + opts['custom_keywords'] = {} + if type(res_or_err) == 'table' then + for k, hdl in pairs(res_or_err) do + settings['custom_keywords'][k] = hdl + end + elseif type(res_or_err) == 'function' then + settings['custom_keywords']['custom'] = res_or_err + end + else + rspamd_logger.errx(rspamd_config, 'cannot execute %s: %s', + opts['custom_keywords'], res_or_err) + settings['custom_keywords'] = {} + end + end + + if opts['message_func'] then + message_func = assert(load(opts['message_func']))() + end + + redis_params = lua_redis.parse_redis_server('ratelimit') + + if not redis_params then + rspamd_logger.infox(rspamd_config, 'no servers are specified, disabling module') + lua_util.disable_module(N, "redis") + else + local s = { + type = settings.prefilter and 'prefilter' or 'callback', + name = 'RATELIMIT_CHECK', + priority = lua_util.symbols_priorities.medium, + callback = ratelimit_cb, + flags = 'empty,nostat', + augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) }, + } + + local id = rspamd_config:register_symbol(s) + + -- Register per bucket symbols + -- Display what's enabled + fun.each(function(set, lim) + if type(lim.buckets) == 'table' then + for _, b in ipairs(lim.buckets) do + if b.symbol then + rspamd_config:register_symbol { + type = 'virtual', + name = b.symbol, + score = 0.0, + parent = id + } + end + end + end + end, settings.limits) + + if settings.info_symbol then + rspamd_config:register_symbol { + type = 'virtual', + name = settings.info_symbol, + score = 0.0, + parent = id + } + end + if settings.symbol then + rspamd_config:register_symbol { + type = 'virtual', + name = settings.symbol, + score = 0.0, -- Might be overridden if needed + parent = id + } + end + + rspamd_config:register_symbol { + type = 'idempotent', + name = 'RATELIMIT_UPDATE', + flags = 'explicit_disable,ignore_passthrough', + callback = ratelimit_update_cb, + augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) }, + } + end +end + +rspamd_config:add_on_load(function(cfg, ev_base, _) + load_scripts(cfg, ev_base) +end) diff --git a/src/plugins/lua/rbl.lua b/src/plugins/lua/rbl.lua new file mode 100644 index 0000000..b2ccf86 --- /dev/null +++ b/src/plugins/lua/rbl.lua @@ -0,0 +1,1425 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> +Copyright (c) 2013-2015, Andrew Lewis <nerf@judo.za.org> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if confighelp then + return +end + +local hash = require 'rspamd_cryptobox_hash' +local rspamd_logger = require 'rspamd_logger' +local rspamd_util = require 'rspamd_util' +local rspamd_ip = require "rspamd_ip" +local fun = require 'fun' +local lua_util = require 'lua_util' +local selectors = require "lua_selectors" +local bit = require 'bit' +local lua_maps = require "lua_maps" +local rbl_common = require "plugins/rbl" +local rspamd_url = require "rspamd_url" + +-- This plugin implements various types of RBL checks +-- Documentation can be found here: +-- https://rspamd.com/doc/modules/rbl.html + +local E = {} +local N = 'rbl' + +-- Checks that could be performed by rbl module +local local_exclusions +local white_symbols = {} +local black_symbols = {} +local monitored_addresses = {} +local known_selectors = {} -- map from selector string to selector id +local url_flag_bits = rspamd_url.flags + +local function get_monitored(rbl) + local function is_random_monitored() + -- Explicit definition + if type(rbl.random_monitored) == 'boolean' then + return rbl.random_monitored + end + + -- We check 127.0.0.1 for merely RBLs with `from` or `received` and only if + -- they don't have `no_ip` attribute at the same time + -- + -- Convert to a boolean variable using the common idiom + return (not (rbl.from or rbl.received) + or rbl.no_ip) + and true or false + end + + local default_monitored = '1.0.0.127' + local ret = { + rcode = 'nxdomain', + prefix = default_monitored, + random = is_random_monitored(), + } + + if rbl.monitored_address then + ret.prefix = rbl.monitored_address + end + + lua_util.debugm(N, rspamd_config, + 'added monitored address: %s (%s random)', + ret.prefix, ret.random) + + return ret +end + +local function validate_dns(lstr) + if lstr:match('%.%.') then + -- two dots in a row + return false, "two dots in a row" + end + if not rspamd_util.is_valid_utf8(lstr) then + -- invalid utf8 detected + return false, "invalid utf8" + end + for v in lstr:gmatch('[^%.]+') do + if v:len() > 63 then + -- too long label + return false, "too long label" + end + if v:match('^-') or v:match('-$') then + -- dash at the beginning or end of label + return false, "dash at the beginning or end of label" + end + end + return true +end + +local function maybe_make_hash(data, rule) + if rule.hash then + local h = hash.create_specific(rule.hash, data) + local s + if rule.hash_format then + if rule.hash_format == 'base32' then + s = h:base32() + elseif rule.hash_format == 'base64' then + s = h:base64() + else + s = h:hex() + end + else + s = h:hex() + end + + if rule.hash_len then + s = s:sub(1, rule.hash_len) + end + + return s + else + return data + end +end + +local function is_excluded_ip(rip) + if local_exclusions and local_exclusions:get_key(rip) then + return true + end + return false +end + +local function ip_to_rbl(ip) + return table.concat(ip:inversed_str_octets(), '.') +end + +local function gen_check_rcvd_conditions(rbl, received_total) + local min_pos = tonumber(rbl.received_min_pos) + local max_pos = tonumber(rbl.received_max_pos) + local match_flags = rbl.received_flags + local nmatch_flags = rbl.received_nflags + + local function basic_received_check(rh) + if not (rh.real_ip and rh.real_ip:is_valid()) then + return false + end + if ((rh.real_ip:get_version() == 6 and rbl.ipv6) or + (rh.real_ip:get_version() == 4 and rbl.ipv4)) and + ((rbl.exclude_local and not rh.real_ip:is_local() or is_excluded_ip(rh.real_ip)) or not rbl.exclude_local) then + return true + else + return false + end + end + + local function positioned_received_check(rh, pos) + if not rh or not basic_received_check(rh) then + return false + end + local got_flags = rh.flags or E + if min_pos then + if min_pos < 0 then + if min_pos == -1 then + if (pos ~= received_total) then + return false + end + else + if pos <= (received_total - math.abs(min_pos)) then + return false + end + end + elseif pos < min_pos then + return false + end + end + if max_pos then + if max_pos < -1 then + if (received_total - math.abs(max_pos)) >= pos then + return false + end + elseif max_pos > 0 then + if pos > max_pos then + return false + end + end + end + if match_flags then + for _, flag in ipairs(match_flags) do + if not got_flags[flag] then + return false + end + end + end + if nmatch_flags then + for _, flag in ipairs(nmatch_flags) do + if got_flags[flag] then + return false + end + end + end + return true + end + + if not (max_pos or min_pos or match_flags or nmatch_flags) then + return basic_received_check + else + return positioned_received_check + end +end + +local matchers = {} + +matchers.radix = function(_, _, real_ip, map) + return map and map:get_key(real_ip) or false +end + +matchers.equality = function(codes, to_match) + if type(codes) ~= 'table' then return codes == to_match end + for _, ip in ipairs(codes) do + if to_match == ip then + return true + end + end + return false +end + +matchers.luapattern = function(codes, to_match) + if type(codes) ~= 'table' then + return string.find(to_match, '^' .. codes .. '$') and true or false + end + for _, pattern in ipairs(codes) do + if string.find(to_match, '^' .. pattern .. '$') then + return true + end + end + return false +end + +matchers.regexp = function(_, to_match, _, map) + return map and map:get_key(to_match) or false +end + +matchers.glob = function(_, to_match, _, map) + return map and map:get_key(to_match) or false +end + +local function rbl_dns_process(task, rbl, to_resolve, results, err, resolve_table_elt, match) + local function make_option(ip, label) + if ip then + return string.format('%s:%s:%s', + resolve_table_elt.orig, + label, + ip) + else + return string.format('%s:%s', + resolve_table_elt.orig, + label) + end + end + + local function insert_result(s, ip, label) + if rbl.symbols_prefixes then + local prefix = rbl.symbols_prefixes[label] + + if not prefix then + rspamd_logger.warnx(task, 'unlisted symbol prefix for %s', label) + task:insert_result(s, 1.0, make_option(ip, label)) + else + task:insert_result(prefix .. '_' .. s, 1.0, make_option(ip, label)) + end + else + task:insert_result(s, 1.0, make_option(ip, label)) + end + end + + local function insert_results(s, ip) + for label in pairs(resolve_table_elt.what) do + insert_result(s, ip, label) + end + end + + if err and (err ~= 'requested record is not found' and + err ~= 'no records with this name') then + rspamd_logger.infox(task, 'error looking up %s: %s', to_resolve, err) + task:insert_result(rbl.symbol .. '_FAIL', 1, string.format('%s:%s', + resolve_table_elt.orig, err)) + return + end + + if not results then + lua_util.debugm(N, task, + 'DNS RESPONSE: label=%1 results=%2 error=%3 rbl=%4', + to_resolve, false, err, rbl.symbol) + return + else + lua_util.debugm(N, task, + 'DNS RESPONSE: label=%1 results=%2 error=%3 rbl=%4', + to_resolve, true, err, rbl.symbol) + end + + if rbl.returncodes == nil and rbl.returnbits == nil and rbl.symbol ~= nil then + insert_results(rbl.symbol) + return + end + + local returncodes_maps = rbl.returncodes_maps or {} + + for _, result in ipairs(results) do + local ipstr = result:to_string() + lua_util.debugm(N, task, '%s DNS result %s', to_resolve, ipstr) + local foundrc = false + -- Check return codes + if rbl.returnbits then + local ipnum = result:to_number() + for s, bits in pairs(rbl.returnbits) do + for _, check_bit in ipairs(bits) do + if bit.band(ipnum, check_bit) == check_bit then + foundrc = true + insert_results(s) + -- Here, we continue with other bits + end + end + end + elseif rbl.returncodes then + for s, codes in pairs(rbl.returncodes) do + local res = match(codes, ipstr, result, returncodes_maps[s]) + if res then + foundrc = true + insert_results(s) + end + end + end + + if not foundrc then + if rbl.unknown and rbl.symbol then + insert_results(rbl.symbol, ipstr) + else + lua_util.debugm(N, task, '%1 returned unknown result: %2', + to_resolve, ipstr) + end + end + end + +end + +local function gen_rbl_callback(rule) + local function is_whitelisted(task, req, req_str, whitelist, what) + if rule.ignore_whitelist then + lua_util.debugm(N, task, + 'ignore whitelisting checks to %s by %s: ignore whitelist is being set', + req_str, rule.symbol) + return false + end + + if rule.whitelist then + if rule.whitelist:get_key(req) then + lua_util.debugm(N, task, + 'whitelisted %s on %s', + req_str, rule.symbol) + + return true + end + end + + -- Maybe whitelisted by some other rbl rule + if whitelist then + local wl = whitelist[req_str] + if wl then + lua_util.debugm(N, task, + 'whitelisted request to %s by %s (%s) rbl rule (%s checked type, %s whitelist type)', + req_str, wl.type, wl.symbol, what, wl.type) + if wl.type == what then + -- This was decided to be a bad idea as in case of whitelisting a request to blacklist + -- is not even sent + --task:adjust_result(wl.symbol, 0.0 / 0.0, rule.symbol) + + return true + end + end + end + + return false + end + + local function add_dns_request(task, req, forced, is_ip, requests_table, label, whitelist) + local req_str = req + if is_ip then + req_str = tostring(req) + end + + if whitelist and is_whitelisted(task, req, req_str, whitelist, label) then + return + end + + if is_ip then + req = ip_to_rbl(req) + end + + if requests_table[req] then + -- Duplicate request + local nreq = requests_table[req] + if forced and not nreq.forced then + nreq.forced = true + end + if not nreq.what[label] then + nreq.what[label] = true + end + + return true, nreq -- Duplicate + else + local nreq + + local resolve_ip = rule.resolve_ip and not is_ip + if rule.process_script then + local processed = rule.process_script(req, rule.rbl, task, resolve_ip) + + if processed then + nreq = { + forced = forced, + n = processed, + orig = req_str, + resolve_ip = resolve_ip, + what = { [label] = true }, + } + requests_table[req] = nreq + end + else + local to_resolve + local origin = req + + if not resolve_ip then + origin = maybe_make_hash(req, rule) + to_resolve = string.format('%s.%s', + origin, + rule.rbl) + else + -- First, resolve origin stuff without hashing or anything + to_resolve = origin + end + + nreq = { + forced = forced, + n = to_resolve, + orig = req_str, + resolve_ip = resolve_ip, + what = { [label] = true }, + } + requests_table[req] = nreq + end + return false, nreq + end + end + + -- Here, we have functional approach: we form a pipeline of functions + -- f1, f2, ... fn. Each function accepts task and return boolean value + -- that allows to process pipeline further + -- Each function in the pipeline can add something to `dns_req` vector as a side effect + local function is_alive(_, _) + if rule.monitored then + if not rule.monitored:alive() then + return false + end + end + + return true + end + + local function check_required_symbols(task, _) + if rule.require_symbols then + return fun.all(function(sym) + task:has_symbol(sym) + end, rule.require_symbols) + end + + return true + end + + local function check_user(task, _) + if task:get_user() then + return false + end + + return true + end + + local function check_local(task, _) + local ip = task:get_from_ip() + + if ip and not ip:is_valid() then + ip = nil + end + + if ip and ip:is_local() or is_excluded_ip(ip) then + return false + end + + return true + end + + local function check_helo(task, requests_table, whitelist) + local helo = task:get_helo() + + if not helo then + -- Avoid pipeline breaking + return true + end + + add_dns_request(task, helo, true, false, requests_table, + 'helo', whitelist) + + return true + end + + local function check_dkim(task, requests_table, whitelist) + local das = task:get_symbol('DKIM_TRACE') + local mime_from_domain + + if das and das[1] and das[1].options then + + if rule.dkim_match_from then + -- We check merely mime from + mime_from_domain = ((task:get_from('mime') or E)[1] or E).domain + if mime_from_domain then + local mime_from_domain_tld = rule.url_full_hostname and + mime_from_domain or rspamd_util.get_tld(mime_from_domain) + + if rule.url_compose_map then + mime_from_domain = rule.url_compose_map:process_url(task, mime_from_domain_tld, mime_from_domain) + else + mime_from_domain = mime_from_domain_tld + end + end + end + + for _, d in ipairs(das[1].options) do + + local domain, result = d:match('^([^%:]*):([%+%-%~])$') + + -- We must ignore bad signatures, omg + if domain and result and result == '+' then + if rule.dkim_match_from then + -- We check merely mime from + local domain_tld = domain + if not rule.dkim_domainonly then + -- Adjust + domain_tld = rspamd_util.get_tld(domain) + + if rule.url_compose_map then + domain_tld = rule.url_compose_map:process_url(task, domain_tld, domain) + elseif rule.url_full_hostname then + domain_tld = domain + end + end + + if mime_from_domain and mime_from_domain == domain_tld then + add_dns_request(task, domain_tld, true, false, requests_table, + 'dkim', whitelist) + end + else + if rule.dkim_domainonly then + local domain_tld = rspamd_util.get_tld(domain) + if rule.url_compose_map then + domain_tld = rule.url_compose_map:process_url(task, domain_tld, domain) + elseif rule.url_full_hostname then + domain_tld = domain + end + add_dns_request(task, domain_tld, + false, false, requests_table, 'dkim', whitelist) + else + add_dns_request(task, domain, false, false, requests_table, + 'dkim', whitelist) + end + end + end + end + end + + return true + end + + local function check_urls(task, requests_table, whitelist) + local esld_lim = 1 + + if rule.url_compose_map then + esld_lim = nil -- Avoid esld limit as we use custom composition rules + end + local ex_params = { + task = task, + limit = rule.requests_limit, + ignore_redirected = true, + ignore_ip = rule.no_ip, + need_images = rule.images, + need_emails = false, + need_content = rule.content_urls or false, + esld_limit = esld_lim, + no_cache = true, + } + + if rule.numeric_urls then + if rule.content_urls then + if not rule.images then + ex_params.flags_mode = 'explicit' + ex_params.flags = { 'numeric' } + ex_params.filter = function(url) + return (bit.band(url:get_flags_num(), url_flag_bits.image) == 0) + end + else + ex_params.filter = function(url) + return (bit.band(url:get_flags_num(), url_flag_bits.numeric) ~= 0) + end + end + elseif rule.images then + ex_params.filter = function(url) + return (bit.band(url:get_flags_num(), url_flag_bits.numeric) ~= 0) + end + else + ex_params.flags_mode = 'explicit' + ex_params.flags = { 'numeric' } + ex_params.filter = function(url) + return (bit.band(url:get_flags_num(), url_flag_bits.content) == 0) + end + end + elseif not rule.urls and (rule.content_urls or rule.images) then + ex_params.flags_mode = 'explicit' + ex_params.flags = {} + if rule.content_urls then + table.insert(ex_params.flags, 'content') + end + if rule.images then + table.insert(ex_params.flags, 'image') + end + end + + local urls = lua_util.extract_specific_urls(ex_params) + + for _, u in ipairs(urls) do + local flags = u:get_flags_num() + + if bit.band(flags, url_flag_bits.numeric) ~= 0 then + -- For numeric urls we convert data to the ip address and + -- reverse octets. See #3948 for details + local to_resolve = u:get_host() + local addr = rspamd_ip.from_string(to_resolve) + + if addr then + to_resolve = table.concat(addr:inversed_str_octets(), ".") + end + add_dns_request(task, to_resolve, false, + false, requests_table, 'url', whitelist) + else + local url_hostname = u:get_host() + local url_tld = rule.url_full_hostname and url_hostname or u:get_tld() + if rule.url_compose_map then + url_tld = rule.url_compose_map:process_url(task, url_tld, url_hostname) + end + add_dns_request(task, url_tld, false, + false, requests_table, 'url', whitelist) + end + end + + return true + end + + local function check_from(task, requests_table, whitelist) + local ip = task:get_from_ip() + + if not ip or not ip:is_valid() then + return true + end + if (ip:get_version() == 6 and rule.ipv6) or + (ip:get_version() == 4 and rule.ipv4) then + add_dns_request(task, ip, true, true, + requests_table, 'from', + whitelist) + end + + return true + end + + local function check_received(task, requests_table, whitelist) + local received = fun .filter(function(h) + return not h['flags']['artificial'] + end, task:get_received_headers()):totable() + + local received_total = #received + local check_conditions = gen_check_rcvd_conditions(rule, received_total) + + for pos, rh in ipairs(received) do + if check_conditions(rh, pos) then + add_dns_request(task, rh.real_ip, false, true, + requests_table, 'received', + whitelist) + end + end + + return true + end + + local function check_rdns(task, requests_table, whitelist) + local hostname = task:get_hostname() + if hostname == nil or hostname == 'unknown' then + return true + end + + add_dns_request(task, hostname, true, false, + requests_table, 'rdns', whitelist) + + return true + end + + local function check_selector(task, requests_table, whitelist) + for selector_label, selector in pairs(rule.selectors) do + local res = selector(task) + + if res and type(res) == 'table' then + for _, r in ipairs(res) do + add_dns_request(task, r, false, false, requests_table, + selector_label, whitelist) + end + elseif res then + add_dns_request(task, res, false, false, + requests_table, selector_label, whitelist) + end + end + + return true + end + + local function check_email_table(task, email_tbl, requests_table, whitelist, what) + lua_util.remove_email_aliases(email_tbl) + email_tbl.domain = email_tbl.domain:lower() + email_tbl.user = email_tbl.user:lower() + + if email_tbl.domain == '' or email_tbl.user == '' then + rspamd_logger.infox(task, "got an email with some empty parts: '%s@%s'; skip it in the checks", + email_tbl.user, email_tbl.domain) + return + end + + if rule.emails_domainonly then + add_dns_request(task, email_tbl.domain, false, false, requests_table, + what, whitelist) + else + -- Also check WL for domain only + if is_whitelisted(task, + email_tbl.domain, + email_tbl.domain, + whitelist, + what) then + return + end + local delimiter = '.' + if rule.emails_delimiter then + delimiter = rule.emails_delimiter + else + if rule.hash then + delimiter = '@' + end + end + add_dns_request(task, string.format('%s%s%s', + email_tbl.user, delimiter, email_tbl.domain), false, false, + requests_table, what, whitelist) + end + end + + local function check_emails(task, requests_table, whitelist) + local ex_params = { + task = task, + limit = rule.requests_limit, + filter = function(u) + return u:get_protocol() == 'mailto' + end, + need_emails = true, + prefix = 'rbl_email' + } + + if rule.emails_domainonly then + if not rule.url_compose_map then + ex_params.esld_limit = 1 + end + ex_params.prefix = 'rbl_email_domainonly' + end + + local emails = lua_util.extract_specific_urls(ex_params) + + for _, email in ipairs(emails) do + local domain + if rule.emails_domainonly and not rule.url_full_hostname then + if rule.url_compose_map then + domain = rule.url_compose_map:process_url(task, email:get_tld(), email:get_host()) + else + domain = email:get_tld() + end + else + domain = email:get_host() + end + + local email_tbl = { + domain = domain or '', + user = email:get_user() or '', + addr = tostring(email), + } + check_email_table(task, email_tbl, requests_table, whitelist, 'email') + end + + return true + end + + local function check_replyto(task, requests_table, whitelist) + local function get_raw_header(name) + return ((task:get_header_full(name) or {})[1] or {})['value'] + end + + local replyto = get_raw_header('Reply-To') + if replyto then + local rt = rspamd_util.parse_mail_address(replyto, task:get_mempool()) + lua_util.debugm(N, task, 'check replyto %s', rt[1]) + + if rt and rt[1] and (rt[1].addr and #rt[1].addr > 0) then + check_email_table(task, rt[1], requests_table, whitelist, 'replyto') + end + end + + return true + end + + -- Create function pipeline depending on rbl settings + local pipeline = { + is_alive, -- check monitored status + check_required_symbols -- if we have require_symbols then check those symbols + } + local description = { + 'alive', + } + + if rule.exclude_users then + pipeline[#pipeline + 1] = check_user + description[#description + 1] = 'user' + end + + if rule.exclude_local then + pipeline[#pipeline + 1] = check_local + description[#description + 1] = 'local' + end + + if rule.helo then + pipeline[#pipeline + 1] = check_helo + description[#description + 1] = 'helo' + end + + if rule.dkim then + pipeline[#pipeline + 1] = check_dkim + description[#description + 1] = 'dkim' + end + + if rule.emails then + pipeline[#pipeline + 1] = check_emails + description[#description + 1] = 'emails' + end + if rule.replyto then + pipeline[#pipeline + 1] = check_replyto + description[#description + 1] = 'replyto' + end + + if rule.urls or rule.content_urls or rule.images or rule.numeric_urls then + pipeline[#pipeline + 1] = check_urls + description[#description + 1] = 'urls' + end + + if rule.from then + pipeline[#pipeline + 1] = check_from + description[#description + 1] = 'ip' + end + + if rule.received then + pipeline[#pipeline + 1] = check_received + description[#description + 1] = 'received' + end + + if rule.rdns then + pipeline[#pipeline + 1] = check_rdns + description[#description + 1] = 'rdns' + end + + if rule.selector then + pipeline[#pipeline + 1] = check_selector + description[#description + 1] = 'selector' + end + + if not rule.returncodes_matcher then + rule.returncodes_matcher = 'equality' + end + local match = matchers[rule.returncodes_matcher] + + local callback_f = function(task) + -- DNS requests to issue (might be hashed afterwards) + local dns_req = {} + local whitelist = task:cache_get('rbl_whitelisted') or {} + + local function gen_rbl_dns_callback(resolve_table_elt) + return function(_, to_resolve, results, err) + rbl_dns_process(task, rule, to_resolve, results, err, resolve_table_elt, match) + end + end + + -- Execute functions pipeline + for i, f in ipairs(pipeline) do + if not f(task, dns_req, whitelist) then + lua_util.debugm(N, task, + "skip rbl check: %s; pipeline condition %s returned false", + rule.symbol, i) + return + end + end + + -- Now check all DNS requests pending and emit them + local r = task:get_resolver() + -- Used for 2 passes ip resolution + local resolved_req = {} + local nresolved = 0 + + -- This is called when doing resolve_ip phase... + local function gen_rbl_ip_dns_callback(orig_resolve_table_elt) + return function(_, _, results, err) + if not err then + for _, dns_res in ipairs(results) do + -- Check if we have rspamd{ip} userdata + if type(dns_res) == 'userdata' then + -- Add result as an actual RBL request + local label = next(orig_resolve_table_elt.what) + local dup, nreq = add_dns_request(task, dns_res, false, true, + resolved_req, label) + -- Add original name + if not dup then + nreq.orig = nreq.orig .. ':' .. orig_resolve_table_elt.n + end + end + end + end + + nresolved = nresolved - 1 + + if nresolved == 0 then + -- Emit real RBL requests as there are no ip resolution requests + for name, req in pairs(resolved_req) do + local val_res, val_error = validate_dns(req.n) + if val_res then + lua_util.debugm(N, task, "rbl %s; resolve %s -> %s", + rule.symbol, name, req.n) + r:resolve_a({ + task = task, + name = req.n, + callback = gen_rbl_dns_callback(req), + forced = req.forced + }) + else + rspamd_logger.warnx(task, 'cannot send invalid DNS request %s for %s: %s', + req.n, rule.symbol, val_error) + end + end + end + end + end + + for name, req in pairs(dns_req) do + local val_res, val_error = validate_dns(req.n) + if val_res then + lua_util.debugm(N, task, "rbl %s; resolve %s -> %s", + rule.symbol, name, req.n) + + if req.resolve_ip then + -- Deal with both ipv4 and ipv6 + -- Resolve names first + if r:resolve_a({ + task = task, + name = req.n, + callback = gen_rbl_ip_dns_callback(req), + forced = req.forced + }) then + nresolved = nresolved + 1 + end + if r:resolve('aaaa', { + task = task, + name = req.n, + callback = gen_rbl_ip_dns_callback(req), + forced = req.forced + }) then + nresolved = nresolved + 1 + end + else + r:resolve_a({ + task = task, + name = req.n, + callback = gen_rbl_dns_callback(req), + forced = req.forced + }) + end + + else + rspamd_logger.warnx(task, 'cannot send invalid DNS request %s for %s: %s', + req.n, rule.symbol, val_error) + end + end + end + + return callback_f, string.format('checks: %s', table.concat(description, ',')) +end + +local map_match_types = { + glob = true, + radix = true, + regexp = true, +} + +local function add_rbl(key, rbl, global_opts) + if not rbl.symbol then + rbl.symbol = key:upper() + end + + local flags_tbl = { 'no_squeeze' } + if rbl.is_whitelist then + flags_tbl[#flags_tbl + 1] = 'nice' + end + + -- Check if rbl is available for empty tasks + if not (rbl.emails or rbl.urls or rbl.dkim or rbl.received or rbl.selector or rbl.replyto) or + rbl.is_empty then + flags_tbl[#flags_tbl + 1] = 'empty' + end + + if rbl.selector then + + rbl.selectors = {} + if type(rbl.selector) ~= 'table' then + rbl.selector = { ['selector'] = rbl.selector } + end + + for selector_label, selector in pairs(rbl.selector) do + if known_selectors[selector] then + lua_util.debugm(N, rspamd_config, 'reuse selector id %s', + known_selectors[selector].id) + rbl.selectors[selector_label] = known_selectors[selector].selector + else + + if type(rbl.selector_flatten) ~= 'boolean' then + -- Fail-safety + rbl.selector_flatten = true + end + local sel = selectors.create_selector_closure(rspamd_config, selector, '', + rbl.selector_flatten) + + if not sel then + rspamd_logger.errx('invalid selector for rbl rule %s: %s', key, selector) + return false + end + + rbl.selector = sel + known_selectors[selector] = { + selector = sel, + id = #lua_util.keys(known_selectors) + 1, + } + rbl.selectors[selector_label] = known_selectors[selector].selector + end + end + + end + + if rbl.process_script then + local ret, f = lua_util.callback_from_string(rbl.process_script) + + if ret then + rbl.process_script = f + else + rspamd_logger.errx(rspamd_config, + 'invalid process script for rbl rule %s: %s; %s', + key, rbl.process_script, f) + return false + end + end + + if rbl.whitelist then + local def_type = 'set' + if rbl.from or rbl.received then + def_type = 'radix' + end + rbl.whitelist = lua_maps.map_add_from_ucl(rbl.whitelist, def_type, + 'RBL whitelist for ' .. rbl.symbol) + rspamd_logger.infox(rspamd_config, 'added %s whitelist for RBL %s', + def_type, rbl.symbol) + end + + local match_type = rbl.returncodes_matcher + if match_type and rbl.returncodes and map_match_types[match_type] then + if not rbl.returncodes_maps then + rbl.returncodes_maps = {} + end + for label, v in pairs(rbl.returncodes) do + if type(v) ~= 'table' then + v = {v} + end + rbl.returncodes_maps[label] = lua_maps.map_add_from_ucl(v, match_type, string.format('%s_%s RBL returncodes', label, rbl.symbol)) + end + end + + if rbl.url_compose_map then + local lua_urls_compose = require "lua_urls_compose" + rbl.url_compose_map = lua_urls_compose.add_composition_map(rspamd_config, rbl.url_compose_map) + + if rbl.url_compose_map then + rspamd_logger.infox(rspamd_config, 'added url composition map for RBL %s', + rbl.symbol) + end + end + + if not rbl.whitelist and not rbl.ignore_url_whitelist and (global_opts.url_whitelist or rbl.url_whitelist) and + (rbl.urls or rbl.emails or rbl.dkim or rbl.replyto) and + not (rbl.from or rbl.received) then + local def_type = 'set' + rbl.whitelist = lua_maps.map_add_from_ucl(rbl.url_whitelist or global_opts.url_whitelist, def_type, + 'RBL url whitelist for ' .. rbl.symbol) + rspamd_logger.infox(rspamd_config, 'added URL whitelist for RBL %s', + rbl.symbol) + end + + local callback, description = gen_rbl_callback(rbl) + + if callback then + local id + + if rbl.symbols_prefixes then + id = rspamd_config:register_symbol { + type = 'callback', + callback = callback, + groups = { 'rbl' }, + name = rbl.symbol .. '_CHECK', + flags = table.concat(flags_tbl, ',') + } + + for _, prefix in pairs(rbl.symbols_prefixes) do + -- For unknown results... + rspamd_config:register_symbol { + type = 'virtual', + parent = id, + group = 'rbl', + score = 0, + name = prefix .. '_' .. rbl.symbol, + } + end + if not (rbl.is_whitelist or rbl.ignore_whitelist) then + table.insert(black_symbols, rbl.symbol .. '_CHECK') + else + lua_util.debugm(N, rspamd_config, 'rule %s ignores whitelists: rbl.is_whitelist = %s, ' .. + 'rbl.ignore_whitelist = %s', + rbl.symbol, rbl.is_whitelist, rbl.ignore_whitelist) + end + else + id = rspamd_config:register_symbol { + type = 'callback', + callback = callback, + name = rbl.symbol, + groups = { 'rbl' }, + group = 'rbl', + score = 0, + flags = table.concat(flags_tbl, ',') + } + if not (rbl.is_whitelist or rbl.ignore_whitelist) then + table.insert(black_symbols, rbl.symbol) + else + lua_util.debugm(N, rspamd_config, 'rule %s ignores whitelists: rbl.is_whitelist = %s, ' .. + 'rbl.ignore_whitelist = %s', + rbl.symbol, rbl.is_whitelist, rbl.ignore_whitelist) + end + end + + rspamd_logger.infox(rspamd_config, 'added rbl rule %s: %s', + rbl.symbol, description) + lua_util.debugm(N, rspamd_config, 'rule dump for %s: %s', + rbl.symbol, rbl) + + local check_sym = rbl.symbols_prefixes and rbl.symbol .. '_CHECK' or rbl.symbol + + if rbl.dkim then + rspamd_config:register_dependency(check_sym, 'DKIM_CHECK') + end + + if rbl.require_symbols then + for _, dep in ipairs(rbl.require_symbols) do + rspamd_config:register_dependency(check_sym, dep) + end + end + + -- Failure symbol + rspamd_config:register_symbol { + type = 'virtual', + flags = 'nostat', + name = rbl.symbol .. '_FAIL', + parent = id, + score = 0.0, + } + + local function process_return_code(suffix) + local function process_specific_suffix(s) + if s ~= rbl.symbol then + -- hack + + rspamd_config:register_symbol { + type = 'virtual', + parent = id, + name = s, + group = 'rbl', + score = 0, + } + end + if rbl.is_whitelist then + if rbl.whitelist_exception then + local found_exception = false + for _, e in ipairs(rbl.whitelist_exception) do + if e == s then + found_exception = true + break + end + end + if not found_exception then + table.insert(white_symbols, s) + end + else + table.insert(white_symbols, s) + end + else + if not rbl.ignore_whitelist then + table.insert(black_symbols, s) + end + end + end + + if rbl.symbols_prefixes then + for _, prefix in pairs(rbl.symbols_prefixes) do + process_specific_suffix(prefix .. '_' .. suffix) + end + else + process_specific_suffix(suffix) + end + + end + + if rbl.returncodes then + for s, _ in pairs(rbl.returncodes) do + process_return_code(s) + end + end + + if rbl.returnbits then + for s, _ in pairs(rbl.returnbits) do + process_return_code(s) + end + end + + -- Process monitored + if not rbl.disable_monitoring then + if not monitored_addresses[rbl.rbl] then + monitored_addresses[rbl.rbl] = true + rbl.monitored = rspamd_config:register_monitored(rbl.rbl, 'dns', + get_monitored(rbl)) + end + end + return true + end + + return false +end + +-- Configuration +local opts = rspamd_config:get_all_opt(N) +if not (opts and type(opts) == 'table') then + rspamd_logger.infox(rspamd_config, 'Module is unconfigured') + lua_util.disable_module(N, "config") + return +end + +-- Plugin defaults should not be changed - override these in config +-- New defaults should not alter behaviour + + +opts = lua_util.override_defaults(rbl_common.default_options, opts) + +if opts.rules and opts.rbls then + -- Common issue :( + rspamd_logger.infox(rspamd_config, 'merging `rules` and `rbls` keys for compatibility') + opts.rbls = lua_util.override_defaults(opts.rbls, opts.rules) +end + +if (opts['local_exclude_ip_map'] ~= nil) then + local_exclusions = lua_maps.map_add(N, 'local_exclude_ip_map', 'radix', + 'RBL exclusions map') +end + +-- TODO: this code should be universal for all modules that use selectors to allow +-- maps usage from selectors registered for a specific module +if type(opts.attached_maps) == 'table' then + opts.attached_maps_processed = {} + for i, map in ipairs(opts.attached_maps) do + -- Store maps in the configuration table to keep lifetime track + opts.attached_maps_processed[i] = lua_maps.map_add_from_ucl(map) + if opts.attached_maps_processed[i] == nil then + rspamd_logger.warnx(rspamd_config, "cannot parse attached map: %s", map) + end + end +end + +for key, rbl in pairs(opts.rbls) do + if type(rbl) ~= 'table' or rbl.disabled == true or rbl.enabled == false then + rspamd_logger.infox(rspamd_config, 'disable rbl "%s"', key) + else + -- Aliases + if type(rbl.ignore_default) == 'boolean' then + rbl.ignore_defaults = rbl.ignore_default + end + if type(rbl.ignore_whitelists) == 'boolean' then + rbl.ignore_whitelist = rbl.ignore_whitelists + end + -- Propagate default options from opts to rule + if not rbl.ignore_defaults then + for default_opt_key, _ in pairs(rbl_common.default_options) do + local rbl_opt = default_opt_key:sub(#('default_') + 1) + if rbl[rbl_opt] == nil then + rbl[rbl_opt] = opts[default_opt_key] + end + end + end + + if not rbl.requests_limit then + rbl.requests_limit = rspamd_config:get_dns_max_requests() + end + + local res, err = rbl_common.rule_schema:transform(rbl) + if not res then + rspamd_logger.errx(rspamd_config, 'invalid config for %s: %s, RBL is DISABLED', + key, err) + else + res = rbl_common.convert_checks(res, rbl.symbol or key:upper()) + -- Aliases + if res.return_codes then + res.returncodes = res.return_codes + end + if res.return_bits then + res.returnbits = res.return_bits + end + + if not res then + rspamd_logger.errx(rspamd_config, 'invalid config for %s: %s, RBL is DISABLED', + key, err) + else + add_rbl(key, res, opts) + end + end + end -- rbl.enabled +end + +-- We now create two symbols: +-- * RBL_CALLBACK_WHITE that depends on all symbols white +-- * RBL_CALLBACK that depends on all symbols black to participate in depends chains +local function rbl_callback_white(task) + local whitelisted_elements = {} + for _, w in ipairs(white_symbols) do + local ws = task:get_symbol(w) + if ws and ws[1] then + ws = ws[1] + if not ws.options then + ws.options = {} + end + for _, opt in ipairs(ws.options) do + local elt, what = opt:match('^([^:]+):([^:]+)') + lua_util.debugm(N, task, 'found whitelist from %s: %s(%s)', w, + elt, what) + if elt and what then + whitelisted_elements[elt] = { + type = what, + symbol = w, + } + end + end + end + end + + task:cache_set('rbl_whitelisted', whitelisted_elements) + + lua_util.debugm(N, task, "finished rbl whitelists processing") +end + +local function rbl_callback_fin(task) + -- Do nothing + lua_util.debugm(N, task, "finished rbl processing") +end + +rspamd_config:register_symbol { + type = 'callback', + callback = rbl_callback_white, + name = 'RBL_CALLBACK_WHITE', + flags = 'nice,empty,no_squeeze', + groups = { 'rbl' }, + augmentations = { string.format("timeout=%f", rspamd_config:get_dns_timeout() or 0.0) }, +} + +rspamd_config:register_symbol { + type = 'callback', + callback = rbl_callback_fin, + name = 'RBL_CALLBACK', + flags = 'empty,no_squeeze', + groups = { 'rbl' }, + augmentations = { string.format("timeout=%f", rspamd_config:get_dns_timeout() or 0.0) }, +} + +for _, w in ipairs(white_symbols) do + rspamd_config:register_dependency('RBL_CALLBACK_WHITE', w) +end + +for _, b in ipairs(black_symbols) do + rspamd_config:register_dependency(b, 'RBL_CALLBACK_WHITE') + rspamd_config:register_dependency('RBL_CALLBACK', b) +end diff --git a/src/plugins/lua/replies.lua b/src/plugins/lua/replies.lua new file mode 100644 index 0000000..c4df9c9 --- /dev/null +++ b/src/plugins/lua/replies.lua @@ -0,0 +1,328 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> +Copyright (c) 2016, Andrew Lewis <nerf@judo.za.org> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if confighelp then + return +end + +local rspamd_logger = require 'rspamd_logger' +local hash = require 'rspamd_cryptobox_hash' +local lua_util = require 'lua_util' +local lua_redis = require 'lua_redis' +local fun = require "fun" + +-- A plugin that implements replies check using redis + +-- Default port for redis upstreams +local redis_params +local settings = { + action = nil, + expire = 86400, -- 1 day by default + key_prefix = 'rr', + key_size = 20, + message = 'Message is reply to one we originated', + symbol = 'REPLY', + score = -4, -- Default score + use_auth = true, + use_local = true, + cookie = nil, + cookie_key = nil, + cookie_is_pattern = false, + cookie_valid_time = '2w', -- 2 weeks by default + min_message_id = 2, -- minimum length of the message-id header +} + +local N = "replies" + +local function make_key(goop, sz, prefix) + local h = hash.create() + h:update(goop) + local key + if sz then + key = h:base32():sub(1, sz) + else + key = h:base32() + end + + if prefix then + key = prefix .. key + end + + return key +end + +local function replies_check(task) + local in_reply_to + local function check_recipient(stored_rcpt) + local rcpts = task:get_recipients('mime') + + if rcpts then + local filter_predicate = function(input_rcpt) + local real_rcpt_h = make_key(input_rcpt:lower(), 8) + + return real_rcpt_h == stored_rcpt + end + + if fun.any(filter_predicate, fun.map(function(rcpt) + return rcpt.addr or '' + end, rcpts)) then + lua_util.debugm(N, task, 'reply to %s validated', in_reply_to) + return true + end + + rspamd_logger.infox(task, 'ignoring reply to %s as no recipients are matching hash %s', + in_reply_to, stored_rcpt) + else + rspamd_logger.infox(task, 'ignoring reply to %s as recipient cannot be detected for hash %s', + in_reply_to, stored_rcpt) + end + + return false + end + + local function redis_get_cb(err, data, addr) + if err ~= nil then + rspamd_logger.errx(task, 'redis_get_cb error when reading data from %s: %s', addr:get_addr(), err) + return + end + if data and type(data) == 'string' and check_recipient(data) then + -- Hash was found + task:insert_result(settings['symbol'], 1.0) + if settings['action'] ~= nil then + local ip_addr = task:get_ip() + if (settings.use_auth and + task:get_user()) or + (settings.use_local and ip_addr and ip_addr:is_local()) then + rspamd_logger.infox(task, "not forcing action for local network or authorized user"); + else + task:set_pre_result(settings['action'], settings['message'], N) + end + end + end + end + -- If in-reply-to header not present return + in_reply_to = task:get_header_raw('in-reply-to') + if not in_reply_to then + return + end + -- Create hash of in-reply-to and query redis + local key = make_key(in_reply_to, settings.key_size, settings.key_prefix) + + local ret = lua_redis.redis_make_request(task, + redis_params, -- connect params + key, -- hash key + false, -- is write + redis_get_cb, --callback + 'GET', -- command + { key } -- arguments + ) + + if not ret then + rspamd_logger.errx(task, "redis request wasn't scheduled") + end +end + +local function replies_set(task) + local function redis_set_cb(err, _, addr) + if err ~= nil then + rspamd_logger.errx(task, 'redis_set_cb error when writing data to %s: %s', addr:get_addr(), err) + end + end + -- If sender is unauthenticated return + local ip = task:get_ip() + if settings.use_auth and task:get_user() then + lua_util.debugm(N, task, 'sender is authenticated') + elseif settings.use_local and (ip and ip:is_local()) then + lua_util.debugm(N, task, 'sender is from local network') + else + return + end + -- If no message-id present return + local msg_id = task:get_header_raw('message-id') + if msg_id == nil or msg_id:len() <= (settings.min_message_id or 2) then + return + end + -- Create hash of message-id and store to redis + local key = make_key(msg_id, settings.key_size, settings.key_prefix) + + local sender = task:get_reply_sender() + + if sender then + local sender_hash = make_key(sender:lower(), 8) + lua_util.debugm(N, task, 'storing id: %s (%s), reply-to: %s (%s) for replies check', + msg_id, key, sender, sender_hash) + local ret = lua_redis.redis_make_request(task, + redis_params, -- connect params + key, -- hash key + true, -- is write + redis_set_cb, --callback + 'PSETEX', -- command + { key, tostring(math.floor(settings['expire'] * 1000)), sender_hash } -- arguments + ) + if not ret then + rspamd_logger.errx(task, "redis request wasn't scheduled") + end + else + rspamd_logger.infox(task, "cannot find reply sender address") + end +end + +local function replies_check_cookie(task) + local function cookie_matched(extra, ts) + local dt = task:get_date { format = 'connect', gmt = true } + + if dt < ts then + rspamd_logger.infox(task, 'ignore cookie as its date is in future') + + return + end + + if settings.cookie_valid_time then + if dt - ts > settings.cookie_valid_time then + rspamd_logger.infox(task, + 'ignore cookie as its timestamp is too old: %s (%s current time)', + ts, dt) + + return + end + end + + if extra then + task:insert_result(settings['symbol'], 1.0, + string.format('cookie:%s:%s', extra, ts)) + else + task:insert_result(settings['symbol'], 1.0, + string.format('cookie:%s', ts)) + end + if settings['action'] ~= nil then + local ip_addr = task:get_ip() + if (settings.use_auth and + task:get_user()) or + (settings.use_local and ip_addr and ip_addr:is_local()) then + rspamd_logger.infox(task, "not forcing action for local network or authorized user"); + else + task:set_pre_result(settings['action'], settings['message'], N) + end + end + end + + -- If in-reply-to header not present return + local irt = task:get_header('in-reply-to') + if irt == nil then + return + end + + local cr = require "rspamd_cryptobox" + -- Extract user part if needed + local extracted_cookie = irt:match('^%<?([^@]+)@.*$') + if not extracted_cookie then + -- Assume full message id as a cookie + extracted_cookie = irt + end + + local dec_cookie, ts = cr.decrypt_cookie(settings.cookie_key, extracted_cookie) + + if dec_cookie then + -- We have something that looks like a cookie + if settings.cookie_is_pattern then + local m = dec_cookie:match(settings.cookie) + + if m then + cookie_matched(m, ts) + end + else + -- Direct match + if dec_cookie == settings.cookie then + cookie_matched(nil, ts) + end + end + end +end + +local opts = rspamd_config:get_all_opt('replies') +if not (opts and type(opts) == 'table') then + rspamd_logger.infox(rspamd_config, 'module is unconfigured') + return +end +if opts then + settings = lua_util.override_defaults(settings, opts) + redis_params = lua_redis.parse_redis_server('replies') + if not redis_params then + if not (settings.cookie and settings.cookie_key) then + rspamd_logger.infox(rspamd_config, 'no servers are specified, disabling module') + lua_util.disable_module(N, "redis") + else + -- Cookies mode + -- Check key sanity: + local pattern = { '^' } + for i = 1, 32 do + pattern[i + 1] = '[a-zA-Z0-9]' + end + pattern[34] = '$' + if not settings.cookie_key:match(table.concat(pattern, '')) then + rspamd_logger.errx(rspamd_config, + 'invalid cookies key: %s, must be 32 hex digits', settings.cookie_key) + lua_util.disable_module(N, "config") + + return + end + + if settings.cookie_valid_time then + settings.cookie_valid_time = lua_util.parse_time_interval(settings.cookie_valid_time) + end + + local id = rspamd_config:register_symbol({ + name = 'REPLIES_CHECK', + type = 'prefilter', + callback = replies_check_cookie, + flags = 'nostat', + priority = lua_util.symbols_priorities.medium, + group = "replies" + }) + rspamd_config:register_symbol({ + name = settings['symbol'], + parent = id, + type = 'virtual', + score = settings.score, + group = "replies", + }) + end + else + rspamd_config:register_symbol({ + name = 'REPLIES_SET', + type = 'idempotent', + callback = replies_set, + group = 'replies', + flags = 'explicit_disable,ignore_passthrough', + }) + local id = rspamd_config:register_symbol({ + name = 'REPLIES_CHECK', + type = 'prefilter', + flags = 'nostat', + callback = replies_check, + priority = lua_util.symbols_priorities.medium, + group = "replies" + }) + rspamd_config:register_symbol({ + name = settings['symbol'], + parent = id, + type = 'virtual', + score = settings.score, + group = "replies", + }) + end +end diff --git a/src/plugins/lua/reputation.lua b/src/plugins/lua/reputation.lua new file mode 100644 index 0000000..a3af26c --- /dev/null +++ b/src/plugins/lua/reputation.lua @@ -0,0 +1,1390 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if confighelp then + return +end + +-- A generic plugin for reputation handling + +local E = {} +local N = 'reputation' + +local rspamd_logger = require "rspamd_logger" +local rspamd_util = require "rspamd_util" +local lua_util = require "lua_util" +local lua_maps = require "lua_maps" +local lua_maps_exprs = require "lua_maps_expressions" +local hash = require 'rspamd_cryptobox_hash' +local lua_redis = require "lua_redis" +local fun = require "fun" +local lua_selectors = require "lua_selectors" +local ts = require("tableshape").types + +local redis_params = nil +local default_expiry = 864000 -- 10 day by default +local default_prefix = 'RR:' -- Rspamd Reputation + +local tanh = math.tanh or rspamd_util.tanh + +-- Get reputation from ham/spam/probable hits +local function generic_reputation_calc(token, rule, mult, task) + local cfg = rule.selector.config or E + local reject_threshold = task:get_metric_score()[2] or 10.0 + + if cfg.score_calc_func then + return cfg.score_calc_func(rule, token, mult) + end + + if tonumber(token[1]) < cfg.lower_bound then + lua_util.debugm(N, task, "not enough matches %s < %s for rule %s", + token[1], cfg.lower_bound, rule.symbol) + return 0 + end + + -- Get average score + local avg_score = fun.foldl(function(acc, v) + return acc + v + end, 0.0, fun.map(tonumber, token[2])) / #token[2] + + -- Apply function tanh(x / reject_score * atanh(0.95) - atanh(0.5)) + -- 1.83178 0.5493 + local score = tanh(avg_score / reject_threshold * 1.83178 - 0.5493) * mult + lua_util.debugm(N, task, "got generic average score %s (reject threshold=%s, mult=%s) -> %s for rule %s", + avg_score, reject_threshold, mult, score, rule.symbol) + return score +end + +local function add_symbol_score(task, rule, mult, params) + if not params then + params = { tostring(mult) } + end + + if rule.selector.config.split_symbols then + local sym_spam = rule.symbol .. '_SPAM' + local sym_ham = rule.symbol .. '_HAM' + if not rule.static_symbols then + rule.static_symbols = {} + rule.static_symbols.ham = rspamd_config:get_symbol(sym_ham) + rule.static_symbols.spam = rspamd_config:get_symbol(sym_spam) + end + if mult >= 0 then + task:insert_result(sym_spam, mult, params) + else + -- Avoid multiplication of negative the `mult` by negative static score of the + -- ham symbol + if rule.static_symbols.ham and rule.static_symbols.ham.score then + if rule.static_symbols.ham.score < 0 then + mult = math.abs(mult) + end + end + task:insert_result(sym_ham, mult, params) + end + else + task:insert_result(rule.symbol, mult, params) + end +end + +local function sub_symbol_score(task, rule, score) + local function sym_score(sym) + local s = task:get_symbol(sym)[1] + return s.score + end + if rule.selector.config.split_symbols then + local spam_sym = rule.symbol .. '_SPAM' + local ham_sym = rule.symbol .. '_HAM' + + if task:has_symbol(spam_sym) then + score = score - sym_score(spam_sym) + elseif task:has_symbol(ham_sym) then + score = score - sym_score(ham_sym) + end + else + if task:has_symbol(rule.symbol) then + score = score - sym_score(rule.symbol) + end + end + + return score +end + +-- Extracts task score and subtracts score of the rule itself +local function extract_task_score(task, rule) + local lua_verdict = require "lua_verdict" + local verdict, score = lua_verdict.get_specific_verdict(N, task) + + if not score or verdict == 'passthrough' then + return nil + end + + return sub_symbol_score(task, rule, score) +end + +-- DKIM Selector functions +local gr +local function gen_dkim_queries(task, rule) + local dkim_trace = (task:get_symbol('DKIM_TRACE') or E)[1] + local lpeg = require 'lpeg' + local ret = {} + + if not gr then + local semicolon = lpeg.P(':') + local domain = lpeg.C((1 - semicolon) ^ 1) + local res = lpeg.S '+-?~' + + local function res_to_label(ch) + if ch == '+' then + return 'a' + elseif ch == '-' then + return 'r' + end + + return 'u' + end + + gr = domain * semicolon * (lpeg.C(res ^ 1) / res_to_label) + end + + if dkim_trace and dkim_trace.options then + for _, opt in ipairs(dkim_trace.options) do + local dom, res = lpeg.match(gr, opt) + + if dom and res then + local tld = rspamd_util.get_tld(dom) + ret[tld] = res + end + end + end + + return ret +end + +local function dkim_reputation_filter(task, rule) + local requests = gen_dkim_queries(task, rule) + local results = {} + local dkim_tlds = lua_util.keys(requests) + local requests_left = #dkim_tlds + local rep_accepted = 0.0 + + lua_util.debugm(N, task, 'dkim reputation tokens: %s', requests) + + local function tokens_cb(err, token, values) + requests_left = requests_left - 1 + + if values then + results[token] = values + end + + if requests_left == 0 then + for k, v in pairs(results) do + -- `k` in results is a prefixed and suffixed tld, so we need to look through + -- all requests to find any request with the matching tld + local sel_tld + for _, tld in ipairs(dkim_tlds) do + if k:find(tld, 1, true) then + sel_tld = tld + break + end + end + + if sel_tld and requests[sel_tld] then + if requests[sel_tld] == 'a' then + rep_accepted = rep_accepted + generic_reputation_calc(v, rule, 1.0, task) + end + else + rspamd_logger.warnx(task, "cannot find the requested tld for a request: %s (%s tlds noticed)", + k, dkim_tlds) + end + end + + -- Set local reputation symbol + local rep_accepted_abs = math.abs(rep_accepted or 0) + lua_util.debugm(N, task, "dkim reputation accepted: %s", + rep_accepted_abs) + if rep_accepted_abs then + local final_rep = rep_accepted + if rep_accepted > 1.0 then + final_rep = 1.0 + end + if rep_accepted < -1.0 then + final_rep = -1.0 + end + add_symbol_score(task, rule, final_rep) + + -- Store results for future DKIM results adjustments + task:get_mempool():set_variable("dkim_reputation_accept", tostring(rep_accepted)) + end + end + end + + for dom, res in pairs(requests) do + -- tld + "." + check_result, e.g. example.com.+ - reputation for valid sigs + local query = string.format('%s.%s', dom, res) + rule.backend.get_token(task, rule, nil, query, tokens_cb, 'string') + end +end + +local function dkim_reputation_idempotent(task, rule) + local requests = gen_dkim_queries(task, rule) + local sc = extract_task_score(task, rule) + + if sc then + for dom, res in pairs(requests) do + -- tld + "." + check_result, e.g. example.com.+ - reputation for valid sigs + local query = string.format('%s.%s', dom, res) + rule.backend.set_token(task, rule, nil, query, sc) + end + end +end + +local function dkim_reputation_postfilter(task, rule) + local sym_accepted = (task:get_symbol('R_DKIM_ALLOW') or E)[1] + local accept_adjustment = task:get_mempool():get_variable("dkim_reputation_accept") + local cfg = rule.selector.config or E + + if sym_accepted and sym_accepted.score and + accept_adjustment and type(cfg.max_accept_adjustment) == 'number' then + local final_adjustment = cfg.max_accept_adjustment * + rspamd_util.tanh(tonumber(accept_adjustment) or 0) + lua_util.debugm(N, task, "adjust DKIM_ALLOW: " .. + "cfg.max_accept_adjustment=%s accept_adjustment=%s final_adjustment=%s sym_accepted.score=%s", + cfg.max_accept_adjustment, accept_adjustment, final_adjustment, + sym_accepted.score) + + task:adjust_result('R_DKIM_ALLOW', sym_accepted.score + final_adjustment) + end +end + +local dkim_selector = { + config = { + symbol = 'DKIM_SCORE', -- symbol to be inserted + lower_bound = 10, -- minimum number of messages to be scored + min_score = nil, + max_score = nil, + outbound = true, + inbound = true, + max_accept_adjustment = 2.0, -- How to adjust accepted DKIM score + }, + dependencies = { "DKIM_TRACE" }, + filter = dkim_reputation_filter, -- used to get scores + postfilter = dkim_reputation_postfilter, -- used to adjust DKIM scores + idempotent = dkim_reputation_idempotent, -- used to set scores +} + +-- URL Selector functions + +local function gen_url_queries(task, rule) + local domains = {} + + fun.each(function(u) + if u:is_redirected() then + local redir = u:get_redirected() -- get the original url + local redir_tld = redir:get_tld() + if domains[redir_tld] then + domains[redir_tld] = domains[redir_tld] - 1 + end + end + local dom = u:get_tld() + if not domains[dom] then + domains[dom] = 1 + else + domains[dom] = domains[dom] + 1 + end + end, fun.filter(function(u) + return not u:is_html_displayed() + end, + task:get_urls(true))) + + local results = {} + for k, v in lua_util.spairs(domains, + function(t, a, b) + return t[a] > t[b] + end, rule.selector.config.max_urls) do + if v > 0 then + table.insert(results, { k, v }) + end + end + + return results +end + +local function url_reputation_filter(task, rule) + local requests = gen_url_queries(task, rule) + local url_keys = lua_util.keys(requests) + local requests_left = #url_keys + local results = {} + + local function indexed_tokens_cb(err, index, values) + requests_left = requests_left - 1 + + if values then + results[index] = values + end + + if requests_left == 0 then + -- Check the url with maximum hits + local mhits = 0 + + for i, res in ipairs(results) do + local req = requests[i] + if req then + local hits = tonumber(res[1]) + if hits > mhits then + mhits = hits + end + else + rspamd_logger.warnx(task, "cannot find the requested response for a request: %s (%s requests noticed)", + i, #requests) + end + end + + if mhits > 0 then + local score = 0 + for i, res in pairs(results) do + local req = requests[i] + if req then + local url_score = generic_reputation_calc(res, rule, + req[2] / mhits, task) + lua_util.debugm(N, task, "score for url %s is %s, score=%s", req[1], url_score, score) + score = score + url_score + end + end + + if math.abs(score) > 1e-3 then + -- TODO: add description + add_symbol_score(task, rule, score) + end + end + end + end + + for i, req in ipairs(requests) do + local function tokens_cb(err, token, values) + indexed_tokens_cb(err, i, values) + end + + rule.backend.get_token(task, rule, nil, req[1], tokens_cb, 'string') + end +end + +local function url_reputation_idempotent(task, rule) + local requests = gen_url_queries(task, rule) + local sc = extract_task_score(task, rule) + + if sc then + for _, tld in ipairs(requests) do + rule.backend.set_token(task, rule, nil, tld[1], sc) + end + end +end + +local url_selector = { + config = { + symbol = 'URL_SCORE', -- symbol to be inserted + lower_bound = 10, -- minimum number of messages to be scored + min_score = nil, + max_score = nil, + max_urls = 10, + check_from = true, + outbound = true, + inbound = true, + }, + filter = url_reputation_filter, -- used to get scores + idempotent = url_reputation_idempotent -- used to set scores +} +-- IP Selector functions + +local function ip_reputation_init(rule) + local cfg = rule.selector.config + + if cfg.asn_cc_whitelist then + cfg.asn_cc_whitelist = lua_maps.map_add('reputation', + 'asn_cc_whitelist', + 'map', + 'IP score whitelisted ASNs/countries') + end + + return true +end + +local function ip_reputation_filter(task, rule) + + local ip = task:get_from_ip() + + if not ip or not ip:is_valid() then + return + end + if lua_util.is_rspamc_or_controller(task) then + return + end + + local cfg = rule.selector.config + + if ip:get_version() == 4 and cfg.ipv4_mask then + ip = ip:apply_mask(cfg.ipv4_mask) + elseif cfg.ipv6_mask then + ip = ip:apply_mask(cfg.ipv6_mask) + end + + local pool = task:get_mempool() + local asn = pool:get_variable("asn") + local country = pool:get_variable("country") + + if country and cfg.asn_cc_whitelist then + if cfg.asn_cc_whitelist:get_key(country) then + return + end + if asn and cfg.asn_cc_whitelist:get_key(asn) then + return + end + end + + -- These variables are used to define if we have some specific token + local has_asn = not asn + local has_country = not country + local has_ip = false + + local asn_stats, country_stats, ip_stats + + local function ipstats_check() + local score = 0.0 + local description_t = {} + + if asn_stats then + local asn_score = generic_reputation_calc(asn_stats, rule, cfg.scores.asn, task) + score = score + asn_score + table.insert(description_t, string.format('asn: %s(%.2f)', + asn, asn_score)) + end + if country_stats then + local country_score = generic_reputation_calc(country_stats, rule, + cfg.scores.country, task) + score = score + country_score + table.insert(description_t, string.format('country: %s(%.2f)', + country, country_score)) + end + if ip_stats then + local ip_score = generic_reputation_calc(ip_stats, rule, cfg.scores.ip, + task) + score = score + ip_score + table.insert(description_t, string.format('ip: %s(%.2f)', + tostring(ip), ip_score)) + end + + if math.abs(score) > 0.001 then + add_symbol_score(task, rule, score, table.concat(description_t, ', ')) + end + end + + local function gen_token_callback(what) + return function(err, _, values) + if not err and values then + if what == 'asn' then + has_asn = true + asn_stats = values + elseif what == 'country' then + has_country = true + country_stats = values + elseif what == 'ip' then + has_ip = true + ip_stats = values + end + else + if what == 'asn' then + has_asn = true + elseif what == 'country' then + has_country = true + elseif what == 'ip' then + has_ip = true + end + end + + if has_asn and has_country and has_ip then + -- Check reputation + ipstats_check() + end + end + end + + if asn then + rule.backend.get_token(task, rule, cfg.asn_prefix, asn, + gen_token_callback('asn'), 'string') + end + if country then + rule.backend.get_token(task, rule, cfg.country_prefix, country, + gen_token_callback('country'), 'string') + end + + rule.backend.get_token(task, rule, cfg.ip_prefix, ip, + gen_token_callback('ip'), 'ip') +end + +-- Used to set scores +local function ip_reputation_idempotent(task, rule) + if not rule.backend.set_token then + return + end -- Read only backend + local ip = task:get_from_ip() + local cfg = rule.selector.config + + if not ip or not ip:is_valid() then + return + end + + if lua_util.is_rspamc_or_controller(task) then + return + end + + if ip:get_version() == 4 and cfg.ipv4_mask then + ip = ip:apply_mask(cfg.ipv4_mask) + elseif cfg.ipv6_mask then + ip = ip:apply_mask(cfg.ipv6_mask) + end + + local pool = task:get_mempool() + local asn = pool:get_variable("asn") + local country = pool:get_variable("country") + + if country and cfg.asn_cc_whitelist then + if cfg.asn_cc_whitelist:get_key(country) then + return + end + if asn and cfg.asn_cc_whitelist:get_key(asn) then + return + end + end + local sc = extract_task_score(task, rule) + if sc then + if asn then + rule.backend.set_token(task, rule, cfg.asn_prefix, asn, sc, nil, 'string') + end + if country then + rule.backend.set_token(task, rule, cfg.country_prefix, country, sc, nil, 'string') + end + + rule.backend.set_token(task, rule, cfg.ip_prefix, ip, sc, nil, 'ip') + end +end + +-- Selectors are used to extract reputation tokens +local ip_selector = { + config = { + scores = { -- how each component is evaluated + ['asn'] = 0.4, + ['country'] = 0.01, + ['ip'] = 1.0 + }, + symbol = 'SENDER_REP', -- symbol to be inserted + split_symbols = true, + asn_prefix = 'a:', -- prefix for ASN hashes + country_prefix = 'c:', -- prefix for country hashes + ip_prefix = 'i:', + lower_bound = 10, -- minimum number of messages to be scored + min_score = nil, + max_score = nil, + score_divisor = 1, + outbound = false, + inbound = true, + ipv4_mask = 32, -- Mask bits for ipv4 + ipv6_mask = 64, -- Mask bits for ipv6 + }, + --dependencies = {"ASN"}, -- ASN is a prefilter now... + init = ip_reputation_init, + filter = ip_reputation_filter, -- used to get scores + idempotent = ip_reputation_idempotent, -- used to set scores +} + +-- SPF Selector functions + +local function spf_reputation_filter(task, rule) + local spf_record = task:get_mempool():get_variable('spf_record') + local spf_allow = task:has_symbol('R_SPF_ALLOW') + + -- Don't care about bad/missing spf + if not spf_record or not spf_allow then + return + end + + local cr = require "rspamd_cryptobox_hash" + local hkey = cr.create(spf_record):base32():sub(1, 32) + + lua_util.debugm(N, task, 'check spf record %s -> %s', spf_record, hkey) + + local function tokens_cb(err, token, values) + if values then + local score = generic_reputation_calc(values, rule, 1.0, task) + + if math.abs(score) > 1e-3 then + -- TODO: add description + add_symbol_score(task, rule, score) + end + end + end + + rule.backend.get_token(task, rule, nil, hkey, tokens_cb, 'string') +end + +local function spf_reputation_idempotent(task, rule) + local sc = extract_task_score(task, rule) + local spf_record = task:get_mempool():get_variable('spf_record') + local spf_allow = task:has_symbol('R_SPF_ALLOW') + + if not spf_record or not spf_allow or not sc then + return + end + + local cr = require "rspamd_cryptobox_hash" + local hkey = cr.create(spf_record):base32():sub(1, 32) + + lua_util.debugm(N, task, 'set spf record %s -> %s = %s', + spf_record, hkey, sc) + rule.backend.set_token(task, rule, nil, hkey, sc) +end + +local spf_selector = { + config = { + symbol = 'SPF_REP', -- symbol to be inserted + split_symbols = true, + lower_bound = 10, -- minimum number of messages to be scored + min_score = nil, + max_score = nil, + outbound = true, + inbound = true, + }, + dependencies = { "R_SPF_ALLOW" }, + filter = spf_reputation_filter, -- used to get scores + idempotent = spf_reputation_idempotent, -- used to set scores +} + +-- Generic selector based on lua_selectors framework + +local function generic_reputation_init(rule) + local cfg = rule.selector.config + + if not cfg.selector then + rspamd_logger.errx(rspamd_config, 'cannot configure generic rule: no selector specified') + return false + end + + local selector = lua_selectors.create_selector_closure(rspamd_config, + cfg.selector, cfg.delimiter) + + if not selector then + rspamd_logger.errx(rspamd_config, 'cannot configure generic rule: bad selector: %s', + cfg.selector) + return false + end + + cfg.selector = selector -- Replace with closure + + if cfg.whitelist then + cfg.whitelist = lua_maps.map_add('reputation', + 'generic_whitelist', + 'map', + 'Whitelisted selectors') + end + + return true +end + +local function generic_reputation_filter(task, rule) + local cfg = rule.selector.config + local selector_res = cfg.selector(task) + + local function tokens_cb(err, token, values) + if values then + local score = generic_reputation_calc(values, rule, 1.0, task) + + if math.abs(score) > 1e-3 then + -- TODO: add description + add_symbol_score(task, rule, score) + end + end + end + + if selector_res then + if type(selector_res) == 'table' then + fun.each(function(e) + lua_util.debugm(N, task, 'check generic reputation (%s) %s', + rule['symbol'], e) + rule.backend.get_token(task, rule, nil, e, tokens_cb, 'string') + end, selector_res) + else + lua_util.debugm(N, task, 'check generic reputation (%s) %s', + rule['symbol'], selector_res) + rule.backend.get_token(task, rule, nil, selector_res, tokens_cb, 'string') + end + end +end + +local function generic_reputation_idempotent(task, rule) + local sc = extract_task_score(task, rule) + local cfg = rule.selector.config + + local selector_res = cfg.selector(task) + if not selector_res then + return + end + + if sc then + if type(selector_res) == 'table' then + fun.each(function(e) + lua_util.debugm(N, task, 'set generic selector (%s) %s = %s', + rule['symbol'], e, sc) + rule.backend.set_token(task, rule, nil, e, sc) + end, selector_res) + else + lua_util.debugm(N, task, 'set generic selector (%s) %s = %s', + rule['symbol'], selector_res, sc) + rule.backend.set_token(task, rule, nil, selector_res, sc) + end + end +end + +local generic_selector = { + schema = ts.shape { + lower_bound = ts.number + ts.string / tonumber, + max_score = ts.number:is_optional(), + min_score = ts.number:is_optional(), + outbound = ts.boolean, + inbound = ts.boolean, + selector = ts.string, + delimiter = ts.string, + whitelist = ts.one_of(lua_maps.map_schema, lua_maps_exprs.schema):is_optional(), + }, + config = { + lower_bound = 10, -- minimum number of messages to be scored + min_score = nil, + max_score = nil, + outbound = true, + inbound = true, + selector = nil, + delimiter = ':', + whitelist = nil + }, + init = generic_reputation_init, + filter = generic_reputation_filter, -- used to get scores + idempotent = generic_reputation_idempotent -- used to set scores +} + +local selectors = { + ip = ip_selector, + sender = ip_selector, -- Better name + url = url_selector, + dkim = dkim_selector, + spf = spf_selector, + generic = generic_selector +} + +local function reputation_dns_init(rule, _, _, _) + if not rule.backend.config.list then + rspamd_logger.errx(rspamd_config, "rule %s with DNS backend has no `list` parameter defined", + rule.symbol) + return false + end + + return true +end + +local function gen_token_key(prefix, token, rule) + if prefix then + token = prefix .. token + end + local res = token + if rule.backend.config.hashed then + local hash_alg = rule.backend.config.hash_alg or "blake2" + local encoding = "base32" + + if rule.backend.config.hash_encoding then + encoding = rule.backend.config.hash_encoding + end + + local h = hash.create_specific(hash_alg, res) + if encoding == 'hex' then + res = h:hex() + elseif encoding == 'base64' then + res = h:base64() + else + res = h:base32() + end + end + + if rule.backend.config.hashlen then + res = string.sub(res, 1, rule.backend.config.hashlen) + end + + if rule.backend.config.prefix then + res = rule.backend.config.prefix .. res + end + + return res +end + +--[[ +-- Generic interface for get and set tokens functions: +-- get_token(task, rule, prefix, token, continuation, token_type), where `continuation` is the following function: +-- +-- function(err, token, values) ... end +-- `err`: string value for error (similar to redis or DNS callbacks) +-- `token`: string value of a token +-- `values`: table of key=number, parsed from backend. It is selector's duty +-- to deal with missing, invalid or other values +-- +-- set_token(task, rule, token, values, continuation_cb) +-- This function takes values, encodes them using whatever suitable format +-- and calls for continuation: +-- +-- function(err, token) ... end +-- `err`: string value for error (similar to redis or DNS callbacks) +-- `token`: string value of a token +-- +-- example of tokens: {'s': 0, 'h': 0, 'p': 1} +--]] + +local function reputation_dns_get_token(task, rule, prefix, token, continuation_cb, token_type) + -- local r = task:get_resolver() + -- In DNS we never ever use prefix as prefix, we use if as a suffix! + if token_type == 'ip' then + token = table.concat(token:inversed_str_octets(), '.') + end + + local key = gen_token_key(nil, token, rule) + local dns_name = key .. '.' .. rule.backend.config.list + + if prefix then + dns_name = string.format('%s.%s.%s', key, prefix, + rule.backend.config.list) + else + dns_name = string.format('%s.%s', key, rule.backend.config.list) + end + + local function dns_cb(_, _, results, err) + if err and (err ~= 'requested record is not found' and + err ~= 'no records with this name') then + rspamd_logger.warnx(task, 'error looking up %s: %s', dns_name, err) + end + + lua_util.debugm(N, task, 'DNS RESPONSE: label=%1 results=%2 err=%3 list=%4', + dns_name, results, err, rule.backend.config.list) + + -- Now split tokens to list of values + if results and results[1] then + -- Format: num_messages;sc1;sc2...scn + local dns_tokens = lua_util.rspamd_str_split(results[1], ";") + -- Convert all to numbers excluding any possible non-numbers + dns_tokens = fun.totable(fun.filter(function(e) + return type(e) == 'number' + end, + fun.map(function(e) + local n = tonumber(e) + if n then + return n + end + return "BAD" + end, dns_tokens))) + + if #dns_tokens < 2 then + rspamd_logger.warnx(task, 'cannot parse response for reputation token %s: %s', + dns_name, results[1]) + continuation_cb(results, dns_name, nil) + else + local cnt = table.remove(dns_tokens, 1) + continuation_cb(nil, dns_name, { cnt, dns_tokens }) + end + else + rspamd_logger.messagex(task, 'invalid response for reputation token %s: %s', + dns_name, results[1]) + continuation_cb(results, dns_name, nil) + end + end + + task:get_resolver():resolve_a({ + task = task, + name = dns_name, + callback = dns_cb, + forced = true, + }) +end + +local function reputation_redis_init(rule, cfg, ev_base, worker) + local our_redis_params = {} + + our_redis_params = lua_redis.try_load_redis_servers(rule.backend.config, rspamd_config, + true) + if not our_redis_params then + our_redis_params = redis_params + end + if not our_redis_params then + rspamd_logger.errx(rspamd_config, 'cannot init redis for reputation rule: %s', + rule) + return false + end + -- Init scripts for buckets + -- Redis script to extract data from Redis buckets + -- KEYS[1] - key to extract + -- Value returned - table of scores as a strings vector + number of scores + local redis_get_script_tpl = [[ + local cnt = redis.call('HGET', KEYS[1], 'n') + local results = {} + if cnt then + {% for w in windows %} + local sc = tonumber(redis.call('HGET', KEYS[1], 'v' .. '{= w.name =}')) + table.insert(results, tostring(sc * {= w.mult =})) + {% endfor %} + else + {% for w in windows %} + table.insert(results, '0') + {% endfor %} + end + + return {cnt or 0, results} + ]] + + local get_script = lua_util.jinja_template(redis_get_script_tpl, + { windows = rule.backend.config.buckets }) + rspamd_logger.debugm(N, rspamd_config, 'added extraction script %s', get_script) + rule.backend.script_get = lua_redis.add_redis_script(get_script, our_redis_params) + + -- Redis script to update Redis buckets + -- KEYS[1] - key to update + -- KEYS[2] - current time in milliseconds + -- KEYS[3] - message score + -- KEYS[4] - expire for a bucket + -- Value returned - table of scores as a strings vector + local redis_adaptive_emea_script_tpl = [[ + local last = redis.call('HGET', KEYS[1], 'l') + local score = tonumber(KEYS[3]) + local now = tonumber(KEYS[2]) + local scores = {} + + if last then + {% for w in windows %} + local last_value = tonumber(redis.call('HGET', KEYS[1], 'v' .. '{= w.name =}')) + local window = {= w.time =} + -- Adjust alpha + local time_diff = now - last + if time_diff < 0 then + time_diff = 0 + end + local alpha = 1.0 - math.exp((-time_diff) / (1000 * window)) + local nscore = alpha * score + (1.0 - alpha) * last_value + table.insert(scores, tostring(nscore * {= w.mult =})) + {% endfor %} + else + {% for w in windows %} + table.insert(scores, tostring(score * {= w.mult =})) + {% endfor %} + end + + local i = 1 + {% for w in windows %} + redis.call('HSET', KEYS[1], 'v' .. '{= w.name =}', scores[i]) + i = i + 1 + {% endfor %} + redis.call('HSET', KEYS[1], 'l', now) + redis.call('HINCRBY', KEYS[1], 'n', 1) + redis.call('EXPIRE', KEYS[1], tonumber(KEYS[4])) + + return scores +]] + + local set_script = lua_util.jinja_template(redis_adaptive_emea_script_tpl, + { windows = rule.backend.config.buckets }) + rspamd_logger.debugm(N, rspamd_config, 'added emea update script %s', set_script) + rule.backend.script_set = lua_redis.add_redis_script(set_script, our_redis_params) + + return true +end + +local function reputation_redis_get_token(task, rule, prefix, token, continuation_cb, token_type) + if token_type and token_type == 'ip' then + token = tostring(token) + end + local key = gen_token_key(prefix, token, rule) + + local function redis_get_cb(err, data) + if data then + if type(data) == 'table' then + lua_util.debugm(N, task, 'rule %s - got values for key %s -> %s', + rule['symbol'], key, data) + continuation_cb(nil, key, data) + else + rspamd_logger.errx(task, 'rule %s - invalid type while getting reputation keys %s: %s', + rule['symbol'], key, type(data)) + continuation_cb("invalid type", key, nil) + end + + elseif err then + rspamd_logger.errx(task, 'rule %s - got error while getting reputation keys %s: %s', + rule['symbol'], key, err) + continuation_cb(err, key, nil) + else + rspamd_logger.errx(task, 'rule %s - got error while getting reputation keys %s: %s', + rule['symbol'], key, "unknown error") + continuation_cb("unknown error", key, nil) + end + end + + local ret = lua_redis.exec_redis_script(rule.backend.script_get, + { task = task, is_write = false }, + redis_get_cb, + { key }) + if not ret then + rspamd_logger.errx(task, 'cannot make redis request to check results') + end +end + +local function reputation_redis_set_token(task, rule, prefix, token, sc, continuation_cb, token_type) + if token_type and token_type == 'ip' then + token = tostring(token) + end + local key = gen_token_key(prefix, token, rule) + + local function redis_set_cb(err, data) + if err then + rspamd_logger.errx(task, 'rule %s - got error while setting reputation keys %s: %s', + rule['symbol'], key, err) + if continuation_cb then + continuation_cb(err, key) + end + else + if continuation_cb then + continuation_cb(nil, key) + end + end + end + + lua_util.debugm(N, task, 'rule %s - set values for key %s -> %s', + rule['symbol'], key, sc) + local ret = lua_redis.exec_redis_script(rule.backend.script_set, + { task = task, is_write = true }, + redis_set_cb, + { key, tostring(os.time() * 1000), + tostring(sc), + tostring(rule.backend.config.expiry) }) + if not ret then + rspamd_logger.errx(task, 'got error while connecting to redis') + end +end + +--[[ Backends are responsible for getting reputation tokens + -- Common config options: + -- `hashed`: if `true` then apply hash function to the key + -- `hash_alg`: use specific hash type (`blake2` by default) + -- `hash_len`: strip hash to this amount of bytes (no strip by default) + -- `hash_encoding`: use specific hash encoding (base32 by default) +--]] +local backends = { + redis = { + schema = lua_redis.enrich_schema({ + prefix = ts.string:is_optional(), + expiry = (ts.number + ts.string / lua_util.parse_time_interval):is_optional(), + buckets = ts.array_of(ts.shape { + time = ts.number + ts.string / lua_util.parse_time_interval, + name = ts.string, + mult = ts.number + ts.string / tonumber + }) :is_optional(), + }), + config = { + expiry = default_expiry, + prefix = default_prefix, + buckets = { + { + time = 60 * 60 * 24 * 30, + name = '1m', + mult = 1.0, + } + }, -- What buckets should be used, default 1h and 1month + }, + init = reputation_redis_init, + get_token = reputation_redis_get_token, + set_token = reputation_redis_set_token, + }, + dns = { + schema = ts.shape { + list = ts.string, + }, + config = { + -- list = rep.example.com + }, + get_token = reputation_dns_get_token, + -- No set token for DNS + init = reputation_dns_init, + } +} + +local function is_rule_applicable(task, rule) + local ip = task:get_from_ip() + if not (rule.selector.config.outbound and rule.selector.config.inbound) then + if rule.selector.config.outbound then + if not (task:get_user() or (ip and ip:is_local())) then + return false + end + elseif rule.selector.config.inbound then + if task:get_user() or (ip and ip:is_local()) then + return false + end + end + end + + if rule.config.whitelist_map then + if rule.config.whitelist_map:process(task) then + return false + end + end + + return true +end + +local function reputation_filter_cb(task, rule) + if (is_rule_applicable(task, rule)) then + rule.selector.filter(task, rule, rule.backend) + end +end + +local function reputation_postfilter_cb(task, rule) + if (is_rule_applicable(task, rule)) then + rule.selector.postfilter(task, rule, rule.backend) + end +end + +local function reputation_idempotent_cb(task, rule) + if (is_rule_applicable(task, rule)) then + rule.selector.idempotent(task, rule, rule.backend) + end +end + +local function callback_gen(cb, rule) + return function(task) + if rule.enabled then + cb(task, rule) + end + end +end + +local function parse_rule(name, tbl) + local sel_type, sel_conf = fun.head(tbl.selector) + local selector = selectors[sel_type] + + if not selector then + rspamd_logger.errx(rspamd_config, "unknown selector defined for rule %s: %s", name, + sel_type) + return false + end + + local bk_type, bk_conf = fun.head(tbl.backend) + + local backend = backends[bk_type] + if not backend then + rspamd_logger.errx(rspamd_config, "unknown backend defined for rule %s: %s", name, + tbl.backend.type) + return false + end + -- Allow config override + local rule = { + selector = lua_util.shallowcopy(selector), + backend = lua_util.shallowcopy(backend), + config = {} + } + + -- Override default config params + rule.backend.config = lua_util.override_defaults(rule.backend.config, bk_conf) + if backend.schema then + local checked, schema_err = backend.schema:transform(rule.backend.config) + if not checked then + rspamd_logger.errx(rspamd_config, "cannot parse backend config for %s: %s", + sel_type, schema_err) + + return false + end + + rule.backend.config = checked + end + + rule.selector.config = lua_util.override_defaults(rule.selector.config, sel_conf) + if selector.schema then + local checked, schema_err = selector.schema:transform(rule.selector.config) + + if not checked then + rspamd_logger.errx(rspamd_config, "cannot parse selector config for %s: %s (%s)", + sel_type, + schema_err, sel_conf) + return + end + + rule.selector.config = checked + end + -- Generic options + tbl.selector = nil + tbl.backend = nil + rule.config = lua_util.override_defaults(rule.config, tbl) + + if rule.config.whitelist then + if lua_maps_exprs.schema(rule.config.whitelist) then + rule.config.whitelist_map = lua_maps_exprs.create(rspamd_config, + rule.config.whitelist, N) + elseif lua_maps.map_schema(rule.config.whitelist) then + local map = lua_maps.map_add_from_ucl(rule.config.whitelist, + 'radix', + sel_type .. ' reputation whitelist') + + if not map then + rspamd_logger.errx(rspamd_config, "cannot parse whitelist map config for %s: (%s)", + sel_type, + rule.config.whitelist) + return + end + + rule.config.whitelist_map = { + process = function(_, task) + -- Hack: we assume that it is an ip whitelist :( + local ip = task:get_from_ip() + + if ip and map:get_key(ip) then + return true + end + return false + end + } + else + rspamd_logger.errx(rspamd_config, "cannot parse whitelist map config for %s: (%s)", + sel_type, + rule.config.whitelist) + return false + end + end + + local symbol = rule.selector.config.symbol or name + if tbl.symbol then + symbol = tbl.symbol + end + + rule.symbol = symbol + rule.enabled = true + if rule.selector.init then + rule.enabled = false + end + if rule.backend.init then + rule.enabled = false + end + -- Perform additional initialization if needed + rspamd_config:add_on_load(function(cfg, ev_base, worker) + if rule.selector.init then + if not rule.selector.init(rule, cfg, ev_base, worker) then + rule.enabled = false + rspamd_logger.errx(rspamd_config, 'Cannot init selector %s (backend %s) for symbol %s', + sel_type, bk_type, rule.symbol) + else + rule.enabled = true + end + end + if rule.backend.init then + if not rule.backend.init(rule, cfg, ev_base, worker) then + rule.enabled = false + rspamd_logger.errx(rspamd_config, 'Cannot init backend (%s) for rule %s for symbol %s', + bk_type, sel_type, rule.symbol) + else + rule.enabled = true + end + end + + if rule.enabled then + rspamd_logger.infox(rspamd_config, 'Enable %s (%s backend) rule for symbol %s (split symbols: %s)', + sel_type, bk_type, rule.symbol, + rule.selector.config.split_symbols) + end + end) + + -- We now generate symbol for checking + local rule_type = 'normal' + if rule.selector.config.split_symbols then + rule_type = 'callback' + end + + local id = rspamd_config:register_symbol { + name = rule.symbol, + type = rule_type, + callback = callback_gen(reputation_filter_cb, rule), + augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) }, + } + + if rule.selector.config.split_symbols then + rspamd_config:register_symbol { + name = rule.symbol .. '_HAM', + type = 'virtual', + parent = id, + } + rspamd_config:register_symbol { + name = rule.symbol .. '_SPAM', + type = 'virtual', + parent = id, + } + end + + if rule.selector.dependencies then + fun.each(function(d) + rspamd_config:register_dependency(symbol, d) + end, rule.selector.dependencies) + end + + if rule.selector.postfilter then + -- Also register a postfilter + rspamd_config:register_symbol { + name = rule.symbol .. '_POST', + type = 'postfilter', + flags = 'nostat,explicit_disable,ignore_passthrough', + callback = callback_gen(reputation_postfilter_cb, rule), + augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) }, + } + end + + if rule.selector.idempotent then + -- Has also idempotent component (e.g. saving data to the backend) + rspamd_config:register_symbol { + name = rule.symbol .. '_IDEMPOTENT', + type = 'idempotent', + flags = 'explicit_disable,ignore_passthrough', + callback = callback_gen(reputation_idempotent_cb, rule), + augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) }, + } + end + + return true +end + +redis_params = lua_redis.parse_redis_server('reputation') +local opts = rspamd_config:get_all_opt("reputation") + +-- Initialization part +if not (opts and type(opts) == 'table') then + rspamd_logger.infox(rspamd_config, 'Module is not configured, disabling it') + return +end + +if opts['rules'] then + for k, v in pairs(opts['rules']) do + if not ((v or E).selector) then + rspamd_logger.errx(rspamd_config, "no selector defined for rule %s", k) + lua_util.config_utils.push_config_error(N, "no selector defined for rule: " .. k) + else + if not parse_rule(k, v) then + lua_util.config_utils.push_config_error(N, "reputation rule is misconfigured: " .. k) + end + end + end +else + lua_util.disable_module(N, "config") +end diff --git a/src/plugins/lua/rspamd_update.lua b/src/plugins/lua/rspamd_update.lua new file mode 100644 index 0000000..deda038 --- /dev/null +++ b/src/plugins/lua/rspamd_update.lua @@ -0,0 +1,161 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if confighelp then + return +end + +-- This plugin implements dynamic updates for rspamd + +local ucl = require "ucl" +local fun = require "fun" +local rspamd_logger = require "rspamd_logger" +local rspamd_config = rspamd_config +local hash = require "rspamd_cryptobox_hash" +local lua_util = require "lua_util" +local N = "rspamd_update" +local rspamd_version = rspamd_version +local maps = {} +local allow_rules = false -- Deny for now +local global_priority = 1 -- Default for local rules + +local function process_symbols(obj, priority) + fun.each(function(sym, score) + rspamd_config:set_metric_symbol({ + name = sym, + score = score, + priority = priority + }) + end, obj) +end + +local function process_actions(obj, priority) + fun.each(function(act, score) + rspamd_config:set_metric_action({ + action = act, + score = score, + priority = priority + }) + end, obj) +end + +local function process_rules(obj) + fun.each(function(key, code) + local f = load(code) + if f then + f() + else + rspamd_logger(rspamd_config, 'cannot load rules for %s', key) + end + end, obj) +end + +local function check_version(obj) + local ret = true + + if not obj then + return false + end + + if obj['min_version'] then + if rspamd_version('cmp', obj['min_version']) > 0 then + ret = false + rspamd_logger.errx(rspamd_config, 'updates require at least %s version of rspamd', + obj['min_version']) + end + end + if obj['max_version'] then + if rspamd_version('cmp', obj['max_version']) < 0 then + ret = false + rspamd_logger.errx(rspamd_config, 'updates require maximum %s version of rspamd', + obj['max_version']) + end + end + + return ret +end + +local function gen_callback() + + return function(data) + local parser = ucl.parser() + local res, err = parser:parse_string(data) + + if not res then + rspamd_logger.warnx(rspamd_config, 'cannot parse updates map: ' .. err) + else + local h = hash.create() + h:update(data) + local obj = parser:get_object() + + if check_version(obj) then + + if obj['symbols'] then + process_symbols(obj['symbols'], global_priority) + end + if obj['actions'] then + process_actions(obj['actions'], global_priority) + end + if allow_rules and obj['rules'] then + process_rules(obj['rules']) + end + + rspamd_logger.infox(rspamd_config, 'loaded new rules with hash "%s"', + h:hex()) + end + end + + return res + end +end + +-- Configuration part +local section = rspamd_config:get_all_opt("rspamd_update") +if section and section.rules then + local trusted_key + if section.key then + trusted_key = section.key + end + + if type(section.rules) ~= 'table' then + section.rules = { section.rules } + end + + fun.each(function(elt) + local map = rspamd_config:add_map(elt, "rspamd updates map", nil, "callback") + if not map then + rspamd_logger.errx(rspamd_config, 'cannot load updates from %1', elt) + else + map:set_callback(gen_callback(map)) + maps['elt'] = map + end + end, section.rules) + + fun.each(function(k, map) + -- Check sanity for maps + local proto = map:get_proto() + if (proto == 'http' or proto == 'https') and not map:get_sign_key() then + if trusted_key then + map:set_sign_key(trusted_key) + else + rspamd_logger.warnx(rspamd_config, 'Map %s is loaded by HTTP and it is not signed', k) + end + end + end, maps) +else + rspamd_logger.infox(rspamd_config, 'Module is unconfigured') + lua_util.disable_module(N, "config") +end diff --git a/src/plugins/lua/settings.lua b/src/plugins/lua/settings.lua new file mode 100644 index 0000000..69d31d3 --- /dev/null +++ b/src/plugins/lua/settings.lua @@ -0,0 +1,1437 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if confighelp then + return +end + +-- This plugin implements user dynamic settings +-- Settings documentation can be found here: +-- https://rspamd.com/doc/configuration/settings.html + +local rspamd_logger = require "rspamd_logger" +local lua_maps = require "lua_maps" +local lua_util = require "lua_util" +local rspamd_ip = require "rspamd_ip" +local rspamd_regexp = require "rspamd_regexp" +local lua_selectors = require "lua_selectors" +local lua_settings = require "lua_settings" +local ucl = require "ucl" +local fun = require "fun" +local rspamd_mempool = require "rspamd_mempool" + +local redis_params + +local settings = {} +local N = "settings" +local settings_initialized = false +local max_pri = 0 +local module_sym_id -- Main module symbol + +local function apply_settings(task, to_apply, id, name) + local cached_name = task:cache_get('settings_name') + if cached_name then + local cached_settings = task:cache_get('settings') + rspamd_logger.warnx(task, "cannot apply settings rule %s (id=%s):" .. + " settings has been already applied by rule %s (id=%s)", + name, id, cached_name, cached_settings.id) + return false + end + + task:set_settings(to_apply) + task:cache_set('settings', to_apply) + task:cache_set('settings_name', name or 'unknown') + + if id then + task:set_settings_id(id) + end + + if to_apply['add_headers'] or to_apply['remove_headers'] then + local rep = { + add_headers = to_apply['add_headers'] or {}, + remove_headers = to_apply['remove_headers'] or {}, + } + task:set_rmilter_reply(rep) + end + + if to_apply.flags and type(to_apply.flags) == 'table' then + for _, fl in ipairs(to_apply.flags) do + task:set_flag(fl) + end + end + + if to_apply.symbols then + -- Add symbols, specified in the settings + if #to_apply.symbols > 0 then + -- Array like symbols + for _, val in ipairs(to_apply.symbols) do + task:insert_result(val, 1.0) + end + else + -- Object like symbols + for k, v in pairs(to_apply.symbols) do + if type(v) == 'table' then + task:insert_result(k, v.score or 1.0, v.options or {}) + elseif tonumber(v) then + task:insert_result(k, tonumber(v)) + end + end + end + end + + if to_apply.subject then + task:set_metric_subject(to_apply.subject) + end + + -- E.g. + -- messages = { smtp_message = "5.3.1 Go away" } + if to_apply.messages and type(to_apply.messages) == 'table' then + fun.each(function(category, message) + task:append_message(message, category) + end, to_apply.messages) + end + + return true +end + +-- Checks for overridden settings within query params and returns 3 values: +-- * Apply element +-- * Settings ID element if found +-- * Priority of the settings according to the place where it is found +-- +-- If no override has been found, it returns `false` +local function check_query_settings(task) + -- Try 'settings' attribute + local settings_id = task:get_settings_id() + local query_set = task:get_request_header('settings') + if query_set then + + local parser = ucl.parser() + local res, err = parser:parse_text(query_set) + if res then + if settings_id then + rspamd_logger.warnx(task, "both settings-id '%s' and settings headers are presented, ignore settings-id; ", + tostring(settings_id)) + end + local settings_obj = parser:get_object() + + -- Treat as low priority + return settings_obj, nil, 1 + else + rspamd_logger.errx(task, 'Parse error: %s', err) + end + end + + local query_maxscore = task:get_request_header('maxscore') + local nset + + if query_maxscore then + if settings_id then + rspamd_logger.infox(task, "both settings id '%s' and maxscore '%s' headers are presented, merge them; " .. + "settings id has priority", + tostring(settings_id), tostring(query_maxscore)) + end + -- We have score limits redefined by request + local ms = tonumber(tostring(query_maxscore)) + if ms then + nset = { + actions = { + reject = ms + } + } + + local query_softscore = task:get_request_header('softscore') + if query_softscore then + local ss = tonumber(tostring(query_softscore)) + nset.actions['add header'] = ss + end + + if not settings_id then + rspamd_logger.infox(task, 'apply maxscore = %s', nset.actions) + -- Maxscore is low priority + return nset, nil, 1 + end + end + end + + if settings_id and settings_initialized then + local cached = lua_settings.settings_by_id(settings_id) + lua_util.debugm(N, task, "check settings id for %s", settings_id) + + if cached then + local elt = cached.settings + if elt['whitelist'] then + elt['apply'] = { whitelist = true } + end + + if elt.apply then + if nset then + elt.apply = lua_util.override_defaults(nset, elt.apply) + end + end + return elt.apply, cached, cached.priority or 1 + else + rspamd_logger.warnx(task, 'no settings id "%s" has been found', settings_id) + if nset then + rspamd_logger.infox(task, 'apply maxscore = %s', nset.actions) + return nset, nil, 1 + end + end + else + if nset then + rspamd_logger.infox(task, 'apply maxscore = %s', nset.actions) + return nset, nil, 1 + end + end + + return false +end + +local function check_addr_setting(expected, addr) + local function check_specific_addr(elt) + if expected.name then + if lua_maps.rspamd_maybe_check_map(expected.name, elt.addr) then + return true + end + end + if expected.user then + if lua_maps.rspamd_maybe_check_map(expected.user, elt.user) then + return true + end + end + if expected.domain and elt.domain then + if lua_maps.rspamd_maybe_check_map(expected.domain, elt.domain) then + return true + end + end + if expected.regexp then + if expected.regexp:match(elt.addr) then + return true + end + end + return false + end + + for _, e in ipairs(addr) do + if check_specific_addr(e) then + return true + end + end + + return false +end + +local function check_string_setting(expected, str) + if expected.regexp then + if expected.regexp:match(str) then + return true + end + elseif expected.check then + if lua_maps.rspamd_maybe_check_map(expected.check, str) then + return true + end + end + return false +end + +local function check_ip_setting(expected, ip) + if not expected[2] then + if lua_maps.rspamd_maybe_check_map(expected[1], ip:to_string()) then + return true + end + else + if expected[2] ~= 0 then + local nip = ip:apply_mask(expected[2]) + if nip and nip:to_string() == expected[1] then + return true + end + elseif ip:to_string() == expected[1] then + return true + end + end + + return false +end + +local function check_map_setting(map, input) + return map:get_key(input) +end + +local function priority_to_string(pri) + if pri then + if pri >= 3 then + return "high" + elseif pri >= 2 then + return "medium" + end + end + + return "low" +end + +-- Check limit for a task +local function check_settings(task) + local function check_specific_setting(rule, matched) + local function process_atom(atom) + local elt = rule.checks[atom] + + if elt then + local input = elt.extract(task) + if not input then + return false + end + + if elt.check(input) then + matched[#matched + 1] = atom + return 1.0 + end + else + rspamd_logger.errx(task, 'error in settings: check %s is not defined!', atom) + end + + return 0 + end + + local res = rule.expression and rule.expression:process(process_atom) or rule.implicit + + if res and res > 0 then + if rule['whitelist'] then + rule['apply'] = { whitelist = true } + end + + return rule + end + + return nil + end + + -- Check if we have override as query argument + local query_apply, id_elt, priority = check_query_settings(task) + + local function maybe_apply_query_settings() + if query_apply then + if id_elt then + apply_settings(task, query_apply, id_elt.id, id_elt.name) + rspamd_logger.infox(task, "applied settings id %s(%s); priority %s", + id_elt.name, id_elt.id, priority_to_string(priority)) + else + apply_settings(task, query_apply, nil, 'HTTP query') + rspamd_logger.infox(task, "applied settings from query; priority %s", + priority_to_string(priority)) + end + end + end + + local min_pri = 1 + if query_apply then + if priority >= min_pri then + -- Do not check lower or equal priorities + min_pri = priority + 1 + end + + if priority > max_pri then + -- Our internal priorities are lower then a priority from query, so no need to check + maybe_apply_query_settings() + + return + end + elseif id_elt and type(id_elt.settings) == 'table' and id_elt.settings.external_map then + local external_map = id_elt.settings.external_map + local selector_result = external_map.selector(task) + + if selector_result then + external_map.map:get_key(selector_result, nil, task) + -- No more selection logic + return + else + rspamd_logger.infox("cannot query selector to make external map request") + end + end + + -- Do not waste resources + if not settings_initialized then + maybe_apply_query_settings() + return + end + + -- Match rules according their order + local applied = false + + for pri = max_pri, min_pri, -1 do + if not applied and settings[pri] then + for _, s in ipairs(settings[pri]) do + local matched = {} + + local result = check_specific_setting(s.rule, matched) + lua_util.debugm(N, task, "check for settings element %s; result = %s", + s.name, result) + -- Can use xor here but more complicated for reading + if result then + if s.rule.apply then + if s.rule.id then + -- Extract static settings + local cached = lua_settings.settings_by_id(s.rule.id) + + if not cached or not cached.settings or not cached.settings.apply then + rspamd_logger.errx(task, 'unregistered settings id found: %s!', s.rule.id) + else + rspamd_logger.infox(task, "<%s> apply static settings %s (id = %s); %s matched; priority %s", + task:get_message_id(), + cached.name, s.rule.id, + table.concat(matched, ','), + priority_to_string(pri)) + apply_settings(task, cached.settings.apply, s.rule.id, s.name) + end + + else + -- Dynamic settings + rspamd_logger.infox(task, "<%s> apply settings according to rule %s (%s matched)", + task:get_message_id(), s.name, table.concat(matched, ',')) + apply_settings(task, s.rule.apply, nil, s.name) + end + + applied = true + elseif s.rule.external_map then + local external_map = s.rule.external_map + local selector_result = external_map.selector(task) + + if selector_result then + external_map.map:get_key(selector_result, nil, task) + -- No more selection logic + return + else + rspamd_logger.infox("cannot query selector to make external map request") + end + end + if s.rule['symbols'] then + -- Add symbols, specified in the settings + fun.each(function(val) + task:insert_result(val, 1.0) + end, s.rule['symbols']) + end + end + end + end + end + + if not applied then + maybe_apply_query_settings() + end + +end + +local function convert_to_table(chk_elt, out) + if type(chk_elt) == 'string' then + return { out } + end + + return out +end + +local function gen_settings_external_cb(name) + return function(result, err_or_data, code, task) + if result then + local parser = ucl.parser() + + local res, ucl_err = parser:parse_text(err_or_data) + if not res then + rspamd_logger.warnx(task, 'cannot parse settings from the external map %s: %s', + name, ucl_err) + else + local obj = parser:get_object() + rspamd_logger.infox(task, "<%s> apply settings according to the external map %s", + name, task:get_message_id()) + apply_settings(task, obj, nil, 'external_map') + end + else + rspamd_logger.infox(task, "<%s> no settings returned from the external map %s: %s (code = %s)", + task:get_message_id(), name, err_or_data, code) + end + end +end + +-- Process IP address: converted to a table {ip, mask} +local function process_ip_condition(ip) + local out = {} + + if type(ip) == "table" then + for _, v in ipairs(ip) do + table.insert(out, process_ip_condition(v)) + end + elseif type(ip) == "string" then + local slash = string.find(ip, '/') + + if not slash then + -- Just a plain IP address + local res = rspamd_ip.from_string(ip) + + if res:is_valid() then + out[1] = res:to_string() + out[2] = 0 + else + -- It can still be a map + out[1] = ip + end + else + local res = rspamd_ip.from_string(string.sub(ip, 1, slash - 1)) + local mask = tonumber(string.sub(ip, slash + 1)) + + if res:is_valid() then + out[1] = res:to_string() + out[2] = mask + else + rspamd_logger.errx(rspamd_config, "bad IP address: " .. ip) + return nil + end + end + else + return nil + end + + return out +end + +-- Process email like condition, converted to a table with fields: +-- name - full email (surprise!) +-- user - user part +-- domain - domain part +-- regexp - full email regexp (yes, it sucks) +local function process_email_condition(addr) + local out = {} + if type(addr) == "table" then + for _, v in ipairs(addr) do + table.insert(out, process_email_condition(v)) + end + elseif type(addr) == "string" then + if string.sub(addr, 1, 4) == "map:" then + -- It is map, don't apply any extra logic + out['name'] = addr + else + local start = string.sub(addr, 1, 1) + if start == '/' then + -- It is a regexp + local re = rspamd_regexp.create(addr) + if re then + out['regexp'] = re + else + rspamd_logger.errx(rspamd_config, "bad regexp: " .. addr) + return nil + end + + elseif start == '@' then + -- It is a domain if form @domain + out['domain'] = string.sub(addr, 2) + else + -- Check user@domain parts + local at = string.find(addr, '@') + if at then + -- It is full address + out['name'] = addr + else + -- It is a user + out['user'] = addr + end + end + end + else + return nil + end + + return out +end + +-- Convert a plain string condition to a table: +-- check - string to match +-- regexp - regexp to match +local function process_string_condition(addr) + local out = {} + if type(addr) == "table" then + for _, v in ipairs(addr) do + table.insert(out, process_string_condition(v)) + end + elseif type(addr) == "string" then + if string.sub(addr, 1, 4) == "map:" then + -- It is map, don't apply any extra logic + out['check'] = addr + else + local start = string.sub(addr, 1, 1) + if start == '/' then + -- It is a regexp + local re = rspamd_regexp.create(addr) + if re then + out['regexp'] = re + else + rspamd_logger.errx(rspamd_config, "bad regexp: " .. addr) + return nil + end + + else + out['check'] = addr + end + end + else + return nil + end + + return out +end + +local function get_priority (elt) + local pri_tonum = function(p) + if p then + if type(p) == "number" then + return tonumber(p) + elseif type(p) == "string" then + if p == "high" then + return 3 + elseif p == "medium" then + return 2 + end + + end + + end + + return 1 + end + + return pri_tonum(elt['priority']) +end + +-- Used to create a checking closure: if value matches expected somehow, return true +local function gen_check_closure(expected, check_func) + return function(value) + if not value then + return false + end + + if type(value) == 'function' then + value = value() + end + + if value then + + if not check_func then + check_func = function(a, b) + return a == b + end + end + + local ret + if type(expected) == 'table' then + ret = fun.any(function(d) + return check_func(d, value) + end, expected) + else + ret = check_func(expected, value) + end + if ret then + return true + end + end + + return false + end +end + +-- Process settings based on their priority +local function process_settings_table(tbl, allow_ids, mempool, is_static) + + -- Check the setting element internal data + local process_setting_elt = function(name, elt) + + lua_util.debugm(N, rspamd_config, 'process settings "%s"', name) + + local out = {} + + local checks = {} + if elt.ip then + local ips_table = process_ip_condition(elt['ip']) + + if ips_table then + lua_util.debugm(N, rspamd_config, 'added ip condition to "%s": %s', + name, ips_table) + checks.ip = { + check = gen_check_closure(convert_to_table(elt.ip, ips_table), check_ip_setting), + extract = function(task) + local ip = task:get_from_ip() + if ip and ip:is_valid() then + return ip + end + return nil + end, + } + end + end + if elt.ip_map then + local ips_map = lua_maps.map_add_from_ucl(elt.ip_map, 'radix', + 'settings ip map for ' .. name) + + if ips_map then + lua_util.debugm(N, rspamd_config, 'added ip_map condition to "%s"', + name) + checks.ip_map = { + check = gen_check_closure(ips_map, check_map_setting), + extract = function(task) + local ip = task:get_from_ip() + if ip and ip:is_valid() then + return ip + end + return nil + end, + } + end + end + + if elt.client_ip then + local client_ips_table = process_ip_condition(elt.client_ip) + + if client_ips_table then + lua_util.debugm(N, rspamd_config, 'added client_ip condition to "%s": %s', + name, client_ips_table) + checks.client_ip = { + check = gen_check_closure(convert_to_table(elt.client_ip, client_ips_table), + check_ip_setting), + extract = function(task) + local ip = task:get_client_ip() + if ip:is_valid() then + return ip + end + return nil + end, + } + end + end + if elt.client_ip_map then + local ips_map = lua_maps.map_add_from_ucl(elt.ip_map, 'radix', + 'settings client ip map for ' .. name) + + if ips_map then + lua_util.debugm(N, rspamd_config, 'added client ip_map condition to "%s"', + name) + checks.client_ip_map = { + check = gen_check_closure(ips_map, check_map_setting), + extract = function(task) + local ip = task:get_client_ip() + if ip and ip:is_valid() then + return ip + end + return nil + end, + } + end + end + + if elt.from then + local from_condition = process_email_condition(elt.from) + + if from_condition then + lua_util.debugm(N, rspamd_config, 'added from condition to "%s": %s', + name, from_condition) + checks.from = { + check = gen_check_closure(convert_to_table(elt.from, from_condition), + check_addr_setting), + extract = function(task) + return task:get_from(1) + end, + } + end + end + + if elt.rcpt then + local rcpt_condition = process_email_condition(elt.rcpt) + if rcpt_condition then + lua_util.debugm(N, rspamd_config, 'added rcpt condition to "%s": %s', + name, rcpt_condition) + checks.rcpt = { + check = gen_check_closure(convert_to_table(elt.rcpt, rcpt_condition), + check_addr_setting), + extract = function(task) + return task:get_recipients(1) + end, + } + end + end + + if elt.from_mime then + local from_mime_condition = process_email_condition(elt.from_mime) + + if from_mime_condition then + lua_util.debugm(N, rspamd_config, 'added from_mime condition to "%s": %s', + name, from_mime_condition) + checks.from_mime = { + check = gen_check_closure(convert_to_table(elt.from_mime, from_mime_condition), + check_addr_setting), + extract = function(task) + return task:get_from(2) + end, + } + end + end + + if elt.rcpt_mime then + local rcpt_mime_condition = process_email_condition(elt.rcpt_mime) + if rcpt_mime_condition then + lua_util.debugm(N, rspamd_config, 'added rcpt mime condition to "%s": %s', + name, rcpt_mime_condition) + checks.rcpt_mime = { + check = gen_check_closure(convert_to_table(elt.rcpt_mime, rcpt_mime_condition), + check_addr_setting), + extract = function(task) + return task:get_recipients(2) + end, + } + end + end + + if elt.user then + local user_condition = process_email_condition(elt.user) + if user_condition then + lua_util.debugm(N, rspamd_config, 'added user condition to "%s": %s', + name, user_condition) + checks.user = { + check = gen_check_closure(convert_to_table(elt.user, user_condition), + check_addr_setting), + extract = function(task) + local uname = task:get_user() + local user = {} + if uname then + user[1] = {} + local localpart, domainpart = string.gmatch(uname, "(.+)@(.+)")() + if localpart then + user[1]["user"] = localpart + user[1]["domain"] = domainpart + user[1]["addr"] = uname + else + user[1]["user"] = uname + user[1]["addr"] = uname + end + + return user + end + + return nil + end, + } + end + end + + if elt.hostname then + local hostname_condition = process_string_condition(elt.hostname) + if hostname_condition then + lua_util.debugm(N, rspamd_config, 'added hostname condition to "%s": %s', + name, hostname_condition) + checks.hostname = { + check = gen_check_closure(convert_to_table(elt.hostname, hostname_condition), + check_string_setting), + extract = function(task) + return task:get_hostname() or '' + end, + } + end + end + + if elt.authenticated then + lua_util.debugm(N, rspamd_config, 'added authenticated condition to "%s"', + name) + checks.authenticated = { + check = function(value) + if value then + return true + end + return false + end, + extract = function(task) + return task:get_user() + end + } + end + + if elt['local'] then + lua_util.debugm(N, rspamd_config, 'added local condition to "%s"', + name) + checks['local'] = { + check = function(value) + if value then + return true + end + return false + end, + extract = function(task) + local ip = task:get_from_ip() + if not ip or not ip:is_valid() then + return nil + end + + if ip:is_local() then + return true + else + return nil + end + end + } + end + + local aliases = {} + -- This function is used to convert compound condition with + -- generic type and specific part (e.g. `header`, `Content-Transfer-Encoding`) + -- to a set of usable check elements: + -- `generic:specific` - most common part + -- `generic:<order>` - e.g. `header:1` for the first header + -- `generic:safe` - replace unsafe stuff with safe + lowercase + -- also aliases entry is set to avoid implicit expression + local function process_compound_condition(cond, generic, specific) + local full_key = generic .. ':' .. specific + checks[full_key] = cond + + -- Try numeric key + for i = 1, 1000 do + local num_key = generic .. ':' .. tostring(i) + if not checks[num_key] then + checks[num_key] = cond + aliases[num_key] = true + break + end + end + + local safe_key = generic .. ':' .. + specific:gsub('[:%-+&|><]', '_') + :gsub('%(', '[') + :gsub('%)', ']') + :lower() + + if not checks[safe_key] then + checks[safe_key] = cond + aliases[full_key] = true + end + + return safe_key + end + -- Headers are tricky: + -- We create an closure with extraction function depending on header name + -- We also inserts it into `checks` table as an atom in form header:<hname> + -- Check function depends on the input: + -- * for something that looks like `header = "/bar/"` we create a regexp + -- * for something that looks like `header = true` we just check the existence + local function process_header_elt(table_element, extractor_func) + if elt[table_element] then + for k, v in pairs(elt[table_element]) do + if type(v) == 'string' then + local re = rspamd_regexp.create(v) + if re then + local cond = { + check = function(values) + return fun.any(function(c) + return re:match(c) + end, values) + end, + extract = extractor_func(k), + } + local skey = process_compound_condition(cond, table_element, + k) + lua_util.debugm(N, rspamd_config, 'added %s condition to "%s": %s =~ %s', + skey, name, k, v) + end + elseif type(v) == 'boolean' then + local cond = { + check = function(values) + if #values == 0 then + return (not v) + end + return v + end, + extract = extractor_func(k), + } + + local skey = process_compound_condition(cond, table_element, + k) + lua_util.debugm(N, rspamd_config, 'added %s condition to "%s": %s == %s', + skey, name, k, v) + else + rspamd_logger.errx(rspamd_config, 'invalid %s %s = %s', table_element, k, v) + end + end + end + end + + process_header_elt('request_header', function(hname) + return function(task) + local rh = task:get_request_header(hname) + if rh then + return { rh } + end + return {} + end + end) + process_header_elt('header', function(hname) + return function(task) + local rh = task:get_header_full(hname) + if rh then + return fun.totable(fun.map(function(h) + return h.decoded + end, rh)) + end + return {} + end + end) + + if elt.selector then + local sel = lua_selectors.create_selector_closure(rspamd_config, elt.selector, + elt.delimiter or "") + + if sel then + local cond = { + check = function(values) + return fun.any(function(c) + return c + end, values) + end, + extract = sel, + } + local skey = process_compound_condition(cond, 'selector', elt.selector) + lua_util.debugm(N, rspamd_config, 'added selector condition to "%s": %s', + name, skey) + end + + end + + -- Special, special case! + local inverse = false + if elt.inverse then + lua_util.debugm(N, rspamd_config, 'added inverse condition to "%s"', + name) + inverse = true + end + + -- Count checks and create Rspamd expression from a set of rules + local nchecks = 0 + for k, _ in pairs(checks) do + if not aliases[k] then + nchecks = nchecks + 1 + end + end + + if nchecks > 0 then + -- Now we can deal with the expression! + if not elt.expression then + -- Artificial & expression to deal with the legacy parts + -- Here we get all keys and concatenate them with '&&' + local s = ' && ' + -- By De Morgan laws + if inverse then + s = ' || ' + end + -- Exclude aliases and join all checks by key + local expr_str = table.concat(lua_util.keys(fun.filter( + function(k, _) + return not aliases[k] + end, + checks)), s) + + if inverse then + expr_str = string.format('!(%s)', expr_str) + end + + elt.expression = expr_str + lua_util.debugm(N, rspamd_config, 'added implicit settings expression for %s: %s', + name, expr_str) + end + + -- Parse expression's sanity + local function parse_atom(str) + local atom = table.concat(fun.totable(fun.take_while(function(c) + if string.find(', \t()><+!|&\n', c, 1, true) then + return false + end + return true + end, fun.iter(str))), '') + + if checks[atom] then + return atom + end + + rspamd_logger.errx(rspamd_config, + 'use of undefined element "%s" when parsing settings expression, known checks: %s', + atom, table.concat(fun.totable(fun.map(function(k, _) + return k + end, checks)), ',')) + + return nil + end + + local rspamd_expression = require "rspamd_expression" + out.expression = rspamd_expression.create(elt.expression, parse_atom, + mempool) + out.checks = checks + + if not out.expression then + rspamd_logger.errx(rspamd_config, 'cannot parse expression %s for %s', + elt.expression, name) + else + lua_util.debugm(N, rspamd_config, 'registered settings %s with %s checks', + name, nchecks) + end + else + if not elt.disabled and elt.external_map then + lua_util.debugm(N, rspamd_config, 'registered settings %s with no checks, assume it as implicit', + name) + out.implicit = 1 + end + end + + -- Process symbols part/apply part + if elt['symbols'] then + lua_util.debugm(N, rspamd_config, 'added symbols condition to "%s": %s', + name, elt.symbols) + out['symbols'] = elt['symbols'] + end + + --[[ + external_map = { + map = { ... }; + selector = "..."; + } + --]] + if type(elt.external_map) == 'table' + and elt.external_map.map and elt.external_map.selector then + local maybe_external_map = {} + maybe_external_map.map = lua_maps.map_add_from_ucl(elt.external_map.map, "", + string.format("External map for settings element %s", name), + gen_settings_external_cb(name)) + maybe_external_map.selector = lua_selectors.create_selector_closure_fn(rspamd_config, + rspamd_config, elt.external_map.selector, ";", lua_selectors.kv_table_from_pairs) + + if maybe_external_map.map and maybe_external_map.selector then + rspamd_logger.infox(rspamd_config, "added external map for user's settings %s", name) + out.external_map = maybe_external_map + else + local incorrect_element + if not maybe_external_map.map then + incorrect_element = "map definition" + else + incorrect_element = "selector definition" + end + rspamd_logger.warnx(rspamd_config, "cannot add external map for user's settings; incorrect element: %s", + incorrect_element) + out.external_map = nil + end + end + + if not elt.external_map then + if elt['apply'] then + -- Just insert all metric results to the action key + out['apply'] = elt['apply'] + elseif elt['whitelist'] or elt['want_spam'] then + out['whitelist'] = true + else + rspamd_logger.errx(rspamd_config, "no actions in settings: " .. name) + return nil + end + end + + if allow_ids then + if not elt.id then + elt.id = name + end + + if elt['id'] then + -- We are here from a postload script + out.id = lua_settings.register_settings_id(elt.id, out, true) + lua_util.debugm(N, rspamd_config, + 'added settings id to "%s": %s -> %s', + name, elt.id, out.id) + end + + if not is_static then + -- If we apply that from map + -- In fact, it is useless and evil but who cares... + if elt.apply and elt.apply.symbols then + -- Register virtual symbols + for k, v in pairs(elt.apply.symbols) do + local rtb = { + type = 'virtual', + parent = module_sym_id, + } + if type(k) == 'number' and type(v) == 'string' then + rtb.name = v + elseif type(k) == 'string' then + rtb.name = k + end + if out.id then + rtb.allowed_ids = tostring(elt.id) + end + rspamd_config:register_symbol(rtb) + end + end + end + else + if elt['id'] then + rspamd_logger.errx(rspamd_config, + 'cannot set static IDs from dynamic settings, please read the docs') + end + end + + return out + end + + settings_initialized = false + -- filter trash in the input + local ft = fun.filter( + function(_, elt) + if type(elt) == "table" then + return true + end + return false + end, tbl) + + -- clear all settings + max_pri = 0 + local nrules = 0 + for k in pairs(settings) do + settings[k] = {} + end + -- fill new settings by priority + fun.for_each(function(k, v) + local pri = get_priority(v) + if pri > max_pri then + max_pri = pri + end + if not settings[pri] then + settings[pri] = {} + end + local s = process_setting_elt(k, v) + if s then + table.insert(settings[pri], { name = k, rule = s }) + nrules = nrules + 1 + end + end, ft) + -- sort settings with equal priorities in alphabetical order + for pri, _ in pairs(settings) do + table.sort(settings[pri], function(a, b) + return a.name < b.name + end) + end + + settings_initialized = true + lua_settings.load_all_settings(true) + rspamd_logger.infox(rspamd_config, 'loaded %s elements of settings', nrules) + + return true +end + +-- Parse settings map from the ucl line +local settings_map_pool + +local function process_settings_map(map_text) + local parser = ucl.parser() + local res, err = parser:parse_text(map_text) + + if not res then + rspamd_logger.warnx(rspamd_config, 'cannot parse settings map: ' .. err) + else + if settings_map_pool then + settings_map_pool:destroy() + end + + settings_map_pool = rspamd_mempool.create() + local obj = parser:get_object() + if obj['settings'] then + process_settings_table(obj['settings'], false, + settings_map_pool, false) + else + process_settings_table(obj, false, settings_map_pool, + false) + end + end + + return res +end + +local function gen_redis_callback(handler, id) + return function(task) + local key = handler(task) + + local function redis_settings_cb(err, data) + if not err and type(data) == 'table' then + for _, d in ipairs(data) do + if type(d) == 'string' then + local parser = ucl.parser() + local res, ucl_err = parser:parse_text(d) + if not res then + rspamd_logger.warnx(rspamd_config, 'cannot parse settings from redis: %s', + ucl_err) + else + local obj = parser:get_object() + rspamd_logger.infox(task, "<%1> apply settings according to redis rule %2", + task:get_message_id(), id) + apply_settings(task, obj, nil, 'redis') + break + end + end + end + elseif err then + rspamd_logger.errx(task, 'Redis error: %1', err) + end + end + + if not key then + lua_util.debugm(N, task, 'handler number %s returned nil', id) + return + end + + local keys + if type(key) == 'table' then + keys = key + else + keys = { key } + end + key = keys[1] + + local ret, _, _ = rspamd_redis_make_request(task, + redis_params, -- connect params + key, -- hash key + false, -- is write + redis_settings_cb, --callback + 'MGET', -- command + keys -- arguments + ) + if not ret then + rspamd_logger.errx(task, 'Redis MGET failed: %s', ret) + end + end +end + +local redis_section = rspamd_config:get_all_opt("settings_redis") +local redis_key_handlers = {} + +if redis_section then + redis_params = rspamd_parse_redis_server('settings_redis') + if redis_params then + local handlers = redis_section.handlers + + for id, h in pairs(handlers) do + local chunk, err = load(h) + + if not chunk then + rspamd_logger.errx(rspamd_config, 'Cannot load handler from string: %s', + tostring(err)) + else + local res, func = pcall(chunk) + if not res then + rspamd_logger.errx(rspamd_config, 'Cannot add handler from string: %s', + tostring(func)) + else + redis_key_handlers[id] = func + end + end + end + end + + fun.each(function(id, h) + rspamd_config:register_symbol({ + name = 'REDIS_SETTINGS' .. tostring(id), + type = 'prefilter', + callback = gen_redis_callback(h, id), + priority = lua_util.symbols_priorities.top, + flags = 'empty,nostat', + augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) }, + }) + end, redis_key_handlers) +end + +module_sym_id = rspamd_config:register_symbol({ + name = 'SETTINGS_CHECK', + type = 'prefilter', + callback = check_settings, + priority = lua_util.symbols_priorities.top, + flags = 'empty,nostat,explicit_disable,ignore_passthrough', +}) + +local set_section = rspamd_config:get_all_opt("settings") + +if set_section and set_section[1] and type(set_section[1]) == "string" then + -- Just a map of ucl + local map_attrs = { + url = set_section[1], + description = "settings map", + callback = process_settings_map, + opaque_data = true + } + if not rspamd_config:add_map(map_attrs) then + rspamd_logger.errx(rspamd_config, 'cannot load settings from %1', set_section) + end +elseif set_section and type(set_section) == "table" then + settings_map_pool = rspamd_mempool.create() + -- We need to check this table and register static symbols first + -- Postponed settings init is needed to ensure that all symbols have been + -- registered BEFORE settings plugin. Otherwise, we can have inconsistent settings expressions + fun.each(function(_, elt) + if elt.register_symbols then + for k, v in pairs(elt.register_symbols) do + local rtb = { + type = 'virtual', + parent = module_sym_id, + } + if type(k) == 'number' and type(v) == 'string' then + rtb.name = v + elseif type(k) == 'string' then + rtb.name = k + if type(v) == 'table' then + for kk, vv in pairs(v) do + -- Enrich table wih extra values + rtb[kk] = vv + end + end + end + rspamd_config:register_symbol(rtb) + end + end + if elt.apply and elt.apply.symbols then + -- Register virtual symbols + for k, v in pairs(elt.apply.symbols) do + local rtb = { + type = 'virtual', + parent = module_sym_id, + } + if type(k) == 'number' and type(v) == 'string' then + rtb.name = v + elseif type(k) == 'string' then + rtb.name = k + end + rspamd_config:register_symbol(rtb) + end + end + end, + -- Include only settings, exclude all maps + fun.filter( + function(_, elt) + if type(elt) == "table" then + return true + end + return false + end, set_section) + ) + + rspamd_config:add_post_init(function() + process_settings_table(set_section, true, settings_map_pool, true) + end, 100) +end + +rspamd_config:add_config_unload(function() + if settings_map_pool then + settings_map_pool:destroy() + end +end) diff --git a/src/plugins/lua/spamassassin.lua b/src/plugins/lua/spamassassin.lua new file mode 100644 index 0000000..3ea7944 --- /dev/null +++ b/src/plugins/lua/spamassassin.lua @@ -0,0 +1,1774 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if confighelp then + return +end + +-- This plugin is intended to read and parse spamassassin rules with regexp +-- rules. SA plugins or statistics are not supported + +local E = {} +local N = 'spamassassin' + +local rspamd_logger = require "rspamd_logger" +local rspamd_regexp = require "rspamd_regexp" +local rspamd_expression = require "rspamd_expression" +local rspamd_trie = require "rspamd_trie" +local util = require "rspamd_util" +local lua_util = require "lua_util" +local fun = require "fun" + +-- Known plugins +local known_plugins = { + 'Mail::SpamAssassin::Plugin::FreeMail', + 'Mail::SpamAssassin::Plugin::HeaderEval', + 'Mail::SpamAssassin::Plugin::ReplaceTags', + 'Mail::SpamAssassin::Plugin::RelayEval', + 'Mail::SpamAssassin::Plugin::MIMEEval', + 'Mail::SpamAssassin::Plugin::BodyEval', + 'Mail::SpamAssassin::Plugin::MIMEHeader', + 'Mail::SpamAssassin::Plugin::WLBLEval', + 'Mail::SpamAssassin::Plugin::HTMLEval', +} + +-- Table that replaces SA symbol with rspamd equivalent +-- Used for dependency resolution +local symbols_replacements = { + -- SPF replacements + USER_IN_SPF_WHITELIST = 'WHITELIST_SPF', + USER_IN_DEF_SPF_WL = 'WHITELIST_SPF', + SPF_PASS = 'R_SPF_ALLOW', + SPF_FAIL = 'R_SPF_FAIL', + SPF_SOFTFAIL = 'R_SPF_SOFTFAIL', + SPF_HELO_PASS = 'R_SPF_ALLOW', + SPF_HELLO_FAIL = 'R_SPF_FAIL', + SPF_HELLO_SOFTFAIL = 'R_SPF_SOFTFAIL', + -- DKIM replacements + USER_IN_DKIM_WHITELIST = 'WHITELIST_DKIM', + USER_IN_DEF_DKIM_WL = 'WHITELIST_DKIM', + DKIM_VALID = 'R_DKIM_ALLOW', + -- SURBL replacements + URIBL_SBL_A = 'URIBL_SBL', + URIBL_DBL_SPAM = 'DBL_SPAM', + URIBL_DBL_PHISH = 'DBL_PHISH', + URIBL_DBL_MALWARE = 'DBL_MALWARE', + URIBL_DBL_BOTNETCC = 'DBL_BOTNET', + URIBL_DBL_ABUSE_SPAM = 'DBL_ABUSE', + URIBL_DBL_ABUSE_REDIR = 'DBL_ABUSE_REDIR', + URIBL_DBL_ABUSE_MALW = 'DBL_ABUSE_MALWARE', + URIBL_DBL_ABUSE_BOTCC = 'DBL_ABUSE_BOTNET', + URIBL_WS_SURBL = 'WS_SURBL_MULTI', + URIBL_PH_SURBL = 'PH_SURBL_MULTI', + URIBL_MW_SURBL = 'MW_SURBL_MULTI', + URIBL_CR_SURBL = 'CRACKED_SURBL', + URIBL_ABUSE_SURBL = 'ABUSE_SURBL', + -- Misc rules + BODY_URI_ONLY = 'R_EMPTY_IMAGE', + HTML_IMAGE_ONLY_04 = 'HTML_SHORT_LINK_IMG_1', + HTML_IMAGE_ONLY_08 = 'HTML_SHORT_LINK_IMG_1', + HTML_IMAGE_ONLY_12 = 'HTML_SHORT_LINK_IMG_1', + HTML_IMAGE_ONLY_16 = 'HTML_SHORT_LINK_IMG_2', + HTML_IMAGE_ONLY_20 = 'HTML_SHORT_LINK_IMG_2', + HTML_IMAGE_ONLY_24 = 'HTML_SHORT_LINK_IMG_3', + HTML_IMAGE_ONLY_28 = 'HTML_SHORT_LINK_IMG_3', + HTML_IMAGE_ONLY_32 = 'HTML_SHORT_LINK_IMG_3', +} + +-- Internal variables +local rules = {} +local atoms = {} +local scores = {} +local scores_added = {} +local external_deps = {} +local freemail_domains = {} +local pcre_only_regexps = {} +local freemail_trie +local replace = { + tags = {}, + pre = {}, + inter = {}, + post = {}, + rules = {}, +} +local internal_regexp = { + date_shift = rspamd_regexp.create("^\\(\\s*'((?:-?\\d+)|(?:undef))'\\s*,\\s*'((?:-?\\d+)|(?:undef))'\\s*\\)$") +} + +-- Mail::SpamAssassin::Plugin::WLBLEval plugin +local sa_lists = { + from_blacklist = {}, + from_whitelist = {}, + from_def_whitelist = {}, + to_blacklist = {}, + to_whitelist = {}, + elts = 0, +} + +local func_cache = {} +local section = rspamd_config:get_all_opt("spamassassin") +if not (section and type(section) == 'table') then + rspamd_logger.infox(rspamd_config, 'Module is unconfigured') +end + +-- Minimum score to treat symbols as meta +local meta_score_alpha = 0.5 + +-- Maximum size of regexp checked +local match_limit = 0 + +-- Default priority of the scores registered in the metric +-- Historically this is set to 2 allowing SA scores to override Rspamd scores +local scores_priority = 2 + +local function split(str, delim) + local result = {} + + if not delim then + delim = '[^%s]+' + end + + for token in string.gmatch(str, delim) do + table.insert(result, token) + end + + return result +end + +local function replace_symbol(s) + local rspamd_symbol = symbols_replacements[s] + if not rspamd_symbol then + return s, false + end + return rspamd_symbol, true +end + +local ffi +if type(jit) == 'table' then + ffi = require("ffi") + ffi.cdef [[ + int rspamd_re_cache_type_from_string (const char *str); + int rspamd_re_cache_process_ffi (void *ptask, + void *pre, + int type, + const char *type_data, + int is_strong); +]] +end + +local function process_regexp_opt(re, task, re_type, header, strong) + --[[ + -- This is now broken with lua regexp conditions! + if type(jit) == 'table' then + -- Use ffi call + local itype = ffi.C.rspamd_re_cache_type_from_string(re_type) + + if not strong then + strong = 0 + else + strong = 1 + end + local iret = ffi.C.rspamd_re_cache_process_ffi (task, re, itype, header, strong) + + return tonumber(iret) + else + return task:process_regexp(re, re_type, header, strong) + end + --]] + return task:process_regexp(re, re_type, header, strong) +end + +local function is_pcre_only(name) + if pcre_only_regexps[name] then + rspamd_logger.infox(rspamd_config, 'mark re %s as PCRE only', name) + return true + end + return false +end + +local function handle_header_def(hline, cur_rule) + --Now check for modifiers inside header's name + local hdrs = split(hline, '[^|]+') + local hdr_params = {} + local cur_param = {} + -- Check if an re is an ordinary re + local ordinary = true + + for _, h in ipairs(hdrs) do + if h == 'ALL' or h == 'ALL:raw' then + ordinary = false + cur_rule['type'] = 'function' + -- Pack closure + local re = cur_rule['re'] + -- Rule to match all headers + rspamd_config:register_regexp({ + re = re, + type = 'allheader', + pcre_only = is_pcre_only(cur_rule['symbol']), + }) + cur_rule['function'] = function(task) + if not re then + rspamd_logger.errx(task, 're is missing for rule %1', h) + return 0 + end + + return process_regexp_opt(re, task, 'allheader') + end + else + local args = split(h, '[^:]+') + cur_param['strong'] = false + cur_param['raw'] = false + cur_param['header'] = args[1] + + if args[2] then + -- We have some ops that are required for the header, so it's not ordinary + ordinary = false + end + + fun.each(function(func) + if func == 'addr' then + cur_param['function'] = function(str) + local addr_parsed = util.parse_mail_address(str) + local ret = {} + if addr_parsed then + for _, elt in ipairs(addr_parsed) do + if elt['addr'] then + table.insert(ret, elt['addr']) + end + end + end + + return ret + end + elseif func == 'name' then + cur_param['function'] = function(str) + local addr_parsed = util.parse_mail_address(str) + local ret = {} + if addr_parsed then + for _, elt in ipairs(addr_parsed) do + if elt['name'] then + table.insert(ret, elt['name']) + end + end + end + + return ret + end + elseif func == 'raw' then + cur_param['raw'] = true + elseif func == 'case' then + cur_param['strong'] = true + else + rspamd_logger.warnx(rspamd_config, 'Function %1 is not supported in %2', + func, cur_rule['symbol']) + end + end, fun.tail(args)) + + local function split_hdr_param(param, headers) + for _, hh in ipairs(headers) do + local nparam = {} + for k, v in pairs(param) do + if k ~= 'header' then + nparam[k] = v + end + end + + nparam['header'] = hh + table.insert(hdr_params, nparam) + end + end + -- Some header rules require splitting to check of multiple headers + if cur_param['header'] == 'MESSAGEID' then + -- Special case for spamassassin + ordinary = false + split_hdr_param(cur_param, { + 'Message-ID', + 'X-Message-ID', + 'Resent-Message-ID' }) + elseif cur_param['header'] == 'ToCc' then + ordinary = false + split_hdr_param(cur_param, { 'To', 'Cc', 'Bcc' }) + else + table.insert(hdr_params, cur_param) + end + end + + cur_rule['ordinary'] = ordinary + cur_rule['header'] = hdr_params + end +end + +local function freemail_search(input) + local res = 0 + local function trie_callback(number, pos) + lua_util.debugm(N, rspamd_config, 'Matched pattern %1 at pos %2', freemail_domains[number], pos) + res = res + 1 + end + + if input then + freemail_trie:match(input, trie_callback, true) + end + + return res +end + +local function gen_eval_rule(arg) + local eval_funcs = { + { 'check_freemail_from', function(task) + local from = task:get_from('mime') + if from and from[1] then + return freemail_search(string.lower(from[1]['addr'])) + end + return 0 + end }, + { 'check_freemail_replyto', + function(task) + return freemail_search(task:get_header('Reply-To')) + end + }, + { 'check_freemail_header', + function(task, remain) + -- Remain here contains one or two args: header and regexp to match + local larg = string.match(remain, "^%(%s*['\"]([^%s]+)['\"]%s*%)$") + local re = nil + if not larg then + larg, re = string.match(remain, "^%(%s*['\"]([^%s]+)['\"]%s*,%s*['\"]([^%s]+)['\"]%s*%)$") + end + + if larg then + local h + if larg == 'EnvelopeFrom' then + h = task:get_from('smtp') + if h then + h = h[1]['addr'] + end + else + h = task:get_header(larg) + end + if h then + local hdr_freemail = freemail_search(string.lower(h)) + + if hdr_freemail > 0 and re then + local r = rspamd_regexp.create_cached(re) + if r then + if r:match(h) then + return 1 + end + return 0 + else + rspamd_logger.infox(rspamd_config, 'cannot create regexp %1', re) + return 0 + end + end + + return hdr_freemail + end + end + + return 0 + end + }, + { + 'check_for_missing_to_header', + function(task) + local th = task:get_recipients('mime') + if not th or #th == 0 then + return 1 + end + + return 0 + end + }, + { + 'check_relays_unparseable', + function(task) + local rh_mime = task:get_header_full('Received') + local rh_parsed = task:get_received_headers() + + local rh_cnt = 0 + if rh_mime then + rh_cnt = #rh_mime + end + local parsed_cnt = 0 + if rh_parsed then + parsed_cnt = #rh_parsed + end + + return rh_cnt - parsed_cnt + end + }, + { + 'check_for_shifted_date', + function(task, remain) + -- Remain here contains two args: start and end hours shift + local matches = internal_regexp['date_shift']:search(remain, true, true) + if matches and matches[1] then + local min_diff = matches[1][2] + local max_diff = matches[1][3] + + if min_diff == 'undef' then + min_diff = 0 + else + min_diff = tonumber(min_diff) * 3600 + end + if max_diff == 'undef' then + max_diff = 0 + else + max_diff = tonumber(max_diff) * 3600 + end + + -- Now get the difference between Date and message received date + local dm = task:get_date { format = 'message', gmt = true } + local dt = task:get_date { format = 'connect', gmt = true } + local diff = dm - dt + + if (max_diff == 0 and diff >= min_diff) or + (min_diff == 0 and diff <= max_diff) or + (diff >= min_diff and diff <= max_diff) then + return 1 + end + end + + return 0 + end + }, + { + 'check_for_mime', + function(task, remain) + local larg = string.match(remain, "^%(%s*['\"]([^%s]+)['\"]%s*%)$") + + if larg then + if larg == 'mime_attachment' then + local parts = task:get_parts() + if parts then + for _, p in ipairs(parts) do + if p:get_filename() then + return 1 + end + end + end + else + rspamd_logger.infox(task, 'unimplemented mime check %1', arg) + end + end + + return 0 + end + }, + { + 'check_from_in_blacklist', + function(task) + local from = task:get_from('mime') + if ((from or E)[1] or E).addr then + if sa_lists['from_blacklist'][string.lower(from[1]['addr'])] then + return 1 + end + end + + return 0 + end + }, + { + 'check_from_in_whitelist', + function(task) + local from = task:get_from('mime') + if ((from or E)[1] or E).addr then + if sa_lists['from_whitelist'][string.lower(from[1]['addr'])] then + return 1 + end + end + + return 0 + end + }, + { + 'check_from_in_default_whitelist', + function(task) + local from = task:get_from('mime') + if ((from or E)[1] or E).addr then + if sa_lists['from_def_whitelist'][string.lower(from[1]['addr'])] then + return 1 + end + end + + return 0 + end + }, + { + 'check_to_in_blacklist', + function(task) + local rcpt = task:get_recipients('mime') + if rcpt then + for _, r in ipairs(rcpt) do + if sa_lists['to_blacklist'][string.lower(r['addr'])] then + return 1 + end + end + end + + return 0 + end + }, + { + 'check_to_in_whitelist', + function(task) + local rcpt = task:get_recipients('mime') + if rcpt then + for _, r in ipairs(rcpt) do + if sa_lists['to_whitelist'][string.lower(r['addr'])] then + return 1 + end + end + end + + return 0 + end + }, + { + 'html_tag_exists', + function(task, remain) + local tp = task:get_text_parts() + + for _, p in ipairs(tp) do + if p:is_html() then + local hc = p:get_html() + + if hc:has_tag(remain) then + return 1 + end + end + end + + return 0 + end + } + } + + for _, f in ipairs(eval_funcs) do + local pat = string.format('^%s', f[1]) + local first, last = string.find(arg, pat) + + if first then + local func_arg = string.sub(arg, last + 1) + return function(task) + return f[2](task, func_arg) + end + end + end +end + +-- Returns parser function or nil +local function maybe_parse_sa_function(line) + local arg + local elts = split(line, '[^:]+') + arg = elts[2] + + lua_util.debugm(N, rspamd_config, 'trying to parse SA function %1 with args %2', + elts[1], elts[2]) + local substitutions = { + { '^exists:', + function(task) + -- filter + local hdrs_check + if arg == 'MESSAGEID' then + hdrs_check = { + 'Message-ID', + 'X-Message-ID', + 'Resent-Message-ID' + } + elseif arg == 'ToCc' then + hdrs_check = { 'To', 'Cc', 'Bcc' } + else + hdrs_check = { arg } + end + + for _, h in ipairs(hdrs_check) do + if task:has_header(h) then + return 1 + end + end + return 0 + end, + }, + { '^eval:', + function(task) + local func = func_cache[arg] + if not func then + func = gen_eval_rule(arg) + func_cache[arg] = func + end + + if not func then + rspamd_logger.errx(task, 'cannot find appropriate eval rule for function %1', + arg) + else + return func(task) + end + + return 0 + end + }, + } + + for _, s in ipairs(substitutions) do + if string.find(line, s[1]) then + return s[2] + end + end + + return nil +end + +local function words_to_re(words, start) + return table.concat(fun.totable(fun.drop_n(start, words)), " "); +end + +local function process_tflags(rule, flags) + fun.each(function(flag) + if flag == 'publish' then + rule['publish'] = true + elseif flag == 'multiple' then + rule['multiple'] = true + elseif string.match(flag, '^maxhits=(%d+)$') then + rule['maxhits'] = tonumber(string.match(flag, '^maxhits=(%d+)$')) + elseif flag == 'nice' then + rule['nice'] = true + end + end, fun.drop_n(1, flags)) + + if rule['re'] then + if rule['maxhits'] then + rule['re']:set_max_hits(rule['maxhits']) + elseif rule['multiple'] then + rule['re']:set_max_hits(0) + else + rule['re']:set_max_hits(1) + end + end +end + +local function process_replace(words, tbl) + local re = words_to_re(words, 2) + tbl[words[2]] = re +end + +local function process_sa_conf(f) + local cur_rule = {} + local valid_rule = false + + local function insert_cur_rule() + if cur_rule['type'] ~= 'meta' and cur_rule['publish'] then + -- Create meta rule from this rule + local nsym = '__fake' .. cur_rule['symbol'] + local nrule = { + type = 'meta', + symbol = cur_rule['symbol'], + score = cur_rule['score'], + meta = nsym, + description = cur_rule['description'], + } + rules[nrule['symbol']] = nrule + cur_rule['symbol'] = nsym + end + -- We have previous rule valid + if not cur_rule['symbol'] then + rspamd_logger.errx(rspamd_config, 'bad rule definition: %1', cur_rule) + end + rules[cur_rule['symbol']] = cur_rule + cur_rule = {} + valid_rule = false + end + + local function parse_score(words) + if #words == 3 then + -- score rule <x> + lua_util.debugm(N, rspamd_config, 'found score for %1: %2', words[2], words[3]) + return tonumber(words[3]) + elseif #words == 6 then + -- score rule <x1> <x2> <x3> <x4> + -- we assume here that bayes and network are enabled and select <x4> + lua_util.debugm(N, rspamd_config, 'found score for %1: %2', words[2], words[6]) + return tonumber(words[6]) + else + rspamd_logger.errx(rspamd_config, 'invalid score for %1', words[2]) + end + + return 0 + end + + local skip_to_endif = false + local if_nested = 0 + for l in f:lines() do + (function() + l = lua_util.rspamd_str_trim(l) + -- Replace bla=~/re/ with bla =~ /re/ (#2372) + l = l:gsub('([^%s])%s*([=!]~)%s*([^%s])', '%1 %2 %3') + + if string.len(l) == 0 or string.sub(l, 1, 1) == '#' then + return + end + + -- Unbalanced if/endif + if if_nested < 0 then + if_nested = 0 + end + if skip_to_endif then + if string.match(l, '^endif') then + if_nested = if_nested - 1 + + if if_nested == 0 then + skip_to_endif = false + end + elseif string.match(l, '^if') then + if_nested = if_nested + 1 + elseif string.match(l, '^else') then + -- Else counterpart for if + skip_to_endif = false + end + return + else + if string.match(l, '^ifplugin') then + local ls = split(l) + + if not fun.any(function(pl) + if pl == ls[2] then + return true + end + return false + end, known_plugins) then + skip_to_endif = true + end + if_nested = if_nested + 1 + elseif string.match(l, '^if !plugin%(') then + local pname = string.match(l, '^if !plugin%(([A-Za-z:]+)%)') + if fun.any(function(pl) + if pl == pname then + return true + end + return false + end, known_plugins) then + skip_to_endif = true + end + if_nested = if_nested + 1 + elseif string.match(l, '^if') then + -- Unknown if + skip_to_endif = true + if_nested = if_nested + 1 + elseif string.match(l, '^else') then + -- Else counterpart for if + skip_to_endif = true + elseif string.match(l, '^endif') then + if_nested = if_nested - 1 + end + end + + -- Skip comments + local words = fun.totable(fun.take_while( + function(w) + return string.sub(w, 1, 1) ~= '#' + end, + fun.filter(function(w) + return w ~= "" + end, + fun.iter(split(l))))) + + if words[1] == "header" or words[1] == 'mimeheader' then + -- header SYMBOL Header ~= /regexp/ + if valid_rule then + insert_cur_rule() + end + if words[4] and (words[4] == '=~' or words[4] == '!~') then + cur_rule['type'] = 'header' + cur_rule['symbol'] = words[2] + + if words[4] == '!~' then + cur_rule['not'] = true + end + + cur_rule['re_expr'] = words_to_re(words, 4) + local unset_comp = string.find(cur_rule['re_expr'], '%s+%[if%-unset:') + if unset_comp then + -- We have optional part that needs to be processed + local unset = string.match(string.sub(cur_rule['re_expr'], unset_comp), + '%[if%-unset:%s*([^%]%s]+)]') + cur_rule['unset'] = unset + -- Cut it down + cur_rule['re_expr'] = string.sub(cur_rule['re_expr'], 1, unset_comp - 1) + end + + cur_rule['re'] = rspamd_regexp.create(cur_rule['re_expr']) + + if not cur_rule['re'] then + rspamd_logger.warnx(rspamd_config, "Cannot parse regexp '%1' for %2", + cur_rule['re_expr'], cur_rule['symbol']) + else + cur_rule['re']:set_max_hits(1) + handle_header_def(words[3], cur_rule) + end + + if cur_rule['unset'] then + cur_rule['ordinary'] = false + end + + if words[1] == 'mimeheader' then + cur_rule['mime'] = true + else + cur_rule['mime'] = false + end + + if cur_rule['re'] and cur_rule['symbol'] and + (cur_rule['header'] or cur_rule['function']) then + valid_rule = true + cur_rule['re']:set_max_hits(1) + if cur_rule['header'] and cur_rule['ordinary'] then + for _, h in ipairs(cur_rule['header']) do + if type(h) == 'string' then + if cur_rule['mime'] then + rspamd_config:register_regexp({ + re = cur_rule['re'], + type = 'mimeheader', + header = h, + pcre_only = is_pcre_only(cur_rule['symbol']), + }) + else + rspamd_config:register_regexp({ + re = cur_rule['re'], + type = 'header', + header = h, + pcre_only = is_pcre_only(cur_rule['symbol']), + }) + end + else + h['mime'] = cur_rule['mime'] + if cur_rule['mime'] then + rspamd_config:register_regexp({ + re = cur_rule['re'], + type = 'mimeheader', + header = h['header'], + pcre_only = is_pcre_only(cur_rule['symbol']), + }) + else + if h['raw'] then + rspamd_config:register_regexp({ + re = cur_rule['re'], + type = 'rawheader', + header = h['header'], + pcre_only = is_pcre_only(cur_rule['symbol']), + }) + else + rspamd_config:register_regexp({ + re = cur_rule['re'], + type = 'header', + header = h['header'], + pcre_only = is_pcre_only(cur_rule['symbol']), + }) + end + end + end + end + cur_rule['re']:set_limit(match_limit) + cur_rule['re']:set_max_hits(1) + end + end + else + -- Maybe we know the function and can convert it + local args = words_to_re(words, 2) + local func = maybe_parse_sa_function(args) + + if func then + cur_rule['type'] = 'function' + cur_rule['symbol'] = words[2] + cur_rule['function'] = func + valid_rule = true + else + rspamd_logger.infox(rspamd_config, 'unknown function %1', args) + end + end + elseif words[1] == "body" then + -- body SYMBOL /regexp/ + if valid_rule then + insert_cur_rule() + end + + cur_rule['symbol'] = words[2] + if words[3] and (string.sub(words[3], 1, 1) == '/' + or string.sub(words[3], 1, 1) == 'm') then + cur_rule['type'] = 'sabody' + cur_rule['re_expr'] = words_to_re(words, 2) + cur_rule['re'] = rspamd_regexp.create(cur_rule['re_expr']) + if cur_rule['re'] then + + rspamd_config:register_regexp({ + re = cur_rule['re'], + type = 'sabody', + pcre_only = is_pcre_only(cur_rule['symbol']), + }) + valid_rule = true + cur_rule['re']:set_limit(match_limit) + cur_rule['re']:set_max_hits(1) + end + else + -- might be function + local args = words_to_re(words, 2) + local func = maybe_parse_sa_function(args) + + if func then + cur_rule['type'] = 'function' + cur_rule['symbol'] = words[2] + cur_rule['function'] = func + valid_rule = true + else + rspamd_logger.infox(rspamd_config, 'unknown function %1', args) + end + end + elseif words[1] == "rawbody" then + -- body SYMBOL /regexp/ + if valid_rule then + insert_cur_rule() + end + + cur_rule['symbol'] = words[2] + if words[3] and (string.sub(words[3], 1, 1) == '/' + or string.sub(words[3], 1, 1) == 'm') then + cur_rule['type'] = 'sarawbody' + cur_rule['re_expr'] = words_to_re(words, 2) + cur_rule['re'] = rspamd_regexp.create(cur_rule['re_expr']) + if cur_rule['re'] then + + rspamd_config:register_regexp({ + re = cur_rule['re'], + type = 'sarawbody', + pcre_only = is_pcre_only(cur_rule['symbol']), + }) + valid_rule = true + cur_rule['re']:set_limit(match_limit) + cur_rule['re']:set_max_hits(1) + end + else + -- might be function + local args = words_to_re(words, 2) + local func = maybe_parse_sa_function(args) + + if func then + cur_rule['type'] = 'function' + cur_rule['symbol'] = words[2] + cur_rule['function'] = func + valid_rule = true + else + rspamd_logger.infox(rspamd_config, 'unknown function %1', args) + end + end + elseif words[1] == "full" then + -- body SYMBOL /regexp/ + if valid_rule then + insert_cur_rule() + end + + cur_rule['symbol'] = words[2] + + if words[3] and (string.sub(words[3], 1, 1) == '/' + or string.sub(words[3], 1, 1) == 'm') then + cur_rule['type'] = 'message' + cur_rule['re_expr'] = words_to_re(words, 2) + cur_rule['re'] = rspamd_regexp.create(cur_rule['re_expr']) + cur_rule['raw'] = true + if cur_rule['re'] then + valid_rule = true + rspamd_config:register_regexp({ + re = cur_rule['re'], + type = 'body', + pcre_only = is_pcre_only(cur_rule['symbol']), + }) + cur_rule['re']:set_limit(match_limit) + cur_rule['re']:set_max_hits(1) + end + else + -- might be function + local args = words_to_re(words, 2) + local func = maybe_parse_sa_function(args) + + if func then + cur_rule['type'] = 'function' + cur_rule['symbol'] = words[2] + cur_rule['function'] = func + valid_rule = true + else + rspamd_logger.infox(rspamd_config, 'unknown function %1', args) + end + end + elseif words[1] == "uri" then + -- uri SYMBOL /regexp/ + if valid_rule then + insert_cur_rule() + end + cur_rule['type'] = 'uri' + cur_rule['symbol'] = words[2] + cur_rule['re_expr'] = words_to_re(words, 2) + cur_rule['re'] = rspamd_regexp.create(cur_rule['re_expr']) + if cur_rule['re'] and cur_rule['symbol'] then + valid_rule = true + rspamd_config:register_regexp({ + re = cur_rule['re'], + type = 'url', + pcre_only = is_pcre_only(cur_rule['symbol']), + }) + cur_rule['re']:set_limit(match_limit) + cur_rule['re']:set_max_hits(1) + end + elseif words[1] == "meta" then + -- meta SYMBOL expression + if valid_rule then + insert_cur_rule() + end + cur_rule['type'] = 'meta' + cur_rule['symbol'] = words[2] + cur_rule['meta'] = words_to_re(words, 2) + if cur_rule['meta'] and cur_rule['symbol'] + and cur_rule['meta'] ~= '0' then + valid_rule = true + end + elseif words[1] == "describe" and valid_rule then + cur_rule['description'] = words_to_re(words, 2) + elseif words[1] == "score" then + scores[words[2]] = parse_score(words) + elseif words[1] == 'freemail_domains' then + fun.each(function(dom) + table.insert(freemail_domains, '@' .. dom) + end, fun.drop_n(1, words)) + elseif words[1] == 'blacklist_from' then + sa_lists['from_blacklist'][words[2]] = 1 + sa_lists['elts'] = sa_lists['elts'] + 1 + elseif words[1] == 'whitelist_from' then + sa_lists['from_whitelist'][words[2]] = 1 + sa_lists['elts'] = sa_lists['elts'] + 1 + elseif words[1] == 'whitelist_to' then + sa_lists['to_whitelist'][words[2]] = 1 + sa_lists['elts'] = sa_lists['elts'] + 1 + elseif words[1] == 'blacklist_to' then + sa_lists['to_blacklist'][words[2]] = 1 + sa_lists['elts'] = sa_lists['elts'] + 1 + elseif words[1] == 'tflags' then + process_tflags(cur_rule, words) + elseif words[1] == 'replace_tag' then + process_replace(words, replace['tags']) + elseif words[1] == 'replace_pre' then + process_replace(words, replace['pre']) + elseif words[1] == 'replace_inter' then + process_replace(words, replace['inter']) + elseif words[1] == 'replace_post' then + process_replace(words, replace['post']) + elseif words[1] == 'replace_rules' then + fun.each(function(r) + table.insert(replace['rules'], r) + end, + fun.drop_n(1, words)) + end + end)() + end + if valid_rule then + insert_cur_rule() + end +end + +-- Now check all valid rules and add the according rspamd rules + +local function calculate_score(sym, rule) + if fun.all(function(c) + return c == '_' + end, fun.take_n(2, fun.iter(sym))) then + return 0.0 + end + + if rule['nice'] or (rule['score'] and rule['score'] < 0.0) then + return -1.0 + end + + return 1.0 +end + +local function add_sole_meta(sym, rule) + local r = { + type = 'meta', + meta = rule['symbol'], + score = rule['score'], + description = rule['description'] + } + rules[sym] = r +end + +local function sa_regexp_match(data, re, raw, rule) + local res = 0 + if not re then + return 0 + end + if rule['multiple'] then + local lim = -1 + if rule['maxhits'] then + lim = rule['maxhits'] + end + res = res + re:matchn(data, lim, raw) + else + if re:match(data, raw) then + res = 1 + end + end + + return res +end + +local function apply_replacements(str) + local pre = "" + local post = "" + local inter = "" + + local function check_specific_tag(prefix, s, tbl) + local replacement = nil + local ret = s + fun.each(function(n, t) + local ns, matches = string.gsub(s, string.format("<%s%s>", prefix, n), "") + if matches > 0 then + replacement = t + ret = ns + end + end, tbl) + + return ret, replacement + end + + local repl + str, repl = check_specific_tag("pre ", str, replace['pre']) + if repl then + pre = repl + end + str, repl = check_specific_tag("inter ", str, replace['inter']) + if repl then + inter = repl + end + str, repl = check_specific_tag("post ", str, replace['post']) + if repl then + post = repl + end + + -- XXX: ugly hack + if inter then + str = string.gsub(str, "><", string.format(">%s<", inter)) + end + + local function replace_all_tags(s) + local sstr + sstr = s + fun.each(function(n, t) + local rep = string.format("%s%s%s", pre, t, post) + rep = string.gsub(rep, '%%', '%%%%') + sstr = string.gsub(sstr, string.format("<%s>", n), rep) + end, replace['tags']) + + return sstr + end + + local s = replace_all_tags(str) + + if str ~= s then + return true, s + end + + return false, str +end + +local function parse_atom(str) + local atom = table.concat(fun.totable(fun.take_while(function(c) + if string.find(', \t()><+!|&\n', c, 1, true) then + return false + end + return true + end, fun.iter(str))), '') + + return atom +end + +local function gen_process_atom_cb(result_name, task) + return function(atom) + local atom_cb = atoms[atom] + + if atom_cb then + local res = atom_cb(task, result_name) + + if not res then + lua_util.debugm(N, task, 'metric: %s, atom: %s, NULL result', result_name, atom) + elseif res > 0 then + lua_util.debugm(N, task, 'metric: %s, atom: %s, result: %s', result_name, atom, res) + end + return res + else + -- This is likely external atom + local real_sym = atom + if symbols_replacements[atom] then + real_sym = symbols_replacements[atom] + end + if task:has_symbol(real_sym, result_name) then + lua_util.debugm(N, task, 'external atom: %s, result: 1, named_result: %s', real_sym, result_name) + return 1 + end + lua_util.debugm(N, task, 'external atom: %s, result: 0, , named_result: %s', real_sym, result_name) + end + return 0 + end +end + +local function post_process() + -- Replace rule tags + local ntags = {} + local function rec_replace_tags(tag, tagv) + if ntags[tag] then + return ntags[tag] + end + fun.each(function(n, t) + if n ~= tag then + local s, matches = string.gsub(tagv, string.format("<%s>", n), t) + if matches > 0 then + ntags[tag] = rec_replace_tags(tag, s) + end + end + end, replace['tags']) + + if not ntags[tag] then + ntags[tag] = tagv + end + return ntags[tag] + end + + fun.each(function(n, t) + rec_replace_tags(n, t) + end, replace['tags']) + fun.each(function(n, t) + replace['tags'][n] = t + end, ntags) + + fun.each(function(r) + local rule = rules[r] + + if rule['re_expr'] and rule['re'] then + local res, nexpr = apply_replacements(rule['re_expr']) + if res then + local nre = rspamd_regexp.create(nexpr) + if not nre then + rspamd_logger.errx(rspamd_config, 'cannot apply replacement for rule %1', r) + --rule['re'] = nil + else + local old_max_hits = rule['re']:get_max_hits() + lua_util.debugm(N, rspamd_config, 'replace %1 -> %2', r, nexpr) + rspamd_config:replace_regexp({ + old_re = rule['re'], + new_re = nre, + pcre_only = is_pcre_only(rule['symbol']), + }) + rule['re'] = nre + rule['re_expr'] = nexpr + nre:set_limit(match_limit) + nre:set_max_hits(old_max_hits) + end + end + end + end, replace['rules']) + + fun.each(function(key, score) + if rules[key] then + rules[key]['score'] = score + end + end, scores) + + -- Header rules + fun.each(function(k, r) + local f = function(task) + + local raw = false + local check = {} + -- Cached path for ordinary expressions + if r['ordinary'] then + local h = r['header'][1] + local t = 'header' + + if h['raw'] then + t = 'rawheader' + end + + if not r['re'] then + rspamd_logger.errx(task, 're is missing for rule %1 (%2 header)', k, + h['header']) + return 0 + end + + local ret = process_regexp_opt(r.re, task, t, h.header, h.strong) + + if r['not'] then + if ret ~= 0 then + ret = 0 + else + ret = 1 + end + end + + return ret + end + + -- Slow path + fun.each(function(h) + local hname = h['header'] + + local hdr + if h['mime'] then + local parts = task:get_parts() + for _, p in ipairs(parts) do + local m_hdr = p:get_header_full(hname, h['strong']) + + if m_hdr then + if not hdr then + hdr = {} + end + for _, mh in ipairs(m_hdr) do + table.insert(hdr, mh) + end + end + end + else + hdr = task:get_header_full(hname, h['strong']) + end + + if hdr then + for _, rh in ipairs(hdr) do + -- Subject for optimization + local str + if h['raw'] then + str = rh['value'] + raw = true + else + str = rh['decoded'] + end + if not str then + return 0 + end + + if h['function'] then + str = h['function'](str) + end + + if type(str) == 'string' then + table.insert(check, str) + else + for _, c in ipairs(str) do + table.insert(check, c) + end + end + end + elseif r['unset'] then + table.insert(check, r['unset']) + end + end, r['header']) + + if #check == 0 then + if r['not'] then + return 1 + end + return 0 + end + + local ret = 0 + for _, c in ipairs(check) do + local match = sa_regexp_match(c, r['re'], raw, r) + if (match > 0 and not r['not']) or (match == 0 and r['not']) then + ret = 1 + end + end + + return ret + end + if r['score'] then + local real_score = r['score'] * calculate_score(k, r) + if math.abs(real_score) > meta_score_alpha then + add_sole_meta(k, r) + end + end + atoms[k] = f + end, + fun.filter(function(_, r) + return r['type'] == 'header' and r['header'] + end, + rules)) + + -- Custom function rules + fun.each(function(k, r) + local f = function(task) + local res = r['function'](task) + if res and res > 0 then + return res + end + return 0 + end + if r['score'] then + local real_score = r['score'] * calculate_score(k, r) + if math.abs(real_score) > meta_score_alpha then + add_sole_meta(k, r) + end + end + atoms[k] = f + end, + fun.filter(function(_, r) + return r['type'] == 'function' and r['function'] + end, + rules)) + + -- Parts rules + fun.each(function(k, r) + local f = function(task) + if not r['re'] then + rspamd_logger.errx(task, 're is missing for rule %1', k) + return 0 + end + + local t = 'mime' + if r['raw'] then + t = 'rawmime' + end + + return process_regexp_opt(r.re, task, t) + end + if r['score'] then + local real_score = r['score'] * calculate_score(k, r) + if math.abs(real_score) > meta_score_alpha then + add_sole_meta(k, r) + end + end + atoms[k] = f + end, + fun.filter(function(_, r) + return r['type'] == 'part' + end, rules)) + + -- SA body rules + fun.each(function(k, r) + local f = function(task) + if not r['re'] then + rspamd_logger.errx(task, 're is missing for rule %1', k) + return 0 + end + + local t = r['type'] + + local ret = process_regexp_opt(r.re, task, t) + return ret + end + if r['score'] then + local real_score = r['score'] * calculate_score(k, r) + if math.abs(real_score) > meta_score_alpha then + add_sole_meta(k, r) + end + end + atoms[k] = f + end, + fun.filter(function(_, r) + return r['type'] == 'sabody' or r['type'] == 'message' or r['type'] == 'sarawbody' + end, rules)) + + -- URL rules + fun.each(function(k, r) + local f = function(task) + if not r['re'] then + rspamd_logger.errx(task, 're is missing for rule %1', k) + return 0 + end + + return process_regexp_opt(r.re, task, 'url') + end + if r['score'] then + local real_score = r['score'] * calculate_score(k, r) + if math.abs(real_score) > meta_score_alpha then + add_sole_meta(k, r) + end + end + atoms[k] = f + end, + fun.filter(function(_, r) + return r['type'] == 'uri' + end, + rules)) + -- Meta rules + fun.each(function(k, r) + local expression = nil + -- Meta function callback + -- Here are dragons! + -- This function can be called from 2 DIFFERENT type of invocations: + -- 1) Invocation from Rspamd itself where `res_name` will be nil + -- 2) Invocation from other meta during expression:process_traced call + -- So we need to distinguish that and return different stuff to be able to deal with atoms + local meta_cb = function(task, res_name) + lua_util.debugm(N, task, 'meta callback for %s; result name: %s', k, res_name) + local cached = task:cache_get('sa_metas_processed') + + -- We avoid many task methods invocations here (likely) + if not cached then + cached = {} + task:cache_set('sa_metas_processed', cached) + end + + local already_processed = cached[k] + + -- Exclude elements that are named in the same way as the symbol itself + local function exclude_sym_filter(sopt) + return sopt ~= k + end + + if not (already_processed and already_processed[res_name or 'default']) then + -- Execute symbol + local function exec_symbol(cur_res) + local res, trace = expression:process_traced(gen_process_atom_cb(cur_res, task)) + lua_util.debugm(N, task, 'meta result for %s: %s; result name: %s', k, res, cur_res) + if res > 0 then + -- Symbol should be one shot to make it working properly + task:insert_result_named(cur_res, k, res, fun.totable(fun.filter(exclude_sym_filter, trace))) + end + + if not cached[k] then + cached[k] = {} + end + + cached[k][cur_res] = res + end + + if not res_name then + -- Invoke for all named results + local named_results = task:get_all_named_results() + for _, cur_res in ipairs(named_results) do + exec_symbol(cur_res) + end + else + -- Invoked from another meta + exec_symbol(res_name) + return cached[k][res_name] or 0 + end + else + -- We have cached the result + local res = already_processed[res_name or 'default'] or 0 + lua_util.debugm(N, task, 'cached meta result for %s: %s; result name: %s', + k, res, res_name) + + if res_name then + return res + end + end + + -- No return if invoked directly from Rspamd as we use task:insert_result_named directly + end + + expression = rspamd_expression.create(r['meta'], parse_atom, rspamd_config:get_mempool()) + if not expression then + rspamd_logger.errx(rspamd_config, 'Cannot parse expression ' .. r['meta']) + else + + if r['score'] then + rspamd_config:set_metric_symbol { + name = k, score = r['score'], + description = r['description'], + priority = scores_priority, + one_shot = true + } + scores_added[k] = 1 + rspamd_config:register_symbol { + name = k, + weight = calculate_score(k, r), + callback = meta_cb + } + else + -- Add 0 score to avoid issues + rspamd_config:register_symbol { + name = k, + weight = calculate_score(k, r), + callback = meta_cb, + score = 0, + } + end + + r['expression'] = expression + + if not atoms[k] then + atoms[k] = meta_cb + end + end + end, + fun.filter(function(_, r) + return r['type'] == 'meta' + end, + rules)) + + -- Check meta rules for foreign symbols and register dependencies + -- First direct dependencies: + fun.each(function(k, r) + if r['expression'] then + local expr_atoms = r['expression']:atoms() + + for _, a in ipairs(expr_atoms) do + if not atoms[a] then + local rspamd_symbol = replace_symbol(a) + if not external_deps[k] then + external_deps[k] = {} + end + + if not external_deps[k][rspamd_symbol] then + rspamd_config:register_dependency(k, rspamd_symbol) + external_deps[k][rspamd_symbol] = true + lua_util.debugm(N, rspamd_config, + 'atom %1 is a direct foreign dependency, ' .. + 'register dependency for %2 on %3', + a, k, rspamd_symbol) + end + end + end + end + end, + fun.filter(function(_, r) + return r['type'] == 'meta' + end, + rules)) + + -- ... And then indirect ones ... + local nchanges + repeat + nchanges = 0 + fun.each(function(k, r) + if r['expression'] then + local expr_atoms = r['expression']:atoms() + for _, a in ipairs(expr_atoms) do + if type(external_deps[a]) == 'table' then + for dep in pairs(external_deps[a]) do + if not external_deps[k] then + external_deps[k] = {} + end + if not external_deps[k][dep] then + rspamd_config:register_dependency(k, dep) + external_deps[k][dep] = true + lua_util.debugm(N, rspamd_config, + 'atom %1 is an indirect foreign dependency, ' .. + 'register dependency for %2 on %3', + a, k, dep) + nchanges = nchanges + 1 + end + end + else + local rspamd_symbol, replaced_symbol = replace_symbol(a) + if replaced_symbol then + external_deps[a] = { [rspamd_symbol] = true } + else + external_deps[a] = {} + end + end + end + end + end, + fun.filter(function(_, r) + return r['type'] == 'meta' + end, + rules)) + until nchanges == 0 + + -- Set missing symbols + fun.each(function(key, score) + if not scores_added[key] then + rspamd_config:set_metric_symbol({ + name = key, score = score, + priority = 2, flags = 'ignore' }) + end + end, scores) + + -- Logging output + if freemail_domains then + freemail_trie = rspamd_trie.create(freemail_domains) + rspamd_logger.infox(rspamd_config, 'loaded %1 freemail domains definitions', + #freemail_domains) + end + rspamd_logger.infox(rspamd_config, 'loaded %1 blacklist/whitelist elements', + sa_lists['elts']) +end + +local has_rules = false + +if type(section) == "table" then + local keywords = { + pcre_only = { 'table', function(v) + pcre_only_regexps = lua_util.list_to_hash(v) + end }, + alpha = { 'number', function(v) + meta_score_alpha = tonumber(v) + end }, + match_limit = { 'number', function(v) + match_limit = tonumber(v) + end }, + scores_priority = { 'number', function(v) + scores_priority = tonumber(v) + end }, + } + + for k, fn in pairs(section) do + local kw = keywords[k] + if kw and type(fn) == kw[1] then + kw[2](fn) + else + -- SA rule file + if type(fn) == 'table' then + for _, elt in ipairs(fn) do + local files = util.glob(elt) + + if not files or #files == 0 then + rspamd_logger.errx(rspamd_config, "cannot find any files matching pattern %s", elt) + else + for _, matched in ipairs(files) do + local f = io.open(matched, "r") + if f then + rspamd_logger.infox(rspamd_config, 'loading SA rules from %s', matched) + process_sa_conf(f) + has_rules = true + else + rspamd_logger.errx(rspamd_config, "cannot open %1", matched) + end + end + end + end + else + -- assume string + local files = util.glob(fn) + + if not files or #files == 0 then + rspamd_logger.errx(rspamd_config, "cannot find any files matching pattern %s", fn) + else + for _, matched in ipairs(files) do + local f = io.open(matched, "r") + if f then + rspamd_logger.infox(rspamd_config, 'loading SA rules from %s', matched) + process_sa_conf(f) + has_rules = true + else + rspamd_logger.errx(rspamd_config, "cannot open %1", matched) + end + end + end + end + end + end +end + +if has_rules then + post_process() +else + lua_util.disable_module(N, "config") +end diff --git a/src/plugins/lua/spamtrap.lua b/src/plugins/lua/spamtrap.lua new file mode 100644 index 0000000..cd3b296 --- /dev/null +++ b/src/plugins/lua/spamtrap.lua @@ -0,0 +1,200 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> +Copyright (c) 2016, Andrew Lewis <nerf@judo.za.org> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +-- A plugin that triggers, if a spam trapped email address was detected + +local rspamd_logger = require "rspamd_logger" +local redis_params +local use_redis = false; +local M = 'spamtrap' +local lua_util = require "lua_util" +local fun = require "fun" + +local settings = { + symbol = 'SPAMTRAP', + score = 0.0, + learn_fuzzy = false, + learn_spam = false, + fuzzy_flag = 1, + fuzzy_weight = 10.0, + key_prefix = 'sptr_', + allow_multiple_rcpts = false, +} + +local check_authed = true +local check_local = true + +local function spamtrap_cb(task) + local rcpts = task:get_recipients('smtp') + local authed_user = task:get_user() + local ip_addr = task:get_ip() + local called_for_domain = false + + if ((not check_authed and authed_user) or + (not check_local and ip_addr and ip_addr:is_local())) then + rspamd_logger.infox(task, "skip spamtrap checks for local networks or authenticated user"); + return + end + + local function do_action(rcpt) + if settings['learn_fuzzy'] then + rspamd_plugins.fuzzy_check.learn(task, + settings['fuzzy_flag'], + settings['fuzzy_weight']) + end + local act_flags = '' + if settings['learn_spam'] then + task:set_flag("learn_spam") + -- Allow processing as we still need to learn and do other stuff + act_flags = 'process_all' + end + task:insert_result(settings['symbol'], 1, rcpt) + + if settings.action then + rspamd_logger.infox(task, 'spamtrap found: <%s>', rcpt) + local smtp_message + if settings.smtp_message then + smtp_message = lua_util.template(settings.smtp_message, { rcpt = rcpt }) + else + smtp_message = 'unknown error' + if settings.action == 'no action' then + smtp_message = 'message accepted' + elseif settings.action == 'reject' then + smtp_message = 'message rejected' + end + end + task:set_pre_result { action = settings.action, + message = smtp_message, + module = 'spamtrap', + flags = act_flags } + end + + return true + end + + local function gen_redis_spamtrap_cb(target) + return function(err, data) + if err ~= nil then + rspamd_logger.errx(task, 'redis_spamtrap_cb received error: %1', err) + return + end + + if data and type(data) ~= 'userdata' then + do_action(target) + else + if not called_for_domain then + -- Recurse for @catchall domain + target = rcpts[1]['domain']:lower() + local key = settings['key_prefix'] .. '@' .. target + local ret = rspamd_redis_make_request(task, + redis_params, -- connect params + key, -- hash key + false, -- is write + gen_redis_spamtrap_cb(target), -- callback + 'GET', -- command + { key } -- arguments + ) + if not ret then + rspamd_logger.errx(task, "redis request wasn't scheduled") + end + called_for_domain = true + else + lua_util.debugm(M, task, 'skip spamtrap for %s', target) + end + end + end + end + + -- Do not risk a FP by checking for more than one recipient + if rcpts and (#rcpts == 1 or (#rcpts > 0 and settings.allow_multiple_rcpts)) then + local targets = fun.map(function(r) + return r['addr']:lower() + end, rcpts) + if use_redis then + fun.each(function(target) + local key = settings['key_prefix'] .. target + local ret = rspamd_redis_make_request(task, + redis_params, -- connect params + key, -- hash key + false, -- is write + gen_redis_spamtrap_cb(target), -- callback + 'GET', -- command + { key } -- arguments + ) + if not ret then + rspamd_logger.errx(task, "redis request wasn't scheduled") + end + end, targets) + + elseif settings['map'] then + local function check_map_functor(target) + if settings['map']:get_key(target) then + return do_action(target) + end + end + if not fun.any(check_map_functor, targets) then + lua_util.debugm(M, task, 'skip spamtrap') + end + end + end +end + +-- Module setup + +local opts = rspamd_config:get_all_opt('spamtrap') +if not (opts and type(opts) == 'table') then + rspamd_logger.infox(rspamd_config, 'module is unconfigured') + return +end + +local auth_and_local_conf = lua_util.config_check_local_or_authed(rspamd_config, 'spamtrap', + false, false) +check_local = auth_and_local_conf[1] +check_authed = auth_and_local_conf[2] + +if opts then + for k, v in pairs(opts) do + settings[k] = v + end + if settings['map'] then + settings['map'] = rspamd_config:add_map { + url = settings['map'], + description = string.format("Spamtrap map for %s", settings['symbol']), + type = "regexp" + } + else + redis_params = rspamd_parse_redis_server('spamtrap') + if not redis_params then + rspamd_logger.errx( + rspamd_config, 'no redis servers are specified, disabling module') + return + end + use_redis = true; + end + + local id = rspamd_config:register_symbol({ + name = "SPAMTRAP_CHECK", + type = "callback,postfilter", + callback = spamtrap_cb + }) + rspamd_config:register_symbol({ + name = settings['symbol'], + parent = id, + type = 'virtual', + score = settings.score + }) +end diff --git a/src/plugins/lua/spf.lua b/src/plugins/lua/spf.lua new file mode 100644 index 0000000..5e15128 --- /dev/null +++ b/src/plugins/lua/spf.lua @@ -0,0 +1,242 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local N = "spf" +local lua_util = require "lua_util" +local rspamd_spf = require "rspamd_spf" +local bit = require "bit" +local rspamd_logger = require "rspamd_logger" + +if confighelp then + rspamd_config:add_example(nil, N, + 'Performs SPF checks', + [[ +spf { + # Enable module + enabled = true + # Number of elements in the cache of parsed SPF records + spf_cache_size = 2048; + # Default max expire for an element in this cache + spf_cache_expire = 1d; + # Whitelist IPs from checks + whitelist = "/path/to/some/file"; + # Maximum number of recursive DNS subrequests (e.g. includes chanin length) + max_dns_nesting = 10; + # Maximum count of DNS requests per record + max_dns_requests = 30; + # Minimum TTL enforced for all elements in SPF records + min_cache_ttl = 5m; + # Disable all IPv6 lookups + disable_ipv6 = false; + # Use IP address from a received header produced by this relay (using by attribute) + external_relay = ["192.168.1.1"]; +} + ]]) + return +end + +local symbols = { + fail = "R_SPF_FAIL", + softfail = "R_SPF_SOFTFAIL", + neutral = "R_SPF_NEUTRAL", + allow = "R_SPF_ALLOW", + dnsfail = "R_SPF_DNSFAIL", + permfail = "R_SPF_PERMFAIL", + na = "R_SPF_NA", +} + +local default_config = { + spf_cache_size = 2048, + max_dns_nesting = 10, + max_dns_requests = 30, + whitelist = nil, + min_cache_ttl = 60 * 5, + disable_ipv6 = false, + symbols = symbols, + external_relay = nil, +} + +local local_config = rspamd_config:get_all_opt('spf') +local auth_and_local_conf = lua_util.config_check_local_or_authed(rspamd_config, N, + false, false) + +if local_config then + local_config = lua_util.override_defaults(default_config, local_config) +else + local_config = default_config +end + +local function spf_check_callback(task) + + local ip + + if local_config.external_relay then + -- Search received headers to get header produced by an external relay + local rh = task:get_received_headers() or {} + local found = false + + for i, hdr in ipairs(rh) do + if hdr.real_ip and local_config.external_relay:get_key(hdr.real_ip) then + -- We can use the next header as a source of IP address + if rh[i + 1] then + local nhdr = rh[i + 1] + lua_util.debugm(N, task, 'found external relay %s at received header number %s -> %s', + local_config.external_relay, i, nhdr.real_ip) + + if nhdr.real_ip then + ip = nhdr.real_ip + found = true + end + end + + break + end + end + if not found then + ip = task:get_from_ip() + rspamd_logger.warnx(task, + "cannot find external relay for SPF checks in received headers; use the original IP: %s", + tostring(ip)) + end + else + ip = task:get_from_ip() + end + + local function flag_to_symbol(fl) + if bit.band(fl, rspamd_spf.flags.temp_fail) ~= 0 then + return local_config.symbols.dnsfail + elseif bit.band(fl, rspamd_spf.flags.perm_fail) ~= 0 then + return local_config.symbols.permfail + elseif bit.band(fl, rspamd_spf.flags.na) ~= 0 then + return local_config.symbols.na + end + + return 'SPF_UNKNOWN' + end + + local function policy_decode(res) + if res == rspamd_spf.policy.fail then + return local_config.symbols.fail, '-' + elseif res == rspamd_spf.policy.pass then + return local_config.symbols.allow, '+' + elseif res == rspamd_spf.policy.soft_fail then + return local_config.symbols.softfail, '~' + elseif res == rspamd_spf.policy.neutral then + return local_config.symbols.neutral, '?' + end + + return 'SPF_UNKNOWN', '?' + end + + local function spf_resolved_cb(record, flags, err) + lua_util.debugm(N, task, 'got spf results: %s flags, %s err', + flags, err) + + if record then + local result, flag_or_policy, error_or_addr = record:check_ip(ip) + + lua_util.debugm(N, task, + 'checked ip %s: result=%s, flag_or_policy=%s, error_or_addr=%s', + ip, flags, err, error_or_addr) + + if result then + local sym, code = policy_decode(flag_or_policy) + local opt = string.format('%s%s', code, error_or_addr.str or '???') + if bit.band(flags, rspamd_spf.flags.cached) ~= 0 then + opt = opt .. ':c' + rspamd_logger.infox(task, + "use cached record for %s (0x%s) in LRU cache for %s seconds", + record:get_domain(), + record:get_digest(), + record:get_ttl() - math.floor(task:get_timeval(true) - + record:get_timestamp())); + end + task:insert_result(sym, 1.0, opt) + else + local sym = flag_to_symbol(flag_or_policy) + task:insert_result(sym, 1.0, error_or_addr) + end + else + local sym = flag_to_symbol(flags) + task:insert_result(sym, 1.0, err) + end + end + + if ip then + if local_config.whitelist and ip and local_config.whitelist:get_key(ip) then + rspamd_logger.infox(task, 'whitelisted SPF checks from %s', + tostring(ip)) + return + end + + if lua_util.is_skip_local_or_authed(task, auth_and_local_conf, ip) then + rspamd_logger.infox(task, 'skip SPF checks for local networks and authorized users') + return + end + + rspamd_spf.resolve(task, spf_resolved_cb) + else + lua_util.debugm(N, task, "spf checks are not possible as no source IP address is defined") + end + + -- FIXME: we actually need to set this variable when we really checked SPF + -- However, the old C module has set it all the times + -- Hence, we follow the same rule for now. It should be better designed at some day + local mpool = task:get_mempool() + local dmarc_checks = mpool:get_variable('dmarc_checks', 'double') or 0 + dmarc_checks = dmarc_checks + 1 + mpool:set_variable('dmarc_checks', dmarc_checks) +end + +-- Register all symbols and init rspamd_spf library +rspamd_spf.config(local_config) +local sym_id = rspamd_config:register_symbol { + name = 'SPF_CHECK', + type = 'callback', + flags = 'fine,empty', + groups = { 'policies', 'spf' }, + score = 0.0, + callback = spf_check_callback, + -- We can merely estimate timeout here, as it is possible to construct an SPF record that would cause + -- many DNS requests. But we won't like to set the maximum value for that all the time, as + -- the majority of requests will typically have 1-4 subrequests + augmentations = { string.format("timeout=%f", rspamd_config:get_dns_timeout() * 4 or 0.0) }, +} + +if local_config.whitelist then + local lua_maps = require "lua_maps" + + local_config.whitelist = lua_maps.map_add_from_ucl(local_config.whitelist, + "radix", "SPF whitelist map") +end + +if local_config.external_relay then + local lua_maps = require "lua_maps" + + local_config.external_relay = lua_maps.map_add_from_ucl(local_config.external_relay, + "radix", "External IP SPF map") +end + +for _, sym in pairs(local_config.symbols) do + rspamd_config:register_symbol { + name = sym, + type = 'virtual', + parent = sym_id, + groups = { 'policies', 'spf' }, + } +end + + diff --git a/src/plugins/lua/trie.lua b/src/plugins/lua/trie.lua new file mode 100644 index 0000000..7ba4552 --- /dev/null +++ b/src/plugins/lua/trie.lua @@ -0,0 +1,184 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if confighelp then + return +end + +-- Trie is rspamd module designed to define and operate with suffix trie + +local N = 'trie' +local rspamd_logger = require "rspamd_logger" +local rspamd_trie = require "rspamd_trie" +local fun = require "fun" +local lua_util = require "lua_util" + +local mime_trie +local raw_trie +local body_trie + +-- here we store all patterns as text +local mime_patterns = {} +local raw_patterns = {} +local body_patterns = {} + +-- here we store params for each pattern, so for each i = 1..n patterns[i] +-- should have corresponding params[i] +local mime_params = {} +local raw_params = {} +local body_params = {} + +local function tries_callback(task) + + local matched = {} + + local function gen_trie_cb(type) + local patterns = mime_patterns + local params = mime_params + if type == 'rawmessage' then + patterns = raw_patterns + params = raw_params + elseif type == 'rawbody' then + patterns = body_patterns + params = body_params + end + + return function(idx, pos) + local param = params[idx] + local pattern = patterns[idx] + local pattern_idx = pattern .. tostring(idx) .. type + + if param['multi'] or not matched[pattern_idx] then + lua_util.debugm(N, task, "<%1> matched pattern %2 at pos %3", + task:get_message_id(), pattern, pos) + task:insert_result(param['symbol'], 1.0, type) + if not param['multi'] then + matched[pattern_idx] = true + end + end + end + end + + if mime_trie then + mime_trie:search_mime(task, gen_trie_cb('mime')) + end + if raw_trie then + raw_trie:search_rawmsg(task, gen_trie_cb('rawmessage')) + end + if body_trie then + body_trie:search_rawbody(task, gen_trie_cb('rawbody')) + end +end + +local function process_single_pattern(pat, symbol, cf) + if pat then + local multi = false + if cf['multi'] then + multi = true + end + + if cf['raw'] then + table.insert(raw_patterns, pat) + table.insert(raw_params, { symbol = symbol, multi = multi }) + elseif cf['body'] then + table.insert(body_patterns, pat) + table.insert(body_params, { symbol = symbol, multi = multi }) + else + table.insert(mime_patterns, pat) + table.insert(mime_params, { symbol = symbol, multi = multi }) + end + end +end + +local function process_trie_file(symbol, cf) + local file = io.open(cf['file']) + + if not file then + rspamd_logger.errx(rspamd_config, 'Cannot open trie file %1', cf['file']) + else + if cf['binary'] then + rspamd_logger.errx(rspamd_config, 'binary trie patterns are not implemented yet: %1', + cf['file']) + else + for line in file:lines() do + local pat = string.match(line, '^([^#].*[^%s])%s*$') + process_single_pattern(pat, symbol, cf) + end + end + end +end + +local function process_trie_conf(symbol, cf) + if type(cf) ~= 'table' then + rspamd_logger.errx(rspamd_config, 'invalid value for symbol %1: "%2", expected table', + symbol, cf) + return + end + + if cf['file'] then + process_trie_file(symbol, cf) + elseif cf['patterns'] then + fun.each(function(pat) + process_single_pattern(pat, symbol, cf) + end, cf['patterns']) + end +end + +local opts = rspamd_config:get_all_opt("trie") +if opts then + for sym, opt in pairs(opts) do + process_trie_conf(sym, opt) + end + + if #raw_patterns > 0 then + raw_trie = rspamd_trie.create(raw_patterns) + rspamd_logger.infox(rspamd_config, 'registered raw search trie from %1 patterns', #raw_patterns) + end + + if #mime_patterns > 0 then + mime_trie = rspamd_trie.create(mime_patterns) + rspamd_logger.infox(rspamd_config, 'registered mime search trie from %1 patterns', #mime_patterns) + end + + if #body_patterns > 0 then + body_trie = rspamd_trie.create(body_patterns) + rspamd_logger.infox(rspamd_config, 'registered body search trie from %1 patterns', #body_patterns) + end + + local id = -1 + if mime_trie or raw_trie or body_trie then + id = rspamd_config:register_symbol({ + name = 'TRIE_CALLBACK', + type = 'callback', + callback = tries_callback + }) + else + rspamd_logger.infox(rspamd_config, 'no tries defined') + end + + if id ~= -1 then + for sym in pairs(opts) do + rspamd_config:register_symbol({ + name = sym, + type = 'virtual', + parent = id + }) + end + end +else + rspamd_logger.infox(rspamd_config, "Module is unconfigured") + lua_util.disable_module(N, "config") +end diff --git a/src/plugins/lua/url_redirector.lua b/src/plugins/lua/url_redirector.lua new file mode 100644 index 0000000..10b5fb2 --- /dev/null +++ b/src/plugins/lua/url_redirector.lua @@ -0,0 +1,422 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if confighelp then + return +end + +local rspamd_logger = require "rspamd_logger" +local rspamd_http = require "rspamd_http" +local hash = require "rspamd_cryptobox_hash" +local rspamd_url = require "rspamd_url" +local lua_util = require "lua_util" +local lua_redis = require "lua_redis" +local N = "url_redirector" + +-- Some popular UA +local default_ua = { + 'Mozilla/5.0 (compatible; Yahoo! Slurp; http://help.yahoo.com/help/us/ysearch/slurp)', + 'Mozilla/5.0 (compatible; YandexBot/3.0; +http://yandex.com/bots)', + 'Wget/1.9.1', + 'Mozilla/5.0 (Android; Linux armv7l; rv:9.0) Gecko/20111216 Firefox/9.0 Fennec/9.0', + 'Mozilla/5.0 (Windows NT 5.2; RW; rv:7.0a1) Gecko/20091211 SeaMonkey/9.23a1pre', + 'Mozilla/5.0 (Windows NT 6.1; WOW64; Trident/7.0; AS; rv:11.0) like Gecko', + 'W3C-checklink/4.5 [4.160] libwww-perl/5.823', + 'Lynx/2.8.8dev.3 libwww-FM/2.14 SSL-MM/1.4.1', +} + +local redis_params + +local settings = { + expire = 86400, -- 1 day by default + timeout = 10, -- 10 seconds by default + nested_limit = 5, -- How many redirects to follow + --proxy = "http://example.com:3128", -- Send request through proxy + key_prefix = 'rdr:', -- default hash name + check_ssl = false, -- check ssl certificates + max_urls = 5, -- how many urls to check + max_size = 10 * 1024, -- maximum body to process + user_agent = default_ua, + redirector_symbol = nil, -- insert symbol if redirected url has been found + redirector_symbol_nested = "URL_REDIRECTOR_NESTED", -- insert symbol if nested limit has been reached + redirectors_only = true, -- follow merely redirectors + top_urls_key = 'rdr:top_urls', -- key for top urls + top_urls_count = 200, -- how many top urls to save + redirector_hosts_map = nil -- check only those redirectors +} + +local function adjust_url(task, orig_url, redir_url) + local mempool = task:get_mempool() + if type(redir_url) == 'string' then + redir_url = rspamd_url.create(mempool, redir_url, { 'redirect_target' }) + end + + if redir_url then + orig_url:set_redirected(redir_url, mempool) + task:inject_url(redir_url) + if settings.redirector_symbol then + task:insert_result(settings.redirector_symbol, 1.0, + string.format('%s->%s', orig_url:get_host(), redir_url:get_host())) + end + else + rspamd_logger.infox(task, 'bad url %s as redirection for %s', redir_url, orig_url) + end +end + +local function cache_url(task, orig_url, url, key, prefix) + -- String representation + local str_orig_url = tostring(orig_url) + local str_url = tostring(url) + + if str_url ~= str_orig_url then + -- Set redirected url + adjust_url(task, orig_url, url) + end + + local function redis_trim_cb(err, _) + if err then + rspamd_logger.errx(task, 'got error while getting top urls count: %s', err) + else + rspamd_logger.infox(task, 'trimmed url set to %s elements', + settings.top_urls_count) + end + end + + -- Cleanup logic + local function redis_card_cb(err, data) + if err then + rspamd_logger.errx(task, 'got error while getting top urls count: %s', err) + else + if data then + if tonumber(data) > settings.top_urls_count * 2 then + local ret = lua_redis.redis_make_request(task, + redis_params, -- connect params + key, -- hash key + true, -- is write + redis_trim_cb, --callback + 'ZREMRANGEBYRANK', -- command + { settings.top_urls_key, '0', + tostring(-(settings.top_urls_count + 1)) } -- arguments + ) + if not ret then + rspamd_logger.errx(task, 'cannot trim top urls set') + else + rspamd_logger.infox(task, 'need to trim urls set from %s to %s elements', + data, + settings.top_urls_count) + return + end + end + end + end + end + + local function redis_set_cb(err, _) + if err then + rspamd_logger.errx(task, 'got error while setting redirect keys: %s', err) + else + local ret = lua_redis.redis_make_request(task, + redis_params, -- connect params + key, -- hash key + false, -- is write + redis_card_cb, --callback + 'ZCARD', -- command + { settings.top_urls_key } -- arguments + ) + if not ret then + rspamd_logger.errx(task, 'cannot make redis request to cache results') + end + end + end + + if prefix then + -- Save url with prefix + str_url = string.format('^%s:%s', prefix, str_url) + end + local ret, conn, _ = lua_redis.redis_make_request(task, + redis_params, -- connect params + key, -- hash key + true, -- is write + redis_set_cb, --callback + 'SETEX', -- command + { key, tostring(settings.expire), str_url } -- arguments + ) + + if not ret then + rspamd_logger.errx(task, 'cannot make redis request to cache results') + else + conn:add_cmd('ZINCRBY', { settings.top_urls_key, '1', str_url }) + end +end + +-- Reduce length of a string to a given length (16 by default) +local function maybe_trim_url(url, limit) + if not limit then + limit = 16 + end + if #url > limit then + return string.sub(url, 1, limit) .. '...' + else + return url + end +end + +-- Resolve maybe cached url +-- Orig url is the original url object +-- url should be a new url object... +local function resolve_cached(task, orig_url, url, key, ntries) + local str_url = tostring(url or "") + local function resolve_url() + if ntries > settings.nested_limit then + -- We cannot resolve more, stop + rspamd_logger.debugm(N, task, 'cannot get more requests to resolve %s, stop on %s after %s attempts', + orig_url, url, ntries) + cache_url(task, orig_url, url, key, 'nested') + local str_orig_url = tostring(orig_url) + task:insert_result(settings.redirector_symbol_nested, 1.0, + string.format('%s->%s:%d', maybe_trim_url(str_orig_url), maybe_trim_url(str_url), ntries)) + + return + end + + local redirection_codes = { + [301] = true, -- moved permanently + [302] = true, -- found + [303] = true, -- see other + [307] = true, -- temporary redirect + [308] = true, -- permanent redirect + } + + local function http_callback(err, code, _, headers) + if err then + rspamd_logger.infox(task, 'found redirect error from %s to %s, err message: %s', + orig_url, url, err) + cache_url(task, orig_url, url, key) + else + if code == 200 then + if orig_url == url then + rspamd_logger.infox(task, 'direct url %s, err code 200', + url) + else + rspamd_logger.infox(task, 'found redirect from %s to %s, err code 200', + orig_url, url) + end + + cache_url(task, orig_url, url, key) + + elseif redirection_codes[code] then + local loc = headers['location'] + local redir_url + if loc then + redir_url = rspamd_url.create(task:get_mempool(), loc) + end + rspamd_logger.debugm(N, task, 'found redirect from %s to %s, err code %s', + orig_url, loc, code) + + if redir_url then + if settings.redirectors_only then + if settings.redirector_hosts_map:get_key(redir_url:get_host()) then + resolve_cached(task, orig_url, redir_url, key, ntries + 1) + else + lua_util.debugm(N, task, + "stop resolving redirects as %s is not a redirector", loc) + cache_url(task, orig_url, redir_url, key) + end + else + resolve_cached(task, orig_url, redir_url, key, ntries + 1) + end + else + rspamd_logger.debugm(N, task, "no location, headers: %s", headers) + cache_url(task, orig_url, url, key) + end + else + rspamd_logger.debugm(N, task, 'found redirect error from %s to %s, err code: %s', + orig_url, url, code) + cache_url(task, orig_url, url, key) + end + end + end + + local ua + if type(settings.user_agent) == 'string' then + ua = settings.user_agent + else + ua = settings.user_agent[math.random(#settings.user_agent)] + end + + lua_util.debugm(N, task, 'select user agent %s', ua) + + rspamd_http.request { + headers = { + ['User-Agent'] = ua, + }, + url = str_url, + task = task, + method = 'head', + max_size = settings.max_size, + timeout = settings.timeout, + opaque_body = true, + no_ssl_verify = not settings.check_ssl, + callback = http_callback + } + end + local function redis_get_cb(err, data) + if not err then + if type(data) == 'string' then + if data ~= 'processing' then + -- Got cached result + rspamd_logger.debugm(N, task, 'found cached redirect from %s to %s', + url, data) + if data:sub(1, 1) == '^' then + -- Prefixed url stored + local prefix, new_url = data:match('^%^(%a+):(.+)$') + if prefix == 'nested' then + task:insert_result(settings.redirector_symbol_nested, 1.0, + string.format('%s->%s:cached', maybe_trim_url(str_url), maybe_trim_url(new_url))) + end + data = new_url + end + if data ~= tostring(orig_url) then + adjust_url(task, orig_url, data) + end + return + end + end + end + local function redis_reserve_cb(nerr, ndata) + if nerr then + rspamd_logger.errx(task, 'got error while setting redirect keys: %s', nerr) + elseif ndata == 'OK' then + resolve_url() + end + end + + if ntries == 1 then + -- Reserve key in Redis that we are processing this redirection + local ret = lua_redis.redis_make_request(task, + redis_params, -- connect params + key, -- hash key + true, -- is write + redis_reserve_cb, --callback + 'SET', -- command + { key, 'processing', 'EX', tostring(settings.timeout * 2), 'NX' } -- arguments + ) + if not ret then + rspamd_logger.errx(task, 'Couldn\'t schedule SET') + end + else + -- Just continue resolving + resolve_url() + end + + end + local ret = lua_redis.redis_make_request(task, + redis_params, -- connect params + key, -- hash key + false, -- is write + redis_get_cb, --callback + 'GET', -- command + { key } -- arguments + ) + if not ret then + rspamd_logger.errx(task, 'cannot make redis request to check results') + end +end + +local function url_redirector_process_url(task, url) + local url_str = url:get_raw() + -- 32 base32 characters are roughly 20 bytes of data or 160 bits + local key = settings.key_prefix .. hash.create(url_str):base32():sub(1, 32) + resolve_cached(task, url, url, key, 1) +end + +local function url_redirector_handler(task) + local sp_urls = lua_util.extract_specific_urls({ + task = task, + limit = settings.max_urls, + filter = function(url) + local host = url:get_host() + if settings.redirector_hosts_map:get_key(host) then + lua_util.debugm(N, task, 'check url %s', tostring(url)) + return true + end + end, + no_cache = true, + need_content = true, + }) + + if sp_urls then + for _, u in ipairs(sp_urls) do + url_redirector_process_url(task, u) + end + end +end + +local opts = rspamd_config:get_all_opt('url_redirector') +if opts then + settings = lua_util.override_defaults(settings, opts) + redis_params = lua_redis.parse_redis_server('url_redirector', settings) + + if not redis_params then + rspamd_logger.infox(rspamd_config, 'no servers are specified, disabling module') + lua_util.disable_module(N, "redis") + else + + if not settings.redirector_hosts_map then + rspamd_logger.infox(rspamd_config, 'no redirector_hosts_map option is specified, disabling module') + lua_util.disable_module(N, "config") + else + local lua_maps = require "lua_maps" + settings.redirector_hosts_map = lua_maps.map_add_from_ucl(settings.redirector_hosts_map, + 'set', 'Redirectors definitions') + + lua_redis.register_prefix(settings.key_prefix .. '[a-z0-9]{32}', N, + 'URL redirector hashes', { + type = 'string', + }) + if settings.top_urls_key then + lua_redis.register_prefix(settings.top_urls_key, N, + 'URL redirector top urls', { + type = 'zlist', + }) + end + local id = rspamd_config:register_symbol { + name = 'URL_REDIRECTOR_CHECK', + type = 'callback,prefilter', + priority = lua_util.symbols_priorities.medium, + callback = url_redirector_handler, + -- In fact, the real timeout is nested_limit * timeout... + augmentations = { string.format("timeout=%f", settings.timeout) } + } + + rspamd_config:register_symbol { + name = settings.redirector_symbol_nested, + type = 'virtual', + parent = id, + score = 0, + } + + if settings.redirector_symbol then + rspamd_config:register_symbol { + name = settings.redirector_symbol, + type = 'virtual', + parent = id, + score = 0, + } + end + end + end +end diff --git a/src/plugins/lua/whitelist.lua b/src/plugins/lua/whitelist.lua new file mode 100644 index 0000000..fa76da8 --- /dev/null +++ b/src/plugins/lua/whitelist.lua @@ -0,0 +1,443 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if confighelp then + return +end + +local rspamd_logger = require "rspamd_logger" +local rspamd_util = require "rspamd_util" +local fun = require "fun" +local lua_util = require "lua_util" + +local N = "whitelist" + +local options = { + dmarc_allow_symbol = 'DMARC_POLICY_ALLOW', + spf_allow_symbol = 'R_SPF_ALLOW', + dkim_allow_symbol = 'R_DKIM_ALLOW', + check_local = false, + check_authed = false, + rules = {} +} + +local E = {} + +local function whitelist_cb(symbol, rule, task) + + local domains = {} + + local function find_domain(dom, check) + local mult + local how = 'wl' + + -- Can be overridden + if rule.blacklist then + how = 'bl' + end + + local function parse_val(val) + local how_override + -- Strict is 'special' + if rule.strict then + how_override = 'both' + end + if val then + lua_util.debugm(N, task, "found whitelist key: %s=%s", dom, val) + if val == '' then + return (how_override or how), 1.0 + elseif val:match('^bl:') then + return (how_override or 'bl'), (tonumber(val:sub(4)) or 1.0) + elseif val:match('^wl:') then + return (how_override or 'wl'), (tonumber(val:sub(4)) or 1.0) + elseif val:match('^both:') then + return (how_override or 'both'), (tonumber(val:sub(6)) or 1.0) + else + return (how_override or how), (tonumber(val) or 1.0) + end + end + + return (how_override or how), 1.0 + end + + if rule['map'] then + local val = rule['map']:get_key(dom) + if val then + how, mult = parse_val(val) + + if not domains[check] then + domains[check] = {} + end + + domains[check] = { + [dom] = { how, mult } + } + + lua_util.debugm(N, task, "final result: %s: %s->%s", + dom, how, mult) + return true, mult, how + end + elseif rule['maps'] then + for _, v in pairs(rule['maps']) do + local map = v.map + if map then + local val = map:get_key(dom) + if val then + how, mult = parse_val(val) + + if not domains[check] then + domains[check] = {} + end + + domains[check] = { + [dom] = { how, mult } + } + + lua_util.debugm(N, task, "final result: %s: %s->%s", + dom, how, mult) + return true, mult, how + end + end + end + else + mult = rule['domains'][dom] + if mult then + if not domains[check] then + domains[check] = {} + end + + domains[check] = { + [dom] = { how, mult } + } + + return true, mult, how + end + end + + return false, 0.0, how + end + + local spf_violated = false + local dmarc_violated = false + local dkim_violated = false + local ip_addr = task:get_ip() + + if rule.valid_spf then + if not task:has_symbol(options['spf_allow_symbol']) then + -- Not whitelisted + spf_violated = true + end + -- Now we can check from domain or helo + local from = task:get_from(1) + + if ((from or E)[1] or E).domain then + local tld = rspamd_util.get_tld(from[1]['domain']) + + if tld then + find_domain(tld, 'spf') + end + else + local helo = task:get_helo() + + if helo then + local tld = rspamd_util.get_tld(helo) + + if tld then + find_domain(tld, 'spf') + end + end + end + end + + if rule.valid_dkim then + if task:has_symbol('DKIM_TRACE') then + local sym = task:get_symbol('DKIM_TRACE') + local dkim_opts = sym[1]['options'] + if dkim_opts then + fun.each(function(val) + if val[2] == '+' then + local tld = rspamd_util.get_tld(val[1]) + find_domain(tld, 'dkim_success') + elseif val[2] == '-' then + local tld = rspamd_util.get_tld(val[1]) + find_domain(tld, 'dkim_fail') + end + end, + fun.map(function(s) + return lua_util.rspamd_str_split(s, ':') + end, dkim_opts)) + end + end + end + + if rule.valid_dmarc then + if not task:has_symbol(options.dmarc_allow_symbol) then + dmarc_violated = true + end + + local from = task:get_from(2) + + if ((from or E)[1] or E).domain then + local tld = rspamd_util.get_tld(from[1]['domain']) + + if tld then + local found = find_domain(tld, 'dmarc') + if not found then + find_domain(from[1]['domain'], 'dmarc') + end + end + end + end + + local final_mult = 1.0 + local found_wl, found_bl = false, false + local opts = {} + + if rule.valid_dkim then + dkim_violated = true + + for dom, val in pairs(domains.dkim_success or E) do + if val[1] == 'wl' or val[1] == 'both' then + -- We have valid and whitelisted signature + table.insert(opts, dom .. ':d:+') + found_wl = true + dkim_violated = false + + if not found_bl then + final_mult = val[2] + end + end + end + + -- Blacklist counterpart + for dom, val in pairs(domains.dkim_fail or E) do + if val[1] == 'bl' or val[1] == 'both' then + -- We have valid and whitelisted signature + table.insert(opts, dom .. ':d:-') + found_bl = true + final_mult = val[2] + else + -- Even in the case of whitelisting we need to indicate dkim failure + dkim_violated = true + end + end + end + + local function check_domain_violation(what, dom, val, violated) + if violated then + if val[1] == 'both' or val[1] == 'bl' then + found_bl = true + final_mult = val[2] + table.insert(opts, string.format("%s:%s:-", dom, what)) + end + else + if val[1] == 'both' or val[1] == 'wl' then + found_wl = true + table.insert(opts, string.format("%s:%s:+", dom, what)) + if not found_bl then + final_mult = val[2] + end + end + end + end + + if rule.valid_dmarc then + + found_wl = false + + for dom, val in pairs(domains.dmarc or E) do + check_domain_violation('D', dom, val, + (dmarc_violated or dkim_violated)) + end + end + + if rule.valid_spf then + found_wl = false + + for dom, val in pairs(domains.spf or E) do + check_domain_violation('s', dom, val, + (spf_violated or dkim_violated)) + end + end + + lua_util.debugm(N, task, "final mult: %s", final_mult) + + local function add_symbol(violated, mult) + local sym = symbol + + if violated then + if rule.inverse_symbol then + sym = rule.inverse_symbol + elseif not rule.blacklist then + mult = -mult + end + + if rule.inverse_multiplier then + mult = mult * rule.inverse_multiplier + end + + task:insert_result(sym, mult, opts) + else + task:insert_result(sym, mult, opts) + end + end + + if found_bl then + if not ((not options.check_authed and task:get_user()) or + (not options.check_local and ip_addr and ip_addr:is_local())) then + add_symbol(true, final_mult) + else + if rule.valid_spf or rule.valid_dmarc then + rspamd_logger.infox(task, "skip DMARC/SPF blacklists for local networks and/or authorized users") + else + add_symbol(true, final_mult) + end + end + elseif found_wl then + add_symbol(false, final_mult) + end + +end + +local function gen_whitelist_cb(symbol, rule) + return function(task) + whitelist_cb(symbol, rule, task) + end +end + +local configure_whitelist_module = function() + local opts = rspamd_config:get_all_opt('whitelist') + if opts then + for k, v in pairs(opts) do + options[k] = v + end + + local auth_and_local_conf = lua_util.config_check_local_or_authed(rspamd_config, N, + false, false) + options.check_local = auth_and_local_conf[1] + options.check_authed = auth_and_local_conf[2] + else + rspamd_logger.infox(rspamd_config, 'Module is unconfigured') + return + end + + if options['rules'] then + fun.each(function(symbol, rule) + if rule['domains'] then + if type(rule['domains']) == 'string' then + rule['map'] = rspamd_config:add_map { + url = rule['domains'], + description = "Whitelist map for " .. symbol, + type = 'map' + } + elseif type(rule['domains']) == 'table' then + -- Transform ['domain1', 'domain2' ...] to indexes: + -- {'domain1' = 1, 'domain2' = 1 ...] + local is_domains_list = fun.all(function(v) + if type(v) == 'table' then + return true + elseif type(v) == 'string' and not (string.match(v, '^https?://') or + string.match(v, '^ftp://') or string.match(v, '^[./]')) then + return true + end + + return false + end, rule.domains) + + if is_domains_list then + rule['domains'] = fun.tomap(fun.map(function(d) + if type(d) == 'table' then + return d[1], d[2] + end + + return d, 1.0 + end, rule['domains'])) + else + rule['map'] = rspamd_config:add_map { + url = rule['domains'], + description = "Whitelist map for " .. symbol, + type = 'map' + } + end + else + rspamd_logger.errx(rspamd_config, 'whitelist %s has bad "domains" value', + symbol) + return + end + + local flags = 'nice,empty' + if rule['blacklist'] then + flags = 'empty' + end + + local id = rspamd_config:register_symbol({ + name = symbol, + flags = flags, + callback = gen_whitelist_cb(symbol, rule), + score = rule.score or 0, + }) + + if rule.inverse_symbol then + rspamd_config:register_symbol({ + name = rule.inverse_symbol, + type = 'virtual', + parent = id, + score = rule.score and -(rule.score) or 0, + }) + end + + local spf_dep = false + local dkim_dep = false + if rule['valid_spf'] then + rspamd_config:register_dependency(symbol, options['spf_allow_symbol']) + spf_dep = true + end + if rule['valid_dkim'] then + rspamd_config:register_dependency(symbol, options['dkim_allow_symbol']) + dkim_dep = true + end + if rule['valid_dmarc'] then + if not spf_dep then + rspamd_config:register_dependency(symbol, options['spf_allow_symbol']) + end + if not dkim_dep then + rspamd_config:register_dependency(symbol, options['dkim_allow_symbol']) + end + rspamd_config:register_dependency(symbol, 'DMARC_CALLBACK') + end + + if rule['score'] then + if not rule['group'] then + rule['group'] = 'whitelist' + end + rule['name'] = symbol + rspamd_config:set_metric_symbol(rule) + + if rule.inverse_symbol then + local inv_rule = lua_util.shallowcopy(rule) + inv_rule.name = rule.inverse_symbol + inv_rule.score = -rule.score + rspamd_config:set_metric_symbol(inv_rule) + end + end + end + end, options['rules']) + else + lua_util.disable_module(N, "config") + end +end + +configure_whitelist_module() diff --git a/src/plugins/regexp.c b/src/plugins/regexp.c new file mode 100644 index 0000000..59a84c5 --- /dev/null +++ b/src/plugins/regexp.c @@ -0,0 +1,564 @@ +/* + * Copyright 2023 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/***MODULE:regexp + * rspamd module that implements different regexp rules + */ + + +#include "config.h" +#include "libmime/message.h" +#include "expression.h" +#include "mime_expressions.h" +#include "libserver/maps/map.h" +#include "lua/lua_common.h" + +static const guint64 rspamd_regexp_cb_magic = 0xca9d9649fc3e2659ULL; + +struct regexp_module_item { + guint64 magic; + struct rspamd_expression *expr; + const gchar *symbol; + struct ucl_lua_funcdata *lua_function; +}; + +struct regexp_ctx { + struct module_ctx ctx; + gsize max_size; +}; + +static void process_regexp_item(struct rspamd_task *task, + struct rspamd_symcache_dynamic_item *item, + void *user_data); + + +/* Initialization */ +gint regexp_module_init(struct rspamd_config *cfg, struct module_ctx **ctx); +gint regexp_module_config(struct rspamd_config *cfg, bool validate); +gint regexp_module_reconfig(struct rspamd_config *cfg); + +module_t regexp_module = { + "regexp", + regexp_module_init, + regexp_module_config, + regexp_module_reconfig, + NULL, + RSPAMD_MODULE_VER, + (guint) -1, +}; + + +static inline struct regexp_ctx * +regexp_get_context(struct rspamd_config *cfg) +{ + return (struct regexp_ctx *) g_ptr_array_index(cfg->c_modules, + regexp_module.ctx_offset); +} + +/* Process regexp expression */ +static gboolean +read_regexp_expression(rspamd_mempool_t *pool, + struct regexp_module_item *chain, + const gchar *symbol, + const gchar *line, + struct rspamd_mime_expr_ud *ud) +{ + struct rspamd_expression *e = NULL; + GError *err = NULL; + + if (!rspamd_parse_expression(line, 0, &mime_expr_subr, ud, pool, &err, + &e)) { + msg_warn_pool("%s = \"%s\" is invalid regexp expression: %e", symbol, + line, + err); + g_error_free(err); + + return FALSE; + } + + g_assert(e != NULL); + chain->expr = e; + + return TRUE; +} + + +/* Init function */ +gint regexp_module_init(struct rspamd_config *cfg, struct module_ctx **ctx) +{ + struct regexp_ctx *regexp_module_ctx; + + regexp_module_ctx = rspamd_mempool_alloc0(cfg->cfg_pool, + sizeof(*regexp_module_ctx)); + + *ctx = (struct module_ctx *) regexp_module_ctx; + + rspamd_rcl_add_doc_by_path(cfg, + NULL, + "Regular expressions rules plugin", + "regexp", + UCL_OBJECT, + NULL, + 0, + NULL, + 0); + + rspamd_rcl_add_doc_by_path(cfg, + "regexp", + "Maximum size of data chunk scanned with any regexp (further data is truncated)", + "max_size", + UCL_INT, + NULL, + 0, + NULL, + 0); + + return 0; +} + +gint regexp_module_config(struct rspamd_config *cfg, bool validate) +{ + struct regexp_ctx *regexp_module_ctx = regexp_get_context(cfg); + struct regexp_module_item *cur_item = NULL; + const ucl_object_t *sec, *value, *elt; + ucl_object_iter_t it = NULL; + gint res = TRUE, nre = 0, nlua = 0, nshots = cfg->default_max_shots; + + if (!rspamd_config_is_module_enabled(cfg, "regexp")) { + return TRUE; + } + + sec = ucl_object_lookup(cfg->cfg_ucl_obj, "regexp"); + if (sec == NULL) { + msg_err_config("regexp module enabled, but no rules are defined"); + return TRUE; + } + + regexp_module_ctx->max_size = 0; + + while ((value = ucl_object_iterate(sec, &it, true)) != NULL) { + if (g_ascii_strncasecmp(ucl_object_key(value), "max_size", + sizeof("max_size") - 1) == 0) { + regexp_module_ctx->max_size = ucl_obj_toint(value); + rspamd_re_cache_set_limit(cfg->re_cache, regexp_module_ctx->max_size); + } + else if (g_ascii_strncasecmp(ucl_object_key(value), "max_threads", + sizeof("max_threads") - 1) == 0) { + msg_warn_config("regexp module is now single threaded, max_threads is ignored"); + } + else if (value->type == UCL_STRING) { + struct rspamd_mime_expr_ud ud; + + cur_item = rspamd_mempool_alloc0(cfg->cfg_pool, + sizeof(struct regexp_module_item)); + cur_item->symbol = ucl_object_key(value); + cur_item->magic = rspamd_regexp_cb_magic; + + ud.conf_obj = NULL; + ud.cfg = cfg; + + if (!read_regexp_expression(cfg->cfg_pool, + cur_item, ucl_object_key(value), + ucl_obj_tostring(value), &ud)) { + if (validate) { + return FALSE; + } + } + else { + rspamd_symcache_add_symbol(cfg->cache, + cur_item->symbol, + 0, + process_regexp_item, + cur_item, + SYMBOL_TYPE_NORMAL, -1); + nre++; + } + } + else if (value->type == UCL_USERDATA) { + /* Just a lua function */ + cur_item = rspamd_mempool_alloc0(cfg->cfg_pool, + sizeof(struct regexp_module_item)); + cur_item->magic = rspamd_regexp_cb_magic; + cur_item->symbol = ucl_object_key(value); + cur_item->lua_function = ucl_object_toclosure(value); + + rspamd_symcache_add_symbol(cfg->cache, + cur_item->symbol, + 0, + process_regexp_item, + cur_item, + SYMBOL_TYPE_NORMAL, -1); + nlua++; + } + else if (value->type == UCL_OBJECT) { + const gchar *description = NULL, *group = NULL; + gdouble score = 0.0; + guint flags = 0, priority = 0; + gboolean is_lua = FALSE, valid_expression = TRUE; + struct rspamd_mime_expr_ud ud; + + /* We have some lua table, extract its arguments */ + elt = ucl_object_lookup(value, "callback"); + + if (elt == NULL || elt->type != UCL_USERDATA) { + + /* Try plain regexp expression */ + elt = ucl_object_lookup_any(value, "regexp", "re", NULL); + + if (elt != NULL && ucl_object_type(elt) == UCL_STRING) { + cur_item = rspamd_mempool_alloc0(cfg->cfg_pool, + sizeof(struct regexp_module_item)); + cur_item->symbol = ucl_object_key(value); + cur_item->magic = rspamd_regexp_cb_magic; + ud.cfg = cfg; + ud.conf_obj = value; + + if (!read_regexp_expression(cfg->cfg_pool, + cur_item, ucl_object_key(value), + ucl_obj_tostring(elt), &ud)) { + if (validate) { + return FALSE; + } + } + else { + valid_expression = TRUE; + nre++; + } + } + else { + msg_err_config( + "no callback/expression defined for regexp symbol: " + "%s", + ucl_object_key(value)); + } + } + else { + is_lua = TRUE; + nlua++; + cur_item = rspamd_mempool_alloc0( + cfg->cfg_pool, + sizeof(struct regexp_module_item)); + cur_item->magic = rspamd_regexp_cb_magic; + cur_item->symbol = ucl_object_key(value); + cur_item->lua_function = ucl_object_toclosure(value); + } + + if (cur_item && (is_lua || valid_expression)) { + + flags = SYMBOL_TYPE_NORMAL; + elt = ucl_object_lookup(value, "mime_only"); + + if (elt) { + if (ucl_object_type(elt) != UCL_BOOLEAN) { + msg_err_config( + "mime_only attribute is not boolean for symbol: '%s'", + cur_item->symbol); + + if (validate) { + return FALSE; + } + } + else { + if (ucl_object_toboolean(elt)) { + flags |= SYMBOL_TYPE_MIME_ONLY; + } + } + } + + rspamd_symcache_add_symbol(cfg->cache, + cur_item->symbol, + 0, + process_regexp_item, + cur_item, + flags, -1); + + /* Reset flags */ + flags = 0; + + elt = ucl_object_lookup(value, "condition"); + + if (elt != NULL && ucl_object_type(elt) == UCL_USERDATA) { + struct ucl_lua_funcdata *conddata; + + g_assert(cur_item->symbol != NULL); + conddata = ucl_object_toclosure(elt); + rspamd_symcache_add_condition_delayed(cfg->cache, + cur_item->symbol, + conddata->L, conddata->idx); + } + + elt = ucl_object_lookup(value, "description"); + + if (elt) { + description = ucl_object_tostring(elt); + } + + elt = ucl_object_lookup(value, "group"); + + if (elt) { + group = ucl_object_tostring(elt); + } + + elt = ucl_object_lookup(value, "score"); + + if (elt) { + if (ucl_object_type(elt) != UCL_FLOAT && ucl_object_type(elt) != UCL_INT) { + msg_err_config( + "score attribute is not numeric for symbol: '%s'", + cur_item->symbol); + + if (validate) { + return FALSE; + } + } + else { + score = ucl_object_todouble(elt); + } + } + + elt = ucl_object_lookup(value, "one_shot"); + + if (elt) { + if (ucl_object_type(elt) != UCL_BOOLEAN) { + msg_err_config( + "one_shot attribute is not boolean for symbol: '%s'", + cur_item->symbol); + + if (validate) { + return FALSE; + } + } + else { + if (ucl_object_toboolean(elt)) { + nshots = 1; + } + } + } + + if ((elt = ucl_object_lookup(value, "any_shot")) != NULL) { + if (ucl_object_type(elt) != UCL_BOOLEAN) { + msg_err_config( + "any_shot attribute is not boolean for symbol: '%s'", + cur_item->symbol); + + if (validate) { + return FALSE; + } + } + else { + if (ucl_object_toboolean(elt)) { + nshots = -1; + } + } + } + + if ((elt = ucl_object_lookup(value, "nshots")) != NULL) { + if (ucl_object_type(elt) != UCL_FLOAT && ucl_object_type(elt) != UCL_INT) { + msg_err_config( + "nshots attribute is not numeric for symbol: '%s'", + cur_item->symbol); + + if (validate) { + return FALSE; + } + } + else { + nshots = ucl_object_toint(elt); + } + } + + elt = ucl_object_lookup(value, "one_param"); + + if (elt) { + if (ucl_object_type(elt) != UCL_BOOLEAN) { + msg_err_config( + "one_param attribute is not boolean for symbol: '%s'", + cur_item->symbol); + + if (validate) { + return FALSE; + } + } + else { + if (ucl_object_toboolean(elt)) { + flags |= RSPAMD_SYMBOL_FLAG_ONEPARAM; + } + } + } + + elt = ucl_object_lookup(value, "priority"); + + if (elt) { + if (ucl_object_type(elt) != UCL_FLOAT && ucl_object_type(elt) != UCL_INT) { + msg_err_config( + "priority attribute is not numeric for symbol: '%s'", + cur_item->symbol); + + if (validate) { + return FALSE; + } + } + else { + priority = ucl_object_toint(elt); + } + } + else { + priority = 0; + } + + rspamd_config_add_symbol(cfg, cur_item->symbol, + score, description, group, flags, priority, nshots); + + elt = ucl_object_lookup(value, "groups"); + + if (elt) { + ucl_object_iter_t gr_it; + const ucl_object_t *cur_gr; + + gr_it = ucl_object_iterate_new(elt); + + while ((cur_gr = ucl_object_iterate_safe(gr_it, true)) != NULL) { + rspamd_config_add_symbol_group(cfg, cur_item->symbol, + ucl_object_tostring(cur_gr)); + } + + ucl_object_iterate_free(gr_it); + } + } + } + else { + msg_warn_config("unknown type of attribute %s for regexp module", + ucl_object_key(value)); + } + } + + if (res) { + msg_info_config("init internal regexp module, %d regexp rules and %d " + "lua rules are loaded", + nre, nlua); + } + else { + msg_err_config("fatal regexp module error"); + } + + return res; +} + +gint regexp_module_reconfig(struct rspamd_config *cfg) +{ + return regexp_module_config(cfg, false); +} + +static gboolean +rspamd_lua_call_expression_func(struct ucl_lua_funcdata *lua_data, + struct rspamd_task *task, + GArray *args, gdouble *res, + const gchar *symbol) +{ + lua_State *L = lua_data->L; + struct rspamd_task **ptask; + struct expression_argument *arg; + gint pop = 0, i, nargs = 0; + + lua_rawgeti(L, LUA_REGISTRYINDEX, lua_data->idx); + /* Now we got function in top of stack */ + ptask = lua_newuserdata(L, sizeof(struct rspamd_task *)); + rspamd_lua_setclass(L, "rspamd{task}", -1); + *ptask = task; + + /* Now push all arguments */ + if (args) { + for (i = 0; i < (gint) args->len; i++) { + arg = &g_array_index(args, struct expression_argument, i); + if (arg) { + switch (arg->type) { + case EXPRESSION_ARGUMENT_NORMAL: + lua_pushstring(L, (const gchar *) arg->data); + break; + case EXPRESSION_ARGUMENT_BOOL: + lua_pushboolean(L, (gboolean) GPOINTER_TO_SIZE(arg->data)); + break; + default: + msg_err_task("%s: cannot pass custom params to lua function", + symbol); + return FALSE; + } + } + } + nargs = args->len; + } + + if (lua_pcall(L, nargs + 1, 1, 0) != 0) { + msg_info_task("%s: call to lua function failed: %s", symbol, + lua_tostring(L, -1)); + lua_pop(L, 1); + + return FALSE; + } + + pop++; + + if (lua_type(L, -1) == LUA_TNUMBER) { + *res = lua_tonumber(L, -1); + } + else if (lua_type(L, -1) == LUA_TBOOLEAN) { + *res = lua_toboolean(L, -1); + } + else { + msg_info_task("%s: lua function must return a boolean", symbol); + *res = FALSE; + } + + lua_pop(L, pop); + + return TRUE; +} + + +static void +process_regexp_item(struct rspamd_task *task, + struct rspamd_symcache_dynamic_item *symcache_item, + void *user_data) +{ + struct regexp_module_item *item = user_data; + gdouble res = FALSE; + + /* Non-threaded version */ + if (item->lua_function) { + /* Just call function */ + res = FALSE; + if (!rspamd_lua_call_expression_func(item->lua_function, task, NULL, + &res, item->symbol)) { + msg_err_task("error occurred when checking symbol %s", + item->symbol); + } + } + else { + /* Process expression */ + if (item->expr) { + res = rspamd_process_expression(item->expr, 0, task); + } + else { + msg_warn_task("FIXME: %s symbol is broken with new expressions", + item->symbol); + } + } + + if (res != 0) { + rspamd_task_insert_result(task, item->symbol, res, NULL); + } + + rspamd_symcache_finalize_item(task, symcache_item); +} |