diff options
Diffstat (limited to 'tests/pytests/kresd.py')
-rw-r--r-- | tests/pytests/kresd.py | 306 |
1 files changed, 306 insertions, 0 deletions
diff --git a/tests/pytests/kresd.py b/tests/pytests/kresd.py new file mode 100644 index 0000000..149efe9 --- /dev/null +++ b/tests/pytests/kresd.py @@ -0,0 +1,306 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +from collections import namedtuple +from contextlib import ContextDecorator, contextmanager +import os +from pathlib import Path +import random +import re +import shutil +import socket +import subprocess +import time + +import jinja2 + +import utils + + +PYTESTS_DIR = os.path.dirname(os.path.realpath(__file__)) +CERTS_DIR = os.path.join(PYTESTS_DIR, 'certs') +TEMPLATES_DIR = os.path.join(PYTESTS_DIR, 'templates') +KRESD_CONF_TEMPLATE = 'kresd.conf.j2' +KRESD_STARTUP_MSGID = 10005 # special unique ID at the start of the "test" log +KRESD_PORTDIR = '/tmp/pytest-kresd-portdir' +KRESD_TESTPORT_MIN = 1024 +KRESD_TESTPORT_MAX = 32768 # avoid overlap with docker ephemeral port range + + +def init_portdir(): + try: + shutil.rmtree(KRESD_PORTDIR) + except FileNotFoundError: + pass + os.makedirs(KRESD_PORTDIR) + + +def create_file_from_template(template_path, dest, data): + env = jinja2.Environment( + loader=jinja2.FileSystemLoader(TEMPLATES_DIR)) + template = env.get_template(template_path) + rendered_template = template.render(**data) + + with open(dest, "w", encoding='UTF-8') as fh: + fh.write(rendered_template) + + +Forward = namedtuple('Forward', ['proto', 'ip', 'port', 'hostname', 'ca_file']) + + +class Kresd(ContextDecorator): + def __init__( + self, workdir, port=None, tls_port=None, ip=None, ip6=None, certname=None, + verbose=True, hints=None, forward=None, policy_test_pass=False): + if ip is None and ip6 is None: + raise ValueError("IPv4 or IPv6 must be specified!") + self.workdir = str(workdir) + self.port = port + self.tls_port = tls_port + self.ip = ip + self.ip6 = ip6 + self.process = None + self.sockets = [] + self.logfile = None + self.verbose = verbose + self.hints = {} if hints is None else hints + self.forward = forward + self.policy_test_pass = policy_test_pass + + if certname: + self.tls_cert_path = os.path.join(CERTS_DIR, certname + '.cert.pem') + self.tls_key_path = os.path.join(CERTS_DIR, certname + '.key.pem') + else: + self.tls_cert_path = None + self.tls_key_path = None + + @property + def config_path(self): + return str(os.path.join(self.workdir, 'kresd.conf')) + + @property + def logfile_path(self): + return str(os.path.join(self.workdir, 'kresd.log')) + + def __enter__(self): + if self.port is not None: + take_port(self.port, self.ip, self.ip6, timeout=120) + else: + self.port = make_port(self.ip, self.ip6) + if self.tls_port is not None: + take_port(self.tls_port, self.ip, self.ip6, timeout=120) + else: + self.tls_port = make_port(self.ip, self.ip6) + + create_file_from_template(KRESD_CONF_TEMPLATE, self.config_path, {'kresd': self}) + self.logfile = open(self.logfile_path, 'w', encoding='UTF-8') + self.process = subprocess.Popen( + ['kresd', '-c', self.config_path, '-n', self.workdir], + stderr=self.logfile, env=os.environ.copy()) + + try: + self._wait_for_tcp_port() # wait for ports to be up and responding + if not self.all_ports_alive(msgid=10001): + raise RuntimeError("Kresd not listening on all ports") + + # issue special msgid to mark start of test log + sock = self.ip_tcp_socket() if self.ip else self.ip6_tcp_socket() + assert utils.try_ping_alive(sock, close=True, msgid=KRESD_STARTUP_MSGID) + + # sanity check - kresd didn't crash + self.process.poll() + if self.process.returncode is not None: + raise RuntimeError("Kresd crashed with returncode: {}".format( + self.process.returncode)) + except (RuntimeError, ConnectionError): # pylint: disable=try-except-raise + with open(self.logfile_path, encoding='UTF-8') as log: # print log for debugging + print(log.read()) + raise + + return self + + def __exit__(self, exc_type, exc_value, traceback): + try: + if not self.all_ports_alive(msgid=1006): + raise RuntimeError("Kresd crashed") + finally: + for sock in self.sockets: + sock.close() + self.process.terminate() + self.logfile.close() + Path(KRESD_PORTDIR, str(self.port)).unlink() + + def all_ports_alive(self, msgid=10001): + alive = True + if self.ip: + alive &= utils.try_ping_alive(self.ip_tcp_socket(), close=True, msgid=msgid) + alive &= utils.try_ping_alive(self.ip_tls_socket(), close=True, msgid=msgid + 1) + if self.ip6: + alive &= utils.try_ping_alive(self.ip6_tcp_socket(), close=True, msgid=msgid + 2) + alive &= utils.try_ping_alive(self.ip6_tls_socket(), close=True, msgid=msgid + 3) + return alive + + def _wait_for_tcp_port(self, max_delay=10, delay_step=0.2): + family = socket.AF_INET if self.ip else socket.AF_INET6 + i = 0 + end_time = time.time() + max_delay + + while time.time() < end_time: + i += 1 + + # use exponential backoff algorithm to choose next delay + rand_delay = random.randrange(0, i) + time.sleep(rand_delay * delay_step) + + try: + sock, dest = self.stream_socket(family, timeout=5) + sock.connect(dest) + except ConnectionRefusedError: + continue + else: + try: + return utils.try_ping_alive(sock, close=True, msgid=10000) + except socket.timeout: + continue + finally: + sock.close() + raise RuntimeError("Kresd didn't start in time {}".format(dest)) + + def socket_dest(self, family, tls=False): + port = self.tls_port if tls else self.port + if family == socket.AF_INET: + return self.ip, port + elif family == socket.AF_INET6: + return self.ip6, port, 0, 0 + raise RuntimeError("Unsupported socket family: {}".format(family)) + + def stream_socket(self, family, tls=False, timeout=20): + """Initialize a socket and return it along with the destination without connecting.""" + sock = socket.socket(family, socket.SOCK_STREAM) + sock.settimeout(timeout) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + dest = self.socket_dest(family, tls) + self.sockets.append(sock) + return sock, dest + + def _tcp_socket(self, family): + sock, dest = self.stream_socket(family) + sock.connect(dest) + return sock + + def ip_tcp_socket(self): + return self._tcp_socket(socket.AF_INET) + + def ip6_tcp_socket(self): + return self._tcp_socket(socket.AF_INET6) + + def _tls_socket(self, family): + sock, dest = self.stream_socket(family, tls=True) + ctx = utils.make_ssl_context(insecure=True) + ssock = ctx.wrap_socket(sock) + try: + ssock.connect(dest) + except OSError as exc: + if exc.errno == 0: # sometimes happens shortly after startup + return None + return ssock + + def _tls_socket_with_retry(self, family): + sock = self._tls_socket(family) + if sock is None: + time.sleep(0.1) + sock = self._tls_socket(family) + if sock is None: + raise RuntimeError("Failed to create TLS socket!") + return sock + + def ip_tls_socket(self): + return self._tls_socket_with_retry(socket.AF_INET) + + def ip6_tls_socket(self): + return self._tls_socket_with_retry(socket.AF_INET6) + + def partial_log(self): + partial_log = '\n (... omitting log start)\n' + with open(self.logfile_path, encoding='UTF-8') as log: # display partial log for debugging + past_startup_msgid = False + past_startup = False + for line in log: + if past_startup: + partial_log += line + else: # find real start of test log (after initial alive-pings) + if not past_startup_msgid: + if re.match(KRESD_LOG_STARTUP_MSGID, line) is not None: + past_startup_msgid = True + else: + if re.match(KRESD_LOG_IO_CLOSE, line) is not None: + past_startup = True + return partial_log + + +def is_port_free(port, ip=None, ip6=None): + def check(family, type_, dest): + sock = socket.socket(family, type_) + sock.bind(dest) + sock.close() + + try: + if ip is not None: + check(socket.AF_INET, socket.SOCK_STREAM, (ip, port)) + check(socket.AF_INET, socket.SOCK_DGRAM, (ip, port)) + if ip6 is not None: + check(socket.AF_INET6, socket.SOCK_STREAM, (ip6, port, 0, 0)) + check(socket.AF_INET6, socket.SOCK_DGRAM, (ip6, port, 0, 0)) + except OSError as exc: + if exc.errno == 98: # address already in use + return False + else: + raise + return True + + +def take_port(port, ip=None, ip6=None, timeout=0): + port_path = Path(KRESD_PORTDIR, str(port)) + end_time = time.time() + timeout + try: + port_path.touch(exist_ok=False) + except FileExistsError as ex: + raise ValueError( + "Port {} already reserved by system or another kresd instance!".format(port)) from ex + + while True: + if is_port_free(port, ip, ip6): + # NOTE: The port_path isn't removed, so other instances don't have to attempt to + # take the same port again. This has the side effect of leaving many of these + # files behind, because when another kresd shuts down and removes its file, the + # port still can't be reserved for a while. This shouldn't become an issue unless + # we have thousands of tests (and run out of the port range). + break + + if time.time() < end_time: + time.sleep(5) + else: + raise ValueError( + "Port {} is reserved by system!".format(port)) + return port + + +def make_port(ip=None, ip6=None): + for _ in range(10): # max attempts + port = random.randint(KRESD_TESTPORT_MIN, KRESD_TESTPORT_MAX) + try: + take_port(port, ip, ip6) + except ValueError: + continue # port reserved by system / another kresd instance + return port + raise RuntimeError("No available port found!") + + +KRESD_LOG_STARTUP_MSGID = re.compile(r'^\[[^]]+\]\[{}.*'.format(KRESD_STARTUP_MSGID)) +KRESD_LOG_IO_CLOSE = re.compile(r'^\[io \].*closed by peer.*') + + +@contextmanager +def make_kresd(workdir, certname=None, ip='127.0.0.1', ip6='::1', **kwargs): + with Kresd(workdir, ip=ip, ip6=ip6, certname=certname, **kwargs) as kresd: + yield kresd + print(kresd.partial_log()) |