diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-10 21:30:40 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-10 21:30:40 +0000 |
commit | 133a45c109da5310add55824db21af5239951f93 (patch) | |
tree | ba6ac4c0a950a0dda56451944315d66409923918 /src/libstat | |
parent | Initial commit. (diff) | |
download | rspamd-133a45c109da5310add55824db21af5239951f93.tar.xz rspamd-133a45c109da5310add55824db21af5239951f93.zip |
Adding upstream version 3.8.1.upstream/3.8.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/libstat')
-rw-r--r-- | src/libstat/CMakeLists.txt | 25 | ||||
-rw-r--r-- | src/libstat/backends/backends.h | 127 | ||||
-rw-r--r-- | src/libstat/backends/cdb_backend.cxx | 491 | ||||
-rw-r--r-- | src/libstat/backends/http_backend.cxx | 440 | ||||
-rw-r--r-- | src/libstat/backends/mmaped_file.c | 1113 | ||||
-rw-r--r-- | src/libstat/backends/redis_backend.cxx | 1132 | ||||
-rw-r--r-- | src/libstat/backends/sqlite3_backend.c | 907 | ||||
-rw-r--r-- | src/libstat/classifiers/bayes.c | 551 | ||||
-rw-r--r-- | src/libstat/classifiers/classifiers.h | 109 | ||||
-rw-r--r-- | src/libstat/classifiers/lua_classifier.c | 237 | ||||
-rw-r--r-- | src/libstat/learn_cache/learn_cache.h | 79 | ||||
-rw-r--r-- | src/libstat/learn_cache/redis_cache.cxx | 254 | ||||
-rw-r--r-- | src/libstat/learn_cache/sqlite3_cache.c | 274 | ||||
-rw-r--r-- | src/libstat/stat_api.h | 147 | ||||
-rw-r--r-- | src/libstat/stat_config.c | 603 | ||||
-rw-r--r-- | src/libstat/stat_internal.h | 134 | ||||
-rw-r--r-- | src/libstat/stat_process.c | 1250 | ||||
-rw-r--r-- | src/libstat/tokenizers/osb.c | 424 | ||||
-rw-r--r-- | src/libstat/tokenizers/tokenizers.c | 955 | ||||
-rw-r--r-- | src/libstat/tokenizers/tokenizers.h | 100 |
20 files changed, 9352 insertions, 0 deletions
diff --git a/src/libstat/CMakeLists.txt b/src/libstat/CMakeLists.txt new file mode 100644 index 0000000..64d572a --- /dev/null +++ b/src/libstat/CMakeLists.txt @@ -0,0 +1,25 @@ +# Librspamdserver +SET(LIBSTATSRC ${CMAKE_CURRENT_SOURCE_DIR}/stat_config.c + ${CMAKE_CURRENT_SOURCE_DIR}/stat_process.c) + +SET(TOKENIZERSSRC ${CMAKE_CURRENT_SOURCE_DIR}/tokenizers/tokenizers.c + ${CMAKE_CURRENT_SOURCE_DIR}/tokenizers/osb.c) + +SET(CLASSIFIERSSRC ${CMAKE_CURRENT_SOURCE_DIR}/classifiers/bayes.c + ${CMAKE_CURRENT_SOURCE_DIR}/classifiers/lua_classifier.c) + +SET(BACKENDSSRC ${CMAKE_CURRENT_SOURCE_DIR}/backends/mmaped_file.c + ${CMAKE_CURRENT_SOURCE_DIR}/backends/sqlite3_backend.c + ${CMAKE_CURRENT_SOURCE_DIR}/backends/cdb_backend.cxx + ${CMAKE_CURRENT_SOURCE_DIR}/backends/http_backend.cxx + ${CMAKE_CURRENT_SOURCE_DIR}/backends/redis_backend.cxx) + +SET(CACHESSRC ${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/sqlite3_cache.c + ${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/redis_cache.cxx) + +SET(RSPAMD_STAT ${LIBSTATSRC} + ${TOKENIZERSSRC} + ${CLASSIFIERSSRC} + ${BACKENDSSRC} + ${CACHESSRC} PARENT_SCOPE) + diff --git a/src/libstat/backends/backends.h b/src/libstat/backends/backends.h new file mode 100644 index 0000000..4b16950 --- /dev/null +++ b/src/libstat/backends/backends.h @@ -0,0 +1,127 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef BACKENDS_H_ +#define BACKENDS_H_ + +#include "config.h" +#include "ucl.h" + +#define RSPAMD_DEFAULT_BACKEND "mmap" + +#ifdef __cplusplus +extern "C" { +#endif + +/* Forwarded declarations */ +struct rspamd_classifier_config; +struct rspamd_statfile_config; +struct rspamd_config; +struct rspamd_stat_ctx; +struct rspamd_token_result; +struct rspamd_statfile; +struct rspamd_task; + +struct rspamd_stat_backend { + const char *name; + bool read_only; + + gpointer (*init)(struct rspamd_stat_ctx *ctx, struct rspamd_config *cfg, + struct rspamd_statfile *st); + + gpointer (*runtime)(struct rspamd_task *task, + struct rspamd_statfile_config *stcf, + gboolean learn, gpointer ctx, + gint id); + + gboolean (*process_tokens)(struct rspamd_task *task, GPtrArray *tokens, + gint id, + gpointer ctx); + + gboolean (*finalize_process)(struct rspamd_task *task, + gpointer runtime, gpointer ctx); + + gboolean (*learn_tokens)(struct rspamd_task *task, GPtrArray *tokens, + gint id, + gpointer ctx); + + gulong (*total_learns)(struct rspamd_task *task, + gpointer runtime, gpointer ctx); + + gboolean (*finalize_learn)(struct rspamd_task *task, + gpointer runtime, gpointer ctx, GError **err); + + gulong (*inc_learns)(struct rspamd_task *task, + gpointer runtime, gpointer ctx); + + gulong (*dec_learns)(struct rspamd_task *task, + gpointer runtime, gpointer ctx); + + ucl_object_t *(*get_stat)(gpointer runtime, gpointer ctx); + + void (*close)(gpointer ctx); + + gpointer (*load_tokenizer_config)(gpointer runtime, gsize *sz); + + gpointer ctx; +}; + +#define RSPAMD_STAT_BACKEND_DEF(name) \ + gpointer rspamd_##name##_init(struct rspamd_stat_ctx *ctx, \ + struct rspamd_config *cfg, struct rspamd_statfile *st); \ + gpointer rspamd_##name##_runtime(struct rspamd_task *task, \ + struct rspamd_statfile_config *stcf, \ + gboolean learn, gpointer ctx, gint id); \ + gboolean rspamd_##name##_process_tokens(struct rspamd_task *task, \ + GPtrArray *tokens, gint id, \ + gpointer runtime); \ + gboolean rspamd_##name##_finalize_process(struct rspamd_task *task, \ + gpointer runtime, \ + gpointer ctx); \ + gboolean rspamd_##name##_learn_tokens(struct rspamd_task *task, \ + GPtrArray *tokens, gint id, \ + gpointer runtime); \ + gboolean rspamd_##name##_finalize_learn(struct rspamd_task *task, \ + gpointer runtime, \ + gpointer ctx, GError **err); \ + gulong rspamd_##name##_total_learns(struct rspamd_task *task, \ + gpointer runtime, \ + gpointer ctx); \ + gulong rspamd_##name##_inc_learns(struct rspamd_task *task, \ + gpointer runtime, \ + gpointer ctx); \ + gulong rspamd_##name##_dec_learns(struct rspamd_task *task, \ + gpointer runtime, \ + gpointer ctx); \ + gulong rspamd_##name##_learns(struct rspamd_task *task, \ + gpointer runtime, \ + gpointer ctx); \ + ucl_object_t *rspamd_##name##_get_stat(gpointer runtime, \ + gpointer ctx); \ + gpointer rspamd_##name##_load_tokenizer_config(gpointer runtime, \ + gsize *len); \ + void rspamd_##name##_close(gpointer ctx) + +RSPAMD_STAT_BACKEND_DEF(mmaped_file); +RSPAMD_STAT_BACKEND_DEF(sqlite3); +RSPAMD_STAT_BACKEND_DEF(cdb); +RSPAMD_STAT_BACKEND_DEF(redis); +RSPAMD_STAT_BACKEND_DEF(http); + +#ifdef __cplusplus +} +#endif + +#endif /* BACKENDS_H_ */ diff --git a/src/libstat/backends/cdb_backend.cxx b/src/libstat/backends/cdb_backend.cxx new file mode 100644 index 0000000..81d87f3 --- /dev/null +++ b/src/libstat/backends/cdb_backend.cxx @@ -0,0 +1,491 @@ +/*- + * Copyright 2021 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * CDB read only statistics backend + */ + +#include "config.h" +#include "stat_internal.h" +#include "contrib/cdb/cdb.h" + +#include <utility> +#include <memory> +#include <string> +#include <optional> +#include "contrib/expected/expected.hpp" +#include "contrib/ankerl/unordered_dense.h" +#include "fmt/core.h" + +namespace rspamd::stat::cdb { + +/* + * Utility class to share cdb instances over statfiles instances, as each + * cdb has tokens for both ham and spam classes + */ +class cdb_shared_storage { +public: + using cdb_element_t = std::shared_ptr<struct cdb>; + cdb_shared_storage() = default; + + auto get_cdb(const char *path) const -> std::optional<cdb_element_t> + { + auto found = elts.find(path); + + if (found != elts.end()) { + if (!found->second.expired()) { + return found->second.lock(); + } + } + + return std::nullopt; + } + /* Create a new smart pointer over POD cdb structure */ + static auto new_cdb() -> cdb_element_t + { + auto ret = cdb_element_t(new struct cdb, cdb_deleter()); + memset(ret.get(), 0, sizeof(struct cdb)); + return ret; + } + /* Enclose cdb into storage */ + auto push_cdb(const char *path, cdb_element_t cdbp) -> cdb_element_t + { + auto found = elts.find(path); + + if (found != elts.end()) { + if (found->second.expired()) { + /* OK, move in lieu of the expired weak pointer */ + + found->second = cdbp; + return cdbp; + } + else { + /* + * Existing and not expired, return the existing one + */ + return found->second.lock(); + } + } + else { + /* Not existing, make a weak ptr and return the original */ + elts.emplace(path, std::weak_ptr<struct cdb>(cdbp)); + return cdbp; + } + } + +private: + /* + * We store weak pointers here to allow owning cdb statfiles to free + * expensive cdb before this cache is terminated (e.g. on dynamic cdb reload) + */ + ankerl::unordered_dense::map<std::string, std::weak_ptr<struct cdb>> elts; + + struct cdb_deleter { + void operator()(struct cdb *c) const + { + cdb_free(c); + delete c; + } + }; +}; + +static cdb_shared_storage cdb_shared_storage; + +class ro_backend final { +public: + explicit ro_backend(struct rspamd_statfile *_st, cdb_shared_storage::cdb_element_t _db) + : st(_st), db(std::move(_db)) + { + } + ro_backend() = delete; + ro_backend(const ro_backend &) = delete; + ro_backend(ro_backend &&other) noexcept + { + *this = std::move(other); + } + ro_backend &operator=(ro_backend &&other) noexcept + { + std::swap(st, other.st); + std::swap(db, other.db); + std::swap(loaded, other.loaded); + std::swap(learns_spam, other.learns_spam); + std::swap(learns_ham, other.learns_ham); + + return *this; + } + ~ro_backend() + { + } + + auto load_cdb() -> tl::expected<bool, std::string>; + auto process_token(const rspamd_token_t *tok) const -> std::optional<float>; + constexpr auto is_spam() const -> bool + { + return st->stcf->is_spam; + } + auto get_learns() const -> std::uint64_t + { + if (is_spam()) { + return learns_spam; + } + else { + return learns_ham; + } + } + auto get_total_learns() const -> std::uint64_t + { + return learns_spam + learns_ham; + } + +private: + struct rspamd_statfile *st; + cdb_shared_storage::cdb_element_t db; + bool loaded = false; + std::uint64_t learns_spam = 0; + std::uint64_t learns_ham = 0; +}; + +template<typename T> +static inline auto +cdb_get_key_as_int64(struct cdb *cdb, T key) -> std::optional<std::int64_t> +{ + auto pos = cdb_find(cdb, (void *) &key, sizeof(key)); + + if (pos > 0) { + auto vpos = cdb_datapos(cdb); + auto vlen = cdb_datalen(cdb); + + if (vlen == sizeof(std::int64_t)) { + std::int64_t ret; + cdb_read(cdb, (void *) &ret, vlen, vpos); + + return ret; + } + } + + return std::nullopt; +} + +template<typename T> +static inline auto +cdb_get_key_as_float_pair(struct cdb *cdb, T key) -> std::optional<std::pair<float, float>> +{ + auto pos = cdb_find(cdb, (void *) &key, sizeof(key)); + + if (pos > 0) { + auto vpos = cdb_datapos(cdb); + auto vlen = cdb_datalen(cdb); + + if (vlen == sizeof(float) * 2) { + union { + struct { + float v1; + float v2; + } d; + char c[sizeof(float) * 2]; + } u; + cdb_read(cdb, (void *) u.c, vlen, vpos); + + return std::make_pair(u.d.v1, u.d.v2); + } + } + + return std::nullopt; +} + + +auto ro_backend::load_cdb() -> tl::expected<bool, std::string> +{ + if (!db) { + return tl::make_unexpected("no database loaded"); + } + + /* Now get number of learns */ + std::int64_t cdb_key; + static const char learn_spam_key[9] = "_lrnspam", learn_ham_key[9] = "_lrnham_"; + + auto check_key = [&](const char *key, std::uint64_t &target) -> tl::expected<bool, std::string> { + memcpy((void *) &cdb_key, key, sizeof(cdb_key)); + + auto maybe_value = cdb_get_key_as_int64(db.get(), cdb_key); + + if (!maybe_value) { + return tl::make_unexpected(fmt::format("missing {} key", key)); + } + + target = (std::uint64_t) maybe_value.value(); + + return true; + }; + + auto res = check_key(learn_spam_key, learns_spam); + + if (!res) { + return res; + } + + res = check_key(learn_ham_key, learns_ham); + + if (!res) { + return res; + } + + loaded = true; + + return true;// expected +} + +auto ro_backend::process_token(const rspamd_token_t *tok) const -> std::optional<float> +{ + if (!loaded) { + return std::nullopt; + } + + auto maybe_value = cdb_get_key_as_float_pair(db.get(), tok->data); + + if (maybe_value) { + auto [spam_count, ham_count] = maybe_value.value(); + + if (is_spam()) { + return spam_count; + } + else { + return ham_count; + } + } + + return std::nullopt; +} + +auto open_cdb(struct rspamd_statfile *st) -> tl::expected<ro_backend, std::string> +{ + const char *path = nullptr; + const auto *stf = st->stcf; + + auto get_filename = [](const ucl_object_t *obj) -> const char * { + const auto *filename = ucl_object_lookup_any(obj, + "filename", "path", "cdb", nullptr); + + if (filename && ucl_object_type(filename) == UCL_STRING) { + return ucl_object_tostring(filename); + } + + return nullptr; + }; + + /* First search in backend configuration */ + const auto *obj = ucl_object_lookup(st->classifier->cfg->opts, "backend"); + if (obj != NULL && ucl_object_type(obj) == UCL_OBJECT) { + path = get_filename(obj); + } + + /* Now try statfiles config */ + if (!path && stf->opts) { + path = get_filename(stf->opts); + } + + /* Now try classifier config */ + if (!path && st->classifier->cfg->opts) { + path = get_filename(st->classifier->cfg->opts); + } + + if (!path) { + return tl::make_unexpected("missing/malformed filename attribute"); + } + + auto cached_cdb_maybe = cdb_shared_storage.get_cdb(path); + cdb_shared_storage::cdb_element_t cdbp; + + if (!cached_cdb_maybe) { + + auto fd = rspamd_file_xopen(path, O_RDONLY, 0, true); + + if (fd == -1) { + return tl::make_unexpected(fmt::format("cannot open {}: {}", + path, strerror(errno))); + } + + cdbp = cdb_shared_storage::new_cdb(); + + if (cdb_init(cdbp.get(), fd) == -1) { + close(fd); + + return tl::make_unexpected(fmt::format("cannot init cdb in {}: {}", + path, strerror(errno))); + } + + cdbp = cdb_shared_storage.push_cdb(path, cdbp); + + close(fd); + } + else { + cdbp = cached_cdb_maybe.value(); + } + + if (!cdbp) { + return tl::make_unexpected(fmt::format("cannot init cdb in {}: internal error", + path)); + } + + ro_backend bk{st, std::move(cdbp)}; + + auto res = bk.load_cdb(); + + if (!res) { + return tl::make_unexpected(res.error()); + } + + return bk; +} + +}// namespace rspamd::stat::cdb + +#define CDB_FROM_RAW(p) (reinterpret_cast<rspamd::stat::cdb::ro_backend *>(p)) + +/* C exports */ +gpointer +rspamd_cdb_init(struct rspamd_stat_ctx *ctx, + struct rspamd_config *cfg, + struct rspamd_statfile *st) +{ + auto maybe_backend = rspamd::stat::cdb::open_cdb(st); + + if (maybe_backend) { + /* Move into a new pointer */ + auto *result = new rspamd::stat::cdb::ro_backend(std::move(maybe_backend.value())); + + return result; + } + else { + msg_err_config("cannot load cdb backend: %s", maybe_backend.error().c_str()); + } + + return nullptr; +} +gpointer +rspamd_cdb_runtime(struct rspamd_task *task, + struct rspamd_statfile_config *stcf, + gboolean learn, + gpointer ctx, + gint _id) +{ + /* In CDB we don't have any dynamic stuff */ + return ctx; +} + +gboolean +rspamd_cdb_process_tokens(struct rspamd_task *task, + GPtrArray *tokens, + gint id, + gpointer runtime) +{ + auto *cdbp = CDB_FROM_RAW(runtime); + bool seen_values = false; + + for (auto i = 0u; i < tokens->len; i++) { + rspamd_token_t *tok; + tok = reinterpret_cast<rspamd_token_t *>(g_ptr_array_index(tokens, i)); + + auto res = cdbp->process_token(tok); + + if (res) { + tok->values[id] = res.value(); + seen_values = true; + } + else { + tok->values[id] = 0; + } + } + + if (seen_values) { + if (cdbp->is_spam()) { + task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS; + } + else { + task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS; + } + } + + return true; +} +gboolean +rspamd_cdb_finalize_process(struct rspamd_task *task, + gpointer runtime, + gpointer ctx) +{ + return true; +} +gboolean +rspamd_cdb_learn_tokens(struct rspamd_task *task, + GPtrArray *tokens, + gint id, + gpointer ctx) +{ + return false; +} +gboolean +rspamd_cdb_finalize_learn(struct rspamd_task *task, + gpointer runtime, + gpointer ctx, + GError **err) +{ + return false; +} + +gulong rspamd_cdb_total_learns(struct rspamd_task *task, + gpointer runtime, + gpointer ctx) +{ + auto *cdbp = CDB_FROM_RAW(ctx); + return cdbp->get_total_learns(); +} +gulong +rspamd_cdb_inc_learns(struct rspamd_task *task, + gpointer runtime, + gpointer ctx) +{ + return (gulong) -1; +} +gulong +rspamd_cdb_dec_learns(struct rspamd_task *task, + gpointer runtime, + gpointer ctx) +{ + return (gulong) -1; +} +gulong +rspamd_cdb_learns(struct rspamd_task *task, + gpointer runtime, + gpointer ctx) +{ + auto *cdbp = CDB_FROM_RAW(ctx); + return cdbp->get_learns(); +} +ucl_object_t * +rspamd_cdb_get_stat(gpointer runtime, gpointer ctx) +{ + return nullptr; +} +gpointer +rspamd_cdb_load_tokenizer_config(gpointer runtime, gsize *len) +{ + return nullptr; +} +void rspamd_cdb_close(gpointer ctx) +{ + auto *cdbp = CDB_FROM_RAW(ctx); + delete cdbp; +}
\ No newline at end of file diff --git a/src/libstat/backends/http_backend.cxx b/src/libstat/backends/http_backend.cxx new file mode 100644 index 0000000..075e508 --- /dev/null +++ b/src/libstat/backends/http_backend.cxx @@ -0,0 +1,440 @@ +/*- + * Copyright 2022 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "config.h" +#include "stat_internal.h" +#include "libserver/http/http_connection.h" +#include "libserver/mempool_vars_internal.h" +#include "upstream.h" +#include "contrib/ankerl/unordered_dense.h" +#include <algorithm> +#include <vector> + +namespace rspamd::stat::http { + +#define msg_debug_stat_http(...) rspamd_conditional_debug_fast(NULL, NULL, \ + rspamd_stat_http_log_id, "stat_http", task->task_pool->tag.uid, \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) + +INIT_LOG_MODULE(stat_http) + +/* Represents all http backends defined in some configuration */ +class http_backends_collection { + std::vector<struct rspamd_statfile *> backends; + double timeout = 1.0; /* Default timeout */ + struct upstream_list *read_servers = nullptr; + struct upstream_list *write_servers = nullptr; + +public: + static auto get() -> http_backends_collection & + { + static http_backends_collection *singleton = nullptr; + + if (singleton == nullptr) { + singleton = new http_backends_collection; + } + + return *singleton; + } + + /** + * Add a new backend and (optionally initialize the basic backend parameters + * @param ctx + * @param cfg + * @param st + * @return + */ + auto add_backend(struct rspamd_stat_ctx *ctx, + struct rspamd_config *cfg, + struct rspamd_statfile *st) -> bool; + /** + * Remove a statfile cleaning things up if the last statfile is removed + * @param st + * @return + */ + auto remove_backend(struct rspamd_statfile *st) -> bool; + + upstream *get_upstream(bool is_learn); + +private: + http_backends_collection() = default; + auto first_init(struct rspamd_stat_ctx *ctx, + struct rspamd_config *cfg, + struct rspamd_statfile *st) -> bool; +}; + +/* + * Created one per each task + */ +class http_backend_runtime final { +public: + static auto create(struct rspamd_task *task, bool is_learn) -> http_backend_runtime *; + /* Add a new statfile with a specific id to the list of statfiles */ + auto notice_statfile(int id, const struct rspamd_statfile_config *st) -> void + { + seen_statfiles[id] = st; + } + + auto process_tokens(struct rspamd_task *task, + GPtrArray *tokens, + gint id, + bool learn) -> bool; + +private: + http_backends_collection *all_backends; + ankerl::unordered_dense::map<int, const struct rspamd_statfile_config *> seen_statfiles; + struct upstream *selected; + +private: + http_backend_runtime(struct rspamd_task *task, bool is_learn) + : all_backends(&http_backends_collection::get()) + { + selected = all_backends->get_upstream(is_learn); + } + ~http_backend_runtime() = default; + static auto dtor(void *p) -> void + { + ((http_backend_runtime *) p)->~http_backend_runtime(); + } +}; + +/* + * Efficient way to make a messagepack payload from stat tokens, + * avoiding any intermediate libraries, as we would send many tokens + * all together + */ +static auto +stat_tokens_to_msgpack(GPtrArray *tokens) -> std::vector<std::uint8_t> +{ + std::vector<std::uint8_t> ret; + rspamd_token_t *cur; + int i; + + /* + * We define array, it's size and N elements each is uint64_t + * Layout: + * 0xdd - array marker + * [4 bytes be] - size of the array + * [ 0xcf + <8 bytes BE integer>] * N - array elements + */ + ret.resize(tokens->len * (sizeof(std::uint64_t) + 1) + 5); + ret.push_back('\xdd'); + std::uint32_t ulen = GUINT32_TO_BE(tokens->len); + std::copy((const std::uint8_t *) &ulen, + ((const std::uint8_t *) &ulen) + sizeof(ulen), std::back_inserter(ret)); + + PTR_ARRAY_FOREACH(tokens, i, cur) + { + ret.push_back('\xcf'); + std::uint64_t val = GUINT64_TO_BE(cur->data); + std::copy((const std::uint8_t *) &val, + ((const std::uint8_t *) &val) + sizeof(val), std::back_inserter(ret)); + } + + return ret; +} + +auto http_backend_runtime::create(struct rspamd_task *task, bool is_learn) -> http_backend_runtime * +{ + /* Alloc type provide proper size and alignment */ + auto *allocated_runtime = rspamd_mempool_alloc_type(task->task_pool, http_backend_runtime); + + rspamd_mempool_add_destructor(task->task_pool, http_backend_runtime::dtor, allocated_runtime); + + return new (allocated_runtime) http_backend_runtime{task, is_learn}; +} + +auto http_backend_runtime::process_tokens(struct rspamd_task *task, GPtrArray *tokens, gint id, bool learn) -> bool +{ + if (!learn) { + if (id == seen_statfiles.size() - 1) { + /* Emit http request on the last statfile */ + } + } + else { + /* On learn we need to learn all statfiles that we were requested to learn */ + if (seen_statfiles.empty()) { + /* Request has been already set, or nothing to learn */ + return true; + } + else { + seen_statfiles.clear(); + } + } + + return true; +} + +auto http_backends_collection::add_backend(struct rspamd_stat_ctx *ctx, + struct rspamd_config *cfg, + struct rspamd_statfile *st) -> bool +{ + /* On empty list of backends we know that we need to load backend data actually */ + if (backends.empty()) { + if (!first_init(ctx, cfg, st)) { + return false; + } + } + + backends.push_back(st); + + return true; +} + +auto http_backends_collection::first_init(struct rspamd_stat_ctx *ctx, + struct rspamd_config *cfg, + struct rspamd_statfile *st) -> bool +{ + auto try_load_backend_config = [&](const ucl_object_t *obj) -> bool { + if (!obj || ucl_object_type(obj) != UCL_OBJECT) { + return false; + } + + /* First try to load read servers */ + auto *rs = ucl_object_lookup_any(obj, "read_servers", "servers", nullptr); + if (rs) { + read_servers = rspamd_upstreams_create(cfg->ups_ctx); + + if (read_servers == nullptr) { + return false; + } + + if (!rspamd_upstreams_from_ucl(read_servers, rs, 80, this)) { + rspamd_upstreams_destroy(read_servers); + return false; + } + } + auto *ws = ucl_object_lookup_any(obj, "write_servers", "servers", nullptr); + if (ws) { + write_servers = rspamd_upstreams_create(cfg->ups_ctx); + + if (write_servers == nullptr) { + return false; + } + + if (!rspamd_upstreams_from_ucl(write_servers, rs, 80, this)) { + rspamd_upstreams_destroy(write_servers); + return false; + } + } + + auto *tim = ucl_object_lookup(obj, "timeout"); + + if (tim) { + timeout = ucl_object_todouble(tim); + } + + return true; + }; + + auto ret = false; + auto obj = ucl_object_lookup(st->classifier->cfg->opts, "backend"); + if (obj != nullptr) { + ret = try_load_backend_config(obj); + } + + /* Now try statfiles config */ + if (!ret && st->stcf->opts) { + ret = try_load_backend_config(st->stcf->opts); + } + + /* Now try classifier config */ + if (!ret && st->classifier->cfg->opts) { + ret = try_load_backend_config(st->classifier->cfg->opts); + } + + return ret; +} + +auto http_backends_collection::remove_backend(struct rspamd_statfile *st) -> bool +{ + auto backend_it = std::remove(std::begin(backends), std::end(backends), st); + + if (backend_it != std::end(backends)) { + /* Fast erasure with no order preservation */ + std::swap(*backend_it, backends.back()); + backends.pop_back(); + + if (backends.empty()) { + /* De-init collection - likely config reload */ + if (read_servers) { + rspamd_upstreams_destroy(read_servers); + read_servers = nullptr; + } + + if (write_servers) { + rspamd_upstreams_destroy(write_servers); + write_servers = nullptr; + } + } + + return true; + } + + return false; +} + +upstream *http_backends_collection::get_upstream(bool is_learn) +{ + auto *ups_list = read_servers; + if (is_learn) { + ups_list = write_servers; + } + + return rspamd_upstream_get(ups_list, RSPAMD_UPSTREAM_ROUND_ROBIN, nullptr, 0); +} + +}// namespace rspamd::stat::http + +/* C API */ + +gpointer +rspamd_http_init(struct rspamd_stat_ctx *ctx, + struct rspamd_config *cfg, + struct rspamd_statfile *st) +{ + auto &collections = rspamd::stat::http::http_backends_collection::get(); + + if (!collections.add_backend(ctx, cfg, st)) { + msg_err_config("cannot load http backend"); + + return nullptr; + } + + return (void *) &collections; +} +gpointer +rspamd_http_runtime(struct rspamd_task *task, + struct rspamd_statfile_config *stcf, + gboolean learn, + gpointer ctx, + gint id) +{ + auto maybe_existing = rspamd_mempool_get_variable(task->task_pool, RSPAMD_MEMPOOL_HTTP_STAT_BACKEND_RUNTIME); + + if (maybe_existing != nullptr) { + auto real_runtime = (rspamd::stat::http::http_backend_runtime *) maybe_existing; + real_runtime->notice_statfile(id, stcf); + + return maybe_existing; + } + + auto runtime = rspamd::stat::http::http_backend_runtime::create(task, learn); + + if (runtime) { + runtime->notice_statfile(id, stcf); + rspamd_mempool_set_variable(task->task_pool, RSPAMD_MEMPOOL_HTTP_STAT_BACKEND_RUNTIME, + (void *) runtime, nullptr); + } + + return (void *) runtime; +} + +gboolean +rspamd_http_process_tokens(struct rspamd_task *task, + GPtrArray *tokens, + gint id, + gpointer runtime) +{ + auto real_runtime = (rspamd::stat::http::http_backend_runtime *) runtime; + + if (real_runtime) { + return real_runtime->process_tokens(task, tokens, id, false); + } + + + return false; +} +gboolean +rspamd_http_finalize_process(struct rspamd_task *task, + gpointer runtime, + gpointer ctx) +{ + /* Not needed */ + return true; +} + +gboolean +rspamd_http_learn_tokens(struct rspamd_task *task, + GPtrArray *tokens, + gint id, + gpointer runtime) +{ + auto real_runtime = (rspamd::stat::http::http_backend_runtime *) runtime; + + if (real_runtime) { + return real_runtime->process_tokens(task, tokens, id, true); + } + + + return false; +} +gboolean +rspamd_http_finalize_learn(struct rspamd_task *task, + gpointer runtime, + gpointer ctx, + GError **err) +{ + return false; +} + +gulong rspamd_http_total_learns(struct rspamd_task *task, + gpointer runtime, + gpointer ctx) +{ + /* TODO */ + return 0; +} +gulong +rspamd_http_inc_learns(struct rspamd_task *task, + gpointer runtime, + gpointer ctx) +{ + /* TODO */ + return 0; +} +gulong +rspamd_http_dec_learns(struct rspamd_task *task, + gpointer runtime, + gpointer ctx) +{ + /* TODO */ + return (gulong) -1; +} +gulong +rspamd_http_learns(struct rspamd_task *task, + gpointer runtime, + gpointer ctx) +{ + /* TODO */ + return 0; +} +ucl_object_t * +rspamd_http_get_stat(gpointer runtime, gpointer ctx) +{ + /* TODO */ + return nullptr; +} +gpointer +rspamd_http_load_tokenizer_config(gpointer runtime, gsize *len) +{ + return nullptr; +} +void rspamd_http_close(gpointer ctx) +{ + /* TODO */ +}
\ No newline at end of file diff --git a/src/libstat/backends/mmaped_file.c b/src/libstat/backends/mmaped_file.c new file mode 100644 index 0000000..5c20207 --- /dev/null +++ b/src/libstat/backends/mmaped_file.c @@ -0,0 +1,1113 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "config.h" +#include "stat_internal.h" +#include "unix-std.h" + +#define CHAIN_LENGTH 128 + +/* Section types */ +#define STATFILE_SECTION_COMMON 1 + +/** + * Common statfile header + */ +struct stat_file_header { + u_char magic[3]; /**< magic signature ('r' 's' 'd') */ + u_char version[2]; /**< version of statfile */ + u_char padding[3]; /**< padding */ + guint64 create_time; /**< create time (time_t->guint64) */ + guint64 revision; /**< revision number */ + guint64 rev_time; /**< revision time */ + guint64 used_blocks; /**< used blocks number */ + guint64 total_blocks; /**< total number of blocks */ + guint64 tokenizer_conf_len; /**< length of tokenizer configuration */ + u_char unused[231]; /**< some bytes that can be used in future */ +}; + +/** + * Section header + */ +struct stat_file_section { + guint64 code; /**< section's code */ + guint64 length; /**< section's length in blocks */ +}; + +/** + * Block of data in statfile + */ +struct stat_file_block { + guint32 hash1; /**< hash1 (also acts as index) */ + guint32 hash2; /**< hash2 */ + double value; /**< double value */ +}; + +/** + * Statistic file + */ +struct stat_file { + struct stat_file_header header; /**< header */ + struct stat_file_section section; /**< first section */ + struct stat_file_block blocks[1]; /**< first block of data */ +}; + +/** + * Common view of statfile object + */ +typedef struct { +#ifdef HAVE_PATH_MAX + gchar filename[PATH_MAX]; /**< name of file */ +#else + gchar filename[MAXPATHLEN]; /**< name of file */ +#endif + rspamd_mempool_t *pool; + gint fd; /**< descriptor */ + void *map; /**< mmaped area */ + off_t seek_pos; /**< current seek position */ + struct stat_file_section cur_section; /**< current section */ + size_t len; /**< length of file(in bytes) */ + struct rspamd_statfile_config *cf; +} rspamd_mmaped_file_t; + + +#define RSPAMD_STATFILE_VERSION \ + { \ + '1', '2' \ + } +#define BACKUP_SUFFIX ".old" + +static void rspamd_mmaped_file_set_block_common(rspamd_mempool_t *pool, + rspamd_mmaped_file_t *file, + guint32 h1, guint32 h2, double value); + +rspamd_mmaped_file_t *rspamd_mmaped_file_open(rspamd_mempool_t *pool, + const gchar *filename, size_t size, + struct rspamd_statfile_config *stcf); +gint rspamd_mmaped_file_create(const gchar *filename, size_t size, + struct rspamd_statfile_config *stcf, + rspamd_mempool_t *pool); +gint rspamd_mmaped_file_close_file(rspamd_mempool_t *pool, + rspamd_mmaped_file_t *file); + +double +rspamd_mmaped_file_get_block(rspamd_mmaped_file_t *file, + guint32 h1, + guint32 h2) +{ + struct stat_file_block *block; + guint i, blocknum; + u_char *c; + + if (!file->map) { + return 0; + } + + blocknum = h1 % file->cur_section.length; + c = (u_char *) file->map + file->seek_pos + blocknum * sizeof(struct stat_file_block); + block = (struct stat_file_block *) c; + + for (i = 0; i < CHAIN_LENGTH; i++) { + if (i + blocknum >= file->cur_section.length) { + break; + } + if (block->hash1 == h1 && block->hash2 == h2) { + return block->value; + } + c += sizeof(struct stat_file_block); + block = (struct stat_file_block *) c; + } + + + return 0; +} + +static void +rspamd_mmaped_file_set_block_common(rspamd_mempool_t *pool, + rspamd_mmaped_file_t *file, + guint32 h1, guint32 h2, double value) +{ + struct stat_file_block *block, *to_expire = NULL; + struct stat_file_header *header; + guint i, blocknum; + u_char *c; + double min = G_MAXDOUBLE; + + if (!file->map) { + return; + } + + blocknum = h1 % file->cur_section.length; + header = (struct stat_file_header *) file->map; + c = (u_char *) file->map + file->seek_pos + blocknum * sizeof(struct stat_file_block); + block = (struct stat_file_block *) c; + + for (i = 0; i < CHAIN_LENGTH; i++) { + if (i + blocknum >= file->cur_section.length) { + /* Need to expire some block in chain */ + msg_info_pool("chain %ud is full in statfile %s, starting expire", + blocknum, + file->filename); + break; + } + /* First try to find block in chain */ + if (block->hash1 == h1 && block->hash2 == h2) { + msg_debug_pool("%s found existing block %ud in chain %ud, value %.2f", + file->filename, + i, + blocknum, + value); + block->value = value; + return; + } + /* Check whether we have a free block in chain */ + if (block->hash1 == 0 && block->hash2 == 0) { + /* Write new block here */ + msg_debug_pool("%s found free block %ud in chain %ud, set h1=%ud, h2=%ud", + file->filename, + i, + blocknum, + h1, + h2); + block->hash1 = h1; + block->hash2 = h2; + block->value = value; + header->used_blocks++; + + return; + } + + /* Expire block with minimum value otherwise */ + if (block->value < min) { + to_expire = block; + min = block->value; + } + c += sizeof(struct stat_file_block); + block = (struct stat_file_block *) c; + } + + /* Try expire some block */ + if (to_expire) { + block = to_expire; + } + else { + /* Expire first block in chain */ + c = (u_char *) file->map + file->seek_pos + blocknum * sizeof(struct stat_file_block); + block = (struct stat_file_block *) c; + } + + block->hash1 = h1; + block->hash2 = h2; + block->value = value; +} + +void rspamd_mmaped_file_set_block(rspamd_mempool_t *pool, + rspamd_mmaped_file_t *file, + guint32 h1, + guint32 h2, + double value) +{ + rspamd_mmaped_file_set_block_common(pool, file, h1, h2, value); +} + +gboolean +rspamd_mmaped_file_set_revision(rspamd_mmaped_file_t *file, guint64 rev, time_t time) +{ + struct stat_file_header *header; + + if (file == NULL || file->map == NULL) { + return FALSE; + } + + header = (struct stat_file_header *) file->map; + + header->revision = rev; + header->rev_time = time; + + return TRUE; +} + +gboolean +rspamd_mmaped_file_inc_revision(rspamd_mmaped_file_t *file) +{ + struct stat_file_header *header; + + if (file == NULL || file->map == NULL) { + return FALSE; + } + + header = (struct stat_file_header *) file->map; + + header->revision++; + + return TRUE; +} + +gboolean +rspamd_mmaped_file_dec_revision(rspamd_mmaped_file_t *file) +{ + struct stat_file_header *header; + + if (file == NULL || file->map == NULL) { + return FALSE; + } + + header = (struct stat_file_header *) file->map; + + header->revision--; + + return TRUE; +} + + +gboolean +rspamd_mmaped_file_get_revision(rspamd_mmaped_file_t *file, guint64 *rev, time_t *time) +{ + struct stat_file_header *header; + + if (file == NULL || file->map == NULL) { + return FALSE; + } + + header = (struct stat_file_header *) file->map; + + if (rev != NULL) { + *rev = header->revision; + } + if (time != NULL) { + *time = header->rev_time; + } + + return TRUE; +} + +guint64 +rspamd_mmaped_file_get_used(rspamd_mmaped_file_t *file) +{ + struct stat_file_header *header; + + if (file == NULL || file->map == NULL) { + return (guint64) -1; + } + + header = (struct stat_file_header *) file->map; + + return header->used_blocks; +} + +guint64 +rspamd_mmaped_file_get_total(rspamd_mmaped_file_t *file) +{ + struct stat_file_header *header; + + if (file == NULL || file->map == NULL) { + return (guint64) -1; + } + + header = (struct stat_file_header *) file->map; + + /* If total blocks is 0 we have old version of header, so set total blocks correctly */ + if (header->total_blocks == 0) { + header->total_blocks = file->cur_section.length; + } + + return header->total_blocks; +} + +/* Check whether specified file is statistic file and calculate its len in blocks */ +static gint +rspamd_mmaped_file_check(rspamd_mempool_t *pool, rspamd_mmaped_file_t *file) +{ + struct stat_file *f; + gchar *c; + static gchar valid_version[] = RSPAMD_STATFILE_VERSION; + + + if (!file || !file->map) { + return -1; + } + + if (file->len < sizeof(struct stat_file)) { + msg_info_pool("file %s is too short to be stat file: %z", + file->filename, + file->len); + return -1; + } + + f = (struct stat_file *) file->map; + c = &f->header.magic[0]; + /* Check magic and version */ + if (*c++ != 'r' || *c++ != 's' || *c++ != 'd') { + msg_info_pool("file %s is invalid stat file", file->filename); + return -1; + } + + c = &f->header.version[0]; + /* Now check version and convert old version to new one (that can be used for sync */ + if (*c == 1 && *(c + 1) == 0) { + return -1; + } + else if (memcmp(c, valid_version, sizeof(valid_version)) != 0) { + /* Unknown version */ + msg_info_pool("file %s has invalid version %c.%c", + file->filename, + '0' + *c, + '0' + *(c + 1)); + return -1; + } + + /* Check first section and set new offset */ + file->cur_section.code = f->section.code; + file->cur_section.length = f->section.length; + if (file->cur_section.length * sizeof(struct stat_file_block) > + file->len) { + msg_info_pool("file %s is truncated: %z, must be %z", + file->filename, + file->len, + file->cur_section.length * sizeof(struct stat_file_block)); + return -1; + } + file->seek_pos = sizeof(struct stat_file) - + sizeof(struct stat_file_block); + + return 0; +} + + +static rspamd_mmaped_file_t * +rspamd_mmaped_file_reindex(rspamd_mempool_t *pool, + const gchar *filename, + size_t old_size, + size_t size, + struct rspamd_statfile_config *stcf) +{ + gchar *backup, *lock; + gint fd, lock_fd; + rspamd_mmaped_file_t *new, *old = NULL; + u_char *map, *pos; + struct stat_file_block *block; + struct stat_file_header *header, *nh; + struct timespec sleep_ts = { + .tv_sec = 0, + .tv_nsec = 1000000}; + + if (size < + sizeof(struct stat_file_header) + sizeof(struct stat_file_section) + + sizeof(block)) { + msg_err_pool("file %s is too small to carry any statistic: %z", + filename, + size); + return NULL; + } + + lock = g_strconcat(filename, ".lock", NULL); + lock_fd = open(lock, O_WRONLY | O_CREAT | O_EXCL, 00600); + + while (lock_fd == -1) { + /* Wait for lock */ + lock_fd = open(lock, O_WRONLY | O_CREAT | O_EXCL, 00600); + if (lock_fd != -1) { + unlink(lock); + close(lock_fd); + g_free(lock); + + return rspamd_mmaped_file_open(pool, filename, size, stcf); + } + else { + nanosleep(&sleep_ts, NULL); + } + } + + backup = g_strconcat(filename, ".old", NULL); + if (rename(filename, backup) == -1) { + msg_err_pool("cannot rename %s to %s: %s", filename, backup, strerror(errno)); + g_free(backup); + unlink(lock); + g_free(lock); + close(lock_fd); + + return NULL; + } + + old = rspamd_mmaped_file_open(pool, backup, old_size, stcf); + + if (old == NULL) { + msg_warn_pool("old file %s is invalid mmapped file, just move it", + backup); + } + + /* We need to release our lock here */ + unlink(lock); + close(lock_fd); + g_free(lock); + + /* Now create new file with required size */ + if (rspamd_mmaped_file_create(filename, size, stcf, pool) != 0) { + msg_err_pool("cannot create new file"); + rspamd_mmaped_file_close(old); + g_free(backup); + + return NULL; + } + + new = rspamd_mmaped_file_open(pool, filename, size, stcf); + + if (old) { + /* Now open new file and start copying */ + fd = open(backup, O_RDONLY); + if (fd == -1 || new == NULL) { + if (fd != -1) { + close(fd); + } + + msg_err_pool("cannot open file: %s", strerror(errno)); + rspamd_mmaped_file_close(old); + g_free(backup); + return NULL; + } + + + /* Now start reading blocks from old statfile */ + if ((map = + mmap(NULL, old_size, PROT_READ, MAP_SHARED, fd, 0)) == MAP_FAILED) { + msg_err_pool("cannot mmap file: %s", strerror(errno)); + close(fd); + rspamd_mmaped_file_close(old); + g_free(backup); + return NULL; + } + + pos = map + (sizeof(struct stat_file) - sizeof(struct stat_file_block)); + + if (pos - map < (gssize) old_size) { + while ((gssize) old_size - (pos - map) >= (gssize) sizeof(struct stat_file_block)) { + block = (struct stat_file_block *) pos; + if (block->hash1 != 0 && block->value != 0) { + rspamd_mmaped_file_set_block_common(pool, + new, block->hash1, + block->hash2, block->value); + } + pos += sizeof(block); + } + } + + header = (struct stat_file_header *) map; + rspamd_mmaped_file_set_revision(new, header->revision, header->rev_time); + nh = new->map; + /* Copy tokenizer configuration */ + memcpy(nh->unused, header->unused, sizeof(header->unused)); + nh->tokenizer_conf_len = header->tokenizer_conf_len; + + munmap(map, old_size); + close(fd); + rspamd_mmaped_file_close_file(pool, old); + } + + unlink(backup); + g_free(backup); + + return new; +} + +/* + * Pre-load mmaped file into memory + */ +static void +rspamd_mmaped_file_preload(rspamd_mmaped_file_t *file) +{ + guint8 *pos, *end; + volatile guint8 t; + gsize size; + + pos = (guint8 *) file->map; + end = (guint8 *) file->map + file->len; + + if (madvise(pos, end - pos, MADV_SEQUENTIAL) == -1) { + msg_info("madvise failed: %s", strerror(errno)); + } + else { + /* Load pages of file */ +#ifdef HAVE_GETPAGESIZE + size = getpagesize(); +#else + size = sysconf(_SC_PAGESIZE); +#endif + while (pos < end) { + t = *pos; + (void) t; + pos += size; + } + } +} + +rspamd_mmaped_file_t * +rspamd_mmaped_file_open(rspamd_mempool_t *pool, + const gchar *filename, size_t size, + struct rspamd_statfile_config *stcf) +{ + struct stat st; + rspamd_mmaped_file_t *new_file; + gchar *lock; + gint lock_fd; + + lock = g_strconcat(filename, ".lock", NULL); + lock_fd = open(lock, O_WRONLY | O_CREAT | O_EXCL, 00600); + + if (lock_fd == -1) { + g_free(lock); + msg_info_pool("cannot open file %s, it is locked by another process", + filename); + return NULL; + } + + close(lock_fd); + unlink(lock); + g_free(lock); + + if (stat(filename, &st) == -1) { + msg_info_pool("cannot stat file %s, error %s, %d", filename, strerror(errno), errno); + return NULL; + } + + if (labs((glong) size - st.st_size) > (long) sizeof(struct stat_file) * 2 && size > sizeof(struct stat_file)) { + msg_warn_pool("need to reindex statfile old size: %Hz, new size: %Hz", + (size_t) st.st_size, size); + return rspamd_mmaped_file_reindex(pool, filename, st.st_size, size, stcf); + } + else if (size < sizeof(struct stat_file)) { + msg_err_pool("requested to shrink statfile to %Hz but it is too small", + size); + } + + new_file = g_malloc0(sizeof(rspamd_mmaped_file_t)); + if ((new_file->fd = open(filename, O_RDWR)) == -1) { + msg_info_pool("cannot open file %s, error %d, %s", + filename, + errno, + strerror(errno)); + g_free(new_file); + return NULL; + } + + if ((new_file->map = + mmap(NULL, st.st_size, PROT_READ | PROT_WRITE, MAP_SHARED, + new_file->fd, 0)) == MAP_FAILED) { + close(new_file->fd); + msg_info_pool("cannot mmap file %s, error %d, %s", + filename, + errno, + strerror(errno)); + g_free(new_file); + return NULL; + } + + rspamd_strlcpy(new_file->filename, filename, sizeof(new_file->filename)); + new_file->len = st.st_size; + /* Try to lock pages in RAM */ + + /* Acquire lock for this operation */ + if (!rspamd_file_lock(new_file->fd, FALSE)) { + close(new_file->fd); + munmap(new_file->map, st.st_size); + msg_info_pool("cannot lock file %s, error %d, %s", + filename, + errno, + strerror(errno)); + g_free(new_file); + return NULL; + } + + if (rspamd_mmaped_file_check(pool, new_file) == -1) { + close(new_file->fd); + rspamd_file_unlock(new_file->fd, FALSE); + munmap(new_file->map, st.st_size); + g_free(new_file); + return NULL; + } + + rspamd_file_unlock(new_file->fd, FALSE); + new_file->cf = stcf; + new_file->pool = pool; + rspamd_mmaped_file_preload(new_file); + + g_assert(stcf->clcf != NULL); + + msg_debug_pool("opened statfile %s of size %l", filename, (long) size); + + return new_file; +} + +gint rspamd_mmaped_file_close_file(rspamd_mempool_t *pool, + rspamd_mmaped_file_t *file) +{ + if (file->map) { + msg_info_pool("syncing statfile %s", file->filename); + msync(file->map, file->len, MS_ASYNC); + munmap(file->map, file->len); + } + if (file->fd != -1) { + close(file->fd); + } + + g_free(file); + + return 0; +} + +gint rspamd_mmaped_file_create(const gchar *filename, + size_t size, + struct rspamd_statfile_config *stcf, + rspamd_mempool_t *pool) +{ + struct stat_file_header header = { + .magic = {'r', 's', 'd'}, + .version = RSPAMD_STATFILE_VERSION, + .padding = {0, 0, 0}, + .revision = 0, + .rev_time = 0, + .used_blocks = 0}; + struct stat_file_section section = { + .code = STATFILE_SECTION_COMMON, + }; + struct stat_file_block block = {0, 0, 0}; + struct rspamd_stat_tokenizer *tokenizer; + gint fd, lock_fd; + guint buflen = 0, nblocks; + gchar *buf = NULL, *lock; + struct stat sb; + gpointer tok_conf; + gsize tok_conf_len; + struct timespec sleep_ts = { + .tv_sec = 0, + .tv_nsec = 1000000}; + + if (size < + sizeof(struct stat_file_header) + sizeof(struct stat_file_section) + + sizeof(block)) { + msg_err_pool("file %s is too small to carry any statistic: %z", + filename, + size); + return -1; + } + + lock = g_strconcat(filename, ".lock", NULL); + lock_fd = open(lock, O_WRONLY | O_CREAT | O_EXCL, 00600); + + while (lock_fd == -1) { + /* Wait for lock */ + lock_fd = open(lock, O_WRONLY | O_CREAT | O_EXCL, 00600); + if (lock_fd != -1) { + if (stat(filename, &sb) != -1) { + /* File has been created by some other process */ + unlink(lock); + close(lock_fd); + g_free(lock); + + return 0; + } + + /* We still need to create it */ + goto create; + } + else { + nanosleep(&sleep_ts, NULL); + } + } + +create: + + msg_debug_pool("create statfile %s of size %l", filename, (long) size); + nblocks = + (size - sizeof(struct stat_file_header) - + sizeof(struct stat_file_section)) / + sizeof(struct stat_file_block); + header.total_blocks = nblocks; + + if ((fd = + open(filename, O_RDWR | O_TRUNC | O_CREAT, S_IWUSR | S_IRUSR)) == -1) { + msg_info_pool("cannot create file %s, error %d, %s", + filename, + errno, + strerror(errno)); + unlink(lock); + close(lock_fd); + g_free(lock); + + return -1; + } + + rspamd_fallocate(fd, + 0, + sizeof(header) + sizeof(section) + sizeof(block) * nblocks); + + header.create_time = (guint64) time(NULL); + g_assert(stcf->clcf != NULL); + g_assert(stcf->clcf->tokenizer != NULL); + tokenizer = rspamd_stat_get_tokenizer(stcf->clcf->tokenizer->name); + g_assert(tokenizer != NULL); + tok_conf = tokenizer->get_config(pool, stcf->clcf->tokenizer, &tok_conf_len); + header.tokenizer_conf_len = tok_conf_len; + g_assert(tok_conf_len < sizeof(header.unused) - sizeof(guint64)); + memcpy(header.unused, tok_conf, tok_conf_len); + + if (write(fd, &header, sizeof(header)) == -1) { + msg_info_pool("cannot write header to file %s, error %d, %s", + filename, + errno, + strerror(errno)); + close(fd); + unlink(lock); + close(lock_fd); + g_free(lock); + + return -1; + } + + section.length = (guint64) nblocks; + if (write(fd, §ion, sizeof(section)) == -1) { + msg_info_pool("cannot write section header to file %s, error %d, %s", + filename, + errno, + strerror(errno)); + close(fd); + unlink(lock); + close(lock_fd); + g_free(lock); + + return -1; + } + + /* Buffer for write 256 blocks at once */ + if (nblocks > 256) { + buflen = sizeof(block) * 256; + buf = g_malloc0(buflen); + } + + while (nblocks) { + if (nblocks > 256) { + /* Just write buffer */ + if (write(fd, buf, buflen) == -1) { + msg_info_pool("cannot write blocks buffer to file %s, error %d, %s", + filename, + errno, + strerror(errno)); + close(fd); + g_free(buf); + unlink(lock); + close(lock_fd); + g_free(lock); + + return -1; + } + nblocks -= 256; + } + else { + if (write(fd, &block, sizeof(block)) == -1) { + msg_info_pool("cannot write block to file %s, error %d, %s", + filename, + errno, + strerror(errno)); + close(fd); + if (buf) { + g_free(buf); + } + + unlink(lock); + close(lock_fd); + g_free(lock); + + return -1; + } + nblocks--; + } + } + + close(fd); + + if (buf) { + g_free(buf); + } + + unlink(lock); + close(lock_fd); + g_free(lock); + msg_debug_pool("created statfile %s of size %l", filename, (long) size); + + return 0; +} + +gpointer +rspamd_mmaped_file_init(struct rspamd_stat_ctx *ctx, + struct rspamd_config *cfg, struct rspamd_statfile *st) +{ + struct rspamd_statfile_config *stf = st->stcf; + rspamd_mmaped_file_t *mf; + const ucl_object_t *filenameo, *sizeo; + const gchar *filename; + gsize size; + + filenameo = ucl_object_lookup(stf->opts, "filename"); + + if (filenameo == NULL || ucl_object_type(filenameo) != UCL_STRING) { + filenameo = ucl_object_lookup(stf->opts, "path"); + + if (filenameo == NULL || ucl_object_type(filenameo) != UCL_STRING) { + msg_err_config("statfile %s has no filename defined", stf->symbol); + return NULL; + } + } + + filename = ucl_object_tostring(filenameo); + + sizeo = ucl_object_lookup(stf->opts, "size"); + + if (sizeo == NULL || ucl_object_type(sizeo) != UCL_INT) { + msg_err_config("statfile %s has no size defined", stf->symbol); + return NULL; + } + + size = ucl_object_toint(sizeo); + mf = rspamd_mmaped_file_open(cfg->cfg_pool, filename, size, stf); + + if (mf != NULL) { + mf->pool = cfg->cfg_pool; + } + else { + /* Create file here */ + + filenameo = ucl_object_find_key(stf->opts, "filename"); + if (filenameo == NULL || ucl_object_type(filenameo) != UCL_STRING) { + filenameo = ucl_object_find_key(stf->opts, "path"); + if (filenameo == NULL || ucl_object_type(filenameo) != UCL_STRING) { + msg_err_config("statfile %s has no filename defined", stf->symbol); + return NULL; + } + } + + filename = ucl_object_tostring(filenameo); + + sizeo = ucl_object_find_key(stf->opts, "size"); + if (sizeo == NULL || ucl_object_type(sizeo) != UCL_INT) { + msg_err_config("statfile %s has no size defined", stf->symbol); + return NULL; + } + + size = ucl_object_toint(sizeo); + + if (rspamd_mmaped_file_create(filename, size, stf, cfg->cfg_pool) != 0) { + msg_err_config("cannot create new file"); + } + + mf = rspamd_mmaped_file_open(cfg->cfg_pool, filename, size, stf); + } + + return (gpointer) mf; +} + +void rspamd_mmaped_file_close(gpointer p) +{ + rspamd_mmaped_file_t *mf = p; + + + if (mf) { + rspamd_mmaped_file_close_file(mf->pool, mf); + } +} + +gpointer +rspamd_mmaped_file_runtime(struct rspamd_task *task, + struct rspamd_statfile_config *stcf, + gboolean learn, + gpointer p, + gint _id) +{ + rspamd_mmaped_file_t *mf = p; + + return (gpointer) mf; +} + +gboolean +rspamd_mmaped_file_process_tokens(struct rspamd_task *task, GPtrArray *tokens, + gint id, + gpointer p) +{ + rspamd_mmaped_file_t *mf = p; + guint32 h1, h2; + rspamd_token_t *tok; + guint i; + + g_assert(tokens != NULL); + g_assert(p != NULL); + + for (i = 0; i < tokens->len; i++) { + tok = g_ptr_array_index(tokens, i); + memcpy(&h1, (guchar *) &tok->data, sizeof(h1)); + memcpy(&h2, ((guchar *) &tok->data) + sizeof(h1), sizeof(h2)); + tok->values[id] = rspamd_mmaped_file_get_block(mf, h1, h2); + } + + if (mf->cf->is_spam) { + task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS; + } + else { + task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS; + } + + return TRUE; +} + +gboolean +rspamd_mmaped_file_learn_tokens(struct rspamd_task *task, GPtrArray *tokens, + gint id, + gpointer p) +{ + rspamd_mmaped_file_t *mf = p; + guint32 h1, h2; + rspamd_token_t *tok; + guint i; + + g_assert(tokens != NULL); + g_assert(p != NULL); + + for (i = 0; i < tokens->len; i++) { + tok = g_ptr_array_index(tokens, i); + memcpy(&h1, (guchar *) &tok->data, sizeof(h1)); + memcpy(&h2, ((guchar *) &tok->data) + sizeof(h1), sizeof(h2)); + rspamd_mmaped_file_set_block(task->task_pool, mf, h1, h2, + tok->values[id]); + } + + return TRUE; +} + +gulong +rspamd_mmaped_file_total_learns(struct rspamd_task *task, gpointer runtime, + gpointer ctx) +{ + rspamd_mmaped_file_t *mf = (rspamd_mmaped_file_t *) runtime; + guint64 rev = 0; + time_t t; + + if (mf != NULL) { + rspamd_mmaped_file_get_revision(mf, &rev, &t); + } + + return rev; +} + +gulong +rspamd_mmaped_file_inc_learns(struct rspamd_task *task, gpointer runtime, + gpointer ctx) +{ + rspamd_mmaped_file_t *mf = (rspamd_mmaped_file_t *) runtime; + guint64 rev = 0; + time_t t; + + if (mf != NULL) { + rspamd_mmaped_file_inc_revision(mf); + rspamd_mmaped_file_get_revision(mf, &rev, &t); + } + + return rev; +} + +gulong +rspamd_mmaped_file_dec_learns(struct rspamd_task *task, gpointer runtime, + gpointer ctx) +{ + rspamd_mmaped_file_t *mf = (rspamd_mmaped_file_t *) runtime; + guint64 rev = 0; + time_t t; + + if (mf != NULL) { + rspamd_mmaped_file_dec_revision(mf); + rspamd_mmaped_file_get_revision(mf, &rev, &t); + } + + return rev; +} + + +ucl_object_t * +rspamd_mmaped_file_get_stat(gpointer runtime, + gpointer ctx) +{ + ucl_object_t *res = NULL; + guint64 rev; + rspamd_mmaped_file_t *mf = (rspamd_mmaped_file_t *) runtime; + + if (mf != NULL) { + res = ucl_object_typed_new(UCL_OBJECT); + rspamd_mmaped_file_get_revision(mf, &rev, NULL); + ucl_object_insert_key(res, ucl_object_fromint(rev), "revision", + 0, false); + ucl_object_insert_key(res, ucl_object_fromint(mf->len), "size", + 0, false); + ucl_object_insert_key(res, ucl_object_fromint(rspamd_mmaped_file_get_total(mf)), "total", 0, false); + ucl_object_insert_key(res, ucl_object_fromint(rspamd_mmaped_file_get_used(mf)), "used", 0, false); + ucl_object_insert_key(res, ucl_object_fromstring(mf->cf->symbol), + "symbol", 0, false); + ucl_object_insert_key(res, ucl_object_fromstring("mmap"), + "type", 0, false); + ucl_object_insert_key(res, ucl_object_fromint(0), + "languages", 0, false); + ucl_object_insert_key(res, ucl_object_fromint(0), + "users", 0, false); + + if (mf->cf->label) { + ucl_object_insert_key(res, ucl_object_fromstring(mf->cf->label), + "label", 0, false); + } + } + + return res; +} + +gboolean +rspamd_mmaped_file_finalize_learn(struct rspamd_task *task, gpointer runtime, + gpointer ctx, GError **err) +{ + rspamd_mmaped_file_t *mf = (rspamd_mmaped_file_t *) runtime; + + if (mf != NULL) { + msync(mf->map, mf->len, MS_INVALIDATE | MS_ASYNC); + } + + return TRUE; +} + +gboolean +rspamd_mmaped_file_finalize_process(struct rspamd_task *task, gpointer runtime, + gpointer ctx) +{ + return TRUE; +} + +gpointer +rspamd_mmaped_file_load_tokenizer_config(gpointer runtime, + gsize *len) +{ + rspamd_mmaped_file_t *mf = runtime; + struct stat_file_header *header; + + g_assert(mf != NULL); + header = mf->map; + + if (len) { + *len = header->tokenizer_conf_len; + } + + return header->unused; +} diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx new file mode 100644 index 0000000..cd0c379 --- /dev/null +++ b/src/libstat/backends/redis_backend.cxx @@ -0,0 +1,1132 @@ +/* + * Copyright 2024 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "config.h" +#include "lua/lua_common.h" +#include "rspamd.h" +#include "stat_internal.h" +#include "upstream.h" +#include "libserver/mempool_vars_internal.h" +#include "fmt/core.h" + +#include "libutil/cxx/error.hxx" + +#include <string> +#include <cstdint> +#include <vector> +#include <optional> + +#define msg_debug_stat_redis(...) rspamd_conditional_debug_fast(nullptr, nullptr, \ + rspamd_stat_redis_log_id, "stat_redis", task->task_pool->tag.uid, \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) + +INIT_LOG_MODULE(stat_redis) + +#define REDIS_CTX(p) (reinterpret_cast<struct redis_stat_ctx *>(p)) +#define REDIS_RUNTIME(p) (reinterpret_cast<struct redis_stat_runtime<float> *>(p)) +#define REDIS_DEFAULT_OBJECT "%s%l" +#define REDIS_DEFAULT_USERS_OBJECT "%s%l%r" +#define REDIS_DEFAULT_TIMEOUT 0.5 +#define REDIS_STAT_TIMEOUT 30 +#define REDIS_MAX_USERS 1000 + +struct redis_stat_ctx { + lua_State *L; + struct rspamd_statfile_config *stcf; + const char *redis_object = REDIS_DEFAULT_OBJECT; + bool enable_users = false; + bool store_tokens = false; + bool enable_signatures = false; + int cbref_user = -1; + + int cbref_classify = -1; + int cbref_learn = -1; + + ucl_object_t *cur_stat = nullptr; + + explicit redis_stat_ctx(lua_State *_L) + : L(_L) + { + } + + ~redis_stat_ctx() + { + if (cbref_user != -1) { + luaL_unref(L, LUA_REGISTRYINDEX, cbref_user); + } + + if (cbref_classify != -1) { + luaL_unref(L, LUA_REGISTRYINDEX, cbref_classify); + } + + if (cbref_learn != -1) { + luaL_unref(L, LUA_REGISTRYINDEX, cbref_learn); + } + } +}; + + +template<class T, std::enable_if_t<std::is_convertible_v<T, float>, bool> = true> +struct redis_stat_runtime { + struct redis_stat_ctx *ctx; + struct rspamd_task *task; + struct rspamd_statfile_config *stcf; + GPtrArray *tokens = nullptr; + const char *redis_object_expanded; + std::uint64_t learned = 0; + int id; + std::vector<std::pair<int, T>> *results = nullptr; + bool need_redis_call = true; + std::optional<rspamd::util::error> err; + + using result_type = std::vector<std::pair<int, T>>; + +private: + /* Called on connection termination */ + static void rt_dtor(gpointer data) + { + auto *rt = REDIS_RUNTIME(data); + + delete rt; + } + + /* Avoid occasional deletion */ + ~redis_stat_runtime() + { + if (tokens) { + g_ptr_array_unref(tokens); + } + + delete results; + } + +public: + explicit redis_stat_runtime(struct redis_stat_ctx *_ctx, struct rspamd_task *_task, const char *_redis_object_expanded) + : ctx(_ctx), task(_task), stcf(_ctx->stcf), redis_object_expanded(_redis_object_expanded) + { + rspamd_mempool_add_destructor(task->task_pool, redis_stat_runtime<T>::rt_dtor, this); + } + + static auto maybe_recover_from_mempool(struct rspamd_task *task, const char *redis_object_expanded, + bool is_spam) -> std::optional<redis_stat_runtime<T> *> + { + auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H"); + auto *res = rspamd_mempool_get_variable(task->task_pool, var_name.c_str()); + + if (res) { + msg_debug_bayes("recovered runtime from mempool at %s", var_name.c_str()); + return reinterpret_cast<redis_stat_runtime<T> *>(res); + } + else { + msg_debug_bayes("no runtime at %s", var_name.c_str()); + return std::nullopt; + } + } + + void set_results(std::vector<std::pair<int, T>> *results) + { + this->results = results; + } + + /* Propagate results from internal representation to the tokens array */ + auto process_tokens(GPtrArray *tokens) const -> bool + { + rspamd_token_t *tok; + + if (!results) { + return false; + } + + for (auto [idx, val]: *results) { + tok = (rspamd_token_t *) g_ptr_array_index(tokens, idx - 1); + tok->values[id] = val; + } + + return true; + } + + auto save_in_mempool(bool is_spam) const + { + auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H"); + /* We do not set destructor for the variable, as it should be already added on creation */ + rspamd_mempool_set_variable(task->task_pool, var_name.c_str(), (gpointer) this, nullptr); + msg_debug_bayes("saved runtime in mempool at %s", var_name.c_str()); + } +}; + +#define GET_TASK_ELT(task, elt) (task == nullptr ? nullptr : (task)->elt) + +static const gchar *M = "redis statistics"; + +static GQuark +rspamd_redis_stat_quark(void) +{ + return g_quark_from_static_string(M); +} + +/* + * Non-static for lua unit testing + */ +gsize rspamd_redis_expand_object(const gchar *pattern, + struct redis_stat_ctx *ctx, + struct rspamd_task *task, + gchar **target) +{ + gsize tlen = 0; + const gchar *p = pattern, *elt; + gchar *d, *end; + enum { + just_char, + percent_char, + mod_char + } state = just_char; + struct rspamd_statfile_config *stcf; + lua_State *L = nullptr; + struct rspamd_task **ptask; + const gchar *rcpt = nullptr; + gint err_idx; + + g_assert(ctx != nullptr); + g_assert(task != nullptr); + stcf = ctx->stcf; + + L = RSPAMD_LUA_CFG_STATE(task->cfg); + g_assert(L != nullptr); + + if (ctx->enable_users) { + if (ctx->cbref_user == -1) { + rcpt = rspamd_task_get_principal_recipient(task); + } + else { + /* Execute lua function to get userdata */ + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + + lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->cbref_user); + ptask = (struct rspamd_task **) lua_newuserdata(L, sizeof(struct rspamd_task *)); + *ptask = task; + rspamd_lua_setclass(L, "rspamd{task}", -1); + + if (lua_pcall(L, 1, 1, err_idx) != 0) { + msg_err_task("call to user extraction script failed: %s", + lua_tostring(L, -1)); + } + else { + rcpt = rspamd_mempool_strdup(task->task_pool, lua_tostring(L, -1)); + } + + /* Result + error function */ + lua_settop(L, err_idx - 1); + } + + if (rcpt) { + rspamd_mempool_set_variable(task->task_pool, "stat_user", + (gpointer) rcpt, nullptr); + } + } + + /* Length calculation */ + while (*p) { + switch (state) { + case just_char: + if (*p == '%') { + state = percent_char; + } + else { + tlen++; + } + p++; + break; + case percent_char: + switch (*p) { + case '%': + tlen++; + state = just_char; + break; + case 'u': + elt = GET_TASK_ELT(task, auth_user); + if (elt) { + tlen += strlen(elt); + } + break; + case 'r': + + if (rcpt == nullptr) { + elt = rspamd_task_get_principal_recipient(task); + } + else { + elt = rcpt; + } + + if (elt) { + tlen += strlen(elt); + } + break; + case 'l': + if (stcf->label) { + tlen += strlen(stcf->label); + } + /* Label miss is OK */ + break; + case 's': + tlen += sizeof("RS") - 1; + break; + default: + state = just_char; + tlen++; + break; + } + + if (state == percent_char) { + state = mod_char; + } + p++; + break; + + case mod_char: + switch (*p) { + case 'd': + p++; + state = just_char; + break; + default: + state = just_char; + break; + } + break; + } + } + + + if (target == nullptr) { + return -1; + } + + *target = (gchar *) rspamd_mempool_alloc(task->task_pool, tlen + 1); + d = *target; + end = d + tlen + 1; + d[tlen] = '\0'; + p = pattern; + state = just_char; + + /* Expand string */ + while (*p && d < end) { + switch (state) { + case just_char: + if (*p == '%') { + state = percent_char; + } + else { + *d++ = *p; + } + p++; + break; + case percent_char: + switch (*p) { + case '%': + *d++ = *p; + state = just_char; + break; + case 'u': + elt = GET_TASK_ELT(task, auth_user); + if (elt) { + d += rspamd_strlcpy(d, elt, end - d); + } + break; + case 'r': + if (rcpt == nullptr) { + elt = rspamd_task_get_principal_recipient(task); + } + else { + elt = rcpt; + } + + if (elt) { + d += rspamd_strlcpy(d, elt, end - d); + } + break; + case 'l': + if (stcf->label) { + d += rspamd_strlcpy(d, stcf->label, end - d); + } + break; + case 's': + d += rspamd_strlcpy(d, "RS", end - d); + break; + default: + state = just_char; + *d++ = *p; + break; + } + + if (state == percent_char) { + state = mod_char; + } + p++; + break; + + case mod_char: + switch (*p) { + case 'd': + /* TODO: not supported yet */ + p++; + state = just_char; + break; + default: + state = just_char; + break; + } + break; + } + } + + return tlen; +} + +static int +rspamd_redis_stat_cb(lua_State *L) +{ + const auto *cookie = lua_tostring(L, lua_upvalueindex(1)); + auto *cfg = lua_check_config(L, 1); + auto *backend = REDIS_CTX(rspamd_mempool_get_variable(cfg->cfg_pool, cookie)); + + if (backend == nullptr) { + msg_err("internal error: cookie %s is not found", cookie); + + return 0; + } + + auto *cur_obj = ucl_object_lua_import(L, 2); + msg_debug_bayes_cfg("got stat object for %s", backend->stcf->symbol); + /* Enrich with some default values that are meaningless for redis */ + ucl_object_insert_key(cur_obj, + ucl_object_typed_new(UCL_INT), "used", 0, false); + ucl_object_insert_key(cur_obj, + ucl_object_typed_new(UCL_INT), "total", 0, false); + ucl_object_insert_key(cur_obj, + ucl_object_typed_new(UCL_INT), "size", 0, false); + ucl_object_insert_key(cur_obj, + ucl_object_fromstring(backend->stcf->symbol), + "symbol", 0, false); + ucl_object_insert_key(cur_obj, ucl_object_fromstring("redis"), + "type", 0, false); + ucl_object_insert_key(cur_obj, ucl_object_fromint(0), + "languages", 0, false); + + if (backend->cur_stat) { + ucl_object_unref(backend->cur_stat); + } + + backend->cur_stat = cur_obj; + + return 0; +} + +static void +rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend, + const ucl_object_t *statfile_obj, + const ucl_object_t *classifier_obj, + struct rspamd_config *cfg) +{ + const gchar *lua_script; + const ucl_object_t *elt, *users_enabled; + auto *L = RSPAMD_LUA_CFG_STATE(cfg); + + users_enabled = ucl_object_lookup_any(classifier_obj, "per_user", + "users_enabled", nullptr); + + if (users_enabled != nullptr) { + if (ucl_object_type(users_enabled) == UCL_BOOLEAN) { + backend->enable_users = ucl_object_toboolean(users_enabled); + backend->cbref_user = -1; + } + else if (ucl_object_type(users_enabled) == UCL_STRING) { + lua_script = ucl_object_tostring(users_enabled); + + if (luaL_dostring(L, lua_script) != 0) { + msg_err_config("cannot execute lua script for users " + "extraction: %s", + lua_tostring(L, -1)); + } + else { + if (lua_type(L, -1) == LUA_TFUNCTION) { + backend->enable_users = TRUE; + backend->cbref_user = luaL_ref(L, + LUA_REGISTRYINDEX); + } + else { + msg_err_config("lua script must return " + "function(task) and not %s", + lua_typename(L, lua_type(L, -1))); + } + } + } + } + else { + backend->enable_users = FALSE; + backend->cbref_user = -1; + } + + elt = ucl_object_lookup(classifier_obj, "prefix"); + if (elt == nullptr || ucl_object_type(elt) != UCL_STRING) { + /* Default non-users statistics */ + if (backend->enable_users || backend->cbref_user != -1) { + backend->redis_object = REDIS_DEFAULT_USERS_OBJECT; + } + else { + backend->redis_object = REDIS_DEFAULT_OBJECT; + } + } + else { + /* XXX: sanity check */ + backend->redis_object = ucl_object_tostring(elt); + } + + elt = ucl_object_lookup(classifier_obj, "store_tokens"); + if (elt) { + backend->store_tokens = ucl_object_toboolean(elt); + } + else { + backend->store_tokens = FALSE; + } + + elt = ucl_object_lookup(classifier_obj, "signatures"); + if (elt) { + backend->enable_signatures = ucl_object_toboolean(elt); + } + else { + backend->enable_signatures = FALSE; + } +} + +gpointer +rspamd_redis_init(struct rspamd_stat_ctx *ctx, + struct rspamd_config *cfg, struct rspamd_statfile *st) +{ + auto *L = RSPAMD_LUA_CFG_STATE(cfg); + + auto backend = std::make_unique<struct redis_stat_ctx>(L); + lua_settop(L, 0); + + rspamd_redis_parse_classifier_opts(backend.get(), st->stcf->opts, st->classifier->cfg->opts, cfg); + + st->stcf->clcf->flags |= RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND; + backend->stcf = st->stcf; + + lua_pushcfunction(L, &rspamd_lua_traceback); + auto err_idx = lua_gettop(L); + + /* Obtain function */ + if (!rspamd_lua_require_function(L, "lua_bayes_redis", "lua_bayes_init_statfile")) { + msg_err_config("cannot require lua_bayes_redis.lua_bayes_init_statfile"); + lua_settop(L, err_idx - 1); + + return nullptr; + } + + /* Push arguments */ + ucl_object_push_lua(L, st->classifier->cfg->opts, false); + ucl_object_push_lua(L, st->stcf->opts, false); + lua_pushstring(L, backend->stcf->symbol); + lua_pushboolean(L, backend->stcf->is_spam); + auto **pev_base = (struct ev_loop **) lua_newuserdata(L, sizeof(struct ev_loop *)); + *pev_base = ctx->event_loop; + rspamd_lua_setclass(L, "rspamd{ev_base}", -1); + + /* Store backend in random cookie */ + char *cookie = (char *) rspamd_mempool_alloc(cfg->cfg_pool, 16); + rspamd_random_hex(cookie, 16); + cookie[15] = '\0'; + rspamd_mempool_set_variable(cfg->cfg_pool, cookie, backend.get(), nullptr); + /* Callback + 1 upvalue */ + lua_pushstring(L, cookie); + lua_pushcclosure(L, &rspamd_redis_stat_cb, 1); + + if (lua_pcall(L, 6, 2, err_idx) != 0) { + msg_err("call to lua_bayes_init_classifier " + "script failed: %s", + lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); + + return nullptr; + } + + /* Results are in the stack: + * top - 1 - classifier function (idx = -2) + * top - learn function (idx = -1) + */ + + lua_pushvalue(L, -2); + backend->cbref_classify = luaL_ref(L, LUA_REGISTRYINDEX); + + lua_pushvalue(L, -1); + backend->cbref_learn = luaL_ref(L, LUA_REGISTRYINDEX); + + lua_settop(L, err_idx - 1); + + return backend.release(); +} + +gpointer +rspamd_redis_runtime(struct rspamd_task *task, + struct rspamd_statfile_config *stcf, + gboolean learn, gpointer c, gint _id) +{ + struct redis_stat_ctx *ctx = REDIS_CTX(c); + char *object_expanded = nullptr; + + g_assert(ctx != nullptr); + g_assert(stcf != nullptr); + + if (rspamd_redis_expand_object(ctx->redis_object, ctx, task, + &object_expanded) == 0) { + msg_err_task("expansion for %s failed for symbol %s " + "(maybe learning per user classifier with no user or recipient)", + learn ? "learning" : "classifying", + stcf->symbol); + return nullptr; + } + + /* Look for the cached results */ + if (!learn) { + auto maybe_existing = redis_stat_runtime<float>::maybe_recover_from_mempool(task, + object_expanded, stcf->is_spam); + + if (maybe_existing) { + auto *rt = maybe_existing.value(); + /* Update stcf and ctx to correspond to what we have been asked */ + rt->stcf = stcf; + rt->ctx = ctx; + return rt; + } + } + + /* No cached result (or learn), create new one */ + auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded); + + if (!learn) { + /* + * For check, we also need to create the opposite class runtime to avoid + * double call for Redis scripts. + * This runtime will be filled later. + */ + auto maybe_opposite_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task, + object_expanded, + !stcf->is_spam); + + if (!maybe_opposite_rt) { + auto *opposite_rt = new redis_stat_runtime<float>(ctx, task, object_expanded); + opposite_rt->save_in_mempool(!stcf->is_spam); + opposite_rt->need_redis_call = false; + } + } + + rt->save_in_mempool(stcf->is_spam); + + return rt; +} + +void rspamd_redis_close(gpointer p) +{ + struct redis_stat_ctx *ctx = REDIS_CTX(p); + delete ctx; +} + +static constexpr auto +msgpack_emit_str(const std::string_view st, char *out) -> std::size_t +{ + auto len = st.size(); + constexpr const unsigned char fix_mask = 0xA0, l8_ch = 0xd9, l16_ch = 0xda, l32_ch = 0xdb; + auto blen = 0; + if (len <= 0x1F) { + blen = 1; + out[0] = (len | fix_mask) & 0xff; + } + else if (len <= 0xff) { + blen = 2; + out[0] = l8_ch; + out[1] = len & 0xff; + } + else if (len <= 0xffff) { + uint16_t bl = GUINT16_TO_BE(len); + + blen = 3; + out[0] = l16_ch; + memcpy(&out[1], &bl, sizeof(bl)); + } + else { + uint32_t bl = GUINT32_TO_BE(len); + + blen = 5; + out[0] = l32_ch; + memcpy(&out[1], &bl, sizeof(bl)); + } + + memcpy(&out[blen], st.data(), st.size()); + + return blen + len; +} + +static constexpr auto +msgpack_str_len(std::size_t len) -> std::size_t +{ + if (len <= 0x1F) { + return 1 + len; + } + else if (len <= 0xff) { + return 2 + len; + } + else if (len <= 0xffff) { + return 3 + len; + } + else { + return 4 + len; + } +} + +/* + * Serialise stat tokens to message pack + */ +static char * +rspamd_redis_serialize_tokens(struct rspamd_task *task, const gchar *prefix, GPtrArray *tokens, gsize *ser_len) +{ + /* Each token is int64_t that requires 10 bytes (2 int32_t) + 4 bytes array len + 1 byte array magic */ + char max_int64_str[] = "18446744073709551615"; + auto prefix_len = strlen(prefix); + std::size_t req_len = 5; + rspamd_token_t *tok; + + /* Calculate required length */ + req_len += tokens->len * (msgpack_str_len(sizeof(max_int64_str) + prefix_len) + 1); + + auto *buf = (gchar *) rspamd_mempool_alloc(task->task_pool, req_len); + auto *p = buf; + + /* Array */ + *p++ = (gchar) 0xdd; + /* Length in big-endian (4 bytes) */ + *p++ = (gchar) ((tokens->len >> 24) & 0xff); + *p++ = (gchar) ((tokens->len >> 16) & 0xff); + *p++ = (gchar) ((tokens->len >> 8) & 0xff); + *p++ = (gchar) (tokens->len & 0xff); + + + int i; + auto numbuf_len = sizeof(max_int64_str) + prefix_len + 1; + auto *numbuf = (char *) g_alloca(numbuf_len); + + PTR_ARRAY_FOREACH(tokens, i, tok) + { + std::size_t r = rspamd_snprintf(numbuf, numbuf_len, "%s_%uL", prefix, tok->data); + auto shift = msgpack_emit_str({numbuf, r}, p); + p += shift; + } + + *ser_len = p - buf; + + return buf; +} + +static char * +rspamd_redis_serialize_text_tokens(struct rspamd_task *task, GPtrArray *tokens, gsize *ser_len) +{ + rspamd_token_t *tok; + auto req_len = 5; /* Messagepack array prefix */ + int i; + + /* + * First we need to determine the requested length + */ + PTR_ARRAY_FOREACH(tokens, i, tok) + { + if (tok->t1 && tok->t2) { + /* Two tokens */ + req_len += msgpack_str_len(tok->t1->stemmed.len) + msgpack_str_len(tok->t2->stemmed.len); + } + else if (tok->t1) { + req_len += msgpack_str_len(tok->t1->stemmed.len); + req_len += 1; /* null */ + } + else { + req_len += 2; /* 2 nulls */ + } + } + + auto *buf = (gchar *) rspamd_mempool_alloc(task->task_pool, req_len); + auto *p = buf; + + /* Array */ + std::uint32_t nlen = tokens->len * 2; + nlen = GUINT32_TO_BE(nlen); + *p++ = (gchar) 0xdd; + /* Length in big-endian (4 bytes) */ + memcpy(p, &nlen, sizeof(nlen)); + p += sizeof(nlen); + + PTR_ARRAY_FOREACH(tokens, i, tok) + { + if (tok->t1 && tok->t2) { + auto step = msgpack_emit_str({tok->t1->stemmed.begin, tok->t1->stemmed.len}, p); + p += step; + step = msgpack_emit_str({tok->t2->stemmed.begin, tok->t2->stemmed.len}, p); + p += step; + } + else if (tok->t1) { + auto step = msgpack_emit_str({tok->t1->stemmed.begin, tok->t1->stemmed.len}, p); + p += step; + *p++ = 0xc0; + } + else { + *p++ = 0xc0; + *p++ = 0xc0; + } + } + + *ser_len = p - buf; + + return buf; +} + +static gint +rspamd_redis_classified(lua_State *L) +{ + const auto *cookie = lua_tostring(L, lua_upvalueindex(1)); + auto *task = lua_check_task(L, 1); + auto *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie)); + + if (rt == nullptr) { + msg_err_task("internal error: cannot find runtime for cookie %s", cookie); + + return 0; + } + + bool result = lua_toboolean(L, 2); + + if (result) { + /* Indexes: + * 3 - learned_ham (int) + * 4 - learned_spam (int) + * 5 - ham_tokens (pair<int, int>) + * 6 - spam_tokens (pair<int, int>) + */ + + /* + * We need to fill our runtime AND the opposite runtime + */ + auto filler_func = [](redis_stat_runtime<float> *rt, lua_State *L, unsigned learned, int tokens_pos) { + rt->learned = learned; + redis_stat_runtime<float>::result_type *res; + + res = new redis_stat_runtime<float>::result_type(); + + for (lua_pushnil(L); lua_next(L, tokens_pos); lua_pop(L, 1)) { + lua_rawgeti(L, -1, 1); + auto idx = lua_tointeger(L, -1); + lua_pop(L, 1); + + lua_rawgeti(L, -1, 2); + auto value = lua_tonumber(L, -1); + lua_pop(L, 1); + + res->emplace_back(idx, value); + } + + rt->set_results(res); + }; + + auto opposite_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task, + rt->redis_object_expanded, + !rt->stcf->is_spam); + + if (!opposite_rt_maybe) { + msg_err_task("internal error: cannot find opposite runtime for cookie %s", cookie); + + return 0; + } + + if (rt->stcf->is_spam) { + filler_func(rt, L, lua_tointeger(L, 4), 6); + filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 3), 5); + } + else { + filler_func(rt, L, lua_tointeger(L, 3), 5); + filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 4), 6); + } + + /* Mark task as being processed */ + task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS | RSPAMD_TASK_FLAG_HAS_HAM_TOKENS; + + /* Process all tokens */ + g_assert(rt->tokens != nullptr); + rt->process_tokens(rt->tokens); + opposite_rt_maybe.value()->process_tokens(rt->tokens); + } + else { + /* Error message is on index 3 */ + const auto *err_msg = lua_tostring(L, 3); + rt->err = rspamd::util::error(err_msg, 500); + msg_err_task("cannot classify task: %s", + err_msg); + } + + return 0; +} + +gboolean +rspamd_redis_process_tokens(struct rspamd_task *task, + GPtrArray *tokens, + gint id, gpointer p) +{ + auto *rt = REDIS_RUNTIME(p); + auto *L = rt->ctx->L; + + if (rspamd_session_blocked(task->s)) { + return FALSE; + } + + if (tokens == nullptr || tokens->len == 0) { + return FALSE; + } + + if (!rt->need_redis_call) { + /* No need to do anything, as it is already done in the opposite class processing */ + /* However, we need to store id as it is needed for further tokens processing */ + rt->id = id; + rt->tokens = g_ptr_array_ref(tokens); + + return TRUE; + } + + gsize tokens_len; + gchar *tokens_buf = rspamd_redis_serialize_tokens(task, rt->redis_object_expanded, tokens, &tokens_len); + rt->id = id; + + lua_pushcfunction(L, &rspamd_lua_traceback); + gint err_idx = lua_gettop(L); + + /* Function arguments */ + lua_rawgeti(L, LUA_REGISTRYINDEX, rt->ctx->cbref_classify); + rspamd_lua_task_push(L, task); + lua_pushstring(L, rt->redis_object_expanded); + lua_pushinteger(L, id); + lua_pushboolean(L, rt->stcf->is_spam); + lua_new_text(L, tokens_buf, tokens_len, false); + + /* Store rt in random cookie */ + char *cookie = (char *) rspamd_mempool_alloc(task->task_pool, 16); + rspamd_random_hex(cookie, 16); + cookie[15] = '\0'; + rspamd_mempool_set_variable(task->task_pool, cookie, rt, nullptr); + /* Callback */ + lua_pushstring(L, cookie); + lua_pushcclosure(L, &rspamd_redis_classified, 1); + + if (lua_pcall(L, 6, 0, err_idx) != 0) { + msg_err_task("call to redis failed: %s", lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); + return FALSE; + } + + rt->tokens = g_ptr_array_ref(tokens); + + lua_settop(L, err_idx - 1); + return TRUE; +} + +gboolean +rspamd_redis_finalize_process(struct rspamd_task *task, gpointer runtime, + gpointer ctx) +{ + auto *rt = REDIS_RUNTIME(runtime); + + return !rt->err.has_value(); +} + + +static gint +rspamd_redis_learned(lua_State *L) +{ + const auto *cookie = lua_tostring(L, lua_upvalueindex(1)); + auto *task = lua_check_task(L, 1); + auto *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie)); + + if (rt == nullptr) { + msg_err_task("internal error: cannot find runtime for cookie %s", cookie); + + return 0; + } + + bool result = lua_toboolean(L, 2); + + if (result) { + /* TODO: write it */ + } + else { + /* Error message is on index 3 */ + const auto *err_msg = lua_tostring(L, 3); + rt->err = rspamd::util::error(err_msg, 500); + msg_err_task("cannot learn task: %s", err_msg); + } + + return 0; +} + +gboolean +rspamd_redis_learn_tokens(struct rspamd_task *task, + GPtrArray *tokens, + gint id, gpointer p) +{ + auto *rt = REDIS_RUNTIME(p); + auto *L = rt->ctx->L; + + if (rspamd_session_blocked(task->s)) { + return FALSE; + } + + if (tokens == nullptr || tokens->len == 0) { + return FALSE; + } + + gsize tokens_len; + gchar *tokens_buf = rspamd_redis_serialize_tokens(task, rt->redis_object_expanded, tokens, &tokens_len); + + rt->id = id; + + gsize text_tokens_len = 0; + gchar *text_tokens_buf = nullptr; + + if (rt->ctx->store_tokens) { + text_tokens_buf = rspamd_redis_serialize_text_tokens(task, tokens, &text_tokens_len); + } + + lua_pushcfunction(L, &rspamd_lua_traceback); + gint err_idx = lua_gettop(L); + auto nargs = 8; + + /* Function arguments */ + lua_rawgeti(L, LUA_REGISTRYINDEX, rt->ctx->cbref_learn); + rspamd_lua_task_push(L, task); + lua_pushstring(L, rt->redis_object_expanded); + lua_pushinteger(L, id); + lua_pushboolean(L, rt->stcf->is_spam); + lua_pushstring(L, rt->stcf->symbol); + + /* Detect unlearn */ + auto *tok = (rspamd_token_t *) g_ptr_array_index(task->tokens, 0); + + if (tok->values[id] > 0) { + lua_pushboolean(L, FALSE);// Learn + } + else { + lua_pushboolean(L, TRUE);// Unlearn + } + lua_new_text(L, tokens_buf, tokens_len, false); + + /* Store rt in random cookie */ + char *cookie = (char *) rspamd_mempool_alloc(task->task_pool, 16); + rspamd_random_hex(cookie, 16); + cookie[15] = '\0'; + rspamd_mempool_set_variable(task->task_pool, cookie, rt, nullptr); + /* Callback */ + lua_pushstring(L, cookie); + lua_pushcclosure(L, &rspamd_redis_learned, 1); + + if (text_tokens_len) { + nargs = 9; + lua_new_text(L, text_tokens_buf, text_tokens_len, false); + } + + if (lua_pcall(L, nargs, 0, err_idx) != 0) { + msg_err_task("call to script failed: %s", lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); + return FALSE; + } + + rt->tokens = g_ptr_array_ref(tokens); + + lua_settop(L, err_idx - 1); + return TRUE; +} + + +gboolean +rspamd_redis_finalize_learn(struct rspamd_task *task, gpointer runtime, + gpointer ctx, GError **err) +{ + auto *rt = REDIS_RUNTIME(runtime); + + if (rt->err.has_value()) { + rt->err->into_g_error_set(rspamd_redis_stat_quark(), err); + + return FALSE; + } + + return TRUE; +} + +gulong +rspamd_redis_total_learns(struct rspamd_task *task, gpointer runtime, + gpointer ctx) +{ + auto *rt = REDIS_RUNTIME(runtime); + + return rt->learned; +} + +gulong +rspamd_redis_inc_learns(struct rspamd_task *task, gpointer runtime, + gpointer ctx) +{ + auto *rt = REDIS_RUNTIME(runtime); + + /* XXX: may cause races */ + return rt->learned + 1; +} + +gulong +rspamd_redis_dec_learns(struct rspamd_task *task, gpointer runtime, + gpointer ctx) +{ + auto *rt = REDIS_RUNTIME(runtime); + + /* XXX: may cause races */ + return rt->learned + 1; +} + +gulong +rspamd_redis_learns(struct rspamd_task *task, gpointer runtime, + gpointer ctx) +{ + auto *rt = REDIS_RUNTIME(runtime); + + return rt->learned; +} + +ucl_object_t * +rspamd_redis_get_stat(gpointer runtime, + gpointer ctx) +{ + auto *rt = REDIS_RUNTIME(runtime); + + return ucl_object_ref(rt->ctx->cur_stat); +} + +gpointer +rspamd_redis_load_tokenizer_config(gpointer runtime, + gsize *len) +{ + return nullptr; +} diff --git a/src/libstat/backends/sqlite3_backend.c b/src/libstat/backends/sqlite3_backend.c new file mode 100644 index 0000000..2fd34d8 --- /dev/null +++ b/src/libstat/backends/sqlite3_backend.c @@ -0,0 +1,907 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "config.h" +#include "rspamd.h" +#include "sqlite3.h" +#include "libutil/sqlite_utils.h" +#include "libstat/stat_internal.h" +#include "libmime/message.h" +#include "lua/lua_common.h" +#include "unix-std.h" + +#define SQLITE3_BACKEND_TYPE "sqlite3" +#define SQLITE3_SCHEMA_VERSION "1" +#define SQLITE3_DEFAULT "default" + +struct rspamd_stat_sqlite3_db { + sqlite3 *sqlite; + gchar *fname; + GArray *prstmt; + lua_State *L; + rspamd_mempool_t *pool; + gboolean in_transaction; + gboolean enable_users; + gboolean enable_languages; + gint cbref_user; + gint cbref_language; +}; + +struct rspamd_stat_sqlite3_rt { + struct rspamd_task *task; + struct rspamd_stat_sqlite3_db *db; + struct rspamd_statfile_config *cf; + gint64 user_id; + gint64 lang_id; +}; + +static const char *create_tables_sql = + "BEGIN IMMEDIATE;" + "CREATE TABLE tokenizer(data BLOB);" + "CREATE TABLE users(" + "id INTEGER PRIMARY KEY," + "name TEXT," + "learns INTEGER" + ");" + "CREATE TABLE languages(" + "id INTEGER PRIMARY KEY," + "name TEXT," + "learns INTEGER" + ");" + "CREATE TABLE tokens(" + "token INTEGER NOT NULL," + "user INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE," + "language INTEGER NOT NULL REFERENCES languages(id) ON DELETE CASCADE," + "value INTEGER," + "modified INTEGER," + "CONSTRAINT tid UNIQUE (token, user, language) ON CONFLICT REPLACE" + ");" + "CREATE UNIQUE INDEX IF NOT EXISTS un ON users(name);" + "CREATE INDEX IF NOT EXISTS tok ON tokens(token);" + "CREATE UNIQUE INDEX IF NOT EXISTS ln ON languages(name);" + "PRAGMA user_version=" SQLITE3_SCHEMA_VERSION ";" + "INSERT INTO users(id, name, learns) VALUES(0, '" SQLITE3_DEFAULT "',0);" + "INSERT INTO languages(id, name, learns) VALUES(0, '" SQLITE3_DEFAULT "',0);" + "COMMIT;"; + +enum rspamd_stat_sqlite3_stmt_idx { + RSPAMD_STAT_BACKEND_TRANSACTION_START_IM = 0, + RSPAMD_STAT_BACKEND_TRANSACTION_START_DEF, + RSPAMD_STAT_BACKEND_TRANSACTION_START_EXCL, + RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT, + RSPAMD_STAT_BACKEND_TRANSACTION_ROLLBACK, + RSPAMD_STAT_BACKEND_GET_TOKEN_FULL, + RSPAMD_STAT_BACKEND_GET_TOKEN_SIMPLE, + RSPAMD_STAT_BACKEND_SET_TOKEN, + RSPAMD_STAT_BACKEND_INC_LEARNS_LANG, + RSPAMD_STAT_BACKEND_INC_LEARNS_USER, + RSPAMD_STAT_BACKEND_DEC_LEARNS_LANG, + RSPAMD_STAT_BACKEND_DEC_LEARNS_USER, + RSPAMD_STAT_BACKEND_GET_LEARNS, + RSPAMD_STAT_BACKEND_GET_LANGUAGE, + RSPAMD_STAT_BACKEND_GET_USER, + RSPAMD_STAT_BACKEND_INSERT_USER, + RSPAMD_STAT_BACKEND_INSERT_LANGUAGE, + RSPAMD_STAT_BACKEND_SAVE_TOKENIZER, + RSPAMD_STAT_BACKEND_LOAD_TOKENIZER, + RSPAMD_STAT_BACKEND_NTOKENS, + RSPAMD_STAT_BACKEND_NLANGUAGES, + RSPAMD_STAT_BACKEND_NUSERS, + RSPAMD_STAT_BACKEND_MAX +}; + +static struct rspamd_sqlite3_prstmt prepared_stmts[RSPAMD_STAT_BACKEND_MAX] = + { + [RSPAMD_STAT_BACKEND_TRANSACTION_START_IM] = { + .idx = RSPAMD_STAT_BACKEND_TRANSACTION_START_IM, + .sql = "BEGIN IMMEDIATE TRANSACTION;", + .args = "", + .stmt = NULL, + .result = SQLITE_DONE, + .flags = 0, + .ret = "", + }, + [RSPAMD_STAT_BACKEND_TRANSACTION_START_DEF] = {.idx = RSPAMD_STAT_BACKEND_TRANSACTION_START_DEF, .sql = "BEGIN DEFERRED TRANSACTION;", .args = "", .stmt = NULL, .result = SQLITE_DONE, .flags = 0, .ret = ""}, + [RSPAMD_STAT_BACKEND_TRANSACTION_START_EXCL] = {.idx = RSPAMD_STAT_BACKEND_TRANSACTION_START_EXCL, .sql = "BEGIN EXCLUSIVE TRANSACTION;", .args = "", .stmt = NULL, .result = SQLITE_DONE, .flags = 0, .ret = ""}, + [RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT] = {.idx = RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT, .sql = "COMMIT;", .args = "", .stmt = NULL, .result = SQLITE_DONE, .flags = 0, .ret = ""}, + [RSPAMD_STAT_BACKEND_TRANSACTION_ROLLBACK] = {.idx = RSPAMD_STAT_BACKEND_TRANSACTION_ROLLBACK, .sql = "ROLLBACK;", .args = "", .stmt = NULL, .result = SQLITE_DONE, .flags = 0, .ret = ""}, + [RSPAMD_STAT_BACKEND_GET_TOKEN_FULL] = {.idx = RSPAMD_STAT_BACKEND_GET_TOKEN_FULL, .sql = "SELECT value FROM tokens " + "LEFT JOIN languages ON tokens.language=languages.id " + "LEFT JOIN users ON tokens.user=users.id " + "WHERE token=?1 AND (users.id=?2) " + "AND (languages.id=?3 OR languages.id=0);", + .stmt = NULL, + .args = "III", + .result = SQLITE_ROW, + .flags = 0, + .ret = "I"}, + [RSPAMD_STAT_BACKEND_GET_TOKEN_SIMPLE] = {.idx = RSPAMD_STAT_BACKEND_GET_TOKEN_SIMPLE, .sql = "SELECT value FROM tokens WHERE token=?1", .stmt = NULL, .args = "I", .result = SQLITE_ROW, .flags = 0, .ret = "I"}, + [RSPAMD_STAT_BACKEND_SET_TOKEN] = {.idx = RSPAMD_STAT_BACKEND_SET_TOKEN, .sql = "INSERT OR REPLACE INTO tokens (token, user, language, value, modified) " + "VALUES (?1, ?2, ?3, ?4, strftime('%s','now'))", + .stmt = NULL, + .args = "IIII", + .result = SQLITE_DONE, + .flags = 0, + .ret = ""}, + [RSPAMD_STAT_BACKEND_INC_LEARNS_LANG] = {.idx = RSPAMD_STAT_BACKEND_INC_LEARNS_LANG, .sql = "UPDATE languages SET learns=learns + 1 WHERE id=?1", .stmt = NULL, .args = "I", .result = SQLITE_DONE, .flags = 0, .ret = ""}, + [RSPAMD_STAT_BACKEND_INC_LEARNS_USER] = {.idx = RSPAMD_STAT_BACKEND_INC_LEARNS_USER, .sql = "UPDATE users SET learns=learns + 1 WHERE id=?1", .stmt = NULL, .args = "I", .result = SQLITE_DONE, .flags = 0, .ret = ""}, + [RSPAMD_STAT_BACKEND_DEC_LEARNS_LANG] = {.idx = RSPAMD_STAT_BACKEND_DEC_LEARNS_LANG, .sql = "UPDATE languages SET learns=MAX(0, learns - 1) WHERE id=?1", .stmt = NULL, .args = "I", .result = SQLITE_DONE, .flags = 0, .ret = ""}, + [RSPAMD_STAT_BACKEND_DEC_LEARNS_USER] = {.idx = RSPAMD_STAT_BACKEND_DEC_LEARNS_USER, .sql = "UPDATE users SET learns=MAX(0, learns - 1) WHERE id=?1", .stmt = NULL, .args = "I", .result = SQLITE_DONE, .flags = 0, .ret = ""}, + [RSPAMD_STAT_BACKEND_GET_LEARNS] = {.idx = RSPAMD_STAT_BACKEND_GET_LEARNS, .sql = "SELECT SUM(MAX(0, learns)) FROM languages", .stmt = NULL, .args = "", .result = SQLITE_ROW, .flags = 0, .ret = "I"}, + [RSPAMD_STAT_BACKEND_GET_LANGUAGE] = {.idx = RSPAMD_STAT_BACKEND_GET_LANGUAGE, .sql = "SELECT id FROM languages WHERE name=?1", .stmt = NULL, .args = "T", .result = SQLITE_ROW, .flags = 0, .ret = "I"}, + [RSPAMD_STAT_BACKEND_GET_USER] = {.idx = RSPAMD_STAT_BACKEND_GET_USER, .sql = "SELECT id FROM users WHERE name=?1", .stmt = NULL, .args = "T", .result = SQLITE_ROW, .flags = 0, .ret = "I"}, + [RSPAMD_STAT_BACKEND_INSERT_USER] = {.idx = RSPAMD_STAT_BACKEND_INSERT_USER, .sql = "INSERT INTO users (name, learns) VALUES (?1, 0)", .stmt = NULL, .args = "T", .result = SQLITE_DONE, .flags = 0, .ret = "L"}, + [RSPAMD_STAT_BACKEND_INSERT_LANGUAGE] = {.idx = RSPAMD_STAT_BACKEND_INSERT_LANGUAGE, .sql = "INSERT INTO languages (name, learns) VALUES (?1, 0)", .stmt = NULL, .args = "T", .result = SQLITE_DONE, .flags = 0, .ret = "L"}, + [RSPAMD_STAT_BACKEND_SAVE_TOKENIZER] = {.idx = RSPAMD_STAT_BACKEND_SAVE_TOKENIZER, .sql = "INSERT INTO tokenizer(data) VALUES (?1)", .stmt = NULL, .args = "B", .result = SQLITE_DONE, .flags = 0, .ret = ""}, + [RSPAMD_STAT_BACKEND_LOAD_TOKENIZER] = {.idx = RSPAMD_STAT_BACKEND_LOAD_TOKENIZER, .sql = "SELECT data FROM tokenizer", .stmt = NULL, .args = "", .result = SQLITE_ROW, .flags = 0, .ret = "B"}, + [RSPAMD_STAT_BACKEND_NTOKENS] = {.idx = RSPAMD_STAT_BACKEND_NTOKENS, .sql = "SELECT COUNT(*) FROM tokens", .stmt = NULL, .args = "", .result = SQLITE_ROW, .flags = 0, .ret = "I"}, + [RSPAMD_STAT_BACKEND_NLANGUAGES] = {.idx = RSPAMD_STAT_BACKEND_NLANGUAGES, .sql = "SELECT COUNT(*) FROM languages", .stmt = NULL, .args = "", .result = SQLITE_ROW, .flags = 0, .ret = "I"}, + [RSPAMD_STAT_BACKEND_NUSERS] = {.idx = RSPAMD_STAT_BACKEND_NUSERS, .sql = "SELECT COUNT(*) FROM users", .stmt = NULL, .args = "", .result = SQLITE_ROW, .flags = 0, .ret = "I"}}; + +static GQuark +rspamd_sqlite3_backend_quark(void) +{ + return g_quark_from_static_string("sqlite3-stat-backend"); +} + +static gint64 +rspamd_sqlite3_get_user(struct rspamd_stat_sqlite3_db *db, + struct rspamd_task *task, gboolean learn) +{ + gint64 id = 0; /* Default user is 0 */ + gint rc, err_idx; + const gchar *user = NULL; + struct rspamd_task **ptask; + lua_State *L = db->L; + + if (db->cbref_user == -1) { + user = rspamd_task_get_principal_recipient(task); + } + else { + /* Execute lua function to get userdata */ + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + + lua_rawgeti(L, LUA_REGISTRYINDEX, db->cbref_user); + ptask = lua_newuserdata(L, sizeof(struct rspamd_task *)); + *ptask = task; + rspamd_lua_setclass(L, "rspamd{task}", -1); + + if (lua_pcall(L, 1, 1, err_idx) != 0) { + msg_err_task("call to user extraction script failed: %s", + lua_tostring(L, -1)); + } + else { + user = rspamd_mempool_strdup(task->task_pool, lua_tostring(L, -1)); + } + + /* Result + error function */ + lua_settop(L, err_idx - 1); + } + + + if (user != NULL) { + rspamd_mempool_set_variable(task->task_pool, "stat_user", + (gpointer) user, NULL); + + rc = rspamd_sqlite3_run_prstmt(task->task_pool, db->sqlite, db->prstmt, + RSPAMD_STAT_BACKEND_GET_USER, user, &id); + + if (rc != SQLITE_OK && learn) { + /* We need to insert a new user */ + if (!db->in_transaction) { + rspamd_sqlite3_run_prstmt(task->task_pool, db->sqlite, db->prstmt, + RSPAMD_STAT_BACKEND_TRANSACTION_START_IM); + db->in_transaction = TRUE; + } + + rc = rspamd_sqlite3_run_prstmt(task->task_pool, db->sqlite, db->prstmt, + RSPAMD_STAT_BACKEND_INSERT_USER, user, &id); + } + } + + return id; +} + +static gint64 +rspamd_sqlite3_get_language(struct rspamd_stat_sqlite3_db *db, + struct rspamd_task *task, gboolean learn) +{ + gint64 id = 0; /* Default language is 0 */ + gint rc, err_idx; + guint i; + const gchar *language = NULL; + struct rspamd_mime_text_part *tp; + struct rspamd_task **ptask; + lua_State *L = db->L; + + if (db->cbref_language == -1) { + PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, tp) + { + + if (tp->language != NULL && tp->language[0] != '\0' && + strcmp(tp->language, "en") != 0) { + language = tp->language; + break; + } + } + } + else { + /* Execute lua function to get userdata */ + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + + lua_rawgeti(L, LUA_REGISTRYINDEX, db->cbref_language); + ptask = lua_newuserdata(L, sizeof(struct rspamd_task *)); + *ptask = task; + rspamd_lua_setclass(L, "rspamd{task}", -1); + + if (lua_pcall(L, 1, 1, err_idx) != 0) { + msg_err_task("call to language extraction script failed: %s", + lua_tostring(L, -1)); + } + else { + language = rspamd_mempool_strdup(task->task_pool, + lua_tostring(L, -1)); + } + + /* Result + error function */ + lua_settop(L, err_idx - 1); + } + + + /* XXX: We ignore multiple languages but default + extra */ + if (language != NULL) { + rc = rspamd_sqlite3_run_prstmt(task->task_pool, db->sqlite, db->prstmt, + RSPAMD_STAT_BACKEND_GET_LANGUAGE, language, &id); + + if (rc != SQLITE_OK && learn) { + /* We need to insert a new language */ + if (!db->in_transaction) { + rspamd_sqlite3_run_prstmt(task->task_pool, db->sqlite, db->prstmt, + RSPAMD_STAT_BACKEND_TRANSACTION_START_IM); + db->in_transaction = TRUE; + } + + rc = rspamd_sqlite3_run_prstmt(task->task_pool, db->sqlite, db->prstmt, + RSPAMD_STAT_BACKEND_INSERT_LANGUAGE, language, &id); + } + } + + return id; +} + +static struct rspamd_stat_sqlite3_db * +rspamd_sqlite3_opendb(rspamd_mempool_t *pool, + struct rspamd_statfile_config *stcf, + const gchar *path, const ucl_object_t *opts, + gboolean create, GError **err) +{ + struct rspamd_stat_sqlite3_db *bk; + struct rspamd_stat_tokenizer *tokenizer; + gpointer tk_conf; + gsize sz = 0; + gint64 sz64 = 0; + gchar *tok_conf_encoded; + gint ret, ntries = 0; + const gint max_tries = 100; + struct timespec sleep_ts = { + .tv_sec = 0, + .tv_nsec = 1000000}; + + bk = g_malloc0(sizeof(*bk)); + bk->sqlite = rspamd_sqlite3_open_or_create(pool, path, create_tables_sql, + 0, err); + bk->pool = pool; + + if (bk->sqlite == NULL) { + g_free(bk); + + return NULL; + } + + bk->fname = g_strdup(path); + + bk->prstmt = rspamd_sqlite3_init_prstmt(bk->sqlite, prepared_stmts, + RSPAMD_STAT_BACKEND_MAX, err); + + if (bk->prstmt == NULL) { + sqlite3_close(bk->sqlite); + g_free(bk); + + return NULL; + } + + /* Check tokenizer configuration */ + if (rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_LOAD_TOKENIZER, &sz64, &tk_conf) != SQLITE_OK || + sz64 == 0) { + + while ((ret = rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_TRANSACTION_START_EXCL)) == SQLITE_BUSY && + ++ntries <= max_tries) { + nanosleep(&sleep_ts, NULL); + } + + msg_info_pool("absent tokenizer conf in %s, creating a new one", + bk->fname); + g_assert(stcf->clcf->tokenizer != NULL); + tokenizer = rspamd_stat_get_tokenizer(stcf->clcf->tokenizer->name); + g_assert(tokenizer != NULL); + tk_conf = tokenizer->get_config(pool, stcf->clcf->tokenizer, &sz); + + /* Encode to base32 */ + tok_conf_encoded = rspamd_encode_base32(tk_conf, sz, RSPAMD_BASE32_DEFAULT); + + if (rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_SAVE_TOKENIZER, + (gint64) strlen(tok_conf_encoded), + tok_conf_encoded) != SQLITE_OK) { + sqlite3_close(bk->sqlite); + g_free(bk); + g_free(tok_conf_encoded); + + return NULL; + } + + rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT); + g_free(tok_conf_encoded); + } + else { + g_free(tk_conf); + } + + return bk; +} + +gpointer +rspamd_sqlite3_init(struct rspamd_stat_ctx *ctx, + struct rspamd_config *cfg, + struct rspamd_statfile *st) +{ + struct rspamd_classifier_config *clf = st->classifier->cfg; + struct rspamd_statfile_config *stf = st->stcf; + const ucl_object_t *filenameo, *lang_enabled, *users_enabled; + const gchar *filename, *lua_script; + struct rspamd_stat_sqlite3_db *bk; + GError *err = NULL; + + filenameo = ucl_object_lookup(stf->opts, "filename"); + if (filenameo == NULL || ucl_object_type(filenameo) != UCL_STRING) { + filenameo = ucl_object_lookup(stf->opts, "path"); + if (filenameo == NULL || ucl_object_type(filenameo) != UCL_STRING) { + msg_err_config("statfile %s has no filename defined", stf->symbol); + return NULL; + } + } + + filename = ucl_object_tostring(filenameo); + + if ((bk = rspamd_sqlite3_opendb(cfg->cfg_pool, stf, filename, + stf->opts, TRUE, &err)) == NULL) { + msg_err_config("cannot open sqlite3 db %s: %e", filename, err); + g_error_free(err); + return NULL; + } + + bk->L = cfg->lua_state; + + users_enabled = ucl_object_lookup_any(clf->opts, "per_user", + "users_enabled", NULL); + if (users_enabled != NULL) { + if (ucl_object_type(users_enabled) == UCL_BOOLEAN) { + bk->enable_users = ucl_object_toboolean(users_enabled); + bk->cbref_user = -1; + } + else if (ucl_object_type(users_enabled) == UCL_STRING) { + lua_script = ucl_object_tostring(users_enabled); + + if (luaL_dostring(cfg->lua_state, lua_script) != 0) { + msg_err_config("cannot execute lua script for users " + "extraction: %s", + lua_tostring(cfg->lua_state, -1)); + } + else { + if (lua_type(cfg->lua_state, -1) == LUA_TFUNCTION) { + bk->enable_users = TRUE; + bk->cbref_user = luaL_ref(cfg->lua_state, + LUA_REGISTRYINDEX); + } + else { + msg_err_config("lua script must return " + "function(task) and not %s", + lua_typename(cfg->lua_state, lua_type( + cfg->lua_state, -1))); + } + } + } + } + else { + bk->enable_users = FALSE; + } + + lang_enabled = ucl_object_lookup_any(clf->opts, + "per_language", "languages_enabled", NULL); + + if (lang_enabled != NULL) { + if (ucl_object_type(lang_enabled) == UCL_BOOLEAN) { + bk->enable_languages = ucl_object_toboolean(lang_enabled); + bk->cbref_language = -1; + } + else if (ucl_object_type(lang_enabled) == UCL_STRING) { + lua_script = ucl_object_tostring(lang_enabled); + + if (luaL_dostring(cfg->lua_state, lua_script) != 0) { + msg_err_config( + "cannot execute lua script for languages " + "extraction: %s", + lua_tostring(cfg->lua_state, -1)); + } + else { + if (lua_type(cfg->lua_state, -1) == LUA_TFUNCTION) { + bk->enable_languages = TRUE; + bk->cbref_language = luaL_ref(cfg->lua_state, + LUA_REGISTRYINDEX); + } + else { + msg_err_config("lua script must return " + "function(task) and not %s", + lua_typename(cfg->lua_state, + lua_type(cfg->lua_state, -1))); + } + } + } + } + else { + bk->enable_languages = FALSE; + } + + if (bk->enable_languages) { + msg_info_config("enable per language statistics for %s", + stf->symbol); + } + + if (bk->enable_users) { + msg_info_config("enable per users statistics for %s", + stf->symbol); + } + + + return (gpointer) bk; +} + +void rspamd_sqlite3_close(gpointer p) +{ + struct rspamd_stat_sqlite3_db *bk = p; + + if (bk->sqlite) { + if (bk->in_transaction) { + rspamd_sqlite3_run_prstmt(bk->pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT); + } + + rspamd_sqlite3_close_prstmt(bk->sqlite, bk->prstmt); + sqlite3_close(bk->sqlite); + g_free(bk->fname); + g_free(bk); + } +} + +gpointer +rspamd_sqlite3_runtime(struct rspamd_task *task, + struct rspamd_statfile_config *stcf, gboolean learn, gpointer p, gint _id) +{ + struct rspamd_stat_sqlite3_rt *rt = NULL; + struct rspamd_stat_sqlite3_db *bk = p; + + if (bk) { + rt = rspamd_mempool_alloc(task->task_pool, sizeof(*rt)); + rt->db = bk; + rt->task = task; + rt->user_id = -1; + rt->lang_id = -1; + rt->cf = stcf; + } + + return rt; +} + +gboolean +rspamd_sqlite3_process_tokens(struct rspamd_task *task, + GPtrArray *tokens, + gint id, gpointer p) +{ + struct rspamd_stat_sqlite3_db *bk; + struct rspamd_stat_sqlite3_rt *rt = p; + gint64 iv = 0; + guint i; + rspamd_token_t *tok; + + g_assert(p != NULL); + g_assert(tokens != NULL); + + bk = rt->db; + + for (i = 0; i < tokens->len; i++) { + tok = g_ptr_array_index(tokens, i); + + if (bk == NULL) { + /* Statfile is does not exist, so all values are zero */ + tok->values[id] = 0.0f; + continue; + } + + if (!bk->in_transaction) { + rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_TRANSACTION_START_DEF); + bk->in_transaction = TRUE; + } + + if (rt->user_id == -1) { + if (bk->enable_users) { + rt->user_id = rspamd_sqlite3_get_user(bk, task, FALSE); + } + else { + rt->user_id = 0; + } + } + + if (rt->lang_id == -1) { + if (bk->enable_languages) { + rt->lang_id = rspamd_sqlite3_get_language(bk, task, FALSE); + } + else { + rt->lang_id = 0; + } + } + + if (bk->enable_languages || bk->enable_users) { + if (rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_GET_TOKEN_FULL, + tok->data, rt->user_id, rt->lang_id, &iv) == SQLITE_OK) { + tok->values[id] = iv; + } + else { + tok->values[id] = 0.0f; + } + } + else { + if (rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_GET_TOKEN_SIMPLE, + tok->data, &iv) == SQLITE_OK) { + tok->values[id] = iv; + } + else { + tok->values[id] = 0.0f; + } + } + + if (rt->cf->is_spam) { + task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS; + } + else { + task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS; + } + } + + + return TRUE; +} + +gboolean +rspamd_sqlite3_finalize_process(struct rspamd_task *task, gpointer runtime, + gpointer ctx) +{ + struct rspamd_stat_sqlite3_rt *rt = runtime; + struct rspamd_stat_sqlite3_db *bk; + + g_assert(rt != NULL); + bk = rt->db; + + if (bk->in_transaction) { + rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT); + bk->in_transaction = FALSE; + } + + rt->lang_id = -1; + rt->user_id = -1; + + return TRUE; +} + +gboolean +rspamd_sqlite3_learn_tokens(struct rspamd_task *task, GPtrArray *tokens, + gint id, gpointer p) +{ + struct rspamd_stat_sqlite3_db *bk; + struct rspamd_stat_sqlite3_rt *rt = p; + gint64 iv = 0; + guint i; + rspamd_token_t *tok; + + g_assert(tokens != NULL); + g_assert(p != NULL); + + bk = rt->db; + + for (i = 0; i < tokens->len; i++) { + tok = g_ptr_array_index(tokens, i); + if (bk == NULL) { + /* Statfile is does not exist, so all values are zero */ + return FALSE; + } + + if (!bk->in_transaction) { + rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_TRANSACTION_START_IM); + bk->in_transaction = TRUE; + } + + if (rt->user_id == -1) { + if (bk->enable_users) { + rt->user_id = rspamd_sqlite3_get_user(bk, task, TRUE); + } + else { + rt->user_id = 0; + } + } + + if (rt->lang_id == -1) { + if (bk->enable_languages) { + rt->lang_id = rspamd_sqlite3_get_language(bk, task, TRUE); + } + else { + rt->lang_id = 0; + } + } + + iv = tok->values[id]; + + if (rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_SET_TOKEN, + tok->data, rt->user_id, rt->lang_id, iv) != SQLITE_OK) { + rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_TRANSACTION_ROLLBACK); + bk->in_transaction = FALSE; + + return FALSE; + } + } + + return TRUE; +} + +gboolean +rspamd_sqlite3_finalize_learn(struct rspamd_task *task, gpointer runtime, + gpointer ctx, GError **err) +{ + struct rspamd_stat_sqlite3_rt *rt = runtime; + struct rspamd_stat_sqlite3_db *bk; + gint wal_frames, wal_checkpointed, mode; + + g_assert(rt != NULL); + bk = rt->db; + + if (bk->in_transaction) { + rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT); + bk->in_transaction = FALSE; + } + +#ifdef SQLITE_OPEN_WAL +#ifdef SQLITE_CHECKPOINT_TRUNCATE + mode = SQLITE_CHECKPOINT_TRUNCATE; +#elif defined(SQLITE_CHECKPOINT_RESTART) + mode = SQLITE_CHECKPOINT_RESTART; +#elif defined(SQLITE_CHECKPOINT_FULL) + mode = SQLITE_CHECKPOINT_FULL; +#endif + /* Perform wal checkpoint (might be long) */ + if (sqlite3_wal_checkpoint_v2(bk->sqlite, + NULL, + mode, + &wal_frames, + &wal_checkpointed) != SQLITE_OK) { + msg_warn_task("cannot commit checkpoint: %s", + sqlite3_errmsg(bk->sqlite)); + + g_set_error(err, rspamd_sqlite3_backend_quark(), 500, + "cannot commit checkpoint: %s", + sqlite3_errmsg(bk->sqlite)); + return FALSE; + } +#endif + + return TRUE; +} + +gulong +rspamd_sqlite3_total_learns(struct rspamd_task *task, gpointer runtime, + gpointer ctx) +{ + struct rspamd_stat_sqlite3_rt *rt = runtime; + struct rspamd_stat_sqlite3_db *bk; + guint64 res; + + g_assert(rt != NULL); + bk = rt->db; + rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_GET_LEARNS, &res); + + return res; +} + +gulong +rspamd_sqlite3_inc_learns(struct rspamd_task *task, gpointer runtime, + gpointer ctx) +{ + struct rspamd_stat_sqlite3_rt *rt = runtime; + struct rspamd_stat_sqlite3_db *bk; + guint64 res; + + g_assert(rt != NULL); + bk = rt->db; + rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_INC_LEARNS_LANG, + rt->lang_id); + rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_INC_LEARNS_USER, + rt->user_id); + + if (bk->in_transaction) { + rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT); + bk->in_transaction = FALSE; + } + + rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_GET_LEARNS, &res); + + return res; +} + +gulong +rspamd_sqlite3_dec_learns(struct rspamd_task *task, gpointer runtime, + gpointer ctx) +{ + struct rspamd_stat_sqlite3_rt *rt = runtime; + struct rspamd_stat_sqlite3_db *bk; + guint64 res; + + g_assert(rt != NULL); + bk = rt->db; + rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_DEC_LEARNS_LANG, + rt->lang_id); + rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_DEC_LEARNS_USER, + rt->user_id); + + if (bk->in_transaction) { + rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT); + bk->in_transaction = FALSE; + } + + rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_GET_LEARNS, &res); + + return res; +} + +gulong +rspamd_sqlite3_learns(struct rspamd_task *task, gpointer runtime, + gpointer ctx) +{ + struct rspamd_stat_sqlite3_rt *rt = runtime; + struct rspamd_stat_sqlite3_db *bk; + guint64 res; + + g_assert(rt != NULL); + bk = rt->db; + rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_GET_LEARNS, &res); + + return res; +} + +ucl_object_t * +rspamd_sqlite3_get_stat(gpointer runtime, + gpointer ctx) +{ + ucl_object_t *res = NULL; + struct rspamd_stat_sqlite3_rt *rt = runtime; + struct rspamd_stat_sqlite3_db *bk; + rspamd_mempool_t *pool; + struct stat st; + gint64 rev; + + g_assert(rt != NULL); + bk = rt->db; + pool = bk->pool; + + (void) stat(bk->fname, &st); + rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_GET_LEARNS, &rev); + + res = ucl_object_typed_new(UCL_OBJECT); + ucl_object_insert_key(res, ucl_object_fromint(rev), "revision", + 0, false); + ucl_object_insert_key(res, ucl_object_fromint(st.st_size), "size", + 0, false); + rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_NTOKENS, &rev); + ucl_object_insert_key(res, ucl_object_fromint(rev), "total", 0, false); + ucl_object_insert_key(res, ucl_object_fromint(rev), "used", 0, false); + ucl_object_insert_key(res, ucl_object_fromstring(rt->cf->symbol), + "symbol", 0, false); + ucl_object_insert_key(res, ucl_object_fromstring("sqlite3"), + "type", 0, false); + rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_NLANGUAGES, &rev); + ucl_object_insert_key(res, ucl_object_fromint(rev), + "languages", 0, false); + rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_NUSERS, &rev); + ucl_object_insert_key(res, ucl_object_fromint(rev), + "users", 0, false); + + if (rt->cf->label) { + ucl_object_insert_key(res, ucl_object_fromstring(rt->cf->label), + "label", 0, false); + } + + return res; +} + +gpointer +rspamd_sqlite3_load_tokenizer_config(gpointer runtime, + gsize *len) +{ + gpointer tk_conf, copied_conf; + guint64 sz; + struct rspamd_stat_sqlite3_rt *rt = runtime; + struct rspamd_stat_sqlite3_db *bk; + + g_assert(rt != NULL); + bk = rt->db; + + g_assert(rspamd_sqlite3_run_prstmt(rt->db->pool, bk->sqlite, bk->prstmt, + RSPAMD_STAT_BACKEND_LOAD_TOKENIZER, &sz, &tk_conf) == SQLITE_OK); + g_assert(sz > 0); + /* + * Here we can have either decoded or undecoded version of tokenizer config + * XXX: dirty hack to check if we have osb magic here + */ + if (sz > 7 && memcmp(tk_conf, "osbtokv", 7) == 0) { + copied_conf = rspamd_mempool_alloc(rt->task->task_pool, sz); + memcpy(copied_conf, tk_conf, sz); + g_free(tk_conf); + } + else { + /* Need to decode */ + copied_conf = rspamd_decode_base32(tk_conf, sz, len, RSPAMD_BASE32_DEFAULT); + g_free(tk_conf); + rspamd_mempool_add_destructor(rt->task->task_pool, g_free, copied_conf); + } + + if (len) { + *len = sz; + } + + return copied_conf; +} diff --git a/src/libstat/classifiers/bayes.c b/src/libstat/classifiers/bayes.c new file mode 100644 index 0000000..513db9a --- /dev/null +++ b/src/libstat/classifiers/bayes.c @@ -0,0 +1,551 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + * Bayesian classifier + */ +#include "classifiers.h" +#include "rspamd.h" +#include "stat_internal.h" +#include "math.h" + +#define msg_err_bayes(...) rspamd_default_log_function(G_LOG_LEVEL_CRITICAL, \ + "bayes", task->task_pool->tag.uid, \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) +#define msg_warn_bayes(...) rspamd_default_log_function(G_LOG_LEVEL_WARNING, \ + "bayes", task->task_pool->tag.uid, \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) +#define msg_info_bayes(...) rspamd_default_log_function(G_LOG_LEVEL_INFO, \ + "bayes", task->task_pool->tag.uid, \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) + +INIT_LOG_MODULE_PUBLIC(bayes) + +static inline GQuark +bayes_error_quark(void) +{ + return g_quark_from_static_string("bayes-error"); +} + +/** + * Returns probability of chisquare > value with specified number of freedom + * degrees + * @param value value to test + * @param freedom_deg number of degrees of freedom + * @return + */ +static gdouble +inv_chi_square(struct rspamd_task *task, gdouble value, gint freedom_deg) +{ + double prob, sum, m; + gint i; + + errno = 0; + m = -value; + prob = exp(value); + + if (errno == ERANGE) { + /* + * e^x where x is large *NEGATIVE* number is OK, so we have a very strong + * confidence that inv-chi-square is close to zero + */ + msg_debug_bayes("exp overflow"); + + if (value < 0) { + return 0; + } + else { + return 1.0; + } + } + + sum = prob; + + msg_debug_bayes("m: %f, probability: %g", m, prob); + + /* + * m is our confidence in class + * prob is e ^ x (small value since x is normally less than zero + * So we integrate over degrees of freedom and produce the total result + * from 1.0 (no confidence) to 0.0 (full confidence) + */ + for (i = 1; i < freedom_deg; i++) { + prob *= m / (gdouble) i; + sum += prob; + msg_debug_bayes("i=%d, probability: %g, sum: %g", i, prob, sum); + } + + return MIN(1.0, sum); +} + +struct bayes_task_closure { + double ham_prob; + double spam_prob; + gdouble meta_skip_prob; + guint64 processed_tokens; + guint64 total_hits; + guint64 text_tokens; + struct rspamd_task *task; +}; + +/* + * Mathematically we use pow(complexity, complexity), where complexity is the + * window index + */ +static const double feature_weight[] = {0, 3125, 256, 27, 1, 0, 0, 0}; + +#define PROB_COMBINE(prob, cnt, weight, assumed) (((weight) * (assumed) + (cnt) * (prob)) / ((weight) + (cnt))) +/* + * In this callback we calculate local probabilities for tokens + */ +static void +bayes_classify_token(struct rspamd_classifier *ctx, + rspamd_token_t *tok, struct bayes_task_closure *cl) +{ + guint i; + gint id; + guint spam_count = 0, ham_count = 0, total_count = 0; + struct rspamd_statfile *st; + struct rspamd_task *task; + const gchar *token_type = "txt"; + double spam_prob, spam_freq, ham_freq, bayes_spam_prob, bayes_ham_prob, + ham_prob, fw, w, val; + + task = cl->task; + +#if 0 + if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_LUA_META) { + /* Ignore lua metatokens for now */ + return; + } +#endif + + if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_META && cl->meta_skip_prob > 0) { + val = rspamd_random_double_fast(); + + if (val <= cl->meta_skip_prob) { + if (tok->t1 && tok->t2) { + msg_debug_bayes( + "token(meta) %uL <%*s:%*s> probabilistically skipped", + tok->data, + (int) tok->t1->original.len, tok->t1->original.begin, + (int) tok->t2->original.len, tok->t2->original.begin); + } + + return; + } + } + + for (i = 0; i < ctx->statfiles_ids->len; i++) { + id = g_array_index(ctx->statfiles_ids, gint, i); + st = g_ptr_array_index(ctx->ctx->statfiles, id); + g_assert(st != NULL); + val = tok->values[id]; + + if (val > 0) { + if (st->stcf->is_spam) { + spam_count += val; + } + else { + ham_count += val; + } + + total_count += val; + cl->total_hits += val; + } + } + + /* Probability for this token */ + if (total_count >= ctx->cfg->min_token_hits) { + spam_freq = ((double) spam_count / MAX(1., (double) ctx->spam_learns)); + ham_freq = ((double) ham_count / MAX(1., (double) ctx->ham_learns)); + spam_prob = spam_freq / (spam_freq + ham_freq); + ham_prob = ham_freq / (spam_freq + ham_freq); + + if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_UNIGRAM) { + fw = 1.0; + } + else { + fw = feature_weight[tok->window_idx % + G_N_ELEMENTS(feature_weight)]; + } + + + w = (fw * total_count) / (1.0 + fw * total_count); + + bayes_spam_prob = PROB_COMBINE(spam_prob, total_count, w, 0.5); + + if ((bayes_spam_prob > 0.5 && bayes_spam_prob < 0.5 + ctx->cfg->min_prob_strength) || + (bayes_spam_prob < 0.5 && bayes_spam_prob > 0.5 - ctx->cfg->min_prob_strength)) { + msg_debug_bayes( + "token %uL <%*s:%*s> skipped, probability not in range: %f", + tok->data, + (int) tok->t1->stemmed.len, tok->t1->stemmed.begin, + (int) tok->t2->stemmed.len, tok->t2->stemmed.begin, + bayes_spam_prob); + + return; + } + + bayes_ham_prob = PROB_COMBINE(ham_prob, total_count, w, 0.5); + + cl->spam_prob += log(bayes_spam_prob); + cl->ham_prob += log(bayes_ham_prob); + cl->processed_tokens++; + + if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) { + cl->text_tokens++; + } + else { + token_type = "meta"; + } + + if (tok->t1 && tok->t2) { + msg_debug_bayes("token(%s) %uL <%*s:%*s>: weight: %f, cf: %f, " + "total_count: %ud, " + "spam_count: %ud, ham_count: %ud," + "spam_prob: %.3f, ham_prob: %.3f, " + "bayes_spam_prob: %.3f, bayes_ham_prob: %.3f, " + "current spam probability: %.3f, current ham probability: %.3f", + token_type, + tok->data, + (int) tok->t1->stemmed.len, tok->t1->stemmed.begin, + (int) tok->t2->stemmed.len, tok->t2->stemmed.begin, + fw, w, total_count, spam_count, ham_count, + spam_prob, ham_prob, + bayes_spam_prob, bayes_ham_prob, + cl->spam_prob, cl->ham_prob); + } + else { + msg_debug_bayes("token(%s) %uL <?:?>: weight: %f, cf: %f, " + "total_count: %ud, " + "spam_count: %ud, ham_count: %ud," + "spam_prob: %.3f, ham_prob: %.3f, " + "bayes_spam_prob: %.3f, bayes_ham_prob: %.3f, " + "current spam probability: %.3f, current ham probability: %.3f", + token_type, + tok->data, + fw, w, total_count, spam_count, ham_count, + spam_prob, ham_prob, + bayes_spam_prob, bayes_ham_prob, + cl->spam_prob, cl->ham_prob); + } + } +} + + +gboolean +bayes_init(struct rspamd_config *cfg, + struct ev_loop *ev_base, + struct rspamd_classifier *cl) +{ + cl->cfg->flags |= RSPAMD_FLAG_CLASSIFIER_INTEGER; + + return TRUE; +} + +void bayes_fin(struct rspamd_classifier *cl) +{ +} + +gboolean +bayes_classify(struct rspamd_classifier *ctx, + GPtrArray *tokens, + struct rspamd_task *task) +{ + double final_prob, h, s, *pprob; + gchar sumbuf[32]; + struct rspamd_statfile *st = NULL; + struct bayes_task_closure cl; + rspamd_token_t *tok; + guint i, text_tokens = 0; + gint id; + + g_assert(ctx != NULL); + g_assert(tokens != NULL); + + memset(&cl, 0, sizeof(cl)); + cl.task = task; + + /* Check min learns */ + if (ctx->cfg->min_learns > 0) { + if (ctx->ham_learns < ctx->cfg->min_learns) { + msg_info_task("not classified as ham. The ham class needs more " + "training samples. Currently: %ul; minimum %ud required", + ctx->ham_learns, ctx->cfg->min_learns); + + return TRUE; + } + if (ctx->spam_learns < ctx->cfg->min_learns) { + msg_info_task("not classified as spam. The spam class needs more " + "training samples. Currently: %ul; minimum %ud required", + ctx->spam_learns, ctx->cfg->min_learns); + + return TRUE; + } + } + + for (i = 0; i < tokens->len; i++) { + tok = g_ptr_array_index(tokens, i); + if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) { + text_tokens++; + } + } + + if (text_tokens == 0) { + msg_info_task("skipped classification as there are no text tokens. " + "Total tokens: %ud", + tokens->len); + + return TRUE; + } + + /* + * Skip some metatokens if we don't have enough text tokens + */ + if (text_tokens > tokens->len - text_tokens) { + cl.meta_skip_prob = 0.0; + } + else { + cl.meta_skip_prob = 1.0 - text_tokens / tokens->len; + } + + for (i = 0; i < tokens->len; i++) { + tok = g_ptr_array_index(tokens, i); + + bayes_classify_token(ctx, tok, &cl); + } + + if (cl.processed_tokens == 0) { + msg_info_bayes("no tokens found in bayes database " + "(%ud total tokens, %ud text tokens), ignore stats", + tokens->len, text_tokens); + + return TRUE; + } + + if (ctx->cfg->min_tokens > 0 && + cl.text_tokens < (gint) (ctx->cfg->min_tokens * 0.1)) { + msg_info_bayes("ignore bayes probability since we have " + "found too few text tokens: %uL (of %ud checked), " + "at least %d required", + cl.text_tokens, + text_tokens, + (gint) (ctx->cfg->min_tokens * 0.1)); + + return TRUE; + } + + if (cl.spam_prob > -300 && cl.ham_prob > -300) { + /* Fisher value is low enough to apply inv_chi_square */ + h = 1 - inv_chi_square(task, cl.spam_prob, cl.processed_tokens); + s = 1 - inv_chi_square(task, cl.ham_prob, cl.processed_tokens); + } + else { + /* Use naive method */ + if (cl.spam_prob < cl.ham_prob) { + h = (1.0 - exp(cl.spam_prob - cl.ham_prob)) / + (1.0 + exp(cl.spam_prob - cl.ham_prob)); + s = 1.0 - h; + } + else { + s = (1.0 - exp(cl.ham_prob - cl.spam_prob)) / + (1.0 + exp(cl.ham_prob - cl.spam_prob)); + h = 1.0 - s; + } + } + + if (isfinite(s) && isfinite(h)) { + final_prob = (s + 1.0 - h) / 2.; + msg_debug_bayes( + "got ham probability %.2f -> %.2f and spam probability %.2f -> %.2f," + " %L tokens processed of %ud total tokens;" + " %uL text tokens found of %ud text tokens)", + cl.ham_prob, + h, + cl.spam_prob, + s, + cl.processed_tokens, + tokens->len, + cl.text_tokens, + text_tokens); + } + else { + /* + * We have some overflow, hence we need to check which class + * is NaN + */ + if (isfinite(h)) { + final_prob = 1.0; + msg_debug_bayes("spam class is full: no" + " ham samples"); + } + else if (isfinite(s)) { + final_prob = 0.0; + msg_debug_bayes("ham class is full: no" + " spam samples"); + } + else { + final_prob = 0.5; + msg_warn_bayes("spam and ham classes are both full"); + } + } + + pprob = rspamd_mempool_alloc(task->task_pool, sizeof(*pprob)); + *pprob = final_prob; + rspamd_mempool_set_variable(task->task_pool, "bayes_prob", pprob, NULL); + + if (cl.processed_tokens > 0 && fabs(final_prob - 0.5) > 0.05) { + /* Now we can have exactly one HAM and exactly one SPAM statfiles per classifier */ + for (i = 0; i < ctx->statfiles_ids->len; i++) { + id = g_array_index(ctx->statfiles_ids, gint, i); + st = g_ptr_array_index(ctx->ctx->statfiles, id); + + if (final_prob > 0.5 && st->stcf->is_spam) { + break; + } + else if (final_prob < 0.5 && !st->stcf->is_spam) { + break; + } + } + + /* Correctly scale HAM */ + if (final_prob < 0.5) { + final_prob = 1.0 - final_prob; + } + + /* + * Bayes p is from 0.5 to 1.0, but confidence is from 0 to 1, so + * we need to rescale it to display correctly + */ + rspamd_snprintf(sumbuf, sizeof(sumbuf), "%.2f%%", + (final_prob - 0.5) * 200.); + final_prob = rspamd_normalize_probability(final_prob, 0.5); + g_assert(st != NULL); + + if (final_prob > 1 || final_prob < 0) { + msg_err_bayes("internal error: probability %f is outside of the " + "allowed range [0..1]", + final_prob); + + if (final_prob > 1) { + final_prob = 1.0; + } + else { + final_prob = 0.0; + } + } + + rspamd_task_insert_result(task, + st->stcf->symbol, + final_prob, + sumbuf); + } + + return TRUE; +} + +gboolean +bayes_learn_spam(struct rspamd_classifier *ctx, + GPtrArray *tokens, + struct rspamd_task *task, + gboolean is_spam, + gboolean unlearn, + GError **err) +{ + guint i, j, total_cnt, spam_cnt, ham_cnt; + gint id; + struct rspamd_statfile *st; + rspamd_token_t *tok; + gboolean incrementing; + + g_assert(ctx != NULL); + g_assert(tokens != NULL); + + incrementing = ctx->cfg->flags & RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND; + + for (i = 0; i < tokens->len; i++) { + total_cnt = 0; + spam_cnt = 0; + ham_cnt = 0; + tok = g_ptr_array_index(tokens, i); + + for (j = 0; j < ctx->statfiles_ids->len; j++) { + id = g_array_index(ctx->statfiles_ids, gint, j); + st = g_ptr_array_index(ctx->ctx->statfiles, id); + g_assert(st != NULL); + + if (!!st->stcf->is_spam == !!is_spam) { + if (incrementing) { + tok->values[id] = 1; + } + else { + tok->values[id]++; + } + + total_cnt += tok->values[id]; + + if (st->stcf->is_spam) { + spam_cnt += tok->values[id]; + } + else { + ham_cnt += tok->values[id]; + } + } + else { + if (tok->values[id] > 0 && unlearn) { + /* Unlearning */ + if (incrementing) { + tok->values[id] = -1; + } + else { + tok->values[id]--; + } + + if (st->stcf->is_spam) { + spam_cnt += tok->values[id]; + } + else { + ham_cnt += tok->values[id]; + } + total_cnt += tok->values[id]; + } + else if (incrementing) { + tok->values[id] = 0; + } + } + } + + if (tok->t1 && tok->t2) { + msg_debug_bayes("token %uL <%*s:%*s>: window: %d, total_count: %d, " + "spam_count: %d, ham_count: %d", + tok->data, + (int) tok->t1->stemmed.len, tok->t1->stemmed.begin, + (int) tok->t2->stemmed.len, tok->t2->stemmed.begin, + tok->window_idx, total_cnt, spam_cnt, ham_cnt); + } + else { + msg_debug_bayes("token %uL <?:?>: window: %d, total_count: %d, " + "spam_count: %d, ham_count: %d", + tok->data, + tok->window_idx, total_cnt, spam_cnt, ham_cnt); + } + } + + return TRUE; +} diff --git a/src/libstat/classifiers/classifiers.h b/src/libstat/classifiers/classifiers.h new file mode 100644 index 0000000..949408c --- /dev/null +++ b/src/libstat/classifiers/classifiers.h @@ -0,0 +1,109 @@ +/* + * Copyright 2023 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CLASSIFIERS_H +#define CLASSIFIERS_H + +#include "config.h" +#include "mem_pool.h" +#include "contrib/libev/ev.h" + +#define RSPAMD_DEFAULT_CLASSIFIER "bayes" +/* Consider this value as 0 */ +#define ALPHA 0.0001 + +#ifdef __cplusplus +extern "C" { +#endif + +struct rspamd_classifier_config; +struct rspamd_task; +struct rspamd_config; +struct rspamd_classifier; + +struct token_node_s; + +struct rspamd_stat_classifier { + char *name; + + gboolean (*init_func)(struct rspamd_config *cfg, + struct ev_loop *ev_base, + struct rspamd_classifier *cl); + + gboolean (*classify_func)(struct rspamd_classifier *ctx, + GPtrArray *tokens, + struct rspamd_task *task); + + gboolean (*learn_spam_func)(struct rspamd_classifier *ctx, + GPtrArray *input, + struct rspamd_task *task, + gboolean is_spam, + gboolean unlearn, + GError **err); + + void (*fin_func)(struct rspamd_classifier *cl); +}; + +/* Bayes algorithm */ +gboolean bayes_init(struct rspamd_config *cfg, + struct ev_loop *ev_base, + struct rspamd_classifier *); + +gboolean bayes_classify(struct rspamd_classifier *ctx, + GPtrArray *tokens, + struct rspamd_task *task); + +gboolean bayes_learn_spam(struct rspamd_classifier *ctx, + GPtrArray *tokens, + struct rspamd_task *task, + gboolean is_spam, + gboolean unlearn, + GError **err); + +void bayes_fin(struct rspamd_classifier *); + +/* Generic lua classifier */ +gboolean lua_classifier_init(struct rspamd_config *cfg, + struct ev_loop *ev_base, + struct rspamd_classifier *); + +gboolean lua_classifier_classify(struct rspamd_classifier *ctx, + GPtrArray *tokens, + struct rspamd_task *task); + +gboolean lua_classifier_learn_spam(struct rspamd_classifier *ctx, + GPtrArray *tokens, + struct rspamd_task *task, + gboolean is_spam, + gboolean unlearn, + GError **err); + +extern gint rspamd_bayes_log_id; +#define msg_debug_bayes(...) rspamd_conditional_debug_fast(NULL, task->from_addr, \ + rspamd_bayes_log_id, "bayes", task->task_pool->tag.uid, \ + G_STRFUNC, \ + __VA_ARGS__) +#define msg_debug_bayes_cfg(...) rspamd_conditional_debug_fast(NULL, NULL, \ + rspamd_bayes_log_id, "bayes", cfg->cfg_pool->tag.uid, \ + G_STRFUNC, \ + __VA_ARGS__) + + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/src/libstat/classifiers/lua_classifier.c b/src/libstat/classifiers/lua_classifier.c new file mode 100644 index 0000000..b74330d --- /dev/null +++ b/src/libstat/classifiers/lua_classifier.c @@ -0,0 +1,237 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "classifiers.h" +#include "cfg_file.h" +#include "stat_internal.h" +#include "lua/lua_common.h" + +struct rspamd_lua_classifier_ctx { + gchar *name; + gint classify_ref; + gint learn_ref; +}; + +static GHashTable *lua_classifiers = NULL; + +#define msg_err_luacl(...) rspamd_default_log_function(G_LOG_LEVEL_CRITICAL, \ + "luacl", task->task_pool->tag.uid, \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) +#define msg_warn_luacl(...) rspamd_default_log_function(G_LOG_LEVEL_WARNING, \ + "luacl", task->task_pool->tag.uid, \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) +#define msg_info_luacl(...) rspamd_default_log_function(G_LOG_LEVEL_INFO, \ + "luacl", task->task_pool->tag.uid, \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) +#define msg_debug_luacl(...) rspamd_conditional_debug_fast(NULL, task->from_addr, \ + rspamd_luacl_log_id, "luacl", task->task_pool->tag.uid, \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) + +INIT_LOG_MODULE(luacl) + +gboolean +lua_classifier_init(struct rspamd_config *cfg, + struct ev_loop *ev_base, + struct rspamd_classifier *cl) +{ + struct rspamd_lua_classifier_ctx *ctx; + lua_State *L = cl->ctx->cfg->lua_state; + gint cb_classify = -1, cb_learn = -1; + + if (lua_classifiers == NULL) { + lua_classifiers = g_hash_table_new_full(rspamd_strcase_hash, + rspamd_strcase_equal, g_free, g_free); + } + + ctx = g_hash_table_lookup(lua_classifiers, cl->subrs->name); + + if (ctx != NULL) { + msg_err_config("duplicate lua classifier definition: %s", + cl->subrs->name); + + return FALSE; + } + + lua_getglobal(L, "rspamd_classifiers"); + if (lua_type(L, -1) != LUA_TTABLE) { + msg_err_config("cannot register classifier %s: no rspamd_classifier global", + cl->subrs->name); + lua_pop(L, 1); + + return FALSE; + } + + lua_pushstring(L, cl->subrs->name); + lua_gettable(L, -2); + + if (lua_type(L, -1) != LUA_TTABLE) { + msg_err_config("cannot register classifier %s: bad lua type: %s", + cl->subrs->name, lua_typename(L, lua_type(L, -1))); + lua_pop(L, 2); + + return FALSE; + } + + lua_pushstring(L, "classify"); + lua_gettable(L, -2); + + if (lua_type(L, -1) != LUA_TFUNCTION) { + msg_err_config("cannot register classifier %s: bad lua type for classify: %s", + cl->subrs->name, lua_typename(L, lua_type(L, -1))); + lua_pop(L, 3); + + return FALSE; + } + + cb_classify = luaL_ref(L, LUA_REGISTRYINDEX); + + lua_pushstring(L, "learn"); + lua_gettable(L, -2); + + if (lua_type(L, -1) != LUA_TFUNCTION) { + msg_err_config("cannot register classifier %s: bad lua type for learn: %s", + cl->subrs->name, lua_typename(L, lua_type(L, -1))); + lua_pop(L, 3); + + return FALSE; + } + + cb_learn = luaL_ref(L, LUA_REGISTRYINDEX); + lua_pop(L, 2); /* Table + global */ + + ctx = g_malloc0(sizeof(*ctx)); + ctx->name = g_strdup(cl->subrs->name); + ctx->classify_ref = cb_classify; + ctx->learn_ref = cb_learn; + cl->cfg->flags |= RSPAMD_FLAG_CLASSIFIER_NO_BACKEND; + g_hash_table_insert(lua_classifiers, ctx->name, ctx); + + return TRUE; +} +gboolean +lua_classifier_classify(struct rspamd_classifier *cl, + GPtrArray *tokens, + struct rspamd_task *task) +{ + struct rspamd_lua_classifier_ctx *ctx; + struct rspamd_task **ptask; + struct rspamd_classifier_config **pcfg; + lua_State *L; + rspamd_token_t *tok; + guint i; + guint64 v; + + ctx = g_hash_table_lookup(lua_classifiers, cl->subrs->name); + g_assert(ctx != NULL); + L = task->cfg->lua_state; + + lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->classify_ref); + ptask = lua_newuserdata(L, sizeof(*ptask)); + *ptask = task; + rspamd_lua_setclass(L, "rspamd{task}", -1); + pcfg = lua_newuserdata(L, sizeof(*pcfg)); + *pcfg = cl->cfg; + rspamd_lua_setclass(L, "rspamd{classifier}", -1); + + lua_createtable(L, tokens->len, 0); + + for (i = 0; i < tokens->len; i++) { + tok = g_ptr_array_index(tokens, i); + v = tok->data; + lua_createtable(L, 3, 0); + /* High word, low word, order */ + lua_pushinteger(L, (guint32) (v >> 32)); + lua_rawseti(L, -2, 1); + lua_pushinteger(L, (guint32) (v)); + lua_rawseti(L, -2, 2); + lua_pushinteger(L, tok->window_idx); + lua_rawseti(L, -2, 3); + lua_rawseti(L, -2, i + 1); + } + + if (lua_pcall(L, 3, 0, 0) != 0) { + msg_err_luacl("error running classify function for %s: %s", ctx->name, + lua_tostring(L, -1)); + lua_pop(L, 1); + + return FALSE; + } + + return TRUE; +} + +gboolean +lua_classifier_learn_spam(struct rspamd_classifier *cl, + GPtrArray *tokens, + struct rspamd_task *task, + gboolean is_spam, + gboolean unlearn, + GError **err) +{ + struct rspamd_lua_classifier_ctx *ctx; + struct rspamd_task **ptask; + struct rspamd_classifier_config **pcfg; + lua_State *L; + rspamd_token_t *tok; + guint i; + guint64 v; + + ctx = g_hash_table_lookup(lua_classifiers, cl->subrs->name); + g_assert(ctx != NULL); + L = task->cfg->lua_state; + + lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->learn_ref); + ptask = lua_newuserdata(L, sizeof(*ptask)); + *ptask = task; + rspamd_lua_setclass(L, "rspamd{task}", -1); + pcfg = lua_newuserdata(L, sizeof(*pcfg)); + *pcfg = cl->cfg; + rspamd_lua_setclass(L, "rspamd{classifier}", -1); + + lua_createtable(L, tokens->len, 0); + + for (i = 0; i < tokens->len; i++) { + tok = g_ptr_array_index(tokens, i); + v = 0; + v = tok->data; + lua_createtable(L, 3, 0); + /* High word, low word, order */ + lua_pushinteger(L, (guint32) (v >> 32)); + lua_rawseti(L, -2, 1); + lua_pushinteger(L, (guint32) (v)); + lua_rawseti(L, -2, 2); + lua_pushinteger(L, tok->window_idx); + lua_rawseti(L, -2, 3); + lua_rawseti(L, -2, i + 1); + } + + lua_pushboolean(L, is_spam); + lua_pushboolean(L, unlearn); + + if (lua_pcall(L, 5, 0, 0) != 0) { + msg_err_luacl("error running learn function for %s: %s", ctx->name, + lua_tostring(L, -1)); + lua_pop(L, 1); + + return FALSE; + } + + return TRUE; +} diff --git a/src/libstat/learn_cache/learn_cache.h b/src/libstat/learn_cache/learn_cache.h new file mode 100644 index 0000000..11a66fc --- /dev/null +++ b/src/libstat/learn_cache/learn_cache.h @@ -0,0 +1,79 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef LEARN_CACHE_H_ +#define LEARN_CACHE_H_ + +#include "config.h" +#include "ucl.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define RSPAMD_DEFAULT_CACHE "sqlite3" + +struct rspamd_task; +struct rspamd_stat_ctx; +struct rspamd_config; +struct rspamd_statfile; + +struct rspamd_stat_cache { + const char *name; + + gpointer (*init)(struct rspamd_stat_ctx *ctx, + struct rspamd_config *cfg, + struct rspamd_statfile *st, + const ucl_object_t *cf); + + gpointer (*runtime)(struct rspamd_task *task, + gpointer ctx, gboolean learn); + + gint (*check)(struct rspamd_task *task, + gboolean is_spam, + gpointer runtime); + + gint (*learn)(struct rspamd_task *task, + gboolean is_spam, + gpointer runtime); + + void (*close)(gpointer ctx); + + gpointer ctx; +}; + +#define RSPAMD_STAT_CACHE_DEF(name) \ + gpointer rspamd_stat_cache_##name##_init(struct rspamd_stat_ctx *ctx, \ + struct rspamd_config *cfg, \ + struct rspamd_statfile *st, \ + const ucl_object_t *cf); \ + gpointer rspamd_stat_cache_##name##_runtime(struct rspamd_task *task, \ + gpointer ctx, gboolean learn); \ + gint rspamd_stat_cache_##name##_check(struct rspamd_task *task, \ + gboolean is_spam, \ + gpointer runtime); \ + gint rspamd_stat_cache_##name##_learn(struct rspamd_task *task, \ + gboolean is_spam, \ + gpointer runtime); \ + void rspamd_stat_cache_##name##_close(gpointer ctx) + +RSPAMD_STAT_CACHE_DEF(sqlite3); +RSPAMD_STAT_CACHE_DEF(redis); + +#ifdef __cplusplus +} +#endif + +#endif /* LEARN_CACHE_H_ */ diff --git a/src/libstat/learn_cache/redis_cache.cxx b/src/libstat/learn_cache/redis_cache.cxx new file mode 100644 index 0000000..0be56bc --- /dev/null +++ b/src/libstat/learn_cache/redis_cache.cxx @@ -0,0 +1,254 @@ +/* + * Copyright 2024 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "config.h" +// Include early to avoid `extern "C"` issues +#include "lua/lua_common.h" +#include "learn_cache.h" +#include "rspamd.h" +#include "stat_api.h" +#include "stat_internal.h" +#include "cryptobox.h" +#include "ucl.h" +#include "libmime/message.h" + +#include <memory> + +struct rspamd_redis_cache_ctx { + lua_State *L; + struct rspamd_statfile_config *stcf; + int check_ref = -1; + int learn_ref = -1; + + rspamd_redis_cache_ctx() = delete; + explicit rspamd_redis_cache_ctx(lua_State *L) + : L(L) + { + } + + ~rspamd_redis_cache_ctx() + { + if (check_ref != -1) { + luaL_unref(L, LUA_REGISTRYINDEX, check_ref); + } + + if (learn_ref != -1) { + luaL_unref(L, LUA_REGISTRYINDEX, learn_ref); + } + } +}; + +static void +rspamd_stat_cache_redis_generate_id(struct rspamd_task *task) +{ + rspamd_cryptobox_hash_state_t st; + rspamd_cryptobox_hash_init(&st, nullptr, 0); + + const auto *user = (const char *) rspamd_mempool_get_variable(task->task_pool, "stat_user"); + /* Use dedicated hash space for per users cache */ + if (user != nullptr) { + rspamd_cryptobox_hash_update(&st, (const unsigned char *) user, strlen(user)); + } + + for (auto i = 0; i < task->tokens->len; i++) { + const auto *tok = (rspamd_token_t *) g_ptr_array_index(task->tokens, i); + rspamd_cryptobox_hash_update(&st, (const unsigned char *) &tok->data, + sizeof(tok->data)); + } + + guchar out[rspamd_cryptobox_HASHBYTES]; + rspamd_cryptobox_hash_final(&st, out); + + auto *b32out = rspamd_mempool_alloc_array_type(task->task_pool, + sizeof(out) * 8 / 5 + 3, char); + auto out_sz = rspamd_encode_base32_buf(out, sizeof(out), b32out, + sizeof(out) * 8 / 5 + 2, RSPAMD_BASE32_DEFAULT); + + if (out_sz > 0) { + /* Zero terminate */ + b32out[out_sz] = '\0'; + rspamd_mempool_set_variable(task->task_pool, "words_hash", b32out, nullptr); + } +} + +gpointer +rspamd_stat_cache_redis_init(struct rspamd_stat_ctx *ctx, + struct rspamd_config *cfg, + struct rspamd_statfile *st, + const ucl_object_t *cf) +{ + std::unique_ptr<rspamd_redis_cache_ctx> cache_ctx = std::make_unique<rspamd_redis_cache_ctx>(RSPAMD_LUA_CFG_STATE(cfg)); + + auto *L = RSPAMD_LUA_CFG_STATE(cfg); + lua_settop(L, 0); + + lua_pushcfunction(L, &rspamd_lua_traceback); + auto err_idx = lua_gettop(L); + + /* Obtain function */ + if (!rspamd_lua_require_function(L, "lua_bayes_redis", "lua_bayes_init_cache")) { + msg_err_config("cannot require lua_bayes_redis.lua_bayes_init_cache"); + lua_settop(L, err_idx - 1); + + return nullptr; + } + + /* Push arguments */ + ucl_object_push_lua(L, st->classifier->cfg->opts, false); + ucl_object_push_lua(L, st->stcf->opts, false); + + if (lua_pcall(L, 2, 2, err_idx) != 0) { + msg_err("call to lua_bayes_init_cache " + "script failed: %s", + lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); + + return nullptr; + } + + /* + * Results are in the stack: + * top - 1 - check function (idx = -2) + * top - learn function (idx = -1) + */ + lua_pushvalue(L, -2); + cache_ctx->check_ref = luaL_ref(L, LUA_REGISTRYINDEX); + + lua_pushvalue(L, -1); + cache_ctx->learn_ref = luaL_ref(L, LUA_REGISTRYINDEX); + + lua_settop(L, err_idx - 1); + + return (gpointer) cache_ctx.release(); +} + +gpointer +rspamd_stat_cache_redis_runtime(struct rspamd_task *task, + gpointer c, gboolean learn) +{ + auto *ctx = (struct rspamd_redis_cache_ctx *) c; + + if (task->tokens == nullptr || task->tokens->len == 0) { + return nullptr; + } + + if (!learn) { + /* On check, we produce words_hash variable, on learn it is guaranteed to be set */ + rspamd_stat_cache_redis_generate_id(task); + } + + return (void *) ctx; +} + +static gint +rspamd_stat_cache_checked(lua_State *L) +{ + auto *task = lua_check_task(L, 1); + auto res = lua_toboolean(L, 2); + + if (res) { + auto val = lua_tointeger(L, 3); + + if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) || + (val <= 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) { + /* Already learned */ + msg_info_task("<%s> has been already " + "learned as %s, ignore it", + MESSAGE_FIELD(task, message_id), + (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham"); + task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED; + } + else if (val != 0) { + /* Unlearn flag */ + task->flags |= RSPAMD_TASK_FLAG_UNLEARN; + } + } + + /* Ignore errors for now, as we can do nothing about them at the moment */ + + return 0; +} + +gint rspamd_stat_cache_redis_check(struct rspamd_task *task, + gboolean is_spam, + gpointer runtime) +{ + auto *ctx = (struct rspamd_redis_cache_ctx *) runtime; + auto *h = (char *) rspamd_mempool_get_variable(task->task_pool, "words_hash"); + + if (h == nullptr) { + return RSPAMD_LEARN_IGNORE; + } + + auto *L = ctx->L; + + lua_pushcfunction(L, &rspamd_lua_traceback); + gint err_idx = lua_gettop(L); + + /* Function arguments */ + lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->check_ref); + rspamd_lua_task_push(L, task); + lua_pushstring(L, h); + + lua_pushcclosure(L, &rspamd_stat_cache_checked, 0); + + if (lua_pcall(L, 3, 0, err_idx) != 0) { + msg_err_task("call to redis failed: %s", lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); + return RSPAMD_LEARN_IGNORE; + } + + /* We need to return OK every time */ + return RSPAMD_LEARN_OK; +} + +gint rspamd_stat_cache_redis_learn(struct rspamd_task *task, + gboolean is_spam, + gpointer runtime) +{ + auto *ctx = (struct rspamd_redis_cache_ctx *) runtime; + + if (rspamd_session_blocked(task->s)) { + return RSPAMD_LEARN_IGNORE; + } + + auto *h = (char *) rspamd_mempool_get_variable(task->task_pool, "words_hash"); + g_assert(h != nullptr); + auto *L = ctx->L; + + lua_pushcfunction(L, &rspamd_lua_traceback); + gint err_idx = lua_gettop(L); + + /* Function arguments */ + lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->learn_ref); + rspamd_lua_task_push(L, task); + lua_pushstring(L, h); + lua_pushboolean(L, is_spam); + + if (lua_pcall(L, 3, 0, err_idx) != 0) { + msg_err_task("call to redis failed: %s", lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); + return RSPAMD_LEARN_IGNORE; + } + + /* We need to return OK every time */ + return RSPAMD_LEARN_OK; +} + +void rspamd_stat_cache_redis_close(gpointer c) +{ + auto *ctx = (struct rspamd_redis_cache_ctx *) c; + delete ctx; +} diff --git a/src/libstat/learn_cache/sqlite3_cache.c b/src/libstat/learn_cache/sqlite3_cache.c new file mode 100644 index 0000000..d8ad20a --- /dev/null +++ b/src/libstat/learn_cache/sqlite3_cache.c @@ -0,0 +1,274 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "config.h" +#include "learn_cache.h" +#include "rspamd.h" +#include "stat_api.h" +#include "stat_internal.h" +#include "cryptobox.h" +#include "ucl.h" +#include "fstring.h" +#include "message.h" +#include "libutil/sqlite_utils.h" + +static const char *create_tables_sql = + "" + "CREATE TABLE IF NOT EXISTS learns(" + "id INTEGER PRIMARY KEY," + "flag INTEGER NOT NULL," + "digest TEXT NOT NULL);" + "CREATE UNIQUE INDEX IF NOT EXISTS d ON learns(digest);" + ""; + +#define SQLITE_CACHE_PATH RSPAMD_DBDIR "/learn_cache.sqlite" + +enum rspamd_stat_sqlite3_stmt_idx { + RSPAMD_STAT_CACHE_TRANSACTION_START_IM = 0, + RSPAMD_STAT_CACHE_TRANSACTION_START_DEF, + RSPAMD_STAT_CACHE_TRANSACTION_COMMIT, + RSPAMD_STAT_CACHE_TRANSACTION_ROLLBACK, + RSPAMD_STAT_CACHE_GET_LEARN, + RSPAMD_STAT_CACHE_ADD_LEARN, + RSPAMD_STAT_CACHE_UPDATE_LEARN, + RSPAMD_STAT_CACHE_MAX +}; + +static struct rspamd_sqlite3_prstmt prepared_stmts[RSPAMD_STAT_CACHE_MAX] = + { + {.idx = RSPAMD_STAT_CACHE_TRANSACTION_START_IM, + .sql = "BEGIN IMMEDIATE TRANSACTION;", + .args = "", + .stmt = NULL, + .result = SQLITE_DONE, + .ret = ""}, + {.idx = RSPAMD_STAT_CACHE_TRANSACTION_START_DEF, + .sql = "BEGIN DEFERRED TRANSACTION;", + .args = "", + .stmt = NULL, + .result = SQLITE_DONE, + .ret = ""}, + {.idx = RSPAMD_STAT_CACHE_TRANSACTION_COMMIT, + .sql = "COMMIT;", + .args = "", + .stmt = NULL, + .result = SQLITE_DONE, + .ret = ""}, + {.idx = RSPAMD_STAT_CACHE_TRANSACTION_ROLLBACK, + .sql = "ROLLBACK;", + .args = "", + .stmt = NULL, + .result = SQLITE_DONE, + .ret = ""}, + {.idx = RSPAMD_STAT_CACHE_GET_LEARN, + .sql = "SELECT flag FROM learns WHERE digest=?1", + .args = "V", + .stmt = NULL, + .result = SQLITE_ROW, + .ret = "I"}, + {.idx = RSPAMD_STAT_CACHE_ADD_LEARN, + .sql = "INSERT INTO learns(digest, flag) VALUES (?1, ?2);", + .args = "VI", + .stmt = NULL, + .result = SQLITE_DONE, + .ret = ""}, + {.idx = RSPAMD_STAT_CACHE_UPDATE_LEARN, + .sql = "UPDATE learns SET flag=?1 WHERE digest=?2;", + .args = "IV", + .stmt = NULL, + .result = SQLITE_DONE, + .ret = ""}}; + +struct rspamd_stat_sqlite3_ctx { + sqlite3 *db; + GArray *prstmt; +}; + +gpointer +rspamd_stat_cache_sqlite3_init(struct rspamd_stat_ctx *ctx, + struct rspamd_config *cfg, + struct rspamd_statfile *st, + const ucl_object_t *cf) +{ + struct rspamd_stat_sqlite3_ctx *new = NULL; + const ucl_object_t *elt; + gchar dbpath[PATH_MAX]; + const gchar *path = SQLITE_CACHE_PATH; + sqlite3 *sqlite; + GError *err = NULL; + + if (cf) { + elt = ucl_object_lookup_any(cf, "path", "file", NULL); + + if (elt != NULL) { + path = ucl_object_tostring(elt); + } + } + + rspamd_snprintf(dbpath, sizeof(dbpath), "%s", path); + + sqlite = rspamd_sqlite3_open_or_create(cfg->cfg_pool, + dbpath, create_tables_sql, 0, &err); + + if (sqlite == NULL) { + msg_err("cannot open sqlite3 cache: %e", err); + g_error_free(err); + err = NULL; + } + else { + new = g_malloc0(sizeof(*new)); + new->db = sqlite; + new->prstmt = rspamd_sqlite3_init_prstmt(sqlite, prepared_stmts, + RSPAMD_STAT_CACHE_MAX, &err); + + if (new->prstmt == NULL) { + msg_err("cannot open sqlite3 cache: %e", err); + g_error_free(err); + err = NULL; + sqlite3_close(sqlite); + g_free(new); + new = NULL; + } + } + + return new; +} + +gpointer +rspamd_stat_cache_sqlite3_runtime(struct rspamd_task *task, + gpointer ctx, gboolean learn) +{ + /* No need of runtime for this type of classifier */ + return ctx; +} + +gint rspamd_stat_cache_sqlite3_check(struct rspamd_task *task, + gboolean is_spam, + gpointer runtime) +{ + struct rspamd_stat_sqlite3_ctx *ctx = runtime; + rspamd_cryptobox_hash_state_t st; + rspamd_token_t *tok; + guchar *out; + gchar *user = NULL; + guint i; + gint rc; + gint64 flag; + + if (task->tokens == NULL || task->tokens->len == 0) { + return RSPAMD_LEARN_IGNORE; + } + + if (ctx != NULL && ctx->db != NULL) { + out = rspamd_mempool_alloc(task->task_pool, rspamd_cryptobox_HASHBYTES); + + rspamd_cryptobox_hash_init(&st, NULL, 0); + + user = rspamd_mempool_get_variable(task->task_pool, "stat_user"); + /* Use dedicated hash space for per users cache */ + if (user != NULL) { + rspamd_cryptobox_hash_update(&st, user, strlen(user)); + } + + for (i = 0; i < task->tokens->len; i++) { + tok = g_ptr_array_index(task->tokens, i); + rspamd_cryptobox_hash_update(&st, (guchar *) &tok->data, + sizeof(tok->data)); + } + + rspamd_cryptobox_hash_final(&st, out); + + rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt, + RSPAMD_STAT_CACHE_TRANSACTION_START_DEF); + rc = rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt, + RSPAMD_STAT_CACHE_GET_LEARN, (gint64) rspamd_cryptobox_HASHBYTES, + out, &flag); + rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt, + RSPAMD_STAT_CACHE_TRANSACTION_COMMIT); + + /* Save hash into variables */ + rspamd_mempool_set_variable(task->task_pool, "words_hash", out, NULL); + + if (rc == SQLITE_OK) { + /* We have some existing record in the table */ + if (!!flag == !!is_spam) { + /* Already learned */ + msg_warn_task("already seen stat hash: %*bs", + rspamd_cryptobox_HASHBYTES, out); + return RSPAMD_LEARN_IGNORE; + } + else { + /* Need to relearn */ + return RSPAMD_LEARN_UNLEARN; + } + } + } + + return RSPAMD_LEARN_OK; +} + +gint rspamd_stat_cache_sqlite3_learn(struct rspamd_task *task, + gboolean is_spam, + gpointer runtime) +{ + struct rspamd_stat_sqlite3_ctx *ctx = runtime; + gboolean unlearn = !!(task->flags & RSPAMD_TASK_FLAG_UNLEARN); + guchar *h; + gint64 flag; + + h = rspamd_mempool_get_variable(task->task_pool, "words_hash"); + + if (h == NULL) { + return RSPAMD_LEARN_IGNORE; + } + + flag = !!is_spam ? 1 : 0; + + if (!unlearn) { + /* Insert result new id */ + rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt, + RSPAMD_STAT_CACHE_TRANSACTION_START_IM); + rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt, + RSPAMD_STAT_CACHE_ADD_LEARN, + (gint64) rspamd_cryptobox_HASHBYTES, h, flag); + rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt, + RSPAMD_STAT_CACHE_TRANSACTION_COMMIT); + } + else { + rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt, + RSPAMD_STAT_CACHE_TRANSACTION_START_IM); + rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt, + RSPAMD_STAT_CACHE_UPDATE_LEARN, + flag, + (gint64) rspamd_cryptobox_HASHBYTES, h); + rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt, + RSPAMD_STAT_CACHE_TRANSACTION_COMMIT); + } + + rspamd_sqlite3_sync(ctx->db, NULL, NULL); + + return RSPAMD_LEARN_OK; +} + +void rspamd_stat_cache_sqlite3_close(gpointer c) +{ + struct rspamd_stat_sqlite3_ctx *ctx = (struct rspamd_stat_sqlite3_ctx *) c; + + if (ctx != NULL) { + rspamd_sqlite3_close_prstmt(ctx->db, ctx->prstmt); + sqlite3_close(ctx->db); + g_free(ctx); + } +} diff --git a/src/libstat/stat_api.h b/src/libstat/stat_api.h new file mode 100644 index 0000000..1badb20 --- /dev/null +++ b/src/libstat/stat_api.h @@ -0,0 +1,147 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef STAT_API_H_ +#define STAT_API_H_ + +#include "config.h" +#include "task.h" +#include "lua/lua_common.h" +#include "contrib/libev/ev.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @file stat_api.h + * High level statistics API + */ + +#define RSPAMD_STAT_TOKEN_FLAG_TEXT (1u << 0) +#define RSPAMD_STAT_TOKEN_FLAG_META (1u << 1) +#define RSPAMD_STAT_TOKEN_FLAG_LUA_META (1u << 2) +#define RSPAMD_STAT_TOKEN_FLAG_EXCEPTION (1u << 3) +#define RSPAMD_STAT_TOKEN_FLAG_HEADER (1u << 4) +#define RSPAMD_STAT_TOKEN_FLAG_UNIGRAM (1u << 5) +#define RSPAMD_STAT_TOKEN_FLAG_UTF (1u << 6) +#define RSPAMD_STAT_TOKEN_FLAG_NORMALISED (1u << 7) +#define RSPAMD_STAT_TOKEN_FLAG_STEMMED (1u << 8) +#define RSPAMD_STAT_TOKEN_FLAG_BROKEN_UNICODE (1u << 9) +#define RSPAMD_STAT_TOKEN_FLAG_STOP_WORD (1u << 10) +#define RSPAMD_STAT_TOKEN_FLAG_SKIPPED (1u << 11) +#define RSPAMD_STAT_TOKEN_FLAG_INVISIBLE_SPACES (1u << 12) +#define RSPAMD_STAT_TOKEN_FLAG_EMOJI (1u << 13) + +typedef struct rspamd_stat_token_s { + rspamd_ftok_t original; /* utf8 raw */ + rspamd_ftok_unicode_t unicode; /* array of unicode characters, normalized, lowercased */ + rspamd_ftok_t normalized; /* normalized and lowercased utf8 */ + rspamd_ftok_t stemmed; /* stemmed utf8 */ + guint flags; +} rspamd_stat_token_t; + +typedef struct token_node_s { + guint64 data; + guint window_idx; + guint flags; + rspamd_stat_token_t *t1; + rspamd_stat_token_t *t2; + float values[]; +} rspamd_token_t; + +struct rspamd_stat_ctx; + +/** + * The results of statistics processing: + * - error + * - need to do additional job for processing + * - all processed + */ +typedef enum rspamd_stat_result_e { + RSPAMD_STAT_PROCESS_ERROR = 0, + RSPAMD_STAT_PROCESS_DELAYED = 1, + RSPAMD_STAT_PROCESS_OK +} rspamd_stat_result_t; + +/** + * Initialise statistics modules + * @param cfg + */ +void rspamd_stat_init(struct rspamd_config *cfg, struct ev_loop *ev_base); + +/** + * Finalize statistics + */ +void rspamd_stat_close(void); + +/** + * Tokenize task + * @param st_ctx + * @param task + */ +void rspamd_stat_process_tokenize(struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task); + +/** + * Classify the task specified and insert symbols if needed + * @param task + * @param L lua state + * @param err error returned + * @return TRUE if task has been classified + */ +rspamd_stat_result_t rspamd_stat_classify(struct rspamd_task *task, + lua_State *L, guint stage, GError **err); + + +/** + * Check if a task should be learned and set the appropriate flags for it + * @param task + * @return + */ +gboolean rspamd_stat_check_autolearn(struct rspamd_task *task); + +/** + * Learn task as spam or ham, task must be processed prior to this call + * @param task task to learn + * @param spam if TRUE learn spam, otherwise learn ham + * @param L lua state + * @param classifier NULL to learn all classifiers, name to learn a specific one + * @param err error returned + * @return TRUE if task has been learned + */ +rspamd_stat_result_t rspamd_stat_learn(struct rspamd_task *task, + gboolean spam, lua_State *L, const gchar *classifier, + guint stage, + GError **err); + +/** + * Get the overall statistics for all statfile backends + * @param cfg configuration + * @param total_learns the total number of learns is stored here + * @return array of statistical information + */ +rspamd_stat_result_t rspamd_stat_statistics(struct rspamd_task *task, + struct rspamd_config *cfg, + guint64 *total_learns, + ucl_object_t **res); + +void rspamd_stat_unload(void); + +#ifdef __cplusplus +} +#endif + +#endif /* STAT_API_H_ */ diff --git a/src/libstat/stat_config.c b/src/libstat/stat_config.c new file mode 100644 index 0000000..2748044 --- /dev/null +++ b/src/libstat/stat_config.c @@ -0,0 +1,603 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "config.h" +#include "stat_api.h" +#include "rspamd.h" +#include "cfg_rcl.h" +#include "stat_internal.h" +#include "lua/lua_common.h" + +static struct rspamd_stat_ctx *stat_ctx = NULL; + +static struct rspamd_stat_classifier lua_classifier = { + .name = "lua", + .init_func = lua_classifier_init, + .classify_func = lua_classifier_classify, + .learn_spam_func = lua_classifier_learn_spam, + .fin_func = NULL, +}; + +static struct rspamd_stat_classifier stat_classifiers[] = { + { + .name = "bayes", + .init_func = bayes_init, + .classify_func = bayes_classify, + .learn_spam_func = bayes_learn_spam, + .fin_func = bayes_fin, + }}; + +static struct rspamd_stat_tokenizer stat_tokenizers[] = { + { + .name = "osb-text", + .get_config = rspamd_tokenizer_osb_get_config, + .tokenize_func = rspamd_tokenizer_osb, + }, + { + .name = "osb", + .get_config = rspamd_tokenizer_osb_get_config, + .tokenize_func = rspamd_tokenizer_osb, + }, +}; + +#define RSPAMD_STAT_BACKEND_ELT(nam, eltn) \ + { \ + .name = #nam, \ + .read_only = false, \ + .init = rspamd_##eltn##_init, \ + .runtime = rspamd_##eltn##_runtime, \ + .process_tokens = rspamd_##eltn##_process_tokens, \ + .finalize_process = rspamd_##eltn##_finalize_process, \ + .learn_tokens = rspamd_##eltn##_learn_tokens, \ + .finalize_learn = rspamd_##eltn##_finalize_learn, \ + .total_learns = rspamd_##eltn##_total_learns, \ + .inc_learns = rspamd_##eltn##_inc_learns, \ + .dec_learns = rspamd_##eltn##_dec_learns, \ + .get_stat = rspamd_##eltn##_get_stat, \ + .load_tokenizer_config = rspamd_##eltn##_load_tokenizer_config, \ + .close = rspamd_##eltn##_close \ + } +#define RSPAMD_STAT_BACKEND_ELT_READONLY(nam, eltn) \ + { \ + .name = #nam, \ + .read_only = true, \ + .init = rspamd_##eltn##_init, \ + .runtime = rspamd_##eltn##_runtime, \ + .process_tokens = rspamd_##eltn##_process_tokens, \ + .finalize_process = rspamd_##eltn##_finalize_process, \ + .learn_tokens = NULL, \ + .finalize_learn = NULL, \ + .total_learns = rspamd_##eltn##_total_learns, \ + .inc_learns = NULL, \ + .dec_learns = NULL, \ + .get_stat = rspamd_##eltn##_get_stat, \ + .load_tokenizer_config = rspamd_##eltn##_load_tokenizer_config, \ + .close = rspamd_##eltn##_close \ + } + +static struct rspamd_stat_backend stat_backends[] = { + RSPAMD_STAT_BACKEND_ELT(mmap, mmaped_file), + RSPAMD_STAT_BACKEND_ELT(sqlite3, sqlite3), + RSPAMD_STAT_BACKEND_ELT_READONLY(cdb, cdb), + RSPAMD_STAT_BACKEND_ELT(redis, redis)}; + +#define RSPAMD_STAT_CACHE_ELT(nam, eltn) \ + { \ + .name = #nam, \ + .init = rspamd_stat_cache_##eltn##_init, \ + .runtime = rspamd_stat_cache_##eltn##_runtime, \ + .check = rspamd_stat_cache_##eltn##_check, \ + .learn = rspamd_stat_cache_##eltn##_learn, \ + .close = rspamd_stat_cache_##eltn##_close \ + } + +static struct rspamd_stat_cache stat_caches[] = { + RSPAMD_STAT_CACHE_ELT(sqlite3, sqlite3), + RSPAMD_STAT_CACHE_ELT(redis, redis), +}; + +void rspamd_stat_init(struct rspamd_config *cfg, struct ev_loop *ev_base) +{ + GList *cur, *curst; + struct rspamd_classifier_config *clf; + struct rspamd_statfile_config *stf; + struct rspamd_stat_backend *bk; + struct rspamd_statfile *st; + struct rspamd_classifier *cl; + const ucl_object_t *cache_obj = NULL, *cache_name_obj; + const gchar *cache_name = NULL; + lua_State *L = cfg->lua_state; + guint lua_classifiers_cnt = 0, i; + gboolean skip_cache = FALSE; + + if (stat_ctx == NULL) { + stat_ctx = g_malloc0(sizeof(*stat_ctx)); + } + + lua_getglobal(L, "rspamd_classifiers"); + + if (lua_type(L, -1) == LUA_TTABLE) { + lua_pushnil(L); + + while (lua_next(L, -2) != 0) { + lua_classifiers_cnt++; + lua_pop(L, 1); + } + } + + lua_pop(L, 1); + + stat_ctx->classifiers_count = G_N_ELEMENTS(stat_classifiers) + + lua_classifiers_cnt; + stat_ctx->classifiers_subrs = g_new0(struct rspamd_stat_classifier, + stat_ctx->classifiers_count); + + for (i = 0; i < G_N_ELEMENTS(stat_classifiers); i++) { + memcpy(&stat_ctx->classifiers_subrs[i], &stat_classifiers[i], + sizeof(struct rspamd_stat_classifier)); + } + + lua_getglobal(L, "rspamd_classifiers"); + + if (lua_type(L, -1) == LUA_TTABLE) { + lua_pushnil(L); + + while (lua_next(L, -2) != 0) { + lua_pushvalue(L, -2); + memcpy(&stat_ctx->classifiers_subrs[i], &lua_classifier, + sizeof(struct rspamd_stat_classifier)); + stat_ctx->classifiers_subrs[i].name = g_strdup(lua_tostring(L, -1)); + i++; + lua_pop(L, 2); + } + } + + lua_pop(L, 1); + stat_ctx->backends_subrs = stat_backends; + stat_ctx->backends_count = G_N_ELEMENTS(stat_backends); + + stat_ctx->tokenizers_subrs = stat_tokenizers; + stat_ctx->tokenizers_count = G_N_ELEMENTS(stat_tokenizers); + stat_ctx->caches_subrs = stat_caches; + stat_ctx->caches_count = G_N_ELEMENTS(stat_caches); + stat_ctx->cfg = cfg; + stat_ctx->statfiles = g_ptr_array_new(); + stat_ctx->classifiers = g_ptr_array_new(); + stat_ctx->async_elts = g_queue_new(); + stat_ctx->event_loop = ev_base; + stat_ctx->lua_stat_tokens_ref = -1; + + /* Interact with lua_stat */ + if (luaL_dostring(L, "return require \"lua_stat\"") != 0) { + msg_err_config("cannot require lua_stat: %s", + lua_tostring(L, -1)); + } + else { +#if LUA_VERSION_NUM >= 504 + lua_settop(L, -2); +#endif + if (lua_type(L, -1) != LUA_TTABLE) { + msg_err_config("lua stat must return " + "table and not %s", + lua_typename(L, lua_type(L, -1))); + } + else { + lua_pushstring(L, "gen_stat_tokens"); + lua_gettable(L, -2); + + if (lua_type(L, -1) != LUA_TFUNCTION) { + msg_err_config("gen_stat_tokens must return " + "function and not %s", + lua_typename(L, lua_type(L, -1))); + } + else { + /* Call this function to obtain closure */ + gint err_idx, ret; + struct rspamd_config **pcfg; + + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + lua_pushvalue(L, err_idx - 1); + + pcfg = lua_newuserdata(L, sizeof(*pcfg)); + *pcfg = cfg; + rspamd_lua_setclass(L, "rspamd{config}", -1); + + if ((ret = lua_pcall(L, 1, 1, err_idx)) != 0) { + msg_err_config("call to gen_stat_tokens lua " + "script failed (%d): %s", + ret, + lua_tostring(L, -1)); + } + else { + if (lua_type(L, -1) != LUA_TFUNCTION) { + msg_err_config("gen_stat_tokens invocation must return " + "function and not %s", + lua_typename(L, lua_type(L, -1))); + } + else { + stat_ctx->lua_stat_tokens_ref = luaL_ref(L, LUA_REGISTRYINDEX); + } + } + } + } + } + + /* Cleanup mess */ + lua_settop(L, 0); + + /* Create statfiles from the classifiers */ + cur = cfg->classifiers; + + while (cur) { + bk = NULL; + clf = cur->data; + cl = g_malloc0(sizeof(*cl)); + cl->cfg = clf; + cl->ctx = stat_ctx; + cl->statfiles_ids = g_array_new(FALSE, FALSE, sizeof(gint)); + cl->subrs = rspamd_stat_get_classifier(clf->classifier); + + if (cl->subrs == NULL) { + g_free(cl); + msg_err_config("cannot init classifier type %s", clf->name); + cur = g_list_next(cur); + continue; + } + + if (!cl->subrs->init_func(cfg, ev_base, cl)) { + g_free(cl); + msg_err_config("cannot init classifier type %s", clf->name); + cur = g_list_next(cur); + continue; + } + + if (!(clf->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND)) { + bk = rspamd_stat_get_backend(clf->backend); + + if (bk == NULL) { + msg_err_config("cannot get backend of type %s, so disable classifier" + " %s completely", + clf->backend, clf->name); + cur = g_list_next(cur); + continue; + } + } + else { + /* This actually is not implemented so it should never happen */ + g_free(cl); + cur = g_list_next(cur); + continue; + } + + /* XXX: + * Here we get the first classifier tokenizer config as the only one + * We NO LONGER support multiple tokenizers per rspamd instance + */ + if (stat_ctx->tkcf == NULL) { + stat_ctx->tokenizer = rspamd_stat_get_tokenizer(clf->tokenizer->name); + g_assert(stat_ctx->tokenizer != NULL); + stat_ctx->tkcf = stat_ctx->tokenizer->get_config(cfg->cfg_pool, + clf->tokenizer, NULL); + } + + /* Init classifier cache */ + cache_name = NULL; + + if (!bk->read_only) { + if (clf->opts) { + cache_obj = ucl_object_lookup(clf->opts, "cache"); + cache_name_obj = NULL; + + if (cache_obj && ucl_object_type(cache_obj) == UCL_NULL) { + skip_cache = TRUE; + } + else { + if (cache_obj) { + cache_name_obj = ucl_object_lookup_any(cache_obj, + "name", "type", NULL); + } + + if (cache_name_obj) { + cache_name = ucl_object_tostring(cache_name_obj); + } + } + } + } + else { + skip_cache = true; + } + + if (cache_name == NULL && !skip_cache) { + /* We assume that learn cache is the same as backend */ + cache_name = clf->backend; + } + + curst = clf->statfiles; + + while (curst) { + stf = curst->data; + st = g_malloc0(sizeof(*st)); + st->classifier = cl; + st->stcf = stf; + + if (!(cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND)) { + st->backend = bk; + st->bkcf = bk->init(stat_ctx, cfg, st); + msg_info_config("added backend %s for symbol %s", + bk->name, stf->symbol); + } + else { + msg_debug_config("added backend-less statfile for symbol %s", + stf->symbol); + } + + /* XXX: bad hack to pass statfiles configuration to cache */ + if (cl->cache == NULL && !skip_cache) { + cl->cache = rspamd_stat_get_cache(cache_name); + g_assert(cl->cache != NULL); + cl->cachecf = cl->cache->init(stat_ctx, cfg, st, cache_obj); + + if (cl->cachecf == NULL) { + msg_err_config("error adding cache %s for symbol %s", + cl->cache->name, stf->symbol); + cl->cache = NULL; + } + else { + msg_debug_config("added cache %s for symbol %s", + cl->cache->name, stf->symbol); + } + } + + if (st->bkcf == NULL && + !(cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND)) { + msg_err_config("cannot init backend %s for statfile %s", + clf->backend, stf->symbol); + + g_free(st); + } + else { + st->id = stat_ctx->statfiles->len; + g_ptr_array_add(stat_ctx->statfiles, st); + g_array_append_val(cl->statfiles_ids, st->id); + } + + curst = curst->next; + } + + g_ptr_array_add(stat_ctx->classifiers, cl); + + cur = cur->next; + } +} + +void rspamd_stat_close(void) +{ + struct rspamd_classifier *cl; + struct rspamd_statfile *st; + struct rspamd_stat_ctx *st_ctx; + struct rspamd_stat_async_elt *aelt; + GList *cur; + guint i, j; + gint id; + + st_ctx = rspamd_stat_get_ctx(); + g_assert(st_ctx != NULL); + + for (i = 0; i < st_ctx->classifiers->len; i++) { + cl = g_ptr_array_index(st_ctx->classifiers, i); + + for (j = 0; j < cl->statfiles_ids->len; j++) { + id = g_array_index(cl->statfiles_ids, gint, j); + st = g_ptr_array_index(st_ctx->statfiles, id); + if (!(st->classifier->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND)) { + st->backend->close(st->bkcf); + } + + g_free(st); + } + + if (cl->cache && cl->cachecf) { + cl->cache->close(cl->cachecf); + } + + g_array_free(cl->statfiles_ids, TRUE); + + if (cl->subrs->fin_func) { + cl->subrs->fin_func(cl); + } + + g_free(cl); + } + + cur = st_ctx->async_elts->head; + + while (cur) { + aelt = cur->data; + REF_RELEASE(aelt); + cur = g_list_next(cur); + } + + g_queue_free(stat_ctx->async_elts); + g_ptr_array_free(st_ctx->statfiles, TRUE); + g_ptr_array_free(st_ctx->classifiers, TRUE); + + if (st_ctx->lua_stat_tokens_ref != -1) { + luaL_unref(st_ctx->cfg->lua_state, LUA_REGISTRYINDEX, + st_ctx->lua_stat_tokens_ref); + } + + g_free(st_ctx->classifiers_subrs); + g_free(st_ctx); + + /* Set global var to NULL */ + stat_ctx = NULL; +} + +struct rspamd_stat_ctx * +rspamd_stat_get_ctx(void) +{ + return stat_ctx; +} + +struct rspamd_stat_classifier * +rspamd_stat_get_classifier(const gchar *name) +{ + guint i; + + if (name == NULL || name[0] == '\0') { + name = RSPAMD_DEFAULT_CLASSIFIER; + } + + for (i = 0; i < stat_ctx->classifiers_count; i++) { + if (strcmp(name, stat_ctx->classifiers_subrs[i].name) == 0) { + return &stat_ctx->classifiers_subrs[i]; + } + } + + msg_err("cannot find classifier named %s", name); + + return NULL; +} + +struct rspamd_stat_backend * +rspamd_stat_get_backend(const gchar *name) +{ + guint i; + + if (name == NULL || name[0] == '\0') { + name = RSPAMD_DEFAULT_BACKEND; + } + + for (i = 0; i < stat_ctx->backends_count; i++) { + if (strcmp(name, stat_ctx->backends_subrs[i].name) == 0) { + return &stat_ctx->backends_subrs[i]; + } + } + + msg_err("cannot find backend named %s", name); + + return NULL; +} + +struct rspamd_stat_tokenizer * +rspamd_stat_get_tokenizer(const gchar *name) +{ + guint i; + + if (name == NULL || name[0] == '\0') { + name = RSPAMD_DEFAULT_TOKENIZER; + } + + for (i = 0; i < stat_ctx->tokenizers_count; i++) { + if (strcmp(name, stat_ctx->tokenizers_subrs[i].name) == 0) { + return &stat_ctx->tokenizers_subrs[i]; + } + } + + msg_err("cannot find tokenizer named %s", name); + + return NULL; +} + +struct rspamd_stat_cache * +rspamd_stat_get_cache(const gchar *name) +{ + guint i; + + if (name == NULL || name[0] == '\0') { + name = RSPAMD_DEFAULT_CACHE; + } + + for (i = 0; i < stat_ctx->caches_count; i++) { + if (strcmp(name, stat_ctx->caches_subrs[i].name) == 0) { + return &stat_ctx->caches_subrs[i]; + } + } + + msg_err("cannot find cache named %s", name); + + return NULL; +} + +static void +rspamd_async_elt_dtor(struct rspamd_stat_async_elt *elt) +{ + if (elt->cleanup) { + elt->cleanup(elt, elt->ud); + } + + ev_timer_stop(elt->event_loop, &elt->timer_ev); + g_free(elt); +} + +static void +rspamd_async_elt_on_timer(EV_P_ ev_timer *w, int revents) +{ + struct rspamd_stat_async_elt *elt = (struct rspamd_stat_async_elt *) w->data; + gdouble jittered_time; + + + if (elt->enabled) { + elt->handler(elt, elt->ud); + } + + jittered_time = rspamd_time_jitter(elt->timeout, 0); + elt->timer_ev.repeat = jittered_time; + ev_timer_again(EV_A_ w); +} + +struct rspamd_stat_async_elt * +rspamd_stat_ctx_register_async(rspamd_stat_async_handler handler, + rspamd_stat_async_cleanup cleanup, + gpointer d, + gdouble timeout) +{ + struct rspamd_stat_async_elt *elt; + struct rspamd_stat_ctx *st_ctx; + + st_ctx = rspamd_stat_get_ctx(); + g_assert(st_ctx != NULL); + + elt = g_malloc0(sizeof(*elt)); + elt->handler = handler; + elt->cleanup = cleanup; + elt->ud = d; + elt->timeout = timeout; + elt->event_loop = st_ctx->event_loop; + REF_INIT_RETAIN(elt, rspamd_async_elt_dtor); + /* Enabled by default */ + + + if (st_ctx->event_loop) { + elt->enabled = TRUE; + /* + * First we set timeval to zero as we want cb to be executed as + * fast as possible + */ + elt->timer_ev.data = elt; + ev_timer_init(&elt->timer_ev, rspamd_async_elt_on_timer, + 0.1, 0.0); + ev_timer_start(st_ctx->event_loop, &elt->timer_ev); + } + else { + elt->enabled = FALSE; + } + + g_queue_push_tail(st_ctx->async_elts, elt); + + return elt; +} diff --git a/src/libstat/stat_internal.h b/src/libstat/stat_internal.h new file mode 100644 index 0000000..8d0ebd4 --- /dev/null +++ b/src/libstat/stat_internal.h @@ -0,0 +1,134 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef STAT_INTERNAL_H_ +#define STAT_INTERNAL_H_ + +#include "config.h" +#include "task.h" +#include "ref.h" +#include "classifiers/classifiers.h" +#include "tokenizers/tokenizers.h" +#include "backends/backends.h" +#include "learn_cache/learn_cache.h" + +#ifdef __cplusplus +extern "C" { +#endif + +struct rspamd_statfile_runtime { + struct rspamd_statfile_config *st; + gpointer backend_runtime; + guint64 hits; + guint64 total_hits; +}; + +/* Common classifier structure */ +struct rspamd_classifier { + struct rspamd_stat_ctx *ctx; + GArray *statfiles_ids; /* int */ + struct rspamd_stat_cache *cache; + gpointer cachecf; + gulong spam_learns; + gulong ham_learns; + gint autolearn_cbref; + struct rspamd_classifier_config *cfg; + struct rspamd_stat_classifier *subrs; + gpointer specific; +}; + +struct rspamd_statfile { + gint id; + struct rspamd_statfile_config *stcf; + struct rspamd_classifier *classifier; + struct rspamd_stat_backend *backend; + gpointer bkcf; +}; + +struct rspamd_stat_async_elt; + +typedef void (*rspamd_stat_async_handler)(struct rspamd_stat_async_elt *elt, + gpointer ud); + +typedef void (*rspamd_stat_async_cleanup)(struct rspamd_stat_async_elt *elt, + gpointer ud); + +struct rspamd_stat_async_elt { + rspamd_stat_async_handler handler; + rspamd_stat_async_cleanup cleanup; + struct ev_loop *event_loop; + ev_timer timer_ev; + gdouble timeout; + gboolean enabled; + gpointer ud; + ref_entry_t ref; +}; + +struct rspamd_stat_ctx { + /* Subroutines for all objects */ + struct rspamd_stat_classifier *classifiers_subrs; + guint classifiers_count; + struct rspamd_stat_tokenizer *tokenizers_subrs; + guint tokenizers_count; + struct rspamd_stat_backend *backends_subrs; + guint backends_count; + struct rspamd_stat_cache *caches_subrs; + guint caches_count; + + /* Runtime configuration */ + GPtrArray *statfiles; /* struct rspamd_statfile */ + GPtrArray *classifiers; /* struct rspamd_classifier */ + GQueue *async_elts; /* struct rspamd_stat_async_elt */ + struct rspamd_config *cfg; + + gint lua_stat_tokens_ref; + + /* Global tokenizer */ + struct rspamd_stat_tokenizer *tokenizer; + gpointer tkcf; + + struct ev_loop *event_loop; +}; + +typedef enum rspamd_learn_cache_result { + RSPAMD_LEARN_OK = 0, + RSPAMD_LEARN_UNLEARN, + RSPAMD_LEARN_IGNORE +} rspamd_learn_t; + +struct rspamd_stat_ctx *rspamd_stat_get_ctx(void); + +struct rspamd_stat_classifier *rspamd_stat_get_classifier(const gchar *name); + +struct rspamd_stat_backend *rspamd_stat_get_backend(const gchar *name); + +struct rspamd_stat_tokenizer *rspamd_stat_get_tokenizer(const gchar *name); + +struct rspamd_stat_cache *rspamd_stat_get_cache(const gchar *name); + +struct rspamd_stat_async_elt *rspamd_stat_ctx_register_async( + rspamd_stat_async_handler handler, rspamd_stat_async_cleanup cleanup, + gpointer d, gdouble timeout); + +static GQuark rspamd_stat_quark(void) +{ + return g_quark_from_static_string("rspamd-statistics"); +} + +#ifdef __cplusplus +} +#endif + +#endif /* STAT_INTERNAL_H_ */ diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c new file mode 100644 index 0000000..8c1d8ff --- /dev/null +++ b/src/libstat/stat_process.c @@ -0,0 +1,1250 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "config.h" +#include "stat_api.h" +#include "rspamd.h" +#include "stat_internal.h" +#include "libmime/message.h" +#include "libmime/images.h" +#include "libserver/html/html.h" +#include "lua/lua_common.h" +#include "libserver/mempool_vars_internal.h" +#include "utlist.h" +#include <math.h> + +#define RSPAMD_CLASSIFY_OP 0 +#define RSPAMD_LEARN_OP 1 +#define RSPAMD_UNLEARN_OP 2 + +static const gdouble similarity_threshold = 80.0; + +static void +rspamd_stat_tokenize_parts_metadata(struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task) +{ + GArray *ar; + rspamd_stat_token_t elt; + guint i; + lua_State *L = task->cfg->lua_state; + + ar = g_array_sized_new(FALSE, FALSE, sizeof(elt), 16); + memset(&elt, 0, sizeof(elt)); + elt.flags = RSPAMD_STAT_TOKEN_FLAG_META; + + if (st_ctx->lua_stat_tokens_ref != -1) { + gint err_idx, ret; + struct rspamd_task **ptask; + + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + lua_rawgeti(L, LUA_REGISTRYINDEX, st_ctx->lua_stat_tokens_ref); + + ptask = lua_newuserdata(L, sizeof(*ptask)); + *ptask = task; + rspamd_lua_setclass(L, "rspamd{task}", -1); + + if ((ret = lua_pcall(L, 1, 1, err_idx)) != 0) { + msg_err_task("call to stat_tokens lua " + "script failed (%d): %s", + ret, lua_tostring(L, -1)); + } + else { + if (lua_type(L, -1) != LUA_TTABLE) { + msg_err_task("stat_tokens invocation must return " + "table and not %s", + lua_typename(L, lua_type(L, -1))); + } + else { + guint vlen; + rspamd_ftok_t tok; + + vlen = rspamd_lua_table_size(L, -1); + + for (i = 0; i < vlen; i++) { + lua_rawgeti(L, -1, i + 1); + tok.begin = lua_tolstring(L, -1, &tok.len); + + if (tok.begin && tok.len > 0) { + elt.original.begin = + rspamd_mempool_ftokdup(task->task_pool, &tok); + elt.original.len = tok.len; + elt.stemmed.begin = elt.original.begin; + elt.stemmed.len = elt.original.len; + elt.normalized.begin = elt.original.begin; + elt.normalized.len = elt.original.len; + + g_array_append_val(ar, elt); + } + + lua_pop(L, 1); + } + } + } + + lua_settop(L, 0); + } + + + if (ar->len > 0) { + st_ctx->tokenizer->tokenize_func(st_ctx, + task, + ar, + TRUE, + "M", + task->tokens); + } + + rspamd_mempool_add_destructor(task->task_pool, + rspamd_array_free_hard, ar); +} + +/* + * Tokenize task using the tokenizer specified + */ +void rspamd_stat_process_tokenize(struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task) +{ + struct rspamd_mime_text_part *part; + rspamd_cryptobox_hash_state_t hst; + rspamd_token_t *st_tok; + guint i, reserved_len = 0; + gdouble *pdiff; + guchar hout[rspamd_cryptobox_HASHBYTES]; + gchar *b32_hout; + + if (st_ctx == NULL) { + st_ctx = rspamd_stat_get_ctx(); + } + + g_assert(st_ctx != NULL); + + PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, part) + { + if (!IS_TEXT_PART_EMPTY(part) && part->utf_words != NULL) { + reserved_len += part->utf_words->len; + } + /* XXX: normal window size */ + reserved_len += 5; + } + + task->tokens = g_ptr_array_sized_new(reserved_len); + rspamd_mempool_add_destructor(task->task_pool, + rspamd_ptr_array_free_hard, task->tokens); + rspamd_mempool_notify_alloc(task->task_pool, reserved_len * sizeof(gpointer)); + pdiff = rspamd_mempool_get_variable(task->task_pool, "parts_distance"); + + PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, part) + { + if (!IS_TEXT_PART_EMPTY(part) && part->utf_words != NULL) { + st_ctx->tokenizer->tokenize_func(st_ctx, task, + part->utf_words, IS_TEXT_PART_UTF(part), + NULL, task->tokens); + } + + + if (pdiff != NULL && (1.0 - *pdiff) * 100.0 > similarity_threshold) { + msg_debug_bayes("message has two common parts (%.2f), so skip the last one", + *pdiff); + break; + } + } + + if (task->meta_words != NULL) { + st_ctx->tokenizer->tokenize_func(st_ctx, + task, + task->meta_words, + TRUE, + "SUBJECT", + task->tokens); + } + + rspamd_stat_tokenize_parts_metadata(st_ctx, task); + + /* Produce signature */ + rspamd_cryptobox_hash_init(&hst, NULL, 0); + + PTR_ARRAY_FOREACH(task->tokens, i, st_tok) + { + rspamd_cryptobox_hash_update(&hst, (guchar *) &st_tok->data, + sizeof(st_tok->data)); + } + + rspamd_cryptobox_hash_final(&hst, hout); + b32_hout = rspamd_encode_base32(hout, sizeof(hout), RSPAMD_BASE32_DEFAULT); + /* + * We need to strip it to 32 characters providing ~160 bits of + * hash distribution + */ + b32_hout[32] = '\0'; + rspamd_mempool_set_variable(task->task_pool, RSPAMD_MEMPOOL_STAT_SIGNATURE, + b32_hout, g_free); +} + +static gboolean +rspamd_stat_classifier_is_skipped(struct rspamd_task *task, + struct rspamd_classifier *cl, gboolean is_learn, gboolean is_spam) +{ + GList *cur = is_learn ? cl->cfg->learn_conditions : cl->cfg->classify_conditions; + lua_State *L = task->cfg->lua_state; + gboolean ret = FALSE; + + while (cur) { + gint cb_ref = GPOINTER_TO_INT(cur->data); + gint old_top = lua_gettop(L); + gint nargs; + + lua_rawgeti(L, LUA_REGISTRYINDEX, cb_ref); + /* Push task and two booleans: is_spam and is_unlearn */ + struct rspamd_task **ptask = lua_newuserdata(L, sizeof(*ptask)); + *ptask = task; + rspamd_lua_setclass(L, "rspamd{task}", -1); + + if (is_learn) { + lua_pushboolean(L, is_spam); + lua_pushboolean(L, + task->flags & RSPAMD_TASK_FLAG_UNLEARN ? true : false); + nargs = 3; + } + else { + nargs = 1; + } + + if (lua_pcall(L, nargs, LUA_MULTRET, 0) != 0) { + msg_err_task("call to %s failed: %s", + "condition callback", + lua_tostring(L, -1)); + } + else { + if (lua_isboolean(L, 1)) { + if (!lua_toboolean(L, 1)) { + ret = TRUE; + } + } + + if (lua_isstring(L, 2)) { + if (ret) { + msg_notice_task("%s condition for classifier %s returned: %s; skip classifier", + is_learn ? "learn" : "classify", cl->cfg->name, + lua_tostring(L, 2)); + } + else { + msg_info_task("%s condition for classifier %s returned: %s", + is_learn ? "learn" : "classify", cl->cfg->name, + lua_tostring(L, 2)); + } + } + else if (ret) { + msg_notice_task("%s condition for classifier %s returned false; skip classifier", + is_learn ? "learn" : "classify", cl->cfg->name); + } + + if (ret) { + lua_settop(L, old_top); + break; + } + } + + lua_settop(L, old_top); + cur = g_list_next(cur); + } + + return ret; +} + +static void +rspamd_stat_preprocess(struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task, gboolean is_learn, gboolean is_spam) +{ + guint i; + struct rspamd_statfile *st; + gpointer bk_run; + + if (task->tokens == NULL) { + rspamd_stat_process_tokenize(st_ctx, task); + } + + task->stat_runtimes = g_ptr_array_sized_new(st_ctx->statfiles->len); + g_ptr_array_set_size(task->stat_runtimes, st_ctx->statfiles->len); + rspamd_mempool_add_destructor(task->task_pool, + rspamd_ptr_array_free_hard, task->stat_runtimes); + + /* Temporary set all stat_runtimes to some max size to distinguish from NULL */ + for (i = 0; i < st_ctx->statfiles->len; i++) { + g_ptr_array_index(task->stat_runtimes, i) = GSIZE_TO_POINTER(G_MAXSIZE); + } + + for (i = 0; i < st_ctx->classifiers->len; i++) { + struct rspamd_classifier *cl = g_ptr_array_index(st_ctx->classifiers, i); + gboolean skip_classifier = FALSE; + + if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) { + skip_classifier = TRUE; + } + else { + if (rspamd_stat_classifier_is_skipped(task, cl, is_learn, is_spam)) { + skip_classifier = TRUE; + } + } + + if (skip_classifier) { + /* Set NULL for all statfiles indexed by id */ + for (int j = 0; j < cl->statfiles_ids->len; j++) { + int id = g_array_index(cl->statfiles_ids, gint, j); + g_ptr_array_index(task->stat_runtimes, id) = NULL; + } + } + } + + for (i = 0; i < st_ctx->statfiles->len; i++) { + st = g_ptr_array_index(st_ctx->statfiles, i); + g_assert(st != NULL); + + if (g_ptr_array_index(task->stat_runtimes, i) == NULL) { + /* The whole classifier is skipped */ + continue; + } + + if (is_learn && st->backend->read_only) { + /* Read only backend, skip it */ + g_ptr_array_index(task->stat_runtimes, i) = NULL; + continue; + } + + if (!is_learn && !rspamd_symcache_is_symbol_enabled(task, task->cfg->cache, + st->stcf->symbol)) { + g_ptr_array_index(task->stat_runtimes, i) = NULL; + msg_debug_bayes("symbol %s is disabled, skip classification", + st->stcf->symbol); + continue; + } + + bk_run = st->backend->runtime(task, st->stcf, is_learn, st->bkcf, i); + + if (bk_run == NULL) { + msg_err_task("cannot init backend %s for statfile %s", + st->backend->name, st->stcf->symbol); + } + + g_ptr_array_index(task->stat_runtimes, i) = bk_run; + } +} + +static void +rspamd_stat_backends_process(struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task) +{ + guint i; + struct rspamd_statfile *st; + gpointer bk_run; + + g_assert(task->stat_runtimes != NULL); + + for (i = 0; i < st_ctx->statfiles->len; i++) { + st = g_ptr_array_index(st_ctx->statfiles, i); + bk_run = g_ptr_array_index(task->stat_runtimes, i); + + if (bk_run != NULL) { + st->backend->process_tokens(task, task->tokens, i, bk_run); + } + } +} + +static void +rspamd_stat_classifiers_process(struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task) +{ + guint i, j, id; + struct rspamd_classifier *cl; + struct rspamd_statfile *st; + gpointer bk_run; + gboolean skip; + + if (st_ctx->classifiers->len == 0) { + return; + } + + /* + * Do not classify a message if some class is missing + */ + if (!(task->flags & RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS)) { + msg_info_task("skip statistics as SPAM class is missing"); + + return; + } + if (!(task->flags & RSPAMD_TASK_FLAG_HAS_HAM_TOKENS)) { + msg_info_task("skip statistics as HAM class is missing"); + + return; + } + + for (i = 0; i < st_ctx->classifiers->len; i++) { + cl = g_ptr_array_index(st_ctx->classifiers, i); + cl->spam_learns = 0; + cl->ham_learns = 0; + } + + g_assert(task->stat_runtimes != NULL); + + for (i = 0; i < st_ctx->statfiles->len; i++) { + st = g_ptr_array_index(st_ctx->statfiles, i); + cl = st->classifier; + + bk_run = g_ptr_array_index(task->stat_runtimes, i); + g_assert(st != NULL); + + if (bk_run != NULL) { + if (st->stcf->is_spam) { + cl->spam_learns += st->backend->total_learns(task, + bk_run, + st_ctx); + } + else { + cl->ham_learns += st->backend->total_learns(task, + bk_run, + st_ctx); + } + } + } + + for (i = 0; i < st_ctx->classifiers->len; i++) { + cl = g_ptr_array_index(st_ctx->classifiers, i); + + g_assert(cl != NULL); + + skip = FALSE; + + /* Do not process classifiers on backend failures */ + for (j = 0; j < cl->statfiles_ids->len; j++) { + id = g_array_index(cl->statfiles_ids, gint, j); + bk_run = g_ptr_array_index(task->stat_runtimes, id); + st = g_ptr_array_index(st_ctx->statfiles, id); + + if (bk_run != NULL) { + if (!st->backend->finalize_process(task, bk_run, st_ctx)) { + skip = TRUE; + break; + } + } + } + + /* Ensure that all symbols enabled */ + if (!skip && !(cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND)) { + for (j = 0; j < cl->statfiles_ids->len; j++) { + id = g_array_index(cl->statfiles_ids, gint, j); + bk_run = g_ptr_array_index(task->stat_runtimes, id); + st = g_ptr_array_index(st_ctx->statfiles, id); + + if (bk_run == NULL) { + skip = TRUE; + msg_debug_bayes("disable classifier %s as statfile symbol %s is disabled", + cl->cfg->name, st->stcf->symbol); + break; + } + } + } + + if (!skip) { + if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) { + msg_debug_bayes( + "contains less tokens than required for %s classifier: " + "%ud < %ud", + cl->cfg->name, + task->tokens->len, + cl->cfg->min_tokens); + continue; + } + else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) { + msg_debug_bayes( + "contains more tokens than allowed for %s classifier: " + "%ud > %ud", + cl->cfg->name, + task->tokens->len, + cl->cfg->max_tokens); + continue; + } + + cl->subrs->classify_func(cl, task->tokens, task); + } + } +} + +rspamd_stat_result_t +rspamd_stat_classify(struct rspamd_task *task, lua_State *L, guint stage, + GError **err) +{ + struct rspamd_stat_ctx *st_ctx; + rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_OK; + + st_ctx = rspamd_stat_get_ctx(); + g_assert(st_ctx != NULL); + + if (st_ctx->classifiers->len == 0) { + task->processed_stages |= stage; + return ret; + } + + if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS_PRE) { + /* Preprocess tokens */ + rspamd_stat_preprocess(st_ctx, task, FALSE, FALSE); + } + else if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS) { + /* Process backends */ + rspamd_stat_backends_process(st_ctx, task); + } + else if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS_POST) { + /* Process classifiers */ + rspamd_stat_classifiers_process(st_ctx, task); + } + + task->processed_stages |= stage; + + return ret; +} + +static gboolean +rspamd_stat_cache_check(struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task, + const gchar *classifier, + gboolean spam, + GError **err) +{ + rspamd_learn_t learn_res = RSPAMD_LEARN_OK; + struct rspamd_classifier *cl, *sel = NULL; + gpointer rt; + guint i; + + /* Check whether we have learned that file */ + for (i = 0; i < st_ctx->classifiers->len; i++) { + cl = g_ptr_array_index(st_ctx->classifiers, i); + + /* Skip other classifiers if they are not needed */ + if (classifier != NULL && (cl->cfg->name == NULL || + g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) { + continue; + } + + sel = cl; + + if (sel->cache && sel->cachecf) { + rt = cl->cache->runtime(task, sel->cachecf, FALSE); + learn_res = cl->cache->check(task, spam, rt); + } + + if (learn_res == RSPAMD_LEARN_IGNORE) { + /* Do not learn twice */ + g_set_error(err, rspamd_stat_quark(), 404, "<%s> has been already " + "learned as %s, ignore it", + MESSAGE_FIELD(task, message_id), + spam ? "spam" : "ham"); + task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED; + + return FALSE; + } + else if (learn_res == RSPAMD_LEARN_UNLEARN) { + task->flags |= RSPAMD_TASK_FLAG_UNLEARN; + break; + } + } + + if (sel == NULL) { + if (classifier) { + g_set_error(err, rspamd_stat_quark(), 404, "cannot find classifier " + "with name %s", + classifier); + } + else { + g_set_error(err, rspamd_stat_quark(), 404, "no classifiers defined"); + } + + return FALSE; + } + + return TRUE; +} + +static gboolean +rspamd_stat_classifiers_learn(struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task, + const gchar *classifier, + gboolean spam, + GError **err) +{ + struct rspamd_classifier *cl, *sel = NULL; + guint i; + gboolean learned = FALSE, too_small = FALSE, too_large = FALSE; + + if ((task->flags & RSPAMD_TASK_FLAG_ALREADY_LEARNED) && err != NULL && + *err == NULL) { + /* Do not learn twice */ + g_set_error(err, rspamd_stat_quark(), 208, "<%s> has been already " + "learned as %s, ignore it", + MESSAGE_FIELD(task, message_id), + spam ? "spam" : "ham"); + + return FALSE; + } + + /* Check whether we have learned that file */ + for (i = 0; i < st_ctx->classifiers->len; i++) { + cl = g_ptr_array_index(st_ctx->classifiers, i); + + /* Skip other classifiers if they are not needed */ + if (classifier != NULL && (cl->cfg->name == NULL || + g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) { + continue; + } + + sel = cl; + + /* Now check max and min tokens */ + if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) { + msg_info_task( + "<%s> contains less tokens than required for %s classifier: " + "%ud < %ud", + MESSAGE_FIELD(task, message_id), + cl->cfg->name, + task->tokens->len, + cl->cfg->min_tokens); + too_small = TRUE; + continue; + } + else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) { + msg_info_task( + "<%s> contains more tokens than allowed for %s classifier: " + "%ud > %ud", + MESSAGE_FIELD(task, message_id), + cl->cfg->name, + task->tokens->len, + cl->cfg->max_tokens); + too_large = TRUE; + continue; + } + + if (cl->subrs->learn_spam_func(cl, task->tokens, task, spam, + task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) { + learned = TRUE; + } + } + + if (sel == NULL) { + if (classifier) { + g_set_error(err, rspamd_stat_quark(), 404, "cannot find classifier " + "with name %s", + classifier); + } + else { + g_set_error(err, rspamd_stat_quark(), 404, "no classifiers defined"); + } + + return FALSE; + } + + if (!learned && err && *err == NULL) { + if (too_large) { + g_set_error(err, rspamd_stat_quark(), 204, + "<%s> contains more tokens than allowed for %s classifier: " + "%d > %d", + MESSAGE_FIELD(task, message_id), + sel->cfg->name, + task->tokens->len, + sel->cfg->max_tokens); + } + else if (too_small) { + g_set_error(err, rspamd_stat_quark(), 204, + "<%s> contains less tokens than required for %s classifier: " + "%d < %d", + MESSAGE_FIELD(task, message_id), + sel->cfg->name, + task->tokens->len, + sel->cfg->min_tokens); + } + } + + return learned; +} + +static gboolean +rspamd_stat_backends_learn(struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task, + const gchar *classifier, + gboolean spam, + GError **err) +{ + struct rspamd_classifier *cl, *sel = NULL; + struct rspamd_statfile *st; + gpointer bk_run; + guint i, j; + gint id; + gboolean res = FALSE, backend_found = FALSE; + + for (i = 0; i < st_ctx->classifiers->len; i++) { + cl = g_ptr_array_index(st_ctx->classifiers, i); + + /* Skip other classifiers if they are not needed */ + if (classifier != NULL && (cl->cfg->name == NULL || + g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) { + continue; + } + + if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) { + res = TRUE; + continue; + } + + sel = cl; + + for (j = 0; j < cl->statfiles_ids->len; j++) { + id = g_array_index(cl->statfiles_ids, gint, j); + st = g_ptr_array_index(st_ctx->statfiles, id); + bk_run = g_ptr_array_index(task->stat_runtimes, id); + + g_assert(st != NULL); + + if (bk_run == NULL) { + /* XXX: must be error */ + if (task->result->passthrough_result) { + /* Passthrough email, cannot learn */ + g_set_error(err, rspamd_stat_quark(), 204, + "Cannot learn statistics when passthrough " + "result has been set; not classified"); + + res = FALSE; + goto end; + } + + msg_debug_task("no runtime for backend %s; classifier %s; symbol %s", + st->backend->name, cl->cfg->name, st->stcf->symbol); + continue; + } + + /* We set sel merely when we have runtime */ + backend_found = TRUE; + + if (!(task->flags & RSPAMD_TASK_FLAG_UNLEARN)) { + if (!!spam != !!st->stcf->is_spam) { + /* If we are not unlearning, then do not touch another class */ + continue; + } + } + + if (!st->backend->learn_tokens(task, task->tokens, id, bk_run)) { + g_set_error(err, rspamd_stat_quark(), 500, + "Cannot push " + "learned results to the backend"); + + res = FALSE; + goto end; + } + else { + if (!!spam == !!st->stcf->is_spam) { + st->backend->inc_learns(task, bk_run, st_ctx); + } + else if (task->flags & RSPAMD_TASK_FLAG_UNLEARN) { + st->backend->dec_learns(task, bk_run, st_ctx); + } + + res = TRUE; + } + } + } + +end: + + if (!res) { + if (err && *err) { + /* Error has been set already */ + return res; + } + + if (sel == NULL) { + if (classifier) { + g_set_error(err, rspamd_stat_quark(), 404, "cannot find classifier " + "with name %s", + classifier); + } + else { + g_set_error(err, rspamd_stat_quark(), 404, "no classifiers defined"); + } + + return FALSE; + } + else if (!backend_found) { + g_set_error(err, rspamd_stat_quark(), 204, "all learn conditions " + "denied learning %s in %s", + spam ? "spam" : "ham", + classifier ? classifier : "default classifier"); + } + else { + g_set_error(err, rspamd_stat_quark(), 404, "cannot find statfile " + "backend to learn %s in %s", + spam ? "spam" : "ham", + classifier ? classifier : "default classifier"); + } + } + + return res; +} + +static gboolean +rspamd_stat_backends_post_learn(struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task, + const gchar *classifier, + gboolean spam, + GError **err) +{ + struct rspamd_classifier *cl; + struct rspamd_statfile *st; + gpointer bk_run, cache_run; + guint i, j; + gint id; + gboolean res = TRUE; + + for (i = 0; i < st_ctx->classifiers->len; i++) { + cl = g_ptr_array_index(st_ctx->classifiers, i); + + /* Skip other classifiers if they are not needed */ + if (classifier != NULL && (cl->cfg->name == NULL || + g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) { + continue; + } + + if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) { + res = TRUE; + continue; + } + + for (j = 0; j < cl->statfiles_ids->len; j++) { + id = g_array_index(cl->statfiles_ids, gint, j); + st = g_ptr_array_index(st_ctx->statfiles, id); + bk_run = g_ptr_array_index(task->stat_runtimes, id); + + g_assert(st != NULL); + + if (bk_run == NULL) { + /* XXX: must be error */ + continue; + } + + if (!st->backend->finalize_learn(task, bk_run, st_ctx, err)) { + return RSPAMD_STAT_PROCESS_ERROR; + } + } + + if (cl->cache) { + cache_run = cl->cache->runtime(task, cl->cachecf, TRUE); + cl->cache->learn(task, spam, cache_run); + } + } + + g_atomic_int_add(&task->worker->srv->stat->messages_learned, 1); + + return res; +} + +rspamd_stat_result_t +rspamd_stat_learn(struct rspamd_task *task, + gboolean spam, lua_State *L, const gchar *classifier, guint stage, + GError **err) +{ + struct rspamd_stat_ctx *st_ctx; + rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_OK; + + /* + * We assume now that a task has been already classified before + * coming to learn + */ + g_assert(RSPAMD_TASK_IS_CLASSIFIED(task)); + + st_ctx = rspamd_stat_get_ctx(); + g_assert(st_ctx != NULL); + + if (st_ctx->classifiers->len == 0) { + task->processed_stages |= stage; + return ret; + } + + if (stage == RSPAMD_TASK_STAGE_LEARN_PRE) { + /* Process classifiers */ + rspamd_stat_preprocess(st_ctx, task, TRUE, spam); + + if (!rspamd_stat_cache_check(st_ctx, task, classifier, spam, err)) { + return RSPAMD_STAT_PROCESS_ERROR; + } + } + else if (stage == RSPAMD_TASK_STAGE_LEARN) { + /* Process classifiers */ + if (!rspamd_stat_classifiers_learn(st_ctx, task, classifier, + spam, err)) { + if (err && *err == NULL) { + g_set_error(err, rspamd_stat_quark(), 500, + "Unknown statistics error, found when learning classifiers;" + " classifier: %s", + task->classifier); + } + return RSPAMD_STAT_PROCESS_ERROR; + } + + /* Process backends */ + if (!rspamd_stat_backends_learn(st_ctx, task, classifier, spam, err)) { + if (err && *err == NULL) { + g_set_error(err, rspamd_stat_quark(), 500, + "Unknown statistics error, found when storing data on backend;" + " classifier: %s", + task->classifier); + } + return RSPAMD_STAT_PROCESS_ERROR; + } + } + else if (stage == RSPAMD_TASK_STAGE_LEARN_POST) { + if (!rspamd_stat_backends_post_learn(st_ctx, task, classifier, spam, err)) { + return RSPAMD_STAT_PROCESS_ERROR; + } + } + + task->processed_stages |= stage; + + return ret; +} + +static gboolean +rspamd_stat_has_classifier_symbols(struct rspamd_task *task, + struct rspamd_scan_result *mres, + struct rspamd_classifier *cl) +{ + guint i; + gint id; + struct rspamd_statfile *st; + struct rspamd_stat_ctx *st_ctx; + gboolean is_spam; + + if (mres == NULL) { + return FALSE; + } + + st_ctx = rspamd_stat_get_ctx(); + is_spam = !!(task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM); + + for (i = 0; i < cl->statfiles_ids->len; i++) { + id = g_array_index(cl->statfiles_ids, gint, i); + st = g_ptr_array_index(st_ctx->statfiles, id); + + if (rspamd_task_find_symbol_result(task, st->stcf->symbol, NULL)) { + if (is_spam == !!st->stcf->is_spam) { + msg_debug_bayes("do not autolearn %s as symbol %s is already " + "added", + is_spam ? "spam" : "ham", st->stcf->symbol); + + return TRUE; + } + } + } + + return FALSE; +} + +gboolean +rspamd_stat_check_autolearn(struct rspamd_task *task) +{ + struct rspamd_stat_ctx *st_ctx; + struct rspamd_classifier *cl; + const ucl_object_t *obj, *elt1, *elt2; + struct rspamd_scan_result *mres = NULL; + struct rspamd_task **ptask; + lua_State *L; + guint i; + gint err_idx; + gboolean ret = FALSE; + gdouble ham_score, spam_score; + const gchar *lua_script, *lua_ret; + + g_assert(RSPAMD_TASK_IS_CLASSIFIED(task)); + st_ctx = rspamd_stat_get_ctx(); + g_assert(st_ctx != NULL); + + L = task->cfg->lua_state; + + for (i = 0; i < st_ctx->classifiers->len; i++) { + cl = g_ptr_array_index(st_ctx->classifiers, i); + ret = FALSE; + + if (cl->cfg->opts) { + obj = ucl_object_lookup(cl->cfg->opts, "autolearn"); + + if (ucl_object_type(obj) == UCL_BOOLEAN) { + /* Legacy true/false */ + if (ucl_object_toboolean(obj)) { + /* + * Default learning algorithm: + * + * - We learn spam if action is ACTION_REJECT + * - We learn ham if score is less than zero + */ + mres = task->result; + + if (mres) { + if (mres->score > rspamd_task_get_required_score(task, mres)) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; + + ret = TRUE; + } + else if (mres->score < 0) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; + ret = TRUE; + } + } + } + } + else if (ucl_object_type(obj) == UCL_ARRAY && obj->len == 2) { + /* Legacy thresholds */ + /* + * We have an array of 2 elements, treat it as a + * ham_score, spam_score + */ + elt1 = ucl_array_find_index(obj, 0); + elt2 = ucl_array_find_index(obj, 1); + + if ((ucl_object_type(elt1) == UCL_FLOAT || + ucl_object_type(elt1) == UCL_INT) && + (ucl_object_type(elt2) == UCL_FLOAT || + ucl_object_type(elt2) == UCL_INT)) { + ham_score = ucl_object_todouble(elt1); + spam_score = ucl_object_todouble(elt2); + + if (ham_score > spam_score) { + gdouble t; + + t = ham_score; + ham_score = spam_score; + spam_score = t; + } + + mres = task->result; + + if (mres) { + if (mres->score >= spam_score) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; + + ret = TRUE; + } + else if (mres->score <= ham_score) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; + ret = TRUE; + } + } + } + } + else if (ucl_object_type(obj) == UCL_STRING) { + /* Legacy script */ + lua_script = ucl_object_tostring(obj); + + if (luaL_dostring(L, lua_script) != 0) { + msg_err_task("cannot execute lua script for autolearn " + "extraction: %s", + lua_tostring(L, -1)); + } + else { + if (lua_type(L, -1) == LUA_TFUNCTION) { + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + lua_pushvalue(L, -2); /* Function itself */ + + ptask = lua_newuserdata(L, sizeof(struct rspamd_task *)); + *ptask = task; + rspamd_lua_setclass(L, "rspamd{task}", -1); + + if (lua_pcall(L, 1, 1, err_idx) != 0) { + msg_err_task("call to autolearn script failed: " + "%s", + lua_tostring(L, -1)); + } + else { + lua_ret = lua_tostring(L, -1); + + /* We can have immediate results */ + if (lua_ret) { + if (strcmp(lua_ret, "ham") == 0) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; + ret = TRUE; + } + else if (strcmp(lua_ret, "spam") == 0) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; + ret = TRUE; + } + } + } + + /* Result + error function + original function */ + lua_pop(L, 3); + } + else { + msg_err_task("lua script must return " + "function(task) and not %s", + lua_typename(L, lua_type( + L, -1))); + } + } + } + else if (ucl_object_type(obj) == UCL_OBJECT) { + /* Try to find autolearn callback */ + if (cl->autolearn_cbref == 0) { + /* We don't have preprocessed cb id, so try to get it */ + if (!rspamd_lua_require_function(L, "lua_bayes_learn", + "autolearn")) { + msg_err_task("cannot get autolearn library from " + "`lua_bayes_learn`"); + } + else { + cl->autolearn_cbref = luaL_ref(L, LUA_REGISTRYINDEX); + } + } + + if (cl->autolearn_cbref != -1) { + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + lua_rawgeti(L, LUA_REGISTRYINDEX, cl->autolearn_cbref); + + ptask = lua_newuserdata(L, sizeof(struct rspamd_task *)); + *ptask = task; + rspamd_lua_setclass(L, "rspamd{task}", -1); + /* Push the whole object as well */ + ucl_object_push_lua(L, obj, true); + + if (lua_pcall(L, 2, 1, err_idx) != 0) { + msg_err_task("call to autolearn script failed: " + "%s", + lua_tostring(L, -1)); + } + else { + lua_ret = lua_tostring(L, -1); + + if (lua_ret) { + if (strcmp(lua_ret, "ham") == 0) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; + ret = TRUE; + } + else if (strcmp(lua_ret, "spam") == 0) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; + ret = TRUE; + } + } + } + + lua_settop(L, err_idx - 1); + } + } + + if (ret) { + /* Do not autolearn if we have this symbol already */ + if (rspamd_stat_has_classifier_symbols(task, mres, cl)) { + ret = FALSE; + task->flags &= ~(RSPAMD_TASK_FLAG_LEARN_HAM | + RSPAMD_TASK_FLAG_LEARN_SPAM); + } + else if (mres != NULL) { + if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) { + msg_info_task("<%s>: autolearn ham for classifier " + "'%s' as message's " + "score is negative: %.2f", + MESSAGE_FIELD(task, message_id), cl->cfg->name, + mres->score); + } + else { + msg_info_task("<%s>: autolearn spam for classifier " + "'%s' as message's " + "action is reject, score: %.2f", + MESSAGE_FIELD(task, message_id), cl->cfg->name, + mres->score); + } + + task->classifier = cl->cfg->name; + break; + } + } + } + } + + return ret; +} + +/** + * Get the overall statistics for all statfile backends + * @param cfg configuration + * @param total_learns the total number of learns is stored here + * @return array of statistical information + */ +rspamd_stat_result_t +rspamd_stat_statistics(struct rspamd_task *task, + struct rspamd_config *cfg, + guint64 *total_learns, + ucl_object_t **target) +{ + struct rspamd_stat_ctx *st_ctx; + struct rspamd_classifier *cl; + struct rspamd_statfile *st; + gpointer backend_runtime; + ucl_object_t *res = NULL, *elt; + guint64 learns = 0; + guint i, j; + gint id; + + st_ctx = rspamd_stat_get_ctx(); + g_assert(st_ctx != NULL); + + res = ucl_object_typed_new(UCL_ARRAY); + + for (i = 0; i < st_ctx->classifiers->len; i++) { + cl = g_ptr_array_index(st_ctx->classifiers, i); + + if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) { + continue; + } + + for (j = 0; j < cl->statfiles_ids->len; j++) { + id = g_array_index(cl->statfiles_ids, gint, j); + st = g_ptr_array_index(st_ctx->statfiles, id); + backend_runtime = st->backend->runtime(task, st->stcf, FALSE, + st->bkcf, id); + elt = st->backend->get_stat(backend_runtime, st->bkcf); + + if (elt && ucl_object_type(elt) == UCL_OBJECT) { + const ucl_object_t *rev = ucl_object_lookup(elt, "revision"); + + learns += ucl_object_toint(rev); + } + else { + learns += st->backend->total_learns(task, backend_runtime, + st->bkcf); + } + + if (elt != NULL) { + ucl_array_append(res, elt); + } + } + } + + if (total_learns != NULL) { + *total_learns = learns; + } + + if (target) { + *target = res; + } + else { + ucl_object_unref(res); + } + + return RSPAMD_STAT_PROCESS_OK; +} diff --git a/src/libstat/tokenizers/osb.c b/src/libstat/tokenizers/osb.c new file mode 100644 index 0000000..d871c7a --- /dev/null +++ b/src/libstat/tokenizers/osb.c @@ -0,0 +1,424 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + * OSB tokenizer + */ + + +#include "tokenizers.h" +#include "stat_internal.h" +#include "libmime/lang_detection.h" + +/* Size for features pipe */ +#define DEFAULT_FEATURE_WINDOW_SIZE 5 +#define DEFAULT_OSB_VERSION 2 + +static const int primes[] = { + 1, + 7, + 3, + 13, + 5, + 29, + 11, + 51, + 23, + 101, + 47, + 203, + 97, + 407, + 197, + 817, + 397, + 1637, + 797, + 3277, +}; + +static const guchar osb_tokenizer_magic[] = {'o', 's', 'b', 't', 'o', 'k', 'v', '2'}; + +enum rspamd_osb_hash_type { + RSPAMD_OSB_HASH_COMPAT = 0, + RSPAMD_OSB_HASH_XXHASH, + RSPAMD_OSB_HASH_SIPHASH +}; + +struct rspamd_osb_tokenizer_config { + guchar magic[8]; + gshort version; + gshort window_size; + enum rspamd_osb_hash_type ht; + guint64 seed; + rspamd_sipkey_t sk; +}; + +/* + * Return default config + */ +static struct rspamd_osb_tokenizer_config * +rspamd_tokenizer_osb_default_config(void) +{ + static struct rspamd_osb_tokenizer_config def; + + if (memcmp(def.magic, osb_tokenizer_magic, sizeof(osb_tokenizer_magic)) != 0) { + memset(&def, 0, sizeof(def)); + memcpy(def.magic, osb_tokenizer_magic, sizeof(osb_tokenizer_magic)); + def.version = DEFAULT_OSB_VERSION; + def.window_size = DEFAULT_FEATURE_WINDOW_SIZE; + def.ht = RSPAMD_OSB_HASH_XXHASH; + def.seed = 0xdeadbabe; + } + + return &def; +} + +static struct rspamd_osb_tokenizer_config * +rspamd_tokenizer_osb_config_from_ucl(rspamd_mempool_t *pool, + const ucl_object_t *obj) +{ + const ucl_object_t *elt; + struct rspamd_osb_tokenizer_config *cf, *def; + guchar *key = NULL; + gsize keylen; + + + if (pool != NULL) { + cf = rspamd_mempool_alloc0(pool, sizeof(*cf)); + } + else { + cf = g_malloc0(sizeof(*cf)); + } + + /* Use default config */ + def = rspamd_tokenizer_osb_default_config(); + memcpy(cf, def, sizeof(*cf)); + + elt = ucl_object_lookup(obj, "hash"); + if (elt != NULL && ucl_object_type(elt) == UCL_STRING) { + if (g_ascii_strncasecmp(ucl_object_tostring(elt), "xxh", 3) == 0) { + cf->ht = RSPAMD_OSB_HASH_XXHASH; + elt = ucl_object_lookup(obj, "seed"); + if (elt != NULL && ucl_object_type(elt) == UCL_INT) { + cf->seed = ucl_object_toint(elt); + } + } + else if (g_ascii_strncasecmp(ucl_object_tostring(elt), "sip", 3) == 0) { + cf->ht = RSPAMD_OSB_HASH_SIPHASH; + elt = ucl_object_lookup(obj, "key"); + + if (elt != NULL && ucl_object_type(elt) == UCL_STRING) { + key = rspamd_decode_base32(ucl_object_tostring(elt), + 0, &keylen, RSPAMD_BASE32_DEFAULT); + if (keylen < sizeof(rspamd_sipkey_t)) { + msg_warn("siphash key is too short: %z", keylen); + g_free(key); + } + else { + memcpy(cf->sk, key, sizeof(cf->sk)); + g_free(key); + } + } + else { + msg_warn_pool("siphash cannot be used without key"); + } + } + } + else { + elt = ucl_object_lookup(obj, "compat"); + if (elt != NULL && ucl_object_toboolean(elt)) { + cf->ht = RSPAMD_OSB_HASH_COMPAT; + } + } + + elt = ucl_object_lookup(obj, "window"); + if (elt != NULL && ucl_object_type(elt) == UCL_INT) { + cf->window_size = ucl_object_toint(elt); + if (cf->window_size > DEFAULT_FEATURE_WINDOW_SIZE * 4) { + msg_err_pool("too large window size: %d", cf->window_size); + cf->window_size = DEFAULT_FEATURE_WINDOW_SIZE; + } + } + + return cf; +} + +gpointer +rspamd_tokenizer_osb_get_config(rspamd_mempool_t *pool, + struct rspamd_tokenizer_config *cf, + gsize *len) +{ + struct rspamd_osb_tokenizer_config *osb_cf, *def; + + if (cf != NULL && cf->opts != NULL) { + osb_cf = rspamd_tokenizer_osb_config_from_ucl(pool, cf->opts); + } + else { + def = rspamd_tokenizer_osb_default_config(); + osb_cf = rspamd_mempool_alloc(pool, sizeof(*osb_cf)); + memcpy(osb_cf, def, sizeof(*osb_cf)); + /* Do not write sipkey to statfile */ + } + + if (osb_cf->ht == RSPAMD_OSB_HASH_SIPHASH) { + msg_info_pool("siphash key is not stored into statfiles, so you'd " + "need to keep it inside the configuration"); + } + + memset(osb_cf->sk, 0, sizeof(osb_cf->sk)); + + if (len != NULL) { + *len = sizeof(*osb_cf); + } + + return osb_cf; +} + +#if 0 +gboolean +rspamd_tokenizer_osb_compatible_config (struct rspamd_tokenizer_runtime *rt, + gpointer ptr, gsize len) +{ + struct rspamd_osb_tokenizer_config *osb_cf, *test_cf; + gboolean ret = FALSE; + + test_cf = rt->config; + g_assert (test_cf != NULL); + + if (len == sizeof (*osb_cf)) { + osb_cf = ptr; + + if (memcmp (osb_cf, osb_tokenizer_magic, sizeof (osb_tokenizer_magic)) != 0) { + ret = test_cf->ht == RSPAMD_OSB_HASH_COMPAT; + } + else { + if (osb_cf->version == DEFAULT_OSB_VERSION) { + /* We can compare them directly now */ + ret = (memcmp (osb_cf, test_cf, sizeof (*osb_cf) + - sizeof (osb_cf->sk))) == 0; + } + } + } + else { + /* We are compatible now merely with fallback config */ + if (test_cf->ht == RSPAMD_OSB_HASH_COMPAT) { + ret = TRUE; + } + } + + return ret; +} + +gboolean +rspamd_tokenizer_osb_load_config (rspamd_mempool_t *pool, + struct rspamd_tokenizer_runtime *rt, + gpointer ptr, gsize len) +{ + struct rspamd_osb_tokenizer_config *osb_cf; + + if (ptr == NULL || len == 0) { + osb_cf = rspamd_tokenizer_osb_config_from_ucl (pool, rt->tkcf->opts); + + if (osb_cf->ht != RSPAMD_OSB_HASH_COMPAT) { + /* Trying to load incompatible configuration */ + msg_err_pool ("cannot load tokenizer configuration from a legacy " + "statfile; maybe you have forgotten to set 'compat' option" + " in the tokenizer configuration"); + + return FALSE; + } + } + else { + g_assert (len == sizeof (*osb_cf)); + osb_cf = ptr; + } + + rt->config = osb_cf; + rt->conf_len = sizeof (*osb_cf); + + return TRUE; +} + +gboolean +rspamd_tokenizer_osb_is_compat (struct rspamd_tokenizer_runtime *rt) +{ + struct rspamd_osb_tokenizer_config *osb_cf = rt->config; + + return (osb_cf->ht == RSPAMD_OSB_HASH_COMPAT); +} +#endif + +struct token_pipe_entry { + guint64 h; + rspamd_stat_token_t *t; +}; + +gint rspamd_tokenizer_osb(struct rspamd_stat_ctx *ctx, + struct rspamd_task *task, + GArray *words, + gboolean is_utf, + const gchar *prefix, + GPtrArray *result) +{ + rspamd_token_t *new_tok = NULL; + rspamd_stat_token_t *token; + struct rspamd_osb_tokenizer_config *osb_cf; + guint64 cur, seed; + struct token_pipe_entry *hashpipe; + guint32 h1, h2; + gsize token_size; + guint processed = 0, i, w, window_size, token_flags = 0; + + if (words == NULL) { + return FALSE; + } + + osb_cf = ctx->tkcf; + window_size = osb_cf->window_size; + + if (prefix) { + seed = rspamd_cryptobox_fast_hash_specific(RSPAMD_CRYPTOBOX_XXHASH64, + prefix, strlen(prefix), osb_cf->seed); + } + else { + seed = osb_cf->seed; + } + + hashpipe = g_alloca(window_size * sizeof(hashpipe[0])); + for (i = 0; i < window_size; i++) { + hashpipe[i].h = 0xfe; + hashpipe[i].t = NULL; + } + + token_size = sizeof(rspamd_token_t) + + sizeof(gdouble) * ctx->statfiles->len; + g_assert(token_size > 0); + + for (w = 0; w < words->len; w++) { + token = &g_array_index(words, rspamd_stat_token_t, w); + token_flags = token->flags; + const gchar *begin; + gsize len; + + if (token->flags & + (RSPAMD_STAT_TOKEN_FLAG_STOP_WORD | RSPAMD_STAT_TOKEN_FLAG_SKIPPED)) { + /* Skip stop/skipped words */ + continue; + } + + if (token->flags & RSPAMD_STAT_TOKEN_FLAG_TEXT) { + begin = token->stemmed.begin; + len = token->stemmed.len; + } + else { + begin = token->original.begin; + len = token->original.len; + } + + if (osb_cf->ht == RSPAMD_OSB_HASH_COMPAT) { + rspamd_ftok_t ftok; + + ftok.begin = begin; + ftok.len = len; + cur = rspamd_fstrhash_lc(&ftok, is_utf); + } + else { + /* We know that the words are normalized */ + if (osb_cf->ht == RSPAMD_OSB_HASH_XXHASH) { + cur = rspamd_cryptobox_fast_hash_specific(RSPAMD_CRYPTOBOX_XXHASH64, + begin, len, osb_cf->seed); + } + else { + rspamd_cryptobox_siphash((guchar *) &cur, begin, + len, osb_cf->sk); + + if (prefix) { + cur ^= seed; + } + } + } + + if (token_flags & RSPAMD_STAT_TOKEN_FLAG_UNIGRAM) { + new_tok = rspamd_mempool_alloc0(task->task_pool, token_size); + new_tok->flags = token_flags; + new_tok->t1 = token; + new_tok->t2 = token; + new_tok->data = cur; + new_tok->window_idx = 0; + g_ptr_array_add(result, new_tok); + + continue; + } + +#define ADD_TOKEN \ + do { \ + new_tok = rspamd_mempool_alloc0(task->task_pool, token_size); \ + new_tok->flags = token_flags; \ + new_tok->t1 = hashpipe[0].t; \ + new_tok->t2 = hashpipe[i].t; \ + if (osb_cf->ht == RSPAMD_OSB_HASH_COMPAT) { \ + h1 = ((guint32) hashpipe[0].h) * primes[0] + \ + ((guint32) hashpipe[i].h) * primes[i << 1]; \ + h2 = ((guint32) hashpipe[0].h) * primes[1] + \ + ((guint32) hashpipe[i].h) * primes[(i << 1) - 1]; \ + memcpy((guchar *) &new_tok->data, &h1, sizeof(h1)); \ + memcpy(((guchar *) &new_tok->data) + sizeof(h1), &h2, sizeof(h2)); \ + } \ + else { \ + new_tok->data = hashpipe[0].h * primes[0] + hashpipe[i].h * primes[i << 1]; \ + } \ + new_tok->window_idx = i; \ + g_ptr_array_add(result, new_tok); \ + } while (0) + + if (processed < window_size) { + /* Just fill a hashpipe */ + ++processed; + hashpipe[window_size - processed].h = cur; + hashpipe[window_size - processed].t = token; + } + else { + /* Shift hashpipe */ + for (i = window_size - 1; i > 0; i--) { + hashpipe[i] = hashpipe[i - 1]; + } + hashpipe[0].h = cur; + hashpipe[0].t = token; + + processed++; + + for (i = 1; i < window_size; i++) { + if (!(hashpipe[i].t->flags & RSPAMD_STAT_TOKEN_FLAG_EXCEPTION)) { + ADD_TOKEN; + } + } + } + } + + if (processed > 1 && processed <= window_size) { + processed--; + memmove(hashpipe, &hashpipe[window_size - processed], + processed * sizeof(hashpipe[0])); + + for (i = 1; i < processed; i++) { + ADD_TOKEN; + } + } + +#undef ADD_TOKEN + + return TRUE; +} diff --git a/src/libstat/tokenizers/tokenizers.c b/src/libstat/tokenizers/tokenizers.c new file mode 100644 index 0000000..ee7234d --- /dev/null +++ b/src/libstat/tokenizers/tokenizers.c @@ -0,0 +1,955 @@ +/* + * Copyright 2023 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + * Common tokenization functions + */ + +#include "rspamd.h" +#include "tokenizers.h" +#include "stat_internal.h" +#include "contrib/mumhash/mum.h" +#include "libmime/lang_detection.h" +#include "libstemmer.h" + +#include <unicode/utf8.h> +#include <unicode/uchar.h> +#include <unicode/uiter.h> +#include <unicode/ubrk.h> +#include <unicode/ucnv.h> +#if U_ICU_VERSION_MAJOR_NUM >= 44 +#include <unicode/unorm2.h> +#endif + +#include <math.h> + +typedef gboolean (*token_get_function)(rspamd_stat_token_t *buf, gchar const **pos, + rspamd_stat_token_t *token, + GList **exceptions, gsize *rl, gboolean check_signature); + +const gchar t_delimiters[256] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0}; + +/* Get next word from specified f_str_t buf */ +static gboolean +rspamd_tokenizer_get_word_raw(rspamd_stat_token_t *buf, + gchar const **cur, rspamd_stat_token_t *token, + GList **exceptions, gsize *rl, gboolean unused) +{ + gsize remain, pos; + const gchar *p; + struct rspamd_process_exception *ex = NULL; + + if (buf == NULL) { + return FALSE; + } + + g_assert(cur != NULL); + + if (exceptions != NULL && *exceptions != NULL) { + ex = (*exceptions)->data; + } + + if (token->original.begin == NULL || *cur == NULL) { + if (ex != NULL) { + if (ex->pos == 0) { + token->original.begin = buf->original.begin + ex->len; + token->original.len = ex->len; + token->flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION; + } + else { + token->original.begin = buf->original.begin; + token->original.len = 0; + } + } + else { + token->original.begin = buf->original.begin; + token->original.len = 0; + } + *cur = token->original.begin; + } + + token->original.len = 0; + + pos = *cur - buf->original.begin; + if (pos >= buf->original.len) { + return FALSE; + } + + remain = buf->original.len - pos; + p = *cur; + + /* Skip non delimiters symbols */ + do { + if (ex != NULL && ex->pos == pos) { + /* Go to the next exception */ + *exceptions = g_list_next(*exceptions); + *cur = p + ex->len; + return TRUE; + } + pos++; + p++; + remain--; + } while (remain > 0 && t_delimiters[(guchar) *p]); + + token->original.begin = p; + + while (remain > 0 && !t_delimiters[(guchar) *p]) { + if (ex != NULL && ex->pos == pos) { + *exceptions = g_list_next(*exceptions); + *cur = p + ex->len; + return TRUE; + } + token->original.len++; + pos++; + remain--; + p++; + } + + if (remain == 0) { + return FALSE; + } + + if (rl) { + *rl = token->original.len; + } + + token->flags = RSPAMD_STAT_TOKEN_FLAG_TEXT; + + *cur = p; + + return TRUE; +} + +static inline gboolean +rspamd_tokenize_check_limit(gboolean decay, + guint word_decay, + guint nwords, + guint64 *hv, + guint64 *prob, + const rspamd_stat_token_t *token, + gssize remain, + gssize total) +{ + static const gdouble avg_word_len = 6.0; + + if (!decay) { + if (token->original.len >= sizeof(guint64)) { + guint64 tmp; + memcpy(&tmp, token->original.begin, sizeof(tmp)); + *hv = mum_hash_step(*hv, tmp); + } + + /* Check for decay */ + if (word_decay > 0 && nwords > word_decay && remain < (gssize) total) { + /* Start decay */ + gdouble decay_prob; + + *hv = mum_hash_finish(*hv); + + /* We assume that word is 6 symbols length in average */ + decay_prob = (gdouble) word_decay / ((total - (remain)) / avg_word_len) * 10; + decay_prob = floor(decay_prob) / 10.0; + + if (decay_prob >= 1.0) { + *prob = G_MAXUINT64; + } + else { + *prob = (guint64) (decay_prob * (double) G_MAXUINT64); + } + + return TRUE; + } + } + else { + /* Decaying probability */ + /* LCG64 x[n] = a x[n - 1] + b mod 2^64 */ + *hv = (*hv) * 2862933555777941757ULL + 3037000493ULL; + + if (*hv > *prob) { + return TRUE; + } + } + + return FALSE; +} + +static inline gboolean +rspamd_utf_word_valid(const guchar *text, const guchar *end, + gint32 start, gint32 finish) +{ + const guchar *st = text + start, *fin = text + finish; + UChar32 c; + + if (st >= end || fin > end || st >= fin) { + return FALSE; + } + + U8_NEXT(text, start, finish, c); + + if (u_isJavaIDPart(c)) { + return TRUE; + } + + return FALSE; +} +#define SHIFT_EX \ + do { \ + cur = g_list_next(cur); \ + if (cur) { \ + ex = (struct rspamd_process_exception *) cur->data; \ + } \ + else { \ + ex = NULL; \ + } \ + } while (0) + +static inline void +rspamd_tokenize_exception(struct rspamd_process_exception *ex, GArray *res) +{ + rspamd_stat_token_t token; + + memset(&token, 0, sizeof(token)); + + if (ex->type == RSPAMD_EXCEPTION_GENERIC) { + token.original.begin = "!!EX!!"; + token.original.len = sizeof("!!EX!!") - 1; + token.flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION; + + g_array_append_val(res, token); + token.flags = 0; + } + else if (ex->type == RSPAMD_EXCEPTION_URL) { + struct rspamd_url *uri; + + uri = ex->ptr; + + if (uri && uri->tldlen > 0) { + token.original.begin = rspamd_url_tld_unsafe(uri); + token.original.len = uri->tldlen; + } + else { + token.original.begin = "!!EX!!"; + token.original.len = sizeof("!!EX!!") - 1; + } + + token.flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION; + g_array_append_val(res, token); + token.flags = 0; + } +} + + +GArray * +rspamd_tokenize_text(const gchar *text, gsize len, + const UText *utxt, + enum rspamd_tokenize_type how, + struct rspamd_config *cfg, + GList *exceptions, + guint64 *hash, + GArray *cur_words, + rspamd_mempool_t *pool) +{ + rspamd_stat_token_t token, buf; + const gchar *pos = NULL; + gsize l = 0; + GArray *res; + GList *cur = exceptions; + guint min_len = 0, max_len = 0, word_decay = 0, initial_size = 128; + guint64 hv = 0; + gboolean decay = FALSE, long_text_mode = FALSE; + guint64 prob = 0; + static UBreakIterator *bi = NULL; + static const gsize long_text_limit = 1 * 1024 * 1024; + static const ev_tstamp max_exec_time = 0.2; /* 200 ms */ + ev_tstamp start; + + if (text == NULL) { + return cur_words; + } + + if (len > long_text_limit) { + /* + * In this mode we do additional checks to avoid performance issues + */ + long_text_mode = TRUE; + start = ev_time(); + } + + buf.original.begin = text; + buf.original.len = len; + buf.flags = 0; + + memset(&token, 0, sizeof(token)); + + if (cfg != NULL) { + min_len = cfg->min_word_len; + max_len = cfg->max_word_len; + word_decay = cfg->words_decay; + initial_size = word_decay * 2; + } + + if (!cur_words) { + res = g_array_sized_new(FALSE, FALSE, sizeof(rspamd_stat_token_t), + initial_size); + } + else { + res = cur_words; + } + + if (G_UNLIKELY(how == RSPAMD_TOKENIZE_RAW || utxt == NULL)) { + while (rspamd_tokenizer_get_word_raw(&buf, &pos, &token, &cur, &l, FALSE)) { + if (l == 0 || (min_len > 0 && l < min_len) || + (max_len > 0 && l > max_len)) { + token.original.begin = pos; + continue; + } + + if (token.original.len > 0 && + rspamd_tokenize_check_limit(decay, word_decay, res->len, + &hv, &prob, &token, pos - text, len)) { + if (!decay) { + decay = TRUE; + } + else { + token.original.begin = pos; + continue; + } + } + + if (long_text_mode) { + if ((res->len + 1) % 16 == 0) { + ev_tstamp now = ev_time(); + + if (now - start > max_exec_time) { + msg_warn_pool_check( + "too long time has been spent on tokenization:" + " %.1f ms, limit is %.1f ms; %d words added so far", + (now - start) * 1e3, max_exec_time * 1e3, + res->len); + + goto end; + } + } + } + + g_array_append_val(res, token); + + if (((gsize) res->len) * sizeof(token) > (0x1ull << 30u)) { + /* Due to bug in glib ! */ + msg_err_pool_check( + "too many words found: %d, stop tokenization to avoid DoS", + res->len); + + goto end; + } + + token.original.begin = pos; + } + } + else { + /* UTF8 boundaries */ + UErrorCode uc_err = U_ZERO_ERROR; + int32_t last, p; + struct rspamd_process_exception *ex = NULL; + + if (bi == NULL) { + bi = ubrk_open(UBRK_WORD, NULL, NULL, 0, &uc_err); + + g_assert(U_SUCCESS(uc_err)); + } + + ubrk_setUText(bi, (UText *) utxt, &uc_err); + last = ubrk_first(bi); + p = last; + + if (cur) { + ex = (struct rspamd_process_exception *) cur->data; + } + + while (p != UBRK_DONE) { + start_over: + token.original.len = 0; + + if (p > last) { + if (ex && cur) { + /* Check exception */ + if (ex->pos >= last && ex->pos <= p) { + /* We have an exception within boundary */ + /* First, start to drain exceptions from the start */ + while (cur && ex->pos <= last) { + /* We have an exception at the beginning, skip those */ + last += ex->len; + rspamd_tokenize_exception(ex, res); + + if (last > p) { + /* Exception spread over the boundaries */ + while (last > p && p != UBRK_DONE) { + gint32 old_p = p; + p = ubrk_next(bi); + + if (p != UBRK_DONE && p <= old_p) { + msg_warn_pool_check( + "tokenization reversed back on position %d," + "%d new position (%d backward), likely libicu bug!", + (gint) (p), (gint) (old_p), old_p - p); + + goto end; + } + } + + /* We need to reset our scan with new p and last */ + SHIFT_EX; + goto start_over; + } + + SHIFT_EX; + } + + /* Now, we can have an exception within boundary again */ + if (cur && ex->pos >= last && ex->pos <= p) { + /* Append the first part */ + if (rspamd_utf_word_valid(text, text + len, last, + ex->pos)) { + token.original.begin = text + last; + token.original.len = ex->pos - last; + token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT | + RSPAMD_STAT_TOKEN_FLAG_UTF; + } + + /* Process the current exception */ + last += ex->len + (ex->pos - last); + + rspamd_tokenize_exception(ex, res); + + if (last > p) { + /* Exception spread over the boundaries */ + while (last > p && p != UBRK_DONE) { + gint32 old_p = p; + p = ubrk_next(bi); + if (p != UBRK_DONE && p <= old_p) { + msg_warn_pool_check( + "tokenization reversed back on position %d," + "%d new position (%d backward), likely libicu bug!", + (gint) (p), (gint) (old_p), old_p - p); + + goto end; + } + } + /* We need to reset our scan with new p and last */ + SHIFT_EX; + goto start_over; + } + + SHIFT_EX; + } + else if (p > last) { + if (rspamd_utf_word_valid(text, text + len, last, p)) { + token.original.begin = text + last; + token.original.len = p - last; + token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT | + RSPAMD_STAT_TOKEN_FLAG_UTF; + } + } + } + else if (ex->pos < last) { + /* Forward exceptions list */ + while (cur && ex->pos <= last) { + /* We have an exception at the beginning, skip those */ + SHIFT_EX; + } + + if (rspamd_utf_word_valid(text, text + len, last, p)) { + token.original.begin = text + last; + token.original.len = p - last; + token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT | + RSPAMD_STAT_TOKEN_FLAG_UTF; + } + } + else { + /* No exceptions within boundary */ + if (rspamd_utf_word_valid(text, text + len, last, p)) { + token.original.begin = text + last; + token.original.len = p - last; + token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT | + RSPAMD_STAT_TOKEN_FLAG_UTF; + } + } + } + else { + if (rspamd_utf_word_valid(text, text + len, last, p)) { + token.original.begin = text + last; + token.original.len = p - last; + token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT | + RSPAMD_STAT_TOKEN_FLAG_UTF; + } + } + + if (token.original.len > 0 && + rspamd_tokenize_check_limit(decay, word_decay, res->len, + &hv, &prob, &token, p, len)) { + if (!decay) { + decay = TRUE; + } + else { + token.flags |= RSPAMD_STAT_TOKEN_FLAG_SKIPPED; + } + } + } + + if (token.original.len > 0) { + /* Additional check for number of words */ + if (((gsize) res->len) * sizeof(token) > (0x1ull << 30u)) { + /* Due to bug in glib ! */ + msg_err("too many words found: %d, stop tokenization to avoid DoS", + res->len); + + goto end; + } + + g_array_append_val(res, token); + } + + /* Also check for long text mode */ + if (long_text_mode) { + /* Check time each 128 words added */ + const int words_check_mask = 0x7F; + + if ((res->len & words_check_mask) == words_check_mask) { + ev_tstamp now = ev_time(); + + if (now - start > max_exec_time) { + msg_warn_pool_check( + "too long time has been spent on tokenization:" + " %.1f ms, limit is %.1f ms; %d words added so far", + (now - start) * 1e3, max_exec_time * 1e3, + res->len); + + goto end; + } + } + } + + last = p; + p = ubrk_next(bi); + + if (p != UBRK_DONE && p <= last) { + msg_warn_pool_check("tokenization reversed back on position %d," + "%d new position (%d backward), likely libicu bug!", + (gint) (p), (gint) (last), last - p); + + goto end; + } + } + } + +end: + if (!decay) { + hv = mum_hash_finish(hv); + } + + if (hash) { + *hash = hv; + } + + return res; +} + +#undef SHIFT_EX + +static void +rspamd_add_metawords_from_str(const gchar *beg, gsize len, + struct rspamd_task *task) +{ + UText utxt = UTEXT_INITIALIZER; + UErrorCode uc_err = U_ZERO_ERROR; + guint i = 0; + UChar32 uc; + gboolean valid_utf = TRUE; + + while (i < len) { + U8_NEXT(beg, i, len, uc); + + if (((gint32) uc) < 0) { + valid_utf = FALSE; + break; + } + +#if U_ICU_VERSION_MAJOR_NUM < 50 + if (u_isalpha(uc)) { + gint32 sc = ublock_getCode(uc); + + if (sc == UBLOCK_THAI) { + valid_utf = FALSE; + msg_info_task("enable workaround for Thai characters for old libicu"); + break; + } + } +#endif + } + + if (valid_utf) { + utext_openUTF8(&utxt, + beg, + len, + &uc_err); + + task->meta_words = rspamd_tokenize_text(beg, len, + &utxt, RSPAMD_TOKENIZE_UTF, + task->cfg, NULL, NULL, + task->meta_words, + task->task_pool); + + utext_close(&utxt); + } + else { + task->meta_words = rspamd_tokenize_text(beg, len, + NULL, RSPAMD_TOKENIZE_RAW, + task->cfg, NULL, NULL, task->meta_words, + task->task_pool); + } +} + +void rspamd_tokenize_meta_words(struct rspamd_task *task) +{ + guint i = 0; + rspamd_stat_token_t *tok; + + if (MESSAGE_FIELD(task, subject)) { + rspamd_add_metawords_from_str(MESSAGE_FIELD(task, subject), + strlen(MESSAGE_FIELD(task, subject)), task); + } + + if (MESSAGE_FIELD(task, from_mime) && MESSAGE_FIELD(task, from_mime)->len > 0) { + struct rspamd_email_address *addr; + + addr = g_ptr_array_index(MESSAGE_FIELD(task, from_mime), 0); + + if (addr->name) { + rspamd_add_metawords_from_str(addr->name, strlen(addr->name), task); + } + } + + if (task->meta_words != NULL) { + const gchar *language = NULL; + + if (MESSAGE_FIELD(task, text_parts) && + MESSAGE_FIELD(task, text_parts)->len > 0) { + struct rspamd_mime_text_part *tp = g_ptr_array_index( + MESSAGE_FIELD(task, text_parts), 0); + + if (tp->language) { + language = tp->language; + } + } + + rspamd_normalize_words(task->meta_words, task->task_pool); + rspamd_stem_words(task->meta_words, task->task_pool, language, + task->lang_det); + + for (i = 0; i < task->meta_words->len; i++) { + tok = &g_array_index(task->meta_words, rspamd_stat_token_t, i); + tok->flags |= RSPAMD_STAT_TOKEN_FLAG_HEADER; + } + } +} + +static inline void +rspamd_uchars_to_ucs32(const UChar *src, gsize srclen, + rspamd_stat_token_t *tok, + rspamd_mempool_t *pool) +{ + UChar32 *dest, t, *d; + gint32 i = 0; + + dest = rspamd_mempool_alloc(pool, srclen * sizeof(UChar32)); + d = dest; + + while (i < srclen) { + U16_NEXT_UNSAFE(src, i, t); + + if (u_isgraph(t)) { + UCharCategory cat; + + cat = u_charType(t); +#if U_ICU_VERSION_MAJOR_NUM >= 57 + if (u_hasBinaryProperty(t, UCHAR_EMOJI)) { + tok->flags |= RSPAMD_STAT_TOKEN_FLAG_EMOJI; + } +#endif + + if ((cat >= U_UPPERCASE_LETTER && cat <= U_OTHER_NUMBER) || + cat == U_CONNECTOR_PUNCTUATION || + cat == U_MATH_SYMBOL || + cat == U_CURRENCY_SYMBOL) { + *d++ = u_tolower(t); + } + } + else { + /* Invisible spaces ! */ + tok->flags |= RSPAMD_STAT_TOKEN_FLAG_INVISIBLE_SPACES; + } + } + + tok->unicode.begin = dest; + tok->unicode.len = d - dest; +} + +static inline void +rspamd_ucs32_to_normalised(rspamd_stat_token_t *tok, + rspamd_mempool_t *pool) +{ + guint i, doff = 0; + gsize utflen = 0; + gchar *dest; + UChar32 t; + + for (i = 0; i < tok->unicode.len; i++) { + utflen += U8_LENGTH(tok->unicode.begin[i]); + } + + dest = rspamd_mempool_alloc(pool, utflen + 1); + + for (i = 0; i < tok->unicode.len; i++) { + t = tok->unicode.begin[i]; + U8_APPEND_UNSAFE(dest, doff, t); + } + + g_assert(doff <= utflen); + dest[doff] = '\0'; + + tok->normalized.len = doff; + tok->normalized.begin = dest; +} + +void rspamd_normalize_single_word(rspamd_stat_token_t *tok, rspamd_mempool_t *pool) +{ + UErrorCode uc_err = U_ZERO_ERROR; + UConverter *utf8_converter; + UChar tmpbuf[1024]; /* Assume that we have no longer words... */ + gsize ulen; + + utf8_converter = rspamd_get_utf8_converter(); + + if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_UTF) { + ulen = ucnv_toUChars(utf8_converter, + tmpbuf, + G_N_ELEMENTS(tmpbuf), + tok->original.begin, + tok->original.len, + &uc_err); + + /* Now, we need to understand if we need to normalise the word */ + if (!U_SUCCESS(uc_err)) { + tok->flags |= RSPAMD_STAT_TOKEN_FLAG_BROKEN_UNICODE; + tok->unicode.begin = NULL; + tok->unicode.len = 0; + tok->normalized.begin = NULL; + tok->normalized.len = 0; + } + else { +#if U_ICU_VERSION_MAJOR_NUM >= 44 + const UNormalizer2 *norm = rspamd_get_unicode_normalizer(); + gint32 end; + + /* We can now check if we need to decompose */ + end = unorm2_spanQuickCheckYes(norm, tmpbuf, ulen, &uc_err); + + if (!U_SUCCESS(uc_err)) { + rspamd_uchars_to_ucs32(tmpbuf, ulen, tok, pool); + tok->normalized.begin = NULL; + tok->normalized.len = 0; + tok->flags |= RSPAMD_STAT_TOKEN_FLAG_BROKEN_UNICODE; + } + else { + if (end == ulen) { + /* Already normalised, just lowercase */ + rspamd_uchars_to_ucs32(tmpbuf, ulen, tok, pool); + rspamd_ucs32_to_normalised(tok, pool); + } + else { + /* Perform normalization */ + UChar normbuf[1024]; + + g_assert(end < G_N_ELEMENTS(normbuf)); + /* First part */ + memcpy(normbuf, tmpbuf, end * sizeof(UChar)); + /* Second part */ + ulen = unorm2_normalizeSecondAndAppend(norm, + normbuf, end, + G_N_ELEMENTS(normbuf), + tmpbuf + end, + ulen - end, + &uc_err); + + if (!U_SUCCESS(uc_err)) { + if (uc_err != U_BUFFER_OVERFLOW_ERROR) { + msg_warn_pool_check("cannot normalise text '%*s': %s", + (gint) tok->original.len, tok->original.begin, + u_errorName(uc_err)); + rspamd_uchars_to_ucs32(tmpbuf, ulen, tok, pool); + rspamd_ucs32_to_normalised(tok, pool); + tok->flags |= RSPAMD_STAT_TOKEN_FLAG_BROKEN_UNICODE; + } + } + else { + /* Copy normalised back */ + rspamd_uchars_to_ucs32(normbuf, ulen, tok, pool); + tok->flags |= RSPAMD_STAT_TOKEN_FLAG_NORMALISED; + rspamd_ucs32_to_normalised(tok, pool); + } + } + } +#else + /* Legacy version with no unorm2 interface */ + rspamd_uchars_to_ucs32(tmpbuf, ulen, tok, pool); + rspamd_ucs32_to_normalised(tok, pool); +#endif + } + } + else { + if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_TEXT) { + /* Simple lowercase */ + gchar *dest; + + dest = rspamd_mempool_alloc(pool, tok->original.len + 1); + rspamd_strlcpy(dest, tok->original.begin, tok->original.len + 1); + rspamd_str_lc(dest, tok->original.len); + tok->normalized.len = tok->original.len; + tok->normalized.begin = dest; + } + } +} + +void rspamd_normalize_words(GArray *words, rspamd_mempool_t *pool) +{ + rspamd_stat_token_t *tok; + guint i; + + for (i = 0; i < words->len; i++) { + tok = &g_array_index(words, rspamd_stat_token_t, i); + rspamd_normalize_single_word(tok, pool); + } +} + +void rspamd_stem_words(GArray *words, rspamd_mempool_t *pool, + const gchar *language, + struct rspamd_lang_detector *lang_detector) +{ + static GHashTable *stemmers = NULL; + struct sb_stemmer *stem = NULL; + guint i; + rspamd_stat_token_t *tok; + gchar *dest; + gsize dlen; + + if (!stemmers) { + stemmers = g_hash_table_new(rspamd_strcase_hash, + rspamd_strcase_equal); + } + + if (language && language[0] != '\0') { + stem = g_hash_table_lookup(stemmers, language); + + if (stem == NULL) { + + stem = sb_stemmer_new(language, "UTF_8"); + + if (stem == NULL) { + msg_debug_pool( + "cannot create lemmatizer for %s language", + language); + g_hash_table_insert(stemmers, g_strdup(language), + GINT_TO_POINTER(-1)); + } + else { + g_hash_table_insert(stemmers, g_strdup(language), + stem); + } + } + else if (stem == GINT_TO_POINTER(-1)) { + /* Negative cache */ + stem = NULL; + } + } + for (i = 0; i < words->len; i++) { + tok = &g_array_index(words, rspamd_stat_token_t, i); + + if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_UTF) { + if (stem) { + const gchar *stemmed = NULL; + + stemmed = sb_stemmer_stem(stem, + tok->normalized.begin, tok->normalized.len); + + dlen = sb_stemmer_length(stem); + + if (stemmed != NULL && dlen > 0) { + dest = rspamd_mempool_alloc(pool, dlen); + memcpy(dest, stemmed, dlen); + tok->stemmed.len = dlen; + tok->stemmed.begin = dest; + tok->flags |= RSPAMD_STAT_TOKEN_FLAG_STEMMED; + } + else { + /* Fallback */ + tok->stemmed.len = tok->normalized.len; + tok->stemmed.begin = tok->normalized.begin; + } + } + else { + tok->stemmed.len = tok->normalized.len; + tok->stemmed.begin = tok->normalized.begin; + } + + if (tok->stemmed.len > 0 && lang_detector != NULL && + rspamd_language_detector_is_stop_word(lang_detector, tok->stemmed.begin, tok->stemmed.len)) { + tok->flags |= RSPAMD_STAT_TOKEN_FLAG_STOP_WORD; + } + } + else { + if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_TEXT) { + /* Raw text, lowercase */ + tok->stemmed.len = tok->normalized.len; + tok->stemmed.begin = tok->normalized.begin; + } + } + } +}
\ No newline at end of file diff --git a/src/libstat/tokenizers/tokenizers.h b/src/libstat/tokenizers/tokenizers.h new file mode 100644 index 0000000..d696364 --- /dev/null +++ b/src/libstat/tokenizers/tokenizers.h @@ -0,0 +1,100 @@ +/* + * Copyright 2023 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TOKENIZERS_H +#define TOKENIZERS_H + +#include "config.h" +#include "mem_pool.h" +#include "fstring.h" +#include "rspamd.h" +#include "stat_api.h" + +#include <unicode/utext.h> + +#define RSPAMD_DEFAULT_TOKENIZER "osb" + +#ifdef __cplusplus +extern "C" { +#endif + +struct rspamd_tokenizer_runtime; +struct rspamd_stat_ctx; + +/* Common tokenizer structure */ +struct rspamd_stat_tokenizer { + gchar *name; + + gpointer (*get_config)(rspamd_mempool_t *pool, + struct rspamd_tokenizer_config *cf, gsize *len); + + gint (*tokenize_func)(struct rspamd_stat_ctx *ctx, + struct rspamd_task *task, + GArray *words, + gboolean is_utf, + const gchar *prefix, + GPtrArray *result); +}; + +enum rspamd_tokenize_type { + RSPAMD_TOKENIZE_UTF = 0, + RSPAMD_TOKENIZE_RAW, + RSPAMD_TOKENIZE_UNICODE +}; + +/* Compare two token nodes */ +gint token_node_compare_func(gconstpointer a, gconstpointer b); + + +/* Tokenize text into array of words (rspamd_stat_token_t type) */ +GArray *rspamd_tokenize_text(const gchar *text, gsize len, + const UText *utxt, + enum rspamd_tokenize_type how, + struct rspamd_config *cfg, + GList *exceptions, + guint64 *hash, + GArray *cur_words, + rspamd_mempool_t *pool); + +/* OSB tokenize function */ +gint rspamd_tokenizer_osb(struct rspamd_stat_ctx *ctx, + struct rspamd_task *task, + GArray *words, + gboolean is_utf, + const gchar *prefix, + GPtrArray *result); + +gpointer rspamd_tokenizer_osb_get_config(rspamd_mempool_t *pool, + struct rspamd_tokenizer_config *cf, + gsize *len); + +struct rspamd_lang_detector; + +void rspamd_normalize_single_word(rspamd_stat_token_t *tok, rspamd_mempool_t *pool); + +void rspamd_normalize_words(GArray *words, rspamd_mempool_t *pool); + +void rspamd_stem_words(GArray *words, rspamd_mempool_t *pool, + const gchar *language, + struct rspamd_lang_detector *lang_detector); + +void rspamd_tokenize_meta_words(struct rspamd_task *task); + +#ifdef __cplusplus +} +#endif + +#endif |