summaryrefslogtreecommitdiffstats
path: root/tests/deckard/pydnstest
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-06 00:55:53 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-06 00:55:53 +0000
commit3d0386f27ca66379acf50199e1d1298386eeeeb8 (patch)
treef87bd4a126b3a843858eb447e8fd5893c3ee3882 /tests/deckard/pydnstest
parentInitial commit. (diff)
downloadknot-resolver-upstream.tar.xz
knot-resolver-upstream.zip
Adding upstream version 3.2.1.upstream/3.2.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
-rw-r--r--tests/deckard/pydnstest/__init__.py0
-rw-r--r--tests/deckard/pydnstest/augwrap.py227
-rw-r--r--tests/deckard/pydnstest/deckard.aug94
-rw-r--r--tests/deckard/pydnstest/empty.rpl20
-rw-r--r--tests/deckard/pydnstest/matchpart.py238
-rw-r--r--tests/deckard/pydnstest/scenario.py1058
-rw-r--r--tests/deckard/pydnstest/tests/__init__.py0
-rw-r--r--tests/deckard/pydnstest/tests/test_parse_config.py17
-rw-r--r--tests/deckard/pydnstest/tests/test_scenario.py55
-rw-r--r--tests/deckard/pydnstest/testserver.py278
10 files changed, 1987 insertions, 0 deletions
diff --git a/tests/deckard/pydnstest/__init__.py b/tests/deckard/pydnstest/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/deckard/pydnstest/__init__.py
diff --git a/tests/deckard/pydnstest/augwrap.py b/tests/deckard/pydnstest/augwrap.py
new file mode 100644
index 0000000..20e7857
--- /dev/null
+++ b/tests/deckard/pydnstest/augwrap.py
@@ -0,0 +1,227 @@
+#!/usr/bin/python3
+
+# Copyright (C) 2017
+
+import posixpath
+import logging
+import os
+import collections
+
+from augeas import Augeas
+
+AUGEAS_LOAD_PATH = '/augeas/load/'
+AUGEAS_FILES_PATH = '/files/'
+AUGEAS_ERROR_PATH = '//error'
+
+log = logging.getLogger('augeas')
+
+
+def join(*paths):
+ """
+ join two Augeas tree paths
+
+ FIXME: Beware: // is normalized to /
+ """
+ norm_paths = [posixpath.normpath(path) for path in paths]
+ # first path must be absolute
+ assert norm_paths[0][0] == '/'
+ new_paths = [norm_paths[0]]
+ # relativize all other paths so join works as expected
+ for path in norm_paths[1:]:
+ if path.startswith('/'):
+ path = path[1:]
+ new_paths.append(path)
+ new_path = posixpath.join(*new_paths)
+ log.debug("join: new_path %s", new_path)
+ return posixpath.normpath(new_path)
+
+
+class AugeasWrapper:
+ """python-augeas higher-level wrapper.
+
+ Load single augeas lens and configuration file.
+ Exposes configuration file as AugeasNode object with dict-like interface.
+
+ AugeasWrapper can be used in with statement in the same way as file does.
+ """
+
+ def __init__(self, confpath, lens, root=None, loadpath=None,
+ flags=Augeas.NO_MODL_AUTOLOAD | Augeas.NO_LOAD | Augeas.ENABLE_SPAN):
+ """Parse configuration file using given lens.
+
+ Params:
+ confpath (str): Absolute path to the configuration file
+ lens (str): Name of module containing Augeas lens
+ root: passed down to original Augeas
+ flags: passed down to original Augeas
+ loadpath: passed down to original Augeas
+ flags: passed down to original Augeas
+ """
+ log.debug('loadpath: %s', loadpath)
+ log.debug('confpath: %s', confpath)
+ self._aug = Augeas(root=root, loadpath=loadpath, flags=flags)
+
+ # /augeas/load/{lens}
+ aug_load_path = join(AUGEAS_LOAD_PATH, lens)
+ # /augeas/load/{lens}/lens = {lens}.lns
+ self._aug.set(join(aug_load_path, 'lens'), '%s.lns' % lens)
+ # /augeas/load/{lens}/incl[0] = {confpath}
+ self._aug.set(join(aug_load_path, 'incl[0]'), confpath)
+ self._aug.load()
+
+ errors = self._aug.match(AUGEAS_ERROR_PATH)
+ if errors:
+ err_msg = '\n'.join(
+ ["{}: {}".format(e, self._aug.get(e)) for e in errors]
+ )
+ raise RuntimeError(err_msg)
+
+ path = join(AUGEAS_FILES_PATH, confpath)
+ paths = self._aug.match(path)
+ if len(paths) != 1:
+ raise ValueError('path %s did not match exactly once' % path)
+ self.tree = AugeasNode(self._aug, path)
+ self._loaded = True
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.save()
+ self.close()
+
+ def save(self):
+ """Save Augeas tree to its original file."""
+ assert self._loaded
+ try:
+ self._aug.save()
+ except IOError as exc:
+ log.exception(exc)
+ for err_path in self._aug.match('//error'):
+ log.error('%s: %s', err_path,
+ self._aug.get(os.path.join(err_path, 'message')))
+ raise
+
+ def close(self):
+ """
+ close Augeas library
+
+ After calling close() the object must not be used anymore.
+ """
+ assert self._loaded
+ self._aug.close()
+ del self._aug
+ self._loaded = False
+
+ def match(self, path):
+ """Yield AugeasNodes matching given expression."""
+ assert self._loaded
+ assert path
+ log.debug('tree match %s', path)
+ for matched_path in self._aug.match(path):
+ yield AugeasNode(self._aug, matched_path)
+
+
+class AugeasNode(collections.MutableMapping):
+ """One Augeas tree node with dict-like interface."""
+
+ def __init__(self, aug, path):
+ """
+ Args:
+ aug (AugeasWrapper or Augeas): Augeas library instance
+ path (str): absolute path in Augeas tree matching single node
+
+ BEWARE: There are no sanity checks of given path for performance reasons.
+ """
+ assert aug
+ assert path
+ assert path.startswith('/')
+ self._aug = aug
+ self._path = path
+ self._span = None
+
+ @property
+ def path(self):
+ """canonical path in Augeas tree, read-only"""
+ return self._path
+
+ @property
+ def value(self):
+ """
+ get value of this node in Augeas tree
+ """
+ value = self._aug.get(self._path)
+ log.debug('tree get: %s = %s', self._path, value)
+ return value
+
+ @value.setter
+ def value(self, value):
+ """
+ set value of this node in Augeas tree
+ """
+ log.debug('tree set: %s = %s', self._path, value)
+ self._aug.set(self._path, value)
+
+ @property
+ def span(self):
+ if self._span is None:
+ self._span = "char position %s" % self._aug.span(self._path)[5]
+ return self._span
+
+ @property
+ def char(self):
+ return self._aug.span(self._path)[5]
+
+ def __len__(self):
+ """
+ number of items matching this path
+
+ It is always 1 after __init__() but it may change
+ as Augeas tree changes.
+ """
+ return len(self._aug.match(self._path))
+
+ def __getitem__(self, key):
+ if isinstance(key, int):
+ # int is a shortcut to write [int]
+ target_path = '%s[%s]' % (self._path, key)
+ else:
+ target_path = self._path + key
+ log.debug('tree getitem: target_path %s', target_path)
+ paths = self._aug.match(target_path)
+ if len(paths) != 1:
+ raise KeyError('path %s did not match exactly once' % target_path)
+ return AugeasNode(self._aug, target_path)
+
+ def __delitem__(self, key):
+ log.debug('tree delitem: %s + %s', self._path, key)
+ target_path = self._path + key
+ log.debug('tree delitem: target_path %s', target_path)
+ self._aug.remove(target_path)
+
+ def __setitem__(self, key, value):
+ assert isinstance(value, AugeasNode)
+ target_path = self.path + key
+ self._aug.copy(value.path, target_path)
+
+ def __iter__(self):
+ self_path_len = len(self._path)
+ assert self_path_len > 0
+
+ log.debug('tree iter: %s', self._path)
+ for new_path in self._aug.match(self._path):
+ if len(new_path) == self_path_len:
+ yield ''
+ else:
+ yield new_path[self_path_len - 1:]
+
+ def match(self, subpath):
+ """Yield AugeasNodes matching given sub-expression."""
+ assert subpath.startswith("/")
+ match_path = "%s%s" % (self._path, subpath)
+ log.debug('tree match %s: %s', match_path, self._path)
+ for matched_path in self._aug.match(match_path):
+ yield AugeasNode(self._aug, matched_path)
+
+ def __repr__(self):
+ return 'AugeasNode(%s)' % self._path
diff --git a/tests/deckard/pydnstest/deckard.aug b/tests/deckard/pydnstest/deckard.aug
new file mode 100644
index 0000000..9e2d167
--- /dev/null
+++ b/tests/deckard/pydnstest/deckard.aug
@@ -0,0 +1,94 @@
+module Deckard =
+ autoload xfm
+
+let del_str = Util.del_str
+
+let space = del /[ \t]+/ " "
+let tab = del /[ \t]+/ "\t"
+let ws = del /[\t ]*/ ""
+let word = /[^\t\n\/; ]+/
+
+let comment = del /[;]/ ";" . [label "comment" . store /[^\n]+/]
+
+let eol = del /([ \t]*([;][^\n]*)?\n)+/ "\n" . Util.indent
+let comment_or_eol = ws . comment? . del_str "\n" . del /([ \t]*([;][^\n]*)?\n)*/ "" . Util.indent
+
+
+(*let comment_or_eol = [ label "#comment" . counter "comment" . (ws . [del /[;#]/ ";" . label "" . store /[^\n]*/ ]? . del_str "\n")]+ . Util.indent
+*)
+
+
+let domain_re = (/[^.\t\n\/; ]+(\.[^.\t\n\/; ]+)*\.?/ | ".") - "SECTION" (*quick n dirty, sorry to whoever will ever own SECTION TLD*)
+let class_re = /CLASS[0-9]+/ | "IN" | "CH" | "HS" | "NONE" | "ANY"
+let domain = [ label "domain" . store domain_re ]
+let ttl = [label "ttl" . store /[0-9]+/]
+let class = [label "class" . store class_re ]
+let type = [label "type" . store ((/[^0-9;\n \t][^\t\n\/; ]*/) - class_re) ]
+(* RFC 3597 section 5 rdata syntax is "\# 1 ab"*)
+let data_re = /((\\#[ \t])?[^ \t\n;][^\n;]*[^ \t\n;])|[^ \t\n;]/ (*Can not start nor end with whitespace but can have whitespace in the middle. Disjunction is there so we match strings of length one.*)
+let data = [label "data" . store data_re ]
+
+let ip_re = /[0-9a-f.:]+/
+let hex_re = /[0-9a-fA-F]+/
+
+
+let match_option = "opcode" | "qtype" | "qcase" | "qname" | "subdomain" | "flags" | "rcode" | "question" | "answer" | "authority" | "additional" | "all" | "edns"
+let adjust_option = "copy_id" | "copy_query" | "raw_id" | "do_not_answer"
+let reply_option = "QR" | "TC" | "AA" | "AD" | "RD" | "RA" | "CD" | "DO" | "NOERROR" | "FORMERR" | "SERVFAIL" | "NXDOMAIN" | "NOTIMP" | "REFUSED" | "YXDOMAIN" | "YXRRSET" | "NXRRSET" | "NOTAUTH" | "NOTZONE" | "BADVERS" | "BADSIG" | "BADKEY" | "BADTIME" | "BADMODE" | "BADNAME" | "BADALG" | "BADTRUNC" | "BADCOOKIE"
+let step_option = "REPLY" | "QUERY" | "CHECK_ANSWER" | "CHECK_OUT_QUERY" | /TIME_PASSES[ \t]+ELAPSE/
+
+let mandatory = [del_str "MANDATORY" . label "mandatory" . value "true" . comment_or_eol]
+let tsig = [del_str "TSIG" . label "tsig" . space . [label "keyname" . store word] . space . [label "secret" . store word] . comment_or_eol]
+
+let match = (mandatory | tsig)* . [ label "match_present" . value "true" . del_str "MATCH" ] . [space . label "match" . store match_option ]+ . comment_or_eol
+let adjust = (mandatory | tsig)* . del_str "ADJUST" . [space . label "adjust" . store adjust_option ]+ . comment_or_eol
+let reply = (mandatory | tsig)* . del ("REPLY" | "FLAGS") "REPLY" . [space . label "reply" . store reply_option ]+ . comment_or_eol
+
+
+let question = [label "record" . domain . tab . (class . tab)? . type . comment_or_eol ]
+let record = [label "record" . domain . tab . (ttl . tab)? . (class . tab)? . type . tab . data . comment_or_eol]
+
+let section_question = [ label "question" . del_str "SECTION QUESTION" .
+ comment_or_eol . question? ]
+let section_answer = [ label "answer" . del_str "SECTION ANSWER" .
+ comment_or_eol . record* ]
+let section_authority = [ label "authority" . del_str "SECTION AUTHORITY" .
+ comment_or_eol . record* ]
+let section_additional = [ label "additional" . del_str "SECTION ADDITIONAL" .
+ comment_or_eol . record* ]
+let sections = [label "section" . section_question? . section_answer? . section_authority? . section_additional?]
+let raw = [del_str "RAW" . comment_or_eol . label "raw" . store hex_re ] . comment_or_eol
+
+(* This is quite dirty hack to match every combination of options given to entry since 'let dnsmsg = ((match | adjust | reply | mandatory | tsig)* . sections)' just is not possible *)
+
+let dnsmsg = (match . (adjust . reply? | reply . adjust?)? | adjust . (match . reply? | reply . match?)? | reply . (match . adjust? | adjust . match?)?)? . (mandatory | tsig)* . sections
+
+let entry = [label "entry" . del_str "ENTRY_BEGIN" . comment_or_eol . dnsmsg . raw? . del_str "ENTRY_END" . eol]
+
+let single_address = [ label "address" . space . store ip_re ]
+
+let addresses = [label "address" . counter "address" . [seq "address" . del_str "ADDRESS" . space . store ip_re . comment_or_eol]+]
+
+let range = [label "range" . del_str "RANGE_BEGIN" . space . [ label "from" . store /[0-9]+/] . space .
+ [ label "to" . store /[0-9]+/] . single_address? . comment_or_eol . addresses? . entry* . del_str "RANGE_END" . eol]
+
+let step = [label "step" . del_str "STEP" . space . store /[0-9]+/ . space . [label "type" . store step_option] . [space . label "timestamp" . store /[0-9]+/]? . comment_or_eol .
+ entry? ]
+
+let config_record = /[^\n]*/ - ("CONFIG_END" | /STEP.*/ | /SCENARIO.*/ | /RANGE.*/ | /ENTRY.*/)
+
+let config = [ label "config" . counter "config" . [seq "config" . store config_record . del_str "\n"]* . del_str "CONFIG_END" . comment_or_eol ]
+
+let guts = (step | range )*
+
+let scenario = [label "scenario" . del_str "SCENARIO_BEGIN" . space . store data_re . comment_or_eol . guts . del_str "SCENARIO_END" . eol]
+
+let lns = config? . scenario
+
+(* TODO: REPLAY step *)
+(* TODO: store all comments into the tree instead of ignoring them *)
+
+(*let filter = incl "/home/test/*.rpl"*)
+let filter = incl "/home/sbalazik/nic/deckard/git/sets/resolver/*.rpl"
+
+let xfm = transform lns filter
diff --git a/tests/deckard/pydnstest/empty.rpl b/tests/deckard/pydnstest/empty.rpl
new file mode 100644
index 0000000..295d5a5
--- /dev/null
+++ b/tests/deckard/pydnstest/empty.rpl
@@ -0,0 +1,20 @@
+stub-addr: 127.0.0.10
+CONFIG_END
+
+SCENARIO_BEGIN empty replies
+
+RANGE_BEGIN 0 100
+ ADDRESS 127.0.0.10
+ENTRY_BEGIN
+MATCH subdomain
+ADJUST copy_id copy_query
+SECTION QUESTION
+. IN A
+ENTRY_END
+RANGE_END
+
+STEP 1 QUERY
+ENTRY_BEGIN
+ENTRY_END
+
+SCENARIO_END
diff --git a/tests/deckard/pydnstest/matchpart.py b/tests/deckard/pydnstest/matchpart.py
new file mode 100644
index 0000000..294e64c
--- /dev/null
+++ b/tests/deckard/pydnstest/matchpart.py
@@ -0,0 +1,238 @@
+"""matchpart is used to compare two DNS messages using a single criterion"""
+
+from typing import ( # noqa
+ Any, Hashable, Sequence, Tuple, Union)
+
+import dns.edns
+import dns.rcode
+import dns.set
+
+MismatchValue = Union[str, Sequence[Any]]
+
+
+class DataMismatch(Exception):
+ def __init__(self, exp_val, got_val):
+ super().__init__()
+ self.exp_val = exp_val
+ self.got_val = got_val
+
+ @staticmethod
+ def format_value(value: MismatchValue) -> str:
+ if isinstance(value, list):
+ return ' '.join([str(val) for val in value])
+ else:
+ return str(value)
+
+ def __str__(self) -> str:
+ return 'expected "{}" got "{}"'.format(
+ self.format_value(self.exp_val),
+ self.format_value(self.got_val))
+
+ def __eq__(self, other):
+ return (isinstance(other, DataMismatch)
+ and self.exp_val == other.exp_val
+ and self.got_val == other.got_val)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ @property
+ def key(self) -> Tuple[Hashable, Hashable]:
+ def make_hashable(value):
+ if isinstance(value, (list, dns.set.Set)):
+ value = (make_hashable(item) for item in value)
+ value = tuple(value)
+ return value
+
+ return (make_hashable(self.exp_val), make_hashable(self.got_val))
+
+ def __hash__(self) -> int:
+ return hash(self.key)
+
+
+def compare_val(exp, got):
+ """Compare arbitraty objects, throw exception if different. """
+ if exp != got:
+ raise DataMismatch(exp, got)
+ return True
+
+
+def compare_rrs(expected, got):
+ """ Compare lists of RR sets, throw exception if different. """
+ for rr in expected:
+ if rr not in got:
+ raise DataMismatch(expected, got)
+ for rr in got:
+ if rr not in expected:
+ raise DataMismatch(expected, got)
+ if len(expected) != len(got):
+ raise DataMismatch(expected, got)
+ return True
+
+
+def compare_rrs_types(exp_val, got_val, skip_rrsigs):
+ """sets of RR types in both sections must match"""
+ def rr_ordering_key(rrset):
+ if rrset.covers:
+ return rrset.covers, 1 # RRSIGs go to the end of RRtype list
+ else:
+ return rrset.rdtype, 0
+
+ def key_to_text(rrtype, rrsig):
+ if not rrsig:
+ return dns.rdatatype.to_text(rrtype)
+ else:
+ return 'RRSIG(%s)' % dns.rdatatype.to_text(rrtype)
+
+ if skip_rrsigs:
+ exp_val = (rrset for rrset in exp_val
+ if rrset.rdtype != dns.rdatatype.RRSIG)
+ got_val = (rrset for rrset in got_val
+ if rrset.rdtype != dns.rdatatype.RRSIG)
+
+ exp_types = frozenset(rr_ordering_key(rrset) for rrset in exp_val)
+ got_types = frozenset(rr_ordering_key(rrset) for rrset in got_val)
+ if exp_types != got_types:
+ exp_types = tuple(key_to_text(*i) for i in sorted(exp_types))
+ got_types = tuple(key_to_text(*i) for i in sorted(got_types))
+ raise DataMismatch(exp_types, got_types)
+
+
+def check_question(question):
+ if len(question) > 2:
+ raise NotImplementedError("More than one record in QUESTION SECTION.")
+
+
+def match_opcode(exp, got):
+ return compare_val(exp.opcode(),
+ got.opcode())
+
+
+def match_qtype(exp, got):
+ check_question(exp.question)
+ check_question(got.question)
+ if not exp.question and not got.question:
+ return True
+ if not exp.question:
+ raise DataMismatch("<empty question>", got.question[0].rdtype)
+ if not got.question:
+ raise DataMismatch(exp.question[0].rdtype, "<empty question>")
+ return compare_val(exp.question[0].rdtype,
+ got.question[0].rdtype)
+
+
+def match_qname(exp, got):
+ check_question(exp.question)
+ check_question(got.question)
+ if not exp.question and not got.question:
+ return True
+ if not exp.question:
+ raise DataMismatch("<empty question>", got.question[0].name)
+ if not got.question:
+ raise DataMismatch(exp.question[0].name, "<empty question>")
+ return compare_val(exp.question[0].name,
+ got.question[0].name)
+
+
+def match_qcase(exp, got):
+ check_question(exp.question)
+ check_question(got.question)
+ if not exp.question and not got.question:
+ return True
+ if not exp.question:
+ raise DataMismatch("<empty question>", got.question[0].name.labels)
+ if not got.question:
+ raise DataMismatch(exp.question[0].name.labels, "<empty question>")
+ return compare_val(exp.question[0].name.labels,
+ got.question[0].name.labels)
+
+
+def match_subdomain(exp, got):
+ if not exp.question:
+ return True
+ if got.question:
+ qname = got.question[0].name
+ else:
+ qname = dns.name.root
+ if exp.question[0].name.is_superdomain(qname):
+ return True
+ raise DataMismatch(exp, got)
+
+
+def match_flags(exp, got):
+ return compare_val(dns.flags.to_text(exp.flags),
+ dns.flags.to_text(got.flags))
+
+
+def match_rcode(exp, got):
+ return compare_val(dns.rcode.to_text(exp.rcode()),
+ dns.rcode.to_text(got.rcode()))
+
+
+def match_answer(exp, got):
+ return compare_rrs(exp.answer,
+ got.answer)
+
+
+def match_answertypes(exp, got):
+ return compare_rrs_types(exp.answer,
+ got.answer, skip_rrsigs=True)
+
+
+def match_answerrrsigs(exp, got):
+ return compare_rrs_types(exp.answer,
+ got.answer, skip_rrsigs=False)
+
+
+def match_authority(exp, got):
+ return compare_rrs(exp.authority,
+ got.authority)
+
+
+def match_additional(exp, got):
+ return compare_rrs(exp.additional,
+ got.additional)
+
+
+def match_edns(exp, got):
+ if got.edns != exp.edns:
+ raise DataMismatch(exp.edns,
+ got.edns)
+ if got.payload != exp.payload:
+ raise DataMismatch(exp.payload,
+ got.payload)
+
+
+def match_nsid(exp, got):
+ nsid_opt = None
+ for opt in exp.options:
+ if opt.otype == dns.edns.NSID:
+ nsid_opt = opt
+ break
+ # Find matching NSID
+ for opt in got.options:
+ if opt.otype == dns.edns.NSID:
+ if not nsid_opt:
+ raise DataMismatch(None, opt.data)
+ if opt == nsid_opt:
+ return True
+ else:
+ raise DataMismatch(nsid_opt.data, opt.data)
+ if nsid_opt:
+ raise DataMismatch(nsid_opt.data, None)
+ return True
+
+
+MATCH = {"opcode": match_opcode, "qtype": match_qtype, "qname": match_qname, "qcase": match_qcase,
+ "subdomain": match_subdomain, "flags": match_flags, "rcode": match_rcode,
+ "answer": match_answer, "answertypes": match_answertypes,
+ "answerrrsigs": match_answerrrsigs, "authority": match_authority,
+ "additional": match_additional, "edns": match_edns,
+ "nsid": match_nsid}
+
+
+def match_part(exp, got, code):
+ try:
+ return MATCH[code](exp, got)
+ except KeyError:
+ raise NotImplementedError('unknown match request "%s"' % code)
diff --git a/tests/deckard/pydnstest/scenario.py b/tests/deckard/pydnstest/scenario.py
new file mode 100644
index 0000000..5e0661b
--- /dev/null
+++ b/tests/deckard/pydnstest/scenario.py
@@ -0,0 +1,1058 @@
+# FIXME pylint: disable=too-many-lines
+from abc import ABC
+import binascii
+import calendar
+from datetime import datetime
+import errno
+import logging
+import os
+import posixpath
+import random
+import socket
+import string
+import struct
+import time
+from typing import Optional
+
+import dns.dnssec
+import dns.message
+import dns.name
+import dns.rcode
+import dns.rrset
+import dns.tsigkeyring
+
+import pydnstest.augwrap
+import pydnstest.matchpart
+
+
+def str2bool(v):
+ """ Return conversion of JSON-ish string value to boolean. """
+ return v.lower() in ('yes', 'true', 'on', '1')
+
+
+# Global statistics
+g_rtt = 0.0
+g_nqueries = 0
+
+
+def recvfrom_msg(stream, raw=False):
+ """
+ Receive DNS message from TCP/UDP socket.
+
+ Returns:
+ if raw == False: (DNS message object, peer address)
+ if raw == True: (blob, peer address)
+ """
+ if stream.type & socket.SOCK_DGRAM:
+ data, addr = stream.recvfrom(4096)
+ elif stream.type & socket.SOCK_STREAM:
+ data = stream.recv(2)
+ if not data:
+ return None, None
+ msg_len = struct.unpack_from("!H", data)[0]
+ data = b""
+ received = 0
+ while received < msg_len:
+ next_chunk = stream.recv(4096)
+ if not next_chunk:
+ return None, None
+ data += next_chunk
+ received += len(next_chunk)
+ addr = stream.getpeername()[0]
+ else:
+ raise NotImplementedError("[recvfrom_msg]: unknown socket type '%i'" % stream.type)
+ if raw:
+ return data, addr
+ else:
+ msg = dns.message.from_wire(data, one_rr_per_rrset=True)
+ return msg, addr
+
+
+def sendto_msg(stream, message, addr=None):
+ """ Send DNS/UDP/TCP message. """
+ try:
+ if stream.type & socket.SOCK_DGRAM:
+ if addr is None:
+ stream.send(message)
+ else:
+ stream.sendto(message, addr)
+ elif stream.type & socket.SOCK_STREAM:
+ data = struct.pack("!H", len(message)) + message
+ stream.send(data)
+ else:
+ raise NotImplementedError("[sendto_msg]: unknown socket type '%i'" % stream.type)
+ except socket.error as ex:
+ if ex.errno != errno.ECONNREFUSED: # TODO Investigate how this can happen
+ raise
+
+
+def replay_rrs(rrs, nqueries, destination, args=None):
+ """ Replay list of queries and report statistics. """
+ if args is None:
+ args = []
+ navail, queries = len(rrs), []
+ chunksize = 16
+ for i in range(nqueries if 'RAND' in args else navail):
+ rr = rrs[i % navail]
+ name = rr.name
+ if 'RAND' in args:
+ prefix = ''.join([random.choice(string.ascii_letters + string.digits)
+ for _ in range(8)])
+ name = prefix + '.' + rr.name.to_text()
+ msg = dns.message.make_query(name, rr.rdtype, rr.rdclass)
+ if 'DO' in args:
+ msg.want_dnssec(True)
+ queries.append(msg.to_wire())
+ # Make a UDP connected socket to the destination
+ family = socket.AF_INET6 if ':' in destination[0] else socket.AF_INET
+ sock = socket.socket(family, socket.SOCK_DGRAM)
+ sock.connect(destination)
+ sock.setblocking(False)
+ # Play the query set
+ # @NOTE: this is only good for relative low-speed replay
+ rcvbuf = bytearray('\x00' * 512)
+ nsent, nrcvd, nwait, navail = 0, 0, 0, len(queries)
+ fdset = [sock]
+ import select
+ while nsent - nwait < nqueries:
+ to_read, to_write, _ = select.select(fdset, fdset if nwait < chunksize else [], [], 0.5)
+ if to_write:
+ try:
+ while nsent < nqueries and nwait < chunksize:
+ sock.send(queries[nsent % navail])
+ nwait += 1
+ nsent += 1
+ except socket.error:
+ pass # EINVAL
+ if to_read:
+ try:
+ while nwait > 0:
+ sock.recv_into(rcvbuf)
+ nwait -= 1
+ nrcvd += 1
+ except socket.error:
+ pass
+ if not to_write and not to_read:
+ nwait = 0 # Timeout, started dropping packets
+ break
+ return nsent, nrcvd
+
+
+class DNSBlob(ABC):
+ def to_wire(self) -> bytes:
+ raise NotImplementedError
+
+ def __str__(self) -> str:
+ return '<DNSBlob>'
+
+
+class DNSMessage(DNSBlob):
+ def __init__(self, message: dns.message.Message) -> None:
+ assert message is not None
+ self.message = message
+
+ def to_wire(self) -> bytes:
+ return self.message.to_wire(max_size=65535)
+
+ def __str__(self) -> str:
+ return str(self.message)
+
+
+class DNSReply(DNSMessage):
+ def __init__(
+ self,
+ message: dns.message.Message,
+ query: Optional[dns.message.Message] = None,
+ copy_id: bool = False,
+ copy_query: bool = False
+ ) -> None:
+ super().__init__(message)
+ if copy_id or copy_query:
+ if query is None:
+ raise ValueError("query must be provided to adjust copy_id/copy_query")
+ self.adjust_reply(query, copy_id, copy_query)
+
+ def adjust_reply(
+ self,
+ query: dns.message.Message,
+ copy_id: bool = True,
+ copy_query: bool = True
+ ) -> None:
+ answer = dns.message.from_wire(self.message.to_wire(),
+ xfr=self.message.xfr,
+ one_rr_per_rrset=True)
+ answer.use_edns(query.edns, query.ednsflags, options=self.message.options)
+ if copy_id:
+ answer.id = query.id
+ # Copy letter-case if the template has QD
+ if answer.question:
+ answer.question[0].name = query.question[0].name
+ if copy_query:
+ answer.question = query.question
+ # Re-set, as the EDNS might have reset the ext-rcode
+ answer.set_rcode(self.message.rcode())
+
+ # sanity check: adjusted answer should be almost the same
+ assert len(answer.answer) == len(self.message.answer)
+ assert len(answer.authority) == len(self.message.authority)
+ assert len(answer.additional) == len(self.message.additional)
+ self.message = answer
+
+
+class DNSReplyRaw(DNSBlob):
+ def __init__(
+ self,
+ wire: bytes,
+ query: Optional[dns.message.Message] = None,
+ copy_id: bool = False
+ ) -> None:
+ assert wire is not None
+ self.wire = wire
+ if copy_id:
+ self.adjust_reply(query, copy_id)
+
+ def adjust_reply(
+ self,
+ query: dns.message.Message,
+ copy_id: bool = True
+ ) -> None:
+ if copy_id:
+ if len(self.wire) < 2:
+ raise ValueError(
+ 'wire data must contain at least 2 bytes to adjust query id')
+ raw_answer = bytearray(self.wire)
+ struct.pack_into('!H', raw_answer, 0, query.id)
+ self.wire = bytes(raw_answer)
+
+ def to_wire(self) -> bytes:
+ return self.wire
+
+ def __str__(self) -> str:
+ return '<DNSReplyRaw>'
+
+
+class DNSReplyServfail(DNSMessage):
+ def __init__(self, query: dns.message.Message) -> None:
+ message = dns.message.make_response(query)
+ message.set_rcode(dns.rcode.SERVFAIL)
+ super().__init__(message)
+
+
+class Entry:
+ """
+ Data entry represents scripted message and extra metadata,
+ notably match criteria and reply adjustments.
+ """
+
+ # Globals
+ default_ttl = 3600
+ default_cls = 'IN'
+ default_rc = 'NOERROR'
+
+ def __init__(self, node):
+ """ Initialize data entry. """
+ self.node = node
+ self.origin = '.'
+ self.message = dns.message.Message()
+ self.message.use_edns(edns=0, payload=4096)
+ self.fired = 0
+
+ # RAW
+ self.raw_data = None # type: Optional[bytes]
+ self.is_raw_data_entry = self.process_raw()
+
+ # MATCH
+ self.match_fields = self.process_match()
+
+ # FLAGS
+ self.process_reply_line()
+
+ # ADJUST
+ self.adjust_fields = {m.value for m in node.match("/adjust")}
+
+ # MANDATORY
+ try:
+ self.mandatory = list(node.match("/mandatory"))[0]
+ except (KeyError, IndexError):
+ self.mandatory = None
+
+ # TSIG
+ self.process_tsig()
+
+ # SECTIONS & RECORDS
+ self.sections = self.process_sections()
+
+ def process_raw(self):
+ try:
+ self.raw_data = binascii.unhexlify(self.node["/raw"].value)
+ return True
+ except KeyError:
+ return False
+
+ def process_match(self):
+ try:
+ self.node["/match_present"]
+ except KeyError:
+ return None
+
+ fields = set(m.value for m in self.node.match("/match"))
+
+ if 'all' in fields:
+ fields.remove("all")
+ fields |= set(["opcode", "qtype", "qname", "flags",
+ "rcode", "answer", "authority", "additional"])
+
+ if 'question' in fields:
+ fields.remove("question")
+ fields |= set(["qtype", "qname"])
+
+ return fields
+
+ def process_reply_line(self):
+ """Extracts flags, rcode and opcode from given node and adjust dns message accordingly"""
+ self.fields = [f.value for f in self.node.match("/reply")]
+ if 'DO' in self.fields:
+ self.message.want_dnssec(True)
+ opcode = self.get_opcode(fields=self.fields)
+ rcode = self.get_rcode(fields=self.fields)
+ self.message.flags = self.get_flags(fields=self.fields)
+ if rcode is not None:
+ self.message.set_rcode(rcode)
+ if opcode is not None:
+ self.message.set_opcode(opcode)
+
+ def process_tsig(self):
+ try:
+ tsig = list(self.node.match("/tsig"))[0]
+ tsig_keyname = tsig["/keyname"].value
+ tsig_secret = tsig["/secret"].value
+ keyring = dns.tsigkeyring.from_text({tsig_keyname: tsig_secret})
+ self.message.use_tsig(keyring=keyring, keyname=tsig_keyname)
+ except (KeyError, IndexError):
+ pass
+
+ def process_sections(self):
+ sections = set()
+ for section in self.node.match("/section/*"):
+ section_name = posixpath.basename(section.path)
+ sections.add(section_name)
+ for record in section.match("/record"):
+ owner = record['/domain'].value
+ if not owner.endswith("."):
+ owner += self.origin
+ try:
+ ttl = dns.ttl.from_text(record['/ttl'].value)
+ except KeyError:
+ ttl = self.default_ttl
+ try:
+ rdclass = dns.rdataclass.from_text(record['/class'].value)
+ except KeyError:
+ rdclass = dns.rdataclass.from_text(self.default_cls)
+ rdtype = dns.rdatatype.from_text(record['/type'].value)
+ rr = dns.rrset.from_text(owner, ttl, rdclass, rdtype)
+ if section_name != "question":
+ rd = record['/data'].value.split()
+ if rd:
+ if rdtype == dns.rdatatype.DS:
+ rd[1] = str(dns.dnssec.algorithm_from_text(rd[1]))
+ rd = dns.rdata.from_text(rr.rdclass, rr.rdtype, ' '.join(
+ rd), origin=dns.name.from_text(self.origin), relativize=False)
+ rr.add(rd)
+ if section_name == 'question':
+ if rr.rdtype == dns.rdatatype.AXFR:
+ self.message.xfr = True
+ self.message.question.append(rr)
+ elif section_name == 'answer':
+ self.message.answer.append(rr)
+ elif section_name == 'authority':
+ self.message.authority.append(rr)
+ elif section_name == 'additional':
+ self.message.additional.append(rr)
+ return sections
+
+ def __str__(self):
+ txt = 'ENTRY_BEGIN\n'
+ if not self.is_raw_data_entry:
+ txt += 'MATCH {0}\n'.format(' '.join(self.match_fields))
+ txt += 'ADJUST {0}\n'.format(' '.join(self.adjust_fields))
+ txt += 'REPLY {rcode} {flags}\n'.format(
+ rcode=dns.rcode.to_text(self.message.rcode()),
+ flags=' '.join([dns.flags.to_text(self.message.flags),
+ dns.flags.edns_to_text(self.message.ednsflags)])
+ )
+ for sect_name in ['question', 'answer', 'authority', 'additional']:
+ sect = getattr(self.message, sect_name)
+ if not sect:
+ continue
+ txt += 'SECTION {n}\n'.format(n=sect_name.upper())
+ for rr in sect:
+ txt += str(rr)
+ txt += '\n'
+ if self.is_raw_data_entry:
+ txt += 'RAW\n'
+ if self.raw_data:
+ txt += binascii.hexlify(self.raw_data)
+ else:
+ txt += 'NULL'
+ txt += '\n'
+ txt += 'ENTRY_END\n'
+ return txt
+
+ @classmethod
+ def get_flags(cls, fields):
+ """From `fields` extracts and returns flags"""
+ flags = []
+ for code in fields:
+ try:
+ dns.flags.from_text(code) # throws KeyError on failure
+ flags.append(code)
+ except KeyError:
+ pass
+ return dns.flags.from_text(' '.join(flags))
+
+ @classmethod
+ def get_rcode(cls, fields):
+ """
+ From `fields` extracts and returns rcode.
+ Throws `ValueError` if there are more then one rcodes
+ """
+ rcodes = []
+ for code in fields:
+ try:
+ rcodes.append(dns.rcode.from_text(code))
+ except dns.rcode.UnknownRcode:
+ pass
+ if len(rcodes) > 1:
+ raise ValueError("Parse failed, too many rcode values.", rcodes)
+ if not rcodes:
+ return None
+ return rcodes[0]
+
+ @classmethod
+ def get_opcode(cls, fields):
+ """
+ From `fields` extracts and returns opcode.
+ Throws `ValueError` if there are more then one opcodes
+ """
+ opcodes = []
+ for code in fields:
+ try:
+ opcodes.append(dns.opcode.from_text(code))
+ except dns.opcode.UnknownOpcode:
+ pass
+ if len(opcodes) > 1:
+ raise ValueError("Parse failed, too many opcode values.")
+ if not opcodes:
+ return None
+ return opcodes[0]
+
+ def match(self, msg):
+ """ Compare scripted reply to given message based on match criteria. """
+ for code in self.match_fields:
+ try:
+ pydnstest.matchpart.match_part(self.message, msg, code)
+ except pydnstest.matchpart.DataMismatch as ex:
+ errstr = '%s in the response:\n%s' % (str(ex), msg.to_text())
+ # TODO: cisla radku
+ raise ValueError("%s, \"%s\": %s" % (self.node.span, code, errstr))
+
+ def cmp_raw(self, raw_value):
+ assert self.is_raw_data_entry
+ expected = None
+ if self.raw_data is not None:
+ expected = binascii.hexlify(self.raw_data)
+ got = None
+ if raw_value is not None:
+ got = binascii.hexlify(raw_value)
+ if expected != got:
+ raise ValueError("raw message comparsion failed: expected %s got %s" % (expected, got))
+
+ def reply(self, query) -> Optional[DNSBlob]:
+ if 'do_not_answer' in self.adjust_fields:
+ return None
+ if self.is_raw_data_entry:
+ copy_id = 'raw_data' in self.adjust_fields
+ assert self.raw_data is not None
+ return DNSReplyRaw(self.raw_data, query, copy_id)
+ copy_id = 'copy_id' in self.adjust_fields
+ copy_query = 'copy_query' in self.adjust_fields
+ return DNSReply(self.message, query, copy_id, copy_query)
+
+ def set_edns(self, fields):
+ """ Set EDNS version and bufsize. """
+ version = 0
+ bufsize = 4096
+ if fields and fields[0].isdigit():
+ version = int(fields.pop(0))
+ if fields and fields[0].isdigit():
+ bufsize = int(fields.pop(0))
+ if bufsize == 0:
+ self.message.use_edns(False)
+ return
+ opts = []
+ for v in fields:
+ k, v = tuple(v.split('=')) if '=' in v else (v, True)
+ if k.lower() == 'nsid':
+ opts.append(dns.edns.GenericOption(dns.edns.NSID, '' if v is True else v))
+ if k.lower() == 'subnet':
+ net = v.split('/')
+ subnet_addr = net[0]
+ family = socket.AF_INET6 if ':' in subnet_addr else socket.AF_INET
+ addr = socket.inet_pton(family, subnet_addr)
+ prefix = len(addr) * 8
+ if len(net) > 1:
+ prefix = int(net[1])
+ addr = addr[0: (prefix + 7) / 8]
+ if prefix % 8 != 0: # Mask the last byte
+ addr = addr[:-1] + chr(ord(addr[-1]) & 0xFF << (8 - prefix % 8))
+ opts.append(dns.edns.GenericOption(8, struct.pack(
+ "!HBB", 1 if family == socket.AF_INET else 2, prefix, 0) + addr))
+ self.message.use_edns(edns=version, payload=bufsize, options=opts)
+
+
+class Range:
+ """
+ Range represents a set of scripted queries valid for given step range.
+ """
+ log = logging.getLogger('pydnstest.scenario.Range')
+
+ def __init__(self, node):
+ """ Initialize reply range. """
+ self.node = node
+ self.a = int(node['/from'].value)
+ self.b = int(node['/to'].value)
+ assert self.a <= self.b
+
+ address = node["/address"].value
+ self.addresses = {address} if address is not None else set()
+ self.addresses |= {a.value for a in node.match("/address/*")}
+ self.stored = [Entry(n) for n in node.match("/entry")]
+ self.args = {}
+ self.received = 0
+ self.sent = 0
+
+ def __del__(self):
+ self.log.info('[ RANGE %d-%d ] %s received: %d sent: %d',
+ self.a, self.b, self.addresses, self.received, self.sent)
+
+ def __str__(self):
+ txt = '\nRANGE_BEGIN {a} {b}\n'.format(a=self.a, b=self.b)
+ for addr in self.addresses:
+ txt += ' ADDRESS {0}\n'.format(addr)
+
+ for entry in self.stored:
+ txt += '\n'
+ txt += str(entry)
+ txt += 'RANGE_END\n\n'
+ return txt
+
+ def eligible(self, ident, address):
+ """ Return true if this range is eligible for fetching reply. """
+ if self.a <= ident <= self.b:
+ return (None is address
+ or set() == self.addresses
+ or address in self.addresses)
+ return False
+
+ def reply(self, query: dns.message.Message) -> Optional[DNSBlob]:
+ """Get answer for given query (adjusted if needed)."""
+ self.received += 1
+ for candidate in self.stored:
+ try:
+ candidate.match(query)
+ resp = candidate.reply(query)
+ # Probabilistic loss
+ if 'LOSS' in self.args:
+ if random.random() < float(self.args['LOSS']):
+ return DNSReplyServfail(query)
+ self.sent += 1
+ candidate.fired += 1
+ return resp
+ except ValueError:
+ pass
+ return DNSReplyServfail(query)
+
+
+class StepLogger(logging.LoggerAdapter): # pylint: disable=too-few-public-methods
+ """
+ Prepent Step identification before each log message.
+ """
+ def process(self, msg, kwargs):
+ return '[STEP %s %s] %s' % (self.extra['id'], self.extra['type'], msg), kwargs
+
+
+class Step:
+ """
+ Step represents one scripted action in a given moment,
+ each step has an order identifier, type and optionally data entry.
+ """
+ require_data = ['QUERY', 'CHECK_ANSWER', 'REPLY']
+
+ def __init__(self, node):
+ """ Initialize single scenario step. """
+ self.node = node
+ self.id = int(node.value)
+ self.type = node["/type"].value
+ self.log = StepLogger(logging.getLogger('pydnstest.scenario.Step'),
+ {'id': self.id, 'type': self.type})
+ try:
+ self.delay = int(node["/timestamp"].value)
+ except KeyError:
+ pass
+ self.data = [Entry(n) for n in node.match("/entry")]
+ self.queries = []
+ self.has_data = self.type in Step.require_data
+ self.answer = None
+ self.raw_answer = None
+ self.repeat_if_fail = 0
+ self.pause_if_fail = 0
+ self.next_if_fail = -1
+
+ # TODO Parser currently can't parse CHECK_ANSWER args, player doesn't understand them anyway
+ # if type == 'CHECK_ANSWER':
+ # for arg in extra_args:
+ # param = arg.split('=')
+ # try:
+ # if param[0] == 'REPEAT':
+ # self.repeat_if_fail = int(param[1])
+ # elif param[0] == 'PAUSE':
+ # self.pause_if_fail = float(param[1])
+ # elif param[0] == 'NEXT':
+ # self.next_if_fail = int(param[1])
+ # except Exception as e:
+ # raise Exception('step %d - wrong %s arg: %s' % (self.id, param[0], str(e)))
+
+ def __str__(self):
+ txt = '\nSTEP {i} {t}'.format(i=self.id, t=self.type)
+ if self.repeat_if_fail:
+ txt += ' REPEAT {v}'.format(v=self.repeat_if_fail)
+ elif self.pause_if_fail:
+ txt += ' PAUSE {v}'.format(v=self.pause_if_fail)
+ elif self.next_if_fail != -1:
+ txt += ' NEXT {v}'.format(v=self.next_if_fail)
+ # if self.args:
+ # txt += ' '
+ # txt += ' '.join(self.args)
+ txt += '\n'
+
+ for data in self.data:
+ # from IPython.core.debugger import Tracer
+ # Tracer()()
+ txt += str(data)
+ return txt
+
+ def play(self, ctx):
+ """ Play one step from a scenario. """
+ if self.type == 'QUERY':
+ self.log.info('')
+ self.log.debug(self.data[0].message.to_text())
+ # Parse QUERY-specific parameters
+ choice, tcp, source = None, False, None
+ return self.__query(ctx, tcp=tcp, choice=choice, source=source)
+ elif self.type == 'CHECK_OUT_QUERY': # ignore
+ self.log.info('')
+ return None
+ elif self.type == 'CHECK_ANSWER' or self.type == 'ANSWER':
+ self.log.info('')
+ return self.__check_answer(ctx)
+ elif self.type == 'TIME_PASSES ELAPSE':
+ self.log.info('')
+ return self.__time_passes()
+ elif self.type == 'REPLY' or self.type == 'MOCK':
+ self.log.info('')
+ return None
+ # Parser currently doesn't support step types LOG, REPLAY and ASSERT.
+ # No test uses them.
+ # elif self.type == 'LOG':
+ # if not ctx.log:
+ # raise Exception('scenario has no log interface')
+ # return ctx.log.match(self.args)
+ # elif self.type == 'REPLAY':
+ # self.__replay(ctx)
+ # elif self.type == 'ASSERT':
+ # self.__assert(ctx)
+ else:
+ raise NotImplementedError('step %03d type %s unsupported' % (self.id, self.type))
+
+ def __check_answer(self, ctx):
+ """ Compare answer from previously resolved query. """
+ if not self.data:
+ raise ValueError("response definition required")
+ expected = self.data[0]
+ if expected.is_raw_data_entry is True:
+ self.log.debug("raw answer: %s", ctx.last_raw_answer.to_text())
+ expected.cmp_raw(ctx.last_raw_answer)
+ else:
+ if ctx.last_answer is None:
+ raise ValueError("no answer from preceding query")
+ self.log.debug("answer: %s", ctx.last_answer.to_text())
+ expected.match(ctx.last_answer)
+
+ # def __replay(self, ctx, chunksize=8):
+ # nqueries = len(self.queries)
+ # if len(self.args) > 0 and self.args[0].isdigit():
+ # nqueries = int(self.args.pop(0))
+ # destination = ctx.client[ctx.client.keys()[0]]
+ # self.log.info('replaying %d queries to %s@%d (%s)',
+ # nqueries, destination[0], destination[1], ' '.join(self.args))
+ # if 'INTENSIFY' in os.environ:
+ # nqueries *= int(os.environ['INTENSIFY'])
+ # tstart = datetime.now()
+ # nsent, nrcvd = replay_rrs(self.queries, nqueries, destination, self.args)
+ # # Keep/print the statistics
+ # rtt = (datetime.now() - tstart).total_seconds() * 1000
+ # pps = 1000 * nrcvd / rtt
+ # self.log.debug('sent: %d, received: %d (%d ms, %d p/s)', nsent, nrcvd, rtt, pps)
+ # tag = None
+ # for arg in self.args:
+ # if arg.upper().startswith('PRINT'):
+ # _, tag = tuple(arg.split('=')) if '=' in arg else (None, 'replay')
+ # if tag:
+ # self.log.info('[ REPLAY ] test: %s pps: %5d time: %4d sent: %5d received: %5d',
+ # tag.ljust(11), pps, rtt, nsent, nrcvd)
+
+ def __query(self, ctx, tcp=False, choice=None, source=None):
+ """
+ Send query and wait for an answer (if the query is not RAW).
+
+ The received answer is stored in self.answer and ctx.last_answer.
+ """
+ if not self.data:
+ raise ValueError("query definition required")
+ if self.data[0].is_raw_data_entry is True:
+ data_to_wire = self.data[0].raw_data
+ else:
+ # Don't use a message copy as the EDNS data portion is not copied.
+ data_to_wire = self.data[0].message.to_wire()
+ if choice is None or not choice:
+ choice = list(ctx.client.keys())[0]
+ if choice not in ctx.client:
+ raise ValueError('step %03d invalid QUERY target: %s' % (self.id, choice))
+ # Create socket to test subject
+ sock = None
+ destination = ctx.client[choice]
+ family = socket.AF_INET6 if ':' in destination[0] else socket.AF_INET
+ sock = socket.socket(family, socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ if tcp:
+ sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
+ sock.settimeout(3)
+ if source:
+ sock.bind((source, 0))
+ sock.connect(destination)
+ # Send query to client and wait for response
+ tstart = datetime.now()
+ while True:
+ try:
+ sendto_msg(sock, data_to_wire)
+ break
+ except OSError as ex:
+ # ENOBUFS, throttle sending
+ if ex.errno == errno.ENOBUFS:
+ time.sleep(0.1)
+ # Wait for a response for a reasonable time
+ answer = None
+ if not self.data[0].is_raw_data_entry:
+ while True:
+ if (datetime.now() - tstart).total_seconds() > 5:
+ raise RuntimeError("Server took too long to respond")
+ try:
+ answer, _ = recvfrom_msg(sock, True)
+ break
+ except OSError as ex:
+ if ex.errno == errno.ENOBUFS:
+ time.sleep(0.1)
+ # Track RTT
+ rtt = (datetime.now() - tstart).total_seconds() * 1000
+ global g_rtt, g_nqueries
+ g_nqueries += 1
+ g_rtt += rtt
+ # Remember last answer for checking later
+ self.raw_answer = answer
+ ctx.last_raw_answer = answer
+ if self.raw_answer is not None:
+ self.answer = dns.message.from_wire(self.raw_answer, one_rr_per_rrset=True)
+ else:
+ self.answer = None
+ ctx.last_answer = self.answer
+
+ def __time_passes(self):
+ """ Modify system time. """
+ file_old = os.environ["FAKETIME_TIMESTAMP_FILE"]
+ file_next = os.environ["FAKETIME_TIMESTAMP_FILE"] + ".next"
+ with open(file_old, 'r') as time_file:
+ line = time_file.readline().strip()
+ t = time.mktime(datetime.strptime(line, '@%Y-%m-%d %H:%M:%S').timetuple())
+ t += self.delay
+ with open(file_next, 'w') as time_file:
+ time_file.write(datetime.fromtimestamp(t).strftime('@%Y-%m-%d %H:%M:%S') + "\n")
+ time_file.flush()
+ os.replace(file_next, file_old)
+
+ # def __assert(self, ctx):
+ # """ Assert that a passed expression evaluates to True. """
+ # result = eval(' '.join(self.args), {'SCENARIO': ctx, 'RANGE': ctx.ranges})
+ # # Evaluate subexpressions for clarity
+ # subexpr = []
+ # for expr in self.args:
+ # try:
+ # ee = eval(expr, {'SCENARIO': ctx, 'RANGE': ctx.ranges})
+ # subexpr.append(str(ee))
+ # except:
+ # subexpr.append(expr)
+ # assert result is True, '"%s" assertion fails (%s)' % (
+ # ' '.join(self.args), ' '.join(subexpr))
+
+
+class Scenario:
+ log = logging.getLogger('pydnstest.scenatio.Scenario')
+
+ def __init__(self, node, filename):
+ """ Initialize scenario with description. """
+ self.node = node
+ self.info = node.value
+ self.file = filename
+ self.ranges = [Range(n) for n in node.match("/range")]
+ self.current_range = None
+ self.steps = [Step(n) for n in node.match("/step")]
+ self.current_step = None
+ self.client = {}
+
+ def __str__(self):
+ txt = 'SCENARIO_BEGIN'
+ if self.info:
+ txt += ' {0}'.format(self.info)
+ txt += '\n'
+ for range_ in self.ranges:
+ txt += str(range_)
+ for step in self.steps:
+ txt += str(step)
+ txt += "\nSCENARIO_END"
+ return txt
+
+ def reply(self, query: dns.message.Message, address=None) -> Optional[DNSBlob]:
+ """Generate answer packet for given query."""
+ current_step_id = self.current_step.id
+ # Unknown address, select any match
+ # TODO: workaround until the server supports stub zones
+ all_addresses = set() # type: ignore
+ for rng in self.ranges:
+ all_addresses.update(rng.addresses)
+ if address not in all_addresses:
+ address = None
+ # Find current valid query response range
+ for rng in self.ranges:
+ if rng.eligible(current_step_id, address):
+ self.current_range = rng
+ return rng.reply(query)
+ # Find any prescripted one-shot replies
+ for step in self.steps:
+ if step.id < current_step_id or step.type != 'REPLY':
+ continue
+ try:
+ candidate = step.data[0]
+ candidate.match(query)
+ step.data.remove(candidate)
+ return candidate.reply(query)
+ except (IndexError, ValueError):
+ pass
+ return DNSReplyServfail(query)
+
+ def play(self, paddr):
+ """ Play given scenario. """
+ # Store test subject => address mapping
+ self.client = paddr
+
+ step = None
+ i = 0
+ while i < len(self.steps):
+ step = self.steps[i]
+ self.current_step = step
+ try:
+ step.play(self)
+ except ValueError as ex:
+ if step.repeat_if_fail > 0:
+ self.log.info("[play] step %d: exception - '%s', retrying step %d (%d left)",
+ step.id, ex, step.next_if_fail, step.repeat_if_fail)
+ step.repeat_if_fail -= 1
+ if step.pause_if_fail > 0:
+ time.sleep(step.pause_if_fail)
+ if step.next_if_fail != -1:
+ next_steps = [j for j in range(len(self.steps)) if self.steps[
+ j].id == step.next_if_fail]
+ if not next_steps:
+ raise ValueError('step %d: wrong NEXT value "%d"' %
+ (step.id, step.next_if_fail))
+ next_step = next_steps[0]
+ if next_step < len(self.steps):
+ i = next_step
+ else:
+ raise ValueError('step %d: Can''t branch to NEXT value "%d"' %
+ (step.id, step.next_if_fail))
+ continue
+ else:
+ raise ValueError('%s step %d %s' % (self.file, step.id, str(ex)))
+ i += 1
+
+ for r in self.ranges:
+ for e in r.stored:
+ if e.mandatory and e.fired == 0:
+ # TODO: cisla radku
+ raise ValueError('Mandatory section at %s not fired' % e.mandatory.span)
+
+
+def get_next(file_in, skip_empty=True):
+ """ Return next token from the input stream. """
+ while True:
+ line = file_in.readline()
+ if not line:
+ return False
+ quoted, escaped = False, False
+ for i, char in enumerate(line):
+ if char == '\\':
+ escaped = not escaped
+ if not escaped and char == '"':
+ quoted = not quoted
+ if char == ';' and not quoted:
+ line = line[0:i]
+ break
+ if char != '\\':
+ escaped = False
+ tokens = ' '.join(line.strip().split()).split()
+ if not tokens:
+ if skip_empty:
+ continue
+ else:
+ return '', []
+ op = tokens.pop(0)
+ return op, tokens
+
+
+def parse_config(scn_cfg, qmin, installdir): # FIXME: pylint: disable=too-many-statements
+ """
+ Transform scene config (key, value) pairs into dict filled with defaults.
+ Returns tuple:
+ context dict: {Jinja2 variable: value}
+ trust anchor dict: {domain: [TA lines for particular domain]}
+ """
+ # defaults
+ do_not_query_localhost = True
+ harden_glue = True
+ sockfamily = 0 # auto-select value for socket.getaddrinfo
+ trust_anchor_list = []
+ trust_anchor_files = {}
+ negative_ta_list = []
+ stub_addr = None
+ override_timestamp = None
+
+ features = {}
+ feature_list_delimiter = ';'
+ feature_pair_delimiter = '='
+
+ for k, v in scn_cfg:
+ # Enable selectively for some tests
+ if k == 'do-not-query-localhost':
+ do_not_query_localhost = str2bool(v)
+ elif k == 'domain-insecure':
+ negative_ta_list.append(v)
+ elif k == 'harden-glue':
+ harden_glue = str2bool(v)
+ elif k == 'query-minimization':
+ qmin = str2bool(v)
+ elif k == 'trust-anchor':
+ trust_anchor = v.strip('"\'')
+ trust_anchor_list.append(trust_anchor)
+ domain = dns.name.from_text(trust_anchor.split()[0]).canonicalize()
+ if domain not in trust_anchor_files:
+ trust_anchor_files[domain] = []
+ trust_anchor_files[domain].append(trust_anchor)
+ elif k == 'val-override-timestamp':
+ override_timestamp_str = v.strip('"\'')
+ override_timestamp = int(override_timestamp_str)
+ elif k == 'val-override-date':
+ override_date_str = v.strip('"\'')
+ ovr_yr = override_date_str[0:4]
+ ovr_mnt = override_date_str[4:6]
+ ovr_day = override_date_str[6:8]
+ ovr_hr = override_date_str[8:10]
+ ovr_min = override_date_str[10:12]
+ ovr_sec = override_date_str[12:]
+ override_date_str_arg = '{0} {1} {2} {3} {4} {5}'.format(
+ ovr_yr, ovr_mnt, ovr_day, ovr_hr, ovr_min, ovr_sec)
+ override_date = time.strptime(override_date_str_arg, "%Y %m %d %H %M %S")
+ override_timestamp = calendar.timegm(override_date)
+ elif k == 'stub-addr':
+ stub_addr = v.strip('"\'')
+ elif k == 'features':
+ feature_list = v.split(feature_list_delimiter)
+ try:
+ for f_item in feature_list:
+ if f_item.find(feature_pair_delimiter) != -1:
+ f_key, f_value = [x.strip()
+ for x
+ in f_item.split(feature_pair_delimiter, 1)]
+ else:
+ f_key = f_item.strip()
+ f_value = ""
+ features[f_key] = f_value
+ except KeyError as ex:
+ raise KeyError("can't parse features (%s) in config section (%s)" % (v, str(ex)))
+ elif k == 'feature-list':
+ try:
+ f_key, f_value = [x.strip() for x in v.split(feature_pair_delimiter, 1)]
+ if f_key not in features:
+ features[f_key] = []
+ f_value = f_value.replace("{{INSTALL_DIR}}", installdir)
+ features[f_key].append(f_value)
+ except KeyError as ex:
+ raise KeyError("can't parse feature-list (%s) in config section (%s)"
+ % (v, str(ex)))
+ elif k == 'force-ipv6' and v.upper() == 'TRUE':
+ sockfamily = socket.AF_INET6
+ else:
+ raise NotImplementedError('unsupported CONFIG key "%s"' % k)
+
+ ctx = {
+ "DO_NOT_QUERY_LOCALHOST": str(do_not_query_localhost).lower(),
+ "NEGATIVE_TRUST_ANCHORS": negative_ta_list,
+ "FEATURES": features,
+ "HARDEN_GLUE": str(harden_glue).lower(),
+ "INSTALL_DIR": installdir,
+ "QMIN": str(qmin).lower(),
+ "TRUST_ANCHORS": trust_anchor_list,
+ "TRUST_ANCHOR_FILES": trust_anchor_files.keys()
+ }
+ if stub_addr:
+ ctx['ROOT_ADDR'] = stub_addr
+ # determine and verify socket family for specified root address
+ gai = socket.getaddrinfo(stub_addr, 53, sockfamily, 0,
+ socket.IPPROTO_UDP, socket.AI_NUMERICHOST)
+ assert len(gai) == 1
+ sockfamily = gai[0][0]
+ if not sockfamily:
+ sockfamily = socket.AF_INET # default to IPv4
+ ctx['_SOCKET_FAMILY'] = sockfamily
+ if override_timestamp:
+ ctx['_OVERRIDE_TIMESTAMP'] = override_timestamp
+ return (ctx, trust_anchor_files)
+
+
+def parse_file(path):
+ """ Parse scenario from a file. """
+
+ aug = pydnstest.augwrap.AugeasWrapper(
+ confpath=path, lens='Deckard', loadpath=os.path.dirname(__file__))
+ node = aug.tree
+ config = []
+ for line in [c.value for c in node.match("/config/*")]:
+ if line:
+ if not line.startswith(';'):
+ if '#' in line:
+ line = line[0:line.index('#')]
+ # Break to key-value pairs
+ # e.g.: ['minimization', 'on']
+ kv = [x.strip() for x in line.split(':', 1)]
+ if len(kv) >= 2:
+ config.append(kv)
+ scenario = Scenario(node["/scenario"], posixpath.basename(node.path))
+ return scenario, config
diff --git a/tests/deckard/pydnstest/tests/__init__.py b/tests/deckard/pydnstest/tests/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/deckard/pydnstest/tests/__init__.py
diff --git a/tests/deckard/pydnstest/tests/test_parse_config.py b/tests/deckard/pydnstest/tests/test_parse_config.py
new file mode 100644
index 0000000..0668760
--- /dev/null
+++ b/tests/deckard/pydnstest/tests/test_parse_config.py
@@ -0,0 +1,17 @@
+""" This is unittest file for parse methods in scenario.py """
+import os
+
+from pydnstest.scenario import parse_config
+
+
+def test_parse_config__trust_anchor():
+ """Checks if trust-anchors are separated into files according to domain."""
+ anchor1 = u'domain1.com.\t3600\tIN\tDS\t11901 7 1 aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'
+ anchor2 = u'domain2.net.\t3600\tIN\tDS\t59835 7 1 cccccccccccccccccccccccccccccccccccccccc'
+ anchor3 = u'domain1.com.\t3600\tIN\tDS\t11902 7 1 1111111111111111111111111111111111111111'
+ anchors = [[u'trust-anchor', u'"{}"'.format(anchor1)],
+ [u'trust-anchor', u'"{}"'.format(anchor2)],
+ [u'trust-anchor', u'"{}"'.format(anchor3)]]
+ args = (anchors, True, os.getcwd())
+ _, ta_files = parse_config(*args)
+ assert sorted(ta_files.values()) == sorted([[anchor1, anchor3], [anchor2]])
diff --git a/tests/deckard/pydnstest/tests/test_scenario.py b/tests/deckard/pydnstest/tests/test_scenario.py
new file mode 100644
index 0000000..454cb5c
--- /dev/null
+++ b/tests/deckard/pydnstest/tests/test_scenario.py
@@ -0,0 +1,55 @@
+""" This is unittest file for scenario.py """
+
+import pytest
+
+from pydnstest.scenario import Entry
+
+RCODE_FLAGS = ['NOERROR', 'FORMERR', 'SERVFAIL', 'NXDOMAIN', 'NOTIMP', 'REFUSED', 'YXDOMAIN',
+ 'YXRRSET', 'NXRRSET', 'NOTAUTH', 'NOTZONE', 'BADVERS']
+OPCODE_FLAGS = ['QUERY', 'IQUERY', 'STATUS', 'NOTIFY', 'UPDATE']
+FLAGS = ['QR', 'TC', 'AA', 'AD', 'RD', 'RA', 'CD']
+
+
+def test_entry__get_flags():
+ """Checks if all rcodes and opcodes are filtered out"""
+ expected_flags = Entry.get_flags(FLAGS)
+ for flag in RCODE_FLAGS + OPCODE_FLAGS:
+ rcode_flags = Entry.get_flags(FLAGS + [flag])
+ assert rcode_flags == expected_flags, \
+ 'Entry._get_flags does not filter out "{flag}"'.format(flag=flag)
+
+
+def test_entry__get_rcode():
+ """
+ Checks if the error is raised for multiple rcodes
+ checks if None is returned for no rcode
+ checks if flags and opcode are filtered out
+ """
+ with pytest.raises(ValueError):
+ Entry.get_rcode(RCODE_FLAGS[:2])
+
+ assert Entry.get_rcode(FLAGS) is None
+ assert Entry.get_rcode([]) is None
+
+ for rcode in RCODE_FLAGS:
+ given_rcode = Entry.get_rcode(FLAGS + OPCODE_FLAGS + [rcode])
+ assert given_rcode is not None, 'Entry.get_rcode does not recognize {rcode}'.format(
+ rcode=rcode)
+
+
+def test_entry__get_opcode():
+ """
+ Checks if the error is raised for multiple opcodes
+ checks if None is returned for no opcode
+ checks if flags and opcode are filtered out
+ """
+ with pytest.raises(ValueError):
+ Entry.get_opcode(OPCODE_FLAGS[:2])
+
+ assert Entry.get_opcode(FLAGS) is None
+ assert Entry.get_opcode([]) is None
+
+ for opcode in OPCODE_FLAGS:
+ given_rcode = Entry.get_opcode(FLAGS + RCODE_FLAGS + [opcode])
+ assert given_rcode is not None, 'Entry.get_opcode does not recognize {opcode}'.format(
+ opcode=opcode)
diff --git a/tests/deckard/pydnstest/testserver.py b/tests/deckard/pydnstest/testserver.py
new file mode 100644
index 0000000..8767644
--- /dev/null
+++ b/tests/deckard/pydnstest/testserver.py
@@ -0,0 +1,278 @@
+import argparse
+import itertools
+import logging
+import os
+import signal
+import selectors
+import socket
+import sys
+import threading
+import time
+
+import dns.message
+import dns.rdatatype
+
+from pydnstest import scenario
+
+
+class TestServer:
+ """ This simulates UDP DNS server returning scripted or mirror DNS responses. """
+
+ def __init__(self, test_scenario, root_addr, addr_family):
+ """ Initialize server instance. """
+ self.thread = None
+ self.srv_socks = []
+ self.client_socks = []
+ self.connections = []
+ self.active = False
+ self.active_lock = threading.Lock()
+ self.condition = threading.Condition()
+ self.scenario = test_scenario
+ self.addr_map = []
+ self.start_iface = 2
+ self.cur_iface = self.start_iface
+ self.kroot_local = root_addr
+ self.addr_family = addr_family
+ self.undefined_answers = 0
+
+ def __del__(self):
+ """ Cleanup after deletion. """
+ with self.active_lock:
+ active = self.active
+ if active:
+ self.stop()
+
+ def start(self, port=53):
+ """ Synchronous start """
+ with self.active_lock:
+ if self.active:
+ raise Exception('TestServer already started')
+ with self.active_lock:
+ self.active = True
+ addr, _ = self.start_srv((self.kroot_local, port), self.addr_family)
+ self.start_srv(addr, self.addr_family, socket.IPPROTO_TCP)
+ self._bind_sockets()
+
+ def stop(self):
+ """ Stop socket server operation. """
+ with self.active_lock:
+ self.active = False
+ if self.thread:
+ self.thread.join()
+ for conn in self.connections:
+ conn.close()
+ for srv_sock in self.srv_socks:
+ srv_sock.close()
+ for client_sock in self.client_socks:
+ client_sock.close()
+ self.client_socks = []
+ self.srv_socks = []
+ self.connections = []
+ self.scenario = None
+
+ def address(self):
+ """ Returns opened sockets list """
+ addrlist = []
+ for s in self.srv_socks:
+ addrlist.append(s.getsockname())
+ return addrlist
+
+ def handle_query(self, client):
+ """
+ Receive query from client socket and send an answer.
+
+ Returns:
+ True if client socket should be closed by caller
+ False if client socket should be kept open
+ """
+ log = logging.getLogger('pydnstest.testserver.handle_query')
+ server_addr = client.getsockname()[0]
+ query, client_addr = scenario.recvfrom_msg(client)
+ if query is None:
+ return False
+ log.debug('server %s received query from %s: %s', server_addr, client_addr, query)
+
+ message = self.scenario.reply(query, server_addr)
+ if not message:
+ log.debug('ignoring')
+ return True
+ elif isinstance(message, scenario.DNSReplyServfail):
+ self.undefined_answers += 1
+ self.scenario.current_step.log.error(
+ 'server %s has no response for question %s, answering with SERVFAIL',
+ server_addr,
+ '; '.join([str(rr) for rr in query.question]))
+ else:
+ log.debug('response: %s', message)
+
+ scenario.sendto_msg(client, message.to_wire(), client_addr)
+ return True
+
+ def query_io(self):
+ """ Main server process """
+ self.undefined_answers = 0
+ with self.active_lock:
+ if not self.active:
+ raise Exception("[query_io] Test server not active")
+ while True:
+ with self.condition:
+ self.condition.notify()
+ with self.active_lock:
+ if not self.active:
+ break
+ objects = self.srv_socks + self.connections
+ sel = selectors.DefaultSelector()
+ for obj in objects:
+ sel.register(obj, selectors.EVENT_READ)
+ items = sel.select(0.1)
+ for key, event in items:
+ sock = key.fileobj
+ if event & selectors.EVENT_READ:
+ if sock in self.srv_socks:
+ if sock.proto == socket.IPPROTO_TCP:
+ conn, _ = sock.accept()
+ self.connections.append(conn)
+ else:
+ self.handle_query(sock)
+ elif sock in self.connections:
+ if not self.handle_query(sock):
+ sock.close()
+ self.connections.remove(sock)
+ else:
+ raise Exception(
+ "[query_io] Socket IO internal error {}, exit"
+ .format(sock.getsockname()))
+ else:
+ raise Exception("[query_io] Socket IO error {}, exit"
+ .format(sock.getsockname()))
+
+ def start_srv(self, address, family, proto=socket.IPPROTO_UDP):
+ """ Starts listening thread if necessary """
+ assert address
+ assert address[0] # host
+ assert address[1] # port
+ assert family
+ assert proto
+ if family == socket.AF_INET6:
+ if not socket.has_ipv6:
+ raise NotImplementedError("[start_srv] IPv6 is not supported by socket {0}"
+ .format(socket))
+ elif family != socket.AF_INET:
+ raise NotImplementedError("[start_srv] unsupported protocol family {0}".format(family))
+
+ if proto == socket.IPPROTO_TCP:
+ socktype = socket.SOCK_STREAM
+ elif proto == socket.IPPROTO_UDP:
+ socktype = socket.SOCK_DGRAM
+ else:
+ raise NotImplementedError("[start_srv] unsupported protocol {0}".format(proto))
+
+ if self.thread is None:
+ self.thread = threading.Thread(target=self.query_io)
+ self.thread.start()
+ with self.condition:
+ self.condition.wait()
+
+ for srv_sock in self.srv_socks:
+ if (srv_sock.family == family
+ and srv_sock.getsockname() == address
+ and srv_sock.proto == proto):
+ return srv_sock.getsockname()
+
+ sock = socket.socket(family, socktype, proto)
+ sock.bind(address)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ if proto == socket.IPPROTO_TCP:
+ sock.listen(5)
+ self.srv_socks.append(sock)
+ sockname = sock.getsockname()
+ return sockname, proto
+
+ def _bind_sockets(self):
+ """
+ Bind test server to port 53 on all addresses referenced by test scenario.
+ """
+ # Bind to test servers
+ for r in self.scenario.ranges:
+ for addr in r.addresses:
+ family = socket.AF_INET6 if ':' in addr else socket.AF_INET
+ self.start_srv((addr, 53), family)
+
+ # Bind addresses in ad-hoc REPLYs
+ for s in self.scenario.steps:
+ if s.type == 'REPLY':
+ reply = s.data[0].message
+ for rr in itertools.chain(reply.answer,
+ reply.additional,
+ reply.question,
+ reply.authority):
+ for rd in rr:
+ if rd.rdtype == dns.rdatatype.A:
+ self.start_srv((rd.address, 53), socket.AF_INET)
+ elif rd.rdtype == dns.rdatatype.AAAA:
+ self.start_srv((rd.address, 53), socket.AF_INET6)
+
+ def play(self, subject_addr):
+ self.scenario.play({'': (subject_addr, 53)})
+
+
+def empty_test_case():
+ """
+ Return (scenario, config) pair which answers to any query on 127.0.0.10.
+ """
+ # Mirror server
+ empty_test_path = os.path.dirname(os.path.realpath(__file__)) + "/empty.rpl"
+ test_config = {'ROOT_ADDR': '127.0.0.10',
+ '_SOCKET_FAMILY': socket.AF_INET}
+ return scenario.parse_file(empty_test_path)[0], test_config
+
+
+def standalone_self_test():
+ """
+ Self-test code
+
+ Usage:
+ LD_PRELOAD=libsocket_wrapper.so SOCKET_WRAPPER_DIR=/tmp $PYTHON -m pydnstest.testserver --help
+ """
+ logging.basicConfig(level=logging.DEBUG)
+ argparser = argparse.ArgumentParser()
+ argparser.add_argument('--scenario', help='absolute path to test scenario',
+ required=False)
+ argparser.add_argument('--step', help='step # in the scenario (default: first)',
+ required=False, type=int)
+ args = argparser.parse_args()
+ if args.scenario:
+ test_scenario, test_config_text = scenario.parse_file(args.scenario)
+ test_config, _ = scenario.parse_config(test_config_text, True, os.getcwd())
+ else:
+ test_scenario, test_config = empty_test_case()
+
+ if args.step:
+ for step in test_scenario.steps:
+ if step.id == args.step:
+ test_scenario.current_step = step
+ if not test_scenario.current_step:
+ raise ValueError('step ID %s not found in scenario' % args.step)
+ else:
+ test_scenario.current_step = test_scenario.steps[0]
+
+ server = TestServer(test_scenario, test_config['ROOT_ADDR'], test_config['_SOCKET_FAMILY'])
+ server.start()
+
+ logging.info("[==========] Mirror server running at %s", server.address())
+
+ def kill(signum, frame): # pylint: disable=unused-argument
+ logging.info("[==========] Shutdown.")
+ server.stop()
+ sys.exit(128 + signum)
+
+ signal.signal(signal.SIGINT, kill)
+ signal.signal(signal.SIGTERM, kill)
+
+ while True:
+ time.sleep(0.5)
+
+
+if __name__ == '__main__':
+ # this is done to avoid creating global variables
+ standalone_self_test()