summaryrefslogtreecommitdiffstats
path: root/src/libstat/backends
diff options
context:
space:
mode:
Diffstat (limited to 'src/libstat/backends')
-rw-r--r--src/libstat/backends/backends.h127
-rw-r--r--src/libstat/backends/cdb_backend.cxx491
-rw-r--r--src/libstat/backends/http_backend.cxx440
-rw-r--r--src/libstat/backends/mmaped_file.c1113
-rw-r--r--src/libstat/backends/redis_backend.cxx1132
-rw-r--r--src/libstat/backends/sqlite3_backend.c907
6 files changed, 4210 insertions, 0 deletions
diff --git a/src/libstat/backends/backends.h b/src/libstat/backends/backends.h
new file mode 100644
index 0000000..4b16950
--- /dev/null
+++ b/src/libstat/backends/backends.h
@@ -0,0 +1,127 @@
+/*-
+ * Copyright 2016 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef BACKENDS_H_
+#define BACKENDS_H_
+
+#include "config.h"
+#include "ucl.h"
+
+#define RSPAMD_DEFAULT_BACKEND "mmap"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/* Forwarded declarations */
+struct rspamd_classifier_config;
+struct rspamd_statfile_config;
+struct rspamd_config;
+struct rspamd_stat_ctx;
+struct rspamd_token_result;
+struct rspamd_statfile;
+struct rspamd_task;
+
+struct rspamd_stat_backend {
+ const char *name;
+ bool read_only;
+
+ gpointer (*init)(struct rspamd_stat_ctx *ctx, struct rspamd_config *cfg,
+ struct rspamd_statfile *st);
+
+ gpointer (*runtime)(struct rspamd_task *task,
+ struct rspamd_statfile_config *stcf,
+ gboolean learn, gpointer ctx,
+ gint id);
+
+ gboolean (*process_tokens)(struct rspamd_task *task, GPtrArray *tokens,
+ gint id,
+ gpointer ctx);
+
+ gboolean (*finalize_process)(struct rspamd_task *task,
+ gpointer runtime, gpointer ctx);
+
+ gboolean (*learn_tokens)(struct rspamd_task *task, GPtrArray *tokens,
+ gint id,
+ gpointer ctx);
+
+ gulong (*total_learns)(struct rspamd_task *task,
+ gpointer runtime, gpointer ctx);
+
+ gboolean (*finalize_learn)(struct rspamd_task *task,
+ gpointer runtime, gpointer ctx, GError **err);
+
+ gulong (*inc_learns)(struct rspamd_task *task,
+ gpointer runtime, gpointer ctx);
+
+ gulong (*dec_learns)(struct rspamd_task *task,
+ gpointer runtime, gpointer ctx);
+
+ ucl_object_t *(*get_stat)(gpointer runtime, gpointer ctx);
+
+ void (*close)(gpointer ctx);
+
+ gpointer (*load_tokenizer_config)(gpointer runtime, gsize *sz);
+
+ gpointer ctx;
+};
+
+#define RSPAMD_STAT_BACKEND_DEF(name) \
+ gpointer rspamd_##name##_init(struct rspamd_stat_ctx *ctx, \
+ struct rspamd_config *cfg, struct rspamd_statfile *st); \
+ gpointer rspamd_##name##_runtime(struct rspamd_task *task, \
+ struct rspamd_statfile_config *stcf, \
+ gboolean learn, gpointer ctx, gint id); \
+ gboolean rspamd_##name##_process_tokens(struct rspamd_task *task, \
+ GPtrArray *tokens, gint id, \
+ gpointer runtime); \
+ gboolean rspamd_##name##_finalize_process(struct rspamd_task *task, \
+ gpointer runtime, \
+ gpointer ctx); \
+ gboolean rspamd_##name##_learn_tokens(struct rspamd_task *task, \
+ GPtrArray *tokens, gint id, \
+ gpointer runtime); \
+ gboolean rspamd_##name##_finalize_learn(struct rspamd_task *task, \
+ gpointer runtime, \
+ gpointer ctx, GError **err); \
+ gulong rspamd_##name##_total_learns(struct rspamd_task *task, \
+ gpointer runtime, \
+ gpointer ctx); \
+ gulong rspamd_##name##_inc_learns(struct rspamd_task *task, \
+ gpointer runtime, \
+ gpointer ctx); \
+ gulong rspamd_##name##_dec_learns(struct rspamd_task *task, \
+ gpointer runtime, \
+ gpointer ctx); \
+ gulong rspamd_##name##_learns(struct rspamd_task *task, \
+ gpointer runtime, \
+ gpointer ctx); \
+ ucl_object_t *rspamd_##name##_get_stat(gpointer runtime, \
+ gpointer ctx); \
+ gpointer rspamd_##name##_load_tokenizer_config(gpointer runtime, \
+ gsize *len); \
+ void rspamd_##name##_close(gpointer ctx)
+
+RSPAMD_STAT_BACKEND_DEF(mmaped_file);
+RSPAMD_STAT_BACKEND_DEF(sqlite3);
+RSPAMD_STAT_BACKEND_DEF(cdb);
+RSPAMD_STAT_BACKEND_DEF(redis);
+RSPAMD_STAT_BACKEND_DEF(http);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif /* BACKENDS_H_ */
diff --git a/src/libstat/backends/cdb_backend.cxx b/src/libstat/backends/cdb_backend.cxx
new file mode 100644
index 0000000..81d87f3
--- /dev/null
+++ b/src/libstat/backends/cdb_backend.cxx
@@ -0,0 +1,491 @@
+/*-
+ * Copyright 2021 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * CDB read only statistics backend
+ */
+
+#include "config.h"
+#include "stat_internal.h"
+#include "contrib/cdb/cdb.h"
+
+#include <utility>
+#include <memory>
+#include <string>
+#include <optional>
+#include "contrib/expected/expected.hpp"
+#include "contrib/ankerl/unordered_dense.h"
+#include "fmt/core.h"
+
+namespace rspamd::stat::cdb {
+
+/*
+ * Utility class to share cdb instances over statfiles instances, as each
+ * cdb has tokens for both ham and spam classes
+ */
+class cdb_shared_storage {
+public:
+ using cdb_element_t = std::shared_ptr<struct cdb>;
+ cdb_shared_storage() = default;
+
+ auto get_cdb(const char *path) const -> std::optional<cdb_element_t>
+ {
+ auto found = elts.find(path);
+
+ if (found != elts.end()) {
+ if (!found->second.expired()) {
+ return found->second.lock();
+ }
+ }
+
+ return std::nullopt;
+ }
+ /* Create a new smart pointer over POD cdb structure */
+ static auto new_cdb() -> cdb_element_t
+ {
+ auto ret = cdb_element_t(new struct cdb, cdb_deleter());
+ memset(ret.get(), 0, sizeof(struct cdb));
+ return ret;
+ }
+ /* Enclose cdb into storage */
+ auto push_cdb(const char *path, cdb_element_t cdbp) -> cdb_element_t
+ {
+ auto found = elts.find(path);
+
+ if (found != elts.end()) {
+ if (found->second.expired()) {
+ /* OK, move in lieu of the expired weak pointer */
+
+ found->second = cdbp;
+ return cdbp;
+ }
+ else {
+ /*
+ * Existing and not expired, return the existing one
+ */
+ return found->second.lock();
+ }
+ }
+ else {
+ /* Not existing, make a weak ptr and return the original */
+ elts.emplace(path, std::weak_ptr<struct cdb>(cdbp));
+ return cdbp;
+ }
+ }
+
+private:
+ /*
+ * We store weak pointers here to allow owning cdb statfiles to free
+ * expensive cdb before this cache is terminated (e.g. on dynamic cdb reload)
+ */
+ ankerl::unordered_dense::map<std::string, std::weak_ptr<struct cdb>> elts;
+
+ struct cdb_deleter {
+ void operator()(struct cdb *c) const
+ {
+ cdb_free(c);
+ delete c;
+ }
+ };
+};
+
+static cdb_shared_storage cdb_shared_storage;
+
+class ro_backend final {
+public:
+ explicit ro_backend(struct rspamd_statfile *_st, cdb_shared_storage::cdb_element_t _db)
+ : st(_st), db(std::move(_db))
+ {
+ }
+ ro_backend() = delete;
+ ro_backend(const ro_backend &) = delete;
+ ro_backend(ro_backend &&other) noexcept
+ {
+ *this = std::move(other);
+ }
+ ro_backend &operator=(ro_backend &&other) noexcept
+ {
+ std::swap(st, other.st);
+ std::swap(db, other.db);
+ std::swap(loaded, other.loaded);
+ std::swap(learns_spam, other.learns_spam);
+ std::swap(learns_ham, other.learns_ham);
+
+ return *this;
+ }
+ ~ro_backend()
+ {
+ }
+
+ auto load_cdb() -> tl::expected<bool, std::string>;
+ auto process_token(const rspamd_token_t *tok) const -> std::optional<float>;
+ constexpr auto is_spam() const -> bool
+ {
+ return st->stcf->is_spam;
+ }
+ auto get_learns() const -> std::uint64_t
+ {
+ if (is_spam()) {
+ return learns_spam;
+ }
+ else {
+ return learns_ham;
+ }
+ }
+ auto get_total_learns() const -> std::uint64_t
+ {
+ return learns_spam + learns_ham;
+ }
+
+private:
+ struct rspamd_statfile *st;
+ cdb_shared_storage::cdb_element_t db;
+ bool loaded = false;
+ std::uint64_t learns_spam = 0;
+ std::uint64_t learns_ham = 0;
+};
+
+template<typename T>
+static inline auto
+cdb_get_key_as_int64(struct cdb *cdb, T key) -> std::optional<std::int64_t>
+{
+ auto pos = cdb_find(cdb, (void *) &key, sizeof(key));
+
+ if (pos > 0) {
+ auto vpos = cdb_datapos(cdb);
+ auto vlen = cdb_datalen(cdb);
+
+ if (vlen == sizeof(std::int64_t)) {
+ std::int64_t ret;
+ cdb_read(cdb, (void *) &ret, vlen, vpos);
+
+ return ret;
+ }
+ }
+
+ return std::nullopt;
+}
+
+template<typename T>
+static inline auto
+cdb_get_key_as_float_pair(struct cdb *cdb, T key) -> std::optional<std::pair<float, float>>
+{
+ auto pos = cdb_find(cdb, (void *) &key, sizeof(key));
+
+ if (pos > 0) {
+ auto vpos = cdb_datapos(cdb);
+ auto vlen = cdb_datalen(cdb);
+
+ if (vlen == sizeof(float) * 2) {
+ union {
+ struct {
+ float v1;
+ float v2;
+ } d;
+ char c[sizeof(float) * 2];
+ } u;
+ cdb_read(cdb, (void *) u.c, vlen, vpos);
+
+ return std::make_pair(u.d.v1, u.d.v2);
+ }
+ }
+
+ return std::nullopt;
+}
+
+
+auto ro_backend::load_cdb() -> tl::expected<bool, std::string>
+{
+ if (!db) {
+ return tl::make_unexpected("no database loaded");
+ }
+
+ /* Now get number of learns */
+ std::int64_t cdb_key;
+ static const char learn_spam_key[9] = "_lrnspam", learn_ham_key[9] = "_lrnham_";
+
+ auto check_key = [&](const char *key, std::uint64_t &target) -> tl::expected<bool, std::string> {
+ memcpy((void *) &cdb_key, key, sizeof(cdb_key));
+
+ auto maybe_value = cdb_get_key_as_int64(db.get(), cdb_key);
+
+ if (!maybe_value) {
+ return tl::make_unexpected(fmt::format("missing {} key", key));
+ }
+
+ target = (std::uint64_t) maybe_value.value();
+
+ return true;
+ };
+
+ auto res = check_key(learn_spam_key, learns_spam);
+
+ if (!res) {
+ return res;
+ }
+
+ res = check_key(learn_ham_key, learns_ham);
+
+ if (!res) {
+ return res;
+ }
+
+ loaded = true;
+
+ return true;// expected
+}
+
+auto ro_backend::process_token(const rspamd_token_t *tok) const -> std::optional<float>
+{
+ if (!loaded) {
+ return std::nullopt;
+ }
+
+ auto maybe_value = cdb_get_key_as_float_pair(db.get(), tok->data);
+
+ if (maybe_value) {
+ auto [spam_count, ham_count] = maybe_value.value();
+
+ if (is_spam()) {
+ return spam_count;
+ }
+ else {
+ return ham_count;
+ }
+ }
+
+ return std::nullopt;
+}
+
+auto open_cdb(struct rspamd_statfile *st) -> tl::expected<ro_backend, std::string>
+{
+ const char *path = nullptr;
+ const auto *stf = st->stcf;
+
+ auto get_filename = [](const ucl_object_t *obj) -> const char * {
+ const auto *filename = ucl_object_lookup_any(obj,
+ "filename", "path", "cdb", nullptr);
+
+ if (filename && ucl_object_type(filename) == UCL_STRING) {
+ return ucl_object_tostring(filename);
+ }
+
+ return nullptr;
+ };
+
+ /* First search in backend configuration */
+ const auto *obj = ucl_object_lookup(st->classifier->cfg->opts, "backend");
+ if (obj != NULL && ucl_object_type(obj) == UCL_OBJECT) {
+ path = get_filename(obj);
+ }
+
+ /* Now try statfiles config */
+ if (!path && stf->opts) {
+ path = get_filename(stf->opts);
+ }
+
+ /* Now try classifier config */
+ if (!path && st->classifier->cfg->opts) {
+ path = get_filename(st->classifier->cfg->opts);
+ }
+
+ if (!path) {
+ return tl::make_unexpected("missing/malformed filename attribute");
+ }
+
+ auto cached_cdb_maybe = cdb_shared_storage.get_cdb(path);
+ cdb_shared_storage::cdb_element_t cdbp;
+
+ if (!cached_cdb_maybe) {
+
+ auto fd = rspamd_file_xopen(path, O_RDONLY, 0, true);
+
+ if (fd == -1) {
+ return tl::make_unexpected(fmt::format("cannot open {}: {}",
+ path, strerror(errno)));
+ }
+
+ cdbp = cdb_shared_storage::new_cdb();
+
+ if (cdb_init(cdbp.get(), fd) == -1) {
+ close(fd);
+
+ return tl::make_unexpected(fmt::format("cannot init cdb in {}: {}",
+ path, strerror(errno)));
+ }
+
+ cdbp = cdb_shared_storage.push_cdb(path, cdbp);
+
+ close(fd);
+ }
+ else {
+ cdbp = cached_cdb_maybe.value();
+ }
+
+ if (!cdbp) {
+ return tl::make_unexpected(fmt::format("cannot init cdb in {}: internal error",
+ path));
+ }
+
+ ro_backend bk{st, std::move(cdbp)};
+
+ auto res = bk.load_cdb();
+
+ if (!res) {
+ return tl::make_unexpected(res.error());
+ }
+
+ return bk;
+}
+
+}// namespace rspamd::stat::cdb
+
+#define CDB_FROM_RAW(p) (reinterpret_cast<rspamd::stat::cdb::ro_backend *>(p))
+
+/* C exports */
+gpointer
+rspamd_cdb_init(struct rspamd_stat_ctx *ctx,
+ struct rspamd_config *cfg,
+ struct rspamd_statfile *st)
+{
+ auto maybe_backend = rspamd::stat::cdb::open_cdb(st);
+
+ if (maybe_backend) {
+ /* Move into a new pointer */
+ auto *result = new rspamd::stat::cdb::ro_backend(std::move(maybe_backend.value()));
+
+ return result;
+ }
+ else {
+ msg_err_config("cannot load cdb backend: %s", maybe_backend.error().c_str());
+ }
+
+ return nullptr;
+}
+gpointer
+rspamd_cdb_runtime(struct rspamd_task *task,
+ struct rspamd_statfile_config *stcf,
+ gboolean learn,
+ gpointer ctx,
+ gint _id)
+{
+ /* In CDB we don't have any dynamic stuff */
+ return ctx;
+}
+
+gboolean
+rspamd_cdb_process_tokens(struct rspamd_task *task,
+ GPtrArray *tokens,
+ gint id,
+ gpointer runtime)
+{
+ auto *cdbp = CDB_FROM_RAW(runtime);
+ bool seen_values = false;
+
+ for (auto i = 0u; i < tokens->len; i++) {
+ rspamd_token_t *tok;
+ tok = reinterpret_cast<rspamd_token_t *>(g_ptr_array_index(tokens, i));
+
+ auto res = cdbp->process_token(tok);
+
+ if (res) {
+ tok->values[id] = res.value();
+ seen_values = true;
+ }
+ else {
+ tok->values[id] = 0;
+ }
+ }
+
+ if (seen_values) {
+ if (cdbp->is_spam()) {
+ task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS;
+ }
+ else {
+ task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
+ }
+ }
+
+ return true;
+}
+gboolean
+rspamd_cdb_finalize_process(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx)
+{
+ return true;
+}
+gboolean
+rspamd_cdb_learn_tokens(struct rspamd_task *task,
+ GPtrArray *tokens,
+ gint id,
+ gpointer ctx)
+{
+ return false;
+}
+gboolean
+rspamd_cdb_finalize_learn(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx,
+ GError **err)
+{
+ return false;
+}
+
+gulong rspamd_cdb_total_learns(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx)
+{
+ auto *cdbp = CDB_FROM_RAW(ctx);
+ return cdbp->get_total_learns();
+}
+gulong
+rspamd_cdb_inc_learns(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx)
+{
+ return (gulong) -1;
+}
+gulong
+rspamd_cdb_dec_learns(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx)
+{
+ return (gulong) -1;
+}
+gulong
+rspamd_cdb_learns(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx)
+{
+ auto *cdbp = CDB_FROM_RAW(ctx);
+ return cdbp->get_learns();
+}
+ucl_object_t *
+rspamd_cdb_get_stat(gpointer runtime, gpointer ctx)
+{
+ return nullptr;
+}
+gpointer
+rspamd_cdb_load_tokenizer_config(gpointer runtime, gsize *len)
+{
+ return nullptr;
+}
+void rspamd_cdb_close(gpointer ctx)
+{
+ auto *cdbp = CDB_FROM_RAW(ctx);
+ delete cdbp;
+} \ No newline at end of file
diff --git a/src/libstat/backends/http_backend.cxx b/src/libstat/backends/http_backend.cxx
new file mode 100644
index 0000000..075e508
--- /dev/null
+++ b/src/libstat/backends/http_backend.cxx
@@ -0,0 +1,440 @@
+/*-
+ * Copyright 2022 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "config.h"
+#include "stat_internal.h"
+#include "libserver/http/http_connection.h"
+#include "libserver/mempool_vars_internal.h"
+#include "upstream.h"
+#include "contrib/ankerl/unordered_dense.h"
+#include <algorithm>
+#include <vector>
+
+namespace rspamd::stat::http {
+
+#define msg_debug_stat_http(...) rspamd_conditional_debug_fast(NULL, NULL, \
+ rspamd_stat_http_log_id, "stat_http", task->task_pool->tag.uid, \
+ RSPAMD_LOG_FUNC, \
+ __VA_ARGS__)
+
+INIT_LOG_MODULE(stat_http)
+
+/* Represents all http backends defined in some configuration */
+class http_backends_collection {
+ std::vector<struct rspamd_statfile *> backends;
+ double timeout = 1.0; /* Default timeout */
+ struct upstream_list *read_servers = nullptr;
+ struct upstream_list *write_servers = nullptr;
+
+public:
+ static auto get() -> http_backends_collection &
+ {
+ static http_backends_collection *singleton = nullptr;
+
+ if (singleton == nullptr) {
+ singleton = new http_backends_collection;
+ }
+
+ return *singleton;
+ }
+
+ /**
+ * Add a new backend and (optionally initialize the basic backend parameters
+ * @param ctx
+ * @param cfg
+ * @param st
+ * @return
+ */
+ auto add_backend(struct rspamd_stat_ctx *ctx,
+ struct rspamd_config *cfg,
+ struct rspamd_statfile *st) -> bool;
+ /**
+ * Remove a statfile cleaning things up if the last statfile is removed
+ * @param st
+ * @return
+ */
+ auto remove_backend(struct rspamd_statfile *st) -> bool;
+
+ upstream *get_upstream(bool is_learn);
+
+private:
+ http_backends_collection() = default;
+ auto first_init(struct rspamd_stat_ctx *ctx,
+ struct rspamd_config *cfg,
+ struct rspamd_statfile *st) -> bool;
+};
+
+/*
+ * Created one per each task
+ */
+class http_backend_runtime final {
+public:
+ static auto create(struct rspamd_task *task, bool is_learn) -> http_backend_runtime *;
+ /* Add a new statfile with a specific id to the list of statfiles */
+ auto notice_statfile(int id, const struct rspamd_statfile_config *st) -> void
+ {
+ seen_statfiles[id] = st;
+ }
+
+ auto process_tokens(struct rspamd_task *task,
+ GPtrArray *tokens,
+ gint id,
+ bool learn) -> bool;
+
+private:
+ http_backends_collection *all_backends;
+ ankerl::unordered_dense::map<int, const struct rspamd_statfile_config *> seen_statfiles;
+ struct upstream *selected;
+
+private:
+ http_backend_runtime(struct rspamd_task *task, bool is_learn)
+ : all_backends(&http_backends_collection::get())
+ {
+ selected = all_backends->get_upstream(is_learn);
+ }
+ ~http_backend_runtime() = default;
+ static auto dtor(void *p) -> void
+ {
+ ((http_backend_runtime *) p)->~http_backend_runtime();
+ }
+};
+
+/*
+ * Efficient way to make a messagepack payload from stat tokens,
+ * avoiding any intermediate libraries, as we would send many tokens
+ * all together
+ */
+static auto
+stat_tokens_to_msgpack(GPtrArray *tokens) -> std::vector<std::uint8_t>
+{
+ std::vector<std::uint8_t> ret;
+ rspamd_token_t *cur;
+ int i;
+
+ /*
+ * We define array, it's size and N elements each is uint64_t
+ * Layout:
+ * 0xdd - array marker
+ * [4 bytes be] - size of the array
+ * [ 0xcf + <8 bytes BE integer>] * N - array elements
+ */
+ ret.resize(tokens->len * (sizeof(std::uint64_t) + 1) + 5);
+ ret.push_back('\xdd');
+ std::uint32_t ulen = GUINT32_TO_BE(tokens->len);
+ std::copy((const std::uint8_t *) &ulen,
+ ((const std::uint8_t *) &ulen) + sizeof(ulen), std::back_inserter(ret));
+
+ PTR_ARRAY_FOREACH(tokens, i, cur)
+ {
+ ret.push_back('\xcf');
+ std::uint64_t val = GUINT64_TO_BE(cur->data);
+ std::copy((const std::uint8_t *) &val,
+ ((const std::uint8_t *) &val) + sizeof(val), std::back_inserter(ret));
+ }
+
+ return ret;
+}
+
+auto http_backend_runtime::create(struct rspamd_task *task, bool is_learn) -> http_backend_runtime *
+{
+ /* Alloc type provide proper size and alignment */
+ auto *allocated_runtime = rspamd_mempool_alloc_type(task->task_pool, http_backend_runtime);
+
+ rspamd_mempool_add_destructor(task->task_pool, http_backend_runtime::dtor, allocated_runtime);
+
+ return new (allocated_runtime) http_backend_runtime{task, is_learn};
+}
+
+auto http_backend_runtime::process_tokens(struct rspamd_task *task, GPtrArray *tokens, gint id, bool learn) -> bool
+{
+ if (!learn) {
+ if (id == seen_statfiles.size() - 1) {
+ /* Emit http request on the last statfile */
+ }
+ }
+ else {
+ /* On learn we need to learn all statfiles that we were requested to learn */
+ if (seen_statfiles.empty()) {
+ /* Request has been already set, or nothing to learn */
+ return true;
+ }
+ else {
+ seen_statfiles.clear();
+ }
+ }
+
+ return true;
+}
+
+auto http_backends_collection::add_backend(struct rspamd_stat_ctx *ctx,
+ struct rspamd_config *cfg,
+ struct rspamd_statfile *st) -> bool
+{
+ /* On empty list of backends we know that we need to load backend data actually */
+ if (backends.empty()) {
+ if (!first_init(ctx, cfg, st)) {
+ return false;
+ }
+ }
+
+ backends.push_back(st);
+
+ return true;
+}
+
+auto http_backends_collection::first_init(struct rspamd_stat_ctx *ctx,
+ struct rspamd_config *cfg,
+ struct rspamd_statfile *st) -> bool
+{
+ auto try_load_backend_config = [&](const ucl_object_t *obj) -> bool {
+ if (!obj || ucl_object_type(obj) != UCL_OBJECT) {
+ return false;
+ }
+
+ /* First try to load read servers */
+ auto *rs = ucl_object_lookup_any(obj, "read_servers", "servers", nullptr);
+ if (rs) {
+ read_servers = rspamd_upstreams_create(cfg->ups_ctx);
+
+ if (read_servers == nullptr) {
+ return false;
+ }
+
+ if (!rspamd_upstreams_from_ucl(read_servers, rs, 80, this)) {
+ rspamd_upstreams_destroy(read_servers);
+ return false;
+ }
+ }
+ auto *ws = ucl_object_lookup_any(obj, "write_servers", "servers", nullptr);
+ if (ws) {
+ write_servers = rspamd_upstreams_create(cfg->ups_ctx);
+
+ if (write_servers == nullptr) {
+ return false;
+ }
+
+ if (!rspamd_upstreams_from_ucl(write_servers, rs, 80, this)) {
+ rspamd_upstreams_destroy(write_servers);
+ return false;
+ }
+ }
+
+ auto *tim = ucl_object_lookup(obj, "timeout");
+
+ if (tim) {
+ timeout = ucl_object_todouble(tim);
+ }
+
+ return true;
+ };
+
+ auto ret = false;
+ auto obj = ucl_object_lookup(st->classifier->cfg->opts, "backend");
+ if (obj != nullptr) {
+ ret = try_load_backend_config(obj);
+ }
+
+ /* Now try statfiles config */
+ if (!ret && st->stcf->opts) {
+ ret = try_load_backend_config(st->stcf->opts);
+ }
+
+ /* Now try classifier config */
+ if (!ret && st->classifier->cfg->opts) {
+ ret = try_load_backend_config(st->classifier->cfg->opts);
+ }
+
+ return ret;
+}
+
+auto http_backends_collection::remove_backend(struct rspamd_statfile *st) -> bool
+{
+ auto backend_it = std::remove(std::begin(backends), std::end(backends), st);
+
+ if (backend_it != std::end(backends)) {
+ /* Fast erasure with no order preservation */
+ std::swap(*backend_it, backends.back());
+ backends.pop_back();
+
+ if (backends.empty()) {
+ /* De-init collection - likely config reload */
+ if (read_servers) {
+ rspamd_upstreams_destroy(read_servers);
+ read_servers = nullptr;
+ }
+
+ if (write_servers) {
+ rspamd_upstreams_destroy(write_servers);
+ write_servers = nullptr;
+ }
+ }
+
+ return true;
+ }
+
+ return false;
+}
+
+upstream *http_backends_collection::get_upstream(bool is_learn)
+{
+ auto *ups_list = read_servers;
+ if (is_learn) {
+ ups_list = write_servers;
+ }
+
+ return rspamd_upstream_get(ups_list, RSPAMD_UPSTREAM_ROUND_ROBIN, nullptr, 0);
+}
+
+}// namespace rspamd::stat::http
+
+/* C API */
+
+gpointer
+rspamd_http_init(struct rspamd_stat_ctx *ctx,
+ struct rspamd_config *cfg,
+ struct rspamd_statfile *st)
+{
+ auto &collections = rspamd::stat::http::http_backends_collection::get();
+
+ if (!collections.add_backend(ctx, cfg, st)) {
+ msg_err_config("cannot load http backend");
+
+ return nullptr;
+ }
+
+ return (void *) &collections;
+}
+gpointer
+rspamd_http_runtime(struct rspamd_task *task,
+ struct rspamd_statfile_config *stcf,
+ gboolean learn,
+ gpointer ctx,
+ gint id)
+{
+ auto maybe_existing = rspamd_mempool_get_variable(task->task_pool, RSPAMD_MEMPOOL_HTTP_STAT_BACKEND_RUNTIME);
+
+ if (maybe_existing != nullptr) {
+ auto real_runtime = (rspamd::stat::http::http_backend_runtime *) maybe_existing;
+ real_runtime->notice_statfile(id, stcf);
+
+ return maybe_existing;
+ }
+
+ auto runtime = rspamd::stat::http::http_backend_runtime::create(task, learn);
+
+ if (runtime) {
+ runtime->notice_statfile(id, stcf);
+ rspamd_mempool_set_variable(task->task_pool, RSPAMD_MEMPOOL_HTTP_STAT_BACKEND_RUNTIME,
+ (void *) runtime, nullptr);
+ }
+
+ return (void *) runtime;
+}
+
+gboolean
+rspamd_http_process_tokens(struct rspamd_task *task,
+ GPtrArray *tokens,
+ gint id,
+ gpointer runtime)
+{
+ auto real_runtime = (rspamd::stat::http::http_backend_runtime *) runtime;
+
+ if (real_runtime) {
+ return real_runtime->process_tokens(task, tokens, id, false);
+ }
+
+
+ return false;
+}
+gboolean
+rspamd_http_finalize_process(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx)
+{
+ /* Not needed */
+ return true;
+}
+
+gboolean
+rspamd_http_learn_tokens(struct rspamd_task *task,
+ GPtrArray *tokens,
+ gint id,
+ gpointer runtime)
+{
+ auto real_runtime = (rspamd::stat::http::http_backend_runtime *) runtime;
+
+ if (real_runtime) {
+ return real_runtime->process_tokens(task, tokens, id, true);
+ }
+
+
+ return false;
+}
+gboolean
+rspamd_http_finalize_learn(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx,
+ GError **err)
+{
+ return false;
+}
+
+gulong rspamd_http_total_learns(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx)
+{
+ /* TODO */
+ return 0;
+}
+gulong
+rspamd_http_inc_learns(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx)
+{
+ /* TODO */
+ return 0;
+}
+gulong
+rspamd_http_dec_learns(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx)
+{
+ /* TODO */
+ return (gulong) -1;
+}
+gulong
+rspamd_http_learns(struct rspamd_task *task,
+ gpointer runtime,
+ gpointer ctx)
+{
+ /* TODO */
+ return 0;
+}
+ucl_object_t *
+rspamd_http_get_stat(gpointer runtime, gpointer ctx)
+{
+ /* TODO */
+ return nullptr;
+}
+gpointer
+rspamd_http_load_tokenizer_config(gpointer runtime, gsize *len)
+{
+ return nullptr;
+}
+void rspamd_http_close(gpointer ctx)
+{
+ /* TODO */
+} \ No newline at end of file
diff --git a/src/libstat/backends/mmaped_file.c b/src/libstat/backends/mmaped_file.c
new file mode 100644
index 0000000..5c20207
--- /dev/null
+++ b/src/libstat/backends/mmaped_file.c
@@ -0,0 +1,1113 @@
+/*-
+ * Copyright 2016 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "config.h"
+#include "stat_internal.h"
+#include "unix-std.h"
+
+#define CHAIN_LENGTH 128
+
+/* Section types */
+#define STATFILE_SECTION_COMMON 1
+
+/**
+ * Common statfile header
+ */
+struct stat_file_header {
+ u_char magic[3]; /**< magic signature ('r' 's' 'd') */
+ u_char version[2]; /**< version of statfile */
+ u_char padding[3]; /**< padding */
+ guint64 create_time; /**< create time (time_t->guint64) */
+ guint64 revision; /**< revision number */
+ guint64 rev_time; /**< revision time */
+ guint64 used_blocks; /**< used blocks number */
+ guint64 total_blocks; /**< total number of blocks */
+ guint64 tokenizer_conf_len; /**< length of tokenizer configuration */
+ u_char unused[231]; /**< some bytes that can be used in future */
+};
+
+/**
+ * Section header
+ */
+struct stat_file_section {
+ guint64 code; /**< section's code */
+ guint64 length; /**< section's length in blocks */
+};
+
+/**
+ * Block of data in statfile
+ */
+struct stat_file_block {
+ guint32 hash1; /**< hash1 (also acts as index) */
+ guint32 hash2; /**< hash2 */
+ double value; /**< double value */
+};
+
+/**
+ * Statistic file
+ */
+struct stat_file {
+ struct stat_file_header header; /**< header */
+ struct stat_file_section section; /**< first section */
+ struct stat_file_block blocks[1]; /**< first block of data */
+};
+
+/**
+ * Common view of statfile object
+ */
+typedef struct {
+#ifdef HAVE_PATH_MAX
+ gchar filename[PATH_MAX]; /**< name of file */
+#else
+ gchar filename[MAXPATHLEN]; /**< name of file */
+#endif
+ rspamd_mempool_t *pool;
+ gint fd; /**< descriptor */
+ void *map; /**< mmaped area */
+ off_t seek_pos; /**< current seek position */
+ struct stat_file_section cur_section; /**< current section */
+ size_t len; /**< length of file(in bytes) */
+ struct rspamd_statfile_config *cf;
+} rspamd_mmaped_file_t;
+
+
+#define RSPAMD_STATFILE_VERSION \
+ { \
+ '1', '2' \
+ }
+#define BACKUP_SUFFIX ".old"
+
+static void rspamd_mmaped_file_set_block_common(rspamd_mempool_t *pool,
+ rspamd_mmaped_file_t *file,
+ guint32 h1, guint32 h2, double value);
+
+rspamd_mmaped_file_t *rspamd_mmaped_file_open(rspamd_mempool_t *pool,
+ const gchar *filename, size_t size,
+ struct rspamd_statfile_config *stcf);
+gint rspamd_mmaped_file_create(const gchar *filename, size_t size,
+ struct rspamd_statfile_config *stcf,
+ rspamd_mempool_t *pool);
+gint rspamd_mmaped_file_close_file(rspamd_mempool_t *pool,
+ rspamd_mmaped_file_t *file);
+
+double
+rspamd_mmaped_file_get_block(rspamd_mmaped_file_t *file,
+ guint32 h1,
+ guint32 h2)
+{
+ struct stat_file_block *block;
+ guint i, blocknum;
+ u_char *c;
+
+ if (!file->map) {
+ return 0;
+ }
+
+ blocknum = h1 % file->cur_section.length;
+ c = (u_char *) file->map + file->seek_pos + blocknum * sizeof(struct stat_file_block);
+ block = (struct stat_file_block *) c;
+
+ for (i = 0; i < CHAIN_LENGTH; i++) {
+ if (i + blocknum >= file->cur_section.length) {
+ break;
+ }
+ if (block->hash1 == h1 && block->hash2 == h2) {
+ return block->value;
+ }
+ c += sizeof(struct stat_file_block);
+ block = (struct stat_file_block *) c;
+ }
+
+
+ return 0;
+}
+
+static void
+rspamd_mmaped_file_set_block_common(rspamd_mempool_t *pool,
+ rspamd_mmaped_file_t *file,
+ guint32 h1, guint32 h2, double value)
+{
+ struct stat_file_block *block, *to_expire = NULL;
+ struct stat_file_header *header;
+ guint i, blocknum;
+ u_char *c;
+ double min = G_MAXDOUBLE;
+
+ if (!file->map) {
+ return;
+ }
+
+ blocknum = h1 % file->cur_section.length;
+ header = (struct stat_file_header *) file->map;
+ c = (u_char *) file->map + file->seek_pos + blocknum * sizeof(struct stat_file_block);
+ block = (struct stat_file_block *) c;
+
+ for (i = 0; i < CHAIN_LENGTH; i++) {
+ if (i + blocknum >= file->cur_section.length) {
+ /* Need to expire some block in chain */
+ msg_info_pool("chain %ud is full in statfile %s, starting expire",
+ blocknum,
+ file->filename);
+ break;
+ }
+ /* First try to find block in chain */
+ if (block->hash1 == h1 && block->hash2 == h2) {
+ msg_debug_pool("%s found existing block %ud in chain %ud, value %.2f",
+ file->filename,
+ i,
+ blocknum,
+ value);
+ block->value = value;
+ return;
+ }
+ /* Check whether we have a free block in chain */
+ if (block->hash1 == 0 && block->hash2 == 0) {
+ /* Write new block here */
+ msg_debug_pool("%s found free block %ud in chain %ud, set h1=%ud, h2=%ud",
+ file->filename,
+ i,
+ blocknum,
+ h1,
+ h2);
+ block->hash1 = h1;
+ block->hash2 = h2;
+ block->value = value;
+ header->used_blocks++;
+
+ return;
+ }
+
+ /* Expire block with minimum value otherwise */
+ if (block->value < min) {
+ to_expire = block;
+ min = block->value;
+ }
+ c += sizeof(struct stat_file_block);
+ block = (struct stat_file_block *) c;
+ }
+
+ /* Try expire some block */
+ if (to_expire) {
+ block = to_expire;
+ }
+ else {
+ /* Expire first block in chain */
+ c = (u_char *) file->map + file->seek_pos + blocknum * sizeof(struct stat_file_block);
+ block = (struct stat_file_block *) c;
+ }
+
+ block->hash1 = h1;
+ block->hash2 = h2;
+ block->value = value;
+}
+
+void rspamd_mmaped_file_set_block(rspamd_mempool_t *pool,
+ rspamd_mmaped_file_t *file,
+ guint32 h1,
+ guint32 h2,
+ double value)
+{
+ rspamd_mmaped_file_set_block_common(pool, file, h1, h2, value);
+}
+
+gboolean
+rspamd_mmaped_file_set_revision(rspamd_mmaped_file_t *file, guint64 rev, time_t time)
+{
+ struct stat_file_header *header;
+
+ if (file == NULL || file->map == NULL) {
+ return FALSE;
+ }
+
+ header = (struct stat_file_header *) file->map;
+
+ header->revision = rev;
+ header->rev_time = time;
+
+ return TRUE;
+}
+
+gboolean
+rspamd_mmaped_file_inc_revision(rspamd_mmaped_file_t *file)
+{
+ struct stat_file_header *header;
+
+ if (file == NULL || file->map == NULL) {
+ return FALSE;
+ }
+
+ header = (struct stat_file_header *) file->map;
+
+ header->revision++;
+
+ return TRUE;
+}
+
+gboolean
+rspamd_mmaped_file_dec_revision(rspamd_mmaped_file_t *file)
+{
+ struct stat_file_header *header;
+
+ if (file == NULL || file->map == NULL) {
+ return FALSE;
+ }
+
+ header = (struct stat_file_header *) file->map;
+
+ header->revision--;
+
+ return TRUE;
+}
+
+
+gboolean
+rspamd_mmaped_file_get_revision(rspamd_mmaped_file_t *file, guint64 *rev, time_t *time)
+{
+ struct stat_file_header *header;
+
+ if (file == NULL || file->map == NULL) {
+ return FALSE;
+ }
+
+ header = (struct stat_file_header *) file->map;
+
+ if (rev != NULL) {
+ *rev = header->revision;
+ }
+ if (time != NULL) {
+ *time = header->rev_time;
+ }
+
+ return TRUE;
+}
+
+guint64
+rspamd_mmaped_file_get_used(rspamd_mmaped_file_t *file)
+{
+ struct stat_file_header *header;
+
+ if (file == NULL || file->map == NULL) {
+ return (guint64) -1;
+ }
+
+ header = (struct stat_file_header *) file->map;
+
+ return header->used_blocks;
+}
+
+guint64
+rspamd_mmaped_file_get_total(rspamd_mmaped_file_t *file)
+{
+ struct stat_file_header *header;
+
+ if (file == NULL || file->map == NULL) {
+ return (guint64) -1;
+ }
+
+ header = (struct stat_file_header *) file->map;
+
+ /* If total blocks is 0 we have old version of header, so set total blocks correctly */
+ if (header->total_blocks == 0) {
+ header->total_blocks = file->cur_section.length;
+ }
+
+ return header->total_blocks;
+}
+
+/* Check whether specified file is statistic file and calculate its len in blocks */
+static gint
+rspamd_mmaped_file_check(rspamd_mempool_t *pool, rspamd_mmaped_file_t *file)
+{
+ struct stat_file *f;
+ gchar *c;
+ static gchar valid_version[] = RSPAMD_STATFILE_VERSION;
+
+
+ if (!file || !file->map) {
+ return -1;
+ }
+
+ if (file->len < sizeof(struct stat_file)) {
+ msg_info_pool("file %s is too short to be stat file: %z",
+ file->filename,
+ file->len);
+ return -1;
+ }
+
+ f = (struct stat_file *) file->map;
+ c = &f->header.magic[0];
+ /* Check magic and version */
+ if (*c++ != 'r' || *c++ != 's' || *c++ != 'd') {
+ msg_info_pool("file %s is invalid stat file", file->filename);
+ return -1;
+ }
+
+ c = &f->header.version[0];
+ /* Now check version and convert old version to new one (that can be used for sync */
+ if (*c == 1 && *(c + 1) == 0) {
+ return -1;
+ }
+ else if (memcmp(c, valid_version, sizeof(valid_version)) != 0) {
+ /* Unknown version */
+ msg_info_pool("file %s has invalid version %c.%c",
+ file->filename,
+ '0' + *c,
+ '0' + *(c + 1));
+ return -1;
+ }
+
+ /* Check first section and set new offset */
+ file->cur_section.code = f->section.code;
+ file->cur_section.length = f->section.length;
+ if (file->cur_section.length * sizeof(struct stat_file_block) >
+ file->len) {
+ msg_info_pool("file %s is truncated: %z, must be %z",
+ file->filename,
+ file->len,
+ file->cur_section.length * sizeof(struct stat_file_block));
+ return -1;
+ }
+ file->seek_pos = sizeof(struct stat_file) -
+ sizeof(struct stat_file_block);
+
+ return 0;
+}
+
+
+static rspamd_mmaped_file_t *
+rspamd_mmaped_file_reindex(rspamd_mempool_t *pool,
+ const gchar *filename,
+ size_t old_size,
+ size_t size,
+ struct rspamd_statfile_config *stcf)
+{
+ gchar *backup, *lock;
+ gint fd, lock_fd;
+ rspamd_mmaped_file_t *new, *old = NULL;
+ u_char *map, *pos;
+ struct stat_file_block *block;
+ struct stat_file_header *header, *nh;
+ struct timespec sleep_ts = {
+ .tv_sec = 0,
+ .tv_nsec = 1000000};
+
+ if (size <
+ sizeof(struct stat_file_header) + sizeof(struct stat_file_section) +
+ sizeof(block)) {
+ msg_err_pool("file %s is too small to carry any statistic: %z",
+ filename,
+ size);
+ return NULL;
+ }
+
+ lock = g_strconcat(filename, ".lock", NULL);
+ lock_fd = open(lock, O_WRONLY | O_CREAT | O_EXCL, 00600);
+
+ while (lock_fd == -1) {
+ /* Wait for lock */
+ lock_fd = open(lock, O_WRONLY | O_CREAT | O_EXCL, 00600);
+ if (lock_fd != -1) {
+ unlink(lock);
+ close(lock_fd);
+ g_free(lock);
+
+ return rspamd_mmaped_file_open(pool, filename, size, stcf);
+ }
+ else {
+ nanosleep(&sleep_ts, NULL);
+ }
+ }
+
+ backup = g_strconcat(filename, ".old", NULL);
+ if (rename(filename, backup) == -1) {
+ msg_err_pool("cannot rename %s to %s: %s", filename, backup, strerror(errno));
+ g_free(backup);
+ unlink(lock);
+ g_free(lock);
+ close(lock_fd);
+
+ return NULL;
+ }
+
+ old = rspamd_mmaped_file_open(pool, backup, old_size, stcf);
+
+ if (old == NULL) {
+ msg_warn_pool("old file %s is invalid mmapped file, just move it",
+ backup);
+ }
+
+ /* We need to release our lock here */
+ unlink(lock);
+ close(lock_fd);
+ g_free(lock);
+
+ /* Now create new file with required size */
+ if (rspamd_mmaped_file_create(filename, size, stcf, pool) != 0) {
+ msg_err_pool("cannot create new file");
+ rspamd_mmaped_file_close(old);
+ g_free(backup);
+
+ return NULL;
+ }
+
+ new = rspamd_mmaped_file_open(pool, filename, size, stcf);
+
+ if (old) {
+ /* Now open new file and start copying */
+ fd = open(backup, O_RDONLY);
+ if (fd == -1 || new == NULL) {
+ if (fd != -1) {
+ close(fd);
+ }
+
+ msg_err_pool("cannot open file: %s", strerror(errno));
+ rspamd_mmaped_file_close(old);
+ g_free(backup);
+ return NULL;
+ }
+
+
+ /* Now start reading blocks from old statfile */
+ if ((map =
+ mmap(NULL, old_size, PROT_READ, MAP_SHARED, fd, 0)) == MAP_FAILED) {
+ msg_err_pool("cannot mmap file: %s", strerror(errno));
+ close(fd);
+ rspamd_mmaped_file_close(old);
+ g_free(backup);
+ return NULL;
+ }
+
+ pos = map + (sizeof(struct stat_file) - sizeof(struct stat_file_block));
+
+ if (pos - map < (gssize) old_size) {
+ while ((gssize) old_size - (pos - map) >= (gssize) sizeof(struct stat_file_block)) {
+ block = (struct stat_file_block *) pos;
+ if (block->hash1 != 0 && block->value != 0) {
+ rspamd_mmaped_file_set_block_common(pool,
+ new, block->hash1,
+ block->hash2, block->value);
+ }
+ pos += sizeof(block);
+ }
+ }
+
+ header = (struct stat_file_header *) map;
+ rspamd_mmaped_file_set_revision(new, header->revision, header->rev_time);
+ nh = new->map;
+ /* Copy tokenizer configuration */
+ memcpy(nh->unused, header->unused, sizeof(header->unused));
+ nh->tokenizer_conf_len = header->tokenizer_conf_len;
+
+ munmap(map, old_size);
+ close(fd);
+ rspamd_mmaped_file_close_file(pool, old);
+ }
+
+ unlink(backup);
+ g_free(backup);
+
+ return new;
+}
+
+/*
+ * Pre-load mmaped file into memory
+ */
+static void
+rspamd_mmaped_file_preload(rspamd_mmaped_file_t *file)
+{
+ guint8 *pos, *end;
+ volatile guint8 t;
+ gsize size;
+
+ pos = (guint8 *) file->map;
+ end = (guint8 *) file->map + file->len;
+
+ if (madvise(pos, end - pos, MADV_SEQUENTIAL) == -1) {
+ msg_info("madvise failed: %s", strerror(errno));
+ }
+ else {
+ /* Load pages of file */
+#ifdef HAVE_GETPAGESIZE
+ size = getpagesize();
+#else
+ size = sysconf(_SC_PAGESIZE);
+#endif
+ while (pos < end) {
+ t = *pos;
+ (void) t;
+ pos += size;
+ }
+ }
+}
+
+rspamd_mmaped_file_t *
+rspamd_mmaped_file_open(rspamd_mempool_t *pool,
+ const gchar *filename, size_t size,
+ struct rspamd_statfile_config *stcf)
+{
+ struct stat st;
+ rspamd_mmaped_file_t *new_file;
+ gchar *lock;
+ gint lock_fd;
+
+ lock = g_strconcat(filename, ".lock", NULL);
+ lock_fd = open(lock, O_WRONLY | O_CREAT | O_EXCL, 00600);
+
+ if (lock_fd == -1) {
+ g_free(lock);
+ msg_info_pool("cannot open file %s, it is locked by another process",
+ filename);
+ return NULL;
+ }
+
+ close(lock_fd);
+ unlink(lock);
+ g_free(lock);
+
+ if (stat(filename, &st) == -1) {
+ msg_info_pool("cannot stat file %s, error %s, %d", filename, strerror(errno), errno);
+ return NULL;
+ }
+
+ if (labs((glong) size - st.st_size) > (long) sizeof(struct stat_file) * 2 && size > sizeof(struct stat_file)) {
+ msg_warn_pool("need to reindex statfile old size: %Hz, new size: %Hz",
+ (size_t) st.st_size, size);
+ return rspamd_mmaped_file_reindex(pool, filename, st.st_size, size, stcf);
+ }
+ else if (size < sizeof(struct stat_file)) {
+ msg_err_pool("requested to shrink statfile to %Hz but it is too small",
+ size);
+ }
+
+ new_file = g_malloc0(sizeof(rspamd_mmaped_file_t));
+ if ((new_file->fd = open(filename, O_RDWR)) == -1) {
+ msg_info_pool("cannot open file %s, error %d, %s",
+ filename,
+ errno,
+ strerror(errno));
+ g_free(new_file);
+ return NULL;
+ }
+
+ if ((new_file->map =
+ mmap(NULL, st.st_size, PROT_READ | PROT_WRITE, MAP_SHARED,
+ new_file->fd, 0)) == MAP_FAILED) {
+ close(new_file->fd);
+ msg_info_pool("cannot mmap file %s, error %d, %s",
+ filename,
+ errno,
+ strerror(errno));
+ g_free(new_file);
+ return NULL;
+ }
+
+ rspamd_strlcpy(new_file->filename, filename, sizeof(new_file->filename));
+ new_file->len = st.st_size;
+ /* Try to lock pages in RAM */
+
+ /* Acquire lock for this operation */
+ if (!rspamd_file_lock(new_file->fd, FALSE)) {
+ close(new_file->fd);
+ munmap(new_file->map, st.st_size);
+ msg_info_pool("cannot lock file %s, error %d, %s",
+ filename,
+ errno,
+ strerror(errno));
+ g_free(new_file);
+ return NULL;
+ }
+
+ if (rspamd_mmaped_file_check(pool, new_file) == -1) {
+ close(new_file->fd);
+ rspamd_file_unlock(new_file->fd, FALSE);
+ munmap(new_file->map, st.st_size);
+ g_free(new_file);
+ return NULL;
+ }
+
+ rspamd_file_unlock(new_file->fd, FALSE);
+ new_file->cf = stcf;
+ new_file->pool = pool;
+ rspamd_mmaped_file_preload(new_file);
+
+ g_assert(stcf->clcf != NULL);
+
+ msg_debug_pool("opened statfile %s of size %l", filename, (long) size);
+
+ return new_file;
+}
+
+gint rspamd_mmaped_file_close_file(rspamd_mempool_t *pool,
+ rspamd_mmaped_file_t *file)
+{
+ if (file->map) {
+ msg_info_pool("syncing statfile %s", file->filename);
+ msync(file->map, file->len, MS_ASYNC);
+ munmap(file->map, file->len);
+ }
+ if (file->fd != -1) {
+ close(file->fd);
+ }
+
+ g_free(file);
+
+ return 0;
+}
+
+gint rspamd_mmaped_file_create(const gchar *filename,
+ size_t size,
+ struct rspamd_statfile_config *stcf,
+ rspamd_mempool_t *pool)
+{
+ struct stat_file_header header = {
+ .magic = {'r', 's', 'd'},
+ .version = RSPAMD_STATFILE_VERSION,
+ .padding = {0, 0, 0},
+ .revision = 0,
+ .rev_time = 0,
+ .used_blocks = 0};
+ struct stat_file_section section = {
+ .code = STATFILE_SECTION_COMMON,
+ };
+ struct stat_file_block block = {0, 0, 0};
+ struct rspamd_stat_tokenizer *tokenizer;
+ gint fd, lock_fd;
+ guint buflen = 0, nblocks;
+ gchar *buf = NULL, *lock;
+ struct stat sb;
+ gpointer tok_conf;
+ gsize tok_conf_len;
+ struct timespec sleep_ts = {
+ .tv_sec = 0,
+ .tv_nsec = 1000000};
+
+ if (size <
+ sizeof(struct stat_file_header) + sizeof(struct stat_file_section) +
+ sizeof(block)) {
+ msg_err_pool("file %s is too small to carry any statistic: %z",
+ filename,
+ size);
+ return -1;
+ }
+
+ lock = g_strconcat(filename, ".lock", NULL);
+ lock_fd = open(lock, O_WRONLY | O_CREAT | O_EXCL, 00600);
+
+ while (lock_fd == -1) {
+ /* Wait for lock */
+ lock_fd = open(lock, O_WRONLY | O_CREAT | O_EXCL, 00600);
+ if (lock_fd != -1) {
+ if (stat(filename, &sb) != -1) {
+ /* File has been created by some other process */
+ unlink(lock);
+ close(lock_fd);
+ g_free(lock);
+
+ return 0;
+ }
+
+ /* We still need to create it */
+ goto create;
+ }
+ else {
+ nanosleep(&sleep_ts, NULL);
+ }
+ }
+
+create:
+
+ msg_debug_pool("create statfile %s of size %l", filename, (long) size);
+ nblocks =
+ (size - sizeof(struct stat_file_header) -
+ sizeof(struct stat_file_section)) /
+ sizeof(struct stat_file_block);
+ header.total_blocks = nblocks;
+
+ if ((fd =
+ open(filename, O_RDWR | O_TRUNC | O_CREAT, S_IWUSR | S_IRUSR)) == -1) {
+ msg_info_pool("cannot create file %s, error %d, %s",
+ filename,
+ errno,
+ strerror(errno));
+ unlink(lock);
+ close(lock_fd);
+ g_free(lock);
+
+ return -1;
+ }
+
+ rspamd_fallocate(fd,
+ 0,
+ sizeof(header) + sizeof(section) + sizeof(block) * nblocks);
+
+ header.create_time = (guint64) time(NULL);
+ g_assert(stcf->clcf != NULL);
+ g_assert(stcf->clcf->tokenizer != NULL);
+ tokenizer = rspamd_stat_get_tokenizer(stcf->clcf->tokenizer->name);
+ g_assert(tokenizer != NULL);
+ tok_conf = tokenizer->get_config(pool, stcf->clcf->tokenizer, &tok_conf_len);
+ header.tokenizer_conf_len = tok_conf_len;
+ g_assert(tok_conf_len < sizeof(header.unused) - sizeof(guint64));
+ memcpy(header.unused, tok_conf, tok_conf_len);
+
+ if (write(fd, &header, sizeof(header)) == -1) {
+ msg_info_pool("cannot write header to file %s, error %d, %s",
+ filename,
+ errno,
+ strerror(errno));
+ close(fd);
+ unlink(lock);
+ close(lock_fd);
+ g_free(lock);
+
+ return -1;
+ }
+
+ section.length = (guint64) nblocks;
+ if (write(fd, &section, sizeof(section)) == -1) {
+ msg_info_pool("cannot write section header to file %s, error %d, %s",
+ filename,
+ errno,
+ strerror(errno));
+ close(fd);
+ unlink(lock);
+ close(lock_fd);
+ g_free(lock);
+
+ return -1;
+ }
+
+ /* Buffer for write 256 blocks at once */
+ if (nblocks > 256) {
+ buflen = sizeof(block) * 256;
+ buf = g_malloc0(buflen);
+ }
+
+ while (nblocks) {
+ if (nblocks > 256) {
+ /* Just write buffer */
+ if (write(fd, buf, buflen) == -1) {
+ msg_info_pool("cannot write blocks buffer to file %s, error %d, %s",
+ filename,
+ errno,
+ strerror(errno));
+ close(fd);
+ g_free(buf);
+ unlink(lock);
+ close(lock_fd);
+ g_free(lock);
+
+ return -1;
+ }
+ nblocks -= 256;
+ }
+ else {
+ if (write(fd, &block, sizeof(block)) == -1) {
+ msg_info_pool("cannot write block to file %s, error %d, %s",
+ filename,
+ errno,
+ strerror(errno));
+ close(fd);
+ if (buf) {
+ g_free(buf);
+ }
+
+ unlink(lock);
+ close(lock_fd);
+ g_free(lock);
+
+ return -1;
+ }
+ nblocks--;
+ }
+ }
+
+ close(fd);
+
+ if (buf) {
+ g_free(buf);
+ }
+
+ unlink(lock);
+ close(lock_fd);
+ g_free(lock);
+ msg_debug_pool("created statfile %s of size %l", filename, (long) size);
+
+ return 0;
+}
+
+gpointer
+rspamd_mmaped_file_init(struct rspamd_stat_ctx *ctx,
+ struct rspamd_config *cfg, struct rspamd_statfile *st)
+{
+ struct rspamd_statfile_config *stf = st->stcf;
+ rspamd_mmaped_file_t *mf;
+ const ucl_object_t *filenameo, *sizeo;
+ const gchar *filename;
+ gsize size;
+
+ filenameo = ucl_object_lookup(stf->opts, "filename");
+
+ if (filenameo == NULL || ucl_object_type(filenameo) != UCL_STRING) {
+ filenameo = ucl_object_lookup(stf->opts, "path");
+
+ if (filenameo == NULL || ucl_object_type(filenameo) != UCL_STRING) {
+ msg_err_config("statfile %s has no filename defined", stf->symbol);
+ return NULL;
+ }
+ }
+
+ filename = ucl_object_tostring(filenameo);
+
+ sizeo = ucl_object_lookup(stf->opts, "size");
+
+ if (sizeo == NULL || ucl_object_type(sizeo) != UCL_INT) {
+ msg_err_config("statfile %s has no size defined", stf->symbol);
+ return NULL;
+ }
+
+ size = ucl_object_toint(sizeo);
+ mf = rspamd_mmaped_file_open(cfg->cfg_pool, filename, size, stf);
+
+ if (mf != NULL) {
+ mf->pool = cfg->cfg_pool;
+ }
+ else {
+ /* Create file here */
+
+ filenameo = ucl_object_find_key(stf->opts, "filename");
+ if (filenameo == NULL || ucl_object_type(filenameo) != UCL_STRING) {
+ filenameo = ucl_object_find_key(stf->opts, "path");
+ if (filenameo == NULL || ucl_object_type(filenameo) != UCL_STRING) {
+ msg_err_config("statfile %s has no filename defined", stf->symbol);
+ return NULL;
+ }
+ }
+
+ filename = ucl_object_tostring(filenameo);
+
+ sizeo = ucl_object_find_key(stf->opts, "size");
+ if (sizeo == NULL || ucl_object_type(sizeo) != UCL_INT) {
+ msg_err_config("statfile %s has no size defined", stf->symbol);
+ return NULL;
+ }
+
+ size = ucl_object_toint(sizeo);
+
+ if (rspamd_mmaped_file_create(filename, size, stf, cfg->cfg_pool) != 0) {
+ msg_err_config("cannot create new file");
+ }
+
+ mf = rspamd_mmaped_file_open(cfg->cfg_pool, filename, size, stf);
+ }
+
+ return (gpointer) mf;
+}
+
+void rspamd_mmaped_file_close(gpointer p)
+{
+ rspamd_mmaped_file_t *mf = p;
+
+
+ if (mf) {
+ rspamd_mmaped_file_close_file(mf->pool, mf);
+ }
+}
+
+gpointer
+rspamd_mmaped_file_runtime(struct rspamd_task *task,
+ struct rspamd_statfile_config *stcf,
+ gboolean learn,
+ gpointer p,
+ gint _id)
+{
+ rspamd_mmaped_file_t *mf = p;
+
+ return (gpointer) mf;
+}
+
+gboolean
+rspamd_mmaped_file_process_tokens(struct rspamd_task *task, GPtrArray *tokens,
+ gint id,
+ gpointer p)
+{
+ rspamd_mmaped_file_t *mf = p;
+ guint32 h1, h2;
+ rspamd_token_t *tok;
+ guint i;
+
+ g_assert(tokens != NULL);
+ g_assert(p != NULL);
+
+ for (i = 0; i < tokens->len; i++) {
+ tok = g_ptr_array_index(tokens, i);
+ memcpy(&h1, (guchar *) &tok->data, sizeof(h1));
+ memcpy(&h2, ((guchar *) &tok->data) + sizeof(h1), sizeof(h2));
+ tok->values[id] = rspamd_mmaped_file_get_block(mf, h1, h2);
+ }
+
+ if (mf->cf->is_spam) {
+ task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS;
+ }
+ else {
+ task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
+ }
+
+ return TRUE;
+}
+
+gboolean
+rspamd_mmaped_file_learn_tokens(struct rspamd_task *task, GPtrArray *tokens,
+ gint id,
+ gpointer p)
+{
+ rspamd_mmaped_file_t *mf = p;
+ guint32 h1, h2;
+ rspamd_token_t *tok;
+ guint i;
+
+ g_assert(tokens != NULL);
+ g_assert(p != NULL);
+
+ for (i = 0; i < tokens->len; i++) {
+ tok = g_ptr_array_index(tokens, i);
+ memcpy(&h1, (guchar *) &tok->data, sizeof(h1));
+ memcpy(&h2, ((guchar *) &tok->data) + sizeof(h1), sizeof(h2));
+ rspamd_mmaped_file_set_block(task->task_pool, mf, h1, h2,
+ tok->values[id]);
+ }
+
+ return TRUE;
+}
+
+gulong
+rspamd_mmaped_file_total_learns(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ rspamd_mmaped_file_t *mf = (rspamd_mmaped_file_t *) runtime;
+ guint64 rev = 0;
+ time_t t;
+
+ if (mf != NULL) {
+ rspamd_mmaped_file_get_revision(mf, &rev, &t);
+ }
+
+ return rev;
+}
+
+gulong
+rspamd_mmaped_file_inc_learns(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ rspamd_mmaped_file_t *mf = (rspamd_mmaped_file_t *) runtime;
+ guint64 rev = 0;
+ time_t t;
+
+ if (mf != NULL) {
+ rspamd_mmaped_file_inc_revision(mf);
+ rspamd_mmaped_file_get_revision(mf, &rev, &t);
+ }
+
+ return rev;
+}
+
+gulong
+rspamd_mmaped_file_dec_learns(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ rspamd_mmaped_file_t *mf = (rspamd_mmaped_file_t *) runtime;
+ guint64 rev = 0;
+ time_t t;
+
+ if (mf != NULL) {
+ rspamd_mmaped_file_dec_revision(mf);
+ rspamd_mmaped_file_get_revision(mf, &rev, &t);
+ }
+
+ return rev;
+}
+
+
+ucl_object_t *
+rspamd_mmaped_file_get_stat(gpointer runtime,
+ gpointer ctx)
+{
+ ucl_object_t *res = NULL;
+ guint64 rev;
+ rspamd_mmaped_file_t *mf = (rspamd_mmaped_file_t *) runtime;
+
+ if (mf != NULL) {
+ res = ucl_object_typed_new(UCL_OBJECT);
+ rspamd_mmaped_file_get_revision(mf, &rev, NULL);
+ ucl_object_insert_key(res, ucl_object_fromint(rev), "revision",
+ 0, false);
+ ucl_object_insert_key(res, ucl_object_fromint(mf->len), "size",
+ 0, false);
+ ucl_object_insert_key(res, ucl_object_fromint(rspamd_mmaped_file_get_total(mf)), "total", 0, false);
+ ucl_object_insert_key(res, ucl_object_fromint(rspamd_mmaped_file_get_used(mf)), "used", 0, false);
+ ucl_object_insert_key(res, ucl_object_fromstring(mf->cf->symbol),
+ "symbol", 0, false);
+ ucl_object_insert_key(res, ucl_object_fromstring("mmap"),
+ "type", 0, false);
+ ucl_object_insert_key(res, ucl_object_fromint(0),
+ "languages", 0, false);
+ ucl_object_insert_key(res, ucl_object_fromint(0),
+ "users", 0, false);
+
+ if (mf->cf->label) {
+ ucl_object_insert_key(res, ucl_object_fromstring(mf->cf->label),
+ "label", 0, false);
+ }
+ }
+
+ return res;
+}
+
+gboolean
+rspamd_mmaped_file_finalize_learn(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx, GError **err)
+{
+ rspamd_mmaped_file_t *mf = (rspamd_mmaped_file_t *) runtime;
+
+ if (mf != NULL) {
+ msync(mf->map, mf->len, MS_INVALIDATE | MS_ASYNC);
+ }
+
+ return TRUE;
+}
+
+gboolean
+rspamd_mmaped_file_finalize_process(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ return TRUE;
+}
+
+gpointer
+rspamd_mmaped_file_load_tokenizer_config(gpointer runtime,
+ gsize *len)
+{
+ rspamd_mmaped_file_t *mf = runtime;
+ struct stat_file_header *header;
+
+ g_assert(mf != NULL);
+ header = mf->map;
+
+ if (len) {
+ *len = header->tokenizer_conf_len;
+ }
+
+ return header->unused;
+}
diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx
new file mode 100644
index 0000000..cd0c379
--- /dev/null
+++ b/src/libstat/backends/redis_backend.cxx
@@ -0,0 +1,1132 @@
+/*
+ * Copyright 2024 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "config.h"
+#include "lua/lua_common.h"
+#include "rspamd.h"
+#include "stat_internal.h"
+#include "upstream.h"
+#include "libserver/mempool_vars_internal.h"
+#include "fmt/core.h"
+
+#include "libutil/cxx/error.hxx"
+
+#include <string>
+#include <cstdint>
+#include <vector>
+#include <optional>
+
+#define msg_debug_stat_redis(...) rspamd_conditional_debug_fast(nullptr, nullptr, \
+ rspamd_stat_redis_log_id, "stat_redis", task->task_pool->tag.uid, \
+ RSPAMD_LOG_FUNC, \
+ __VA_ARGS__)
+
+INIT_LOG_MODULE(stat_redis)
+
+#define REDIS_CTX(p) (reinterpret_cast<struct redis_stat_ctx *>(p))
+#define REDIS_RUNTIME(p) (reinterpret_cast<struct redis_stat_runtime<float> *>(p))
+#define REDIS_DEFAULT_OBJECT "%s%l"
+#define REDIS_DEFAULT_USERS_OBJECT "%s%l%r"
+#define REDIS_DEFAULT_TIMEOUT 0.5
+#define REDIS_STAT_TIMEOUT 30
+#define REDIS_MAX_USERS 1000
+
+struct redis_stat_ctx {
+ lua_State *L;
+ struct rspamd_statfile_config *stcf;
+ const char *redis_object = REDIS_DEFAULT_OBJECT;
+ bool enable_users = false;
+ bool store_tokens = false;
+ bool enable_signatures = false;
+ int cbref_user = -1;
+
+ int cbref_classify = -1;
+ int cbref_learn = -1;
+
+ ucl_object_t *cur_stat = nullptr;
+
+ explicit redis_stat_ctx(lua_State *_L)
+ : L(_L)
+ {
+ }
+
+ ~redis_stat_ctx()
+ {
+ if (cbref_user != -1) {
+ luaL_unref(L, LUA_REGISTRYINDEX, cbref_user);
+ }
+
+ if (cbref_classify != -1) {
+ luaL_unref(L, LUA_REGISTRYINDEX, cbref_classify);
+ }
+
+ if (cbref_learn != -1) {
+ luaL_unref(L, LUA_REGISTRYINDEX, cbref_learn);
+ }
+ }
+};
+
+
+template<class T, std::enable_if_t<std::is_convertible_v<T, float>, bool> = true>
+struct redis_stat_runtime {
+ struct redis_stat_ctx *ctx;
+ struct rspamd_task *task;
+ struct rspamd_statfile_config *stcf;
+ GPtrArray *tokens = nullptr;
+ const char *redis_object_expanded;
+ std::uint64_t learned = 0;
+ int id;
+ std::vector<std::pair<int, T>> *results = nullptr;
+ bool need_redis_call = true;
+ std::optional<rspamd::util::error> err;
+
+ using result_type = std::vector<std::pair<int, T>>;
+
+private:
+ /* Called on connection termination */
+ static void rt_dtor(gpointer data)
+ {
+ auto *rt = REDIS_RUNTIME(data);
+
+ delete rt;
+ }
+
+ /* Avoid occasional deletion */
+ ~redis_stat_runtime()
+ {
+ if (tokens) {
+ g_ptr_array_unref(tokens);
+ }
+
+ delete results;
+ }
+
+public:
+ explicit redis_stat_runtime(struct redis_stat_ctx *_ctx, struct rspamd_task *_task, const char *_redis_object_expanded)
+ : ctx(_ctx), task(_task), stcf(_ctx->stcf), redis_object_expanded(_redis_object_expanded)
+ {
+ rspamd_mempool_add_destructor(task->task_pool, redis_stat_runtime<T>::rt_dtor, this);
+ }
+
+ static auto maybe_recover_from_mempool(struct rspamd_task *task, const char *redis_object_expanded,
+ bool is_spam) -> std::optional<redis_stat_runtime<T> *>
+ {
+ auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H");
+ auto *res = rspamd_mempool_get_variable(task->task_pool, var_name.c_str());
+
+ if (res) {
+ msg_debug_bayes("recovered runtime from mempool at %s", var_name.c_str());
+ return reinterpret_cast<redis_stat_runtime<T> *>(res);
+ }
+ else {
+ msg_debug_bayes("no runtime at %s", var_name.c_str());
+ return std::nullopt;
+ }
+ }
+
+ void set_results(std::vector<std::pair<int, T>> *results)
+ {
+ this->results = results;
+ }
+
+ /* Propagate results from internal representation to the tokens array */
+ auto process_tokens(GPtrArray *tokens) const -> bool
+ {
+ rspamd_token_t *tok;
+
+ if (!results) {
+ return false;
+ }
+
+ for (auto [idx, val]: *results) {
+ tok = (rspamd_token_t *) g_ptr_array_index(tokens, idx - 1);
+ tok->values[id] = val;
+ }
+
+ return true;
+ }
+
+ auto save_in_mempool(bool is_spam) const
+ {
+ auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H");
+ /* We do not set destructor for the variable, as it should be already added on creation */
+ rspamd_mempool_set_variable(task->task_pool, var_name.c_str(), (gpointer) this, nullptr);
+ msg_debug_bayes("saved runtime in mempool at %s", var_name.c_str());
+ }
+};
+
+#define GET_TASK_ELT(task, elt) (task == nullptr ? nullptr : (task)->elt)
+
+static const gchar *M = "redis statistics";
+
+static GQuark
+rspamd_redis_stat_quark(void)
+{
+ return g_quark_from_static_string(M);
+}
+
+/*
+ * Non-static for lua unit testing
+ */
+gsize rspamd_redis_expand_object(const gchar *pattern,
+ struct redis_stat_ctx *ctx,
+ struct rspamd_task *task,
+ gchar **target)
+{
+ gsize tlen = 0;
+ const gchar *p = pattern, *elt;
+ gchar *d, *end;
+ enum {
+ just_char,
+ percent_char,
+ mod_char
+ } state = just_char;
+ struct rspamd_statfile_config *stcf;
+ lua_State *L = nullptr;
+ struct rspamd_task **ptask;
+ const gchar *rcpt = nullptr;
+ gint err_idx;
+
+ g_assert(ctx != nullptr);
+ g_assert(task != nullptr);
+ stcf = ctx->stcf;
+
+ L = RSPAMD_LUA_CFG_STATE(task->cfg);
+ g_assert(L != nullptr);
+
+ if (ctx->enable_users) {
+ if (ctx->cbref_user == -1) {
+ rcpt = rspamd_task_get_principal_recipient(task);
+ }
+ else {
+ /* Execute lua function to get userdata */
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ err_idx = lua_gettop(L);
+
+ lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->cbref_user);
+ ptask = (struct rspamd_task **) lua_newuserdata(L, sizeof(struct rspamd_task *));
+ *ptask = task;
+ rspamd_lua_setclass(L, "rspamd{task}", -1);
+
+ if (lua_pcall(L, 1, 1, err_idx) != 0) {
+ msg_err_task("call to user extraction script failed: %s",
+ lua_tostring(L, -1));
+ }
+ else {
+ rcpt = rspamd_mempool_strdup(task->task_pool, lua_tostring(L, -1));
+ }
+
+ /* Result + error function */
+ lua_settop(L, err_idx - 1);
+ }
+
+ if (rcpt) {
+ rspamd_mempool_set_variable(task->task_pool, "stat_user",
+ (gpointer) rcpt, nullptr);
+ }
+ }
+
+ /* Length calculation */
+ while (*p) {
+ switch (state) {
+ case just_char:
+ if (*p == '%') {
+ state = percent_char;
+ }
+ else {
+ tlen++;
+ }
+ p++;
+ break;
+ case percent_char:
+ switch (*p) {
+ case '%':
+ tlen++;
+ state = just_char;
+ break;
+ case 'u':
+ elt = GET_TASK_ELT(task, auth_user);
+ if (elt) {
+ tlen += strlen(elt);
+ }
+ break;
+ case 'r':
+
+ if (rcpt == nullptr) {
+ elt = rspamd_task_get_principal_recipient(task);
+ }
+ else {
+ elt = rcpt;
+ }
+
+ if (elt) {
+ tlen += strlen(elt);
+ }
+ break;
+ case 'l':
+ if (stcf->label) {
+ tlen += strlen(stcf->label);
+ }
+ /* Label miss is OK */
+ break;
+ case 's':
+ tlen += sizeof("RS") - 1;
+ break;
+ default:
+ state = just_char;
+ tlen++;
+ break;
+ }
+
+ if (state == percent_char) {
+ state = mod_char;
+ }
+ p++;
+ break;
+
+ case mod_char:
+ switch (*p) {
+ case 'd':
+ p++;
+ state = just_char;
+ break;
+ default:
+ state = just_char;
+ break;
+ }
+ break;
+ }
+ }
+
+
+ if (target == nullptr) {
+ return -1;
+ }
+
+ *target = (gchar *) rspamd_mempool_alloc(task->task_pool, tlen + 1);
+ d = *target;
+ end = d + tlen + 1;
+ d[tlen] = '\0';
+ p = pattern;
+ state = just_char;
+
+ /* Expand string */
+ while (*p && d < end) {
+ switch (state) {
+ case just_char:
+ if (*p == '%') {
+ state = percent_char;
+ }
+ else {
+ *d++ = *p;
+ }
+ p++;
+ break;
+ case percent_char:
+ switch (*p) {
+ case '%':
+ *d++ = *p;
+ state = just_char;
+ break;
+ case 'u':
+ elt = GET_TASK_ELT(task, auth_user);
+ if (elt) {
+ d += rspamd_strlcpy(d, elt, end - d);
+ }
+ break;
+ case 'r':
+ if (rcpt == nullptr) {
+ elt = rspamd_task_get_principal_recipient(task);
+ }
+ else {
+ elt = rcpt;
+ }
+
+ if (elt) {
+ d += rspamd_strlcpy(d, elt, end - d);
+ }
+ break;
+ case 'l':
+ if (stcf->label) {
+ d += rspamd_strlcpy(d, stcf->label, end - d);
+ }
+ break;
+ case 's':
+ d += rspamd_strlcpy(d, "RS", end - d);
+ break;
+ default:
+ state = just_char;
+ *d++ = *p;
+ break;
+ }
+
+ if (state == percent_char) {
+ state = mod_char;
+ }
+ p++;
+ break;
+
+ case mod_char:
+ switch (*p) {
+ case 'd':
+ /* TODO: not supported yet */
+ p++;
+ state = just_char;
+ break;
+ default:
+ state = just_char;
+ break;
+ }
+ break;
+ }
+ }
+
+ return tlen;
+}
+
+static int
+rspamd_redis_stat_cb(lua_State *L)
+{
+ const auto *cookie = lua_tostring(L, lua_upvalueindex(1));
+ auto *cfg = lua_check_config(L, 1);
+ auto *backend = REDIS_CTX(rspamd_mempool_get_variable(cfg->cfg_pool, cookie));
+
+ if (backend == nullptr) {
+ msg_err("internal error: cookie %s is not found", cookie);
+
+ return 0;
+ }
+
+ auto *cur_obj = ucl_object_lua_import(L, 2);
+ msg_debug_bayes_cfg("got stat object for %s", backend->stcf->symbol);
+ /* Enrich with some default values that are meaningless for redis */
+ ucl_object_insert_key(cur_obj,
+ ucl_object_typed_new(UCL_INT), "used", 0, false);
+ ucl_object_insert_key(cur_obj,
+ ucl_object_typed_new(UCL_INT), "total", 0, false);
+ ucl_object_insert_key(cur_obj,
+ ucl_object_typed_new(UCL_INT), "size", 0, false);
+ ucl_object_insert_key(cur_obj,
+ ucl_object_fromstring(backend->stcf->symbol),
+ "symbol", 0, false);
+ ucl_object_insert_key(cur_obj, ucl_object_fromstring("redis"),
+ "type", 0, false);
+ ucl_object_insert_key(cur_obj, ucl_object_fromint(0),
+ "languages", 0, false);
+
+ if (backend->cur_stat) {
+ ucl_object_unref(backend->cur_stat);
+ }
+
+ backend->cur_stat = cur_obj;
+
+ return 0;
+}
+
+static void
+rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend,
+ const ucl_object_t *statfile_obj,
+ const ucl_object_t *classifier_obj,
+ struct rspamd_config *cfg)
+{
+ const gchar *lua_script;
+ const ucl_object_t *elt, *users_enabled;
+ auto *L = RSPAMD_LUA_CFG_STATE(cfg);
+
+ users_enabled = ucl_object_lookup_any(classifier_obj, "per_user",
+ "users_enabled", nullptr);
+
+ if (users_enabled != nullptr) {
+ if (ucl_object_type(users_enabled) == UCL_BOOLEAN) {
+ backend->enable_users = ucl_object_toboolean(users_enabled);
+ backend->cbref_user = -1;
+ }
+ else if (ucl_object_type(users_enabled) == UCL_STRING) {
+ lua_script = ucl_object_tostring(users_enabled);
+
+ if (luaL_dostring(L, lua_script) != 0) {
+ msg_err_config("cannot execute lua script for users "
+ "extraction: %s",
+ lua_tostring(L, -1));
+ }
+ else {
+ if (lua_type(L, -1) == LUA_TFUNCTION) {
+ backend->enable_users = TRUE;
+ backend->cbref_user = luaL_ref(L,
+ LUA_REGISTRYINDEX);
+ }
+ else {
+ msg_err_config("lua script must return "
+ "function(task) and not %s",
+ lua_typename(L, lua_type(L, -1)));
+ }
+ }
+ }
+ }
+ else {
+ backend->enable_users = FALSE;
+ backend->cbref_user = -1;
+ }
+
+ elt = ucl_object_lookup(classifier_obj, "prefix");
+ if (elt == nullptr || ucl_object_type(elt) != UCL_STRING) {
+ /* Default non-users statistics */
+ if (backend->enable_users || backend->cbref_user != -1) {
+ backend->redis_object = REDIS_DEFAULT_USERS_OBJECT;
+ }
+ else {
+ backend->redis_object = REDIS_DEFAULT_OBJECT;
+ }
+ }
+ else {
+ /* XXX: sanity check */
+ backend->redis_object = ucl_object_tostring(elt);
+ }
+
+ elt = ucl_object_lookup(classifier_obj, "store_tokens");
+ if (elt) {
+ backend->store_tokens = ucl_object_toboolean(elt);
+ }
+ else {
+ backend->store_tokens = FALSE;
+ }
+
+ elt = ucl_object_lookup(classifier_obj, "signatures");
+ if (elt) {
+ backend->enable_signatures = ucl_object_toboolean(elt);
+ }
+ else {
+ backend->enable_signatures = FALSE;
+ }
+}
+
+gpointer
+rspamd_redis_init(struct rspamd_stat_ctx *ctx,
+ struct rspamd_config *cfg, struct rspamd_statfile *st)
+{
+ auto *L = RSPAMD_LUA_CFG_STATE(cfg);
+
+ auto backend = std::make_unique<struct redis_stat_ctx>(L);
+ lua_settop(L, 0);
+
+ rspamd_redis_parse_classifier_opts(backend.get(), st->stcf->opts, st->classifier->cfg->opts, cfg);
+
+ st->stcf->clcf->flags |= RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND;
+ backend->stcf = st->stcf;
+
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ auto err_idx = lua_gettop(L);
+
+ /* Obtain function */
+ if (!rspamd_lua_require_function(L, "lua_bayes_redis", "lua_bayes_init_statfile")) {
+ msg_err_config("cannot require lua_bayes_redis.lua_bayes_init_statfile");
+ lua_settop(L, err_idx - 1);
+
+ return nullptr;
+ }
+
+ /* Push arguments */
+ ucl_object_push_lua(L, st->classifier->cfg->opts, false);
+ ucl_object_push_lua(L, st->stcf->opts, false);
+ lua_pushstring(L, backend->stcf->symbol);
+ lua_pushboolean(L, backend->stcf->is_spam);
+ auto **pev_base = (struct ev_loop **) lua_newuserdata(L, sizeof(struct ev_loop *));
+ *pev_base = ctx->event_loop;
+ rspamd_lua_setclass(L, "rspamd{ev_base}", -1);
+
+ /* Store backend in random cookie */
+ char *cookie = (char *) rspamd_mempool_alloc(cfg->cfg_pool, 16);
+ rspamd_random_hex(cookie, 16);
+ cookie[15] = '\0';
+ rspamd_mempool_set_variable(cfg->cfg_pool, cookie, backend.get(), nullptr);
+ /* Callback + 1 upvalue */
+ lua_pushstring(L, cookie);
+ lua_pushcclosure(L, &rspamd_redis_stat_cb, 1);
+
+ if (lua_pcall(L, 6, 2, err_idx) != 0) {
+ msg_err("call to lua_bayes_init_classifier "
+ "script failed: %s",
+ lua_tostring(L, -1));
+ lua_settop(L, err_idx - 1);
+
+ return nullptr;
+ }
+
+ /* Results are in the stack:
+ * top - 1 - classifier function (idx = -2)
+ * top - learn function (idx = -1)
+ */
+
+ lua_pushvalue(L, -2);
+ backend->cbref_classify = luaL_ref(L, LUA_REGISTRYINDEX);
+
+ lua_pushvalue(L, -1);
+ backend->cbref_learn = luaL_ref(L, LUA_REGISTRYINDEX);
+
+ lua_settop(L, err_idx - 1);
+
+ return backend.release();
+}
+
+gpointer
+rspamd_redis_runtime(struct rspamd_task *task,
+ struct rspamd_statfile_config *stcf,
+ gboolean learn, gpointer c, gint _id)
+{
+ struct redis_stat_ctx *ctx = REDIS_CTX(c);
+ char *object_expanded = nullptr;
+
+ g_assert(ctx != nullptr);
+ g_assert(stcf != nullptr);
+
+ if (rspamd_redis_expand_object(ctx->redis_object, ctx, task,
+ &object_expanded) == 0) {
+ msg_err_task("expansion for %s failed for symbol %s "
+ "(maybe learning per user classifier with no user or recipient)",
+ learn ? "learning" : "classifying",
+ stcf->symbol);
+ return nullptr;
+ }
+
+ /* Look for the cached results */
+ if (!learn) {
+ auto maybe_existing = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+ object_expanded, stcf->is_spam);
+
+ if (maybe_existing) {
+ auto *rt = maybe_existing.value();
+ /* Update stcf and ctx to correspond to what we have been asked */
+ rt->stcf = stcf;
+ rt->ctx = ctx;
+ return rt;
+ }
+ }
+
+ /* No cached result (or learn), create new one */
+ auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
+
+ if (!learn) {
+ /*
+ * For check, we also need to create the opposite class runtime to avoid
+ * double call for Redis scripts.
+ * This runtime will be filled later.
+ */
+ auto maybe_opposite_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+ object_expanded,
+ !stcf->is_spam);
+
+ if (!maybe_opposite_rt) {
+ auto *opposite_rt = new redis_stat_runtime<float>(ctx, task, object_expanded);
+ opposite_rt->save_in_mempool(!stcf->is_spam);
+ opposite_rt->need_redis_call = false;
+ }
+ }
+
+ rt->save_in_mempool(stcf->is_spam);
+
+ return rt;
+}
+
+void rspamd_redis_close(gpointer p)
+{
+ struct redis_stat_ctx *ctx = REDIS_CTX(p);
+ delete ctx;
+}
+
+static constexpr auto
+msgpack_emit_str(const std::string_view st, char *out) -> std::size_t
+{
+ auto len = st.size();
+ constexpr const unsigned char fix_mask = 0xA0, l8_ch = 0xd9, l16_ch = 0xda, l32_ch = 0xdb;
+ auto blen = 0;
+ if (len <= 0x1F) {
+ blen = 1;
+ out[0] = (len | fix_mask) & 0xff;
+ }
+ else if (len <= 0xff) {
+ blen = 2;
+ out[0] = l8_ch;
+ out[1] = len & 0xff;
+ }
+ else if (len <= 0xffff) {
+ uint16_t bl = GUINT16_TO_BE(len);
+
+ blen = 3;
+ out[0] = l16_ch;
+ memcpy(&out[1], &bl, sizeof(bl));
+ }
+ else {
+ uint32_t bl = GUINT32_TO_BE(len);
+
+ blen = 5;
+ out[0] = l32_ch;
+ memcpy(&out[1], &bl, sizeof(bl));
+ }
+
+ memcpy(&out[blen], st.data(), st.size());
+
+ return blen + len;
+}
+
+static constexpr auto
+msgpack_str_len(std::size_t len) -> std::size_t
+{
+ if (len <= 0x1F) {
+ return 1 + len;
+ }
+ else if (len <= 0xff) {
+ return 2 + len;
+ }
+ else if (len <= 0xffff) {
+ return 3 + len;
+ }
+ else {
+ return 4 + len;
+ }
+}
+
+/*
+ * Serialise stat tokens to message pack
+ */
+static char *
+rspamd_redis_serialize_tokens(struct rspamd_task *task, const gchar *prefix, GPtrArray *tokens, gsize *ser_len)
+{
+ /* Each token is int64_t that requires 10 bytes (2 int32_t) + 4 bytes array len + 1 byte array magic */
+ char max_int64_str[] = "18446744073709551615";
+ auto prefix_len = strlen(prefix);
+ std::size_t req_len = 5;
+ rspamd_token_t *tok;
+
+ /* Calculate required length */
+ req_len += tokens->len * (msgpack_str_len(sizeof(max_int64_str) + prefix_len) + 1);
+
+ auto *buf = (gchar *) rspamd_mempool_alloc(task->task_pool, req_len);
+ auto *p = buf;
+
+ /* Array */
+ *p++ = (gchar) 0xdd;
+ /* Length in big-endian (4 bytes) */
+ *p++ = (gchar) ((tokens->len >> 24) & 0xff);
+ *p++ = (gchar) ((tokens->len >> 16) & 0xff);
+ *p++ = (gchar) ((tokens->len >> 8) & 0xff);
+ *p++ = (gchar) (tokens->len & 0xff);
+
+
+ int i;
+ auto numbuf_len = sizeof(max_int64_str) + prefix_len + 1;
+ auto *numbuf = (char *) g_alloca(numbuf_len);
+
+ PTR_ARRAY_FOREACH(tokens, i, tok)
+ {
+ std::size_t r = rspamd_snprintf(numbuf, numbuf_len, "%s_%uL", prefix, tok->data);
+ auto shift = msgpack_emit_str({numbuf, r}, p);
+ p += shift;
+ }
+
+ *ser_len = p - buf;
+
+ return buf;
+}
+
+static char *
+rspamd_redis_serialize_text_tokens(struct rspamd_task *task, GPtrArray *tokens, gsize *ser_len)
+{
+ rspamd_token_t *tok;
+ auto req_len = 5; /* Messagepack array prefix */
+ int i;
+
+ /*
+ * First we need to determine the requested length
+ */
+ PTR_ARRAY_FOREACH(tokens, i, tok)
+ {
+ if (tok->t1 && tok->t2) {
+ /* Two tokens */
+ req_len += msgpack_str_len(tok->t1->stemmed.len) + msgpack_str_len(tok->t2->stemmed.len);
+ }
+ else if (tok->t1) {
+ req_len += msgpack_str_len(tok->t1->stemmed.len);
+ req_len += 1; /* null */
+ }
+ else {
+ req_len += 2; /* 2 nulls */
+ }
+ }
+
+ auto *buf = (gchar *) rspamd_mempool_alloc(task->task_pool, req_len);
+ auto *p = buf;
+
+ /* Array */
+ std::uint32_t nlen = tokens->len * 2;
+ nlen = GUINT32_TO_BE(nlen);
+ *p++ = (gchar) 0xdd;
+ /* Length in big-endian (4 bytes) */
+ memcpy(p, &nlen, sizeof(nlen));
+ p += sizeof(nlen);
+
+ PTR_ARRAY_FOREACH(tokens, i, tok)
+ {
+ if (tok->t1 && tok->t2) {
+ auto step = msgpack_emit_str({tok->t1->stemmed.begin, tok->t1->stemmed.len}, p);
+ p += step;
+ step = msgpack_emit_str({tok->t2->stemmed.begin, tok->t2->stemmed.len}, p);
+ p += step;
+ }
+ else if (tok->t1) {
+ auto step = msgpack_emit_str({tok->t1->stemmed.begin, tok->t1->stemmed.len}, p);
+ p += step;
+ *p++ = 0xc0;
+ }
+ else {
+ *p++ = 0xc0;
+ *p++ = 0xc0;
+ }
+ }
+
+ *ser_len = p - buf;
+
+ return buf;
+}
+
+static gint
+rspamd_redis_classified(lua_State *L)
+{
+ const auto *cookie = lua_tostring(L, lua_upvalueindex(1));
+ auto *task = lua_check_task(L, 1);
+ auto *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie));
+
+ if (rt == nullptr) {
+ msg_err_task("internal error: cannot find runtime for cookie %s", cookie);
+
+ return 0;
+ }
+
+ bool result = lua_toboolean(L, 2);
+
+ if (result) {
+ /* Indexes:
+ * 3 - learned_ham (int)
+ * 4 - learned_spam (int)
+ * 5 - ham_tokens (pair<int, int>)
+ * 6 - spam_tokens (pair<int, int>)
+ */
+
+ /*
+ * We need to fill our runtime AND the opposite runtime
+ */
+ auto filler_func = [](redis_stat_runtime<float> *rt, lua_State *L, unsigned learned, int tokens_pos) {
+ rt->learned = learned;
+ redis_stat_runtime<float>::result_type *res;
+
+ res = new redis_stat_runtime<float>::result_type();
+
+ for (lua_pushnil(L); lua_next(L, tokens_pos); lua_pop(L, 1)) {
+ lua_rawgeti(L, -1, 1);
+ auto idx = lua_tointeger(L, -1);
+ lua_pop(L, 1);
+
+ lua_rawgeti(L, -1, 2);
+ auto value = lua_tonumber(L, -1);
+ lua_pop(L, 1);
+
+ res->emplace_back(idx, value);
+ }
+
+ rt->set_results(res);
+ };
+
+ auto opposite_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task,
+ rt->redis_object_expanded,
+ !rt->stcf->is_spam);
+
+ if (!opposite_rt_maybe) {
+ msg_err_task("internal error: cannot find opposite runtime for cookie %s", cookie);
+
+ return 0;
+ }
+
+ if (rt->stcf->is_spam) {
+ filler_func(rt, L, lua_tointeger(L, 4), 6);
+ filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 3), 5);
+ }
+ else {
+ filler_func(rt, L, lua_tointeger(L, 3), 5);
+ filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 4), 6);
+ }
+
+ /* Mark task as being processed */
+ task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS | RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
+
+ /* Process all tokens */
+ g_assert(rt->tokens != nullptr);
+ rt->process_tokens(rt->tokens);
+ opposite_rt_maybe.value()->process_tokens(rt->tokens);
+ }
+ else {
+ /* Error message is on index 3 */
+ const auto *err_msg = lua_tostring(L, 3);
+ rt->err = rspamd::util::error(err_msg, 500);
+ msg_err_task("cannot classify task: %s",
+ err_msg);
+ }
+
+ return 0;
+}
+
+gboolean
+rspamd_redis_process_tokens(struct rspamd_task *task,
+ GPtrArray *tokens,
+ gint id, gpointer p)
+{
+ auto *rt = REDIS_RUNTIME(p);
+ auto *L = rt->ctx->L;
+
+ if (rspamd_session_blocked(task->s)) {
+ return FALSE;
+ }
+
+ if (tokens == nullptr || tokens->len == 0) {
+ return FALSE;
+ }
+
+ if (!rt->need_redis_call) {
+ /* No need to do anything, as it is already done in the opposite class processing */
+ /* However, we need to store id as it is needed for further tokens processing */
+ rt->id = id;
+ rt->tokens = g_ptr_array_ref(tokens);
+
+ return TRUE;
+ }
+
+ gsize tokens_len;
+ gchar *tokens_buf = rspamd_redis_serialize_tokens(task, rt->redis_object_expanded, tokens, &tokens_len);
+ rt->id = id;
+
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ gint err_idx = lua_gettop(L);
+
+ /* Function arguments */
+ lua_rawgeti(L, LUA_REGISTRYINDEX, rt->ctx->cbref_classify);
+ rspamd_lua_task_push(L, task);
+ lua_pushstring(L, rt->redis_object_expanded);
+ lua_pushinteger(L, id);
+ lua_pushboolean(L, rt->stcf->is_spam);
+ lua_new_text(L, tokens_buf, tokens_len, false);
+
+ /* Store rt in random cookie */
+ char *cookie = (char *) rspamd_mempool_alloc(task->task_pool, 16);
+ rspamd_random_hex(cookie, 16);
+ cookie[15] = '\0';
+ rspamd_mempool_set_variable(task->task_pool, cookie, rt, nullptr);
+ /* Callback */
+ lua_pushstring(L, cookie);
+ lua_pushcclosure(L, &rspamd_redis_classified, 1);
+
+ if (lua_pcall(L, 6, 0, err_idx) != 0) {
+ msg_err_task("call to redis failed: %s", lua_tostring(L, -1));
+ lua_settop(L, err_idx - 1);
+ return FALSE;
+ }
+
+ rt->tokens = g_ptr_array_ref(tokens);
+
+ lua_settop(L, err_idx - 1);
+ return TRUE;
+}
+
+gboolean
+rspamd_redis_finalize_process(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ auto *rt = REDIS_RUNTIME(runtime);
+
+ return !rt->err.has_value();
+}
+
+
+static gint
+rspamd_redis_learned(lua_State *L)
+{
+ const auto *cookie = lua_tostring(L, lua_upvalueindex(1));
+ auto *task = lua_check_task(L, 1);
+ auto *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie));
+
+ if (rt == nullptr) {
+ msg_err_task("internal error: cannot find runtime for cookie %s", cookie);
+
+ return 0;
+ }
+
+ bool result = lua_toboolean(L, 2);
+
+ if (result) {
+ /* TODO: write it */
+ }
+ else {
+ /* Error message is on index 3 */
+ const auto *err_msg = lua_tostring(L, 3);
+ rt->err = rspamd::util::error(err_msg, 500);
+ msg_err_task("cannot learn task: %s", err_msg);
+ }
+
+ return 0;
+}
+
+gboolean
+rspamd_redis_learn_tokens(struct rspamd_task *task,
+ GPtrArray *tokens,
+ gint id, gpointer p)
+{
+ auto *rt = REDIS_RUNTIME(p);
+ auto *L = rt->ctx->L;
+
+ if (rspamd_session_blocked(task->s)) {
+ return FALSE;
+ }
+
+ if (tokens == nullptr || tokens->len == 0) {
+ return FALSE;
+ }
+
+ gsize tokens_len;
+ gchar *tokens_buf = rspamd_redis_serialize_tokens(task, rt->redis_object_expanded, tokens, &tokens_len);
+
+ rt->id = id;
+
+ gsize text_tokens_len = 0;
+ gchar *text_tokens_buf = nullptr;
+
+ if (rt->ctx->store_tokens) {
+ text_tokens_buf = rspamd_redis_serialize_text_tokens(task, tokens, &text_tokens_len);
+ }
+
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ gint err_idx = lua_gettop(L);
+ auto nargs = 8;
+
+ /* Function arguments */
+ lua_rawgeti(L, LUA_REGISTRYINDEX, rt->ctx->cbref_learn);
+ rspamd_lua_task_push(L, task);
+ lua_pushstring(L, rt->redis_object_expanded);
+ lua_pushinteger(L, id);
+ lua_pushboolean(L, rt->stcf->is_spam);
+ lua_pushstring(L, rt->stcf->symbol);
+
+ /* Detect unlearn */
+ auto *tok = (rspamd_token_t *) g_ptr_array_index(task->tokens, 0);
+
+ if (tok->values[id] > 0) {
+ lua_pushboolean(L, FALSE);// Learn
+ }
+ else {
+ lua_pushboolean(L, TRUE);// Unlearn
+ }
+ lua_new_text(L, tokens_buf, tokens_len, false);
+
+ /* Store rt in random cookie */
+ char *cookie = (char *) rspamd_mempool_alloc(task->task_pool, 16);
+ rspamd_random_hex(cookie, 16);
+ cookie[15] = '\0';
+ rspamd_mempool_set_variable(task->task_pool, cookie, rt, nullptr);
+ /* Callback */
+ lua_pushstring(L, cookie);
+ lua_pushcclosure(L, &rspamd_redis_learned, 1);
+
+ if (text_tokens_len) {
+ nargs = 9;
+ lua_new_text(L, text_tokens_buf, text_tokens_len, false);
+ }
+
+ if (lua_pcall(L, nargs, 0, err_idx) != 0) {
+ msg_err_task("call to script failed: %s", lua_tostring(L, -1));
+ lua_settop(L, err_idx - 1);
+ return FALSE;
+ }
+
+ rt->tokens = g_ptr_array_ref(tokens);
+
+ lua_settop(L, err_idx - 1);
+ return TRUE;
+}
+
+
+gboolean
+rspamd_redis_finalize_learn(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx, GError **err)
+{
+ auto *rt = REDIS_RUNTIME(runtime);
+
+ if (rt->err.has_value()) {
+ rt->err->into_g_error_set(rspamd_redis_stat_quark(), err);
+
+ return FALSE;
+ }
+
+ return TRUE;
+}
+
+gulong
+rspamd_redis_total_learns(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ auto *rt = REDIS_RUNTIME(runtime);
+
+ return rt->learned;
+}
+
+gulong
+rspamd_redis_inc_learns(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ auto *rt = REDIS_RUNTIME(runtime);
+
+ /* XXX: may cause races */
+ return rt->learned + 1;
+}
+
+gulong
+rspamd_redis_dec_learns(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ auto *rt = REDIS_RUNTIME(runtime);
+
+ /* XXX: may cause races */
+ return rt->learned + 1;
+}
+
+gulong
+rspamd_redis_learns(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ auto *rt = REDIS_RUNTIME(runtime);
+
+ return rt->learned;
+}
+
+ucl_object_t *
+rspamd_redis_get_stat(gpointer runtime,
+ gpointer ctx)
+{
+ auto *rt = REDIS_RUNTIME(runtime);
+
+ return ucl_object_ref(rt->ctx->cur_stat);
+}
+
+gpointer
+rspamd_redis_load_tokenizer_config(gpointer runtime,
+ gsize *len)
+{
+ return nullptr;
+}
diff --git a/src/libstat/backends/sqlite3_backend.c b/src/libstat/backends/sqlite3_backend.c
new file mode 100644
index 0000000..2fd34d8
--- /dev/null
+++ b/src/libstat/backends/sqlite3_backend.c
@@ -0,0 +1,907 @@
+/*-
+ * Copyright 2016 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "config.h"
+#include "rspamd.h"
+#include "sqlite3.h"
+#include "libutil/sqlite_utils.h"
+#include "libstat/stat_internal.h"
+#include "libmime/message.h"
+#include "lua/lua_common.h"
+#include "unix-std.h"
+
+#define SQLITE3_BACKEND_TYPE "sqlite3"
+#define SQLITE3_SCHEMA_VERSION "1"
+#define SQLITE3_DEFAULT "default"
+
+struct rspamd_stat_sqlite3_db {
+ sqlite3 *sqlite;
+ gchar *fname;
+ GArray *prstmt;
+ lua_State *L;
+ rspamd_mempool_t *pool;
+ gboolean in_transaction;
+ gboolean enable_users;
+ gboolean enable_languages;
+ gint cbref_user;
+ gint cbref_language;
+};
+
+struct rspamd_stat_sqlite3_rt {
+ struct rspamd_task *task;
+ struct rspamd_stat_sqlite3_db *db;
+ struct rspamd_statfile_config *cf;
+ gint64 user_id;
+ gint64 lang_id;
+};
+
+static const char *create_tables_sql =
+ "BEGIN IMMEDIATE;"
+ "CREATE TABLE tokenizer(data BLOB);"
+ "CREATE TABLE users("
+ "id INTEGER PRIMARY KEY,"
+ "name TEXT,"
+ "learns INTEGER"
+ ");"
+ "CREATE TABLE languages("
+ "id INTEGER PRIMARY KEY,"
+ "name TEXT,"
+ "learns INTEGER"
+ ");"
+ "CREATE TABLE tokens("
+ "token INTEGER NOT NULL,"
+ "user INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,"
+ "language INTEGER NOT NULL REFERENCES languages(id) ON DELETE CASCADE,"
+ "value INTEGER,"
+ "modified INTEGER,"
+ "CONSTRAINT tid UNIQUE (token, user, language) ON CONFLICT REPLACE"
+ ");"
+ "CREATE UNIQUE INDEX IF NOT EXISTS un ON users(name);"
+ "CREATE INDEX IF NOT EXISTS tok ON tokens(token);"
+ "CREATE UNIQUE INDEX IF NOT EXISTS ln ON languages(name);"
+ "PRAGMA user_version=" SQLITE3_SCHEMA_VERSION ";"
+ "INSERT INTO users(id, name, learns) VALUES(0, '" SQLITE3_DEFAULT "',0);"
+ "INSERT INTO languages(id, name, learns) VALUES(0, '" SQLITE3_DEFAULT "',0);"
+ "COMMIT;";
+
+enum rspamd_stat_sqlite3_stmt_idx {
+ RSPAMD_STAT_BACKEND_TRANSACTION_START_IM = 0,
+ RSPAMD_STAT_BACKEND_TRANSACTION_START_DEF,
+ RSPAMD_STAT_BACKEND_TRANSACTION_START_EXCL,
+ RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT,
+ RSPAMD_STAT_BACKEND_TRANSACTION_ROLLBACK,
+ RSPAMD_STAT_BACKEND_GET_TOKEN_FULL,
+ RSPAMD_STAT_BACKEND_GET_TOKEN_SIMPLE,
+ RSPAMD_STAT_BACKEND_SET_TOKEN,
+ RSPAMD_STAT_BACKEND_INC_LEARNS_LANG,
+ RSPAMD_STAT_BACKEND_INC_LEARNS_USER,
+ RSPAMD_STAT_BACKEND_DEC_LEARNS_LANG,
+ RSPAMD_STAT_BACKEND_DEC_LEARNS_USER,
+ RSPAMD_STAT_BACKEND_GET_LEARNS,
+ RSPAMD_STAT_BACKEND_GET_LANGUAGE,
+ RSPAMD_STAT_BACKEND_GET_USER,
+ RSPAMD_STAT_BACKEND_INSERT_USER,
+ RSPAMD_STAT_BACKEND_INSERT_LANGUAGE,
+ RSPAMD_STAT_BACKEND_SAVE_TOKENIZER,
+ RSPAMD_STAT_BACKEND_LOAD_TOKENIZER,
+ RSPAMD_STAT_BACKEND_NTOKENS,
+ RSPAMD_STAT_BACKEND_NLANGUAGES,
+ RSPAMD_STAT_BACKEND_NUSERS,
+ RSPAMD_STAT_BACKEND_MAX
+};
+
+static struct rspamd_sqlite3_prstmt prepared_stmts[RSPAMD_STAT_BACKEND_MAX] =
+ {
+ [RSPAMD_STAT_BACKEND_TRANSACTION_START_IM] = {
+ .idx = RSPAMD_STAT_BACKEND_TRANSACTION_START_IM,
+ .sql = "BEGIN IMMEDIATE TRANSACTION;",
+ .args = "",
+ .stmt = NULL,
+ .result = SQLITE_DONE,
+ .flags = 0,
+ .ret = "",
+ },
+ [RSPAMD_STAT_BACKEND_TRANSACTION_START_DEF] = {.idx = RSPAMD_STAT_BACKEND_TRANSACTION_START_DEF, .sql = "BEGIN DEFERRED TRANSACTION;", .args = "", .stmt = NULL, .result = SQLITE_DONE, .flags = 0, .ret = ""},
+ [RSPAMD_STAT_BACKEND_TRANSACTION_START_EXCL] = {.idx = RSPAMD_STAT_BACKEND_TRANSACTION_START_EXCL, .sql = "BEGIN EXCLUSIVE TRANSACTION;", .args = "", .stmt = NULL, .result = SQLITE_DONE, .flags = 0, .ret = ""},
+ [RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT] = {.idx = RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT, .sql = "COMMIT;", .args = "", .stmt = NULL, .result = SQLITE_DONE, .flags = 0, .ret = ""},
+ [RSPAMD_STAT_BACKEND_TRANSACTION_ROLLBACK] = {.idx = RSPAMD_STAT_BACKEND_TRANSACTION_ROLLBACK, .sql = "ROLLBACK;", .args = "", .stmt = NULL, .result = SQLITE_DONE, .flags = 0, .ret = ""},
+ [RSPAMD_STAT_BACKEND_GET_TOKEN_FULL] = {.idx = RSPAMD_STAT_BACKEND_GET_TOKEN_FULL, .sql = "SELECT value FROM tokens "
+ "LEFT JOIN languages ON tokens.language=languages.id "
+ "LEFT JOIN users ON tokens.user=users.id "
+ "WHERE token=?1 AND (users.id=?2) "
+ "AND (languages.id=?3 OR languages.id=0);",
+ .stmt = NULL,
+ .args = "III",
+ .result = SQLITE_ROW,
+ .flags = 0,
+ .ret = "I"},
+ [RSPAMD_STAT_BACKEND_GET_TOKEN_SIMPLE] = {.idx = RSPAMD_STAT_BACKEND_GET_TOKEN_SIMPLE, .sql = "SELECT value FROM tokens WHERE token=?1", .stmt = NULL, .args = "I", .result = SQLITE_ROW, .flags = 0, .ret = "I"},
+ [RSPAMD_STAT_BACKEND_SET_TOKEN] = {.idx = RSPAMD_STAT_BACKEND_SET_TOKEN, .sql = "INSERT OR REPLACE INTO tokens (token, user, language, value, modified) "
+ "VALUES (?1, ?2, ?3, ?4, strftime('%s','now'))",
+ .stmt = NULL,
+ .args = "IIII",
+ .result = SQLITE_DONE,
+ .flags = 0,
+ .ret = ""},
+ [RSPAMD_STAT_BACKEND_INC_LEARNS_LANG] = {.idx = RSPAMD_STAT_BACKEND_INC_LEARNS_LANG, .sql = "UPDATE languages SET learns=learns + 1 WHERE id=?1", .stmt = NULL, .args = "I", .result = SQLITE_DONE, .flags = 0, .ret = ""},
+ [RSPAMD_STAT_BACKEND_INC_LEARNS_USER] = {.idx = RSPAMD_STAT_BACKEND_INC_LEARNS_USER, .sql = "UPDATE users SET learns=learns + 1 WHERE id=?1", .stmt = NULL, .args = "I", .result = SQLITE_DONE, .flags = 0, .ret = ""},
+ [RSPAMD_STAT_BACKEND_DEC_LEARNS_LANG] = {.idx = RSPAMD_STAT_BACKEND_DEC_LEARNS_LANG, .sql = "UPDATE languages SET learns=MAX(0, learns - 1) WHERE id=?1", .stmt = NULL, .args = "I", .result = SQLITE_DONE, .flags = 0, .ret = ""},
+ [RSPAMD_STAT_BACKEND_DEC_LEARNS_USER] = {.idx = RSPAMD_STAT_BACKEND_DEC_LEARNS_USER, .sql = "UPDATE users SET learns=MAX(0, learns - 1) WHERE id=?1", .stmt = NULL, .args = "I", .result = SQLITE_DONE, .flags = 0, .ret = ""},
+ [RSPAMD_STAT_BACKEND_GET_LEARNS] = {.idx = RSPAMD_STAT_BACKEND_GET_LEARNS, .sql = "SELECT SUM(MAX(0, learns)) FROM languages", .stmt = NULL, .args = "", .result = SQLITE_ROW, .flags = 0, .ret = "I"},
+ [RSPAMD_STAT_BACKEND_GET_LANGUAGE] = {.idx = RSPAMD_STAT_BACKEND_GET_LANGUAGE, .sql = "SELECT id FROM languages WHERE name=?1", .stmt = NULL, .args = "T", .result = SQLITE_ROW, .flags = 0, .ret = "I"},
+ [RSPAMD_STAT_BACKEND_GET_USER] = {.idx = RSPAMD_STAT_BACKEND_GET_USER, .sql = "SELECT id FROM users WHERE name=?1", .stmt = NULL, .args = "T", .result = SQLITE_ROW, .flags = 0, .ret = "I"},
+ [RSPAMD_STAT_BACKEND_INSERT_USER] = {.idx = RSPAMD_STAT_BACKEND_INSERT_USER, .sql = "INSERT INTO users (name, learns) VALUES (?1, 0)", .stmt = NULL, .args = "T", .result = SQLITE_DONE, .flags = 0, .ret = "L"},
+ [RSPAMD_STAT_BACKEND_INSERT_LANGUAGE] = {.idx = RSPAMD_STAT_BACKEND_INSERT_LANGUAGE, .sql = "INSERT INTO languages (name, learns) VALUES (?1, 0)", .stmt = NULL, .args = "T", .result = SQLITE_DONE, .flags = 0, .ret = "L"},
+ [RSPAMD_STAT_BACKEND_SAVE_TOKENIZER] = {.idx = RSPAMD_STAT_BACKEND_SAVE_TOKENIZER, .sql = "INSERT INTO tokenizer(data) VALUES (?1)", .stmt = NULL, .args = "B", .result = SQLITE_DONE, .flags = 0, .ret = ""},
+ [RSPAMD_STAT_BACKEND_LOAD_TOKENIZER] = {.idx = RSPAMD_STAT_BACKEND_LOAD_TOKENIZER, .sql = "SELECT data FROM tokenizer", .stmt = NULL, .args = "", .result = SQLITE_ROW, .flags = 0, .ret = "B"},
+ [RSPAMD_STAT_BACKEND_NTOKENS] = {.idx = RSPAMD_STAT_BACKEND_NTOKENS, .sql = "SELECT COUNT(*) FROM tokens", .stmt = NULL, .args = "", .result = SQLITE_ROW, .flags = 0, .ret = "I"},
+ [RSPAMD_STAT_BACKEND_NLANGUAGES] = {.idx = RSPAMD_STAT_BACKEND_NLANGUAGES, .sql = "SELECT COUNT(*) FROM languages", .stmt = NULL, .args = "", .result = SQLITE_ROW, .flags = 0, .ret = "I"},
+ [RSPAMD_STAT_BACKEND_NUSERS] = {.idx = RSPAMD_STAT_BACKEND_NUSERS, .sql = "SELECT COUNT(*) FROM users", .stmt = NULL, .args = "", .result = SQLITE_ROW, .flags = 0, .ret = "I"}};
+
+static GQuark
+rspamd_sqlite3_backend_quark(void)
+{
+ return g_quark_from_static_string("sqlite3-stat-backend");
+}
+
+static gint64
+rspamd_sqlite3_get_user(struct rspamd_stat_sqlite3_db *db,
+ struct rspamd_task *task, gboolean learn)
+{
+ gint64 id = 0; /* Default user is 0 */
+ gint rc, err_idx;
+ const gchar *user = NULL;
+ struct rspamd_task **ptask;
+ lua_State *L = db->L;
+
+ if (db->cbref_user == -1) {
+ user = rspamd_task_get_principal_recipient(task);
+ }
+ else {
+ /* Execute lua function to get userdata */
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ err_idx = lua_gettop(L);
+
+ lua_rawgeti(L, LUA_REGISTRYINDEX, db->cbref_user);
+ ptask = lua_newuserdata(L, sizeof(struct rspamd_task *));
+ *ptask = task;
+ rspamd_lua_setclass(L, "rspamd{task}", -1);
+
+ if (lua_pcall(L, 1, 1, err_idx) != 0) {
+ msg_err_task("call to user extraction script failed: %s",
+ lua_tostring(L, -1));
+ }
+ else {
+ user = rspamd_mempool_strdup(task->task_pool, lua_tostring(L, -1));
+ }
+
+ /* Result + error function */
+ lua_settop(L, err_idx - 1);
+ }
+
+
+ if (user != NULL) {
+ rspamd_mempool_set_variable(task->task_pool, "stat_user",
+ (gpointer) user, NULL);
+
+ rc = rspamd_sqlite3_run_prstmt(task->task_pool, db->sqlite, db->prstmt,
+ RSPAMD_STAT_BACKEND_GET_USER, user, &id);
+
+ if (rc != SQLITE_OK && learn) {
+ /* We need to insert a new user */
+ if (!db->in_transaction) {
+ rspamd_sqlite3_run_prstmt(task->task_pool, db->sqlite, db->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_START_IM);
+ db->in_transaction = TRUE;
+ }
+
+ rc = rspamd_sqlite3_run_prstmt(task->task_pool, db->sqlite, db->prstmt,
+ RSPAMD_STAT_BACKEND_INSERT_USER, user, &id);
+ }
+ }
+
+ return id;
+}
+
+static gint64
+rspamd_sqlite3_get_language(struct rspamd_stat_sqlite3_db *db,
+ struct rspamd_task *task, gboolean learn)
+{
+ gint64 id = 0; /* Default language is 0 */
+ gint rc, err_idx;
+ guint i;
+ const gchar *language = NULL;
+ struct rspamd_mime_text_part *tp;
+ struct rspamd_task **ptask;
+ lua_State *L = db->L;
+
+ if (db->cbref_language == -1) {
+ PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, tp)
+ {
+
+ if (tp->language != NULL && tp->language[0] != '\0' &&
+ strcmp(tp->language, "en") != 0) {
+ language = tp->language;
+ break;
+ }
+ }
+ }
+ else {
+ /* Execute lua function to get userdata */
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ err_idx = lua_gettop(L);
+
+ lua_rawgeti(L, LUA_REGISTRYINDEX, db->cbref_language);
+ ptask = lua_newuserdata(L, sizeof(struct rspamd_task *));
+ *ptask = task;
+ rspamd_lua_setclass(L, "rspamd{task}", -1);
+
+ if (lua_pcall(L, 1, 1, err_idx) != 0) {
+ msg_err_task("call to language extraction script failed: %s",
+ lua_tostring(L, -1));
+ }
+ else {
+ language = rspamd_mempool_strdup(task->task_pool,
+ lua_tostring(L, -1));
+ }
+
+ /* Result + error function */
+ lua_settop(L, err_idx - 1);
+ }
+
+
+ /* XXX: We ignore multiple languages but default + extra */
+ if (language != NULL) {
+ rc = rspamd_sqlite3_run_prstmt(task->task_pool, db->sqlite, db->prstmt,
+ RSPAMD_STAT_BACKEND_GET_LANGUAGE, language, &id);
+
+ if (rc != SQLITE_OK && learn) {
+ /* We need to insert a new language */
+ if (!db->in_transaction) {
+ rspamd_sqlite3_run_prstmt(task->task_pool, db->sqlite, db->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_START_IM);
+ db->in_transaction = TRUE;
+ }
+
+ rc = rspamd_sqlite3_run_prstmt(task->task_pool, db->sqlite, db->prstmt,
+ RSPAMD_STAT_BACKEND_INSERT_LANGUAGE, language, &id);
+ }
+ }
+
+ return id;
+}
+
+static struct rspamd_stat_sqlite3_db *
+rspamd_sqlite3_opendb(rspamd_mempool_t *pool,
+ struct rspamd_statfile_config *stcf,
+ const gchar *path, const ucl_object_t *opts,
+ gboolean create, GError **err)
+{
+ struct rspamd_stat_sqlite3_db *bk;
+ struct rspamd_stat_tokenizer *tokenizer;
+ gpointer tk_conf;
+ gsize sz = 0;
+ gint64 sz64 = 0;
+ gchar *tok_conf_encoded;
+ gint ret, ntries = 0;
+ const gint max_tries = 100;
+ struct timespec sleep_ts = {
+ .tv_sec = 0,
+ .tv_nsec = 1000000};
+
+ bk = g_malloc0(sizeof(*bk));
+ bk->sqlite = rspamd_sqlite3_open_or_create(pool, path, create_tables_sql,
+ 0, err);
+ bk->pool = pool;
+
+ if (bk->sqlite == NULL) {
+ g_free(bk);
+
+ return NULL;
+ }
+
+ bk->fname = g_strdup(path);
+
+ bk->prstmt = rspamd_sqlite3_init_prstmt(bk->sqlite, prepared_stmts,
+ RSPAMD_STAT_BACKEND_MAX, err);
+
+ if (bk->prstmt == NULL) {
+ sqlite3_close(bk->sqlite);
+ g_free(bk);
+
+ return NULL;
+ }
+
+ /* Check tokenizer configuration */
+ if (rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_LOAD_TOKENIZER, &sz64, &tk_conf) != SQLITE_OK ||
+ sz64 == 0) {
+
+ while ((ret = rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_START_EXCL)) == SQLITE_BUSY &&
+ ++ntries <= max_tries) {
+ nanosleep(&sleep_ts, NULL);
+ }
+
+ msg_info_pool("absent tokenizer conf in %s, creating a new one",
+ bk->fname);
+ g_assert(stcf->clcf->tokenizer != NULL);
+ tokenizer = rspamd_stat_get_tokenizer(stcf->clcf->tokenizer->name);
+ g_assert(tokenizer != NULL);
+ tk_conf = tokenizer->get_config(pool, stcf->clcf->tokenizer, &sz);
+
+ /* Encode to base32 */
+ tok_conf_encoded = rspamd_encode_base32(tk_conf, sz, RSPAMD_BASE32_DEFAULT);
+
+ if (rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_SAVE_TOKENIZER,
+ (gint64) strlen(tok_conf_encoded),
+ tok_conf_encoded) != SQLITE_OK) {
+ sqlite3_close(bk->sqlite);
+ g_free(bk);
+ g_free(tok_conf_encoded);
+
+ return NULL;
+ }
+
+ rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT);
+ g_free(tok_conf_encoded);
+ }
+ else {
+ g_free(tk_conf);
+ }
+
+ return bk;
+}
+
+gpointer
+rspamd_sqlite3_init(struct rspamd_stat_ctx *ctx,
+ struct rspamd_config *cfg,
+ struct rspamd_statfile *st)
+{
+ struct rspamd_classifier_config *clf = st->classifier->cfg;
+ struct rspamd_statfile_config *stf = st->stcf;
+ const ucl_object_t *filenameo, *lang_enabled, *users_enabled;
+ const gchar *filename, *lua_script;
+ struct rspamd_stat_sqlite3_db *bk;
+ GError *err = NULL;
+
+ filenameo = ucl_object_lookup(stf->opts, "filename");
+ if (filenameo == NULL || ucl_object_type(filenameo) != UCL_STRING) {
+ filenameo = ucl_object_lookup(stf->opts, "path");
+ if (filenameo == NULL || ucl_object_type(filenameo) != UCL_STRING) {
+ msg_err_config("statfile %s has no filename defined", stf->symbol);
+ return NULL;
+ }
+ }
+
+ filename = ucl_object_tostring(filenameo);
+
+ if ((bk = rspamd_sqlite3_opendb(cfg->cfg_pool, stf, filename,
+ stf->opts, TRUE, &err)) == NULL) {
+ msg_err_config("cannot open sqlite3 db %s: %e", filename, err);
+ g_error_free(err);
+ return NULL;
+ }
+
+ bk->L = cfg->lua_state;
+
+ users_enabled = ucl_object_lookup_any(clf->opts, "per_user",
+ "users_enabled", NULL);
+ if (users_enabled != NULL) {
+ if (ucl_object_type(users_enabled) == UCL_BOOLEAN) {
+ bk->enable_users = ucl_object_toboolean(users_enabled);
+ bk->cbref_user = -1;
+ }
+ else if (ucl_object_type(users_enabled) == UCL_STRING) {
+ lua_script = ucl_object_tostring(users_enabled);
+
+ if (luaL_dostring(cfg->lua_state, lua_script) != 0) {
+ msg_err_config("cannot execute lua script for users "
+ "extraction: %s",
+ lua_tostring(cfg->lua_state, -1));
+ }
+ else {
+ if (lua_type(cfg->lua_state, -1) == LUA_TFUNCTION) {
+ bk->enable_users = TRUE;
+ bk->cbref_user = luaL_ref(cfg->lua_state,
+ LUA_REGISTRYINDEX);
+ }
+ else {
+ msg_err_config("lua script must return "
+ "function(task) and not %s",
+ lua_typename(cfg->lua_state, lua_type(
+ cfg->lua_state, -1)));
+ }
+ }
+ }
+ }
+ else {
+ bk->enable_users = FALSE;
+ }
+
+ lang_enabled = ucl_object_lookup_any(clf->opts,
+ "per_language", "languages_enabled", NULL);
+
+ if (lang_enabled != NULL) {
+ if (ucl_object_type(lang_enabled) == UCL_BOOLEAN) {
+ bk->enable_languages = ucl_object_toboolean(lang_enabled);
+ bk->cbref_language = -1;
+ }
+ else if (ucl_object_type(lang_enabled) == UCL_STRING) {
+ lua_script = ucl_object_tostring(lang_enabled);
+
+ if (luaL_dostring(cfg->lua_state, lua_script) != 0) {
+ msg_err_config(
+ "cannot execute lua script for languages "
+ "extraction: %s",
+ lua_tostring(cfg->lua_state, -1));
+ }
+ else {
+ if (lua_type(cfg->lua_state, -1) == LUA_TFUNCTION) {
+ bk->enable_languages = TRUE;
+ bk->cbref_language = luaL_ref(cfg->lua_state,
+ LUA_REGISTRYINDEX);
+ }
+ else {
+ msg_err_config("lua script must return "
+ "function(task) and not %s",
+ lua_typename(cfg->lua_state,
+ lua_type(cfg->lua_state, -1)));
+ }
+ }
+ }
+ }
+ else {
+ bk->enable_languages = FALSE;
+ }
+
+ if (bk->enable_languages) {
+ msg_info_config("enable per language statistics for %s",
+ stf->symbol);
+ }
+
+ if (bk->enable_users) {
+ msg_info_config("enable per users statistics for %s",
+ stf->symbol);
+ }
+
+
+ return (gpointer) bk;
+}
+
+void rspamd_sqlite3_close(gpointer p)
+{
+ struct rspamd_stat_sqlite3_db *bk = p;
+
+ if (bk->sqlite) {
+ if (bk->in_transaction) {
+ rspamd_sqlite3_run_prstmt(bk->pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT);
+ }
+
+ rspamd_sqlite3_close_prstmt(bk->sqlite, bk->prstmt);
+ sqlite3_close(bk->sqlite);
+ g_free(bk->fname);
+ g_free(bk);
+ }
+}
+
+gpointer
+rspamd_sqlite3_runtime(struct rspamd_task *task,
+ struct rspamd_statfile_config *stcf, gboolean learn, gpointer p, gint _id)
+{
+ struct rspamd_stat_sqlite3_rt *rt = NULL;
+ struct rspamd_stat_sqlite3_db *bk = p;
+
+ if (bk) {
+ rt = rspamd_mempool_alloc(task->task_pool, sizeof(*rt));
+ rt->db = bk;
+ rt->task = task;
+ rt->user_id = -1;
+ rt->lang_id = -1;
+ rt->cf = stcf;
+ }
+
+ return rt;
+}
+
+gboolean
+rspamd_sqlite3_process_tokens(struct rspamd_task *task,
+ GPtrArray *tokens,
+ gint id, gpointer p)
+{
+ struct rspamd_stat_sqlite3_db *bk;
+ struct rspamd_stat_sqlite3_rt *rt = p;
+ gint64 iv = 0;
+ guint i;
+ rspamd_token_t *tok;
+
+ g_assert(p != NULL);
+ g_assert(tokens != NULL);
+
+ bk = rt->db;
+
+ for (i = 0; i < tokens->len; i++) {
+ tok = g_ptr_array_index(tokens, i);
+
+ if (bk == NULL) {
+ /* Statfile is does not exist, so all values are zero */
+ tok->values[id] = 0.0f;
+ continue;
+ }
+
+ if (!bk->in_transaction) {
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_START_DEF);
+ bk->in_transaction = TRUE;
+ }
+
+ if (rt->user_id == -1) {
+ if (bk->enable_users) {
+ rt->user_id = rspamd_sqlite3_get_user(bk, task, FALSE);
+ }
+ else {
+ rt->user_id = 0;
+ }
+ }
+
+ if (rt->lang_id == -1) {
+ if (bk->enable_languages) {
+ rt->lang_id = rspamd_sqlite3_get_language(bk, task, FALSE);
+ }
+ else {
+ rt->lang_id = 0;
+ }
+ }
+
+ if (bk->enable_languages || bk->enable_users) {
+ if (rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_GET_TOKEN_FULL,
+ tok->data, rt->user_id, rt->lang_id, &iv) == SQLITE_OK) {
+ tok->values[id] = iv;
+ }
+ else {
+ tok->values[id] = 0.0f;
+ }
+ }
+ else {
+ if (rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_GET_TOKEN_SIMPLE,
+ tok->data, &iv) == SQLITE_OK) {
+ tok->values[id] = iv;
+ }
+ else {
+ tok->values[id] = 0.0f;
+ }
+ }
+
+ if (rt->cf->is_spam) {
+ task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS;
+ }
+ else {
+ task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
+ }
+ }
+
+
+ return TRUE;
+}
+
+gboolean
+rspamd_sqlite3_finalize_process(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ struct rspamd_stat_sqlite3_rt *rt = runtime;
+ struct rspamd_stat_sqlite3_db *bk;
+
+ g_assert(rt != NULL);
+ bk = rt->db;
+
+ if (bk->in_transaction) {
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT);
+ bk->in_transaction = FALSE;
+ }
+
+ rt->lang_id = -1;
+ rt->user_id = -1;
+
+ return TRUE;
+}
+
+gboolean
+rspamd_sqlite3_learn_tokens(struct rspamd_task *task, GPtrArray *tokens,
+ gint id, gpointer p)
+{
+ struct rspamd_stat_sqlite3_db *bk;
+ struct rspamd_stat_sqlite3_rt *rt = p;
+ gint64 iv = 0;
+ guint i;
+ rspamd_token_t *tok;
+
+ g_assert(tokens != NULL);
+ g_assert(p != NULL);
+
+ bk = rt->db;
+
+ for (i = 0; i < tokens->len; i++) {
+ tok = g_ptr_array_index(tokens, i);
+ if (bk == NULL) {
+ /* Statfile is does not exist, so all values are zero */
+ return FALSE;
+ }
+
+ if (!bk->in_transaction) {
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_START_IM);
+ bk->in_transaction = TRUE;
+ }
+
+ if (rt->user_id == -1) {
+ if (bk->enable_users) {
+ rt->user_id = rspamd_sqlite3_get_user(bk, task, TRUE);
+ }
+ else {
+ rt->user_id = 0;
+ }
+ }
+
+ if (rt->lang_id == -1) {
+ if (bk->enable_languages) {
+ rt->lang_id = rspamd_sqlite3_get_language(bk, task, TRUE);
+ }
+ else {
+ rt->lang_id = 0;
+ }
+ }
+
+ iv = tok->values[id];
+
+ if (rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_SET_TOKEN,
+ tok->data, rt->user_id, rt->lang_id, iv) != SQLITE_OK) {
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_ROLLBACK);
+ bk->in_transaction = FALSE;
+
+ return FALSE;
+ }
+ }
+
+ return TRUE;
+}
+
+gboolean
+rspamd_sqlite3_finalize_learn(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx, GError **err)
+{
+ struct rspamd_stat_sqlite3_rt *rt = runtime;
+ struct rspamd_stat_sqlite3_db *bk;
+ gint wal_frames, wal_checkpointed, mode;
+
+ g_assert(rt != NULL);
+ bk = rt->db;
+
+ if (bk->in_transaction) {
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT);
+ bk->in_transaction = FALSE;
+ }
+
+#ifdef SQLITE_OPEN_WAL
+#ifdef SQLITE_CHECKPOINT_TRUNCATE
+ mode = SQLITE_CHECKPOINT_TRUNCATE;
+#elif defined(SQLITE_CHECKPOINT_RESTART)
+ mode = SQLITE_CHECKPOINT_RESTART;
+#elif defined(SQLITE_CHECKPOINT_FULL)
+ mode = SQLITE_CHECKPOINT_FULL;
+#endif
+ /* Perform wal checkpoint (might be long) */
+ if (sqlite3_wal_checkpoint_v2(bk->sqlite,
+ NULL,
+ mode,
+ &wal_frames,
+ &wal_checkpointed) != SQLITE_OK) {
+ msg_warn_task("cannot commit checkpoint: %s",
+ sqlite3_errmsg(bk->sqlite));
+
+ g_set_error(err, rspamd_sqlite3_backend_quark(), 500,
+ "cannot commit checkpoint: %s",
+ sqlite3_errmsg(bk->sqlite));
+ return FALSE;
+ }
+#endif
+
+ return TRUE;
+}
+
+gulong
+rspamd_sqlite3_total_learns(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ struct rspamd_stat_sqlite3_rt *rt = runtime;
+ struct rspamd_stat_sqlite3_db *bk;
+ guint64 res;
+
+ g_assert(rt != NULL);
+ bk = rt->db;
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_GET_LEARNS, &res);
+
+ return res;
+}
+
+gulong
+rspamd_sqlite3_inc_learns(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ struct rspamd_stat_sqlite3_rt *rt = runtime;
+ struct rspamd_stat_sqlite3_db *bk;
+ guint64 res;
+
+ g_assert(rt != NULL);
+ bk = rt->db;
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_INC_LEARNS_LANG,
+ rt->lang_id);
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_INC_LEARNS_USER,
+ rt->user_id);
+
+ if (bk->in_transaction) {
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT);
+ bk->in_transaction = FALSE;
+ }
+
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_GET_LEARNS, &res);
+
+ return res;
+}
+
+gulong
+rspamd_sqlite3_dec_learns(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ struct rspamd_stat_sqlite3_rt *rt = runtime;
+ struct rspamd_stat_sqlite3_db *bk;
+ guint64 res;
+
+ g_assert(rt != NULL);
+ bk = rt->db;
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_DEC_LEARNS_LANG,
+ rt->lang_id);
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_DEC_LEARNS_USER,
+ rt->user_id);
+
+ if (bk->in_transaction) {
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_TRANSACTION_COMMIT);
+ bk->in_transaction = FALSE;
+ }
+
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_GET_LEARNS, &res);
+
+ return res;
+}
+
+gulong
+rspamd_sqlite3_learns(struct rspamd_task *task, gpointer runtime,
+ gpointer ctx)
+{
+ struct rspamd_stat_sqlite3_rt *rt = runtime;
+ struct rspamd_stat_sqlite3_db *bk;
+ guint64 res;
+
+ g_assert(rt != NULL);
+ bk = rt->db;
+ rspamd_sqlite3_run_prstmt(task->task_pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_GET_LEARNS, &res);
+
+ return res;
+}
+
+ucl_object_t *
+rspamd_sqlite3_get_stat(gpointer runtime,
+ gpointer ctx)
+{
+ ucl_object_t *res = NULL;
+ struct rspamd_stat_sqlite3_rt *rt = runtime;
+ struct rspamd_stat_sqlite3_db *bk;
+ rspamd_mempool_t *pool;
+ struct stat st;
+ gint64 rev;
+
+ g_assert(rt != NULL);
+ bk = rt->db;
+ pool = bk->pool;
+
+ (void) stat(bk->fname, &st);
+ rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_GET_LEARNS, &rev);
+
+ res = ucl_object_typed_new(UCL_OBJECT);
+ ucl_object_insert_key(res, ucl_object_fromint(rev), "revision",
+ 0, false);
+ ucl_object_insert_key(res, ucl_object_fromint(st.st_size), "size",
+ 0, false);
+ rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_NTOKENS, &rev);
+ ucl_object_insert_key(res, ucl_object_fromint(rev), "total", 0, false);
+ ucl_object_insert_key(res, ucl_object_fromint(rev), "used", 0, false);
+ ucl_object_insert_key(res, ucl_object_fromstring(rt->cf->symbol),
+ "symbol", 0, false);
+ ucl_object_insert_key(res, ucl_object_fromstring("sqlite3"),
+ "type", 0, false);
+ rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_NLANGUAGES, &rev);
+ ucl_object_insert_key(res, ucl_object_fromint(rev),
+ "languages", 0, false);
+ rspamd_sqlite3_run_prstmt(pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_NUSERS, &rev);
+ ucl_object_insert_key(res, ucl_object_fromint(rev),
+ "users", 0, false);
+
+ if (rt->cf->label) {
+ ucl_object_insert_key(res, ucl_object_fromstring(rt->cf->label),
+ "label", 0, false);
+ }
+
+ return res;
+}
+
+gpointer
+rspamd_sqlite3_load_tokenizer_config(gpointer runtime,
+ gsize *len)
+{
+ gpointer tk_conf, copied_conf;
+ guint64 sz;
+ struct rspamd_stat_sqlite3_rt *rt = runtime;
+ struct rspamd_stat_sqlite3_db *bk;
+
+ g_assert(rt != NULL);
+ bk = rt->db;
+
+ g_assert(rspamd_sqlite3_run_prstmt(rt->db->pool, bk->sqlite, bk->prstmt,
+ RSPAMD_STAT_BACKEND_LOAD_TOKENIZER, &sz, &tk_conf) == SQLITE_OK);
+ g_assert(sz > 0);
+ /*
+ * Here we can have either decoded or undecoded version of tokenizer config
+ * XXX: dirty hack to check if we have osb magic here
+ */
+ if (sz > 7 && memcmp(tk_conf, "osbtokv", 7) == 0) {
+ copied_conf = rspamd_mempool_alloc(rt->task->task_pool, sz);
+ memcpy(copied_conf, tk_conf, sz);
+ g_free(tk_conf);
+ }
+ else {
+ /* Need to decode */
+ copied_conf = rspamd_decode_base32(tk_conf, sz, len, RSPAMD_BASE32_DEFAULT);
+ g_free(tk_conf);
+ rspamd_mempool_add_destructor(rt->task->task_pool, g_free, copied_conf);
+ }
+
+ if (len) {
+ *len = sz;
+ }
+
+ return copied_conf;
+}