summaryrefslogtreecommitdiffstats
path: root/examples/tests/ngtcp2test
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--examples/tests/ngtcp2test/__init__.py6
-rw-r--r--examples/tests/ngtcp2test/certs.py476
-rw-r--r--examples/tests/ngtcp2test/client.py187
-rw-r--r--examples/tests/ngtcp2test/env.py191
-rw-r--r--examples/tests/ngtcp2test/log.py101
-rw-r--r--examples/tests/ngtcp2test/server.py137
-rw-r--r--examples/tests/ngtcp2test/tls.py983
7 files changed, 2081 insertions, 0 deletions
diff --git a/examples/tests/ngtcp2test/__init__.py b/examples/tests/ngtcp2test/__init__.py
new file mode 100644
index 0000000..65c61d8
--- /dev/null
+++ b/examples/tests/ngtcp2test/__init__.py
@@ -0,0 +1,6 @@
+from .env import Env, CryptoLib
+from .log import LogFile
+from .client import ExampleClient, ClientRun
+from .server import ExampleServer, ServerRun
+from .certs import Ngtcp2TestCA, Credentials
+from .tls import HandShake, HSRecord
diff --git a/examples/tests/ngtcp2test/certs.py b/examples/tests/ngtcp2test/certs.py
new file mode 100644
index 0000000..3ab6260
--- /dev/null
+++ b/examples/tests/ngtcp2test/certs.py
@@ -0,0 +1,476 @@
+import os
+import re
+from datetime import timedelta, datetime
+from typing import List, Any, Optional
+
+from cryptography import x509
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives.asymmetric import ec, rsa
+from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
+from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
+from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption, load_pem_private_key
+from cryptography.x509 import ExtendedKeyUsageOID, NameOID
+
+
+EC_SUPPORTED = {}
+EC_SUPPORTED.update([(curve.name.upper(), curve) for curve in [
+ ec.SECP192R1,
+ ec.SECP224R1,
+ ec.SECP256R1,
+ ec.SECP384R1,
+]])
+
+
+def _private_key(key_type):
+ if isinstance(key_type, str):
+ key_type = key_type.upper()
+ m = re.match(r'^(RSA)?(\d+)$', key_type)
+ if m:
+ key_type = int(m.group(2))
+
+ if isinstance(key_type, int):
+ return rsa.generate_private_key(
+ public_exponent=65537,
+ key_size=key_type,
+ backend=default_backend()
+ )
+ if not isinstance(key_type, ec.EllipticCurve) and key_type in EC_SUPPORTED:
+ key_type = EC_SUPPORTED[key_type]
+ return ec.generate_private_key(
+ curve=key_type,
+ backend=default_backend()
+ )
+
+
+class CertificateSpec:
+
+ def __init__(self, name: str = None, domains: List[str] = None,
+ email: str = None,
+ key_type: str = None, single_file: bool = False,
+ valid_from: timedelta = timedelta(days=-1),
+ valid_to: timedelta = timedelta(days=89),
+ client: bool = False,
+ sub_specs: List['CertificateSpec'] = None):
+ self._name = name
+ self.domains = domains
+ self.client = client
+ self.email = email
+ self.key_type = key_type
+ self.single_file = single_file
+ self.valid_from = valid_from
+ self.valid_to = valid_to
+ self.sub_specs = sub_specs
+
+ @property
+ def name(self) -> Optional[str]:
+ if self._name:
+ return self._name
+ elif self.domains:
+ return self.domains[0]
+ return None
+
+ @property
+ def type(self) -> Optional[str]:
+ if self.domains and len(self.domains):
+ return "server"
+ elif self.client:
+ return "client"
+ elif self.name:
+ return "ca"
+ return None
+
+
+class Credentials:
+
+ def __init__(self, name: str, cert: Any, pkey: Any, issuer: 'Credentials' = None):
+ self._name = name
+ self._cert = cert
+ self._pkey = pkey
+ self._issuer = issuer
+ self._cert_file = None
+ self._pkey_file = None
+ self._store = None
+
+ @property
+ def name(self) -> str:
+ return self._name
+
+ @property
+ def subject(self) -> x509.Name:
+ return self._cert.subject
+
+ @property
+ def key_type(self):
+ if isinstance(self._pkey, RSAPrivateKey):
+ return f"rsa{self._pkey.key_size}"
+ elif isinstance(self._pkey, EllipticCurvePrivateKey):
+ return f"{self._pkey.curve.name}"
+ else:
+ raise Exception(f"unknown key type: {self._pkey}")
+
+ @property
+ def private_key(self) -> Any:
+ return self._pkey
+
+ @property
+ def certificate(self) -> Any:
+ return self._cert
+
+ @property
+ def cert_pem(self) -> bytes:
+ return self._cert.public_bytes(Encoding.PEM)
+
+ @property
+ def pkey_pem(self) -> bytes:
+ return self._pkey.private_bytes(
+ Encoding.PEM,
+ PrivateFormat.TraditionalOpenSSL if self.key_type.startswith('rsa') else PrivateFormat.PKCS8,
+ NoEncryption())
+
+ @property
+ def issuer(self) -> Optional['Credentials']:
+ return self._issuer
+
+ def set_store(self, store: 'CertStore'):
+ self._store = store
+
+ def set_files(self, cert_file: str, pkey_file: str = None):
+ self._cert_file = cert_file
+ self._pkey_file = pkey_file
+
+ @property
+ def cert_file(self) -> str:
+ return self._cert_file
+
+ @property
+ def pkey_file(self) -> Optional[str]:
+ return self._pkey_file
+
+ def get_first(self, name) -> Optional['Credentials']:
+ creds = self._store.get_credentials_for_name(name) if self._store else []
+ return creds[0] if len(creds) else None
+
+ def get_credentials_for_name(self, name) -> List['Credentials']:
+ return self._store.get_credentials_for_name(name) if self._store else []
+
+ def issue_certs(self, specs: List[CertificateSpec],
+ chain: List['Credentials'] = None) -> List['Credentials']:
+ return [self.issue_cert(spec=spec, chain=chain) for spec in specs]
+
+ def issue_cert(self, spec: CertificateSpec, chain: List['Credentials'] = None) -> 'Credentials':
+ key_type = spec.key_type if spec.key_type else self.key_type
+ creds = None
+ if self._store:
+ creds = self._store.load_credentials(
+ name=spec.name, key_type=key_type, single_file=spec.single_file, issuer=self)
+ if creds is None:
+ creds = Ngtcp2TestCA.create_credentials(spec=spec, issuer=self, key_type=key_type,
+ valid_from=spec.valid_from, valid_to=spec.valid_to)
+ if self._store:
+ self._store.save(creds, single_file=spec.single_file)
+ if spec.type == "ca":
+ self._store.save_chain(creds, "ca", with_root=True)
+
+ if spec.sub_specs:
+ if self._store:
+ sub_store = CertStore(fpath=os.path.join(self._store.path, creds.name))
+ creds.set_store(sub_store)
+ subchain = chain.copy() if chain else []
+ subchain.append(self)
+ creds.issue_certs(spec.sub_specs, chain=subchain)
+ return creds
+
+
+class CertStore:
+
+ def __init__(self, fpath: str):
+ self._store_dir = fpath
+ if not os.path.exists(self._store_dir):
+ os.makedirs(self._store_dir)
+ self._creds_by_name = {}
+
+ @property
+ def path(self) -> str:
+ return self._store_dir
+
+ def save(self, creds: Credentials, name: str = None,
+ chain: List[Credentials] = None,
+ single_file: bool = False) -> None:
+ name = name if name is not None else creds.name
+ cert_file = self.get_cert_file(name=name, key_type=creds.key_type)
+ pkey_file = self.get_pkey_file(name=name, key_type=creds.key_type)
+ if single_file:
+ pkey_file = None
+ with open(cert_file, "wb") as fd:
+ fd.write(creds.cert_pem)
+ if chain:
+ for c in chain:
+ fd.write(c.cert_pem)
+ if pkey_file is None:
+ fd.write(creds.pkey_pem)
+ if pkey_file is not None:
+ with open(pkey_file, "wb") as fd:
+ fd.write(creds.pkey_pem)
+ creds.set_files(cert_file, pkey_file)
+ self._add_credentials(name, creds)
+
+ def save_chain(self, creds: Credentials, infix: str, with_root=False):
+ name = creds.name
+ chain = [creds]
+ while creds.issuer is not None:
+ creds = creds.issuer
+ chain.append(creds)
+ if not with_root and len(chain) > 1:
+ chain = chain[:-1]
+ chain_file = os.path.join(self._store_dir, f'{name}-{infix}.pem')
+ with open(chain_file, "wb") as fd:
+ for c in chain:
+ fd.write(c.cert_pem)
+
+ def _add_credentials(self, name: str, creds: Credentials):
+ if name not in self._creds_by_name:
+ self._creds_by_name[name] = []
+ self._creds_by_name[name].append(creds)
+
+ def get_credentials_for_name(self, name) -> List[Credentials]:
+ return self._creds_by_name[name] if name in self._creds_by_name else []
+
+ def get_cert_file(self, name: str, key_type=None) -> str:
+ key_infix = ".{0}".format(key_type) if key_type is not None else ""
+ return os.path.join(self._store_dir, f'{name}{key_infix}.cert.pem')
+
+ def get_pkey_file(self, name: str, key_type=None) -> str:
+ key_infix = ".{0}".format(key_type) if key_type is not None else ""
+ return os.path.join(self._store_dir, f'{name}{key_infix}.pkey.pem')
+
+ def load_pem_cert(self, fpath: str) -> x509.Certificate:
+ with open(fpath) as fd:
+ return x509.load_pem_x509_certificate("".join(fd.readlines()).encode())
+
+ def load_pem_pkey(self, fpath: str):
+ with open(fpath) as fd:
+ return load_pem_private_key("".join(fd.readlines()).encode(), password=None)
+
+ def load_credentials(self, name: str, key_type=None, single_file: bool = False, issuer: Credentials = None):
+ cert_file = self.get_cert_file(name=name, key_type=key_type)
+ pkey_file = cert_file if single_file else self.get_pkey_file(name=name, key_type=key_type)
+ if os.path.isfile(cert_file) and os.path.isfile(pkey_file):
+ cert = self.load_pem_cert(cert_file)
+ pkey = self.load_pem_pkey(pkey_file)
+ creds = Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
+ creds.set_store(self)
+ creds.set_files(cert_file, pkey_file)
+ self._add_credentials(name, creds)
+ return creds
+ return None
+
+
+class Ngtcp2TestCA:
+
+ @classmethod
+ def create_root(cls, name: str, store_dir: str, key_type: str = "rsa2048") -> Credentials:
+ store = CertStore(fpath=store_dir)
+ creds = store.load_credentials(name="ca", key_type=key_type, issuer=None)
+ if creds is None:
+ creds = Ngtcp2TestCA._make_ca_credentials(name=name, key_type=key_type)
+ store.save(creds, name="ca")
+ creds.set_store(store)
+ return creds
+
+ @staticmethod
+ def create_credentials(spec: CertificateSpec, issuer: Credentials, key_type: Any,
+ valid_from: timedelta = timedelta(days=-1),
+ valid_to: timedelta = timedelta(days=89),
+ ) -> Credentials:
+ """Create a certificate signed by this CA for the given domains.
+ :returns: the certificate and private key PEM file paths
+ """
+ if spec.domains and len(spec.domains):
+ creds = Ngtcp2TestCA._make_server_credentials(name=spec.name, domains=spec.domains,
+ issuer=issuer, valid_from=valid_from,
+ valid_to=valid_to, key_type=key_type)
+ elif spec.client:
+ creds = Ngtcp2TestCA._make_client_credentials(name=spec.name, issuer=issuer,
+ email=spec.email, valid_from=valid_from,
+ valid_to=valid_to, key_type=key_type)
+ elif spec.name:
+ creds = Ngtcp2TestCA._make_ca_credentials(name=spec.name, issuer=issuer,
+ valid_from=valid_from, valid_to=valid_to,
+ key_type=key_type)
+ else:
+ raise Exception(f"unrecognized certificate specification: {spec}")
+ return creds
+
+ @staticmethod
+ def _make_x509_name(org_name: str = None, common_name: str = None, parent: x509.Name = None) -> x509.Name:
+ name_pieces = []
+ if org_name:
+ oid = NameOID.ORGANIZATIONAL_UNIT_NAME if parent else NameOID.ORGANIZATION_NAME
+ name_pieces.append(x509.NameAttribute(oid, org_name))
+ elif common_name:
+ name_pieces.append(x509.NameAttribute(NameOID.COMMON_NAME, common_name))
+ if parent:
+ name_pieces.extend([rdn for rdn in parent])
+ return x509.Name(name_pieces)
+
+ @staticmethod
+ def _make_csr(
+ subject: x509.Name,
+ pkey: Any,
+ issuer_subject: Optional[Credentials],
+ valid_from_delta: timedelta = None,
+ valid_until_delta: timedelta = None
+ ):
+ pubkey = pkey.public_key()
+ issuer_subject = issuer_subject if issuer_subject is not None else subject
+
+ valid_from = datetime.now()
+ if valid_until_delta is not None:
+ valid_from += valid_from_delta
+ valid_until = datetime.now()
+ if valid_until_delta is not None:
+ valid_until += valid_until_delta
+
+ return (
+ x509.CertificateBuilder()
+ .subject_name(subject)
+ .issuer_name(issuer_subject)
+ .public_key(pubkey)
+ .not_valid_before(valid_from)
+ .not_valid_after(valid_until)
+ .serial_number(x509.random_serial_number())
+ .add_extension(
+ x509.SubjectKeyIdentifier.from_public_key(pubkey),
+ critical=False,
+ )
+ )
+
+ @staticmethod
+ def _add_ca_usages(csr: Any) -> Any:
+ return csr.add_extension(
+ x509.BasicConstraints(ca=True, path_length=9),
+ critical=True,
+ ).add_extension(
+ x509.KeyUsage(
+ digital_signature=True,
+ content_commitment=False,
+ key_encipherment=False,
+ data_encipherment=False,
+ key_agreement=False,
+ key_cert_sign=True,
+ crl_sign=True,
+ encipher_only=False,
+ decipher_only=False),
+ critical=True
+ ).add_extension(
+ x509.ExtendedKeyUsage([
+ ExtendedKeyUsageOID.CLIENT_AUTH,
+ ExtendedKeyUsageOID.SERVER_AUTH,
+ ExtendedKeyUsageOID.CODE_SIGNING,
+ ]),
+ critical=True
+ )
+
+ @staticmethod
+ def _add_leaf_usages(csr: Any, domains: List[str], issuer: Credentials) -> Any:
+ return csr.add_extension(
+ x509.BasicConstraints(ca=False, path_length=None),
+ critical=True,
+ ).add_extension(
+ x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
+ issuer.certificate.extensions.get_extension_for_class(
+ x509.SubjectKeyIdentifier).value),
+ critical=False
+ ).add_extension(
+ x509.SubjectAlternativeName([x509.DNSName(domain) for domain in domains]),
+ critical=True,
+ ).add_extension(
+ x509.ExtendedKeyUsage([
+ ExtendedKeyUsageOID.SERVER_AUTH,
+ ]),
+ critical=True
+ )
+
+ @staticmethod
+ def _add_client_usages(csr: Any, issuer: Credentials, rfc82name: str = None) -> Any:
+ cert = csr.add_extension(
+ x509.BasicConstraints(ca=False, path_length=None),
+ critical=True,
+ ).add_extension(
+ x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
+ issuer.certificate.extensions.get_extension_for_class(
+ x509.SubjectKeyIdentifier).value),
+ critical=False
+ )
+ if rfc82name:
+ cert.add_extension(
+ x509.SubjectAlternativeName([x509.RFC822Name(rfc82name)]),
+ critical=True,
+ )
+ cert.add_extension(
+ x509.ExtendedKeyUsage([
+ ExtendedKeyUsageOID.CLIENT_AUTH,
+ ]),
+ critical=True
+ )
+ return cert
+
+ @staticmethod
+ def _make_ca_credentials(name, key_type: Any,
+ issuer: Credentials = None,
+ valid_from: timedelta = timedelta(days=-1),
+ valid_to: timedelta = timedelta(days=89),
+ ) -> Credentials:
+ pkey = _private_key(key_type=key_type)
+ if issuer is not None:
+ issuer_subject = issuer.certificate.subject
+ issuer_key = issuer.private_key
+ else:
+ issuer_subject = None
+ issuer_key = pkey
+ subject = Ngtcp2TestCA._make_x509_name(org_name=name, parent=issuer.subject if issuer else None)
+ csr = Ngtcp2TestCA._make_csr(subject=subject,
+ issuer_subject=issuer_subject, pkey=pkey,
+ valid_from_delta=valid_from, valid_until_delta=valid_to)
+ csr = Ngtcp2TestCA._add_ca_usages(csr)
+ cert = csr.sign(private_key=issuer_key,
+ algorithm=hashes.SHA256(),
+ backend=default_backend())
+ return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
+
+ @staticmethod
+ def _make_server_credentials(name: str, domains: List[str], issuer: Credentials,
+ key_type: Any,
+ valid_from: timedelta = timedelta(days=-1),
+ valid_to: timedelta = timedelta(days=89),
+ ) -> Credentials:
+ name = name
+ pkey = _private_key(key_type=key_type)
+ subject = Ngtcp2TestCA._make_x509_name(common_name=name, parent=issuer.subject)
+ csr = Ngtcp2TestCA._make_csr(subject=subject,
+ issuer_subject=issuer.certificate.subject, pkey=pkey,
+ valid_from_delta=valid_from, valid_until_delta=valid_to)
+ csr = Ngtcp2TestCA._add_leaf_usages(csr, domains=domains, issuer=issuer)
+ cert = csr.sign(private_key=issuer.private_key,
+ algorithm=hashes.SHA256(),
+ backend=default_backend())
+ return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
+
+ @staticmethod
+ def _make_client_credentials(name: str,
+ issuer: Credentials, email: Optional[str],
+ key_type: Any,
+ valid_from: timedelta = timedelta(days=-1),
+ valid_to: timedelta = timedelta(days=89),
+ ) -> Credentials:
+ pkey = _private_key(key_type=key_type)
+ subject = Ngtcp2TestCA._make_x509_name(common_name=name, parent=issuer.subject)
+ csr = Ngtcp2TestCA._make_csr(subject=subject,
+ issuer_subject=issuer.certificate.subject, pkey=pkey,
+ valid_from_delta=valid_from, valid_until_delta=valid_to)
+ csr = Ngtcp2TestCA._add_client_usages(csr, issuer=issuer, rfc82name=email)
+ cert = csr.sign(private_key=issuer.private_key,
+ algorithm=hashes.SHA256(),
+ backend=default_backend())
+ return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
diff --git a/examples/tests/ngtcp2test/client.py b/examples/tests/ngtcp2test/client.py
new file mode 100644
index 0000000..2676343
--- /dev/null
+++ b/examples/tests/ngtcp2test/client.py
@@ -0,0 +1,187 @@
+import logging
+import os
+import re
+import subprocess
+from typing import List
+
+import pytest
+
+from .server import ExampleServer, ServerRun
+from .certs import Credentials
+from .tls import HandShake, HSRecord
+from .env import Env, CryptoLib
+from .log import LogFile, HexDumpScanner
+
+
+log = logging.getLogger(__name__)
+
+
+class ClientRun:
+
+ def __init__(self, env: Env, returncode, logfile: LogFile, srun: ServerRun):
+ self.env = env
+ self.returncode = returncode
+ self.logfile = logfile
+ self.log_lines = logfile.get_recent()
+ self._data_recs = None
+ self._hs_recs = None
+ self._srun = srun
+ if self.env.verbose > 1:
+ log.debug(f'read {len(self.log_lines)} lines from {logfile.path}')
+
+ @property
+ def handshake(self) -> List[HSRecord]:
+ if self._data_recs is None:
+ crypto_line = re.compile(r'Ordered CRYPTO data in \S+ crypto level')
+ scanner = HexDumpScanner(source=self.log_lines,
+ leading_regex=crypto_line)
+ self._data_recs = [data for data in scanner]
+ if self.env.verbose > 1:
+ log.debug(f'detected {len(self._data_recs)} crypto hexdumps '
+ f'in {self.logfile.path}')
+ if self._hs_recs is None:
+ self._hs_recs = [hrec for hrec in HandShake(source=self._data_recs,
+ verbose=self.env.verbose)]
+ if self.env.verbose > 1:
+ log.debug(f'detected {len(self._hs_recs)} crypto '
+ f'records in {self.logfile.path}')
+ return self._hs_recs
+
+ @property
+ def hs_stripe(self) -> str:
+ return ":".join([hrec.name for hrec in self.handshake])
+
+ @property
+ def early_data_rejected(self) -> bool:
+ for l in self.log_lines:
+ if re.match(r'^Early data was rejected by server.*', l):
+ return True
+ return False
+
+ @property
+ def server(self) -> ServerRun:
+ return self._srun
+
+ def norm_exp(self, c_hs, s_hs, allow_hello_retry=True):
+ if allow_hello_retry and self.hs_stripe.startswith('HelloRetryRequest:'):
+ c_hs = "HelloRetryRequest:" + c_hs
+ s_hs = "ClientHello:" + s_hs
+ return c_hs, s_hs
+
+ def _assert_hs(self, c_hs, s_hs):
+ if not self.hs_stripe.startswith(c_hs):
+ # what happened?
+ if self.hs_stripe == '':
+ # server send nothing
+ if self.server.hs_stripe == '':
+ # client send nothing
+ pytest.fail(f'client did not send a ClientHello"')
+ else:
+ # client send sth, but server did not respond
+ pytest.fail(f'server did not respond to ClientHello: '
+ f'{self.server.handshake[0].to_text()}"')
+ else:
+ pytest.fail(f'Expected "{c_hs}", got "{self.hs_stripe}"')
+ assert self.server.hs_stripe == s_hs, \
+ f'Expected "{s_hs}", got "{self.server.hs_stripe}"\n'
+
+ def assert_non_resume_handshake(self, allow_hello_retry=True):
+ # for client/server where KEY_SHARE do not match, the hello is retried
+ c_hs, s_hs = self.norm_exp(
+ "ServerHello:EncryptedExtensions:Certificate:CertificateVerify:Finished",
+ "ClientHello:Finished", allow_hello_retry=allow_hello_retry)
+ self._assert_hs(c_hs, s_hs)
+
+ def assert_resume_handshake(self):
+ # for client/server where KEY_SHARE do not match, the hello is retried
+ c_hs, s_hs = self.norm_exp("ServerHello:EncryptedExtensions:Finished",
+ "ClientHello:Finished")
+ self._assert_hs(c_hs, s_hs)
+
+ def assert_verify_null_handshake(self):
+ c_hs, s_hs = self.norm_exp(
+ "ServerHello:EncryptedExtensions:CertificateRequest:Certificate:CertificateVerify:Finished",
+ "ClientHello:Certificate:Finished")
+ self._assert_hs(c_hs, s_hs)
+
+ def assert_verify_cert_handshake(self):
+ c_hs, s_hs = self.norm_exp(
+ "ServerHello:EncryptedExtensions:CertificateRequest:Certificate:CertificateVerify:Finished",
+ "ClientHello:Certificate:CertificateVerify:Finished")
+ self._assert_hs(c_hs, s_hs)
+
+
+class ExampleClient:
+
+ def __init__(self, env: Env, crypto_lib: str):
+ self.env = env
+ self._crypto_lib = crypto_lib
+ self._path = env.client_path(self._crypto_lib)
+ self._log_path = f'{self.env.gen_dir}/{self._crypto_lib}-client.log'
+ self._qlog_path = f'{self.env.gen_dir}/{self._crypto_lib}-client.qlog'
+ self._session_path = f'{self.env.gen_dir}/{self._crypto_lib}-client.session'
+ self._tp_path = f'{self.env.gen_dir}/{self._crypto_lib}-client.tp'
+ self._data_path = f'{self.env.gen_dir}/{self._crypto_lib}-client.data'
+
+ @property
+ def path(self):
+ return self._path
+
+ @property
+ def crypto_lib(self):
+ return self._crypto_lib
+
+ @property
+ def uses_cipher_config(self):
+ return CryptoLib.uses_cipher_config(self.crypto_lib)
+
+ def supports_cipher(self, cipher):
+ return CryptoLib.supports_cipher(self.crypto_lib, cipher)
+
+ def exists(self):
+ return os.path.isfile(self.path)
+
+ def clear_session(self):
+ if os.path.isfile(self._session_path):
+ os.remove(self._session_path)
+ if os.path.isfile(self._tp_path):
+ os.remove(self._tp_path)
+
+ def http_get(self, server: ExampleServer, url: str, extra_args: List[str] = None,
+ use_session=False, data=None,
+ credentials: Credentials = None,
+ ciphers: str = None):
+ args = [
+ self.path, '--exit-on-all-streams-close',
+ f'--qlog-file={self._qlog_path}'
+ ]
+ if use_session:
+ args.append(f'--session-file={self._session_path}')
+ args.append(f'--tp-file={self._tp_path}')
+ if data is not None:
+ with open(self._data_path, 'w') as fd:
+ fd.write(data)
+ args.append(f'--data={self._data_path}')
+ if ciphers is not None:
+ ciphers = CryptoLib.adjust_ciphers(self.crypto_lib, ciphers)
+ args.append(f'--ciphers={ciphers}')
+ if credentials is not None:
+ args.append(f'--key={credentials.pkey_file}')
+ args.append(f'--cert={credentials.cert_file}')
+ if extra_args is not None:
+ args.extend(extra_args)
+ args.extend([
+ 'localhost', str(self.env.examples_port),
+ url
+ ])
+ if os.path.isfile(self._qlog_path):
+ os.remove(self._qlog_path)
+ with open(self._log_path, 'w') as log_file:
+ logfile = LogFile(path=self._log_path)
+ server.log.advance()
+ process = subprocess.Popen(args=args, text=True,
+ stdout=log_file, stderr=log_file)
+ process.wait()
+ return ClientRun(env=self.env, returncode=process.returncode,
+ logfile=logfile, srun=server.get_run())
+
diff --git a/examples/tests/ngtcp2test/env.py b/examples/tests/ngtcp2test/env.py
new file mode 100644
index 0000000..9699d55
--- /dev/null
+++ b/examples/tests/ngtcp2test/env.py
@@ -0,0 +1,191 @@
+import logging
+import os
+from configparser import ConfigParser, ExtendedInterpolation
+from typing import Dict, Optional
+
+from .certs import CertificateSpec, Ngtcp2TestCA, Credentials
+
+log = logging.getLogger(__name__)
+
+
+class CryptoLib:
+
+ IGNORES_CIPHER_CONFIG = [
+ 'picotls', 'boringssl'
+ ]
+ UNSUPPORTED_CIPHERS = {
+ 'wolfssl': [
+ 'TLS_AES_128_CCM_SHA256', # no plans to
+ ],
+ 'picotls': [
+ 'TLS_AES_128_CCM_SHA256', # no plans to
+ ],
+ 'boringssl': [
+ 'TLS_AES_128_CCM_SHA256', # no plans to
+ ]
+ }
+ GNUTLS_CIPHERS = {
+ 'TLS_AES_128_GCM_SHA256': 'AES-128-GCM',
+ 'TLS_AES_256_GCM_SHA384': 'AES-256-GCM',
+ 'TLS_CHACHA20_POLY1305_SHA256': 'CHACHA20-POLY1305',
+ 'TLS_AES_128_CCM_SHA256': 'AES-128-CCM',
+ }
+
+ @classmethod
+ def uses_cipher_config(cls, crypto_lib):
+ return crypto_lib not in cls.IGNORES_CIPHER_CONFIG
+
+ @classmethod
+ def supports_cipher(cls, crypto_lib, cipher):
+ return crypto_lib not in cls.UNSUPPORTED_CIPHERS or \
+ cipher not in cls.UNSUPPORTED_CIPHERS[crypto_lib]
+
+ @classmethod
+ def adjust_ciphers(cls, crypto_lib, ciphers: str) -> str:
+ if crypto_lib == 'gnutls':
+ gciphers = "NORMAL:-VERS-ALL:+VERS-TLS1.3:-CIPHER-ALL"
+ for cipher in ciphers.split(':'):
+ gciphers += f':+{cls.GNUTLS_CIPHERS[cipher]}'
+ return gciphers
+ return ciphers
+
+
+def init_config_from(conf_path):
+ if os.path.isfile(conf_path):
+ config = ConfigParser(interpolation=ExtendedInterpolation())
+ config.read(conf_path)
+ return config
+ return None
+
+
+TESTS_PATH = os.path.dirname(os.path.dirname(__file__))
+EXAMPLES_PATH = os.path.dirname(TESTS_PATH)
+DEF_CONFIG = init_config_from(os.path.join(TESTS_PATH, 'config.ini'))
+
+
+class Env:
+
+ @classmethod
+ def get_crypto_libs(cls, configurable_ciphers=None):
+ names = [name for name in DEF_CONFIG['examples']
+ if DEF_CONFIG['examples'][name] == 'yes']
+ if configurable_ciphers is not None:
+ names = [n for n in names if CryptoLib.uses_cipher_config(n)]
+ return names
+
+ def __init__(self, examples_dir=None, tests_dir=None, config=None,
+ pytestconfig=None):
+ self._verbose = pytestconfig.option.verbose if pytestconfig is not None else 0
+ self._examples_dir = examples_dir if examples_dir is not None else EXAMPLES_PATH
+ self._tests_dir = examples_dir if tests_dir is not None else TESTS_PATH
+ self._gen_dir = os.path.join(self._tests_dir, 'gen')
+ self.config = config if config is not None else DEF_CONFIG
+ self._version = self.config['ngtcp2']['version']
+ self._crypto_libs = [name for name in self.config['examples']
+ if self.config['examples'][name] == 'yes']
+ self._clients = [self.config['clients'][lib] for lib in self._crypto_libs
+ if lib in self.config['clients']]
+ self._servers = [self.config['servers'][lib] for lib in self._crypto_libs
+ if lib in self.config['servers']]
+ self._examples_pem = {
+ 'key': 'xxx',
+ 'cert': 'xxx',
+ }
+ self._htdocs_dir = os.path.join(self._gen_dir, 'htdocs')
+ self._tld = 'tests.ngtcp2.nghttp2.org'
+ self._example_domain = f"one.{self._tld}"
+ self._ca = None
+ self._cert_specs = [
+ CertificateSpec(domains=[self._example_domain], key_type='rsa2048'),
+ CertificateSpec(name="clientsX", sub_specs=[
+ CertificateSpec(name="user1", client=True),
+ ]),
+ ]
+
+ def issue_certs(self):
+ if self._ca is None:
+ self._ca = Ngtcp2TestCA.create_root(name=self._tld,
+ store_dir=os.path.join(self.gen_dir, 'ca'),
+ key_type="rsa2048")
+ self._ca.issue_certs(self._cert_specs)
+
+ def setup(self):
+ os.makedirs(self._gen_dir, exist_ok=True)
+ os.makedirs(self._htdocs_dir, exist_ok=True)
+ self.issue_certs()
+
+ def get_server_credentials(self) -> Optional[Credentials]:
+ creds = self.ca.get_credentials_for_name(self._example_domain)
+ if len(creds) > 0:
+ return creds[0]
+ return None
+
+ @property
+ def verbose(self) -> int:
+ return self._verbose
+
+ @property
+ def version(self) -> str:
+ return self._version
+
+ @property
+ def gen_dir(self) -> str:
+ return self._gen_dir
+
+ @property
+ def ca(self):
+ return self._ca
+
+ @property
+ def htdocs_dir(self) -> str:
+ return self._htdocs_dir
+
+ @property
+ def example_domain(self) -> str:
+ return self._example_domain
+
+ @property
+ def examples_dir(self) -> str:
+ return self._examples_dir
+
+ @property
+ def examples_port(self) -> int:
+ return int(self.config['examples']['port'])
+
+ @property
+ def examples_pem(self) -> Dict[str, str]:
+ return self._examples_pem
+
+ @property
+ def crypto_libs(self):
+ return self._crypto_libs
+
+ @property
+ def clients(self):
+ return self._clients
+
+ @property
+ def servers(self):
+ return self._servers
+
+ def client_name(self, crypto_lib):
+ if crypto_lib in self.config['clients']:
+ return self.config['clients'][crypto_lib]
+ return None
+
+ def client_path(self, crypto_lib):
+ cname = self.client_name(crypto_lib)
+ if cname is not None:
+ return os.path.join(self.examples_dir, cname)
+ return None
+
+ def server_name(self, crypto_lib):
+ if crypto_lib in self.config['servers']:
+ return self.config['servers'][crypto_lib]
+ return None
+
+ def server_path(self, crypto_lib):
+ sname = self.server_name(crypto_lib)
+ if sname is not None:
+ return os.path.join(self.examples_dir, sname)
+ return None
diff --git a/examples/tests/ngtcp2test/log.py b/examples/tests/ngtcp2test/log.py
new file mode 100644
index 0000000..9e8f399
--- /dev/null
+++ b/examples/tests/ngtcp2test/log.py
@@ -0,0 +1,101 @@
+import binascii
+import os
+import re
+import sys
+import time
+from datetime import timedelta, datetime
+from io import SEEK_END
+from typing import List
+
+
+class LogFile:
+
+ def __init__(self, path: str):
+ self._path = path
+ self._start_pos = 0
+ self._last_pos = self._start_pos
+
+ @property
+ def path(self) -> str:
+ return self._path
+
+ def reset(self):
+ self._start_pos = 0
+ self._last_pos = self._start_pos
+
+ def advance(self) -> None:
+ if os.path.isfile(self._path):
+ with open(self._path) as fd:
+ self._start_pos = fd.seek(0, SEEK_END)
+
+ def get_recent(self, advance=True) -> List[str]:
+ lines = []
+ if os.path.isfile(self._path):
+ with open(self._path) as fd:
+ fd.seek(self._last_pos, os.SEEK_SET)
+ for line in fd:
+ lines.append(line)
+ if advance:
+ self._last_pos = fd.tell()
+ return lines
+
+ def scan_recent(self, pattern: re, timeout=10) -> bool:
+ if not os.path.isfile(self.path):
+ return False
+ with open(self.path) as fd:
+ end = datetime.now() + timedelta(seconds=timeout)
+ while True:
+ fd.seek(self._last_pos, os.SEEK_SET)
+ for line in fd:
+ if pattern.match(line):
+ return True
+ if datetime.now() > end:
+ raise TimeoutError(f"pattern not found in error log after {timeout} seconds")
+ time.sleep(.1)
+ return False
+
+
+class HexDumpScanner:
+
+ def __init__(self, source, leading_regex=None):
+ self._source = source
+ self._leading_regex = leading_regex
+
+ def __iter__(self):
+ data = b''
+ offset = 0 if self._leading_regex is None else -1
+ idx = 0
+ for l in self._source:
+ if offset == -1:
+ pass
+ elif offset == 0:
+ # possible start of a hex dump
+ m = re.match(r'^\s*0+(\s+-)?((\s+[0-9a-f]{2}){1,16})(\s+.*)$',
+ l, re.IGNORECASE)
+ if m:
+ data = binascii.unhexlify(re.sub(r'\s+', '', m.group(2)))
+ offset = 16
+ idx = 1
+ continue
+ else:
+ # possible continuation of a hexdump
+ m = re.match(r'^\s*([0-9a-f]+)(\s+-)?((\s+[0-9a-f]{2}){1,16})'
+ r'(\s+.*)$', l, re.IGNORECASE)
+ if m:
+ loffset = int(m.group(1), 16)
+ if loffset == offset or loffset == idx:
+ data += binascii.unhexlify(re.sub(r'\s+', '',
+ m.group(3)))
+ offset += 16
+ idx += 1
+ continue
+ else:
+ sys.stderr.write(f'wrong offset {loffset}, expected {offset} or {idx}\n')
+ # not a hexdump line, produce any collected data
+ if len(data) > 0:
+ yield data
+ data = b''
+ offset = 0 if self._leading_regex is None \
+ or self._leading_regex.match(l) else -1
+ if len(data) > 0:
+ yield data
diff --git a/examples/tests/ngtcp2test/server.py b/examples/tests/ngtcp2test/server.py
new file mode 100644
index 0000000..9f4e9a0
--- /dev/null
+++ b/examples/tests/ngtcp2test/server.py
@@ -0,0 +1,137 @@
+import logging
+import os
+import re
+import subprocess
+import time
+from datetime import datetime, timedelta
+from threading import Thread
+
+from .tls import HandShake
+from .env import Env, CryptoLib
+from .log import LogFile, HexDumpScanner
+
+
+log = logging.getLogger(__name__)
+
+
+class ServerRun:
+
+ def __init__(self, env: Env, logfile: LogFile):
+ self.env = env
+ self._logfile = logfile
+ self.log_lines = self._logfile.get_recent()
+ self._data_recs = None
+ self._hs_recs = None
+ if self.env.verbose > 1:
+ log.debug(f'read {len(self.log_lines)} lines from {logfile.path}')
+
+ @property
+ def handshake(self):
+ if self._data_recs is None:
+ self._data_recs = [data for data in HexDumpScanner(source=self.log_lines)]
+ if self.env.verbose > 1:
+ log.debug(f'detected {len(self._data_recs)} hexdumps '
+ f'in {self._logfile.path}')
+ if self._hs_recs is None:
+ self._hs_recs = [hrec for hrec in HandShake(source=self._data_recs,
+ verbose=self.env.verbose)]
+ if self.env.verbose > 1:
+ log.debug(f'detected {len(self._hs_recs)} crypto records '
+ f'in {self._logfile.path}')
+ return self._hs_recs
+
+ @property
+ def hs_stripe(self):
+ return ":".join([hrec.name for hrec in self.handshake])
+
+
+def monitor_proc(env: Env, proc):
+ _env = env
+ proc.wait()
+
+
+class ExampleServer:
+
+ def __init__(self, env: Env, crypto_lib: str, verify_client=False):
+ self.env = env
+ self._crypto_lib = crypto_lib
+ self._path = env.server_path(self._crypto_lib)
+ self._logpath = f'{self.env.gen_dir}/{self._crypto_lib}-server.log'
+ self._log = LogFile(path=self._logpath)
+ self._logfile = None
+ self._process = None
+ self._verify_client = verify_client
+
+ @property
+ def path(self):
+ return self._path
+
+ @property
+ def crypto_lib(self):
+ return self._crypto_lib
+
+ @property
+ def uses_cipher_config(self):
+ return CryptoLib.uses_cipher_config(self.crypto_lib)
+
+ def supports_cipher(self, cipher):
+ return CryptoLib.supports_cipher(self.crypto_lib, cipher)
+
+ @property
+ def log(self):
+ return self._log
+
+ def exists(self):
+ return os.path.isfile(self.path)
+
+ def start(self):
+ if self._process is not None:
+ return False
+ creds = self.env.get_server_credentials()
+ assert creds
+ args = [
+ self.path,
+ f'--htdocs={self.env.htdocs_dir}',
+ ]
+ if self._verify_client:
+ args.append('--verify-client')
+ args.extend([
+ '*', str(self.env.examples_port),
+ creds.pkey_file, creds.cert_file
+ ])
+ self._logfile = open(self._logpath, 'w')
+ self._process = subprocess.Popen(args=args, text=True,
+ stdout=self._logfile, stderr=self._logfile)
+ t = Thread(target=monitor_proc, daemon=True, args=(self.env, self._process))
+ t.start()
+ timeout = 5
+ end = datetime.now() + timedelta(seconds=timeout)
+ while True:
+ if self._process.poll():
+ return False
+ try:
+ if self.log.scan_recent(pattern=re.compile(r'^Using document root'), timeout=0.5):
+ break
+ except TimeoutError:
+ pass
+ if datetime.now() > end:
+ raise TimeoutError(f"pattern not found in error log after {timeout} seconds")
+ self.log.advance()
+ return True
+
+ def stop(self):
+ if self._process:
+ self._process.terminate()
+ self._process = None
+ if self._logfile:
+ self._logfile.close()
+ self._logfile = None
+ return True
+
+ def restart(self):
+ self.stop()
+ self._log.reset()
+ return self.start()
+
+ def get_run(self) -> ServerRun:
+ return ServerRun(env=self.env, logfile=self.log)
diff --git a/examples/tests/ngtcp2test/tls.py b/examples/tests/ngtcp2test/tls.py
new file mode 100644
index 0000000..f9bce62
--- /dev/null
+++ b/examples/tests/ngtcp2test/tls.py
@@ -0,0 +1,983 @@
+import binascii
+import logging
+import sys
+from collections.abc import Iterator
+from typing import Dict, Any, Iterable
+
+
+log = logging.getLogger(__name__)
+
+
+class ParseError(Exception):
+ pass
+
+def _get_int(d, n):
+ if len(d) < n:
+ raise ParseError(f'get_int: {n} bytes needed, but data is {d}')
+ if n == 1:
+ dlen = d[0]
+ else:
+ dlen = int.from_bytes(d[0:n], byteorder='big')
+ return d[n:], dlen
+
+
+def _get_field(d, dlen):
+ if dlen > 0:
+ if len(d) < dlen:
+ raise ParseError(f'field len={dlen}, but data len={len(d)}')
+ field = d[0:dlen]
+ return d[dlen:], field
+ return d, b''
+
+
+def _get_len_field(d, n):
+ d, dlen = _get_int(d, n)
+ return _get_field(d, dlen)
+
+
+# d are bytes that start with a quic variable length integer
+def _get_qint(d):
+ i = d[0] & 0xc0
+ if i == 0:
+ return d[1:], int(d[0])
+ elif i == 0x40:
+ ndata = bytearray(d[0:2])
+ d = d[2:]
+ ndata[0] = ndata[0] & ~0xc0
+ return d, int.from_bytes(ndata, byteorder='big')
+ elif i == 0x80:
+ ndata = bytearray(d[0:4])
+ d = d[4:]
+ ndata[0] = ndata[0] & ~0xc0
+ return d, int.from_bytes(ndata, byteorder='big')
+ else:
+ ndata = bytearray(d[0:8])
+ d = d[8:]
+ ndata[0] = ndata[0] & ~0xc0
+ return d, int.from_bytes(ndata, byteorder='big')
+
+
+class TlsSupportedGroups:
+ NAME_BY_ID = {
+ 0: 'Reserved',
+ 1: 'sect163k1',
+ 2: 'sect163r1',
+ 3: 'sect163r2',
+ 4: 'sect193r1',
+ 5: 'sect193r2',
+ 6: 'sect233k1',
+ 7: 'sect233r1',
+ 8: 'sect239k1',
+ 9: 'sect283k1',
+ 10: 'sect283r1',
+ 11: 'sect409k1',
+ 12: 'sect409r1',
+ 13: 'sect571k1',
+ 14: 'sect571r1',
+ 15: 'secp160k1',
+ 16: 'secp160r1',
+ 17: 'secp160r2',
+ 18: 'secp192k1',
+ 19: 'secp192r1',
+ 20: 'secp224k1',
+ 21: 'secp224r1',
+ 22: 'secp256k1',
+ 23: 'secp256r1',
+ 24: 'secp384r1',
+ 25: 'secp521r1',
+ 26: 'brainpoolP256r1',
+ 27: 'brainpoolP384r1',
+ 28: 'brainpoolP512r1',
+ 29: 'x25519',
+ 30: 'x448',
+ 31: 'brainpoolP256r1tls13',
+ 32: 'brainpoolP384r1tls13',
+ 33: 'brainpoolP512r1tls13',
+ 34: 'GC256A',
+ 35: 'GC256B',
+ 36: 'GC256C',
+ 37: 'GC256D',
+ 38: 'GC512A',
+ 39: 'GC512B',
+ 40: 'GC512C',
+ 41: 'curveSM2',
+ }
+
+ @classmethod
+ def name(cls, gid):
+ if gid in cls.NAME_BY_ID:
+ return f'{cls.NAME_BY_ID[gid]}(0x{gid:0x})'
+ return f'0x{gid:0x}'
+
+
+class TlsSignatureScheme:
+ NAME_BY_ID = {
+ 0x0201: 'rsa_pkcs1_sha1',
+ 0x0202: 'Reserved',
+ 0x0203: 'ecdsa_sha1',
+ 0x0401: 'rsa_pkcs1_sha256',
+ 0x0403: 'ecdsa_secp256r1_sha256',
+ 0x0420: 'rsa_pkcs1_sha256_legacy',
+ 0x0501: 'rsa_pkcs1_sha384',
+ 0x0503: 'ecdsa_secp384r1_sha384',
+ 0x0520: 'rsa_pkcs1_sha384_legacy',
+ 0x0601: 'rsa_pkcs1_sha512',
+ 0x0603: 'ecdsa_secp521r1_sha512',
+ 0x0620: 'rsa_pkcs1_sha512_legacy',
+ 0x0704: 'eccsi_sha256',
+ 0x0705: 'iso_ibs1',
+ 0x0706: 'iso_ibs2',
+ 0x0707: 'iso_chinese_ibs',
+ 0x0708: 'sm2sig_sm3',
+ 0x0709: 'gostr34102012_256a',
+ 0x070A: 'gostr34102012_256b',
+ 0x070B: 'gostr34102012_256c',
+ 0x070C: 'gostr34102012_256d',
+ 0x070D: 'gostr34102012_512a',
+ 0x070E: 'gostr34102012_512b',
+ 0x070F: 'gostr34102012_512c',
+ 0x0804: 'rsa_pss_rsae_sha256',
+ 0x0805: 'rsa_pss_rsae_sha384',
+ 0x0806: 'rsa_pss_rsae_sha512',
+ 0x0807: 'ed25519',
+ 0x0808: 'ed448',
+ 0x0809: 'rsa_pss_pss_sha256',
+ 0x080A: 'rsa_pss_pss_sha384',
+ 0x080B: 'rsa_pss_pss_sha512',
+ 0x081A: 'ecdsa_brainpoolP256r1tls13_sha256',
+ 0x081B: 'ecdsa_brainpoolP384r1tls13_sha384',
+ 0x081C: 'ecdsa_brainpoolP512r1tls13_sha512',
+ }
+
+ @classmethod
+ def name(cls, gid):
+ if gid in cls.NAME_BY_ID:
+ return f'{cls.NAME_BY_ID[gid]}(0x{gid:0x})'
+ return f'0x{gid:0x}'
+
+
+class TlsCipherSuites:
+ NAME_BY_ID = {
+ 0x1301: 'TLS_AES_128_GCM_SHA256',
+ 0x1302: 'TLS_AES_256_GCM_SHA384',
+ 0x1303: 'TLS_CHACHA20_POLY1305_SHA256',
+ 0x1304: 'TLS_AES_128_CCM_SHA256',
+ 0x1305: 'TLS_AES_128_CCM_8_SHA256',
+ 0x00ff: 'TLS_EMPTY_RENEGOTIATION_INFO_SCSV',
+ }
+
+ @classmethod
+ def name(cls, cid):
+ if cid in cls.NAME_BY_ID:
+ return f'{cls.NAME_BY_ID[cid]}(0x{cid:0x})'
+ return f'Cipher(0x{cid:0x})'
+
+
+class PskKeyExchangeMode:
+ NAME_BY_ID = {
+ 0x00: 'psk_ke',
+ 0x01: 'psk_dhe_ke',
+ }
+
+ @classmethod
+ def name(cls, gid):
+ if gid in cls.NAME_BY_ID:
+ return f'{cls.NAME_BY_ID[gid]}(0x{gid:0x})'
+ return f'0x{gid:0x}'
+
+
+class QuicTransportParam:
+ NAME_BY_ID = {
+ 0x00: 'original_destination_connection_id',
+ 0x01: 'max_idle_timeout',
+ 0x02: 'stateless_reset_token',
+ 0x03: 'max_udp_payload_size',
+ 0x04: 'initial_max_data',
+ 0x05: 'initial_max_stream_data_bidi_local',
+ 0x06: 'initial_max_stream_data_bidi_remote',
+ 0x07: 'initial_max_stream_data_uni',
+ 0x08: 'initial_max_streams_bidi',
+ 0x09: 'initial_max_streams_uni',
+ 0x0a: 'ack_delay_exponent',
+ 0x0b: 'max_ack_delay',
+ 0x0c: 'disable_active_migration',
+ 0x0d: 'preferred_address',
+ 0x0e: 'active_connection_id_limit',
+ 0x0f: 'initial_source_connection_id',
+ 0x10: 'retry_source_connection_id',
+ }
+ TYPE_BY_ID = {
+ 0x00: bytes,
+ 0x01: int,
+ 0x02: bytes,
+ 0x03: int,
+ 0x04: int,
+ 0x05: int,
+ 0x06: int,
+ 0x07: int,
+ 0x08: int,
+ 0x09: int,
+ 0x0a: int,
+ 0x0b: int,
+ 0x0c: int,
+ 0x0d: bytes,
+ 0x0e: int,
+ 0x0f: bytes,
+ 0x10: bytes,
+ }
+
+ @classmethod
+ def name(cls, cid):
+ if cid in cls.NAME_BY_ID:
+ return f'{cls.NAME_BY_ID[cid]}(0x{cid:0x})'
+ return f'QuicTP(0x{cid:0x})'
+
+ @classmethod
+ def is_qint(cls, cid):
+ if cid in cls.TYPE_BY_ID:
+ return cls.TYPE_BY_ID[cid] == int
+ return False
+
+
+class Extension:
+
+ def __init__(self, eid, name, edata, hsid):
+ self._eid = eid
+ self._name = name
+ self._edata = edata
+ self._hsid = hsid
+
+ @property
+ def data(self):
+ return self._edata
+
+ @property
+ def hsid(self):
+ return self._hsid
+
+ def to_json(self):
+ jdata = {
+ 'id': self._eid,
+ 'name': self._name,
+ }
+ if len(self.data) > 0:
+ jdata['data'] = binascii.hexlify(self.data).decode()
+ return jdata
+
+ def to_text(self, indent: int = 0):
+ ind = ' ' * (indent + 2)
+ s = f'{ind}{self._name}(0x{self._eid:0x})'
+ if len(self._edata):
+ s += f'\n{ind} data({len(self._edata)}): ' \
+ f'{binascii.hexlify(self._edata).decode()}'
+ return s
+
+
+class ExtSupportedGroups(Extension):
+
+ def __init__(self, eid, name, edata, hsid):
+ super().__init__(eid=eid, name=name, edata=edata, hsid=hsid)
+ d = edata
+ self._groups = []
+ while len(d) > 0:
+ d, gid = _get_int(d, 2)
+ self._groups.append(gid)
+
+ def to_json(self):
+ jdata = {
+ 'id': self._eid,
+ 'name': self._name,
+ }
+ if len(self._groups):
+ jdata['groups'] = [TlsSupportedGroups.name(gid)
+ for gid in self._groups]
+ return jdata
+
+ def to_text(self, indent: int = 0):
+ ind = ' ' * (indent + 2)
+ gnames = [TlsSupportedGroups.name(gid) for gid in self._groups]
+ s = f'{ind}{self._name}(0x{self._eid:0x}): {", ".join(gnames)}'
+ return s
+
+
+class ExtKeyShare(Extension):
+
+ def __init__(self, eid, name, edata, hsid):
+ super().__init__(eid=eid, name=name, edata=edata, hsid=hsid)
+ d = self.data
+ self._keys = []
+ self._group = None
+ self._pubkey = None
+ if self.hsid == 2: # ServerHello
+ # single key share (group, pubkey)
+ d, self._group = _get_int(d, 2)
+ d, self._pubkey = _get_len_field(d, 2)
+ elif self.hsid == 6: # HelloRetryRequest
+ assert len(d) == 2
+ d, self._group = _get_int(d, 2)
+ else:
+ # list if key shares (group, pubkey)
+ d, shares = _get_len_field(d, 2)
+ while len(shares) > 0:
+ shares, group = _get_int(shares, 2)
+ shares, pubkey = _get_len_field(shares, 2)
+ self._keys.append({
+ 'group': TlsSupportedGroups.name(group),
+ 'pubkey': binascii.hexlify(pubkey).decode()
+ })
+
+ def to_json(self):
+ jdata = super().to_json()
+ if self._group is not None:
+ jdata['group'] = TlsSupportedGroups.name(self._group)
+ if self._pubkey is not None:
+ jdata['pubkey'] = binascii.hexlify(self._pubkey).decode()
+ if len(self._keys) > 0:
+ jdata['keys'] = self._keys
+ return jdata
+
+ def to_text(self, indent: int = 0):
+ ind = ' ' * (indent + 2)
+ s = f'{ind}{self._name}(0x{self._eid:0x})'
+ if self._group is not None:
+ s += f'\n{ind} group: {TlsSupportedGroups.name(self._group)}'
+ if self._pubkey is not None:
+ s += f'\n{ind} pubkey: {binascii.hexlify(self._pubkey).decode()}'
+ if len(self._keys) > 0:
+ for idx, key in enumerate(self._keys):
+ s += f'\n{ind} {idx}: {key["group"]}, {key["pubkey"]}'
+ return s
+
+
+class ExtSNI(Extension):
+
+ def __init__(self, eid, name, edata, hsid):
+ super().__init__(eid=eid, name=name, edata=edata, hsid=hsid)
+ d = self.data
+ self._indicators = []
+ while len(d) > 0:
+ d, entry = _get_len_field(d, 2)
+ entry, stype = _get_int(entry, 1)
+ entry, sname = _get_len_field(entry, 2)
+ self._indicators.append({
+ 'type': stype,
+ 'name': sname.decode()
+ })
+
+ def to_json(self):
+ jdata = super().to_json()
+ for i in self._indicators:
+ if i['type'] == 0:
+ jdata['host_name'] = i['name']
+ else:
+ jdata[f'0x{i["type"]}'] = i['name']
+ return jdata
+
+ def to_text(self, indent: int = 0):
+ ind = ' ' * (indent + 2)
+ s = f'{ind}{self._name}(0x{self._eid:0x})'
+ if len(self._indicators) == 1 and self._indicators[0]['type'] == 0:
+ s += f': {self._indicators[0]["name"]}'
+ else:
+ for i in self._indicators:
+ ikey = 'host_name' if i["type"] == 0 else f'type(0x{i["type"]:0x}'
+ s += f'\n{ind} {ikey}: {i["name"]}'
+ return s
+
+
+class ExtALPN(Extension):
+
+ def __init__(self, eid, name, edata, hsid):
+ super().__init__(eid=eid, name=name, edata=edata, hsid=hsid)
+ d = self.data
+ d, list_len = _get_int(d, 2)
+ self._protocols = []
+ while len(d) > 0:
+ d, proto = _get_len_field(d, 1)
+ self._protocols.append(proto.decode())
+
+ def to_json(self):
+ jdata = super().to_json()
+ if len(self._protocols) == 1:
+ jdata['alpn'] = self._protocols[0]
+ else:
+ jdata['alpn'] = self._protocols
+ return jdata
+
+ def to_text(self, indent: int = 0):
+ ind = ' ' * (indent + 2)
+ return f'{ind}{self._name}(0x{self._eid:0x}): {", ".join(self._protocols)}'
+
+
+class ExtEarlyData(Extension):
+
+ def __init__(self, eid, name, edata, hsid):
+ super().__init__(eid=eid, name=name, edata=edata, hsid=hsid)
+ self._max_size = None
+ d = self.data
+ if hsid == 4: # SessionTicket
+ assert len(d) == 4, f'expected 4, len is {len(d)} data={d}'
+ d, self._max_size = _get_int(d, 4)
+ else:
+ assert len(d) == 0
+
+ def to_json(self):
+ jdata = super().to_json()
+ if self._max_size is not None:
+ jdata['max_size'] = self._max_size
+ return jdata
+
+
+class ExtSignatureAlgorithms(Extension):
+
+ def __init__(self, eid, name, edata, hsid):
+ super().__init__(eid=eid, name=name, edata=edata, hsid=hsid)
+ d = self.data
+ d, list_len = _get_int(d, 2)
+ self._algos = []
+ while len(d) > 0:
+ d, algo = _get_int(d, 2)
+ self._algos.append(TlsSignatureScheme.name(algo))
+
+ def to_json(self):
+ jdata = super().to_json()
+ if len(self._algos) > 0:
+ jdata['algorithms'] = self._algos
+ return jdata
+
+ def to_text(self, indent: int = 0):
+ ind = ' ' * (indent + 2)
+ return f'{ind}{self._name}(0x{self._eid:0x}): {", ".join(self._algos)}'
+
+
+class ExtPSKExchangeModes(Extension):
+
+ def __init__(self, eid, name, edata, hsid):
+ super().__init__(eid=eid, name=name, edata=edata, hsid=hsid)
+ d = self.data
+ d, list_len = _get_int(d, 1)
+ self._modes = []
+ while len(d) > 0:
+ d, mode = _get_int(d, 1)
+ self._modes.append(PskKeyExchangeMode.name(mode))
+
+ def to_json(self):
+ jdata = super().to_json()
+ jdata['modes'] = self._modes
+ return jdata
+
+ def to_text(self, indent: int = 0):
+ ind = ' ' * (indent + 2)
+ return f'{ind}{self._name}(0x{self._eid:0x}): {", ".join(self._modes)}'
+
+
+class ExtPreSharedKey(Extension):
+
+ def __init__(self, eid, name, edata, hsid):
+ super().__init__(eid=eid, name=name, edata=edata, hsid=hsid)
+ self._kid = None
+ self._identities = None
+ self._binders = None
+ d = self.data
+ if hsid == 1: # client hello
+ d, idata = _get_len_field(d, 2)
+ self._identities = []
+ while len(idata):
+ idata, identity = _get_len_field(idata, 2)
+ idata, obfs_age = _get_int(idata, 4)
+ self._identities.append({
+ 'id': binascii.hexlify(identity).decode(),
+ 'age': obfs_age,
+ })
+ d, binders = _get_len_field(d, 2)
+ self._binders = []
+ while len(binders) > 0:
+ binders, hmac = _get_len_field(binders, 1)
+ self._binders.append(binascii.hexlify(hmac).decode())
+ assert len(d) == 0
+ else:
+ d, self._kid = _get_int(d, 2)
+
+ def to_json(self):
+ jdata = super().to_json()
+ if self.hsid == 1:
+ jdata['identities'] = self._identities
+ jdata['binders'] = self._binders
+ else:
+ jdata['identity'] = self._kid
+ return jdata
+
+ def to_text(self, indent: int = 0):
+ ind = ' ' * (indent + 2)
+ s = f'{ind}{self._name}(0x{self._hsid:0x})'
+ if self.hsid == 1:
+ for idx, i in enumerate(self._identities):
+ s += f'\n{ind} {idx}: {i["id"]} ({i["age"]})'
+ s += f'\n{ind} binders: {self._binders}'
+ else:
+ s += f'\n{ind} identity: {self._kid}'
+ return s
+
+
+class ExtSupportedVersions(Extension):
+
+ def __init__(self, eid, name, edata, hsid):
+ super().__init__(eid=eid, name=name, edata=edata, hsid=hsid)
+ d = self.data
+ self._versions = []
+ if hsid == 1: # client hello
+ d, list_len = _get_int(d, 1)
+ while len(d) > 0:
+ d, version = _get_int(d, 2)
+ self._versions.append(f'0x{version:0x}')
+ else:
+ d, version = _get_int(d, 2)
+ self._versions.append(f'0x{version:0x}')
+
+ def to_json(self):
+ jdata = super().to_json()
+ if len(self._versions) == 1:
+ jdata['version'] = self._versions[0]
+ else:
+ jdata['versions'] = self._versions
+ return jdata
+
+ def to_text(self, indent: int = 0):
+ ind = ' ' * (indent + 2)
+ return f'{ind}{self._name}(0x{self._eid:0x}): {", ".join(self._versions)}'
+
+
+class ExtQuicTP(Extension):
+
+ def __init__(self, eid, name, edata, hsid):
+ super().__init__(eid=eid, name=name, edata=edata, hsid=hsid)
+ d = self.data
+ self._params = []
+ while len(d) > 0:
+ d, ptype = _get_qint(d)
+ d, plen = _get_qint(d)
+ d, pvalue = _get_field(d, plen)
+ if QuicTransportParam.is_qint(ptype):
+ _, pvalue = _get_qint(pvalue)
+ else:
+ pvalue = binascii.hexlify(pvalue).decode()
+ self._params.append({
+ 'key': QuicTransportParam.name(ptype),
+ 'value': pvalue,
+ })
+
+ def to_json(self):
+ jdata = super().to_json()
+ jdata['params'] = self._params
+ return jdata
+
+ def to_text(self, indent: int = 0):
+ ind = ' ' * (indent + 2)
+ s = f'{ind}{self._name}(0x{self._eid:0x})'
+ for p in self._params:
+ s += f'\n{ind} {p["key"]}: {p["value"]}'
+ return s
+
+
+class TlsExtensions:
+
+ EXT_TYPES = [
+ (0x00, 'SNI', ExtSNI),
+ (0x01, 'MAX_FRAGMENT_LENGTH', Extension),
+ (0x03, 'TRUSTED_CA_KEYS', Extension),
+ (0x04, 'TRUNCATED_HMAC', Extension),
+ (0x05, 'OSCP_STATUS_REQUEST', Extension),
+ (0x0a, 'SUPPORTED_GROUPS', ExtSupportedGroups),
+ (0x0b, 'EC_POINT_FORMATS', Extension),
+ (0x0d, 'SIGNATURE_ALGORITHMS', ExtSignatureAlgorithms),
+ (0x0e, 'USE_SRTP', Extension),
+ (0x10, 'ALPN', ExtALPN),
+ (0x11, 'STATUS_REQUEST_V2', Extension),
+ (0x16, 'ENCRYPT_THEN_MAC', Extension),
+ (0x17, 'EXTENDED_MASTER_SECRET', Extension),
+ (0x23, 'SESSION_TICKET', Extension),
+ (0x29, 'PRE_SHARED_KEY', ExtPreSharedKey),
+ (0x2a, 'EARLY_DATA', ExtEarlyData),
+ (0x2b, 'SUPPORTED_VERSIONS', ExtSupportedVersions),
+ (0x2c, 'COOKIE', Extension),
+ (0x2d, 'PSK_KEY_EXCHANGE_MODES', ExtPSKExchangeModes),
+ (0x31, 'POST_HANDSHAKE_AUTH', Extension),
+ (0x32, 'SIGNATURE_ALGORITHMS_CERT', Extension),
+ (0x33, 'KEY_SHARE', ExtKeyShare),
+ (0x39, 'QUIC_TP_PARAMS', ExtQuicTP),
+ (0xff01, 'RENEGOTIATION_INFO', Extension),
+ (0xffa5, 'QUIC_TP_PARAMS_DRAFT', ExtQuicTP),
+ ]
+ NAME_BY_ID = {}
+ CLASS_BY_ID = {}
+
+ @classmethod
+ def init(cls):
+ for (eid, name, ecls) in cls.EXT_TYPES:
+ cls.NAME_BY_ID[eid] = name
+ cls.CLASS_BY_ID[eid] = ecls
+
+ @classmethod
+ def from_data(cls, hsid, data):
+ exts = []
+ d = data
+ while len(d):
+ d, eid = _get_int(d, 2)
+ d, elen = _get_int(d, 2)
+ d, edata = _get_field(d, elen)
+ if eid in cls.NAME_BY_ID:
+ ename = cls.NAME_BY_ID[eid]
+ ecls = cls.CLASS_BY_ID[eid]
+ exts.append(ecls(eid=eid, name=ename, edata=edata, hsid=hsid))
+ else:
+ exts.append(Extension(eid=eid, name=f'(0x{eid:0x})',
+ edata=edata, hsid=hsid))
+ return exts
+
+
+TlsExtensions.init()
+
+
+class HSRecord:
+
+ def __init__(self, hsid: int, name: str, data):
+ self._hsid = hsid
+ self._name = name
+ self._data = data
+
+ @property
+ def hsid(self):
+ return self._hsid
+
+ @property
+ def name(self):
+ return self._name
+
+ @name.setter
+ def name(self, value):
+ self._name = value
+
+ @property
+ def data(self):
+ return self._data
+
+ def __repr__(self):
+ return f'{self.name}[{binascii.hexlify(self._data).decode()}]'
+
+ def to_json(self) -> Dict[str, Any]:
+ return {
+ 'name': self.name,
+ 'data': binascii.hexlify(self._data).decode(),
+ }
+
+ def to_text(self, indent: int = 0):
+ ind = ' ' * (indent + 2)
+ return f'{ind}{self._name}\n'\
+ f'{ind} id: 0x{self._hsid:0x}\n'\
+ f'{ind} data({len(self._data)}): '\
+ f'{binascii.hexlify(self._data).decode()}'
+
+
+class ClientHello(HSRecord):
+
+ def __init__(self, hsid: int, name: str, data):
+ super().__init__(hsid=hsid, name=name, data=data)
+ d = data
+ d, self._version = _get_int(d, 2)
+ d, self._random = _get_field(d, 32)
+ d, self._session_id = _get_len_field(d, 1)
+ self._ciphers = []
+ d, ciphers = _get_len_field(d, 2)
+ while len(ciphers):
+ ciphers, cipher = _get_int(ciphers, 2)
+ self._ciphers.append(TlsCipherSuites.name(cipher))
+ d, comps = _get_len_field(d, 1)
+ self._compressions = [int(c) for c in comps]
+ d, edata = _get_len_field(d, 2)
+ self._extensions = TlsExtensions.from_data(hsid, edata)
+
+ def to_json(self):
+ jdata = super().to_json()
+ jdata['version'] = f'0x{self._version:0x}'
+ jdata['random'] = f'{binascii.hexlify(self._random).decode()}'
+ jdata['session_id'] = binascii.hexlify(self._session_id).decode()
+ jdata['ciphers'] = self._ciphers
+ jdata['compressions'] = self._compressions
+ jdata['extensions'] = [ext.to_json() for ext in self._extensions]
+ return jdata
+
+ def to_text(self, indent: int = 0):
+ ind = ' ' * (indent + 2)
+ return super().to_text(indent=indent) + '\n'\
+ f'{ind} version: 0x{self._version:0x}\n'\
+ f'{ind} random: {binascii.hexlify(self._random).decode()}\n' \
+ f'{ind} session_id: {binascii.hexlify(self._session_id).decode()}\n' \
+ f'{ind} ciphers: {", ".join(self._ciphers)}\n'\
+ f'{ind} compressions: {self._compressions}\n'\
+ f'{ind} extensions: \n' + '\n'.join(
+ [ext.to_text(indent=indent+4) for ext in self._extensions])
+
+
+class ServerHello(HSRecord):
+
+ HELLO_RETRY_RANDOM = binascii.unhexlify(
+ 'CF21AD74E59A6111BE1D8C021E65B891C2A211167ABB8C5E079E09E2C8A8339C'
+ )
+
+ def __init__(self, hsid: int, name: str, data):
+ super().__init__(hsid=hsid, name=name, data=data)
+ d = data
+ d, self._version = _get_int(d, 2)
+ d, self._random = _get_field(d, 32)
+ if self._random == self.HELLO_RETRY_RANDOM:
+ self.name = 'HelloRetryRequest'
+ hsid = 6
+ d, self._session_id = _get_len_field(d, 1)
+ d, cipher = _get_int(d, 2)
+ self._cipher = TlsCipherSuites.name(cipher)
+ d, self._compression = _get_int(d, 1)
+ d, edata = _get_len_field(d, 2)
+ self._extensions = TlsExtensions.from_data(hsid, edata)
+
+ def to_json(self):
+ jdata = super().to_json()
+ jdata['version'] = f'0x{self._version:0x}'
+ jdata['random'] = f'{binascii.hexlify(self._random).decode()}'
+ jdata['session_id'] = binascii.hexlify(self._session_id).decode()
+ jdata['cipher'] = self._cipher
+ jdata['compression'] = int(self._compression)
+ jdata['extensions'] = [ext.to_json() for ext in self._extensions]
+ return jdata
+
+ def to_text(self, indent: int = 0):
+ ind = ' ' * (indent + 2)
+ return super().to_text(indent=indent) + '\n'\
+ f'{ind} version: 0x{self._version:0x}\n'\
+ f'{ind} random: {binascii.hexlify(self._random).decode()}\n' \
+ f'{ind} session_id: {binascii.hexlify(self._session_id).decode()}\n' \
+ f'{ind} cipher: {self._cipher}\n'\
+ f'{ind} compression: {int(self._compression)}\n'\
+ f'{ind} extensions: \n' + '\n'.join(
+ [ext.to_text(indent=indent+4) for ext in self._extensions])
+
+
+class EncryptedExtensions(HSRecord):
+
+ def __init__(self, hsid: int, name: str, data):
+ super().__init__(hsid=hsid, name=name, data=data)
+ d = data
+ d, edata = _get_len_field(d, 2)
+ self._extensions = TlsExtensions.from_data(hsid, edata)
+
+ def to_json(self):
+ jdata = super().to_json()
+ jdata['extensions'] = [ext.to_json() for ext in self._extensions]
+ return jdata
+
+ def to_text(self, indent: int = 0):
+ ind = ' ' * (indent + 2)
+ return super().to_text(indent=indent) + '\n'\
+ f'{ind} extensions: \n' + '\n'.join(
+ [ext.to_text(indent=indent+4) for ext in self._extensions])
+
+
+class CertificateRequest(HSRecord):
+
+ def __init__(self, hsid: int, name: str, data):
+ super().__init__(hsid=hsid, name=name, data=data)
+ d = data
+ d, self._context = _get_int(d, 1)
+ d, edata = _get_len_field(d, 2)
+ self._extensions = TlsExtensions.from_data(hsid, edata)
+
+ def to_json(self):
+ jdata = super().to_json()
+ jdata['context'] = self._context
+ jdata['extensions'] = [ext.to_json() for ext in self._extensions]
+ return jdata
+
+ def to_text(self, indent: int = 0):
+ ind = ' ' * (indent + 2)
+ return super().to_text(indent=indent) + '\n'\
+ f'{ind} context: {self._context}\n'\
+ f'{ind} extensions: \n' + '\n'.join(
+ [ext.to_text(indent=indent+4) for ext in self._extensions])
+
+
+class Certificate(HSRecord):
+
+ def __init__(self, hsid: int, name: str, data):
+ super().__init__(hsid=hsid, name=name, data=data)
+ d = data
+ d, self._context = _get_int(d, 1)
+ d, clist = _get_len_field(d, 3)
+ self._cert_entries = []
+ while len(clist) > 0:
+ clist, cert_data = _get_len_field(clist, 3)
+ clist, cert_exts = _get_len_field(clist, 2)
+ exts = TlsExtensions.from_data(hsid, cert_exts)
+ self._cert_entries.append({
+ 'cert': binascii.hexlify(cert_data).decode(),
+ 'extensions': exts,
+ })
+
+ def to_json(self):
+ jdata = super().to_json()
+ jdata['context'] = self._context
+ jdata['certificate_list'] = [{
+ 'cert': e['cert'],
+ 'extensions': [x.to_json() for x in e['extensions']],
+ } for e in self._cert_entries]
+ return jdata
+
+ def _enxtry_text(self, e, indent: int = 0):
+ ind = ' ' * (indent + 2)
+ return f'{ind} cert: {e["cert"]}\n'\
+ f'{ind} extensions: \n' + '\n'.join(
+ [x.to_text(indent=indent + 4) for x in e['extensions']])
+
+ def to_text(self, indent: int = 0):
+ ind = ' ' * (indent + 2)
+ return super().to_text(indent=indent) + '\n'\
+ f'{ind} context: {self._context}\n'\
+ f'{ind} certificate_list: \n' + '\n'.join(
+ [self._enxtry_text(e, indent+4) for e in self._cert_entries])
+
+
+class SessionTicket(HSRecord):
+
+ def __init__(self, hsid: int, name: str, data):
+ super().__init__(hsid=hsid, name=name, data=data)
+ d = data
+ d, self._lifetime = _get_int(d, 4)
+ d, self._age = _get_int(d, 4)
+ d, self._nonce = _get_len_field(d, 1)
+ d, self._ticket = _get_len_field(d, 2)
+ d, edata = _get_len_field(d, 2)
+ self._extensions = TlsExtensions.from_data(hsid, edata)
+
+ def to_json(self):
+ jdata = super().to_json()
+ jdata['lifetime'] = self._lifetime
+ jdata['age'] = self._age
+ jdata['nonce'] = binascii.hexlify(self._nonce).decode()
+ jdata['ticket'] = binascii.hexlify(self._ticket).decode()
+ jdata['extensions'] = [ext.to_json() for ext in self._extensions]
+ return jdata
+
+
+class HSIterator(Iterator):
+
+ def __init__(self, recs):
+ self._recs = recs
+ self._index = 0
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ try:
+ result = self._recs[self._index]
+ except IndexError:
+ raise StopIteration
+ self._index += 1
+ return result
+
+
+class HandShake:
+ REC_TYPES = [
+ (1, 'ClientHello', ClientHello),
+ (2, 'ServerHello', ServerHello),
+ (3, 'HelloVerifyRequest', HSRecord),
+ (4, 'SessionTicket', SessionTicket),
+ (5, 'EndOfEarlyData', HSRecord),
+ (6, 'HelloRetryRequest', ServerHello),
+ (8, 'EncryptedExtensions', EncryptedExtensions),
+ (11, 'Certificate', Certificate),
+ (12, 'ServerKeyExchange ', HSRecord),
+ (13, 'CertificateRequest', CertificateRequest),
+ (14, 'ServerHelloDone', HSRecord),
+ (15, 'CertificateVerify', HSRecord),
+ (16, 'ClientKeyExchange', HSRecord),
+ (20, 'Finished', HSRecord),
+ (22, 'CertificateStatus', HSRecord),
+ (24, 'KeyUpdate', HSRecord),
+ ]
+ RT_NAME_BY_ID = {}
+ RT_CLS_BY_ID = {}
+
+ @classmethod
+ def _parse_rec(cls, data):
+ d, hsid = _get_int(data, 1)
+ if hsid not in cls.RT_CLS_BY_ID:
+ raise ParseError(f'unknown type {hsid}')
+ d, rec_len = _get_int(d, 3)
+ if rec_len > len(d):
+ # incomplete, need more data
+ return data, None
+ d, rec_data = _get_field(d, rec_len)
+ if hsid in cls.RT_CLS_BY_ID:
+ name = cls.RT_NAME_BY_ID[hsid]
+ rcls = cls.RT_CLS_BY_ID[hsid]
+ else:
+ name = f'CryptoRecord(0x{hsid:0x})'
+ rcls = HSRecord
+ return d, rcls(hsid=hsid, name=name, data=rec_data)
+
+ @classmethod
+ def _parse(cls, source, strict=False, verbose: int = 0):
+ d = b''
+ hsid = 0
+ hsrecs = []
+ if verbose > 0:
+ log.debug(f'scanning for handshake records')
+ blocks = [d for d in source]
+ while len(blocks) > 0:
+ try:
+ total_data = b''.join(blocks)
+ remain, r = cls._parse_rec(total_data)
+ if r is None:
+ # if we could not recognize a record, skip the first
+ # data block and try again
+ blocks = blocks[1:]
+ continue
+ hsrecs.append(r)
+ cons_len = len(total_data) - len(remain)
+ while cons_len > 0 and len(blocks) > 0:
+ if cons_len >= len(blocks[0]):
+ cons_len -= len(blocks[0])
+ blocks = blocks[1:]
+ else:
+ blocks[0] = blocks[0][cons_len:]
+ cons_len = 0
+ if verbose > 2:
+ log.debug(f'added record: {r.to_text()}')
+ except ParseError as err:
+ # if we could not recognize a record, skip the first
+ # data block and try again
+ blocks = blocks[1:]
+ if len(blocks) > 0 and strict:
+ raise Exception(f'possibly incomplete handshake record '
+ f'id={hsid}, from raw={blocks}\n')
+ return hsrecs
+
+
+
+ @classmethod
+ def init(cls):
+ for (hsid, name, rcls) in cls.REC_TYPES:
+ cls.RT_NAME_BY_ID[hsid] = name
+ cls.RT_CLS_BY_ID[hsid] = rcls
+
+ def __init__(self, source: Iterable[bytes], strict: bool = False,
+ verbose: int = 0):
+ self._source = source
+ self._strict = strict
+ self._verbose = verbose
+
+ def __iter__(self):
+ return HSIterator(recs=self._parse(self._source, strict=self._strict,
+ verbose=self._verbose))
+
+
+HandShake.init()