1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
|
# SPDX-License-Identifier: GPL-3.0-or-later
from contextlib import contextmanager
import random
import ssl
import struct
import time
import dns
import dns.message
import pytest
# default net.tcp_in_idle is 10s, TCP_DEFER_ACCEPT 3s, some extra for
# Python handling / edge cases
MAX_TIMEOUT = 16
def receive_answer(sock):
answer_total_len = 0
data = sock.recv(2)
if not data:
return None
answer_total_len = struct.unpack_from("!H", data)[0]
answer_received_len = 0
data_answer = b''
while answer_received_len < answer_total_len:
data_chunk = sock.recv(answer_total_len - answer_received_len)
if not data_chunk:
return None
data_answer += data_chunk
answer_received_len += len(data_answer)
return data_answer
def receive_parse_answer(sock):
data_answer = receive_answer(sock)
if data_answer is None:
raise BrokenPipeError("kresd closed connection")
msg_answer = dns.message.from_wire(data_answer, one_rr_per_rrset=True)
return msg_answer
def prepare_wire(
qname='localhost.',
qtype=dns.rdatatype.A,
qclass=dns.rdataclass.IN,
msgid=None):
"""Utility function to generate DNS wire format message"""
msg = dns.message.make_query(qname, qtype, qclass, use_edns=True)
if msgid is not None:
msg.id = msgid
return msg.to_wire(), msg.id
def prepare_buffer(wire, datalen=None):
"""Utility function to prepare TCP buffer from DNS message in wire format"""
assert isinstance(wire, bytes)
if datalen is None:
datalen = len(wire)
return struct.pack("!H", datalen) + wire
def get_msgbuff(qname='localhost.', qtype=dns.rdatatype.A, msgid=None):
wire, msgid = prepare_wire(qname, qtype, msgid=msgid)
buff = prepare_buffer(wire)
return buff, msgid
def get_garbage(length):
return bytes(random.getrandbits(8) for _ in range(length))
def get_prefixed_garbage(length):
data = get_garbage(length)
return prepare_buffer(data)
def try_ping_alive(sock, msgid=None, close=False):
try:
ping_alive(sock, msgid)
except AssertionError:
return False
finally:
if close:
sock.close()
return True
def ping_alive(sock, msgid=None):
buff, msgid = get_msgbuff(msgid=msgid)
sock.sendall(buff)
answer = receive_parse_answer(sock)
assert answer.id == msgid
@contextmanager
def expect_kresd_close(rst_ok=False):
with pytest.raises((BrokenPipeError, ssl.SSLEOFError)):
try:
time.sleep(0.2) # give kresd time to close connection with TCP FIN
yield
except ConnectionResetError as ex:
if rst_ok:
raise BrokenPipeError from ex
pytest.skip("kresd closed connection with TCP RST")
pytest.fail("kresd didn't close the connection")
def make_ssl_context(insecure=False, verify_location=None,
minimum_tls=ssl.TLSVersion.TLSv1_2,
maximum_tls=ssl.TLSVersion.MAXIMUM_SUPPORTED):
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.minimum_version = minimum_tls
context.maximum_version = maximum_tls
if insecure:
# turn off certificate verification
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
else:
context.verify_mode = ssl.CERT_REQUIRED
context.check_hostname = True
if verify_location is not None:
context.load_verify_locations(verify_location)
return context
|