316 lines
11 KiB
Python
316 lines
11 KiB
Python
# SPDX-License-Identifier: GPL-3.0-or-later
|
|
|
|
from collections import namedtuple
|
|
from contextlib import ContextDecorator, contextmanager
|
|
import os
|
|
from pathlib import Path
|
|
import random
|
|
import re
|
|
import shutil
|
|
import socket
|
|
import subprocess
|
|
import time
|
|
|
|
import jinja2
|
|
|
|
import utils
|
|
|
|
|
|
PYTESTS_DIR = os.path.dirname(os.path.realpath(__file__))
|
|
CERTS_DIR = os.path.join(PYTESTS_DIR, 'certs')
|
|
TEMPLATES_DIR = os.path.join(PYTESTS_DIR, 'templates')
|
|
KRESD_CONF_TEMPLATE = 'kresd.conf.j2'
|
|
KRESD_STARTUP_MSGID = 10005 # special unique ID at the start of the "test" log
|
|
KRESD_PORTDIR = '/tmp/pytest-kresd-portdir'
|
|
KRESD_TESTPORT_MIN = 1024
|
|
KRESD_TESTPORT_MAX = 32768 # avoid overlap with docker ephemeral port range
|
|
|
|
|
|
def init_portdir():
|
|
try:
|
|
shutil.rmtree(KRESD_PORTDIR)
|
|
except FileNotFoundError:
|
|
pass
|
|
os.makedirs(KRESD_PORTDIR)
|
|
|
|
|
|
def create_file_from_template(template_path, dest, data):
|
|
env = jinja2.Environment(
|
|
loader=jinja2.FileSystemLoader(TEMPLATES_DIR))
|
|
template = env.get_template(template_path)
|
|
rendered_template = template.render(**data)
|
|
|
|
with open(dest, "w", encoding='UTF-8') as fh:
|
|
fh.write(rendered_template)
|
|
|
|
|
|
Forward = namedtuple('Forward', ['proto', 'ip', 'port', 'hostname', 'ca_file'])
|
|
|
|
|
|
class Kresd(ContextDecorator):
|
|
def __init__(
|
|
self, workdir, port=None, tls_port=None, ip=None, ip6=None, certname=None,
|
|
verbose=True, hints=None, forward=None, policy_test_pass=False, rr=False,
|
|
valgrind=False):
|
|
if ip is None and ip6 is None:
|
|
raise ValueError("IPv4 or IPv6 must be specified!")
|
|
self.workdir = str(workdir)
|
|
self.port = port
|
|
self.tls_port = tls_port
|
|
self.ip = ip
|
|
self.ip6 = ip6
|
|
self.process = None
|
|
self.sockets = []
|
|
self.logfile = None
|
|
self.verbose = verbose
|
|
self.hints = {} if hints is None else hints
|
|
self.forward = forward
|
|
self.policy_test_pass = policy_test_pass
|
|
self.rr = rr
|
|
self.valgrind = valgrind
|
|
|
|
if certname:
|
|
self.tls_cert_path = os.path.join(CERTS_DIR, certname + '.cert.pem')
|
|
self.tls_key_path = os.path.join(CERTS_DIR, certname + '.key.pem')
|
|
else:
|
|
self.tls_cert_path = None
|
|
self.tls_key_path = None
|
|
|
|
@property
|
|
def config_path(self):
|
|
return str(os.path.join(self.workdir, 'kresd.conf'))
|
|
|
|
@property
|
|
def logfile_path(self):
|
|
return str(os.path.join(self.workdir, 'kresd.log'))
|
|
|
|
def __enter__(self):
|
|
if self.port is not None:
|
|
take_port(self.port, self.ip, self.ip6, timeout=120)
|
|
else:
|
|
self.port = make_port(self.ip, self.ip6)
|
|
if self.tls_port is not None:
|
|
take_port(self.tls_port, self.ip, self.ip6, timeout=120)
|
|
else:
|
|
self.tls_port = make_port(self.ip, self.ip6)
|
|
|
|
create_file_from_template(KRESD_CONF_TEMPLATE, self.config_path, {'kresd': self})
|
|
self.logfile = open(self.logfile_path, 'w', encoding='UTF-8')
|
|
|
|
proc_args = ['kresd', '-c', self.config_path, '-n', self.workdir]
|
|
if self.rr:
|
|
proc_args = ['rr', 'record', '--'] + proc_args
|
|
if self.valgrind:
|
|
proc_args = ['valgrind', '--'] + proc_args
|
|
|
|
self.process = subprocess.Popen(
|
|
proc_args,
|
|
stderr=self.logfile, env=os.environ.copy())
|
|
|
|
try:
|
|
self._wait_for_tcp_port() # wait for ports to be up and responding
|
|
if not self.all_ports_alive(msgid=10001):
|
|
raise RuntimeError("Kresd not listening on all ports")
|
|
|
|
# issue special msgid to mark start of test log
|
|
sock = self.ip_tcp_socket() if self.ip else self.ip6_tcp_socket()
|
|
assert utils.try_ping_alive(sock, close=True, msgid=KRESD_STARTUP_MSGID)
|
|
|
|
# sanity check - kresd didn't crash
|
|
self.process.poll()
|
|
if self.process.returncode is not None:
|
|
raise RuntimeError("Kresd crashed with returncode: {}".format(
|
|
self.process.returncode))
|
|
except (RuntimeError, ConnectionError): # pylint: disable=try-except-raise
|
|
with open(self.logfile_path, encoding='UTF-8') as log: # print log for debugging
|
|
print(log.read())
|
|
raise
|
|
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
try:
|
|
if not self.all_ports_alive(msgid=1006):
|
|
raise RuntimeError("Kresd crashed")
|
|
finally:
|
|
for sock in self.sockets:
|
|
sock.close()
|
|
self.process.terminate()
|
|
self.logfile.close()
|
|
Path(KRESD_PORTDIR, str(self.port)).unlink()
|
|
|
|
def all_ports_alive(self, msgid=10001):
|
|
alive = True
|
|
if self.ip:
|
|
alive &= utils.try_ping_alive(self.ip_tcp_socket(), close=True, msgid=msgid)
|
|
alive &= utils.try_ping_alive(self.ip_tls_socket(), close=True, msgid=msgid + 1)
|
|
if self.ip6:
|
|
alive &= utils.try_ping_alive(self.ip6_tcp_socket(), close=True, msgid=msgid + 2)
|
|
alive &= utils.try_ping_alive(self.ip6_tls_socket(), close=True, msgid=msgid + 3)
|
|
return alive
|
|
|
|
def _wait_for_tcp_port(self, max_delay=10, delay_step=0.2):
|
|
family = socket.AF_INET if self.ip else socket.AF_INET6
|
|
i = 0
|
|
end_time = time.time() + max_delay
|
|
|
|
while time.time() < end_time:
|
|
i += 1
|
|
|
|
# use exponential backoff algorithm to choose next delay
|
|
rand_delay = random.randrange(0, i)
|
|
time.sleep(rand_delay * delay_step)
|
|
|
|
try:
|
|
sock, dest = self.stream_socket(family, timeout=5)
|
|
sock.connect(dest)
|
|
except ConnectionRefusedError:
|
|
continue
|
|
else:
|
|
try:
|
|
return utils.try_ping_alive(sock, close=True, msgid=10000)
|
|
except socket.timeout:
|
|
continue
|
|
finally:
|
|
sock.close()
|
|
raise RuntimeError("Kresd didn't start in time {}".format(dest))
|
|
|
|
def socket_dest(self, family, tls=False):
|
|
port = self.tls_port if tls else self.port
|
|
if family == socket.AF_INET:
|
|
return self.ip, port
|
|
elif family == socket.AF_INET6:
|
|
return self.ip6, port, 0, 0
|
|
raise RuntimeError("Unsupported socket family: {}".format(family))
|
|
|
|
def stream_socket(self, family, tls=False, timeout=20):
|
|
"""Initialize a socket and return it along with the destination without connecting."""
|
|
sock = socket.socket(family, socket.SOCK_STREAM)
|
|
sock.settimeout(timeout)
|
|
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
|
dest = self.socket_dest(family, tls)
|
|
self.sockets.append(sock)
|
|
return sock, dest
|
|
|
|
def _tcp_socket(self, family):
|
|
sock, dest = self.stream_socket(family)
|
|
sock.connect(dest)
|
|
return sock
|
|
|
|
def ip_tcp_socket(self):
|
|
return self._tcp_socket(socket.AF_INET)
|
|
|
|
def ip6_tcp_socket(self):
|
|
return self._tcp_socket(socket.AF_INET6)
|
|
|
|
def _tls_socket(self, family):
|
|
sock, dest = self.stream_socket(family, tls=True)
|
|
ctx = utils.make_ssl_context(insecure=True)
|
|
ssock = ctx.wrap_socket(sock)
|
|
try:
|
|
ssock.connect(dest)
|
|
except OSError as exc:
|
|
if exc.errno == 0: # sometimes happens shortly after startup
|
|
return None
|
|
return ssock
|
|
|
|
def _tls_socket_with_retry(self, family):
|
|
sock = self._tls_socket(family)
|
|
if sock is None:
|
|
time.sleep(0.1)
|
|
sock = self._tls_socket(family)
|
|
if sock is None:
|
|
raise RuntimeError("Failed to create TLS socket!")
|
|
return sock
|
|
|
|
def ip_tls_socket(self):
|
|
return self._tls_socket_with_retry(socket.AF_INET)
|
|
|
|
def ip6_tls_socket(self):
|
|
return self._tls_socket_with_retry(socket.AF_INET6)
|
|
|
|
def partial_log(self):
|
|
partial_log = '\n (... omitting log start)\n'
|
|
with open(self.logfile_path, encoding='UTF-8') as log: # display partial log for debugging
|
|
past_startup_msgid = False
|
|
past_startup = False
|
|
for line in log:
|
|
if past_startup:
|
|
partial_log += line
|
|
else: # find real start of test log (after initial alive-pings)
|
|
if not past_startup_msgid:
|
|
if re.match(KRESD_LOG_STARTUP_MSGID, line) is not None:
|
|
past_startup_msgid = True
|
|
else:
|
|
if re.match(KRESD_LOG_IO_CLOSE, line) is not None:
|
|
past_startup = True
|
|
return partial_log
|
|
|
|
|
|
def is_port_free(port, ip=None, ip6=None):
|
|
def check(family, type_, dest):
|
|
sock = socket.socket(family, type_)
|
|
sock.bind(dest)
|
|
sock.close()
|
|
|
|
try:
|
|
if ip is not None:
|
|
check(socket.AF_INET, socket.SOCK_STREAM, (ip, port))
|
|
check(socket.AF_INET, socket.SOCK_DGRAM, (ip, port))
|
|
if ip6 is not None:
|
|
check(socket.AF_INET6, socket.SOCK_STREAM, (ip6, port, 0, 0))
|
|
check(socket.AF_INET6, socket.SOCK_DGRAM, (ip6, port, 0, 0))
|
|
except OSError as exc:
|
|
if exc.errno == 98: # address already in use
|
|
return False
|
|
else:
|
|
raise
|
|
return True
|
|
|
|
|
|
def take_port(port, ip=None, ip6=None, timeout=0):
|
|
port_path = Path(KRESD_PORTDIR, str(port))
|
|
end_time = time.time() + timeout
|
|
try:
|
|
port_path.touch(exist_ok=False)
|
|
except FileExistsError as ex:
|
|
raise ValueError(
|
|
"Port {} already reserved by system or another kresd instance!".format(port)) from ex
|
|
|
|
while True:
|
|
if is_port_free(port, ip, ip6):
|
|
# NOTE: The port_path isn't removed, so other instances don't have to attempt to
|
|
# take the same port again. This has the side effect of leaving many of these
|
|
# files behind, because when another kresd shuts down and removes its file, the
|
|
# port still can't be reserved for a while. This shouldn't become an issue unless
|
|
# we have thousands of tests (and run out of the port range).
|
|
break
|
|
|
|
if time.time() < end_time:
|
|
time.sleep(5)
|
|
else:
|
|
raise ValueError(
|
|
"Port {} is reserved by system!".format(port))
|
|
return port
|
|
|
|
|
|
def make_port(ip=None, ip6=None):
|
|
for _ in range(10): # max attempts
|
|
port = random.randint(KRESD_TESTPORT_MIN, KRESD_TESTPORT_MAX)
|
|
try:
|
|
take_port(port, ip, ip6)
|
|
except ValueError:
|
|
continue # port reserved by system / another kresd instance
|
|
return port
|
|
raise RuntimeError("No available port found!")
|
|
|
|
|
|
KRESD_LOG_STARTUP_MSGID = re.compile(r'^\[[^]]+\]\[{}.*'.format(KRESD_STARTUP_MSGID))
|
|
KRESD_LOG_IO_CLOSE = re.compile(r'^\[io \].*closed by peer.*')
|
|
|
|
|
|
@contextmanager
|
|
def make_kresd(workdir, certname=None, ip='127.0.0.1', ip6='::1', **kwargs):
|
|
with Kresd(workdir, ip=ip, ip6=ip6, certname=certname, **kwargs) as kresd:
|
|
yield kresd
|
|
print(kresd.partial_log())
|