summaryrefslogtreecommitdiffstats
path: root/modules/daf/daf.lua
blob: c3b089bf8e7b979fba2d9cd89f716ec343b30d7a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
-- SPDX-License-Identifier: GPL-3.0-or-later

local ffi = require('ffi')

-- Load dependent modules
if not view then modules.load('view') end
if not policy then modules.load('policy') end

-- Actions
local actions = {
	pass = function() return policy.PASS end,
	deny = function () return policy.DENY end,
	drop = function() return policy.DROP end,
	tc = function() return policy.TC end,
	truncate = function() return policy.TC end,
	forward = function (g)
		local addrs = {}
		local tok = g()
		for addr in string.gmatch(tok, '[^,]+') do
			table.insert(addrs, addr)
		end
		return policy.FORWARD(addrs)
	end,
	mirror = function (g)
		return policy.MIRROR(g())
	end,
	reroute = function (g)
		local rules = {}
		local tok = g()
		while tok do
			local from, to = tok:match '([^-]+)-(%S+)'
			rules[from] = to
			tok = g()
		end
		return policy.REROUTE(rules)
	end,
	rewrite = function (g)
		local rules = {}
		local tok = g()
		while tok do
			-- This is currently limited to A/AAAA rewriting
			-- in fixed format '<owner> <type> <addr>'
			local _, to = g(), g()
			rules[tok] = to
			tok = g()
		end
		return policy.REROUTE(rules, true)
	end,
}

-- Filter rules per column
local filters = {
	-- Filter on QNAME (either pattern or suffix match)
	qname = function (g)
		local op, val = g(), todname(g())
		if     op == '~' then return policy.pattern(true, val:sub(2)) -- Skip leading label length
		elseif op == '=' then return policy.suffix(true, {val})
		else error(string.format('invalid operator "%s" on qname', op)) end
	end,
	-- Filter on source address
	src = function (g)
		local op = g()
		if op ~= '=' then error('address supports only "=" operator') end
		return view.rule_src(true, g())
	end,
	-- Filter on destination address
	dst = function (g)
		local op = g()
		if op ~= '=' then error('address supports only "=" operator') end
		return view.rule_dst(true, g())
	end,
}

local function parse_filter(tok, g, prev)
	if not tok then error(string.format('expected filter after "%s"', prev)) end
	local filter = filters[tok:lower()]
	if not filter then error(string.format('invalid filter "%s"', tok)) end
	return filter(g)
end

local function parse_rule(g)
	-- Allow action without filter
	local tok = g()
	if tok == nil then
		error('empty rule is not allowed')
	end
	if not filters[tok:lower()] then
		return tok, nil
	end
	local f = parse_filter(tok, g)
	-- Compose filter functions on conjunctions
	-- or terminate filter chain and return
	tok = g()
	while tok do
		if tok:lower() == 'and' then
			local fa, fb = f, parse_filter(g(), g, tok)
			f = function (req, qry) return fa(req, qry) and fb(req, qry) end
		elseif tok:lower() == 'or' then
			local fa, fb = f, parse_filter(g(), g, tok)
			f = function (req, qry) return fa(req, qry) or fb(req, qry) end
		else
			break
		end
		tok = g()
	end
	return tok, f
end

local function parse_query(g)
	local ok, actid, filter = pcall(parse_rule, g)
	if not ok then return nil, actid end
	actid = actid:lower()
	if not actions[actid] then return nil, string.format('invalid action "%s"', actid) end
	-- Parse and interpret action
	local action = actions[actid]
	if type(action) == 'function' then
		action = action(g)
	end
	return actid, action, filter
end

-- Compile a rule described by query language
-- The query language is modelled by iptables/nftables
-- conj = AND | OR
-- op = IS | NOT | LIKE | IN
-- filter = <key> <op> <expr>
-- rule = <filter> | <filter> <conj> <rule>
-- action = PASS | DENY | DROP | TC | FORWARD
-- query = <rule> <action>
local function compile(query)
	local g = string.gmatch(query, '%S+')
	return parse_query(g)
end

-- @function Describe given rule for presentation
local function rule_info(r)
	return {info=r.info, id=r.rule.id, active=(r.rule.suspended ~= true), count=r.rule.count}
end

-- Module declaration
local M = {
	rules = {}
}

-- @function Remove a rule

-- @function Cleanup module
function M.deinit()
	if http then
		local endpoints = http.configs._builtin.webmgmt.endpoints
		endpoints['/daf'] = nil
		endpoints['/daf.js'] = nil
		http.snippets['/daf'] = nil
	end
end

-- @function Add rule
function M.add(rule)
	-- Ignore duplicates
	for _, r in ipairs(M.rules) do
		if r.info == rule then return r end
	end
	local id, action, filter = compile(rule)
	if not id then error(action) end
	-- Combine filter and action into policy
	local p
	if filter then
		p = function (req, qry)
			return filter(req, qry) and action
		end
	else
		p = function ()
			return action
		end
	end
	local desc = {info=rule, policy=p}
	-- Enforce in policy module, special actions are postrules
	if id == 'reroute' or id == 'rewrite' then
		desc.rule = policy.add(p, true)
	else
		desc.rule = policy.add(p)
	end
	table.insert(M.rules, desc)
	return desc
end

