1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
|
"""Module takes care of sending and recieving DNS messages as a mock client"""
import errno
import socket
import struct
import time
from typing import Optional, Tuple, Union
import dns.message
import dns.inet
SOCKET_OPERATION_TIMEOUT = 5
RECEIVE_MESSAGE_SIZE = 2**16-1
THROTTLE_BY = 0.1
def handle_socket_timeout(sock: socket.socket, deadline: float):
# deadline is always time.monotonic
remaining = deadline - time.monotonic()
if remaining <= 0:
raise RuntimeError("Server took too long to respond")
sock.settimeout(remaining)
def recv_n_bytes_from_tcp(stream: socket.socket, n: int, deadline: float) -> bytes:
# deadline is always time.monotonic
data = b""
while n != 0:
handle_socket_timeout(stream, deadline)
chunk = stream.recv(n)
# Empty bytes from socket.recv mean that socket is closed
if not chunk:
raise OSError()
n -= len(chunk)
data += chunk
return data
def recvfrom_blob(sock: socket.socket,
timeout: int = SOCKET_OPERATION_TIMEOUT) -> Tuple[bytes, str]:
"""
Receive DNS message from TCP/UDP socket.
"""
# deadline is always time.monotonic
deadline = time.monotonic() + timeout
while True:
try:
if sock.type & socket.SOCK_DGRAM:
handle_socket_timeout(sock, deadline)
data, addr = sock.recvfrom(RECEIVE_MESSAGE_SIZE)
elif sock.type & socket.SOCK_STREAM:
# First 2 bytes of TCP packet are the size of the message
# See https://tools.ietf.org/html/rfc1035#section-4.2.2
data = recv_n_bytes_from_tcp(sock, 2, deadline)
msg_len = struct.unpack_from("!H", data)[0]
data = recv_n_bytes_from_tcp(sock, msg_len, deadline)
addr = sock.getpeername()[0]
else:
raise NotImplementedError("[recvfrom_blob]: unknown socket type '%i'" % sock.type)
return data, addr
except socket.timeout as ex:
raise RuntimeError("Server took too long to respond") from ex
except OSError as ex:
if ex.errno == errno.ENOBUFS:
time.sleep(0.1)
else:
raise
def recvfrom_msg(sock: socket.socket,
timeout: int = SOCKET_OPERATION_TIMEOUT) -> Tuple[dns.message.Message, str]:
data, addr = recvfrom_blob(sock, timeout=timeout)
msg = dns.message.from_wire(data, one_rr_per_rrset=True)
return msg, addr
def sendto_msg(sock: socket.socket, message: bytes, addr: Optional[str] = None) -> None:
""" Send DNS/UDP/TCP message. """
try:
if sock.type & socket.SOCK_DGRAM:
if addr is None:
sock.send(message)
else:
sock.sendto(message, addr)
elif sock.type & socket.SOCK_STREAM:
data = struct.pack("!H", len(message)) + message
sock.sendall(data)
else:
raise NotImplementedError("[sendto_msg]: unknown socket type '%i'" % sock.type)
except OSError as ex:
# Reference: http://lkml.iu.edu/hypermail/linux/kernel/0002.3/0709.html
if ex.errno != errno.ECONNREFUSED:
raise
def setup_socket(address: str,
port: int,
tcp: bool = False,
src_address: str = None) -> socket.socket:
family = dns.inet.af_for_address(address)
sock = socket.socket(family, socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM)
if tcp:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
if src_address is not None:
sock.bind((src_address, 0)) # random source port
sock.settimeout(SOCKET_OPERATION_TIMEOUT)
sock.connect((address, port))
return sock
def send_query(sock: socket.socket, query: Union[dns.message.Message, bytes]) -> None:
message = query if isinstance(query, bytes) else query.to_wire()
while True:
try:
sendto_msg(sock, message)
break
except OSError as ex:
# ENOBUFS, throttle sending
if ex.errno == errno.ENOBUFS:
time.sleep(0.1)
else:
raise
def get_answer(sock: socket.socket, timeout: int = SOCKET_OPERATION_TIMEOUT) -> bytes:
""" Compatibility function """
answer, _ = recvfrom_blob(sock, timeout=timeout)
return answer
def get_dns_message(sock: socket.socket,
timeout: int = SOCKET_OPERATION_TIMEOUT) -> dns.message.Message:
return dns.message.from_wire(get_answer(sock, timeout=timeout))
|