summaryrefslogtreecommitdiffstats
path: root/src/libserver/symcache/symcache_impl.cxx
diff options
context:
space:
mode:
Diffstat (limited to 'src/libserver/symcache/symcache_impl.cxx')
-rw-r--r--src/libserver/symcache/symcache_impl.cxx1316
1 files changed, 1316 insertions, 0 deletions
diff --git a/src/libserver/symcache/symcache_impl.cxx b/src/libserver/symcache/symcache_impl.cxx
new file mode 100644
index 0000000..93675ac
--- /dev/null
+++ b/src/libserver/symcache/symcache_impl.cxx
@@ -0,0 +1,1316 @@
+/*
+ * Copyright 2023 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lua/lua_common.h"
+#include "symcache_internal.hxx"
+#include "symcache_item.hxx"
+#include "symcache_runtime.hxx"
+#include "unix-std.h"
+#include "libutil/cxx/file_util.hxx"
+#include "libutil/cxx/util.hxx"
+#include "fmt/core.h"
+#include "contrib/t1ha/t1ha.h"
+
+#ifdef __has_include
+#if __has_include(<version>)
+#include <version>
+#endif
+#endif
+#include <cmath>
+
+namespace rspamd::symcache {
+
+INIT_LOG_MODULE_PUBLIC(symcache)
+
+auto symcache::init() -> bool
+{
+ auto res = true;
+ reload_time = cfg->cache_reload_time;
+
+ if (cfg->cache_filename != nullptr) {
+ msg_debug_cache("loading symcache saved data from %s", cfg->cache_filename);
+ load_items();
+ }
+
+ ankerl::unordered_dense::set<int> disabled_ids;
+ /* Process enabled/disabled symbols */
+ for (const auto &[id, it]: items_by_id) {
+ if (disabled_symbols) {
+ /*
+ * Due to the ability to add patterns, this is now O(N^2), but it is done
+ * once on configuration and the amount of static patterns is usually low
+ * The possible optimization is to store non patterns in a different set to check it
+ * quickly. However, it is unlikely that this would be used to something really heavy.
+ */
+ for (const auto &disable_pat: *disabled_symbols) {
+ if (disable_pat.matches(it->get_name())) {
+ msg_debug_cache("symbol %s matches %*s disable pattern", it->get_name().c_str(),
+ (int) disable_pat.to_string_view().size(), disable_pat.to_string_view().data());
+ auto need_disable = true;
+
+ if (enabled_symbols) {
+ for (const auto &enable_pat: *enabled_symbols) {
+ if (enable_pat.matches(it->get_name())) {
+ msg_debug_cache("symbol %s matches %*s enable pattern; skip disabling", it->get_name().c_str(),
+ (int) enable_pat.to_string_view().size(), enable_pat.to_string_view().data());
+ need_disable = false;
+ break;
+ }
+ }
+ }
+
+ if (need_disable) {
+ disabled_ids.insert(it->id);
+
+ if (it->is_virtual()) {
+ auto real_elt = it->get_parent(*this);
+
+ if (real_elt) {
+ disabled_ids.insert(real_elt->id);
+
+ const auto *children = real_elt->get_children();
+ if (children != nullptr) {
+ for (const auto &cld: *children) {
+ msg_debug_cache("symbol %s is a virtual sibling of the disabled symbol %s",
+ cld->get_name().c_str(), it->get_name().c_str());
+ disabled_ids.insert(cld->id);
+ }
+ }
+ }
+ }
+ else {
+ /* Also disable all virtual children of this element */
+ const auto *children = it->get_children();
+
+ if (children != nullptr) {
+ for (const auto &cld: *children) {
+ msg_debug_cache("symbol %s is a virtual child of the disabled symbol %s",
+ cld->get_name().c_str(), it->get_name().c_str());
+ disabled_ids.insert(cld->id);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ /* Deal with the delayed dependencies */
+ msg_debug_cache("resolving delayed dependencies: %d in list", (int) delayed_deps->size());
+ for (const auto &delayed_dep: *delayed_deps) {
+ auto virt_item = get_item_by_name(delayed_dep.from, false);
+ auto real_item = get_item_by_name(delayed_dep.from, true);
+
+ if (virt_item == nullptr || real_item == nullptr) {
+ msg_err_cache("cannot register delayed dependency between %s and %s: "
+ "%s is missing",
+ delayed_dep.from.data(),
+ delayed_dep.to.data(), delayed_dep.from.data());
+ }
+ else {
+
+ if (!disabled_ids.contains(real_item->id)) {
+ msg_debug_cache("delayed between %s(%d:%d) -> %s",
+ delayed_dep.from.data(),
+ real_item->id, virt_item->id,
+ delayed_dep.to.data());
+ add_dependency(real_item->id, delayed_dep.to,
+ virt_item != real_item ? virt_item->id : -1);
+ }
+ else {
+ msg_debug_cache("no delayed between %s(%d:%d) -> %s; %s is disabled",
+ delayed_dep.from.data(),
+ real_item->id, virt_item->id,
+ delayed_dep.to.data(),
+ delayed_dep.from.data());
+ }
+ }
+ }
+
+ /* Remove delayed dependencies, as they are no longer needed at this point */
+ delayed_deps.reset();
+
+ /* Physically remove ids that are disabled statically */
+ for (auto id_to_disable: disabled_ids) {
+ /*
+ * This erasure is inefficient, we can swap the last element with the removed id
+ * But in this way, our ids are still sorted by addition
+ */
+
+ /* Preserve refcount here */
+ auto deleted_element_refcount = items_by_id[id_to_disable];
+ items_by_id.erase(id_to_disable);
+ items_by_symbol.erase(deleted_element_refcount->get_name());
+
+ auto &additional_vec = get_item_specific_vector(*deleted_element_refcount);
+#if defined(__cpp_lib_erase_if)
+ std::erase_if(additional_vec, [id_to_disable](cache_item *elt) {
+ return elt->id == id_to_disable;
+ });
+#else
+ auto it = std::remove_if(additional_vec.begin(),
+ additional_vec.end(), [id_to_disable](cache_item *elt) {
+ return elt->id == id_to_disable;
+ });
+ additional_vec.erase(it, additional_vec.end());
+#endif
+
+ /* Refcount is dropped, so the symbol should be freed, ensure that nothing else owns this symbol */
+ g_assert(deleted_element_refcount.use_count() == 1);
+ }
+
+ /* Remove no longer used stuff */
+ enabled_symbols.reset();
+ disabled_symbols.reset();
+
+ /* Deal with the delayed conditions */
+ msg_debug_cache("resolving delayed conditions: %d in list", (int) delayed_conditions->size());
+ for (const auto &delayed_cond: *delayed_conditions) {
+ auto it = get_item_by_name_mut(delayed_cond.sym, true);
+
+ if (it == nullptr) {
+ msg_err_cache(
+ "cannot register delayed condition for %s",
+ delayed_cond.sym.c_str());
+ luaL_unref(delayed_cond.L, LUA_REGISTRYINDEX, delayed_cond.cbref);
+ }
+ else {
+ if (!it->add_condition(delayed_cond.L, delayed_cond.cbref)) {
+ msg_err_cache(
+ "cannot register delayed condition for %s: virtual parent; qed",
+ delayed_cond.sym.c_str());
+ g_abort();
+ }
+
+ msg_debug_cache("added a condition to the symbol %s", it->symbol.c_str());
+ }
+ }
+ delayed_conditions.reset();
+
+ msg_debug_cache("process dependencies");
+ for (const auto &[_id, it]: items_by_id) {
+ it->process_deps(*this);
+ }
+
+ /* Sorting stuff */
+ constexpr auto postfilters_cmp = [](const auto &it1, const auto &it2) -> bool {
+ return it1->priority < it2->priority;
+ };
+ constexpr auto prefilters_cmp = [](const auto &it1, const auto &it2) -> bool {
+ return it1->priority > it2->priority;
+ };
+
+ msg_debug_cache("sorting stuff");
+ std::stable_sort(std::begin(connfilters), std::end(connfilters), prefilters_cmp);
+ std::stable_sort(std::begin(prefilters), std::end(prefilters), prefilters_cmp);
+ std::stable_sort(std::begin(postfilters), std::end(postfilters), postfilters_cmp);
+ std::stable_sort(std::begin(idempotent), std::end(idempotent), postfilters_cmp);
+
+ resort();
+
+ /* Connect metric symbols with symcache symbols */
+ if (cfg->symbols) {
+ msg_debug_cache("connect metrics");
+ g_hash_table_foreach(cfg->symbols,
+ symcache::metric_connect_cb,
+ (void *) this);
+ }
+
+ return res;
+}
+
+auto symcache::load_items() -> bool
+{
+ auto cached_map = util::raii_mmaped_file::mmap_shared(cfg->cache_filename,
+ O_RDONLY, PROT_READ);
+
+ if (!cached_map.has_value()) {
+ if (cached_map.error().category == util::error_category::CRITICAL) {
+ msg_err_cache("%s", cached_map.error().error_message.data());
+ }
+ else {
+ msg_info_cache("%s", cached_map.error().error_message.data());
+ }
+ return false;
+ }
+
+
+ if (cached_map->get_size() < (gint) sizeof(symcache_header)) {
+ msg_info_cache("cannot use file %s, truncated: %z", cfg->cache_filename,
+ errno, strerror(errno));
+ return false;
+ }
+
+ const auto *hdr = (struct symcache_header *) cached_map->get_map();
+
+ if (memcmp(hdr->magic, symcache_magic,
+ sizeof(symcache_magic)) != 0) {
+ msg_info_cache("cannot use file %s, bad magic", cfg->cache_filename);
+
+ return false;
+ }
+
+ auto *parser = ucl_parser_new(0);
+ const auto *p = (const std::uint8_t *) (hdr + 1);
+
+ if (!ucl_parser_add_chunk(parser, p, cached_map->get_size() - sizeof(*hdr))) {
+ msg_info_cache("cannot use file %s, cannot parse: %s", cfg->cache_filename,
+ ucl_parser_get_error(parser));
+ ucl_parser_free(parser);
+
+ return false;
+ }
+
+ auto *top = ucl_parser_get_object(parser);
+ ucl_parser_free(parser);
+
+ if (top == nullptr || ucl_object_type(top) != UCL_OBJECT) {
+ msg_info_cache("cannot use file %s, bad object", cfg->cache_filename);
+ ucl_object_unref(top);
+
+ return false;
+ }
+
+ auto it = ucl_object_iterate_new(top);
+ const ucl_object_t *cur;
+ while ((cur = ucl_object_iterate_safe(it, true)) != nullptr) {
+ auto item_it = items_by_symbol.find(ucl_object_key(cur));
+
+ if (item_it != items_by_symbol.end()) {
+ auto item = item_it->second;
+ /* Copy saved info */
+ /*
+ * XXX: don't save or load weight, it should be obtained from the
+ * metric
+ */
+#if 0
+ elt = ucl_object_lookup (cur, "weight");
+
+ if (elt) {
+ w = ucl_object_todouble (elt);
+ if (w != 0) {
+ item->weight = w;
+ }
+ }
+#endif
+ const auto *elt = ucl_object_lookup(cur, "time");
+ if (elt) {
+ item->st->avg_time = ucl_object_todouble(elt);
+ }
+
+ elt = ucl_object_lookup(cur, "count");
+ if (elt) {
+ item->st->total_hits = ucl_object_toint(elt);
+ item->last_count = item->st->total_hits;
+ }
+
+ elt = ucl_object_lookup(cur, "frequency");
+ if (elt && ucl_object_type(elt) == UCL_OBJECT) {
+ const ucl_object_t *freq_elt;
+
+ freq_elt = ucl_object_lookup(elt, "avg");
+
+ if (freq_elt) {
+ item->st->avg_frequency = ucl_object_todouble(freq_elt);
+ }
+ freq_elt = ucl_object_lookup(elt, "stddev");
+
+ if (freq_elt) {
+ item->st->stddev_frequency = ucl_object_todouble(freq_elt);
+ }
+ }
+
+ if (item->is_virtual() && !item->is_ghost()) {
+ const auto &parent = item->get_parent(*this);
+
+ if (parent) {
+ if (parent->st->weight < item->st->weight) {
+ parent->st->weight = item->st->weight;
+ }
+ }
+ /*
+ * We maintain avg_time for virtual symbols equal to the
+ * parent item avg_time
+ */
+ item->st->avg_time = parent->st->avg_time;
+ }
+
+ total_weight += fabs(item->st->weight);
+ total_hits += item->st->total_hits;
+ }
+ }
+
+ ucl_object_iterate_free(it);
+ ucl_object_unref(top);
+
+ return true;
+}
+
+template<typename T>
+static constexpr auto round_to_hundreds(T x)
+{
+ return (::floor(x) * 100.0) / 100.0;
+}
+
+bool symcache::save_items() const
+{
+ if (cfg->cache_filename == nullptr) {
+ return false;
+ }
+
+ auto file_sink = util::raii_file_sink::create(cfg->cache_filename,
+ O_WRONLY | O_TRUNC, 00644);
+
+ if (!file_sink.has_value()) {
+ if (errno == EEXIST) {
+ /* Some other process is already writing data, give up silently */
+ return false;
+ }
+
+ msg_err_cache("%s", file_sink.error().error_message.data());
+
+ return false;
+ }
+
+ struct symcache_header hdr;
+ memset(&hdr, 0, sizeof(hdr));
+ memcpy(hdr.magic, symcache_magic, sizeof(symcache_magic));
+
+ if (write(file_sink->get_fd(), &hdr, sizeof(hdr)) == -1) {
+ msg_err_cache("cannot write to file %s, error %d, %s", cfg->cache_filename,
+ errno, strerror(errno));
+
+ return false;
+ }
+
+ auto *top = ucl_object_typed_new(UCL_OBJECT);
+
+ for (const auto &it: items_by_symbol) {
+ auto item = it.second;
+ auto elt = ucl_object_typed_new(UCL_OBJECT);
+ ucl_object_insert_key(elt,
+ ucl_object_fromdouble(round_to_hundreds(item->st->weight)),
+ "weight", 0, false);
+ ucl_object_insert_key(elt,
+ ucl_object_fromdouble(round_to_hundreds(item->st->time_counter.mean)),
+ "time", 0, false);
+ ucl_object_insert_key(elt, ucl_object_fromint(item->st->total_hits),
+ "count", 0, false);
+
+ auto *freq = ucl_object_typed_new(UCL_OBJECT);
+ ucl_object_insert_key(freq,
+ ucl_object_fromdouble(round_to_hundreds(item->st->frequency_counter.mean)),
+ "avg", 0, false);
+ ucl_object_insert_key(freq,
+ ucl_object_fromdouble(round_to_hundreds(item->st->frequency_counter.stddev)),
+ "stddev", 0, false);
+ ucl_object_insert_key(elt, freq, "frequency", 0, false);
+
+ ucl_object_insert_key(top, elt, it.first.data(), 0, true);
+ }
+
+ auto fp = fdopen(file_sink->get_fd(), "a");
+ auto *efunc = ucl_object_emit_file_funcs(fp);
+ auto ret = ucl_object_emit_full(top, UCL_EMIT_JSON_COMPACT, efunc, nullptr);
+ ucl_object_emit_funcs_free(efunc);
+ ucl_object_unref(top);
+ fclose(fp);
+
+ return ret;
+}
+
+auto symcache::metric_connect_cb(void *k, void *v, void *ud) -> void
+{
+ auto *cache = (symcache *) ud;
+ const auto *sym = (const char *) k;
+ auto *s = (struct rspamd_symbol *) v;
+ auto weight = *s->weight_ptr;
+ auto *item = cache->get_item_by_name_mut(sym, false);
+
+ if (item) {
+ item->st->weight = weight;
+ s->cache_item = (void *) item;
+ }
+}
+
+
+auto symcache::get_item_by_id(int id, bool resolve_parent) const -> const cache_item *
+{
+ if (id < 0 || id >= items_by_id.size()) {
+ msg_err_cache("internal error: requested item with id %d, when we have just %d items in the cache",
+ id, (int) items_by_id.size());
+ return nullptr;
+ }
+
+ const auto &maybe_item = rspamd::find_map(items_by_id, id);
+
+ if (!maybe_item.has_value()) {
+ msg_err_cache("internal error: requested item with id %d but it is empty; qed",
+ id);
+ return nullptr;
+ }
+
+ const auto &item = maybe_item.value().get();
+
+ if (resolve_parent && item->is_virtual()) {
+ return item->get_parent(*this);
+ }
+
+ return item.get();
+}
+
+auto symcache::get_item_by_id_mut(int id, bool resolve_parent) const -> cache_item *
+{
+ if (id < 0 || id >= items_by_id.size()) {
+ msg_err_cache("internal error: requested item with id %d, when we have just %d items in the cache",
+ id, (int) items_by_id.size());
+ return nullptr;
+ }
+
+ const auto &maybe_item = rspamd::find_map(items_by_id, id);
+
+ if (!maybe_item.has_value()) {
+ msg_err_cache("internal error: requested item with id %d but it is empty; qed",
+ id);
+ return nullptr;
+ }
+
+ const auto &item = maybe_item.value().get();
+
+ if (resolve_parent && item->is_virtual()) {
+ return const_cast<cache_item *>(item->get_parent(*this));
+ }
+
+ return item.get();
+}
+
+auto symcache::get_item_by_name(std::string_view name, bool resolve_parent) const -> const cache_item *
+{
+ auto it = items_by_symbol.find(name);
+
+ if (it == items_by_symbol.end()) {
+ return nullptr;
+ }
+
+ if (resolve_parent && it->second->is_virtual()) {
+ it->second->resolve_parent(*this);
+ return it->second->get_parent(*this);
+ }
+
+ return it->second;
+}
+
+auto symcache::get_item_by_name_mut(std::string_view name, bool resolve_parent) const -> cache_item *
+{
+ auto it = items_by_symbol.find(name);
+
+ if (it == items_by_symbol.end()) {
+ return nullptr;
+ }
+
+ if (resolve_parent && it->second->is_virtual()) {
+ return (cache_item *) it->second->get_parent(*this);
+ }
+
+ return it->second;
+}
+
+auto symcache::add_dependency(int id_from, std::string_view to, int virtual_id_from) -> void
+{
+ g_assert(id_from >= 0 && id_from < (gint) items_by_id.size());
+ const auto &source = items_by_id[id_from];
+ g_assert(source.get() != nullptr);
+
+ source->deps.emplace_back(nullptr,
+ std::string(to),
+ id_from,
+ -1);
+
+
+ if (virtual_id_from >= 0) {
+ g_assert(virtual_id_from < (gint) items_by_id.size());
+ /* We need that for settings id propagation */
+ const auto &vsource = items_by_id[virtual_id_from];
+ g_assert(vsource.get() != nullptr);
+ vsource->deps.emplace_back(nullptr,
+ std::string(to),
+ -1,
+ virtual_id_from);
+ }
+}
+
+auto symcache::resort() -> void
+{
+ auto log_func = RSPAMD_LOG_FUNC;
+ auto ord = std::make_shared<order_generation>(filters.size() +
+ prefilters.size() +
+ composites.size() +
+ postfilters.size() +
+ idempotent.size() +
+ connfilters.size() +
+ classifiers.size(),
+ cur_order_gen);
+
+ for (auto &it: filters) {
+ if (it) {
+ total_hits += it->st->total_hits;
+ /* Unmask topological order */
+ it->order = 0;
+ ord->d.emplace_back(it->getptr());
+ }
+ }
+
+ enum class tsort_mask {
+ PERM,
+ TEMP
+ };
+
+ constexpr auto tsort_unmask = [](cache_item *it) -> auto {
+ return (it->order & ~((1u << 31) | (1u << 30)));
+ };
+
+ /* Recursive topological sort helper */
+ const auto tsort_visit = [&](cache_item *it, unsigned cur_order, auto &&rec) {
+ constexpr auto tsort_mark = [](cache_item *it, tsort_mask how) {
+ switch (how) {
+ case tsort_mask::PERM:
+ it->order |= (1u << 31);
+ break;
+ case tsort_mask::TEMP:
+ it->order |= (1u << 30);
+ break;
+ }
+ };
+ constexpr auto tsort_is_marked = [](cache_item *it, tsort_mask how) {
+ switch (how) {
+ case tsort_mask::PERM:
+ return (it->order & (1u << 31));
+ case tsort_mask::TEMP:
+ return (it->order & (1u << 30));
+ }
+
+ return 100500u; /* Because fuck compilers, that's why */
+ };
+
+ if (tsort_is_marked(it, tsort_mask::PERM)) {
+ if (cur_order > tsort_unmask(it)) {
+ /* Need to recalculate the whole chain */
+ it->order = cur_order; /* That also removes all masking */
+ }
+ else {
+ /* We are fine, stop DFS */
+ return;
+ }
+ }
+ else if (tsort_is_marked(it, tsort_mask::TEMP)) {
+ msg_err_cache_lambda("cyclic dependencies found when checking '%s'!",
+ it->symbol.c_str());
+ return;
+ }
+
+ tsort_mark(it, tsort_mask::TEMP);
+ msg_debug_cache_lambda("visiting node: %s (%d)", it->symbol.c_str(), cur_order);
+
+ for (const auto &dep: it->deps) {
+ msg_debug_cache_lambda("visiting dep: %s (%d)", dep.item->symbol.c_str(), cur_order + 1);
+ rec(dep.item, cur_order + 1, rec);
+ }
+
+ it->order = cur_order;
+ tsort_mark(it, tsort_mask::PERM);
+ };
+ /*
+ * Topological sort
+ */
+ total_hits = 0;
+ auto used_items = ord->d.size();
+
+ for (const auto &it: ord->d) {
+ if (it->order == 0) {
+ tsort_visit(it.get(), 0, tsort_visit);
+ }
+ }
+
+
+ /* Main sorting comparator */
+ constexpr auto score_functor = [](auto w, auto f, auto t) -> auto {
+ auto time_alpha = 1.0, weight_alpha = 0.1, freq_alpha = 0.01;
+
+ return ((w > 0.0 ? w : weight_alpha) * (f > 0.0 ? f : freq_alpha) /
+ (t > time_alpha ? t : time_alpha));
+ };
+
+ auto cache_order_cmp = [&](const auto &it1, const auto &it2) -> auto {
+ constexpr const auto topology_mult = 1e7,
+ priority_mult = 1e6,
+ augmentations1_mult = 1e5;
+ auto w1 = tsort_unmask(it1.get()) * topology_mult,
+ w2 = tsort_unmask(it2.get()) * topology_mult;
+
+ w1 += it1->priority * priority_mult;
+ w2 += it2->priority * priority_mult;
+ w1 += it1->get_augmentation_weight() * augmentations1_mult;
+ w2 += it2->get_augmentation_weight() * augmentations1_mult;
+
+ auto avg_freq = ((double) total_hits / used_items);
+ auto avg_weight = (total_weight / used_items);
+ auto f1 = (double) it1->st->total_hits / avg_freq;
+ auto f2 = (double) it2->st->total_hits / avg_freq;
+ auto weight1 = std::fabs(it1->st->weight) / avg_weight;
+ auto weight2 = std::fabs(it2->st->weight) / avg_weight;
+ auto t1 = it1->st->avg_time;
+ auto t2 = it2->st->avg_time;
+ w1 += score_functor(weight1, f1, t1);
+ w2 += score_functor(weight2, f2, t2);
+
+ return w1 > w2;
+ };
+
+ std::stable_sort(std::begin(ord->d), std::end(ord->d), cache_order_cmp);
+ /*
+ * Here lives some ugly legacy!
+ * We have several filters classes, connfilters, prefilters, filters... etc
+ *
+ * Our order is meaningful merely for filters, but we have to add other classes
+ * to understand if those symbols are checked or disabled.
+ * We can disable symbols for almost everything but not for virtual symbols.
+ * The rule of thumb is that if a symbol has explicit parent, then it is a
+ * virtual symbol that follows it's special rules
+ */
+
+ /*
+ * We enrich ord with all other symbol types without any sorting,
+ * as it is done in another place
+ */
+ constexpr auto append_items_vec = [](const auto &vec, auto &out) {
+ for (const auto &it: vec) {
+ if (it) {
+ out.emplace_back(it->getptr());
+ }
+ }
+ };
+
+ append_items_vec(connfilters, ord->d);
+ append_items_vec(prefilters, ord->d);
+ append_items_vec(postfilters, ord->d);
+ append_items_vec(idempotent, ord->d);
+ append_items_vec(composites, ord->d);
+ append_items_vec(classifiers, ord->d);
+
+ /* After sorting is done, we can assign all elements in the by_symbol hash */
+ for (const auto [i, it]: rspamd::enumerate(ord->d)) {
+ ord->by_symbol.emplace(it->get_name(), i);
+ ord->by_cache_id[it->id] = i;
+ }
+ /* Finally set the current order */
+ std::swap(ord, items_by_order);
+}
+
+auto symcache::add_symbol_with_callback(std::string_view name,
+ int priority,
+ symbol_func_t func,
+ void *user_data,
+ int flags_and_type) -> int
+{
+ auto real_type_pair_maybe = item_type_from_c(flags_and_type);
+
+ if (!real_type_pair_maybe.has_value()) {
+ msg_err_cache("incompatible flags when adding %s: %s", name.data(),
+ real_type_pair_maybe.error().c_str());
+ return -1;
+ }
+
+ auto real_type_pair = real_type_pair_maybe.value();
+
+ if (real_type_pair.first != symcache_item_type::FILTER) {
+ real_type_pair.second |= SYMBOL_TYPE_NOSTAT;
+ }
+ if (real_type_pair.second & (SYMBOL_TYPE_GHOST | SYMBOL_TYPE_CALLBACK)) {
+ real_type_pair.second |= SYMBOL_TYPE_NOSTAT;
+ }
+
+ if (real_type_pair.first == symcache_item_type::VIRTUAL) {
+ msg_err_cache("trying to add virtual symbol %s as real (no parent)", name.data());
+ return -1;
+ }
+
+ std::string static_string_name;
+
+ if (name.empty()) {
+ static_string_name = fmt::format("AUTO_{}_{}", (void *) func, user_data);
+ msg_warn_cache("trying to add an empty symbol name, convert it to %s",
+ static_string_name.c_str());
+ }
+ else {
+ static_string_name = name;
+ }
+
+ if (real_type_pair.first == symcache_item_type::IDEMPOTENT && priority != 0) {
+ msg_warn_cache("priority has been set for idempotent symbol %s: %d",
+ static_string_name.c_str(), priority);
+ }
+
+ if ((real_type_pair.second & SYMBOL_TYPE_FINE) && priority == 0) {
+ /* Adjust priority for negative weighted symbols */
+ priority = 1;
+ }
+
+ if (items_by_symbol.contains(static_string_name)) {
+ msg_err_cache("duplicate symbol name: %s", static_string_name.data());
+ return -1;
+ }
+
+ auto id = items_by_id.size();
+
+ auto item = cache_item::create_with_function(static_pool, id,
+ std::move(static_string_name),
+ priority, func, user_data,
+ real_type_pair.first, real_type_pair.second);
+
+ items_by_symbol.emplace(item->get_name(), item.get());
+ get_item_specific_vector(*item).push_back(item.get());
+ items_by_id.emplace(id, std::move(item));// Takes ownership
+
+ if (!(real_type_pair.second & SYMBOL_TYPE_NOSTAT)) {
+ cksum = t1ha(name.data(), name.size(), cksum);
+ stats_symbols_count++;
+ }
+
+ return id;
+}
+
+auto symcache::add_virtual_symbol(std::string_view name, int parent_id, int flags_and_type) -> int
+{
+ if (name.empty()) {
+ msg_err_cache("cannot register a virtual symbol with no name; qed");
+ return -1;
+ }
+
+ auto real_type_pair_maybe = item_type_from_c(flags_and_type);
+
+ if (!real_type_pair_maybe.has_value()) {
+ msg_err_cache("incompatible flags when adding %s: %s", name.data(),
+ real_type_pair_maybe.error().c_str());
+ return -1;
+ }
+
+ auto real_type_pair = real_type_pair_maybe.value();
+
+ if (items_by_symbol.contains(name)) {
+ msg_err_cache("duplicate symbol name: %s", name.data());
+ return -1;
+ }
+
+ if (items_by_id.size() < parent_id) {
+ msg_err_cache("parent id %d is out of bounds for virtual symbol %s", parent_id, name.data());
+ return -1;
+ }
+
+ auto id = items_by_id.size();
+
+ auto item = cache_item::create_with_virtual(static_pool,
+ id,
+ std::string{name},
+ parent_id, real_type_pair.first, real_type_pair.second);
+ const auto &parent = items_by_id[parent_id].get();
+ parent->add_child(item.get());
+ items_by_symbol.emplace(item->get_name(), item.get());
+ get_item_specific_vector(*item).push_back(item.get());
+ items_by_id.emplace(id, std::move(item));// Takes ownership
+
+ return id;
+}
+
+auto symcache::set_peak_cb(int cbref) -> void
+{
+ if (peak_cb != -1) {
+ luaL_unref(L, LUA_REGISTRYINDEX, peak_cb);
+ }
+
+ peak_cb = cbref;
+ msg_info_cache("registered peak callback");
+}
+
+auto symcache::add_delayed_condition(std::string_view sym, int cbref) -> void
+{
+ delayed_conditions->emplace_back(sym, cbref, (lua_State *) cfg->lua_state);
+}
+
+auto symcache::validate(bool strict) -> bool
+{
+ total_weight = 1.0;
+
+ for (auto &pair: items_by_symbol) {
+ auto &item = pair.second;
+ auto ghost = item->st->weight == 0 ? true : false;
+ auto skipped = !ghost;
+
+ if (item->is_scoreable() && g_hash_table_lookup(cfg->symbols, item->symbol.c_str()) == nullptr) {
+ if (!std::isnan(cfg->unknown_weight)) {
+ item->st->weight = cfg->unknown_weight;
+ auto *s = rspamd_mempool_alloc0_type(static_pool,
+ struct rspamd_symbol);
+ /* Legit as we actually never modify this data */
+ s->name = (char *) item->symbol.c_str();
+ s->weight_ptr = &item->st->weight;
+ g_hash_table_insert(cfg->symbols, (void *) s->name, (void *) s);
+
+ msg_info_cache("adding unknown symbol %s with weight: %.2f",
+ item->symbol.c_str(), cfg->unknown_weight);
+ ghost = false;
+ skipped = false;
+ }
+ else {
+ skipped = true;
+ }
+ }
+ else {
+ skipped = false;
+ }
+
+ if (!ghost && skipped) {
+ if (!(item->flags & SYMBOL_TYPE_SKIPPED)) {
+ item->flags |= SYMBOL_TYPE_SKIPPED;
+ msg_warn_cache("symbol %s has no score registered, skip its check",
+ item->symbol.c_str());
+ }
+ }
+
+ if (ghost) {
+ msg_debug_cache("symbol %s is registered as ghost symbol, it won't be inserted "
+ "to any metric",
+ item->symbol.c_str());
+ }
+
+ if (item->st->weight < 0 && item->priority == 0) {
+ item->priority++;
+ }
+
+ if (item->is_virtual()) {
+ if (!(item->flags & SYMBOL_TYPE_GHOST)) {
+ auto *parent = const_cast<cache_item *>(item->get_parent(*this));
+
+ if (parent == nullptr) {
+ item->resolve_parent(*this);
+ parent = const_cast<cache_item *>(item->get_parent(*this));
+ }
+
+ if (::fabs(parent->st->weight) < ::fabs(item->st->weight)) {
+ parent->st->weight = item->st->weight;
+ }
+
+ auto p1 = ::abs(item->priority);
+ auto p2 = ::abs(parent->priority);
+
+ if (p1 != p2) {
+ parent->priority = MAX(p1, p2);
+ item->priority = parent->priority;
+ }
+ }
+ }
+
+ total_weight += fabs(item->st->weight);
+ }
+
+ /* Now check each metric item and find corresponding symbol in a cache */
+ auto ret = true;
+ GHashTableIter it;
+ void *k, *v;
+ g_hash_table_iter_init(&it, cfg->symbols);
+
+ while (g_hash_table_iter_next(&it, &k, &v)) {
+ auto ignore_symbol = false;
+ auto sym_def = (struct rspamd_symbol *) v;
+
+ if (sym_def && (sym_def->flags &
+ (RSPAMD_SYMBOL_FLAG_IGNORE_METRIC | RSPAMD_SYMBOL_FLAG_DISABLED))) {
+ ignore_symbol = true;
+ }
+
+ if (!ignore_symbol) {
+ if (!items_by_symbol.contains((const char *) k)) {
+ msg_debug_cache(
+ "symbol '%s' has its score defined but there is no "
+ "corresponding rule registered",
+ k);
+ }
+ }
+ else if (sym_def->flags & RSPAMD_SYMBOL_FLAG_DISABLED) {
+ auto item = get_item_by_name_mut((const char *) k, false);
+
+ if (item) {
+ item->enabled = FALSE;
+ }
+ }
+ }
+
+ return ret;
+}
+
+auto symcache::counters() const -> ucl_object_t *
+{
+ auto *top = ucl_object_typed_new(UCL_ARRAY);
+ constexpr const auto round_float = [](const auto x, const int digits) -> auto {
+ const auto power10 = ::pow(10, digits);
+ return (::floor(x * power10) / power10);
+ };
+
+ for (auto &pair: items_by_symbol) {
+ auto &item = pair.second;
+ auto symbol = pair.first;
+
+ auto *obj = ucl_object_typed_new(UCL_OBJECT);
+ ucl_object_insert_key(obj, ucl_object_fromlstring(symbol.data(), symbol.size()),
+ "symbol", 0, false);
+
+ if (item->is_virtual()) {
+ if (!(item->flags & SYMBOL_TYPE_GHOST)) {
+ const auto *parent = item->get_parent(*this);
+ ucl_object_insert_key(obj,
+ ucl_object_fromdouble(round_float(item->st->weight, 3)),
+ "weight", 0, false);
+ ucl_object_insert_key(obj,
+ ucl_object_fromdouble(round_float(parent->st->avg_frequency, 3)),
+ "frequency", 0, false);
+ ucl_object_insert_key(obj,
+ ucl_object_fromint(parent->st->total_hits),
+ "hits", 0, false);
+ ucl_object_insert_key(obj,
+ ucl_object_fromdouble(round_float(parent->st->avg_time, 3)),
+ "time", 0, false);
+ }
+ else {
+ ucl_object_insert_key(obj,
+ ucl_object_fromdouble(round_float(item->st->weight, 3)),
+ "weight", 0, false);
+ ucl_object_insert_key(obj,
+ ucl_object_fromdouble(0.0),
+ "frequency", 0, false);
+ ucl_object_insert_key(obj,
+ ucl_object_fromdouble(0.0),
+ "hits", 0, false);
+ ucl_object_insert_key(obj,
+ ucl_object_fromdouble(0.0),
+ "time", 0, false);
+ }
+ }
+ else {
+ ucl_object_insert_key(obj,
+ ucl_object_fromdouble(round_float(item->st->weight, 3)),
+ "weight", 0, false);
+ ucl_object_insert_key(obj,
+ ucl_object_fromdouble(round_float(item->st->avg_frequency, 3)),
+ "frequency", 0, false);
+ ucl_object_insert_key(obj,
+ ucl_object_fromint(item->st->total_hits),
+ "hits", 0, false);
+ ucl_object_insert_key(obj,
+ ucl_object_fromdouble(round_float(item->st->avg_time, 3)),
+ "time", 0, false);
+ }
+
+ ucl_array_append(top, obj);
+ }
+
+ return top;
+}
+
+auto symcache::periodic_resort(struct ev_loop *ev_loop, double cur_time, double last_resort) -> void
+{
+ for (const auto &item: filters) {
+
+ if (item->update_counters_check_peak(L, ev_loop, cur_time, last_resort)) {
+ auto cur_value = (item->st->total_hits - item->last_count) /
+ (cur_time - last_resort);
+ auto cur_err = (item->st->avg_frequency - cur_value);
+ cur_err *= cur_err;
+ msg_debug_cache("peak found for %s is %.2f, avg: %.2f, "
+ "stddev: %.2f, error: %.2f, peaks: %d",
+ item->symbol.c_str(), cur_value,
+ item->st->avg_frequency,
+ item->st->stddev_frequency,
+ cur_err,
+ item->frequency_peaks);
+
+ if (peak_cb != -1) {
+ struct ev_loop **pbase;
+
+ lua_rawgeti(L, LUA_REGISTRYINDEX, peak_cb);
+ pbase = (struct ev_loop **) lua_newuserdata(L, sizeof(*pbase));
+ *pbase = ev_loop;
+ rspamd_lua_setclass(L, "rspamd{ev_base}", -1);
+ lua_pushlstring(L, item->symbol.c_str(), item->symbol.size());
+ lua_pushnumber(L, item->st->avg_frequency);
+ lua_pushnumber(L, ::sqrt(item->st->stddev_frequency));
+ lua_pushnumber(L, cur_value);
+ lua_pushnumber(L, cur_err);
+
+ if (lua_pcall(L, 6, 0, 0) != 0) {
+ msg_info_cache("call to peak function for %s failed: %s",
+ item->symbol.c_str(), lua_tostring(L, -1));
+ lua_pop(L, 1);
+ }
+ }
+ }
+ }
+}
+
+symcache::~symcache()
+{
+ if (peak_cb != -1) {
+ luaL_unref(L, LUA_REGISTRYINDEX, peak_cb);
+ }
+}
+
+auto symcache::maybe_resort() -> bool
+{
+ if (items_by_order->generation_id != cur_order_gen) {
+ /*
+ * Cache has been modified, need to resort it
+ */
+ msg_info_cache("symbols cache has been modified since last check:"
+ " old id: %ud, new id: %ud",
+ items_by_order->generation_id, cur_order_gen);
+ resort();
+
+ return true;
+ }
+
+ return false;
+}
+
+auto symcache::get_item_specific_vector(const cache_item &it) -> symcache::items_ptr_vec &
+{
+ switch (it.get_type()) {
+ case symcache_item_type::CONNFILTER:
+ return connfilters;
+ case symcache_item_type::FILTER:
+ return filters;
+ case symcache_item_type::IDEMPOTENT:
+ return idempotent;
+ case symcache_item_type::PREFILTER:
+ return prefilters;
+ case symcache_item_type::POSTFILTER:
+ return postfilters;
+ case symcache_item_type::COMPOSITE:
+ return composites;
+ case symcache_item_type::CLASSIFIER:
+ return classifiers;
+ case symcache_item_type::VIRTUAL:
+ return virtual_symbols;
+ }
+
+ RSPAMD_UNREACHABLE;
+}
+
+auto symcache::process_settings_elt(struct rspamd_config_settings_elt *elt) -> void
+{
+
+ auto id = elt->id;
+
+ if (elt->symbols_disabled) {
+ /* Process denied symbols */
+ ucl_object_iter_t iter = nullptr;
+ const ucl_object_t *cur;
+
+ while ((cur = ucl_object_iterate(elt->symbols_disabled, &iter, true)) != NULL) {
+ const auto *sym = ucl_object_key(cur);
+ auto *item = get_item_by_name_mut(sym, false);
+
+ if (item != nullptr) {
+ if (item->is_virtual()) {
+ /*
+ * Virtual symbols are special:
+ * we ignore them in symcache but prevent them from being
+ * inserted.
+ */
+ item->forbidden_ids.add_id(id);
+ msg_debug_cache("deny virtual symbol %s for settings %ud (%s); "
+ "parent can still be executed",
+ sym, id, elt->name);
+ }
+ else {
+ /* Normal symbol, disable it */
+ item->forbidden_ids.add_id(id);
+ msg_debug_cache("deny symbol %s for settings %ud (%s)",
+ sym, id, elt->name);
+ }
+ }
+ else {
+ msg_warn_cache("cannot find a symbol to disable %s "
+ "when processing settings %ud (%s)",
+ sym, id, elt->name);
+ }
+ }
+ }
+
+ if (elt->symbols_enabled) {
+ ucl_object_iter_t iter = nullptr;
+ const ucl_object_t *cur;
+
+ while ((cur = ucl_object_iterate(elt->symbols_enabled, &iter, true)) != nullptr) {
+ /* Here, we resolve parent and explicitly allow it */
+ const auto *sym = ucl_object_key(cur);
+
+ auto *item = get_item_by_name_mut(sym, false);
+
+ if (item != nullptr) {
+ if (item->is_virtual()) {
+ auto *parent = get_item_by_name_mut(sym, true);
+
+ if (parent) {
+ if (elt->symbols_disabled &&
+ ucl_object_lookup(elt->symbols_disabled, parent->symbol.data())) {
+ msg_err_cache("conflict in %s: cannot enable disabled symbol %s, "
+ "wanted to enable symbol %s",
+ elt->name, parent->symbol.data(), sym);
+ continue;
+ }
+
+ parent->exec_only_ids.add_id(id);
+ msg_debug_cache("allow just execution of symbol %s for settings %ud (%s)",
+ parent->symbol.data(), id, elt->name);
+ }
+ }
+
+ item->allowed_ids.add_id(id);
+ msg_debug_cache("allow execution of symbol %s for settings %ud (%s)",
+ sym, id, elt->name);
+ }
+ else {
+ msg_warn_cache("cannot find a symbol to enable %s "
+ "when processing settings %ud (%s)",
+ sym, id, elt->name);
+ }
+ }
+ }
+}
+
+auto symcache::get_max_timeout(std::vector<std::pair<double, const cache_item *>> &elts) const -> double
+{
+ auto accumulated_timeout = 0.0;
+ auto log_func = RSPAMD_LOG_FUNC;
+ ankerl::unordered_dense::set<const cache_item *> seen_items;
+
+ auto get_item_timeout = [](cache_item *it) {
+ return it->get_numeric_augmentation("timeout").value_or(0.0);
+ };
+
+ /* This function returns the timeout for an item and all it's dependencies */
+ auto get_filter_timeout = [&](cache_item *it, auto self) -> double {
+ auto own_timeout = get_item_timeout(it);
+ auto max_child_timeout = 0.0;
+
+ for (const auto &dep: it->deps) {
+ auto cld_timeout = self(dep.item, self);
+
+ if (cld_timeout > max_child_timeout) {
+ max_child_timeout = cld_timeout;
+ }
+ }
+
+ return own_timeout + max_child_timeout;
+ };
+
+ /* For prefilters and postfilters, we just care about priorities */
+ auto pre_postfilter_iter = [&](const items_ptr_vec &vec) -> double {
+ auto saved_priority = -1;
+ auto max_timeout = 0.0, added_timeout = 0.0;
+ const cache_item *max_elt = nullptr;
+ for (const auto &it: vec) {
+ if (it->priority != saved_priority && max_elt != nullptr && max_timeout > 0) {
+ if (!seen_items.contains(max_elt)) {
+ accumulated_timeout += max_timeout;
+ added_timeout += max_timeout;
+
+ msg_debug_cache_lambda("added %.2f to the timeout (%.2f) as the priority has changed (%d -> %d); "
+ "symbol: %s",
+ max_timeout, accumulated_timeout, saved_priority, it->priority,
+ max_elt->symbol.c_str());
+ elts.emplace_back(max_timeout, max_elt);
+ seen_items.insert(max_elt);
+ }
+ max_timeout = 0;
+ saved_priority = it->priority;
+ max_elt = nullptr;
+ }
+
+ auto timeout = get_item_timeout(it);
+
+ if (timeout > max_timeout) {
+ max_timeout = timeout;
+ max_elt = it;
+ }
+ }
+
+ if (max_elt != nullptr && max_timeout > 0) {
+ if (!seen_items.contains(max_elt)) {
+ accumulated_timeout += max_timeout;
+ added_timeout += max_timeout;
+
+ msg_debug_cache_lambda("added %.2f to the timeout (%.2f) end of processing; "
+ "symbol: %s",
+ max_timeout, accumulated_timeout,
+ max_elt->symbol.c_str());
+ elts.emplace_back(max_timeout, max_elt);
+ seen_items.insert(max_elt);
+ }
+ }
+
+ return added_timeout;
+ };
+
+ auto prefilters_timeout = pre_postfilter_iter(this->prefilters);
+
+ /* For normal filters, we check the maximum chain of the dependencies
+ * This function might have O(N^2) complexity if all symbols are in a single
+ * dependencies chain. But it is not the case in practice
+ */
+ double max_filters_timeout = 0;
+ for (const auto &it: this->filters) {
+ auto timeout = get_filter_timeout(it, get_filter_timeout);
+
+ if (timeout > max_filters_timeout) {
+ max_filters_timeout = timeout;
+ if (!seen_items.contains(it)) {
+ elts.emplace_back(timeout, it);
+ seen_items.insert(it);
+ }
+ }
+ }
+
+ accumulated_timeout += max_filters_timeout;
+
+ auto postfilters_timeout = pre_postfilter_iter(this->postfilters);
+ auto idempotent_timeout = pre_postfilter_iter(this->idempotent);
+
+ /* Sort in decreasing order by timeout */
+ std::stable_sort(std::begin(elts), std::end(elts),
+ [](const auto &p1, const auto &p2) {
+ return p1.first > p2.first;
+ });
+
+ msg_debug_cache("overall cache timeout: %.2f, %.2f from prefilters,"
+ " %.2f from postfilters, %.2f from idempotent filters,"
+ " %.2f from normal filters",
+ accumulated_timeout, prefilters_timeout, postfilters_timeout,
+ idempotent_timeout, max_filters_timeout);
+
+ return accumulated_timeout;
+}
+
+}// namespace rspamd::symcache \ No newline at end of file