diff options
Diffstat (limited to '')
26 files changed, 3372 insertions, 0 deletions
diff --git a/tests/pytests/README.rst b/tests/pytests/README.rst new file mode 100644 index 0000000..42d85d0 --- /dev/null +++ b/tests/pytests/README.rst @@ -0,0 +1,56 @@ +.. SPDX-License-Identifier: GPL-3.0-or-later + +Python client tests for kresd +============================= + +The tests run `/usr/bin/env kresd` (can be modified with `$PATH`) with custom config +and execute client-side testing, such as TCP / TLS connection management. + +Requirements +------------ + +- pip3 install -r requirements.txt + +Executing tests +--------------- + +Tests can be executed with the pytest framework. + +.. code-block:: bash + + $ pytest-3 # sequential, all tests (with exception of few special tests) + $ pytest-3 test_conn_mgmt.py::test_ignore_garbage # specific test only + $ pytest-3 --html pytests.html --self-contained-html # html report + +It's highly recommended to run these tests in parallel, since lot of them +wait for kresd timeout. This can be done with `python-xdist`: + +.. code-block:: bash + + $ pytest-3 -n 24 # parallel with 24 jobs + +Each test spawns an independent kresd instance, so test failures shouldn't affect +each other. + +Some tests are omitted from automatic test collection by default, due to their +resource contraints. These typicially have to be executed separately by providing +the path to test file directly. + +.. code-block:: bash + + $ pytest-3 conn_flood.py + +Note: some tests may fail without an internet connection. + +Developer notes +--------------- + +Typically, each test requires a setup of kresd, and a connected socket to run tests on. +The framework provides a few useful pytest fixtures to simplify this process: + +- `kresd_sock` provides a connected socket to a test-specific, running kresd instance. + It expands to 4 values (tests) - IPv4 TCP, IPv6 TCP, IPv4 TLS, IPv6 TLS sockets +- `make_kresd_sock` is similar to `kresd_sock`, except it's a factory function that + produces a new connected socket (of the same type) on each call +- `kresd`, `kresd_tt` are all Kresd instances, already running + and initialized with config (with no / valid TLS certificates) diff --git a/tests/pytests/certs/tt-certgen-expired.sh b/tests/pytests/certs/tt-certgen-expired.sh new file mode 100755 index 0000000..23a6978 --- /dev/null +++ b/tests/pytests/certs/tt-certgen-expired.sh @@ -0,0 +1,19 @@ +# !/bin/bash +# SPDX-License-Identifier: GPL-3.0-or-later + +if [ ! -d ./demoCA ]; then + mkdir ./demoCA +fi +if [ ! -d ./demoCA/newcerts ]; then + mkdir ./demoCA/newcerts +fi +touch ./demoCA/index.txt +touch ./demoCA/index.txt.attr +if [ ! -f ./demoCA/serial ]; then + echo 01 > ./demoCA/serial +fi + +openssl genrsa -out tt-expired.key.pem 2048 +openssl req -config tt.conf -new -key tt-expired.key.pem -out tt-expired.csr.pem +openssl ca -config tt.conf -selfsign -keyfile tt-expired.key.pem -out tt-expired.cert.pem -in tt-expired.csr.pem -startdate 19700101000000Z -enddate 19700101000000Z + diff --git a/tests/pytests/certs/tt-certgen.sh b/tests/pytests/certs/tt-certgen.sh new file mode 100755 index 0000000..9414475 --- /dev/null +++ b/tests/pytests/certs/tt-certgen.sh @@ -0,0 +1,5 @@ +# !/bin/sh +# SPDX-License-Identifier: GPL-3.0-or-later + +openssl req -config tt.conf -new -x509 -newkey rsa:2048 -nodes -keyout tt.key.pem -sha256 -out tt.cert.pem -days 20000 + diff --git a/tests/pytests/certs/tt-expired.cert.pem b/tests/pytests/certs/tt-expired.cert.pem new file mode 100644 index 0000000..c9f8c09 --- /dev/null +++ b/tests/pytests/certs/tt-expired.cert.pem @@ -0,0 +1,80 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: 1 (0x1) + Signature Algorithm: sha256WithRSAEncryption + Issuer: C=CZ, ST=PRAGUE, CN=transport-test-server.com + Validity + Not Before: Jan 1 00:00:00 1970 GMT + Not After : Jan 1 00:00:00 1970 GMT + Subject: C=CZ, ST=PRAGUE, CN=transport-test-server.com + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (2048 bit) + Modulus: + 00:bf:6b:1a:11:47:01:ac:eb:5c:2d:cf:ce:6a:a4: + 00:ce:2f:d1:25:03:5f:06:38:02:92:24:18:92:2a: + 69:19:b2:2b:a3:4f:f7:79:de:35:c3:f5:72:37:83: + 44:93:f9:76:fc:89:29:32:9c:0d:4b:95:7d:d1:5d: + 40:e9:ba:49:50:7d:c6:0a:c8:1e:e7:90:1e:37:7c: + 0b:23:a3:e3:bc:c9:53:81:de:d6:5f:cb:b2:3d:36: + ac:59:b0:33:91:8f:0c:5f:10:20:70:bf:a3:22:b3: + 98:ac:d4:7a:ea:67:b8:b1:8c:cf:e5:fe:8f:a0:a5: + 02:ad:6d:ce:f1:62:ab:dc:5d:96:9c:4f:95:47:d5: + 82:b7:b3:e3:87:4c:8d:38:85:2a:24:9d:7f:c7:a4: + 0e:bd:8a:2d:6b:d2:d4:e8:78:62:1b:aa:25:5f:5a: + 64:e5:76:23:ae:11:03:9a:5c:ed:a2:ba:51:ec:b1: + f3:ae:ba:5c:eb:dd:49:63:ca:c7:af:0c:16:1d:94: + 95:3a:ce:2c:8f:e2:94:7f:1f:a1:76:e2:9f:d1:41: + 31:f0:68:e5:ae:df:d0:75:a0:34:f5:25:93:85:b3: + 25:50:42:6c:00:c0:fe:3b:e0:fb:00:de:75:33:86: + 6a:21:35:14:9d:7f:4a:af:f7:15:f2:d7:bb:2f:de: + df:ab + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Basic Constraints: + CA:FALSE + Netscape Comment: + OpenSSL Generated Certificate + X509v3 Subject Key Identifier: + B3:42:0A:9A:00:19:CB:CB:24:A0:02:45:1E:8A:B0:54:CB:9F:55:FE + X509v3 Authority Key Identifier: + keyid:B3:42:0A:9A:00:19:CB:CB:24:A0:02:45:1E:8A:B0:54:CB:9F:55:FE + + Signature Algorithm: sha256WithRSAEncryption + 32:9a:05:e3:6f:ae:ee:b1:a2:12:0a:9f:0a:e7:78:26:df:90: + fb:84:60:ae:13:fc:ff:fd:42:84:23:14:c3:2e:e2:a9:df:4b: + 5c:2f:5b:0e:3d:f9:5a:56:50:13:bc:89:1a:08:70:dd:6c:6c: + e8:ae:cf:22:39:92:f2:3b:40:03:8f:4e:bc:54:88:6b:fd:8c: + b6:eb:30:90:21:db:fc:4e:5c:7e:12:75:e2:52:76:df:19:0f: + 30:49:1e:15:bc:ba:6a:e6:f7:af:93:ad:e4:36:da:47:47:a6: + 88:b0:ae:46:1e:91:e1:d6:b1:5e:a4:f0:68:02:81:57:86:5d: + 17:d1:6c:7e:7a:9f:5e:0d:fc:10:e7:7a:1a:b5:f9:4b:1d:78: + a4:9a:9d:d7:c2:64:c3:52:28:7f:a1:b7:25:d7:13:3f:09:7f: + f2:fd:dd:c6:91:eb:9b:51:80:e2:36:cb:9f:5b:4e:47:eb:77: + d3:cc:8b:18:b5:0b:97:a2:53:8e:fb:9b:94:7d:57:21:32:c6: + f3:67:93:a4:9b:eb:46:b7:cd:08:43:99:dd:c1:c3:51:b9:19: + ef:92:77:1c:84:67:80:67:95:ba:00:75:3d:7b:8b:ff:24:30: + f1:fa:6d:da:31:9d:cf:06:da:5d:04:07:14:45:8c:6b:e7:21: + 31:ec:7b:23 +-----BEGIN CERTIFICATE----- +MIIDfjCCAmagAwIBAgIBATANBgkqhkiG9w0BAQsFADBCMQswCQYDVQQGEwJDWjEP +MA0GA1UECAwGUFJBR1VFMSIwIAYDVQQDDBl0cmFuc3BvcnQtdGVzdC1zZXJ2ZXIu +Y29tMCIYDzE5NzAwMTAxMDAwMDAwWhgPMTk3MDAxMDEwMDAwMDBaMEIxCzAJBgNV +BAYTAkNaMQ8wDQYDVQQIDAZQUkFHVUUxIjAgBgNVBAMMGXRyYW5zcG9ydC10ZXN0 +LXNlcnZlci5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC/axoR +RwGs61wtz85qpADOL9ElA18GOAKSJBiSKmkZsiujT/d53jXD9XI3g0ST+Xb8iSky +nA1LlX3RXUDpuklQfcYKyB7nkB43fAsjo+O8yVOB3tZfy7I9NqxZsDORjwxfECBw +v6Mis5is1HrqZ7ixjM/l/o+gpQKtbc7xYqvcXZacT5VH1YK3s+OHTI04hSoknX/H +pA69ii1r0tToeGIbqiVfWmTldiOuEQOaXO2iulHssfOuulzr3UljysevDBYdlJU6 +ziyP4pR/H6F24p/RQTHwaOWu39B1oDT1JZOFsyVQQmwAwP474PsA3nUzhmohNRSd +f0qv9xXy17sv3t+rAgMBAAGjezB5MAkGA1UdEwQCMAAwLAYJYIZIAYb4QgENBB8W +HU9wZW5TU0wgR2VuZXJhdGVkIENlcnRpZmljYXRlMB0GA1UdDgQWBBSzQgqaABnL +yySgAkUeirBUy59V/jAfBgNVHSMEGDAWgBSzQgqaABnLyySgAkUeirBUy59V/jAN +BgkqhkiG9w0BAQsFAAOCAQEAMpoF42+u7rGiEgqfCud4Jt+Q+4RgrhP8//1ChCMU +wy7iqd9LXC9bDj35WlZQE7yJGghw3Wxs6K7PIjmS8jtAA49OvFSIa/2MtuswkCHb +/E5cfhJ14lJ23xkPMEkeFby6aub3r5Ot5DbaR0emiLCuRh6R4daxXqTwaAKBV4Zd +F9FsfnqfXg38EOd6GrX5Sx14pJqd18Jkw1Iof6G3JdcTPwl/8v3dxpHrm1GA4jbL +n1tOR+t308yLGLULl6JTjvublH1XITLG82eTpJvrRrfNCEOZ3cHDUbkZ75J3HIRn +gGeVugB1PXuL/yQw8fpt2jGdzwbaXQQHFEWMa+chMex7Iw== +-----END CERTIFICATE----- diff --git a/tests/pytests/certs/tt-expired.key.pem b/tests/pytests/certs/tt-expired.key.pem new file mode 100644 index 0000000..ca2988c --- /dev/null +++ b/tests/pytests/certs/tt-expired.key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEAv2saEUcBrOtcLc/OaqQAzi/RJQNfBjgCkiQYkippGbIro0/3 +ed41w/VyN4NEk/l2/IkpMpwNS5V90V1A6bpJUH3GCsge55AeN3wLI6PjvMlTgd7W +X8uyPTasWbAzkY8MXxAgcL+jIrOYrNR66me4sYzP5f6PoKUCrW3O8WKr3F2WnE+V +R9WCt7Pjh0yNOIUqJJ1/x6QOvYota9LU6HhiG6olX1pk5XYjrhEDmlztorpR7LHz +rrpc691JY8rHrwwWHZSVOs4sj+KUfx+hduKf0UEx8Gjlrt/QdaA09SWThbMlUEJs +AMD+O+D7AN51M4ZqITUUnX9Kr/cV8te7L97fqwIDAQABAoIBAEA4ytIpJKLDhHXK +VtLom2ySFnV4oBUSDarCeYvwtrpsUL/GQJ2etCM+4kdFv2h2NjmcOzpDqSJG0aPA +ydqhKZ/b0uojIltGuxyafZJDllDsqxvTi9EwImjvQvwEZgjcGaZ7Xqb1ZOJrpzm1 +QFgM3KaVO9tKgR3Avxk40kmidU7FctFi5IELwnH/RR1OHvJbxOE4+i0LlDx0QzhX +QHtnvHLqLLdqsFk8KvuVuVj1FwqJ6cSL0JrAdt7dnGmXBo4PDqT8Hj0AjM+CcNrV +1D6Li9xr4y55EZUK2qU/FVDC3LqlYQy5mBfasJAXPQG4RgSVFxJ929HC7gi8vMCO +UMeLniECgYEA6gBoRwzQ5pJUXfZGW41lJt08utfycGZm7VrA81r0x0F+DcuZ2t6J +kB9Wnp/MNpB4DJLbl7oM2OlFOO3cw0n3VaFpNMPHVHzNbyi1hp94AIIeDz/sxfUI +Lx7ynAQSPPQzDRfVJesT8waBdweA71TBOlrFQ2Cp7O4Qf+p0akQSv3cCgYEA0Wnd +1Gbierv2m6Jnblg+brTMQwbRsOAM2n0V4Gd2kRaLSYd23ebshvx8xTWipRlrb5vP +UEh+LkfuscqaJDCrikasht9z5FJtfIzHKgTrLSoR3MJRjrnuLJWTQUwSqzd0UNN6 +HigV6p+CqesNnELErak53IMfmkHAhTSkII8R9m0CgYBRY+DhTaDfgegcYouoTm7v +bKYx6uillciZKCbSvkFDiREaJUYXba31ViEfvT8ff3JyFSaSCKFtVP3BxmIx/ukr +fKAGPU54oYwm7Mbu00q/CoMAFOD7HbZCBYanI3dggiO7mx2FOdXPguTHDPIYzKcE +8AuK2vVftpJAm8DwMUtAEwKBgH/eRc5ZGDdbKGS10LQm+9A7Y3IV6to2pIKQ2FfS +tSo4espmBeXPCGQQLdt5OZvYHqril77s1OdLkutKy74HXecr6lLchHZJAoOHrmDw +6e0FAC0tFgGxdEYS+vxnCAs17DciOjHJxkAiL/WzCfd9KXzklOkZw6U8OuLbVtBu +q8gtAoGAbl03XZm+SHrO7XjHK/Fe5YD13cOirg48htvjbpqEDZNQr3l0eVnEj074 +IopDa/wUFlaaqPZ/DVFctqSocyskWIP4u9HfmsNBHjK5zQlge7B1fNVao++YKund +qnVnXjWQuF2aL8k2geFxdSKmHTF4/N1qEyeyR+tMaFpGfMZOuM8= +-----END RSA PRIVATE KEY----- diff --git a/tests/pytests/certs/tt.cert.pem b/tests/pytests/certs/tt.cert.pem new file mode 100644 index 0000000..2ea4898 --- /dev/null +++ b/tests/pytests/certs/tt.cert.pem @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDkTCCAnmgAwIBAgIJAP/KybHquuyUMA0GCSqGSIb3DQEBCwUAMEIxCzAJBgNV +BAYTAkNaMQ8wDQYDVQQIDAZQUkFHVUUxIjAgBgNVBAMMGXRyYW5zcG9ydC10ZXN0 +LXNlcnZlci5jb20wIBcNMTgwNzE4MTAwODIzWhgPMjA3MzA0MjAxMDA4MjNaMEIx +CzAJBgNVBAYTAkNaMQ8wDQYDVQQIDAZQUkFHVUUxIjAgBgNVBAMMGXRyYW5zcG9y +dC10ZXN0LXNlcnZlci5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIB +AQDRAQDAX6+lFKurvm7fgQqm8WyYzT/wxfPJjsVQGe87OlH1KFzVfzYzgEt0RMlM +eZgipREBZB2zK+WFM5RBHWYAwlI5PKt7EAGn8q1Zm4z+M9Uom3/Hy3bZ9q+AJwjk +odpHYuFyWJqHIQBqaQ3SFyJwdZ/GsuzEUfWuIl74oyyMAeykTKFGdaVuIlLC3fKm +8UCnfk99i/LEXUwRcmOV0uaG7deN5ITDDCFdb615yVjLkMhGY/jHK7uuxATOopEk +4vThQ1aQjSkHwluaqFUW6Zl4QF8WOAufoWQPFZ8XxmUYEIG/sMvLv6dol7ltjEbC +bfyzlS+9Qbnq6MfhTZF/4jAPAgMBAAGjgYcwgYQwHQYDVR0OBBYEFBNiUgCiKw4b +CFNKaEkqhkNSer7wMB8GA1UdIwQYMBaAFBNiUgCiKw4bCFNKaEkqhkNSer7wMA8G +A1UdEwEB/wQFMAMBAf8wJAYDVR0RBB0wG4IZdHJhbnNwb3J0LXRlc3Qtc2VydmVy +LmNvbTALBgNVHQ8EBAMCAaYwDQYJKoZIhvcNAQELBQADggEBACz1ZQ8XkGobhTcA +hkSTSw0ko6qwVuJJD5ue3SUcWLATsskohTJmN6bde3IMDRyQvLJAlMdG2p1qMbtA +OTbnQJTT7oDLaW8w2D+eO5oWTJvxLpl6TxbIfJN/8ITB1ltOCxTU9cVNbd2eh8sj +l3R4etg9djYRrqtNxCQZOYSwvhHw2MefnwjGVuJEu6JYOn3IE8Jqsh/LI59C87nE +MetZrXlzC6kSAFfRYgQET9RhBobMU9yFR8zGVHDFoxqNQs2lYKPz/3rFPetL2rjT +cFwzxkxDdwn+RNisBc1LMfIg7pvSMFR6sAnpjeRHN0Uoem1jj2qtzjbFENDuyQ4/ +HSi4UcE= +-----END CERTIFICATE----- diff --git a/tests/pytests/certs/tt.conf b/tests/pytests/certs/tt.conf new file mode 100644 index 0000000..f011e5a --- /dev/null +++ b/tests/pytests/certs/tt.conf @@ -0,0 +1,353 @@ +# SPDX-License-Identifier: CC0-1.0 +# OpenSSL example configuration file. +# This is mostly being used for generation of certificate requests. +# + +# This definition stops the following lines choking if HOME isn't +# defined. +HOME = . +RANDFILE = $ENV::HOME/.rnd + +# Extra OBJECT IDENTIFIER info: +#oid_file = $ENV::HOME/.oid +oid_section = new_oids + +# To use this configuration file with the "-extfile" option of the +# "openssl x509" utility, name here the section containing the +# X.509v3 extensions to use: +# extensions = +# (Alternatively, use a configuration file that has only +# X.509v3 extensions in its main [= default] section.) + +[ new_oids ] + +# We can add new OIDs in here for use by 'ca', 'req' and 'ts'. +# Add a simple OID like this: +# testoid1=1.2.3.4 +# Or use config file substitution like this: +# testoid2=${testoid1}.5.6 + +# Policies used by the TSA examples. +tsa_policy1 = 1.2.3.4.1 +tsa_policy2 = 1.2.3.4.5.6 +tsa_policy3 = 1.2.3.4.5.7 + +#################################################################### +[ ca ] +default_ca = CA_default # The default ca section + +#################################################################### +[ CA_default ] + +dir = ./demoCA # Where everything is kept +certs = $dir/certs # Where the issued certs are kept +crl_dir = $dir/crl # Where the issued crl are kept +database = $dir/index.txt # database index file. +#unique_subject = no # Set to 'no' to allow creation of + # several certs with same subject. +new_certs_dir = $dir/newcerts # default place for new certs. + +certificate = $dir/cacert.pem # The CA certificate +serial = $dir/serial # The current serial number +crlnumber = $dir/crlnumber # the current crl number + # must be commented out to leave a V1 CRL +crl = $dir/crl.pem # The current CRL +private_key = $dir/private/cakey.pem# The private key +RANDFILE = $dir/private/.rand # private random number file + +x509_extensions = usr_cert # The extensions to add to the cert + +# Comment out the following two lines for the "traditional" +# (and highly broken) format. +name_opt = ca_default # Subject Name options +cert_opt = ca_default # Certificate field options + +# Extension copying option: use with caution. +copy_extensions = copy + +# Extensions to add to a CRL. Note: Netscape communicator chokes on V2 CRLs +# so this is commented out by default to leave a V1 CRL. +# crlnumber must also be commented out to leave a V1 CRL. +# crl_extensions = crl_ext + +default_days = 365 # how long to certify for +default_crl_days= 30 # how long before next CRL +default_md = default # use public key default MD +preserve = no # keep passed DN ordering + +# A few difference way of specifying how similar the request should look +# For type CA, the listed attributes must be the same, and the optional +# and supplied fields are just that :-) +policy = policy_match + +# For the CA policy +[ policy_match ] +countryName = optional +stateOrProvinceName = optional +organizationName = optional +organizationalUnitName = optional +commonName = supplied +emailAddress = optional + +# For the 'anything' policy +# At this point in time, you must list all acceptable 'object' +# types. +[ policy_anything ] +countryName = optional +stateOrProvinceName = optional +localityName = optional +organizationName = optional +organizationalUnitName = optional +commonName = supplied +emailAddress = optional + +#################################################################### +[ req ] +default_bits = 2048 +default_keyfile = privkey.pem +distinguished_name = req_distinguished_name +attributes = req_attributes +x509_extensions = v3_ca # The extensions to add to the self signed cert + +# Passwords for private keys if not present they will be prompted for +# input_password = secret +# output_password = secret + +# This sets a mask for permitted string types. There are several options. +# default: PrintableString, T61String, BMPString. +# pkix : PrintableString, BMPString (PKIX recommendation before 2004) +# utf8only: only UTF8Strings (PKIX recommendation after 2004). +# nombstr : PrintableString, T61String (no BMPStrings or UTF8Strings). +# MASK:XXXX a literal mask value. +# WARNING: ancient versions of Netscape crash on BMPStrings or UTF8Strings. +string_mask = utf8only + +# req_extensions = v3_req # The extensions to add to a certificate request + +[ req_distinguished_name ] +countryName = Country Name (2 letter code) +countryName_default = CZ +countryName_min = 2 +countryName_max = 2 + +stateOrProvinceName = State or Province Name (full name) +stateOrProvinceName_default = PRAGUE + +localityName = Locality Name (eg, city) + +0.organizationName = Organization Name (eg, company) +0.organizationName_default = + +# we can do this but it is not needed normally :-) +#1.organizationName = Second Organization Name (eg, company) +#1.organizationName_default = World Wide Web Pty Ltd + +organizationalUnitName = Organizational Unit Name (eg, section) +#organizationalUnitName_default = + +commonName = Common Name (e.g. server FQDN or YOUR name) +commonName_max = 64 +commonName_default = transport-test-server.com + +emailAddress = Email Address +emailAddress_max = 64 + +# SET-ex3 = SET extension number 3 + +[ req_attributes ] +challengePassword = A challenge password +challengePassword_min = 4 +challengePassword_max = 20 + +unstructuredName = An optional company name + +[ usr_cert ] + +# These extensions are added when 'ca' signs a request. + +# This goes against PKIX guidelines but some CAs do it and some software +# requires this to avoid interpreting an end user certificate as a CA. + +basicConstraints=CA:FALSE + +# Here are some examples of the usage of nsCertType. If it is omitted +# the certificate can be used for anything *except* object signing. + +# This is OK for an SSL server. +# nsCertType = server + +# For an object signing certificate this would be used. +# nsCertType = objsign + +# For normal client use this is typical +# nsCertType = client, email + +# and for everything including object signing: +# nsCertType = client, email, objsign + +# This is typical in keyUsage for a client certificate. +# keyUsage = nonRepudiation, digitalSignature, keyEncipherment + +# This will be displayed in Netscape's comment listbox. +nsComment = "OpenSSL Generated Certificate" + +# PKIX recommendations harmless if included in all certificates. +subjectKeyIdentifier=hash +authorityKeyIdentifier=keyid,issuer + +# This stuff is for subjectAltName and issuerAltname. +# Import the email address. +# subjectAltName=email:copy +# An alternative to produce certificates that aren't +# deprecated according to PKIX. +# subjectAltName=email:move + +# Copy subject details +# issuerAltName=issuer:copy + +#nsCaRevocationUrl = http://www.domain.dom/ca-crl.pem +#nsBaseUrl +#nsRevocationUrl +#nsRenewalUrl +#nsCaPolicyUrl +#nsSslServerName + +# This is required for TSA certificates. +# extendedKeyUsage = critical,timeStamping + +[ v3_req ] + +# Extensions to add to a certificate request + +basicConstraints = CA:FALSE +keyUsage = nonRepudiation, digitalSignature, keyEncipherment + +[ v3_ca ] + + +# Extensions for a typical CA + + +# PKIX recommendation. + +subjectKeyIdentifier=hash + +authorityKeyIdentifier=keyid:always,issuer + +basicConstraints = critical,CA:true + +subjectAltName = @alternate_names + +# Key usage: this is typical for a CA certificate. However since it will +# prevent it being used as an test self-signed certificate it is best +# left out by default. +keyUsage = digitalSignature, keyEncipherment, cRLSign, keyCertSign + +# Some might want this also +# nsCertType = sslCA, emailCA + +# Include email address in subject alt name: another PKIX recommendation +# subjectAltName=email:copy +# Copy issuer details +# issuerAltName=issuer:copy + +# DER hex encoding of an extension: beware experts only! +# obj=DER:02:03 +# Where 'obj' is a standard or added object +# You can even override a supported extension: +# basicConstraints= critical, DER:30:03:01:01:FF + +[ crl_ext ] + +# CRL extensions. +# Only issuerAltName and authorityKeyIdentifier make any sense in a CRL. + +# issuerAltName=issuer:copy +authorityKeyIdentifier=keyid:always + +[ proxy_cert_ext ] +# These extensions should be added when creating a proxy certificate + +# This goes against PKIX guidelines but some CAs do it and some software +# requires this to avoid interpreting an end user certificate as a CA. + +basicConstraints=CA:FALSE + +# Here are some examples of the usage of nsCertType. If it is omitted +# the certificate can be used for anything *except* object signing. + +# This is OK for an SSL server. +# nsCertType = server + +# For an object signing certificate this would be used. +# nsCertType = objsign + +# For normal client use this is typical +# nsCertType = client, email + +# and for everything including object signing: +# nsCertType = client, email, objsign + +# This is typical in keyUsage for a client certificate. +# keyUsage = nonRepudiation, digitalSignature, keyEncipherment + +# This will be displayed in Netscape's comment listbox. +nsComment = "OpenSSL Generated Certificate" + +# PKIX recommendations harmless if included in all certificates. +subjectKeyIdentifier=hash +authorityKeyIdentifier=keyid,issuer + +# This stuff is for subjectAltName and issuerAltname. +# Import the email address. +# subjectAltName=email:copy +# An alternative to produce certificates that aren't +# deprecated according to PKIX. +# subjectAltName=email:move + +# Copy subject details +# issuerAltName=issuer:copy + +#nsCaRevocationUrl = http://www.domain.dom/ca-crl.pem +#nsBaseUrl +#nsRevocationUrl +#nsRenewalUrl +#nsCaPolicyUrl +#nsSslServerName + +# This really needs to be in place for it to be a proxy certificate. +proxyCertInfo=critical,language:id-ppl-anyLanguage,pathlen:3,policy:foo + +#################################################################### +[ tsa ] + +default_tsa = tsa_config1 # the default TSA section + +[ tsa_config1 ] + +# These are used by the TSA reply generation only. +dir = ./demoCA # TSA root directory +serial = $dir/tsaserial # The current serial number (mandatory) +crypto_device = builtin # OpenSSL engine to use for signing +signer_cert = $dir/tsacert.pem # The TSA signing certificate + # (optional) +certs = $dir/cacert.pem # Certificate chain to include in reply + # (optional) +signer_key = $dir/private/tsakey.pem # The TSA private key (optional) +signer_digest = sha256 # Signing digest to use. (Optional) +default_policy = tsa_policy1 # Policy if request did not specify it + # (optional) +other_policies = tsa_policy2, tsa_policy3 # acceptable policies (optional) +digests = sha1, sha256, sha384, sha512 # Acceptable message digests (mandatory) +accuracy = secs:1, millisecs:500, microsecs:100 # (optional) +clock_precision_digits = 0 # number of digits after dot. (optional) +ordering = yes # Is ordering defined for timestamps? + # (optional, default: no) +tsa_name = yes # Must the TSA name be included in the reply? + # (optional, default: no) +ess_cert_id_chain = no # Must the ESS cert id chain be included? + # (optional, default: no) + +[ alternate_names ] + +DNS.1 = transport-test-server.com diff --git a/tests/pytests/certs/tt.key.pem b/tests/pytests/certs/tt.key.pem new file mode 100644 index 0000000..1974be7 --- /dev/null +++ b/tests/pytests/certs/tt.key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDRAQDAX6+lFKur +vm7fgQqm8WyYzT/wxfPJjsVQGe87OlH1KFzVfzYzgEt0RMlMeZgipREBZB2zK+WF +M5RBHWYAwlI5PKt7EAGn8q1Zm4z+M9Uom3/Hy3bZ9q+AJwjkodpHYuFyWJqHIQBq +aQ3SFyJwdZ/GsuzEUfWuIl74oyyMAeykTKFGdaVuIlLC3fKm8UCnfk99i/LEXUwR +cmOV0uaG7deN5ITDDCFdb615yVjLkMhGY/jHK7uuxATOopEk4vThQ1aQjSkHwlua +qFUW6Zl4QF8WOAufoWQPFZ8XxmUYEIG/sMvLv6dol7ltjEbCbfyzlS+9Qbnq6Mfh +TZF/4jAPAgMBAAECggEBALSs10d18FMW0WjAUPxpgxnaLnTRSesMVLjy8ONT6Bkd +S2hRIh91vxc6WwABzrqLita4N0EqmPoggmNpuUmo7lrNoWLVbbAOoD/da7nA3FuL +10MpWYcP/ohh1klEdU2gFSAM/LNqoPsbrk5OzqHFWgI5zItqdX8pEucb01nBRWsp +VMY2vzVuFB2jweZQ5+LCpfSMcRIzlxQa9CG4Peu6YW1Z4b3aUcS63/829JN/ZOGd +uoRqR+gP71yNIt6i7wA5cot5FRmzlFEGhb1XzBOB1FFHOiknOZzbBtDsGUUmVtfA +6mXcTumhdHbC0bXnHei/s2s9X0EeyQFYPkoS4NUQ2dECgYEA/7lhgn89K8rpUPnS +eccTpKVPWp8luQei98Hi/F94kwP32l7Zl7Bmu2nltUoB1GBRXoXY6KzTphmT6ioA +8joLCKIii5/nOdZAdHbIN2tkXS56h524q5I2jKogjfRrpCaAJE8x99f8L9uTBfZb +/7BBQDHai1/S6LcpIRf/4g1/xBsCgYEA0Tq4V5hR9mGDUFir1FDGhA3ijDkIE/sO +3QGTU7W90BL27te98FuQtWOPqfd1fi26WypNpNQUZb3V4x5tmDcpWscfj6I10432 +4zECPlDgaevucJjj245U7WjUhdAvlRy6K8H/8MgRBAjw9h8dwIGIx9gmOqKdA+/h +ve3xyjKQex0CgYAz0XzQ1LewiA1/OyBLTOvOETFjS5x5QfLkAYXdXfswzz0KIu40 +rqoij/LcKYL1Zg8W+Ehb3amFnuk6KgjHDLvvo+scH+ra7W9iKi+oCzrrJt/tWyhw +m9Ax8Mdn/H9TY/nTYbjeYAXaLMQ+EQ3TYgPW3kNKusAiJ/tNmW9gfxvEwQKBgGSJ +Rbj5fTDZjGKYKQDdS3Z6wYhFg0culObHcgaARtPruPHtgtwy82blj0vJl5Bo4qoZ +urNgIOj+ff8jSOAiaWGwWs8Gz7x289IZY42UCTF8Z9d878g5LT/i5nPiJGsPIboS +/yuwxtRcg4SQURiGZbY5e60jJDWXF67O3icdguVVAoGARXLufXvZ/9Xf1DmFFxjq +PJMCa1sfofqjB4KqYbt17vFtTsddCiyqsbpx36oY6nIdm9yUiGo10YaSEJtDEGLS +L3TPZ4s8M8dcjOfj8Kk75pKbJ7NY4qA64dtbxcZbrFp3/mGZkDing94y+Zc/aFqa +xQsA/yhmYV9r+FHDL54Cn6I= +-----END PRIVATE KEY----- diff --git a/tests/pytests/conftest.py b/tests/pytests/conftest.py new file mode 100644 index 0000000..4c711f8 --- /dev/null +++ b/tests/pytests/conftest.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +import socket + +import pytest + +from kresd import init_portdir, make_kresd + + +@pytest.fixture +def kresd(tmpdir): + with make_kresd(tmpdir) as kresd: + yield kresd + + +@pytest.fixture +def kresd_silent(tmpdir): + with make_kresd(tmpdir, verbose=False) as kresd: + yield kresd + + +@pytest.fixture +def kresd_tt(tmpdir): + with make_kresd(tmpdir, 'tt') as kresd: + yield kresd + + +@pytest.fixture(params=[ + 'ip_tcp_socket', + 'ip6_tcp_socket', + 'ip_tls_socket', + 'ip6_tls_socket', +]) +def make_kresd_sock(request, kresd): + """Factory function to create sockets of the same kind.""" + sock_func = getattr(kresd, request.param) + + def _make_kresd_sock(): + return sock_func() + + return _make_kresd_sock + + +@pytest.fixture(params=[ + 'ip_tcp_socket', + 'ip6_tcp_socket', + 'ip_tls_socket', + 'ip6_tls_socket', +]) +def make_kresd_silent_sock(request, kresd_silent): + """Factory function to create sockets of the same kind (no verbose).""" + sock_func = getattr(kresd_silent, request.param) + + def _make_kresd_sock(): + return sock_func() + + return _make_kresd_sock + + +@pytest.fixture +def kresd_sock(make_kresd_sock): + return make_kresd_sock() + + +@pytest.fixture(params=[ + socket.AF_INET, + socket.AF_INET6, +]) +def sock_family(request): + return request.param + + +@pytest.fixture(params=[ + True, + False +]) +def single_buffer(request): # whether to send all data in a single buffer + return request.param + + +@pytest.fixture(params=[ + True, + False +]) +def query_before(request): # whether to send an initial query + return request.param + + +@pytest.mark.optionalhook +def pytest_metadata(metadata): # filter potentially sensitive data from GitLab CI + keys_to_delete = [] + for key in metadata.keys(): + key_lower = key.lower() + if 'password' in key_lower or 'token' in key_lower or \ + key_lower.startswith('ci') or key_lower.startswith('gitlab'): + keys_to_delete.append(key) + for key in keys_to_delete: + del metadata[key] + + +def pytest_sessionstart(session): # pylint: disable=unused-argument + init_portdir() diff --git a/tests/pytests/conn_flood.py b/tests/pytests/conn_flood.py new file mode 100644 index 0000000..037ba7c --- /dev/null +++ b/tests/pytests/conn_flood.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Test opening as many connections as possible. + +Due to resource-intensity of this test, it's filename doesn't contain +"test" on purpose, so it doesn't automatically get picked up by pytest +(to allow easy parallel testing). + +To execute this test, pass the filename of this file to pytest directly. +Also, make sure not to use parallel execution (-n). +""" + +import resource +import time + +import pytest + +from kresd import Kresd +import utils + + +MAX_SOCKETS = 10000 # upper bound of how many connections to open +MAX_ITERATIONS = 10 # number of iterations to run the test + +# we can't use softlimit ifself since kresd already has open sockets, +# so use lesser value +RESERVED_NOFILE = 40 # 40 is empirical value + + +@pytest.mark.parametrize('sock_func_name', [ + 'ip_tcp_socket', + 'ip6_tcp_socket', + 'ip_tls_socket', + 'ip6_tls_socket', +]) +def test_conn_flood(tmpdir, sock_func_name): + def create_sockets(make_sock, nsockets): + sockets = [] + next_ping = time.time() + 4 # less than tcp idle timeout / 2 + while True: + additional_sockets = 0 + while time.time() < next_ping: + nsock_to_init = min(100, nsockets - len(sockets)) + if not nsock_to_init: + return sockets + sockets.extend([make_sock() for _ in range(nsock_to_init)]) + additional_sockets += nsock_to_init + + # large number of connections can take a lot of time to open + # send some valid data to avoid TCP idle timeout for already open sockets + next_ping = time.time() + 4 + for s in sockets: + utils.ping_alive(s) + + # break when no more than 20% additional sockets are created + if additional_sockets / len(sockets) < 0.2: + return sockets + + max_num_of_open_files = resource.getrlimit(resource.RLIMIT_NOFILE)[0] - RESERVED_NOFILE + nsockets = min(max_num_of_open_files, MAX_SOCKETS) + + # create kresd instance with verbose=False + ip = '127.0.0.1' + ip6 = '::1' + with Kresd(tmpdir, ip=ip, ip6=ip6, verbose=False) as kresd: + make_sock = getattr(kresd, sock_func_name) # function for creating sockets + sockets = create_sockets(make_sock, nsockets) + print("\nEstablished {} connections".format(len(sockets))) + + print("Start sending data") + for i in range(MAX_ITERATIONS): + for s in sockets: + utils.ping_alive(s) + print("Iteration {} done...".format(i)) + + print("Close connections") + for s in sockets: + s.close() + + # check in kresd is alive + print("Check upstream is still alive") + sock = make_sock() + utils.ping_alive(sock) + + print("OK!") diff --git a/tests/pytests/kresd.py b/tests/pytests/kresd.py new file mode 100644 index 0000000..f190003 --- /dev/null +++ b/tests/pytests/kresd.py @@ -0,0 +1,306 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +from collections import namedtuple +from contextlib import ContextDecorator, contextmanager +import os +from pathlib import Path +import random +import re +import shutil +import socket +import subprocess +import time + +import jinja2 + +import utils + + +PYTESTS_DIR = os.path.dirname(os.path.realpath(__file__)) +CERTS_DIR = os.path.join(PYTESTS_DIR, 'certs') +TEMPLATES_DIR = os.path.join(PYTESTS_DIR, 'templates') +KRESD_CONF_TEMPLATE = 'kresd.conf.j2' +KRESD_STARTUP_MSGID = 10005 # special unique ID at the start of the "test" log +KRESD_PORTDIR = '/tmp/pytest-kresd-portdir' +KRESD_TESTPORT_MIN = 1024 +KRESD_TESTPORT_MAX = 32768 # avoid overlap with docker ephemeral port range + + +def init_portdir(): + try: + shutil.rmtree(KRESD_PORTDIR) + except FileNotFoundError: + pass + os.makedirs(KRESD_PORTDIR) + + +def create_file_from_template(template_path, dest, data): + env = jinja2.Environment( + loader=jinja2.FileSystemLoader(TEMPLATES_DIR)) + template = env.get_template(template_path) + rendered_template = template.render(**data) + + with open(dest, "w") as fh: + fh.write(rendered_template) + + +Forward = namedtuple('Forward', ['proto', 'ip', 'port', 'hostname', 'ca_file']) + + +class Kresd(ContextDecorator): + def __init__( + self, workdir, port=None, tls_port=None, ip=None, ip6=None, certname=None, + verbose=True, hints=None, forward=None, policy_test_pass=False): + if ip is None and ip6 is None: + raise ValueError("IPv4 or IPv6 must be specified!") + self.workdir = str(workdir) + self.port = port + self.tls_port = tls_port + self.ip = ip + self.ip6 = ip6 + self.process = None + self.sockets = [] + self.logfile = None + self.verbose = verbose + self.hints = {} if hints is None else hints + self.forward = forward + self.policy_test_pass = policy_test_pass + + if certname: + self.tls_cert_path = os.path.join(CERTS_DIR, certname + '.cert.pem') + self.tls_key_path = os.path.join(CERTS_DIR, certname + '.key.pem') + else: + self.tls_cert_path = None + self.tls_key_path = None + + @property + def config_path(self): + return str(os.path.join(self.workdir, 'kresd.conf')) + + @property + def logfile_path(self): + return str(os.path.join(self.workdir, 'kresd.log')) + + def __enter__(self): + if self.port is not None: + take_port(self.port, self.ip, self.ip6, timeout=120) + else: + self.port = make_port(self.ip, self.ip6) + if self.tls_port is not None: + take_port(self.tls_port, self.ip, self.ip6, timeout=120) + else: + self.tls_port = make_port(self.ip, self.ip6) + + create_file_from_template(KRESD_CONF_TEMPLATE, self.config_path, {'kresd': self}) + self.logfile = open(self.logfile_path, 'w') + self.process = subprocess.Popen( + ['kresd', '-c', self.config_path, '-n', self.workdir], + stdout=self.logfile, env=os.environ.copy()) + + try: + self._wait_for_tcp_port() # wait for ports to be up and responding + if not self.all_ports_alive(msgid=10001): + raise RuntimeError("Kresd not listening on all ports") + + # issue special msgid to mark start of test log + sock = self.ip_tcp_socket() if self.ip else self.ip6_tcp_socket() + assert utils.try_ping_alive(sock, close=True, msgid=KRESD_STARTUP_MSGID) + + # sanity check - kresd didn't crash + self.process.poll() + if self.process.returncode is not None: + raise RuntimeError("Kresd crashed with returncode: {}".format( + self.process.returncode)) + except (RuntimeError, ConnectionError): # pylint: disable=try-except-raise + with open(self.logfile_path) as log: # print log for debugging + print(log.read()) + raise + + return self + + def __exit__(self, exc_type, exc_value, traceback): + try: + if not self.all_ports_alive(msgid=1006): + raise RuntimeError("Kresd crashed") + finally: + for sock in self.sockets: + sock.close() + self.process.terminate() + self.logfile.close() + Path(KRESD_PORTDIR, str(self.port)).unlink() + + def all_ports_alive(self, msgid=10001): + alive = True + if self.ip: + alive &= utils.try_ping_alive(self.ip_tcp_socket(), close=True, msgid=msgid) + alive &= utils.try_ping_alive(self.ip_tls_socket(), close=True, msgid=msgid + 1) + if self.ip6: + alive &= utils.try_ping_alive(self.ip6_tcp_socket(), close=True, msgid=msgid + 2) + alive &= utils.try_ping_alive(self.ip6_tls_socket(), close=True, msgid=msgid + 3) + return alive + + def _wait_for_tcp_port(self, max_delay=10, delay_step=0.2): + family = socket.AF_INET if self.ip else socket.AF_INET6 + i = 0 + end_time = time.time() + max_delay + + while time.time() < end_time: + i += 1 + + # use exponential backoff algorhitm to choose next delay + rand_delay = random.randrange(0, i) + time.sleep(rand_delay * delay_step) + + try: + sock, dest = self.stream_socket(family, timeout=5) + sock.connect(dest) + except ConnectionRefusedError: + continue + else: + try: + return utils.try_ping_alive(sock, close=True, msgid=10000) + except socket.timeout: + continue + finally: + sock.close() + raise RuntimeError("Kresd didn't start in time {}".format(dest)) + + def socket_dest(self, family, tls=False): + port = self.tls_port if tls else self.port + if family == socket.AF_INET: + return self.ip, port + elif family == socket.AF_INET6: + return self.ip6, port, 0, 0 + raise RuntimeError("Unsupported socket family: {}".format(family)) + + def stream_socket(self, family, tls=False, timeout=20): + """Initialize a socket and return it along with the destination without connecting.""" + sock = socket.socket(family, socket.SOCK_STREAM) + sock.settimeout(timeout) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + dest = self.socket_dest(family, tls) + self.sockets.append(sock) + return sock, dest + + def _tcp_socket(self, family): + sock, dest = self.stream_socket(family) + sock.connect(dest) + return sock + + def ip_tcp_socket(self): + return self._tcp_socket(socket.AF_INET) + + def ip6_tcp_socket(self): + return self._tcp_socket(socket.AF_INET6) + + def _tls_socket(self, family): + sock, dest = self.stream_socket(family, tls=True) + ctx = utils.make_ssl_context(insecure=True) + ssock = ctx.wrap_socket(sock) + try: + ssock.connect(dest) + except OSError as exc: + if exc.errno == 0: # sometimes happens shortly after startup + return None + return ssock + + def _tls_socket_with_retry(self, family): + sock = self._tls_socket(family) + if sock is None: + time.sleep(0.1) + sock = self._tls_socket(family) + if sock is None: + raise RuntimeError("Failed to create TLS socket!") + return sock + + def ip_tls_socket(self): + return self._tls_socket_with_retry(socket.AF_INET) + + def ip6_tls_socket(self): + return self._tls_socket_with_retry(socket.AF_INET6) + + def partial_log(self): + partial_log = '\n (... ommiting log start)\n' + with open(self.logfile_path) as log: # display partial log for debugging + past_startup_msgid = False + past_startup = False + for line in log: + if past_startup: + partial_log += line + else: # find real start of test log (after initial alive-pings) + if not past_startup_msgid: + if re.match(KRESD_LOG_STARTUP_MSGID, line) is not None: + past_startup_msgid = True + else: + if re.match(KRESD_LOG_IO_CLOSE, line) is not None: + past_startup = True + return partial_log + + +def is_port_free(port, ip=None, ip6=None): + def check(family, type_, dest): + sock = socket.socket(family, type_) + sock.bind(dest) + sock.close() + + try: + if ip is not None: + check(socket.AF_INET, socket.SOCK_STREAM, (ip, port)) + check(socket.AF_INET, socket.SOCK_DGRAM, (ip, port)) + if ip6 is not None: + check(socket.AF_INET6, socket.SOCK_STREAM, (ip6, port, 0, 0)) + check(socket.AF_INET6, socket.SOCK_DGRAM, (ip6, port, 0, 0)) + except OSError as exc: + if exc.errno == 98: # address alrady in use + return False + else: + raise + return True + + +def take_port(port, ip=None, ip6=None, timeout=0): + port_path = Path(KRESD_PORTDIR, str(port)) + end_time = time.time() + timeout + try: + port_path.touch(exist_ok=False) + except FileExistsError as ex: + raise ValueError( + "Port {} already reserved by system or another kresd instance!".format(port)) from ex + + while True: + if is_port_free(port, ip, ip6): + # NOTE: The port_path isn't removed, so other instances don't have to attempt to + # take the same port again. This has the side effect of leaving many of these + # files behind, because when another kresd shuts down and removes its file, the + # port still can't be reserved for a while. This shouldn't become an issue unless + # we have thousands of tests (and run out of the port range). + break + + if time.time() < end_time: + time.sleep(5) + else: + raise ValueError( + "Port {} is reserved by system!".format(port)) + return port + + +def make_port(ip=None, ip6=None): + for _ in range(10): # max attempts + port = random.randint(KRESD_TESTPORT_MIN, KRESD_TESTPORT_MAX) + try: + take_port(port, ip, ip6) + except ValueError: + continue # port reserved by system / another kresd instance + return port + raise RuntimeError("No available port found!") + + +KRESD_LOG_STARTUP_MSGID = re.compile(r'^\[{}.*'.format(KRESD_STARTUP_MSGID)) +KRESD_LOG_IO_CLOSE = re.compile(r'^\[io\].*closed by peer.*') + + +@contextmanager +def make_kresd(workdir, certname=None, ip='127.0.0.1', ip6='::1', **kwargs): + with Kresd(workdir, ip=ip, ip6=ip6, certname=certname, **kwargs) as kresd: + yield kresd + print(kresd.partial_log()) diff --git a/tests/pytests/meson.build b/tests/pytests/meson.build new file mode 100644 index 0000000..217fc02 --- /dev/null +++ b/tests/pytests/meson.build @@ -0,0 +1,75 @@ +# tests: pytests +# SPDX-License-Identifier: GPL-3.0-or-later + +# python 3 dependencies +py3_deps += [ + ['jinja2', 'jinja2 (for pytests)'], + ['dns', 'dnspython (for pytests)'], + ['pytest', 'pytest (for pytests)'], + ['pytest_html', 'pytest-html (for pytests)'], + ['xdist', 'pytest-xdist (for pytests)'], +] + +if gnutls.version().version_compare('<3.6.4') + error('pytests require GnuTLS >= 3.6.4') +endif + +# compile tlsproxy +tlsproxy_src = files([ + 'proxy/tlsproxy.c', + 'proxy/tls-proxy.c', +]) +tlsproxy = executable( + 'tlsproxy', + tlsproxy_src, + dependencies: [ + libkres_dep, + libuv, + gnutls, + ], +) + +# path to kresd and tlsproxy +pytests_env = environment() +pytests_env.prepend('PATH', sbin_dir, meson.current_build_dir()) + +test( + 'pytests.parallel', + python3, + args: [ + '-m', 'pytest', + '-d', + '--html', 'pytests.parallel.html', + '--self-contained-html', + '-n', '24', + '-v', + ], + env: pytests_env, + suite: [ + 'postinstall', + 'pytests', + ], + workdir: meson.current_source_dir(), + is_parallel: false, + timeout: 180, + depends: tlsproxy, +) + +test( + 'pytests.single', + python3, + args: [ + '-m', 'pytest', + '-ra', + '--capture=no', + 'conn_flood.py', + ], + env: pytests_env, + suite: [ + 'postinstall', + 'pytests', + ], + workdir: meson.current_source_dir(), + is_parallel: false, + timeout: 240, +) diff --git a/tests/pytests/proxy.py b/tests/pytests/proxy.py new file mode 100644 index 0000000..a55542b --- /dev/null +++ b/tests/pytests/proxy.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: GPL-3.0-or-later + +from contextlib import contextmanager, ContextDecorator +import os +import subprocess +from typing import Any, Dict, Optional + +import dns +import dns.rcode +import pytest + +from kresd import CERTS_DIR, Forward, Kresd, make_kresd, make_port +import utils + + +HINTS = { + '0.foo.': '127.0.0.1', + '1.foo.': '127.0.0.1', + '2.foo.': '127.0.0.1', + '3.foo.': '127.0.0.1', +} + + +def resolve_hint(sock, qname): + buff, msgid = utils.get_msgbuff(qname) + sock.sendall(buff) + answer = utils.receive_parse_answer(sock) + assert answer.id == msgid + assert answer.rcode() == dns.rcode.NOERROR + assert answer.answer[0][0].address == HINTS[qname] + + +class Proxy(ContextDecorator): + EXECUTABLE = '' + + def __init__( + self, + local_ip: str = '127.0.0.1', + local_port: Optional[int] = None, + upstream_ip: str = '127.0.0.1', + upstream_port: Optional[int] = None + ) -> None: + self.local_ip = local_ip + self.local_port = local_port + self.upstream_ip = upstream_ip + self.upstream_port = upstream_port + self.proxy = None + + def get_args(self): + args = [] + args.append('--local') + args.append(self.local_ip) + if self.local_port is not None: + args.append('--lport') + args.append(str(self.local_port)) + args.append('--upstream') + args.append(self.upstream_ip) + if self.upstream_port is not None: + args.append('--uport') + args.append(str(self.upstream_port)) + return args + + def __enter__(self): + args = [self.EXECUTABLE] + self.get_args() + print(' '.join(args)) + + try: + self.proxy = subprocess.Popen( + args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + except subprocess.CalledProcessError: + pytest.skip("proxy '{}' failed to run (did you compile it?)" + .format(self.EXECUTABLE)) + + return self + + def __exit__(self, exc_type, exc_value, traceback): + if self.proxy is not None: + self.proxy.terminate() + self.proxy = None + + +class TLSProxy(Proxy): + EXECUTABLE = 'tlsproxy' + + def __init__( + self, + local_ip: str = '127.0.0.1', + local_port: Optional[int] = None, + upstream_ip: str = '127.0.0.1', + upstream_port: Optional[int] = None, + certname: Optional[str] = 'tt', + close: Optional[int] = None, + rehandshake: bool = False, + force_tls13: bool = False + ) -> None: + super().__init__(local_ip, local_port, upstream_ip, upstream_port) + if certname is not None: + self.cert_path = os.path.join(CERTS_DIR, certname + '.cert.pem') + self.key_path = os.path.join(CERTS_DIR, certname + '.key.pem') + else: + self.cert_path = None + self.key_path = None + self.close = close + self.rehandshake = rehandshake + self.force_tls13 = force_tls13 + + def get_args(self): + args = super().get_args() + if self.cert_path is not None: + args.append('--cert') + args.append(self.cert_path) + if self.key_path is not None: + args.append('--key') + args.append(self.key_path) + if self.close is not None: + args.append('--close') + args.append(str(self.close)) + if self.rehandshake: + args.append('--rehandshake') + if self.force_tls13: + args.append('--tls13') + return args + + +@contextmanager +def kresd_tls_client( + workdir: str, + proxy: TLSProxy, + kresd_tls_client_kwargs: Optional[Dict[Any, Any]] = None, + kresd_fwd_target_kwargs: Optional[Dict[Any, Any]] = None + ) -> Kresd: + """kresd_tls_client --(tls)--> tlsproxy --(tcp)--> kresd_fwd_target""" + ALLOWED_IPS = {'127.0.0.1', '::1'} + assert proxy.local_ip in ALLOWED_IPS, "only localhost IPs supported for proxy" + assert proxy.upstream_ip in ALLOWED_IPS, "only localhost IPs are supported for proxy" + + if kresd_tls_client_kwargs is None: + kresd_tls_client_kwargs = dict() + if kresd_fwd_target_kwargs is None: + kresd_fwd_target_kwargs = dict() + + # run forward target instance + dir1 = os.path.join(workdir, 'kresd_fwd_target') + os.makedirs(dir1) + + with make_kresd(dir1, hints=HINTS, **kresd_fwd_target_kwargs) as kresd_fwd_target: + sock = kresd_fwd_target.ip_tcp_socket() + resolve_hint(sock, list(HINTS.keys())[0]) + + proxy.local_port = make_port('127.0.0.1', '::1') + proxy.upstream_port = kresd_fwd_target.port + + with proxy: + # run test kresd instance + dir2 = os.path.join(workdir, 'kresd_tls_client') + os.makedirs(dir2) + forward = Forward( + proto='tls', ip=proxy.local_ip, port=proxy.local_port, + hostname='transport-test-server.com', ca_file=proxy.cert_path) + with make_kresd(dir2, forward=forward, **kresd_tls_client_kwargs) as kresd: + yield kresd diff --git a/tests/pytests/proxy/tls-proxy.c b/tests/pytests/proxy/tls-proxy.c new file mode 100644 index 0000000..1002a82 --- /dev/null +++ b/tests/pytests/proxy/tls-proxy.c @@ -0,0 +1,1038 @@ +/* SPDX-License-Identifier: GPL-3.0-or-later */ + +#include <assert.h> +#include <stdio.h> +#include <unistd.h> +#include <string.h> +#include <stdlib.h> +#include <stdbool.h> +#include <gnutls/gnutls.h> +#include <uv.h> +#include "lib/generic/array.h" +#include "tls-proxy.h" + +#define TLS_MAX_SEND_RETRIES 100 +#define CLIENT_ANSWER_CHUNK_SIZE 8 + +#define MAX_CLIENT_PENDING_SIZE 4096 + +struct buf { + size_t size; + char buf[]; +}; + +enum peer_state { + STATE_NOT_CONNECTED, + STATE_LISTENING, + STATE_CONNECTED, + STATE_CONNECT_IN_PROGRESS, + STATE_CLOSING_IN_PROGRESS +}; + +enum handshake_state { + TLS_HS_NOT_STARTED = 0, + TLS_HS_EXPECTED, + TLS_HS_REAUTH_EXPECTED, + TLS_HS_IN_PROGRESS, + TLS_HS_DONE, + TLS_HS_CLOSING, + TLS_HS_LAST +}; + +struct tls_ctx { + gnutls_session_t session; + enum handshake_state handshake_state; + /* for reading from the network */ + const uint8_t *buf; + ssize_t nread; + ssize_t consumed; + uint8_t recv_buf[4096]; +}; + +struct peer { + uv_tcp_t handle; + enum peer_state state; + struct sockaddr_storage addr; + array_t(struct buf *) pending_buf; + uint64_t connection_timestamp; + struct tls_ctx *tls; + struct peer *peer; + int active_requests; +}; + +struct tls_proxy_ctx { + const struct args *a; + uv_loop_t *loop; + gnutls_certificate_credentials_t tls_credentials; + gnutls_priority_t tls_priority_cache; + struct { + uv_tcp_t handle; + struct sockaddr_storage addr; + } server; + struct sockaddr_storage upstream_addr; + array_t(struct peer *) client_list; + char uv_wire_buf[65535 * 2]; + int conn_sequence; +}; + +static void read_from_upstream_cb(uv_stream_t *upstream, ssize_t nread, const uv_buf_t *buf); +static void read_from_client_cb(uv_stream_t *client, ssize_t nread, const uv_buf_t *buf); +static ssize_t proxy_gnutls_pull(gnutls_transport_ptr_t h, void *buf, size_t len); +static ssize_t proxy_gnutls_push(gnutls_transport_ptr_t h, const void *buf, size_t len); +static int tls_process_from_upstream(struct peer *upstream, const uint8_t *buf, ssize_t nread); +static int tls_process_from_client(struct peer *client, const uint8_t *buf, ssize_t nread); +static int write_to_upstream_pending(struct peer *peer); +static int write_to_client_pending(struct peer *peer); +static void on_client_close(uv_handle_t *handle); +static void on_upstream_close(uv_handle_t *handle); + +static int gnutls_references = 0; + +static const char * const tlsv12_priorities = + "NORMAL:" /* GnuTLS defaults */ + "-VERS-TLS1.0:-VERS-TLS1.1:+VERS-TLS1.2:-VERS-TLS1.3:" /* TLS 1.2 only */ + "-VERS-SSL3.0:-ARCFOUR-128:-COMP-ALL:+COMP-NULL"; + +static const char * const tlsv13_priorities = + "NORMAL:" /* GnuTLS defaults */ + "-VERS-TLS1.0:-VERS-TLS1.1:-VERS-TLS1.2:+VERS-TLS1.3:" /* TLS 1.3 only */ + "-VERS-SSL3.0:-ARCFOUR-128:-COMP-ALL:+COMP-NULL"; + +static struct tls_proxy_ctx *get_proxy(struct peer *peer) +{ + return (struct tls_proxy_ctx *)peer->handle.loop->data; +} + +const void *ip_addr(const struct sockaddr *addr) +{ + if (!addr) { + return NULL; + } + switch (addr->sa_family) { + case AF_INET: return (const void *)&(((const struct sockaddr_in *)addr)->sin_addr); + case AF_INET6: return (const void *)&(((const struct sockaddr_in6 *)addr)->sin6_addr); + default: return NULL; + } +} + +uint16_t ip_addr_port(const struct sockaddr *addr) +{ + if (!addr) { + return 0; + } + switch (addr->sa_family) { + case AF_INET: return ntohs(((const struct sockaddr_in *)addr)->sin_port); + case AF_INET6: return ntohs(((const struct sockaddr_in6 *)addr)->sin6_port); + default: return 0; + } +} + +static int ip_addr_str(const struct sockaddr *addr, char *buf, size_t *buflen) +{ + int ret = 0; + if (!addr || !buf || !buflen) { + return EINVAL; + } + + char str[INET6_ADDRSTRLEN + 6]; + if (!inet_ntop(addr->sa_family, ip_addr(addr), str, sizeof(str))) { + return errno; + } + int len = strlen(str); + str[len] = '#'; + snprintf(&str[len + 1], 6, "%hu", ip_addr_port(addr)); + len += 6; + str[len] = 0; + if (len >= *buflen) { + ret = ENOSPC; + } else { + memcpy(buf, str, len + 1); + } + *buflen = len; + return ret; +} + +static inline char *ip_straddr(const struct sockaddr_storage *saddr_storage) +{ + assert(saddr_storage != NULL); + const struct sockaddr *addr = (const struct sockaddr *)saddr_storage; + /* We are the sinle-threaded application */ + static char str[INET6_ADDRSTRLEN + 6]; + size_t len = sizeof(str); + int ret = ip_addr_str(addr, str, &len); + return ret != 0 || len == 0 ? NULL : str; +} + +static struct buf *alloc_io_buffer(size_t size) +{ + struct buf *buf = calloc(1, sizeof (struct buf) + size); + buf->size = size; + return buf; +} + +static void free_io_buffer(struct buf *buf) +{ + if (!buf) { + return; + } + free(buf); +} + +static struct buf *get_first_pending_buf(struct peer *peer) +{ + struct buf *buf = NULL; + if (peer->pending_buf.len > 0) { + buf = peer->pending_buf.at[0]; + } + return buf; +} + +static struct buf *remove_first_pending_buf(struct peer *peer) +{ + if (peer->pending_buf.len == 0) { + return NULL; + } + struct buf * buf = peer->pending_buf.at[0]; + for (int i = 1; i < peer->pending_buf.len; ++i) { + peer->pending_buf.at[i - 1] = peer->pending_buf.at[i]; + } + peer->pending_buf.len -= 1; + return buf; +} + +static void clear_pending_bufs(struct peer *peer) +{ + for (int i = 0; i < peer->pending_buf.len; ++i) { + struct buf *b = peer->pending_buf.at[i]; + free_io_buffer(b); + } + peer->pending_buf.len = 0; +} + +static void alloc_uv_buffer(uv_handle_t *handle, size_t suggested_size, uv_buf_t *buf) +{ + struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)handle->loop->data; + buf->base = proxy->uv_wire_buf; + buf->len = sizeof(proxy->uv_wire_buf); +} + +static void on_client_close(uv_handle_t *handle) +{ + struct peer *client = (struct peer *)handle->data; + struct peer *upstream = client->peer; + fprintf(stdout, "[client] connection with '%s' closed\n", ip_straddr(&client->addr)); + assert(client->tls); + gnutls_deinit(client->tls->session); + client->tls->handshake_state = TLS_HS_NOT_STARTED; + client->state = STATE_NOT_CONNECTED; + if (upstream->state != STATE_NOT_CONNECTED) { + if (upstream->state == STATE_CONNECTED) { + fprintf(stdout, "[client] closing connection with upstream for '%s'\n", + ip_straddr(&client->addr)); + upstream->state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&upstream->handle, on_upstream_close); + } + return; + } + struct tls_proxy_ctx *proxy = get_proxy(client); + for (size_t i = 0; i < proxy->client_list.len; ++i) { + struct peer *client_i = proxy->client_list.at[i]; + if (client_i == client) { + fprintf(stdout, "[client] connection structures deallocated for '%s'\n", + ip_straddr(&client->addr)); + array_del(proxy->client_list, i); + free(client->tls); + free(client); + break; + } + } +} + +static void on_upstream_close(uv_handle_t *handle) +{ + struct peer *upstream = (struct peer *)handle->data; + struct peer *client = upstream->peer; + assert(upstream->tls == NULL); + upstream->state = STATE_NOT_CONNECTED; + fprintf(stdout, "[upstream] connection with upstream closed for client '%s'\n", ip_straddr(&client->addr)); + if (client->state != STATE_NOT_CONNECTED) { + if (client->state == STATE_CONNECTED) { + fprintf(stdout, "[upstream] closing connection to client '%s'\n", + ip_straddr(&client->addr)); + client->state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&client->handle, on_client_close); + } + return; + } + struct tls_proxy_ctx *proxy = get_proxy(upstream); + for (size_t i = 0; i < proxy->client_list.len; ++i) { + struct peer *client_i = proxy->client_list.at[i]; + if (client_i == client) { + fprintf(stdout, "[upstream] connection structures deallocated for '%s'\n", + ip_straddr(&client->addr)); + array_del(proxy->client_list, i); + free(upstream); + free(client->tls); + free(client); + break; + } + } +} + +static void write_to_client_cb(uv_write_t *req, int status) +{ + struct peer *client = (struct peer *)req->handle->data; + free(req); + client->active_requests -= 1; + if (status) { + fprintf(stdout, "[client] error writing to client '%s': %s\n", + ip_straddr(&client->addr), uv_strerror(status)); + clear_pending_bufs(client); + if (client->state == STATE_CONNECTED) { + client->state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&client->handle, on_client_close); + return; + } + } + fprintf(stdout, "[client] successfully wrote to client '%s', pending len is %zd, active requests %i\n", + ip_straddr(&client->addr), client->pending_buf.len, client->active_requests); + if (client->state == STATE_CONNECTED && + client->tls->handshake_state == TLS_HS_DONE) { + struct tls_proxy_ctx *proxy = get_proxy(client); + uint64_t elapsed = uv_now(proxy->loop) - client->connection_timestamp; + if (!proxy->a->close_connection || elapsed < proxy->a->close_timeout) { + write_to_client_pending(client); + } else { + clear_pending_bufs(client); + client->state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&client->handle, on_client_close); + fprintf(stdout, "[client] closing connection to client '%s'\n", ip_straddr(&client->addr)); + } + } +} + +static void write_to_upstream_cb(uv_write_t *req, int status) +{ + struct peer *upstream = (struct peer *)req->handle->data; + void *data = req->data; + free(req); + if (status) { + fprintf(stdout, "[upstream] error writing to upstream: %s\n", uv_strerror(status)); + clear_pending_bufs(upstream); + upstream->state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&upstream->handle, on_upstream_close); + return; + } + if (data != NULL) { + assert(upstream->pending_buf.len > 0); + struct buf *buf = get_first_pending_buf(upstream); + assert(data == (void *)buf->buf); + fprintf(stdout, "[upstream] successfully wrote %zi bytes to upstream, pending len is %zd\n", + buf->size, upstream->pending_buf.len); + remove_first_pending_buf(upstream); + free_io_buffer(buf); + } else { + fprintf(stdout, "[upstream] successfully wrote to upstream, pending len is %zd\n", + upstream->pending_buf.len); + } + if (upstream->peer == NULL || upstream->peer->state != STATE_CONNECTED) { + clear_pending_bufs(upstream); + } else if (upstream->state == STATE_CONNECTED && upstream->pending_buf.len > 0) { + write_to_upstream_pending(upstream); + } +} + +static void accept_connection_from_client(uv_stream_t *server) +{ + struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)server->loop->data; + struct peer *client = calloc(1, sizeof(struct peer)); + uv_tcp_init(proxy->loop, &client->handle); + uv_tcp_nodelay((uv_tcp_t *)&client->handle, 1); + + int err = uv_accept(server, (uv_stream_t*)&client->handle); + if (err != 0) { + fprintf(stdout, "[client] incoming connection - uv_accept() failed: (%d) %s\n", + err, uv_strerror(err)); + proxy->conn_sequence = 0; + return; + } + + client->state = STATE_CONNECTED; + array_init(client->pending_buf); + client->handle.data = client; + + struct peer *upstream = calloc(1, sizeof(struct peer)); + uv_tcp_init(proxy->loop, &upstream->handle); + uv_tcp_nodelay((uv_tcp_t *)&upstream->handle, 1); + + client->peer = upstream; + + array_init(upstream->pending_buf); + upstream->state = STATE_NOT_CONNECTED; + upstream->peer = client; + upstream->handle.data = upstream; + + struct sockaddr *addr = (struct sockaddr *)&(client->addr); + int addr_len = sizeof(client->addr); + int ret = uv_tcp_getpeername(&client->handle, addr, &addr_len); + if (ret || addr->sa_family == AF_UNSPEC) { + client->state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&client->handle, on_client_close); + fprintf(stdout, "[client] incoming connection - uv_tcp_getpeername() failed: (%d) %s\n", + err, uv_strerror(err)); + proxy->conn_sequence = 0; + return; + } + memcpy(&upstream->addr, &proxy->upstream_addr, sizeof(struct sockaddr_storage)); + + struct tls_ctx *tls = calloc(1, sizeof(struct tls_ctx)); + tls->handshake_state = TLS_HS_NOT_STARTED; + + client->tls = tls; + const char *errpos = NULL; + unsigned int gnutls_flags = GNUTLS_SERVER | GNUTLS_NONBLOCK; +#if GNUTLS_VERSION_NUMBER >= 0x030604 + if (proxy->a->tls_13) { + gnutls_flags |= GNUTLS_POST_HANDSHAKE_AUTH; + } +#endif + err = gnutls_init(&tls->session, gnutls_flags); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stdout, "[client] gnutls_init() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + } + err = gnutls_priority_set(tls->session, proxy->tls_priority_cache); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stdout, "[client] gnutls_priority_set() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + } + + const char *direct_priorities = proxy->a->tls_13 ? tlsv13_priorities : tlsv12_priorities; + err = gnutls_priority_set_direct(tls->session, direct_priorities, &errpos); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stdout, "[client] setting priority '%s' failed at character %zd (...'%s') with %s (%d)\n", + direct_priorities, errpos - direct_priorities, errpos, + gnutls_strerror_name(err), err); + } + err = gnutls_credentials_set(tls->session, GNUTLS_CRD_CERTIFICATE, proxy->tls_credentials); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stdout, "[client] gnutls_credentials_set() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + } + if (proxy->a->tls_13) { + gnutls_certificate_server_set_request(tls->session, GNUTLS_CERT_REQUEST); + } else { + gnutls_certificate_server_set_request(tls->session, GNUTLS_CERT_IGNORE); + } + gnutls_handshake_set_timeout(tls->session, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT); + gnutls_transport_set_pull_function(tls->session, proxy_gnutls_pull); + gnutls_transport_set_push_function(tls->session, proxy_gnutls_push); + gnutls_transport_set_ptr(tls->session, client); + + tls->handshake_state = TLS_HS_IN_PROGRESS; + + client->connection_timestamp = uv_now(proxy->loop); + proxy->conn_sequence += 1; + array_push(proxy->client_list, client); + + fprintf(stdout, "[client] incoming connection from '%s'\n", ip_straddr(&client->addr)); + uv_read_start((uv_stream_t*)&client->handle, alloc_uv_buffer, read_from_client_cb); +} + +static void dynamic_handle_close_cb(uv_handle_t *handle) +{ + free(handle); +} + +static void delayed_accept_timer_cb(uv_timer_t *timer) +{ + uv_stream_t *server = (uv_stream_t *)timer->data; + fprintf(stdout, "[client] delayed connection processing\n"); + accept_connection_from_client(server); + uv_close((uv_handle_t *)timer, dynamic_handle_close_cb); +} + +static void on_client_connection(uv_stream_t *server, int status) +{ + if (status < 0) { + fprintf(stdout, "[client] incoming connection error: %s\n", uv_strerror(status)); + return; + } + struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)server->loop->data; + proxy->conn_sequence += 1; + if (proxy->a->max_conn_sequence > 0 && + proxy->conn_sequence > proxy->a->max_conn_sequence) { + fprintf(stdout, "[client] incoming connection, delaying\n"); + uv_timer_t *timer = (uv_timer_t*)malloc(sizeof *timer); + uv_timer_init(uv_default_loop(), timer); + timer->data = server; + uv_timer_start(timer, delayed_accept_timer_cb, 10000, 0); + proxy->conn_sequence = 0; + } else { + accept_connection_from_client(server); + } +} + +static void on_connect_to_upstream(uv_connect_t *req, int status) +{ + struct peer *upstream = (struct peer *)req->handle->data; + free(req); + if (status < 0) { + fprintf(stdout, "[upstream] error connecting to upstream (%s): %s\n", + ip_straddr(&upstream->addr), + uv_strerror(status)); + clear_pending_bufs(upstream); + upstream->state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&upstream->handle, on_upstream_close); + return; + } + fprintf(stdout, "[upstream] connected to %s\n", ip_straddr(&upstream->addr)); + + upstream->state = STATE_CONNECTED; + uv_read_start((uv_stream_t*)&upstream->handle, alloc_uv_buffer, read_from_upstream_cb); + if (upstream->pending_buf.len > 0) { + write_to_upstream_pending(upstream); + } +} + +static void read_from_client_cb(uv_stream_t *handle, ssize_t nread, const uv_buf_t *buf) +{ + if (nread == 0) { + fprintf(stdout, "[client] reading %zd bytes\n", nread); + return; + } + struct peer *client = (struct peer *)handle->data; + if (nread < 0) { + if (nread != UV_EOF) { + fprintf(stdout, "[client] error reading from '%s': %s\n", + ip_straddr(&client->addr), + uv_err_name(nread)); + } else { + fprintf(stdout, "[client] closing connection with '%s'\n", + ip_straddr(&client->addr)); + } + if (client->state == STATE_CONNECTED) { + client->state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)handle, on_client_close); + } + return; + } + + struct tls_proxy_ctx *proxy = get_proxy(client); + if (proxy->a->accept_only) { + fprintf(stdout, "[client] ignoring %zd bytes from '%s'\n", nread, ip_straddr(&client->addr)); + return; + } + fprintf(stdout, "[client] reading %zd bytes from '%s'\n", nread, ip_straddr(&client->addr)); + + int res = tls_process_from_client(client, (const uint8_t *)buf->base, nread); + if (res < 0) { + if (client->state == STATE_CONNECTED) { + client->state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&client->handle, on_client_close); + } + } +} + +static void read_from_upstream_cb(uv_stream_t *handle, ssize_t nread, const uv_buf_t *buf) +{ + fprintf(stdout, "[upstream] reading %zd bytes\n", nread); + if (nread == 0) { + return; + } + struct peer *upstream = (struct peer *)handle->data; + if (nread < 0) { + if (nread != UV_EOF) { + fprintf(stdout, "[upstream] error reading from upstream: %s\n", uv_err_name(nread)); + } else { + fprintf(stdout, "[upstream] closing connection\n"); + } + clear_pending_bufs(upstream); + if (upstream->state == STATE_CONNECTED) { + upstream->state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&upstream->handle, on_upstream_close); + } + return; + } + int res = tls_process_from_upstream(upstream, (const uint8_t *)buf->base, nread); + if (res < 0) { + fprintf(stdout, "[upstream] error processing tls data to client\n"); + if (upstream->peer->state == STATE_CONNECTED) { + upstream->peer->state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&upstream->peer->handle, on_client_close); + } + } +} + +static void push_to_upstream_pending(struct peer *upstream, const char *buf, size_t size) +{ + struct buf *b = alloc_io_buffer(size); + memcpy(b->buf, buf, b->size); + array_push(upstream->pending_buf, b); +} + +static void push_to_client_pending(struct peer *client, const char *buf, size_t size) +{ + struct tls_proxy_ctx *proxy = get_proxy(client); + while (size > 0) { + int temp_size = size; + if (proxy->a->rehandshake && temp_size > CLIENT_ANSWER_CHUNK_SIZE) { + temp_size = CLIENT_ANSWER_CHUNK_SIZE; + } + struct buf *b = alloc_io_buffer(temp_size); + memcpy(b->buf, buf, b->size); + array_push(client->pending_buf, b); + size -= temp_size; + buf += temp_size; + } +} + +static int write_to_upstream_pending(struct peer *upstream) +{ + struct buf *buf = get_first_pending_buf(upstream); + uv_write_t *req = (uv_write_t *) malloc(sizeof(uv_write_t)); + uv_buf_t wrbuf = uv_buf_init(buf->buf, buf->size); + req->data = buf->buf; + fprintf(stdout, "[upstream] writing %zd bytes\n", buf->size); + return uv_write(req, (uv_stream_t *)&upstream->handle, &wrbuf, 1, write_to_upstream_cb); +} + +static ssize_t proxy_gnutls_pull(gnutls_transport_ptr_t h, void *buf, size_t len) +{ + struct peer *peer = (struct peer *)h; + struct tls_ctx *t = peer->tls; + + fprintf(stdout, "[gnutls_pull] pulling %zd bytes\n", len); + + if (t->nread <= t->consumed) { + errno = EAGAIN; + fprintf(stdout, "[gnutls_pull] return EAGAIN\n"); + return -1; + } + + ssize_t avail = t->nread - t->consumed; + ssize_t transfer = (avail <= len ? avail : len); + memcpy(buf, t->buf + t->consumed, transfer); + t->consumed += transfer; + return transfer; +} + +ssize_t proxy_gnutls_push(gnutls_transport_ptr_t h, const void *buf, size_t len) +{ + struct peer *client = (struct peer *)h; + fprintf(stdout, "[gnutls_push] writing %zd bytes\n", len); + + ssize_t ret = -1; + const size_t req_size_aligned = ((sizeof(uv_write_t) / 16) + 1) * 16; + char *common_buf = malloc(req_size_aligned + len); + uv_write_t *req = (uv_write_t *) common_buf; + char *data = common_buf + req_size_aligned; + const uv_buf_t uv_buf[1] = { + { data, len } + }; + memcpy(data, buf, len); + int res = uv_write(req, (uv_stream_t *)&client->handle, uv_buf, 1, write_to_client_cb); + if (res == 0) { + ret = len; + client->active_requests += 1; + } else { + free(common_buf); + errno = EIO; + } + return ret; +} + +static int write_to_client_pending(struct peer *client) +{ + if (client->pending_buf.len == 0) { + return 0; + } + + struct tls_proxy_ctx *proxy = get_proxy(client); + struct buf *buf = get_first_pending_buf(client); + fprintf(stdout, "[client] writing %zd bytes\n", buf->size); + + gnutls_session_t tls_session = client->tls->session; + assert(client->tls->handshake_state != TLS_HS_IN_PROGRESS); + + char *data = buf->buf; + size_t len = buf->size; + + ssize_t count = 0; + ssize_t submitted = len; + ssize_t retries = 0; + do { + count = gnutls_record_send(tls_session, data, len); + if (count < 0) { + if (gnutls_error_is_fatal(count)) { + fprintf(stdout, "[client] gnutls_record_send failed: %s (%zd)\n", + gnutls_strerror_name(count), count); + return -1; + } + if (++retries > TLS_MAX_SEND_RETRIES) { + fprintf(stdout, "[client] gnutls_record_send: too many sequential non-fatal errors (%zd), last error is: %s (%zd)\n", + retries, gnutls_strerror_name(count), count); + return -1; + } + } else if (count != 0) { + data += count; + len -= count; + retries = 0; + } else { + if (++retries < TLS_MAX_SEND_RETRIES) { + continue; + } + fprintf(stdout, "[client] gnutls_record_send: too many retries (%zd)\n", + retries); + fprintf(stdout, "[client] tls_push_to_client didn't send all data(%zd of %zd)\n", + len, submitted); + return -1; + } + } while (len > 0); + + remove_first_pending_buf(client); + free_io_buffer(buf); + + fprintf(stdout, "[client] submitted %zd bytes\n", submitted); + if (proxy->a->rehandshake) { + int err = GNUTLS_E_SUCCESS; +#if GNUTLS_VERSION_NUMBER >= 0x030604 + if (proxy->a->tls_13) { + int flags = gnutls_session_get_flags(tls_session); + if ((flags & GNUTLS_SFLAGS_POST_HANDSHAKE_AUTH) == 0) { + /* Client doesn't support post-handshake re-authentication, + * nothing to test here */ + fprintf(stdout, "[client] GNUTLS_SFLAGS_POST_HANDSHAKE_AUTH flag not detected\n"); + assert(false); + } + err = gnutls_reauth(tls_session, 0); + if (err != GNUTLS_E_INTERRUPTED && + err != GNUTLS_E_AGAIN && + err != GNUTLS_E_GOT_APPLICATION_DATA) { + fprintf(stdout, "[client] gnutls_reauth() failed: %s (%i)\n", + gnutls_strerror_name(err), err); + } else { + fprintf(stdout, "[client] post-handshake authentication initiated\n"); + } + client->tls->handshake_state = TLS_HS_REAUTH_EXPECTED; + } else { + assert (gnutls_safe_renegotiation_status(tls_session) != 0); + err = gnutls_rehandshake(tls_session); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stdout, "[client] gnutls_rehandshake() failed: %s (%i)\n", + gnutls_strerror_name(err), err); + assert(false); + } else { + fprintf(stdout, "[client] rehandshake started\n"); + } + client->tls->handshake_state = TLS_HS_EXPECTED; + } +#else + assert (gnutls_safe_renegotiation_status(tls_session) != 0); + err = gnutls_rehandshake(tls_session); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stdout, "[client] gnutls_rehandshake() failed: %s (%i)\n", + gnutls_strerror_name(err), err); + assert(false); + } else { + fprintf(stdout, "[client] rehandshake started\n"); + } + /* Prevent write-to-client callback from sending next pending chunk. + * At the same time tls_process_from_client() must not call gnutls_handshake() + * as there can be application data in this direction. */ + client->tls->handshake_state = TLS_HS_EXPECTED; +#endif + } + return submitted; +} + +static int tls_process_from_upstream(struct peer *upstream, const uint8_t *buf, ssize_t len) +{ + struct peer *client = upstream->peer; + + fprintf(stdout, "[upstream] pushing %zd bytes to client\n", len); + + ssize_t submitted = 0; + if (client->state != STATE_CONNECTED) { + return submitted; + } + + bool list_was_empty = (client->pending_buf.len == 0); + push_to_client_pending(client, (const char *)buf, len); + submitted = len; + if (client->tls->handshake_state == TLS_HS_DONE) { + if (list_was_empty && client->pending_buf.len > 0) { + int ret = write_to_client_pending(client); + if (ret < 0) { + submitted = -1; + } + } + } + + return submitted; +} + +int tls_process_handshake(struct peer *peer) +{ + struct tls_ctx *tls = peer->tls; + int ret = 1; + while (tls->handshake_state == TLS_HS_IN_PROGRESS) { + fprintf(stdout, "[tls] TLS handshake in progress...\n"); + int err = gnutls_handshake(tls->session); + if (err == GNUTLS_E_SUCCESS) { + tls->handshake_state = TLS_HS_DONE; + fprintf(stdout, "[tls] TLS handshake has completed\n"); + ret = 1; + if (peer->pending_buf.len != 0) { + write_to_client_pending(peer); + } + } else if (gnutls_error_is_fatal(err)) { + fprintf(stdout, "[tls] gnutls_handshake failed: %s (%d)\n", + gnutls_strerror_name(err), err); + ret = -1; + break; + } else { + fprintf(stdout, "[tls] gnutls_handshake nonfatal error: %s (%d)\n", + gnutls_strerror_name(err), err); + ret = 0; + break; + } + } + return ret; +} + +#if GNUTLS_VERSION_NUMBER >= 0x030604 +int tls_process_reauth(struct peer *peer) +{ + struct tls_ctx *tls = peer->tls; + int ret = 1; + while (tls->handshake_state == TLS_HS_REAUTH_EXPECTED) { + fprintf(stdout, "[tls] TLS re-authentication in progress...\n"); + int err = gnutls_reauth(tls->session, 0); + if (err == GNUTLS_E_SUCCESS) { + tls->handshake_state = TLS_HS_DONE; + fprintf(stdout, "[tls] TLS re-authentication has completed\n"); + ret = 1; + if (peer->pending_buf.len != 0) { + write_to_client_pending(peer); + } + } else if (err != GNUTLS_E_INTERRUPTED && + err != GNUTLS_E_AGAIN && + err != GNUTLS_E_GOT_APPLICATION_DATA) { + /* these are listed as nonfatal errors there + * https://www.gnutls.org/manual/gnutls.html#gnutls_005freauth */ + fprintf(stdout, "[tls] gnutls_reauth failed: %s (%d)\n", + gnutls_strerror_name(err), err); + ret = -1; + break; + } else { + fprintf(stdout, "[tls] gnutls_reauth nonfatal error: %s (%d)\n", + gnutls_strerror_name(err), err); + ret = 0; + break; + } + } + return ret; +} +#endif + +int tls_process_from_client(struct peer *client, const uint8_t *buf, ssize_t nread) +{ + struct tls_ctx *tls = client->tls; + + tls->buf = buf; + tls->nread = nread >= 0 ? nread : 0; + tls->consumed = 0; + + fprintf(stdout, "[client] tls_process: reading %zd bytes from client\n", nread); + + int ret = 0; + if (tls->handshake_state == TLS_HS_REAUTH_EXPECTED) { + ret = tls_process_reauth(client); + } else { + ret = tls_process_handshake(client); + } + if (ret <= 0) { + return ret; + } + + int submitted = 0; + while (true) { + ssize_t count = gnutls_record_recv(tls->session, tls->recv_buf, sizeof(tls->recv_buf)); + if (count == GNUTLS_E_AGAIN) { + break; /* No data available */ + } else if (count == GNUTLS_E_INTERRUPTED) { + continue; /* Try reading again */ + } else if (count == GNUTLS_E_REHANDSHAKE) { + tls->handshake_state = TLS_HS_IN_PROGRESS; + ret = tls_process_handshake(client); + if (ret < 0) { /* Critical error */ + return ret; + } + if (ret == 0) { /* Non fatal, most likely GNUTLS_E_AGAIN */ + break; + } + continue; + } +#if GNUTLS_VERSION_NUMBER >= 0x030604 + else if (count == GNUTLS_E_REAUTH_REQUEST) { + assert(false); + tls->handshake_state = TLS_HS_IN_PROGRESS; + ret = tls_process_reauth(client); + if (ret < 0) { /* Critical error */ + return ret; + } + if (ret == 0) { /* Non fatal, most likely GNUTLS_E_AGAIN */ + break; + } + continue; + } +#endif + else if (count < 0) { + fprintf(stdout, "[client] gnutls_record_recv failed: %s (%zd)\n", + gnutls_strerror_name(count), count); + assert(false); + return -1; + } else if (count == 0) { + break; + } + struct peer *upstream = client->peer; + if (upstream->state == STATE_CONNECTED) { + bool upstream_pending_is_empty = (upstream->pending_buf.len == 0); + push_to_upstream_pending(upstream, (const char *)tls->recv_buf, count); + if (upstream_pending_is_empty) { + write_to_upstream_pending(upstream); + } + } else if (upstream->state == STATE_NOT_CONNECTED) { + uv_connect_t *conn = (uv_connect_t *) malloc(sizeof(uv_connect_t)); + upstream->state = STATE_CONNECT_IN_PROGRESS; + fprintf(stdout, "[client] connecting to upstream '%s'\n", ip_straddr(&upstream->addr)); + uv_tcp_connect(conn, &upstream->handle, (struct sockaddr *)&upstream->addr, + on_connect_to_upstream); + push_to_upstream_pending(upstream, (const char *)tls->recv_buf, count); + } else if (upstream->state == STATE_CONNECT_IN_PROGRESS) { + push_to_upstream_pending(upstream, (const char *)tls->recv_buf, count); + } + submitted += count; + } + return submitted; +} + +struct tls_proxy_ctx *tls_proxy_allocate() +{ + return malloc(sizeof(struct tls_proxy_ctx)); +} + +int tls_proxy_init(struct tls_proxy_ctx *proxy, const struct args *a) +{ + const char *server_addr = a->local_addr; + int server_port = a->local_port; + const char *upstream_addr = a->upstream; + int upstream_port = a->upstream_port; + const char *cert_file = a->cert_file; + const char *key_file = a->key_file; + proxy->a = a; + proxy->loop = uv_default_loop(); + uv_tcp_init(proxy->loop, &proxy->server.handle); + int res = uv_ip4_addr(server_addr, server_port, (struct sockaddr_in *)&proxy->server.addr); + if (res != 0) { + res = uv_ip6_addr(server_addr, server_port, (struct sockaddr_in6 *)&proxy->server.addr); + if (res != 0) { + fprintf(stdout, "[proxy] tls_proxy_init: can't parse local address '%s'\n", server_addr); + return -1; + } + } + res = uv_ip4_addr(upstream_addr, upstream_port, (struct sockaddr_in *)&proxy->upstream_addr); + if (res != 0) { + res = uv_ip6_addr(upstream_addr, upstream_port, (struct sockaddr_in6 *)&proxy->upstream_addr); + if (res != 0) { + fprintf(stdout, "[proxy] tls_proxy_init: can't parse upstream address '%s'\n", upstream_addr); + return -1; + } + } + array_init(proxy->client_list); + proxy->conn_sequence = 0; + + proxy->loop->data = proxy; + + int err = 0; + if (gnutls_references == 0) { + err = gnutls_global_init(); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stdout, "[proxy] gnutls_global_init() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + return -1; + } + } + gnutls_references += 1; + + err = gnutls_certificate_allocate_credentials(&proxy->tls_credentials); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stdout, "[proxy] gnutls_certificate_allocate_credentials() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + return -1; + } + + err = gnutls_certificate_set_x509_system_trust(proxy->tls_credentials); + if (err <= 0) { + fprintf(stdout, "[proxy] gnutls_certificate_set_x509_system_trust() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + return -1; + } + + if (cert_file && key_file) { + err = gnutls_certificate_set_x509_key_file(proxy->tls_credentials, + cert_file, key_file, GNUTLS_X509_FMT_PEM); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stdout, "[proxy] gnutls_certificate_set_x509_key_file() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + return -1; + } + } + + err = gnutls_priority_init(&proxy->tls_priority_cache, NULL, NULL); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stdout, "[proxy] gnutls_priority_init() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + return -1; + } + + return 0; +} + +void tls_proxy_free(struct tls_proxy_ctx *proxy) +{ + if (!proxy) { + return; + } + while (proxy->client_list.len > 0) { + size_t last_index = proxy->client_list.len - 1; + struct peer *client = proxy->client_list.at[last_index]; + clear_pending_bufs(client); + clear_pending_bufs(client->peer); + /* TODO correctly close all the uv_tcp_t */ + free(client->peer); + free(client); + array_del(proxy->client_list, last_index); + } + gnutls_certificate_free_credentials(proxy->tls_credentials); + gnutls_priority_deinit(proxy->tls_priority_cache); + free(proxy); + + gnutls_references -= 1; + if (gnutls_references == 0) { + gnutls_global_deinit(); + } +} + +int tls_proxy_start_listen(struct tls_proxy_ctx *proxy) +{ + uv_tcp_bind(&proxy->server.handle, (const struct sockaddr*)&proxy->server.addr, 0); + int ret = uv_listen((uv_stream_t*)&proxy->server.handle, 128, on_client_connection); + return ret; +} + +int tls_proxy_run(struct tls_proxy_ctx *proxy) +{ + return uv_run(proxy->loop, UV_RUN_DEFAULT); +} diff --git a/tests/pytests/proxy/tls-proxy.h b/tests/pytests/proxy/tls-proxy.h new file mode 100644 index 0000000..c6f468d --- /dev/null +++ b/tests/pytests/proxy/tls-proxy.h @@ -0,0 +1,32 @@ +#pragma once +/* SPDX-License-Identifier: GPL-3.0-or-later */ + +#include <stdint.h> +#include <stdbool.h> +#include <netinet/in.h> + +struct args { + const char *local_addr; + uint16_t local_port; + const char *upstream; + uint16_t upstream_port; + + bool rehandshake; + bool close_connection; + bool accept_only; + bool tls_13; + + uint64_t close_timeout; + uint32_t max_conn_sequence; + + const char *cert_file; + const char *key_file; +}; + +struct tls_proxy_ctx; + +struct tls_proxy_ctx *tls_proxy_allocate(); +void tls_proxy_free(struct tls_proxy_ctx *proxy); +int tls_proxy_init(struct tls_proxy_ctx *proxy, const struct args *a); +int tls_proxy_start_listen(struct tls_proxy_ctx *proxy); +int tls_proxy_run(struct tls_proxy_ctx *proxy); diff --git a/tests/pytests/proxy/tlsproxy.c b/tests/pytests/proxy/tlsproxy.c new file mode 100644 index 0000000..f1b4acc --- /dev/null +++ b/tests/pytests/proxy/tlsproxy.c @@ -0,0 +1,198 @@ +/* SPDX-License-Identifier: GPL-3.0-or-later */ + +#include <stdio.h> +#include <getopt.h> +#include <stdlib.h> +#include <signal.h> +#include <errno.h> +#include <string.h> +#include <gnutls/gnutls.h> +#include "tls-proxy.h" + +static char default_local_addr[] = "127.0.0.1"; +static char default_upstream_addr[] = "127.0.0.1"; +static char default_cert_path[] = "../certs/tt.cert.pem"; +static char default_key_path[] = "../certs/tt.key.pem"; + +void help(char *argv[], struct args *a) +{ + printf("Usage: %s [parameters] [rundir]\n", argv[0]); + printf("\nParameters:\n" + " -l, --local=[addr] Server address to bind to (default: %s).\n" + " -p, --lport=[port] Server port to bind to (default: %u).\n" + " -u, --upstream=[addr] Upstream address (default: %s).\n" + " -d, --uport=[port] Upstream port (default: %u).\n" + " -t, --cert=[path] Path to certificate file (default: %s).\n" + " -k, --key=[path] Path to key file (default: %s).\n" + " -c, --close=[N] Close connection to client after\n" + " every N ms (default: %li).\n" + " -f, --fail=[N] Delay every Nth incoming connection by 10 sec,\n" + " 0 disables delaying (default: 0).\n" + " -r, --rehandshake Do TLS rehandshake after every 8 bytes\n" + " sent to the client (default: no).\n" + " -a, --acceptonly Accept incoming connections, but don't\n" + " connect to upstream (default: no).\n" + " -v, --tls13 Force use of TLSv1.3. If not turned on,\n" + " TLSv1.2 will be used (default: no).\n" + , + a->local_addr, a->local_port, + a->upstream, a->upstream_port, + a->cert_file, a->key_file, + a->close_timeout); +} + +void init_args(struct args *a) +{ + a->local_addr = default_local_addr; + a->local_port = 54000; + a->upstream = default_upstream_addr; + a->upstream_port = 53000; + a->cert_file = default_cert_path; + a->key_file = default_key_path; + a->rehandshake = false; + a->accept_only = false; + a->tls_13 = false; + a->close_connection = false; + a->close_timeout = 1000; + a->max_conn_sequence = 0; /* disabled */ +} + +int main(int argc, char **argv) +{ + long int li_value = 0; + int c = 0, li = 0; + struct option opts[] = { + {"local", required_argument, 0, 'l'}, + {"lport", required_argument, 0, 'p'}, + {"upstream", required_argument, 0, 'u'}, + {"uport", required_argument, 0, 'd'}, + {"cert", required_argument, 0, 't'}, + {"key", required_argument, 0, 'k'}, + {"close", required_argument, 0, 'c'}, + {"fail", required_argument, 0, 'f'}, + {"rehandshake", no_argument, 0, 'r'}, + {"acceptonly", no_argument, 0, 'a'}, +#if GNUTLS_VERSION_NUMBER >= 0x030604 + {"tls13", no_argument, 0, 'v'}, +#endif + {0, 0, 0, 0} + }; + struct args args; + init_args(&args); + while ((c = getopt_long(argc, argv, "l:p:u:d:t:k:c:f:rav", opts, &li)) != -1) { + switch (c) + { + case 'l': + args.local_addr = optarg; + break; + case 'u': + args.upstream = optarg; + break; + case 't': + args.cert_file = optarg; + break; + case 'k': + args.key_file = optarg; + break; + case 'p': + li_value = strtol(optarg, NULL, 10); + if (li_value <= 0 || li_value > UINT16_MAX) { + printf("error: '-p' requires a positive" + " number less or equal to 65535, not '%s'\n", optarg); + return -1; + } + args.local_port = (uint16_t)li_value; + break; + case 'd': + li_value = strtol(optarg, NULL, 10); + if (li_value <= 0 || li_value > UINT16_MAX) { + printf("error: '-d' requires a positive" + " number less or equal to 65535, not '%s'\n", optarg); + return -1; + } + args.upstream_port = (uint16_t)li_value; + break; + case 'c': + li_value = strtol(optarg, NULL, 10); + if (li_value <= 0) { + printf("[system] error '-c' requires a positive" + " number, not '%s'\n", optarg); + return -1; + } + args.close_connection = true; + args.close_timeout = li_value; + break; + case 'f': + li_value = strtol(optarg, NULL, 10); + if (li_value <= 0 || li_value > UINT32_MAX) { + printf("error: '-f' requires a positive" + " number less or equal to %i, not '%s'\n", + UINT32_MAX, optarg); + return -1; + } + args.max_conn_sequence = (uint32_t)li_value; + break; + case 'r': + args.rehandshake = true; + break; + case 'a': + args.accept_only = true; + break; + case 'v': +#if GNUTLS_VERSION_NUMBER >= 0x030604 + args.tls_13 = true; +#endif + break; + default: + init_args(&args); + help(argv, &args); + return -1; + } + } + if (signal(SIGPIPE, SIG_IGN) == SIG_ERR) { + fprintf(stderr, "failed to set up SIGPIPE handler to ignore(%s)\n", + strerror(errno)); + } + struct tls_proxy_ctx *proxy = tls_proxy_allocate(); + if (!proxy) { + fprintf(stderr, "can't allocate tls_proxy structure\n"); + return 1; + } + int res = tls_proxy_init(proxy, &args); + if (res) { + fprintf(stderr, "can't initialize tls_proxy structure\n"); + return res; + } + res = tls_proxy_start_listen(proxy); + if (res) { + fprintf(stderr, "error starting listen, error code: %i\n", res); + return res; + } + fprintf(stdout, "Listen on %s#%u\n" + "Upstream is expected on %s#%u\n" + "Certificate file %s\n" + "Key file %s\n" + "Rehandshake %s\n" + "Close %s\n" + "Refuse incoming connections every %ith%s\n" + "Only accept, don't forward %s\n" + "Force TLSv1.3 %s\n" + , + args.local_addr, args.local_port, + args.upstream, args.upstream_port, + args.cert_file, args.key_file, + args.rehandshake ? "yes" : "no", + args.close_connection ? "yes" : "no", + args.max_conn_sequence, args.max_conn_sequence ? "" : " (disabled)", + args.accept_only ? "yes" : "no", +#if GNUTLS_VERSION_NUMBER >= 0x030604 + args.tls_13 ? "yes" : "no" +#else + "Not supported" +#endif + ); + res = tls_proxy_run(proxy); + tls_proxy_free(proxy); + return res; +} + diff --git a/tests/pytests/proxyutils.py b/tests/pytests/proxyutils.py new file mode 100644 index 0000000..fa20d51 --- /dev/null +++ b/tests/pytests/proxyutils.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +from contextlib import contextmanager +import os +import subprocess + +import dns +import dns.rcode + +from kresd import CERTS_DIR +import utils + + +HINTS = { + '0.foo.': '127.0.0.1', + '1.foo.': '127.0.0.1', + '2.foo.': '127.0.0.1', + '3.foo.': '127.0.0.1', +} + +PROXY_CA_FILE = os.path.join(CERTS_DIR, 'tt.cert.pem') + + +def resolve_hint(sock, qname): + buff, msgid = utils.get_msgbuff(qname) + sock.sendall(buff) + answer = utils.receive_parse_answer(sock) + assert answer.id == msgid + assert answer.rcode() == dns.rcode.NOERROR + assert answer.answer[0][0].address == HINTS[qname] + + +@contextmanager +def proxy(path): + cwd, cmd = os.path.split(path) + cmd = './' + cmd + try: + proxy = subprocess.Popen( + [cmd], cwd=cwd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + yield proxy + finally: + proxy.terminate() diff --git a/tests/pytests/pylintrc b/tests/pytests/pylintrc new file mode 100644 index 0000000..39fcb28 --- /dev/null +++ b/tests/pytests/pylintrc @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +[MESSAGES CONTROL] + +disable= + missing-docstring, + too-few-public-methods, + too-many-arguments, + too-many-instance-attributes, + fixme, + unused-import, # checked by flake8 + line-too-long, # checked by flake8 + invalid-name, + broad-except, + bad-continuation, + global-statement, + no-else-return, + redefined-outer-name, # commonly used with pytest fixtures + + +[SIMILARITIES] +min-similarity-lines=6 +ignore-comments=yes +ignore-docstrings=yes +ignore-imports=no + +[DESIGN] +max-parents=10 +max-locals=20 + +[TYPECHECK] +ignored-modules=ssl diff --git a/tests/pytests/requirements.txt b/tests/pytests/requirements.txt new file mode 100644 index 0000000..6e2e4d2 --- /dev/null +++ b/tests/pytests/requirements.txt @@ -0,0 +1,5 @@ +dnspython +jinja2 +pytest +pytest-html +pytest-xdist diff --git a/tests/pytests/templates/kresd.conf.j2 b/tests/pytests/templates/kresd.conf.j2 new file mode 100644 index 0000000..08c73c8 --- /dev/null +++ b/tests/pytests/templates/kresd.conf.j2 @@ -0,0 +1,55 @@ +-- SPDX-License-Identifier: GPL-3.0-or-later +modules = { + 'hints > policy', + 'policy > iterate', +} + +verbose({{ 'true' if kresd.verbose else 'false' }}) + +{% if kresd.ip %} +net.listen('{{ kresd.ip }}', {{ kresd.port }}) +net.listen('{{ kresd.ip }}', {{ kresd.tls_port }}, {tls = true}) +{% endif %} + +{% if kresd.ip6 %} +net.listen('{{ kresd.ip6 }}', {{ kresd.port }}) +net.listen('{{ kresd.ip6 }}', {{ kresd.tls_port }}, {tls = true}) +{% endif %} + +net.ipv4=true +net.ipv6=true + +{% if kresd.tls_key_path and kresd.tls_cert_path %} +net.tls("{{ kresd.tls_cert_path }}", "{{ kresd.tls_key_path }}") +net.tls_sticket_secret('0123456789ABCDEF0123456789ABCDEF') +{% endif %} + +hints['localhost.'] = '127.0.0.1' +{% for name, ip in kresd.hints.items() %} +hints['{{ name }}'] = '{{ ip }}' +{% endfor %} + +policy.add(policy.all(policy.QTRACE)) + +{% if kresd.forward %} +policy.add(policy.all( + {% if kresd.forward.proto == 'tls' %} + policy.TLS_FORWARD({ + {"{{ kresd.forward.ip }}@{{ kresd.forward.port }}", hostname='{{ kresd.forward.hostname}}', ca_file='{{ kresd.forward.ca_file }}'}}) + {% endif %} +)) +{% endif %} + +{% if kresd.policy_test_pass %} +policy.add(policy.suffix(policy.PASS, {todname('test.')})) +{% endif %} + +-- make sure DNSSEC is turned off for tests +trust_anchors.remove('.') +modules.unload("ta_update") +modules.unload("ta_signal_query") +modules.unload("priming") +modules.unload("detect_time_skew") + +-- choose a small cache, since it is preallocated +cache.size = 1 * MB diff --git a/tests/pytests/test_conn_mgmt.py b/tests/pytests/test_conn_mgmt.py new file mode 100644 index 0000000..1d15091 --- /dev/null +++ b/tests/pytests/test_conn_mgmt.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +"""TCP Connection Management tests""" + +import socket +import struct +import time + +import pytest + +import utils + + +@pytest.mark.parametrize('garbage_lengths', [ + (1,), + (1024,), + (65533,), # max size garbage + (65533, 65533), + (1024, 1024, 1024), + # (0,), # currently kresd uses this as a heuristic of "lost in bytestream" + # (0, 1024), # and closes the connection +]) +def test_ignore_garbage(kresd_sock, garbage_lengths, single_buffer, query_before): + """Send chunk of garbage, prefixed by garbage length. It should be ignored.""" + buff = b'' + if query_before: # optionally send initial query + msg_buff_before, msgid_before = utils.get_msgbuff() + if single_buffer: + buff += msg_buff_before + else: + kresd_sock.sendall(msg_buff_before) + + for glength in garbage_lengths: # prepare garbage data + if glength is None: + continue + garbage_buff = utils.get_prefixed_garbage(glength) + if single_buffer: + buff += garbage_buff + else: + kresd_sock.sendall(garbage_buff) + + msg_buff, msgid = utils.get_msgbuff() # final query + buff += msg_buff + kresd_sock.sendall(buff) + + if query_before: + answer_before = utils.receive_parse_answer(kresd_sock) + assert answer_before.id == msgid_before + answer = utils.receive_parse_answer(kresd_sock) + assert answer.id == msgid + + +def test_pipelining(kresd_sock): + """ + First query takes longer to resolve - answer to second query should arrive sooner. + + This test requires internet connection. + """ + # initialization (to avoid issues with net.ipv6=true) + buff_pre, msgid_pre = utils.get_msgbuff('0.delay.getdnsapi.net.') + kresd_sock.sendall(buff_pre) + msg_answer = utils.receive_parse_answer(kresd_sock) + assert msg_answer.id == msgid_pre + + # test + buff1, msgid1 = utils.get_msgbuff('1500.delay.getdnsapi.net.', msgid=1) + buff2, msgid2 = utils.get_msgbuff('1.delay.getdnsapi.net.', msgid=2) + buff = buff1 + buff2 + kresd_sock.sendall(buff) + + msg_answer = utils.receive_parse_answer(kresd_sock) + assert msg_answer.id == msgid2 + + msg_answer = utils.receive_parse_answer(kresd_sock) + assert msg_answer.id == msgid1 + + +@pytest.mark.parametrize('duration, delay', [ + (utils.MAX_TIMEOUT, 0.1), + (utils.MAX_TIMEOUT, 3), + (utils.MAX_TIMEOUT, 7), + (utils.MAX_TIMEOUT + 10, 3), +]) +def test_long_lived(kresd_sock, duration, delay): + """Establish and keep connection alive for longer than maximum timeout.""" + utils.ping_alive(kresd_sock) + end_time = time.time() + duration + + while time.time() < end_time: + time.sleep(delay) + utils.ping_alive(kresd_sock) + + +def test_close(kresd_sock, query_before): + """Establish a connection and wait for timeout from kresd.""" + if query_before: + utils.ping_alive(kresd_sock) + time.sleep(utils.MAX_TIMEOUT) + + with utils.expect_kresd_close(): + utils.ping_alive(kresd_sock) + + +def test_slow_lorris(kresd_sock, query_before): + """Simulate slow-lorris attack by sending byte after byte with delays in between.""" + if query_before: + utils.ping_alive(kresd_sock) + + buff, _ = utils.get_msgbuff() + end_time = time.time() + utils.MAX_TIMEOUT + + with utils.expect_kresd_close(): + for i in range(len(buff)): + b = buff[i:i+1] + kresd_sock.send(b) + if time.time() > end_time: + break + time.sleep(1) + + +@pytest.mark.parametrize('sock_func_name', [ + 'ip_tcp_socket', + 'ip6_tcp_socket', +]) +def test_oob(kresd, sock_func_name): + """TCP out-of-band (urgent) data must not crash resolver.""" + make_sock = getattr(kresd, sock_func_name) + sock = make_sock() + msg_buff, msgid = utils.get_msgbuff() + sock.sendall(msg_buff, socket.MSG_OOB) + + try: + msg_answer = utils.receive_parse_answer(sock) + assert msg_answer.id == msgid + except ConnectionError: + pass # TODO kresd responds with TCP RST, this should be fixed + + # check kresd is alive + sock2 = make_sock() + utils.ping_alive(sock2) + + +def flood_buffer(msgcount): + flood_buff = bytes() + msgbuff, _ = utils.get_msgbuff() + noid_msgbuff = msgbuff[2:] + + def gen_msg(msgid): + return struct.pack("!H", len(msgbuff)) + struct.pack("!H", msgid) + noid_msgbuff + + for i in range(msgcount): + flood_buff += gen_msg(i) + return flood_buff + + +def test_query_flood_close(make_kresd_sock): + """Flood resolver with queries and close the connection.""" + buff = flood_buffer(10000) + sock1 = make_kresd_sock() + sock1.sendall(buff) + sock1.close() + + sock2 = make_kresd_sock() + utils.ping_alive(sock2) + + +def test_query_flood_no_recv(make_kresd_sock): + """Flood resolver with queries but don't read any data.""" + # A use-case for TCP_USER_TIMEOUT socket option? See RFC 793 and RFC 5482 + + # It seems it doesn't works as expected. libuv doesn't return any error + # (neither on uv_write() call, not in the callback) when kresd sends answers, + # so kresd can't recognize that client didn't read any answers. At a certain + # point, kresd stops receiving queries from the client (whilst client keep + # sending) and closes connection due to timeout. + + buff = flood_buffer(10000) + sock1 = make_kresd_sock() + end_time = time.time() + utils.MAX_TIMEOUT + + with utils.expect_kresd_close(rst_ok=True): # connection must be closed + while time.time() < end_time: + sock1.sendall(buff) + time.sleep(0.5) + + sock2 = make_kresd_sock() + utils.ping_alive(sock2) # resolver must stay alive + + +@pytest.mark.parametrize('glength, gcount, delay', [ + (65533, 100, 0.5), + (0, 100000, 0.5), + (1024, 1000, 0.5), + (65533, 1, 0), + (0, 1, 0), + (1024, 1, 0), +]) +def test_query_flood_garbage(make_kresd_silent_sock, glength, gcount, delay, query_before): + """Flood resolver with prefixed garbage.""" + sock1 = make_kresd_silent_sock() + if query_before: + utils.ping_alive(sock1) + + gbuff = utils.get_prefixed_garbage(glength) + buff = gbuff * gcount + + end_time = time.time() + utils.MAX_TIMEOUT + + with utils.expect_kresd_close(rst_ok=True): # connection must be closed + while time.time() < end_time: + sock1.sendall(buff) + time.sleep(delay) + + sock2 = make_kresd_silent_sock() + utils.ping_alive(sock2) # resolver must stay alive diff --git a/tests/pytests/test_prefix.py b/tests/pytests/test_prefix.py new file mode 100644 index 0000000..2b576b6 --- /dev/null +++ b/tests/pytests/test_prefix.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +"""TCP Connection Management tests - prefix length + +RFC1035 +4.2.2. TCP usage +The message is prefixed with a two byte length field which gives the message +length, excluding the two byte length field. + +The following test suite focuses on edge cases for the prefix - when it +is either too short or too long, instead of matching the length of DNS +message exactly. +""" + +import time + +import pytest + +import utils + + +@pytest.fixture(params=[ + 'no_query_before', + 'query_before', + 'query_before_in_single_buffer', +]) +def send_query(request): + """Function sends a buffer, either by itself, or with a valid query before. + If a valid query is sent before, it can be sent either in a separate buffer, or + along with the provided buffer.""" + + # pylint: disable=possibly-unused-variable + + def no_query_before(sock, buff): # pylint: disable=unused-argument + sock.sendall(buff) + + def query_before(sock, buff, single_buffer=False): + """Send an initial query and expect a response.""" + msg_buff, msgid = utils.get_msgbuff() + + if single_buffer: + sock.sendall(msg_buff + buff) + else: + sock.sendall(msg_buff) + sock.sendall(buff) + + answer = utils.receive_parse_answer(sock) + assert answer.id == msgid + + def query_before_in_single_buffer(sock, buff): + return query_before(sock, buff, single_buffer=True) + + return locals()[request.param] + + +@pytest.mark.parametrize('datalen', [ + 1, # just one byte of DNS header + 11, # DNS header size minus 1 + 14, # DNS Header size plus 2 +]) +def test_prefix_cuts_message(kresd_sock, datalen, send_query): + """Prefix is shorter than the DNS message.""" + wire, _ = utils.prepare_wire() + assert datalen < len(wire) + invalid_buff = utils.prepare_buffer(wire, datalen) + + send_query(kresd_sock, invalid_buff) # buffer breaks parsing of TCP stream + + with utils.expect_kresd_close(): + utils.ping_alive(kresd_sock) + + +def test_prefix_greater_than_message(kresd_sock, send_query): + """Prefix is greater than the length of the entire DNS message.""" + wire, invalid_msgid = utils.prepare_wire() + datalen = len(wire) + 16 + invalid_buff = utils.prepare_buffer(wire, datalen) + + send_query(kresd_sock, invalid_buff) + + valid_buff, _ = utils.get_msgbuff() + kresd_sock.sendall(valid_buff) + + # invalid_buff is answered (treats additional data as trailing garbage) + answer = utils.receive_parse_answer(kresd_sock) + assert answer.id == invalid_msgid + + # parsing stream is broken by the invalid_buff, valid query is never answered + with utils.expect_kresd_close(): + utils.receive_parse_answer(kresd_sock) + + +@pytest.mark.parametrize('glength', [ + 0, + 1, + 8, + 1024, + 4096, + 20000, +]) +def test_prefix_trailing_garbage(kresd_sock, glength, query_before): + """Send messages with trailing garbage (its length included in prefix).""" + if query_before: + utils.ping_alive(kresd_sock) + + for _ in range(10): + wire, msgid = utils.prepare_wire() + wire += utils.get_garbage(glength) + buff = utils.prepare_buffer(wire) + + kresd_sock.sendall(buff) + answer = utils.receive_parse_answer(kresd_sock) + assert answer.id == msgid + + time.sleep(0.1) diff --git a/tests/pytests/test_random_close.py b/tests/pytests/test_random_close.py new file mode 100644 index 0000000..a7cc877 --- /dev/null +++ b/tests/pytests/test_random_close.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +"""TLS test when forward target closes connection after one second + +Test utilizes TLS proxy, which forwards queries to configured +resolver, but closes the connection 1s after establishing. + +Kresd must stay alive and be able to answer queries. +""" + +import random +import string +import time + +from proxy import HINTS, kresd_tls_client, resolve_hint, TLSProxy +import utils + + +QPS = 500 + + +def random_string(size=32, chars=(string.ascii_lowercase + string.digits)): + return ''.join(random.choice(chars) for x in range(size)) + + +def rsa_cannon(sock, duration, domain='test.', qps=QPS): + end_time = time.time() + duration + + while time.time() < end_time: + next_time = time.time() + 1/qps + buff, _ = utils.get_msgbuff('{}.{}'.format(random_string(), domain)) + sock.sendall(buff) + time_left = next_time - time.time() + if time_left > 0: + time.sleep(time_left) + + +def test_proxy_random_close(tmpdir): + proxy = TLSProxy(close=1000) + + kresd_tls_client_kwargs = { + 'verbose': False, + 'policy_test_pass': True + } + kresd_fwd_target_kwargs = { + 'verbose': False + } + with kresd_tls_client(str(tmpdir), proxy, kresd_tls_client_kwargs, kresd_fwd_target_kwargs) \ + as kresd: + sock2 = kresd.ip_tcp_socket() + rsa_cannon(sock2, 20) + sock3 = kresd.ip_tcp_socket() + for hint in HINTS: + resolve_hint(sock3, hint) + time.sleep(0.1) diff --git a/tests/pytests/test_rehandshake.py b/tests/pytests/test_rehandshake.py new file mode 100644 index 0000000..f07ba58 --- /dev/null +++ b/tests/pytests/test_rehandshake.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +"""TLS rehandshake test + +Test is using TLS proxy with rehandshake. When queries are sent, they are +simply forwarded. When the responses are sent back, a rehandshake is performed +after every 8 bytes. + +It is expected the answer will be received by the source kresd instance +and sent back to the client (this test). +""" + +import re +import time + +import pytest + +from proxy import HINTS, kresd_tls_client, resolve_hint, TLSProxy + + +def verify_rehandshake(tmpdir, proxy): + with kresd_tls_client(str(tmpdir), proxy) as kresd: + sock2 = kresd.ip_tcp_socket() + try: + for hint in HINTS: + resolve_hint(sock2, hint) + time.sleep(0.1) + finally: + # verify log + n_connecting_to = 0 + n_rehandshake = 0 + partial_log = kresd.partial_log() + print(partial_log) + for line in partial_log.splitlines(): + if re.search(r"connecting to: .*", line) is not None: + n_connecting_to += 1 + elif re.search(r"TLS rehandshake .* has started", line) is not None: + n_rehandshake += 1 + assert n_connecting_to == 1 # should connect exactly once + assert n_rehandshake > 0 + + +def test_proxy_rehandshake_tls12(tmpdir): + proxy = TLSProxy(rehandshake=True) + verify_rehandshake(tmpdir, proxy) + + +# TODO fix TLS v1.3 proxy / kresd rehandshake +@pytest.mark.xfail( + reason="TLS 1.3 rehandshake isn't properly supported either in tlsproxy or in kresd") +def test_proxy_rehandshake_tls13(tmpdir): + proxy = TLSProxy(rehandshake=True, force_tls13=True) + verify_rehandshake(tmpdir, proxy) diff --git a/tests/pytests/test_tls.py b/tests/pytests/test_tls.py new file mode 100644 index 0000000..4f622c1 --- /dev/null +++ b/tests/pytests/test_tls.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +"""TLS-specific tests""" + +import itertools +import os +from socket import AF_INET, AF_INET6 +import ssl +import sys + +import pytest + +from kresd import make_kresd +import utils + + +def test_tls_no_cert(kresd, sock_family): + """Use TLS without certificates.""" + sock, dest = kresd.stream_socket(sock_family, tls=True) + ctx = utils.make_ssl_context(insecure=True) + ssock = ctx.wrap_socket(sock) + ssock.connect(dest) + + utils.ping_alive(ssock) + + +def test_tls_selfsigned_cert(kresd_tt, sock_family): + """Use TLS with a self signed certificate.""" + sock, dest = kresd_tt.stream_socket(sock_family, tls=True) + ctx = utils.make_ssl_context(verify_location=kresd_tt.tls_cert_path) + ssock = ctx.wrap_socket(sock, server_hostname='transport-test-server.com') + ssock.connect(dest) + + utils.ping_alive(ssock) + + +def test_tls_cert_hostname_mismatch(kresd_tt, sock_family): + """Attempt to use self signed certificate and incorrect hostname.""" + sock, dest = kresd_tt.stream_socket(sock_family, tls=True) + ctx = utils.make_ssl_context(verify_location=kresd_tt.tls_cert_path) + ssock = ctx.wrap_socket(sock, server_hostname='wrong-host-name') + + with pytest.raises(ssl.CertificateError): + ssock.connect(dest) + + +@pytest.mark.skipif(sys.version_info < (3, 6), + reason="requires python3.6 or higher") +@pytest.mark.parametrize('sf1, sf2, sf3', itertools.product( + [AF_INET, AF_INET6], [AF_INET, AF_INET6], [AF_INET, AF_INET6])) +def test_tls_session_resumption(tmpdir, sf1, sf2, sf3): + """Attempt TLS session resumption against the same kresd instance and a different one.""" + # TODO ensure that session can't be resumed after session ticket key regeneration + # at the first kresd instance + + # NOTE TLS 1.3 is intentionally disabled for session resumption tests, + # becuase python's SSLSocket.session isn't compatible with TLS 1.3 + # https://docs.python.org/3/library/ssl.html?highlight=ssl%20ticket#tls-1-3 + + def connect(kresd, ctx, sf, session=None): + sock, dest = kresd.stream_socket(sf, tls=True) + ssock = ctx.wrap_socket( + sock, server_hostname='transport-test-server.com', session=session) + ssock.connect(dest) + new_session = ssock.session + assert new_session.has_ticket + assert ssock.session_reused == (session is not None) + utils.ping_alive(ssock) + ssock.close() + return new_session + + workdir = os.path.join(str(tmpdir), 'kresd') + os.makedirs(workdir) + + with make_kresd(workdir, 'tt') as kresd: + ctx = utils.make_ssl_context( + verify_location=kresd.tls_cert_path, extra_options=[ssl.OP_NO_TLSv1_3]) + session = connect(kresd, ctx, sf1) # initial conn + connect(kresd, ctx, sf2, session) # resume session on the same instance + + workdir2 = os.path.join(str(tmpdir), 'kresd2') + os.makedirs(workdir2) + with make_kresd(workdir2, 'tt') as kresd2: + connect(kresd2, ctx, sf3, session) # resume session on a different instance diff --git a/tests/pytests/utils.py b/tests/pytests/utils.py new file mode 100644 index 0000000..812b6f3 --- /dev/null +++ b/tests/pytests/utils.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +from contextlib import contextmanager +import random +import ssl +import struct +import time + +import dns +import dns.message +import pytest + + +# default net.tcp_in_idle is 10s, TCP_DEFER_ACCEPT 3s, some extra for +# Python handling / edge cases +MAX_TIMEOUT = 16 + + +def receive_answer(sock): + answer_total_len = 0 + data = sock.recv(2) + if not data: + return None + answer_total_len = struct.unpack_from("!H", data)[0] + + answer_received_len = 0 + data_answer = b'' + while answer_received_len < answer_total_len: + data_chunk = sock.recv(answer_total_len - answer_received_len) + if not data_chunk: + return None + data_answer += data_chunk + answer_received_len += len(data_answer) + + return data_answer + + +def receive_parse_answer(sock): + data_answer = receive_answer(sock) + + if data_answer is None: + raise BrokenPipeError("kresd closed connection") + + msg_answer = dns.message.from_wire(data_answer, one_rr_per_rrset=True) + return msg_answer + + +def prepare_wire( + qname='localhost.', + qtype=dns.rdatatype.A, + qclass=dns.rdataclass.IN, + msgid=None): + """Utility function to generate DNS wire format message""" + msg = dns.message.make_query(qname, qtype, qclass) + if msgid is not None: + msg.id = msgid + return msg.to_wire(), msg.id + + +def prepare_buffer(wire, datalen=None): + """Utility function to prepare TCP buffer from DNS message in wire format""" + assert isinstance(wire, bytes) + if datalen is None: + datalen = len(wire) + return struct.pack("!H", datalen) + wire + + +def get_msgbuff(qname='localhost.', qtype=dns.rdatatype.A, msgid=None): + wire, msgid = prepare_wire(qname, qtype, msgid=msgid) + buff = prepare_buffer(wire) + return buff, msgid + + +def get_garbage(length): + return bytes(random.getrandbits(8) for _ in range(length)) + + +def get_prefixed_garbage(length): + data = get_garbage(length) + return prepare_buffer(data) + + +def try_ping_alive(sock, msgid=None, close=False): + try: + ping_alive(sock, msgid) + except AssertionError: + return False + finally: + if close: + sock.close() + return True + + +def ping_alive(sock, msgid=None): + buff, msgid = get_msgbuff(msgid=msgid) + sock.sendall(buff) + answer = receive_parse_answer(sock) + assert answer.id == msgid + + +@contextmanager +def expect_kresd_close(rst_ok=False): + with pytest.raises(BrokenPipeError): + try: + time.sleep(0.2) # give kresd time to close connection with TCP FIN + yield + except ConnectionResetError as ex: + if rst_ok: + raise BrokenPipeError from ex + pytest.skip("kresd closed connection with TCP RST") + pytest.fail("kresd didn't close the connection") + + +def make_ssl_context(insecure=False, verify_location=None, extra_options=None): + # set TLS v1.2+ + context = ssl.SSLContext(ssl.PROTOCOL_TLS) + context.options |= ssl.OP_NO_SSLv2 + context.options |= ssl.OP_NO_SSLv3 + context.options |= ssl.OP_NO_TLSv1 + context.options |= ssl.OP_NO_TLSv1_1 + + if extra_options is not None: + for option in extra_options: + context.options |= option + + if insecure: + # turn off certificate verification + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + else: + context.verify_mode = ssl.CERT_REQUIRED + context.check_hostname = True + + if verify_location is not None: + context.load_verify_locations(verify_location) + + return context |