-- @function Remove a rule
function M.del(id)
	for key, r in ipairs(M.rules) do
		if r.rule.id == id then
			policy.del(id)
			table.remove(M.rules, key)
			return true
		end
	end
	return nil
end

-- @function Remove all rules
function M.clear()
	for _, r in ipairs(M.rules) do
		policy.del(r.rule.id)
	end
	M.rules = {}
	return true
end

-- @function Find a rule
function M.get(id)
	for _, r in ipairs(M.rules) do
		if r.rule.id == id then
			return r
		end
	end
	return nil
end

-- @function Enable/disable a rule
function M.toggle(id, val)
	for _, r in ipairs(M.rules) do
		if r.rule.id == id then
			r.rule.suspended = not val
			return true
		end
	end
	return nil
end

-- @function Enable/disable a rule
function M.disable(id)
	return M.toggle(id, false)
end
function M.enable(id)
	return M.toggle(id, true)
end

local function consensus(op, ...)
	local results = map(string.format(op, ...))
	local ret = results.n > 0  -- init to true for non-empty results
	for idx=1, results.n do
		ret = ret and results[idx]
	end
	return ret
end

-- @function Public-facing API
local function api(h, stream)
	local m = h:get(':method')
	-- GET method
	if m == 'GET' then
		local path = h:get(':path')
		local id = tonumber(path:match '/([^/]*)$')
		if id then
			local r = M.get(id)
			if r then
				return rule_info(r)
			end
			return 404, '"No such rule"' -- Not found
		else
			local ret = {}
			for _, r in ipairs(M.rules) do
				table.insert(ret, rule_info(r))
			end
			return ret
		end
	-- DELETE method
	elseif m == 'DELETE' then
		local path = h:get(':path')
		local id = tonumber(path:match '/([^/]*)$')
		if id then
			if consensus('daf.del(%s)', id) then
				return tojson(true)
			end
			return 404, '"No such rule"' -- Not found
		end
		return 400 -- Request doesn't have numeric id
	-- POST method
	elseif m == 'POST' then
		local query = stream:get_body_as_string()
		if query then
			local ok, r = pcall(M.add, query)
			if not ok then return 500, string.format('"%s"', r:match('/([^/]+)$')) end
			-- Dispatch to all other workers:
			-- we ignore return values except error() because they are not serializable
			consensus('daf.add "%s" and true', query)
			return rule_info(r)
		end
		return 400
	-- PATCH method
	elseif m == 'PATCH' then
		local path = h:get(':path')
		local id, action, val = path:match '(%d+)/([^/]*)/([^/]*)$'
		id = tonumber(id)
		if not id or not action or not val then
			return 400 -- Request not well formatted
		end
		-- We do not support more actions
		if action == 'active' then
			if consensus('daf.toggle(%d, %s)', id, val == 'true' or 'false') then
				return tojson(true)
			else
				return 404, '"No such rule"'
			end
		else
			return 501, '"Action not implemented"'
		end
	end
end

local function getmatches()
	local update = {}
	-- Must have string keys for JSON object and not an array
	local inst_counters = map('ret = {} '
		.. 'for _, rule in ipairs(daf.rules) do '
			.. 'ret[tostring(rule.rule.id)] = rule.rule.count '
		.. 'end '
		.. 'return ret')
	for inst_idx=1, inst_counters.n do
		for r_id, r_cnt in pairs(inst_counters[inst_idx]) do
			update[r_id] = (update[r_id] or 0) + r_cnt
		end
	end
	return update
end

-- @function Publish DAF statistics
local function publish(_, ws)
	local ok, last = true, nil
	while ok do
		-- Check if we have new rule matches
		local diff = {}
		local has_update, update = pcall(getmatches)
		if has_update then
			if last then
				for id, count in pairs(update) do
					if not last[id] or last[id] < count then
						diff[id] = count
					end
				end
			end
			last = update
		end
		-- Update counters when there is a new data
		if next(diff) ~= nil then
			ok = ws:send(tojson(diff))
		else
			ok = ws:send_ping()
		end
		worker.sleep(1)
	end
end

function M.init()
	-- avoid ordering problem between HTTP and daf module
	event.after(0, M.config)
end

-- @function Configure module
function M.config()
	if not http then
		log_warn(ffi.C.LOG_GRP_DAF,
			'HTTP API unavailable because HTTP module is not loaded, use modules.load("http")')
		return
	end
	local endpoints = http.configs._builtin.webmgmt.endpoints
	-- Export API and data publisher
	endpoints['/daf.js'] = http.page('daf.js', 'daf')
	endpoints['/daf'] = {'application/json', api, publish}
	-- Export snippet
	http.snippets['/daf'] = {'Application Firewall', [[
		<script type="text/javascript" src="daf.js"></script>
		<div class="row" style="margin-bottom: 5px">
			<form id="daf-builder-form">
				<div class="col-md-11">
					<input type="text" id="daf-builder" class="form-control" aria-label="..." />
				</div>
				<div class="col-md-1">
					<button type="button" id="daf-add" class="btn btn-default btn-sm">Add</button>
				</div>
			</form>
		</div>
		<div class="row">
			<div class="col-md-12">
				<table id="daf-rules" class="table table-striped table-responsive">
				<th><td>No rules here yet.</td></th>
				</table>
			</div>
		</div>
	]]}
end

return M