summaryrefslogtreecommitdiffstats
path: root/test/functional/lua/neural.lua
blob: 5a09c50fccba4e138624dd5f3f164f27b08e4d3f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
local logger = require "rspamd_logger"

for i = 1,14 do
  rspamd_config:register_symbol({
    name = 'SPAM_SYMBOL'..tostring(i),
    score = 5.0,
    callback = function()
      return true, 'Fires always'
    end
  })
  rspamd_config:register_symbol({
    name = 'HAM_SYMBOL'..tostring(i),
    score = -3.0,
    callback = function()
      return true, 'Fires always'
    end
  })
end



rspamd_config:register_symbol({
  name = 'NEUTRAL_SYMBOL',
  score = 1.0,
  flags = 'explicit_disable',
  callback = function()
    return true, 'Fires always'
  end
})

rspamd_config.SAVE_NN_ROW = {
  callback = function(task)
    local fname = os.tmpname()
    task:cache_set('nn_row_tmpfile', fname)
    return true, 1.0, fname
  end
}

rspamd_config.SAVE_NN_ROW_IDEMPOTENT = {
  callback = function(task)
    local function tohex(str)
      return (str:gsub('.', function (c)
        return string.format('%02X', string.byte(c))
      end))
    end
    local fname = task:cache_get('nn_row_tmpfile')
    if not fname then
      return
    end
    local f, err = io.open(fname, 'w')
    if not f then
      logger.errx(task, err)
      return
    end
    f:write(tohex(task:cache_get('SHORT_neural_vec_mpack') or ''))
    f:close()
    return
  end,
  type = 'idempotent',
  flags = 'explicit_disable',
  priority = 100,
}

dofile(rspamd_env.INSTALLROOT .. "/share/rspamd/rules/controller/init.lua")