diff options
Diffstat (limited to 'tests/integration/deckard/pydnstest/matchpart.py')
-rw-r--r-- | tests/integration/deckard/pydnstest/matchpart.py | 238 |
1 files changed, 238 insertions, 0 deletions
diff --git a/tests/integration/deckard/pydnstest/matchpart.py b/tests/integration/deckard/pydnstest/matchpart.py new file mode 100644 index 0000000..4a9d8a0 --- /dev/null +++ b/tests/integration/deckard/pydnstest/matchpart.py @@ -0,0 +1,238 @@ +"""matchpart is used to compare two DNS messages using a single criterion""" + +from typing import ( # noqa + Any, Hashable, Sequence, Tuple, Union) + +import dns.edns +import dns.rcode +import dns.set + +MismatchValue = Union[str, Sequence[Any]] + + +class DataMismatch(Exception): + def __init__(self, exp_val, got_val): + super().__init__() + self.exp_val = exp_val + self.got_val = got_val + + @staticmethod + def format_value(value: MismatchValue) -> str: + if isinstance(value, list): + return ' '.join([str(val) for val in value]) + else: + return str(value) + + def __str__(self) -> str: + return 'expected "{}" got "{}"'.format( + self.format_value(self.exp_val), + self.format_value(self.got_val)) + + def __eq__(self, other): + return (isinstance(other, DataMismatch) + and self.exp_val == other.exp_val + and self.got_val == other.got_val) + + def __ne__(self, other): + return not self.__eq__(other) + + @property + def key(self) -> Tuple[Hashable, Hashable]: + def make_hashable(value): + if isinstance(value, (list, dns.set.Set)): + value = (make_hashable(item) for item in value) + value = tuple(value) + return value + + return (make_hashable(self.exp_val), make_hashable(self.got_val)) + + def __hash__(self) -> int: + return hash(self.key) + + +def compare_val(exp, got): + """Compare arbitraty objects, throw exception if different. """ + if exp != got: + raise DataMismatch(exp, got) + return True + + +def compare_rrs(expected, got): + """ Compare lists of RR sets, throw exception if different. """ + for rr in expected: + if rr not in got: + raise DataMismatch(expected, got) + for rr in got: + if rr not in expected: + raise DataMismatch(expected, got) + if len(expected) != len(got): + raise DataMismatch(expected, got) + return True + + +def compare_rrs_types(exp_val, got_val, skip_rrsigs): + """sets of RR types in both sections must match""" + def rr_ordering_key(rrset): + if rrset.covers: + return rrset.covers, 1 # RRSIGs go to the end of RRtype list + else: + return rrset.rdtype, 0 + + def key_to_text(rrtype, rrsig): + if not rrsig: + return dns.rdatatype.to_text(rrtype) + else: + return 'RRSIG(%s)' % dns.rdatatype.to_text(rrtype) + + if skip_rrsigs: + exp_val = (rrset for rrset in exp_val + if rrset.rdtype != dns.rdatatype.RRSIG) + got_val = (rrset for rrset in got_val + if rrset.rdtype != dns.rdatatype.RRSIG) + + exp_types = frozenset(rr_ordering_key(rrset) for rrset in exp_val) + got_types = frozenset(rr_ordering_key(rrset) for rrset in got_val) + if exp_types != got_types: + exp_types = tuple(key_to_text(*i) for i in sorted(exp_types)) + got_types = tuple(key_to_text(*i) for i in sorted(got_types)) + raise DataMismatch(exp_types, got_types) + + +def check_question(question): + if len(question) > 2: + raise NotImplementedError("More than one record in QUESTION SECTION.") + + +def match_opcode(exp, got): + return compare_val(exp.opcode(), + got.opcode()) + + +def match_qtype(exp, got): + check_question(exp.question) + check_question(got.question) + if not exp.question and not got.question: + return True + if not exp.question: + raise DataMismatch("<empty question>", got.question[0].rdtype) + if not got.question: + raise DataMismatch(exp.question[0].rdtype, "<empty question>") + return compare_val(exp.question[0].rdtype, + got.question[0].rdtype) + + +def match_qname(exp, got): + check_question(exp.question) + check_question(got.question) + if not exp.question and not got.question: + return True + if not exp.question: + raise DataMismatch("<empty question>", got.question[0].name) + if not got.question: + raise DataMismatch(exp.question[0].name, "<empty question>") + return compare_val(exp.question[0].name, + got.question[0].name) + + +def match_qcase(exp, got): + check_question(exp.question) + check_question(got.question) + if not exp.question and not got.question: + return True + if not exp.question: + raise DataMismatch("<empty question>", got.question[0].name.labels) + if not got.question: + raise DataMismatch(exp.question[0].name.labels, "<empty question>") + return compare_val(exp.question[0].name.labels, + got.question[0].name.labels) + + +def match_subdomain(exp, got): + if not exp.question: + return True + if got.question: + qname = got.question[0].name + else: + qname = dns.name.root + if exp.question[0].name.is_superdomain(qname): + return True + raise DataMismatch(exp, got) + + +def match_flags(exp, got): + return compare_val(dns.flags.to_text(exp.flags), + dns.flags.to_text(got.flags)) + + +def match_rcode(exp, got): + return compare_val(dns.rcode.to_text(exp.rcode()), + dns.rcode.to_text(got.rcode())) + + +def match_answer(exp, got): + return compare_rrs(exp.answer, + got.answer) + + +def match_answertypes(exp, got): + return compare_rrs_types(exp.answer, + got.answer, skip_rrsigs=True) + + +def match_answerrrsigs(exp, got): + return compare_rrs_types(exp.answer, + got.answer, skip_rrsigs=False) + + +def match_authority(exp, got): + return compare_rrs(exp.authority, + got.authority) + + +def match_additional(exp, got): + return compare_rrs(exp.additional, + got.additional) + + +def match_edns(exp, got): + if got.edns != exp.edns: + raise DataMismatch(exp.edns, + got.edns) + if got.payload != exp.payload: + raise DataMismatch(exp.payload, + got.payload) + + +def match_nsid(exp, got): + nsid_opt = None + for opt in exp.options: + if opt.otype == dns.edns.NSID: + nsid_opt = opt + break + # Find matching NSID + for opt in got.options: + if opt.otype == dns.edns.NSID: + if not nsid_opt: + raise DataMismatch(None, opt.data) + if opt == nsid_opt: + return True + else: + raise DataMismatch(nsid_opt.data, opt.data) + if nsid_opt: + raise DataMismatch(nsid_opt.data, None) + return True + + +MATCH = {"opcode": match_opcode, "qtype": match_qtype, "qname": match_qname, "qcase": match_qcase, + "subdomain": match_subdomain, "flags": match_flags, "rcode": match_rcode, + "answer": match_answer, "answertypes": match_answertypes, + "answerrrsigs": match_answerrrsigs, "authority": match_authority, + "additional": match_additional, "edns": match_edns, + "nsid": match_nsid} + + +def match_part(exp, got, code): + try: + return MATCH[code](exp, got) + except KeyError as ex: + raise NotImplementedError('unknown match request "%s"' % code) from ex |