239 lines
7 KiB
Python
239 lines
7 KiB
Python
"""matchpart is used to compare two DNS messages using a single criterion"""
|
|
|
|
from typing import ( # noqa
|
|
Any, Hashable, Sequence, Tuple, Union)
|
|
|
|
import dns.edns
|
|
import dns.rcode
|
|
import dns.set
|
|
|
|
MismatchValue = Union[str, Sequence[Any]]
|
|
|
|
|
|
class DataMismatch(Exception):
|
|
def __init__(self, exp_val, got_val):
|
|
super().__init__()
|
|
self.exp_val = exp_val
|
|
self.got_val = got_val
|
|
|
|
@staticmethod
|
|
def format_value(value: MismatchValue) -> str:
|
|
if isinstance(value, list):
|
|
return ' '.join([str(val) for val in value])
|
|
else:
|
|
return str(value)
|
|
|
|
def __str__(self) -> str:
|
|
return (
|
|
f'expected "{self.format_value(self.exp_val)}" '
|
|
f'got "{self.format_value(self.got_val)}"'
|
|
)
|
|
|
|
def __eq__(self, other):
|
|
return (isinstance(other, DataMismatch)
|
|
and self.exp_val == other.exp_val
|
|
and self.got_val == other.got_val)
|
|
|
|
def __ne__(self, other):
|
|
return not self.__eq__(other)
|
|
|
|
@property
|
|
def key(self) -> Tuple[Hashable, Hashable]:
|
|
def make_hashable(value):
|
|
if isinstance(value, (list, dns.set.Set)):
|
|
value = (make_hashable(item) for item in value)
|
|
value = tuple(value)
|
|
return value
|
|
|
|
return (make_hashable(self.exp_val), make_hashable(self.got_val))
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(self.key)
|
|
|
|
|
|
def compare_val(exp, got):
|
|
"""Compare arbitraty objects, throw exception if different. """
|
|
if exp != got:
|
|
raise DataMismatch(exp, got)
|
|
return True
|
|
|
|
|
|
def compare_rrs(expected, got):
|
|
""" Compare lists of RR sets, throw exception if different. """
|
|
for rr in expected:
|
|
if rr not in got:
|
|
raise DataMismatch(expected, got)
|
|
for rr in got:
|
|
if rr not in expected:
|
|
raise DataMismatch(expected, got)
|
|
if len(expected) != len(got):
|
|
raise DataMismatch(expected, got)
|
|
return True
|
|
|
|
|
|
def compare_rrs_types(exp_val, got_val, skip_rrsigs):
|
|
"""sets of RR types in both sections must match"""
|
|
def rr_ordering_key(rrset):
|
|
if rrset.covers:
|
|
return rrset.covers, 1 # RRSIGs go to the end of RRtype list
|
|
else:
|
|
return rrset.rdtype, 0
|
|
|
|
def key_to_text(rrtype, rrsig):
|
|
if not rrsig:
|
|
return dns.rdatatype.to_text(rrtype)
|
|
else:
|
|
return f'RRSIG({dns.rdatatype.to_text(rrtype)})'
|
|
|
|
if skip_rrsigs:
|
|
exp_val = (rrset for rrset in exp_val
|
|
if rrset.rdtype != dns.rdatatype.RRSIG)
|
|
got_val = (rrset for rrset in got_val
|
|
if rrset.rdtype != dns.rdatatype.RRSIG)
|
|
|
|
exp_types = frozenset(rr_ordering_key(rrset) for rrset in exp_val)
|
|
got_types = frozenset(rr_ordering_key(rrset) for rrset in got_val)
|
|
if exp_types != got_types:
|
|
exp_types = tuple(key_to_text(*i) for i in sorted(exp_types))
|
|
got_types = tuple(key_to_text(*i) for i in sorted(got_types))
|
|
raise DataMismatch(exp_types, got_types)
|
|
|
|
|
|
def check_question(question):
|
|
if len(question) > 2:
|
|
raise NotImplementedError("More than one record in QUESTION SECTION.")
|
|
|
|
|
|
def match_opcode(exp, got):
|
|
return compare_val(exp.opcode(),
|
|
got.opcode())
|
|
|
|
|
|
def match_qtype(exp, got):
|
|
check_question(exp.question)
|
|
check_question(got.question)
|
|
if not exp.question and not got.question:
|
|
return True
|
|
if not exp.question:
|
|
raise DataMismatch("<empty question>", got.question[0].rdtype)
|
|
if not got.question:
|
|
raise DataMismatch(exp.question[0].rdtype, "<empty question>")
|
|
return compare_val(exp.question[0].rdtype,
|
|
got.question[0].rdtype)
|
|
|
|
|
|
def match_qname(exp, got):
|
|
check_question(exp.question)
|
|
check_question(got.question)
|
|
if not exp.question and not got.question:
|
|
return True
|
|
if not exp.question:
|
|
raise DataMismatch("<empty question>", got.question[0].name)
|
|
if not got.question:
|
|
raise DataMismatch(exp.question[0].name, "<empty question>")
|
|
return compare_val(exp.question[0].name,
|
|
got.question[0].name)
|
|
|
|
|
|
def match_qcase(exp, got):
|
|
check_question(exp.question)
|
|
check_question(got.question)
|
|
if not exp.question and not got.question:
|
|
return True
|
|
if not exp.question:
|
|
raise DataMismatch("<empty question>", got.question[0].name.labels)
|
|
if not got.question:
|
|
raise DataMismatch(exp.question[0].name.labels, "<empty question>")
|
|
return compare_val(exp.question[0].name.labels,
|
|
got.question[0].name.labels)
|
|
|
|
|
|
def match_subdomain(exp, got):
|
|
if not exp.question:
|
|
return True
|
|
if got.question:
|
|
qname = got.question[0].name
|
|
else:
|
|
qname = dns.name.root
|
|
if exp.question[0].name.is_superdomain(qname):
|
|
return True
|
|
raise DataMismatch(exp, got)
|
|
|
|
|
|
def match_flags(exp, got):
|
|
return compare_val(dns.flags.to_text(exp.flags),
|
|
dns.flags.to_text(got.flags))
|
|
|
|
|
|
def match_rcode(exp, got):
|
|
return compare_val(dns.rcode.to_text(exp.rcode()),
|
|
dns.rcode.to_text(got.rcode()))
|
|
|
|
|
|
def match_answer(exp, got):
|
|
return compare_rrs(exp.answer,
|
|
got.answer)
|
|
|
|
|
|
def match_answertypes(exp, got):
|
|
return compare_rrs_types(exp.answer,
|
|
got.answer, skip_rrsigs=True)
|
|
|
|
|
|
def match_answerrrsigs(exp, got):
|
|
return compare_rrs_types(exp.answer,
|
|
got.answer, skip_rrsigs=False)
|
|
|
|
|
|
def match_authority(exp, got):
|
|
return compare_rrs(exp.authority,
|
|
got.authority)
|
|
|
|
|
|
def match_additional(exp, got):
|
|
return compare_rrs(exp.additional,
|
|
got.additional)
|
|
|
|
|
|
def match_edns(exp, got):
|
|
if got.edns != exp.edns:
|
|
raise DataMismatch(exp.edns,
|
|
got.edns)
|
|
if got.payload != exp.payload:
|
|
raise DataMismatch(exp.payload,
|
|
got.payload)
|
|
|
|
|
|
def match_nsid(exp, got):
|
|
nsid_opt = None
|
|
for opt in exp.options:
|
|
if opt.otype == dns.edns.NSID:
|
|
nsid_opt = opt
|
|
break
|
|
# Find matching NSID
|
|
for opt in got.options:
|
|
if opt.otype == dns.edns.NSID:
|
|
if not nsid_opt:
|
|
raise DataMismatch(None, opt.data)
|
|
if opt == nsid_opt:
|
|
return True
|
|
else:
|
|
raise DataMismatch(nsid_opt.data, opt.data)
|
|
if nsid_opt:
|
|
raise DataMismatch(nsid_opt.data, None)
|
|
return True
|
|
|
|
|
|
MATCH = {"opcode": match_opcode, "qtype": match_qtype, "qname": match_qname, "qcase": match_qcase,
|
|
"subdomain": match_subdomain, "flags": match_flags, "rcode": match_rcode,
|
|
"answer": match_answer, "answertypes": match_answertypes,
|
|
"answerrrsigs": match_answerrrsigs, "authority": match_authority,
|
|
"additional": match_additional, "edns": match_edns,
|
|
"nsid": match_nsid}
|
|
|
|
|
|
def match_part(exp, got, code):
|
|
try:
|
|
return MATCH[code](exp, got)
|
|
except KeyError as ex:
|
|
raise NotImplementedError(f'unknown match request "{code}"') from ex
|