summaryrefslogtreecommitdiffstats
path: root/src/libstat/classifiers/lua_classifier.c
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/libstat/classifiers/lua_classifier.c237
1 files changed, 237 insertions, 0 deletions
diff --git a/src/libstat/classifiers/lua_classifier.c b/src/libstat/classifiers/lua_classifier.c
new file mode 100644
index 0000000..b74330d
--- /dev/null
+++ b/src/libstat/classifiers/lua_classifier.c
@@ -0,0 +1,237 @@
+/*-
+ * Copyright 2016 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "classifiers.h"
+#include "cfg_file.h"
+#include "stat_internal.h"
+#include "lua/lua_common.h"
+
+struct rspamd_lua_classifier_ctx {
+ gchar *name;
+ gint classify_ref;
+ gint learn_ref;
+};
+
+static GHashTable *lua_classifiers = NULL;
+
+#define msg_err_luacl(...) rspamd_default_log_function(G_LOG_LEVEL_CRITICAL, \
+ "luacl", task->task_pool->tag.uid, \
+ RSPAMD_LOG_FUNC, \
+ __VA_ARGS__)
+#define msg_warn_luacl(...) rspamd_default_log_function(G_LOG_LEVEL_WARNING, \
+ "luacl", task->task_pool->tag.uid, \
+ RSPAMD_LOG_FUNC, \
+ __VA_ARGS__)
+#define msg_info_luacl(...) rspamd_default_log_function(G_LOG_LEVEL_INFO, \
+ "luacl", task->task_pool->tag.uid, \
+ RSPAMD_LOG_FUNC, \
+ __VA_ARGS__)
+#define msg_debug_luacl(...) rspamd_conditional_debug_fast(NULL, task->from_addr, \
+ rspamd_luacl_log_id, "luacl", task->task_pool->tag.uid, \
+ RSPAMD_LOG_FUNC, \
+ __VA_ARGS__)
+
+INIT_LOG_MODULE(luacl)
+
+gboolean
+lua_classifier_init(struct rspamd_config *cfg,
+ struct ev_loop *ev_base,
+ struct rspamd_classifier *cl)
+{
+ struct rspamd_lua_classifier_ctx *ctx;
+ lua_State *L = cl->ctx->cfg->lua_state;
+ gint cb_classify = -1, cb_learn = -1;
+
+ if (lua_classifiers == NULL) {
+ lua_classifiers = g_hash_table_new_full(rspamd_strcase_hash,
+ rspamd_strcase_equal, g_free, g_free);
+ }
+
+ ctx = g_hash_table_lookup(lua_classifiers, cl->subrs->name);
+
+ if (ctx != NULL) {
+ msg_err_config("duplicate lua classifier definition: %s",
+ cl->subrs->name);
+
+ return FALSE;
+ }
+
+ lua_getglobal(L, "rspamd_classifiers");
+ if (lua_type(L, -1) != LUA_TTABLE) {
+ msg_err_config("cannot register classifier %s: no rspamd_classifier global",
+ cl->subrs->name);
+ lua_pop(L, 1);
+
+ return FALSE;
+ }
+
+ lua_pushstring(L, cl->subrs->name);
+ lua_gettable(L, -2);
+
+ if (lua_type(L, -1) != LUA_TTABLE) {
+ msg_err_config("cannot register classifier %s: bad lua type: %s",
+ cl->subrs->name, lua_typename(L, lua_type(L, -1)));
+ lua_pop(L, 2);
+
+ return FALSE;
+ }
+
+ lua_pushstring(L, "classify");
+ lua_gettable(L, -2);
+
+ if (lua_type(L, -1) != LUA_TFUNCTION) {
+ msg_err_config("cannot register classifier %s: bad lua type for classify: %s",
+ cl->subrs->name, lua_typename(L, lua_type(L, -1)));
+ lua_pop(L, 3);
+
+ return FALSE;
+ }
+
+ cb_classify = luaL_ref(L, LUA_REGISTRYINDEX);
+
+ lua_pushstring(L, "learn");
+ lua_gettable(L, -2);
+
+ if (lua_type(L, -1) != LUA_TFUNCTION) {
+ msg_err_config("cannot register classifier %s: bad lua type for learn: %s",
+ cl->subrs->name, lua_typename(L, lua_type(L, -1)));
+ lua_pop(L, 3);
+
+ return FALSE;
+ }
+
+ cb_learn = luaL_ref(L, LUA_REGISTRYINDEX);
+ lua_pop(L, 2); /* Table + global */
+
+ ctx = g_malloc0(sizeof(*ctx));
+ ctx->name = g_strdup(cl->subrs->name);
+ ctx->classify_ref = cb_classify;
+ ctx->learn_ref = cb_learn;
+ cl->cfg->flags |= RSPAMD_FLAG_CLASSIFIER_NO_BACKEND;
+ g_hash_table_insert(lua_classifiers, ctx->name, ctx);
+
+ return TRUE;
+}
+gboolean
+lua_classifier_classify(struct rspamd_classifier *cl,
+ GPtrArray *tokens,
+ struct rspamd_task *task)
+{
+ struct rspamd_lua_classifier_ctx *ctx;
+ struct rspamd_task **ptask;
+ struct rspamd_classifier_config **pcfg;
+ lua_State *L;
+ rspamd_token_t *tok;
+ guint i;
+ guint64 v;
+
+ ctx = g_hash_table_lookup(lua_classifiers, cl->subrs->name);
+ g_assert(ctx != NULL);
+ L = task->cfg->lua_state;
+
+ lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->classify_ref);
+ ptask = lua_newuserdata(L, sizeof(*ptask));
+ *ptask = task;
+ rspamd_lua_setclass(L, "rspamd{task}", -1);
+ pcfg = lua_newuserdata(L, sizeof(*pcfg));
+ *pcfg = cl->cfg;
+ rspamd_lua_setclass(L, "rspamd{classifier}", -1);
+
+ lua_createtable(L, tokens->len, 0);
+
+ for (i = 0; i < tokens->len; i++) {
+ tok = g_ptr_array_index(tokens, i);
+ v = tok->data;
+ lua_createtable(L, 3, 0);
+ /* High word, low word, order */
+ lua_pushinteger(L, (guint32) (v >> 32));
+ lua_rawseti(L, -2, 1);
+ lua_pushinteger(L, (guint32) (v));
+ lua_rawseti(L, -2, 2);
+ lua_pushinteger(L, tok->window_idx);
+ lua_rawseti(L, -2, 3);
+ lua_rawseti(L, -2, i + 1);
+ }
+
+ if (lua_pcall(L, 3, 0, 0) != 0) {
+ msg_err_luacl("error running classify function for %s: %s", ctx->name,
+ lua_tostring(L, -1));
+ lua_pop(L, 1);
+
+ return FALSE;
+ }
+
+ return TRUE;
+}
+
+gboolean
+lua_classifier_learn_spam(struct rspamd_classifier *cl,
+ GPtrArray *tokens,
+ struct rspamd_task *task,
+ gboolean is_spam,
+ gboolean unlearn,
+ GError **err)
+{
+ struct rspamd_lua_classifier_ctx *ctx;
+ struct rspamd_task **ptask;
+ struct rspamd_classifier_config **pcfg;
+ lua_State *L;
+ rspamd_token_t *tok;
+ guint i;
+ guint64 v;
+
+ ctx = g_hash_table_lookup(lua_classifiers, cl->subrs->name);
+ g_assert(ctx != NULL);
+ L = task->cfg->lua_state;
+
+ lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->learn_ref);
+ ptask = lua_newuserdata(L, sizeof(*ptask));
+ *ptask = task;
+ rspamd_lua_setclass(L, "rspamd{task}", -1);
+ pcfg = lua_newuserdata(L, sizeof(*pcfg));
+ *pcfg = cl->cfg;
+ rspamd_lua_setclass(L, "rspamd{classifier}", -1);
+
+ lua_createtable(L, tokens->len, 0);
+
+ for (i = 0; i < tokens->len; i++) {
+ tok = g_ptr_array_index(tokens, i);
+ v = 0;
+ v = tok->data;
+ lua_createtable(L, 3, 0);
+ /* High word, low word, order */
+ lua_pushinteger(L, (guint32) (v >> 32));
+ lua_rawseti(L, -2, 1);
+ lua_pushinteger(L, (guint32) (v));
+ lua_rawseti(L, -2, 2);
+ lua_pushinteger(L, tok->window_idx);
+ lua_rawseti(L, -2, 3);
+ lua_rawseti(L, -2, i + 1);
+ }
+
+ lua_pushboolean(L, is_spam);
+ lua_pushboolean(L, unlearn);
+
+ if (lua_pcall(L, 5, 0, 0) != 0) {
+ msg_err_luacl("error running learn function for %s: %s", ctx->name,
+ lua_tostring(L, -1));
+ lua_pop(L, 1);
+
+ return FALSE;
+ }
+
+ return TRUE;
+}