summaryrefslogtreecommitdiffstats
path: root/test/lua/unit/kann.lua
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-10 21:30:40 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-10 21:30:40 +0000
commit133a45c109da5310add55824db21af5239951f93 (patch)
treeba6ac4c0a950a0dda56451944315d66409923918 /test/lua/unit/kann.lua
parentInitial commit. (diff)
downloadrspamd-133a45c109da5310add55824db21af5239951f93.tar.xz
rspamd-133a45c109da5310add55824db21af5239951f93.zip
Adding upstream version 3.8.1.upstream/3.8.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'test/lua/unit/kann.lua')
-rw-r--r--test/lua/unit/kann.lua46
1 files changed, 46 insertions, 0 deletions
diff --git a/test/lua/unit/kann.lua b/test/lua/unit/kann.lua
new file mode 100644
index 0000000..4f8185b
--- /dev/null
+++ b/test/lua/unit/kann.lua
@@ -0,0 +1,46 @@
+-- Simple kann test (xor function vs 2 layer MLP)
+
+context("Kann test", function()
+ local kann = require "rspamd_kann"
+ local k
+ local inputs = {
+ {0, 0},
+ {0, 1},
+ {1, 0},
+ {1, 1}
+ }
+
+ local outputs = {
+ {0},
+ {1},
+ {1},
+ {0}
+ }
+
+ local t = kann.layer.input(2)
+ t = kann.transform.relu(t)
+ t = kann.transform.tanh(kann.layer.dense(t, 2));
+ t = kann.layer.cost(t, 1, kann.cost.mse)
+ k = kann.new.kann(t)
+
+ local iters = 500
+ local niter = k:train1(inputs, outputs, {
+ lr = 0.01,
+ max_epoch = iters,
+ mini_size = 80,
+ })
+
+ local ser = k:save()
+ k = kann.load(ser)
+
+ for i,inp in ipairs(inputs) do
+ test(string.format("Check XOR MLP %s ^ %s == %s", inp[1], inp[2], outputs[i][1]),
+ function()
+ local res = math.floor(k:apply1(inp)[1] + 0.5)
+ assert_equal(outputs[i][1], res,
+ tostring(outputs[i][1]) .. " but test returned " .. tostring(res))
+ end)
+ end
+
+
+end) \ No newline at end of file