summaryrefslogtreecommitdiffstats
path: root/modules/view/view.lua
blob: f5e186261d3fb4aa268566172543b4e046d38904 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
-- SPDX-License-Identifier: GPL-3.0-or-later
local kres = require('kres')
local ffi = require('ffi')
local C = ffi.C

-- Module declaration
local view = {
	key = {}, -- map from :owner() to list of policy rules
	src = {},
	dst = {},
}

-- @function View based on TSIG key name.
function view.tsig(_, tsig, rule)
	if view.key[tsig] == nil then
		view.key[tsig] = { rule }
	else
		table.insert(view.key[tsig], rule)
	end
end

-- @function View based on source IP subnet.
function view.addr(_, subnet, rules, dst)
	local subnet_cd = ffi.new('char[16]')
	local family = C.kr_straddr_family(subnet)
	local bitlen = C.kr_straddr_subnet(subnet_cd, subnet)
	if bitlen < 0 then
		error(string.format('failed to parse subnet %s', subnet))
	end
	local t = {family, subnet_cd, bitlen, rules}
	table.insert(dst and view.dst or view.src, t)
	return t
end

-- @function Match IP against given subnet
local function match_subnet(family, subnet, bitlen, addr)
	return (family == addr:family()) and (C.kr_bitcmp(subnet, addr:ip(), bitlen) == 0)
end

-- @function Execute a policy callback (may be nil);
-- return boolean: whether to continue trying further rules.
local function execute(state, req, match_cb)
	if match_cb == nil then return false end
	local action = match_cb(req, req:current())
	if action == nil then return false end
	local next_state = action(state, req)
	if next_state then    -- Not a chain rule,
		req.state = next_state
		return true
	else
		return false
	end
end

-- @function Try all the rules in order, until a non-chain rule gets executed.
local function evaluate(state, req)
	-- Try :tsig rules first.
	local client_key = req.qsource.packet.tsig_rr
	local match_cbs = (client_key ~= nil) and view.key[client_key:owner()] or {}
	for _, match_cb in ipairs(match_cbs) do
		if execute(state, req, match_cb) then return end
	end
	-- Then try :addr by the source.
	if req.qsource.addr ~= nil then
		for i = 1, #view.src do
			local pair = view.src[i]
			if match_subnet(pair[1], pair[2], pair[3], req.qsource.addr) then
				local match_cb = pair[4]
				if execute(state, req, match_cb) then return end
			end
		end
	-- Finally try :addr by the destination.
	elseif req.qsource.dst_addr ~= nil then
		for i = 1, #view.dst do
			local pair = view.dst[i]
			if match_subnet(pair[1], pair[2], pair[3], req.qsource.dst_addr) then
				local match_cb = pair[4]
				if execute(state, req, match_cb) then return end
			end
		end
	end
end

-- @function Return policy based on source address
function view.rule_src(action, subnet)
	local subnet_cd = ffi.new('char[16]')
	local family = C.kr_straddr_family(subnet)
	local bitlen = C.kr_straddr_subnet(subnet_cd, subnet)
	return function(req, _)
		local addr = req.qsource.addr
		if addr ~= nil and match_subnet(family, subnet_cd, bitlen, addr) then
			return action
		end
	end
end

-- @function Return policy based on destination address
function view.rule_dst(action, subnet)
	local subnet_cd = ffi.new('char[16]')
	local family = C.kr_straddr_family(subnet)
	local bitlen = C.kr_straddr_subnet(subnet_cd, subnet)
	return function(req, _)
		local addr = req.qsource.dst_addr
		if addr ~= nil and match_subnet(family, subnet_cd, bitlen, addr) then
			return action
		end
	end
end

-- @function Module layers
view.layer = {
	begin = function(state, req)
		-- Don't act on "finished" cases.
		if bit.band(state, bit.bor(kres.FAIL, kres.DONE)) ~= 0 then return state end

		evaluate(state, req)
		return req.state
	end
}

return view