summaryrefslogtreecommitdiffstats
path: root/tests/integration/deckard/tools/answer_checker.py
blob: 3754ef357ca7fa447c0193a6e827c202f73de7c1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""Functions for sending DNS queries and checking recieved answers checking"""
# pylint: disable=C0301
# flake8: noqa

from ipaddress import IPv4Address, IPv6Address
import random
from typing import Iterable, Optional, Set, Union

import dns.message
import dns.flags

import pydnstest.matchpart
import pydnstest.mock_client

def unset_flag(message: dns.message.Message, flag: int) -> dns.message.Message:
    """Unsets given flag in given DNS message."""
    message.flags &= ~flag
    return message


def send_and_check(question: Union[dns.message.Message, bytes],  # pylint: disable=R0913
                   expected: dns.message.Message,
                   server: Union[IPv4Address, IPv6Address],
                   match_fields: Set[str],
                   port: int = 53,
                   tcp: bool = False,
                   timeout: int = pydnstest.mock_client.SOCKET_OPERATION_TIMEOUT,
                   unset_flags: Iterable[int] = tuple()) -> bool:
    """Checks if DNS answer recieved for a question from a server matches expected one in specified
    field. See pydnstest.matchpart for more information on match fields

    Returns True on success, raises an exceptions on failure.
    """
    print("Sending query:\n%s\n" % str(question))
    answer = get_answer(question, server, port, tcp, timeout=timeout)

    for flag in unset_flags:
        answer = unset_flag(answer, flag)

    print("Got answer:\n%s\n" % answer)
    print("Matching:\n%s\n%s\n" % (match_fields, expected))
    for field in match_fields:
        pydnstest.matchpart.match_part(expected, answer, field)

    return True


def get_answer(question: Union[dns.message.Message, bytes],
               server: Union[IPv4Address, IPv6Address],
               port: int = 53,
               tcp: bool = False,
               timeout: int = pydnstest.mock_client.SOCKET_OPERATION_TIMEOUT) -> dns.message.Message:
    """Get an DNS message with answer with specific query"""
    sock = pydnstest.mock_client.setup_socket(str(server), port, tcp=tcp)
    with sock:
        pydnstest.mock_client.send_query(sock, question)
        return pydnstest.mock_client.get_dns_message(sock, timeout=timeout)


def string_answer(question: Union[dns.message.Message, bytes],
                  server: Union[IPv4Address, IPv6Address],
                  port: int = 53,
                  tcp: bool = False) -> str:
    """Prints answer of a server. Good for generating tests."""
    return get_answer(question, server, port, tcp).to_text()


def randomize_case(label: bytes) -> bytes:
    """Randomize case in a DNS name label"""
    output = []
    for byte in label:
        if random.randint(0, 1):
            output.append(bytes([byte]).swapcase())
        else:
            output.append(bytes([byte]))
    return b''.join(output)


def make_random_case_query(name: str, *args, **kwargs) -> dns.message.Message:
    """Proxy for dns.message.make_query with rANdoM-cASe"""
    query = dns.message.make_query(name, *args, **kwargs)
    for label in query.question[0].name.labels:
        label = randomize_case(label)
    return query