diff options
Diffstat (limited to '')
-rw-r--r-- | src/libserver/symcache/symcache_item.cxx | 652 |
1 files changed, 652 insertions, 0 deletions
diff --git a/src/libserver/symcache/symcache_item.cxx b/src/libserver/symcache/symcache_item.cxx new file mode 100644 index 0000000..ac901f5 --- /dev/null +++ b/src/libserver/symcache/symcache_item.cxx @@ -0,0 +1,652 @@ +/* + * Copyright 2023 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "lua/lua_common.h" +#include "symcache_internal.hxx" +#include "symcache_item.hxx" +#include "fmt/core.h" +#include "libserver/task.h" +#include "libutil/cxx/util.hxx" +#include <numeric> +#include <functional> + +namespace rspamd::symcache { + +enum class augmentation_value_type { + NO_VALUE, + STRING_VALUE, + NUMBER_VALUE, +}; + +struct augmentation_info { + int weight = 0; + int implied_flags = 0; + augmentation_value_type value_type = augmentation_value_type::NO_VALUE; +}; + +/* A list of internal augmentations that are known to Rspamd with their weight */ +static const auto known_augmentations = + ankerl::unordered_dense::map<std::string, augmentation_info, rspamd::smart_str_hash, rspamd::smart_str_equal>{ + {"passthrough", {.weight = 10, .implied_flags = SYMBOL_TYPE_IGNORE_PASSTHROUGH}}, + {"single_network", {.weight = 1, .implied_flags = 0}}, + {"no_network", {.weight = 0, .implied_flags = 0}}, + {"many_network", {.weight = 1, .implied_flags = 0}}, + {"important", {.weight = 5, .implied_flags = SYMBOL_TYPE_FINE}}, + {"timeout", { + .weight = 0, + .implied_flags = 0, + .value_type = augmentation_value_type::NUMBER_VALUE, + }}}; + +auto cache_item::get_parent(const symcache &cache) const -> const cache_item * +{ + if (is_virtual()) { + const auto &virtual_sp = std::get<virtual_item>(specific); + + return virtual_sp.get_parent(cache); + } + + return nullptr; +} + +auto cache_item::get_parent_mut(const symcache &cache) -> cache_item * +{ + if (is_virtual()) { + auto &virtual_sp = std::get<virtual_item>(specific); + + return virtual_sp.get_parent_mut(cache); + } + + return nullptr; +} + +auto cache_item::process_deps(const symcache &cache) -> void +{ + /* Allow logging macros to work */ + auto log_tag = [&]() { return cache.log_tag(); }; + + for (auto &dep: deps) { + msg_debug_cache("process real dependency %s on %s", symbol.c_str(), dep.sym.c_str()); + auto *dit = cache.get_item_by_name_mut(dep.sym, true); + + if (dep.vid >= 0) { + /* Case of the virtual symbol that depends on another (maybe virtual) symbol */ + const auto *vdit = cache.get_item_by_name(dep.sym, false); + + if (!vdit) { + if (dit) { + msg_err_cache("cannot add dependency from %s on %s: no dependency symbol registered", + dep.sym.c_str(), dit->symbol.c_str()); + } + } + else { + msg_debug_cache("process virtual dependency %s(%d) on %s(%d)", symbol.c_str(), + dep.vid, vdit->symbol.c_str(), vdit->id); + + unsigned nids = 0; + + /* Propagate ids */ + msg_debug_cache("check id propagation for dependency %s from %s", + symbol.c_str(), dit->symbol.c_str()); + + const auto *ids = dit->allowed_ids.get_ids(nids); + + if (nids > 0) { + msg_debug_cache("propagate allowed ids from %s to %s", + dit->symbol.c_str(), symbol.c_str()); + + allowed_ids.set_ids(ids, nids); + } + + ids = dit->forbidden_ids.get_ids(nids); + + if (nids > 0) { + msg_debug_cache("propagate forbidden ids from %s to %s", + dit->symbol.c_str(), symbol.c_str()); + + forbidden_ids.set_ids(ids, nids); + } + } + } + + if (dit != nullptr) { + if (!dit->is_filter()) { + /* + * Check sanity: + * - filters -> prefilter dependency is OK and always satisfied + * - postfilter -> (filter, prefilter) dep is ok + * - idempotent -> (any) dep is OK + * + * Otherwise, emit error + * However, even if everything is fine this dep is useless ¯\_(ツ)_/¯ + */ + auto ok_dep = false; + + if (dit->get_type() == type) { + ok_dep = true; + } + else if (type < dit->get_type()) { + ok_dep = true; + } + + if (!ok_dep) { + msg_err_cache("cannot add dependency from %s on %s: invalid symbol types", + dep.sym.c_str(), symbol.c_str()); + + continue; + } + } + else { + if (dit->id == id) { + msg_err_cache("cannot add dependency on self: %s -> %s " + "(resolved to %s)", + symbol.c_str(), dep.sym.c_str(), dit->symbol.c_str()); + } + else { + /* Create a reverse dep */ + if (is_virtual()) { + auto *parent = get_parent_mut(cache); + + if (parent) { + dit->rdeps.emplace_back(parent, parent->symbol, parent->id, -1); + dep.item = dit; + dep.id = dit->id; + + msg_debug_cache("added reverse dependency from %d on %d", parent->id, + dit->id); + } + } + else { + dep.item = dit; + dep.id = dit->id; + dit->rdeps.emplace_back(this, symbol, id, -1); + msg_debug_cache("added reverse dependency from %d on %d", id, + dit->id); + } + } + } + } + else if (dep.id >= 0) { + msg_err_cache("cannot find dependency on symbol %s for symbol %s", + dep.sym.c_str(), symbol.c_str()); + + continue; + } + } + + // Remove empty deps + deps.erase(std::remove_if(std::begin(deps), std::end(deps), + [](const auto &dep) { return !dep.item; }), + std::end(deps)); +} + +auto cache_item::resolve_parent(const symcache &cache) -> bool +{ + auto log_tag = [&]() { return cache.log_tag(); }; + + if (is_virtual()) { + auto &virt = std::get<virtual_item>(specific); + + if (virt.get_parent(cache)) { + msg_debug_cache("trying to resolve parent twice for %s", symbol.c_str()); + + return false; + } + + return virt.resolve_parent(cache); + } + else { + msg_warn_cache("trying to resolve a parent for non-virtual symbol %s", symbol.c_str()); + } + + return false; +} + +auto cache_item::update_counters_check_peak(lua_State *L, + struct ev_loop *ev_loop, + double cur_time, + double last_resort) -> bool +{ + auto ret = false; + static const double decay_rate = 0.25; + + st->total_hits += st->hits; + g_atomic_int_set(&st->hits, 0); + + if (last_count > 0) { + auto cur_value = (st->total_hits - last_count) / + (cur_time - last_resort); + rspamd_set_counter_ema(&st->frequency_counter, + cur_value, decay_rate); + st->avg_frequency = st->frequency_counter.mean; + st->stddev_frequency = st->frequency_counter.stddev; + + auto cur_err = (st->avg_frequency - cur_value); + cur_err *= cur_err; + + if (st->frequency_counter.number > 10 && + cur_err > ::sqrt(st->stddev_frequency) * 3) { + frequency_peaks++; + ret = true; + } + } + + last_count = st->total_hits; + + if (cd->number > 0) { + if (!is_virtual()) { + st->avg_time = cd->mean; + rspamd_set_counter_ema(&st->time_counter, + st->avg_time, decay_rate); + st->avg_time = st->time_counter.mean; + memset(cd, 0, sizeof(*cd)); + } + } + + return ret; +} + +auto cache_item::inc_frequency(const char *sym_name, symcache &cache) -> void +{ + if (sym_name && symbol != sym_name) { + if (is_filter()) { + const auto *children = get_children(); + if (children) { + /* Likely a callback symbol with some virtual symbol that needs to be adjusted */ + for (const auto &cld: *children) { + if (cld->get_name() == sym_name) { + cld->inc_frequency(sym_name, cache); + } + } + } + } + else { + /* Name not equal to symbol name, so we need to find the proper name */ + auto *another_item = cache.get_item_by_name_mut(sym_name, false); + if (another_item != nullptr) { + another_item->inc_frequency(sym_name, cache); + } + } + } + else { + /* Symbol and sym name are the same */ + g_atomic_int_inc(&st->hits); + } +} + +auto cache_item::get_type_str() const -> const char * +{ + switch (type) { + case symcache_item_type::CONNFILTER: + return "connfilter"; + case symcache_item_type::FILTER: + return "filter"; + case symcache_item_type::IDEMPOTENT: + return "idempotent"; + case symcache_item_type::PREFILTER: + return "prefilter"; + case symcache_item_type::POSTFILTER: + return "postfilter"; + case symcache_item_type::COMPOSITE: + return "composite"; + case symcache_item_type::CLASSIFIER: + return "classifier"; + case symcache_item_type::VIRTUAL: + return "virtual"; + } + + RSPAMD_UNREACHABLE; +} + +auto cache_item::is_allowed(struct rspamd_task *task, bool exec_only) const -> bool +{ + const auto *what = "execution"; + + if (!exec_only) { + what = "symbol insertion"; + } + + /* Static checks */ + if (!enabled || + (RSPAMD_TASK_IS_EMPTY(task) && !(flags & SYMBOL_TYPE_EMPTY)) || + (flags & SYMBOL_TYPE_MIME_ONLY && !RSPAMD_TASK_IS_MIME(task))) { + + if (!enabled) { + msg_debug_cache_task("skipping %s of %s as it is permanently disabled", + what, symbol.c_str()); + + return false; + } + else { + /* + * If we check merely execution (not insertion), then we disallow + * mime symbols for non mime tasks and vice versa + */ + if (exec_only) { + msg_debug_cache_task("skipping check of %s as it cannot be " + "executed for this task type", + symbol.c_str()); + + return FALSE; + } + } + } + + /* Settings checks */ + if (task->settings_elt != nullptr) { + if (forbidden_ids.check_id(task->settings_elt->id)) { + msg_debug_cache_task("deny %s of %s as it is forbidden for " + "settings id %ud", + what, + symbol.c_str(), + task->settings_elt->id); + + return false; + } + + if (!(flags & SYMBOL_TYPE_EXPLICIT_DISABLE)) { + if (!allowed_ids.check_id(task->settings_elt->id)) { + + if (task->settings_elt->policy == RSPAMD_SETTINGS_POLICY_IMPLICIT_ALLOW) { + msg_debug_cache_task("allow execution of %s settings id %ud " + "allows implicit execution of the symbols;", + symbol.c_str(), + id); + + return true; + } + + if (exec_only) { + /* + * Special case if any of our virtual children are enabled + */ + if (exec_only_ids.check_id(task->settings_elt->id)) { + return true; + } + } + + msg_debug_cache_task("deny %s of %s as it is not listed " + "as allowed for settings id %ud", + what, + symbol.c_str(), + task->settings_elt->id); + return false; + } + } + else { + msg_debug_cache_task("allow %s of %s for " + "settings id %ud as it can be only disabled explicitly", + what, + symbol.c_str(), + task->settings_elt->id); + } + } + else if (flags & SYMBOL_TYPE_EXPLICIT_ENABLE) { + msg_debug_cache_task("deny %s of %s as it must be explicitly enabled", + what, + symbol.c_str()); + return false; + } + + /* Allow all symbols with no settings id */ + return true; +} + +auto cache_item::add_augmentation(const symcache &cache, std::string_view augmentation, + std::optional<std::string_view> value) -> bool +{ + auto log_tag = [&]() { return cache.log_tag(); }; + + if (augmentations.contains(augmentation)) { + msg_warn_cache("duplicate augmentation: %s", augmentation.data()); + + return false; + } + + auto maybe_known = rspamd::find_map(known_augmentations, augmentation); + + if (maybe_known.has_value()) { + auto &known_info = maybe_known.value().get(); + + if (known_info.implied_flags) { + if ((known_info.implied_flags & flags) == 0) { + msg_info_cache("added implied flags (%bd) for symbol %s as it has %s augmentation", + known_info.implied_flags, symbol.data(), augmentation.data()); + flags |= known_info.implied_flags; + } + } + + if (known_info.value_type == augmentation_value_type::NO_VALUE) { + if (value.has_value()) { + msg_err_cache("value specified for augmentation %s, that has no value", + augmentation.data()); + + return false; + } + return augmentations.try_emplace(augmentation, known_info.weight).second; + } + else { + if (!value.has_value()) { + msg_err_cache("value is not specified for augmentation %s, that requires explicit value", + augmentation.data()); + + return false; + } + + if (known_info.value_type == augmentation_value_type::STRING_VALUE) { + return augmentations.try_emplace(augmentation, std::string{value.value()}, + known_info.weight) + .second; + } + else if (known_info.value_type == augmentation_value_type::NUMBER_VALUE) { + /* I wish it was supported properly */ + //auto conv_res = std::from_chars(value->data(), value->size(), num); + char numbuf[128], *endptr = nullptr; + rspamd_strlcpy(numbuf, value->data(), MIN(value->size(), sizeof(numbuf))); + auto num = g_ascii_strtod(numbuf, &endptr); + + if (fabs(num) >= G_MAXFLOAT || std::isnan(num)) { + msg_err_cache("value for augmentation %s is not numeric: %*s", + augmentation.data(), + (int) value->size(), value->data()); + return false; + } + + return augmentations.try_emplace(augmentation, num, + known_info.weight) + .second; + } + } + } + else { + msg_debug_cache("added unknown augmentation %s for symbol %s", + "unknown", augmentation.data(), symbol.data()); + return augmentations.try_emplace(augmentation, 0).second; + } + + // Should not be reached + return false; +} + +auto cache_item::get_augmentation_weight() const -> int +{ + return std::accumulate(std::begin(augmentations), std::end(augmentations), + 0, [](int acc, const auto &map_pair) { + return acc + map_pair.second.weight; + }); +} + +auto cache_item::get_numeric_augmentation(std::string_view name) const -> std::optional<double> +{ + const auto augmentation_value_maybe = rspamd::find_map(this->augmentations, name); + + if (augmentation_value_maybe.has_value()) { + const auto &augmentation = augmentation_value_maybe.value().get(); + + if (std::holds_alternative<double>(augmentation.value)) { + return std::get<double>(augmentation.value); + } + } + + return std::nullopt; +} + + +auto virtual_item::get_parent(const symcache &cache) const -> const cache_item * +{ + if (parent) { + return parent; + } + + return cache.get_item_by_id(parent_id, false); +} + +auto virtual_item::get_parent_mut(const symcache &cache) -> cache_item * +{ + if (parent) { + return parent; + } + + return const_cast<cache_item *>(cache.get_item_by_id(parent_id, false)); +} + +auto virtual_item::resolve_parent(const symcache &cache) -> bool +{ + if (parent) { + return false; + } + + auto item_ptr = cache.get_item_by_id(parent_id, true); + + if (item_ptr) { + parent = const_cast<cache_item *>(item_ptr); + + return true; + } + + return false; +} + +auto item_type_from_c(int type) -> tl::expected<std::pair<symcache_item_type, int>, std::string> +{ + constexpr const auto trivial_types = SYMBOL_TYPE_CONNFILTER | SYMBOL_TYPE_PREFILTER | SYMBOL_TYPE_POSTFILTER | SYMBOL_TYPE_IDEMPOTENT | SYMBOL_TYPE_COMPOSITE | SYMBOL_TYPE_CLASSIFIER | SYMBOL_TYPE_VIRTUAL; + + constexpr auto all_but_one_ty = [&](int type, int exclude_bit) -> auto { + return (type & trivial_types) & (trivial_types & ~exclude_bit); + }; + + if (type & trivial_types) { + auto check_trivial = [&](auto flag, + symcache_item_type ty) -> tl::expected<std::pair<symcache_item_type, int>, std::string> { + if (all_but_one_ty(type, flag)) { + return tl::make_unexpected(fmt::format("invalid flags for a symbol: {}", (int) type)); + } + + return std::make_pair(ty, type & ~flag); + }; + if (type & SYMBOL_TYPE_CONNFILTER) { + return check_trivial(SYMBOL_TYPE_CONNFILTER, symcache_item_type::CONNFILTER); + } + else if (type & SYMBOL_TYPE_PREFILTER) { + return check_trivial(SYMBOL_TYPE_PREFILTER, symcache_item_type::PREFILTER); + } + else if (type & SYMBOL_TYPE_POSTFILTER) { + return check_trivial(SYMBOL_TYPE_POSTFILTER, symcache_item_type::POSTFILTER); + } + else if (type & SYMBOL_TYPE_IDEMPOTENT) { + return check_trivial(SYMBOL_TYPE_IDEMPOTENT, symcache_item_type::IDEMPOTENT); + } + else if (type & SYMBOL_TYPE_COMPOSITE) { + return check_trivial(SYMBOL_TYPE_COMPOSITE, symcache_item_type::COMPOSITE); + } + else if (type & SYMBOL_TYPE_CLASSIFIER) { + return check_trivial(SYMBOL_TYPE_CLASSIFIER, symcache_item_type::CLASSIFIER); + } + else if (type & SYMBOL_TYPE_VIRTUAL) { + return check_trivial(SYMBOL_TYPE_VIRTUAL, symcache_item_type::VIRTUAL); + } + + return tl::make_unexpected(fmt::format("internal error: impossible flags combination: {}", (int) type)); + } + + /* Maybe check other flags combination here? */ + return std::make_pair(symcache_item_type::FILTER, type); +} + +bool operator<(symcache_item_type lhs, symcache_item_type rhs) +{ + auto ret = false; + switch (lhs) { + case symcache_item_type::CONNFILTER: + break; + case symcache_item_type::PREFILTER: + if (rhs == symcache_item_type::CONNFILTER) { + ret = true; + } + break; + case symcache_item_type::FILTER: + if (rhs == symcache_item_type::CONNFILTER || rhs == symcache_item_type::PREFILTER) { + ret = true; + } + break; + case symcache_item_type::POSTFILTER: + if (rhs != symcache_item_type::IDEMPOTENT) { + ret = true; + } + break; + case symcache_item_type::IDEMPOTENT: + default: + break; + } + + return ret; +} + +item_condition::~item_condition() +{ + if (cb != -1 && L != nullptr) { + luaL_unref(L, LUA_REGISTRYINDEX, cb); + } +} + +auto item_condition::check(std::string_view sym_name, struct rspamd_task *task) const -> bool +{ + if (cb != -1 && L != nullptr) { + auto ret = false; + + lua_pushcfunction(L, &rspamd_lua_traceback); + auto err_idx = lua_gettop(L); + + lua_rawgeti(L, LUA_REGISTRYINDEX, cb); + rspamd_lua_task_push(L, task); + + if (lua_pcall(L, 1, 1, err_idx) != 0) { + msg_info_task("call to condition for %s failed: %s", + sym_name.data(), lua_tostring(L, -1)); + } + else { + ret = lua_toboolean(L, -1); + } + + lua_settop(L, err_idx - 1); + + return ret; + } + + return true; +} + +}// namespace rspamd::symcache |