summaryrefslogtreecommitdiffstats
path: root/test/modules/md/md_certs.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rwxr-xr-xtest/modules/md/md_certs.py444
1 files changed, 444 insertions, 0 deletions
diff --git a/test/modules/md/md_certs.py b/test/modules/md/md_certs.py
new file mode 100755
index 0000000..2501d25
--- /dev/null
+++ b/test/modules/md/md_certs.py
@@ -0,0 +1,444 @@
+import os
+import re
+from datetime import timedelta, datetime
+from typing import List, Any, Optional
+
+from cryptography import x509
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives.asymmetric import ec, rsa
+from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
+from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
+from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption, load_pem_private_key
+from cryptography.x509 import ExtendedKeyUsageOID, NameOID
+
+
+EC_SUPPORTED = {}
+EC_SUPPORTED.update([(curve.name.upper(), curve) for curve in [
+ ec.SECP192R1,
+ ec.SECP224R1,
+ ec.SECP256R1,
+ ec.SECP384R1,
+]])
+
+
+def _private_key(key_type):
+ if isinstance(key_type, str):
+ key_type = key_type.upper()
+ m = re.match(r'^(RSA)?(\d+)$', key_type)
+ if m:
+ key_type = int(m.group(2))
+
+ if isinstance(key_type, int):
+ return rsa.generate_private_key(
+ public_exponent=65537,
+ key_size=key_type,
+ backend=default_backend()
+ )
+ if not isinstance(key_type, ec.EllipticCurve) and key_type in EC_SUPPORTED:
+ key_type = EC_SUPPORTED[key_type]
+ return ec.generate_private_key(
+ curve=key_type,
+ backend=default_backend()
+ )
+
+
+class CertificateSpec:
+
+ def __init__(self, name: str = None, domains: List[str] = None,
+ email: str = None,
+ key_type: str = None, single_file: bool = False,
+ valid_from: timedelta = timedelta(days=-1),
+ valid_to: timedelta = timedelta(days=89),
+ client: bool = False,
+ sub_specs: List['CertificateSpec'] = None):
+ self._name = name
+ self.domains = domains
+ self.client = client
+ self.email = email
+ self.key_type = key_type
+ self.single_file = single_file
+ self.valid_from = valid_from
+ self.valid_to = valid_to
+ self.sub_specs = sub_specs
+
+ @property
+ def name(self) -> Optional[str]:
+ if self._name:
+ return self._name
+ elif self.domains:
+ return self.domains[0]
+ return None
+
+
+class Credentials:
+
+ def __init__(self, name: str, cert: Any, pkey: Any):
+ self._name = name
+ self._cert = cert
+ self._pkey = pkey
+ self._cert_file = None
+ self._pkey_file = None
+ self._store = None
+
+ @property
+ def name(self) -> str:
+ return self._name
+
+ @property
+ def subject(self) -> x509.Name:
+ return self._cert.subject
+
+ @property
+ def key_type(self):
+ if isinstance(self._pkey, RSAPrivateKey):
+ return f"rsa{self._pkey.key_size}"
+ elif isinstance(self._pkey, EllipticCurvePrivateKey):
+ return f"{self._pkey.curve.name}"
+ else:
+ raise Exception(f"unknown key type: {self._pkey}")
+
+ @property
+ def private_key(self) -> Any:
+ return self._pkey
+
+ @property
+ def certificate(self) -> Any:
+ return self._cert
+
+ @property
+ def cert_pem(self) -> bytes:
+ return self._cert.public_bytes(Encoding.PEM)
+
+ @property
+ def pkey_pem(self) -> bytes:
+ return self._pkey.private_bytes(
+ Encoding.PEM,
+ PrivateFormat.TraditionalOpenSSL if self.key_type.startswith('rsa') else PrivateFormat.PKCS8,
+ NoEncryption())
+
+ def set_store(self, store: 'CertStore'):
+ self._store = store
+
+ def set_files(self, cert_file: str, pkey_file: str = None):
+ self._cert_file = cert_file
+ self._pkey_file = pkey_file
+
+ @property
+ def cert_file(self) -> str:
+ return self._cert_file
+
+ @property
+ def pkey_file(self) -> Optional[str]:
+ return self._pkey_file
+
+ def get_first(self, name) -> Optional['Credentials']:
+ creds = self._store.get_credentials_for_name(name) if self._store else []
+ return creds[0] if len(creds) else None
+
+ def get_credentials_for_name(self, name) -> List['Credentials']:
+ return self._store.get_credentials_for_name(name) if self._store else []
+
+ def issue_certs(self, specs: List[CertificateSpec],
+ chain: List['Credentials'] = None) -> List['Credentials']:
+ return [self.issue_cert(spec=spec, chain=chain) for spec in specs]
+
+ def issue_cert(self, spec: CertificateSpec, chain: List['Credentials'] = None) -> 'Credentials':
+ key_type = spec.key_type if spec.key_type else self.key_type
+ creds = self._store.load_credentials(name=spec.name, key_type=key_type, single_file=spec.single_file) \
+ if self._store else None
+ if creds is None:
+ creds = MDTestCA.create_credentials(spec=spec, issuer=self, key_type=key_type,
+ valid_from=spec.valid_from, valid_to=spec.valid_to)
+ if self._store:
+ self._store.save(creds, single_file=spec.single_file)
+
+ if spec.sub_specs:
+ if self._store:
+ sub_store = CertStore(fpath=os.path.join(self._store.path, creds.name))
+ creds.set_store(sub_store)
+ subchain = chain.copy() if chain else []
+ subchain.append(self)
+ creds.issue_certs(spec.sub_specs, chain=subchain)
+ return creds
+
+
+class CertStore:
+
+ def __init__(self, fpath: str):
+ self._store_dir = fpath
+ if not os.path.exists(self._store_dir):
+ os.makedirs(self._store_dir)
+ self._creds_by_name = {}
+
+ @property
+ def path(self) -> str:
+ return self._store_dir
+
+ def save(self, creds: Credentials, name: str = None,
+ chain: List[Credentials] = None,
+ single_file: bool = False) -> None:
+ name = name if name is not None else creds.name
+ cert_file = self.get_cert_file(name=name, key_type=creds.key_type)
+ pkey_file = self.get_pkey_file(name=name, key_type=creds.key_type)
+ if single_file:
+ pkey_file = None
+ with open(cert_file, "wb") as fd:
+ fd.write(creds.cert_pem)
+ if chain:
+ for c in chain:
+ fd.write(c.cert_pem)
+ if pkey_file is None:
+ fd.write(creds.pkey_pem)
+ if pkey_file is not None:
+ with open(pkey_file, "wb") as fd:
+ fd.write(creds.pkey_pem)
+ creds.set_files(cert_file, pkey_file)
+ self._add_credentials(name, creds)
+
+ def _add_credentials(self, name: str, creds: Credentials):
+ if name not in self._creds_by_name:
+ self._creds_by_name[name] = []
+ self._creds_by_name[name].append(creds)
+
+ def get_credentials_for_name(self, name) -> List[Credentials]:
+ return self._creds_by_name[name] if name in self._creds_by_name else []
+
+ def get_cert_file(self, name: str, key_type=None) -> str:
+ key_infix = ".{0}".format(key_type) if key_type is not None else ""
+ return os.path.join(self._store_dir, f'{name}{key_infix}.cert.pem')
+
+ def get_pkey_file(self, name: str, key_type=None) -> str:
+ key_infix = ".{0}".format(key_type) if key_type is not None else ""
+ return os.path.join(self._store_dir, f'{name}{key_infix}.pkey.pem')
+
+ def load_pem_cert(self, fpath: str) -> x509.Certificate:
+ with open(fpath) as fd:
+ return x509.load_pem_x509_certificate("".join(fd.readlines()).encode())
+
+ def load_pem_pkey(self, fpath: str):
+ with open(fpath) as fd:
+ return load_pem_private_key("".join(fd.readlines()).encode(), password=None)
+
+ def load_credentials(self, name: str, key_type=None, single_file: bool = False):
+ cert_file = self.get_cert_file(name=name, key_type=key_type)
+ pkey_file = cert_file if single_file else self.get_pkey_file(name=name, key_type=key_type)
+ if os.path.isfile(cert_file) and os.path.isfile(pkey_file):
+ cert = self.load_pem_cert(cert_file)
+ pkey = self.load_pem_pkey(pkey_file)
+ creds = Credentials(name=name, cert=cert, pkey=pkey)
+ creds.set_store(self)
+ creds.set_files(cert_file, pkey_file)
+ self._add_credentials(name, creds)
+ return creds
+ return None
+
+
+class MDTestCA:
+
+ @classmethod
+ def create_root(cls, name: str, store_dir: str, key_type: str = "rsa2048") -> Credentials:
+ store = CertStore(fpath=store_dir)
+ creds = store.load_credentials(name="ca", key_type=key_type)
+ if creds is None:
+ creds = MDTestCA._make_ca_credentials(name=name, key_type=key_type)
+ store.save(creds, name="ca")
+ creds.set_store(store)
+ return creds
+
+ @staticmethod
+ def create_credentials(spec: CertificateSpec, issuer: Credentials, key_type: Any,
+ valid_from: timedelta = timedelta(days=-1),
+ valid_to: timedelta = timedelta(days=89),
+ ) -> Credentials:
+ """Create a certificate signed by this CA for the given domains.
+ :returns: the certificate and private key PEM file paths
+ """
+ if spec.domains and len(spec.domains):
+ creds = MDTestCA._make_server_credentials(name=spec.name, domains=spec.domains,
+ issuer=issuer, valid_from=valid_from,
+ valid_to=valid_to, key_type=key_type)
+ elif spec.client:
+ creds = MDTestCA._make_client_credentials(name=spec.name, issuer=issuer,
+ email=spec.email, valid_from=valid_from,
+ valid_to=valid_to, key_type=key_type)
+ elif spec.name:
+ creds = MDTestCA._make_ca_credentials(name=spec.name, issuer=issuer,
+ valid_from=valid_from, valid_to=valid_to,
+ key_type=key_type)
+ else:
+ raise Exception(f"unrecognized certificate specification: {spec}")
+ return creds
+
+ @staticmethod
+ def _make_x509_name(org_name: str = None, common_name: str = None, parent: x509.Name = None) -> x509.Name:
+ name_pieces = []
+ if org_name:
+ oid = NameOID.ORGANIZATIONAL_UNIT_NAME if parent else NameOID.ORGANIZATION_NAME
+ name_pieces.append(x509.NameAttribute(oid, org_name))
+ elif common_name:
+ name_pieces.append(x509.NameAttribute(NameOID.COMMON_NAME, common_name))
+ if parent:
+ name_pieces.extend([rdn for rdn in parent])
+ return x509.Name(name_pieces)
+
+ @staticmethod
+ def _make_csr(
+ subject: x509.Name,
+ pkey: Any,
+ issuer_subject: Optional[Credentials],
+ valid_from_delta: timedelta = None,
+ valid_until_delta: timedelta = None
+ ):
+ pubkey = pkey.public_key()
+ issuer_subject = issuer_subject if issuer_subject is not None else subject
+
+ valid_from = datetime.now()
+ if valid_until_delta is not None:
+ valid_from += valid_from_delta
+ valid_until = datetime.now()
+ if valid_until_delta is not None:
+ valid_until += valid_until_delta
+
+ return (
+ x509.CertificateBuilder()
+ .subject_name(subject)
+ .issuer_name(issuer_subject)
+ .public_key(pubkey)
+ .not_valid_before(valid_from)
+ .not_valid_after(valid_until)
+ .serial_number(x509.random_serial_number())
+ .add_extension(
+ x509.SubjectKeyIdentifier.from_public_key(pubkey),
+ critical=False,
+ )
+ )
+
+ @staticmethod
+ def _add_ca_usages(csr: Any) -> Any:
+ return csr.add_extension(
+ x509.BasicConstraints(ca=True, path_length=9),
+ critical=True,
+ ).add_extension(
+ x509.KeyUsage(
+ digital_signature=True,
+ content_commitment=False,
+ key_encipherment=False,
+ data_encipherment=False,
+ key_agreement=False,
+ key_cert_sign=True,
+ crl_sign=True,
+ encipher_only=False,
+ decipher_only=False),
+ critical=True
+ ).add_extension(
+ x509.ExtendedKeyUsage([
+ ExtendedKeyUsageOID.CLIENT_AUTH,
+ ExtendedKeyUsageOID.SERVER_AUTH,
+ ExtendedKeyUsageOID.CODE_SIGNING,
+ ]),
+ critical=True
+ )
+
+ @staticmethod
+ def _add_leaf_usages(csr: Any, domains: List[str], issuer: Credentials) -> Any:
+ return csr.add_extension(
+ x509.BasicConstraints(ca=False, path_length=None),
+ critical=True,
+ ).add_extension(
+ x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
+ issuer.certificate.extensions.get_extension_for_class(
+ x509.SubjectKeyIdentifier).value),
+ critical=False
+ ).add_extension(
+ x509.SubjectAlternativeName([x509.DNSName(domain) for domain in domains]),
+ critical=True,
+ ).add_extension(
+ x509.ExtendedKeyUsage([
+ ExtendedKeyUsageOID.SERVER_AUTH,
+ ]),
+ critical=True
+ )
+
+ @staticmethod
+ def _add_client_usages(csr: Any, issuer: Credentials, rfc82name: str = None) -> Any:
+ cert = csr.add_extension(
+ x509.BasicConstraints(ca=False, path_length=None),
+ critical=True,
+ ).add_extension(
+ x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
+ issuer.certificate.extensions.get_extension_for_class(
+ x509.SubjectKeyIdentifier).value),
+ critical=False
+ )
+ if rfc82name:
+ cert.add_extension(
+ x509.SubjectAlternativeName([x509.RFC822Name(rfc82name)]),
+ critical=True,
+ )
+ cert.add_extension(
+ x509.ExtendedKeyUsage([
+ ExtendedKeyUsageOID.CLIENT_AUTH,
+ ]),
+ critical=True
+ )
+ return cert
+
+ @staticmethod
+ def _make_ca_credentials(name, key_type: Any,
+ issuer: Credentials = None,
+ valid_from: timedelta = timedelta(days=-1),
+ valid_to: timedelta = timedelta(days=89),
+ ) -> Credentials:
+ pkey = _private_key(key_type=key_type)
+ if issuer is not None:
+ issuer_subject = issuer.certificate.subject
+ issuer_key = issuer.private_key
+ else:
+ issuer_subject = None
+ issuer_key = pkey
+ subject = MDTestCA._make_x509_name(org_name=name, parent=issuer.subject if issuer else None)
+ csr = MDTestCA._make_csr(subject=subject,
+ issuer_subject=issuer_subject, pkey=pkey,
+ valid_from_delta=valid_from, valid_until_delta=valid_to)
+ csr = MDTestCA._add_ca_usages(csr)
+ cert = csr.sign(private_key=issuer_key,
+ algorithm=hashes.SHA256(),
+ backend=default_backend())
+ return Credentials(name=name, cert=cert, pkey=pkey)
+
+ @staticmethod
+ def _make_server_credentials(name: str, domains: List[str], issuer: Credentials,
+ key_type: Any,
+ valid_from: timedelta = timedelta(days=-1),
+ valid_to: timedelta = timedelta(days=89),
+ ) -> Credentials:
+ name = name
+ pkey = _private_key(key_type=key_type)
+ subject = MDTestCA._make_x509_name(common_name=name, parent=issuer.subject)
+ csr = MDTestCA._make_csr(subject=subject,
+ issuer_subject=issuer.certificate.subject, pkey=pkey,
+ valid_from_delta=valid_from, valid_until_delta=valid_to)
+ csr = MDTestCA._add_leaf_usages(csr, domains=domains, issuer=issuer)
+ cert = csr.sign(private_key=issuer.private_key,
+ algorithm=hashes.SHA256(),
+ backend=default_backend())
+ return Credentials(name=name, cert=cert, pkey=pkey)
+
+ @staticmethod
+ def _make_client_credentials(name: str,
+ issuer: Credentials, email: Optional[str],
+ key_type: Any,
+ valid_from: timedelta = timedelta(days=-1),
+ valid_to: timedelta = timedelta(days=89),
+ ) -> Credentials:
+ pkey = _private_key(key_type=key_type)
+ subject = MDTestCA._make_x509_name(common_name=name, parent=issuer.subject)
+ csr = MDTestCA._make_csr(subject=subject,
+ issuer_subject=issuer.certificate.subject, pkey=pkey,
+ valid_from_delta=valid_from, valid_until_delta=valid_to)
+ csr = MDTestCA._add_client_usages(csr, issuer=issuer, rfc82name=email)
+ cert = csr.sign(private_key=issuer.private_key,
+ algorithm=hashes.SHA256(),
+ backend=default_backend())
+ return Credentials(name=name, cert=cert, pkey=pkey)