131 lines
3.5 KiB
Python
131 lines
3.5 KiB
Python
# SPDX-License-Identifier: GPL-3.0-or-later
|
|
from contextlib import contextmanager
|
|
import random
|
|
import ssl
|
|
import struct
|
|
import time
|
|
|
|
import dns
|
|
import dns.message
|
|
import pytest
|
|
|
|
|
|
# default net.tcp_in_idle is 10s, TCP_DEFER_ACCEPT 3s, some extra for
|
|
# Python handling / edge cases
|
|
MAX_TIMEOUT = 16
|
|
|
|
|
|
def receive_answer(sock):
|
|
answer_total_len = 0
|
|
data = sock.recv(2)
|
|
if not data:
|
|
return None
|
|
answer_total_len = struct.unpack_from("!H", data)[0]
|
|
|
|
answer_received_len = 0
|
|
data_answer = b''
|
|
while answer_received_len < answer_total_len:
|
|
data_chunk = sock.recv(answer_total_len - answer_received_len)
|
|
if not data_chunk:
|
|
return None
|
|
data_answer += data_chunk
|
|
answer_received_len += len(data_answer)
|
|
|
|
return data_answer
|
|
|
|
|
|
def receive_parse_answer(sock):
|
|
data_answer = receive_answer(sock)
|
|
|
|
if data_answer is None:
|
|
raise BrokenPipeError("kresd closed connection")
|
|
|
|
msg_answer = dns.message.from_wire(data_answer, one_rr_per_rrset=True)
|
|
return msg_answer
|
|
|
|
|
|
def prepare_wire(
|
|
qname='localhost.',
|
|
qtype=dns.rdatatype.A,
|
|
qclass=dns.rdataclass.IN,
|
|
msgid=None):
|
|
"""Utility function to generate DNS wire format message"""
|
|
msg = dns.message.make_query(qname, qtype, qclass, use_edns=True)
|
|
if msgid is not None:
|
|
msg.id = msgid
|
|
return msg.to_wire(), msg.id
|
|
|
|
|
|
def prepare_buffer(wire, datalen=None):
|
|
"""Utility function to prepare TCP buffer from DNS message in wire format"""
|
|
assert isinstance(wire, bytes)
|
|
if datalen is None:
|
|
datalen = len(wire)
|
|
return struct.pack("!H", datalen) + wire
|
|
|
|
|
|
def get_msgbuff(qname='localhost.', qtype=dns.rdatatype.A, msgid=None):
|
|
wire, msgid = prepare_wire(qname, qtype, msgid=msgid)
|
|
buff = prepare_buffer(wire)
|
|
return buff, msgid
|
|
|
|
|
|
def get_garbage(length):
|
|
return bytes(random.getrandbits(8) for _ in range(length))
|
|
|
|
|
|
def get_prefixed_garbage(length):
|
|
data = get_garbage(length)
|
|
return prepare_buffer(data)
|
|
|
|
|
|
def try_ping_alive(sock, msgid=None, close=False):
|
|
try:
|
|
ping_alive(sock, msgid)
|
|
except AssertionError:
|
|
return False
|
|
finally:
|
|
if close:
|
|
sock.close()
|
|
return True
|
|
|
|
|
|
def ping_alive(sock, msgid=None):
|
|
buff, msgid = get_msgbuff(msgid=msgid)
|
|
sock.sendall(buff)
|
|
answer = receive_parse_answer(sock)
|
|
assert answer.id == msgid
|
|
|
|
|
|
@contextmanager
|
|
def expect_kresd_close(rst_ok=False):
|
|
with pytest.raises((BrokenPipeError, ssl.SSLEOFError)):
|
|
try:
|
|
time.sleep(0.2) # give kresd time to close connection with TCP FIN
|
|
yield
|
|
except ConnectionResetError as ex:
|
|
if rst_ok:
|
|
raise BrokenPipeError from ex
|
|
pytest.skip("kresd closed connection with TCP RST")
|
|
pytest.fail("kresd didn't close the connection")
|
|
|
|
|
|
def make_ssl_context(insecure=False, verify_location=None,
|
|
minimum_tls=ssl.TLSVersion.TLSv1_2,
|
|
maximum_tls=ssl.TLSVersion.MAXIMUM_SUPPORTED):
|
|
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
context.minimum_version = minimum_tls
|
|
context.maximum_version = maximum_tls
|
|
|
|
if insecure:
|
|
# turn off certificate verification
|
|
context.check_hostname = False
|
|
context.verify_mode = ssl.CERT_NONE
|
|
else:
|
|
context.verify_mode = ssl.CERT_REQUIRED
|
|
context.check_hostname = True
|
|
|
|
if verify_location is not None:
|
|
context.load_verify_locations(verify_location)
|
|
|
|
return context
|