diff options
Diffstat (limited to 'src/libmime/lang_detection_fasttext.cxx')
-rw-r--r-- | src/libmime/lang_detection_fasttext.cxx | 269 |
1 files changed, 269 insertions, 0 deletions
diff --git a/src/libmime/lang_detection_fasttext.cxx b/src/libmime/lang_detection_fasttext.cxx new file mode 100644 index 0000000..c973ed7 --- /dev/null +++ b/src/libmime/lang_detection_fasttext.cxx @@ -0,0 +1,269 @@ +/* + * Copyright 2023 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "lang_detection_fasttext.h" + +#ifdef WITH_FASTTEXT +#include "fasttext/fasttext.h" +#include "libserver/cfg_file.h" +#include "libserver/logger.h" +#include "fmt/core.h" +#include "stat_api.h" +#include <exception> +#include <string_view> +#include <vector> +#endif + +#ifdef WITH_FASTTEXT + +EXTERN_LOG_MODULE_DEF(langdet); +#define msg_debug_lang_det(...) rspamd_conditional_debug_fast(nullptr, nullptr, \ + rspamd_langdet_log_id, "langdet", task->task_pool->tag.uid, \ + __FUNCTION__, \ + __VA_ARGS__) + +namespace rspamd::langdet { +class fasttext_langdet { +private: + fasttext::FastText ft; + std::string model_fname; + bool loaded = false; + +public: + explicit fasttext_langdet(struct rspamd_config *cfg) + { + const auto *ucl_obj = cfg->cfg_ucl_obj; + const auto *opts_section = ucl_object_find_key(ucl_obj, "lang_detection"); + + if (opts_section) { + const auto *model = ucl_object_find_key(opts_section, "fasttext_model"); + + if (model) { + try { + ft.loadModel(ucl_object_tostring(model)); + loaded = true; + model_fname = std::string{ucl_object_tostring(model)}; + } catch (std::exception &e) { + auto err_message = fmt::format("cannot load fasttext model: {}", e.what()); + msg_err_config("%s", err_message.c_str()); + loaded = false; + } + } + } + } + + /* Disallow multiple initialisation */ + fasttext_langdet() = delete; + fasttext_langdet(const fasttext_langdet &) = delete; + fasttext_langdet(fasttext_langdet &&) = delete; + + ~fasttext_langdet() = default; + + auto is_enabled() const -> bool + { + return loaded; + } + auto word2vec(const char *in, std::size_t len, std::vector<std::int32_t> &word_ngramms) const + { + if (!loaded) { + return; + } + + std::string tok{in, len}; + const auto &dic = ft.getDictionary(); + auto h = dic->hash(tok); + auto wid = dic->getId(tok, h); + auto type = wid < 0 ? dic->getType(tok) : dic->getType(wid); + + if (type == fasttext::entry_type::word) { + if (wid < 0) { + auto pipelined_word = fmt::format("{}{}{}", fasttext::Dictionary::BOW, tok, fasttext::Dictionary::EOW); + dic->computeSubwords(pipelined_word, word_ngramms); + } + else { + if (ft.getArgs().maxn <= 0) { + word_ngramms.push_back(wid); + } + else { + const auto ngrams = dic->getSubwords(wid); + word_ngramms.insert(word_ngramms.end(), ngrams.cbegin(), ngrams.cend()); + } + } + } + } + auto detect_language(std::vector<std::int32_t> &words, int k) + -> std::vector<std::pair<fasttext::real, std::string>> * + { + if (!loaded) { + return nullptr; + } + + auto predictions = new std::vector<std::pair<fasttext::real, std::string>>; + predictions->reserve(k); + fasttext::Predictions line_predictions; + line_predictions.reserve(k); + ft.predict(k, words, line_predictions, 0.0f); + const auto *dict = ft.getDictionary().get(); + + for (const auto &pred: line_predictions) { + predictions->push_back(std::make_pair(std::exp(pred.first), dict->getLabel(pred.second))); + } + return predictions; + } + + auto model_info(void) const -> const std::string + { + if (!loaded) { + static const auto not_loaded = std::string{"fasttext model is not loaded"}; + return not_loaded; + } + else { + return fmt::format("fasttext model {}: {} languages, {} tokens", model_fname, + ft.getDictionary()->nlabels(), ft.getDictionary()->ntokens()); + } + } +}; +}// namespace rspamd::langdet +#endif + +/* C API part */ +G_BEGIN_DECLS + +#define FASTTEXT_MODEL_TO_C_API(p) reinterpret_cast<rspamd::langdet::fasttext_langdet *>(p) +#define FASTTEXT_RESULT_TO_C_API(res) reinterpret_cast<std::vector<std::pair<fasttext::real, std::string>> *>(res) + +void *rspamd_lang_detection_fasttext_init(struct rspamd_config *cfg) +{ +#ifndef WITH_FASTTEXT + return nullptr; +#else + return (void *) new rspamd::langdet::fasttext_langdet(cfg); +#endif +} + +char *rspamd_lang_detection_fasttext_show_info(void *ud) +{ +#ifndef WITH_FASTTEXT + return g_strdup("fasttext is not compiled in"); +#else + auto model_info = FASTTEXT_MODEL_TO_C_API(ud)->model_info(); + + return g_strdup(model_info.c_str()); +#endif +} + +bool rspamd_lang_detection_fasttext_is_enabled(void *ud) +{ +#ifdef WITH_FASTTEXT + auto *real_model = FASTTEXT_MODEL_TO_C_API(ud); + + if (real_model) { + return real_model->is_enabled(); + } +#endif + + return false; +} + +rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud, + struct rspamd_task *task, + GArray *utf_words, + int k) +{ +#ifndef WITH_FASTTEXT + return nullptr; +#else + /* Avoid too long inputs */ + static const guint max_fasttext_input_len = 1024 * 1024; + auto *real_model = FASTTEXT_MODEL_TO_C_API(ud); + std::vector<std::int32_t> words_vec; + words_vec.reserve(utf_words->len); + + for (auto i = 0; i < std::min(utf_words->len, max_fasttext_input_len); i++) { + const auto *w = &g_array_index(utf_words, rspamd_stat_token_t, i); + if (w->original.len > 0) { + real_model->word2vec(w->original.begin, w->original.len, words_vec); + } + } + + msg_debug_lang_det("fasttext: got %z word tokens from %ud words", words_vec.size(), utf_words->len); + + auto *res = real_model->detect_language(words_vec, k); + + return (rspamd_fasttext_predict_result_t) res; +#endif +} + +void rspamd_lang_detection_fasttext_destroy(void *ud) +{ +#ifdef WITH_FASTTEXT + delete FASTTEXT_MODEL_TO_C_API(ud); +#endif +} + + +guint rspamd_lang_detection_fasttext_get_nlangs(rspamd_fasttext_predict_result_t res) +{ +#ifdef WITH_FASTTEXT + auto *real_res = FASTTEXT_RESULT_TO_C_API(res); + + if (real_res) { + return real_res->size(); + } +#endif + return 0; +} + +const char * +rspamd_lang_detection_fasttext_get_lang(rspamd_fasttext_predict_result_t res, unsigned int idx) +{ +#ifdef WITH_FASTTEXT + auto *real_res = FASTTEXT_RESULT_TO_C_API(res); + + if (real_res && real_res->size() > idx) { + /* Fasttext returns result in form __label__<lang>, so we need to remove __label__ prefix */ + auto lang = std::string_view{real_res->at(idx).second}; + if (lang.size() > sizeof("__label__") && lang.substr(0, sizeof("__label__") - 1) == "__label__") { + lang.remove_prefix(sizeof("__label__") - 1); + } + return lang.data(); + } +#endif + return nullptr; +} + +float rspamd_lang_detection_fasttext_get_prob(rspamd_fasttext_predict_result_t res, unsigned int idx) +{ +#ifdef WITH_FASTTEXT + auto *real_res = FASTTEXT_RESULT_TO_C_API(res); + + if (real_res && real_res->size() > idx) { + return real_res->at(idx).first; + } +#endif + return 0.0f; +} + +void rspamd_fasttext_predict_result_destroy(rspamd_fasttext_predict_result_t res) +{ +#ifdef WITH_FASTTEXT + auto *real_res = FASTTEXT_RESULT_TO_C_API(res); + + delete real_res; +#endif +} + +G_END_DECLS
\ No newline at end of file |