summaryrefslogtreecommitdiffstats
path: root/src/libstat/learn_cache/sqlite3_cache.c
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-10 21:30:40 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-10 21:30:40 +0000
commit133a45c109da5310add55824db21af5239951f93 (patch)
treeba6ac4c0a950a0dda56451944315d66409923918 /src/libstat/learn_cache/sqlite3_cache.c
parentInitial commit. (diff)
downloadrspamd-upstream.tar.xz
rspamd-upstream.zip
Adding upstream version 3.8.1.upstream/3.8.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/libstat/learn_cache/sqlite3_cache.c')
-rw-r--r--src/libstat/learn_cache/sqlite3_cache.c274
1 files changed, 274 insertions, 0 deletions
diff --git a/src/libstat/learn_cache/sqlite3_cache.c b/src/libstat/learn_cache/sqlite3_cache.c
new file mode 100644
index 0000000..d8ad20a
--- /dev/null
+++ b/src/libstat/learn_cache/sqlite3_cache.c
@@ -0,0 +1,274 @@
+/*-
+ * Copyright 2016 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "config.h"
+#include "learn_cache.h"
+#include "rspamd.h"
+#include "stat_api.h"
+#include "stat_internal.h"
+#include "cryptobox.h"
+#include "ucl.h"
+#include "fstring.h"
+#include "message.h"
+#include "libutil/sqlite_utils.h"
+
+static const char *create_tables_sql =
+ ""
+ "CREATE TABLE IF NOT EXISTS learns("
+ "id INTEGER PRIMARY KEY,"
+ "flag INTEGER NOT NULL,"
+ "digest TEXT NOT NULL);"
+ "CREATE UNIQUE INDEX IF NOT EXISTS d ON learns(digest);"
+ "";
+
+#define SQLITE_CACHE_PATH RSPAMD_DBDIR "/learn_cache.sqlite"
+
+enum rspamd_stat_sqlite3_stmt_idx {
+ RSPAMD_STAT_CACHE_TRANSACTION_START_IM = 0,
+ RSPAMD_STAT_CACHE_TRANSACTION_START_DEF,
+ RSPAMD_STAT_CACHE_TRANSACTION_COMMIT,
+ RSPAMD_STAT_CACHE_TRANSACTION_ROLLBACK,
+ RSPAMD_STAT_CACHE_GET_LEARN,
+ RSPAMD_STAT_CACHE_ADD_LEARN,
+ RSPAMD_STAT_CACHE_UPDATE_LEARN,
+ RSPAMD_STAT_CACHE_MAX
+};
+
+static struct rspamd_sqlite3_prstmt prepared_stmts[RSPAMD_STAT_CACHE_MAX] =
+ {
+ {.idx = RSPAMD_STAT_CACHE_TRANSACTION_START_IM,
+ .sql = "BEGIN IMMEDIATE TRANSACTION;",
+ .args = "",
+ .stmt = NULL,
+ .result = SQLITE_DONE,
+ .ret = ""},
+ {.idx = RSPAMD_STAT_CACHE_TRANSACTION_START_DEF,
+ .sql = "BEGIN DEFERRED TRANSACTION;",
+ .args = "",
+ .stmt = NULL,
+ .result = SQLITE_DONE,
+ .ret = ""},
+ {.idx = RSPAMD_STAT_CACHE_TRANSACTION_COMMIT,
+ .sql = "COMMIT;",
+ .args = "",
+ .stmt = NULL,
+ .result = SQLITE_DONE,
+ .ret = ""},
+ {.idx = RSPAMD_STAT_CACHE_TRANSACTION_ROLLBACK,
+ .sql = "ROLLBACK;",
+ .args = "",
+ .stmt = NULL,
+ .result = SQLITE_DONE,
+ .ret = ""},
+ {.idx = RSPAMD_STAT_CACHE_GET_LEARN,
+ .sql = "SELECT flag FROM learns WHERE digest=?1",
+ .args = "V",
+ .stmt = NULL,
+ .result = SQLITE_ROW,
+ .ret = "I"},
+ {.idx = RSPAMD_STAT_CACHE_ADD_LEARN,
+ .sql = "INSERT INTO learns(digest, flag) VALUES (?1, ?2);",
+ .args = "VI",
+ .stmt = NULL,
+ .result = SQLITE_DONE,
+ .ret = ""},
+ {.idx = RSPAMD_STAT_CACHE_UPDATE_LEARN,
+ .sql = "UPDATE learns SET flag=?1 WHERE digest=?2;",
+ .args = "IV",
+ .stmt = NULL,
+ .result = SQLITE_DONE,
+ .ret = ""}};
+
+struct rspamd_stat_sqlite3_ctx {
+ sqlite3 *db;
+ GArray *prstmt;
+};
+
+gpointer
+rspamd_stat_cache_sqlite3_init(struct rspamd_stat_ctx *ctx,
+ struct rspamd_config *cfg,
+ struct rspamd_statfile *st,
+ const ucl_object_t *cf)
+{
+ struct rspamd_stat_sqlite3_ctx *new = NULL;
+ const ucl_object_t *elt;
+ gchar dbpath[PATH_MAX];
+ const gchar *path = SQLITE_CACHE_PATH;
+ sqlite3 *sqlite;
+ GError *err = NULL;
+
+ if (cf) {
+ elt = ucl_object_lookup_any(cf, "path", "file", NULL);
+
+ if (elt != NULL) {
+ path = ucl_object_tostring(elt);
+ }
+ }
+
+ rspamd_snprintf(dbpath, sizeof(dbpath), "%s", path);
+
+ sqlite = rspamd_sqlite3_open_or_create(cfg->cfg_pool,
+ dbpath, create_tables_sql, 0, &err);
+
+ if (sqlite == NULL) {
+ msg_err("cannot open sqlite3 cache: %e", err);
+ g_error_free(err);
+ err = NULL;
+ }
+ else {
+ new = g_malloc0(sizeof(*new));
+ new->db = sqlite;
+ new->prstmt = rspamd_sqlite3_init_prstmt(sqlite, prepared_stmts,
+ RSPAMD_STAT_CACHE_MAX, &err);
+
+ if (new->prstmt == NULL) {
+ msg_err("cannot open sqlite3 cache: %e", err);
+ g_error_free(err);
+ err = NULL;
+ sqlite3_close(sqlite);
+ g_free(new);
+ new = NULL;
+ }
+ }
+
+ return new;
+}
+
+gpointer
+rspamd_stat_cache_sqlite3_runtime(struct rspamd_task *task,
+ gpointer ctx, gboolean learn)
+{
+ /* No need of runtime for this type of classifier */
+ return ctx;
+}
+
+gint rspamd_stat_cache_sqlite3_check(struct rspamd_task *task,
+ gboolean is_spam,
+ gpointer runtime)
+{
+ struct rspamd_stat_sqlite3_ctx *ctx = runtime;
+ rspamd_cryptobox_hash_state_t st;
+ rspamd_token_t *tok;
+ guchar *out;
+ gchar *user = NULL;
+ guint i;
+ gint rc;
+ gint64 flag;
+
+ if (task->tokens == NULL || task->tokens->len == 0) {
+ return RSPAMD_LEARN_IGNORE;
+ }
+
+ if (ctx != NULL && ctx->db != NULL) {
+ out = rspamd_mempool_alloc(task->task_pool, rspamd_cryptobox_HASHBYTES);
+
+ rspamd_cryptobox_hash_init(&st, NULL, 0);
+
+ user = rspamd_mempool_get_variable(task->task_pool, "stat_user");
+ /* Use dedicated hash space for per users cache */
+ if (user != NULL) {
+ rspamd_cryptobox_hash_update(&st, user, strlen(user));
+ }
+
+ for (i = 0; i < task->tokens->len; i++) {
+ tok = g_ptr_array_index(task->tokens, i);
+ rspamd_cryptobox_hash_update(&st, (guchar *) &tok->data,
+ sizeof(tok->data));
+ }
+
+ rspamd_cryptobox_hash_final(&st, out);
+
+ rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
+ RSPAMD_STAT_CACHE_TRANSACTION_START_DEF);
+ rc = rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
+ RSPAMD_STAT_CACHE_GET_LEARN, (gint64) rspamd_cryptobox_HASHBYTES,
+ out, &flag);
+ rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
+ RSPAMD_STAT_CACHE_TRANSACTION_COMMIT);
+
+ /* Save hash into variables */
+ rspamd_mempool_set_variable(task->task_pool, "words_hash", out, NULL);
+
+ if (rc == SQLITE_OK) {
+ /* We have some existing record in the table */
+ if (!!flag == !!is_spam) {
+ /* Already learned */
+ msg_warn_task("already seen stat hash: %*bs",
+ rspamd_cryptobox_HASHBYTES, out);
+ return RSPAMD_LEARN_IGNORE;
+ }
+ else {
+ /* Need to relearn */
+ return RSPAMD_LEARN_UNLEARN;
+ }
+ }
+ }
+
+ return RSPAMD_LEARN_OK;
+}
+
+gint rspamd_stat_cache_sqlite3_learn(struct rspamd_task *task,
+ gboolean is_spam,
+ gpointer runtime)
+{
+ struct rspamd_stat_sqlite3_ctx *ctx = runtime;
+ gboolean unlearn = !!(task->flags & RSPAMD_TASK_FLAG_UNLEARN);
+ guchar *h;
+ gint64 flag;
+
+ h = rspamd_mempool_get_variable(task->task_pool, "words_hash");
+
+ if (h == NULL) {
+ return RSPAMD_LEARN_IGNORE;
+ }
+
+ flag = !!is_spam ? 1 : 0;
+
+ if (!unlearn) {
+ /* Insert result new id */
+ rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
+ RSPAMD_STAT_CACHE_TRANSACTION_START_IM);
+ rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
+ RSPAMD_STAT_CACHE_ADD_LEARN,
+ (gint64) rspamd_cryptobox_HASHBYTES, h, flag);
+ rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
+ RSPAMD_STAT_CACHE_TRANSACTION_COMMIT);
+ }
+ else {
+ rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
+ RSPAMD_STAT_CACHE_TRANSACTION_START_IM);
+ rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
+ RSPAMD_STAT_CACHE_UPDATE_LEARN,
+ flag,
+ (gint64) rspamd_cryptobox_HASHBYTES, h);
+ rspamd_sqlite3_run_prstmt(task->task_pool, ctx->db, ctx->prstmt,
+ RSPAMD_STAT_CACHE_TRANSACTION_COMMIT);
+ }
+
+ rspamd_sqlite3_sync(ctx->db, NULL, NULL);
+
+ return RSPAMD_LEARN_OK;
+}
+
+void rspamd_stat_cache_sqlite3_close(gpointer c)
+{
+ struct rspamd_stat_sqlite3_ctx *ctx = (struct rspamd_stat_sqlite3_ctx *) c;
+
+ if (ctx != NULL) {
+ rspamd_sqlite3_close_prstmt(ctx->db, ctx->prstmt);
+ sqlite3_close(ctx->db);
+ g_free(ctx);
+ }
+}