diff options
Diffstat (limited to '')
-rw-r--r-- | tests/deckard/pydnstest/__init__.py | 0 | ||||
-rw-r--r-- | tests/deckard/pydnstest/augwrap.py | 227 | ||||
-rw-r--r-- | tests/deckard/pydnstest/deckard.aug | 94 | ||||
-rw-r--r-- | tests/deckard/pydnstest/empty.rpl | 20 | ||||
-rw-r--r-- | tests/deckard/pydnstest/matchpart.py | 238 | ||||
-rw-r--r-- | tests/deckard/pydnstest/scenario.py | 1058 | ||||
-rw-r--r-- | tests/deckard/pydnstest/tests/__init__.py | 0 | ||||
-rw-r--r-- | tests/deckard/pydnstest/tests/test_parse_config.py | 17 | ||||
-rw-r--r-- | tests/deckard/pydnstest/tests/test_scenario.py | 55 | ||||
-rw-r--r-- | tests/deckard/pydnstest/testserver.py | 278 |
10 files changed, 1987 insertions, 0 deletions
diff --git a/tests/deckard/pydnstest/__init__.py b/tests/deckard/pydnstest/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/deckard/pydnstest/__init__.py diff --git a/tests/deckard/pydnstest/augwrap.py b/tests/deckard/pydnstest/augwrap.py new file mode 100644 index 0000000..20e7857 --- /dev/null +++ b/tests/deckard/pydnstest/augwrap.py @@ -0,0 +1,227 @@ +#!/usr/bin/python3 + +# Copyright (C) 2017 + +import posixpath +import logging +import os +import collections + +from augeas import Augeas + +AUGEAS_LOAD_PATH = '/augeas/load/' +AUGEAS_FILES_PATH = '/files/' +AUGEAS_ERROR_PATH = '//error' + +log = logging.getLogger('augeas') + + +def join(*paths): + """ + join two Augeas tree paths + + FIXME: Beware: // is normalized to / + """ + norm_paths = [posixpath.normpath(path) for path in paths] + # first path must be absolute + assert norm_paths[0][0] == '/' + new_paths = [norm_paths[0]] + # relativize all other paths so join works as expected + for path in norm_paths[1:]: + if path.startswith('/'): + path = path[1:] + new_paths.append(path) + new_path = posixpath.join(*new_paths) + log.debug("join: new_path %s", new_path) + return posixpath.normpath(new_path) + + +class AugeasWrapper: + """python-augeas higher-level wrapper. + + Load single augeas lens and configuration file. + Exposes configuration file as AugeasNode object with dict-like interface. + + AugeasWrapper can be used in with statement in the same way as file does. + """ + + def __init__(self, confpath, lens, root=None, loadpath=None, + flags=Augeas.NO_MODL_AUTOLOAD | Augeas.NO_LOAD | Augeas.ENABLE_SPAN): + """Parse configuration file using given lens. + + Params: + confpath (str): Absolute path to the configuration file + lens (str): Name of module containing Augeas lens + root: passed down to original Augeas + flags: passed down to original Augeas + loadpath: passed down to original Augeas + flags: passed down to original Augeas + """ + log.debug('loadpath: %s', loadpath) + log.debug('confpath: %s', confpath) + self._aug = Augeas(root=root, loadpath=loadpath, flags=flags) + + # /augeas/load/{lens} + aug_load_path = join(AUGEAS_LOAD_PATH, lens) + # /augeas/load/{lens}/lens = {lens}.lns + self._aug.set(join(aug_load_path, 'lens'), '%s.lns' % lens) + # /augeas/load/{lens}/incl[0] = {confpath} + self._aug.set(join(aug_load_path, 'incl[0]'), confpath) + self._aug.load() + + errors = self._aug.match(AUGEAS_ERROR_PATH) + if errors: + err_msg = '\n'.join( + ["{}: {}".format(e, self._aug.get(e)) for e in errors] + ) + raise RuntimeError(err_msg) + + path = join(AUGEAS_FILES_PATH, confpath) + paths = self._aug.match(path) + if len(paths) != 1: + raise ValueError('path %s did not match exactly once' % path) + self.tree = AugeasNode(self._aug, path) + self._loaded = True + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.save() + self.close() + + def save(self): + """Save Augeas tree to its original file.""" + assert self._loaded + try: + self._aug.save() + except IOError as exc: + log.exception(exc) + for err_path in self._aug.match('//error'): + log.error('%s: %s', err_path, + self._aug.get(os.path.join(err_path, 'message'))) + raise + + def close(self): + """ + close Augeas library + + After calling close() the object must not be used anymore. + """ + assert self._loaded + self._aug.close() + del self._aug + self._loaded = False + + def match(self, path): + """Yield AugeasNodes matching given expression.""" + assert self._loaded + assert path + log.debug('tree match %s', path) + for matched_path in self._aug.match(path): + yield AugeasNode(self._aug, matched_path) + + +class AugeasNode(collections.MutableMapping): + """One Augeas tree node with dict-like interface.""" + + def __init__(self, aug, path): + """ + Args: + aug (AugeasWrapper or Augeas): Augeas library instance + path (str): absolute path in Augeas tree matching single node + + BEWARE: There are no sanity checks of given path for performance reasons. + """ + assert aug + assert path + assert path.startswith('/') + self._aug = aug + self._path = path + self._span = None + + @property + def path(self): + """canonical path in Augeas tree, read-only""" + return self._path + + @property + def value(self): + """ + get value of this node in Augeas tree + """ + value = self._aug.get(self._path) + log.debug('tree get: %s = %s', self._path, value) + return value + + @value.setter + def value(self, value): + """ + set value of this node in Augeas tree + """ + log.debug('tree set: %s = %s', self._path, value) + self._aug.set(self._path, value) + + @property + def span(self): + if self._span is None: + self._span = "char position %s" % self._aug.span(self._path)[5] + return self._span + + @property + def char(self): + return self._aug.span(self._path)[5] + + def __len__(self): + """ + number of items matching this path + + It is always 1 after __init__() but it may change + as Augeas tree changes. + """ + return len(self._aug.match(self._path)) + + def __getitem__(self, key): + if isinstance(key, int): + # int is a shortcut to write [int] + target_path = '%s[%s]' % (self._path, key) + else: + target_path = self._path + key + log.debug('tree getitem: target_path %s', target_path) + paths = self._aug.match(target_path) + if len(paths) != 1: + raise KeyError('path %s did not match exactly once' % target_path) + return AugeasNode(self._aug, target_path) + + def __delitem__(self, key): + log.debug('tree delitem: %s + %s', self._path, key) + target_path = self._path + key + log.debug('tree delitem: target_path %s', target_path) + self._aug.remove(target_path) + + def __setitem__(self, key, value): + assert isinstance(value, AugeasNode) + target_path = self.path + key + self._aug.copy(value.path, target_path) + + def __iter__(self): + self_path_len = len(self._path) + assert self_path_len > 0 + + log.debug('tree iter: %s', self._path) + for new_path in self._aug.match(self._path): + if len(new_path) == self_path_len: + yield '' + else: + yield new_path[self_path_len - 1:] + + def match(self, subpath): + """Yield AugeasNodes matching given sub-expression.""" + assert subpath.startswith("/") + match_path = "%s%s" % (self._path, subpath) + log.debug('tree match %s: %s', match_path, self._path) + for matched_path in self._aug.match(match_path): + yield AugeasNode(self._aug, matched_path) + + def __repr__(self): + return 'AugeasNode(%s)' % self._path diff --git a/tests/deckard/pydnstest/deckard.aug b/tests/deckard/pydnstest/deckard.aug new file mode 100644 index 0000000..9e2d167 --- /dev/null +++ b/tests/deckard/pydnstest/deckard.aug @@ -0,0 +1,94 @@ +module Deckard = + autoload xfm + +let del_str = Util.del_str + +let space = del /[ \t]+/ " " +let tab = del /[ \t]+/ "\t" +let ws = del /[\t ]*/ "" +let word = /[^\t\n\/; ]+/ + +let comment = del /[;]/ ";" . [label "comment" . store /[^\n]+/] + +let eol = del /([ \t]*([;][^\n]*)?\n)+/ "\n" . Util.indent +let comment_or_eol = ws . comment? . del_str "\n" . del /([ \t]*([;][^\n]*)?\n)*/ "" . Util.indent + + +(*let comment_or_eol = [ label "#comment" . counter "comment" . (ws . [del /[;#]/ ";" . label "" . store /[^\n]*/ ]? . del_str "\n")]+ . Util.indent +*) + + +let domain_re = (/[^.\t\n\/; ]+(\.[^.\t\n\/; ]+)*\.?/ | ".") - "SECTION" (*quick n dirty, sorry to whoever will ever own SECTION TLD*) +let class_re = /CLASS[0-9]+/ | "IN" | "CH" | "HS" | "NONE" | "ANY" +let domain = [ label "domain" . store domain_re ] +let ttl = [label "ttl" . store /[0-9]+/] +let class = [label "class" . store class_re ] +let type = [label "type" . store ((/[^0-9;\n \t][^\t\n\/; ]*/) - class_re) ] +(* RFC 3597 section 5 rdata syntax is "\# 1 ab"*) +let data_re = /((\\#[ \t])?[^ \t\n;][^\n;]*[^ \t\n;])|[^ \t\n;]/ (*Can not start nor end with whitespace but can have whitespace in the middle. Disjunction is there so we match strings of length one.*) +let data = [label "data" . store data_re ] + +let ip_re = /[0-9a-f.:]+/ +let hex_re = /[0-9a-fA-F]+/ + + +let match_option = "opcode" | "qtype" | "qcase" | "qname" | "subdomain" | "flags" | "rcode" | "question" | "answer" | "authority" | "additional" | "all" | "edns" +let adjust_option = "copy_id" | "copy_query" | "raw_id" | "do_not_answer" +let reply_option = "QR" | "TC" | "AA" | "AD" | "RD" | "RA" | "CD" | "DO" | "NOERROR" | "FORMERR" | "SERVFAIL" | "NXDOMAIN" | "NOTIMP" | "REFUSED" | "YXDOMAIN" | "YXRRSET" | "NXRRSET" | "NOTAUTH" | "NOTZONE" | "BADVERS" | "BADSIG" | "BADKEY" | "BADTIME" | "BADMODE" | "BADNAME" | "BADALG" | "BADTRUNC" | "BADCOOKIE" +let step_option = "REPLY" | "QUERY" | "CHECK_ANSWER" | "CHECK_OUT_QUERY" | /TIME_PASSES[ \t]+ELAPSE/ + +let mandatory = [del_str "MANDATORY" . label "mandatory" . value "true" . comment_or_eol] +let tsig = [del_str "TSIG" . label "tsig" . space . [label "keyname" . store word] . space . [label "secret" . store word] . comment_or_eol] + +let match = (mandatory | tsig)* . [ label "match_present" . value "true" . del_str "MATCH" ] . [space . label "match" . store match_option ]+ . comment_or_eol +let adjust = (mandatory | tsig)* . del_str "ADJUST" . [space . label "adjust" . store adjust_option ]+ . comment_or_eol +let reply = (mandatory | tsig)* . del ("REPLY" | "FLAGS") "REPLY" . [space . label "reply" . store reply_option ]+ . comment_or_eol + + +let question = [label "record" . domain . tab . (class . tab)? . type . comment_or_eol ] +let record = [label "record" . domain . tab . (ttl . tab)? . (class . tab)? . type . tab . data . comment_or_eol] + +let section_question = [ label "question" . del_str "SECTION QUESTION" . + comment_or_eol . question? ] +let section_answer = [ label "answer" . del_str "SECTION ANSWER" . + comment_or_eol . record* ] +let section_authority = [ label "authority" . del_str "SECTION AUTHORITY" . + comment_or_eol . record* ] +let section_additional = [ label "additional" . del_str "SECTION ADDITIONAL" . + comment_or_eol . record* ] +let sections = [label "section" . section_question? . section_answer? . section_authority? . section_additional?] +let raw = [del_str "RAW" . comment_or_eol . label "raw" . store hex_re ] . comment_or_eol + +(* This is quite dirty hack to match every combination of options given to entry since 'let dnsmsg = ((match | adjust | reply | mandatory | tsig)* . sections)' just is not possible *) + +let dnsmsg = (match . (adjust . reply? | reply . adjust?)? | adjust . (match . reply? | reply . match?)? | reply . (match . adjust? | adjust . match?)?)? . (mandatory | tsig)* . sections + +let entry = [label "entry" . del_str "ENTRY_BEGIN" . comment_or_eol . dnsmsg . raw? . del_str "ENTRY_END" . eol] + +let single_address = [ label "address" . space . store ip_re ] + +let addresses = [label "address" . counter "address" . [seq "address" . del_str "ADDRESS" . space . store ip_re . comment_or_eol]+] + +let range = [label "range" . del_str "RANGE_BEGIN" . space . [ label "from" . store /[0-9]+/] . space . + [ label "to" . store /[0-9]+/] . single_address? . comment_or_eol . addresses? . entry* . del_str "RANGE_END" . eol] + +let step = [label "step" . del_str "STEP" . space . store /[0-9]+/ . space . [label "type" . store step_option] . [space . label "timestamp" . store /[0-9]+/]? . comment_or_eol . + entry? ] + +let config_record = /[^\n]*/ - ("CONFIG_END" | /STEP.*/ | /SCENARIO.*/ | /RANGE.*/ | /ENTRY.*/) + +let config = [ label "config" . counter "config" . [seq "config" . store config_record . del_str "\n"]* . del_str "CONFIG_END" . comment_or_eol ] + +let guts = (step | range )* + +let scenario = [label "scenario" . del_str "SCENARIO_BEGIN" . space . store data_re . comment_or_eol . guts . del_str "SCENARIO_END" . eol] + +let lns = config? . scenario + +(* TODO: REPLAY step *) +(* TODO: store all comments into the tree instead of ignoring them *) + +(*let filter = incl "/home/test/*.rpl"*) +let filter = incl "/home/sbalazik/nic/deckard/git/sets/resolver/*.rpl" + +let xfm = transform lns filter diff --git a/tests/deckard/pydnstest/empty.rpl b/tests/deckard/pydnstest/empty.rpl new file mode 100644 index 0000000..295d5a5 --- /dev/null +++ b/tests/deckard/pydnstest/empty.rpl @@ -0,0 +1,20 @@ +stub-addr: 127.0.0.10 +CONFIG_END + +SCENARIO_BEGIN empty replies + +RANGE_BEGIN 0 100 + ADDRESS 127.0.0.10 +ENTRY_BEGIN +MATCH subdomain +ADJUST copy_id copy_query +SECTION QUESTION +. IN A +ENTRY_END +RANGE_END + +STEP 1 QUERY +ENTRY_BEGIN +ENTRY_END + +SCENARIO_END diff --git a/tests/deckard/pydnstest/matchpart.py b/tests/deckard/pydnstest/matchpart.py new file mode 100644 index 0000000..294e64c --- /dev/null +++ b/tests/deckard/pydnstest/matchpart.py @@ -0,0 +1,238 @@ +"""matchpart is used to compare two DNS messages using a single criterion""" + +from typing import ( # noqa + Any, Hashable, Sequence, Tuple, Union) + +import dns.edns +import dns.rcode +import dns.set + +MismatchValue = Union[str, Sequence[Any]] + + +class DataMismatch(Exception): + def __init__(self, exp_val, got_val): + super().__init__() + self.exp_val = exp_val + self.got_val = got_val + + @staticmethod + def format_value(value: MismatchValue) -> str: + if isinstance(value, list): + return ' '.join([str(val) for val in value]) + else: + return str(value) + + def __str__(self) -> str: + return 'expected "{}" got "{}"'.format( + self.format_value(self.exp_val), + self.format_value(self.got_val)) + + def __eq__(self, other): + return (isinstance(other, DataMismatch) + and self.exp_val == other.exp_val + and self.got_val == other.got_val) + + def __ne__(self, other): + return not self.__eq__(other) + + @property + def key(self) -> Tuple[Hashable, Hashable]: + def make_hashable(value): + if isinstance(value, (list, dns.set.Set)): + value = (make_hashable(item) for item in value) + value = tuple(value) + return value + + return (make_hashable(self.exp_val), make_hashable(self.got_val)) + + def __hash__(self) -> int: + return hash(self.key) + + +def compare_val(exp, got): + """Compare arbitraty objects, throw exception if different. """ + if exp != got: + raise DataMismatch(exp, got) + return True + + +def compare_rrs(expected, got): + """ Compare lists of RR sets, throw exception if different. """ + for rr in expected: + if rr not in got: + raise DataMismatch(expected, got) + for rr in got: + if rr not in expected: + raise DataMismatch(expected, got) + if len(expected) != len(got): + raise DataMismatch(expected, got) + return True + + +def compare_rrs_types(exp_val, got_val, skip_rrsigs): + """sets of RR types in both sections must match""" + def rr_ordering_key(rrset): + if rrset.covers: + return rrset.covers, 1 # RRSIGs go to the end of RRtype list + else: + return rrset.rdtype, 0 + + def key_to_text(rrtype, rrsig): + if not rrsig: + return dns.rdatatype.to_text(rrtype) + else: + return 'RRSIG(%s)' % dns.rdatatype.to_text(rrtype) + + if skip_rrsigs: + exp_val = (rrset for rrset in exp_val + if rrset.rdtype != dns.rdatatype.RRSIG) + got_val = (rrset for rrset in got_val + if rrset.rdtype != dns.rdatatype.RRSIG) + + exp_types = frozenset(rr_ordering_key(rrset) for rrset in exp_val) + got_types = frozenset(rr_ordering_key(rrset) for rrset in got_val) + if exp_types != got_types: + exp_types = tuple(key_to_text(*i) for i in sorted(exp_types)) + got_types = tuple(key_to_text(*i) for i in sorted(got_types)) + raise DataMismatch(exp_types, got_types) + + +def check_question(question): + if len(question) > 2: + raise NotImplementedError("More than one record in QUESTION SECTION.") + + +def match_opcode(exp, got): + return compare_val(exp.opcode(), + got.opcode()) + + +def match_qtype(exp, got): + check_question(exp.question) + check_question(got.question) + if not exp.question and not got.question: + return True + if not exp.question: + raise DataMismatch("<empty question>", got.question[0].rdtype) + if not got.question: + raise DataMismatch(exp.question[0].rdtype, "<empty question>") + return compare_val(exp.question[0].rdtype, + got.question[0].rdtype) + + +def match_qname(exp, got): + check_question(exp.question) + check_question(got.question) + if not exp.question and not got.question: + return True + if not exp.question: + raise DataMismatch("<empty question>", got.question[0].name) + if not got.question: + raise DataMismatch(exp.question[0].name, "<empty question>") + return compare_val(exp.question[0].name, + got.question[0].name) + + +def match_qcase(exp, got): + check_question(exp.question) + check_question(got.question) + if not exp.question and not got.question: + return True + if not exp.question: + raise DataMismatch("<empty question>", got.question[0].name.labels) + if not got.question: + raise DataMismatch(exp.question[0].name.labels, "<empty question>") + return compare_val(exp.question[0].name.labels, + got.question[0].name.labels) + + +def match_subdomain(exp, got): + if not exp.question: + return True + if got.question: + qname = got.question[0].name + else: + qname = dns.name.root + if exp.question[0].name.is_superdomain(qname): + return True + raise DataMismatch(exp, got) + + +def match_flags(exp, got): + return compare_val(dns.flags.to_text(exp.flags), + dns.flags.to_text(got.flags)) + + +def match_rcode(exp, got): + return compare_val(dns.rcode.to_text(exp.rcode()), + dns.rcode.to_text(got.rcode())) + + +def match_answer(exp, got): + return compare_rrs(exp.answer, + got.answer) + + +def match_answertypes(exp, got): + return compare_rrs_types(exp.answer, + got.answer, skip_rrsigs=True) + + +def match_answerrrsigs(exp, got): + return compare_rrs_types(exp.answer, + got.answer, skip_rrsigs=False) + + +def match_authority(exp, got): + return compare_rrs(exp.authority, + got.authority) + + +def match_additional(exp, got): + return compare_rrs(exp.additional, + got.additional) + + +def match_edns(exp, got): + if got.edns != exp.edns: + raise DataMismatch(exp.edns, + got.edns) + if got.payload != exp.payload: + raise DataMismatch(exp.payload, + got.payload) + + +def match_nsid(exp, got): + nsid_opt = None + for opt in exp.options: + if opt.otype == dns.edns.NSID: + nsid_opt = opt + break + # Find matching NSID + for opt in got.options: + if opt.otype == dns.edns.NSID: + if not nsid_opt: + raise DataMismatch(None, opt.data) + if opt == nsid_opt: + return True + else: + raise DataMismatch(nsid_opt.data, opt.data) + if nsid_opt: + raise DataMismatch(nsid_opt.data, None) + return True + + +MATCH = {"opcode": match_opcode, "qtype": match_qtype, "qname": match_qname, "qcase": match_qcase, + "subdomain": match_subdomain, "flags": match_flags, "rcode": match_rcode, + "answer": match_answer, "answertypes": match_answertypes, + "answerrrsigs": match_answerrrsigs, "authority": match_authority, + "additional": match_additional, "edns": match_edns, + "nsid": match_nsid} + + +def match_part(exp, got, code): + try: + return MATCH[code](exp, got) + except KeyError: + raise NotImplementedError('unknown match request "%s"' % code) diff --git a/tests/deckard/pydnstest/scenario.py b/tests/deckard/pydnstest/scenario.py new file mode 100644 index 0000000..5e0661b --- /dev/null +++ b/tests/deckard/pydnstest/scenario.py @@ -0,0 +1,1058 @@ +# FIXME pylint: disable=too-many-lines +from abc import ABC +import binascii +import calendar +from datetime import datetime +import errno +import logging +import os +import posixpath +import random +import socket +import string +import struct +import time +from typing import Optional + +import dns.dnssec +import dns.message +import dns.name +import dns.rcode +import dns.rrset +import dns.tsigkeyring + +import pydnstest.augwrap +import pydnstest.matchpart + + +def str2bool(v): + """ Return conversion of JSON-ish string value to boolean. """ + return v.lower() in ('yes', 'true', 'on', '1') + + +# Global statistics +g_rtt = 0.0 +g_nqueries = 0 + + +def recvfrom_msg(stream, raw=False): + """ + Receive DNS message from TCP/UDP socket. + + Returns: + if raw == False: (DNS message object, peer address) + if raw == True: (blob, peer address) + """ + if stream.type & socket.SOCK_DGRAM: + data, addr = stream.recvfrom(4096) + elif stream.type & socket.SOCK_STREAM: + data = stream.recv(2) + if not data: + return None, None + msg_len = struct.unpack_from("!H", data)[0] + data = b"" + received = 0 + while received < msg_len: + next_chunk = stream.recv(4096) + if not next_chunk: + return None, None + data += next_chunk + received += len(next_chunk) + addr = stream.getpeername()[0] + else: + raise NotImplementedError("[recvfrom_msg]: unknown socket type '%i'" % stream.type) + if raw: + return data, addr + else: + msg = dns.message.from_wire(data, one_rr_per_rrset=True) + return msg, addr + + +def sendto_msg(stream, message, addr=None): + """ Send DNS/UDP/TCP message. """ + try: + if stream.type & socket.SOCK_DGRAM: + if addr is None: + stream.send(message) + else: + stream.sendto(message, addr) + elif stream.type & socket.SOCK_STREAM: + data = struct.pack("!H", len(message)) + message + stream.send(data) + else: + raise NotImplementedError("[sendto_msg]: unknown socket type '%i'" % stream.type) + except socket.error as ex: + if ex.errno != errno.ECONNREFUSED: # TODO Investigate how this can happen + raise + + +def replay_rrs(rrs, nqueries, destination, args=None): + """ Replay list of queries and report statistics. """ + if args is None: + args = [] + navail, queries = len(rrs), [] + chunksize = 16 + for i in range(nqueries if 'RAND' in args else navail): + rr = rrs[i % navail] + name = rr.name + if 'RAND' in args: + prefix = ''.join([random.choice(string.ascii_letters + string.digits) + for _ in range(8)]) + name = prefix + '.' + rr.name.to_text() + msg = dns.message.make_query(name, rr.rdtype, rr.rdclass) + if 'DO' in args: + msg.want_dnssec(True) + queries.append(msg.to_wire()) + # Make a UDP connected socket to the destination + family = socket.AF_INET6 if ':' in destination[0] else socket.AF_INET + sock = socket.socket(family, socket.SOCK_DGRAM) + sock.connect(destination) + sock.setblocking(False) + # Play the query set + # @NOTE: this is only good for relative low-speed replay + rcvbuf = bytearray('\x00' * 512) + nsent, nrcvd, nwait, navail = 0, 0, 0, len(queries) + fdset = [sock] + import select + while nsent - nwait < nqueries: + to_read, to_write, _ = select.select(fdset, fdset if nwait < chunksize else [], [], 0.5) + if to_write: + try: + while nsent < nqueries and nwait < chunksize: + sock.send(queries[nsent % navail]) + nwait += 1 + nsent += 1 + except socket.error: + pass # EINVAL + if to_read: + try: + while nwait > 0: + sock.recv_into(rcvbuf) + nwait -= 1 + nrcvd += 1 + except socket.error: + pass + if not to_write and not to_read: + nwait = 0 # Timeout, started dropping packets + break + return nsent, nrcvd + + +class DNSBlob(ABC): + def to_wire(self) -> bytes: + raise NotImplementedError + + def __str__(self) -> str: + return '<DNSBlob>' + + +class DNSMessage(DNSBlob): + def __init__(self, message: dns.message.Message) -> None: + assert message is not None + self.message = message + + def to_wire(self) -> bytes: + return self.message.to_wire(max_size=65535) + + def __str__(self) -> str: + return str(self.message) + + +class DNSReply(DNSMessage): + def __init__( + self, + message: dns.message.Message, + query: Optional[dns.message.Message] = None, + copy_id: bool = False, + copy_query: bool = False + ) -> None: + super().__init__(message) + if copy_id or copy_query: + if query is None: + raise ValueError("query must be provided to adjust copy_id/copy_query") + self.adjust_reply(query, copy_id, copy_query) + + def adjust_reply( + self, + query: dns.message.Message, + copy_id: bool = True, + copy_query: bool = True + ) -> None: + answer = dns.message.from_wire(self.message.to_wire(), + xfr=self.message.xfr, + one_rr_per_rrset=True) + answer.use_edns(query.edns, query.ednsflags, options=self.message.options) + if copy_id: + answer.id = query.id + # Copy letter-case if the template has QD + if answer.question: + answer.question[0].name = query.question[0].name + if copy_query: + answer.question = query.question + # Re-set, as the EDNS might have reset the ext-rcode + answer.set_rcode(self.message.rcode()) + + # sanity check: adjusted answer should be almost the same + assert len(answer.answer) == len(self.message.answer) + assert len(answer.authority) == len(self.message.authority) + assert len(answer.additional) == len(self.message.additional) + self.message = answer + + +class DNSReplyRaw(DNSBlob): + def __init__( + self, + wire: bytes, + query: Optional[dns.message.Message] = None, + copy_id: bool = False + ) -> None: + assert wire is not None + self.wire = wire + if copy_id: + self.adjust_reply(query, copy_id) + + def adjust_reply( + self, + query: dns.message.Message, + copy_id: bool = True + ) -> None: + if copy_id: + if len(self.wire) < 2: + raise ValueError( + 'wire data must contain at least 2 bytes to adjust query id') + raw_answer = bytearray(self.wire) + struct.pack_into('!H', raw_answer, 0, query.id) + self.wire = bytes(raw_answer) + + def to_wire(self) -> bytes: + return self.wire + + def __str__(self) -> str: + return '<DNSReplyRaw>' + + +class DNSReplyServfail(DNSMessage): + def __init__(self, query: dns.message.Message) -> None: + message = dns.message.make_response(query) + message.set_rcode(dns.rcode.SERVFAIL) + super().__init__(message) + + +class Entry: + """ + Data entry represents scripted message and extra metadata, + notably match criteria and reply adjustments. + """ + + # Globals + default_ttl = 3600 + default_cls = 'IN' + default_rc = 'NOERROR' + + def __init__(self, node): + """ Initialize data entry. """ + self.node = node + self.origin = '.' + self.message = dns.message.Message() + self.message.use_edns(edns=0, payload=4096) + self.fired = 0 + + # RAW + self.raw_data = None # type: Optional[bytes] + self.is_raw_data_entry = self.process_raw() + + # MATCH + self.match_fields = self.process_match() + + # FLAGS + self.process_reply_line() + + # ADJUST + self.adjust_fields = {m.value for m in node.match("/adjust")} + + # MANDATORY + try: + self.mandatory = list(node.match("/mandatory"))[0] + except (KeyError, IndexError): + self.mandatory = None + + # TSIG + self.process_tsig() + + # SECTIONS & RECORDS + self.sections = self.process_sections() + + def process_raw(self): + try: + self.raw_data = binascii.unhexlify(self.node["/raw"].value) + return True + except KeyError: + return False + + def process_match(self): + try: + self.node["/match_present"] + except KeyError: + return None + + fields = set(m.value for m in self.node.match("/match")) + + if 'all' in fields: + fields.remove("all") + fields |= set(["opcode", "qtype", "qname", "flags", + "rcode", "answer", "authority", "additional"]) + + if 'question' in fields: + fields.remove("question") + fields |= set(["qtype", "qname"]) + + return fields + + def process_reply_line(self): + """Extracts flags, rcode and opcode from given node and adjust dns message accordingly""" + self.fields = [f.value for f in self.node.match("/reply")] + if 'DO' in self.fields: + self.message.want_dnssec(True) + opcode = self.get_opcode(fields=self.fields) + rcode = self.get_rcode(fields=self.fields) + self.message.flags = self.get_flags(fields=self.fields) + if rcode is not None: + self.message.set_rcode(rcode) + if opcode is not None: + self.message.set_opcode(opcode) + + def process_tsig(self): + try: + tsig = list(self.node.match("/tsig"))[0] + tsig_keyname = tsig["/keyname"].value + tsig_secret = tsig["/secret"].value + keyring = dns.tsigkeyring.from_text({tsig_keyname: tsig_secret}) + self.message.use_tsig(keyring=keyring, keyname=tsig_keyname) + except (KeyError, IndexError): + pass + + def process_sections(self): + sections = set() + for section in self.node.match("/section/*"): + section_name = posixpath.basename(section.path) + sections.add(section_name) + for record in section.match("/record"): + owner = record['/domain'].value + if not owner.endswith("."): + owner += self.origin + try: + ttl = dns.ttl.from_text(record['/ttl'].value) + except KeyError: + ttl = self.default_ttl + try: + rdclass = dns.rdataclass.from_text(record['/class'].value) + except KeyError: + rdclass = dns.rdataclass.from_text(self.default_cls) + rdtype = dns.rdatatype.from_text(record['/type'].value) + rr = dns.rrset.from_text(owner, ttl, rdclass, rdtype) + if section_name != "question": + rd = record['/data'].value.split() + if rd: + if rdtype == dns.rdatatype.DS: + rd[1] = str(dns.dnssec.algorithm_from_text(rd[1])) + rd = dns.rdata.from_text(rr.rdclass, rr.rdtype, ' '.join( + rd), origin=dns.name.from_text(self.origin), relativize=False) + rr.add(rd) + if section_name == 'question': + if rr.rdtype == dns.rdatatype.AXFR: + self.message.xfr = True + self.message.question.append(rr) + elif section_name == 'answer': + self.message.answer.append(rr) + elif section_name == 'authority': + self.message.authority.append(rr) + elif section_name == 'additional': + self.message.additional.append(rr) + return sections + + def __str__(self): + txt = 'ENTRY_BEGIN\n' + if not self.is_raw_data_entry: + txt += 'MATCH {0}\n'.format(' '.join(self.match_fields)) + txt += 'ADJUST {0}\n'.format(' '.join(self.adjust_fields)) + txt += 'REPLY {rcode} {flags}\n'.format( + rcode=dns.rcode.to_text(self.message.rcode()), + flags=' '.join([dns.flags.to_text(self.message.flags), + dns.flags.edns_to_text(self.message.ednsflags)]) + ) + for sect_name in ['question', 'answer', 'authority', 'additional']: + sect = getattr(self.message, sect_name) + if not sect: + continue + txt += 'SECTION {n}\n'.format(n=sect_name.upper()) + for rr in sect: + txt += str(rr) + txt += '\n' + if self.is_raw_data_entry: + txt += 'RAW\n' + if self.raw_data: + txt += binascii.hexlify(self.raw_data) + else: + txt += 'NULL' + txt += '\n' + txt += 'ENTRY_END\n' + return txt + + @classmethod + def get_flags(cls, fields): + """From `fields` extracts and returns flags""" + flags = [] + for code in fields: + try: + dns.flags.from_text(code) # throws KeyError on failure + flags.append(code) + except KeyError: + pass + return dns.flags.from_text(' '.join(flags)) + + @classmethod + def get_rcode(cls, fields): + """ + From `fields` extracts and returns rcode. + Throws `ValueError` if there are more then one rcodes + """ + rcodes = [] + for code in fields: + try: + rcodes.append(dns.rcode.from_text(code)) + except dns.rcode.UnknownRcode: + pass + if len(rcodes) > 1: + raise ValueError("Parse failed, too many rcode values.", rcodes) + if not rcodes: + return None + return rcodes[0] + + @classmethod + def get_opcode(cls, fields): + """ + From `fields` extracts and returns opcode. + Throws `ValueError` if there are more then one opcodes + """ + opcodes = [] + for code in fields: + try: + opcodes.append(dns.opcode.from_text(code)) + except dns.opcode.UnknownOpcode: + pass + if len(opcodes) > 1: + raise ValueError("Parse failed, too many opcode values.") + if not opcodes: + return None + return opcodes[0] + + def match(self, msg): + """ Compare scripted reply to given message based on match criteria. """ + for code in self.match_fields: + try: + pydnstest.matchpart.match_part(self.message, msg, code) + except pydnstest.matchpart.DataMismatch as ex: + errstr = '%s in the response:\n%s' % (str(ex), msg.to_text()) + # TODO: cisla radku + raise ValueError("%s, \"%s\": %s" % (self.node.span, code, errstr)) + + def cmp_raw(self, raw_value): + assert self.is_raw_data_entry + expected = None + if self.raw_data is not None: + expected = binascii.hexlify(self.raw_data) + got = None + if raw_value is not None: + got = binascii.hexlify(raw_value) + if expected != got: + raise ValueError("raw message comparsion failed: expected %s got %s" % (expected, got)) + + def reply(self, query) -> Optional[DNSBlob]: + if 'do_not_answer' in self.adjust_fields: + return None + if self.is_raw_data_entry: + copy_id = 'raw_data' in self.adjust_fields + assert self.raw_data is not None + return DNSReplyRaw(self.raw_data, query, copy_id) + copy_id = 'copy_id' in self.adjust_fields + copy_query = 'copy_query' in self.adjust_fields + return DNSReply(self.message, query, copy_id, copy_query) + + def set_edns(self, fields): + """ Set EDNS version and bufsize. """ + version = 0 + bufsize = 4096 + if fields and fields[0].isdigit(): + version = int(fields.pop(0)) + if fields and fields[0].isdigit(): + bufsize = int(fields.pop(0)) + if bufsize == 0: + self.message.use_edns(False) + return + opts = [] + for v in fields: + k, v = tuple(v.split('=')) if '=' in v else (v, True) + if k.lower() == 'nsid': + opts.append(dns.edns.GenericOption(dns.edns.NSID, '' if v is True else v)) + if k.lower() == 'subnet': + net = v.split('/') + subnet_addr = net[0] + family = socket.AF_INET6 if ':' in subnet_addr else socket.AF_INET + addr = socket.inet_pton(family, subnet_addr) + prefix = len(addr) * 8 + if len(net) > 1: + prefix = int(net[1]) + addr = addr[0: (prefix + 7) / 8] + if prefix % 8 != 0: # Mask the last byte + addr = addr[:-1] + chr(ord(addr[-1]) & 0xFF << (8 - prefix % 8)) + opts.append(dns.edns.GenericOption(8, struct.pack( + "!HBB", 1 if family == socket.AF_INET else 2, prefix, 0) + addr)) + self.message.use_edns(edns=version, payload=bufsize, options=opts) + + +class Range: + """ + Range represents a set of scripted queries valid for given step range. + """ + log = logging.getLogger('pydnstest.scenario.Range') + + def __init__(self, node): + """ Initialize reply range. """ + self.node = node + self.a = int(node['/from'].value) + self.b = int(node['/to'].value) + assert self.a <= self.b + + address = node["/address"].value + self.addresses = {address} if address is not None else set() + self.addresses |= {a.value for a in node.match("/address/*")} + self.stored = [Entry(n) for n in node.match("/entry")] + self.args = {} + self.received = 0 + self.sent = 0 + + def __del__(self): + self.log.info('[ RANGE %d-%d ] %s received: %d sent: %d', + self.a, self.b, self.addresses, self.received, self.sent) + + def __str__(self): + txt = '\nRANGE_BEGIN {a} {b}\n'.format(a=self.a, b=self.b) + for addr in self.addresses: + txt += ' ADDRESS {0}\n'.format(addr) + + for entry in self.stored: + txt += '\n' + txt += str(entry) + txt += 'RANGE_END\n\n' + return txt + + def eligible(self, ident, address): + """ Return true if this range is eligible for fetching reply. """ + if self.a <= ident <= self.b: + return (None is address + or set() == self.addresses + or address in self.addresses) + return False + + def reply(self, query: dns.message.Message) -> Optional[DNSBlob]: + """Get answer for given query (adjusted if needed).""" + self.received += 1 + for candidate in self.stored: + try: + candidate.match(query) + resp = candidate.reply(query) + # Probabilistic loss + if 'LOSS' in self.args: + if random.random() < float(self.args['LOSS']): + return DNSReplyServfail(query) + self.sent += 1 + candidate.fired += 1 + return resp + except ValueError: + pass + return DNSReplyServfail(query) + + +class StepLogger(logging.LoggerAdapter): # pylint: disable=too-few-public-methods + """ + Prepent Step identification before each log message. + """ + def process(self, msg, kwargs): + return '[STEP %s %s] %s' % (self.extra['id'], self.extra['type'], msg), kwargs + + +class Step: + """ + Step represents one scripted action in a given moment, + each step has an order identifier, type and optionally data entry. + """ + require_data = ['QUERY', 'CHECK_ANSWER', 'REPLY'] + + def __init__(self, node): + """ Initialize single scenario step. """ + self.node = node + self.id = int(node.value) + self.type = node["/type"].value + self.log = StepLogger(logging.getLogger('pydnstest.scenario.Step'), + {'id': self.id, 'type': self.type}) + try: + self.delay = int(node["/timestamp"].value) + except KeyError: + pass + self.data = [Entry(n) for n in node.match("/entry")] + self.queries = [] + self.has_data = self.type in Step.require_data + self.answer = None + self.raw_answer = None + self.repeat_if_fail = 0 + self.pause_if_fail = 0 + self.next_if_fail = -1 + + # TODO Parser currently can't parse CHECK_ANSWER args, player doesn't understand them anyway + # if type == 'CHECK_ANSWER': + # for arg in extra_args: + # param = arg.split('=') + # try: + # if param[0] == 'REPEAT': + # self.repeat_if_fail = int(param[1]) + # elif param[0] == 'PAUSE': + # self.pause_if_fail = float(param[1]) + # elif param[0] == 'NEXT': + # self.next_if_fail = int(param[1]) + # except Exception as e: + # raise Exception('step %d - wrong %s arg: %s' % (self.id, param[0], str(e))) + + def __str__(self): + txt = '\nSTEP {i} {t}'.format(i=self.id, t=self.type) + if self.repeat_if_fail: + txt += ' REPEAT {v}'.format(v=self.repeat_if_fail) + elif self.pause_if_fail: + txt += ' PAUSE {v}'.format(v=self.pause_if_fail) + elif self.next_if_fail != -1: + txt += ' NEXT {v}'.format(v=self.next_if_fail) + # if self.args: + # txt += ' ' + # txt += ' '.join(self.args) + txt += '\n' + + for data in self.data: + # from IPython.core.debugger import Tracer + # Tracer()() + txt += str(data) + return txt + + def play(self, ctx): + """ Play one step from a scenario. """ + if self.type == 'QUERY': + self.log.info('') + self.log.debug(self.data[0].message.to_text()) + # Parse QUERY-specific parameters + choice, tcp, source = None, False, None + return self.__query(ctx, tcp=tcp, choice=choice, source=source) + elif self.type == 'CHECK_OUT_QUERY': # ignore + self.log.info('') + return None + elif self.type == 'CHECK_ANSWER' or self.type == 'ANSWER': + self.log.info('') + return self.__check_answer(ctx) + elif self.type == 'TIME_PASSES ELAPSE': + self.log.info('') + return self.__time_passes() + elif self.type == 'REPLY' or self.type == 'MOCK': + self.log.info('') + return None + # Parser currently doesn't support step types LOG, REPLAY and ASSERT. + # No test uses them. + # elif self.type == 'LOG': + # if not ctx.log: + # raise Exception('scenario has no log interface') + # return ctx.log.match(self.args) + # elif self.type == 'REPLAY': + # self.__replay(ctx) + # elif self.type == 'ASSERT': + # self.__assert(ctx) + else: + raise NotImplementedError('step %03d type %s unsupported' % (self.id, self.type)) + + def __check_answer(self, ctx): + """ Compare answer from previously resolved query. """ + if not self.data: + raise ValueError("response definition required") + expected = self.data[0] + if expected.is_raw_data_entry is True: + self.log.debug("raw answer: %s", ctx.last_raw_answer.to_text()) + expected.cmp_raw(ctx.last_raw_answer) + else: + if ctx.last_answer is None: + raise ValueError("no answer from preceding query") + self.log.debug("answer: %s", ctx.last_answer.to_text()) + expected.match(ctx.last_answer) + + # def __replay(self, ctx, chunksize=8): + # nqueries = len(self.queries) + # if len(self.args) > 0 and self.args[0].isdigit(): + # nqueries = int(self.args.pop(0)) + # destination = ctx.client[ctx.client.keys()[0]] + # self.log.info('replaying %d queries to %s@%d (%s)', + # nqueries, destination[0], destination[1], ' '.join(self.args)) + # if 'INTENSIFY' in os.environ: + # nqueries *= int(os.environ['INTENSIFY']) + # tstart = datetime.now() + # nsent, nrcvd = replay_rrs(self.queries, nqueries, destination, self.args) + # # Keep/print the statistics + # rtt = (datetime.now() - tstart).total_seconds() * 1000 + # pps = 1000 * nrcvd / rtt + # self.log.debug('sent: %d, received: %d (%d ms, %d p/s)', nsent, nrcvd, rtt, pps) + # tag = None + # for arg in self.args: + # if arg.upper().startswith('PRINT'): + # _, tag = tuple(arg.split('=')) if '=' in arg else (None, 'replay') + # if tag: + # self.log.info('[ REPLAY ] test: %s pps: %5d time: %4d sent: %5d received: %5d', + # tag.ljust(11), pps, rtt, nsent, nrcvd) + + def __query(self, ctx, tcp=False, choice=None, source=None): + """ + Send query and wait for an answer (if the query is not RAW). + + The received answer is stored in self.answer and ctx.last_answer. + """ + if not self.data: + raise ValueError("query definition required") + if self.data[0].is_raw_data_entry is True: + data_to_wire = self.data[0].raw_data + else: + # Don't use a message copy as the EDNS data portion is not copied. + data_to_wire = self.data[0].message.to_wire() + if choice is None or not choice: + choice = list(ctx.client.keys())[0] + if choice not in ctx.client: + raise ValueError('step %03d invalid QUERY target: %s' % (self.id, choice)) + # Create socket to test subject + sock = None + destination = ctx.client[choice] + family = socket.AF_INET6 if ':' in destination[0] else socket.AF_INET + sock = socket.socket(family, socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if tcp: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True) + sock.settimeout(3) + if source: + sock.bind((source, 0)) + sock.connect(destination) + # Send query to client and wait for response + tstart = datetime.now() + while True: + try: + sendto_msg(sock, data_to_wire) + break + except OSError as ex: + # ENOBUFS, throttle sending + if ex.errno == errno.ENOBUFS: + time.sleep(0.1) + # Wait for a response for a reasonable time + answer = None + if not self.data[0].is_raw_data_entry: + while True: + if (datetime.now() - tstart).total_seconds() > 5: + raise RuntimeError("Server took too long to respond") + try: + answer, _ = recvfrom_msg(sock, True) + break + except OSError as ex: + if ex.errno == errno.ENOBUFS: + time.sleep(0.1) + # Track RTT + rtt = (datetime.now() - tstart).total_seconds() * 1000 + global g_rtt, g_nqueries + g_nqueries += 1 + g_rtt += rtt + # Remember last answer for checking later + self.raw_answer = answer + ctx.last_raw_answer = answer + if self.raw_answer is not None: + self.answer = dns.message.from_wire(self.raw_answer, one_rr_per_rrset=True) + else: + self.answer = None + ctx.last_answer = self.answer + + def __time_passes(self): + """ Modify system time. """ + file_old = os.environ["FAKETIME_TIMESTAMP_FILE"] + file_next = os.environ["FAKETIME_TIMESTAMP_FILE"] + ".next" + with open(file_old, 'r') as time_file: + line = time_file.readline().strip() + t = time.mktime(datetime.strptime(line, '@%Y-%m-%d %H:%M:%S').timetuple()) + t += self.delay + with open(file_next, 'w') as time_file: + time_file.write(datetime.fromtimestamp(t).strftime('@%Y-%m-%d %H:%M:%S') + "\n") + time_file.flush() + os.replace(file_next, file_old) + + # def __assert(self, ctx): + # """ Assert that a passed expression evaluates to True. """ + # result = eval(' '.join(self.args), {'SCENARIO': ctx, 'RANGE': ctx.ranges}) + # # Evaluate subexpressions for clarity + # subexpr = [] + # for expr in self.args: + # try: + # ee = eval(expr, {'SCENARIO': ctx, 'RANGE': ctx.ranges}) + # subexpr.append(str(ee)) + # except: + # subexpr.append(expr) + # assert result is True, '"%s" assertion fails (%s)' % ( + # ' '.join(self.args), ' '.join(subexpr)) + + +class Scenario: + log = logging.getLogger('pydnstest.scenatio.Scenario') + + def __init__(self, node, filename): + """ Initialize scenario with description. """ + self.node = node + self.info = node.value + self.file = filename + self.ranges = [Range(n) for n in node.match("/range")] + self.current_range = None + self.steps = [Step(n) for n in node.match("/step")] + self.current_step = None + self.client = {} + + def __str__(self): + txt = 'SCENARIO_BEGIN' + if self.info: + txt += ' {0}'.format(self.info) + txt += '\n' + for range_ in self.ranges: + txt += str(range_) + for step in self.steps: + txt += str(step) + txt += "\nSCENARIO_END" + return txt + + def reply(self, query: dns.message.Message, address=None) -> Optional[DNSBlob]: + """Generate answer packet for given query.""" + current_step_id = self.current_step.id + # Unknown address, select any match + # TODO: workaround until the server supports stub zones + all_addresses = set() # type: ignore + for rng in self.ranges: + all_addresses.update(rng.addresses) + if address not in all_addresses: + address = None + # Find current valid query response range + for rng in self.ranges: + if rng.eligible(current_step_id, address): + self.current_range = rng + return rng.reply(query) + # Find any prescripted one-shot replies + for step in self.steps: + if step.id < current_step_id or step.type != 'REPLY': + continue + try: + candidate = step.data[0] + candidate.match(query) + step.data.remove(candidate) + return candidate.reply(query) + except (IndexError, ValueError): + pass + return DNSReplyServfail(query) + + def play(self, paddr): + """ Play given scenario. """ + # Store test subject => address mapping + self.client = paddr + + step = None + i = 0 + while i < len(self.steps): + step = self.steps[i] + self.current_step = step + try: + step.play(self) + except ValueError as ex: + if step.repeat_if_fail > 0: + self.log.info("[play] step %d: exception - '%s', retrying step %d (%d left)", + step.id, ex, step.next_if_fail, step.repeat_if_fail) + step.repeat_if_fail -= 1 + if step.pause_if_fail > 0: + time.sleep(step.pause_if_fail) + if step.next_if_fail != -1: + next_steps = [j for j in range(len(self.steps)) if self.steps[ + j].id == step.next_if_fail] + if not next_steps: + raise ValueError('step %d: wrong NEXT value "%d"' % + (step.id, step.next_if_fail)) + next_step = next_steps[0] + if next_step < len(self.steps): + i = next_step + else: + raise ValueError('step %d: Can''t branch to NEXT value "%d"' % + (step.id, step.next_if_fail)) + continue + else: + raise ValueError('%s step %d %s' % (self.file, step.id, str(ex))) + i += 1 + + for r in self.ranges: + for e in r.stored: + if e.mandatory and e.fired == 0: + # TODO: cisla radku + raise ValueError('Mandatory section at %s not fired' % e.mandatory.span) + + +def get_next(file_in, skip_empty=True): + """ Return next token from the input stream. """ + while True: + line = file_in.readline() + if not line: + return False + quoted, escaped = False, False + for i, char in enumerate(line): + if char == '\\': + escaped = not escaped + if not escaped and char == '"': + quoted = not quoted + if char == ';' and not quoted: + line = line[0:i] + break + if char != '\\': + escaped = False + tokens = ' '.join(line.strip().split()).split() + if not tokens: + if skip_empty: + continue + else: + return '', [] + op = tokens.pop(0) + return op, tokens + + +def parse_config(scn_cfg, qmin, installdir): # FIXME: pylint: disable=too-many-statements + """ + Transform scene config (key, value) pairs into dict filled with defaults. + Returns tuple: + context dict: {Jinja2 variable: value} + trust anchor dict: {domain: [TA lines for particular domain]} + """ + # defaults + do_not_query_localhost = True + harden_glue = True + sockfamily = 0 # auto-select value for socket.getaddrinfo + trust_anchor_list = [] + trust_anchor_files = {} + negative_ta_list = [] + stub_addr = None + override_timestamp = None + + features = {} + feature_list_delimiter = ';' + feature_pair_delimiter = '=' + + for k, v in scn_cfg: + # Enable selectively for some tests + if k == 'do-not-query-localhost': + do_not_query_localhost = str2bool(v) + elif k == 'domain-insecure': + negative_ta_list.append(v) + elif k == 'harden-glue': + harden_glue = str2bool(v) + elif k == 'query-minimization': + qmin = str2bool(v) + elif k == 'trust-anchor': + trust_anchor = v.strip('"\'') + trust_anchor_list.append(trust_anchor) + domain = dns.name.from_text(trust_anchor.split()[0]).canonicalize() + if domain not in trust_anchor_files: + trust_anchor_files[domain] = [] + trust_anchor_files[domain].append(trust_anchor) + elif k == 'val-override-timestamp': + override_timestamp_str = v.strip('"\'') + override_timestamp = int(override_timestamp_str) + elif k == 'val-override-date': + override_date_str = v.strip('"\'') + ovr_yr = override_date_str[0:4] + ovr_mnt = override_date_str[4:6] + ovr_day = override_date_str[6:8] + ovr_hr = override_date_str[8:10] + ovr_min = override_date_str[10:12] + ovr_sec = override_date_str[12:] + override_date_str_arg = '{0} {1} {2} {3} {4} {5}'.format( + ovr_yr, ovr_mnt, ovr_day, ovr_hr, ovr_min, ovr_sec) + override_date = time.strptime(override_date_str_arg, "%Y %m %d %H %M %S") + override_timestamp = calendar.timegm(override_date) + elif k == 'stub-addr': + stub_addr = v.strip('"\'') + elif k == 'features': + feature_list = v.split(feature_list_delimiter) + try: + for f_item in feature_list: + if f_item.find(feature_pair_delimiter) != -1: + f_key, f_value = [x.strip() + for x + in f_item.split(feature_pair_delimiter, 1)] + else: + f_key = f_item.strip() + f_value = "" + features[f_key] = f_value + except KeyError as ex: + raise KeyError("can't parse features (%s) in config section (%s)" % (v, str(ex))) + elif k == 'feature-list': + try: + f_key, f_value = [x.strip() for x in v.split(feature_pair_delimiter, 1)] + if f_key not in features: + features[f_key] = [] + f_value = f_value.replace("{{INSTALL_DIR}}", installdir) + features[f_key].append(f_value) + except KeyError as ex: + raise KeyError("can't parse feature-list (%s) in config section (%s)" + % (v, str(ex))) + elif k == 'force-ipv6' and v.upper() == 'TRUE': + sockfamily = socket.AF_INET6 + else: + raise NotImplementedError('unsupported CONFIG key "%s"' % k) + + ctx = { + "DO_NOT_QUERY_LOCALHOST": str(do_not_query_localhost).lower(), + "NEGATIVE_TRUST_ANCHORS": negative_ta_list, + "FEATURES": features, + "HARDEN_GLUE": str(harden_glue).lower(), + "INSTALL_DIR": installdir, + "QMIN": str(qmin).lower(), + "TRUST_ANCHORS": trust_anchor_list, + "TRUST_ANCHOR_FILES": trust_anchor_files.keys() + } + if stub_addr: + ctx['ROOT_ADDR'] = stub_addr + # determine and verify socket family for specified root address + gai = socket.getaddrinfo(stub_addr, 53, sockfamily, 0, + socket.IPPROTO_UDP, socket.AI_NUMERICHOST) + assert len(gai) == 1 + sockfamily = gai[0][0] + if not sockfamily: + sockfamily = socket.AF_INET # default to IPv4 + ctx['_SOCKET_FAMILY'] = sockfamily + if override_timestamp: + ctx['_OVERRIDE_TIMESTAMP'] = override_timestamp + return (ctx, trust_anchor_files) + + +def parse_file(path): + """ Parse scenario from a file. """ + + aug = pydnstest.augwrap.AugeasWrapper( + confpath=path, lens='Deckard', loadpath=os.path.dirname(__file__)) + node = aug.tree + config = [] + for line in [c.value for c in node.match("/config/*")]: + if line: + if not line.startswith(';'): + if '#' in line: + line = line[0:line.index('#')] + # Break to key-value pairs + # e.g.: ['minimization', 'on'] + kv = [x.strip() for x in line.split(':', 1)] + if len(kv) >= 2: + config.append(kv) + scenario = Scenario(node["/scenario"], posixpath.basename(node.path)) + return scenario, config diff --git a/tests/deckard/pydnstest/tests/__init__.py b/tests/deckard/pydnstest/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/deckard/pydnstest/tests/__init__.py diff --git a/tests/deckard/pydnstest/tests/test_parse_config.py b/tests/deckard/pydnstest/tests/test_parse_config.py new file mode 100644 index 0000000..0668760 --- /dev/null +++ b/tests/deckard/pydnstest/tests/test_parse_config.py @@ -0,0 +1,17 @@ +""" This is unittest file for parse methods in scenario.py """ +import os + +from pydnstest.scenario import parse_config + + +def test_parse_config__trust_anchor(): + """Checks if trust-anchors are separated into files according to domain.""" + anchor1 = u'domain1.com.\t3600\tIN\tDS\t11901 7 1 aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa' + anchor2 = u'domain2.net.\t3600\tIN\tDS\t59835 7 1 cccccccccccccccccccccccccccccccccccccccc' + anchor3 = u'domain1.com.\t3600\tIN\tDS\t11902 7 1 1111111111111111111111111111111111111111' + anchors = [[u'trust-anchor', u'"{}"'.format(anchor1)], + [u'trust-anchor', u'"{}"'.format(anchor2)], + [u'trust-anchor', u'"{}"'.format(anchor3)]] + args = (anchors, True, os.getcwd()) + _, ta_files = parse_config(*args) + assert sorted(ta_files.values()) == sorted([[anchor1, anchor3], [anchor2]]) diff --git a/tests/deckard/pydnstest/tests/test_scenario.py b/tests/deckard/pydnstest/tests/test_scenario.py new file mode 100644 index 0000000..454cb5c --- /dev/null +++ b/tests/deckard/pydnstest/tests/test_scenario.py @@ -0,0 +1,55 @@ +""" This is unittest file for scenario.py """ + +import pytest + +from pydnstest.scenario import Entry + +RCODE_FLAGS = ['NOERROR', 'FORMERR', 'SERVFAIL', 'NXDOMAIN', 'NOTIMP', 'REFUSED', 'YXDOMAIN', + 'YXRRSET', 'NXRRSET', 'NOTAUTH', 'NOTZONE', 'BADVERS'] +OPCODE_FLAGS = ['QUERY', 'IQUERY', 'STATUS', 'NOTIFY', 'UPDATE'] +FLAGS = ['QR', 'TC', 'AA', 'AD', 'RD', 'RA', 'CD'] + + +def test_entry__get_flags(): + """Checks if all rcodes and opcodes are filtered out""" + expected_flags = Entry.get_flags(FLAGS) + for flag in RCODE_FLAGS + OPCODE_FLAGS: + rcode_flags = Entry.get_flags(FLAGS + [flag]) + assert rcode_flags == expected_flags, \ + 'Entry._get_flags does not filter out "{flag}"'.format(flag=flag) + + +def test_entry__get_rcode(): + """ + Checks if the error is raised for multiple rcodes + checks if None is returned for no rcode + checks if flags and opcode are filtered out + """ + with pytest.raises(ValueError): + Entry.get_rcode(RCODE_FLAGS[:2]) + + assert Entry.get_rcode(FLAGS) is None + assert Entry.get_rcode([]) is None + + for rcode in RCODE_FLAGS: + given_rcode = Entry.get_rcode(FLAGS + OPCODE_FLAGS + [rcode]) + assert given_rcode is not None, 'Entry.get_rcode does not recognize {rcode}'.format( + rcode=rcode) + + +def test_entry__get_opcode(): + """ + Checks if the error is raised for multiple opcodes + checks if None is returned for no opcode + checks if flags and opcode are filtered out + """ + with pytest.raises(ValueError): + Entry.get_opcode(OPCODE_FLAGS[:2]) + + assert Entry.get_opcode(FLAGS) is None + assert Entry.get_opcode([]) is None + + for opcode in OPCODE_FLAGS: + given_rcode = Entry.get_opcode(FLAGS + RCODE_FLAGS + [opcode]) + assert given_rcode is not None, 'Entry.get_opcode does not recognize {opcode}'.format( + opcode=opcode) diff --git a/tests/deckard/pydnstest/testserver.py b/tests/deckard/pydnstest/testserver.py new file mode 100644 index 0000000..8767644 --- /dev/null +++ b/tests/deckard/pydnstest/testserver.py @@ -0,0 +1,278 @@ +import argparse +import itertools +import logging +import os +import signal +import selectors +import socket +import sys +import threading +import time + +import dns.message +import dns.rdatatype + +from pydnstest import scenario + + +class TestServer: + """ This simulates UDP DNS server returning scripted or mirror DNS responses. """ + + def __init__(self, test_scenario, root_addr, addr_family): + """ Initialize server instance. """ + self.thread = None + self.srv_socks = [] + self.client_socks = [] + self.connections = [] + self.active = False + self.active_lock = threading.Lock() + self.condition = threading.Condition() + self.scenario = test_scenario + self.addr_map = [] + self.start_iface = 2 + self.cur_iface = self.start_iface + self.kroot_local = root_addr + self.addr_family = addr_family + self.undefined_answers = 0 + + def __del__(self): + """ Cleanup after deletion. """ + with self.active_lock: + active = self.active + if active: + self.stop() + + def start(self, port=53): + """ Synchronous start """ + with self.active_lock: + if self.active: + raise Exception('TestServer already started') + with self.active_lock: + self.active = True + addr, _ = self.start_srv((self.kroot_local, port), self.addr_family) + self.start_srv(addr, self.addr_family, socket.IPPROTO_TCP) + self._bind_sockets() + + def stop(self): + """ Stop socket server operation. """ + with self.active_lock: + self.active = False + if self.thread: + self.thread.join() + for conn in self.connections: + conn.close() + for srv_sock in self.srv_socks: + srv_sock.close() + for client_sock in self.client_socks: + client_sock.close() + self.client_socks = [] + self.srv_socks = [] + self.connections = [] + self.scenario = None + + def address(self): + """ Returns opened sockets list """ + addrlist = [] + for s in self.srv_socks: + addrlist.append(s.getsockname()) + return addrlist + + def handle_query(self, client): + """ + Receive query from client socket and send an answer. + + Returns: + True if client socket should be closed by caller + False if client socket should be kept open + """ + log = logging.getLogger('pydnstest.testserver.handle_query') + server_addr = client.getsockname()[0] + query, client_addr = scenario.recvfrom_msg(client) + if query is None: + return False + log.debug('server %s received query from %s: %s', server_addr, client_addr, query) + + message = self.scenario.reply(query, server_addr) + if not message: + log.debug('ignoring') + return True + elif isinstance(message, scenario.DNSReplyServfail): + self.undefined_answers += 1 + self.scenario.current_step.log.error( + 'server %s has no response for question %s, answering with SERVFAIL', + server_addr, + '; '.join([str(rr) for rr in query.question])) + else: + log.debug('response: %s', message) + + scenario.sendto_msg(client, message.to_wire(), client_addr) + return True + + def query_io(self): + """ Main server process """ + self.undefined_answers = 0 + with self.active_lock: + if not self.active: + raise Exception("[query_io] Test server not active") + while True: + with self.condition: + self.condition.notify() + with self.active_lock: + if not self.active: + break + objects = self.srv_socks + self.connections + sel = selectors.DefaultSelector() + for obj in objects: + sel.register(obj, selectors.EVENT_READ) + items = sel.select(0.1) + for key, event in items: + sock = key.fileobj + if event & selectors.EVENT_READ: + if sock in self.srv_socks: + if sock.proto == socket.IPPROTO_TCP: + conn, _ = sock.accept() + self.connections.append(conn) + else: + self.handle_query(sock) + elif sock in self.connections: + if not self.handle_query(sock): + sock.close() + self.connections.remove(sock) + else: + raise Exception( + "[query_io] Socket IO internal error {}, exit" + .format(sock.getsockname())) + else: + raise Exception("[query_io] Socket IO error {}, exit" + .format(sock.getsockname())) + + def start_srv(self, address, family, proto=socket.IPPROTO_UDP): + """ Starts listening thread if necessary """ + assert address + assert address[0] # host + assert address[1] # port + assert family + assert proto + if family == socket.AF_INET6: + if not socket.has_ipv6: + raise NotImplementedError("[start_srv] IPv6 is not supported by socket {0}" + .format(socket)) + elif family != socket.AF_INET: + raise NotImplementedError("[start_srv] unsupported protocol family {0}".format(family)) + + if proto == socket.IPPROTO_TCP: + socktype = socket.SOCK_STREAM + elif proto == socket.IPPROTO_UDP: + socktype = socket.SOCK_DGRAM + else: + raise NotImplementedError("[start_srv] unsupported protocol {0}".format(proto)) + + if self.thread is None: + self.thread = threading.Thread(target=self.query_io) + self.thread.start() + with self.condition: + self.condition.wait() + + for srv_sock in self.srv_socks: + if (srv_sock.family == family + and srv_sock.getsockname() == address + and srv_sock.proto == proto): + return srv_sock.getsockname() + + sock = socket.socket(family, socktype, proto) + sock.bind(address) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if proto == socket.IPPROTO_TCP: + sock.listen(5) + self.srv_socks.append(sock) + sockname = sock.getsockname() + return sockname, proto + + def _bind_sockets(self): + """ + Bind test server to port 53 on all addresses referenced by test scenario. + """ + # Bind to test servers + for r in self.scenario.ranges: + for addr in r.addresses: + family = socket.AF_INET6 if ':' in addr else socket.AF_INET + self.start_srv((addr, 53), family) + + # Bind addresses in ad-hoc REPLYs + for s in self.scenario.steps: + if s.type == 'REPLY': + reply = s.data[0].message + for rr in itertools.chain(reply.answer, + reply.additional, + reply.question, + reply.authority): + for rd in rr: + if rd.rdtype == dns.rdatatype.A: + self.start_srv((rd.address, 53), socket.AF_INET) + elif rd.rdtype == dns.rdatatype.AAAA: + self.start_srv((rd.address, 53), socket.AF_INET6) + + def play(self, subject_addr): + self.scenario.play({'': (subject_addr, 53)}) + + +def empty_test_case(): + """ + Return (scenario, config) pair which answers to any query on 127.0.0.10. + """ + # Mirror server + empty_test_path = os.path.dirname(os.path.realpath(__file__)) + "/empty.rpl" + test_config = {'ROOT_ADDR': '127.0.0.10', + '_SOCKET_FAMILY': socket.AF_INET} + return scenario.parse_file(empty_test_path)[0], test_config + + +def standalone_self_test(): + """ + Self-test code + + Usage: + LD_PRELOAD=libsocket_wrapper.so SOCKET_WRAPPER_DIR=/tmp $PYTHON -m pydnstest.testserver --help + """ + logging.basicConfig(level=logging.DEBUG) + argparser = argparse.ArgumentParser() + argparser.add_argument('--scenario', help='absolute path to test scenario', + required=False) + argparser.add_argument('--step', help='step # in the scenario (default: first)', + required=False, type=int) + args = argparser.parse_args() + if args.scenario: + test_scenario, test_config_text = scenario.parse_file(args.scenario) + test_config, _ = scenario.parse_config(test_config_text, True, os.getcwd()) + else: + test_scenario, test_config = empty_test_case() + + if args.step: + for step in test_scenario.steps: + if step.id == args.step: + test_scenario.current_step = step + if not test_scenario.current_step: + raise ValueError('step ID %s not found in scenario' % args.step) + else: + test_scenario.current_step = test_scenario.steps[0] + + server = TestServer(test_scenario, test_config['ROOT_ADDR'], test_config['_SOCKET_FAMILY']) + server.start() + + logging.info("[==========] Mirror server running at %s", server.address()) + + def kill(signum, frame): # pylint: disable=unused-argument + logging.info("[==========] Shutdown.") + server.stop() + sys.exit(128 + signum) + + signal.signal(signal.SIGINT, kill) + signal.signal(signal.SIGTERM, kill) + + while True: + time.sleep(0.5) + + +if __name__ == '__main__': + # this is done to avoid creating global variables + standalone_self_test() |