diff options
Diffstat (limited to '')
-rw-r--r-- | examples/tests/ngtcp2test/__init__.py | 6 | ||||
-rw-r--r-- | examples/tests/ngtcp2test/certs.py | 476 | ||||
-rw-r--r-- | examples/tests/ngtcp2test/client.py | 187 | ||||
-rw-r--r-- | examples/tests/ngtcp2test/env.py | 191 | ||||
-rw-r--r-- | examples/tests/ngtcp2test/log.py | 101 | ||||
-rw-r--r-- | examples/tests/ngtcp2test/server.py | 137 | ||||
-rw-r--r-- | examples/tests/ngtcp2test/tls.py | 983 |
7 files changed, 2081 insertions, 0 deletions
diff --git a/examples/tests/ngtcp2test/__init__.py b/examples/tests/ngtcp2test/__init__.py new file mode 100644 index 0000000..65c61d8 --- /dev/null +++ b/examples/tests/ngtcp2test/__init__.py @@ -0,0 +1,6 @@ +from .env import Env, CryptoLib +from .log import LogFile +from .client import ExampleClient, ClientRun +from .server import ExampleServer, ServerRun +from .certs import Ngtcp2TestCA, Credentials +from .tls import HandShake, HSRecord diff --git a/examples/tests/ngtcp2test/certs.py b/examples/tests/ngtcp2test/certs.py new file mode 100644 index 0000000..3ab6260 --- /dev/null +++ b/examples/tests/ngtcp2test/certs.py @@ -0,0 +1,476 @@ +import os +import re +from datetime import timedelta, datetime +from typing import List, Any, Optional + +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import ec, rsa +from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey +from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption, load_pem_private_key +from cryptography.x509 import ExtendedKeyUsageOID, NameOID + + +EC_SUPPORTED = {} +EC_SUPPORTED.update([(curve.name.upper(), curve) for curve in [ + ec.SECP192R1, + ec.SECP224R1, + ec.SECP256R1, + ec.SECP384R1, +]]) + + +def _private_key(key_type): + if isinstance(key_type, str): + key_type = key_type.upper() + m = re.match(r'^(RSA)?(\d+)$', key_type) + if m: + key_type = int(m.group(2)) + + if isinstance(key_type, int): + return rsa.generate_private_key( + public_exponent=65537, + key_size=key_type, + backend=default_backend() + ) + if not isinstance(key_type, ec.EllipticCurve) and key_type in EC_SUPPORTED: + key_type = EC_SUPPORTED[key_type] + return ec.generate_private_key( + curve=key_type, + backend=default_backend() + ) + + +class CertificateSpec: + + def __init__(self, name: str = None, domains: List[str] = None, + email: str = None, + key_type: str = None, single_file: bool = False, + valid_from: timedelta = timedelta(days=-1), + valid_to: timedelta = timedelta(days=89), + client: bool = False, + sub_specs: List['CertificateSpec'] = None): + self._name = name + self.domains = domains + self.client = client + self.email = email + self.key_type = key_type + self.single_file = single_file + self.valid_from = valid_from + self.valid_to = valid_to + self.sub_specs = sub_specs + + @property + def name(self) -> Optional[str]: + if self._name: + return self._name + elif self.domains: + return self.domains[0] + return None + + @property + def type(self) -> Optional[str]: + if self.domains and len(self.domains): + return "server" + elif self.client: + return "client" + elif self.name: + return "ca" + return None + + +class Credentials: + + def __init__(self, name: str, cert: Any, pkey: Any, issuer: 'Credentials' = None): + self._name = name + self._cert = cert + self._pkey = pkey + self._issuer = issuer + self._cert_file = None + self._pkey_file = None + self._store = None + + @property + def name(self) -> str: + return self._name + + @property + def subject(self) -> x509.Name: + return self._cert.subject + + @property + def key_type(self): + if isinstance(self._pkey, RSAPrivateKey): + return f"rsa{self._pkey.key_size}" + elif isinstance(self._pkey, EllipticCurvePrivateKey): + return f"{self._pkey.curve.name}" + else: + raise Exception(f"unknown key type: {self._pkey}") + + @property + def private_key(self) -> Any: + return self._pkey + + @property + def certificate(self) -> Any: + return self._cert + + @property + def cert_pem(self) -> bytes: + return self._cert.public_bytes(Encoding.PEM) + + @property + def pkey_pem(self) -> bytes: + return self._pkey.private_bytes( + Encoding.PEM, + PrivateFormat.TraditionalOpenSSL if self.key_type.startswith('rsa') else PrivateFormat.PKCS8, + NoEncryption()) + + @property + def issuer(self) -> Optional['Credentials']: + return self._issuer + + def set_store(self, store: 'CertStore'): + self._store = store + + def set_files(self, cert_file: str, pkey_file: str = None): + self._cert_file = cert_file + self._pkey_file = pkey_file + + @property + def cert_file(self) -> str: + return self._cert_file + + @property + def pkey_file(self) -> Optional[str]: + return self._pkey_file + + def get_first(self, name) -> Optional['Credentials']: + creds = self._store.get_credentials_for_name(name) if self._store else [] + return creds[0] if len(creds) else None + + def get_credentials_for_name(self, name) -> List['Credentials']: + return self._store.get_credentials_for_name(name) if self._store else [] + + def issue_certs(self, specs: List[CertificateSpec], + chain: List['Credentials'] = None) -> List['Credentials']: + return [self.issue_cert(spec=spec, chain=chain) for spec in specs] + + def issue_cert(self, spec: CertificateSpec, chain: List['Credentials'] = None) -> 'Credentials': + key_type = spec.key_type if spec.key_type else self.key_type + creds = None + if self._store: + creds = self._store.load_credentials( + name=spec.name, key_type=key_type, single_file=spec.single_file, issuer=self) + if creds is None: + creds = Ngtcp2TestCA.create_credentials(spec=spec, issuer=self, key_type=key_type, + valid_from=spec.valid_from, valid_to=spec.valid_to) + if self._store: + self._store.save(creds, single_file=spec.single_file) + if spec.type == "ca": + self._store.save_chain(creds, "ca", with_root=True) + + if spec.sub_specs: + if self._store: + sub_store = CertStore(fpath=os.path.join(self._store.path, creds.name)) + creds.set_store(sub_store) + subchain = chain.copy() if chain else [] + subchain.append(self) + creds.issue_certs(spec.sub_specs, chain=subchain) + return creds + + +class CertStore: + + def __init__(self, fpath: str): + self._store_dir = fpath + if not os.path.exists(self._store_dir): + os.makedirs(self._store_dir) + self._creds_by_name = {} + + @property + def path(self) -> str: + return self._store_dir + + def save(self, creds: Credentials, name: str = None, + chain: List[Credentials] = None, + single_file: bool = False) -> None: + name = name if name is not None else creds.name + cert_file = self.get_cert_file(name=name, key_type=creds.key_type) + pkey_file = self.get_pkey_file(name=name, key_type=creds.key_type) + if single_file: + pkey_file = None + with open(cert_file, "wb") as fd: + fd.write(creds.cert_pem) + if chain: + for c in chain: + fd.write(c.cert_pem) + if pkey_file is None: + fd.write(creds.pkey_pem) + if pkey_file is not None: + with open(pkey_file, "wb") as fd: + fd.write(creds.pkey_pem) + creds.set_files(cert_file, pkey_file) + self._add_credentials(name, creds) + + def save_chain(self, creds: Credentials, infix: str, with_root=False): + name = creds.name + chain = [creds] + while creds.issuer is not None: + creds = creds.issuer + chain.append(creds) + if not with_root and len(chain) > 1: + chain = chain[:-1] + chain_file = os.path.join(self._store_dir, f'{name}-{infix}.pem') + with open(chain_file, "wb") as fd: + for c in chain: + fd.write(c.cert_pem) + + def _add_credentials(self, name: str, creds: Credentials): + if name not in self._creds_by_name: + self._creds_by_name[name] = [] + self._creds_by_name[name].append(creds) + + def get_credentials_for_name(self, name) -> List[Credentials]: + return self._creds_by_name[name] if name in self._creds_by_name else [] + + def get_cert_file(self, name: str, key_type=None) -> str: + key_infix = ".{0}".format(key_type) if key_type is not None else "" + return os.path.join(self._store_dir, f'{name}{key_infix}.cert.pem') + + def get_pkey_file(self, name: str, key_type=None) -> str: + key_infix = ".{0}".format(key_type) if key_type is not None else "" + return os.path.join(self._store_dir, f'{name}{key_infix}.pkey.pem') + + def load_pem_cert(self, fpath: str) -> x509.Certificate: + with open(fpath) as fd: + return x509.load_pem_x509_certificate("".join(fd.readlines()).encode()) + + def load_pem_pkey(self, fpath: str): + with open(fpath) as fd: + return load_pem_private_key("".join(fd.readlines()).encode(), password=None) + + def load_credentials(self, name: str, key_type=None, single_file: bool = False, issuer: Credentials = None): + cert_file = self.get_cert_file(name=name, key_type=key_type) + pkey_file = cert_file if single_file else self.get_pkey_file(name=name, key_type=key_type) + if os.path.isfile(cert_file) and os.path.isfile(pkey_file): + cert = self.load_pem_cert(cert_file) + pkey = self.load_pem_pkey(pkey_file) + creds = Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer) + creds.set_store(self) + creds.set_files(cert_file, pkey_file) + self._add_credentials(name, creds) + return creds + return None + + +class Ngtcp2TestCA: + + @classmethod + def create_root(cls, name: str, store_dir: str, key_type: str = "rsa2048") -> Credentials: + store = CertStore(fpath=store_dir) + creds = store.load_credentials(name="ca", key_type=key_type, issuer=None) + if creds is None: + creds = Ngtcp2TestCA._make_ca_credentials(name=name, key_type=key_type) + store.save(creds, name="ca") + creds.set_store(store) + return creds + + @staticmethod + def create_credentials(spec: CertificateSpec, issuer: Credentials, key_type: Any, + valid_from: timedelta = timedelta(days=-1), + valid_to: timedelta = timedelta(days=89), + ) -> Credentials: + """Create a certificate signed by this CA for the given domains. + :returns: the certificate and private key PEM file paths + """ + if spec.domains and len(spec.domains): + creds = Ngtcp2TestCA._make_server_credentials(name=spec.name, domains=spec.domains, + issuer=issuer, valid_from=valid_from, + valid_to=valid_to, key_type=key_type) + elif spec.client: + creds = Ngtcp2TestCA._make_client_credentials(name=spec.name, issuer=issuer, + email=spec.email, valid_from=valid_from, + valid_to=valid_to, key_type=key_type) + elif spec.name: + creds = Ngtcp2TestCA._make_ca_credentials(name=spec.name, issuer=issuer, + valid_from=valid_from, valid_to=valid_to, + key_type=key_type) + else: + raise Exception(f"unrecognized certificate specification: {spec}") + return creds + + @staticmethod + def _make_x509_name(org_name: str = None, common_name: str = None, parent: x509.Name = None) -> x509.Name: + name_pieces = [] + if org_name: + oid = NameOID.ORGANIZATIONAL_UNIT_NAME if parent else NameOID.ORGANIZATION_NAME + name_pieces.append(x509.NameAttribute(oid, org_name)) + elif common_name: + name_pieces.append(x509.NameAttribute(NameOID.COMMON_NAME, common_name)) + if parent: + name_pieces.extend([rdn for rdn in parent]) + return x509.Name(name_pieces) + + @staticmethod + def _make_csr( + subject: x509.Name, + pkey: Any, + issuer_subject: Optional[Credentials], + valid_from_delta: timedelta = None, + valid_until_delta: timedelta = None + ): + pubkey = pkey.public_key() + issuer_subject = issuer_subject if issuer_subject is not None else subject + + valid_from = datetime.now() + if valid_until_delta is not None: + valid_from += valid_from_delta + valid_until = datetime.now() + if valid_until_delta is not None: + valid_until += valid_until_delta + + return ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer_subject) + .public_key(pubkey) + .not_valid_before(valid_from) + .not_valid_after(valid_until) + .serial_number(x509.random_serial_number()) + .add_extension( + x509.SubjectKeyIdentifier.from_public_key(pubkey), + critical=False, + ) + ) + + @staticmethod + def _add_ca_usages(csr: Any) -> Any: + return csr.add_extension( + x509.BasicConstraints(ca=True, path_length=9), + critical=True, + ).add_extension( + x509.KeyUsage( + digital_signature=True, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=True, + crl_sign=True, + encipher_only=False, + decipher_only=False), + critical=True + ).add_extension( + x509.ExtendedKeyUsage([ + ExtendedKeyUsageOID.CLIENT_AUTH, + ExtendedKeyUsageOID.SERVER_AUTH, + ExtendedKeyUsageOID.CODE_SIGNING, + ]), + critical=True + ) + + @staticmethod + def _add_leaf_usages(csr: Any, domains: List[str], issuer: Credentials) -> Any: + return csr.add_extension( + x509.BasicConstraints(ca=False, path_length=None), + critical=True, + ).add_extension( + x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier( + issuer.certificate.extensions.get_extension_for_class( + x509.SubjectKeyIdentifier).value), + critical=False + ).add_extension( + x509.SubjectAlternativeName([x509.DNSName(domain) for domain in domains]), + critical=True, + ).add_extension( + x509.ExtendedKeyUsage([ + ExtendedKeyUsageOID.SERVER_AUTH, + ]), + critical=True + ) + + @staticmethod + def _add_client_usages(csr: Any, issuer: Credentials, rfc82name: str = None) -> Any: + cert = csr.add_extension( + x509.BasicConstraints(ca=False, path_length=None), + critical=True, + ).add_extension( + x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier( + issuer.certificate.extensions.get_extension_for_class( + x509.SubjectKeyIdentifier).value), + critical=False + ) + if rfc82name: + cert.add_extension( + x509.SubjectAlternativeName([x509.RFC822Name(rfc82name)]), + critical=True, + ) + cert.add_extension( + x509.ExtendedKeyUsage([ + ExtendedKeyUsageOID.CLIENT_AUTH, + ]), + critical=True + ) + return cert + + @staticmethod + def _make_ca_credentials(name, key_type: Any, + issuer: Credentials = None, + valid_from: timedelta = timedelta(days=-1), + valid_to: timedelta = timedelta(days=89), + ) -> Credentials: + pkey = _private_key(key_type=key_type) + if issuer is not None: + issuer_subject = issuer.certificate.subject + issuer_key = issuer.private_key + else: + issuer_subject = None + issuer_key = pkey + subject = Ngtcp2TestCA._make_x509_name(org_name=name, parent=issuer.subject if issuer else None) + csr = Ngtcp2TestCA._make_csr(subject=subject, + issuer_subject=issuer_subject, pkey=pkey, + valid_from_delta=valid_from, valid_until_delta=valid_to) + csr = Ngtcp2TestCA._add_ca_usages(csr) + cert = csr.sign(private_key=issuer_key, + algorithm=hashes.SHA256(), + backend=default_backend()) + return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer) + + @staticmethod + def _make_server_credentials(name: str, domains: List[str], issuer: Credentials, + key_type: Any, + valid_from: timedelta = timedelta(days=-1), + valid_to: timedelta = timedelta(days=89), + ) -> Credentials: + name = name + pkey = _private_key(key_type=key_type) + subject = Ngtcp2TestCA._make_x509_name(common_name=name, parent=issuer.subject) + csr = Ngtcp2TestCA._make_csr(subject=subject, + issuer_subject=issuer.certificate.subject, pkey=pkey, + valid_from_delta=valid_from, valid_until_delta=valid_to) + csr = Ngtcp2TestCA._add_leaf_usages(csr, domains=domains, issuer=issuer) + cert = csr.sign(private_key=issuer.private_key, + algorithm=hashes.SHA256(), + backend=default_backend()) + return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer) + + @staticmethod + def _make_client_credentials(name: str, + issuer: Credentials, email: Optional[str], + key_type: Any, + valid_from: timedelta = timedelta(days=-1), + valid_to: timedelta = timedelta(days=89), + ) -> Credentials: + pkey = _private_key(key_type=key_type) + subject = Ngtcp2TestCA._make_x509_name(common_name=name, parent=issuer.subject) + csr = Ngtcp2TestCA._make_csr(subject=subject, + issuer_subject=issuer.certificate.subject, pkey=pkey, + valid_from_delta=valid_from, valid_until_delta=valid_to) + csr = Ngtcp2TestCA._add_client_usages(csr, issuer=issuer, rfc82name=email) + cert = csr.sign(private_key=issuer.private_key, + algorithm=hashes.SHA256(), + backend=default_backend()) + return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer) diff --git a/examples/tests/ngtcp2test/client.py b/examples/tests/ngtcp2test/client.py new file mode 100644 index 0000000..2676343 --- /dev/null +++ b/examples/tests/ngtcp2test/client.py @@ -0,0 +1,187 @@ +import logging +import os +import re +import subprocess +from typing import List + +import pytest + +from .server import ExampleServer, ServerRun +from .certs import Credentials +from .tls import HandShake, HSRecord +from .env import Env, CryptoLib +from .log import LogFile, HexDumpScanner + + +log = logging.getLogger(__name__) + + +class ClientRun: + + def __init__(self, env: Env, returncode, logfile: LogFile, srun: ServerRun): + self.env = env + self.returncode = returncode + self.logfile = logfile + self.log_lines = logfile.get_recent() + self._data_recs = None + self._hs_recs = None + self._srun = srun + if self.env.verbose > 1: + log.debug(f'read {len(self.log_lines)} lines from {logfile.path}') + + @property + def handshake(self) -> List[HSRecord]: + if self._data_recs is None: + crypto_line = re.compile(r'Ordered CRYPTO data in \S+ crypto level') + scanner = HexDumpScanner(source=self.log_lines, + leading_regex=crypto_line) + self._data_recs = [data for data in scanner] + if self.env.verbose > 1: + log.debug(f'detected {len(self._data_recs)} crypto hexdumps ' + f'in {self.logfile.path}') + if self._hs_recs is None: + self._hs_recs = [hrec for hrec in HandShake(source=self._data_recs, + verbose=self.env.verbose)] + if self.env.verbose > 1: + log.debug(f'detected {len(self._hs_recs)} crypto ' + f'records in {self.logfile.path}') + return self._hs_recs + + @property + def hs_stripe(self) -> str: + return ":".join([hrec.name for hrec in self.handshake]) + + @property + def early_data_rejected(self) -> bool: + for l in self.log_lines: + if re.match(r'^Early data was rejected by server.*', l): + return True + return False + + @property + def server(self) -> ServerRun: + return self._srun + + def norm_exp(self, c_hs, s_hs, allow_hello_retry=True): + if allow_hello_retry and self.hs_stripe.startswith('HelloRetryRequest:'): + c_hs = "HelloRetryRequest:" + c_hs + s_hs = "ClientHello:" + s_hs + return c_hs, s_hs + + def _assert_hs(self, c_hs, s_hs): + if not self.hs_stripe.startswith(c_hs): + # what happened? + if self.hs_stripe == '': + # server send nothing + if self.server.hs_stripe == '': + # client send nothing + pytest.fail(f'client did not send a ClientHello"') + else: + # client send sth, but server did not respond + pytest.fail(f'server did not respond to ClientHello: ' + f'{self.server.handshake[0].to_text()}"') + else: + pytest.fail(f'Expected "{c_hs}", got "{self.hs_stripe}"') + assert self.server.hs_stripe == s_hs, \ + f'Expected "{s_hs}", got "{self.server.hs_stripe}"\n' + + def assert_non_resume_handshake(self, allow_hello_retry=True): + # for client/server where KEY_SHARE do not match, the hello is retried + c_hs, s_hs = self.norm_exp( + "ServerHello:EncryptedExtensions:Certificate:CertificateVerify:Finished", + "ClientHello:Finished", allow_hello_retry=allow_hello_retry) + self._assert_hs(c_hs, s_hs) + + def assert_resume_handshake(self): + # for client/server where KEY_SHARE do not match, the hello is retried + c_hs, s_hs = self.norm_exp("ServerHello:EncryptedExtensions:Finished", + "ClientHello:Finished") + self._assert_hs(c_hs, s_hs) + + def assert_verify_null_handshake(self): + c_hs, s_hs = self.norm_exp( + "ServerHello:EncryptedExtensions:CertificateRequest:Certificate:CertificateVerify:Finished", + "ClientHello:Certificate:Finished") + self._assert_hs(c_hs, s_hs) + + def assert_verify_cert_handshake(self): + c_hs, s_hs = self.norm_exp( + "ServerHello:EncryptedExtensions:CertificateRequest:Certificate:CertificateVerify:Finished", + "ClientHello:Certificate:CertificateVerify:Finished") + self._assert_hs(c_hs, s_hs) + + +class ExampleClient: + + def __init__(self, env: Env, crypto_lib: str): + self.env = env + self._crypto_lib = crypto_lib + self._path = env.client_path(self._crypto_lib) + self._log_path = f'{self.env.gen_dir}/{self._crypto_lib}-client.log' + self._qlog_path = f'{self.env.gen_dir}/{self._crypto_lib}-client.qlog' + self._session_path = f'{self.env.gen_dir}/{self._crypto_lib}-client.session' + self._tp_path = f'{self.env.gen_dir}/{self._crypto_lib}-client.tp' + self._data_path = f'{self.env.gen_dir}/{self._crypto_lib}-client.data' + + @property + def path(self): + return self._path + + @property + def crypto_lib(self): + return self._crypto_lib + + @property + def uses_cipher_config(self): + return CryptoLib.uses_cipher_config(self.crypto_lib) + + def supports_cipher(self, cipher): + return CryptoLib.supports_cipher(self.crypto_lib, cipher) + + def exists(self): + return os.path.isfile(self.path) + + def clear_session(self): + if os.path.isfile(self._session_path): + os.remove(self._session_path) + if os.path.isfile(self._tp_path): + os.remove(self._tp_path) + + def http_get(self, server: ExampleServer, url: str, extra_args: List[str] = None, + use_session=False, data=None, + credentials: Credentials = None, + ciphers: str = None): + args = [ + self.path, '--exit-on-all-streams-close', + f'--qlog-file={self._qlog_path}' + ] + if use_session: + args.append(f'--session-file={self._session_path}') + args.append(f'--tp-file={self._tp_path}') + if data is not None: + with open(self._data_path, 'w') as fd: + fd.write(data) + args.append(f'--data={self._data_path}') + if ciphers is not None: + ciphers = CryptoLib.adjust_ciphers(self.crypto_lib, ciphers) + args.append(f'--ciphers={ciphers}') + if credentials is not None: + args.append(f'--key={credentials.pkey_file}') + args.append(f'--cert={credentials.cert_file}') + if extra_args is not None: + args.extend(extra_args) + args.extend([ + 'localhost', str(self.env.examples_port), + url + ]) + if os.path.isfile(self._qlog_path): + os.remove(self._qlog_path) + with open(self._log_path, 'w') as log_file: + logfile = LogFile(path=self._log_path) + server.log.advance() + process = subprocess.Popen(args=args, text=True, + stdout=log_file, stderr=log_file) + process.wait() + return ClientRun(env=self.env, returncode=process.returncode, + logfile=logfile, srun=server.get_run()) + diff --git a/examples/tests/ngtcp2test/env.py b/examples/tests/ngtcp2test/env.py new file mode 100644 index 0000000..9699d55 --- /dev/null +++ b/examples/tests/ngtcp2test/env.py @@ -0,0 +1,191 @@ +import logging +import os +from configparser import ConfigParser, ExtendedInterpolation +from typing import Dict, Optional + +from .certs import CertificateSpec, Ngtcp2TestCA, Credentials + +log = logging.getLogger(__name__) + + +class CryptoLib: + + IGNORES_CIPHER_CONFIG = [ + 'picotls', 'boringssl' + ] + UNSUPPORTED_CIPHERS = { + 'wolfssl': [ + 'TLS_AES_128_CCM_SHA256', # no plans to + ], + 'picotls': [ + 'TLS_AES_128_CCM_SHA256', # no plans to + ], + 'boringssl': [ + 'TLS_AES_128_CCM_SHA256', # no plans to + ] + } + GNUTLS_CIPHERS = { + 'TLS_AES_128_GCM_SHA256': 'AES-128-GCM', + 'TLS_AES_256_GCM_SHA384': 'AES-256-GCM', + 'TLS_CHACHA20_POLY1305_SHA256': 'CHACHA20-POLY1305', + 'TLS_AES_128_CCM_SHA256': 'AES-128-CCM', + } + + @classmethod + def uses_cipher_config(cls, crypto_lib): + return crypto_lib not in cls.IGNORES_CIPHER_CONFIG + + @classmethod + def supports_cipher(cls, crypto_lib, cipher): + return crypto_lib not in cls.UNSUPPORTED_CIPHERS or \ + cipher not in cls.UNSUPPORTED_CIPHERS[crypto_lib] + + @classmethod + def adjust_ciphers(cls, crypto_lib, ciphers: str) -> str: + if crypto_lib == 'gnutls': + gciphers = "NORMAL:-VERS-ALL:+VERS-TLS1.3:-CIPHER-ALL" + for cipher in ciphers.split(':'): + gciphers += f':+{cls.GNUTLS_CIPHERS[cipher]}' + return gciphers + return ciphers + + +def init_config_from(conf_path): + if os.path.isfile(conf_path): + config = ConfigParser(interpolation=ExtendedInterpolation()) + config.read(conf_path) + return config + return None + + +TESTS_PATH = os.path.dirname(os.path.dirname(__file__)) +EXAMPLES_PATH = os.path.dirname(TESTS_PATH) +DEF_CONFIG = init_config_from(os.path.join(TESTS_PATH, 'config.ini')) + + +class Env: + + @classmethod + def get_crypto_libs(cls, configurable_ciphers=None): + names = [name for name in DEF_CONFIG['examples'] + if DEF_CONFIG['examples'][name] == 'yes'] + if configurable_ciphers is not None: + names = [n for n in names if CryptoLib.uses_cipher_config(n)] + return names + + def __init__(self, examples_dir=None, tests_dir=None, config=None, + pytestconfig=None): + self._verbose = pytestconfig.option.verbose if pytestconfig is not None else 0 + self._examples_dir = examples_dir if examples_dir is not None else EXAMPLES_PATH + self._tests_dir = examples_dir if tests_dir is not None else TESTS_PATH + self._gen_dir = os.path.join(self._tests_dir, 'gen') + self.config = config if config is not None else DEF_CONFIG + self._version = self.config['ngtcp2']['version'] + self._crypto_libs = [name for name in self.config['examples'] + if self.config['examples'][name] == 'yes'] + self._clients = [self.config['clients'][lib] for lib in self._crypto_libs + if lib in self.config['clients']] + self._servers = [self.config['servers'][lib] for lib in self._crypto_libs + if lib in self.config['servers']] + self._examples_pem = { + 'key': 'xxx', + 'cert': 'xxx', + } + self._htdocs_dir = os.path.join(self._gen_dir, 'htdocs') + self._tld = 'tests.ngtcp2.nghttp2.org' + self._example_domain = f"one.{self._tld}" + self._ca = None + self._cert_specs = [ + CertificateSpec(domains=[self._example_domain], key_type='rsa2048'), + CertificateSpec(name="clientsX", sub_specs=[ + CertificateSpec(name="user1", client=True), + ]), + ] + + def issue_certs(self): + if self._ca is None: + self._ca = Ngtcp2TestCA.create_root(name=self._tld, + store_dir=os.path.join(self.gen_dir, 'ca'), + key_type="rsa2048") + self._ca.issue_certs(self._cert_specs) + + def setup(self): + os.makedirs(self._gen_dir, exist_ok=True) + os.makedirs(self._htdocs_dir, exist_ok=True) + self.issue_certs() + + def get_server_credentials(self) -> Optional[Credentials]: + creds = self.ca.get_credentials_for_name(self._example_domain) + if len(creds) > 0: + return creds[0] + return None + + @property + def verbose(self) -> int: + return self._verbose + + @property + def version(self) -> str: + return self._version + + @property + def gen_dir(self) -> str: + return self._gen_dir + + @property + def ca(self): + return self._ca + + @property + def htdocs_dir(self) -> str: + return self._htdocs_dir + + @property + def example_domain(self) -> str: + return self._example_domain + + @property + def examples_dir(self) -> str: + return self._examples_dir + + @property + def examples_port(self) -> int: + return int(self.config['examples']['port']) + + @property + def examples_pem(self) -> Dict[str, str]: + return self._examples_pem + + @property + def crypto_libs(self): + return self._crypto_libs + + @property + def clients(self): + return self._clients + + @property + def servers(self): + return self._servers + + def client_name(self, crypto_lib): + if crypto_lib in self.config['clients']: + return self.config['clients'][crypto_lib] + return None + + def client_path(self, crypto_lib): + cname = self.client_name(crypto_lib) + if cname is not None: + return os.path.join(self.examples_dir, cname) + return None + + def server_name(self, crypto_lib): + if crypto_lib in self.config['servers']: + return self.config['servers'][crypto_lib] + return None + + def server_path(self, crypto_lib): + sname = self.server_name(crypto_lib) + if sname is not None: + return os.path.join(self.examples_dir, sname) + return None diff --git a/examples/tests/ngtcp2test/log.py b/examples/tests/ngtcp2test/log.py new file mode 100644 index 0000000..9e8f399 --- /dev/null +++ b/examples/tests/ngtcp2test/log.py @@ -0,0 +1,101 @@ +import binascii +import os +import re +import sys +import time +from datetime import timedelta, datetime +from io import SEEK_END +from typing import List + + +class LogFile: + + def __init__(self, path: str): + self._path = path + self._start_pos = 0 + self._last_pos = self._start_pos + + @property + def path(self) -> str: + return self._path + + def reset(self): + self._start_pos = 0 + self._last_pos = self._start_pos + + def advance(self) -> None: + if os.path.isfile(self._path): + with open(self._path) as fd: + self._start_pos = fd.seek(0, SEEK_END) + + def get_recent(self, advance=True) -> List[str]: + lines = [] + if os.path.isfile(self._path): + with open(self._path) as fd: + fd.seek(self._last_pos, os.SEEK_SET) + for line in fd: + lines.append(line) + if advance: + self._last_pos = fd.tell() + return lines + + def scan_recent(self, pattern: re, timeout=10) -> bool: + if not os.path.isfile(self.path): + return False + with open(self.path) as fd: + end = datetime.now() + timedelta(seconds=timeout) + while True: + fd.seek(self._last_pos, os.SEEK_SET) + for line in fd: + if pattern.match(line): + return True + if datetime.now() > end: + raise TimeoutError(f"pattern not found in error log after {timeout} seconds") + time.sleep(.1) + return False + + +class HexDumpScanner: + + def __init__(self, source, leading_regex=None): + self._source = source + self._leading_regex = leading_regex + + def __iter__(self): + data = b'' + offset = 0 if self._leading_regex is None else -1 + idx = 0 + for l in self._source: + if offset == -1: + pass + elif offset == 0: + # possible start of a hex dump + m = re.match(r'^\s*0+(\s+-)?((\s+[0-9a-f]{2}){1,16})(\s+.*)$', + l, re.IGNORECASE) + if m: + data = binascii.unhexlify(re.sub(r'\s+', '', m.group(2))) + offset = 16 + idx = 1 + continue + else: + # possible continuation of a hexdump + m = re.match(r'^\s*([0-9a-f]+)(\s+-)?((\s+[0-9a-f]{2}){1,16})' + r'(\s+.*)$', l, re.IGNORECASE) + if m: + loffset = int(m.group(1), 16) + if loffset == offset or loffset == idx: + data += binascii.unhexlify(re.sub(r'\s+', '', + m.group(3))) + offset += 16 + idx += 1 + continue + else: + sys.stderr.write(f'wrong offset {loffset}, expected {offset} or {idx}\n') + # not a hexdump line, produce any collected data + if len(data) > 0: + yield data + data = b'' + offset = 0 if self._leading_regex is None \ + or self._leading_regex.match(l) else -1 + if len(data) > 0: + yield data diff --git a/examples/tests/ngtcp2test/server.py b/examples/tests/ngtcp2test/server.py new file mode 100644 index 0000000..9f4e9a0 --- /dev/null +++ b/examples/tests/ngtcp2test/server.py @@ -0,0 +1,137 @@ +import logging +import os +import re +import subprocess +import time +from datetime import datetime, timedelta +from threading import Thread + +from .tls import HandShake +from .env import Env, CryptoLib +from .log import LogFile, HexDumpScanner + + +log = logging.getLogger(__name__) + + +class ServerRun: + + def __init__(self, env: Env, logfile: LogFile): + self.env = env + self._logfile = logfile + self.log_lines = self._logfile.get_recent() + self._data_recs = None + self._hs_recs = None + if self.env.verbose > 1: + log.debug(f'read {len(self.log_lines)} lines from {logfile.path}') + + @property + def handshake(self): + if self._data_recs is None: + self._data_recs = [data for data in HexDumpScanner(source=self.log_lines)] + if self.env.verbose > 1: + log.debug(f'detected {len(self._data_recs)} hexdumps ' + f'in {self._logfile.path}') + if self._hs_recs is None: + self._hs_recs = [hrec for hrec in HandShake(source=self._data_recs, + verbose=self.env.verbose)] + if self.env.verbose > 1: + log.debug(f'detected {len(self._hs_recs)} crypto records ' + f'in {self._logfile.path}') + return self._hs_recs + + @property + def hs_stripe(self): + return ":".join([hrec.name for hrec in self.handshake]) + + +def monitor_proc(env: Env, proc): + _env = env + proc.wait() + + +class ExampleServer: + + def __init__(self, env: Env, crypto_lib: str, verify_client=False): + self.env = env + self._crypto_lib = crypto_lib + self._path = env.server_path(self._crypto_lib) + self._logpath = f'{self.env.gen_dir}/{self._crypto_lib}-server.log' + self._log = LogFile(path=self._logpath) + self._logfile = None + self._process = None + self._verify_client = verify_client + + @property + def path(self): + return self._path + + @property + def crypto_lib(self): + return self._crypto_lib + + @property + def uses_cipher_config(self): + return CryptoLib.uses_cipher_config(self.crypto_lib) + + def supports_cipher(self, cipher): + return CryptoLib.supports_cipher(self.crypto_lib, cipher) + + @property + def log(self): + return self._log + + def exists(self): + return os.path.isfile(self.path) + + def start(self): + if self._process is not None: + return False + creds = self.env.get_server_credentials() + assert creds + args = [ + self.path, + f'--htdocs={self.env.htdocs_dir}', + ] + if self._verify_client: + args.append('--verify-client') + args.extend([ + '*', str(self.env.examples_port), + creds.pkey_file, creds.cert_file + ]) + self._logfile = open(self._logpath, 'w') + self._process = subprocess.Popen(args=args, text=True, + stdout=self._logfile, stderr=self._logfile) + t = Thread(target=monitor_proc, daemon=True, args=(self.env, self._process)) + t.start() + timeout = 5 + end = datetime.now() + timedelta(seconds=timeout) + while True: + if self._process.poll(): + return False + try: + if self.log.scan_recent(pattern=re.compile(r'^Using document root'), timeout=0.5): + break + except TimeoutError: + pass + if datetime.now() > end: + raise TimeoutError(f"pattern not found in error log after {timeout} seconds") + self.log.advance() + return True + + def stop(self): + if self._process: + self._process.terminate() + self._process = None + if self._logfile: + self._logfile.close() + self._logfile = None + return True + + def restart(self): + self.stop() + self._log.reset() + return self.start() + + def get_run(self) -> ServerRun: + return ServerRun(env=self.env, logfile=self.log) diff --git a/examples/tests/ngtcp2test/tls.py b/examples/tests/ngtcp2test/tls.py new file mode 100644 index 0000000..f9bce62 --- /dev/null +++ b/examples/tests/ngtcp2test/tls.py @@ -0,0 +1,983 @@ +import binascii +import logging +import sys +from collections.abc import Iterator +from typing import Dict, Any, Iterable + + +log = logging.getLogger(__name__) + + +class ParseError(Exception): + pass + +def _get_int(d, n): + if len(d) < n: + raise ParseError(f'get_int: {n} bytes needed, but data is {d}') + if n == 1: + dlen = d[0] + else: + dlen = int.from_bytes(d[0:n], byteorder='big') + return d[n:], dlen + + +def _get_field(d, dlen): + if dlen > 0: + if len(d) < dlen: + raise ParseError(f'field len={dlen}, but data len={len(d)}') + field = d[0:dlen] + return d[dlen:], field + return d, b'' + + +def _get_len_field(d, n): + d, dlen = _get_int(d, n) + return _get_field(d, dlen) + + +# d are bytes that start with a quic variable length integer +def _get_qint(d): + i = d[0] & 0xc0 + if i == 0: + return d[1:], int(d[0]) + elif i == 0x40: + ndata = bytearray(d[0:2]) + d = d[2:] + ndata[0] = ndata[0] & ~0xc0 + return d, int.from_bytes(ndata, byteorder='big') + elif i == 0x80: + ndata = bytearray(d[0:4]) + d = d[4:] + ndata[0] = ndata[0] & ~0xc0 + return d, int.from_bytes(ndata, byteorder='big') + else: + ndata = bytearray(d[0:8]) + d = d[8:] + ndata[0] = ndata[0] & ~0xc0 + return d, int.from_bytes(ndata, byteorder='big') + + +class TlsSupportedGroups: + NAME_BY_ID = { + 0: 'Reserved', + 1: 'sect163k1', + 2: 'sect163r1', + 3: 'sect163r2', + 4: 'sect193r1', + 5: 'sect193r2', + 6: 'sect233k1', + 7: 'sect233r1', + 8: 'sect239k1', + 9: 'sect283k1', + 10: 'sect283r1', + 11: 'sect409k1', + 12: 'sect409r1', + 13: 'sect571k1', + 14: 'sect571r1', + 15: 'secp160k1', + 16: 'secp160r1', + 17: 'secp160r2', + 18: 'secp192k1', + 19: 'secp192r1', + 20: 'secp224k1', + 21: 'secp224r1', + 22: 'secp256k1', + 23: 'secp256r1', + 24: 'secp384r1', + 25: 'secp521r1', + 26: 'brainpoolP256r1', + 27: 'brainpoolP384r1', + 28: 'brainpoolP512r1', + 29: 'x25519', + 30: 'x448', + 31: 'brainpoolP256r1tls13', + 32: 'brainpoolP384r1tls13', + 33: 'brainpoolP512r1tls13', + 34: 'GC256A', + 35: 'GC256B', + 36: 'GC256C', + 37: 'GC256D', + 38: 'GC512A', + 39: 'GC512B', + 40: 'GC512C', + 41: 'curveSM2', + } + + @classmethod + def name(cls, gid): + if gid in cls.NAME_BY_ID: + return f'{cls.NAME_BY_ID[gid]}(0x{gid:0x})' + return f'0x{gid:0x}' + + +class TlsSignatureScheme: + NAME_BY_ID = { + 0x0201: 'rsa_pkcs1_sha1', + 0x0202: 'Reserved', + 0x0203: 'ecdsa_sha1', + 0x0401: 'rsa_pkcs1_sha256', + 0x0403: 'ecdsa_secp256r1_sha256', + 0x0420: 'rsa_pkcs1_sha256_legacy', + 0x0501: 'rsa_pkcs1_sha384', + 0x0503: 'ecdsa_secp384r1_sha384', + 0x0520: 'rsa_pkcs1_sha384_legacy', + 0x0601: 'rsa_pkcs1_sha512', + 0x0603: 'ecdsa_secp521r1_sha512', + 0x0620: 'rsa_pkcs1_sha512_legacy', + 0x0704: 'eccsi_sha256', + 0x0705: 'iso_ibs1', + 0x0706: 'iso_ibs2', + 0x0707: 'iso_chinese_ibs', + 0x0708: 'sm2sig_sm3', + 0x0709: 'gostr34102012_256a', + 0x070A: 'gostr34102012_256b', + 0x070B: 'gostr34102012_256c', + 0x070C: 'gostr34102012_256d', + 0x070D: 'gostr34102012_512a', + 0x070E: 'gostr34102012_512b', + 0x070F: 'gostr34102012_512c', + 0x0804: 'rsa_pss_rsae_sha256', + 0x0805: 'rsa_pss_rsae_sha384', + 0x0806: 'rsa_pss_rsae_sha512', + 0x0807: 'ed25519', + 0x0808: 'ed448', + 0x0809: 'rsa_pss_pss_sha256', + 0x080A: 'rsa_pss_pss_sha384', + 0x080B: 'rsa_pss_pss_sha512', + 0x081A: 'ecdsa_brainpoolP256r1tls13_sha256', + 0x081B: 'ecdsa_brainpoolP384r1tls13_sha384', + 0x081C: 'ecdsa_brainpoolP512r1tls13_sha512', + } + + @classmethod + def name(cls, gid): + if gid in cls.NAME_BY_ID: + return f'{cls.NAME_BY_ID[gid]}(0x{gid:0x})' + return f'0x{gid:0x}' + + +class TlsCipherSuites: + NAME_BY_ID = { + 0x1301: 'TLS_AES_128_GCM_SHA256', + 0x1302: 'TLS_AES_256_GCM_SHA384', + 0x1303: 'TLS_CHACHA20_POLY1305_SHA256', + 0x1304: 'TLS_AES_128_CCM_SHA256', + 0x1305: 'TLS_AES_128_CCM_8_SHA256', + 0x00ff: 'TLS_EMPTY_RENEGOTIATION_INFO_SCSV', + } + + @classmethod + def name(cls, cid): + if cid in cls.NAME_BY_ID: + return f'{cls.NAME_BY_ID[cid]}(0x{cid:0x})' + return f'Cipher(0x{cid:0x})' + + +class PskKeyExchangeMode: + NAME_BY_ID = { + 0x00: 'psk_ke', + 0x01: 'psk_dhe_ke', + } + + @classmethod + def name(cls, gid): + if gid in cls.NAME_BY_ID: + return f'{cls.NAME_BY_ID[gid]}(0x{gid:0x})' + return f'0x{gid:0x}' + + +class QuicTransportParam: + NAME_BY_ID = { + 0x00: 'original_destination_connection_id', + 0x01: 'max_idle_timeout', + 0x02: 'stateless_reset_token', + 0x03: 'max_udp_payload_size', + 0x04: 'initial_max_data', + 0x05: 'initial_max_stream_data_bidi_local', + 0x06: 'initial_max_stream_data_bidi_remote', + 0x07: 'initial_max_stream_data_uni', + 0x08: 'initial_max_streams_bidi', + 0x09: 'initial_max_streams_uni', + 0x0a: 'ack_delay_exponent', + 0x0b: 'max_ack_delay', + 0x0c: 'disable_active_migration', + 0x0d: 'preferred_address', + 0x0e: 'active_connection_id_limit', + 0x0f: 'initial_source_connection_id', + 0x10: 'retry_source_connection_id', + } + TYPE_BY_ID = { + 0x00: bytes, + 0x01: int, + 0x02: bytes, + 0x03: int, + 0x04: int, + 0x05: int, + 0x06: int, + 0x07: int, + 0x08: int, + 0x09: int, + 0x0a: int, + 0x0b: int, + 0x0c: int, + 0x0d: bytes, + 0x0e: int, + 0x0f: bytes, + 0x10: bytes, + } + + @classmethod + def name(cls, cid): + if cid in cls.NAME_BY_ID: + return f'{cls.NAME_BY_ID[cid]}(0x{cid:0x})' + return f'QuicTP(0x{cid:0x})' + + @classmethod + def is_qint(cls, cid): + if cid in cls.TYPE_BY_ID: + return cls.TYPE_BY_ID[cid] == int + return False + + +class Extension: + + def __init__(self, eid, name, edata, hsid): + self._eid = eid + self._name = name + self._edata = edata + self._hsid = hsid + + @property + def data(self): + return self._edata + + @property + def hsid(self): + return self._hsid + + def to_json(self): + jdata = { + 'id': self._eid, + 'name': self._name, + } + if len(self.data) > 0: + jdata['data'] = binascii.hexlify(self.data).decode() + return jdata + + def to_text(self, indent: int = 0): + ind = ' ' * (indent + 2) + s = f'{ind}{self._name}(0x{self._eid:0x})' + if len(self._edata): + s += f'\n{ind} data({len(self._edata)}): ' \ + f'{binascii.hexlify(self._edata).decode()}' + return s + + +class ExtSupportedGroups(Extension): + + def __init__(self, eid, name, edata, hsid): + super().__init__(eid=eid, name=name, edata=edata, hsid=hsid) + d = edata + self._groups = [] + while len(d) > 0: + d, gid = _get_int(d, 2) + self._groups.append(gid) + + def to_json(self): + jdata = { + 'id': self._eid, + 'name': self._name, + } + if len(self._groups): + jdata['groups'] = [TlsSupportedGroups.name(gid) + for gid in self._groups] + return jdata + + def to_text(self, indent: int = 0): + ind = ' ' * (indent + 2) + gnames = [TlsSupportedGroups.name(gid) for gid in self._groups] + s = f'{ind}{self._name}(0x{self._eid:0x}): {", ".join(gnames)}' + return s + + +class ExtKeyShare(Extension): + + def __init__(self, eid, name, edata, hsid): + super().__init__(eid=eid, name=name, edata=edata, hsid=hsid) + d = self.data + self._keys = [] + self._group = None + self._pubkey = None + if self.hsid == 2: # ServerHello + # single key share (group, pubkey) + d, self._group = _get_int(d, 2) + d, self._pubkey = _get_len_field(d, 2) + elif self.hsid == 6: # HelloRetryRequest + assert len(d) == 2 + d, self._group = _get_int(d, 2) + else: + # list if key shares (group, pubkey) + d, shares = _get_len_field(d, 2) + while len(shares) > 0: + shares, group = _get_int(shares, 2) + shares, pubkey = _get_len_field(shares, 2) + self._keys.append({ + 'group': TlsSupportedGroups.name(group), + 'pubkey': binascii.hexlify(pubkey).decode() + }) + + def to_json(self): + jdata = super().to_json() + if self._group is not None: + jdata['group'] = TlsSupportedGroups.name(self._group) + if self._pubkey is not None: + jdata['pubkey'] = binascii.hexlify(self._pubkey).decode() + if len(self._keys) > 0: + jdata['keys'] = self._keys + return jdata + + def to_text(self, indent: int = 0): + ind = ' ' * (indent + 2) + s = f'{ind}{self._name}(0x{self._eid:0x})' + if self._group is not None: + s += f'\n{ind} group: {TlsSupportedGroups.name(self._group)}' + if self._pubkey is not None: + s += f'\n{ind} pubkey: {binascii.hexlify(self._pubkey).decode()}' + if len(self._keys) > 0: + for idx, key in enumerate(self._keys): + s += f'\n{ind} {idx}: {key["group"]}, {key["pubkey"]}' + return s + + +class ExtSNI(Extension): + + def __init__(self, eid, name, edata, hsid): + super().__init__(eid=eid, name=name, edata=edata, hsid=hsid) + d = self.data + self._indicators = [] + while len(d) > 0: + d, entry = _get_len_field(d, 2) + entry, stype = _get_int(entry, 1) + entry, sname = _get_len_field(entry, 2) + self._indicators.append({ + 'type': stype, + 'name': sname.decode() + }) + + def to_json(self): + jdata = super().to_json() + for i in self._indicators: + if i['type'] == 0: + jdata['host_name'] = i['name'] + else: + jdata[f'0x{i["type"]}'] = i['name'] + return jdata + + def to_text(self, indent: int = 0): + ind = ' ' * (indent + 2) + s = f'{ind}{self._name}(0x{self._eid:0x})' + if len(self._indicators) == 1 and self._indicators[0]['type'] == 0: + s += f': {self._indicators[0]["name"]}' + else: + for i in self._indicators: + ikey = 'host_name' if i["type"] == 0 else f'type(0x{i["type"]:0x}' + s += f'\n{ind} {ikey}: {i["name"]}' + return s + + +class ExtALPN(Extension): + + def __init__(self, eid, name, edata, hsid): + super().__init__(eid=eid, name=name, edata=edata, hsid=hsid) + d = self.data + d, list_len = _get_int(d, 2) + self._protocols = [] + while len(d) > 0: + d, proto = _get_len_field(d, 1) + self._protocols.append(proto.decode()) + + def to_json(self): + jdata = super().to_json() + if len(self._protocols) == 1: + jdata['alpn'] = self._protocols[0] + else: + jdata['alpn'] = self._protocols + return jdata + + def to_text(self, indent: int = 0): + ind = ' ' * (indent + 2) + return f'{ind}{self._name}(0x{self._eid:0x}): {", ".join(self._protocols)}' + + +class ExtEarlyData(Extension): + + def __init__(self, eid, name, edata, hsid): + super().__init__(eid=eid, name=name, edata=edata, hsid=hsid) + self._max_size = None + d = self.data + if hsid == 4: # SessionTicket + assert len(d) == 4, f'expected 4, len is {len(d)} data={d}' + d, self._max_size = _get_int(d, 4) + else: + assert len(d) == 0 + + def to_json(self): + jdata = super().to_json() + if self._max_size is not None: + jdata['max_size'] = self._max_size + return jdata + + +class ExtSignatureAlgorithms(Extension): + + def __init__(self, eid, name, edata, hsid): + super().__init__(eid=eid, name=name, edata=edata, hsid=hsid) + d = self.data + d, list_len = _get_int(d, 2) + self._algos = [] + while len(d) > 0: + d, algo = _get_int(d, 2) + self._algos.append(TlsSignatureScheme.name(algo)) + + def to_json(self): + jdata = super().to_json() + if len(self._algos) > 0: + jdata['algorithms'] = self._algos + return jdata + + def to_text(self, indent: int = 0): + ind = ' ' * (indent + 2) + return f'{ind}{self._name}(0x{self._eid:0x}): {", ".join(self._algos)}' + + +class ExtPSKExchangeModes(Extension): + + def __init__(self, eid, name, edata, hsid): + super().__init__(eid=eid, name=name, edata=edata, hsid=hsid) + d = self.data + d, list_len = _get_int(d, 1) + self._modes = [] + while len(d) > 0: + d, mode = _get_int(d, 1) + self._modes.append(PskKeyExchangeMode.name(mode)) + + def to_json(self): + jdata = super().to_json() + jdata['modes'] = self._modes + return jdata + + def to_text(self, indent: int = 0): + ind = ' ' * (indent + 2) + return f'{ind}{self._name}(0x{self._eid:0x}): {", ".join(self._modes)}' + + +class ExtPreSharedKey(Extension): + + def __init__(self, eid, name, edata, hsid): + super().__init__(eid=eid, name=name, edata=edata, hsid=hsid) + self._kid = None + self._identities = None + self._binders = None + d = self.data + if hsid == 1: # client hello + d, idata = _get_len_field(d, 2) + self._identities = [] + while len(idata): + idata, identity = _get_len_field(idata, 2) + idata, obfs_age = _get_int(idata, 4) + self._identities.append({ + 'id': binascii.hexlify(identity).decode(), + 'age': obfs_age, + }) + d, binders = _get_len_field(d, 2) + self._binders = [] + while len(binders) > 0: + binders, hmac = _get_len_field(binders, 1) + self._binders.append(binascii.hexlify(hmac).decode()) + assert len(d) == 0 + else: + d, self._kid = _get_int(d, 2) + + def to_json(self): + jdata = super().to_json() + if self.hsid == 1: + jdata['identities'] = self._identities + jdata['binders'] = self._binders + else: + jdata['identity'] = self._kid + return jdata + + def to_text(self, indent: int = 0): + ind = ' ' * (indent + 2) + s = f'{ind}{self._name}(0x{self._hsid:0x})' + if self.hsid == 1: + for idx, i in enumerate(self._identities): + s += f'\n{ind} {idx}: {i["id"]} ({i["age"]})' + s += f'\n{ind} binders: {self._binders}' + else: + s += f'\n{ind} identity: {self._kid}' + return s + + +class ExtSupportedVersions(Extension): + + def __init__(self, eid, name, edata, hsid): + super().__init__(eid=eid, name=name, edata=edata, hsid=hsid) + d = self.data + self._versions = [] + if hsid == 1: # client hello + d, list_len = _get_int(d, 1) + while len(d) > 0: + d, version = _get_int(d, 2) + self._versions.append(f'0x{version:0x}') + else: + d, version = _get_int(d, 2) + self._versions.append(f'0x{version:0x}') + + def to_json(self): + jdata = super().to_json() + if len(self._versions) == 1: + jdata['version'] = self._versions[0] + else: + jdata['versions'] = self._versions + return jdata + + def to_text(self, indent: int = 0): + ind = ' ' * (indent + 2) + return f'{ind}{self._name}(0x{self._eid:0x}): {", ".join(self._versions)}' + + +class ExtQuicTP(Extension): + + def __init__(self, eid, name, edata, hsid): + super().__init__(eid=eid, name=name, edata=edata, hsid=hsid) + d = self.data + self._params = [] + while len(d) > 0: + d, ptype = _get_qint(d) + d, plen = _get_qint(d) + d, pvalue = _get_field(d, plen) + if QuicTransportParam.is_qint(ptype): + _, pvalue = _get_qint(pvalue) + else: + pvalue = binascii.hexlify(pvalue).decode() + self._params.append({ + 'key': QuicTransportParam.name(ptype), + 'value': pvalue, + }) + + def to_json(self): + jdata = super().to_json() + jdata['params'] = self._params + return jdata + + def to_text(self, indent: int = 0): + ind = ' ' * (indent + 2) + s = f'{ind}{self._name}(0x{self._eid:0x})' + for p in self._params: + s += f'\n{ind} {p["key"]}: {p["value"]}' + return s + + +class TlsExtensions: + + EXT_TYPES = [ + (0x00, 'SNI', ExtSNI), + (0x01, 'MAX_FRAGMENT_LENGTH', Extension), + (0x03, 'TRUSTED_CA_KEYS', Extension), + (0x04, 'TRUNCATED_HMAC', Extension), + (0x05, 'OSCP_STATUS_REQUEST', Extension), + (0x0a, 'SUPPORTED_GROUPS', ExtSupportedGroups), + (0x0b, 'EC_POINT_FORMATS', Extension), + (0x0d, 'SIGNATURE_ALGORITHMS', ExtSignatureAlgorithms), + (0x0e, 'USE_SRTP', Extension), + (0x10, 'ALPN', ExtALPN), + (0x11, 'STATUS_REQUEST_V2', Extension), + (0x16, 'ENCRYPT_THEN_MAC', Extension), + (0x17, 'EXTENDED_MASTER_SECRET', Extension), + (0x23, 'SESSION_TICKET', Extension), + (0x29, 'PRE_SHARED_KEY', ExtPreSharedKey), + (0x2a, 'EARLY_DATA', ExtEarlyData), + (0x2b, 'SUPPORTED_VERSIONS', ExtSupportedVersions), + (0x2c, 'COOKIE', Extension), + (0x2d, 'PSK_KEY_EXCHANGE_MODES', ExtPSKExchangeModes), + (0x31, 'POST_HANDSHAKE_AUTH', Extension), + (0x32, 'SIGNATURE_ALGORITHMS_CERT', Extension), + (0x33, 'KEY_SHARE', ExtKeyShare), + (0x39, 'QUIC_TP_PARAMS', ExtQuicTP), + (0xff01, 'RENEGOTIATION_INFO', Extension), + (0xffa5, 'QUIC_TP_PARAMS_DRAFT', ExtQuicTP), + ] + NAME_BY_ID = {} + CLASS_BY_ID = {} + + @classmethod + def init(cls): + for (eid, name, ecls) in cls.EXT_TYPES: + cls.NAME_BY_ID[eid] = name + cls.CLASS_BY_ID[eid] = ecls + + @classmethod + def from_data(cls, hsid, data): + exts = [] + d = data + while len(d): + d, eid = _get_int(d, 2) + d, elen = _get_int(d, 2) + d, edata = _get_field(d, elen) + if eid in cls.NAME_BY_ID: + ename = cls.NAME_BY_ID[eid] + ecls = cls.CLASS_BY_ID[eid] + exts.append(ecls(eid=eid, name=ename, edata=edata, hsid=hsid)) + else: + exts.append(Extension(eid=eid, name=f'(0x{eid:0x})', + edata=edata, hsid=hsid)) + return exts + + +TlsExtensions.init() + + +class HSRecord: + + def __init__(self, hsid: int, name: str, data): + self._hsid = hsid + self._name = name + self._data = data + + @property + def hsid(self): + return self._hsid + + @property + def name(self): + return self._name + + @name.setter + def name(self, value): + self._name = value + + @property + def data(self): + return self._data + + def __repr__(self): + return f'{self.name}[{binascii.hexlify(self._data).decode()}]' + + def to_json(self) -> Dict[str, Any]: + return { + 'name': self.name, + 'data': binascii.hexlify(self._data).decode(), + } + + def to_text(self, indent: int = 0): + ind = ' ' * (indent + 2) + return f'{ind}{self._name}\n'\ + f'{ind} id: 0x{self._hsid:0x}\n'\ + f'{ind} data({len(self._data)}): '\ + f'{binascii.hexlify(self._data).decode()}' + + +class ClientHello(HSRecord): + + def __init__(self, hsid: int, name: str, data): + super().__init__(hsid=hsid, name=name, data=data) + d = data + d, self._version = _get_int(d, 2) + d, self._random = _get_field(d, 32) + d, self._session_id = _get_len_field(d, 1) + self._ciphers = [] + d, ciphers = _get_len_field(d, 2) + while len(ciphers): + ciphers, cipher = _get_int(ciphers, 2) + self._ciphers.append(TlsCipherSuites.name(cipher)) + d, comps = _get_len_field(d, 1) + self._compressions = [int(c) for c in comps] + d, edata = _get_len_field(d, 2) + self._extensions = TlsExtensions.from_data(hsid, edata) + + def to_json(self): + jdata = super().to_json() + jdata['version'] = f'0x{self._version:0x}' + jdata['random'] = f'{binascii.hexlify(self._random).decode()}' + jdata['session_id'] = binascii.hexlify(self._session_id).decode() + jdata['ciphers'] = self._ciphers + jdata['compressions'] = self._compressions + jdata['extensions'] = [ext.to_json() for ext in self._extensions] + return jdata + + def to_text(self, indent: int = 0): + ind = ' ' * (indent + 2) + return super().to_text(indent=indent) + '\n'\ + f'{ind} version: 0x{self._version:0x}\n'\ + f'{ind} random: {binascii.hexlify(self._random).decode()}\n' \ + f'{ind} session_id: {binascii.hexlify(self._session_id).decode()}\n' \ + f'{ind} ciphers: {", ".join(self._ciphers)}\n'\ + f'{ind} compressions: {self._compressions}\n'\ + f'{ind} extensions: \n' + '\n'.join( + [ext.to_text(indent=indent+4) for ext in self._extensions]) + + +class ServerHello(HSRecord): + + HELLO_RETRY_RANDOM = binascii.unhexlify( + 'CF21AD74E59A6111BE1D8C021E65B891C2A211167ABB8C5E079E09E2C8A8339C' + ) + + def __init__(self, hsid: int, name: str, data): + super().__init__(hsid=hsid, name=name, data=data) + d = data + d, self._version = _get_int(d, 2) + d, self._random = _get_field(d, 32) + if self._random == self.HELLO_RETRY_RANDOM: + self.name = 'HelloRetryRequest' + hsid = 6 + d, self._session_id = _get_len_field(d, 1) + d, cipher = _get_int(d, 2) + self._cipher = TlsCipherSuites.name(cipher) + d, self._compression = _get_int(d, 1) + d, edata = _get_len_field(d, 2) + self._extensions = TlsExtensions.from_data(hsid, edata) + + def to_json(self): + jdata = super().to_json() + jdata['version'] = f'0x{self._version:0x}' + jdata['random'] = f'{binascii.hexlify(self._random).decode()}' + jdata['session_id'] = binascii.hexlify(self._session_id).decode() + jdata['cipher'] = self._cipher + jdata['compression'] = int(self._compression) + jdata['extensions'] = [ext.to_json() for ext in self._extensions] + return jdata + + def to_text(self, indent: int = 0): + ind = ' ' * (indent + 2) + return super().to_text(indent=indent) + '\n'\ + f'{ind} version: 0x{self._version:0x}\n'\ + f'{ind} random: {binascii.hexlify(self._random).decode()}\n' \ + f'{ind} session_id: {binascii.hexlify(self._session_id).decode()}\n' \ + f'{ind} cipher: {self._cipher}\n'\ + f'{ind} compression: {int(self._compression)}\n'\ + f'{ind} extensions: \n' + '\n'.join( + [ext.to_text(indent=indent+4) for ext in self._extensions]) + + +class EncryptedExtensions(HSRecord): + + def __init__(self, hsid: int, name: str, data): + super().__init__(hsid=hsid, name=name, data=data) + d = data + d, edata = _get_len_field(d, 2) + self._extensions = TlsExtensions.from_data(hsid, edata) + + def to_json(self): + jdata = super().to_json() + jdata['extensions'] = [ext.to_json() for ext in self._extensions] + return jdata + + def to_text(self, indent: int = 0): + ind = ' ' * (indent + 2) + return super().to_text(indent=indent) + '\n'\ + f'{ind} extensions: \n' + '\n'.join( + [ext.to_text(indent=indent+4) for ext in self._extensions]) + + +class CertificateRequest(HSRecord): + + def __init__(self, hsid: int, name: str, data): + super().__init__(hsid=hsid, name=name, data=data) + d = data + d, self._context = _get_int(d, 1) + d, edata = _get_len_field(d, 2) + self._extensions = TlsExtensions.from_data(hsid, edata) + + def to_json(self): + jdata = super().to_json() + jdata['context'] = self._context + jdata['extensions'] = [ext.to_json() for ext in self._extensions] + return jdata + + def to_text(self, indent: int = 0): + ind = ' ' * (indent + 2) + return super().to_text(indent=indent) + '\n'\ + f'{ind} context: {self._context}\n'\ + f'{ind} extensions: \n' + '\n'.join( + [ext.to_text(indent=indent+4) for ext in self._extensions]) + + +class Certificate(HSRecord): + + def __init__(self, hsid: int, name: str, data): + super().__init__(hsid=hsid, name=name, data=data) + d = data + d, self._context = _get_int(d, 1) + d, clist = _get_len_field(d, 3) + self._cert_entries = [] + while len(clist) > 0: + clist, cert_data = _get_len_field(clist, 3) + clist, cert_exts = _get_len_field(clist, 2) + exts = TlsExtensions.from_data(hsid, cert_exts) + self._cert_entries.append({ + 'cert': binascii.hexlify(cert_data).decode(), + 'extensions': exts, + }) + + def to_json(self): + jdata = super().to_json() + jdata['context'] = self._context + jdata['certificate_list'] = [{ + 'cert': e['cert'], + 'extensions': [x.to_json() for x in e['extensions']], + } for e in self._cert_entries] + return jdata + + def _enxtry_text(self, e, indent: int = 0): + ind = ' ' * (indent + 2) + return f'{ind} cert: {e["cert"]}\n'\ + f'{ind} extensions: \n' + '\n'.join( + [x.to_text(indent=indent + 4) for x in e['extensions']]) + + def to_text(self, indent: int = 0): + ind = ' ' * (indent + 2) + return super().to_text(indent=indent) + '\n'\ + f'{ind} context: {self._context}\n'\ + f'{ind} certificate_list: \n' + '\n'.join( + [self._enxtry_text(e, indent+4) for e in self._cert_entries]) + + +class SessionTicket(HSRecord): + + def __init__(self, hsid: int, name: str, data): + super().__init__(hsid=hsid, name=name, data=data) + d = data + d, self._lifetime = _get_int(d, 4) + d, self._age = _get_int(d, 4) + d, self._nonce = _get_len_field(d, 1) + d, self._ticket = _get_len_field(d, 2) + d, edata = _get_len_field(d, 2) + self._extensions = TlsExtensions.from_data(hsid, edata) + + def to_json(self): + jdata = super().to_json() + jdata['lifetime'] = self._lifetime + jdata['age'] = self._age + jdata['nonce'] = binascii.hexlify(self._nonce).decode() + jdata['ticket'] = binascii.hexlify(self._ticket).decode() + jdata['extensions'] = [ext.to_json() for ext in self._extensions] + return jdata + + +class HSIterator(Iterator): + + def __init__(self, recs): + self._recs = recs + self._index = 0 + + def __iter__(self): + return self + + def __next__(self): + try: + result = self._recs[self._index] + except IndexError: + raise StopIteration + self._index += 1 + return result + + +class HandShake: + REC_TYPES = [ + (1, 'ClientHello', ClientHello), + (2, 'ServerHello', ServerHello), + (3, 'HelloVerifyRequest', HSRecord), + (4, 'SessionTicket', SessionTicket), + (5, 'EndOfEarlyData', HSRecord), + (6, 'HelloRetryRequest', ServerHello), + (8, 'EncryptedExtensions', EncryptedExtensions), + (11, 'Certificate', Certificate), + (12, 'ServerKeyExchange ', HSRecord), + (13, 'CertificateRequest', CertificateRequest), + (14, 'ServerHelloDone', HSRecord), + (15, 'CertificateVerify', HSRecord), + (16, 'ClientKeyExchange', HSRecord), + (20, 'Finished', HSRecord), + (22, 'CertificateStatus', HSRecord), + (24, 'KeyUpdate', HSRecord), + ] + RT_NAME_BY_ID = {} + RT_CLS_BY_ID = {} + + @classmethod + def _parse_rec(cls, data): + d, hsid = _get_int(data, 1) + if hsid not in cls.RT_CLS_BY_ID: + raise ParseError(f'unknown type {hsid}') + d, rec_len = _get_int(d, 3) + if rec_len > len(d): + # incomplete, need more data + return data, None + d, rec_data = _get_field(d, rec_len) + if hsid in cls.RT_CLS_BY_ID: + name = cls.RT_NAME_BY_ID[hsid] + rcls = cls.RT_CLS_BY_ID[hsid] + else: + name = f'CryptoRecord(0x{hsid:0x})' + rcls = HSRecord + return d, rcls(hsid=hsid, name=name, data=rec_data) + + @classmethod + def _parse(cls, source, strict=False, verbose: int = 0): + d = b'' + hsid = 0 + hsrecs = [] + if verbose > 0: + log.debug(f'scanning for handshake records') + blocks = [d for d in source] + while len(blocks) > 0: + try: + total_data = b''.join(blocks) + remain, r = cls._parse_rec(total_data) + if r is None: + # if we could not recognize a record, skip the first + # data block and try again + blocks = blocks[1:] + continue + hsrecs.append(r) + cons_len = len(total_data) - len(remain) + while cons_len > 0 and len(blocks) > 0: + if cons_len >= len(blocks[0]): + cons_len -= len(blocks[0]) + blocks = blocks[1:] + else: + blocks[0] = blocks[0][cons_len:] + cons_len = 0 + if verbose > 2: + log.debug(f'added record: {r.to_text()}') + except ParseError as err: + # if we could not recognize a record, skip the first + # data block and try again + blocks = blocks[1:] + if len(blocks) > 0 and strict: + raise Exception(f'possibly incomplete handshake record ' + f'id={hsid}, from raw={blocks}\n') + return hsrecs + + + + @classmethod + def init(cls): + for (hsid, name, rcls) in cls.REC_TYPES: + cls.RT_NAME_BY_ID[hsid] = name + cls.RT_CLS_BY_ID[hsid] = rcls + + def __init__(self, source: Iterable[bytes], strict: bool = False, + verbose: int = 0): + self._source = source + self._strict = strict + self._verbose = verbose + + def __iter__(self): + return HSIterator(recs=self._parse(self._source, strict=self._strict, + verbose=self._verbose)) + + +HandShake.init() |