diff options
Diffstat (limited to 'test/modules/md/md_cert_util.py')
-rwxr-xr-x | test/modules/md/md_cert_util.py | 239 |
1 files changed, 239 insertions, 0 deletions
diff --git a/test/modules/md/md_cert_util.py b/test/modules/md/md_cert_util.py new file mode 100755 index 0000000..8cd99aa --- /dev/null +++ b/test/modules/md/md_cert_util.py @@ -0,0 +1,239 @@ +import logging +import re +import os +import socket +import OpenSSL +import time +import sys + +from datetime import datetime +from datetime import tzinfo +from datetime import timedelta +from http.client import HTTPConnection +from urllib.parse import urlparse + + +SEC_PER_DAY = 24 * 60 * 60 + + +log = logging.getLogger(__name__) + + +class MDCertUtil(object): + # Utility class for inspecting certificates in test cases + # Uses PyOpenSSL: https://pyopenssl.org/en/stable/index.html + + @classmethod + def create_self_signed_cert(cls, path, name_list, valid_days, serial=1000): + domain = name_list[0] + if not os.path.exists(path): + os.makedirs(path) + + cert_file = os.path.join(path, 'pubcert.pem') + pkey_file = os.path.join(path, 'privkey.pem') + # create a key pair + if os.path.exists(pkey_file): + key_buffer = open(pkey_file, 'rt').read() + k = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, key_buffer) + else: + k = OpenSSL.crypto.PKey() + k.generate_key(OpenSSL.crypto.TYPE_RSA, 2048) + + # create a self-signed cert + cert = OpenSSL.crypto.X509() + cert.get_subject().C = "DE" + cert.get_subject().ST = "NRW" + cert.get_subject().L = "Muenster" + cert.get_subject().O = "greenbytes GmbH" + cert.get_subject().CN = domain + cert.set_serial_number(serial) + cert.gmtime_adj_notBefore(valid_days["notBefore"] * SEC_PER_DAY) + cert.gmtime_adj_notAfter(valid_days["notAfter"] * SEC_PER_DAY) + cert.set_issuer(cert.get_subject()) + + cert.add_extensions([OpenSSL.crypto.X509Extension( + b"subjectAltName", False, b", ".join(map(lambda n: b"DNS:" + n.encode(), name_list)) + )]) + cert.set_pubkey(k) + cert.sign(k, 'sha1') + + open(cert_file, "wt").write( + OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert).decode('utf-8')) + open(pkey_file, "wt").write( + OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, k).decode('utf-8')) + + @classmethod + def load_server_cert(cls, host_ip, host_port, host_name, tls=None, ciphers=None): + ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) + if tls is not None and tls != 1.0: + ctx.set_options(OpenSSL.SSL.OP_NO_TLSv1) + if tls is not None and tls != 1.1: + ctx.set_options(OpenSSL.SSL.OP_NO_TLSv1_1) + if tls is not None and tls != 1.2: + ctx.set_options(OpenSSL.SSL.OP_NO_TLSv1_2) + if tls is not None and tls != 1.3 and hasattr(OpenSSL.SSL, "OP_NO_TLSv1_3"): + ctx.set_options(OpenSSL.SSL.OP_NO_TLSv1_3) + if ciphers is not None: + ctx.set_cipher_list(ciphers) + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + connection = OpenSSL.SSL.Connection(ctx, s) + connection.connect((host_ip, int(host_port))) + connection.setblocking(1) + connection.set_tlsext_host_name(host_name.encode('utf-8')) + connection.do_handshake() + peer_cert = connection.get_peer_certificate() + return MDCertUtil(None, cert=peer_cert) + + @classmethod + def parse_pem_cert(cls, text): + cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, text.encode('utf-8')) + return MDCertUtil(None, cert=cert) + + @classmethod + def get_plain(cls, url, timeout): + server = urlparse(url) + try_until = time.time() + timeout + while time.time() < try_until: + # noinspection PyBroadException + try: + c = HTTPConnection(server.hostname, server.port, timeout=timeout) + c.request('GET', server.path) + resp = c.getresponse() + data = resp.read() + c.close() + return data + except IOError: + log.debug("connect error:", sys.exc_info()[0]) + time.sleep(.1) + except: + log.error("Unexpected error:", sys.exc_info()[0]) + log.error("Unable to contact server after %d sec" % timeout) + return None + + def __init__(self, cert_path, cert=None): + if cert_path is not None: + self.cert_path = cert_path + # load certificate and private key + if cert_path.startswith("http"): + cert_data = self.get_plain(cert_path, 1) + else: + cert_data = MDCertUtil._load_binary_file(cert_path) + + for file_type in (OpenSSL.crypto.FILETYPE_PEM, OpenSSL.crypto.FILETYPE_ASN1): + try: + self.cert = OpenSSL.crypto.load_certificate(file_type, cert_data) + except Exception as error: + self.error = error + if cert is not None: + self.cert = cert + + if self.cert is None: + raise self.error + + def get_issuer(self): + return self.cert.get_issuer() + + def get_serial(self): + # the string representation of a serial number is not unique. Some + # add leading 0s to align with word boundaries. + return ("%lx" % (self.cert.get_serial_number())).upper() + + def same_serial_as(self, other): + if isinstance(other, MDCertUtil): + return self.cert.get_serial_number() == other.cert.get_serial_number() + elif isinstance(other, OpenSSL.crypto.X509): + return self.cert.get_serial_number() == other.get_serial_number() + elif isinstance(other, str): + # assume a hex number + return self.cert.get_serial_number() == int(other, 16) + elif isinstance(other, int): + return self.cert.get_serial_number() == other + return False + + def get_not_before(self): + tsp = self.cert.get_notBefore() + return self._parse_tsp(tsp) + + def get_not_after(self): + tsp = self.cert.get_notAfter() + return self._parse_tsp(tsp) + + def get_cn(self): + return self.cert.get_subject().CN + + def get_key_length(self): + return self.cert.get_pubkey().bits() + + def get_san_list(self): + text = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_TEXT, self.cert).decode("utf-8") + m = re.search(r"X509v3 Subject Alternative Name:\s*(.*)", text) + sans_list = [] + if m: + sans_list = m.group(1).split(",") + + def _strip_prefix(s): + return s.split(":")[1] if s.strip().startswith("DNS:") else s.strip() + return list(map(_strip_prefix, sans_list)) + + def get_must_staple(self): + text = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_TEXT, self.cert).decode("utf-8") + m = re.search(r"1.3.6.1.5.5.7.1.24:\s*\n\s*0....", text) + if not m: + # Newer openssl versions print this differently + m = re.search(r"TLS Feature:\s*\n\s*status_request\s*\n", text) + return m is not None + + @classmethod + def validate_privkey(cls, privkey_path, passphrase=None): + privkey_data = cls._load_binary_file(privkey_path) + if passphrase: + privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, privkey_data, passphrase) + else: + privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, privkey_data) + return privkey.check() + + def validate_cert_matches_priv_key(self, privkey_path): + # Verifies that the private key and cert match. + privkey_data = MDCertUtil._load_binary_file(privkey_path) + privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, privkey_data) + context = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) + context.use_privatekey(privkey) + context.use_certificate(self.cert) + context.check_privatekey() + + # --------- _utils_ --------- + + def astr(self, s): + return s.decode('utf-8') + + def _parse_tsp(self, tsp): + # timestampss returned by PyOpenSSL are bytes + # parse date and time part + s = ("%s-%s-%s %s:%s:%s" % (self.astr(tsp[0:4]), self.astr(tsp[4:6]), self.astr(tsp[6:8]), + self.astr(tsp[8:10]), self.astr(tsp[10:12]), self.astr(tsp[12:14]))) + timestamp = datetime.strptime(s, '%Y-%m-%d %H:%M:%S') + # adjust timezone + tz_h, tz_m = 0, 0 + m = re.match(r"([+\-]\d{2})(\d{2})", self.astr(tsp[14:])) + if m: + tz_h, tz_m = int(m.group(1)), int(m.group(2)) if tz_h > 0 else -1 * int(m.group(2)) + return timestamp.replace(tzinfo=self.FixedOffset(60 * tz_h + tz_m)) + + @classmethod + def _load_binary_file(cls, path): + with open(path, mode="rb") as file: + return file.read() + + class FixedOffset(tzinfo): + + def __init__(self, offset): + self.__offset = timedelta(minutes=offset) + + def utcoffset(self, dt): + return self.__offset + + def tzname(self, dt): + return None + + def dst(self, dt): + return timedelta(0) |