summaryrefslogtreecommitdiffstats
path: root/src/libstat
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-10 21:30:40 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-10 21:30:40 +0000
commit133a45c109da5310add55824db21af5239951f93 (patch)
treeba6ac4c0a950a0dda56451944315d66409923918 /src/libstat
parentInitial commit. (diff)
downloadrspamd-133a45c109da5310add55824db21af5239951f93.tar.xz
rspamd-133a45c109da5310add55824db21af5239951f93.zip
Adding upstream version 3.8.1.upstream/3.8.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/libstat')
-rw-r--r--src/libstat/CMakeLists.txt25
-rw-r--r--src/libstat/backends/backends.h127
-rw-r--r--src/libstat/backends/cdb_backend.cxx491
-rw-r--r--src/libstat/backends/http_backend.cxx440
-rw-r--r--src/libstat/backends/mmaped_file.c1113
-rw-r--r--src/libstat/backends/redis_backend.cxx1132
-rw-r--r--src/libstat/backends/sqlite3_backend.c907
-rw-r--r--src/libstat/classifiers/bayes.c551
-rw-r--r--src/libstat/classifiers/classifiers.h109
-rw-r--r--src/libstat/classifiers/lua_classifier.c237
-rw-r--r--src/libstat/learn_cache/learn_cache.h79
-rw-r--r--src/libstat/learn_cache/redis_cache.cxx254
-rw-r--r--src/libstat/learn_cache/sqlite3_cache.c274
-rw-r--r--src/libstat/stat_api.h147
-rw-r--r--src/libstat/stat_config.c603
-rw-r--r--src/libstat/stat_internal.h134
-rw-r--r--src/libstat/stat_process.c1250
-rw-r--r--src/libstat/tokenizers/osb.c424
-rw-r--r--src/libstat/tokenizers/tokenizers.c955
-rw-r--r--src/libstat/tokenizers/tokenizers.h100
20 files changed, 9352 insertions, 0 deletions
diff --git a/src/libstat/CMakeLists.txt b/src/libstat/CMakeLists.txt
new file mode 100644
index 0000000..64d572a
--- /dev/null
+++ b/src/libstat/CMakeLists.txt
@@ -0,0 +1,25 @@
+# Librspamdserver
+SET(LIBSTATSRC ${CMAKE_CURRENT_SOURCE_DIR}/stat_config.c
+ ${CMAKE_CURRENT_SOURCE_DIR}/stat_process.c)
+
+SET(TOKENIZERSSRC ${CMAKE_CURRENT_SOURCE_DIR}/tokenizers/tokenizers.c
+ ${CMAKE_CURRENT_SOURCE_DIR}/tokenizers/osb.c)
+
+SET(CLASSIFIERSSRC ${CMAKE_CURRENT_SOURCE_DIR}/classifiers/bayes.c
+ ${CMAKE_CURRENT_SOURCE_DIR}/classifiers/lua_classifier.c)
+
+SET(BACKENDSSRC ${CMAKE_CURRENT_SOURCE_DIR}/backends/mmaped_file.c
+ ${CMAKE_CURRENT_SOURCE_DIR}/backends/sqlite3_backend.c
+ ${CMAKE_CURRENT_SOURCE_DIR}/backends/cdb_backend.cxx
+ ${CMAKE_CURRENT_SOURCE_DIR}/backends/http_backend.cxx
+ ${CMAKE_CURRENT_SOURCE_DIR}/backends/redis_backend.cxx)
+
+SET(CACHESSRC ${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/sqlite3_cache.c
+ ${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/redis_cache.cxx)
+
+SET(RSPAMD_STAT ${LIBSTATSRC}
+ ${TOKENIZERSSRC}
+ ${CLASSIFIERSSRC}
+ ${BACKENDSSRC}
+ ${CACHESSRC} PARENT_SCOPE)
+
diff --git a/src/libstat/backends/backends.h b/src/libstat/backends/backends.h
new file mode 100644
index 0000000..4b16950
--- /dev/null
+++ b/src/libstat/backends/backends.h
@@ -0,0 +1,127 @@
+/*-
+ * Copyright 2016 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef BACKENDS_H_
+#define BACKENDS_H_
+
+#include "config.h"
+#include "ucl.h"
+
+#define RSPAMD_DEFAULT_BACKEND "mmap"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/* Forwarded declarations */
+struct rspamd_classifier_config;
+struct rspamd_statfile_config;
+struct rspamd_config;
+struct rspamd_stat_ctx;
+struct rspamd_token_result;
+struct rspamd_statfile;
+struct rspamd_task;
+
+struct rspamd_stat_backend {
+ const char *name;
+ bool read_only;
+
+ gpointer (*init)(struct rspamd_stat_ctx *ctx, struct rspamd_config *cfg,
+ struct rspamd_statfile *st);
+
+ gpointer (*runtime)(struct rspamd_task *task,
+ struct rspamd_statfile_config *stcf,
+ gboolean learn, gpointer ctx,
+ gint id);
+
+ gboolean (*process_tokens)(struct rspamd_task *task, GPtrArray *tokens,
+ gint id,
+ gpointer ctx);
+
+ gboolean (*finalize_process)(struct rspamd_task *task,
+ gpointer runtime, gpointer ctx);
+
+ gboolean (*learn_tokens)(struct rspamd_task *task, GPtrArray *tokens,
+ gint id,
+ gpointer ctx);
+
+ gulong (*total_learns)(struct rspamd_task *task,
+ gpointer runtime, gpointer ctx);
+
+ gboolean (*finalize_learn)(struct rspamd_task *task,
+ gpointer runtime, gpointer ctx, GError **err);
+
+ gulong (*inc_learns)(struct rspamd_task *task,
+ gpointer runtime, gpointer ctx);
+
+ gulong (*dec_learns)(struct rspamd_task *task,
+ gpointer runtime, gpointer ctx);
+
+ ucl_object_t *(*get_stat)(gpointer runtime, gpointer ctx);
+
+ void (*close)(gpointer ctx);
+
+ gpointer (*load_tokenizer_config)(gpointer runtime, gsize *sz);
+
+ gpointer ctx;
+};
+
+#define RSPAMD_STAT_BACKEND_DEF(name) \
+ gpointer rspamd_##name##_init(struct rspamd_stat_ctx *ctx, \
+ struct rspamd_config *cfg, struct rspamd_statfile *st); \
+ gpointer rspamd_##name##_runtime(struct rspamd_task *task, \
+ struct rspamd_statfile_config *stcf, \
+ gboolean learn, gpointer ctx, gint id); \
+ gboolean rspamd_##name##_process_tokens(struct rspamd_task *task, \
+ GPtrArray *tokens, gint id, \
+ gpointer runtime); \
+ gboolean rspamd_##name##_finalize_process(struct rspamd_task *task, \
+ gpointer runtime, \
+ gpointer ctx); \
+ gboolean rspamd_##name##_learn_tokens(struct rspamd_task *task, \
+ GPtrArray *tokens, gint id, \
+ gpointer runtime); \
+ gboolean rspamd_##name##_finalize_learn(struct rspamd_task *task, \
+ gpointer runtime, \
+ gpointer ctx, GError **err); \
+ gulong rspamd_##name##_total_learns(struct rspamd_task *task, \
+ gpointer runtime, \
+ gpointer ctx); \
+ gulong rspamd_##name##_inc_learns(struct rspamd_task *task, \
+ gpointer runtime, \
+ gpointer ctx); \
+ gulong rspamd_##name##_dec_learns(struct rspamd_task *task, \
+ gpointer runtime, \
+ gpointer ctx); \
+ gulong rspamd_##name##_learns(struct rspamd_task *task, \
+ gpointer runtime, \
+ gpointer ctx); \
+ ucl_object_t *rspamd_##name##_get_stat(gpointer runtime, \
+ gpointer ctx); \
+ gpointer rspamd_##name##_load_tokenizer_config(gpointer runtime, \
+ gsize *len); \
+ void rspamd_##name##_close(gpointer ctx)
+
+RSPAMD_STAT_BACKEND_DEF(mmaped_file);
+RSPAMD_STAT_BACKEND_DEF(sqlite3);
+RSPAMD_STAT_BACKEND_DEF(cdb);
+RSPAMD_STAT_BACKEND_DEF(redis);
+RSPAMD_STAT_BACKEND_DEF(http);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif /* BACKENDS_H_ */
diff --git a/src/libstat/backends/cdb_backend.cxx b/src/libstat/backends/cdb_backend.cxx
new file mode 100644
index 0000000..81d87f3
--- /dev/null
+++ b/src/libstat/backends/cdb_backend.cxx
@@ -0,0 +1,491 @@
+/*-
+ * Copyright 2021 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * CDB read only statistics backend
+ */
+
+#include "config.h"
+#include "stat_internal.h"
+#include "contrib/cdb/cdb.h"
+
+#include <utility>
+#include <memory>
+#include <string>
+#include <optional>
+#include "contrib/expected/expected.hpp"
+#include "contrib/ankerl/unordered_dense.h"
+#include "fmt/core.h"
+
+namespace rspamd::stat::cdb {
+
+/*
+ * Utility class to share cdb instances over statfiles instances, as each
+ * cdb has tokens for both ham and spam classes
+ */
+class cdb_shared_storage {
+public:
+ using cdb_element_t = std::shared_ptr<struct cdb>;
+ cdb_shared_storage() = default;
+
+ auto get_cdb(const char *path) const -> std::optional<cdb_element_t>
+ {
+ auto found = elts.find(path);
+
+ if (found != elts.end()) {
+ if (!found->second.expired()) {
+ return found->second.lock();
+ }
+ }
+
+ return std::nullopt;
+ }
+ /* Create a new smart pointer over POD cdb structure */
+ static auto new_cdb() -> cdb_element_t
+ {
+ auto ret = cdb_element_t(new struct cdb, cdb_deleter());
+ memset(ret.get(), 0, sizeof(struct cdb));
+ return ret;
+ }
+ /* Enclose cdb into storage */
+ auto push_cdb(const char *path, cdb_element_t cdbp) -> cdb_element_t
+ {
+ auto found = elts.find(path);
+
+ if (found != elts.end()) {
+ if (found->second.expired()) {
+ /* OK, move in lieu of the expired weak pointer */
+
+ found->second = cdbp;
+ return cdbp;
+ }
+ else {
+ /*
+ * Existing and not expired, return the existing one
+ */
+ return found->second.lock();
+ }
+ }
+ else {
+ /* Not existing, make a weak ptr and return the original */
+ elts.emplace(path, std::weak_ptr<struct cdb>(cdbp));
+ return cdbp;
+ }
+ }
+
+private:
+ /*
+ * We store weak pointers here to allow owning cdb statfiles to free
+ * expensive cdb before this cache is terminated (e.g. on dynamic cdb reload)
+ */
+ ankerl::unordered_dense::map<std::string, std::weak_ptr<struct cdb>> elts;
+
+ struct cdb_deleter {
+ void operator()(struct cdb *c) const
+ {
+ cdb_free(c);
+ delete c;
+ }
+ };
+};
+
+static cdb_shared_storage cdb_shared_storage;
+
+class ro_backend final {
+public:
+ explicit ro_backend(struct rspamd_statfile *_st, cdb_shared_storage::cdb_element_t _db)
+ : st(_st), db(std::move(_db))
+ {
+ }
+ ro_backend() = delete;
+ ro_backend(const ro_backend &) = delete;
+ ro_backend(ro_backend &&other) noexcept
+ {
+ *this = std::move(other);
+ }
+ ro_backend &operator=(ro_backend &&other) noexcept
+ {
+ std::swap(st, other.st);
+ std::swap(db, other.db);
+ std::swap(loaded, other.loaded);
+ std::swap(learns_spam, other.learns_spam);
+ std::swap(learns_ham, other.learns_ham);
+
+ return *this;
+ }
+ ~ro_backend()
+ {
+ }
+
+ auto load_cdb() -> tl::expected<bool, std::string>;
+ auto process_token(const rspamd_token_t *tok) const -> std::optional<float>;
+ constexpr auto is_spam() const -> bool
+ {
+ return st->stcf->is_spam;
+ }
+ auto get_learns() const -> std::uint64_t
+ {
+ if (is_spam()) {
+ return learns_spam;
+ }
+ else {
+ return learns_ham;
+ }
+ }
+ auto get_total_learns() const -> std::uint64_t
+ {
+ return learns_spam + learns_ham;
+ }
+
+private:
+ struct rspamd_statfile *st;
+ cdb_shared_storage::cdb_element_t db;
+ bool loaded = false;
+ std::uint64_t learns_spam = 0;
+ std::uint64_t learns_ham = 0;
+};
+
+template<typename T>
+static inline auto
+cdb_get_key_as_int64(struct cdb *cdb, T key) -> std::optional<std::int64_t>
+{
+ auto pos = cdb_find(cdb, (void *) &key, sizeof(key));
+
+ if (pos > 0) {
+ auto vpos = cdb_datapos(cdb);
+ auto vlen = cdb_datalen(cdb);
+
+ if (vlen == sizeof(std::int64_t)) {
+ std::int64_t ret;
+ cdb_read(cdb, (void *) &ret, vlen, vpos);
+
+ return ret;
+ }
+ }
+
+ return std::nullopt;
+}
+
+template<typename T>
+static inline auto
+cdb_get_key_as_float_pair(struct cdb *cdb, T key) -> std::optional<std::pair<float, float>>
+{
+ auto pos = cdb_find(cdb, (void *) &key, sizeof(key));
+
+ if (pos > 0) {
+ auto vpos = cdb_datapos(cdb);
+ auto vlen = cdb_datalen(cdb);
+
+ if (vlen == sizeof(float) * 2) {
+ union {
+ struct {
+ float v1;
+ float v2;
+ } d;
+ char c[sizeof(float) * 2];
+ } u;
+ cdb_read(cdb, (void *) u.c, vlen, vpos);
+
+ return std::make_pair(u.d.v1, u.d.v2);
+ }
+ }
+
+ return std::nullopt;
+}
+
+
+auto ro_backend::load_cdb() -> tl::expected<bool, std::string>
+{
+ if (!db) {
+ return tl::make_unexpected("no database loaded");
+ }
+
+ /* Now get number of learns */
+ std::int64_t cdb_key;
+ static const char learn_spam_key[9] = "_lrnspam", learn_ham_key[9] = "_lrnham_";
+
+ auto check_key = [&](const char *key, std::uint64_t &target) -> tl::expected<bool, std::string> {
+ memcpy((void *) &cdb_key, key, sizeof(cdb_key));
+
+ auto maybe_value = cdb_get_key_as_int64(db.get(), cdb_key);
+
+ if (!maybe_value) {
+ return tl::make_unexpected(fmt::format("missing {} key", key));
+ }
+
+ target = (std::uint64_t) maybe_value.value();
+
+ return true;
+ };
+
+ auto res = check_key(learn_spam_key, learns_spam);
+
+ if (!res) {
+ return res;
+ }
+
+ res = check_key(learn_ham_key, learns_ham);
+
+ if (!res) {
+ return res;
+ }
+
+ loaded = true;
+
+ return true;// expected
+}
+
+auto ro_backend::process_token(const rspamd_token_t *tok) const -> std::optional<float>
+{
+ if (!loaded) {
+ return std::nullopt;
+ }
+
+ auto maybe_value = cdb_get_key_as_float_pair(db.get(), tok->data);
+
+ if (maybe_value) {
+ auto [spam_count, ham_count] = maybe_value.value();
+
+ if (is_spam()) {
+ return spam_count;
+ }
+ else {
+ return ham_count;
+ }
+ }
+
+ return std::nullopt;
+}
+
+auto open_cdb(struct rspamd_statfile *st) -> tl::expected<ro_backend, std::string>
+{
+ const char *path = nullptr;
+ const auto *stf = st->stcf;
+
+ auto get_filename = [](const ucl_object_t *obj) -> const char * {
+ const auto *filename = ucl_object_lookup_any(obj,
+ "filename", "path", "cdb", nullptr);
+
+ if (filename && ucl_object_type(filename) == UCL_STRING) {
+ return ucl_object_tostring(filename);
+ }
+
+ return nullptr;
+ };
+
+ /* First search in backend configuration */
+ const auto *obj = ucl_object_lookup(st->classifier->cfg->opts, "backend");
+ if (obj != NULL && ucl_object_type(obj) == UCL_OBJECT) {
+ path = get_filename(obj);
+ }
+
+ /* Now try statfiles config */
+ if (!path && stf->opts) {
+ path = get_filename(stf->opts);
+ }
+
+ /* Now try classifier config */
+ if (!path && st->classifier->cfg->opts) {
+ path = get_filename(st->classifier->cfg->opts);
+ }
+
+ if (!path) {
+ return tl::make_unexpected("missing/malformed filename attribute");
+ }
+
+ auto cached_cdb_maybe = cdb_shared_storage.get_cdb(path);
+ cdb_shared_storage::cdb_element_t cdbp;
+
+ if (!cached_cdb_maybe) {
+
+ auto fd = rspamd_file_xopen(path, O_RDONLY, 0, true);
+
+ if (fd == -1) {
+ return tl::make_unexpected(fmt::format("cannot open {}: {}",
+ path, strerror(errno)));
+ }
+
+ cdbp = cdb_shared_storage::new_cdb();
+
+ if (cdb_init(cdbp.get(), fd) == -1) {
+ close(fd);
+
+ return tl::make_unexpected(fmt::format("cannot init cdb in {}: {}",
+ path, strerror(errno)));
+ }
+
+ cdbp = cdb_shared_storage.push_cdb(path, cdbp);
+
+ close(fd);
+ }
+ else {
+ cdbp = cached_cdb_maybe.value();
+ }
+
+ if (!cdbp) {
+ return tl::make_unexpected(fmt::format("cannot init cdb in {}: internal error",
+ path));
+ }
+
+ ro_backend bk{st, std::move(cdbp)};
+
+ auto res = bk.load_cdb();
+
+ if (!res) {
+ return tl::make_unexpected(res.error());
+ }
+
+ return bk;
+}
+
+}// namespace rspamd::stat::cdb
+
+#define CDB_FROM_RAW(p) (reinterpret_cast<rspamd::stat::cdb::ro_backend *>(p))
+
+/* C exports */
+gpointer
+rspamd_cdb_init(struct rspamd_stat_ctx *ctx,
+ struct rspamd_config *cfg,
+ struct rspamd_statfile *st)
+{
+ auto maybe_backend = rspamd::stat::cdb::open_cdb(st);
+
+ if (maybe_backend) {
+ /* Move into a new pointer */
+ auto *result = new rspamd::stat::cdb::ro_backend(std::move(maybe_backend.value()));
+
+ return result;
+ }
+ else {
+ msg_err_config("cannot load cdb backend: %s", maybe_backend.error().c_str());
+ }
+
+ return nullptr;
+}
+gpointer
+rspamd_cdb_runtime(struct rspamd_task *task,
+ struct rspamd_statfile_config *stcf,
+ gboolean learn,
+ gpointer ctx,
+ gint _id)
+{
+ /* In CDB we don't have any dynamic stuff */
+ return ctx;
+}
+
+gboolean
+rspamd_cdb_process_tokens(struct rspamd_task *task,
+ GPtrArray *tokens,
+ gint id,
+ gpointer runtime)
+{
+ auto *cdbp = CDB_FROM_RAW(runtime);
+ bool seen_values = false;
+
+ for (auto i = 0u; i < tokens->len; i++) {
+ rspamd_token_t *tok;
+ tok = reinterpret_cast<rspamd_token_t *>(g_ptr_array_index(tokens, i));
+
+ auto res = cdbp->process_token(tok);
+
+ if (res) {
+ tok->values[id] = res.value();
+ seen_values = true;
+ }
+ else {
+ tok->values[id] = 0;
+ }
+ }
+
+ if (seen_values) {
+ if (cdbp->is_spam()) {
+ task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS;
+ }
+ else {
+ task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
+ }
+ }
+
+ return true;
+}
+gboolean
+rspamd_cdb_finalize_process(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx)
+{
+ return true;
+}
+gboolean
+rspamd_cdb_learn_tokens(struct rspamd_task *task,
+ GPtrArray *tokens,
+ gint id,
+ gpointer ctx)
+{
+ return false;
+}
+gboolean
+rspamd_cdb_finalize_learn(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx,
+ GError **err)
+{
+ return false;
+}
+
+gulong rspamd_cdb_total_learns(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx)
+{
+ auto *cdbp = CDB_FROM_RAW(ctx);
+ return cdbp->get_total_learns();
+}
+gulong
+rspamd_cdb_inc_learns(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx)
+{
+ return (gulong) -1;
+}
+gulong
+rspamd_cdb_dec_learns(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx)
+{
+ return (gulong) -1;
+}
+gulong
+rspamd_cdb_learns(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx)
+{
+ auto *cdbp = CDB_FROM_RAW(ctx);
+ return cdbp->get_learns();
+}
+ucl_object_t *
+rspamd_cdb_get_stat(gpointer runtime, gpointer ctx)
+{
+ return nullptr;
+}
+gpointer
+rspamd_cdb_load_tokenizer_config(gpointer runtime, gsize *len)
+{
+ return nullptr;
+}
+void rspamd_cdb_close(gpointer ctx)
+{
+ auto *cdbp = CDB_FROM_RAW(ctx);
+ delete cdbp;
+} \ No newline at end of file
diff --git a/src/libstat/backends/http_backend.cxx b/src/libstat/backends/http_backend.cxx
new file mode 100644
index 0000000..075e508
--- /dev/null
+++ b/src/libstat/backends/http_backend.cxx
@@ -0,0 +1,440 @@
+/*-
+ * Copyright 2022 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "config.h"
+#include "stat_internal.h"
+#include "libserver/http/http_connection.h"
+#include "libserver/mempool_vars_internal.h"
+#include "upstream.h"
+#include "contrib/ankerl/unordered_dense.h"
+#include <algorithm>
+#include <vector>
+
+namespace rspamd::stat::http {
+
+#define msg_debug_stat_http(...) rspamd_conditional_debug_fast(NULL, NULL, \
+ rspamd_stat_http_log_id, "stat_http", task->task_pool->tag.uid, \
+ RSPAMD_LOG_FUNC, \
+ __VA_ARGS__)
+
+INIT_LOG_MODULE(stat_http)
+
+/* Represents all http backends defined in some configuration */
+class http_backends_collection {
+ std::vector<struct rspamd_statfile *> backends;
+ double timeout = 1.0; /* Default timeout */
+ struct upstream_list *read_servers = nullptr;
+ struct upstream_list *write_servers = nullptr;
+
+public:
+ static auto get() -> http_backends_collection &
+ {
+ static http_backends_collection *singleton = nullptr;
+
+ if (singleton == nullptr) {
+ singleton = new http_backends_collection;
+ }
+
+ return *singleton;
+ }
+
+ /**
+ * Add a new backend and (optionally initialize the basic backend parameters
+ * @param ctx
+ * @param cfg
+ * @param st
+ * @return
+ */
+ auto add_backend(struct rspamd_stat_ctx *ctx,
+ struct rspamd_config *cfg,
+ struct rspamd_statfile *st) -> bool;
+ /**
+ * Remove a statfile cleaning things up if the last statfile is removed
+ * @param st
+ * @return
+ */
+ auto remove_backend(struct rspamd_statfile *st) -> bool;
+
+ upstream *get_upstream(bool is_learn);
+
+private:
+ http_backends_collection() = default;
+ auto first_init(struct rspamd_stat_ctx *ctx,
+ struct rspamd_config *cfg,
+ struct rspamd_statfile *st) -> bool;
+};
+
+/*
+ * Created one per each task
+ */
+class http_backend_runtime final {
+public:
+ static auto create(struct rspamd_task *task, bool is_learn) -> http_backend_runtime *;
+ /* Add a new statfile with a specific id to the list of statfiles */
+ auto notice_statfile(int id, const struct rspamd_statfile_config *st) -> void
+ {
+ seen_statfiles[id] = st;
+ }
+
+ auto process_tokens(struct rspamd_task *task,
+ GPtrArray *tokens,
+ gint id,
+ bool learn) -> bool;
+
+private:
+ http_backends_collection *all_backends;
+ ankerl::unordered_dense::map<int, const struct rspamd_statfile_config *> seen_statfiles;
+ struct upstream *selected;
+
+private:
+ http_backend_runtime(struct rspamd_task *task, bool is_learn)
+ : all_backends(&http_backends_collection::get())
+ {
+ selected = all_backends->get_upstream(is_learn);
+ }
+ ~http_backend_runtime() = default;
+ static auto dtor(void *p) -> void
+ {
+ ((http_backend_runtime *) p)->~http_backend_runtime();
+ }
+};
+
+/*
+ * Efficient way to make a messagepack payload from stat tokens,
+ * avoiding any intermediate libraries, as we would send many tokens
+ * all together
+ */
+static auto
+stat_tokens_to_msgpack(GPtrArray *tokens) -> std::vector<std::uint8_t>
+{
+ std::vector<std::uint8_t> ret;
+ rspamd_token_t *cur;
+ int i;
+
+ /*
+ * We define array, it's size and N elements each is uint64_t
+ * Layout:
+ * 0xdd - array marker
+ * [4 bytes be] - size of the array
+ * [ 0xcf + <8 bytes BE integer>] * N - array elements
+ */
+ ret.resize(tokens->len * (sizeof(std::uint64_t) + 1) + 5);
+ ret.push_back('\xdd');
+ std::uint32_t ulen = GUINT32_TO_BE(tokens->len);
+ std::copy((const std::uint8_t *) &ulen,
+ ((const std::uint8_t *) &ulen) + sizeof(ulen), std::back_inserter(ret));
+
+ PTR_ARRAY_FOREACH(tokens, i, cur)
+ {
+ ret.push_back('\xcf');
+ std::uint64_t val = GUINT64_TO_BE(cur->data);
+ std::copy((const std::uint8_t *) &val,
+ ((const std::uint8_t *) &val) + sizeof(val), std::back_inserter(ret));
+ }
+
+ return ret;
+}
+
+auto http_backend_runtime::create(struct rspamd_task *task, bool is_learn) -> http_backend_runtime *
+{
+ /* Alloc type provide proper size and alignment */
+ auto *allocated_runtime = rspamd_mempool_alloc_type(task->task_pool, http_backend_runtime);
+
+ rspamd_mempool_add_destructor(task->task_pool, http_backend_runtime::dtor, allocated_runtime);
+
+ return new (allocated_runtime) http_backend_runtime{task, is_learn};
+}
+
+auto http_backend_runtime::process_tokens(struct rspamd_task *task, GPtrArray *tokens, gint id, bool learn) -> bool
+{
+ if (!learn) {
+ if (id == seen_statfiles.size() - 1) {
+ /* Emit http request on the last statfile */
+ }
+ }
+ else {
+ /* On learn we need to learn all statfiles that we were requested to learn */
+ if (seen_statfiles.empty()) {
+ /* Request has been already set, or nothing to learn */
+ return true;
+ }
+ else {
+ seen_statfiles.clear();
+ }
+ }
+
+ return true;
+}
+
+auto http_backends_collection::add_backend(struct rspamd_stat_ctx *ctx,
+ struct rspamd_config *cfg,
+ struct rspamd_statfile *st) -> bool
+{
+ /* On empty list of backends we know that we need to load backend data actually */
+ if (backends.empty()) {
+ if (!first_init(ctx, cfg, st)) {
+ return false;
+ }
+ }
+
+ backends.push_back(st);
+
+ return true;
+}
+
+auto http_backends_collection::first_init(struct rspamd_stat_ctx *ctx,
+ struct rspamd_config *cfg,
+ struct rspamd_statfile *st) -> bool
+{
+ auto try_load_backend_config = [&](const ucl_object_t *obj) -> bool {
+ if (!obj || ucl_object_type(obj) != UCL_OBJECT) {
+ return false;
+ }
+
+ /* First try to load read servers */
+ auto *rs = ucl_object_lookup_any(obj, "read_servers", "servers", nullptr);
+ if (rs) {
+ read_servers = rspamd_upstreams_create(cfg->ups_ctx);
+
+ if (read_servers == nullptr) {
+ return false;
+ }
+
+ if (!rspamd_upstreams_from_ucl(read_servers, rs, 80, this)) {
+ rspamd_upstreams_destroy(read_servers);
+ return false;
+ }
+ }
+ auto *ws = ucl_object_lookup_any(obj, "write_servers", "servers", nullptr);
+ if (ws) {
+ write_servers = rspamd_upstreams_create(cfg->ups_ctx);
+
+ if (write_servers == nullptr) {
+ return false;
+ }
+
+ if (!rspamd_upstreams_from_ucl(write_servers, rs, 80, this)) {
+ rspamd_upstreams_destroy(write_servers);
+ return false;
+ }
+ }
+
+ auto *tim = ucl_object_lookup(obj, "timeout");
+
+ if (tim) {
+ timeout = ucl_object_todouble(tim);
+ }
+
+ return true;
+ };
+
+ auto ret = false;
+ auto obj = ucl_object_lookup(st->classifier->cfg->opts, "backend");
+ if (obj != nullptr) {
+ ret = try_load_backend_config(obj);
+ }
+
+ /* Now try statfiles config */
+ if (!ret && st->stcf->opts) {
+ ret = try_load_backend_config(st->stcf->opts);
+ }
+
+ /* Now try classifier config */
+ if (!ret && st->classifier->cfg->opts) {
+ ret = try_load_backend_config(st->classifier->cfg->opts);
+ }
+
+ return ret;
+}
+
+auto http_backends_collection::remove_backend(struct rspamd_statfile *st) -> bool
+{
+ auto backend_it = std::remove(std::begin(backends), std::end(backends), st);
+
+ if (backend_it != std::end(backends)) {
+ /* Fast erasure with no order preservation */
+ std::swap(*backend_it, backends.back());
+ backends.pop_back();
+
+ if (backends.empty()) {
+ /* De-init collection - likely config reload */
+ if (read_servers) {
+ rspamd_upstreams_destroy(read_servers);
+ read_servers = nullptr;
+ }
+
+ if (write_servers) {
+ rspamd_upstreams_destroy(write_servers);
+ write_servers = nullptr;
+ }
+ }
+
+ return true;
+ }
+
+ return false;
+}
+
+upstream *http_backends_collection::get_upstream(bool is_learn)
+{
+ auto *ups_list = read_servers;
+ if (is_learn) {
+ ups_list = write_servers;
+ }
+
+ return rspamd_upstream_get(ups_list, RSPAMD_UPSTREAM_ROUND_ROBIN, nullptr, 0);
+}
+
+}// namespace rspamd::stat::http
+
+/* C API */
+
+gpointer
+rspamd_http_init(struct rspamd_stat_ctx *ctx,
+ struct rspamd_config *cfg,
+ struct rspamd_statfile *st)
+{
+ auto &collections = rspamd::stat::http::http_backends_collection::get();
+
+ if (!collections.add_backend(ctx, cfg, st)) {
+ msg_err_config("cannot load http backend");
+
+ return nullptr;
+ }
+
+ return (void *) &collections;
+}
+gpointer
+rspamd_http_runtime(struct rspamd_task *task,
+ struct rspamd_statfile_config *stcf,
+ gboolean learn,
+ gpointer ctx,
+ gint id)
+{
+ auto maybe_existing = rspamd_mempool_get_variable(task->task_pool, RSPAMD_MEMPOOL_HTTP_STAT_BACKEND_RUNTIME);
+
+ if (maybe_existing != nullptr) {
+ auto real_runtime = (rspamd::stat::http::http_backend_runtime *) maybe_existing;
+ real_runtime->notice_statfile(id, stcf);
+
+ return maybe_existing;
+ }
+
+ auto runtime = rspamd::stat::http::http_backend_runtime::create(task, learn);
+
+ if (runtime) {
+ runtime->notice_statfile(id, stcf);
+ rspamd_mempool_set_variable(task->task_pool, RSPAMD_MEMPOOL_HTTP_STAT_BACKEND_RUNTIME,
+ (void *) runtime, nullptr);
+ }
+
+ return (void *) runtime;
+}
+
+gboolean
+rspamd_http_process_tokens(struct rspamd_task *task,
+ GPtrArray *tokens,
+ gint id,
+ gpointer runtime)
+{
+ auto real_runtime = (rspamd::stat::http::http_backend_runtime *) runtime;
+
+ if (real_runtime) {
+ return real_runtime->process_tokens(task, tokens, id, false);
+ }
+
+
+ return false;
+}
+gboolean
+rspamd_http_finalize_process(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx)
+{
+ /* Not needed */
+ return true;
+}
+
+gboolean
+rspamd_http_learn_tokens(struct rspamd_task *task,
+ GPtrArray *tokens,
+ gint id,
+ gpointer runtime)
+{
+ auto real_runtime = (rspamd::stat::http::http_backend_runtime *) runtime;
+
+ if (real_runtime) {
+ return real_runtime->process_tokens(task, tokens, id, true);
+ }
+
+
+ return false;
+}
+gboolean
+rspamd_http_finalize_learn(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx,
+ GError **err)
+{
+ return false;
+}
+
+gulong rspamd_http_total_learns(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx)
+{
+ /* TODO */
+ return 0;
+}
+gulong
+rspamd_http_inc_learns(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx)
+{
+ /* TODO */
+ return 0;
+}
+gulong
+rspamd_http_dec_learns(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx)
+{
+ /* TODO */
+ return (gulong) -1;
+}
+gulong
+rspamd_http_learns(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx)
+{
+ /* TODO */
+ return 0;
+}
+ucl_object_t *
+rspamd_http_get_stat(gpointer runtime, gpointer ctx)
+{
+ /* TODO */
+ return nullptr;
+}
+gpointer
+rspamd_http_load_tokenizer_config(gpointer runtime, gsize *len)
+{
+ return nullptr;
+}
+void rspamd_http_close(gpointer ctx)
+{
+ /* TODO */
+} \ No newline at end of file
diff --git a/src/libstat/backends/mmaped_file.c b/src/libstat/backends/mmaped_file.c
new file mode 100644
index 0000000..5c20207
--- /dev/null
+++ b/src/libstat/backends/mmaped_file.c
@@ -0,0 +1,1113 @@
+/*-
+ * Copyright 2016 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "config.h"
+#include "stat_internal.h"
+#include "unix-std.h"
+
+#define CHAIN_LENGTH 128
+
+/* Section types */
+#define STATFILE_SECTION_COMMON 1
+
+/**
+ * Common statfile header
+ */
+struct stat_file_header {
+ u_char magic[3]; /**< magic signature ('r' 's' 'd') */
+ u_char version[2]; /**< version of statfile */
+ u_char padding[3]; /**< padding */
+ guint64 create_time; /**< create time (time_t->guint64) */
+ guint64 revision; /**< revision number */
+ guint64 rev_time; /**< revision time */
+ guint64 used_blocks; /**< used blocks number */
+ guint64 total_blocks; /**< total number of blocks */
+ guint64 tokenizer_conf_len; /**< length of tokenizer configuration */
+ u_char unused[231]; /**< some bytes that can be used in future */
+};
+
+/**
+ * Section header
+ */
+struct stat_file_section {
+ guint64 code; /**< section's code */
+ guint64 length; /**< section's length in blocks */
+};
+
+/**
+ * Block of data in statfile
+ */
+struct stat_file_block {
+ guint32 hash1; /**< hash1 (also acts as index) */
+ guint32 hash2; /**< hash2 */
+ double value; /**< double value */
+};
+
+/**
+ * Statistic file
+ */
+struct stat_file {
+ struct stat_file_header header; /**< header */
+ struct stat_file_section section; /**< first section */
+ struct stat_file_block blocks[1]; /**< first block of data */
+};
+
+/**
+ * Common view of statfile object
+ */
+typedef struct {
+#ifdef HAVE_PATH_MAX
+ gchar filename[PATH_MAX]; /**< name of file */
+#else
+ gchar filename[MAXPATHLEN]; /**< name of file */
+#endif
+ rspamd_mempool_t *pool;
+ gint fd; /**< descriptor */
+ void *map; /**< mmaped area */
+ off_t seek_pos; /**< current seek position */
+ struct stat_file_section cur_section; /**< current section */
+ size_t len; /**< length of file(in bytes) */
+ struct rspamd_statfile_config *cf;
+} rspamd_mmaped_file_t;
+
+
+#define RSPAMD_STATFILE_VERSION \
+ { \
+ '1', '2' \
+ }
+#define BACKUP_SUFFIX ".old"
+
+static void rspamd_mmaped_file_set_block_common(rspamd_mempool_t *pool,
+ rspamd_mmaped_file_t *file,
+ guint32 h1, guint32 h2, double value);
+
+rspamd_mmaped_file_t *rspamd_mmaped_file_open(rspamd_mempool_t *pool,
+ const gchar *filename, size_t size,
+ struct rspamd_statfile_config *stcf);
+gint rspamd_mmaped_file_create(const gchar *filename, size_t size,
+ struct rspamd_statfile_config *stcf,
+ rspamd_mempool_t *pool);
+gint rspamd_mmaped_file_close_file(rspamd_mempool_t *pool,
+ rspamd_mmaped_file_t *file);
+
+double
+rspamd_mmaped_file_get_block(rspamd_mmaped_file_t *file,
+ guint32 h1,
+ guint32 h2)
+{
+ struct stat_file_block *block;
+ guint i, blocknum;
+ u_char *c;
+
+ if (!file->map) {
+ return 0;
+ }
+
+ blocknum = h1 % file->cur_section.length;
+ c = (u_char *) file->map + file->seek_pos + blocknum * sizeof(struct stat_file_block);
+ block = (struct stat_file_block *) c;
+
+ for (i = 0; i < CHAIN_LENGTH; i++) {
+ if (i + blocknum >= file->cur_section.length) {
+ break;
+ }
+ if (block->hash1 == h1 && block->hash2 == h2) {
+ return block->value;
+ }
+ c += sizeof(struct stat_file_block);
+ block = (struct stat_file_block *) c;
+ }
+
+
+ return 0;
+}
+
+static void
+rspamd_mmaped_file_set_block_common(rspamd_mempool_t *pool,
+ rspamd_mmaped_file_t *file,
+ guint32 h1, guint32 h2, double value)
+{
+ struct stat_file_block *block, *to_expire = NULL;
+ struct stat_file_header *header;
+ guint i, blocknum;
+ u_char *c;
+ double min = G_MAXDOUBLE;
+
+ if (!file->map) {
+ return;
+ }
+
+ blocknum = h1 % file->cur_section.length;
+ header = (struct stat_file_header *) file->map;
+ c = (u_char *) file->map + file->seek_pos + blocknum * sizeof(struct stat_file_block);
+ block = (struct stat_file_block *) c;
+
+ for (i = 0; i < CHAIN_LENGTH; i++) {
+ if (i + blocknum >= file->cur_section.length) {
+ /* Need to expire some block in chain */
+ msg_info_pool("chain %ud is full in statfile %s, starting expire",
+ blocknum,
+ file->filename);
+ break;
+ }
+ /* First try to find block in chain */
+ if (block->hash1 == h1 && block->hash2 == h2) {
+ msg_debug_pool("%s found existing block %ud in chain %ud, value %.2f",
+ file->filename,
+ i,
+ blocknum,
+ value);
+ block->value = value;
+ return;
+ }
+ /* Check whether we have a free block in chain */
+ if (block->hash1 == 0 && block->hash2 == 0) {
+ /* Write new block here */
+ msg_debug_pool("%s found free block %ud in chain %ud, set h1=%ud, h2=%ud",
+ file->filename,
+ i,
+ blocknum,
+ h1,
+ h2);
+ block->hash1 = h1;
+ block->hash2 = h2;
+ block->value = value;
+ header->used_blocks++;
+
+ return;
+ }
+
+ /* Expire block with minimum value otherwise */
+ if (block->value < min) {
+ to_expire = block;
+ min = block->value;
+ }
+ c += sizeof(struct stat_file_block);
+ block = (struct stat_file_block *) c;
+ }
+
+ /* Try expire some block */
+ if (to_expire) {
+ block = to_expire;
+ }
+ else {
+ /* Expire first block in chain */
+ c = (u_char *) file->map + file->seek_pos + blocknum * sizeof(struct stat_file_block);
+ block = (struct stat_file_block *) c;
+ }
+
+ block->hash1 = h1;
+ block->hash2 = h2;
+ block->value = value;
+}
+
+void rspamd_mmaped_file_set_block(rspamd_mempool_t *pool,
+ rspamd_mmaped_file_t *file,
+ guint32 h1,
+ guint32 h2,
+ double value)
+{
+ rspamd_mmaped_file_set_block_common(pool, file, h1, h2, value);
+}
+
+gboolean
+rspamd_mmaped_file_set_revision(rspamd_mmaped_file_t *file, guint64 rev, time_t time)
+{
+ struct stat_file_header *header;
+
+ if (file == NULL || file->map == NULL) {
+ return FALSE;
+ }
+
+ header = (struct stat_file_header *) file->map;
+
+ header->revision = rev;
+ header->rev_time = time;
+
+ return TRUE;
+}
+
+gboolean
+rspamd_mmaped_file_inc_revision(rspamd_mmaped_file_t *file)
+{
+ struct stat_file_header *header;
+
+ if (file == NULL || file->map == NULL) {
+ return FALSE;
+ }
+
+ header = (struct stat_file_header *) file->map;
+
+ header->revision++;
+
+ return TRUE;
+}
+
+gboolean
+rspamd_mmaped_file_dec_revision(rspamd_mmaped_file_t *file)
+{
+ struct stat_file_header *header;
+
+ if (file == NULL || file->map == NULL) {
+ return FALSE;
+ }
+
+ header = (struct stat_file_header *) file->map;
+
+ header->revision--;
+
+ return TRUE;
+}
+
+
+gboolean
+rspamd_mmaped_file_get_revision(rspamd_mmaped_file_t *file, guint64 *rev, time_t *time)
+{
+ struct stat_file_header *header;
+
+ if (file == NULL || file->map == NULL) {
+ return FALSE;
+ }
+
+ header = (struct stat_file_header *) file->map;
+
+ if (rev != NULL) {
+ *rev = header->revision;
+ }
+ if (time != NULL) {
+ *time = header->rev_time;
+ }
+
+ return TRUE;
+}
+
+guint64
+rspamd_mmaped_file_get_used(rspamd_mmaped_file_t *file)
+{
+ struct stat_file_header *header;
+
+ if (file == NULL || file->map == NULL) {
+ return (guint64) -1;
+ }
+
+ header = (struct stat_file_header *) file->map;
+
+ return header->used_blocks;
+}
+
+guint64
+rspamd_mmaped_file_get_total(rspamd_mmaped_file_t *file)
+{
+ struct stat_file_header *header;
+
+ if (file == NULL || file->map == NULL) {
+ return (guint64) -1;
+ }
+
+ header = (struct stat_file_header *) file->map;
+
+ /* If total blocks is 0 we have old version of header, so set total blocks correctly */
+ if (header->total_blocks == 0) {
+ header->total_blocks = file->cur_section.length;
+ }
+
+ return header->total_blocks;
+}
+
+/* Check whether specified file is statistic file and calculate its len in blocks */
+static gint
+rspamd_mmaped_file_check(rspamd_mempool_t *pool, rspamd_mmaped_file_t *file)
+{
+ struct stat_file *f;
+ gchar *c;
+ static gchar valid_version[] = RSPAMD_STATFILE_VERSION;
+
+
+ if (!file || !file->map) {
+ return -1;
+ }
+
+ if (file->len < sizeof(struct stat_file)) {
+ msg_info_pool("file %s is too short to be stat file: %z",
+ file->filename,
+ file->len);
+ return -1;
+ }
+
+ f = (struct stat_file *) file->map;
+ c = &f->header.magic[0];
+ /* Check magic and version */
+ if (*c++ != 'r' || *c++ != 's' || *c++ != 'd') {
+ msg_info_pool("file %s is invalid stat file", file->filename);
+ return -1;
+ }
+
+ c = &f->header.version[0];
+ /* Now check version and convert old version to new one (that can be used for sync */
+ if (*c == 1 && *(c + 1) == 0) {
+ return -1;
+ }
+ else if (memcmp(c, valid_version, sizeof(valid_version)) != 0) {
+ /* Unknown version */
+ msg_info_pool("file %s has invalid version %c.%c",
+ file->filename,
+ '0' + *c,
+ '0' + *(c + 1));
+ return -1;
+ }
+
+ /* Check first section and set new offset */
+ file->cur_section.code = f->section.code;
+ file->cur_section.length = f->section.length;
+ if (file->cur_section.length * sizeof(struct stat_file_block) >
+ file->len) {
+ msg_info_pool("file %s is truncated: %z, must be %z",
+ file->filename,
+ file->len,
+ file->cur_section.length * sizeof(struct stat_file_block));
+ return -1;
+ }
+ file->seek_pos = sizeof(struct stat_file) -
+ sizeof(struct stat_file_block);
+
+ return 0;
+}
+
+
+static rspamd_mmaped_file_t *
+rspamd_mmaped_file_reindex(rspamd_mempool_t *pool,
+ const gchar *filename,
+ size_t old_size,
+ size_t size,
+ struct rspamd_statfile_config *stcf)
+{
+ gchar *backup, *lock;
+ gint fd, lock_fd;
+ rspamd_mmaped_file_t *new, *old = NULL;
+ u_char *map, *pos;
+ struct stat_file_block *block;
+ struct stat_file_header *header, *nh;
+ struct timespec sleep_ts = {
+ .tv_sec = 0,
+ .tv_nsec = 1000000};
+
+ if (size <
+ sizeof(struct stat_file_header) + sizeof(struct stat_file_section) +
+ sizeof(block)) {
+ msg_err_pool("file %s is too small to carry any statistic: %z",
+ filename,
+ size);
+ return NULL;
+ }
+
+ lock = g_strconcat(filename, ".lock", NULL);
+ lock_fd = open(lock, O_WRONLY | O_CREAT | O_EXCL, 00600);
+
+ while (lock_fd == -1) {
+ /* Wait for lock */
+ lock_fd = open(lock, O_WRONLY | O_CREAT | O_EXCL, 00600);
+ if (lock_fd != -1) {
+ unlink(lock);
+ close(lock_fd);
+ g_free(lock);
+
+ return rspamd_mmaped_file_open(pool, filename, size, stcf);
+ }
+ else {
+ nanosleep(&sleep_ts, NULL);
+ }
+ }
+
+ backup = g_strconcat(filename, ".old", NULL);
+ if (rename(filename, backup) == -1) {
+ msg_err_pool("cannot rename %s to %s: %s", filename, backup, strerror(errno));
+ g_free(backup);
+ unlink(lock);
+ g_free(lock);
+ close(lock_fd);
+
+ return NULL;
+ }
+
+ old = rspamd_mmaped_file_open(pool, backup, old_size, stcf);
+
+ if (old == NULL) {
+ msg_warn_pool("old file %s is invalid mmapped file, just move it",
+ backup);
+ }
+
+ /* We need to release our lock here */
+ unlink(lock);
+ close(lock_fd);
+ g_free(lock);
+
+ /* Now create new file with required size */
+ if (rspamd_mmaped_file_create(filename, size, stcf, pool) != 0) {
+ msg_err_pool("cannot create new file");
+ rspamd_mmaped_file_close(old);
+ g_free(backup);
+
+ return NULL;
+ }
+
+ new = rspamd_mmaped_file_open(pool, filename, size, stcf);
+
+ if (old) {
+ /* Now open new file and start copying */
+ fd = open(backup, O_RDONLY);
+ if (fd == -1 || new == NULL) {
+ if (fd != -1) {
+ close(fd);
+ }
+
+ msg_err_pool("cannot open file: %s", strerror(errno));
+ rspamd_mmaped_file_close(old);
+ g_free(backup);
+ return NULL;
+ }
+
+
+ /* Now start reading blocks from old statfile */
+ if ((map =
+ mmap(NULL, old_size, PROT_READ, MAP_SHARED, fd, 0)) == MAP_FAILED) {
+ msg_err_pool("cannot mmap file: %s", strerror(errno));
+ close(fd);
+ rspamd_mmaped_file_close(old);
+ g_free(backup);
+ return NULL;
+ }
+
+ pos = map + (sizeof(struct stat_file) - sizeof(struct stat_file_block));
+
+ if (pos - map < (gssize) old_size) {
+ while ((gssize) old_size - (pos - map) >= (gssize) sizeof(struct stat_file_block)) {
+ block = (struct stat_file_block *) pos;
+ if (block->hash1 != 0 && block->value != 0) {
+ rspamd_mmaped_file_set_block_common(pool,
+ new, block->hash1,
+ block->hash2, block->value);
+ }
+ pos += sizeof(block);
+ }
+ }
+
+ header = (struct stat_file_header *) map;
+ rspamd_mmaped_file_set_revision(new, header->revision, header->rev_time);
+ nh = new->map;
+ /* Copy tokenizer configuration */
+ memcpy(nh->unused, header->unused, sizeof(header->unused));
+ nh->tokenizer_conf_len = header->tokenizer_conf_len;
+
+ munmap(map, old_size);
+ close(fd);
+ rspamd_mmaped_file_close_file(pool, old);
+ }
+
+ unlink(backup);
+ g_free(backup);
+
+ return new;
+}
+
+/*
+ * Pre-load mmaped file into memory
+ */
+static void
+rspamd_mmaped_file_preload(rspamd_mmaped_file_t *file)
+{
+ guint8 *pos, *end;
+ volatile guint8 t;
+ gsize size;
+
+ pos = (guint8 *) file->map;
+ end = (guint8 *) file->map + file->len;
+
+ if (madvise(pos, end - pos, MADV_SEQUENTIAL) == -1) {
+ msg_info("madvise failed: %s", strerror(errno));
+ }
+ else {
+ /* Load pages of file */
+#ifdef HAVE_GETPAGESIZE
+ size = getpagesize();
+#else
+ size = sysconf(_SC_PAGESIZE);
+#endif
+ while (pos < end) {
+ t = *pos;
+ (void) t;
+ pos += size;
+ }
+ }
+}
+
+rspamd_mmaped_file_t *
+rspamd_mmaped_file_open(rspamd_mempool_t *pool,
+ const gchar *filename, size_t size,
+ struct rspamd_statfile_config *stcf)
+{
+ struct stat st;
+ rspamd_mmaped_file_t *new_file;
+ gchar *lock;
+ gint lock_fd;
+
+ lock = g_strconcat(filename, ".lock", NULL);
+ lock_fd = open(lock, O_WRONLY | O_CREAT | O_EXCL, 00600);
+
+ if (lock_fd == -1) {
+ g_free(lock);
+ msg_info_pool("cannot open file %s, it is locked by another process",
+ filename);
+ return NULL;
+ }
+
+ close(lock_fd);
+ unlink(lock);
+ g_free(lock);
+
+ if (stat(filename, &st) == -1) {
+ msg_info_pool("cannot stat file %s, error %s, %d", filename, strerror(errno), errno);
+ return NULL;
+ }
+
+ if (labs((glong) size - st.st_size) > (long) sizeof(struct stat_file) * 2 && size > sizeof(struct stat_file)) {
+ msg_warn_pool("need to reindex statfile old size: %Hz, new size: %Hz",
+ (size_t) st.st_size, size);
+ return rspamd_mmaped_file_reindex(pool, filename, st.st_size, size, stcf);
+ }
+ else if (size < sizeof(struct stat_file)) {
+ msg_err_pool("requested to shrink statfile to %Hz but it is too small",
+ size);
+ }
+
+ new_file = g_malloc0(sizeof(rspamd_mmaped_file_t));
+ if ((new_file->fd = open(filename, O_RDWR)) == -1) {
+ msg_info_pool("cannot open file %s, error %d, %s",
+ filename,
+ errno,
+ strerror(errno));
+ g_free(new_file);
+ return NULL;
+ }
+
+ if ((new_file->map =
+ mmap(NULL, st.st_size, PROT_READ | PROT_WRITE, MAP_SHARED,
+ new_file->fd, 0)) == MAP_FAILED) {
+ close(new_file->fd);
+ msg_info_pool("cannot mmap file %s, error %d, %s",
+ filename,
+ errno,
+ strerror(errno));
+ g_free(new_file);
+ return NULL;
+ }
+
+ rspamd_strlcpy(new_file->filename, filename, sizeof(new_file->filename));
+ new_file->len = st.st_size;
+ /* Try to lock pages in RAM */
+
+ /* Acquire lock for this operation */
+ if (!rspamd_file_lock(new_file->fd, FALSE)) {
+ close(new_file->fd);
+ munmap(new_file->map, st.st_size);
+ msg_info_pool("cannot lock file %s, error %d, %s",
+ filename,
+ errno,
+ strerror(errno));
+ g_free(new_file);
+ return NULL;
+ }
+
+ if (rspamd_mmaped_file_check(pool, new_file) == -1) {
+ close(new_file->fd);
+ rspamd_file_unlock(new_file->fd, FALSE);
+ munmap(new_file->map, st.st_size);
+ g_free(new_file);
+ return NULL;
+ }
+
+ rspamd_file_unlock(new_file->fd, FALSE);
+ new_file->cf = stcf;
+ new_file->pool = pool;
+ rspamd_mmaped_file_preload(new_file);
+
+ g_assert(stcf->clcf != NULL);
+
+ msg_debug_pool("opened statfile %s of size %l", filename, (long) size);
+
+ return new_file;
+}
+
+gint rspamd_mmaped_file_close_file(rspamd_mempool_t *pool,
+ rspamd_mmaped_file_t *file)
+{
+ if (file->map) {
+ msg_info_pool("syncing statfile %s", file->filename);
+ msync(file->map, file->len, MS_ASYNC);
+ munmap(file->map, file->len);
+ }
+ if (file->fd != -1) {
+ close(file->fd);
+ }
+
+ g_free(file);
+
+ return 0;
+}
+
+gint rspamd_mmaped_file_create(const gchar *filename,
+ size_t size,
+ struct rspamd_statfile_config *stcf,
+ rspamd_mempool_t *pool)
+{
+ struct stat_file_header header = {
+ .magic = {'r', 's', 'd'},
+ .version = RSPAMD_STATFILE_VERSION,
+ .padding = {0, 0, 0},
+ .revision = 0,
+ .rev_time = 0,
+ .used_blocks = 0};
+ struct stat_file_section section = {
+ .code = STATFILE_SECTION_COMMON,
+ };
+ struct stat_file_block block = {0, 0, 0};
+ struct rspamd_stat_tokenizer *tokenizer;
+ gint fd, lock_fd;
+ guint buflen = 0, nblocks;
+ gchar *buf = NULL, *lock;
+ struct stat sb;
+ gpointer tok_conf;
+ gsize tok_conf_len;
+ struct timespec sleep_ts = {
+ .tv_sec = 0,
+ .tv_nsec = 1000000};
+
+ if (size <
+ sizeof(struct stat_file_header) + sizeof(struct stat_file_section) +
+ sizeof(block)) {
+ msg_err_pool("file %s is too small to carry any statistic: %z",
+ filename,
+ size);
+ return -1;
+ }
+
+ lock = g_strconcat(filename, ".lock", NULL);
+ lock_fd = open(lock, O_WRONLY | O_CREAT | O_EXCL, 00600);
+
+ while (lock_fd == -1) {
+ /* Wait for lock */
+ lock_fd = open(lock, O_WRONLY | O_CREAT | O_EXCL, 00600);
+ if (lock_fd != -1) {
+ if (stat(filename, &sb) != -1) {
+ /* File has been created by some other process */
+ unlink(lock);
+ close(lock_fd);
+ g_free(lock);
+
+ return 0;
+ }
+
+ /* We still need to create it */
+ goto create;
+ }
+ else {
+ nanosleep(&sleep_ts, NULL);
+ }
+ }
+
+create:
+
+ msg_debug_pool("create statfile %s of size %l", filename, (long) size);
+ nblocks =
+ (size - sizeof(struct stat_file_header) -
+ sizeof(struct stat_file_section)) /
+ sizeof(struct stat_file_block);
+ header.total_blocks = nblocks;
+
+ if ((fd =
+ open(filename, O_RDWR | O_TRUNC | O_CREAT, S_IWUSR | S_IRUSR)) == -1) {
+ msg_info_pool("cannot create file %s, error %d, %s",
+ filename,
+ errno,
+ strerror(errno));
+ unlink(lock);
+ close(lock_fd);
+ g_free(lock);
+
+ return -1;
+ }
+
+ rspamd_fallocate(fd,
+ 0,
+ sizeof(header) + sizeof(section) + sizeof(block) * nblocks);
+
+ header.create_time = (guint64) time(NULL);
+ g_assert(stcf->clcf != NULL);
+ g_assert(stcf->clcf->tokenizer != NULL);
+ tokenizer = rspamd_stat_get_tokenizer(stcf->clcf->tokenizer->name);
+ g_assert(tokenizer != NULL);
+ tok_conf = tokenizer->get_config(pool, stcf->clcf->tokenizer, &tok_conf_len);
+ header.tokenizer_conf_len = tok_conf_len;
+ g_assert(tok_conf_len < sizeof(header.unused) - sizeof(guint64));
+ memcpy(header.unused, tok_conf, tok_conf_len);
+
+ if (write(fd, &header, sizeof(header)) == -1) {
+ msg_info_pool("cannot write header to file %s, error %d, %s",
+ filename,
+ errno,
+ strerror(errno));
+ close(fd);
+ unlink(lock);
+ close(lock_fd);
+ g_free(lock);
+
+ return -1;
+ }
+
+ section.length = (guint64) nblocks;
+ if (write(fd, &section, sizeof(section)) == -1) {
+ msg_info_pool("cannot write section header to file %s, error %d, %s",
+ filename,
+ errno,
+ strerror(errno));
+ close(fd);
+ unlink(lock);
+ close(lock_fd);
+ g_free(lock);
+
+ return -1;
+ }
+
+ /* Buffer for write 256 blocks at once */
+ if (nblocks > 256) {
+ buflen = sizeof(block) * 256;
+ buf = g_malloc0(buflen);
+ }
+
+ while (nblocks) {
+ if (nblocks > 256) {
+ /* Just write buffer */
+ if (write(fd, buf, buflen) == -1) {
+ msg_info_pool("cannot write blocks buffer to file %s, error %d, %s",
+ filename,
+ errno,
+ strerror(errno));
+ close(fd);
+ g_free(buf);
+ unlink(lock);
+ close(lock_fd);
+ g_free(lock);
+
+ return -1;
+ }
+ nblocks -= 256;
+ }
+ else {
+ if (write(fd, &block, sizeof(block)) == -1) {
+ msg_info_pool("cannot write block to file %s, error %d, %s",
+ filename,
+ errno,
+ strerror(errno));
+ close(fd);
+ if (buf) {
+ g_free(buf);
+ }
+
+ unlink(lock);
+ close(lock_fd);
+ g_free(lock);
+
+ return -1;
+ }
+ nblocks--;
+ }
+ }
+
+ close(fd);
+
+ if (buf) {
+ g_free(buf);
+ }
+
+ unlink(lock);
+ close(lock_fd);
+ g_free(lock);
+ msg_debug_pool("created statfile %s of size %l", filename, (long) size);
+
+ return 0;
+}
+
+gpointer
+rspamd_mmaped_file_init(struct rspamd_stat_ctx *ctx,
+ struct rspamd_config *cfg, struct rspamd_statfile *st)
+{
+ struct rspamd_statfile_config *stf = st->stcf;
+ rspamd_mmaped_file_t *mf;
+ const ucl_object_t *filenameo, *sizeo;
+ const gchar *filename;
+ gsize size;
+
+ filenameo = ucl_object_lookup(stf->opts, "filename");
+
+ if (filenameo == NULL || ucl_object_type(filenameo) != UCL_STRING) {
+ filenameo = ucl_object_lookup(stf->opts, "path");
+
+ if (filenameo == NULL || ucl_object_type(filenameo) != UCL_STRING) {
+ msg_err_config("statfile %s has no filename defined", stf->symbol);
+ return NULL;
+ }
+ }
+
+ filename = ucl_object_tostring(filenameo);
+
+ sizeo = ucl_object_lookup(stf->opts, "size");
+
+ if (sizeo == NULL || ucl_object_type(sizeo) != UCL_INT) {
+ msg_err_config("statfile %s has no size defined", stf->symbol);
+ return NULL;
+ }
+
+ size = ucl_object_toint(sizeo);
+ mf = rspamd_mmaped_file_open(cfg->cfg_pool, filename, size, stf);
+
+ if (mf != NULL) {
+ mf->pool = cfg->cfg_pool;
+ }
+ else {
+ /* Create file here */
+
+ filenameo = ucl_object_find_key(stf->opts, "filename");
+ if (filenameo == NULL || ucl_object_type(filenameo) != UCL_STRING) {
+ filenameo = ucl_object_find_key(stf->opts, "path");
+ if (filenameo == NULL || ucl_object_type(filenameo) != UCL_STRING) {
+ msg_err_config("statfile %s has no filename defined", stf->symbol);
+ return NULL;
+ }
+ }
+
+ filename = ucl_object_tostring(filenameo);
+
+ sizeo = ucl_object_find_key(stf->opts, "size");
+ if (sizeo == NULL || ucl_object_type(sizeo) != UCL_INT) {
+ msg_err_config("statfile %s has no size defined", stf->symbol);
+ return NULL;
+ }
+
+ size = ucl_object_toint(sizeo);
+
+ if (rspamd_mmaped_file_create(filename, size, stf, cfg->cfg_pool) != 0) {
+ msg_err_config("cannot create new file");
+ }
+
+ mf = rspamd_mmaped_file_open(cfg->cfg_pool, filename, size, stf);
+ }
+
+ return (gpointer) mf;
+}
+
+void rspamd_mmaped_file_close(gpointer p)
+{
+ rspamd_mmaped_file_t *mf = p;
+
+
+ if (mf) {
+ rspamd_mmaped_file_close_file(mf->pool, mf);
+ }
+}
+
+gpointer
+rspamd_mmaped_file_runtime(struct rspamd_task *task,
+ struct rspamd_statfile_config *stcf,
+ gboolean learn,
+ gpointer p,
+ gint _id)
+{
+ rspamd_mmaped_file_t *mf = p;
+
+ return (gpointer) mf;
+}
+
+gboolean
+rspamd_mmaped_file_process_tokens(struct rspamd_task *task, GPtrArray *tokens,
+ gint id,
+ gpointer p)
+{
+ rspamd_mmaped_file_t *mf = p;
+ guint32 h1, h2;
+ rspamd_token_t *tok;
+ guint i;
+
+ g_assert(tokens != NULL);
+ g_assert(p != NULL);
+
+ for (i = 0; i < tokens->len; i++) {
+ tok = g_ptr_array_index(tokens, i);
+ memcpy(&h1, (guchar *) &tok->data, sizeof(h1));
+ memcpy(&h2, ((guchar *) &tok->data) + sizeof(h1), sizeof(h2));
+ tok->values[id] = rspamd_mmaped_file_get_block(mf, h1, h2);
+ }
+
+ if (mf->cf->is_spam) {
+ task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS;
+ }
+ else {
+ task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
+ }
+
+ return TRUE;
+}
+
+gboolean
+rspamd_mmaped_file_learn_tokens(struct rspamd_task *task, GPtrArray *tokens,
+ gint id,
+ gpointer p)
+{
+ rspamd_mmaped_file_t *mf = p;
+ guint32 h1, h2;
+ rspamd_token_t *tok;
+ guint i;
+
+ g_assert(tokens != NULL);
+ g_assert(p != NULL);
+
+ for (i = 0; i < tokens->len; i++) {
+ tok = g_ptr_array_index(tokens, i);
+ memcpy(&h1, (guchar *) &tok->data, sizeof(h1));
+ memcpy(&h2, ((guchar *) &tok->data) + sizeof(h1), sizeof(h2));
+ rspamd_mmaped_file_set_block(task->task_pool, mf, h1, h2,
+ tok->values[id]);
+ }
+
+ return TRUE;
+}
+
+gulong
+rspamd_mmaped_file_total_learns(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ rspamd_mmaped_file_t *mf = (rspamd_mmaped_file_t *) runtime;
+ guint64 rev = 0;
+ time_t t;
+
+ if (mf != NULL) {
+ rspamd_mmaped_file_get_revision(mf, &rev, &t);
+ }
+
+ return rev;
+}
+
+gulong
+rspamd_mmaped_file_inc_learns(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ rspamd_mmaped_file_t *mf = (rspamd_mmaped_file_t *) runtime;
+ guint64 rev = 0;
+ time_t t;
+
+ if (mf != NULL) {
+ rspamd_mmaped_file_inc_revision(mf);
+ rspamd_mmaped_file_get_revision(mf, &rev, &t);
+ }
+
+ return rev;
+}
+
+gulong
+rspamd_mmaped_file_dec_learns(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ rspamd_mmaped_file_t *mf = (rspamd_mmaped_file_t *) runtime;
+ guint64 rev = 0;
+ time_t t;
+
+ if (mf != NULL) {
+ rspamd_mmaped_file_dec_revision(mf);
+ rspamd_mmaped_file_get_revision(mf, &rev, &t);
+ }
+
+ return rev;
+}
+
+
+ucl_object_t *
+rspamd_mmaped_file_get_stat(gpointer runtime,
+ gpointer ctx)
+{
+ ucl_object_t *res = NULL;
+ guint64 rev;
+ rspamd_mmaped_file_t *mf = (rspamd_mmaped_file_t *) runtime;
+
+ if (mf != NULL) {
+ res = ucl_object_typed_new(UCL_OBJECT);
+ rspamd_mmaped_file_get_revision(mf, &rev, NULL);
+ ucl_object_insert_key(res, ucl_object_fromint(rev), "revision",
+ 0, false);
+ ucl_object_insert_key(res, ucl_object_fromint(mf->len), "size",
+ 0, false);
+ ucl_object_insert_key(res, ucl_object_fromint(rspamd_mmaped_file_get_total(mf)), "total", 0, false);
+ ucl_object_insert_key(res, ucl_object_fromint(rspamd_mmaped_file_get_used(mf)), "used", 0, false);
+ ucl_object_insert_key(res, ucl_object_fromstring(mf->cf->symbol),
+ "symbol", 0, false);
+ ucl_object_insert_key(res, ucl_object_fromstring("mmap"),
+ "type", 0, false);
+ ucl_object_insert_key(res, ucl_object_fromint(0),
+ "languages", 0, false);
+ ucl_object_insert_key(res, ucl_object_fromint(0),
+ "users", 0, false);
+
+ if (mf->cf->label) {
+ ucl_object_insert_key(res, ucl_object_fromstring(mf->cf->label),
+ "label", 0, false);
+ }
+ }
+
+ return res;
+}
+
+gboolean
+rspamd_mmaped_file_finalize_learn(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx, GError **err)
+{
+ rspamd_mmaped_file_t *mf = (rspamd_mmaped_file_t *) runtime;
+
+ if (mf != NULL) {
+ msync(mf->map, mf->len, MS_INVALIDATE | MS_ASYNC);
+ }
+
+ return TRUE;
+}
+
+gboolean
+rspamd_mmaped_file_finalize_process(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ return TRUE;
+}
+
+gpointer
+rspamd_mmaped_file_load_tokenizer_config(gpointer runtime,
+ gsize *len)
+{
+ rspamd_mmaped_file_t *mf = runtime;
+ struct stat_file_header *header;
+
+ g_assert(mf != NULL);
+ header = mf->map;
+
+ if (len) {
+ *len = header->tokenizer_conf_len;
+ }
+
+ return header->unused;
+}
diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx
new file mode 100644
index 0000000..cd0c379
--- /dev/null
+++ b/src/libstat/backends/redis_backend.cxx
@@ -0,0 +1,1132 @@
+/*
+ * Copyright 2024 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "config.h"
+#include "lua/lua_common.h"
+#include "rspamd.h"
+#include "stat_internal.h"
+#include "upstream.h"
+#include "libserver/mempool_vars_internal.h"
+#include "fmt/core.h"
+
+#include "libutil/cxx/error.hxx"
+
+#include <string>
+#include <cstdint>
+#include <vector>
+#include <optional>
+
+#define msg_debug_stat_redis(...) rspamd_conditional_debug_fast(nullptr, nullptr, \
+ rspamd_stat_redis_log_id, "stat_redis", task->task_pool->tag.uid, \
+ RSPAMD_LOG_FUNC, \
+ __VA_ARGS__)
+
+INIT_LOG_MODULE(stat_redis)
+
+#define REDIS_CTX(p) (reinterpret_cast<struct redis_stat_ctx *>(p))
+#define REDIS_RUNTIME(p) (reinterpret_cast<struct redis_stat_runtime<float> *>(p))
+#define REDIS_DEFAULT_OBJECT "%s%l"
+#define REDIS_DEFAULT_USERS_OBJECT "%s%l%r"
+#define REDIS_DEFAULT_TIMEOUT 0.5
+#define REDIS_STAT_TIMEOUT 30
+#define REDIS_MAX_USERS 1000
+
+struct redis_stat_ctx {
+ lua_State *L;
+ struct rspamd_statfile_config *stcf;
+ const char *redis_object = REDIS_DEFAULT_OBJECT;
+ bool enable_users = false;
+ bool store_tokens = false;
+ bool enable_signatures = false;
+ int cbref_user = -1;
+
+ int cbref_classify = -1;
+ int cbref_learn = -1;
+
+ ucl_object_t *cur_stat = nullptr;
+
+ explicit redis_stat_ctx(lua_State *_L)
+ : L(_L)
+ {
+ }
+
+ ~redis_stat_ctx()
+ {
+ if (cbref_user != -1) {
+ luaL_unref(L, LUA_REGISTRYINDEX, cbref_user);
+ }
+
+ if (cbref_classify != -1) {
+ luaL_unref(L, LUA_REGISTRYINDEX, cbref_classify);
+ }
+
+ if (cbref_learn != -1) {
+ luaL_unref(L, LUA_REGISTRYINDEX, cbref_learn);
+ }
+ }
+};
+
+
+template<class T, std::enable_if_t<std::is_convertible_v<T, float>, bool> = true>
+struct redis_stat_runtime {
+ struct redis_stat_ctx *ctx;
+ struct rspamd_task *task;
+ struct rspamd_statfile_config *stcf;
+ GPtrArray *tokens = nullptr;
+ const char *redis_object_expanded;
+ std::uint64_t learned = 0;
+ int id;
+ std::vector<std::pair<int, T>> *results = nullptr;
+ bool need_redis_call = true;
+ std::optional<rspamd::util::error> err;
+
+ using result_type = std::vector<std::pair<int, T>>;
+
+private:
+ /* Called on connection termination */
+ static void rt_dtor(gpointer data)
+ {
+ auto *rt = REDIS_RUNTIME(data);
+
+ delete rt;
+ }
+
+ /* Avoid occasional deletion */
+ ~redis_stat_runtime()
+ {
+ if (tokens) {
+ g_ptr_array_unref(tokens);
+ }
+
+ delete results;
+ }
+
+public:
+ explicit redis_stat_runtime(struct redis_stat_ctx *_ctx, struct rspamd_task *_task, const char *_redis_object_expanded)
+ : ctx(_ctx), task(_task), stcf(_ctx->stcf), redis_object_expanded(_redis_object_expanded)
+ {
+ rspamd_mempool_add_destructor(task->task_pool, redis_stat_runtime<T>::rt_dtor, this);
+ }
+
+ static auto maybe_recover_from_mempool(struct rspamd_task *task, const char *redis_object_expanded,
+ bool is_spam) -> std::optional<redis_stat_runtime<T> *>
+ {
+ auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H");
+ auto *res = rspamd_mempool_get_variable(task->task_pool, var_name.c_str());
+
+ if (res) {
+ msg_debug_bayes("recovered runtime from mempool at %s", var_name.c_str());
+ return reinterpret_cast<redis_stat_runtime<T> *>(res);
+ }
+ else {
+ msg_debug_bayes("no runtime at %s", var_name.c_str());
+ return std::nullopt;
+ }
+ }
+
+ void set_results(std::vector<std::pair<int, T>> *results)
+ {
+ this->results = results;
+ }
+
+ /* Propagate results from internal representation to the tokens array */
+ auto process_tokens(GPtrArray *tokens) const -> bool
+ {
+ rspamd_token_t *tok;
+
+ if (!results) {
+ return false;
+ }
+
+ for (auto [idx, val]: *results) {
+ tok = (rspamd_token_t *) g_ptr_array_index(tokens, idx - 1);
+ tok->values[id] = val;
+ }
+
+ return true;
+ }
+
+ auto save_in_mempool(bool is_spam) const
+ {
+ auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H");
+ /* We do not set destructor for the variable, as it should be already added on creation */
+ rspamd_mempool_set_variable(task->task_pool, var_name.c_str(), (gpointer) this, nullptr);
+ msg_debug_bayes("saved runtime in mempool at %s", var_name.c_str());
+ }
+};
+
+#define GET_TASK_ELT(task, elt) (task == nullptr ? nullptr : (task)->elt)
+
+static const gchar *M = "redis statistics";
+
+static GQuark
+rspamd_redis_stat_quark(void)
+{
+ return g_quark_from_static_string(M);
+}
+
+/*
+ * Non-static for lua unit testing
+ */
+gsize rspamd_redis_expand_object(const gchar *pattern,
+ struct redis_stat_ctx *ctx,
+ struct rspamd_task *task,
+ gchar **target)
+{
+ gsize tlen = 0;
+ const gchar *p = pattern, *elt;
+ gchar *d, *end;
+ enum {
+ just_char,
+ percent_char,
+ mod_char
+ } state = just_char;
+ struct rspamd_statfile_config *stcf;
+ lua_State *L = nullptr;
+ struct rspamd_task **ptask;
+ const gchar *rcpt = nullptr;
+ gint err_idx;
+
+ g_assert(ctx != nullptr);
+ g_assert(task != nullptr);
+ stcf = ctx->stcf;
+
+ L = RSPAMD_LUA_CFG_STATE(task->cfg);
+ g_assert(L != nullptr);
+
+ if (ctx->enable_users) {
+ if (ctx->cbref_user == -1) {
+ rcpt = rspamd_task_get_principal_recipient(task);
+ }
+ else {
+ /* Execute lua function to get userdata */
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ err_idx = lua_gettop(L);
+
+ lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->cbref_user);
+ ptask = (struct rspamd_task **) lua_newuserdata(L, sizeof(struct rspamd_task *));
+ *ptask = task;
+ rspamd_lua_setclass(L, "rspamd{task}", -1);
+
+ if (lua_pcall(L, 1, 1, err_idx) != 0) {
+ msg_err_task("call to user extraction script failed: %s",
+ lua_tostring(L, -1));
+ }
+ else {
+ rcpt = rspamd_mempool_strdup(task->task_pool, lua_tostring(L, -1));
+ }
+
+ /* Result + error function */
+ lua_settop(L, err_idx - 1);
+ }
+
+ if (rcpt) {
+ rspamd_mempool_set_variable(task->task_pool, "stat_user",
+ (gpointer) rcpt, nullptr);
+ }
+ }
+
+ /* Length calculation */
+ while (*p) {
+ switch (state) {
+ case just_char:
+ if (*p == '%') {
+ state = percent_char;
+ }
+ else {
+ tlen++;
+ }
+ p++;
+ break;
+ case percent_char:
+ switch (*p) {
+ case '%':
+ tlen++;
+ state = just_char;
+ break;
+ case 'u':
+ elt = GET_TASK_ELT(task, auth_user);
+ if (elt) {
+ tlen += strlen(elt);
+ }
+ break;
+ case 'r':
+
+ if (rcpt == nullptr) {
+ elt = rspamd_task_get_principal_recipient(task);
+ }
+ else {
+ elt = rcpt;
+ }
+
+ if (elt) {
+ tlen += strlen(elt);
+ }
+ break;
+ case 'l':
+ if (stcf->label) {
+ tlen += strlen(stcf->label);
+ }
+ /* Label miss is OK */
+ break;
+ case 's':
+ tlen += sizeof("RS") - 1;
+ break;
+ default:
+ state = just_char;
+ tlen++;
+ break;
+ }
+
+ if (state == percent_char) {
+ state = mod_char;
+ }
+ p++;
+ break;
+
+ case mod_char:
+ switch (*p) {
+ case 'd':
+ p++;
+ state = just_char;
+ break;
+ default:
+ state = just_char;
+ break;
+ }
+ break;
+ }
+ }
+
+
+ if (target == nullptr) {
+ return -1;
+ }
+
+ *target = (gchar *) rspamd_mempool_alloc(task->task_pool, tlen + 1);
+ d = *target;
+ end = d + tlen + 1;
+ d[tlen] = '\0';
+ p = pattern;
+ state = just_char;
+
+ /* Expand string */
+ while (*p && d < end) {
+ switch (state) {
+ case just_char:
+ if (*p == '%') {
+ state = percent_char;
+ }
+ else {
+ *d++ = *p;
+ }
+ p++;
+ break;
+ case percent_char:
+ switch (*p) {
+ case '%':
+ *d++ = *p;
+ state = just_char;
+ break;
+ case 'u':
+ elt = GET_TASK_ELT(task, auth_user);
+ if (elt) {
+ d += rspamd_strlcpy(d, elt, end - d);
+ }
+ break;
+ case 'r':
+ if (rcpt == nullptr) {
+ elt = rspamd_task_get_principal_recipient(task);
+ }
+ else {
+ elt = rcpt;
+ }
+
+ if (elt) {
+ d += rspamd_strlcpy(d, elt, end - d);
+ }
+ break;
+ case 'l':
+ if (stcf->label) {
+ d += rspamd_strlcpy(d, stcf->label, end - d);
+ }
+ break;
+ case 's':
+ d += rspamd_strlcpy(d, "RS", end - d);
+ break;
+ default:
+ state = just_char;
+ *d++ = *p;
+ break;
+ }
+
+ if (state == percent_char) {
+ state = mod_char;
+ }
+ p++;
+ break;
+
+ case mod_char:
+ switch (*p) {
+ case 'd':
+ /* TODO: not supported yet */
+ p++;
+ state = just_char;
+ break;
+ default:
+ state = just_char;
+ break;
+ }
+ break;
+ }
+ }
+
+ return tlen;
+}
+
+static int
+rspamd_redis_stat_cb(lua_State *L)
+{
+ const auto *cookie = lua_tostring(L, lua_upvalueindex(1));
+ auto *cfg = lua_check_config(L, 1);
+ auto *backend = REDIS_CTX(rspamd_mempool_get_variable(cfg->cfg_pool, cookie));
+
+ if (backend == nullptr) {
+ msg_err("internal error: cookie %s is not found", cookie);
+
+ return 0;
+ }
+
+ auto *cur_obj = ucl_object_lua_import(L, 2);
+ msg_debug_bayes_cfg("got stat object for %s", backend->stcf->symbol);
+ /* Enrich with some default values that are meaningless for redis */
+ ucl_object_insert_key(cur_obj,
+ ucl_object_typed_new(UCL_INT), "used", 0, false);
+ ucl_object_insert_key(cur_obj,
+ ucl_object_typed_new(UCL_INT), "total", 0, false);
+ ucl_object_insert_key(cur_obj,
+ ucl_object_typed_new(UCL_INT), "size", 0, false);
+ ucl_object_insert_key(cur_obj,
+ ucl_object_fromstring(backend->stcf->symbol),
+ "symbol", 0, false);
+ ucl_object_insert_key(cur_obj, ucl_object_fromstring("redis"),
+ "type", 0, false);
+ ucl_object_insert_key(cur_obj, ucl_object_fromint(0),
+ "languages", 0, false);
+
+ if (backend->cur_stat) {
+ ucl_object_unref(backend->cur_stat);
+ }
+
+ backend->cur_stat = cur_obj;
+
+ return 0;
+}
+
+static void
+rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend,
+ const ucl_object_t *statfile_obj,
+ const ucl_object_t *classifier_obj,
+ struct rspamd_config *cfg)
+{
+ const gchar *lua_script;
+ const ucl_object_t *elt, *users_enabled;
+ auto *L = RSPAMD_LUA_CFG_STATE(cfg);
+
+ users_enabled = ucl_object_lookup_any(classifier_obj, "per_user",
+ "users_enabled", nullptr);
+
+ if (users_enabled != nullptr) {
+ if (ucl_object_type(users_enabled) == UCL_BOOLEAN) {
+ backend->enable_users = ucl_object_toboolean(users_enabled);
+ backend->cbref_user = -1;
+ }
+ else if (ucl_object_type(users_enabled) == UCL_STRING) {
+ lua_script = ucl_object_tostring(users_enabled);
+
+ if (luaL_dostring(L, lua_script) != 0) {
+ msg_err_config("cannot execute lua script for users "
+ "extraction: %s",
+ lua_tostring(L, -1));
+ }
+ else {
+ if (lua_type(L, -1) == LUA_TFUNCTION) {
+ backend->enable_users = TRUE;
+ backend->cbref_user = luaL_ref(L,
+ LUA_REGISTRYINDEX);
+ }
+ else {
+ msg_err_config("lua script must return "
+ "function(task) and not %s",
+ lua_typename(L, lua_type(L, -1)));
+ }
+ }
+ }
+ }
+ else {
+ backend->enable_users = FALSE;
+ backend->cbref_user = -1;
+ }
+
+ elt = ucl_object_lookup(classifier_obj, "prefix");
+ if (elt == nullptr || ucl_object_type(elt) != UCL_STRING) {
+ /* Default non-users statistics */
+ if (backend->enable_users || backend->cbref_user != -1) {
+ backend->redis_object = REDIS_DEFAULT_USERS_OBJECT;
+ }
+ else {
+ backend->redis_object = REDIS_DEFAULT_OBJECT;
+ }
+ }
+ else {
+ /* XXX: sanity check */
+ backend->redis_object = ucl_object_tostring(elt);
+ }
+
+ elt = ucl_object_lookup(classifier_obj, "store_tokens");
+ if (elt) {
+ backend->store_tokens = ucl_object_toboolean(elt);
+ }
+ else {
+ backend->store_tokens = FALSE;
+ }
+
+ elt = ucl_object_lookup(classifier_obj, "signatures");
+ if (elt) {
+ backend->enable_signatures = ucl_object_toboolean(elt);
+ }
+ else {
+ backend->enable_signatures = FALSE;
+ }
+}
+
+gpointer
+rspamd_redis_init(struct rspamd_stat_ctx *ctx,
+ struct rspamd_config *cfg, struct rspamd_statfile *st)
+{
+ auto *L = RSPAMD_LUA_CFG_STATE(cfg);
+
+ auto backend = std::make_unique<struct redis_stat_ctx>(L);
+ lua_settop(L, 0);
+
+ rspamd_redis_parse_classifier_opts(backend.get(), st->stcf->opts, st->classifier->cfg->opts, cfg);
+
+ st->stcf->clcf->flags |= RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND;
+ backend->stcf = st->stcf;
+
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ auto err_idx = lua_gettop(L);
+
+ /* Obtain function */
+ if (!rspamd_lua_require_function(L, "lua_bayes_redis", "lua_bayes_init_statfile")) {
+ msg_err_config("cannot require lua_bayes_redis.lua_bayes_init_statfile");
+ lua_settop(L, err_idx - 1);
+
+ return nullptr;
+ }
+
+ /* Push arguments */
+ ucl_object_push_lua(L, st->classifier->cfg->opts, false);
+ ucl_object_push_lua(L, st->stcf->opts, false);
+ lua_pushstring(L, backend->stcf->symbol);
+ lua_pushboolean(L, backend->stcf->is_spam);
+ auto **pev_base = (struct ev_loop **) lua_newuserdata(L, sizeof(struct ev_loop *));
+ *pev_base = ctx->event_loop;
+ rspamd_lua_setclass(L, "rspamd{ev_base}", -1);
+
+ /* Store backend in random cookie */
+ char *cookie = (char *) rspamd_mempool_alloc(cfg->cfg_pool, 16);
+ rspamd_random_hex(cookie, 16);
+ cookie[15] = '\0';
+ rspamd_mempool_set_variable(cfg->cfg_pool, cookie, backend.get(), nullptr);
+ /* Callback + 1 upvalue */
+ lua_pushstring(L, cookie);
+ lua_pushcclosure(L, &rspamd_redis_stat_cb, 1);
+
+ if (lua_pcall(L, 6, 2, err_idx) != 0) {
+ msg_err("call to lua_bayes_init_classifier "
+ "script failed: %s",
+ lua_tostring(L, -1));
+ lua_settop(L, err_idx - 1);
+
+ return nullptr;
+ }
+
+ /* Results are in the stack:
+ * top - 1 - classifier function (idx = -2)
+ * top - learn function (idx = -1)
+ */
+
+ lua_pushvalue(L, -2);
+ backend->cbref_classify = luaL_ref(L, LUA_REGISTRYINDEX);
+
+ lua_pushvalue(L, -1);
+ backend->cbref_learn = luaL_ref(L, LUA_REGISTRYINDEX);
+
+ lua_settop(L, err_idx - 1);
+
+ return backend.release();
+}
+
+gpointer
+rspamd_redis_runtime(struct rspamd_task *task,
+ struct rspamd_statfile_config *stcf,
+ gboolean learn, gpointer c, gint _id)
+{
+ struct redis_stat_ctx *ctx = REDIS_CTX(c);
+ char *object_expanded = nullptr;
+
+ g_assert(ctx != nullptr);
+ g_assert(stcf != nullptr);
+
+ if (rspamd_redis_expand_object(ctx->redis_object, ctx, task,
+ &object_expanded) == 0) {
+ msg_err_task("expansion for %s failed for symbol %s "
+ "(maybe learning per user classifier with no user or recipient)",
+ learn ? "learning" : "classifying",
+ stcf->symbol);
+ return nullptr;
+ }
+
+ /* Look for the cached results */
+ if (!learn) {
+ auto maybe_existing = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+ object_expanded, stcf->is_spam);
+
+ if (maybe_existing) {
+ auto *rt = maybe_existing.value();
+ /* Update stcf and ctx to correspond to what we have been asked */
+ rt->stcf = stcf;
+ rt->ctx = ctx;
+ return rt;
+ }
+ }
+
+ /* No cached result (or learn), create new one */
+ auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
+
+ if (!learn) {
+ /*
+ * For check, we also need to create the opposite class runtime to avoid
+ * double call for Redis scripts.
+ * This runtime will be filled later.
+ */
+ auto maybe_opposite_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+ object_expanded,
+ !stcf->is_spam);
+
+ if (!maybe_opposite_rt) {
+ auto *opposite_rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
+ opposite_rt->save_in_mempool(!stcf->is_spam);
+ opposite_rt->need_redis_call = false;
+ }
+ }
+
+ rt->save_in_mempool(stcf->is_spam);
+
+ return rt;
+}
+
+void rspamd_redis_close(gpointer p)
+{
+ struct redis_stat_ctx *ctx = REDIS_CTX(p);
+ delete ctx;
+}
+
+static constexpr auto
+msgpack_emit_str(const std::string_view st, char *out) -> std::size_t
+{
+ auto len = st.size();
+ constexpr const unsigned char fix_mask = 0xA0, l8_ch = 0xd9, l16_ch = 0xda, l32_ch = 0xdb;
+ auto blen = 0;
+ if (len <= 0x1F) {
+ blen = 1;
+ out[0] = (len | fix_mask) & 0xff;
+ }
+ else if (len <= 0xff) {
+ blen = 2;
+ out[0] = l8_ch;
+ out[1] = len & 0xff;
+ }
+ else if (len <= 0xffff) {
+ uint16_t bl = GUINT16_TO_BE(len);
+
+ blen = 3;
+ out[0] = l16_ch;
+ memcpy(&out[1], &bl, sizeof(bl));
+ }
+ else {
+ uint32_t bl = GUINT32_TO_BE(len);
+
+ blen = 5;
+ out[0] = l32_ch;
+ memcpy(&out[1], &bl, sizeof(bl));
+ }
+
+ memcpy(&out[blen], st.data(), st.size());
+
+ return blen + len;
+}
+
+static constexpr auto
+msgpack_str_len(std::size_t len) -> std::size_t
+{
+ if (len <= 0x1F) {
+ return 1 + len;
+ }
+ else if (len <= 0xff) {
+ return 2 + len;
+ }
+ else if (len <= 0xffff) {
+ return 3 + len;
+ }
+ else {
+ return 4 + len;
+ }
+}
+
+/*
+ * Serialise stat tokens to message pack
+ */
+static char *
+rspamd_redis_serialize_tokens(struct rspamd_task *task, const gchar *prefix, GPtrArray *tokens, gsize *ser_len)
+{
+ /* Each token is int64_t that requires 10 bytes (2 int32_t) + 4 bytes array len + 1 byte array magic */
+ char max_int64_str[] = "18446744073709551615";
+ auto prefix_len = strlen(prefix);
+ std::size_t req_len = 5;
+ rspamd_token_t *tok;
+
+ /* Calculate required length */
+ req_len += tokens->len * (msgpack_str_len(sizeof(max_int64_str) + prefix_len) + 1);
+
+ auto *buf = (gchar *) rspamd_mempool_alloc(task->task_pool, req_len);
+ auto *p = buf;
+
+ /* Array */
+ *p++ = (gchar) 0xdd;
+ /* Length in big-endian (4 bytes) */
+ *p++ = (gchar) ((tokens->len >> 24) & 0xff);
+ *p++ = (gchar) ((tokens->len >> 16) & 0xff);
+ *p++ = (gchar) ((tokens->len >> 8) & 0xff);
+ *p++ = (gchar) (tokens->len & 0xff);
+
+
+ int i;
+ auto numbuf_len = sizeof(max_int64_str) + prefix_len + 1;
+ auto *numbuf = (char *) g_alloca(numbuf_len);
+
+ PTR_ARRAY_FOREACH(tokens, i, tok)
+ {
+ std::size_t r = rspamd_snprintf(numbuf, numbuf_len, "%s_%uL", prefix, tok->data);
+ auto shift = msgpack_emit_str({numbuf, r}, p);
+ p += shift;
+ }
+
+ *ser_len = p - buf;
+
+ return buf;
+}
+
+static char *
+rspamd_redis_serialize_text_tokens(struct rspamd_task *task, GPtrArray *tokens, gsize *ser_len)
+{
+ rspamd_token_t *tok;
+ auto req_len = 5; /* Messagepack array prefix */
+ int i;
+
+ /*
+ * First we need to determine the requested length
+ */
+ PTR_ARRAY_FOREACH(tokens, i, tok)
+ {
+ if (tok->t1 && tok->t2) {
+ /* Two tokens */
+ req_len += msgpack_str_len(tok->t1->stemmed.len) + msgpack_str_len(tok->t2->stemmed.len);
+ }
+ else if (tok->t1) {
+ req_len += msgpack_str_len(tok->t1->stemmed.len);
+ req_len += 1; /* null */
+ }
+ else {
+ req_len += 2; /* 2 nulls */
+ }
+ }
+
+ auto *buf = (gchar *) rspamd_mempool_alloc(task->task_pool, req_len);
+ auto *p = buf;
+
+ /* Array */
+ std::uint32_t nlen = tokens->len * 2;
+ nlen = GUINT32_TO_BE(nlen);
+ *p++ = (gchar) 0xdd;
+ /* Length in big-endian (4 bytes) */
+ memcpy(p, &nlen, sizeof(nlen));
+ p += sizeof(nlen);
+
+ PTR_ARRAY_FOREACH(tokens, i, tok)
+ {
+ if (tok->t1 && tok->t2) {
+ auto step = msgpack_emit_str({tok->t1->stemmed.begin, tok->t1->stemmed.len}, p);
+ p += step;
+ step = msgpack_emit_str({tok->t2->stemmed.begin, tok->t2->stemmed.len}, p);
+ p += step;
+ }
+ else if (tok->t1) {
+ auto step = msgpack_emit_str({tok->t1->stemmed.begin, tok->t1->stemmed.len}, p);
+ p += step;
+ *p++ = 0xc0;
+ }
+ else {
+ *p++ = 0xc0;
+ *p++ = 0xc0;
+ }
+ }
+
+ *ser_len = p - buf;
+
+ return buf;
+}
+
+static gint
+rspamd_redis_classified(lua_State *L)
+{
+ const auto *cookie = lua_tostring(L, lua_upvalueindex(1));
+ auto *task = lua_check_task(L, 1);
+ auto *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie));
+
+ if (rt == nullptr) {
+ msg_err_task("internal error: cannot find runtime for cookie %s", cookie);
+
+ return 0;
+ }
+
+ bool result = lua_toboolean(L, 2);
+
+ if (result) {
+ /* Indexes:
+ * 3 - learned_ham (int)
+ * 4 - learned_spam (int)
+ * 5 - ham_tokens (pair<int, int>)
+ * 6 - spam_tokens (pair<int, int>)
+ */
+
+ /*
+ * We need to fill our runtime AND the opposite runtime
+ */
+ auto filler_func = [](redis_stat_runtime<float> *rt, lua_State *L, unsigned learned, int tokens_pos) {
+ rt->learned = learned;
+ redis_stat_runtime<float>::result_type *res;
+
+ res = new redis_stat_runtime<float>::result_type();
+
+ for (lua_pushnil(L); lua_next(L, tokens_pos); lua_pop(L, 1)) {
+ lua_rawgeti(L, -1, 1);
+ auto idx = lua_tointeger(L, -1);
+ lua_pop(L, 1);
+
+ lua_rawgeti(L, -1, 2);
+ auto value = lua_tonumber(L, -1);
+ lua_pop(L, 1);
+
+ res->emplace_back(idx, value);
+ }
+
+ rt->set_results(res);
+ };
+
+ auto opposite_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+ rt->redis_object_expanded,
+ !rt->stcf->is_spam);
+
+ if (!opposite_rt_maybe) {
+ msg_err_task("internal error: cannot find opposite runtime for cookie %s", cookie);
+
+ return 0;
+ }
+
+ if (rt->stcf->is_spam) {
+ filler_func(rt, L, lua_tointeger(L, 4), 6);
+ filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 3), 5);
+ }
+ else {
+ filler_func(rt, L, lua_tointeger(L, 3), 5);
+ filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 4), 6);
+ }
+
+ /* Mark task as being processed */
+ task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS | RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
+
+ /* Process all tokens */
+ g_assert(rt->tokens != nullptr);
+ rt->process_tokens(rt->tokens);
+ opposite_rt_maybe.value()->process_tokens(rt->tokens);
+ }
+ else {
+ /* Error message is on index 3 */
+ const auto *err_msg = lua_tostring(L, 3);
+ rt->err = rspamd::util::error(err_msg, 500);
+ msg_err_task("cannot classify task: %s",
+ err_msg);
+ }
+
+ return 0;
+}
+
+gboolean
+rspamd_redis_process_tokens(struct rspamd_task *task,
+ GPtrArray *tokens,
+ gint id, gpointer p)
+{
+ auto *rt = REDIS_RUNTIME(p);
+ auto *L = rt->ctx->L;
+
+ if (rspamd_session_blocked(task->s)) {
+ return FALSE;
+ }
+
+ if (tokens == nullptr || tokens->len == 0) {
+ return FALSE;
+ }
+
+ if (!rt->need_redis_call) {
+ /* No need to do anything, as it is already done in the opposite class processing */
+ /* However, we need to store id as it is needed for further tokens processing */
+ rt->id = id;
+ rt->tokens = g_ptr_array_ref(tokens);
+
+ return TRUE;
+ }
+
+ gsize tokens_len;
+ gchar *tokens_buf = rspamd_redis_serialize_tokens(task, rt->redis_object_expanded, tokens, &tokens_len);
+ rt->id = id;
+
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ gint err_idx = lua_gettop(L);
+
+ /* Function arguments */
+ lua_rawgeti(L, LUA_REGISTRYINDEX, rt->ctx->cbref_classify);
+ rspamd_lua_task_push(L, task);
+ lua_pushstring(L, rt->redis_object_expanded);
+ lua_pushinteger(L, id);
+ lua_pushboolean(L, rt->stcf->is_spam);
+ lua_new_text(L, tokens_buf, tokens_len, false);
+
+ /* Store rt in random cookie */
+ char *cookie = (char *) rspamd_mempool_alloc(task->task_pool, 16);
+ rspamd_random_hex(cookie, 16);
+ cookie[15] = '\0';
+ rspamd_mempool_set_variable(task->task_pool, cookie, rt, nullptr);
+ /* Callback */
+ lua_pushstring(L, cookie);
+ lua_pushcclosure(L, &rspamd_redis_classified, 1);
+
+ if (lua_pcall(L, 6, 0, err_idx) != 0) {
+ msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
+ lua_settop(L, err_idx - 1);
+ return FALSE;
+ }
+
+ rt->tokens = g_ptr_array_ref(tokens);
+
+ lua_settop(L, err_idx - 1);
+ return TRUE;
+}
+
+gboolean
+rspamd_redis_finalize_process(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ auto *rt = REDIS_RUNTIME(runtime);
+
+ return !rt->err.has_value();
+}
+
+
+static gint
+rspamd_redis_learned(lua_State *L)
+{
+ const auto *cookie = lua_tostring(L, lua_upvalueindex(1));
+ auto *task = lua_check_task(L, 1);
+ auto *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie));
+
+ if (rt == nullptr) {
+ msg_err_task("internal error: cannot find runtime for cookie %s", cookie);
+
+ return 0;
+ }
+
+ bool result = lua_toboolean(L, 2);
+
+ if (result) {
+ /* TODO: write it */
+ }
+ else {
+ /* Error message is on index 3 */
+ const auto *err_msg = lua_tostring(L, 3);
+ rt->err = rspamd::util::error(err_msg, 500);
+ msg_err_task("cannot learn task: %s", err_msg);
+ }
+
+ return 0;
+}
+
+gboolean
+rspamd_redis_learn_tokens(struct rspamd_task *task,
+ GPtrArray *tokens,
+ gint id, gpointer p)
+{
+ auto *rt = REDIS_RUNTIME(p);
+ auto *L = rt->ctx->L;
+
+ if (rspamd_session_blocked(task->s)) {
+ return FALSE;
+ }
+
+ if (tokens == nullptr || tokens->len == 0) {
+ return FALSE;
+ }
+
+ gsize tokens_len;
+ gchar *tokens_buf = rspamd_redis_serialize_tokens(task, rt->redis_object_expanded, tokens, &tokens_len);
+
+ rt->id = id;
+
+ gsize text_tokens_len = 0;
+ gchar *text_tokens_buf = nullptr;
+
+ if (rt->ctx->store_tokens) {
+ text_tokens_buf = rspamd_redis_serialize_text_tokens(task, tokens, &text_tokens_len);
+ }
+
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ gint err_idx = lua_gettop(L);
+ auto nargs = 8;
+
+ /* Function arguments */
+ lua_rawgeti(L, LUA_REGISTRYINDEX, rt->ctx->cbref_learn);
+ rspamd_lua_task_push(L, task);
+ lua_pushstring(L, rt->redis_object_expanded);
+ lua_pushinteger(L, id);
+ lua_pushboolean(L, rt->stcf->is_spam);
+ lua_pushstring(L, rt->stcf->symbol);
+
+ /* Detect unlearn */
+ auto *tok = (rspamd_token_t *) g_ptr_array_index(task->tokens, 0);
+
+ if (tok->values[id] > 0) {
+ lua_pushboolean(L, FALSE);// Learn
+ }
+ else {
+ lua_pushboolean(L, TRUE);// Unlearn
+ }
+ lua_new_text(L, tokens_buf, tokens_len, false);
+
+ /* Store rt in random cookie */
+ char *cookie = (char *) rspamd_mempool_alloc(task->task_pool, 16);
+ rspamd_random_hex(cookie, 16);
+ cookie[15] = '\0';
+ rspamd_mempool_set_variable(task->task_pool, cookie, rt, nullptr);
+ /* Callback */
+ lua_pushstring(L, cookie);
+ lua_pushcclosure(L, &rspamd_redis_learned, 1);
+
+ if (text_tokens_len) {
+ nargs = 9;
+ lua_new_text(L, text_tokens_buf, text_tokens_len, false);
+ }
+
+ if (lua_pcall(L, nargs, 0, err_idx) != 0) {
+ msg_err_task("call to script failed: %s", lua_tostring(L, -1));
+ lua_settop(L, err_idx - 1);
+ return FALSE;
+ }
+
+ rt->tokens = g_ptr_array_ref(tokens);
+
+ lua_settop(L, err_idx - 1);
+ return TRUE;
+}
+
+
+gboolean
+rspamd_redis_finalize_learn(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx, GError **err)
+{
+ auto *rt = REDIS_RUNTIME(runtime);
+
+ if (rt->err.has_value()) {
+ rt->err->into_g_error_set(rspamd_redis_stat_quark(), err);
+
+ return FALSE;
+ }
+
+ return TRUE;
+}
+
+gulong
+rspamd_redis_total_learns(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ auto *rt = REDIS_RUNTIME(runtime);
+
+ return rt->learned;
+}
+
+gulong
+rspamd_redis_inc_learns(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ auto *rt = REDIS_RUNTIME(runtime);
+
+ /* XXX: may cause races */
+ return rt->learned + 1;
+}
+
+gulong
+rspamd_redis_dec_learns(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ auto *rt = REDIS_RUNTIME(runtime);
+
+ /* XXX: may cause races */
+ return rt->learned + 1;
+}
+
+gulong
+rspamd_redis_learns(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ auto *rt = REDIS_RUNTIME(runtime);
+
+ return rt->learned;
+}
+
+ucl_object_t *
+rspamd_redis_get_stat(gpointer runtime,
+ gpointer ctx)
+{
+ auto *rt = REDIS_RUNTIME(runtime);
+
+ return ucl_object_ref(rt->ctx->cur_stat);
+}
+
+gpointer
+rspamd_redis_load_tokenizer_config(gpointer runtime,
+ gsize *len)
+{
+ return nullptr;
+}
diff --git a/src/libstat/backends/sqlite3_backend.c b/src/libstat/backends/sqlite3_backend.c
new file mode 100644
index 0000000..2fd34d8
--- /dev/null
+++ b/src/libstat/backends/sqlite3_backend.c
@@ -0,0 +1,907 @@
+/*-
+ * Copyright 2016 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "config.h"
+#include "rspamd.h"
+#include "sqlite3.h"
+#include "libutil/sqlite_utils.h"
+#include "libstat/stat_internal.h"
+#include "libmime/message.h"
+#include "lua/lua_common.h"
+#include "unix-std.h"
+
+#define SQLITE3_BACKEND_TYPE "sqlite3"
+#define SQLITE3_SCHEMA_VERSION "1"
+#define SQLITE3_DEFAULT "default"
+
+struct rspamd_stat_sqlite3_db {
+ sqlite3 *sqlite;
+ gchar *fname;
+ GArray *prstmt;
+ lua_State *L;
+ rspamd_mempool_t *pool;
+ gboolean in_transaction;
+ gboolean enable_users;
+ gboolean enable_languages;
+ gint cbref_user;
+ gint cbref_language;
+};
+
+struct rspamd_stat_sqlite3_rt {
+ struct rspamd_task *task;
+ struct rspamd_stat_sqlite3_db *db;
+ struct rspamd_statfile_config *cf;
+ gint64 user_id;
+ gint64 lang_id;
+};
+
+static const char *create_tables_sql =
+ "BEGIN IMMEDIATE;"
+ "CREATE TABLE tokenizer(data BLOB);"
+ "CREATE TABLE users("
+ "id INTEGER PRIMARY KEY,"
+ "name TEXT,"
+ "learns INTEGER"
+ ");"
+ "CREATE TABLE languages("
+ "id INTEGER PRIMARY KEY,"
+ "name TEXT,"
+ "learns INTEGER"
+ ");"
+ "CREATE TABLE tokens("
+ "token INTEGER NOT NULL,"
+ "user INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,"
+ "language INTEGER NOT NULL REFERENCES languages(id) ON DELETE CASCADE,"
+ "value INTEGER,"
+ "modified INTEGER,"
+ "CONSTRAINT tid UNIQUE (token, user, language) ON CONFLICT REPLACE"
+ ");"
+ "CREATE UNIQUE INDEX IF NOT EXISTS un ON users(name);"
+ "CREATE INDEX IF NOT EXISTS tok ON tokens(token);"
+ "CREATE UNIQUE INDEX IF NOT EXISTS ln ON languages(name);"
+ "PRAGMA user_version=" SQLITE3_SCHEMA_VERSION ";"
+ "INSERT INTO users(id, name, learns) VALUES(0, '" SQLITE3_DEFAULT "',0);"
+ "INSERT INTO languages(id, name, learns) VALUES(0, '" SQLITE3_DEFAULT "',0);"
+ "COMMIT;";
+
+enum rspamd_stat_sqlite3_stmt_idx {
+ RSPAMD_STAT_BACKEND_TRANSACTION_START_IM = 0,
+ RSPAMD_STAT_BACKEND_TRANSACTION_START_DEF,
+ RSPAMD_STAT_BACKEND_TRANSACTION_START_EXCL,
+ RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT,
+ RSPAMD_STAT_BACKEND_TRANSACTION_ROLLBACK,
+ RSPAMD_STAT_BACKEND_GET_TOKEN_FULL,
+ RSPAMD_STAT_BACKEND_GET_TOKEN_SIMPLE,
+ RSPAMD_STAT_BACKEND_SET_TOKEN,
+ RSPAMD_STAT_BACKEND_INC_LEARNS_LANG,
+ RSPAMD_STAT_BACKEND_INC_LEARNS_USER,
+ RSPAMD_STAT_BACKEND_DEC_LEARNS_LANG,
+ RSPAMD_STAT_BACKEND_DEC_LEARNS_USER,
+ RSPAMD_STAT_BACKEND_GET_LEARNS,
+ RSPAMD_STAT_BACKEND_GET_LANGUAGE,
+ RSPAMD_STAT_BACKEND_GET_USER,
+ RSPAMD_STAT_BACKEND_INSERT_USER,
+ RSPAMD_STAT_BACKEND_INSERT_LANGUAGE,
+ RSPAMD_STAT_BACKEND_SAVE_TOKENIZER,
+ RSPAMD_STAT_BACKEND_LOAD_TOKENIZER,
+ RSPAMD_STAT_BACKEND_NTOKENS,
+ RSPAMD_STAT_BACKEND_NLANGUAGES,
+ RSPAMD_STAT_BACKEND_NUSERS,
+ RSPAMD_STAT_BACKEND_MAX
+};
+
+static struct rspamd_sqlite3_prstmt prepared_stmts[RSPAMD_STAT_BACKEND_MAX] =
+ {
+ [RSPAMD_STAT_BACKEND_TRANSACTION_START_IM] = {
+ .idx = RSPAMD_STAT_BACKEND_TRANSACTION_START_IM,
+ .sql = "BEGIN IMMEDIATE TRANSACTION;",
+ .args = "",
+ .stmt = NULL,
+ .result = SQLITE_DONE,
+ .flags = 0,
+ .ret = "",
+ },
+ [RSPAMD_STAT_BACKEND_TRANSACTION_START_DEF] = {.idx = RSPAMD_STAT_BACKEND_TRANSACTION_START_DEF, .sql = "BEGIN DEFERRED TRANSACTION;", .args = "", .stmt = NULL, .result = SQLITE_DONE, .flags = 0, .ret = ""},
+ [RSPAMD_STAT_BACKEND_TRANSACTION_START_EXCL] = {.idx = RSPAMD_STAT_BACKEND_TRANSACTION_START_EXCL, .sql = "BEGIN EXCLUSIVE TRANSACTION;", .args = "", .stmt = NULL, .result = SQLITE_DONE, .flags = 0, .ret = ""},
+ [RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT] = {.idx = RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT, .sql = "COMMIT;", .args = "", .stmt = NULL, .result = SQLITE_DONE, .flags = 0, .ret = ""},
+ [RSPAMD_STAT_BACKEND_TRANSACTION_ROLLBACK] = {.idx = RSPAMD_STAT_BACKEND_TRANSACTION_ROLLBACK, .sql = "ROLLBACK;", .args = "", .stmt = NULL, .result = SQLITE_DONE, .flags = 0, .ret = ""},
+ [RSPAMD_STAT_BACKEND_GET_TOKEN_FULL] = {.idx = RSPAMD_STAT_BACKEND_GET_TOKEN_FULL, .sql = "SELECT value FROM tokens "
+ "LEFT JOIN languages ON tokens.language=languages.id "
+ "LEFT JOIN users ON tokens.user=users.id "
+ "WHERE token=?1 AND (users.id=?2) "
+ "AND (languages.id=?3 OR languages.id=0);",
+ .stmt = NULL,
+ .args = "III",
+ .result = SQLITE_ROW,
+ .flags = 0,
+ .ret = "I"},
+ [RSPAMD_STAT_BACKEND_GET_TOKEN_SIMPLE] = {.idx = RSPAMD_STAT_BACKEND_GET_TOKEN_SIMPLE, .sql = "SELECT value FROM tokens WHERE token=?1", .stmt = NULL, .args = "I", .result = SQLITE_ROW, .flags = 0, .ret = "I"},
+ [RSPAMD_STAT_BACKEND_SET_TOKEN] = {.idx = RSPAMD_STAT_BACKEND_SET_TOKEN, .sql = "INSERT OR REPLACE INTO tokens (token, user, language, value, modified) "
+ "VALUES (?1, ?2, ?3, ?4, strftime('%s','now'))",
+ .stmt = NULL,
+ .args = "IIII",
+ .result = SQLITE_DONE,
+ .flags = 0,
+ .ret = ""},
+ [RSPAMD_STAT_BACKEND_INC_LEARNS_LANG] = {.idx = RSPAMD_STAT_BACKEND_INC_LEARNS_LANG, .sql = "UPDATE languages SET learns=learns + 1 WHERE id=?1", .stmt = NULL, .args = "I", .result = SQLITE_DONE, .flags = 0, .ret = ""},
+ [RSPAMD_STAT_BACKEND_INC_LEARNS_USER] = {.idx = RSPAMD_STAT_BACKEND_INC_LEARNS_USER, .sql = "UPDATE users SET learns=learns + 1 WHERE id=?1", .stmt = NULL, .args = "I", .result = SQLITE_DONE, .flags = 0, .ret = ""},
+ [RSPAMD_STAT_BACKEND_DEC_LEARNS_LANG] = {.idx = RSPAMD_STAT_BACKEND_DEC_LEARNS_LANG, .sql = "UPDATE languages SET learns=MAX(0, learns - 1) WHERE id=?1", .stmt = NULL, .args = "I", .result = SQLITE_DONE, .flags = 0, .ret = ""},
+ [RSPAMD_STAT_BACKEND_DEC_LEARNS_USER] = {.idx = RSPAMD_STAT_BACKEND_DEC_LEARNS_USER, .sql = "UPDATE users SET learns=MAX(0, learns - 1) WHERE id=?1", .stmt = NULL, .args = "I", .result = SQLITE_DONE, .flags = 0, .ret = ""},
+ [RSPAMD_STAT_BACKEND_GET_LEARNS] = {.idx = RSPAMD_STAT_BACKEND_GET_LEARNS, .sql = "SELECT SUM(MAX(0, learns)) FROM languages", .stmt = NULL, .args = "", .result = SQLITE_ROW, .flags = 0, .ret = "I"},
+ [RSPAMD_STAT_BACKEND_GET_LANGUAGE] = {.idx = RSPAMD_STAT_BACKEND_GET_LANGUAGE, .sql = "SELECT id FROM languages WHERE name=?1", .stmt = NULL, .args = "T", .result = SQLITE_ROW, .flags = 0, .ret = "I"},
+ [RSPAMD_STAT_BACKEND_GET_USER] = {.idx = RSPAMD_STAT_BACKEND_GET_USER, .sql = "SELECT id FROM users WHERE name=?1", .stmt = NULL, .args = "T", .result = SQLITE_ROW, .flags = 0, .ret = "I"},
+ [RSPAMD_STAT_BACKEND_INSERT_USER] = {.idx = RSPAMD_STAT_BACKEND_INSERT_USER, .sql = "INSERT INTO users (name, learns) VALUES (?1, 0)", .stmt = NULL, .args = "T", .result = SQLITE_DONE, .flags = 0, .ret = "L"},
+ [RSPAMD_STAT_BACKEND_INSERT_LANGUAGE] = {.idx = RSPAMD_STAT_BACKEND_INSERT_LANGUAGE, .sql = "INSERT INTO languages (name, learns) VALUES (?1, 0)", .stmt = NULL, .args = "T", .result = SQLITE_DONE, .flags = 0, .ret = "L"},
+ [RSPAMD_STAT_BACKEND_SAVE_TOKENIZER] = {.idx = RSPAMD_STAT_BACKEND_SAVE_TOKENIZER, .sql = "INSERT INTO tokenizer(data) VALUES (?1)", .stmt = NULL, .args = "B", .result = SQLITE_DONE, .flags = 0, .ret = ""},
+ [RSPAMD_STAT_BACKEND_LOAD_TOKENIZER] = {.idx = RSPAMD_STAT_BACKEND_LOAD_TOKENIZER, .sql = "SELECT data FROM tokenizer", .stmt = NULL, .args = "", .result = SQLITE_ROW, .flags = 0, .ret = "B"},
+ [RSPAMD_STAT_BACKEND_NTOKENS] = {.idx = RSPAMD_STAT_BACKEND_NTOKENS, .sql = "SELECT COUNT(*) FROM tokens", .stmt = NULL, .args = "", .result = SQLITE_ROW, .flags = 0, .ret = "I"},
+ [RSPAMD_STAT_BACKEND_NLANGUAGES] = {.idx = RSPAMD_STAT_BACKEND_NLANGUAGES, .sql = "SELECT COUNT(*) FROM languages", .stmt = NULL, .args = "", .result = SQLITE_ROW, .flags = 0, .ret = "I"},
+ [RSPAMD_STAT_BACKEND_NUSERS] = {.idx = RSPAMD_STAT_BACKEND_NUSERS, .sql = "SELECT COUNT(*) FROM users", .stmt = NULL, .args = "", .result = SQLITE_ROW, .flags = 0, .ret = "I"}};
+
+static GQuark
+rspamd_sqlite3_backend_quark(void)
+{
+ return g_quark_from_static_string("sqlite3-stat-backend");
+}
+
+static gint64
+rspamd_sqlite3_get_user(struct rspamd_stat_sqlite3_db *db,
+ struct rspamd_task *task, gboolean learn)
+{
+ gint64 id = 0; /* Default user is 0 */
+ gint rc, err_idx;
+ const gchar *user = NULL;
+ struct rspamd_task **ptask;
+ lua_State *L = db->L;
+
+ if (db->cbref_user == -1) {
+ user = rspamd_task_get_principal_recipient(task);
+ }
+ else {
+ /* Execute lua function to get userdata */
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ err_idx = lua_gettop(L);
+
+ lua_rawgeti(L, LUA_REGISTRYINDEX, db->cbref_user);
+ ptask = lua_newuserdata(L, sizeof(struct rspamd_task *));
+ *ptask = task;
+ rspamd_lua_setclass(L, "rspamd{task}", -1);
+
+ if (lua_pcall(L, 1, 1, err_idx) != 0) {
+ msg_err_task("call to user extraction script failed: %s",
+ lua_tostring(L, -1));
+ }
+ else {
+ user = rspamd_mempool_strdup(task->task_pool, lua_tostring(L, -1));
+ }
+
+ /* Result + error function */
+ lua_settop(L, err_idx - 1);
+ }
+
+
+ if (user != NULL) {
+ rspamd_mempool_set_variable(task->task_pool, "stat_user",
+ (gpointer) user, NULL);
+
+ rc = rspamd_sqlite3_run_prstmt(task->task_pool, db->sqlite, db->prstmt,
+ RSPAMD_STAT_BACKEND_GET_USER, user, &id);
+
+ if (rc != SQLITE_OK && learn) {
+ /* We need to insert a new user */
+ if (!db->in_transaction) {
+ rspamd_sqlite3_run_prstmt(task->task_pool, db->sqlite, db->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_START_IM);
+ db->in_transaction = TRUE;
+ }
+
+ rc = rspamd_sqlite3_run_prstmt(task->task_pool, db->sqlite, db->prstmt,
+ RSPAMD_STAT_BACKEND_INSERT_USER, user, &id);
+ }
+ }
+
+ return id;
+}
+
+static gint64
+rspamd_sqlite3_get_language(struct rspamd_stat_sqlite3_db *db,
+ struct rspamd_task *task, gboolean learn)
+{
+ gint64 id = 0; /* Default language is 0 */
+ gint rc, err_idx;
+ guint i;
+ const gchar *language = NULL;
+ struct rspamd_mime_text_part *tp;
+ struct rspamd_task **ptask;
+ lua_State *L = db->L;
+
+ if (db->cbref_language == -1) {
+ PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, tp)
+ {
+
+ if (tp->language != NULL && tp->language[0] != '\0' &&
+ strcmp(tp->language, "en") != 0) {
+ language = tp->language;
+ break;
+ }
+ }
+ }
+ else {
+ /* Execute lua function to get userdata */
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ err_idx = lua_gettop(L);
+
+ lua_rawgeti(L, LUA_REGISTRYINDEX, db->cbref_language);
+ ptask = lua_newuserdata(L, sizeof(struct rspamd_task *));
+ *ptask = task;
+ rspamd_lua_setclass(L, "rspamd{task}", -1);
+
+ if (lua_pcall(L, 1, 1, err_idx) != 0) {
+ msg_err_task("call to language extraction script failed: %s",
+ lua_tostring(L, -1));
+ }
+ else {
+ language = rspamd_mempool_strdup(task->task_pool,
+ lua_tostring(L, -1));
+ }
+
+ /* Result + error function */
+ lua_settop(L, err_idx - 1);
+ }
+
+
+ /* XXX: We ignore multiple languages but default + extra */
+ if (language != NULL) {
+ rc = rspamd_sqlite3_run_prstmt(task->task_pool, db->sqlite, db->prstmt,
+ RSPAMD_STAT_BACKEND_GET_LANGUAGE, language, &id);
+
+ if (rc != SQLITE_OK && learn) {
+ /* We need to insert a new language */
+ if (!db->in_transaction) {
+ rspamd_sqlite3_run_prstmt(task->task_pool, db->sqlite, db->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_START_IM);
+ db->in_transaction = TRUE;
+ }
+
+ rc = rspamd_sqlite3_run_prstmt(task->task_pool, db->sqlite, db->prstmt,
+ RSPAMD_STAT_BACKEND_INSERT_LANGUAGE, language, &id);
+ }
+ }
+
+ return id;
+}
+
+static struct rspamd_stat_sqlite3_db *
+rspamd_sqlite3_opendb(rspamd_mempool_t *pool,
+ struct rspamd_statfile_config *stcf,
+ const gchar *path, const ucl_object_t *opts,
+ gboolean create, GError **err)
+{
+ struct rspamd_stat_sqlite3_db *bk;
+ struct rspamd_stat_tokenizer *tokenizer;
+ gpointer tk_conf;
+ gsize sz = 0;
+ gint64 sz64 = 0;
+ gchar *tok_conf_encoded;
+ gint ret, ntries = 0;
+ const gint max_tries = 100;
+ struct timespec sleep_ts = {
+ .tv_sec = 0,
+ .tv_nsec = 1000000};
+
+ bk = g_malloc0(sizeof(*bk));
+ bk->sqlite = rspamd_sqlite3_open_or_create(pool, path, create_tables_sql,
+ 0, err);
+ bk->pool = pool;
+
+ if (bk->sqlite == NULL) {
+ g_free(bk);
+
+ return NULL;
+ }
+
+ bk->fname = g_strdup(path);
+
+ bk->prstmt = rspamd_sqlite3_init_prstmt(bk->sqlite, prepared_stmts,
+ RSPAMD_STAT_BACKEND_MAX, err);
+
+ if (bk->prstmt == NULL) {
+ sqlite3_close(bk->sqlite);
+ g_free(bk);
+
+ return NULL;
+ }
+
+ /* Check tokenizer configuration */
+ if (rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_LOAD_TOKENIZER, &sz64, &tk_conf) != SQLITE_OK ||
+ sz64 == 0) {
+
+ while ((ret = rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_START_EXCL)) == SQLITE_BUSY &&
+ ++ntries <= max_tries) {
+ nanosleep(&sleep_ts, NULL);
+ }
+
+ msg_info_pool("absent tokenizer conf in %s, creating a new one",
+ bk->fname);
+ g_assert(stcf->clcf->tokenizer != NULL);
+ tokenizer = rspamd_stat_get_tokenizer(stcf->clcf->tokenizer->name);
+ g_assert(tokenizer != NULL);
+ tk_conf = tokenizer->get_config(pool, stcf->clcf->tokenizer, &sz);
+
+ /* Encode to base32 */
+ tok_conf_encoded = rspamd_encode_base32(tk_conf, sz, RSPAMD_BASE32_DEFAULT);
+
+ if (rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_SAVE_TOKENIZER,
+ (gint64) strlen(tok_conf_encoded),
+ tok_conf_encoded) != SQLITE_OK) {
+ sqlite3_close(bk->sqlite);
+ g_free(bk);
+ g_free(tok_conf_encoded);
+
+ return NULL;
+ }
+
+ rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT);
+ g_free(tok_conf_encoded);
+ }
+ else {
+ g_free(tk_conf);
+ }
+
+ return bk;
+}
+
+gpointer
+rspamd_sqlite3_init(struct rspamd_stat_ctx *ctx,
+ struct rspamd_config *cfg,
+ struct rspamd_statfile *st)
+{
+ struct rspamd_classifier_config *clf = st->classifier->cfg;
+ struct rspamd_statfile_config *stf = st->stcf;
+ const ucl_object_t *filenameo, *lang_enabled, *users_enabled;
+ const gchar *filename, *lua_script;
+ struct rspamd_stat_sqlite3_db *bk;
+ GError *err = NULL;
+
+ filenameo = ucl_object_lookup(stf->opts, "filename");
+ if (filenameo == NULL || ucl_object_type(filenameo) != UCL_STRING) {
+ filenameo = ucl_object_lookup(stf->opts, "path");
+ if (filenameo == NULL || ucl_object_type(filenameo) != UCL_STRING) {
+ msg_err_config("statfile %s has no filename defined", stf->symbol);
+ return NULL;
+ }
+ }
+
+ filename = ucl_object_tostring(filenameo);
+
+ if ((bk = rspamd_sqlite3_opendb(cfg->cfg_pool, stf, filename,
+ stf->opts, TRUE, &err)) == NULL) {
+ msg_err_config("cannot open sqlite3 db %s: %e", filename, err);
+ g_error_free(err);
+ return NULL;
+ }
+
+ bk->L = cfg->lua_state;
+
+ users_enabled = ucl_object_lookup_any(clf->opts, "per_user",
+ "users_enabled", NULL);
+ if (users_enabled != NULL) {
+ if (ucl_object_type(users_enabled) == UCL_BOOLEAN) {
+ bk->enable_users = ucl_object_toboolean(users_enabled);
+ bk->cbref_user = -1;
+ }
+ else if (ucl_object_type(users_enabled) == UCL_STRING) {
+ lua_script = ucl_object_tostring(users_enabled);
+
+ if (luaL_dostring(cfg->lua_state, lua_script) != 0) {
+ msg_err_config("cannot execute lua script for users "
+ "extraction: %s",
+ lua_tostring(cfg->lua_state, -1));
+ }
+ else {
+ if (lua_type(cfg->lua_state, -1) == LUA_TFUNCTION) {
+ bk->enable_users = TRUE;
+ bk->cbref_user = luaL_ref(cfg->lua_state,
+ LUA_REGISTRYINDEX);
+ }
+ else {
+ msg_err_config("lua script must return "
+ "function(task) and not %s",
+ lua_typename(cfg->lua_state, lua_type(
+ cfg->lua_state, -1)));
+ }
+ }
+ }
+ }
+ else {
+ bk->enable_users = FALSE;
+ }
+
+ lang_enabled = ucl_object_lookup_any(clf->opts,
+ "per_language", "languages_enabled", NULL);
+
+ if (lang_enabled != NULL) {
+ if (ucl_object_type(lang_enabled) == UCL_BOOLEAN) {
+ bk->enable_languages = ucl_object_toboolean(lang_enabled);
+ bk->cbref_language = -1;
+ }
+ else if (ucl_object_type(lang_enabled) == UCL_STRING) {
+ lua_script = ucl_object_tostring(lang_enabled);
+
+ if (luaL_dostring(cfg->lua_state, lua_script) != 0) {
+ msg_err_config(
+ "cannot execute lua script for languages "
+ "extraction: %s",
+ lua_tostring(cfg->lua_state, -1));
+ }
+ else {
+ if (lua_type(cfg->lua_state, -1) == LUA_TFUNCTION) {
+ bk->enable_languages = TRUE;
+ bk->cbref_language = luaL_ref(cfg->lua_state,
+ LUA_REGISTRYINDEX);
+ }
+ else {
+ msg_err_config("lua script must return "
+ "function(task) and not %s",
+ lua_typename(cfg->lua_state,
+ lua_type(cfg->lua_state, -1)));
+ }
+ }
+ }
+ }
+ else {
+ bk->enable_languages = FALSE;
+ }
+
+ if (bk->enable_languages) {
+ msg_info_config("enable per language statistics for %s",
+ stf->symbol);
+ }
+
+ if (bk->enable_users) {
+ msg_info_config("enable per users statistics for %s",
+ stf->symbol);
+ }
+
+
+ return (gpointer) bk;
+}
+
+void rspamd_sqlite3_close(gpointer p)
+{
+ struct rspamd_stat_sqlite3_db *bk = p;
+
+ if (bk->sqlite) {
+ if (bk->in_transaction) {
+ rspamd_sqlite3_run_prstmt(bk->pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT);
+ }
+
+ rspamd_sqlite3_close_prstmt(bk->sqlite, bk->prstmt);
+ sqlite3_close(bk->sqlite);
+ g_free(bk->fname);
+ g_free(bk);
+ }
+}
+
+gpointer
+rspamd_sqlite3_runtime(struct rspamd_task *task,
+ struct rspamd_statfile_config *stcf, gboolean learn, gpointer p, gint _id)
+{
+ struct rspamd_stat_sqlite3_rt *rt = NULL;
+ struct rspamd_stat_sqlite3_db *bk = p;
+
+ if (bk) {
+ rt = rspamd_mempool_alloc(task->task_pool, sizeof(*rt));
+ rt->db = bk;
+ rt->task = task;
+ rt->user_id = -1;
+ rt->lang_id = -1;
+ rt->cf = stcf;
+ }
+
+ return rt;
+}
+
+gboolean
+rspamd_sqlite3_process_tokens(struct rspamd_task *task,
+ GPtrArray *tokens,
+ gint id, gpointer p)
+{
+ struct rspamd_stat_sqlite3_db *bk;
+ struct rspamd_stat_sqlite3_rt *rt = p;
+ gint64 iv = 0;
+ guint i;
+ rspamd_token_t *tok;
+
+ g_assert(p != NULL);
+ g_assert(tokens != NULL);
+
+ bk = rt->db;
+
+ for (i = 0; i < tokens->len; i++) {
+ tok = g_ptr_array_index(tokens, i);
+
+ if (bk == NULL) {
+ /* Statfile is does not exist, so all values are zero */
+ tok->values[id] = 0.0f;
+ continue;
+ }
+
+ if (!bk->in_transaction) {
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_START_DEF);
+ bk->in_transaction = TRUE;
+ }
+
+ if (rt->user_id == -1) {
+ if (bk->enable_users) {
+ rt->user_id = rspamd_sqlite3_get_user(bk, task, FALSE);
+ }
+ else {
+ rt->user_id = 0;
+ }
+ }
+
+ if (rt->lang_id == -1) {
+ if (bk->enable_languages) {
+ rt->lang_id = rspamd_sqlite3_get_language(bk, task, FALSE);
+ }
+ else {
+ rt->lang_id = 0;
+ }
+ }
+
+ if (bk->enable_languages || bk->enable_users) {
+ if (rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_GET_TOKEN_FULL,
+ tok->data, rt->user_id, rt->lang_id, &iv) == SQLITE_OK) {
+ tok->values[id] = iv;
+ }
+ else {
+ tok->values[id] = 0.0f;
+ }
+ }
+ else {
+ if (rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_GET_TOKEN_SIMPLE,
+ tok->data, &iv) == SQLITE_OK) {
+ tok->values[id] = iv;
+ }
+ else {
+ tok->values[id] = 0.0f;
+ }
+ }
+
+ if (rt->cf->is_spam) {
+ task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS;
+ }
+ else {
+ task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
+ }
+ }
+
+
+ return TRUE;
+}
+
+gboolean
+rspamd_sqlite3_finalize_process(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ struct rspamd_stat_sqlite3_rt *rt = runtime;
+ struct rspamd_stat_sqlite3_db *bk;
+
+ g_assert(rt != NULL);
+ bk = rt->db;
+
+ if (bk->in_transaction) {
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT);
+ bk->in_transaction = FALSE;
+ }
+
+ rt->lang_id = -1;
+ rt->user_id = -1;
+
+ return TRUE;
+}
+
+gboolean
+rspamd_sqlite3_learn_tokens(struct rspamd_task *task, GPtrArray *tokens,
+ gint id, gpointer p)
+{
+ struct rspamd_stat_sqlite3_db *bk;
+ struct rspamd_stat_sqlite3_rt *rt = p;
+ gint64 iv = 0;
+ guint i;
+ rspamd_token_t *tok;
+
+ g_assert(tokens != NULL);
+ g_assert(p != NULL);
+
+ bk = rt->db;
+
+ for (i = 0; i < tokens->len; i++) {
+ tok = g_ptr_array_index(tokens, i);
+ if (bk == NULL) {
+ /* Statfile is does not exist, so all values are zero */
+ return FALSE;
+ }
+
+ if (!bk->in_transaction) {
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_START_IM);
+ bk->in_transaction = TRUE;
+ }
+
+ if (rt->user_id == -1) {
+ if (bk->enable_users) {
+ rt->user_id = rspamd_sqlite3_get_user(bk, task, TRUE);
+ }
+ else {
+ rt->user_id = 0;
+ }
+ }
+
+ if (rt->lang_id == -1) {
+ if (bk->enable_languages) {
+ rt->lang_id = rspamd_sqlite3_get_language(bk, task, TRUE);
+ }
+ else {
+ rt->lang_id = 0;
+ }
+ }
+
+ iv = tok->values[id];
+
+ if (rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_SET_TOKEN,
+ tok->data, rt->user_id, rt->lang_id, iv) != SQLITE_OK) {
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_ROLLBACK);
+ bk->in_transaction = FALSE;
+
+ return FALSE;
+ }
+ }
+
+ return TRUE;
+}
+
+gboolean
+rspamd_sqlite3_finalize_learn(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx, GError **err)
+{
+ struct rspamd_stat_sqlite3_rt *rt = runtime;
+ struct rspamd_stat_sqlite3_db *bk;
+ gint wal_frames, wal_checkpointed, mode;
+
+ g_assert(rt != NULL);
+ bk = rt->db;
+
+ if (bk->in_transaction) {
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT);
+ bk->in_transaction = FALSE;
+ }
+
+#ifdef SQLITE_OPEN_WAL
+#ifdef SQLITE_CHECKPOINT_TRUNCATE
+ mode = SQLITE_CHECKPOINT_TRUNCATE;
+#elif defined(SQLITE_CHECKPOINT_RESTART)
+ mode = SQLITE_CHECKPOINT_RESTART;
+#elif defined(SQLITE_CHECKPOINT_FULL)
+ mode = SQLITE_CHECKPOINT_FULL;
+#endif
+ /* Perform wal checkpoint (might be long) */
+ if (sqlite3_wal_checkpoint_v2(bk->sqlite,
+ NULL,
+ mode,
+ &wal_frames,
+ &wal_checkpointed) != SQLITE_OK) {
+ msg_warn_task("cannot commit checkpoint: %s",
+ sqlite3_errmsg(bk->sqlite));
+
+ g_set_error(err, rspamd_sqlite3_backend_quark(), 500,
+ "cannot commit checkpoint: %s",
+ sqlite3_errmsg(bk->sqlite));
+ return FALSE;
+ }
+#endif
+
+ return TRUE;
+}
+
+gulong
+rspamd_sqlite3_total_learns(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ struct rspamd_stat_sqlite3_rt *rt = runtime;
+ struct rspamd_stat_sqlite3_db *bk;
+ guint64 res;
+
+ g_assert(rt != NULL);
+ bk = rt->db;
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_GET_LEARNS, &res);
+
+ return res;
+}
+
+gulong
+rspamd_sqlite3_inc_learns(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ struct rspamd_stat_sqlite3_rt *rt = runtime;
+ struct rspamd_stat_sqlite3_db *bk;
+ guint64 res;
+
+ g_assert(rt != NULL);
+ bk = rt->db;
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_INC_LEARNS_LANG,
+ rt->lang_id);
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_INC_LEARNS_USER,
+ rt->user_id);
+
+ if (bk->in_transaction) {
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT);
+ bk->in_transaction = FALSE;
+ }
+
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_GET_LEARNS, &res);
+
+ return res;
+}
+
+gulong
+rspamd_sqlite3_dec_learns(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ struct rspamd_stat_sqlite3_rt *rt = runtime;
+ struct rspamd_stat_sqlite3_db *bk;
+ guint64 res;
+
+ g_assert(rt != NULL);
+ bk = rt->db;
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_DEC_LEARNS_LANG,
+ rt->lang_id);
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_DEC_LEARNS_USER,
+ rt->user_id);
+
+ if (bk->in_transaction) {
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT);
+ bk->in_transaction = FALSE;
+ }
+
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_GET_LEARNS, &res);
+
+ return res;
+}
+
+gulong
+rspamd_sqlite3_learns(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ struct rspamd_stat_sqlite3_rt *rt = runtime;
+ struct rspamd_stat_sqlite3_db *bk;
+ guint64 res;
+
+ g_assert(rt != NULL);
+ bk = rt->db;
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_GET_LEARNS, &res);
+
+ return res;
+}
+
+ucl_object_t *
+rspamd_sqlite3_get_stat(gpointer runtime,
+ gpointer ctx)
+{
+ ucl_object_t *res = NULL;
+ struct rspamd_stat_sqlite3_rt *rt = runtime;
+ struct rspamd_stat_sqlite3_db *bk;
+ rspamd_mempool_t *pool;
+ struct stat st;
+ gint64 rev;
+
+ g_assert(rt != NULL);
+ bk = rt->db;
+ pool = bk->pool;
+
+ (void) stat(bk->fname, &st);
+ rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_GET_LEARNS, &rev);
+
+ res = ucl_object_typed_new(UCL_OBJECT);
+ ucl_object_insert_key(res, ucl_object_fromint(rev), "revision",
+ 0, false);
+ ucl_object_insert_key(res, ucl_object_fromint(st.st_size), "size",
+ 0, false);
+ rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_NTOKENS, &rev);
+ ucl_object_insert_key(res, ucl_object_fromint(rev), "total", 0, false);
+ ucl_object_insert_key(res, ucl_object_fromint(rev), "used", 0, false);
+ ucl_object_insert_key(res, ucl_object_fromstring(rt->cf->symbol),
+ "symbol", 0, false);
+ ucl_object_insert_key(res, ucl_object_fromstring("sqlite3"),
+ "type", 0, false);
+ rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_NLANGUAGES, &rev);
+ ucl_object_insert_key(res, ucl_object_fromint(rev),
+ "languages", 0, false);
+ rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_NUSERS, &rev);
+ ucl_object_insert_key(res, ucl_object_fromint(rev),
+ "users", 0, false);
+
+ if (rt->cf->label) {
+ ucl_object_insert_key(res, ucl_object_fromstring(rt->cf->label),
+ "label", 0, false);
+ }
+
+ return res;
+}
+
+gpointer
+rspamd_sqlite3_load_tokenizer_config(gpointer runtime,
+ gsize *len)
+{
+ gpointer tk_conf, copied_conf;
+ guint64 sz;
+ struct rspamd_stat_sqlite3_rt *rt = runtime;
+ struct rspamd_stat_sqlite3_db *bk;
+
+ g_assert(rt != NULL);
+ bk = rt->db;
+
+ g_assert(rspamd_sqlite3_run_prstmt(rt->db->pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_LOAD_TOKENIZER, &sz, &tk_conf) == SQLITE_OK);
+ g_assert(sz > 0);
+ /*
+ * Here we can have either decoded or undecoded version of tokenizer config
+ * XXX: dirty hack to check if we have osb magic here
+ */
+ if (sz > 7 && memcmp(tk_conf, "osbtokv", 7) == 0) {
+ copied_conf = rspamd_mempool_alloc(rt->task->task_pool, sz);
+ memcpy(copied_conf, tk_conf, sz);
+ g_free(tk_conf);
+ }
+ else {
+ /* Need to decode */
+ copied_conf = rspamd_decode_base32(tk_conf, sz, len, RSPAMD_BASE32_DEFAULT);
+ g_free(tk_conf);
+ rspamd_mempool_add_destructor(rt->task->task_pool, g_free, copied_conf);
+ }
+
+ if (len) {
+ *len = sz;
+ }
+
+ return copied_conf;
+}
diff --git a/src/libstat/classifiers/bayes.c b/src/libstat/classifiers/bayes.c
new file mode 100644
index 0000000..513db9a
--- /dev/null
+++ b/src/libstat/classifiers/bayes.c
@@ -0,0 +1,551 @@
+/*-
+ * Copyright 2016 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/*
+ * Bayesian classifier
+ */
+#include "classifiers.h"
+#include "rspamd.h"
+#include "stat_internal.h"
+#include "math.h"
+
+#define msg_err_bayes(...) rspamd_default_log_function(G_LOG_LEVEL_CRITICAL, \
+ "bayes", task->task_pool->tag.uid, \
+ RSPAMD_LOG_FUNC, \
+ __VA_ARGS__)
+#define msg_warn_bayes(...) rspamd_default_log_function(G_LOG_LEVEL_WARNING, \
+ "bayes", task->task_pool->tag.uid, \
+ RSPAMD_LOG_FUNC, \
+ __VA_ARGS__)
+#define msg_info_bayes(...) rspamd_default_log_function(G_LOG_LEVEL_INFO, \
+ "bayes", task->task_pool->tag.uid, \
+ RSPAMD_LOG_FUNC, \
+ __VA_ARGS__)
+
+INIT_LOG_MODULE_PUBLIC(bayes)
+
+static inline GQuark
+bayes_error_quark(void)
+{
+ return g_quark_from_static_string("bayes-error");
+}
+
+/**
+ * Returns probability of chisquare > value with specified number of freedom
+ * degrees
+ * @param value value to test
+ * @param freedom_deg number of degrees of freedom
+ * @return
+ */
+static gdouble
+inv_chi_square(struct rspamd_task *task, gdouble value, gint freedom_deg)
+{
+ double prob, sum, m;
+ gint i;
+
+ errno = 0;
+ m = -value;
+ prob = exp(value);
+
+ if (errno == ERANGE) {
+ /*
+ * e^x where x is large *NEGATIVE* number is OK, so we have a very strong
+ * confidence that inv-chi-square is close to zero
+ */
+ msg_debug_bayes("exp overflow");
+
+ if (value < 0) {
+ return 0;
+ }
+ else {
+ return 1.0;
+ }
+ }
+
+ sum = prob;
+
+ msg_debug_bayes("m: %f, probability: %g", m, prob);
+
+ /*
+ * m is our confidence in class
+ * prob is e ^ x (small value since x is normally less than zero
+ * So we integrate over degrees of freedom and produce the total result
+ * from 1.0 (no confidence) to 0.0 (full confidence)
+ */
+ for (i = 1; i < freedom_deg; i++) {
+ prob *= m / (gdouble) i;
+ sum += prob;
+ msg_debug_bayes("i=%d, probability: %g, sum: %g", i, prob, sum);
+ }
+
+ return MIN(1.0, sum);
+}
+
+struct bayes_task_closure {
+ double ham_prob;
+ double spam_prob;
+ gdouble meta_skip_prob;
+ guint64 processed_tokens;
+ guint64 total_hits;
+ guint64 text_tokens;
+ struct rspamd_task *task;
+};
+
+/*
+ * Mathematically we use pow(complexity, complexity), where complexity is the
+ * window index
+ */
+static const double feature_weight[] = {0, 3125, 256, 27, 1, 0, 0, 0};
+
+#define PROB_COMBINE(prob, cnt, weight, assumed) (((weight) * (assumed) + (cnt) * (prob)) / ((weight) + (cnt)))
+/*
+ * In this callback we calculate local probabilities for tokens
+ */
+static void
+bayes_classify_token(struct rspamd_classifier *ctx,
+ rspamd_token_t *tok, struct bayes_task_closure *cl)
+{
+ guint i;
+ gint id;
+ guint spam_count = 0, ham_count = 0, total_count = 0;
+ struct rspamd_statfile *st;
+ struct rspamd_task *task;
+ const gchar *token_type = "txt";
+ double spam_prob, spam_freq, ham_freq, bayes_spam_prob, bayes_ham_prob,
+ ham_prob, fw, w, val;
+
+ task = cl->task;
+
+#if 0
+ if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_LUA_META) {
+ /* Ignore lua metatokens for now */
+ return;
+ }
+#endif
+
+ if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_META && cl->meta_skip_prob > 0) {
+ val = rspamd_random_double_fast();
+
+ if (val <= cl->meta_skip_prob) {
+ if (tok->t1 && tok->t2) {
+ msg_debug_bayes(
+ "token(meta) %uL <%*s:%*s> probabilistically skipped",
+ tok->data,
+ (int) tok->t1->original.len, tok->t1->original.begin,
+ (int) tok->t2->original.len, tok->t2->original.begin);
+ }
+
+ return;
+ }
+ }
+
+ for (i = 0; i < ctx->statfiles_ids->len; i++) {
+ id = g_array_index(ctx->statfiles_ids, gint, i);
+ st = g_ptr_array_index(ctx->ctx->statfiles, id);
+ g_assert(st != NULL);
+ val = tok->values[id];
+
+ if (val > 0) {
+ if (st->stcf->is_spam) {
+ spam_count += val;
+ }
+ else {
+ ham_count += val;
+ }
+
+ total_count += val;
+ cl->total_hits += val;
+ }
+ }
+
+ /* Probability for this token */
+ if (total_count >= ctx->cfg->min_token_hits) {
+ spam_freq = ((double) spam_count / MAX(1., (double) ctx->spam_learns));
+ ham_freq = ((double) ham_count / MAX(1., (double) ctx->ham_learns));
+ spam_prob = spam_freq / (spam_freq + ham_freq);
+ ham_prob = ham_freq / (spam_freq + ham_freq);
+
+ if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_UNIGRAM) {
+ fw = 1.0;
+ }
+ else {
+ fw = feature_weight[tok->window_idx %
+ G_N_ELEMENTS(feature_weight)];
+ }
+
+
+ w = (fw * total_count) / (1.0 + fw * total_count);
+
+ bayes_spam_prob = PROB_COMBINE(spam_prob, total_count, w, 0.5);
+
+ if ((bayes_spam_prob > 0.5 && bayes_spam_prob < 0.5 + ctx->cfg->min_prob_strength) ||
+ (bayes_spam_prob < 0.5 && bayes_spam_prob > 0.5 - ctx->cfg->min_prob_strength)) {
+ msg_debug_bayes(
+ "token %uL <%*s:%*s> skipped, probability not in range: %f",
+ tok->data,
+ (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
+ (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
+ bayes_spam_prob);
+
+ return;
+ }
+
+ bayes_ham_prob = PROB_COMBINE(ham_prob, total_count, w, 0.5);
+
+ cl->spam_prob += log(bayes_spam_prob);
+ cl->ham_prob += log(bayes_ham_prob);
+ cl->processed_tokens++;
+
+ if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) {
+ cl->text_tokens++;
+ }
+ else {
+ token_type = "meta";
+ }
+
+ if (tok->t1 && tok->t2) {
+ msg_debug_bayes("token(%s) %uL <%*s:%*s>: weight: %f, cf: %f, "
+ "total_count: %ud, "
+ "spam_count: %ud, ham_count: %ud,"
+ "spam_prob: %.3f, ham_prob: %.3f, "
+ "bayes_spam_prob: %.3f, bayes_ham_prob: %.3f, "
+ "current spam probability: %.3f, current ham probability: %.3f",
+ token_type,
+ tok->data,
+ (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
+ (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
+ fw, w, total_count, spam_count, ham_count,
+ spam_prob, ham_prob,
+ bayes_spam_prob, bayes_ham_prob,
+ cl->spam_prob, cl->ham_prob);
+ }
+ else {
+ msg_debug_bayes("token(%s) %uL <?:?>: weight: %f, cf: %f, "
+ "total_count: %ud, "
+ "spam_count: %ud, ham_count: %ud,"
+ "spam_prob: %.3f, ham_prob: %.3f, "
+ "bayes_spam_prob: %.3f, bayes_ham_prob: %.3f, "
+ "current spam probability: %.3f, current ham probability: %.3f",
+ token_type,
+ tok->data,
+ fw, w, total_count, spam_count, ham_count,
+ spam_prob, ham_prob,
+ bayes_spam_prob, bayes_ham_prob,
+ cl->spam_prob, cl->ham_prob);
+ }
+ }
+}
+
+
+gboolean
+bayes_init(struct rspamd_config *cfg,
+ struct ev_loop *ev_base,
+ struct rspamd_classifier *cl)
+{
+ cl->cfg->flags |= RSPAMD_FLAG_CLASSIFIER_INTEGER;
+
+ return TRUE;
+}
+
+void bayes_fin(struct rspamd_classifier *cl)
+{
+}
+
+gboolean
+bayes_classify(struct rspamd_classifier *ctx,
+ GPtrArray *tokens,
+ struct rspamd_task *task)
+{
+ double final_prob, h, s, *pprob;
+ gchar sumbuf[32];
+ struct rspamd_statfile *st = NULL;
+ struct bayes_task_closure cl;
+ rspamd_token_t *tok;
+ guint i, text_tokens = 0;
+ gint id;
+
+ g_assert(ctx != NULL);
+ g_assert(tokens != NULL);
+
+ memset(&cl, 0, sizeof(cl));
+ cl.task = task;
+
+ /* Check min learns */
+ if (ctx->cfg->min_learns > 0) {
+ if (ctx->ham_learns < ctx->cfg->min_learns) {
+ msg_info_task("not classified as ham. The ham class needs more "
+ "training samples. Currently: %ul; minimum %ud required",
+ ctx->ham_learns, ctx->cfg->min_learns);
+
+ return TRUE;
+ }
+ if (ctx->spam_learns < ctx->cfg->min_learns) {
+ msg_info_task("not classified as spam. The spam class needs more "
+ "training samples. Currently: %ul; minimum %ud required",
+ ctx->spam_learns, ctx->cfg->min_learns);
+
+ return TRUE;
+ }
+ }
+
+ for (i = 0; i < tokens->len; i++) {
+ tok = g_ptr_array_index(tokens, i);
+ if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) {
+ text_tokens++;
+ }
+ }
+
+ if (text_tokens == 0) {
+ msg_info_task("skipped classification as there are no text tokens. "
+ "Total tokens: %ud",
+ tokens->len);
+
+ return TRUE;
+ }
+
+ /*
+ * Skip some metatokens if we don't have enough text tokens
+ */
+ if (text_tokens > tokens->len - text_tokens) {
+ cl.meta_skip_prob = 0.0;
+ }
+ else {
+ cl.meta_skip_prob = 1.0 - text_tokens / tokens->len;
+ }
+
+ for (i = 0; i < tokens->len; i++) {
+ tok = g_ptr_array_index(tokens, i);
+
+ bayes_classify_token(ctx, tok, &cl);
+ }
+
+ if (cl.processed_tokens == 0) {
+ msg_info_bayes("no tokens found in bayes database "
+ "(%ud total tokens, %ud text tokens), ignore stats",
+ tokens->len, text_tokens);
+
+ return TRUE;
+ }
+
+ if (ctx->cfg->min_tokens > 0 &&
+ cl.text_tokens < (gint) (ctx->cfg->min_tokens * 0.1)) {
+ msg_info_bayes("ignore bayes probability since we have "
+ "found too few text tokens: %uL (of %ud checked), "
+ "at least %d required",
+ cl.text_tokens,
+ text_tokens,
+ (gint) (ctx->cfg->min_tokens * 0.1));
+
+ return TRUE;
+ }
+
+ if (cl.spam_prob > -300 && cl.ham_prob > -300) {
+ /* Fisher value is low enough to apply inv_chi_square */
+ h = 1 - inv_chi_square(task, cl.spam_prob, cl.processed_tokens);
+ s = 1 - inv_chi_square(task, cl.ham_prob, cl.processed_tokens);
+ }
+ else {
+ /* Use naive method */
+ if (cl.spam_prob < cl.ham_prob) {
+ h = (1.0 - exp(cl.spam_prob - cl.ham_prob)) /
+ (1.0 + exp(cl.spam_prob - cl.ham_prob));
+ s = 1.0 - h;
+ }
+ else {
+ s = (1.0 - exp(cl.ham_prob - cl.spam_prob)) /
+ (1.0 + exp(cl.ham_prob - cl.spam_prob));
+ h = 1.0 - s;
+ }
+ }
+
+ if (isfinite(s) && isfinite(h)) {
+ final_prob = (s + 1.0 - h) / 2.;
+ msg_debug_bayes(
+ "got ham probability %.2f -> %.2f and spam probability %.2f -> %.2f,"
+ " %L tokens processed of %ud total tokens;"
+ " %uL text tokens found of %ud text tokens)",
+ cl.ham_prob,
+ h,
+ cl.spam_prob,
+ s,
+ cl.processed_tokens,
+ tokens->len,
+ cl.text_tokens,
+ text_tokens);
+ }
+ else {
+ /*
+ * We have some overflow, hence we need to check which class
+ * is NaN
+ */
+ if (isfinite(h)) {
+ final_prob = 1.0;
+ msg_debug_bayes("spam class is full: no"
+ " ham samples");
+ }
+ else if (isfinite(s)) {
+ final_prob = 0.0;
+ msg_debug_bayes("ham class is full: no"
+ " spam samples");
+ }
+ else {
+ final_prob = 0.5;
+ msg_warn_bayes("spam and ham classes are both full");
+ }
+ }
+
+ pprob = rspamd_mempool_alloc(task->task_pool, sizeof(*pprob));
+ *pprob = final_prob;
+ rspamd_mempool_set_variable(task->task_pool, "bayes_prob", pprob, NULL);
+
+ if (cl.processed_tokens > 0 && fabs(final_prob - 0.5) > 0.05) {
+ /* Now we can have exactly one HAM and exactly one SPAM statfiles per classifier */
+ for (i = 0; i < ctx->statfiles_ids->len; i++) {
+ id = g_array_index(ctx->statfiles_ids, gint, i);
+ st = g_ptr_array_index(ctx->ctx->statfiles, id);
+
+ if (final_prob > 0.5 && st->stcf->is_spam) {
+ break;
+ }
+ else if (final_prob < 0.5 && !st->stcf->is_spam) {
+ break;
+ }
+ }
+
+ /* Correctly scale HAM */
+ if (final_prob < 0.5) {
+ final_prob = 1.0 - final_prob;
+ }
+
+ /*
+ * Bayes p is from 0.5 to 1.0, but confidence is from 0 to 1, so
+ * we need to rescale it to display correctly
+ */
+ rspamd_snprintf(sumbuf, sizeof(sumbuf), "%.2f%%",
+ (final_prob - 0.5) * 200.);
+ final_prob = rspamd_normalize_probability(final_prob, 0.5);
+ g_assert(st != NULL);
+
+ if (final_prob > 1 || final_prob < 0) {
+ msg_err_bayes("internal error: probability %f is outside of the "
+ "allowed range [0..1]",
+ final_prob);
+
+ if (final_prob > 1) {
+ final_prob = 1.0;
+ }
+ else {
+ final_prob = 0.0;
+ }
+ }
+
+ rspamd_task_insert_result(task,
+ st->stcf->symbol,
+ final_prob,
+ sumbuf);
+ }
+
+ return TRUE;
+}
+
+gboolean
+bayes_learn_spam(struct rspamd_classifier *ctx,
+ GPtrArray *tokens,
+ struct rspamd_task *task,
+ gboolean is_spam,
+ gboolean unlearn,
+ GError **err)
+{
+ guint i, j, total_cnt, spam_cnt, ham_cnt;
+ gint id;
+ struct rspamd_statfile *st;
+ rspamd_token_t *tok;
+ gboolean incrementing;
+
+ g_assert(ctx != NULL);
+ g_assert(tokens != NULL);
+
+ incrementing = ctx->cfg->flags & RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND;
+
+ for (i = 0; i < tokens->len; i++) {
+ total_cnt = 0;
+ spam_cnt = 0;
+ ham_cnt = 0;
+ tok = g_ptr_array_index(tokens, i);
+
+ for (j = 0; j < ctx->statfiles_ids->len; j++) {
+ id = g_array_index(ctx->statfiles_ids, gint, j);
+ st = g_ptr_array_index(ctx->ctx->statfiles, id);
+ g_assert(st != NULL);
+
+ if (!!st->stcf->is_spam == !!is_spam) {
+ if (incrementing) {
+ tok->values[id] = 1;
+ }
+ else {
+ tok->values[id]++;
+ }
+
+ total_cnt += tok->values[id];
+
+ if (st->stcf->is_spam) {
+ spam_cnt += tok->values[id];
+ }
+ else {
+ ham_cnt += tok->values[id];
+ }
+ }
+ else {
+ if (tok->values[id] > 0 && unlearn) {
+ /* Unlearning */
+ if (incrementing) {
+ tok->values[id] = -1;
+ }
+ else {
+ tok->values[id]--;
+ }
+
+ if (st->stcf->is_spam) {
+ spam_cnt += tok->values[id];
+ }
+ else {
+ ham_cnt += tok->values[id];
+ }
+ total_cnt += tok->values[id];
+ }
+ else if (incrementing) {
+ tok->values[id] = 0;
+ }
+ }
+ }
+
+ if (tok->t1 && tok->t2) {
+ msg_debug_bayes("token %uL <%*s:%*s>: window: %d, total_count: %d, "
+ "spam_count: %d, ham_count: %d",
+ tok->data,
+ (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
+ (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
+ tok->window_idx, total_cnt, spam_cnt, ham_cnt);
+ }
+ else {
+ msg_debug_bayes("token %uL <?:?>: window: %d, total_count: %d, "
+ "spam_count: %d, ham_count: %d",
+ tok->data,
+ tok->window_idx, total_cnt, spam_cnt, ham_cnt);
+ }
+ }
+
+ return TRUE;
+}
diff --git a/src/libstat/classifiers/classifiers.h b/src/libstat/classifiers/classifiers.h
new file mode 100644
index 0000000..949408c
--- /dev/null
+++ b/src/libstat/classifiers/classifiers.h
@@ -0,0 +1,109 @@
+/*
+ * Copyright 2023 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef CLASSIFIERS_H
+#define CLASSIFIERS_H
+
+#include "config.h"
+#include "mem_pool.h"
+#include "contrib/libev/ev.h"
+
+#define RSPAMD_DEFAULT_CLASSIFIER "bayes"
+/* Consider this value as 0 */
+#define ALPHA 0.0001
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+struct rspamd_classifier_config;
+struct rspamd_task;
+struct rspamd_config;
+struct rspamd_classifier;
+
+struct token_node_s;
+
+struct rspamd_stat_classifier {
+ char *name;
+
+ gboolean (*init_func)(struct rspamd_config *cfg,
+ struct ev_loop *ev_base,
+ struct rspamd_classifier *cl);
+
+ gboolean (*classify_func)(struct rspamd_classifier *ctx,
+ GPtrArray *tokens,
+ struct rspamd_task *task);
+
+ gboolean (*learn_spam_func)(struct rspamd_classifier *ctx,
+ GPtrArray *input,
+ struct rspamd_task *task,
+ gboolean is_spam,
+ gboolean unlearn,
+ GError **err);
+
+ void (*fin_func)(struct rspamd_classifier *cl);
+};
+
+/* Bayes algorithm */
+gboolean bayes_init(struct rspamd_config *cfg,
+ struct ev_loop *ev_base,
+ struct rspamd_classifier *);
+
+gboolean bayes_classify(struct rspamd_classifier *ctx,
+ GPtrArray *tokens,
+ struct rspamd_task *task);
+
+gboolean bayes_learn_spam(struct rspamd_classifier *ctx,
+ GPtrArray *tokens,
+ struct rspamd_task *task,
+ gboolean is_spam,
+ gboolean unlearn,
+ GError **err);
+
+void bayes_fin(struct rspamd_classifier *);
+
+/* Generic lua classifier */
+gboolean lua_classifier_init(struct rspamd_config *cfg,
+ struct ev_loop *ev_base,
+ struct rspamd_classifier *);
+
+gboolean lua_classifier_classify(struct rspamd_classifier *ctx,
+ GPtrArray *tokens,
+ struct rspamd_task *task);
+
+gboolean lua_classifier_learn_spam(struct rspamd_classifier *ctx,
+ GPtrArray *tokens,
+ struct rspamd_task *task,
+ gboolean is_spam,
+ gboolean unlearn,
+ GError **err);
+
+extern gint rspamd_bayes_log_id;
+#define msg_debug_bayes(...) rspamd_conditional_debug_fast(NULL, task->from_addr, \
+ rspamd_bayes_log_id, "bayes", task->task_pool->tag.uid, \
+ G_STRFUNC, \
+ __VA_ARGS__)
+#define msg_debug_bayes_cfg(...) rspamd_conditional_debug_fast(NULL, NULL, \
+ rspamd_bayes_log_id, "bayes", cfg->cfg_pool->tag.uid, \
+ G_STRFUNC, \
+ __VA_ARGS__)
+
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/src/libstat/classifiers/lua_classifier.c b/src/libstat/classifiers/lua_classifier.c
new file mode 100644
index 0000000..b74330d
--- /dev/null
+++ b/src/libstat/classifiers/lua_classifier.c
@@ -0,0 +1,237 @@
+/*-
+ * Copyright 2016 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "classifiers.h"
+#include "cfg_file.h"
+#include "stat_internal.h"
+#include "lua/lua_common.h"
+
+struct rspamd_lua_classifier_ctx {
+ gchar *name;
+ gint classify_ref;
+ gint learn_ref;
+};
+
+static GHashTable *lua_classifiers = NULL;
+
+#define msg_err_luacl(...) rspamd_default_log_function(G_LOG_LEVEL_CRITICAL, \
+ "luacl", task->task_pool->tag.uid, \
+ RSPAMD_LOG_FUNC, \
+ __VA_ARGS__)
+#define msg_warn_luacl(...) rspamd_default_log_function(G_LOG_LEVEL_WARNING, \
+ "luacl", task->task_pool->tag.uid, \
+ RSPAMD_LOG_FUNC, \
+ __VA_ARGS__)
+#define msg_info_luacl(...) rspamd_default_log_function(G_LOG_LEVEL_INFO, \
+ "luacl", task->task_pool->tag.uid, \
+ RSPAMD_LOG_FUNC, \
+ __VA_ARGS__)
+#define msg_debug_luacl(...) rspamd_conditional_debug_fast(NULL, task->from_addr, \
+ rspamd_luacl_log_id, "luacl", task->task_pool->tag.uid, \
+ RSPAMD_LOG_FUNC, \
+ __VA_ARGS__)
+
+INIT_LOG_MODULE(luacl)
+
+gboolean
+lua_classifier_init(struct rspamd_config *cfg,
+ struct ev_loop *ev_base,
+ struct rspamd_classifier *cl)
+{
+ struct rspamd_lua_classifier_ctx *ctx;
+ lua_State *L = cl->ctx->cfg->lua_state;
+ gint cb_classify = -1, cb_learn = -1;
+
+ if (lua_classifiers == NULL) {
+ lua_classifiers = g_hash_table_new_full(rspamd_strcase_hash,
+ rspamd_strcase_equal, g_free, g_free);
+ }
+
+ ctx = g_hash_table_lookup(lua_classifiers, cl->subrs->name);
+
+ if (ctx != NULL) {
+ msg_err_config("duplicate lua classifier definition: %s",
+ cl->subrs->name);
+
+ return FALSE;
+ }
+
+ lua_getglobal(L, "rspamd_classifiers");
+ if (lua_type(L, -1) != LUA_TTABLE) {
+ msg_err_config("cannot register classifier %s: no rspamd_classifier global",
+ cl->subrs->name);
+ lua_pop(L, 1);
+
+ return FALSE;
+ }
+
+ lua_pushstring(L, cl->subrs->name);
+ lua_gettable(L, -2);
+
+ if (lua_type(L, -1) != LUA_TTABLE) {
+ msg_err_config("cannot register classifier %s: bad lua type: %s",
+ cl->subrs->name, lua_typename(L, lua_type(L, -1)));
+ lua_pop(L, 2);
+
+ return FALSE;
+ }
+
+ lua_pushstring(L, "classify");
+ lua_gettable(L, -2);
+
+ if (lua_type(L, -1) != LUA_TFUNCTION) {
+ msg_err_config("cannot register classifier %s: bad lua type for classify: %s",
+ cl->subrs->name, lua_typename(L, lua_type(L, -1)));
+ lua_pop(L, 3);
+
+ return FALSE;
+ }
+
+ cb_classify = luaL_ref(L, LUA_REGISTRYINDEX);
+
+ lua_pushstring(L, "learn");
+ lua_gettable(L, -2);
+
+ if (lua_type(L, -1) != LUA_TFUNCTION) {
+ msg_err_config("cannot register classifier %s: bad lua type for learn: %s",
+ cl->subrs->name, lua_typename(L, lua_type(L, -1)));
+ lua_pop(L, 3);
+
+ return FALSE;
+ }
+
+ cb_learn = luaL_ref(L, LUA_REGISTRYINDEX);
+ lua_pop(L, 2); /* Table + global */
+
+ ctx = g_malloc0(sizeof(*ctx));
+ ctx->name = g_strdup(cl->subrs->name);
+ ctx->classify_ref = cb_classify;
+ ctx->learn_ref = cb_learn;
+ cl->cfg->flags |= RSPAMD_FLAG_CLASSIFIER_NO_BACKEND;
+ g_hash_table_insert(lua_classifiers, ctx->name, ctx);
+
+ return TRUE;
+}
+gboolean
+lua_classifier_classify(struct rspamd_classifier *cl,
+ GPtrArray *tokens,
+ struct rspamd_task *task)
+{
+ struct rspamd_lua_classifier_ctx *ctx;
+ struct rspamd_task **ptask;
+ struct rspamd_classifier_config **pcfg;
+ lua_State *L;
+ rspamd_token_t *tok;
+ guint i;
+ guint64 v;
+
+ ctx = g_hash_table_lookup(lua_classifiers, cl->subrs->name);
+ g_assert(ctx != NULL);
+ L = task->cfg->lua_state;
+
+ lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->classify_ref);
+ ptask = lua_newuserdata(L, sizeof(*ptask));
+ *ptask = task;
+ rspamd_lua_setclass(L, "rspamd{task}", -1);
+ pcfg = lua_newuserdata(L, sizeof(*pcfg));
+ *pcfg = cl->cfg;
+ rspamd_lua_setclass(L, "rspamd{classifier}", -1);
+
+ lua_createtable(L, tokens->len, 0);
+
+ for (i = 0; i < tokens->len; i++) {
+ tok = g_ptr_array_index(tokens, i);
+ v = tok->data;
+ lua_createtable(L, 3, 0);
+ /* High word, low word, order */
+ lua_pushinteger(L, (guint32) (v >> 32));
+ lua_rawseti(L, -2, 1);
+ lua_pushinteger(L, (guint32) (v));
+ lua_rawseti(L, -2, 2);
+ lua_pushinteger(L, tok->window_idx);
+ lua_rawseti(L, -2, 3);
+ lua_rawseti(L, -2, i + 1);
+ }
+
+ if (lua_pcall(L, 3, 0, 0) != 0) {
+ msg_err_luacl("error running classify function for %s: %s", ctx->name,
+ lua_tostring(L, -1));
+ lua_pop(L, 1);
+
+ return FALSE;
+ }
+
+ return TRUE;
+}
+
+gboolean
+lua_classifier_learn_spam(struct rspamd_classifier *cl,
+ GPtrArray *tokens,
+ struct rspamd_task *task,
+ gboolean is_spam,
+ gboolean unlearn,
+ GError **err)
+{
+ struct rspamd_lua_classifier_ctx *ctx;
+ struct rspamd_task **ptask;
+ struct rspamd_classifier_config **pcfg;
+ lua_State *L;
+ rspamd_token_t *tok;
+ guint i;
+ guint64 v;
+
+ ctx = g_hash_table_lookup(lua_classifiers, cl->subrs->name);
+ g_assert(ctx != NULL);
+ L = task->cfg->lua_state;
+
+ lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->learn_ref);
+ ptask = lua_newuserdata(L, sizeof(*ptask));
+ *ptask = task;
+ rspamd_lua_setclass(L, "rspamd{task}", -1);
+ pcfg = lua_newuserdata(L, sizeof(*pcfg));
+ *pcfg = cl->cfg;
+ rspamd_lua_setclass(L, "rspamd{classifier}", -1);
+
+ lua_createtable(L, tokens->len, 0);
+
+ for (i = 0; i < tokens->len; i++) {
+ tok = g_ptr_array_index(tokens, i);
+ v = 0;
+ v = tok->data;
+ lua_createtable(L, 3, 0);
+ /* High word, low word, order */
+ lua_pushinteger(L, (guint32) (v >> 32));
+ lua_rawseti(L, -2, 1);
+ lua_pushinteger(L, (guint32) (v));
+ lua_rawseti(L, -2, 2);
+ lua_pushinteger(L, tok->window_idx);
+ lua_rawseti(L, -2, 3);
+ lua_rawseti(L, -2, i + 1);
+ }
+
+ lua_pushboolean(L, is_spam);
+ lua_pushboolean(L, unlearn);
+
+ if (lua_pcall(L, 5, 0, 0) != 0) {
+ msg_err_luacl("error running learn function for %s: %s", ctx->name,
+ lua_tostring(L, -1));
+ lua_pop(L, 1);
+
+ return FALSE;
+ }
+
+ return TRUE;
+}
diff --git a/src/libstat/learn_cache/learn_cache.h b/src/libstat/learn_cache/learn_cache.h
new file mode 100644
index 0000000..11a66fc
--- /dev/null
+++ b/src/libstat/learn_cache/learn_cache.h
@@ -0,0 +1,79 @@
+/*-
+ * Copyright 2016 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef LEARN_CACHE_H_
+#define LEARN_CACHE_H_
+
+#include "config.h"
+#include "ucl.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#define RSPAMD_DEFAULT_CACHE "sqlite3"
+
+struct rspamd_task;
+struct rspamd_stat_ctx;
+struct rspamd_config;
+struct rspamd_statfile;
+
+struct rspamd_stat_cache {
+ const char *name;
+
+ gpointer (*init)(struct rspamd_stat_ctx *ctx,
+ struct rspamd_config *cfg,
+ struct rspamd_statfile *st,
+ const ucl_object_t *cf);
+
+ gpointer (*runtime)(struct rspamd_task *task,
+ gpointer ctx, gboolean learn);
+
+ gint (*check)(struct rspamd_task *task,
+ gboolean is_spam,
+ gpointer runtime);
+
+ gint (*learn)(struct rspamd_task *task,
+ gboolean is_spam,
+ gpointer runtime);
+
+ void (*close)(gpointer ctx);
+
+ gpointer ctx;
+};
+
+#define RSPAMD_STAT_CACHE_DEF(name) \
+ gpointer rspamd_stat_cache_##name##_init(struct rspamd_stat_ctx *ctx, \
+ struct rspamd_config *cfg, \
+ struct rspamd_statfile *st, \
+ const ucl_object_t *cf); \
+ gpointer rspamd_stat_cache_##name##_runtime(struct rspamd_task *task, \
+ gpointer ctx, gboolean learn); \
+ gint rspamd_stat_cache_##name##_check(struct rspamd_task *task, \
+ gboolean is_spam, \
+ gpointer runtime); \
+ gint rspamd_stat_cache_##name##_learn(struct rspamd_task *task, \
+ gboolean is_spam, \
+ gpointer runtime); \
+ void rspamd_stat_cache_##name##_close(gpointer ctx)
+
+RSPAMD_STAT_CACHE_DEF(sqlite3);
+RSPAMD_STAT_CACHE_DEF(redis);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif /* LEARN_CACHE_H_ */
diff --git a/src/libstat/learn_cache/redis_cache.cxx b/src/libstat/learn_cache/redis_cache.cxx
new file mode 100644
index 0000000..0be56bc
--- /dev/null
+++ b/src/libstat/learn_cache/redis_cache.cxx
@@ -0,0 +1,254 @@
+/*
+ * Copyright 2024 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "config.h"
+// Include early to avoid `extern "C"` issues
+#include "lua/lua_common.h"
+#include "learn_cache.h"
+#include "rspamd.h"
+#include "stat_api.h"
+#include "stat_internal.h"
+#include "cryptobox.h"
+#include "ucl.h"
+#include "libmime/message.h"
+
+#include <memory>
+
+struct rspamd_redis_cache_ctx {
+ lua_State *L;
+ struct rspamd_statfile_config *stcf;
+ int check_ref = -1;
+ int learn_ref = -1;
+
+ rspamd_redis_cache_ctx() = delete;
+ explicit rspamd_redis_cache_ctx(lua_State *L)
+ : L(L)
+ {
+ }
+
+ ~rspamd_redis_cache_ctx()
+ {
+ if (check_ref != -1) {
+ luaL_unref(L, LUA_REGISTRYINDEX, check_ref);
+ }
+
+ if (learn_ref != -1) {
+ luaL_unref(L, LUA_REGISTRYINDEX, learn_ref);
+ }
+ }
+};
+
+static void
+rspamd_stat_cache_redis_generate_id(struct rspamd_task *task)
+{
+ rspamd_cryptobox_hash_state_t st;
+ rspamd_cryptobox_hash_init(&st, nullptr, 0);
+
+ const auto *user = (const char *) rspamd_mempool_get_variable(task->task_pool, "stat_user");
+ /* Use dedicated hash space for per users cache */
+ if (user != nullptr) {
+ rspamd_cryptobox_hash_update(&st, (const unsigned char *) user, strlen(user));
+ }
+
+ for (auto i = 0; i < task->tokens->len; i++) {
+ const auto *tok = (rspamd_token_t *) g_ptr_array_index(task->tokens, i);
+ rspamd_cryptobox_hash_update(&st, (const unsigned char *) &tok->data,
+ sizeof(tok->data));
+ }
+
+ guchar out[rspamd_cryptobox_HASHBYTES];
+ rspamd_cryptobox_hash_final(&st, out);
+
+ auto *b32out = rspamd_mempool_alloc_array_type(task->task_pool,
+ sizeof(out) * 8 / 5 + 3, char);
+ auto out_sz = rspamd_encode_base32_buf(out, sizeof(out), b32out,
+ sizeof(out) * 8 / 5 + 2, RSPAMD_BASE32_DEFAULT);
+
+ if (out_sz > 0) {
+ /* Zero terminate */
+ b32out[out_sz] = '\0';
+ rspamd_mempool_set_variable(task->task_pool, "words_hash", b32out, nullptr);
+ }
+}
+
+gpointer
+rspamd_stat_cache_redis_init(struct rspamd_stat_ctx *ctx,
+ struct rspamd_config *cfg,
+ struct rspamd_statfile *st,
+ const ucl_object_t *cf)
+{
+ std::unique_ptr<rspamd_redis_cache_ctx> cache_ctx = std::make_unique<rspamd_redis_cache_ctx>(RSPAMD_LUA_CFG_STATE(cfg));
+
+ auto *L = RSPAMD_LUA_CFG_STATE(cfg);
+ lua_settop(L, 0);
+
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ auto err_idx = lua_gettop(L);
+
+ /* Obtain function */
+ if (!rspamd_lua_require_function(L, "lua_bayes_redis", "lua_bayes_init_cache")) {
+ msg_err_config("cannot require lua_bayes_redis.lua_bayes_init_cache");
+ lua_settop(L, err_idx - 1);
+
+ return nullptr;
+ }
+
+ /* Push arguments */
+ ucl_object_push_lua(L, st->classifier->cfg->opts, false);
+ ucl_object_push_lua(L, st->stcf->opts, false);
+
+ if (lua_pcall(L, 2, 2, err_idx) != 0) {
+ msg_err("call to lua_bayes_init_cache "
+ "script failed: %s",
+ lua_tostring(L, -1));
+ lua_settop(L, err_idx - 1);
+
+ return nullptr;
+ }
+
+ /*
+ * Results are in the stack:
+ * top - 1 - check function (idx = -2)
+ * top - learn function (idx = -1)
+ */
+ lua_pushvalue(L, -2);
+ cache_ctx->check_ref = luaL_ref(L, LUA_REGISTRYINDEX);
+
+ lua_pushvalue(L, -1);
+ cache_ctx->learn_ref = luaL_ref(L, LUA_REGISTRYINDEX);
+
+ lua_settop(L, err_idx - 1);
+
+ return (gpointer) cache_ctx.release();
+}
+
+gpointer
+rspamd_stat_cache_redis_runtime(struct rspamd_task *task,
+ gpointer c, gboolean learn)
+{
+ auto *ctx = (struct rspamd_redis_cache_ctx *) c;
+
+ if (task->tokens == nullptr || task->tokens->len == 0) {
+ return nullptr;
+ }
+
+ if (!learn) {
+ /* On check, we produce words_hash variable, on learn it is guaranteed to be set */
+ rspamd_stat_cache_redis_generate_id(task);
+ }
+
+ return (void *) ctx;
+}
+
+static gint
+rspamd_stat_cache_checked(lua_State *L)
+{
+ auto *task = lua_check_task(L, 1);
+ auto res = lua_toboolean(L, 2);
+
+ if (res) {
+ auto val = lua_tointeger(L, 3);
+
+ if ((val > 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM)) ||
+ (val <= 0 && (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM))) {
+ /* Already learned */
+ msg_info_task("<%s> has been already "
+ "learned as %s, ignore it",
+ MESSAGE_FIELD(task, message_id),
+ (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) ? "spam" : "ham");
+ task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
+ }
+ else if (val != 0) {
+ /* Unlearn flag */
+ task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
+ }
+ }
+
+ /* Ignore errors for now, as we can do nothing about them at the moment */
+
+ return 0;
+}
+
+gint rspamd_stat_cache_redis_check(struct rspamd_task *task,
+ gboolean is_spam,
+ gpointer runtime)
+{
+ auto *ctx = (struct rspamd_redis_cache_ctx *) runtime;
+ auto *h = (char *) rspamd_mempool_get_variable(task->task_pool, "words_hash");
+
+ if (h == nullptr) {
+ return RSPAMD_LEARN_IGNORE;
+ }
+
+ auto *L = ctx->L;
+
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ gint err_idx = lua_gettop(L);
+
+ /* Function arguments */
+ lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->check_ref);
+ rspamd_lua_task_push(L, task);
+ lua_pushstring(L, h);
+
+ lua_pushcclosure(L, &rspamd_stat_cache_checked, 0);
+
+ if (lua_pcall(L, 3, 0, err_idx) != 0) {
+ msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
+ lua_settop(L, err_idx - 1);
+ return RSPAMD_LEARN_IGNORE;
+ }
+
+ /* We need to return OK every time */
+ return RSPAMD_LEARN_OK;
+}
+
+gint rspamd_stat_cache_redis_learn(struct rspamd_task *task,
+ gboolean is_spam,
+ gpointer runtime)
+{
+ auto *ctx = (struct rspamd_redis_cache_ctx *) runtime;
+
+ if (rspamd_session_blocked(task->s)) {
+ return RSPAMD_LEARN_IGNORE;
+ }
+
+ auto *h = (char *) rspamd_mempool_get_variable(task->task_pool, "words_hash");
+ g_assert(h != nullptr);
+ auto *L = ctx->L;
+
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ gint err_idx = lua_gettop(L);
+
+ /* Function arguments */
+ lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->learn_ref);
+ rspamd_lua_task_push(L, task);
+ lua_pushstring(L, h);
+ lua_pushboolean(L, is_spam);
+
+ if (lua_pcall(L, 3, 0, err_idx) != 0) {
+ msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
+ lua_settop(L, err_idx - 1);
+ return RSPAMD_LEARN_IGNORE;
+ }
+
+ /* We need to return OK every time */
+ return RSPAMD_LEARN_OK;
+}
+
+void rspamd_stat_cache_redis_close(gpointer c)
+{
+ auto *ctx = (struct rspamd_redis_cache_ctx *) c;
+ delete ctx;
+}
diff --git a/src/libstat/learn_cache/sqlite3_cache.c b/src/libstat/learn_cache/sqlite3_cache.c
new file mode 100644
index 0000000..d8ad20a
--- /dev/null
+++ b/src/libstat/learn_cache/sqlite3_cache.c
@@ -0,0 +1,274 @@
+/*-
+ * Copyright 2016 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "config.h"
+#include "learn_cache.h"
+#include "rspamd.h"
+#include "stat_api.h"
+#include "stat_internal.h"
+#include "cryptobox.h"
+#include "ucl.h"
+#include "fstring.h"
+#include "message.h"
+#include "libutil/sqlite_utils.h"
+
+static const char *create_tables_sql =
+ ""
+ "CREATE TABLE IF NOT EXISTS learns("
+ "id INTEGER PRIMARY KEY,"
+ "flag INTEGER NOT NULL,"
+ "digest TEXT NOT NULL);"
+ "CREATE UNIQUE INDEX IF NOT EXISTS d ON learns(digest);"
+ "";
+
+#define SQLITE_CACHE_PATH RSPAMD_DBDIR "/learn_cache.sqlite"
+
+enum rspamd_stat_sqlite3_stmt_idx {
+ RSPAMD_STAT_CACHE_TRANSACTION_START_IM = 0,
+ RSPAMD_STAT_CACHE_TRANSACTION_START_DEF,
+ RSPAMD_STAT_CACHE_TRANSACTION_COMMIT,
+ RSPAMD_STAT_CACHE_TRANSACTION_ROLLBACK,
+ RSPAMD_STAT_CACHE_GET_LEARN,
+ RSPAMD_STAT_CACHE_ADD_LEARN,
+ RSPAMD_STAT_CACHE_UPDATE_LEARN,
+ RSPAMD_STAT_CACHE_MAX
+};
+
+static struct rspamd_sqlite3_prstmt prepared_stmts[RSPAMD_STAT_CACHE_MAX] =
+ {
+ {.idx = RSPAMD_STAT_CACHE_TRANSACTION_START_IM,
+ .sql = "BEGIN IMMEDIATE TRANSACTION;",
+ .args = "",
+ .stmt = NULL,
+ .result = SQLITE_DONE,
+ .ret = ""},
+ {.idx = RSPAMD_STAT_CACHE_TRANSACTION_START_DEF,
+ .sql = "BEGIN DEFERRED TRANSACTION;",
+ .args = "",
+ .stmt = NULL,
+ .result = SQLITE_DONE,
+ .ret = ""},
+ {.idx = RSPAMD_STAT_CACHE_TRANSACTION_COMMIT,
+ .sql = "COMMIT;",
+ .args = "",
+ .stmt = NULL,
+ .result = SQLITE_DONE,
+ .ret = ""},
+ {.idx = RSPAMD_STAT_CACHE_TRANSACTION_ROLLBACK,
+ .sql = "ROLLBACK;",
+ .args = "",
+ .stmt = NULL,
+ .result = SQLITE_DONE,
+ .ret = ""},
+ {.idx = RSPAMD_STAT_CACHE_GET_LEARN,
+ .sql = "SELECT flag FROM learns WHERE digest=?1",
+ .args = "V",
+ .stmt = NULL,
+ .result = SQLITE_ROW,
+ .ret = "I"},
+ {.idx = RSPAMD_STAT_CACHE_ADD_LEARN,
+ .sql = "INSERT INTO learns(digest, flag) VALUES (?1, ?2);",
+ .args = "VI",
+ .stmt = NULL,
+ .result = SQLITE_DONE,
+ .ret = ""},
+ {.idx = RSPAMD_STAT_CACHE_UPDATE_LEARN,
+ .sql = "UPDATE learns SET flag=?1 WHERE digest=?2;",
+ .args = "IV",
+ .stmt = NULL,
+ .result = SQLITE_DONE,
+ .ret = ""}};
+
+struct rspamd_stat_sqlite3_ctx {
+ sqlite3 *db;
+ GArray *prstmt;
+};
+
+gpointer
+rspamd_stat_cache_sqlite3_init(struct rspamd_stat_ctx *ctx,
+ struct rspamd_config *cfg,
+ struct rspamd_statfile *st,
+ const ucl_object_t *cf)
+{
+ struct rspamd_stat_sqlite3_ctx *new = NULL;
+ const ucl_object_t *elt;
+ gchar dbpath[PATH_MAX];
+ const gchar *path = SQLITE_CACHE_PATH;
+ sqlite3 *sqlite;
+ GError *err = NULL;
+
+ if (cf) {
+ elt = ucl_object_lookup_any(cf, "path", "file", NULL);
+
+ if (elt != NULL) {
+ path = ucl_object_tostring(elt);
+ }
+ }
+
+ rspamd_snprintf(dbpath, sizeof(dbpath), "%s", path);
+
+ sqlite = rspamd_sqlite3_open_or_create(cfg->cfg_pool,
+ dbpath, create_tables_sql, 0, &err);
+
+ if (sqlite == NULL) {
+ msg_err("cannot open sqlite3 cache: %e", err);
+ g_error_free(err);
+ err = NULL;
+ }
+ else {
+ new = g_malloc0(sizeof(*new));
+ new->db = sqlite;
+ new->prstmt = rspamd_sqlite3_init_prstmt(sqlite, prepared_stmts,
+ RSPAMD_STAT_CACHE_MAX, &err);
+
+ if (new->prstmt == NULL) {
+ msg_err("cannot open sqlite3 cache: %e", err);
+ g_error_free(err);
+ err = NULL;
+ sqlite3_close(sqlite);
+ g_free(new);
+ new = NULL;
+ }
+ }
+
+ return new;
+}
+
+gpointer
+rspamd_stat_cache_sqlite3_runtime(struct rspamd_task *task,
+ gpointer ctx, gboolean learn)
+{
+ /* No need of runtime for this type of classifier */
+ return ctx;
+}
+
+gint rspamd_stat_cache_sqlite3_check(struct rspamd_task *task,
+ gboolean is_spam,
+ gpointer runtime)
+{
+ struct rspamd_stat_sqlite3_ctx *ctx = runtime;
+ rspamd_cryptobox_hash_state_t st;
+ rspamd_token_t *tok;
+ guchar *out;
+ gchar *user = NULL;
+ guint i;
+ gint rc;
+ gint64 flag;
+
+ if (task->tokens == NULL || task->tokens->len == 0) {
+ return RSPAMD_LEARN_IGNORE;
+ }
+
+ if (ctx != NULL && ctx->db != NULL) {
+ out = rspamd_mempool_alloc(task->task_pool, rspamd_cryptobox_HASHBYTES);
+
+ rspamd_cryptobox_hash_init(&st, NULL, 0);
+
+ user = rspamd_mempool_get_variable(task->task_pool, "stat_user");
+ /* Use dedicated hash space for per users cache */
+ if (user != NULL) {
+ rspamd_cryptobox_hash_update(&st, user, strlen(user));
+ }
+
+ for (i = 0; i < task->tokens->len; i++) {
+ tok = g_ptr_array_index(task->tokens, i);
+ rspamd_cryptobox_hash_update(&st, (guchar *) &tok->data,
+ sizeof(tok->data));
+ }
+
+ rspamd_cryptobox_hash_final(&st, out);
+
+ rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
+ RSPAMD_STAT_CACHE_TRANSACTION_START_DEF);
+ rc = rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
+ RSPAMD_STAT_CACHE_GET_LEARN, (gint64) rspamd_cryptobox_HASHBYTES,
+ out, &flag);
+ rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
+ RSPAMD_STAT_CACHE_TRANSACTION_COMMIT);
+
+ /* Save hash into variables */
+ rspamd_mempool_set_variable(task->task_pool, "words_hash", out, NULL);
+
+ if (rc == SQLITE_OK) {
+ /* We have some existing record in the table */
+ if (!!flag == !!is_spam) {
+ /* Already learned */
+ msg_warn_task("already seen stat hash: %*bs",
+ rspamd_cryptobox_HASHBYTES, out);
+ return RSPAMD_LEARN_IGNORE;
+ }
+ else {
+ /* Need to relearn */
+ return RSPAMD_LEARN_UNLEARN;
+ }
+ }
+ }
+
+ return RSPAMD_LEARN_OK;
+}
+
+gint rspamd_stat_cache_sqlite3_learn(struct rspamd_task *task,
+ gboolean is_spam,
+ gpointer runtime)
+{
+ struct rspamd_stat_sqlite3_ctx *ctx = runtime;
+ gboolean unlearn = !!(task->flags & RSPAMD_TASK_FLAG_UNLEARN);
+ guchar *h;
+ gint64 flag;
+
+ h = rspamd_mempool_get_variable(task->task_pool, "words_hash");
+
+ if (h == NULL) {
+ return RSPAMD_LEARN_IGNORE;
+ }
+
+ flag = !!is_spam ? 1 : 0;
+
+ if (!unlearn) {
+ /* Insert result new id */
+ rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
+ RSPAMD_STAT_CACHE_TRANSACTION_START_IM);
+ rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
+ RSPAMD_STAT_CACHE_ADD_LEARN,
+ (gint64) rspamd_cryptobox_HASHBYTES, h, flag);
+ rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
+ RSPAMD_STAT_CACHE_TRANSACTION_COMMIT);
+ }
+ else {
+ rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
+ RSPAMD_STAT_CACHE_TRANSACTION_START_IM);
+ rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
+ RSPAMD_STAT_CACHE_UPDATE_LEARN,
+ flag,
+ (gint64) rspamd_cryptobox_HASHBYTES, h);
+ rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
+ RSPAMD_STAT_CACHE_TRANSACTION_COMMIT);
+ }
+
+ rspamd_sqlite3_sync(ctx->db, NULL, NULL);
+
+ return RSPAMD_LEARN_OK;
+}
+
+void rspamd_stat_cache_sqlite3_close(gpointer c)
+{
+ struct rspamd_stat_sqlite3_ctx *ctx = (struct rspamd_stat_sqlite3_ctx *) c;
+
+ if (ctx != NULL) {
+ rspamd_sqlite3_close_prstmt(ctx->db, ctx->prstmt);
+ sqlite3_close(ctx->db);
+ g_free(ctx);
+ }
+}
diff --git a/src/libstat/stat_api.h b/src/libstat/stat_api.h
new file mode 100644
index 0000000..1badb20
--- /dev/null
+++ b/src/libstat/stat_api.h
@@ -0,0 +1,147 @@
+/*-
+ * Copyright 2016 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef STAT_API_H_
+#define STAT_API_H_
+
+#include "config.h"
+#include "task.h"
+#include "lua/lua_common.h"
+#include "contrib/libev/ev.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * @file stat_api.h
+ * High level statistics API
+ */
+
+#define RSPAMD_STAT_TOKEN_FLAG_TEXT (1u << 0)
+#define RSPAMD_STAT_TOKEN_FLAG_META (1u << 1)
+#define RSPAMD_STAT_TOKEN_FLAG_LUA_META (1u << 2)
+#define RSPAMD_STAT_TOKEN_FLAG_EXCEPTION (1u << 3)
+#define RSPAMD_STAT_TOKEN_FLAG_HEADER (1u << 4)
+#define RSPAMD_STAT_TOKEN_FLAG_UNIGRAM (1u << 5)
+#define RSPAMD_STAT_TOKEN_FLAG_UTF (1u << 6)
+#define RSPAMD_STAT_TOKEN_FLAG_NORMALISED (1u << 7)
+#define RSPAMD_STAT_TOKEN_FLAG_STEMMED (1u << 8)
+#define RSPAMD_STAT_TOKEN_FLAG_BROKEN_UNICODE (1u << 9)
+#define RSPAMD_STAT_TOKEN_FLAG_STOP_WORD (1u << 10)
+#define RSPAMD_STAT_TOKEN_FLAG_SKIPPED (1u << 11)
+#define RSPAMD_STAT_TOKEN_FLAG_INVISIBLE_SPACES (1u << 12)
+#define RSPAMD_STAT_TOKEN_FLAG_EMOJI (1u << 13)
+
+typedef struct rspamd_stat_token_s {
+ rspamd_ftok_t original; /* utf8 raw */
+ rspamd_ftok_unicode_t unicode; /* array of unicode characters, normalized, lowercased */
+ rspamd_ftok_t normalized; /* normalized and lowercased utf8 */
+ rspamd_ftok_t stemmed; /* stemmed utf8 */
+ guint flags;
+} rspamd_stat_token_t;
+
+typedef struct token_node_s {
+ guint64 data;
+ guint window_idx;
+ guint flags;
+ rspamd_stat_token_t *t1;
+ rspamd_stat_token_t *t2;
+ float values[];
+} rspamd_token_t;
+
+struct rspamd_stat_ctx;
+
+/**
+ * The results of statistics processing:
+ * - error
+ * - need to do additional job for processing
+ * - all processed
+ */
+typedef enum rspamd_stat_result_e {
+ RSPAMD_STAT_PROCESS_ERROR = 0,
+ RSPAMD_STAT_PROCESS_DELAYED = 1,
+ RSPAMD_STAT_PROCESS_OK
+} rspamd_stat_result_t;
+
+/**
+ * Initialise statistics modules
+ * @param cfg
+ */
+void rspamd_stat_init(struct rspamd_config *cfg, struct ev_loop *ev_base);
+
+/**
+ * Finalize statistics
+ */
+void rspamd_stat_close(void);
+
+/**
+ * Tokenize task
+ * @param st_ctx
+ * @param task
+ */
+void rspamd_stat_process_tokenize(struct rspamd_stat_ctx *st_ctx,
+ struct rspamd_task *task);
+
+/**
+ * Classify the task specified and insert symbols if needed
+ * @param task
+ * @param L lua state
+ * @param err error returned
+ * @return TRUE if task has been classified
+ */
+rspamd_stat_result_t rspamd_stat_classify(struct rspamd_task *task,
+ lua_State *L, guint stage, GError **err);
+
+
+/**
+ * Check if a task should be learned and set the appropriate flags for it
+ * @param task
+ * @return
+ */
+gboolean rspamd_stat_check_autolearn(struct rspamd_task *task);
+
+/**
+ * Learn task as spam or ham, task must be processed prior to this call
+ * @param task task to learn
+ * @param spam if TRUE learn spam, otherwise learn ham
+ * @param L lua state
+ * @param classifier NULL to learn all classifiers, name to learn a specific one
+ * @param err error returned
+ * @return TRUE if task has been learned
+ */
+rspamd_stat_result_t rspamd_stat_learn(struct rspamd_task *task,
+ gboolean spam, lua_State *L, const gchar *classifier,
+ guint stage,
+ GError **err);
+
+/**
+ * Get the overall statistics for all statfile backends
+ * @param cfg configuration
+ * @param total_learns the total number of learns is stored here
+ * @return array of statistical information
+ */
+rspamd_stat_result_t rspamd_stat_statistics(struct rspamd_task *task,
+ struct rspamd_config *cfg,
+ guint64 *total_learns,
+ ucl_object_t **res);
+
+void rspamd_stat_unload(void);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif /* STAT_API_H_ */
diff --git a/src/libstat/stat_config.c b/src/libstat/stat_config.c
new file mode 100644
index 0000000..2748044
--- /dev/null
+++ b/src/libstat/stat_config.c
@@ -0,0 +1,603 @@
+/*-
+ * Copyright 2016 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "config.h"
+#include "stat_api.h"
+#include "rspamd.h"
+#include "cfg_rcl.h"
+#include "stat_internal.h"
+#include "lua/lua_common.h"
+
+static struct rspamd_stat_ctx *stat_ctx = NULL;
+
+static struct rspamd_stat_classifier lua_classifier = {
+ .name = "lua",
+ .init_func = lua_classifier_init,
+ .classify_func = lua_classifier_classify,
+ .learn_spam_func = lua_classifier_learn_spam,
+ .fin_func = NULL,
+};
+
+static struct rspamd_stat_classifier stat_classifiers[] = {
+ {
+ .name = "bayes",
+ .init_func = bayes_init,
+ .classify_func = bayes_classify,
+ .learn_spam_func = bayes_learn_spam,
+ .fin_func = bayes_fin,
+ }};
+
+static struct rspamd_stat_tokenizer stat_tokenizers[] = {
+ {
+ .name = "osb-text",
+ .get_config = rspamd_tokenizer_osb_get_config,
+ .tokenize_func = rspamd_tokenizer_osb,
+ },
+ {
+ .name = "osb",
+ .get_config = rspamd_tokenizer_osb_get_config,
+ .tokenize_func = rspamd_tokenizer_osb,
+ },
+};
+
+#define RSPAMD_STAT_BACKEND_ELT(nam, eltn) \
+ { \
+ .name = #nam, \
+ .read_only = false, \
+ .init = rspamd_##eltn##_init, \
+ .runtime = rspamd_##eltn##_runtime, \
+ .process_tokens = rspamd_##eltn##_process_tokens, \
+ .finalize_process = rspamd_##eltn##_finalize_process, \
+ .learn_tokens = rspamd_##eltn##_learn_tokens, \
+ .finalize_learn = rspamd_##eltn##_finalize_learn, \
+ .total_learns = rspamd_##eltn##_total_learns, \
+ .inc_learns = rspamd_##eltn##_inc_learns, \
+ .dec_learns = rspamd_##eltn##_dec_learns, \
+ .get_stat = rspamd_##eltn##_get_stat, \
+ .load_tokenizer_config = rspamd_##eltn##_load_tokenizer_config, \
+ .close = rspamd_##eltn##_close \
+ }
+#define RSPAMD_STAT_BACKEND_ELT_READONLY(nam, eltn) \
+ { \
+ .name = #nam, \
+ .read_only = true, \
+ .init = rspamd_##eltn##_init, \
+ .runtime = rspamd_##eltn##_runtime, \
+ .process_tokens = rspamd_##eltn##_process_tokens, \
+ .finalize_process = rspamd_##eltn##_finalize_process, \
+ .learn_tokens = NULL, \
+ .finalize_learn = NULL, \
+ .total_learns = rspamd_##eltn##_total_learns, \
+ .inc_learns = NULL, \
+ .dec_learns = NULL, \
+ .get_stat = rspamd_##eltn##_get_stat, \
+ .load_tokenizer_config = rspamd_##eltn##_load_tokenizer_config, \
+ .close = rspamd_##eltn##_close \
+ }
+
+static struct rspamd_stat_backend stat_backends[] = {
+ RSPAMD_STAT_BACKEND_ELT(mmap, mmaped_file),
+ RSPAMD_STAT_BACKEND_ELT(sqlite3, sqlite3),
+ RSPAMD_STAT_BACKEND_ELT_READONLY(cdb, cdb),
+ RSPAMD_STAT_BACKEND_ELT(redis, redis)};
+
+#define RSPAMD_STAT_CACHE_ELT(nam, eltn) \
+ { \
+ .name = #nam, \
+ .init = rspamd_stat_cache_##eltn##_init, \
+ .runtime = rspamd_stat_cache_##eltn##_runtime, \
+ .check = rspamd_stat_cache_##eltn##_check, \
+ .learn = rspamd_stat_cache_##eltn##_learn, \
+ .close = rspamd_stat_cache_##eltn##_close \
+ }
+
+static struct rspamd_stat_cache stat_caches[] = {
+ RSPAMD_STAT_CACHE_ELT(sqlite3, sqlite3),
+ RSPAMD_STAT_CACHE_ELT(redis, redis),
+};
+
+void rspamd_stat_init(struct rspamd_config *cfg, struct ev_loop *ev_base)
+{
+ GList *cur, *curst;
+ struct rspamd_classifier_config *clf;
+ struct rspamd_statfile_config *stf;
+ struct rspamd_stat_backend *bk;
+ struct rspamd_statfile *st;
+ struct rspamd_classifier *cl;
+ const ucl_object_t *cache_obj = NULL, *cache_name_obj;
+ const gchar *cache_name = NULL;
+ lua_State *L = cfg->lua_state;
+ guint lua_classifiers_cnt = 0, i;
+ gboolean skip_cache = FALSE;
+
+ if (stat_ctx == NULL) {
+ stat_ctx = g_malloc0(sizeof(*stat_ctx));
+ }
+
+ lua_getglobal(L, "rspamd_classifiers");
+
+ if (lua_type(L, -1) == LUA_TTABLE) {
+ lua_pushnil(L);
+
+ while (lua_next(L, -2) != 0) {
+ lua_classifiers_cnt++;
+ lua_pop(L, 1);
+ }
+ }
+
+ lua_pop(L, 1);
+
+ stat_ctx->classifiers_count = G_N_ELEMENTS(stat_classifiers) +
+ lua_classifiers_cnt;
+ stat_ctx->classifiers_subrs = g_new0(struct rspamd_stat_classifier,
+ stat_ctx->classifiers_count);
+
+ for (i = 0; i < G_N_ELEMENTS(stat_classifiers); i++) {
+ memcpy(&stat_ctx->classifiers_subrs[i], &stat_classifiers[i],
+ sizeof(struct rspamd_stat_classifier));
+ }
+
+ lua_getglobal(L, "rspamd_classifiers");
+
+ if (lua_type(L, -1) == LUA_TTABLE) {
+ lua_pushnil(L);
+
+ while (lua_next(L, -2) != 0) {
+ lua_pushvalue(L, -2);
+ memcpy(&stat_ctx->classifiers_subrs[i], &lua_classifier,
+ sizeof(struct rspamd_stat_classifier));
+ stat_ctx->classifiers_subrs[i].name = g_strdup(lua_tostring(L, -1));
+ i++;
+ lua_pop(L, 2);
+ }
+ }
+
+ lua_pop(L, 1);
+ stat_ctx->backends_subrs = stat_backends;
+ stat_ctx->backends_count = G_N_ELEMENTS(stat_backends);
+
+ stat_ctx->tokenizers_subrs = stat_tokenizers;
+ stat_ctx->tokenizers_count = G_N_ELEMENTS(stat_tokenizers);
+ stat_ctx->caches_subrs = stat_caches;
+ stat_ctx->caches_count = G_N_ELEMENTS(stat_caches);
+ stat_ctx->cfg = cfg;
+ stat_ctx->statfiles = g_ptr_array_new();
+ stat_ctx->classifiers = g_ptr_array_new();
+ stat_ctx->async_elts = g_queue_new();
+ stat_ctx->event_loop = ev_base;
+ stat_ctx->lua_stat_tokens_ref = -1;
+
+ /* Interact with lua_stat */
+ if (luaL_dostring(L, "return require \"lua_stat\"") != 0) {
+ msg_err_config("cannot require lua_stat: %s",
+ lua_tostring(L, -1));
+ }
+ else {
+#if LUA_VERSION_NUM >= 504
+ lua_settop(L, -2);
+#endif
+ if (lua_type(L, -1) != LUA_TTABLE) {
+ msg_err_config("lua stat must return "
+ "table and not %s",
+ lua_typename(L, lua_type(L, -1)));
+ }
+ else {
+ lua_pushstring(L, "gen_stat_tokens");
+ lua_gettable(L, -2);
+
+ if (lua_type(L, -1) != LUA_TFUNCTION) {
+ msg_err_config("gen_stat_tokens must return "
+ "function and not %s",
+ lua_typename(L, lua_type(L, -1)));
+ }
+ else {
+ /* Call this function to obtain closure */
+ gint err_idx, ret;
+ struct rspamd_config **pcfg;
+
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ err_idx = lua_gettop(L);
+ lua_pushvalue(L, err_idx - 1);
+
+ pcfg = lua_newuserdata(L, sizeof(*pcfg));
+ *pcfg = cfg;
+ rspamd_lua_setclass(L, "rspamd{config}", -1);
+
+ if ((ret = lua_pcall(L, 1, 1, err_idx)) != 0) {
+ msg_err_config("call to gen_stat_tokens lua "
+ "script failed (%d): %s",
+ ret,
+ lua_tostring(L, -1));
+ }
+ else {
+ if (lua_type(L, -1) != LUA_TFUNCTION) {
+ msg_err_config("gen_stat_tokens invocation must return "
+ "function and not %s",
+ lua_typename(L, lua_type(L, -1)));
+ }
+ else {
+ stat_ctx->lua_stat_tokens_ref = luaL_ref(L, LUA_REGISTRYINDEX);
+ }
+ }
+ }
+ }
+ }
+
+ /* Cleanup mess */
+ lua_settop(L, 0);
+
+ /* Create statfiles from the classifiers */
+ cur = cfg->classifiers;
+
+ while (cur) {
+ bk = NULL;
+ clf = cur->data;
+ cl = g_malloc0(sizeof(*cl));
+ cl->cfg = clf;
+ cl->ctx = stat_ctx;
+ cl->statfiles_ids = g_array_new(FALSE, FALSE, sizeof(gint));
+ cl->subrs = rspamd_stat_get_classifier(clf->classifier);
+
+ if (cl->subrs == NULL) {
+ g_free(cl);
+ msg_err_config("cannot init classifier type %s", clf->name);
+ cur = g_list_next(cur);
+ continue;
+ }
+
+ if (!cl->subrs->init_func(cfg, ev_base, cl)) {
+ g_free(cl);
+ msg_err_config("cannot init classifier type %s", clf->name);
+ cur = g_list_next(cur);
+ continue;
+ }
+
+ if (!(clf->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND)) {
+ bk = rspamd_stat_get_backend(clf->backend);
+
+ if (bk == NULL) {
+ msg_err_config("cannot get backend of type %s, so disable classifier"
+ " %s completely",
+ clf->backend, clf->name);
+ cur = g_list_next(cur);
+ continue;
+ }
+ }
+ else {
+ /* This actually is not implemented so it should never happen */
+ g_free(cl);
+ cur = g_list_next(cur);
+ continue;
+ }
+
+ /* XXX:
+ * Here we get the first classifier tokenizer config as the only one
+ * We NO LONGER support multiple tokenizers per rspamd instance
+ */
+ if (stat_ctx->tkcf == NULL) {
+ stat_ctx->tokenizer = rspamd_stat_get_tokenizer(clf->tokenizer->name);
+ g_assert(stat_ctx->tokenizer != NULL);
+ stat_ctx->tkcf = stat_ctx->tokenizer->get_config(cfg->cfg_pool,
+ clf->tokenizer, NULL);
+ }
+
+ /* Init classifier cache */
+ cache_name = NULL;
+
+ if (!bk->read_only) {
+ if (clf->opts) {
+ cache_obj = ucl_object_lookup(clf->opts, "cache");
+ cache_name_obj = NULL;
+
+ if (cache_obj && ucl_object_type(cache_obj) == UCL_NULL) {
+ skip_cache = TRUE;
+ }
+ else {
+ if (cache_obj) {
+ cache_name_obj = ucl_object_lookup_any(cache_obj,
+ "name", "type", NULL);
+ }
+
+ if (cache_name_obj) {
+ cache_name = ucl_object_tostring(cache_name_obj);
+ }
+ }
+ }
+ }
+ else {
+ skip_cache = true;
+ }
+
+ if (cache_name == NULL && !skip_cache) {
+ /* We assume that learn cache is the same as backend */
+ cache_name = clf->backend;
+ }
+
+ curst = clf->statfiles;
+
+ while (curst) {
+ stf = curst->data;
+ st = g_malloc0(sizeof(*st));
+ st->classifier = cl;
+ st->stcf = stf;
+
+ if (!(cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND)) {
+ st->backend = bk;
+ st->bkcf = bk->init(stat_ctx, cfg, st);
+ msg_info_config("added backend %s for symbol %s",
+ bk->name, stf->symbol);
+ }
+ else {
+ msg_debug_config("added backend-less statfile for symbol %s",
+ stf->symbol);
+ }
+
+ /* XXX: bad hack to pass statfiles configuration to cache */
+ if (cl->cache == NULL && !skip_cache) {
+ cl->cache = rspamd_stat_get_cache(cache_name);
+ g_assert(cl->cache != NULL);
+ cl->cachecf = cl->cache->init(stat_ctx, cfg, st, cache_obj);
+
+ if (cl->cachecf == NULL) {
+ msg_err_config("error adding cache %s for symbol %s",
+ cl->cache->name, stf->symbol);
+ cl->cache = NULL;
+ }
+ else {
+ msg_debug_config("added cache %s for symbol %s",
+ cl->cache->name, stf->symbol);
+ }
+ }
+
+ if (st->bkcf == NULL &&
+ !(cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND)) {
+ msg_err_config("cannot init backend %s for statfile %s",
+ clf->backend, stf->symbol);
+
+ g_free(st);
+ }
+ else {
+ st->id = stat_ctx->statfiles->len;
+ g_ptr_array_add(stat_ctx->statfiles, st);
+ g_array_append_val(cl->statfiles_ids, st->id);
+ }
+
+ curst = curst->next;
+ }
+
+ g_ptr_array_add(stat_ctx->classifiers, cl);
+
+ cur = cur->next;
+ }
+}
+
+void rspamd_stat_close(void)
+{
+ struct rspamd_classifier *cl;
+ struct rspamd_statfile *st;
+ struct rspamd_stat_ctx *st_ctx;
+ struct rspamd_stat_async_elt *aelt;
+ GList *cur;
+ guint i, j;
+ gint id;
+
+ st_ctx = rspamd_stat_get_ctx();
+ g_assert(st_ctx != NULL);
+
+ for (i = 0; i < st_ctx->classifiers->len; i++) {
+ cl = g_ptr_array_index(st_ctx->classifiers, i);
+
+ for (j = 0; j < cl->statfiles_ids->len; j++) {
+ id = g_array_index(cl->statfiles_ids, gint, j);
+ st = g_ptr_array_index(st_ctx->statfiles, id);
+ if (!(st->classifier->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND)) {
+ st->backend->close(st->bkcf);
+ }
+
+ g_free(st);
+ }
+
+ if (cl->cache && cl->cachecf) {
+ cl->cache->close(cl->cachecf);
+ }
+
+ g_array_free(cl->statfiles_ids, TRUE);
+
+ if (cl->subrs->fin_func) {
+ cl->subrs->fin_func(cl);
+ }
+
+ g_free(cl);
+ }
+
+ cur = st_ctx->async_elts->head;
+
+ while (cur) {
+ aelt = cur->data;
+ REF_RELEASE(aelt);
+ cur = g_list_next(cur);
+ }
+
+ g_queue_free(stat_ctx->async_elts);
+ g_ptr_array_free(st_ctx->statfiles, TRUE);
+ g_ptr_array_free(st_ctx->classifiers, TRUE);
+
+ if (st_ctx->lua_stat_tokens_ref != -1) {
+ luaL_unref(st_ctx->cfg->lua_state, LUA_REGISTRYINDEX,
+ st_ctx->lua_stat_tokens_ref);
+ }
+
+ g_free(st_ctx->classifiers_subrs);
+ g_free(st_ctx);
+
+ /* Set global var to NULL */
+ stat_ctx = NULL;
+}
+
+struct rspamd_stat_ctx *
+rspamd_stat_get_ctx(void)
+{
+ return stat_ctx;
+}
+
+struct rspamd_stat_classifier *
+rspamd_stat_get_classifier(const gchar *name)
+{
+ guint i;
+
+ if (name == NULL || name[0] == '\0') {
+ name = RSPAMD_DEFAULT_CLASSIFIER;
+ }
+
+ for (i = 0; i < stat_ctx->classifiers_count; i++) {
+ if (strcmp(name, stat_ctx->classifiers_subrs[i].name) == 0) {
+ return &stat_ctx->classifiers_subrs[i];
+ }
+ }
+
+ msg_err("cannot find classifier named %s", name);
+
+ return NULL;
+}
+
+struct rspamd_stat_backend *
+rspamd_stat_get_backend(const gchar *name)
+{
+ guint i;
+
+ if (name == NULL || name[0] == '\0') {
+ name = RSPAMD_DEFAULT_BACKEND;
+ }
+
+ for (i = 0; i < stat_ctx->backends_count; i++) {
+ if (strcmp(name, stat_ctx->backends_subrs[i].name) == 0) {
+ return &stat_ctx->backends_subrs[i];
+ }
+ }
+
+ msg_err("cannot find backend named %s", name);
+
+ return NULL;
+}
+
+struct rspamd_stat_tokenizer *
+rspamd_stat_get_tokenizer(const gchar *name)
+{
+ guint i;
+
+ if (name == NULL || name[0] == '\0') {
+ name = RSPAMD_DEFAULT_TOKENIZER;
+ }
+
+ for (i = 0; i < stat_ctx->tokenizers_count; i++) {
+ if (strcmp(name, stat_ctx->tokenizers_subrs[i].name) == 0) {
+ return &stat_ctx->tokenizers_subrs[i];
+ }
+ }
+
+ msg_err("cannot find tokenizer named %s", name);
+
+ return NULL;
+}
+
+struct rspamd_stat_cache *
+rspamd_stat_get_cache(const gchar *name)
+{
+ guint i;
+
+ if (name == NULL || name[0] == '\0') {
+ name = RSPAMD_DEFAULT_CACHE;
+ }
+
+ for (i = 0; i < stat_ctx->caches_count; i++) {
+ if (strcmp(name, stat_ctx->caches_subrs[i].name) == 0) {
+ return &stat_ctx->caches_subrs[i];
+ }
+ }
+
+ msg_err("cannot find cache named %s", name);
+
+ return NULL;
+}
+
+static void
+rspamd_async_elt_dtor(struct rspamd_stat_async_elt *elt)
+{
+ if (elt->cleanup) {
+ elt->cleanup(elt, elt->ud);
+ }
+
+ ev_timer_stop(elt->event_loop, &elt->timer_ev);
+ g_free(elt);
+}
+
+static void
+rspamd_async_elt_on_timer(EV_P_ ev_timer *w, int revents)
+{
+ struct rspamd_stat_async_elt *elt = (struct rspamd_stat_async_elt *) w->data;
+ gdouble jittered_time;
+
+
+ if (elt->enabled) {
+ elt->handler(elt, elt->ud);
+ }
+
+ jittered_time = rspamd_time_jitter(elt->timeout, 0);
+ elt->timer_ev.repeat = jittered_time;
+ ev_timer_again(EV_A_ w);
+}
+
+struct rspamd_stat_async_elt *
+rspamd_stat_ctx_register_async(rspamd_stat_async_handler handler,
+ rspamd_stat_async_cleanup cleanup,
+ gpointer d,
+ gdouble timeout)
+{
+ struct rspamd_stat_async_elt *elt;
+ struct rspamd_stat_ctx *st_ctx;
+
+ st_ctx = rspamd_stat_get_ctx();
+ g_assert(st_ctx != NULL);
+
+ elt = g_malloc0(sizeof(*elt));
+ elt->handler = handler;
+ elt->cleanup = cleanup;
+ elt->ud = d;
+ elt->timeout = timeout;
+ elt->event_loop = st_ctx->event_loop;
+ REF_INIT_RETAIN(elt, rspamd_async_elt_dtor);
+ /* Enabled by default */
+
+
+ if (st_ctx->event_loop) {
+ elt->enabled = TRUE;
+ /*
+ * First we set timeval to zero as we want cb to be executed as
+ * fast as possible
+ */
+ elt->timer_ev.data = elt;
+ ev_timer_init(&elt->timer_ev, rspamd_async_elt_on_timer,
+ 0.1, 0.0);
+ ev_timer_start(st_ctx->event_loop, &elt->timer_ev);
+ }
+ else {
+ elt->enabled = FALSE;
+ }
+
+ g_queue_push_tail(st_ctx->async_elts, elt);
+
+ return elt;
+}
diff --git a/src/libstat/stat_internal.h b/src/libstat/stat_internal.h
new file mode 100644
index 0000000..8d0ebd4
--- /dev/null
+++ b/src/libstat/stat_internal.h
@@ -0,0 +1,134 @@
+/*-
+ * Copyright 2016 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef STAT_INTERNAL_H_
+#define STAT_INTERNAL_H_
+
+#include "config.h"
+#include "task.h"
+#include "ref.h"
+#include "classifiers/classifiers.h"
+#include "tokenizers/tokenizers.h"
+#include "backends/backends.h"
+#include "learn_cache/learn_cache.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+struct rspamd_statfile_runtime {
+ struct rspamd_statfile_config *st;
+ gpointer backend_runtime;
+ guint64 hits;
+ guint64 total_hits;
+};
+
+/* Common classifier structure */
+struct rspamd_classifier {
+ struct rspamd_stat_ctx *ctx;
+ GArray *statfiles_ids; /* int */
+ struct rspamd_stat_cache *cache;
+ gpointer cachecf;
+ gulong spam_learns;
+ gulong ham_learns;
+ gint autolearn_cbref;
+ struct rspamd_classifier_config *cfg;
+ struct rspamd_stat_classifier *subrs;
+ gpointer specific;
+};
+
+struct rspamd_statfile {
+ gint id;
+ struct rspamd_statfile_config *stcf;
+ struct rspamd_classifier *classifier;
+ struct rspamd_stat_backend *backend;
+ gpointer bkcf;
+};
+
+struct rspamd_stat_async_elt;
+
+typedef void (*rspamd_stat_async_handler)(struct rspamd_stat_async_elt *elt,
+ gpointer ud);
+
+typedef void (*rspamd_stat_async_cleanup)(struct rspamd_stat_async_elt *elt,
+ gpointer ud);
+
+struct rspamd_stat_async_elt {
+ rspamd_stat_async_handler handler;
+ rspamd_stat_async_cleanup cleanup;
+ struct ev_loop *event_loop;
+ ev_timer timer_ev;
+ gdouble timeout;
+ gboolean enabled;
+ gpointer ud;
+ ref_entry_t ref;
+};
+
+struct rspamd_stat_ctx {
+ /* Subroutines for all objects */
+ struct rspamd_stat_classifier *classifiers_subrs;
+ guint classifiers_count;
+ struct rspamd_stat_tokenizer *tokenizers_subrs;
+ guint tokenizers_count;
+ struct rspamd_stat_backend *backends_subrs;
+ guint backends_count;
+ struct rspamd_stat_cache *caches_subrs;
+ guint caches_count;
+
+ /* Runtime configuration */
+ GPtrArray *statfiles; /* struct rspamd_statfile */
+ GPtrArray *classifiers; /* struct rspamd_classifier */
+ GQueue *async_elts; /* struct rspamd_stat_async_elt */
+ struct rspamd_config *cfg;
+
+ gint lua_stat_tokens_ref;
+
+ /* Global tokenizer */
+ struct rspamd_stat_tokenizer *tokenizer;
+ gpointer tkcf;
+
+ struct ev_loop *event_loop;
+};
+
+typedef enum rspamd_learn_cache_result {
+ RSPAMD_LEARN_OK = 0,
+ RSPAMD_LEARN_UNLEARN,
+ RSPAMD_LEARN_IGNORE
+} rspamd_learn_t;
+
+struct rspamd_stat_ctx *rspamd_stat_get_ctx(void);
+
+struct rspamd_stat_classifier *rspamd_stat_get_classifier(const gchar *name);
+
+struct rspamd_stat_backend *rspamd_stat_get_backend(const gchar *name);
+
+struct rspamd_stat_tokenizer *rspamd_stat_get_tokenizer(const gchar *name);
+
+struct rspamd_stat_cache *rspamd_stat_get_cache(const gchar *name);
+
+struct rspamd_stat_async_elt *rspamd_stat_ctx_register_async(
+ rspamd_stat_async_handler handler, rspamd_stat_async_cleanup cleanup,
+ gpointer d, gdouble timeout);
+
+static GQuark rspamd_stat_quark(void)
+{
+ return g_quark_from_static_string("rspamd-statistics");
+}
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif /* STAT_INTERNAL_H_ */
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c
new file mode 100644
index 0000000..8c1d8ff
--- /dev/null
+++ b/src/libstat/stat_process.c
@@ -0,0 +1,1250 @@
+/*-
+ * Copyright 2016 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "config.h"
+#include "stat_api.h"
+#include "rspamd.h"
+#include "stat_internal.h"
+#include "libmime/message.h"
+#include "libmime/images.h"
+#include "libserver/html/html.h"
+#include "lua/lua_common.h"
+#include "libserver/mempool_vars_internal.h"
+#include "utlist.h"
+#include <math.h>
+
+#define RSPAMD_CLASSIFY_OP 0
+#define RSPAMD_LEARN_OP 1
+#define RSPAMD_UNLEARN_OP 2
+
+static const gdouble similarity_threshold = 80.0;
+
+static void
+rspamd_stat_tokenize_parts_metadata(struct rspamd_stat_ctx *st_ctx,
+ struct rspamd_task *task)
+{
+ GArray *ar;
+ rspamd_stat_token_t elt;
+ guint i;
+ lua_State *L = task->cfg->lua_state;
+
+ ar = g_array_sized_new(FALSE, FALSE, sizeof(elt), 16);
+ memset(&elt, 0, sizeof(elt));
+ elt.flags = RSPAMD_STAT_TOKEN_FLAG_META;
+
+ if (st_ctx->lua_stat_tokens_ref != -1) {
+ gint err_idx, ret;
+ struct rspamd_task **ptask;
+
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ err_idx = lua_gettop(L);
+ lua_rawgeti(L, LUA_REGISTRYINDEX, st_ctx->lua_stat_tokens_ref);
+
+ ptask = lua_newuserdata(L, sizeof(*ptask));
+ *ptask = task;
+ rspamd_lua_setclass(L, "rspamd{task}", -1);
+
+ if ((ret = lua_pcall(L, 1, 1, err_idx)) != 0) {
+ msg_err_task("call to stat_tokens lua "
+ "script failed (%d): %s",
+ ret, lua_tostring(L, -1));
+ }
+ else {
+ if (lua_type(L, -1) != LUA_TTABLE) {
+ msg_err_task("stat_tokens invocation must return "
+ "table and not %s",
+ lua_typename(L, lua_type(L, -1)));
+ }
+ else {
+ guint vlen;
+ rspamd_ftok_t tok;
+
+ vlen = rspamd_lua_table_size(L, -1);
+
+ for (i = 0; i < vlen; i++) {
+ lua_rawgeti(L, -1, i + 1);
+ tok.begin = lua_tolstring(L, -1, &tok.len);
+
+ if (tok.begin && tok.len > 0) {
+ elt.original.begin =
+ rspamd_mempool_ftokdup(task->task_pool, &tok);
+ elt.original.len = tok.len;
+ elt.stemmed.begin = elt.original.begin;
+ elt.stemmed.len = elt.original.len;
+ elt.normalized.begin = elt.original.begin;
+ elt.normalized.len = elt.original.len;
+
+ g_array_append_val(ar, elt);
+ }
+
+ lua_pop(L, 1);
+ }
+ }
+ }
+
+ lua_settop(L, 0);
+ }
+
+
+ if (ar->len > 0) {
+ st_ctx->tokenizer->tokenize_func(st_ctx,
+ task,
+ ar,
+ TRUE,
+ "M",
+ task->tokens);
+ }
+
+ rspamd_mempool_add_destructor(task->task_pool,
+ rspamd_array_free_hard, ar);
+}
+
+/*
+ * Tokenize task using the tokenizer specified
+ */
+void rspamd_stat_process_tokenize(struct rspamd_stat_ctx *st_ctx,
+ struct rspamd_task *task)
+{
+ struct rspamd_mime_text_part *part;
+ rspamd_cryptobox_hash_state_t hst;
+ rspamd_token_t *st_tok;
+ guint i, reserved_len = 0;
+ gdouble *pdiff;
+ guchar hout[rspamd_cryptobox_HASHBYTES];
+ gchar *b32_hout;
+
+ if (st_ctx == NULL) {
+ st_ctx = rspamd_stat_get_ctx();
+ }
+
+ g_assert(st_ctx != NULL);
+
+ PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, part)
+ {
+ if (!IS_TEXT_PART_EMPTY(part) && part->utf_words != NULL) {
+ reserved_len += part->utf_words->len;
+ }
+ /* XXX: normal window size */
+ reserved_len += 5;
+ }
+
+ task->tokens = g_ptr_array_sized_new(reserved_len);
+ rspamd_mempool_add_destructor(task->task_pool,
+ rspamd_ptr_array_free_hard, task->tokens);
+ rspamd_mempool_notify_alloc(task->task_pool, reserved_len * sizeof(gpointer));
+ pdiff = rspamd_mempool_get_variable(task->task_pool, "parts_distance");
+
+ PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, part)
+ {
+ if (!IS_TEXT_PART_EMPTY(part) && part->utf_words != NULL) {
+ st_ctx->tokenizer->tokenize_func(st_ctx, task,
+ part->utf_words, IS_TEXT_PART_UTF(part),
+ NULL, task->tokens);
+ }
+
+
+ if (pdiff != NULL && (1.0 - *pdiff) * 100.0 > similarity_threshold) {
+ msg_debug_bayes("message has two common parts (%.2f), so skip the last one",
+ *pdiff);
+ break;
+ }
+ }
+
+ if (task->meta_words != NULL) {
+ st_ctx->tokenizer->tokenize_func(st_ctx,
+ task,
+ task->meta_words,
+ TRUE,
+ "SUBJECT",
+ task->tokens);
+ }
+
+ rspamd_stat_tokenize_parts_metadata(st_ctx, task);
+
+ /* Produce signature */
+ rspamd_cryptobox_hash_init(&hst, NULL, 0);
+
+ PTR_ARRAY_FOREACH(task->tokens, i, st_tok)
+ {
+ rspamd_cryptobox_hash_update(&hst, (guchar *) &st_tok->data,
+ sizeof(st_tok->data));
+ }
+
+ rspamd_cryptobox_hash_final(&hst, hout);
+ b32_hout = rspamd_encode_base32(hout, sizeof(hout), RSPAMD_BASE32_DEFAULT);
+ /*
+ * We need to strip it to 32 characters providing ~160 bits of
+ * hash distribution
+ */
+ b32_hout[32] = '\0';
+ rspamd_mempool_set_variable(task->task_pool, RSPAMD_MEMPOOL_STAT_SIGNATURE,
+ b32_hout, g_free);
+}
+
+static gboolean
+rspamd_stat_classifier_is_skipped(struct rspamd_task *task,
+ struct rspamd_classifier *cl, gboolean is_learn, gboolean is_spam)
+{
+ GList *cur = is_learn ? cl->cfg->learn_conditions : cl->cfg->classify_conditions;
+ lua_State *L = task->cfg->lua_state;
+ gboolean ret = FALSE;
+
+ while (cur) {
+ gint cb_ref = GPOINTER_TO_INT(cur->data);
+ gint old_top = lua_gettop(L);
+ gint nargs;
+
+ lua_rawgeti(L, LUA_REGISTRYINDEX, cb_ref);
+ /* Push task and two booleans: is_spam and is_unlearn */
+ struct rspamd_task **ptask = lua_newuserdata(L, sizeof(*ptask));
+ *ptask = task;
+ rspamd_lua_setclass(L, "rspamd{task}", -1);
+
+ if (is_learn) {
+ lua_pushboolean(L, is_spam);
+ lua_pushboolean(L,
+ task->flags & RSPAMD_TASK_FLAG_UNLEARN ? true : false);
+ nargs = 3;
+ }
+ else {
+ nargs = 1;
+ }
+
+ if (lua_pcall(L, nargs, LUA_MULTRET, 0) != 0) {
+ msg_err_task("call to %s failed: %s",
+ "condition callback",
+ lua_tostring(L, -1));
+ }
+ else {
+ if (lua_isboolean(L, 1)) {
+ if (!lua_toboolean(L, 1)) {
+ ret = TRUE;
+ }
+ }
+
+ if (lua_isstring(L, 2)) {
+ if (ret) {
+ msg_notice_task("%s condition for classifier %s returned: %s; skip classifier",
+ is_learn ? "learn" : "classify", cl->cfg->name,
+ lua_tostring(L, 2));
+ }
+ else {
+ msg_info_task("%s condition for classifier %s returned: %s",
+ is_learn ? "learn" : "classify", cl->cfg->name,
+ lua_tostring(L, 2));
+ }
+ }
+ else if (ret) {
+ msg_notice_task("%s condition for classifier %s returned false; skip classifier",
+ is_learn ? "learn" : "classify", cl->cfg->name);
+ }
+
+ if (ret) {
+ lua_settop(L, old_top);
+ break;
+ }
+ }
+
+ lua_settop(L, old_top);
+ cur = g_list_next(cur);
+ }
+
+ return ret;
+}
+
+static void
+rspamd_stat_preprocess(struct rspamd_stat_ctx *st_ctx,
+ struct rspamd_task *task, gboolean is_learn, gboolean is_spam)
+{
+ guint i;
+ struct rspamd_statfile *st;
+ gpointer bk_run;
+
+ if (task->tokens == NULL) {
+ rspamd_stat_process_tokenize(st_ctx, task);
+ }
+
+ task->stat_runtimes = g_ptr_array_sized_new(st_ctx->statfiles->len);
+ g_ptr_array_set_size(task->stat_runtimes, st_ctx->statfiles->len);
+ rspamd_mempool_add_destructor(task->task_pool,
+ rspamd_ptr_array_free_hard, task->stat_runtimes);
+
+ /* Temporary set all stat_runtimes to some max size to distinguish from NULL */
+ for (i = 0; i < st_ctx->statfiles->len; i++) {
+ g_ptr_array_index(task->stat_runtimes, i) = GSIZE_TO_POINTER(G_MAXSIZE);
+ }
+
+ for (i = 0; i < st_ctx->classifiers->len; i++) {
+ struct rspamd_classifier *cl = g_ptr_array_index(st_ctx->classifiers, i);
+ gboolean skip_classifier = FALSE;
+
+ if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
+ skip_classifier = TRUE;
+ }
+ else {
+ if (rspamd_stat_classifier_is_skipped(task, cl, is_learn, is_spam)) {
+ skip_classifier = TRUE;
+ }
+ }
+
+ if (skip_classifier) {
+ /* Set NULL for all statfiles indexed by id */
+ for (int j = 0; j < cl->statfiles_ids->len; j++) {
+ int id = g_array_index(cl->statfiles_ids, gint, j);
+ g_ptr_array_index(task->stat_runtimes, id) = NULL;
+ }
+ }
+ }
+
+ for (i = 0; i < st_ctx->statfiles->len; i++) {
+ st = g_ptr_array_index(st_ctx->statfiles, i);
+ g_assert(st != NULL);
+
+ if (g_ptr_array_index(task->stat_runtimes, i) == NULL) {
+ /* The whole classifier is skipped */
+ continue;
+ }
+
+ if (is_learn && st->backend->read_only) {
+ /* Read only backend, skip it */
+ g_ptr_array_index(task->stat_runtimes, i) = NULL;
+ continue;
+ }
+
+ if (!is_learn && !rspamd_symcache_is_symbol_enabled(task, task->cfg->cache,
+ st->stcf->symbol)) {
+ g_ptr_array_index(task->stat_runtimes, i) = NULL;
+ msg_debug_bayes("symbol %s is disabled, skip classification",
+ st->stcf->symbol);
+ continue;
+ }
+
+ bk_run = st->backend->runtime(task, st->stcf, is_learn, st->bkcf, i);
+
+ if (bk_run == NULL) {
+ msg_err_task("cannot init backend %s for statfile %s",
+ st->backend->name, st->stcf->symbol);
+ }
+
+ g_ptr_array_index(task->stat_runtimes, i) = bk_run;
+ }
+}
+
+static void
+rspamd_stat_backends_process(struct rspamd_stat_ctx *st_ctx,
+ struct rspamd_task *task)
+{
+ guint i;
+ struct rspamd_statfile *st;
+ gpointer bk_run;
+
+ g_assert(task->stat_runtimes != NULL);
+
+ for (i = 0; i < st_ctx->statfiles->len; i++) {
+ st = g_ptr_array_index(st_ctx->statfiles, i);
+ bk_run = g_ptr_array_index(task->stat_runtimes, i);
+
+ if (bk_run != NULL) {
+ st->backend->process_tokens(task, task->tokens, i, bk_run);
+ }
+ }
+}
+
+static void
+rspamd_stat_classifiers_process(struct rspamd_stat_ctx *st_ctx,
+ struct rspamd_task *task)
+{
+ guint i, j, id;
+ struct rspamd_classifier *cl;
+ struct rspamd_statfile *st;
+ gpointer bk_run;
+ gboolean skip;
+
+ if (st_ctx->classifiers->len == 0) {
+ return;
+ }
+
+ /*
+ * Do not classify a message if some class is missing
+ */
+ if (!(task->flags & RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS)) {
+ msg_info_task("skip statistics as SPAM class is missing");
+
+ return;
+ }
+ if (!(task->flags & RSPAMD_TASK_FLAG_HAS_HAM_TOKENS)) {
+ msg_info_task("skip statistics as HAM class is missing");
+
+ return;
+ }
+
+ for (i = 0; i < st_ctx->classifiers->len; i++) {
+ cl = g_ptr_array_index(st_ctx->classifiers, i);
+ cl->spam_learns = 0;
+ cl->ham_learns = 0;
+ }
+
+ g_assert(task->stat_runtimes != NULL);
+
+ for (i = 0; i < st_ctx->statfiles->len; i++) {
+ st = g_ptr_array_index(st_ctx->statfiles, i);
+ cl = st->classifier;
+
+ bk_run = g_ptr_array_index(task->stat_runtimes, i);
+ g_assert(st != NULL);
+
+ if (bk_run != NULL) {
+ if (st->stcf->is_spam) {
+ cl->spam_learns += st->backend->total_learns(task,
+ bk_run,
+ st_ctx);
+ }
+ else {
+ cl->ham_learns += st->backend->total_learns(task,
+ bk_run,
+ st_ctx);
+ }
+ }
+ }
+
+ for (i = 0; i < st_ctx->classifiers->len; i++) {
+ cl = g_ptr_array_index(st_ctx->classifiers, i);
+
+ g_assert(cl != NULL);
+
+ skip = FALSE;
+
+ /* Do not process classifiers on backend failures */
+ for (j = 0; j < cl->statfiles_ids->len; j++) {
+ id = g_array_index(cl->statfiles_ids, gint, j);
+ bk_run = g_ptr_array_index(task->stat_runtimes, id);
+ st = g_ptr_array_index(st_ctx->statfiles, id);
+
+ if (bk_run != NULL) {
+ if (!st->backend->finalize_process(task, bk_run, st_ctx)) {
+ skip = TRUE;
+ break;
+ }
+ }
+ }
+
+ /* Ensure that all symbols enabled */
+ if (!skip && !(cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND)) {
+ for (j = 0; j < cl->statfiles_ids->len; j++) {
+ id = g_array_index(cl->statfiles_ids, gint, j);
+ bk_run = g_ptr_array_index(task->stat_runtimes, id);
+ st = g_ptr_array_index(st_ctx->statfiles, id);
+
+ if (bk_run == NULL) {
+ skip = TRUE;
+ msg_debug_bayes("disable classifier %s as statfile symbol %s is disabled",
+ cl->cfg->name, st->stcf->symbol);
+ break;
+ }
+ }
+ }
+
+ if (!skip) {
+ if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) {
+ msg_debug_bayes(
+ "contains less tokens than required for %s classifier: "
+ "%ud < %ud",
+ cl->cfg->name,
+ task->tokens->len,
+ cl->cfg->min_tokens);
+ continue;
+ }
+ else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) {
+ msg_debug_bayes(
+ "contains more tokens than allowed for %s classifier: "
+ "%ud > %ud",
+ cl->cfg->name,
+ task->tokens->len,
+ cl->cfg->max_tokens);
+ continue;
+ }
+
+ cl->subrs->classify_func(cl, task->tokens, task);
+ }
+ }
+}
+
+rspamd_stat_result_t
+rspamd_stat_classify(struct rspamd_task *task, lua_State *L, guint stage,
+ GError **err)
+{
+ struct rspamd_stat_ctx *st_ctx;
+ rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_OK;
+
+ st_ctx = rspamd_stat_get_ctx();
+ g_assert(st_ctx != NULL);
+
+ if (st_ctx->classifiers->len == 0) {
+ task->processed_stages |= stage;
+ return ret;
+ }
+
+ if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS_PRE) {
+ /* Preprocess tokens */
+ rspamd_stat_preprocess(st_ctx, task, FALSE, FALSE);
+ }
+ else if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS) {
+ /* Process backends */
+ rspamd_stat_backends_process(st_ctx, task);
+ }
+ else if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS_POST) {
+ /* Process classifiers */
+ rspamd_stat_classifiers_process(st_ctx, task);
+ }
+
+ task->processed_stages |= stage;
+
+ return ret;
+}
+
+static gboolean
+rspamd_stat_cache_check(struct rspamd_stat_ctx *st_ctx,
+ struct rspamd_task *task,
+ const gchar *classifier,
+ gboolean spam,
+ GError **err)
+{
+ rspamd_learn_t learn_res = RSPAMD_LEARN_OK;
+ struct rspamd_classifier *cl, *sel = NULL;
+ gpointer rt;
+ guint i;
+
+ /* Check whether we have learned that file */
+ for (i = 0; i < st_ctx->classifiers->len; i++) {
+ cl = g_ptr_array_index(st_ctx->classifiers, i);
+
+ /* Skip other classifiers if they are not needed */
+ if (classifier != NULL && (cl->cfg->name == NULL ||
+ g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) {
+ continue;
+ }
+
+ sel = cl;
+
+ if (sel->cache && sel->cachecf) {
+ rt = cl->cache->runtime(task, sel->cachecf, FALSE);
+ learn_res = cl->cache->check(task, spam, rt);
+ }
+
+ if (learn_res == RSPAMD_LEARN_IGNORE) {
+ /* Do not learn twice */
+ g_set_error(err, rspamd_stat_quark(), 404, "<%s> has been already "
+ "learned as %s, ignore it",
+ MESSAGE_FIELD(task, message_id),
+ spam ? "spam" : "ham");
+ task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED;
+
+ return FALSE;
+ }
+ else if (learn_res == RSPAMD_LEARN_UNLEARN) {
+ task->flags |= RSPAMD_TASK_FLAG_UNLEARN;
+ break;
+ }
+ }
+
+ if (sel == NULL) {
+ if (classifier) {
+ g_set_error(err, rspamd_stat_quark(), 404, "cannot find classifier "
+ "with name %s",
+ classifier);
+ }
+ else {
+ g_set_error(err, rspamd_stat_quark(), 404, "no classifiers defined");
+ }
+
+ return FALSE;
+ }
+
+ return TRUE;
+}
+
+static gboolean
+rspamd_stat_classifiers_learn(struct rspamd_stat_ctx *st_ctx,
+ struct rspamd_task *task,
+ const gchar *classifier,
+ gboolean spam,
+ GError **err)
+{
+ struct rspamd_classifier *cl, *sel = NULL;
+ guint i;
+ gboolean learned = FALSE, too_small = FALSE, too_large = FALSE;
+
+ if ((task->flags & RSPAMD_TASK_FLAG_ALREADY_LEARNED) && err != NULL &&
+ *err == NULL) {
+ /* Do not learn twice */
+ g_set_error(err, rspamd_stat_quark(), 208, "<%s> has been already "
+ "learned as %s, ignore it",
+ MESSAGE_FIELD(task, message_id),
+ spam ? "spam" : "ham");
+
+ return FALSE;
+ }
+
+ /* Check whether we have learned that file */
+ for (i = 0; i < st_ctx->classifiers->len; i++) {
+ cl = g_ptr_array_index(st_ctx->classifiers, i);
+
+ /* Skip other classifiers if they are not needed */
+ if (classifier != NULL && (cl->cfg->name == NULL ||
+ g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) {
+ continue;
+ }
+
+ sel = cl;
+
+ /* Now check max and min tokens */
+ if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) {
+ msg_info_task(
+ "<%s> contains less tokens than required for %s classifier: "
+ "%ud < %ud",
+ MESSAGE_FIELD(task, message_id),
+ cl->cfg->name,
+ task->tokens->len,
+ cl->cfg->min_tokens);
+ too_small = TRUE;
+ continue;
+ }
+ else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) {
+ msg_info_task(
+ "<%s> contains more tokens than allowed for %s classifier: "
+ "%ud > %ud",
+ MESSAGE_FIELD(task, message_id),
+ cl->cfg->name,
+ task->tokens->len,
+ cl->cfg->max_tokens);
+ too_large = TRUE;
+ continue;
+ }
+
+ if (cl->subrs->learn_spam_func(cl, task->tokens, task, spam,
+ task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
+ learned = TRUE;
+ }
+ }
+
+ if (sel == NULL) {
+ if (classifier) {
+ g_set_error(err, rspamd_stat_quark(), 404, "cannot find classifier "
+ "with name %s",
+ classifier);
+ }
+ else {
+ g_set_error(err, rspamd_stat_quark(), 404, "no classifiers defined");
+ }
+
+ return FALSE;
+ }
+
+ if (!learned && err && *err == NULL) {
+ if (too_large) {
+ g_set_error(err, rspamd_stat_quark(), 204,
+ "<%s> contains more tokens than allowed for %s classifier: "
+ "%d > %d",
+ MESSAGE_FIELD(task, message_id),
+ sel->cfg->name,
+ task->tokens->len,
+ sel->cfg->max_tokens);
+ }
+ else if (too_small) {
+ g_set_error(err, rspamd_stat_quark(), 204,
+ "<%s> contains less tokens than required for %s classifier: "
+ "%d < %d",
+ MESSAGE_FIELD(task, message_id),
+ sel->cfg->name,
+ task->tokens->len,
+ sel->cfg->min_tokens);
+ }
+ }
+
+ return learned;
+}
+
+static gboolean
+rspamd_stat_backends_learn(struct rspamd_stat_ctx *st_ctx,
+ struct rspamd_task *task,
+ const gchar *classifier,
+ gboolean spam,
+ GError **err)
+{
+ struct rspamd_classifier *cl, *sel = NULL;
+ struct rspamd_statfile *st;
+ gpointer bk_run;
+ guint i, j;
+ gint id;
+ gboolean res = FALSE, backend_found = FALSE;
+
+ for (i = 0; i < st_ctx->classifiers->len; i++) {
+ cl = g_ptr_array_index(st_ctx->classifiers, i);
+
+ /* Skip other classifiers if they are not needed */
+ if (classifier != NULL && (cl->cfg->name == NULL ||
+ g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) {
+ continue;
+ }
+
+ if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
+ res = TRUE;
+ continue;
+ }
+
+ sel = cl;
+
+ for (j = 0; j < cl->statfiles_ids->len; j++) {
+ id = g_array_index(cl->statfiles_ids, gint, j);
+ st = g_ptr_array_index(st_ctx->statfiles, id);
+ bk_run = g_ptr_array_index(task->stat_runtimes, id);
+
+ g_assert(st != NULL);
+
+ if (bk_run == NULL) {
+ /* XXX: must be error */
+ if (task->result->passthrough_result) {
+ /* Passthrough email, cannot learn */
+ g_set_error(err, rspamd_stat_quark(), 204,
+ "Cannot learn statistics when passthrough "
+ "result has been set; not classified");
+
+ res = FALSE;
+ goto end;
+ }
+
+ msg_debug_task("no runtime for backend %s; classifier %s; symbol %s",
+ st->backend->name, cl->cfg->name, st->stcf->symbol);
+ continue;
+ }
+
+ /* We set sel merely when we have runtime */
+ backend_found = TRUE;
+
+ if (!(task->flags & RSPAMD_TASK_FLAG_UNLEARN)) {
+ if (!!spam != !!st->stcf->is_spam) {
+ /* If we are not unlearning, then do not touch another class */
+ continue;
+ }
+ }
+
+ if (!st->backend->learn_tokens(task, task->tokens, id, bk_run)) {
+ g_set_error(err, rspamd_stat_quark(), 500,
+ "Cannot push "
+ "learned results to the backend");
+
+ res = FALSE;
+ goto end;
+ }
+ else {
+ if (!!spam == !!st->stcf->is_spam) {
+ st->backend->inc_learns(task, bk_run, st_ctx);
+ }
+ else if (task->flags & RSPAMD_TASK_FLAG_UNLEARN) {
+ st->backend->dec_learns(task, bk_run, st_ctx);
+ }
+
+ res = TRUE;
+ }
+ }
+ }
+
+end:
+
+ if (!res) {
+ if (err && *err) {
+ /* Error has been set already */
+ return res;
+ }
+
+ if (sel == NULL) {
+ if (classifier) {
+ g_set_error(err, rspamd_stat_quark(), 404, "cannot find classifier "
+ "with name %s",
+ classifier);
+ }
+ else {
+ g_set_error(err, rspamd_stat_quark(), 404, "no classifiers defined");
+ }
+
+ return FALSE;
+ }
+ else if (!backend_found) {
+ g_set_error(err, rspamd_stat_quark(), 204, "all learn conditions "
+ "denied learning %s in %s",
+ spam ? "spam" : "ham",
+ classifier ? classifier : "default classifier");
+ }
+ else {
+ g_set_error(err, rspamd_stat_quark(), 404, "cannot find statfile "
+ "backend to learn %s in %s",
+ spam ? "spam" : "ham",
+ classifier ? classifier : "default classifier");
+ }
+ }
+
+ return res;
+}
+
+static gboolean
+rspamd_stat_backends_post_learn(struct rspamd_stat_ctx *st_ctx,
+ struct rspamd_task *task,
+ const gchar *classifier,
+ gboolean spam,
+ GError **err)
+{
+ struct rspamd_classifier *cl;
+ struct rspamd_statfile *st;
+ gpointer bk_run, cache_run;
+ guint i, j;
+ gint id;
+ gboolean res = TRUE;
+
+ for (i = 0; i < st_ctx->classifiers->len; i++) {
+ cl = g_ptr_array_index(st_ctx->classifiers, i);
+
+ /* Skip other classifiers if they are not needed */
+ if (classifier != NULL && (cl->cfg->name == NULL ||
+ g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) {
+ continue;
+ }
+
+ if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
+ res = TRUE;
+ continue;
+ }
+
+ for (j = 0; j < cl->statfiles_ids->len; j++) {
+ id = g_array_index(cl->statfiles_ids, gint, j);
+ st = g_ptr_array_index(st_ctx->statfiles, id);
+ bk_run = g_ptr_array_index(task->stat_runtimes, id);
+
+ g_assert(st != NULL);
+
+ if (bk_run == NULL) {
+ /* XXX: must be error */
+ continue;
+ }
+
+ if (!st->backend->finalize_learn(task, bk_run, st_ctx, err)) {
+ return RSPAMD_STAT_PROCESS_ERROR;
+ }
+ }
+
+ if (cl->cache) {
+ cache_run = cl->cache->runtime(task, cl->cachecf, TRUE);
+ cl->cache->learn(task, spam, cache_run);
+ }
+ }
+
+ g_atomic_int_add(&task->worker->srv->stat->messages_learned, 1);
+
+ return res;
+}
+
+rspamd_stat_result_t
+rspamd_stat_learn(struct rspamd_task *task,
+ gboolean spam, lua_State *L, const gchar *classifier, guint stage,
+ GError **err)
+{
+ struct rspamd_stat_ctx *st_ctx;
+ rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_OK;
+
+ /*
+ * We assume now that a task has been already classified before
+ * coming to learn
+ */
+ g_assert(RSPAMD_TASK_IS_CLASSIFIED(task));
+
+ st_ctx = rspamd_stat_get_ctx();
+ g_assert(st_ctx != NULL);
+
+ if (st_ctx->classifiers->len == 0) {
+ task->processed_stages |= stage;
+ return ret;
+ }
+
+ if (stage == RSPAMD_TASK_STAGE_LEARN_PRE) {
+ /* Process classifiers */
+ rspamd_stat_preprocess(st_ctx, task, TRUE, spam);
+
+ if (!rspamd_stat_cache_check(st_ctx, task, classifier, spam, err)) {
+ return RSPAMD_STAT_PROCESS_ERROR;
+ }
+ }
+ else if (stage == RSPAMD_TASK_STAGE_LEARN) {
+ /* Process classifiers */
+ if (!rspamd_stat_classifiers_learn(st_ctx, task, classifier,
+ spam, err)) {
+ if (err && *err == NULL) {
+ g_set_error(err, rspamd_stat_quark(), 500,
+ "Unknown statistics error, found when learning classifiers;"
+ " classifier: %s",
+ task->classifier);
+ }
+ return RSPAMD_STAT_PROCESS_ERROR;
+ }
+
+ /* Process backends */
+ if (!rspamd_stat_backends_learn(st_ctx, task, classifier, spam, err)) {
+ if (err && *err == NULL) {
+ g_set_error(err, rspamd_stat_quark(), 500,
+ "Unknown statistics error, found when storing data on backend;"
+ " classifier: %s",
+ task->classifier);
+ }
+ return RSPAMD_STAT_PROCESS_ERROR;
+ }
+ }
+ else if (stage == RSPAMD_TASK_STAGE_LEARN_POST) {
+ if (!rspamd_stat_backends_post_learn(st_ctx, task, classifier, spam, err)) {
+ return RSPAMD_STAT_PROCESS_ERROR;
+ }
+ }
+
+ task->processed_stages |= stage;
+
+ return ret;
+}
+
+static gboolean
+rspamd_stat_has_classifier_symbols(struct rspamd_task *task,
+ struct rspamd_scan_result *mres,
+ struct rspamd_classifier *cl)
+{
+ guint i;
+ gint id;
+ struct rspamd_statfile *st;
+ struct rspamd_stat_ctx *st_ctx;
+ gboolean is_spam;
+
+ if (mres == NULL) {
+ return FALSE;
+ }
+
+ st_ctx = rspamd_stat_get_ctx();
+ is_spam = !!(task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM);
+
+ for (i = 0; i < cl->statfiles_ids->len; i++) {
+ id = g_array_index(cl->statfiles_ids, gint, i);
+ st = g_ptr_array_index(st_ctx->statfiles, id);
+
+ if (rspamd_task_find_symbol_result(task, st->stcf->symbol, NULL)) {
+ if (is_spam == !!st->stcf->is_spam) {
+ msg_debug_bayes("do not autolearn %s as symbol %s is already "
+ "added",
+ is_spam ? "spam" : "ham", st->stcf->symbol);
+
+ return TRUE;
+ }
+ }
+ }
+
+ return FALSE;
+}
+
+gboolean
+rspamd_stat_check_autolearn(struct rspamd_task *task)
+{
+ struct rspamd_stat_ctx *st_ctx;
+ struct rspamd_classifier *cl;
+ const ucl_object_t *obj, *elt1, *elt2;
+ struct rspamd_scan_result *mres = NULL;
+ struct rspamd_task **ptask;
+ lua_State *L;
+ guint i;
+ gint err_idx;
+ gboolean ret = FALSE;
+ gdouble ham_score, spam_score;
+ const gchar *lua_script, *lua_ret;
+
+ g_assert(RSPAMD_TASK_IS_CLASSIFIED(task));
+ st_ctx = rspamd_stat_get_ctx();
+ g_assert(st_ctx != NULL);
+
+ L = task->cfg->lua_state;
+
+ for (i = 0; i < st_ctx->classifiers->len; i++) {
+ cl = g_ptr_array_index(st_ctx->classifiers, i);
+ ret = FALSE;
+
+ if (cl->cfg->opts) {
+ obj = ucl_object_lookup(cl->cfg->opts, "autolearn");
+
+ if (ucl_object_type(obj) == UCL_BOOLEAN) {
+ /* Legacy true/false */
+ if (ucl_object_toboolean(obj)) {
+ /*
+ * Default learning algorithm:
+ *
+ * - We learn spam if action is ACTION_REJECT
+ * - We learn ham if score is less than zero
+ */
+ mres = task->result;
+
+ if (mres) {
+ if (mres->score > rspamd_task_get_required_score(task, mres)) {
+ task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
+
+ ret = TRUE;
+ }
+ else if (mres->score < 0) {
+ task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+ ret = TRUE;
+ }
+ }
+ }
+ }
+ else if (ucl_object_type(obj) == UCL_ARRAY && obj->len == 2) {
+ /* Legacy thresholds */
+ /*
+ * We have an array of 2 elements, treat it as a
+ * ham_score, spam_score
+ */
+ elt1 = ucl_array_find_index(obj, 0);
+ elt2 = ucl_array_find_index(obj, 1);
+
+ if ((ucl_object_type(elt1) == UCL_FLOAT ||
+ ucl_object_type(elt1) == UCL_INT) &&
+ (ucl_object_type(elt2) == UCL_FLOAT ||
+ ucl_object_type(elt2) == UCL_INT)) {
+ ham_score = ucl_object_todouble(elt1);
+ spam_score = ucl_object_todouble(elt2);
+
+ if (ham_score > spam_score) {
+ gdouble t;
+
+ t = ham_score;
+ ham_score = spam_score;
+ spam_score = t;
+ }
+
+ mres = task->result;
+
+ if (mres) {
+ if (mres->score >= spam_score) {
+ task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
+
+ ret = TRUE;
+ }
+ else if (mres->score <= ham_score) {
+ task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+ ret = TRUE;
+ }
+ }
+ }
+ }
+ else if (ucl_object_type(obj) == UCL_STRING) {
+ /* Legacy script */
+ lua_script = ucl_object_tostring(obj);
+
+ if (luaL_dostring(L, lua_script) != 0) {
+ msg_err_task("cannot execute lua script for autolearn "
+ "extraction: %s",
+ lua_tostring(L, -1));
+ }
+ else {
+ if (lua_type(L, -1) == LUA_TFUNCTION) {
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ err_idx = lua_gettop(L);
+ lua_pushvalue(L, -2); /* Function itself */
+
+ ptask = lua_newuserdata(L, sizeof(struct rspamd_task *));
+ *ptask = task;
+ rspamd_lua_setclass(L, "rspamd{task}", -1);
+
+ if (lua_pcall(L, 1, 1, err_idx) != 0) {
+ msg_err_task("call to autolearn script failed: "
+ "%s",
+ lua_tostring(L, -1));
+ }
+ else {
+ lua_ret = lua_tostring(L, -1);
+
+ /* We can have immediate results */
+ if (lua_ret) {
+ if (strcmp(lua_ret, "ham") == 0) {
+ task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+ ret = TRUE;
+ }
+ else if (strcmp(lua_ret, "spam") == 0) {
+ task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
+ ret = TRUE;
+ }
+ }
+ }
+
+ /* Result + error function + original function */
+ lua_pop(L, 3);
+ }
+ else {
+ msg_err_task("lua script must return "
+ "function(task) and not %s",
+ lua_typename(L, lua_type(
+ L, -1)));
+ }
+ }
+ }
+ else if (ucl_object_type(obj) == UCL_OBJECT) {
+ /* Try to find autolearn callback */
+ if (cl->autolearn_cbref == 0) {
+ /* We don't have preprocessed cb id, so try to get it */
+ if (!rspamd_lua_require_function(L, "lua_bayes_learn",
+ "autolearn")) {
+ msg_err_task("cannot get autolearn library from "
+ "`lua_bayes_learn`");
+ }
+ else {
+ cl->autolearn_cbref = luaL_ref(L, LUA_REGISTRYINDEX);
+ }
+ }
+
+ if (cl->autolearn_cbref != -1) {
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ err_idx = lua_gettop(L);
+ lua_rawgeti(L, LUA_REGISTRYINDEX, cl->autolearn_cbref);
+
+ ptask = lua_newuserdata(L, sizeof(struct rspamd_task *));
+ *ptask = task;
+ rspamd_lua_setclass(L, "rspamd{task}", -1);
+ /* Push the whole object as well */
+ ucl_object_push_lua(L, obj, true);
+
+ if (lua_pcall(L, 2, 1, err_idx) != 0) {
+ msg_err_task("call to autolearn script failed: "
+ "%s",
+ lua_tostring(L, -1));
+ }
+ else {
+ lua_ret = lua_tostring(L, -1);
+
+ if (lua_ret) {
+ if (strcmp(lua_ret, "ham") == 0) {
+ task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+ ret = TRUE;
+ }
+ else if (strcmp(lua_ret, "spam") == 0) {
+ task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
+ ret = TRUE;
+ }
+ }
+ }
+
+ lua_settop(L, err_idx - 1);
+ }
+ }
+
+ if (ret) {
+ /* Do not autolearn if we have this symbol already */
+ if (rspamd_stat_has_classifier_symbols(task, mres, cl)) {
+ ret = FALSE;
+ task->flags &= ~(RSPAMD_TASK_FLAG_LEARN_HAM |
+ RSPAMD_TASK_FLAG_LEARN_SPAM);
+ }
+ else if (mres != NULL) {
+ if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) {
+ msg_info_task("<%s>: autolearn ham for classifier "
+ "'%s' as message's "
+ "score is negative: %.2f",
+ MESSAGE_FIELD(task, message_id), cl->cfg->name,
+ mres->score);
+ }
+ else {
+ msg_info_task("<%s>: autolearn spam for classifier "
+ "'%s' as message's "
+ "action is reject, score: %.2f",
+ MESSAGE_FIELD(task, message_id), cl->cfg->name,
+ mres->score);
+ }
+
+ task->classifier = cl->cfg->name;
+ break;
+ }
+ }
+ }
+ }
+
+ return ret;
+}
+
+/**
+ * Get the overall statistics for all statfile backends
+ * @param cfg configuration
+ * @param total_learns the total number of learns is stored here
+ * @return array of statistical information
+ */
+rspamd_stat_result_t
+rspamd_stat_statistics(struct rspamd_task *task,
+ struct rspamd_config *cfg,
+ guint64 *total_learns,
+ ucl_object_t **target)
+{
+ struct rspamd_stat_ctx *st_ctx;
+ struct rspamd_classifier *cl;
+ struct rspamd_statfile *st;
+ gpointer backend_runtime;
+ ucl_object_t *res = NULL, *elt;
+ guint64 learns = 0;
+ guint i, j;
+ gint id;
+
+ st_ctx = rspamd_stat_get_ctx();
+ g_assert(st_ctx != NULL);
+
+ res = ucl_object_typed_new(UCL_ARRAY);
+
+ for (i = 0; i < st_ctx->classifiers->len; i++) {
+ cl = g_ptr_array_index(st_ctx->classifiers, i);
+
+ if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
+ continue;
+ }
+
+ for (j = 0; j < cl->statfiles_ids->len; j++) {
+ id = g_array_index(cl->statfiles_ids, gint, j);
+ st = g_ptr_array_index(st_ctx->statfiles, id);
+ backend_runtime = st->backend->runtime(task, st->stcf, FALSE,
+ st->bkcf, id);
+ elt = st->backend->get_stat(backend_runtime, st->bkcf);
+
+ if (elt && ucl_object_type(elt) == UCL_OBJECT) {
+ const ucl_object_t *rev = ucl_object_lookup(elt, "revision");
+
+ learns += ucl_object_toint(rev);
+ }
+ else {
+ learns += st->backend->total_learns(task, backend_runtime,
+ st->bkcf);
+ }
+
+ if (elt != NULL) {
+ ucl_array_append(res, elt);
+ }
+ }
+ }
+
+ if (total_learns != NULL) {
+ *total_learns = learns;
+ }
+
+ if (target) {
+ *target = res;
+ }
+ else {
+ ucl_object_unref(res);
+ }
+
+ return RSPAMD_STAT_PROCESS_OK;
+}
diff --git a/src/libstat/tokenizers/osb.c b/src/libstat/tokenizers/osb.c
new file mode 100644
index 0000000..d871c7a
--- /dev/null
+++ b/src/libstat/tokenizers/osb.c
@@ -0,0 +1,424 @@
+/*-
+ * Copyright 2016 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/*
+ * OSB tokenizer
+ */
+
+
+#include "tokenizers.h"
+#include "stat_internal.h"
+#include "libmime/lang_detection.h"
+
+/* Size for features pipe */
+#define DEFAULT_FEATURE_WINDOW_SIZE 5
+#define DEFAULT_OSB_VERSION 2
+
+static const int primes[] = {
+ 1,
+ 7,
+ 3,
+ 13,
+ 5,
+ 29,
+ 11,
+ 51,
+ 23,
+ 101,
+ 47,
+ 203,
+ 97,
+ 407,
+ 197,
+ 817,
+ 397,
+ 1637,
+ 797,
+ 3277,
+};
+
+static const guchar osb_tokenizer_magic[] = {'o', 's', 'b', 't', 'o', 'k', 'v', '2'};
+
+enum rspamd_osb_hash_type {
+ RSPAMD_OSB_HASH_COMPAT = 0,
+ RSPAMD_OSB_HASH_XXHASH,
+ RSPAMD_OSB_HASH_SIPHASH
+};
+
+struct rspamd_osb_tokenizer_config {
+ guchar magic[8];
+ gshort version;
+ gshort window_size;
+ enum rspamd_osb_hash_type ht;
+ guint64 seed;
+ rspamd_sipkey_t sk;
+};
+
+/*
+ * Return default config
+ */
+static struct rspamd_osb_tokenizer_config *
+rspamd_tokenizer_osb_default_config(void)
+{
+ static struct rspamd_osb_tokenizer_config def;
+
+ if (memcmp(def.magic, osb_tokenizer_magic, sizeof(osb_tokenizer_magic)) != 0) {
+ memset(&def, 0, sizeof(def));
+ memcpy(def.magic, osb_tokenizer_magic, sizeof(osb_tokenizer_magic));
+ def.version = DEFAULT_OSB_VERSION;
+ def.window_size = DEFAULT_FEATURE_WINDOW_SIZE;
+ def.ht = RSPAMD_OSB_HASH_XXHASH;
+ def.seed = 0xdeadbabe;
+ }
+
+ return &def;
+}
+
+static struct rspamd_osb_tokenizer_config *
+rspamd_tokenizer_osb_config_from_ucl(rspamd_mempool_t *pool,
+ const ucl_object_t *obj)
+{
+ const ucl_object_t *elt;
+ struct rspamd_osb_tokenizer_config *cf, *def;
+ guchar *key = NULL;
+ gsize keylen;
+
+
+ if (pool != NULL) {
+ cf = rspamd_mempool_alloc0(pool, sizeof(*cf));
+ }
+ else {
+ cf = g_malloc0(sizeof(*cf));
+ }
+
+ /* Use default config */
+ def = rspamd_tokenizer_osb_default_config();
+ memcpy(cf, def, sizeof(*cf));
+
+ elt = ucl_object_lookup(obj, "hash");
+ if (elt != NULL && ucl_object_type(elt) == UCL_STRING) {
+ if (g_ascii_strncasecmp(ucl_object_tostring(elt), "xxh", 3) == 0) {
+ cf->ht = RSPAMD_OSB_HASH_XXHASH;
+ elt = ucl_object_lookup(obj, "seed");
+ if (elt != NULL && ucl_object_type(elt) == UCL_INT) {
+ cf->seed = ucl_object_toint(elt);
+ }
+ }
+ else if (g_ascii_strncasecmp(ucl_object_tostring(elt), "sip", 3) == 0) {
+ cf->ht = RSPAMD_OSB_HASH_SIPHASH;
+ elt = ucl_object_lookup(obj, "key");
+
+ if (elt != NULL && ucl_object_type(elt) == UCL_STRING) {
+ key = rspamd_decode_base32(ucl_object_tostring(elt),
+ 0, &keylen, RSPAMD_BASE32_DEFAULT);
+ if (keylen < sizeof(rspamd_sipkey_t)) {
+ msg_warn("siphash key is too short: %z", keylen);
+ g_free(key);
+ }
+ else {
+ memcpy(cf->sk, key, sizeof(cf->sk));
+ g_free(key);
+ }
+ }
+ else {
+ msg_warn_pool("siphash cannot be used without key");
+ }
+ }
+ }
+ else {
+ elt = ucl_object_lookup(obj, "compat");
+ if (elt != NULL && ucl_object_toboolean(elt)) {
+ cf->ht = RSPAMD_OSB_HASH_COMPAT;
+ }
+ }
+
+ elt = ucl_object_lookup(obj, "window");
+ if (elt != NULL && ucl_object_type(elt) == UCL_INT) {
+ cf->window_size = ucl_object_toint(elt);
+ if (cf->window_size > DEFAULT_FEATURE_WINDOW_SIZE * 4) {
+ msg_err_pool("too large window size: %d", cf->window_size);
+ cf->window_size = DEFAULT_FEATURE_WINDOW_SIZE;
+ }
+ }
+
+ return cf;
+}
+
+gpointer
+rspamd_tokenizer_osb_get_config(rspamd_mempool_t *pool,
+ struct rspamd_tokenizer_config *cf,
+ gsize *len)
+{
+ struct rspamd_osb_tokenizer_config *osb_cf, *def;
+
+ if (cf != NULL && cf->opts != NULL) {
+ osb_cf = rspamd_tokenizer_osb_config_from_ucl(pool, cf->opts);
+ }
+ else {
+ def = rspamd_tokenizer_osb_default_config();
+ osb_cf = rspamd_mempool_alloc(pool, sizeof(*osb_cf));
+ memcpy(osb_cf, def, sizeof(*osb_cf));
+ /* Do not write sipkey to statfile */
+ }
+
+ if (osb_cf->ht == RSPAMD_OSB_HASH_SIPHASH) {
+ msg_info_pool("siphash key is not stored into statfiles, so you'd "
+ "need to keep it inside the configuration");
+ }
+
+ memset(osb_cf->sk, 0, sizeof(osb_cf->sk));
+
+ if (len != NULL) {
+ *len = sizeof(*osb_cf);
+ }
+
+ return osb_cf;
+}
+
+#if 0
+gboolean
+rspamd_tokenizer_osb_compatible_config (struct rspamd_tokenizer_runtime *rt,
+ gpointer ptr, gsize len)
+{
+ struct rspamd_osb_tokenizer_config *osb_cf, *test_cf;
+ gboolean ret = FALSE;
+
+ test_cf = rt->config;
+ g_assert (test_cf != NULL);
+
+ if (len == sizeof (*osb_cf)) {
+ osb_cf = ptr;
+
+ if (memcmp (osb_cf, osb_tokenizer_magic, sizeof (osb_tokenizer_magic)) != 0) {
+ ret = test_cf->ht == RSPAMD_OSB_HASH_COMPAT;
+ }
+ else {
+ if (osb_cf->version == DEFAULT_OSB_VERSION) {
+ /* We can compare them directly now */
+ ret = (memcmp (osb_cf, test_cf, sizeof (*osb_cf)
+ - sizeof (osb_cf->sk))) == 0;
+ }
+ }
+ }
+ else {
+ /* We are compatible now merely with fallback config */
+ if (test_cf->ht == RSPAMD_OSB_HASH_COMPAT) {
+ ret = TRUE;
+ }
+ }
+
+ return ret;
+}
+
+gboolean
+rspamd_tokenizer_osb_load_config (rspamd_mempool_t *pool,
+ struct rspamd_tokenizer_runtime *rt,
+ gpointer ptr, gsize len)
+{
+ struct rspamd_osb_tokenizer_config *osb_cf;
+
+ if (ptr == NULL || len == 0) {
+ osb_cf = rspamd_tokenizer_osb_config_from_ucl (pool, rt->tkcf->opts);
+
+ if (osb_cf->ht != RSPAMD_OSB_HASH_COMPAT) {
+ /* Trying to load incompatible configuration */
+ msg_err_pool ("cannot load tokenizer configuration from a legacy "
+ "statfile; maybe you have forgotten to set 'compat' option"
+ " in the tokenizer configuration");
+
+ return FALSE;
+ }
+ }
+ else {
+ g_assert (len == sizeof (*osb_cf));
+ osb_cf = ptr;
+ }
+
+ rt->config = osb_cf;
+ rt->conf_len = sizeof (*osb_cf);
+
+ return TRUE;
+}
+
+gboolean
+rspamd_tokenizer_osb_is_compat (struct rspamd_tokenizer_runtime *rt)
+{
+ struct rspamd_osb_tokenizer_config *osb_cf = rt->config;
+
+ return (osb_cf->ht == RSPAMD_OSB_HASH_COMPAT);
+}
+#endif
+
+struct token_pipe_entry {
+ guint64 h;
+ rspamd_stat_token_t *t;
+};
+
+gint rspamd_tokenizer_osb(struct rspamd_stat_ctx *ctx,
+ struct rspamd_task *task,
+ GArray *words,
+ gboolean is_utf,
+ const gchar *prefix,
+ GPtrArray *result)
+{
+ rspamd_token_t *new_tok = NULL;
+ rspamd_stat_token_t *token;
+ struct rspamd_osb_tokenizer_config *osb_cf;
+ guint64 cur, seed;
+ struct token_pipe_entry *hashpipe;
+ guint32 h1, h2;
+ gsize token_size;
+ guint processed = 0, i, w, window_size, token_flags = 0;
+
+ if (words == NULL) {
+ return FALSE;
+ }
+
+ osb_cf = ctx->tkcf;
+ window_size = osb_cf->window_size;
+
+ if (prefix) {
+ seed = rspamd_cryptobox_fast_hash_specific(RSPAMD_CRYPTOBOX_XXHASH64,
+ prefix, strlen(prefix), osb_cf->seed);
+ }
+ else {
+ seed = osb_cf->seed;
+ }
+
+ hashpipe = g_alloca(window_size * sizeof(hashpipe[0]));
+ for (i = 0; i < window_size; i++) {
+ hashpipe[i].h = 0xfe;
+ hashpipe[i].t = NULL;
+ }
+
+ token_size = sizeof(rspamd_token_t) +
+ sizeof(gdouble) * ctx->statfiles->len;
+ g_assert(token_size > 0);
+
+ for (w = 0; w < words->len; w++) {
+ token = &g_array_index(words, rspamd_stat_token_t, w);
+ token_flags = token->flags;
+ const gchar *begin;
+ gsize len;
+
+ if (token->flags &
+ (RSPAMD_STAT_TOKEN_FLAG_STOP_WORD | RSPAMD_STAT_TOKEN_FLAG_SKIPPED)) {
+ /* Skip stop/skipped words */
+ continue;
+ }
+
+ if (token->flags & RSPAMD_STAT_TOKEN_FLAG_TEXT) {
+ begin = token->stemmed.begin;
+ len = token->stemmed.len;
+ }
+ else {
+ begin = token->original.begin;
+ len = token->original.len;
+ }
+
+ if (osb_cf->ht == RSPAMD_OSB_HASH_COMPAT) {
+ rspamd_ftok_t ftok;
+
+ ftok.begin = begin;
+ ftok.len = len;
+ cur = rspamd_fstrhash_lc(&ftok, is_utf);
+ }
+ else {
+ /* We know that the words are normalized */
+ if (osb_cf->ht == RSPAMD_OSB_HASH_XXHASH) {
+ cur = rspamd_cryptobox_fast_hash_specific(RSPAMD_CRYPTOBOX_XXHASH64,
+ begin, len, osb_cf->seed);
+ }
+ else {
+ rspamd_cryptobox_siphash((guchar *) &cur, begin,
+ len, osb_cf->sk);
+
+ if (prefix) {
+ cur ^= seed;
+ }
+ }
+ }
+
+ if (token_flags & RSPAMD_STAT_TOKEN_FLAG_UNIGRAM) {
+ new_tok = rspamd_mempool_alloc0(task->task_pool, token_size);
+ new_tok->flags = token_flags;
+ new_tok->t1 = token;
+ new_tok->t2 = token;
+ new_tok->data = cur;
+ new_tok->window_idx = 0;
+ g_ptr_array_add(result, new_tok);
+
+ continue;
+ }
+
+#define ADD_TOKEN \
+ do { \
+ new_tok = rspamd_mempool_alloc0(task->task_pool, token_size); \
+ new_tok->flags = token_flags; \
+ new_tok->t1 = hashpipe[0].t; \
+ new_tok->t2 = hashpipe[i].t; \
+ if (osb_cf->ht == RSPAMD_OSB_HASH_COMPAT) { \
+ h1 = ((guint32) hashpipe[0].h) * primes[0] + \
+ ((guint32) hashpipe[i].h) * primes[i << 1]; \
+ h2 = ((guint32) hashpipe[0].h) * primes[1] + \
+ ((guint32) hashpipe[i].h) * primes[(i << 1) - 1]; \
+ memcpy((guchar *) &new_tok->data, &h1, sizeof(h1)); \
+ memcpy(((guchar *) &new_tok->data) + sizeof(h1), &h2, sizeof(h2)); \
+ } \
+ else { \
+ new_tok->data = hashpipe[0].h * primes[0] + hashpipe[i].h * primes[i << 1]; \
+ } \
+ new_tok->window_idx = i; \
+ g_ptr_array_add(result, new_tok); \
+ } while (0)
+
+ if (processed < window_size) {
+ /* Just fill a hashpipe */
+ ++processed;
+ hashpipe[window_size - processed].h = cur;
+ hashpipe[window_size - processed].t = token;
+ }
+ else {
+ /* Shift hashpipe */
+ for (i = window_size - 1; i > 0; i--) {
+ hashpipe[i] = hashpipe[i - 1];
+ }
+ hashpipe[0].h = cur;
+ hashpipe[0].t = token;
+
+ processed++;
+
+ for (i = 1; i < window_size; i++) {
+ if (!(hashpipe[i].t->flags & RSPAMD_STAT_TOKEN_FLAG_EXCEPTION)) {
+ ADD_TOKEN;
+ }
+ }
+ }
+ }
+
+ if (processed > 1 && processed <= window_size) {
+ processed--;
+ memmove(hashpipe, &hashpipe[window_size - processed],
+ processed * sizeof(hashpipe[0]));
+
+ for (i = 1; i < processed; i++) {
+ ADD_TOKEN;
+ }
+ }
+
+#undef ADD_TOKEN
+
+ return TRUE;
+}
diff --git a/src/libstat/tokenizers/tokenizers.c b/src/libstat/tokenizers/tokenizers.c
new file mode 100644
index 0000000..ee7234d
--- /dev/null
+++ b/src/libstat/tokenizers/tokenizers.c
@@ -0,0 +1,955 @@
+/*
+ * Copyright 2023 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/*
+ * Common tokenization functions
+ */
+
+#include "rspamd.h"
+#include "tokenizers.h"
+#include "stat_internal.h"
+#include "contrib/mumhash/mum.h"
+#include "libmime/lang_detection.h"
+#include "libstemmer.h"
+
+#include <unicode/utf8.h>
+#include <unicode/uchar.h>
+#include <unicode/uiter.h>
+#include <unicode/ubrk.h>
+#include <unicode/ucnv.h>
+#if U_ICU_VERSION_MAJOR_NUM >= 44
+#include <unicode/unorm2.h>
+#endif
+
+#include <math.h>
+
+typedef gboolean (*token_get_function)(rspamd_stat_token_t *buf, gchar const **pos,
+ rspamd_stat_token_t *token,
+ GList **exceptions, gsize *rl, gboolean check_signature);
+
+const gchar t_delimiters[256] = {
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
+ 1, 0, 0, 1, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 1, 0, 1, 1, 1, 1, 1, 0,
+ 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
+ 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 1, 1, 1, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0};
+
+/* Get next word from specified f_str_t buf */
+static gboolean
+rspamd_tokenizer_get_word_raw(rspamd_stat_token_t *buf,
+ gchar const **cur, rspamd_stat_token_t *token,
+ GList **exceptions, gsize *rl, gboolean unused)
+{
+ gsize remain, pos;
+ const gchar *p;
+ struct rspamd_process_exception *ex = NULL;
+
+ if (buf == NULL) {
+ return FALSE;
+ }
+
+ g_assert(cur != NULL);
+
+ if (exceptions != NULL && *exceptions != NULL) {
+ ex = (*exceptions)->data;
+ }
+
+ if (token->original.begin == NULL || *cur == NULL) {
+ if (ex != NULL) {
+ if (ex->pos == 0) {
+ token->original.begin = buf->original.begin + ex->len;
+ token->original.len = ex->len;
+ token->flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION;
+ }
+ else {
+ token->original.begin = buf->original.begin;
+ token->original.len = 0;
+ }
+ }
+ else {
+ token->original.begin = buf->original.begin;
+ token->original.len = 0;
+ }
+ *cur = token->original.begin;
+ }
+
+ token->original.len = 0;
+
+ pos = *cur - buf->original.begin;
+ if (pos >= buf->original.len) {
+ return FALSE;
+ }
+
+ remain = buf->original.len - pos;
+ p = *cur;
+
+ /* Skip non delimiters symbols */
+ do {
+ if (ex != NULL && ex->pos == pos) {
+ /* Go to the next exception */
+ *exceptions = g_list_next(*exceptions);
+ *cur = p + ex->len;
+ return TRUE;
+ }
+ pos++;
+ p++;
+ remain--;
+ } while (remain > 0 && t_delimiters[(guchar) *p]);
+
+ token->original.begin = p;
+
+ while (remain > 0 && !t_delimiters[(guchar) *p]) {
+ if (ex != NULL && ex->pos == pos) {
+ *exceptions = g_list_next(*exceptions);
+ *cur = p + ex->len;
+ return TRUE;
+ }
+ token->original.len++;
+ pos++;
+ remain--;
+ p++;
+ }
+
+ if (remain == 0) {
+ return FALSE;
+ }
+
+ if (rl) {
+ *rl = token->original.len;
+ }
+
+ token->flags = RSPAMD_STAT_TOKEN_FLAG_TEXT;
+
+ *cur = p;
+
+ return TRUE;
+}
+
+static inline gboolean
+rspamd_tokenize_check_limit(gboolean decay,
+ guint word_decay,
+ guint nwords,
+ guint64 *hv,
+ guint64 *prob,
+ const rspamd_stat_token_t *token,
+ gssize remain,
+ gssize total)
+{
+ static const gdouble avg_word_len = 6.0;
+
+ if (!decay) {
+ if (token->original.len >= sizeof(guint64)) {
+ guint64 tmp;
+ memcpy(&tmp, token->original.begin, sizeof(tmp));
+ *hv = mum_hash_step(*hv, tmp);
+ }
+
+ /* Check for decay */
+ if (word_decay > 0 && nwords > word_decay && remain < (gssize) total) {
+ /* Start decay */
+ gdouble decay_prob;
+
+ *hv = mum_hash_finish(*hv);
+
+ /* We assume that word is 6 symbols length in average */
+ decay_prob = (gdouble) word_decay / ((total - (remain)) / avg_word_len) * 10;
+ decay_prob = floor(decay_prob) / 10.0;
+
+ if (decay_prob >= 1.0) {
+ *prob = G_MAXUINT64;
+ }
+ else {
+ *prob = (guint64) (decay_prob * (double) G_MAXUINT64);
+ }
+
+ return TRUE;
+ }
+ }
+ else {
+ /* Decaying probability */
+ /* LCG64 x[n] = a x[n - 1] + b mod 2^64 */
+ *hv = (*hv) * 2862933555777941757ULL + 3037000493ULL;
+
+ if (*hv > *prob) {
+ return TRUE;
+ }
+ }
+
+ return FALSE;
+}
+
+static inline gboolean
+rspamd_utf_word_valid(const guchar *text, const guchar *end,
+ gint32 start, gint32 finish)
+{
+ const guchar *st = text + start, *fin = text + finish;
+ UChar32 c;
+
+ if (st >= end || fin > end || st >= fin) {
+ return FALSE;
+ }
+
+ U8_NEXT(text, start, finish, c);
+
+ if (u_isJavaIDPart(c)) {
+ return TRUE;
+ }
+
+ return FALSE;
+}
+#define SHIFT_EX \
+ do { \
+ cur = g_list_next(cur); \
+ if (cur) { \
+ ex = (struct rspamd_process_exception *) cur->data; \
+ } \
+ else { \
+ ex = NULL; \
+ } \
+ } while (0)
+
+static inline void
+rspamd_tokenize_exception(struct rspamd_process_exception *ex, GArray *res)
+{
+ rspamd_stat_token_t token;
+
+ memset(&token, 0, sizeof(token));
+
+ if (ex->type == RSPAMD_EXCEPTION_GENERIC) {
+ token.original.begin = "!!EX!!";
+ token.original.len = sizeof("!!EX!!") - 1;
+ token.flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION;
+
+ g_array_append_val(res, token);
+ token.flags = 0;
+ }
+ else if (ex->type == RSPAMD_EXCEPTION_URL) {
+ struct rspamd_url *uri;
+
+ uri = ex->ptr;
+
+ if (uri && uri->tldlen > 0) {
+ token.original.begin = rspamd_url_tld_unsafe(uri);
+ token.original.len = uri->tldlen;
+ }
+ else {
+ token.original.begin = "!!EX!!";
+ token.original.len = sizeof("!!EX!!") - 1;
+ }
+
+ token.flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION;
+ g_array_append_val(res, token);
+ token.flags = 0;
+ }
+}
+
+
+GArray *
+rspamd_tokenize_text(const gchar *text, gsize len,
+ const UText *utxt,
+ enum rspamd_tokenize_type how,
+ struct rspamd_config *cfg,
+ GList *exceptions,
+ guint64 *hash,
+ GArray *cur_words,
+ rspamd_mempool_t *pool)
+{
+ rspamd_stat_token_t token, buf;
+ const gchar *pos = NULL;
+ gsize l = 0;
+ GArray *res;
+ GList *cur = exceptions;
+ guint min_len = 0, max_len = 0, word_decay = 0, initial_size = 128;
+ guint64 hv = 0;
+ gboolean decay = FALSE, long_text_mode = FALSE;
+ guint64 prob = 0;
+ static UBreakIterator *bi = NULL;
+ static const gsize long_text_limit = 1 * 1024 * 1024;
+ static const ev_tstamp max_exec_time = 0.2; /* 200 ms */
+ ev_tstamp start;
+
+ if (text == NULL) {
+ return cur_words;
+ }
+
+ if (len > long_text_limit) {
+ /*
+ * In this mode we do additional checks to avoid performance issues
+ */
+ long_text_mode = TRUE;
+ start = ev_time();
+ }
+
+ buf.original.begin = text;
+ buf.original.len = len;
+ buf.flags = 0;
+
+ memset(&token, 0, sizeof(token));
+
+ if (cfg != NULL) {
+ min_len = cfg->min_word_len;
+ max_len = cfg->max_word_len;
+ word_decay = cfg->words_decay;
+ initial_size = word_decay * 2;
+ }
+
+ if (!cur_words) {
+ res = g_array_sized_new(FALSE, FALSE, sizeof(rspamd_stat_token_t),
+ initial_size);
+ }
+ else {
+ res = cur_words;
+ }
+
+ if (G_UNLIKELY(how == RSPAMD_TOKENIZE_RAW || utxt == NULL)) {
+ while (rspamd_tokenizer_get_word_raw(&buf, &pos, &token, &cur, &l, FALSE)) {
+ if (l == 0 || (min_len > 0 && l < min_len) ||
+ (max_len > 0 && l > max_len)) {
+ token.original.begin = pos;
+ continue;
+ }
+
+ if (token.original.len > 0 &&
+ rspamd_tokenize_check_limit(decay, word_decay, res->len,
+ &hv, &prob, &token, pos - text, len)) {
+ if (!decay) {
+ decay = TRUE;
+ }
+ else {
+ token.original.begin = pos;
+ continue;
+ }
+ }
+
+ if (long_text_mode) {
+ if ((res->len + 1) % 16 == 0) {
+ ev_tstamp now = ev_time();
+
+ if (now - start > max_exec_time) {
+ msg_warn_pool_check(
+ "too long time has been spent on tokenization:"
+ " %.1f ms, limit is %.1f ms; %d words added so far",
+ (now - start) * 1e3, max_exec_time * 1e3,
+ res->len);
+
+ goto end;
+ }
+ }
+ }
+
+ g_array_append_val(res, token);
+
+ if (((gsize) res->len) * sizeof(token) > (0x1ull << 30u)) {
+ /* Due to bug in glib ! */
+ msg_err_pool_check(
+ "too many words found: %d, stop tokenization to avoid DoS",
+ res->len);
+
+ goto end;
+ }
+
+ token.original.begin = pos;
+ }
+ }
+ else {
+ /* UTF8 boundaries */
+ UErrorCode uc_err = U_ZERO_ERROR;
+ int32_t last, p;
+ struct rspamd_process_exception *ex = NULL;
+
+ if (bi == NULL) {
+ bi = ubrk_open(UBRK_WORD, NULL, NULL, 0, &uc_err);
+
+ g_assert(U_SUCCESS(uc_err));
+ }
+
+ ubrk_setUText(bi, (UText *) utxt, &uc_err);
+ last = ubrk_first(bi);
+ p = last;
+
+ if (cur) {
+ ex = (struct rspamd_process_exception *) cur->data;
+ }
+
+ while (p != UBRK_DONE) {
+ start_over:
+ token.original.len = 0;
+
+ if (p > last) {
+ if (ex && cur) {
+ /* Check exception */
+ if (ex->pos >= last && ex->pos <= p) {
+ /* We have an exception within boundary */
+ /* First, start to drain exceptions from the start */
+ while (cur && ex->pos <= last) {
+ /* We have an exception at the beginning, skip those */
+ last += ex->len;
+ rspamd_tokenize_exception(ex, res);
+
+ if (last > p) {
+ /* Exception spread over the boundaries */
+ while (last > p && p != UBRK_DONE) {
+ gint32 old_p = p;
+ p = ubrk_next(bi);
+
+ if (p != UBRK_DONE && p <= old_p) {
+ msg_warn_pool_check(
+ "tokenization reversed back on position %d,"
+ "%d new position (%d backward), likely libicu bug!",
+ (gint) (p), (gint) (old_p), old_p - p);
+
+ goto end;
+ }
+ }
+
+ /* We need to reset our scan with new p and last */
+ SHIFT_EX;
+ goto start_over;
+ }
+
+ SHIFT_EX;
+ }
+
+ /* Now, we can have an exception within boundary again */
+ if (cur && ex->pos >= last && ex->pos <= p) {
+ /* Append the first part */
+ if (rspamd_utf_word_valid(text, text + len, last,
+ ex->pos)) {
+ token.original.begin = text + last;
+ token.original.len = ex->pos - last;
+ token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT |
+ RSPAMD_STAT_TOKEN_FLAG_UTF;
+ }
+
+ /* Process the current exception */
+ last += ex->len + (ex->pos - last);
+
+ rspamd_tokenize_exception(ex, res);
+
+ if (last > p) {
+ /* Exception spread over the boundaries */
+ while (last > p && p != UBRK_DONE) {
+ gint32 old_p = p;
+ p = ubrk_next(bi);
+ if (p != UBRK_DONE && p <= old_p) {
+ msg_warn_pool_check(
+ "tokenization reversed back on position %d,"
+ "%d new position (%d backward), likely libicu bug!",
+ (gint) (p), (gint) (old_p), old_p - p);
+
+ goto end;
+ }
+ }
+ /* We need to reset our scan with new p and last */
+ SHIFT_EX;
+ goto start_over;
+ }
+
+ SHIFT_EX;
+ }
+ else if (p > last) {
+ if (rspamd_utf_word_valid(text, text + len, last, p)) {
+ token.original.begin = text + last;
+ token.original.len = p - last;
+ token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT |
+ RSPAMD_STAT_TOKEN_FLAG_UTF;
+ }
+ }
+ }
+ else if (ex->pos < last) {
+ /* Forward exceptions list */
+ while (cur && ex->pos <= last) {
+ /* We have an exception at the beginning, skip those */
+ SHIFT_EX;
+ }
+
+ if (rspamd_utf_word_valid(text, text + len, last, p)) {
+ token.original.begin = text + last;
+ token.original.len = p - last;
+ token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT |
+ RSPAMD_STAT_TOKEN_FLAG_UTF;
+ }
+ }
+ else {
+ /* No exceptions within boundary */
+ if (rspamd_utf_word_valid(text, text + len, last, p)) {
+ token.original.begin = text + last;
+ token.original.len = p - last;
+ token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT |
+ RSPAMD_STAT_TOKEN_FLAG_UTF;
+ }
+ }
+ }
+ else {
+ if (rspamd_utf_word_valid(text, text + len, last, p)) {
+ token.original.begin = text + last;
+ token.original.len = p - last;
+ token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT |
+ RSPAMD_STAT_TOKEN_FLAG_UTF;
+ }
+ }
+
+ if (token.original.len > 0 &&
+ rspamd_tokenize_check_limit(decay, word_decay, res->len,
+ &hv, &prob, &token, p, len)) {
+ if (!decay) {
+ decay = TRUE;
+ }
+ else {
+ token.flags |= RSPAMD_STAT_TOKEN_FLAG_SKIPPED;
+ }
+ }
+ }
+
+ if (token.original.len > 0) {
+ /* Additional check for number of words */
+ if (((gsize) res->len) * sizeof(token) > (0x1ull << 30u)) {
+ /* Due to bug in glib ! */
+ msg_err("too many words found: %d, stop tokenization to avoid DoS",
+ res->len);
+
+ goto end;
+ }
+
+ g_array_append_val(res, token);
+ }
+
+ /* Also check for long text mode */
+ if (long_text_mode) {
+ /* Check time each 128 words added */
+ const int words_check_mask = 0x7F;
+
+ if ((res->len & words_check_mask) == words_check_mask) {
+ ev_tstamp now = ev_time();
+
+ if (now - start > max_exec_time) {
+ msg_warn_pool_check(
+ "too long time has been spent on tokenization:"
+ " %.1f ms, limit is %.1f ms; %d words added so far",
+ (now - start) * 1e3, max_exec_time * 1e3,
+ res->len);
+
+ goto end;
+ }
+ }
+ }
+
+ last = p;
+ p = ubrk_next(bi);
+
+ if (p != UBRK_DONE && p <= last) {
+ msg_warn_pool_check("tokenization reversed back on position %d,"
+ "%d new position (%d backward), likely libicu bug!",
+ (gint) (p), (gint) (last), last - p);
+
+ goto end;
+ }
+ }
+ }
+
+end:
+ if (!decay) {
+ hv = mum_hash_finish(hv);
+ }
+
+ if (hash) {
+ *hash = hv;
+ }
+
+ return res;
+}
+
+#undef SHIFT_EX
+
+static void
+rspamd_add_metawords_from_str(const gchar *beg, gsize len,
+ struct rspamd_task *task)
+{
+ UText utxt = UTEXT_INITIALIZER;
+ UErrorCode uc_err = U_ZERO_ERROR;
+ guint i = 0;
+ UChar32 uc;
+ gboolean valid_utf = TRUE;
+
+ while (i < len) {
+ U8_NEXT(beg, i, len, uc);
+
+ if (((gint32) uc) < 0) {
+ valid_utf = FALSE;
+ break;
+ }
+
+#if U_ICU_VERSION_MAJOR_NUM < 50
+ if (u_isalpha(uc)) {
+ gint32 sc = ublock_getCode(uc);
+
+ if (sc == UBLOCK_THAI) {
+ valid_utf = FALSE;
+ msg_info_task("enable workaround for Thai characters for old libicu");
+ break;
+ }
+ }
+#endif
+ }
+
+ if (valid_utf) {
+ utext_openUTF8(&utxt,
+ beg,
+ len,
+ &uc_err);
+
+ task->meta_words = rspamd_tokenize_text(beg, len,
+ &utxt, RSPAMD_TOKENIZE_UTF,
+ task->cfg, NULL, NULL,
+ task->meta_words,
+ task->task_pool);
+
+ utext_close(&utxt);
+ }
+ else {
+ task->meta_words = rspamd_tokenize_text(beg, len,
+ NULL, RSPAMD_TOKENIZE_RAW,
+ task->cfg, NULL, NULL, task->meta_words,
+ task->task_pool);
+ }
+}
+
+void rspamd_tokenize_meta_words(struct rspamd_task *task)
+{
+ guint i = 0;
+ rspamd_stat_token_t *tok;
+
+ if (MESSAGE_FIELD(task, subject)) {
+ rspamd_add_metawords_from_str(MESSAGE_FIELD(task, subject),
+ strlen(MESSAGE_FIELD(task, subject)), task);
+ }
+
+ if (MESSAGE_FIELD(task, from_mime) && MESSAGE_FIELD(task, from_mime)->len > 0) {
+ struct rspamd_email_address *addr;
+
+ addr = g_ptr_array_index(MESSAGE_FIELD(task, from_mime), 0);
+
+ if (addr->name) {
+ rspamd_add_metawords_from_str(addr->name, strlen(addr->name), task);
+ }
+ }
+
+ if (task->meta_words != NULL) {
+ const gchar *language = NULL;
+
+ if (MESSAGE_FIELD(task, text_parts) &&
+ MESSAGE_FIELD(task, text_parts)->len > 0) {
+ struct rspamd_mime_text_part *tp = g_ptr_array_index(
+ MESSAGE_FIELD(task, text_parts), 0);
+
+ if (tp->language) {
+ language = tp->language;
+ }
+ }
+
+ rspamd_normalize_words(task->meta_words, task->task_pool);
+ rspamd_stem_words(task->meta_words, task->task_pool, language,
+ task->lang_det);
+
+ for (i = 0; i < task->meta_words->len; i++) {
+ tok = &g_array_index(task->meta_words, rspamd_stat_token_t, i);
+ tok->flags |= RSPAMD_STAT_TOKEN_FLAG_HEADER;
+ }
+ }
+}
+
+static inline void
+rspamd_uchars_to_ucs32(const UChar *src, gsize srclen,
+ rspamd_stat_token_t *tok,
+ rspamd_mempool_t *pool)
+{
+ UChar32 *dest, t, *d;
+ gint32 i = 0;
+
+ dest = rspamd_mempool_alloc(pool, srclen * sizeof(UChar32));
+ d = dest;
+
+ while (i < srclen) {
+ U16_NEXT_UNSAFE(src, i, t);
+
+ if (u_isgraph(t)) {
+ UCharCategory cat;
+
+ cat = u_charType(t);
+#if U_ICU_VERSION_MAJOR_NUM >= 57
+ if (u_hasBinaryProperty(t, UCHAR_EMOJI)) {
+ tok->flags |= RSPAMD_STAT_TOKEN_FLAG_EMOJI;
+ }
+#endif
+
+ if ((cat >= U_UPPERCASE_LETTER && cat <= U_OTHER_NUMBER) ||
+ cat == U_CONNECTOR_PUNCTUATION ||
+ cat == U_MATH_SYMBOL ||
+ cat == U_CURRENCY_SYMBOL) {
+ *d++ = u_tolower(t);
+ }
+ }
+ else {
+ /* Invisible spaces ! */
+ tok->flags |= RSPAMD_STAT_TOKEN_FLAG_INVISIBLE_SPACES;
+ }
+ }
+
+ tok->unicode.begin = dest;
+ tok->unicode.len = d - dest;
+}
+
+static inline void
+rspamd_ucs32_to_normalised(rspamd_stat_token_t *tok,
+ rspamd_mempool_t *pool)
+{
+ guint i, doff = 0;
+ gsize utflen = 0;
+ gchar *dest;
+ UChar32 t;
+
+ for (i = 0; i < tok->unicode.len; i++) {
+ utflen += U8_LENGTH(tok->unicode.begin[i]);
+ }
+
+ dest = rspamd_mempool_alloc(pool, utflen + 1);
+
+ for (i = 0; i < tok->unicode.len; i++) {
+ t = tok->unicode.begin[i];
+ U8_APPEND_UNSAFE(dest, doff, t);
+ }
+
+ g_assert(doff <= utflen);
+ dest[doff] = '\0';
+
+ tok->normalized.len = doff;
+ tok->normalized.begin = dest;
+}
+
+void rspamd_normalize_single_word(rspamd_stat_token_t *tok, rspamd_mempool_t *pool)
+{
+ UErrorCode uc_err = U_ZERO_ERROR;
+ UConverter *utf8_converter;
+ UChar tmpbuf[1024]; /* Assume that we have no longer words... */
+ gsize ulen;
+
+ utf8_converter = rspamd_get_utf8_converter();
+
+ if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_UTF) {
+ ulen = ucnv_toUChars(utf8_converter,
+ tmpbuf,
+ G_N_ELEMENTS(tmpbuf),
+ tok->original.begin,
+ tok->original.len,
+ &uc_err);
+
+ /* Now, we need to understand if we need to normalise the word */
+ if (!U_SUCCESS(uc_err)) {
+ tok->flags |= RSPAMD_STAT_TOKEN_FLAG_BROKEN_UNICODE;
+ tok->unicode.begin = NULL;
+ tok->unicode.len = 0;
+ tok->normalized.begin = NULL;
+ tok->normalized.len = 0;
+ }
+ else {
+#if U_ICU_VERSION_MAJOR_NUM >= 44
+ const UNormalizer2 *norm = rspamd_get_unicode_normalizer();
+ gint32 end;
+
+ /* We can now check if we need to decompose */
+ end = unorm2_spanQuickCheckYes(norm, tmpbuf, ulen, &uc_err);
+
+ if (!U_SUCCESS(uc_err)) {
+ rspamd_uchars_to_ucs32(tmpbuf, ulen, tok, pool);
+ tok->normalized.begin = NULL;
+ tok->normalized.len = 0;
+ tok->flags |= RSPAMD_STAT_TOKEN_FLAG_BROKEN_UNICODE;
+ }
+ else {
+ if (end == ulen) {
+ /* Already normalised, just lowercase */
+ rspamd_uchars_to_ucs32(tmpbuf, ulen, tok, pool);
+ rspamd_ucs32_to_normalised(tok, pool);
+ }
+ else {
+ /* Perform normalization */
+ UChar normbuf[1024];
+
+ g_assert(end < G_N_ELEMENTS(normbuf));
+ /* First part */
+ memcpy(normbuf, tmpbuf, end * sizeof(UChar));
+ /* Second part */
+ ulen = unorm2_normalizeSecondAndAppend(norm,
+ normbuf, end,
+ G_N_ELEMENTS(normbuf),
+ tmpbuf + end,
+ ulen - end,
+ &uc_err);
+
+ if (!U_SUCCESS(uc_err)) {
+ if (uc_err != U_BUFFER_OVERFLOW_ERROR) {
+ msg_warn_pool_check("cannot normalise text '%*s': %s",
+ (gint) tok->original.len, tok->original.begin,
+ u_errorName(uc_err));
+ rspamd_uchars_to_ucs32(tmpbuf, ulen, tok, pool);
+ rspamd_ucs32_to_normalised(tok, pool);
+ tok->flags |= RSPAMD_STAT_TOKEN_FLAG_BROKEN_UNICODE;
+ }
+ }
+ else {
+ /* Copy normalised back */
+ rspamd_uchars_to_ucs32(normbuf, ulen, tok, pool);
+ tok->flags |= RSPAMD_STAT_TOKEN_FLAG_NORMALISED;
+ rspamd_ucs32_to_normalised(tok, pool);
+ }
+ }
+ }
+#else
+ /* Legacy version with no unorm2 interface */
+ rspamd_uchars_to_ucs32(tmpbuf, ulen, tok, pool);
+ rspamd_ucs32_to_normalised(tok, pool);
+#endif
+ }
+ }
+ else {
+ if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_TEXT) {
+ /* Simple lowercase */
+ gchar *dest;
+
+ dest = rspamd_mempool_alloc(pool, tok->original.len + 1);
+ rspamd_strlcpy(dest, tok->original.begin, tok->original.len + 1);
+ rspamd_str_lc(dest, tok->original.len);
+ tok->normalized.len = tok->original.len;
+ tok->normalized.begin = dest;
+ }
+ }
+}
+
+void rspamd_normalize_words(GArray *words, rspamd_mempool_t *pool)
+{
+ rspamd_stat_token_t *tok;
+ guint i;
+
+ for (i = 0; i < words->len; i++) {
+ tok = &g_array_index(words, rspamd_stat_token_t, i);
+ rspamd_normalize_single_word(tok, pool);
+ }
+}
+
+void rspamd_stem_words(GArray *words, rspamd_mempool_t *pool,
+ const gchar *language,
+ struct rspamd_lang_detector *lang_detector)
+{
+ static GHashTable *stemmers = NULL;
+ struct sb_stemmer *stem = NULL;
+ guint i;
+ rspamd_stat_token_t *tok;
+ gchar *dest;
+ gsize dlen;
+
+ if (!stemmers) {
+ stemmers = g_hash_table_new(rspamd_strcase_hash,
+ rspamd_strcase_equal);
+ }
+
+ if (language && language[0] != '\0') {
+ stem = g_hash_table_lookup(stemmers, language);
+
+ if (stem == NULL) {
+
+ stem = sb_stemmer_new(language, "UTF_8");
+
+ if (stem == NULL) {
+ msg_debug_pool(
+ "cannot create lemmatizer for %s language",
+ language);
+ g_hash_table_insert(stemmers, g_strdup(language),
+ GINT_TO_POINTER(-1));
+ }
+ else {
+ g_hash_table_insert(stemmers, g_strdup(language),
+ stem);
+ }
+ }
+ else if (stem == GINT_TO_POINTER(-1)) {
+ /* Negative cache */
+ stem = NULL;
+ }
+ }
+ for (i = 0; i < words->len; i++) {
+ tok = &g_array_index(words, rspamd_stat_token_t, i);
+
+ if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_UTF) {
+ if (stem) {
+ const gchar *stemmed = NULL;
+
+ stemmed = sb_stemmer_stem(stem,
+ tok->normalized.begin, tok->normalized.len);
+
+ dlen = sb_stemmer_length(stem);
+
+ if (stemmed != NULL && dlen > 0) {
+ dest = rspamd_mempool_alloc(pool, dlen);
+ memcpy(dest, stemmed, dlen);
+ tok->stemmed.len = dlen;
+ tok->stemmed.begin = dest;
+ tok->flags |= RSPAMD_STAT_TOKEN_FLAG_STEMMED;
+ }
+ else {
+ /* Fallback */
+ tok->stemmed.len = tok->normalized.len;
+ tok->stemmed.begin = tok->normalized.begin;
+ }
+ }
+ else {
+ tok->stemmed.len = tok->normalized.len;
+ tok->stemmed.begin = tok->normalized.begin;
+ }
+
+ if (tok->stemmed.len > 0 && lang_detector != NULL &&
+ rspamd_language_detector_is_stop_word(lang_detector, tok->stemmed.begin, tok->stemmed.len)) {
+ tok->flags |= RSPAMD_STAT_TOKEN_FLAG_STOP_WORD;
+ }
+ }
+ else {
+ if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_TEXT) {
+ /* Raw text, lowercase */
+ tok->stemmed.len = tok->normalized.len;
+ tok->stemmed.begin = tok->normalized.begin;
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/src/libstat/tokenizers/tokenizers.h b/src/libstat/tokenizers/tokenizers.h
new file mode 100644
index 0000000..d696364
--- /dev/null
+++ b/src/libstat/tokenizers/tokenizers.h
@@ -0,0 +1,100 @@
+/*
+ * Copyright 2023 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef TOKENIZERS_H
+#define TOKENIZERS_H
+
+#include "config.h"
+#include "mem_pool.h"
+#include "fstring.h"
+#include "rspamd.h"
+#include "stat_api.h"
+
+#include <unicode/utext.h>
+
+#define RSPAMD_DEFAULT_TOKENIZER "osb"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+struct rspamd_tokenizer_runtime;
+struct rspamd_stat_ctx;
+
+/* Common tokenizer structure */
+struct rspamd_stat_tokenizer {
+ gchar *name;
+
+ gpointer (*get_config)(rspamd_mempool_t *pool,
+ struct rspamd_tokenizer_config *cf, gsize *len);
+
+ gint (*tokenize_func)(struct rspamd_stat_ctx *ctx,
+ struct rspamd_task *task,
+ GArray *words,
+ gboolean is_utf,
+ const gchar *prefix,
+ GPtrArray *result);
+};
+
+enum rspamd_tokenize_type {
+ RSPAMD_TOKENIZE_UTF = 0,
+ RSPAMD_TOKENIZE_RAW,
+ RSPAMD_TOKENIZE_UNICODE
+};
+
+/* Compare two token nodes */
+gint token_node_compare_func(gconstpointer a, gconstpointer b);
+
+
+/* Tokenize text into array of words (rspamd_stat_token_t type) */
+GArray *rspamd_tokenize_text(const gchar *text, gsize len,
+ const UText *utxt,
+ enum rspamd_tokenize_type how,
+ struct rspamd_config *cfg,
+ GList *exceptions,
+ guint64 *hash,
+ GArray *cur_words,
+ rspamd_mempool_t *pool);
+
+/* OSB tokenize function */
+gint rspamd_tokenizer_osb(struct rspamd_stat_ctx *ctx,
+ struct rspamd_task *task,
+ GArray *words,
+ gboolean is_utf,
+ const gchar *prefix,
+ GPtrArray *result);
+
+gpointer rspamd_tokenizer_osb_get_config(rspamd_mempool_t *pool,
+ struct rspamd_tokenizer_config *cf,
+ gsize *len);
+
+struct rspamd_lang_detector;
+
+void rspamd_normalize_single_word(rspamd_stat_token_t *tok, rspamd_mempool_t *pool);
+
+void rspamd_normalize_words(GArray *words, rspamd_mempool_t *pool);
+
+void rspamd_stem_words(GArray *words, rspamd_mempool_t *pool,
+ const gchar *language,
+ struct rspamd_lang_detector *lang_detector);
+
+void rspamd_tokenize_meta_words(struct rspamd_task *task);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif