"""matchpart is used to compare two DNS messages using a single criterion"""

from typing import (  # noqa
    Any, Hashable, Sequence, Tuple, Union)

import dns.edns
import dns.rcode
import dns.set

MismatchValue = Union[str, Sequence[Any]]


class DataMismatch(Exception):
    def __init__(self, exp_val, got_val):
        super().__init__()
        self.exp_val = exp_val
        self.got_val = got_val

    @staticmethod
    def format_value(value: MismatchValue) -> str:
        if isinstance(value, list):
            return ' '.join([str(val) for val in value])
        else:
            return str(value)

    def __str__(self) -> str:
        return 'expected "{}" got "{}"'.format(
            self.format_value(self.exp_val),
            self.format_value(self.got_val))

    def __eq__(self, other):
        return (isinstance(other, DataMismatch)
                and self.exp_val == other.exp_val
                and self.got_val == other.got_val)

    def __ne__(self, other):
        return not self.__eq__(other)

    @property
    def key(self) -> Tuple[Hashable, Hashable]:
        def make_hashable(value):
            if isinstance(value, (list, dns.set.Set)):
                value = (make_hashable(item) for item in value)
                value = tuple(value)
            return value

        return (make_hashable(self.exp_val), make_hashable(self.got_val))

    def __hash__(self) -> int:
        return hash(self.key)


def compare_val(exp, got):
    """Compare arbitraty objects, throw exception if different. """
    if exp != got:
        raise DataMismatch(exp, got)
    return True


def compare_rrs(expected, got):
    """ Compare lists of RR sets, throw exception if different. """
    for rr in expected:
        if rr not in got:
            raise DataMismatch(expected, got)
    for rr in got:
        if rr not in expected:
            raise DataMismatch(expected, got)
    if len(expected) != len(got):
        raise DataMismatch(expected, got)
    return True


def compare_rrs_types(exp_val, got_val, skip_rrsigs):
    """sets of RR types in both sections must match"""
    def rr_ordering_key(rrset):
        if rrset.covers:
            return rrset.covers, 1  # RRSIGs go to the end of RRtype list
        else:
            return rrset.rdtype, 0

    def key_to_text(rrtype, rrsig):
        if not rrsig:
            return dns.rdatatype.to_text(rrtype)
        else:
            return 'RRSIG(%s)' % dns.rdatatype.to_text(rrtype)

    if skip_rrsigs:
        exp_val = (rrset for rrset in exp_val
                   if rrset.rdtype != dns.rdatatype.RRSIG)
        got_val = (rrset for rrset in got_val
                   if rrset.rdtype != dns.rdatatype.RRSIG)

    exp_types = frozenset(rr_ordering_key(rrset) for rrset in exp_val)
    got_types = frozenset(rr_ordering_key(rrset) for rrset in got_val)
    if exp_types != got_types:
        exp_types = tuple(key_to_text(*i) for i in sorted(exp_types))
        got_types = tuple(key_to_text(*i) for i in sorted(got_types))
        raise DataMismatch(exp_types, got_types)


def check_question(question):
    if len(question) > 2:
        raise NotImplementedError("More than one record in QUESTION SECTION.")


def match_opcode(exp, got):
    return compare_val(exp.opcode(),
                       got.opcode())


def match_qtype(exp, got):
    check_question(exp.question)
    check_question(got.question)
    if not exp.question and not got.question:
        return True
    if not exp.question:
        raise DataMismatch("<empty question>", got.question[0].rdtype)
    if not got.question:
        raise DataMismatch(exp.question[0].rdtype, "<empty question>")
    return compare_val(exp.question[0].rdtype,
                       got.question[0].rdtype)


def match_qname(exp, got):
    check_question(exp.question)
    check_question(got.question)
    if not exp.question and not got.question:
        return True
    if not exp.question:
        raise DataMismatch("<empty question>", got.question[0].name)
    if not got.question:
        raise DataMismatch(exp.question[0].name, "<empty question>")
    return compare_val(exp.question[0].name,
                       got.question[0].name)


def match_qcase(exp, got):
    check_question(exp.question)
    check_question(got.question)
    if not exp.question and not got.question:
        return True
    if not exp.question:
        raise DataMismatch("<empty question>", got.question[0].name.labels)
    if not got.question:
        raise DataMismatch(exp.question[0].name.labels, "<empty question>")
    return compare_val(exp.question[0].name.labels,
                       got.question[0].name.labels)


def match_subdomain(exp, got):
    if not exp.question:
        return True
    if got.question:
        qname = got.question[0].name
    else:
        qname = dns.name.root
    if exp.question[0].name.is_superdomain(qname):
        return True
    raise DataMismatch(exp, got)


def match_flags(exp, got):
    return compare_val(dns.flags.to_text(exp.flags),
                       dns.flags.to_text(got.flags))


def match_rcode(exp, got):
    return compare_val(dns.rcode.to_text(exp.rcode()),
                       dns.rcode.to_text(got.rcode()))


def match_answer(exp, got):
    return compare_rrs(exp.answer,
                       got.answer)


def match_answertypes(exp, got):
    return compare_rrs_types(exp.answer,
                             got.answer, skip_rrsigs=True)


def match_answerrrsigs(exp, got):
    return compare_rrs_types(exp.answer,
                             got.answer, skip_rrsigs=False)


def match_authority(exp, got):
    return compare_rrs(exp.authority,
                       got.authority)


def match_additional(exp, got):
    return compare_rrs(exp.additional,
                       got.additional)


def match_edns(exp, got):
    if got.edns != exp.edns:
        raise DataMismatch(exp.edns,
                           got.edns)
    if got.payload != exp.payload:
        raise DataMismatch(exp.payload,
                           got.payload)


def match_nsid(exp, got):
    nsid_opt = None
    for opt in exp.options:
        if opt.otype == dns.edns.NSID:
            nsid_opt = opt
            break
    # Find matching NSID
    for opt in got.options:
        if opt.otype == dns.edns.NSID:
            if not nsid_opt:
                raise DataMismatch(None, opt.data)
            if opt == nsid_opt:
                return True
            else:
                raise DataMismatch(nsid_opt.data, opt.data)
    if nsid_opt:
        raise DataMismatch(nsid_opt.data, None)
    return True


MATCH = {"opcode": match_opcode, "qtype": match_qtype, "qname": match_qname, "qcase": match_qcase,
         "subdomain": match_subdomain, "flags": match_flags, "rcode": match_rcode,
         "answer": match_answer, "answertypes": match_answertypes,
         "answerrrsigs": match_answerrrsigs, "authority": match_authority,
         "additional": match_additional, "edns": match_edns,
         "nsid": match_nsid}


def match_part(exp, got, code):
    try:
        return MATCH[code](exp, got)
    except KeyError as ex:
        raise NotImplementedError('unknown match request "%s"' % code) from ex