diff options
Diffstat (limited to '')
-rw-r--r-- | modules/policy/policy.lua | 1109 |
1 files changed, 1109 insertions, 0 deletions
diff --git a/modules/policy/policy.lua b/modules/policy/policy.lua new file mode 100644 index 0000000..47e436f --- /dev/null +++ b/modules/policy/policy.lua @@ -0,0 +1,1109 @@ +-- SPDX-License-Identifier: GPL-3.0-or-later +local kres = require('kres') +local ffi = require('ffi') + +local LOG_GRP_POLICY_TAG = ffi.string(ffi.C.kr_log_grp2name(ffi.C.LOG_GRP_POLICY)) +local LOG_GRP_REQDBG_TAG = ffi.string(ffi.C.kr_log_grp2name(ffi.C.LOG_GRP_REQDBG)) + +local todname = kres.str2dname -- not available during module load otherwise + +-- Counter of unique rules +local nextid = 0 +local function getruleid() + local newid = nextid + nextid = nextid + 1 + return newid +end + +-- Support for client sockets from inside policy actions +local socket_client = function () + return error("missing lua-cqueues library, can't create socket client") +end +local has_socket, socket = pcall(require, 'cqueues.socket') +if has_socket then + socket_client = function (host, port) + local s, err, status + + s = socket.connect({ host = host, port = port, type = socket.SOCK_DGRAM }) + s:setmode('bn', 'bn') + status, err = pcall(s.connect, s) + + if not status then + return status, err + end + return s + end +end + +-- Split address and port from a combined string. +local function addr_split_port(target, default_port) + assert(default_port and type(default_port) == 'number') + local port = ffi.new('uint16_t[1]', default_port) + local addr = ffi.new('char[47]') -- INET6_ADDRSTRLEN + 1 + local ret = ffi.C.kr_straddr_split(target, addr, port) + if ret ~= 0 then + error('failed to parse address ' .. target) + end + return addr, tonumber(port[0]) +end + +-- String address@port -> sockaddr. +local function addr2sock(target, default_port) + local addr, port = addr_split_port(target, default_port) + local sock = ffi.gc(ffi.C.kr_straddr_socket(addr, port, nil), ffi.C.free); + if sock == nil then + error("target '"..target..'" is not a valid IP address') + end + return sock +end + +-- Debug logging for taken policy actions +local function log_policy_action(req, name) + if ffi.C.kr_log_is_debug_fun(ffi.C.LOG_GRP_POLICY, req) then + local qry = req:current() + ffi.C.kr_log_req1( + req, qry.uid, 2, ffi.C.LOG_GRP_POLICY, LOG_GRP_POLICY_TAG, + "%s applied for %s %s\n", + name, kres.dname2str(qry.sname), kres.tostring.type[qry.stype]) + end +end + +-- policy functions are defined below +local policy = {} + +function policy.PASS(state, _) + return state +end + +-- Mirror request elsewhere, and continue solving +function policy.MIRROR(target) + local addr, port = addr_split_port(target, 53) + local sink, err = socket_client(ffi.string(addr), port) + if not sink then panic('MIRROR target %s is not a valid: %s', target, err) end + return function(state, req) + if state == kres.FAIL then return state end + local query = req.qsource.packet + if query ~= nil then + sink:send(ffi.string(query.wire, query.size), 1, tonumber(query.size)) + end + return -- Chain action to next + end +end + +-- Override the list of nameservers (forwarders) +local function set_nslist(req, list) + local ns_i = 0 + for _, ns in ipairs(list) do + if ffi.C.kr_forward_add_target(req, ns) == 0 then + ns_i = ns_i + 1 + end + end + if ns_i == 0 then + -- would use assert() but don't want to compose the message if not triggered + error('no usable address in NS set (check net.ipv4 and ' + .. 'net.ipv6 config):\n' .. table_print(list, 2)) + end +end + +-- Forward request, and solve as stub query +function policy.STUB(target) + local list = {} + if type(target) == 'table' then + for _, v in pairs(target) do + table.insert(list, addr2sock(v, 53)) + end + else + table.insert(list, addr2sock(target, 53)) + end + return function(state, req) + local qry = req:current() + -- Switch mode to stub resolver, do not track origin zone cut since it's not real authority NS + qry.flags.STUB = true + qry.flags.ALWAYS_CUT = false + set_nslist(req, list) + return state + end +end + +-- Forward request and all subrequests to upstream; validate answers +function policy.FORWARD(target) + local list = {} + if type(target) == 'table' then + for _, v in pairs(target) do + table.insert(list, addr2sock(v, 53)) + end + else + table.insert(list, addr2sock(target, 53)) + end + return function(state, req) + local qry = req:current() + req.options.FORWARD = true + req.options.NO_MINIMIZE = true + qry.flags.FORWARD = true + qry.flags.ALWAYS_CUT = false + qry.flags.NO_MINIMIZE = true + qry.flags.AWAIT_CUT = true + set_nslist(req, list) + return state + end +end + +-- Forward request and all subrequests to upstream over TLS; validate answers +function policy.TLS_FORWARD(targets) + if type(targets) ~= 'table' or #targets < 1 then + error('TLS_FORWARD argument must be a non-empty table') + end + + local sockaddr_c_set = {} + local nslist = {} -- to persist in closure of the returned function + for idx, target in pairs(targets) do + if type(target) ~= 'table' or type(target[1]) ~= 'string' then + error(string.format('TLS_FORWARD configuration at position ' .. + '%d must be a table starting with an IP address', idx)) + end + -- Note: some functions have checks with error() calls inside. + local sockaddr_c = addr2sock(target[1], 853) + + -- Refuse repeated addresses in the same set. + local sockaddr_lua = ffi.string(sockaddr_c, ffi.C.kr_sockaddr_len(sockaddr_c)) + if sockaddr_c_set[sockaddr_lua] then + error('TLS_FORWARD configuration cannot declare two configs for IP address ' + .. target[1]) + else + sockaddr_c_set[sockaddr_lua] = true; + end + + table.insert(nslist, sockaddr_c) + net.tls_client(target) + end + + return function(state, req) + local qry = req:current() + req.options.FORWARD = true + req.options.NO_MINIMIZE = true + qry.flags.FORWARD = true + qry.flags.ALWAYS_CUT = false + qry.flags.NO_MINIMIZE = true + qry.flags.AWAIT_CUT = true + req.options.TCP = true + qry.flags.TCP = true + set_nslist(req, nslist) + return state + end +end + +-- Rewrite records in packet +function policy.REROUTE(tbl, names) + -- Import renumbering rules + local ren = require('kres_modules.renumber') + local prefixes = {} + for from, to in pairs(tbl) do + local prefix = names and ren.name(from, to) or ren.prefix(from, to) + table.insert(prefixes, prefix) + end + -- Return rule closure + return ren.rule(prefixes) +end + +-- Set and clear some query flags +function policy.FLAGS(opts_set, opts_clear) + return function(_, req) + -- We assume to be running in the begin phase, so to truly apply + -- to the whole request we need to change both kr_request and kr_query. + local qry = req:current() + for _, flags in pairs({qry.flags, req.options}) do + ffi.C.kr_qflags_set (flags, kres.mk_qflags(opts_set or {})) + ffi.C.kr_qflags_clear(flags, kres.mk_qflags(opts_clear or {})) + end + return nil -- chain rule + end +end + +local function mkauth_soa(answer, dname, mname, ttl) + if mname == nil then + mname = dname + end + return answer:put(dname, ttl or 10800, answer:qclass(), kres.type.SOA, + mname .. '\6nobody\7invalid\0\0\0\0\1\0\0\14\16\0\0\4\176\0\9\58\128\0\0\42\48') +end + +-- Create answer with passed arguments +function policy.ANSWER(rtable, nodata) + return function(_, req) + local qry = req:current() + local data = rtable[qry.stype] + if data == nil and nodata ~= true then + return nil + end + -- now we're certain we want to generate an answer + + local answer = req:ensure_answer() + if answer == nil then return nil end + ffi.C.kr_pkt_make_auth_header(answer) + local ttl = (data or {}).ttl or 1 + answer:rcode(kres.rcode.NOERROR) + req:set_extended_error(kres.extended_error.FORGED, "5DO5") + + if data == nil then -- want NODATA, i.e. just a SOA + answer:begin(kres.section.AUTHORITY) + local soa = rtable[kres.type.SOA] + if soa ~= nil then + answer:put(qry.sname, soa.ttl or ttl, qry.sclass, kres.type.SOA, + soa.rdata[1] or soa.rdata) + else + mkauth_soa(answer, kres.dname2wire(qry.sname), nil, ttl) + end + log_policy_action(req, 'ANSWER (nodata)') + else + answer:begin(kres.section.ANSWER) + if type(data.rdata) == 'table' then + for _, entry in ipairs(data.rdata) do + answer:put(qry.sname, ttl, qry.sclass, qry.stype, entry) + end + else + answer:put(qry.sname, ttl, qry.sclass, qry.stype, data.rdata) + end + log_policy_action(req, 'ANSWER (forged)') + end + return kres.DONE + end +end + +local dname_localhost = todname('localhost.') + +-- Rule for localhost. zone; see RFC6303, sec. 3 +local function localhost(_, req) + local qry = req:current() + local answer = req:ensure_answer() + if answer == nil then return nil end + ffi.C.kr_pkt_make_auth_header(answer) + + local is_exact = ffi.C.knot_dname_is_equal(qry.sname, dname_localhost) + + answer:rcode(kres.rcode.NOERROR) + answer:begin(kres.section.ANSWER) + if qry.stype == kres.type.AAAA then + answer:put(qry.sname, 900, answer:qclass(), kres.type.AAAA, + '\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\1') + elseif qry.stype == kres.type.A then + answer:put(qry.sname, 900, answer:qclass(), kres.type.A, '\127\0\0\1') + elseif is_exact and qry.stype == kres.type.SOA then + mkauth_soa(answer, dname_localhost) + elseif is_exact and qry.stype == kres.type.NS then + answer:put(dname_localhost, 900, answer:qclass(), kres.type.NS, dname_localhost) + else + answer:begin(kres.section.AUTHORITY) + mkauth_soa(answer, dname_localhost) + end + return kres.DONE +end + +local dname_rev4_localhost = todname('1.0.0.127.in-addr.arpa'); +local dname_rev4_localhost_apex = todname('127.in-addr.arpa'); + +-- Rule for reverse localhost. +-- Answer with locally served minimal 127.in-addr.arpa domain, only having +-- a PTR record in 1.0.0.127.in-addr.arpa, and with 1.0...0.ip6.arpa. zone. +-- TODO: much of this would better be left to the hints module (or coordinated). +local function localhost_reversed(_, req) + local qry = req:current() + local answer = req:ensure_answer() + if answer == nil then return nil end + + -- classify qry.sname: + local is_exact -- exact dname for localhost + local is_apex -- apex of a locally-served localhost zone + local is_nonterm -- empty non-terminal name + if ffi.C.knot_dname_in_bailiwick(qry.sname, todname('ip6.arpa.')) > 0 then + -- exact ::1 query (relying on the calling rule) + is_exact = true + is_apex = true + else + -- within 127.in-addr.arpa. + local labels = ffi.C.knot_dname_labels(qry.sname, nil) + if labels == 3 then + is_exact = false + is_apex = true + elseif labels == 4+2 and ffi.C.knot_dname_is_equal( + qry.sname, dname_rev4_localhost) then + is_exact = true + else + is_exact = false + is_apex = false + is_nonterm = ffi.C.knot_dname_in_bailiwick(dname_rev4_localhost, qry.sname) > 0 + end + end + + ffi.C.kr_pkt_make_auth_header(answer) + answer:rcode(kres.rcode.NOERROR) + answer:begin(kres.section.ANSWER) + if is_exact and qry.stype == kres.type.PTR then + answer:put(qry.sname, 900, answer:qclass(), kres.type.PTR, dname_localhost) + elseif is_apex and qry.stype == kres.type.SOA then + mkauth_soa(answer, dname_rev4_localhost_apex, dname_localhost) + elseif is_apex and qry.stype == kres.type.NS then + answer:put(dname_rev4_localhost_apex, 900, answer:qclass(), kres.type.NS, + dname_localhost) + else + if not is_nonterm then + answer:rcode(kres.rcode.NXDOMAIN) + end + answer:begin(kres.section.AUTHORITY) + mkauth_soa(answer, dname_rev4_localhost_apex, dname_localhost) + end + return kres.DONE +end + +-- All requests +function policy.all(action) + return function(_, _) return action end +end + +-- Requests whose QNAME is exactly the provided domain +function policy.domains(action, dname_list) + return function(_, query) + local qname = query:name() + for _, dname in ipairs(dname_list) do + if ffi.C.knot_dname_is_equal(qname, dname) then + return action + end + end + return nil + end +end + +-- Requests whose QNAME matches given zone list (i.e. suffix match) +function policy.suffix(action, zone_list) + local AC = require('ahocorasick') + local tree = AC.create(zone_list) + return function(_, query) + local match = AC.match(tree, query:name(), false) + if match ~= nil then + return action + end + return nil + end +end + +-- Check for common suffix first, then suffix match (specialized version of suffix match) +function policy.suffix_common(action, suffix_list, common_suffix) + local common_len = string.len(common_suffix) + local suffix_count = #suffix_list + return function(_, query) + -- Preliminary check + local qname = query:name() + if not string.find(qname, common_suffix, -common_len, true) then + return nil + end + -- String match + for i = 1, suffix_count do + local zone = suffix_list[i] + if string.find(qname, zone, -string.len(zone), true) then + return action + end + end + return nil + end +end + +-- Filter QNAME pattern +function policy.pattern(action, pattern) + return function(_, query) + if string.find(query:name(), pattern) then + return action + end + return nil + end +end + +local function rpz_parse(action, path) + local rules = {} + local new_actions = {} + local action_map = { + -- RPZ Policy Actions + ['\0'] = action, + ['\1*\0'] = policy.ANSWER({}, true), + ['\012rpz-passthru\0'] = policy.PASS, -- the grammar... + ['\008rpz-drop\0'] = policy.DROP, + ['\012rpz-tcp-only\0'] = policy.TC, + -- Policy triggers @NYI@ + } + -- RR types to be skipped; boolean denoting whether to throw a warning even for RPZ apex. + local rrtype_bad = { + [kres.type.DNAME] = true, + [kres.type.NS] = false, + [kres.type.DNSKEY] = true, + [kres.type.DS] = true, + [kres.type.RRSIG] = true, + [kres.type.NSEC] = true, + [kres.type.NSEC3] = true, + } + + -- We generally don't know what zone should be in the file; we try to detect it. + -- Fortunately, it's typical that SOA is the first record, even required for AXFR. + local origin_soa = nil + local warned_soa, warned_bailiwick + + local parser = require('zonefile').new() + local ok, errstr = parser:open(path) + if not ok then + error(string.format('failed to parse "%s": %s', path, errstr or "unknown error")) + end + while true do + ok, errstr = parser:parse() + if errstr then + log_warn(ffi.C.LOG_GRP_POLICY, 'RPZ %s:%d: %s', + path, tonumber(parser.line_counter), errstr) + end + if not ok then break end + + local full_name = ffi.gc(ffi.C.knot_dname_copy(parser.r_owner, nil), ffi.C.free) + local rdata = ffi.string(parser.r_data, parser.r_data_length) + ffi.C.knot_dname_to_lower(full_name) + + local origin = origin_soa or parser.zone_origin + local prefix_labels = ffi.C.knot_dname_in_bailiwick(full_name, origin) + if prefix_labels < 0 then + if not warned_bailiwick then + warned_bailiwick = true + log_warn(ffi.C.LOG_GRP_POLICY, + 'RPZ %s:%d: RR owner "%s" outside the zone (ignored; reported once per file)', + path, tonumber(parser.line_counter), kres.dname2str(full_name)) + end + goto continue + end + + local bytes = ffi.C.knot_dname_size(full_name) - ffi.C.knot_dname_size(origin) + local name = ffi.string(full_name, bytes) .. '\0' + + if parser.r_type == kres.type.CNAME then + if action_map[rdata] then + rules[name] = action_map[rdata] + else + log_warn(ffi.C.LOG_GRP_POLICY, + 'RPZ %s:%d: CNAME with custom target in RPZ is not supported yet (ignored)', + path, tonumber(parser.line_counter)) + end + else + if #name then + local is_bad = rrtype_bad[parser.r_type] + + if parser.r_type == kres.type.SOA then + if origin_soa == nil then + origin_soa = ffi.gc(ffi.C.knot_dname_copy(parser.r_owner, nil), ffi.C.free) + goto continue -- we don't want to modify `new_actions` + else + is_bad = true -- maybe provide more info, but it seems rare + end + elseif origin_soa == nil and not warned_soa then + warned_soa = true + log_warn(ffi.C.LOG_GRP_POLICY, + 'RPZ %s:%d warning: SOA missing as the first record', + path, tonumber(parser.line_counter)) + end + + if is_bad == true or (is_bad == false and prefix_labels ~= 0) then + log_warn(ffi.C.LOG_GRP_POLICY, 'RPZ %s:%d warning: RR type %s is not allowed in RPZ (ignored)', + path, tonumber(parser.line_counter), kres.tostring.type[parser.r_type]) + elseif is_bad == nil then + if new_actions[name] == nil then new_actions[name] = {} end + local act = new_actions[name][parser.r_type] + if act == nil then + new_actions[name][parser.r_type] = { ttl=parser.r_ttl, rdata=rdata } + else -- multiple RRs: no reordering or deduplication + if type(act.rdata) ~= 'table' then + act.rdata = { act.rdata } + end + table.insert(act.rdata, rdata) + if parser.r_ttl ~= act.ttl then -- be conservative + log_warn(ffi.C.LOG_GRP_POLICY, 'RPZ %s:%d warning: different TTLs in a set (minimum taken)', + path, tonumber(parser.line_counter)) + act.ttl = math.min(act.ttl, parser.r_ttl) + end + end + else + assert(is_bad == false and prefix_labels == 0) + end + end + end + + ::continue:: + end + collectgarbage() + for qname, rrsets in pairs(new_actions) do + rules[qname] = policy.ANSWER(rrsets, true) + end + return rules +end + +-- Split path into dirname and basename (like the shell utilities) +local function get_dir_and_file(path) + local dir, file = string.match(path, "(.*)/([^/]+)") + + -- If regex doesn't match then path must be the file directly (i.e. doesn't contain '/') + -- This assumes that the file exists (rpz_parse() would fail if it doesn't) + if not dir and not file then + dir = '.' + file = path + end + + return dir, file +end + +-- RPZ policy set +-- Create RPZ from zone file and optionally watch the file for changes +function policy.rpz(action, path, watch) + local rules = rpz_parse(action, path) + + if watch ~= false then + local has_notify, notify = pcall(require, 'cqueues.notify') + if has_notify then + local bit = require('bit') + + local dir, file = get_dir_and_file(path) + local watcher = notify.opendir(dir) + watcher:add(file, bit.bxor(notify.CREATE, notify.MODIFY)) + + worker.coroutine(function () + for _, name in watcher:changes() do + -- Limit to changes on file we're interested in + -- Watcher will also fire for changes to the directory itself + if name == file then + -- If the file changes then reparse and replace the existing ruleset + log_info(ffi.C.LOG_GRP_POLICY, 'RPZ reloading: ' .. name) + rules = rpz_parse(action, path) + end + end + end) + elseif watch then -- explicitly requested and failed + error('[poli] lua-cqueues required to watch and reload RPZ file') + else + log_info(ffi.C.LOG_GRP_POLICY, 'lua-cqueues required to watch and reload RPZ file, continuing without watching') + end + end + + return function(_, query) + local label = query:name() + local rule = rules[label] + while rule == nil and string.len(label) > 0 do + label = string.sub(label, string.byte(label) + 2) + rule = rules['\1*'..label] + end + return rule + end +end + +-- Apply an action when query belongs to a slice (determined by slice_func()) +function policy.slice(slice_func, ...) + local actions = {...} + if #actions <= 0 then + error('[poli] at least one action must be provided to policy.slice()') + end + + return function(_, query) + local index = slice_func(query, #actions) + return actions[index] + end +end + +-- Initializes slicing function that randomly assigns queries to a slice based on their registrable domain +function policy.slice_randomize_psl(seed) + local has_psl, psl_lib = pcall(require, 'psl') + if not has_psl then + error('[poli] lua-psl is required for policy.slice_randomize_psl()') + end + -- load psl + local has_latest, psl = pcall(psl_lib.latest) + if not has_latest then -- compatibility with lua-psl < 0.15 + psl = psl_lib.builtin() + end + + if seed == nil then + seed = os.time() / (3600 * 24 * 7) + end + seed = math.floor(seed) -- convert to int + + return function(query, length) + assert(length > 0) + + local domain = kres.dname2str(query:name()) + if domain == nil then -- invalid data: auto-select first action + return 1 + end + if domain:len() > 1 then --remove trailing dot + domain = domain:sub(0, -2) + end + + -- do psl lookup for registrable domain + local reg_domain = psl:registrable_domain(domain) + if reg_domain == nil then -- fallback to unreg. domain + reg_domain = psl:unregistrable_domain(domain) + if reg_domain == nil then -- shouldn't happen: safe fallback + return 1 + end + end + + local rand_seed = seed + -- create deterministic seed for pseudo-random slice assignment + for i = 1, #reg_domain do + rand_seed = rand_seed + reg_domain:byte(i) + end + + -- use linear congruential generator with values from ANSI C + rand_seed = rand_seed % 0x80000000 -- ensure seed is positive 32b int + local rand = (1103515245 * rand_seed + 12345) % 0x10000 + return 1 + rand % length + end +end + +-- Prepare for making an answer from scratch. (Return the packet for convenience.) +local function answer_clear(req) + -- If we're in postrules, previous resolving might have chosen some RRs + -- for inclusion in the answer, so we need to avoid those. + -- *_selected arrays are in mempool, so explicit deallocation is not necessary. + req.answ_selected.len = 0 + req.auth_selected.len = 0 + req.add_selected.len = 0 + + -- Let's be defensive and clear the answer, too. + local pkt = req:ensure_answer() + if pkt == nil then return nil end + pkt:clear_payload() + req:ensure_edns() + return pkt +end + +function policy.DENY_MSG(msg, extended_error) + if msg and (type(msg) ~= 'string' or #msg >= 255) then + error('DENY_MSG: optional msg must be string shorter than 256 characters') + end + if extended_error == nil then + extended_error = kres.extended_error.BLOCKED + end + local action_name = msg and 'DENY_MSG' or 'DENY' + + return function (_, req) + -- Write authority information + local answer = answer_clear(req) + if answer == nil then return nil end + ffi.C.kr_pkt_make_auth_header(answer) + answer:rcode(kres.rcode.NXDOMAIN) + answer:begin(kres.section.AUTHORITY) + mkauth_soa(answer, answer:qname()) + if msg then + answer:begin(kres.section.ADDITIONAL) + answer:put('\11explanation\7invalid', 10800, answer:qclass(), kres.type.TXT, + string.char(#msg) .. msg) + + end + req:set_extended_error(extended_error, "CR36") + log_policy_action(req, action_name) + return kres.DONE + end +end + +local function free_cb(func) + func:free() +end + +local debug_logline_cb = ffi.cast('trace_log_f', function (_, msg) + jit.off(true, true) -- JIT for (C -> lua)^2 nesting isn't allowed + ffi.C.kr_log_fmt( + ffi.C.LOG_GRP_REQDBG, -- but the original [group] tag also remains in the string + LOG_DEBUG, + 'CODE_FILE=policy.lua', 'CODE_LINE=', 'CODE_FUNC=policy.DEBUG_ALWAYS', -- no meaningful locations + '[%-6s]%s', LOG_GRP_REQDBG_TAG, msg) -- msg should end with newline already +end) +ffi.gc(debug_logline_cb, free_cb) + +-- LOG_DEBUG without log_trace and without code locations +local function log_notrace(req, fmt, ...) + ffi.C.kr_log_fmt(ffi.C.LOG_GRP_REQDBG, LOG_DEBUG, + 'CODE_FILE=policy.lua', 'CODE_LINE=', 'CODE_FUNC=', -- no meaningful locations + '%s', string.format( -- convert in lua, as integers in C varargs would pass as double + '[%-6s][%-6s][%05u.00] ' .. fmt, + LOG_GRP_REQDBG_TAG, LOG_GRP_POLICY_TAG, req.uid, ...) + ) +end + +local debug_logfinish_cb = ffi.cast('trace_callback_f', function (req) + jit.off(true, true) -- JIT for (C -> lua)^2 nesting isn't allowed + log_notrace(req, 'following rrsets were marked as interesting:\n%s\n', + req:selected_tostring()) + if req.answer ~= nil then + log_notrace(req, 'answer packet:\n%s\n', req.answer) + else + log_notrace(req, 'answer packet DROPPED\n') + end +end) +ffi.gc(debug_logfinish_cb, free_cb) + +-- log request packet +function policy.REQTRACE(_, req) + log_notrace(req, 'request packet:\n%s', req.qsource.packet) +end + +-- log how the request arrived, notably the client's IP +function policy.IPTRACE(_, req) + if req.qsource.addr == nil then + log_notrace(req, 'request packet arrived internally\n') + else + -- stringify transport flags: struct kr_request_qsource_flags + local qf = req.qsource.flags + local qf_str = qf.tcp and 'TCP' or 'UDP' + if qf.tls then qf_str = qf_str .. ' + TLS' end + if qf.http then qf_str = qf_str .. ' + HTTP' end + if qf.xdp then qf_str = qf_str .. ' + XDP' end + + log_notrace(req, 'request packet arrived from %s to %s (%s)\n', + req.qsource.addr, req.qsource.dst_addr, qf_str) + end + return nil -- chain rule +end + +function policy.DEBUG_ALWAYS(state, req) + policy.QTRACE(state, req) + req:trace_chain_callbacks(debug_logline_cb, debug_logfinish_cb) + policy.REQTRACE(state, req) +end + +local debug_stashlog_cb = ffi.cast('trace_log_f', function (req, msg) + jit.off(true, true) -- JIT for (C -> lua)^2 nesting isn't allowed + + -- stash messages for conditional logging in trace_finish + local stash = req:vars()['policy_debug_stash'] + table.insert(stash, ffi.string(msg)) +end) +ffi.gc(debug_stashlog_cb, free_cb) + +-- buffer debug logs and print then only if test() returns a truthy value +function policy.DEBUG_IF(test) + local debug_finish_cb = ffi.cast('trace_callback_f', function (cbreq) + jit.off(true, true) -- JIT for (C -> lua)^2 nesting isn't allowed + if test(cbreq) then + policy.REQTRACE(nil, cbreq) + debug_logfinish_cb(cbreq) -- unconditional version + + local stash = cbreq:vars()['policy_debug_stash'] + for _, line in ipairs(stash) do -- don't want one huge entry + ffi.C.kr_log_fmt(ffi.C.LOG_GRP_REQDBG, LOG_DEBUG, + 'CODE_FILE=policy.lua', 'CODE_LINE=', 'CODE_FUNC=', -- no meaningful locations + '[%-6s]%s', LOG_GRP_REQDBG_TAG, line) + end + end + end) + ffi.gc(debug_finish_cb, function (func) func:free() end) + + return function (state, req) + req:vars()['policy_debug_stash'] = {} + policy.QTRACE(state, req) + req:trace_chain_callbacks(debug_stashlog_cb, debug_finish_cb) + return + end +end + +policy.DEBUG_CACHE_MISS = policy.DEBUG_IF( + function(req) + return not req:all_from_cache() + end +) + +policy.DENY = policy.DENY_MSG() -- compatibility with < 2.0 + +function policy.DROP(_, req) + local answer = answer_clear(req) + if answer == nil then return nil end + req:set_extended_error(kres.extended_error.PROHIBITED, "U5KL") + log_policy_action(req, 'DROP') + return kres.FAIL +end + +function policy.NO_ANSWER(_, req) + req.options.NO_ANSWER = true + log_policy_action(req, 'NO_ANSWER') + return kres.FAIL +end + +function policy.REFUSE(_, req) + local answer = answer_clear(req) + if answer == nil then return nil end + answer:rcode(kres.rcode.REFUSED) + answer:ad(false) + req:set_extended_error(kres.extended_error.PROHIBITED, "EIM4") + log_policy_action(req, 'REFUSE') + return kres.DONE +end + +function policy.TC(state, req) + -- Avoid non-UDP queries + if req.qsource.addr == nil or req.qsource.flags.tcp then + return state + end + + local answer = answer_clear(req) + if answer == nil then return nil end + answer:tc(1) + answer:ad(false) + log_policy_action(req, 'TC') + return kres.DONE +end + +function policy.QTRACE(_, req) + local qry = req:current() + req.options.TRACE = true + qry.flags.TRACE = true + return -- this allows to continue iterating over policy list +end + +-- Evaluate packet in given rules to determine policy action +function policy.evaluate(rules, req, query, state) + for i = 1, #rules do + local rule = rules[i] + if not rule.suspended then + local action = rule.cb(req, query) + if action ~= nil then + rule.count = rule.count + 1 + local next_state = action(state, req) + if next_state then -- Not a chain rule, + return next_state -- stop on first match + end + end + end + end + return +end + +-- Add rule to policy list +function policy.add(rule, postrule) + -- Compatibility with 1.0.0 API + -- it will be dropped in 1.2.0 + if rule == policy then + rule = postrule + postrule = nil + end + -- End of compatibility shim + local desc = {id=getruleid(), cb=rule, count=0} + table.insert(postrule and policy.postrules or policy.rules, desc) + return desc +end + +-- Remove rule from a list +local function delrule(rules, id) + for i, r in ipairs(rules) do + if r.id == id then + table.remove(rules, i) + return true + end + end + return false +end + +-- Delete rule from policy list +function policy.del(id) + if not delrule(policy.rules, id) then + if not delrule(policy.postrules, id) then + return false + end + end + return true +end + +-- Convert list of string names to domain names +function policy.todnames(names) + for i, v in ipairs(names) do + names[i] = kres.str2dname(v) + end + return names +end + +-- RFC1918 Private, local, broadcast, test and special zones +-- Considerations: RFC6761, sec 6.1. +-- https://www.iana.org/assignments/locally-served-dns-zones +local private_zones = { + -- RFC6303 + '10.in-addr.arpa.', + '16.172.in-addr.arpa.', + '17.172.in-addr.arpa.', + '18.172.in-addr.arpa.', + '19.172.in-addr.arpa.', + '20.172.in-addr.arpa.', + '21.172.in-addr.arpa.', + '22.172.in-addr.arpa.', + '23.172.in-addr.arpa.', + '24.172.in-addr.arpa.', + '25.172.in-addr.arpa.', + '26.172.in-addr.arpa.', + '27.172.in-addr.arpa.', + '28.172.in-addr.arpa.', + '29.172.in-addr.arpa.', + '30.172.in-addr.arpa.', + '31.172.in-addr.arpa.', + '168.192.in-addr.arpa.', + '0.in-addr.arpa.', + '254.169.in-addr.arpa.', + '2.0.192.in-addr.arpa.', + '100.51.198.in-addr.arpa.', + '113.0.203.in-addr.arpa.', + '255.255.255.255.in-addr.arpa.', + -- RFC7793 + '64.100.in-addr.arpa.', + '65.100.in-addr.arpa.', + '66.100.in-addr.arpa.', + '67.100.in-addr.arpa.', + '68.100.in-addr.arpa.', + '69.100.in-addr.arpa.', + '70.100.in-addr.arpa.', + '71.100.in-addr.arpa.', + '72.100.in-addr.arpa.', + '73.100.in-addr.arpa.', + '74.100.in-addr.arpa.', + '75.100.in-addr.arpa.', + '76.100.in-addr.arpa.', + '77.100.in-addr.arpa.', + '78.100.in-addr.arpa.', + '79.100.in-addr.arpa.', + '80.100.in-addr.arpa.', + '81.100.in-addr.arpa.', + '82.100.in-addr.arpa.', + '83.100.in-addr.arpa.', + '84.100.in-addr.arpa.', + '85.100.in-addr.arpa.', + '86.100.in-addr.arpa.', + '87.100.in-addr.arpa.', + '88.100.in-addr.arpa.', + '89.100.in-addr.arpa.', + '90.100.in-addr.arpa.', + '91.100.in-addr.arpa.', + '92.100.in-addr.arpa.', + '93.100.in-addr.arpa.', + '94.100.in-addr.arpa.', + '95.100.in-addr.arpa.', + '96.100.in-addr.arpa.', + '97.100.in-addr.arpa.', + '98.100.in-addr.arpa.', + '99.100.in-addr.arpa.', + '100.100.in-addr.arpa.', + '101.100.in-addr.arpa.', + '102.100.in-addr.arpa.', + '103.100.in-addr.arpa.', + '104.100.in-addr.arpa.', + '105.100.in-addr.arpa.', + '106.100.in-addr.arpa.', + '107.100.in-addr.arpa.', + '108.100.in-addr.arpa.', + '109.100.in-addr.arpa.', + '110.100.in-addr.arpa.', + '111.100.in-addr.arpa.', + '112.100.in-addr.arpa.', + '113.100.in-addr.arpa.', + '114.100.in-addr.arpa.', + '115.100.in-addr.arpa.', + '116.100.in-addr.arpa.', + '117.100.in-addr.arpa.', + '118.100.in-addr.arpa.', + '119.100.in-addr.arpa.', + '120.100.in-addr.arpa.', + '121.100.in-addr.arpa.', + '122.100.in-addr.arpa.', + '123.100.in-addr.arpa.', + '124.100.in-addr.arpa.', + '125.100.in-addr.arpa.', + '126.100.in-addr.arpa.', + '127.100.in-addr.arpa.', + + -- RFC6303 + -- localhost_reversed handles ::1 + '0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.', + 'd.f.ip6.arpa.', + '8.e.f.ip6.arpa.', + '9.e.f.ip6.arpa.', + 'a.e.f.ip6.arpa.', + 'b.e.f.ip6.arpa.', + '8.b.d.0.1.0.0.2.ip6.arpa.', + -- RFC8375 + 'home.arpa.', +} +policy.todnames(private_zones) + +-- @var Default rules +policy.rules = {} +policy.postrules = {} +policy.special_names = { + -- XXX: beware of special_names_optim() when modifying these filters + { + cb=policy.suffix_common(policy.DENY_MSG( + 'Blocking is mandated by standards, see references on ' + .. 'https://www.iana.org/assignments/' + .. 'locally-served-dns-zones/locally-served-dns-zones.xhtml', + kres.extended_error.NOTSUP), + private_zones, todname('arpa.')), + count=0 + }, + { + cb=policy.suffix(policy.DENY_MSG( + 'Blocking is mandated by standards, see references on ' + .. 'https://www.iana.org/assignments/' + .. 'special-use-domain-names/special-use-domain-names.xhtml', + kres.extended_error.NOTSUP), + { + todname('test.'), + todname('onion.'), + todname('invalid.'), + todname('local.'), -- RFC 8375.4 + }), + count=0 + }, + { + cb=policy.suffix(localhost, {dname_localhost}), + count=0 + }, + { + cb=policy.suffix_common(localhost_reversed, { + todname('127.in-addr.arpa.'), + todname('1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.')}, + todname('arpa.')), + count=0 + }, +} + +-- Return boolean; false = no special name may apply, true = some might apply. +-- The point is to *efficiently* filter almost all QNAMEs that do not apply. +local function special_names_optim(req, sname) + local qname_size = req.qsource.packet.qname_size + if qname_size < 9 then return true end -- don't want to special-case bad array access + local root = sname + qname_size - 1 + return + -- .a???. or .t???. + (root[-5] == 4 and (root[-4] == 97 or root[-4] == 116)) + -- .on???. or .in?????. or lo???. or *ost. + or (root[-6] == 5 and root[-5] == 111 and root[-4] == 110) + or (root[-8] == 7 and root[-7] == 105 and root[-6] == 110) + or (root[-6] == 5 and root[-5] == 108 and root[-4] == 111) + or (root[-3] == 111 and root[-2] == 115 and root[-1] == 116) +end + +-- Top-down policy list walk until we hit a match +-- the caller is responsible for reordering policy list +-- from most specific to least specific. +-- Some rules may be chained, in this case they are evaluated +-- as a dependency chain, e.g. r1,r2,r3 -> r3(r2(r1(state))) +policy.layer = { + begin = function(state, req) + -- Don't act on "finished" cases. + if bit.band(state, bit.bor(kres.FAIL, kres.DONE)) ~= 0 then return state end + local qry = req:initial() -- same as :current() but more descriptive + return policy.evaluate(policy.rules, req, qry, state) + or (special_names_optim(req, qry.sname) + and policy.evaluate(policy.special_names, req, qry, state)) + or state + end, + finish = function(state, req) + -- Optimization for the typical case + if #policy.postrules == 0 then return state end + -- Don't act on failed cases. + if bit.band(state, kres.FAIL) ~= 0 then return state end + return policy.evaluate(policy.postrules, req, req:initial(), state) or state + end +} + +return policy |