summaryrefslogtreecommitdiffstats
path: root/src/lua/lua_kann.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/lua/lua_kann.c')
-rw-r--r--src/lua/lua_kann.c1361
1 files changed, 1361 insertions, 0 deletions
diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c
new file mode 100644
index 0000000..e42fbfb
--- /dev/null
+++ b/src/lua/lua_kann.c
@@ -0,0 +1,1361 @@
+/*
+ * Copyright 2023 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lua_common.h"
+#include "lua_tensor.h"
+#include "contrib/kann/kann.h"
+
+/***
+ * @module rspamd_kann
+ * `rspamd_kann` is a Lua interface to kann library
+ */
+
+#define KANN_NODE_CLASS "rspamd{kann_node}"
+#define KANN_NETWORK_CLASS "rspamd{kann}"
+
+/* Simple macros to define behaviour */
+#define KANN_LAYER_DEF(name) static int lua_kann_layer_##name(lua_State *L)
+#define KANN_LAYER_INTERFACE(name) \
+ { \
+ #name, lua_kann_layer_##name \
+ }
+
+#define KANN_TRANSFORM_DEF(name) static int lua_kann_transform_##name(lua_State *L)
+#define KANN_TRANSFORM_INTERFACE(name) \
+ { \
+ #name, lua_kann_transform_##name \
+ }
+
+#define KANN_LOSS_DEF(name) static int lua_kann_loss_##name(lua_State *L)
+#define KANN_LOSS_INTERFACE(name) \
+ { \
+ #name, lua_kann_loss_##name \
+ }
+
+#define KANN_NEW_DEF(name) static int lua_kann_new_##name(lua_State *L)
+#define KANN_NEW_INTERFACE(name) \
+ { \
+ #name, lua_kann_new_##name \
+ }
+
+
+/*
+ * Forwarded declarations
+ */
+static kad_node_t *lua_check_kann_node(lua_State *L, int pos);
+
+/* Layers */
+KANN_LAYER_DEF(input);
+KANN_LAYER_DEF(dense);
+KANN_LAYER_DEF(layernorm);
+KANN_LAYER_DEF(rnn);
+KANN_LAYER_DEF(lstm);
+KANN_LAYER_DEF(gru);
+KANN_LAYER_DEF(conv2d);
+KANN_LAYER_DEF(conv1d);
+KANN_LAYER_DEF(cost);
+
+static luaL_reg rspamd_kann_layers_f[] = {
+ KANN_LAYER_INTERFACE(input),
+ KANN_LAYER_INTERFACE(dense),
+ KANN_LAYER_INTERFACE(layernorm),
+ KANN_LAYER_INTERFACE(rnn),
+ KANN_LAYER_INTERFACE(lstm),
+ KANN_LAYER_INTERFACE(gru),
+ KANN_LAYER_INTERFACE(conv2d),
+ KANN_LAYER_INTERFACE(conv1d),
+ KANN_LAYER_INTERFACE(cost),
+ {NULL, NULL},
+};
+
+/* Transition and composition functions */
+
+/* General transform */
+KANN_TRANSFORM_DEF(add);
+KANN_TRANSFORM_DEF(sub);
+KANN_TRANSFORM_DEF(mul);
+KANN_TRANSFORM_DEF(cmul);
+KANN_TRANSFORM_DEF(matmul);
+
+KANN_TRANSFORM_DEF(square);
+KANN_TRANSFORM_DEF(sigm);
+KANN_TRANSFORM_DEF(tanh);
+KANN_TRANSFORM_DEF(relu);
+KANN_TRANSFORM_DEF(softmax);
+KANN_TRANSFORM_DEF(1minus);
+KANN_TRANSFORM_DEF(exp);
+KANN_TRANSFORM_DEF(log);
+KANN_TRANSFORM_DEF(sin);
+static luaL_reg rspamd_kann_transform_f[] = {
+ KANN_TRANSFORM_INTERFACE(add),
+ KANN_TRANSFORM_INTERFACE(sub),
+ KANN_TRANSFORM_INTERFACE(mul),
+ KANN_TRANSFORM_INTERFACE(cmul),
+ KANN_TRANSFORM_INTERFACE(matmul),
+
+ KANN_TRANSFORM_INTERFACE(square),
+ KANN_TRANSFORM_INTERFACE(sigm),
+ KANN_TRANSFORM_INTERFACE(tanh),
+ KANN_TRANSFORM_INTERFACE(relu),
+ KANN_TRANSFORM_INTERFACE(softmax),
+ KANN_TRANSFORM_INTERFACE(1minus),
+ KANN_TRANSFORM_INTERFACE(exp),
+ KANN_TRANSFORM_INTERFACE(log),
+ KANN_TRANSFORM_INTERFACE(sin),
+ {NULL, NULL},
+};
+
+/* Loss functions */
+KANN_LOSS_DEF(mse);
+KANN_LOSS_DEF(ce_multi);
+KANN_LOSS_DEF(ce_bin);
+KANN_LOSS_DEF(ce_bin_neg);
+KANN_LOSS_DEF(ce_multi_weighted);
+static luaL_reg rspamd_kann_loss_f[] = {
+ KANN_LOSS_INTERFACE(mse),
+ KANN_LOSS_INTERFACE(ce_multi),
+ KANN_LOSS_INTERFACE(ce_bin),
+ KANN_LOSS_INTERFACE(ce_bin_neg),
+ KANN_LOSS_INTERFACE(ce_multi_weighted),
+ {NULL, NULL},
+};
+
+/* Creation functions */
+KANN_NEW_DEF(leaf);
+KANN_NEW_DEF(scalar);
+KANN_NEW_DEF(weight);
+KANN_NEW_DEF(bias);
+KANN_NEW_DEF(weight_conv2d);
+KANN_NEW_DEF(weight_conv1d);
+KANN_NEW_DEF(kann);
+
+static luaL_reg rspamd_kann_new_f[] = {
+ KANN_NEW_INTERFACE(leaf),
+ KANN_NEW_INTERFACE(scalar),
+ KANN_NEW_INTERFACE(weight),
+ KANN_NEW_INTERFACE(bias),
+ KANN_NEW_INTERFACE(weight_conv2d),
+ KANN_NEW_INTERFACE(weight_conv1d),
+ KANN_NEW_INTERFACE(kann),
+ {NULL, NULL},
+};
+
+LUA_FUNCTION_DEF(kann, load);
+LUA_FUNCTION_DEF(kann, destroy);
+LUA_FUNCTION_DEF(kann, save);
+LUA_FUNCTION_DEF(kann, train1);
+LUA_FUNCTION_DEF(kann, apply1);
+
+static luaL_reg rspamd_kann_m[] = {
+ LUA_INTERFACE_DEF(kann, save),
+ LUA_INTERFACE_DEF(kann, train1),
+ LUA_INTERFACE_DEF(kann, apply1),
+ {"__gc", lua_kann_destroy},
+ {NULL, NULL},
+};
+
+static int
+rspamd_kann_table_to_flags(lua_State *L, int table_pos)
+{
+ int result = 0;
+
+ lua_pushvalue(L, table_pos);
+
+ for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) {
+ int fl = lua_tointeger(L, -1);
+
+ result |= fl;
+ }
+
+ lua_pop(L, 1);
+
+ return result;
+}
+
+static gint
+lua_load_kann(lua_State *L)
+{
+ lua_newtable(L);
+
+ /* Flags */
+ lua_pushstring(L, "flag");
+ lua_newtable(L);
+ lua_pushinteger(L, KANN_F_IN);
+ lua_setfield(L, -2, "in");
+ lua_pushinteger(L, KANN_F_COST);
+ lua_setfield(L, -2, "cost");
+ lua_pushinteger(L, KANN_F_OUT);
+ lua_setfield(L, -2, "out");
+ lua_pushinteger(L, KANN_F_TRUTH);
+ lua_setfield(L, -2, "truth");
+ lua_settable(L, -3);
+
+ /* Cost type */
+ lua_pushstring(L, "cost");
+ lua_newtable(L);
+ /* binary cross-entropy cost, used with sigmoid */
+ lua_pushinteger(L, KANN_C_CEB);
+ lua_setfield(L, -2, "ceb");
+ /* multi-class cross-entropy cost, used with softmax */
+ lua_pushinteger(L, KANN_C_CEM);
+ lua_setfield(L, -2, "cem");
+ /* binary cross-entropy-like cost, used with tanh */
+ lua_pushinteger(L, KANN_C_CEB_NEG);
+ lua_setfield(L, -2, "ceb_neg");
+ lua_pushinteger(L, KANN_C_MSE);
+ lua_setfield(L, -2, "mse");
+ lua_settable(L, -3);
+
+ /* RNN flag */
+ lua_pushstring(L, "rnn");
+ lua_newtable(L);
+ /* apply layer normalization */
+ lua_pushinteger(L, KANN_RNN_NORM);
+ lua_setfield(L, -2, "norm");
+ /* take the initial hidden values as variables */
+ lua_pushinteger(L, KANN_RNN_VAR_H0);
+ lua_setfield(L, -2, "var_h0");
+ lua_settable(L, -3);
+
+ /* Layers */
+ lua_pushstring(L, "layer");
+ lua_newtable(L);
+ luaL_register(L, NULL, rspamd_kann_layers_f);
+ lua_settable(L, -3);
+
+ /* Transforms */
+ lua_pushstring(L, "transform");
+ lua_newtable(L);
+ luaL_register(L, NULL, rspamd_kann_transform_f);
+ lua_settable(L, -3);
+
+ /* Cost */
+ lua_pushstring(L, "loss");
+ lua_newtable(L);
+ luaL_register(L, NULL, rspamd_kann_loss_f);
+ lua_settable(L, -3);
+
+ /* Create functions */
+ lua_pushstring(L, "new");
+ lua_newtable(L);
+ luaL_register(L, NULL, rspamd_kann_new_f);
+ lua_settable(L, -3);
+
+ /* Load ann from memory or file */
+ lua_pushstring(L, "load");
+ lua_pushcfunction(L, lua_kann_load);
+ lua_settable(L, -3);
+
+ return 1;
+}
+
+static kad_node_t *
+lua_check_kann_node(lua_State *L, int pos)
+{
+ void *ud = rspamd_lua_check_udata(L, pos, KANN_NODE_CLASS);
+ luaL_argcheck(L, ud != NULL, pos, "'kann_node' expected");
+ return ud ? *((kad_node_t **) ud) : NULL;
+}
+
+static kann_t *
+lua_check_kann(lua_State *L, int pos)
+{
+ void *ud = rspamd_lua_check_udata(L, pos, KANN_NETWORK_CLASS);
+ luaL_argcheck(L, ud != NULL, pos, "'kann' expected");
+ return ud ? *((kann_t **) ud) : NULL;
+}
+
+void luaopen_kann(lua_State *L)
+{
+ /* Metatables */
+ rspamd_lua_new_class(L, KANN_NODE_CLASS, NULL); /* TODO: add methods */
+ lua_pop(L, 1); /* No need in metatable... */
+ rspamd_lua_new_class(L, KANN_NETWORK_CLASS, rspamd_kann_m);
+ lua_pop(L, 1); /* No need in metatable... */
+ rspamd_lua_add_preload(L, "rspamd_kann", lua_load_kann);
+ lua_settop(L, 0);
+}
+
+/* Layers implementation */
+#define PUSH_KAD_NODE(n) \
+ do { \
+ kad_node_t **pt; \
+ pt = lua_newuserdata(L, sizeof(kad_node_t *)); \
+ *pt = (n); \
+ rspamd_lua_setclass(L, KANN_NODE_CLASS, -1); \
+ } while (0)
+
+#define PUSH_KAN_NETWORK(n) \
+ do { \
+ kann_t **pn; \
+ pn = lua_newuserdata(L, sizeof(kann_t *)); \
+ *pn = (n); \
+ rspamd_lua_setclass(L, KANN_NETWORK_CLASS, -1); \
+ } while (0)
+
+#define PROCESS_KAD_FLAGS(n, pos) \
+ do { \
+ int fl = 0; \
+ if (lua_type(L, (pos)) == LUA_TTABLE) { fl = rspamd_kann_table_to_flags(L, (pos)); } \
+ else if (lua_type(L, (pos)) == LUA_TNUMBER) { \
+ fl = lua_tointeger(L, (pos)); \
+ } \
+ (n)->ext_flag |= fl; \
+ } while (0)
+
+/***
+ * @function kann.layer.input(ninputs[, flags])
+ * Creates an input layer for ANN
+ * @param {int} ninputs number of inputs
+ * @param {table|int} flags optional flags
+ * @return {kann_node} kann node object (should be used to combine ANN)
+*/
+static int
+lua_kann_layer_input(lua_State *L)
+{
+ gint nnodes = luaL_checkinteger(L, 1);
+
+ if (nnodes > 0) {
+ kad_node_t *t;
+
+ t = kann_layer_input(nnodes);
+
+ PROCESS_KAD_FLAGS(t, 2);
+ PUSH_KAD_NODE(t);
+ }
+ else {
+ return luaL_error(L, "invalid arguments, nnodes required");
+ }
+
+ return 1;
+}
+
+/***
+ * @function kann.layer.dense(in, ninputs[, flags])
+ * Creates a dense layer (e.g. for hidden layer)
+ * @param {kann_node} in kann node
+ * @param {int} ninputs number of dense nodes
+ * @param {table|int} flags optional flags
+ * @return {kann_node} kann node object (should be used to combine ANN)
+*/
+static int
+lua_kann_layer_dense(lua_State *L)
+{
+ kad_node_t *in = lua_check_kann_node(L, 1);
+ gint nnodes = luaL_checkinteger(L, 2);
+
+ if (in != NULL && nnodes > 0) {
+ kad_node_t *t;
+
+ t = kann_layer_dense(in, nnodes);
+
+ PROCESS_KAD_FLAGS(t, 3);
+ PUSH_KAD_NODE(t);
+ }
+ else {
+ return luaL_error(L, "invalid arguments, input + nnodes required");
+ }
+
+ return 1;
+}
+
+/***
+ * @function kann.layer.dropout(in, ratio[, flags])
+ * Creates a dropout layer
+ * @param {kann_node} in kann node
+ * @param {float} ratio drop ratio
+ * @param {table|int} flags optional flags
+ * @return {kann_node} kann node object (should be used to combine ANN)
+*/
+static int
+lua_kann_layer_layerdropout(lua_State *L)
+{
+ kad_node_t *in = lua_check_kann_node(L, 1);
+ double r = luaL_checknumber(L, 2);
+
+ if (in != NULL) {
+ kad_node_t *t;
+
+ t = kann_layer_dropout(in, r);
+
+ PROCESS_KAD_FLAGS(t, 3);
+ PUSH_KAD_NODE(t);
+ }
+ else {
+ return luaL_error(L, "invalid arguments, input + rate required");
+ }
+
+ return 1;
+}
+
+/***
+ * @function kann.layer.dropout(in [, flags])
+ * Creates a normalisation layer
+ * @param {kann_node} in kann node
+ * @param {table|int} flags optional flags
+ * @return {kann_node} kann node object (should be used to combine ANN)
+*/
+static int
+lua_kann_layer_layernorm(lua_State *L)
+{
+ kad_node_t *in = lua_check_kann_node(L, 1);
+
+ if (in != NULL) {
+ kad_node_t *t;
+
+ t = kann_layer_layernorm(in);
+
+ PROCESS_KAD_FLAGS(t, 2);
+ PUSH_KAD_NODE(t);
+ }
+ else {
+ return luaL_error(L, "invalid arguments, input required");
+ }
+
+ return 1;
+}
+
+/***
+ * @function kann.layer.rnn(in, nnodes[, rnn_flags, [, flags]])
+ * Creates a recursive NN layer
+ * @param {kann_node} in kann node
+ * @param {int} nnodes number of cells
+ * @param {int} rnnflags rnn flags
+ * @param {table|int} flags optional flags
+ * @return {kann_node} kann node object (should be used to combine ANN)
+*/
+static int
+lua_kann_layer_rnn(lua_State *L)
+{
+ kad_node_t *in = lua_check_kann_node(L, 1);
+ gint nnodes = luaL_checkinteger(L, 2);
+ gint rnnflags = 0;
+
+ if (in != NULL && nnodes > 0) {
+ kad_node_t *t;
+
+ if (lua_type(L, 3) == LUA_TNUMBER) {
+ rnnflags = lua_tointeger(L, 3);
+ }
+
+ t = kann_layer_rnn(in, nnodes, rnnflags);
+
+ PROCESS_KAD_FLAGS(t, 4);
+ PUSH_KAD_NODE(t);
+ }
+ else {
+ return luaL_error(L, "invalid arguments, input + nnodes required");
+ }
+
+ return 1;
+}
+
+/***
+ * @function kann.layer.lstm(in, nnodes[, rnn_flags, [, flags]])
+ * Creates a recursive NN layer using LSTM cells
+ * @param {kann_node} in kann node
+ * @param {int} nnodes number of cells
+ * @param {int} rnnflags rnn flags
+ * @param {table|int} flags optional flags
+ * @return {kann_node} kann node object (should be used to combine ANN)
+*/
+static int
+lua_kann_layer_lstm(lua_State *L)
+{
+ kad_node_t *in = lua_check_kann_node(L, 1);
+ gint nnodes = luaL_checkinteger(L, 2);
+ gint rnnflags = 0;
+
+ if (in != NULL && nnodes > 0) {
+ kad_node_t *t;
+
+ if (lua_type(L, 3) == LUA_TNUMBER) {
+ rnnflags = lua_tointeger(L, 3);
+ }
+
+ t = kann_layer_lstm(in, nnodes, rnnflags);
+
+ PROCESS_KAD_FLAGS(t, 4);
+ PUSH_KAD_NODE(t);
+ }
+ else {
+ return luaL_error(L, "invalid arguments, input + nnodes required");
+ }
+
+ return 1;
+}
+
+/***
+ * @function kann.layer.rnn(in, nnodes[, rnn_flags, [, flags]])
+ * Creates a recursive NN layer using GRU cells
+ * @param {kann_node} in kann node
+ * @param {int} nnodes number of cells
+ * @param {int} rnnflags rnn flags
+ * @param {table|int} flags optional flags
+ * @return {kann_node} kann node object (should be used to combine ANN)
+*/
+static int
+lua_kann_layer_gru(lua_State *L)
+{
+ kad_node_t *in = lua_check_kann_node(L, 1);
+ gint nnodes = luaL_checkinteger(L, 2);
+ gint rnnflags = 0;
+
+ if (in != NULL && nnodes > 0) {
+ kad_node_t *t;
+
+ if (lua_type(L, 3) == LUA_TNUMBER) {
+ rnnflags = lua_tointeger(L, 3);
+ }
+
+ t = kann_layer_gru(in, nnodes, rnnflags);
+
+ PROCESS_KAD_FLAGS(t, 4);
+ PUSH_KAD_NODE(t);
+ }
+ else {
+ return luaL_error(L, "invalid arguments, input + nnodes required");
+ }
+
+ return 1;
+}
+
+/***
+ * @function kann.layer.conv2d(in, n_flt, k_rows, k_cols, stride_rows, stride_cols, pad_rows, pad_columns[, flags])
+ * Creates a 2D convolution layer
+ * @param {kann_node} in kann node
+ * @param {int} n_flt number of filters
+ * @param {int} k_rows kernel rows
+ * @param {int} k_cols kernel columns
+ * @param {int} stride_rows stride rows
+ * @param {int} stride_cols stride columns
+ * @param {int} pad_rows padding rows
+ * @param {int} pad_columns padding columns
+ * @param {table|int} flags optional flags
+ * @return {kann_node} kann node object (should be used to combine ANN)
+*/
+static int
+lua_kann_layer_conv2d(lua_State *L)
+{
+ kad_node_t *in = lua_check_kann_node(L, 1);
+ int n_flt = luaL_checkinteger(L, 2);
+ int k_rows = luaL_checkinteger(L, 3);
+ int k_cols = luaL_checkinteger(L, 4);
+ int stride_r = luaL_checkinteger(L, 5);
+ int stride_c = luaL_checkinteger(L, 6);
+ int pad_r = luaL_checkinteger(L, 7);
+ int pad_c = luaL_checkinteger(L, 8);
+
+ if (in != NULL) {
+ kad_node_t *t;
+ t = kann_layer_conv2d(in, n_flt, k_rows, k_cols, stride_r, stride_c,
+ pad_r, pad_c);
+
+ PROCESS_KAD_FLAGS(t, 9);
+ PUSH_KAD_NODE(t);
+ }
+ else {
+ return luaL_error(L, "invalid arguments, input, nflt, kx, ky, stridex, stridey, padx, pady are required");
+ }
+
+ return 1;
+}
+
+/***
+ * @function kann.layer.conv1d(in, n_flt, kern_size, stride_size, pad_size[, flags])
+ * Creates 1D convolution layer
+ * @param {kann_node} in kann node
+ * @param {int} n_flt number of filters
+ * @param {int} kern_size kernel rows
+ * @param {int} stride_size stride rows
+ * @param {int} pad_size padding rows
+ * @param {table|int} flags optional flags
+ * @return {kann_node} kann node object (should be used to combine ANN)
+*/
+static int
+lua_kann_layer_conv1d(lua_State *L)
+{
+ kad_node_t *in = lua_check_kann_node(L, 1);
+ int n_flt = luaL_checkinteger(L, 2);
+ int k_size = luaL_checkinteger(L, 3);
+ int stride = luaL_checkinteger(L, 4);
+ int pad = luaL_checkinteger(L, 5);
+
+ if (in != NULL) {
+ kad_node_t *t;
+ t = kann_layer_conv1d(in, n_flt, k_size, stride, pad);
+
+ PROCESS_KAD_FLAGS(t, 6);
+ PUSH_KAD_NODE(t);
+ }
+ else {
+ return luaL_error(L, "invalid arguments, input, nflt, k, stride, pad required");
+ }
+
+ return 1;
+}
+
+/***
+ * @function kann.layer.cost(in, nout, cost_type[, flags])
+ * Creates 1D convolution layer
+ * @param {kann_node} in kann node
+ * @param {int} nout number of outputs
+ * @param {int} cost_type see kann.cost table
+ * @param {table|int} flags optional flags
+ * @return {kann_node} kann node object (should be used to combine ANN)
+*/
+static int
+lua_kann_layer_cost(lua_State *L)
+{
+ kad_node_t *in = lua_check_kann_node(L, 1);
+ int nout = luaL_checkinteger(L, 2);
+ int cost_type = luaL_checkinteger(L, 3);
+
+ if (in != NULL && nout > 0) {
+ kad_node_t *t;
+ t = kann_layer_cost(in, nout, cost_type);
+
+ PROCESS_KAD_FLAGS(t, 4);
+ PUSH_KAD_NODE(t);
+ }
+ else {
+ return luaL_error(L, "invalid arguments, input, nout and cost_type are required");
+ }
+
+ return 1;
+}
+
+/* Generic helpers */
+static int
+lua_kann_call_unary_function(lua_State *L, const char *name,
+ kad_node_t *(*func)(kad_node_t *) )
+{
+ kad_node_t *in = lua_check_kann_node(L, 1);
+
+ if (in != NULL) {
+ kad_node_t *t;
+ t = func(in);
+
+ PUSH_KAD_NODE(t);
+ }
+ else {
+ return luaL_error(L, "invalid arguments for %s, input required", name);
+ }
+
+ return 1;
+}
+static int
+lua_kann_call_binary_function(lua_State *L, const char *name,
+ kad_node_t *(*func)(kad_node_t *, kad_node_t *) )
+{
+ kad_node_t *x = lua_check_kann_node(L, 1);
+ kad_node_t *y = lua_check_kann_node(L, 2);
+
+ if (x != NULL && y != NULL) {
+ kad_node_t *t;
+ t = func(x, y);
+
+ PUSH_KAD_NODE(t);
+ }
+ else {
+ return luaL_error(L, "invalid arguments for %s, 2 inputs required", name);
+ }
+
+ return 1;
+}
+
+#define LUA_UNARY_TRANSFORM_FUNC_IMPL(name) \
+ static int lua_kann_transform_##name(lua_State *L) \
+ { \
+ return lua_kann_call_unary_function(L, #name, kad_##name); \
+ }
+
+#define LUA_BINARY_TRANSFORM_FUNC_IMPL(name) \
+ static int lua_kann_transform_##name(lua_State *L) \
+ { \
+ return lua_kann_call_binary_function(L, #name, kad_##name); \
+ }
+
+#define LUA_LOSS_FUNC_IMPL(name) \
+ static int lua_kann_loss_##name(lua_State *L) \
+ { \
+ return lua_kann_call_binary_function(L, #name, kad_##name); \
+ }
+
+/* Transform functions registered via macro helpers */
+LUA_BINARY_TRANSFORM_FUNC_IMPL(add)
+LUA_BINARY_TRANSFORM_FUNC_IMPL(sub)
+LUA_BINARY_TRANSFORM_FUNC_IMPL(mul)
+LUA_BINARY_TRANSFORM_FUNC_IMPL(cmul)
+LUA_BINARY_TRANSFORM_FUNC_IMPL(matmul)
+
+LUA_UNARY_TRANSFORM_FUNC_IMPL(square)
+LUA_UNARY_TRANSFORM_FUNC_IMPL(sigm)
+LUA_UNARY_TRANSFORM_FUNC_IMPL(tanh)
+LUA_UNARY_TRANSFORM_FUNC_IMPL(relu)
+LUA_UNARY_TRANSFORM_FUNC_IMPL(softmax)
+LUA_UNARY_TRANSFORM_FUNC_IMPL(1minus)
+LUA_UNARY_TRANSFORM_FUNC_IMPL(exp)
+LUA_UNARY_TRANSFORM_FUNC_IMPL(log)
+LUA_UNARY_TRANSFORM_FUNC_IMPL(sin)
+
+/* Generic cost functions */
+LUA_LOSS_FUNC_IMPL(mse)
+LUA_LOSS_FUNC_IMPL(ce_multi)
+LUA_LOSS_FUNC_IMPL(ce_bin)
+LUA_LOSS_FUNC_IMPL(ce_bin_neg)
+
+/* The only case of ternary weight function */
+static int
+lua_kann_loss_ce_multi_weighted(lua_State *L)
+{
+ kad_node_t *pred = lua_check_kann_node(L, 1);
+ kad_node_t *truth = lua_check_kann_node(L, 2);
+ kad_node_t *weight = lua_check_kann_node(L, 3);
+
+ if (pred != NULL && truth != NULL && weight != NULL) {
+ kad_node_t *t;
+ t = kad_ce_multi_weighted(pred, truth, weight);
+
+ PUSH_KAD_NODE(t);
+ }
+ else {
+ return luaL_error(L, "invalid arguments for ce_multi_weighted, 3 inputs required");
+ }
+
+ return 1;
+}
+
+/* Creation functions */
+static int
+lua_kann_new_scalar(lua_State *L)
+{
+ gint flag = luaL_checkinteger(L, 1);
+ double x = luaL_checknumber(L, 2);
+ kad_node_t *t;
+
+ t = kann_new_scalar(flag, x);
+
+ PROCESS_KAD_FLAGS(t, 3);
+ PUSH_KAD_NODE(t);
+
+ return 1;
+}
+
+static int
+lua_kann_new_weight(lua_State *L)
+{
+ gint nrow = luaL_checkinteger(L, 1);
+ gint ncol = luaL_checkinteger(L, 2);
+ kad_node_t *t;
+
+ t = kann_new_weight(nrow, ncol);
+
+ PROCESS_KAD_FLAGS(t, 3);
+ PUSH_KAD_NODE(t);
+
+ return 1;
+}
+
+static int
+lua_kann_new_bias(lua_State *L)
+{
+ gint n = luaL_checkinteger(L, 1);
+ kad_node_t *t;
+
+ t = kann_new_bias(n);
+
+ PROCESS_KAD_FLAGS(t, 2);
+ PUSH_KAD_NODE(t);
+
+ return 1;
+}
+
+static int
+lua_kann_new_weight_conv2d(lua_State *L)
+{
+ gint nout = luaL_checkinteger(L, 1);
+ gint nin = luaL_checkinteger(L, 2);
+ gint krow = luaL_checkinteger(L, 3);
+ gint kcol = luaL_checkinteger(L, 4);
+ kad_node_t *t;
+
+ t = kann_new_weight_conv2d(nout, nin, krow, kcol);
+
+ PROCESS_KAD_FLAGS(t, 5);
+ PUSH_KAD_NODE(t);
+
+ return 1;
+}
+
+static int
+lua_kann_new_weight_conv1d(lua_State *L)
+{
+ gint nout = luaL_checkinteger(L, 1);
+ gint nin = luaL_checkinteger(L, 2);
+ gint klen = luaL_checkinteger(L, 3);
+ kad_node_t *t;
+
+ t = kann_new_weight_conv1d(nout, nin, klen);
+
+ PROCESS_KAD_FLAGS(t, 4);
+ PUSH_KAD_NODE(t);
+
+ return 1;
+}
+
+static int
+lua_kann_new_leaf(lua_State *L)
+{
+ int dim = luaL_checkinteger(L, 1), i, *ar;
+ kad_node_t *t;
+
+ if (dim >= 1 && dim < KAD_MAX_DIM && lua_istable(L, 2)) {
+ ar = g_new0(int, KAD_MAX_DIM);
+
+ for (i = 0; i < dim; i++) {
+ lua_rawgeti(L, 2, i + 1);
+ ar[i] = lua_tointeger(L, -1);
+ lua_pop(L, 1);
+ }
+
+ t = kann_new_leaf_array(NULL, NULL, 0, 0.0, dim, ar);
+
+ PROCESS_KAD_FLAGS(t, 3);
+ PUSH_KAD_NODE(t);
+
+ g_free(ar);
+ }
+ else {
+ return luaL_error(L, "invalid arguments for new.leaf, "
+ "dim and vector of elements are required");
+ }
+
+ return 1;
+}
+
+static int
+lua_kann_new_kann(lua_State *L)
+{
+ kad_node_t *cost = lua_check_kann_node(L, 1);
+ kann_t *k;
+
+ if (cost) {
+ k = kann_new(cost, 0);
+
+ PUSH_KAN_NETWORK(k);
+ }
+ else {
+ return luaL_error(L, "invalid arguments for new.kann, "
+ "cost node is required");
+ }
+
+ return 1;
+}
+
+static int
+lua_kann_destroy(lua_State *L)
+{
+ kann_t *k = lua_check_kann(L, 1);
+
+ kann_delete(k);
+
+ return 0;
+}
+
+static int
+lua_kann_save(lua_State *L)
+{
+ kann_t *k = lua_check_kann(L, 1);
+
+ if (k) {
+ if (lua_istable(L, 2)) {
+ lua_getfield(L, 2, "filename");
+
+ if (lua_isstring(L, -1)) {
+ const gchar *fname = lua_tostring(L, -1);
+ FILE *f;
+
+ f = fopen(fname, "w");
+
+ if (!f) {
+ lua_pop(L, 1);
+
+ return luaL_error(L, "cannot open %s for writing: %s",
+ fname, strerror(errno));
+ }
+
+ kann_save_fp(f, k);
+ fclose(f);
+
+ lua_pushboolean(L, true);
+ }
+ else {
+ lua_pop(L, 1);
+
+ return luaL_error(L, "invalid arguments: missing filename");
+ }
+
+ lua_pop(L, 1);
+ }
+ else {
+ /* Save to Rspamd text */
+#ifndef HAVE_OPENMEMSTREAM
+ return luaL_error(L, "no support of saving to memory on your system");
+#endif
+ FILE *f;
+ char *buf = NULL;
+ size_t buflen;
+ struct rspamd_lua_text *t;
+
+ f = open_memstream(&buf, &buflen);
+ g_assert(f != NULL);
+
+ kann_save_fp(f, k);
+ fclose(f);
+
+ t = lua_newuserdata(L, sizeof(*t));
+ rspamd_lua_setclass(L, "rspamd{text}", -1);
+ t->flags = RSPAMD_TEXT_FLAG_OWN;
+ t->start = (const gchar *) buf;
+ t->len = buflen;
+ }
+ }
+ else {
+ return luaL_error(L, "invalid arguments");
+ }
+
+ return 1;
+}
+
+static int
+lua_kann_load(lua_State *L)
+{
+ kann_t *k;
+ FILE *f = NULL;
+
+ if (lua_istable(L, 1)) {
+ lua_getfield(L, 2, "filename");
+
+ if (lua_isstring(L, -1)) {
+ const gchar *fname = lua_tostring(L, -1);
+
+ f = fopen(fname, "rb");
+ }
+ else {
+ lua_pop(L, 1);
+
+ return luaL_error(L, "invalid arguments: missing filename");
+ }
+
+ lua_pop(L, 1);
+ }
+ else if (lua_isstring(L, 1)) {
+ gsize dlen;
+ const gchar *data;
+
+ data = lua_tolstring(L, 1, &dlen);
+
+#ifndef HAVE_FMEMOPEN
+ return luaL_error(L, "no support of loading from memory on your system");
+#endif
+ f = fmemopen((void *) data, dlen, "rb");
+ }
+ else if (lua_isuserdata(L, 1)) {
+ struct rspamd_lua_text *t;
+
+ t = lua_check_text(L, 1);
+
+ if (!t) {
+ return luaL_error(L, "invalid arguments");
+ }
+
+#ifndef HAVE_FMEMOPEN
+ return luaL_error(L, "no support of loading from memory on your system");
+#endif
+ f = fmemopen((void *) t->start, t->len, "rb");
+ }
+
+ if (f == NULL) {
+ return luaL_error(L, "invalid arguments or cannot open file");
+ }
+
+ k = kann_load_fp(f);
+ fclose(f);
+
+ if (k == NULL) {
+ lua_pushnil(L);
+ }
+ else {
+ PUSH_KAN_NETWORK(k);
+ }
+
+ return 1;
+}
+
+struct rspamd_kann_train_cbdata {
+ lua_State *L;
+ kann_t *k;
+ gint cbref;
+};
+
+static void
+lua_kann_train_cb(int iter, float train_cost, float val_cost, void *ud)
+{
+ struct rspamd_kann_train_cbdata *cbd = (struct rspamd_kann_train_cbdata *) ud;
+
+ if (cbd->cbref != -1) {
+ gint err_idx;
+ lua_State *L = cbd->L;
+
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ err_idx = lua_gettop(L);
+
+ lua_rawgeti(L, LUA_REGISTRYINDEX, cbd->cbref);
+ lua_pushinteger(L, iter);
+ lua_pushnumber(L, train_cost);
+ lua_pushnumber(L, val_cost);
+
+ if (lua_pcall(L, 3, 0, err_idx) != 0) {
+ msg_err("cannot run lua train callback: %s",
+ lua_tostring(L, -1));
+ }
+
+ lua_settop(L, err_idx - 1);
+ }
+}
+
+#define FREE_VEC(a, n) \
+ do { \
+ for (int i = 0; i < (n); i++) g_free((a)[i]); \
+ g_free(a); \
+ } while (0)
+
+static int
+lua_kann_train1(lua_State *L)
+{
+ kann_t *k = lua_check_kann(L, 1);
+ struct rspamd_lua_tensor *pca = NULL;
+
+ /* Default train params */
+ double lr = 0.001;
+ gint64 mini_size = 64;
+ gint64 max_epoch = 25;
+ gint64 max_drop_streak = 10;
+ double frac_val = 0.1;
+ gint cbref = -1;
+
+ if (k && lua_istable(L, 2) && lua_istable(L, 3)) {
+ int n = rspamd_lua_table_size(L, 2);
+ int n_in = kann_dim_in(k);
+ int n_out = kann_dim_out(k);
+
+ if (n_in <= 0) {
+ return luaL_error(L, "invalid inputs count: %d", n_in);
+ }
+
+ if (n_out <= 0) {
+ return luaL_error(L, "invalid outputs count: %d", n_out);
+ }
+
+ if (n != rspamd_lua_table_size(L, 3) || n == 0) {
+ return luaL_error(L, "invalid dimensions: outputs size must be "
+ "equal to inputs and non zero");
+ }
+
+ if (lua_istable(L, 4)) {
+ GError *err = NULL;
+
+ if (!rspamd_lua_parse_table_arguments(L, 4, &err,
+ RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING,
+ "lr=N;mini_size=I;max_epoch=I;max_drop_streak=I;frac_val=N;cb=F;pca=u{tensor}",
+ &lr, &mini_size, &max_epoch, &max_drop_streak, &frac_val, &cbref, &pca)) {
+ n = luaL_error(L, "invalid params: %s",
+ err ? err->message : "unknown error");
+ g_error_free(err);
+
+ return n;
+ }
+ }
+
+ if (pca) {
+ /* Check pca matrix validity */
+ if (pca->ndims != 2) {
+ return luaL_error(L, "invalid pca tensor: matrix expected, got a row");
+ }
+
+ if (pca->dim[0] != n_in) {
+ return luaL_error(L, "invalid pca tensor: "
+ "matrix must have %d rows and it has %d rows instead",
+ n_in, pca->dim[0]);
+ }
+ }
+
+ float **x, **y, *tmp_row = NULL;
+
+ /* Fill vectors row by row */
+ x = (float **) g_malloc0(sizeof(float *) * n);
+ y = (float **) g_malloc0(sizeof(float *) * n);
+
+ if (pca) {
+ tmp_row = g_malloc(sizeof(float) * pca->dim[1]);
+ }
+
+ for (int s = 0; s < n; s++) {
+ /* Inputs */
+ lua_rawgeti(L, 2, s + 1);
+ x[s] = (float *) g_malloc(sizeof(float) * n_in);
+
+ if (pca == NULL) {
+ if (rspamd_lua_table_size(L, -1) != n_in) {
+ FREE_VEC(x, n);
+ FREE_VEC(y, n);
+
+ n = luaL_error(L, "invalid params at pos %d: "
+ "bad input dimension %d; %d expected",
+ s + 1,
+ (int) rspamd_lua_table_size(L, -1),
+ n_in);
+ lua_pop(L, 1);
+
+ return n;
+ }
+
+ for (int i = 0; i < n_in; i++) {
+ lua_rawgeti(L, -1, i + 1);
+ x[s][i] = lua_tonumber(L, -1);
+
+ lua_pop(L, 1);
+ }
+ }
+ else {
+ if (rspamd_lua_table_size(L, -1) != pca->dim[1]) {
+ FREE_VEC(x, n);
+ FREE_VEC(y, n);
+ g_free(tmp_row);
+
+ n = luaL_error(L, "(pca on) invalid params at pos %d: "
+ "bad input dimension %d; %d expected",
+ s + 1,
+ (int) rspamd_lua_table_size(L, -1),
+ pca->dim[1]);
+ lua_pop(L, 1);
+
+ return n;
+ }
+
+
+ for (int i = 0; i < pca->dim[1]; i++) {
+ lua_rawgeti(L, -1, i + 1);
+ tmp_row[i] = lua_tonumber(L, -1);
+
+ lua_pop(L, 1);
+ }
+
+ kad_sgemm_simple(0, 1, 1, n_in,
+ pca->dim[1], tmp_row, pca->data,
+ x[s]);
+ }
+
+ lua_pop(L, 1);
+
+ /* Outputs */
+ y[s] = (float *) g_malloc(sizeof(float) * n_out);
+ lua_rawgeti(L, 3, s + 1);
+
+ if (rspamd_lua_table_size(L, -1) != n_out) {
+ FREE_VEC(x, n);
+ FREE_VEC(y, n);
+ g_free(tmp_row);
+
+ n = luaL_error(L, "invalid params at pos %d: "
+ "bad output dimension %d; "
+ "%d expected",
+ s + 1,
+ (int) rspamd_lua_table_size(L, -1),
+ n_out);
+ lua_pop(L, 1);
+
+ return n;
+ }
+
+ for (int i = 0; i < n_out; i++) {
+ lua_rawgeti(L, -1, i + 1);
+ y[s][i] = lua_tonumber(L, -1);
+
+ lua_pop(L, 1);
+ }
+
+ lua_pop(L, 1);
+ }
+
+ struct rspamd_kann_train_cbdata cbd;
+
+ cbd.cbref = cbref;
+ cbd.k = k;
+ cbd.L = L;
+
+ int niters = kann_train_fnn1(k, lr,
+ mini_size, max_epoch, max_drop_streak,
+ frac_val, n, x, y, lua_kann_train_cb, &cbd);
+
+ lua_pushinteger(L, niters);
+
+ FREE_VEC(x, n);
+ FREE_VEC(y, n);
+ g_free(tmp_row);
+ }
+ else {
+ return luaL_error(L, "invalid arguments: kann, inputs, outputs and"
+ " optional params are expected");
+ }
+
+ return 1;
+}
+
+static int
+lua_kann_apply1(lua_State *L)
+{
+ kann_t *k = lua_check_kann(L, 1);
+ struct rspamd_lua_tensor *pca = NULL;
+
+ if (k) {
+ if (lua_istable(L, 2)) {
+ gsize vec_len = rspamd_lua_table_size(L, 2);
+ float *vec = (float *) g_malloc(sizeof(float) * vec_len),
+ *pca_out = NULL;
+ int i_out;
+ int n_in = kann_dim_in(k);
+
+ if (n_in <= 0) {
+ g_free(vec);
+ return luaL_error(L, "invalid inputs count: %d", n_in);
+ }
+
+ if (lua_isuserdata(L, 3)) {
+ pca = lua_check_tensor(L, 3);
+
+ if (pca) {
+ if (pca->ndims != 2) {
+ g_free(vec);
+ return luaL_error(L, "invalid pca tensor: matrix expected, got a row");
+ }
+
+ if (pca->dim[0] != n_in) {
+ g_free(vec);
+ return luaL_error(L, "invalid pca tensor: "
+ "matrix must have %d rows and it has %d rows instead",
+ n_in, pca->dim[0]);
+ }
+ }
+ else {
+ g_free(vec);
+ return luaL_error(L, "invalid params: pca matrix expected");
+ }
+ }
+ else {
+ if (n_in != vec_len) {
+ g_free(vec);
+ return luaL_error(L, "invalid params: bad input dimension %d; %d expected",
+ (int) vec_len, n_in);
+ }
+ }
+
+ for (gsize i = 0; i < vec_len; i++) {
+ lua_rawgeti(L, 2, i + 1);
+ vec[i] = lua_tonumber(L, -1);
+ lua_pop(L, 1);
+ }
+
+ i_out = kann_find(k, KANN_F_OUT, 0);
+
+ if (i_out <= 0) {
+ g_free(vec);
+ return luaL_error(L, "invalid ANN: output layer is missing or is "
+ "at the input pos");
+ }
+
+ kann_set_batch_size(k, 1);
+ if (pca) {
+ pca_out = g_malloc(sizeof(float) * n_in);
+
+ kad_sgemm_simple(0, 1, 1, n_in,
+ vec_len, vec, pca->data,
+ pca_out);
+
+ kann_feed_bind(k, KANN_F_IN, 0, &pca_out);
+ }
+ else {
+ kann_feed_bind(k, KANN_F_IN, 0, &vec);
+ }
+
+ kad_eval_at(k->n, k->v, i_out);
+
+ gsize outlen = kad_len(k->v[i_out]);
+ lua_createtable(L, outlen, 0);
+
+ for (gsize i = 0; i < outlen; i++) {
+ lua_pushnumber(L, k->v[i_out]->x[i]);
+ lua_rawseti(L, -2, i + 1);
+ }
+
+ g_free(vec);
+ g_free(pca_out);
+ }
+ else if (lua_isuserdata(L, 2)) {
+ struct rspamd_lua_tensor *t = lua_check_tensor(L, 2);
+
+ if (t && t->ndims == 1) {
+ int i_out;
+ int n_in = kann_dim_in(k);
+
+ if (n_in != t->dim[0]) {
+ return luaL_error(L, "invalid params: bad input dimension %d; %d expected",
+ (int) t->dim[0], n_in);
+ }
+
+ i_out = kann_find(k, KANN_F_OUT, 0);
+
+ if (i_out <= 0) {
+ return luaL_error(L, "invalid ANN: output layer is missing or is "
+ "at the input pos");
+ }
+
+ kann_set_batch_size(k, 1);
+ kann_feed_bind(k, KANN_F_IN, 0, &t->data);
+ kad_eval_at(k->n, k->v, i_out);
+
+ gint outlen = kad_len(k->v[i_out]);
+ struct rspamd_lua_tensor *out;
+ out = lua_newtensor(L, 1, &outlen, false, false);
+ /* Ensure that kann and tensor have the same understanding of floats */
+ G_STATIC_ASSERT(sizeof(float) == sizeof(rspamd_tensor_num_t));
+ memcpy(out->data, k->v[i_out]->x, outlen * sizeof(float));
+ }
+ else {
+ return luaL_error(L, "invalid arguments: 1D rspamd{tensor} expected");
+ }
+ }
+ else {
+ return luaL_error(L, "invalid arguments: 1D rspamd{tensor} expected");
+ }
+ }
+ else {
+ return luaL_error(L, "invalid arguments: rspamd{kann} expected");
+ }
+
+ return 1;
+} \ No newline at end of file