summaryrefslogtreecommitdiffstats
path: root/ml/Database.h
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--ml/Database.h131
1 files changed, 131 insertions, 0 deletions
diff --git a/ml/Database.h b/ml/Database.h
new file mode 100644
index 000000000..cc7b75872
--- /dev/null
+++ b/ml/Database.h
@@ -0,0 +1,131 @@
+// SPDX-License-Identifier: GPL-3.0-or-later
+
+#ifndef ML_DATABASE_H
+#define ML_DATABASE_H
+
+#include "Dimension.h"
+#include "ml-private.h"
+
+#include "json/single_include/nlohmann/json.hpp"
+
+namespace ml {
+
+class Statement {
+public:
+ using RowCallback = std::function<void(sqlite3_stmt *Stmt)>;
+
+public:
+ Statement(const char *RawStmt) : RawStmt(RawStmt), ParsedStmt(nullptr) {}
+
+ template<typename ...ArgTypes>
+ bool exec(sqlite3 *Conn, RowCallback RowCb, ArgTypes ...Args) {
+ if (!prepare(Conn))
+ return false;
+
+ switch (bind(1, Args...)) {
+ case 0:
+ return false;
+ case sizeof...(Args):
+ break;
+ default:
+ return resetAndClear(false);
+ }
+
+ while (true) {
+ switch (int RC = sqlite3_step(ParsedStmt)) {
+ case SQLITE_BUSY: case SQLITE_LOCKED:
+ usleep(SQLITE_INSERT_DELAY * USEC_PER_MS);
+ continue;
+ case SQLITE_ROW:
+ RowCb(ParsedStmt);
+ continue;
+ case SQLITE_DONE:
+ return resetAndClear(true);
+ default:
+ error("Stepping through '%s' returned rc=%d", RawStmt, RC);
+ return resetAndClear(false);
+ }
+ }
+ }
+
+ ~Statement() {
+ if (!ParsedStmt)
+ return;
+
+ int RC = sqlite3_finalize(ParsedStmt);
+ if (RC != SQLITE_OK)
+ error("Could not properly finalize statement (rc=%d)", RC);
+ }
+
+private:
+ bool prepare(sqlite3 *Conn);
+
+ bool bindValue(size_t Pos, const int Value);
+ bool bindValue(size_t Pos, const std::string &Value);
+
+ template<typename ArgType, typename ...ArgTypes>
+ size_t bind(size_t Pos, ArgType T) {
+ return bindValue(Pos, T);
+ }
+
+ template<typename ArgType, typename ...ArgTypes>
+ size_t bind(size_t Pos, ArgType T, ArgTypes ...Args) {
+ return bindValue(Pos, T) + bind(Pos + 1, Args...);
+ }
+
+ bool resetAndClear(bool Ret);
+
+private:
+ const char *RawStmt;
+ sqlite3_stmt *ParsedStmt;
+};
+
+class Database {
+private:
+ static const char *SQL_CREATE_ANOMALIES_TABLE;
+ static const char *SQL_INSERT_ANOMALY;
+ static const char *SQL_SELECT_ANOMALY;
+ static const char *SQL_SELECT_ANOMALY_EVENTS;
+
+public:
+ Database(const std::string &Path);
+
+ ~Database();
+
+ template<typename ...ArgTypes>
+ bool insertAnomaly(ArgTypes... Args) {
+ Statement::RowCallback RowCb = [](sqlite3_stmt *Stmt) { (void) Stmt; };
+ return InsertAnomalyStmt.exec(Conn, RowCb, Args...);
+ }
+
+ template<typename ...ArgTypes>
+ bool getAnomalyInfo(nlohmann::json &Json, ArgTypes&&... Args) {
+ Statement::RowCallback RowCb = [&](sqlite3_stmt *Stmt) {
+ const char *Text = static_cast<const char *>(sqlite3_column_blob(Stmt, 0));
+ Json = nlohmann::json::parse(Text);
+ };
+ return GetAnomalyInfoStmt.exec(Conn, RowCb, Args...);
+ }
+
+ template<typename ...ArgTypes>
+ bool getAnomaliesInRange(std::vector<std::pair<time_t, time_t>> &V, ArgTypes&&... Args) {
+ Statement::RowCallback RowCb = [&](sqlite3_stmt *Stmt) {
+ V.push_back({
+ sqlite3_column_int64(Stmt, 0),
+ sqlite3_column_int64(Stmt, 1)
+ });
+ };
+ return GetAnomaliesInRangeStmt.exec(Conn, RowCb, Args...);
+ }
+
+private:
+ sqlite3 *Conn;
+
+ Statement InsertAnomalyStmt{SQL_INSERT_ANOMALY};
+ Statement GetAnomalyInfoStmt{SQL_SELECT_ANOMALY};
+ Statement GetAnomaliesInRangeStmt{SQL_SELECT_ANOMALY_EVENTS};
+};
+
+}
+
+#endif /* ML_DATABASE_H */