summaryrefslogtreecommitdiffstats
path: root/tests/config/test_utils.lua
diff options
context:
space:
mode:
Diffstat (limited to 'tests/config/test_utils.lua')
-rw-r--r--tests/config/test_utils.lua121
1 files changed, 121 insertions, 0 deletions
diff --git a/tests/config/test_utils.lua b/tests/config/test_utils.lua
new file mode 100644
index 0000000..4389293
--- /dev/null
+++ b/tests/config/test_utils.lua
@@ -0,0 +1,121 @@
+-- SPDX-License-Identifier: GPL-3.0-or-later
+local M = {}
+
+function M.test(f, ...)
+ local res, exception = xpcall(f, debug.traceback, ...)
+ if not res then
+ io.stderr:write(string.format('%s\n', exception))
+ os.exit(2)
+ end
+ return res
+end
+
+function M.table_keys_to_lower(table)
+ local res = {}
+ for k, v in pairs(table) do
+ res[k:lower()] = v
+ end
+ return res
+end
+
+local function contains(pass, fail, table, value, message)
+ message = message or string.format('table contains "%s"', value)
+ for _, v in pairs(table) do
+ if v == value then
+ pass(message)
+ return
+ end
+ end
+ fail(message)
+ return
+end
+
+function M.contains(table, value, message)
+ return contains(pass, fail, table, value, message)
+end
+
+function M.not_contains(table, value, message)
+ return contains(fail, pass, table, value, message)
+end
+
+local function answer2table(pkt)
+ local got_answers = {}
+ local ans_rrs = pkt:rrsets(kres.section.ANSWER)
+ for i = 1, #ans_rrs do
+ rrs = ans_rrs[i]
+ for rri = 0, rrs:rdcount() - 1 do
+ local rr = ans_rrs[i]:txt_fields(rri)
+ got_answers[rr.owner] = got_answers[rr.owner] or {}
+ got_answers[rr.owner][rr.type] = got_answers[rr.owner][rr.type] or {}
+ table.insert(got_answers[rr.owner][rr.type], rr.rdata)
+ table.sort(got_answers[rr.owner][rr.type])
+ end
+ end
+ return got_answers
+end
+
+M.NODATA = -1
+-- Resolve a name and check the answer. Do *not* return until finished.
+-- expected_rdata is one string or a table of strings in presentation format
+function M.check_answer(desc, qname, qtype, expected_rcode, expected_rdata)
+ assert(type(qtype) == 'number')
+ local qtype_str = kres.tostring.type[qtype]
+ qname = string.lower(qname)
+
+ local expected_answer = {}
+ if expected_rdata ~= nil then
+ if type(expected_rdata) ~= 'table' then
+ expected_rdata = { expected_rdata }
+ end
+ if #expected_rdata > 0 then
+ table.sort(expected_rdata)
+ expected_answer = {
+ [qname] = {
+ [qtype_str] =
+ expected_rdata
+ }
+ }
+ end
+ end
+
+ local wire_rcode = expected_rcode
+ if expected_rcode == kres.rcode.NOERROR and type(expected_rdata) == 'table'
+ and #expected_rdata == 0 then
+ expected_rcode = M.NODATA
+ end
+ if expected_rcode == M.NODATA then wire_rcode = kres.rcode.NOERROR end
+
+ local done = false
+ local callback = function(pkt)
+ ok(pkt, 'answer not dropped')
+ same(pkt:rcode(), wire_rcode,
+ desc .. ': expecting answer for query ' .. qname .. ' ' .. qtype_str
+ .. ' with rcode ' .. kres.tostring.rcode[wire_rcode])
+
+ ok((pkt:ancount() > 0) == (expected_rcode == kres.rcode.NOERROR),
+ desc ..': checking number of answers for ' .. qname .. ' ' .. qtype_str)
+
+ if expected_rdata then
+ same(expected_answer, answer2table(pkt), 'ANSWER section matches')
+ end
+ done = true
+ end
+ resolve(qname, qtype, kres.class.IN, {},
+ function(...)
+ local ok, err = xpcall(callback, debug.traceback, ...)
+ if not ok then
+ fail('error in check_answer callback function')
+ io.stderr:write(string.format('%s\n', err))
+ os.exit(2)
+ end
+ end
+ )
+
+ for delay = 0.1, 5, 0.5 do -- total max 23.5s in 9 steps
+ if done then return end
+ worker.sleep(delay)
+ end
+ if not done then fail('check_answer() timed out') end
+end
+
+return M