summaryrefslogtreecommitdiffstats
path: root/test/lua/unit/kann.lua
blob: 4f8185b023be67f107606b67422211936c083488 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
-- Simple kann test (xor function vs 2 layer MLP)

context("Kann test", function()
  local kann = require "rspamd_kann"
  local k
  local inputs = {
    {0, 0},
    {0, 1},
    {1, 0},
    {1, 1}
  }

  local outputs = {
    {0},
    {1},
    {1},
    {0}
  }

  local t = kann.layer.input(2)
  t = kann.transform.relu(t)
  t = kann.transform.tanh(kann.layer.dense(t, 2));
  t = kann.layer.cost(t, 1, kann.cost.mse)
  k = kann.new.kann(t)

  local iters = 500
  local niter = k:train1(inputs, outputs, {
    lr = 0.01,
    max_epoch = iters,
    mini_size = 80,
  })

  local ser = k:save()
  k = kann.load(ser)

  for i,inp in ipairs(inputs) do
    test(string.format("Check XOR MLP %s ^ %s == %s", inp[1], inp[2], outputs[i][1]),
        function()
          local res = math.floor(k:apply1(inp)[1] + 0.5)
          assert_equal(outputs[i][1], res,
              tostring(outputs[i][1]) .. " but test returned " .. tostring(res))
        end)
  end


end)