import os import re from datetime import timedelta, datetime from typing import List, Any, Optional from cryptography import x509 from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import ec, rsa from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption, load_pem_private_key from cryptography.x509 import ExtendedKeyUsageOID, NameOID EC_SUPPORTED = {} EC_SUPPORTED.update([(curve.name.upper(), curve) for curve in [ ec.SECP192R1, ec.SECP224R1, ec.SECP256R1, ec.SECP384R1, ]]) def _private_key(key_type): if isinstance(key_type, str): key_type = key_type.upper() m = re.match(r'^(RSA)?(\d+)$', key_type) if m: key_type = int(m.group(2)) if isinstance(key_type, int): return rsa.generate_private_key( public_exponent=65537, key_size=key_type, backend=default_backend() ) if not isinstance(key_type, ec.EllipticCurve) and key_type in EC_SUPPORTED: key_type = EC_SUPPORTED[key_type] return ec.generate_private_key( curve=key_type, backend=default_backend() ) class CertificateSpec: def __init__(self, name: str = None, domains: List[str] = None, email: str = None, key_type: str = None, single_file: bool = False, valid_from: timedelta = timedelta(days=-1), valid_to: timedelta = timedelta(days=89), client: bool = False, sub_specs: List['CertificateSpec'] = None): self._name = name self.domains = domains self.client = client self.email = email self.key_type = key_type self.single_file = single_file self.valid_from = valid_from self.valid_to = valid_to self.sub_specs = sub_specs @property def name(self) -> Optional[str]: if self._name: return self._name elif self.domains: return self.domains[0] return None class Credentials: def __init__(self, name: str, cert: Any, pkey: Any): self._name = name self._cert = cert self._pkey = pkey self._cert_file = None self._pkey_file = None self._store = None @property def name(self) -> str: return self._name @property def subject(self) -> x509.Name: return self._cert.subject @property def key_type(self): if isinstance(self._pkey, RSAPrivateKey): return f"rsa{self._pkey.key_size}" elif isinstance(self._pkey, EllipticCurvePrivateKey): return f"{self._pkey.curve.name}" else: raise Exception(f"unknown key type: {self._pkey}") @property def private_key(self) -> Any: return self._pkey @property def certificate(self) -> Any: return self._cert @property def cert_pem(self) -> bytes: return self._cert.public_bytes(Encoding.PEM) @property def pkey_pem(self) -> bytes: return self._pkey.private_bytes( Encoding.PEM, PrivateFormat.TraditionalOpenSSL if self.key_type.startswith('rsa') else PrivateFormat.PKCS8, NoEncryption()) def set_store(self, store: 'CertStore'): self._store = store def set_files(self, cert_file: str, pkey_file: str = None): self._cert_file = cert_file self._pkey_file = pkey_file @property def cert_file(self) -> str: return self._cert_file @property def pkey_file(self) -> Optional[str]: return self._pkey_file def get_first(self, name) -> Optional['Credentials']: creds = self._store.get_credentials_for_name(name) if self._store else [] return creds[0] if len(creds) else None def get_credentials_for_name(self, name) -> List['Credentials']: return self._store.get_credentials_for_name(name) if self._store else [] def issue_certs(self, specs: List[CertificateSpec], chain: List['Credentials'] = None) -> List['Credentials']: return [self.issue_cert(spec=spec, chain=chain) for spec in specs] def issue_cert(self, spec: CertificateSpec, chain: List['Credentials'] = None) -> 'Credentials': key_type = spec.key_type if spec.key_type else self.key_type creds = self._store.load_credentials(name=spec.name, key_type=key_type, single_file=spec.single_file) \ if self._store else None if creds is None: creds = MDTestCA.create_credentials(spec=spec, issuer=self, key_type=key_type, valid_from=spec.valid_from, valid_to=spec.valid_to) if self._store: self._store.save(creds, single_file=spec.single_file) if spec.sub_specs: if self._store: sub_store = CertStore(fpath=os.path.join(self._store.path, creds.name)) creds.set_store(sub_store) subchain = chain.copy() if chain else [] subchain.append(self) creds.issue_certs(spec.sub_specs, chain=subchain) return creds class CertStore: def __init__(self, fpath: str): self._store_dir = fpath if not os.path.exists(self._store_dir): os.makedirs(self._store_dir) self._creds_by_name = {} @property def path(self) -> str: return self._store_dir def save(self, creds: Credentials, name: str = None, chain: List[Credentials] = None, single_file: bool = False) -> None: name = name if name is not None else creds.name cert_file = self.get_cert_file(name=name, key_type=creds.key_type) pkey_file = self.get_pkey_file(name=name, key_type=creds.key_type) if single_file: pkey_file = None with open(cert_file, "wb") as fd: fd.write(creds.cert_pem) if chain: for c in chain: fd.write(c.cert_pem) if pkey_file is None: fd.write(creds.pkey_pem) if pkey_file is not None: with open(pkey_file, "wb") as fd: fd.write(creds.pkey_pem) creds.set_files(cert_file, pkey_file) self._add_credentials(name, creds) def _add_credentials(self, name: str, creds: Credentials): if name not in self._creds_by_name: self._creds_by_name[name] = [] self._creds_by_name[name].append(creds) def get_credentials_for_name(self, name) -> List[Credentials]: return self._creds_by_name[name] if name in self._creds_by_name else [] def get_cert_file(self, name: str, key_type=None) -> str: key_infix = ".{0}".format(key_type) if key_type is not None else "" return os.path.join(self._store_dir, f'{name}{key_infix}.cert.pem') def get_pkey_file(self, name: str, key_type=None) -> str: key_infix = ".{0}".format(key_type) if key_type is not None else "" return os.path.join(self._store_dir, f'{name}{key_infix}.pkey.pem') def load_pem_cert(self, fpath: str) -> x509.Certificate: with open(fpath) as fd: return x509.load_pem_x509_certificate("".join(fd.readlines()).encode()) def load_pem_pkey(self, fpath: str): with open(fpath) as fd: return load_pem_private_key("".join(fd.readlines()).encode(), password=None) def load_credentials(self, name: str, key_type=None, single_file: bool = False): cert_file = self.get_cert_file(name=name, key_type=key_type) pkey_file = cert_file if single_file else self.get_pkey_file(name=name, key_type=key_type) if os.path.isfile(cert_file) and os.path.isfile(pkey_file): cert = self.load_pem_cert(cert_file) pkey = self.load_pem_pkey(pkey_file) creds = Credentials(name=name, cert=cert, pkey=pkey) creds.set_store(self) creds.set_files(cert_file, pkey_file) self._add_credentials(name, creds) return creds return None class MDTestCA: @classmethod def create_root(cls, name: str, store_dir: str, key_type: str = "rsa2048") -> Credentials: store = CertStore(fpath=store_dir) creds = store.load_credentials(name="ca", key_type=key_type) if creds is None: creds = MDTestCA._make_ca_credentials(name=name, key_type=key_type) store.save(creds, name="ca") creds.set_store(store) return creds @staticmethod def create_credentials(spec: CertificateSpec, issuer: Credentials, key_type: Any, valid_from: timedelta = timedelta(days=-1), valid_to: timedelta = timedelta(days=89), ) -> Credentials: """Create a certificate signed by this CA for the given domains. :returns: the certificate and private key PEM file paths """ if spec.domains and len(spec.domains): creds = MDTestCA._make_server_credentials(name=spec.name, domains=spec.domains, issuer=issuer, valid_from=valid_from, valid_to=valid_to, key_type=key_type) elif spec.client: creds = MDTestCA._make_client_credentials(name=spec.name, issuer=issuer, email=spec.email, valid_from=valid_from, valid_to=valid_to, key_type=key_type) elif spec.name: creds = MDTestCA._make_ca_credentials(name=spec.name, issuer=issuer, valid_from=valid_from, valid_to=valid_to, key_type=key_type) else: raise Exception(f"unrecognized certificate specification: {spec}") return creds @staticmethod def _make_x509_name(org_name: str = None, common_name: str = None, parent: x509.Name = None) -> x509.Name: name_pieces = [] if org_name: oid = NameOID.ORGANIZATIONAL_UNIT_NAME if parent else NameOID.ORGANIZATION_NAME name_pieces.append(x509.NameAttribute(oid, org_name)) elif common_name: name_pieces.append(x509.NameAttribute(NameOID.COMMON_NAME, common_name)) if parent: name_pieces.extend([rdn for rdn in parent]) return x509.Name(name_pieces) @staticmethod def _make_csr( subject: x509.Name, pkey: Any, issuer_subject: Optional[Credentials], valid_from_delta: timedelta = None, valid_until_delta: timedelta = None ): pubkey = pkey.public_key() issuer_subject = issuer_subject if issuer_subject is not None else subject valid_from = datetime.now() if valid_until_delta is not None: valid_from += valid_from_delta valid_until = datetime.now() if valid_until_delta is not None: valid_until += valid_until_delta return ( x509.CertificateBuilder() .subject_name(subject) .issuer_name(issuer_subject) .public_key(pubkey) .not_valid_before(valid_from) .not_valid_after(valid_until) .serial_number(x509.random_serial_number()) .add_extension( x509.SubjectKeyIdentifier.from_public_key(pubkey), critical=False, ) ) @staticmethod def _add_ca_usages(csr: Any) -> Any: return csr.add_extension( x509.BasicConstraints(ca=True, path_length=9), critical=True, ).add_extension( x509.KeyUsage( digital_signature=True, content_commitment=False, key_encipherment=False, data_encipherment=False, key_agreement=False, key_cert_sign=True, crl_sign=True, encipher_only=False, decipher_only=False), critical=True ).add_extension( x509.ExtendedKeyUsage([ ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH, ExtendedKeyUsageOID.CODE_SIGNING, ]), critical=True ) @staticmethod def _add_leaf_usages(csr: Any, domains: List[str], issuer: Credentials) -> Any: return csr.add_extension( x509.BasicConstraints(ca=False, path_length=None), critical=True, ).add_extension( x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier( issuer.certificate.extensions.get_extension_for_class( x509.SubjectKeyIdentifier).value), critical=False ).add_extension( x509.SubjectAlternativeName([x509.DNSName(domain) for domain in domains]), critical=True, ).add_extension( x509.ExtendedKeyUsage([ ExtendedKeyUsageOID.SERVER_AUTH, ]), critical=True ) @staticmethod def _add_client_usages(csr: Any, issuer: Credentials, rfc82name: str = None) -> Any: cert = csr.add_extension( x509.BasicConstraints(ca=False, path_length=None), critical=True, ).add_extension( x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier( issuer.certificate.extensions.get_extension_for_class( x509.SubjectKeyIdentifier).value), critical=False ) if rfc82name: cert.add_extension( x509.SubjectAlternativeName([x509.RFC822Name(rfc82name)]), critical=True, ) cert.add_extension( x509.ExtendedKeyUsage([ ExtendedKeyUsageOID.CLIENT_AUTH, ]), critical=True ) return cert @staticmethod def _make_ca_credentials(name, key_type: Any, issuer: Credentials = None, valid_from: timedelta = timedelta(days=-1), valid_to: timedelta = timedelta(days=89), ) -> Credentials: pkey = _private_key(key_type=key_type) if issuer is not None: issuer_subject = issuer.certificate.subject issuer_key = issuer.private_key else: issuer_subject = None issuer_key = pkey subject = MDTestCA._make_x509_name(org_name=name, parent=issuer.subject if issuer else None) csr = MDTestCA._make_csr(subject=subject, issuer_subject=issuer_subject, pkey=pkey, valid_from_delta=valid_from, valid_until_delta=valid_to) csr = MDTestCA._add_ca_usages(csr) cert = csr.sign(private_key=issuer_key, algorithm=hashes.SHA256(), backend=default_backend()) return Credentials(name=name, cert=cert, pkey=pkey) @staticmethod def _make_server_credentials(name: str, domains: List[str], issuer: Credentials, key_type: Any, valid_from: timedelta = timedelta(days=-1), valid_to: timedelta = timedelta(days=89), ) -> Credentials: name = name pkey = _private_key(key_type=key_type) subject = MDTestCA._make_x509_name(common_name=name, parent=issuer.subject) csr = MDTestCA._make_csr(subject=subject, issuer_subject=issuer.certificate.subject, pkey=pkey, valid_from_delta=valid_from, valid_until_delta=valid_to) csr = MDTestCA._add_leaf_usages(csr, domains=domains, issuer=issuer) cert = csr.sign(private_key=issuer.private_key, algorithm=hashes.SHA256(), backend=default_backend()) return Credentials(name=name, cert=cert, pkey=pkey) @staticmethod def _make_client_credentials(name: str, issuer: Credentials, email: Optional[str], key_type: Any, valid_from: timedelta = timedelta(days=-1), valid_to: timedelta = timedelta(days=89), ) -> Credentials: pkey = _private_key(key_type=key_type) subject = MDTestCA._make_x509_name(common_name=name, parent=issuer.subject) csr = MDTestCA._make_csr(subject=subject, issuer_subject=issuer.certificate.subject, pkey=pkey, valid_from_delta=valid_from, valid_until_delta=valid_to) csr = MDTestCA._add_client_usages(csr, issuer=issuer, rfc82name=email) cert = csr.sign(private_key=issuer.private_key, algorithm=hashes.SHA256(), backend=default_backend()) return Credentials(name=name, cert=cert, pkey=pkey)