122 lines
3.4 KiB
Lua
122 lines
3.4 KiB
Lua
-- SPDX-License-Identifier: GPL-3.0-or-later
|
|
local kres = require('kres')
|
|
local ffi = require('ffi')
|
|
local C = ffi.C
|
|
|
|
-- Module declaration
|
|
local view = {
|
|
key = {}, -- map from :owner() to list of policy rules
|
|
src = {},
|
|
dst = {},
|
|
}
|
|
|
|
-- @function View based on TSIG key name.
|
|
function view.tsig(_, tsig, rule)
|
|
if view.key[tsig] == nil then
|
|
view.key[tsig] = { rule }
|
|
else
|
|
table.insert(view.key[tsig], rule)
|
|
end
|
|
end
|
|
|
|
-- @function View based on source IP subnet.
|
|
function view.addr(_, subnet, rules, dst)
|
|
local subnet_cd = ffi.new('char[16]')
|
|
local family = C.kr_straddr_family(subnet)
|
|
local bitlen = C.kr_straddr_subnet(subnet_cd, subnet)
|
|
if bitlen < 0 then
|
|
error(string.format('failed to parse subnet %s', subnet))
|
|
end
|
|
local t = {family, subnet_cd, bitlen, rules}
|
|
table.insert(dst and view.dst or view.src, t)
|
|
return t
|
|
end
|
|
|
|
-- @function Match IP against given subnet
|
|
local function match_subnet(family, subnet, bitlen, addr)
|
|
return (family == addr:family()) and (C.kr_bitcmp(subnet, addr:ip(), bitlen) == 0)
|
|
end
|
|
|
|
-- @function Execute a policy callback (may be nil);
|
|
-- return boolean: whether to continue trying further rules.
|
|
local function execute(state, req, match_cb)
|
|
if match_cb == nil then return false end
|
|
local action = match_cb(req, req:current())
|
|
if action == nil then return false end
|
|
local next_state = action(state, req)
|
|
if next_state then -- Not a chain rule,
|
|
req.state = next_state
|
|
return true
|
|
else
|
|
return false
|
|
end
|
|
end
|
|
|
|
-- @function Try all the rules in order, until a non-chain rule gets executed.
|
|
local function evaluate(state, req)
|
|
-- Try :tsig rules first.
|
|
local client_key = req.qsource.packet.tsig_rr
|
|
local match_cbs = (client_key ~= nil) and view.key[client_key:owner()] or {}
|
|
for _, match_cb in ipairs(match_cbs) do
|
|
if execute(state, req, match_cb) then return end
|
|
end
|
|
-- Then try :addr by the source.
|
|
if req.qsource.addr ~= nil then
|
|
for i = 1, #view.src do
|
|
local pair = view.src[i]
|
|
if match_subnet(pair[1], pair[2], pair[3], req.qsource.addr) then
|
|
local match_cb = pair[4]
|
|
if execute(state, req, match_cb) then return end
|
|
end
|
|
end
|
|
end
|
|
-- Finally try :addr by the destination.
|
|
if req.qsource.dst_addr ~= nil then
|
|
for i = 1, #view.dst do
|
|
local pair = view.dst[i]
|
|
if match_subnet(pair[1], pair[2], pair[3], req.qsource.dst_addr) then
|
|
local match_cb = pair[4]
|
|
if execute(state, req, match_cb) then return end
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
-- @function Return policy based on source address
|
|
function view.rule_src(action, subnet)
|
|
local subnet_cd = ffi.new('char[16]')
|
|
local family = C.kr_straddr_family(subnet)
|
|
local bitlen = C.kr_straddr_subnet(subnet_cd, subnet)
|
|
return function(req, _)
|
|
local addr = req.qsource.addr
|
|
if addr ~= nil and match_subnet(family, subnet_cd, bitlen, addr) then
|
|
return action
|
|
end
|
|
end
|
|
end
|
|
|
|
-- @function Return policy based on destination address
|
|
function view.rule_dst(action, subnet)
|
|
local subnet_cd = ffi.new('char[16]')
|
|
local family = C.kr_straddr_family(subnet)
|
|
local bitlen = C.kr_straddr_subnet(subnet_cd, subnet)
|
|
return function(req, _)
|
|
local addr = req.qsource.dst_addr
|
|
if addr ~= nil and match_subnet(family, subnet_cd, bitlen, addr) then
|
|
return action
|
|
end
|
|
end
|
|
end
|
|
|
|
-- @function Module layers
|
|
view.layer = {
|
|
begin = function(state, req)
|
|
-- Don't act on "finished" cases.
|
|
if bit.band(state, bit.bor(kres.FAIL, kres.DONE)) ~= 0 then return state end
|
|
|
|
evaluate(state, req)
|
|
return req.state
|
|
end
|
|
}
|
|
|
|
return view
|