summaryrefslogtreecommitdiffstats
path: root/test/modules/md/md_env.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rwxr-xr-xtest/modules/md/md_env.py613
1 files changed, 613 insertions, 0 deletions
diff --git a/test/modules/md/md_env.py b/test/modules/md/md_env.py
new file mode 100755
index 0000000..e8e36e5
--- /dev/null
+++ b/test/modules/md/md_env.py
@@ -0,0 +1,613 @@
+import copy
+import inspect
+import json
+import logging
+from configparser import ConfigParser, ExtendedInterpolation
+
+import pytest
+import re
+import os
+import shutil
+import subprocess
+import time
+
+from datetime import datetime, timedelta
+from typing import Dict, Optional
+
+from pyhttpd.certs import CertificateSpec
+from .md_cert_util import MDCertUtil
+from pyhttpd.env import HttpdTestSetup, HttpdTestEnv
+from pyhttpd.result import ExecResult
+
+log = logging.getLogger(__name__)
+
+
+class MDTestSetup(HttpdTestSetup):
+
+ def __init__(self, env: 'MDTestEnv'):
+ super().__init__(env=env)
+ self.mdenv = env
+ self.add_modules(["watchdog", "proxy_connect", "md"])
+
+ def make(self):
+ super().make()
+ if "pebble" == self.mdenv.acme_server:
+ self._make_pebble_conf()
+ self.mdenv.clear_store()
+
+ def _make_pebble_conf(self):
+ our_dir = os.path.dirname(inspect.getfile(MDTestSetup))
+ conf_src_dir = os.path.join(our_dir, 'pebble')
+ conf_dest_dir = os.path.join(self.env.gen_dir, 'pebble')
+ if not os.path.exists(conf_dest_dir):
+ os.makedirs(conf_dest_dir)
+ for name in os.listdir(conf_src_dir):
+ src_path = os.path.join(conf_src_dir, name)
+ m = re.match(r'(.+).template', name)
+ if m:
+ self._make_template(src_path, os.path.join(conf_dest_dir, m.group(1)))
+ elif os.path.isfile(src_path):
+ shutil.copy(src_path, os.path.join(conf_dest_dir, name))
+
+
+class MDTestEnv(HttpdTestEnv):
+
+ MD_S_UNKNOWN = 0
+ MD_S_INCOMPLETE = 1
+ MD_S_COMPLETE = 2
+ MD_S_EXPIRED = 3
+ MD_S_ERROR = 4
+
+ EMPTY_JOUT = {'status': 0, 'output': []}
+
+ DOMAIN_SUFFIX = "%d.org" % time.time()
+ LOG_FMT_TIGHT = '%(levelname)s: %(message)s'
+
+ @classmethod
+ def get_acme_server(cls):
+ return os.environ['ACME'] if 'ACME' in os.environ else "pebble"
+
+ @classmethod
+ def has_acme_server(cls):
+ return cls.get_acme_server() != 'none'
+
+ @classmethod
+ def has_acme_eab(cls):
+ return cls.get_acme_server() == 'pebble'
+
+ @classmethod
+ def is_pebble(cls) -> bool:
+ return cls.get_acme_server() == 'pebble'
+
+ @classmethod
+ def lacks_ocsp(cls):
+ return cls.is_pebble()
+
+ @classmethod
+ def has_a2md(cls):
+ d = os.path.dirname(inspect.getfile(HttpdTestEnv))
+ config = ConfigParser(interpolation=ExtendedInterpolation())
+ config.read(os.path.join(d, 'config.ini'))
+ bin_dir = config.get('global', 'bindir')
+ a2md_bin = os.path.join(bin_dir, 'a2md')
+ return os.path.isfile(a2md_bin)
+
+ def __init__(self, pytestconfig=None):
+ super().__init__(pytestconfig=pytestconfig)
+ self.add_httpd_log_modules(["md"])
+ self._acme_server = self.get_acme_server()
+ self._acme_tos = "accepted"
+ self._acme_ca_pemfile = os.path.join(self.gen_dir, "apache/acme-ca.pem")
+ if "pebble" == self._acme_server:
+ self._acme_url = "https://localhost:14000/dir"
+ self._acme_eab_url = "https://localhost:14001/dir"
+ elif "boulder" == self._acme_server:
+ self._acme_url = "http://localhost:4001/directory"
+ self._acme_eab_url = None
+ else:
+ raise Exception(f"unknown ACME server type: {self._acme_server}")
+ self._acme_server_down = False
+ self._acme_server_ok = False
+
+ self._a2md_bin = os.path.join(self.bin_dir, 'a2md')
+ self._default_domain = f"test1.{self.http_tld}"
+ self._tailscale_domain = "test.headless-chicken.ts.net"
+ self._store_dir = "./md"
+ self.set_store_dir_default()
+
+ self.add_cert_specs([
+ CertificateSpec(domains=[f"expired.{self._http_tld}"],
+ valid_from=timedelta(days=-100),
+ valid_to=timedelta(days=-10)),
+ CertificateSpec(domains=["localhost"], key_type='rsa2048'),
+ CertificateSpec(domains=[self._tailscale_domain]),
+ ])
+
+ def setup_httpd(self, setup: HttpdTestSetup = None):
+ super().setup_httpd(setup=MDTestSetup(env=self))
+
+ def set_store_dir_default(self):
+ dirpath = "md"
+ if self.httpd_is_at_least("2.5.0"):
+ dirpath = os.path.join("state", dirpath)
+ self.set_store_dir(dirpath)
+
+ def set_store_dir(self, dirpath):
+ self._store_dir = os.path.join(self.server_dir, dirpath)
+ if self.acme_url:
+ self.a2md_stdargs([self.a2md_bin, "-a", self.acme_url,
+ "-d", self._store_dir, "-C", self.acme_ca_pemfile, "-j"])
+ self.a2md_rawargs([self.a2md_bin, "-a", self.acme_url,
+ "-d", self._store_dir, "-C", self.acme_ca_pemfile])
+
+ def get_apxs_var(self, name: str) -> str:
+ p = subprocess.run([self._apxs, "-q", name], capture_output=True, text=True)
+ if p.returncode != 0:
+ return ""
+ return p.stdout.strip()
+
+ @property
+ def acme_server(self):
+ return self._acme_server
+
+ @property
+ def acme_url(self):
+ return self._acme_url
+
+ @property
+ def acme_tos(self):
+ return self._acme_tos
+
+ @property
+ def a2md_bin(self):
+ return self._a2md_bin
+
+ @property
+ def acme_ca_pemfile(self):
+ return self._acme_ca_pemfile
+
+ @property
+ def store_dir(self):
+ return self._store_dir
+
+ @property
+ def tailscale_domain(self):
+ return self._tailscale_domain
+
+ def get_request_domain(self, request):
+ name = request.node.originalname if request.node.originalname else request.node.name
+ return "%s-%s" % (re.sub(r'[_]', '-', name), MDTestEnv.DOMAIN_SUFFIX)
+
+ def get_method_domain(self, method):
+ return "%s-%s" % (re.sub(r'[_]', '-', method.__name__.lower()), MDTestEnv.DOMAIN_SUFFIX)
+
+ def get_module_domain(self, module):
+ return "%s-%s" % (re.sub(r'[_]', '-', module.__name__.lower()), MDTestEnv.DOMAIN_SUFFIX)
+
+ def get_class_domain(self, c):
+ return "%s-%s" % (re.sub(r'[_]', '-', c.__name__.lower()), MDTestEnv.DOMAIN_SUFFIX)
+
+ # --------- cmd execution ---------
+
+ _a2md_args = []
+ _a2md_args_raw = []
+
+ def a2md_stdargs(self, args):
+ self._a2md_args = [] + args
+
+ def a2md_rawargs(self, args):
+ self._a2md_args_raw = [] + args
+
+ def a2md(self, args, raw=False) -> ExecResult:
+ preargs = self._a2md_args
+ if raw:
+ preargs = self._a2md_args_raw
+ log.debug("running: {0} {1}".format(preargs, args))
+ return self.run(preargs + args)
+
+ def check_acme(self):
+ if self._acme_server_ok:
+ return True
+ if self._acme_server_down:
+ pytest.skip(msg="ACME server not running")
+ return False
+ if self.is_live(self.acme_url, timeout=timedelta(seconds=0.5)):
+ self._acme_server_ok = True
+ return True
+ else:
+ self._acme_server_down = True
+ pytest.fail(msg="ACME server not running", pytrace=False)
+ return False
+
+ def get_ca_pem_file(self, hostname: str) -> Optional[str]:
+ pem_file = super().get_ca_pem_file(hostname)
+ if pem_file is None:
+ pem_file = self.acme_ca_pemfile
+ return pem_file
+
+ # --------- access local store ---------
+
+ def purge_store(self):
+ log.debug("purge store dir: %s" % self._store_dir)
+ assert len(self._store_dir) > 1
+ if os.path.exists(self._store_dir):
+ shutil.rmtree(self._store_dir, ignore_errors=False)
+ os.makedirs(self._store_dir)
+
+ def clear_store(self):
+ log.debug("clear store dir: %s" % self._store_dir)
+ assert len(self._store_dir) > 1
+ if not os.path.exists(self._store_dir):
+ os.makedirs(self._store_dir)
+ for dirpath in ["challenges", "tmp", "archive", "domains", "accounts", "staging", "ocsp"]:
+ shutil.rmtree(os.path.join(self._store_dir, dirpath), ignore_errors=True)
+
+ def clear_ocsp_store(self):
+ assert len(self._store_dir) > 1
+ dirpath = os.path.join(self._store_dir, "ocsp")
+ log.debug("clear ocsp store dir: %s" % dir)
+ if os.path.exists(dirpath):
+ shutil.rmtree(dirpath, ignore_errors=True)
+
+ def authz_save(self, name, content):
+ dirpath = os.path.join(self._store_dir, 'staging', name)
+ os.makedirs(dirpath)
+ open(os.path.join(dirpath, 'authz.json'), "w").write(content)
+
+ def path_store_json(self):
+ return os.path.join(self._store_dir, 'md_store.json')
+
+ def path_account(self, acct):
+ return os.path.join(self._store_dir, 'accounts', acct, 'account.json')
+
+ def path_account_key(self, acct):
+ return os.path.join(self._store_dir, 'accounts', acct, 'account.pem')
+
+ def store_domains(self):
+ return os.path.join(self._store_dir, 'domains')
+
+ def store_archives(self):
+ return os.path.join(self._store_dir, 'archive')
+
+ def store_stagings(self):
+ return os.path.join(self._store_dir, 'staging')
+
+ def store_challenges(self):
+ return os.path.join(self._store_dir, 'challenges')
+
+ def store_domain_file(self, domain, filename):
+ return os.path.join(self.store_domains(), domain, filename)
+
+ def store_archived_file(self, domain, version, filename):
+ return os.path.join(self.store_archives(), "%s.%d" % (domain, version), filename)
+
+ def store_staged_file(self, domain, filename):
+ return os.path.join(self.store_stagings(), domain, filename)
+
+ def path_fallback_cert(self, domain):
+ return os.path.join(self._store_dir, 'domains', domain, 'fallback-pubcert.pem')
+
+ def path_job(self, domain):
+ return os.path.join(self._store_dir, 'staging', domain, 'job.json')
+
+ def replace_store(self, src):
+ shutil.rmtree(self._store_dir, ignore_errors=False)
+ shutil.copytree(src, self._store_dir)
+
+ def list_accounts(self):
+ return os.listdir(os.path.join(self._store_dir, 'accounts'))
+
+ def check_md(self, domain, md=None, state=-1, ca=None, protocol=None, agreement=None, contacts=None):
+ domains = None
+ if isinstance(domain, list):
+ domains = domain
+ domain = domains[0]
+ if md:
+ domain = md
+ path = self.store_domain_file(domain, 'md.json')
+ with open(path) as f:
+ md = json.load(f)
+ assert md
+ if domains:
+ assert md['domains'] == domains
+ if state >= 0:
+ assert md['state'] == state
+ if ca:
+ assert len(md['ca']['urls']) == 1
+ assert md['ca']['urls'][0] == ca
+ if protocol:
+ assert md['ca']['proto'] == protocol
+ if agreement:
+ assert md['ca']['agreement'] == agreement
+ if contacts:
+ assert md['contacts'] == contacts
+
+ def pkey_fname(self, pkeyspec=None):
+ if pkeyspec and not re.match(r'^rsa( ?\d+)?$', pkeyspec.lower()):
+ return "privkey.{0}.pem".format(pkeyspec.lower())
+ return 'privkey.pem'
+
+ def cert_fname(self, pkeyspec=None):
+ if pkeyspec and not re.match(r'^rsa( ?\d+)?$', pkeyspec.lower()):
+ return "pubcert.{0}.pem".format(pkeyspec.lower())
+ return 'pubcert.pem'
+
+ def check_md_complete(self, domain, pkey=None):
+ md = self.get_md_status(domain)
+ assert md
+ assert 'state' in md, "md is unexpected: {0}".format(md)
+ assert md['state'] is MDTestEnv.MD_S_COMPLETE, f"unexpected state: {md['state']}"
+ pkey_file = self.store_domain_file(domain, self.pkey_fname(pkey))
+ cert_file = self.store_domain_file(domain, self.cert_fname(pkey))
+ r = self.run(['ls', os.path.dirname(pkey_file)])
+ if not os.path.isfile(pkey_file):
+ assert False, f"pkey missing: {pkey_file}: {r.stdout}"
+ if not os.path.isfile(cert_file):
+ assert False, f"cert missing: {cert_file}: {r.stdout}"
+ return md
+
+ def check_md_credentials(self, domain):
+ if isinstance(domain, list):
+ domains = domain
+ domain = domains[0]
+ else:
+ domains = [domain]
+ # check private key, validate certificate, etc
+ MDCertUtil.validate_privkey(self.store_domain_file(domain, 'privkey.pem'))
+ cert = MDCertUtil(self.store_domain_file(domain, 'pubcert.pem'))
+ cert.validate_cert_matches_priv_key(self.store_domain_file(domain, 'privkey.pem'))
+ # check SANs and CN
+ assert cert.get_cn() == domain
+ # compare lists twice in opposite directions: SAN may not respect ordering
+ san_list = list(cert.get_san_list())
+ assert len(san_list) == len(domains)
+ assert set(san_list).issubset(domains)
+ assert set(domains).issubset(san_list)
+ # check valid dates interval
+ not_before = cert.get_not_before()
+ not_after = cert.get_not_after()
+ assert not_before < datetime.now(not_before.tzinfo)
+ assert not_after > datetime.now(not_after.tzinfo)
+
+ # --------- check utilities ---------
+
+ def check_json_contains(self, actual, expected):
+ # write all expected key:value bindings to a copy of the actual data ...
+ # ... assert it stays unchanged
+ test_json = copy.deepcopy(actual)
+ test_json.update(expected)
+ assert actual == test_json
+
+ def check_file_access(self, path, exp_mask):
+ actual_mask = os.lstat(path).st_mode & 0o777
+ assert oct(actual_mask) == oct(exp_mask)
+
+ def check_dir_empty(self, path):
+ assert os.listdir(path) == []
+
+ def get_http_status(self, domain, path, use_https=True):
+ r = self.get_meta(domain, path, use_https, insecure=True)
+ return r.response['status']
+
+ def get_cert(self, domain, tls=None, ciphers=None):
+ return MDCertUtil.load_server_cert(self._httpd_addr, self.https_port,
+ domain, tls=tls, ciphers=ciphers)
+
+ def get_server_cert(self, domain, proto=None, ciphers=None):
+ args = [
+ "openssl", "s_client", "-status",
+ "-connect", "%s:%s" % (self._httpd_addr, self.https_port),
+ "-CAfile", self.acme_ca_pemfile,
+ "-servername", domain,
+ "-showcerts"
+ ]
+ if proto is not None:
+ args.extend(["-{0}".format(proto)])
+ if ciphers is not None:
+ args.extend(["-cipher", ciphers])
+ r = self.run(args)
+ # noinspection PyBroadException
+ try:
+ return MDCertUtil.parse_pem_cert(r.stdout)
+ except:
+ return None
+
+ def verify_cert_key_lenghts(self, domain, pkeys):
+ for p in pkeys:
+ cert = self.get_server_cert(domain, proto="tls1_2", ciphers=p['ciphers'])
+ if 0 == p['keylen']:
+ assert cert is None
+ else:
+ assert cert, "no cert returned for cipher: {0}".format(p['ciphers'])
+ assert cert.get_key_length() == p['keylen'], "key length, expected {0}, got {1}".format(
+ p['keylen'], cert.get_key_length()
+ )
+
+ def get_meta(self, domain, path, use_https=True, insecure=False):
+ schema = "https" if use_https else "http"
+ port = self.https_port if use_https else self.http_port
+ r = self.curl_get(f"{schema}://{domain}:{port}{path}", insecure=insecure)
+ assert r.exit_code == 0
+ assert r.response
+ assert r.response['header']
+ return r
+
+ def get_content(self, domain, path, use_https=True):
+ schema = "https" if use_https else "http"
+ port = self.https_port if use_https else self.http_port
+ r = self.curl_get(f"{schema}://{domain}:{port}{path}")
+ assert r.exit_code == 0
+ return r.stdout
+
+ def get_json_content(self, domain, path, use_https=True, insecure=False):
+ schema = "https" if use_https else "http"
+ port = self.https_port if use_https else self.http_port
+ url = f"{schema}://{domain}:{port}{path}"
+ r = self.curl_get(url, insecure=insecure)
+ if r.exit_code != 0:
+ log.error(f"curl get on {url} returned {r.exit_code}"
+ f"\nstdout: {r.stdout}"
+ f"\nstderr: {r.stderr}")
+ assert r.exit_code == 0, r.stderr
+ return r.json
+
+ def get_certificate_status(self, domain) -> Dict:
+ return self.get_json_content(domain, "/.httpd/certificate-status", insecure=True)
+
+ def get_md_status(self, domain, via_domain=None, use_https=True) -> Dict:
+ if via_domain is None:
+ via_domain = self._default_domain
+ return self.get_json_content(via_domain, f"/md-status/{domain}",
+ use_https=use_https)
+
+ def get_server_status(self, query="/", via_domain=None, use_https=True):
+ if via_domain is None:
+ via_domain = self._default_domain
+ return self.get_content(via_domain, "/server-status%s" % query, use_https=use_https)
+
+ def await_completion(self, names, must_renew=False, restart=True, timeout=60,
+ via_domain=None, use_https=True):
+ try_until = time.time() + timeout
+ renewals = {}
+ names = names.copy()
+ while len(names) > 0:
+ if time.time() >= try_until:
+ return False
+ for name in names:
+ mds = self.get_md_status(name, via_domain=via_domain, use_https=use_https)
+ if mds is None:
+ log.debug("not managed by md: %s" % name)
+ return False
+
+ if 'renewal' in mds:
+ renewal = mds['renewal']
+ renewals[name] = True
+ if 'finished' in renewal and renewal['finished'] is True:
+ if (not must_renew) or (name in renewals):
+ log.debug(f"domain cert was renewed: {name}")
+ names.remove(name)
+
+ if len(names) != 0:
+ time.sleep(0.1)
+ if restart:
+ time.sleep(0.1)
+ return self.apache_restart() == 0
+ return True
+
+ def is_renewing(self, name):
+ stat = self.get_certificate_status(name)
+ return 'renewal' in stat
+
+ def await_renewal(self, names, timeout=60):
+ try_until = time.time() + timeout
+ while len(names) > 0:
+ if time.time() >= try_until:
+ return False
+ for name in names:
+ md = self.get_md_status(name)
+ if md is None:
+ log.debug("not managed by md: %s" % name)
+ return False
+
+ if 'renewal' in md:
+ names.remove(name)
+
+ if len(names) != 0:
+ time.sleep(0.1)
+ return True
+
+ def await_error(self, domain, timeout=60, via_domain=None, use_https=True, errors=1):
+ try_until = time.time() + timeout
+ while True:
+ if time.time() >= try_until:
+ return False
+ md = self.get_md_status(domain, via_domain=via_domain, use_https=use_https)
+ if md:
+ if 'state' in md and md['state'] == MDTestEnv.MD_S_ERROR:
+ return md
+ if 'renewal' in md and 'errors' in md['renewal'] \
+ and md['renewal']['errors'] >= errors:
+ return md
+ time.sleep(0.1)
+
+ def await_file(self, fpath, timeout=60):
+ try_until = time.time() + timeout
+ while True:
+ if time.time() >= try_until:
+ return False
+ if os.path.isfile(fpath):
+ return True
+ time.sleep(0.1)
+
+ def check_file_permissions(self, domain):
+ dpath = os.path.join(self.store_dir, 'domains', domain)
+ assert os.path.isdir(dpath)
+ md = json.load(open(os.path.join(dpath, 'md.json')))
+ assert md
+ acct = md['ca']['account']
+ assert acct
+ self.check_file_access(self.path_store_json(), 0o600)
+ # domains
+ self.check_file_access(self.store_domains(), 0o700)
+ self.check_file_access(os.path.join(self.store_domains(), domain), 0o700)
+ self.check_file_access(self.store_domain_file(domain, 'privkey.pem'), 0o600)
+ self.check_file_access(self.store_domain_file(domain, 'pubcert.pem'), 0o600)
+ self.check_file_access(self.store_domain_file(domain, 'md.json'), 0o600)
+ # archive
+ self.check_file_access(self.store_archived_file(domain, 1, 'md.json'), 0o600)
+ # accounts
+ self.check_file_access(os.path.join(self._store_dir, 'accounts'), 0o755)
+ self.check_file_access(os.path.join(self._store_dir, 'accounts', acct), 0o755)
+ self.check_file_access(self.path_account(acct), 0o644)
+ self.check_file_access(self.path_account_key(acct), 0o644)
+ # staging
+ self.check_file_access(self.store_stagings(), 0o755)
+
+ def get_ocsp_status(self, domain, proto=None, cipher=None, ca_file=None):
+ stat = {}
+ args = [
+ "openssl", "s_client", "-status",
+ "-connect", "%s:%s" % (self._httpd_addr, self.https_port),
+ "-CAfile", ca_file if ca_file else self.acme_ca_pemfile,
+ "-servername", domain,
+ "-showcerts"
+ ]
+ if proto is not None:
+ args.extend(["-{0}".format(proto)])
+ if cipher is not None:
+ args.extend(["-cipher", cipher])
+ r = self.run(args, debug_log=False)
+ ocsp_regex = re.compile(r'OCSP response: +([^=\n]+)\n')
+ matches = ocsp_regex.finditer(r.stdout)
+ for m in matches:
+ if m.group(1) != "":
+ stat['ocsp'] = m.group(1)
+ if 'ocsp' not in stat:
+ ocsp_regex = re.compile(r'OCSP Response Status:\s*(.+)')
+ matches = ocsp_regex.finditer(r.stdout)
+ for m in matches:
+ if m.group(1) != "":
+ stat['ocsp'] = m.group(1)
+ verify_regex = re.compile(r'Verify return code:\s*(.+)')
+ matches = verify_regex.finditer(r.stdout)
+ for m in matches:
+ if m.group(1) != "":
+ stat['verify'] = m.group(1)
+ return stat
+
+ def await_ocsp_status(self, domain, timeout=10, ca_file=None):
+ try_until = time.time() + timeout
+ while True:
+ if time.time() >= try_until:
+ break
+ stat = self.get_ocsp_status(domain, ca_file=ca_file)
+ if 'ocsp' in stat and stat['ocsp'] != "no response sent":
+ return stat
+ time.sleep(0.1)
+ raise TimeoutError(f"ocsp respopnse not available: {domain}")
+
+ def create_self_signed_cert(self, name_list, valid_days, serial=1000, path=None):
+ dirpath = path
+ if not path:
+ dirpath = os.path.join(self.store_domains(), name_list[0])
+ return MDCertUtil.create_self_signed_cert(dirpath, name_list, valid_days, serial)