summaryrefslogtreecommitdiffstats
path: root/test/modules/md/md_cert_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/modules/md/md_cert_util.py')
-rwxr-xr-xtest/modules/md/md_cert_util.py239
1 files changed, 239 insertions, 0 deletions
diff --git a/test/modules/md/md_cert_util.py b/test/modules/md/md_cert_util.py
new file mode 100755
index 0000000..8cd99aa
--- /dev/null
+++ b/test/modules/md/md_cert_util.py
@@ -0,0 +1,239 @@
+import logging
+import re
+import os
+import socket
+import OpenSSL
+import time
+import sys
+
+from datetime import datetime
+from datetime import tzinfo
+from datetime import timedelta
+from http.client import HTTPConnection
+from urllib.parse import urlparse
+
+
+SEC_PER_DAY = 24 * 60 * 60
+
+
+log = logging.getLogger(__name__)
+
+
+class MDCertUtil(object):
+ # Utility class for inspecting certificates in test cases
+ # Uses PyOpenSSL: https://pyopenssl.org/en/stable/index.html
+
+ @classmethod
+ def create_self_signed_cert(cls, path, name_list, valid_days, serial=1000):
+ domain = name_list[0]
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+ cert_file = os.path.join(path, 'pubcert.pem')
+ pkey_file = os.path.join(path, 'privkey.pem')
+ # create a key pair
+ if os.path.exists(pkey_file):
+ key_buffer = open(pkey_file, 'rt').read()
+ k = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, key_buffer)
+ else:
+ k = OpenSSL.crypto.PKey()
+ k.generate_key(OpenSSL.crypto.TYPE_RSA, 2048)
+
+ # create a self-signed cert
+ cert = OpenSSL.crypto.X509()
+ cert.get_subject().C = "DE"
+ cert.get_subject().ST = "NRW"
+ cert.get_subject().L = "Muenster"
+ cert.get_subject().O = "greenbytes GmbH"
+ cert.get_subject().CN = domain
+ cert.set_serial_number(serial)
+ cert.gmtime_adj_notBefore(valid_days["notBefore"] * SEC_PER_DAY)
+ cert.gmtime_adj_notAfter(valid_days["notAfter"] * SEC_PER_DAY)
+ cert.set_issuer(cert.get_subject())
+
+ cert.add_extensions([OpenSSL.crypto.X509Extension(
+ b"subjectAltName", False, b", ".join(map(lambda n: b"DNS:" + n.encode(), name_list))
+ )])
+ cert.set_pubkey(k)
+ cert.sign(k, 'sha1')
+
+ open(cert_file, "wt").write(
+ OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert).decode('utf-8'))
+ open(pkey_file, "wt").write(
+ OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, k).decode('utf-8'))
+
+ @classmethod
+ def load_server_cert(cls, host_ip, host_port, host_name, tls=None, ciphers=None):
+ ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
+ if tls is not None and tls != 1.0:
+ ctx.set_options(OpenSSL.SSL.OP_NO_TLSv1)
+ if tls is not None and tls != 1.1:
+ ctx.set_options(OpenSSL.SSL.OP_NO_TLSv1_1)
+ if tls is not None and tls != 1.2:
+ ctx.set_options(OpenSSL.SSL.OP_NO_TLSv1_2)
+ if tls is not None and tls != 1.3 and hasattr(OpenSSL.SSL, "OP_NO_TLSv1_3"):
+ ctx.set_options(OpenSSL.SSL.OP_NO_TLSv1_3)
+ if ciphers is not None:
+ ctx.set_cipher_list(ciphers)
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ connection = OpenSSL.SSL.Connection(ctx, s)
+ connection.connect((host_ip, int(host_port)))
+ connection.setblocking(1)
+ connection.set_tlsext_host_name(host_name.encode('utf-8'))
+ connection.do_handshake()
+ peer_cert = connection.get_peer_certificate()
+ return MDCertUtil(None, cert=peer_cert)
+
+ @classmethod
+ def parse_pem_cert(cls, text):
+ cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, text.encode('utf-8'))
+ return MDCertUtil(None, cert=cert)
+
+ @classmethod
+ def get_plain(cls, url, timeout):
+ server = urlparse(url)
+ try_until = time.time() + timeout
+ while time.time() < try_until:
+ # noinspection PyBroadException
+ try:
+ c = HTTPConnection(server.hostname, server.port, timeout=timeout)
+ c.request('GET', server.path)
+ resp = c.getresponse()
+ data = resp.read()
+ c.close()
+ return data
+ except IOError:
+ log.debug("connect error:", sys.exc_info()[0])
+ time.sleep(.1)
+ except:
+ log.error("Unexpected error:", sys.exc_info()[0])
+ log.error("Unable to contact server after %d sec" % timeout)
+ return None
+
+ def __init__(self, cert_path, cert=None):
+ if cert_path is not None:
+ self.cert_path = cert_path
+ # load certificate and private key
+ if cert_path.startswith("http"):
+ cert_data = self.get_plain(cert_path, 1)
+ else:
+ cert_data = MDCertUtil._load_binary_file(cert_path)
+
+ for file_type in (OpenSSL.crypto.FILETYPE_PEM, OpenSSL.crypto.FILETYPE_ASN1):
+ try:
+ self.cert = OpenSSL.crypto.load_certificate(file_type, cert_data)
+ except Exception as error:
+ self.error = error
+ if cert is not None:
+ self.cert = cert
+
+ if self.cert is None:
+ raise self.error
+
+ def get_issuer(self):
+ return self.cert.get_issuer()
+
+ def get_serial(self):
+ # the string representation of a serial number is not unique. Some
+ # add leading 0s to align with word boundaries.
+ return ("%lx" % (self.cert.get_serial_number())).upper()
+
+ def same_serial_as(self, other):
+ if isinstance(other, MDCertUtil):
+ return self.cert.get_serial_number() == other.cert.get_serial_number()
+ elif isinstance(other, OpenSSL.crypto.X509):
+ return self.cert.get_serial_number() == other.get_serial_number()
+ elif isinstance(other, str):
+ # assume a hex number
+ return self.cert.get_serial_number() == int(other, 16)
+ elif isinstance(other, int):
+ return self.cert.get_serial_number() == other
+ return False
+
+ def get_not_before(self):
+ tsp = self.cert.get_notBefore()
+ return self._parse_tsp(tsp)
+
+ def get_not_after(self):
+ tsp = self.cert.get_notAfter()
+ return self._parse_tsp(tsp)
+
+ def get_cn(self):
+ return self.cert.get_subject().CN
+
+ def get_key_length(self):
+ return self.cert.get_pubkey().bits()
+
+ def get_san_list(self):
+ text = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_TEXT, self.cert).decode("utf-8")
+ m = re.search(r"X509v3 Subject Alternative Name:\s*(.*)", text)
+ sans_list = []
+ if m:
+ sans_list = m.group(1).split(",")
+
+ def _strip_prefix(s):
+ return s.split(":")[1] if s.strip().startswith("DNS:") else s.strip()
+ return list(map(_strip_prefix, sans_list))
+
+ def get_must_staple(self):
+ text = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_TEXT, self.cert).decode("utf-8")
+ m = re.search(r"1.3.6.1.5.5.7.1.24:\s*\n\s*0....", text)
+ if not m:
+ # Newer openssl versions print this differently
+ m = re.search(r"TLS Feature:\s*\n\s*status_request\s*\n", text)
+ return m is not None
+
+ @classmethod
+ def validate_privkey(cls, privkey_path, passphrase=None):
+ privkey_data = cls._load_binary_file(privkey_path)
+ if passphrase:
+ privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, privkey_data, passphrase)
+ else:
+ privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, privkey_data)
+ return privkey.check()
+
+ def validate_cert_matches_priv_key(self, privkey_path):
+ # Verifies that the private key and cert match.
+ privkey_data = MDCertUtil._load_binary_file(privkey_path)
+ privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, privkey_data)
+ context = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
+ context.use_privatekey(privkey)
+ context.use_certificate(self.cert)
+ context.check_privatekey()
+
+ # --------- _utils_ ---------
+
+ def astr(self, s):
+ return s.decode('utf-8')
+
+ def _parse_tsp(self, tsp):
+ # timestampss returned by PyOpenSSL are bytes
+ # parse date and time part
+ s = ("%s-%s-%s %s:%s:%s" % (self.astr(tsp[0:4]), self.astr(tsp[4:6]), self.astr(tsp[6:8]),
+ self.astr(tsp[8:10]), self.astr(tsp[10:12]), self.astr(tsp[12:14])))
+ timestamp = datetime.strptime(s, '%Y-%m-%d %H:%M:%S')
+ # adjust timezone
+ tz_h, tz_m = 0, 0
+ m = re.match(r"([+\-]\d{2})(\d{2})", self.astr(tsp[14:]))
+ if m:
+ tz_h, tz_m = int(m.group(1)), int(m.group(2)) if tz_h > 0 else -1 * int(m.group(2))
+ return timestamp.replace(tzinfo=self.FixedOffset(60 * tz_h + tz_m))
+
+ @classmethod
+ def _load_binary_file(cls, path):
+ with open(path, mode="rb") as file:
+ return file.read()
+
+ class FixedOffset(tzinfo):
+
+ def __init__(self, offset):
+ self.__offset = timedelta(minutes=offset)
+
+ def utcoffset(self, dt):
+ return self.__offset
+
+ def tzname(self, dt):
+ return None
+
+ def dst(self, dt):
+ return timedelta(0)