From 3d0386f27ca66379acf50199e1d1298386eeeeb8 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 6 May 2024 02:55:53 +0200 Subject: Adding upstream version 3.2.1. Signed-off-by: Daniel Baumann --- tests/pytests/README.rst | 54 ++ tests/pytests/certs/tt-certgen-expired.sh | 18 + tests/pytests/certs/tt-certgen.sh | 4 + tests/pytests/certs/tt-expired.cert.pem | 80 +++ tests/pytests/certs/tt-expired.key.pem | 27 + tests/pytests/certs/tt.cert.pem | 22 + tests/pytests/certs/tt.conf | 353 +++++++++++++ tests/pytests/certs/tt.key.pem | 28 + tests/pytests/conftest.py | 78 +++ tests/pytests/conn_flood.py | 83 +++ tests/pytests/kresd.py | 305 +++++++++++ tests/pytests/pylintrc | 29 + tests/pytests/rehandshake/Makefile | 28 + tests/pytests/rehandshake/array.h | 166 ++++++ tests/pytests/rehandshake/tcp-proxy.c | 336 ++++++++++++ tests/pytests/rehandshake/tcp-proxy.h | 12 + tests/pytests/rehandshake/tcproxy.c | 25 + tests/pytests/rehandshake/tls-proxy.c | 848 ++++++++++++++++++++++++++++++ tests/pytests/rehandshake/tls-proxy.h | 14 + tests/pytests/rehandshake/tlsproxy.c | 31 ++ tests/pytests/requirements.txt | 5 + tests/pytests/templates/kresd.conf.j2 | 42 ++ tests/pytests/test_conn_mgmt.py | 213 ++++++++ tests/pytests/test_prefix.py | 113 ++++ tests/pytests/test_rehandshake.py | 87 +++ tests/pytests/test_tls.py | 77 +++ tests/pytests/utils.py | 131 +++++ 27 files changed, 3209 insertions(+) create mode 100644 tests/pytests/README.rst create mode 100755 tests/pytests/certs/tt-certgen-expired.sh create mode 100755 tests/pytests/certs/tt-certgen.sh create mode 100644 tests/pytests/certs/tt-expired.cert.pem create mode 100644 tests/pytests/certs/tt-expired.key.pem create mode 100644 tests/pytests/certs/tt.cert.pem create mode 100644 tests/pytests/certs/tt.conf create mode 100644 tests/pytests/certs/tt.key.pem create mode 100644 tests/pytests/conftest.py create mode 100644 tests/pytests/conn_flood.py create mode 100644 tests/pytests/kresd.py create mode 100644 tests/pytests/pylintrc create mode 100644 tests/pytests/rehandshake/Makefile create mode 100644 tests/pytests/rehandshake/array.h create mode 100644 tests/pytests/rehandshake/tcp-proxy.c create mode 100644 tests/pytests/rehandshake/tcp-proxy.h create mode 100644 tests/pytests/rehandshake/tcproxy.c create mode 100644 tests/pytests/rehandshake/tls-proxy.c create mode 100644 tests/pytests/rehandshake/tls-proxy.h create mode 100644 tests/pytests/rehandshake/tlsproxy.c create mode 100644 tests/pytests/requirements.txt create mode 100644 tests/pytests/templates/kresd.conf.j2 create mode 100644 tests/pytests/test_conn_mgmt.py create mode 100644 tests/pytests/test_prefix.py create mode 100644 tests/pytests/test_rehandshake.py create mode 100644 tests/pytests/test_tls.py create mode 100644 tests/pytests/utils.py (limited to 'tests/pytests') diff --git a/tests/pytests/README.rst b/tests/pytests/README.rst new file mode 100644 index 0000000..9a11ccd --- /dev/null +++ b/tests/pytests/README.rst @@ -0,0 +1,54 @@ +Python client tests for kresd +============================= + +The tests run `/usr/bin/env kresd` (can be modified with `$PATH`) with custom config +and execute client-side testing, such as TCP / TLS connection management. + +Requirements +------------ + +- pip3 install -r requirements.txt + +Executing tests +--------------- + +Tests can be executed with the pytest framework. + +.. code-block:: bash + + $ pytest-3 # sequential, all tests (with exception of few special tests) + $ pytest-3 test_conn_mgmt.py::test_ignore_garbage # specific test only + $ pytest-3 --html pytests.html --self-contained-html # html report + +It's highly recommended to run these tests in parallel, since lot of them +wait for kresd timeout. This can be done with `python-xdist`: + +.. code-block:: bash + + $ pytest-3 -n 24 # parallel with 24 jobs + +Each test spawns an independent kresd instance, so test failures shouldn't affect +each other. + +Some tests are omitted from automatic test collection by default, due to their +resource contraints. These typicially have to be executed separately by providing +the path to test file directly. + +.. code-block:: bash + + $ pytest-3 conn_flood.py + +Note: some tests may fail without an internet connection. + +Developer notes +--------------- + +Typically, each test requires a setup of kresd, and a connected socket to run tests on. +The framework provides a few useful pytest fixtures to simplify this process: + +- `kresd_sock` provides a connected socket to a test-specific, running kresd instance. + It expands to 4 values (tests) - IPv4 TCP, IPv6 TCP, IPv4 TLS, IPv6 TLS sockets +- `make_kresd_sock` is similar to `kresd_sock`, except it's a factory function that + produces a new connected socket (of the same type) on each call +- `kresd`, `kresd_tt` are all Kresd instances, already running + and initialized with config (with no / valid TLS certificates) diff --git a/tests/pytests/certs/tt-certgen-expired.sh b/tests/pytests/certs/tt-certgen-expired.sh new file mode 100755 index 0000000..7750291 --- /dev/null +++ b/tests/pytests/certs/tt-certgen-expired.sh @@ -0,0 +1,18 @@ +# !/bin/bash + +if [ ! -d ./demoCA ]; then + mkdir ./demoCA +fi +if [ ! -d ./demoCA/newcerts ]; then + mkdir ./demoCA/newcerts +fi +touch ./demoCA/index.txt +touch ./demoCA/index.txt.attr +if [ ! -f ./demoCA/serial ]; then + echo 01 > ./demoCA/serial +fi + +openssl genrsa -out tt-expired.key.pem 2048 +openssl req -config tt.conf -new -key tt-expired.key.pem -out tt-expired.csr.pem +openssl ca -config tt.conf -selfsign -keyfile tt-expired.key.pem -out tt-expired.cert.pem -in tt-expired.csr.pem -startdate 19700101000000Z -enddate 19700101000000Z + diff --git a/tests/pytests/certs/tt-certgen.sh b/tests/pytests/certs/tt-certgen.sh new file mode 100755 index 0000000..b6b3d7f --- /dev/null +++ b/tests/pytests/certs/tt-certgen.sh @@ -0,0 +1,4 @@ +# !/bin/sh + +openssl req -config tt.conf -new -x509 -newkey rsa:2048 -nodes -keyout tt.key.pem -sha256 -out tt.cert.pem -days 20000 + diff --git a/tests/pytests/certs/tt-expired.cert.pem b/tests/pytests/certs/tt-expired.cert.pem new file mode 100644 index 0000000..c9f8c09 --- /dev/null +++ b/tests/pytests/certs/tt-expired.cert.pem @@ -0,0 +1,80 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: 1 (0x1) + Signature Algorithm: sha256WithRSAEncryption + Issuer: C=CZ, ST=PRAGUE, CN=transport-test-server.com + Validity + Not Before: Jan 1 00:00:00 1970 GMT + Not After : Jan 1 00:00:00 1970 GMT + Subject: C=CZ, ST=PRAGUE, CN=transport-test-server.com + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (2048 bit) + Modulus: + 00:bf:6b:1a:11:47:01:ac:eb:5c:2d:cf:ce:6a:a4: + 00:ce:2f:d1:25:03:5f:06:38:02:92:24:18:92:2a: + 69:19:b2:2b:a3:4f:f7:79:de:35:c3:f5:72:37:83: + 44:93:f9:76:fc:89:29:32:9c:0d:4b:95:7d:d1:5d: + 40:e9:ba:49:50:7d:c6:0a:c8:1e:e7:90:1e:37:7c: + 0b:23:a3:e3:bc:c9:53:81:de:d6:5f:cb:b2:3d:36: + ac:59:b0:33:91:8f:0c:5f:10:20:70:bf:a3:22:b3: + 98:ac:d4:7a:ea:67:b8:b1:8c:cf:e5:fe:8f:a0:a5: + 02:ad:6d:ce:f1:62:ab:dc:5d:96:9c:4f:95:47:d5: + 82:b7:b3:e3:87:4c:8d:38:85:2a:24:9d:7f:c7:a4: + 0e:bd:8a:2d:6b:d2:d4:e8:78:62:1b:aa:25:5f:5a: + 64:e5:76:23:ae:11:03:9a:5c:ed:a2:ba:51:ec:b1: + f3:ae:ba:5c:eb:dd:49:63:ca:c7:af:0c:16:1d:94: + 95:3a:ce:2c:8f:e2:94:7f:1f:a1:76:e2:9f:d1:41: + 31:f0:68:e5:ae:df:d0:75:a0:34:f5:25:93:85:b3: + 25:50:42:6c:00:c0:fe:3b:e0:fb:00:de:75:33:86: + 6a:21:35:14:9d:7f:4a:af:f7:15:f2:d7:bb:2f:de: + df:ab + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Basic Constraints: + CA:FALSE + Netscape Comment: + OpenSSL Generated Certificate + X509v3 Subject Key Identifier: + B3:42:0A:9A:00:19:CB:CB:24:A0:02:45:1E:8A:B0:54:CB:9F:55:FE + X509v3 Authority Key Identifier: + keyid:B3:42:0A:9A:00:19:CB:CB:24:A0:02:45:1E:8A:B0:54:CB:9F:55:FE + + Signature Algorithm: sha256WithRSAEncryption + 32:9a:05:e3:6f:ae:ee:b1:a2:12:0a:9f:0a:e7:78:26:df:90: + fb:84:60:ae:13:fc:ff:fd:42:84:23:14:c3:2e:e2:a9:df:4b: + 5c:2f:5b:0e:3d:f9:5a:56:50:13:bc:89:1a:08:70:dd:6c:6c: + e8:ae:cf:22:39:92:f2:3b:40:03:8f:4e:bc:54:88:6b:fd:8c: + b6:eb:30:90:21:db:fc:4e:5c:7e:12:75:e2:52:76:df:19:0f: + 30:49:1e:15:bc:ba:6a:e6:f7:af:93:ad:e4:36:da:47:47:a6: + 88:b0:ae:46:1e:91:e1:d6:b1:5e:a4:f0:68:02:81:57:86:5d: + 17:d1:6c:7e:7a:9f:5e:0d:fc:10:e7:7a:1a:b5:f9:4b:1d:78: + a4:9a:9d:d7:c2:64:c3:52:28:7f:a1:b7:25:d7:13:3f:09:7f: + f2:fd:dd:c6:91:eb:9b:51:80:e2:36:cb:9f:5b:4e:47:eb:77: + d3:cc:8b:18:b5:0b:97:a2:53:8e:fb:9b:94:7d:57:21:32:c6: + f3:67:93:a4:9b:eb:46:b7:cd:08:43:99:dd:c1:c3:51:b9:19: + ef:92:77:1c:84:67:80:67:95:ba:00:75:3d:7b:8b:ff:24:30: + f1:fa:6d:da:31:9d:cf:06:da:5d:04:07:14:45:8c:6b:e7:21: + 31:ec:7b:23 +-----BEGIN CERTIFICATE----- +MIIDfjCCAmagAwIBAgIBATANBgkqhkiG9w0BAQsFADBCMQswCQYDVQQGEwJDWjEP +MA0GA1UECAwGUFJBR1VFMSIwIAYDVQQDDBl0cmFuc3BvcnQtdGVzdC1zZXJ2ZXIu +Y29tMCIYDzE5NzAwMTAxMDAwMDAwWhgPMTk3MDAxMDEwMDAwMDBaMEIxCzAJBgNV +BAYTAkNaMQ8wDQYDVQQIDAZQUkFHVUUxIjAgBgNVBAMMGXRyYW5zcG9ydC10ZXN0 +LXNlcnZlci5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC/axoR +RwGs61wtz85qpADOL9ElA18GOAKSJBiSKmkZsiujT/d53jXD9XI3g0ST+Xb8iSky +nA1LlX3RXUDpuklQfcYKyB7nkB43fAsjo+O8yVOB3tZfy7I9NqxZsDORjwxfECBw +v6Mis5is1HrqZ7ixjM/l/o+gpQKtbc7xYqvcXZacT5VH1YK3s+OHTI04hSoknX/H +pA69ii1r0tToeGIbqiVfWmTldiOuEQOaXO2iulHssfOuulzr3UljysevDBYdlJU6 +ziyP4pR/H6F24p/RQTHwaOWu39B1oDT1JZOFsyVQQmwAwP474PsA3nUzhmohNRSd +f0qv9xXy17sv3t+rAgMBAAGjezB5MAkGA1UdEwQCMAAwLAYJYIZIAYb4QgENBB8W +HU9wZW5TU0wgR2VuZXJhdGVkIENlcnRpZmljYXRlMB0GA1UdDgQWBBSzQgqaABnL +yySgAkUeirBUy59V/jAfBgNVHSMEGDAWgBSzQgqaABnLyySgAkUeirBUy59V/jAN +BgkqhkiG9w0BAQsFAAOCAQEAMpoF42+u7rGiEgqfCud4Jt+Q+4RgrhP8//1ChCMU +wy7iqd9LXC9bDj35WlZQE7yJGghw3Wxs6K7PIjmS8jtAA49OvFSIa/2MtuswkCHb +/E5cfhJ14lJ23xkPMEkeFby6aub3r5Ot5DbaR0emiLCuRh6R4daxXqTwaAKBV4Zd +F9FsfnqfXg38EOd6GrX5Sx14pJqd18Jkw1Iof6G3JdcTPwl/8v3dxpHrm1GA4jbL +n1tOR+t308yLGLULl6JTjvublH1XITLG82eTpJvrRrfNCEOZ3cHDUbkZ75J3HIRn +gGeVugB1PXuL/yQw8fpt2jGdzwbaXQQHFEWMa+chMex7Iw== +-----END CERTIFICATE----- diff --git a/tests/pytests/certs/tt-expired.key.pem b/tests/pytests/certs/tt-expired.key.pem new file mode 100644 index 0000000..ca2988c --- /dev/null +++ b/tests/pytests/certs/tt-expired.key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEAv2saEUcBrOtcLc/OaqQAzi/RJQNfBjgCkiQYkippGbIro0/3 +ed41w/VyN4NEk/l2/IkpMpwNS5V90V1A6bpJUH3GCsge55AeN3wLI6PjvMlTgd7W +X8uyPTasWbAzkY8MXxAgcL+jIrOYrNR66me4sYzP5f6PoKUCrW3O8WKr3F2WnE+V +R9WCt7Pjh0yNOIUqJJ1/x6QOvYota9LU6HhiG6olX1pk5XYjrhEDmlztorpR7LHz +rrpc691JY8rHrwwWHZSVOs4sj+KUfx+hduKf0UEx8Gjlrt/QdaA09SWThbMlUEJs +AMD+O+D7AN51M4ZqITUUnX9Kr/cV8te7L97fqwIDAQABAoIBAEA4ytIpJKLDhHXK +VtLom2ySFnV4oBUSDarCeYvwtrpsUL/GQJ2etCM+4kdFv2h2NjmcOzpDqSJG0aPA +ydqhKZ/b0uojIltGuxyafZJDllDsqxvTi9EwImjvQvwEZgjcGaZ7Xqb1ZOJrpzm1 +QFgM3KaVO9tKgR3Avxk40kmidU7FctFi5IELwnH/RR1OHvJbxOE4+i0LlDx0QzhX +QHtnvHLqLLdqsFk8KvuVuVj1FwqJ6cSL0JrAdt7dnGmXBo4PDqT8Hj0AjM+CcNrV +1D6Li9xr4y55EZUK2qU/FVDC3LqlYQy5mBfasJAXPQG4RgSVFxJ929HC7gi8vMCO +UMeLniECgYEA6gBoRwzQ5pJUXfZGW41lJt08utfycGZm7VrA81r0x0F+DcuZ2t6J +kB9Wnp/MNpB4DJLbl7oM2OlFOO3cw0n3VaFpNMPHVHzNbyi1hp94AIIeDz/sxfUI +Lx7ynAQSPPQzDRfVJesT8waBdweA71TBOlrFQ2Cp7O4Qf+p0akQSv3cCgYEA0Wnd +1Gbierv2m6Jnblg+brTMQwbRsOAM2n0V4Gd2kRaLSYd23ebshvx8xTWipRlrb5vP +UEh+LkfuscqaJDCrikasht9z5FJtfIzHKgTrLSoR3MJRjrnuLJWTQUwSqzd0UNN6 +HigV6p+CqesNnELErak53IMfmkHAhTSkII8R9m0CgYBRY+DhTaDfgegcYouoTm7v +bKYx6uillciZKCbSvkFDiREaJUYXba31ViEfvT8ff3JyFSaSCKFtVP3BxmIx/ukr +fKAGPU54oYwm7Mbu00q/CoMAFOD7HbZCBYanI3dggiO7mx2FOdXPguTHDPIYzKcE +8AuK2vVftpJAm8DwMUtAEwKBgH/eRc5ZGDdbKGS10LQm+9A7Y3IV6to2pIKQ2FfS +tSo4espmBeXPCGQQLdt5OZvYHqril77s1OdLkutKy74HXecr6lLchHZJAoOHrmDw +6e0FAC0tFgGxdEYS+vxnCAs17DciOjHJxkAiL/WzCfd9KXzklOkZw6U8OuLbVtBu +q8gtAoGAbl03XZm+SHrO7XjHK/Fe5YD13cOirg48htvjbpqEDZNQr3l0eVnEj074 +IopDa/wUFlaaqPZ/DVFctqSocyskWIP4u9HfmsNBHjK5zQlge7B1fNVao++YKund +qnVnXjWQuF2aL8k2geFxdSKmHTF4/N1qEyeyR+tMaFpGfMZOuM8= +-----END RSA PRIVATE KEY----- diff --git a/tests/pytests/certs/tt.cert.pem b/tests/pytests/certs/tt.cert.pem new file mode 100644 index 0000000..2ea4898 --- /dev/null +++ b/tests/pytests/certs/tt.cert.pem @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDkTCCAnmgAwIBAgIJAP/KybHquuyUMA0GCSqGSIb3DQEBCwUAMEIxCzAJBgNV +BAYTAkNaMQ8wDQYDVQQIDAZQUkFHVUUxIjAgBgNVBAMMGXRyYW5zcG9ydC10ZXN0 +LXNlcnZlci5jb20wIBcNMTgwNzE4MTAwODIzWhgPMjA3MzA0MjAxMDA4MjNaMEIx +CzAJBgNVBAYTAkNaMQ8wDQYDVQQIDAZQUkFHVUUxIjAgBgNVBAMMGXRyYW5zcG9y +dC10ZXN0LXNlcnZlci5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIB +AQDRAQDAX6+lFKurvm7fgQqm8WyYzT/wxfPJjsVQGe87OlH1KFzVfzYzgEt0RMlM +eZgipREBZB2zK+WFM5RBHWYAwlI5PKt7EAGn8q1Zm4z+M9Uom3/Hy3bZ9q+AJwjk +odpHYuFyWJqHIQBqaQ3SFyJwdZ/GsuzEUfWuIl74oyyMAeykTKFGdaVuIlLC3fKm +8UCnfk99i/LEXUwRcmOV0uaG7deN5ITDDCFdb615yVjLkMhGY/jHK7uuxATOopEk +4vThQ1aQjSkHwluaqFUW6Zl4QF8WOAufoWQPFZ8XxmUYEIG/sMvLv6dol7ltjEbC +bfyzlS+9Qbnq6MfhTZF/4jAPAgMBAAGjgYcwgYQwHQYDVR0OBBYEFBNiUgCiKw4b +CFNKaEkqhkNSer7wMB8GA1UdIwQYMBaAFBNiUgCiKw4bCFNKaEkqhkNSer7wMA8G +A1UdEwEB/wQFMAMBAf8wJAYDVR0RBB0wG4IZdHJhbnNwb3J0LXRlc3Qtc2VydmVy +LmNvbTALBgNVHQ8EBAMCAaYwDQYJKoZIhvcNAQELBQADggEBACz1ZQ8XkGobhTcA +hkSTSw0ko6qwVuJJD5ue3SUcWLATsskohTJmN6bde3IMDRyQvLJAlMdG2p1qMbtA +OTbnQJTT7oDLaW8w2D+eO5oWTJvxLpl6TxbIfJN/8ITB1ltOCxTU9cVNbd2eh8sj +l3R4etg9djYRrqtNxCQZOYSwvhHw2MefnwjGVuJEu6JYOn3IE8Jqsh/LI59C87nE +MetZrXlzC6kSAFfRYgQET9RhBobMU9yFR8zGVHDFoxqNQs2lYKPz/3rFPetL2rjT +cFwzxkxDdwn+RNisBc1LMfIg7pvSMFR6sAnpjeRHN0Uoem1jj2qtzjbFENDuyQ4/ +HSi4UcE= +-----END CERTIFICATE----- diff --git a/tests/pytests/certs/tt.conf b/tests/pytests/certs/tt.conf new file mode 100644 index 0000000..5ac7737 --- /dev/null +++ b/tests/pytests/certs/tt.conf @@ -0,0 +1,353 @@ +# +# OpenSSL example configuration file. +# This is mostly being used for generation of certificate requests. +# + +# This definition stops the following lines choking if HOME isn't +# defined. +HOME = . +RANDFILE = $ENV::HOME/.rnd + +# Extra OBJECT IDENTIFIER info: +#oid_file = $ENV::HOME/.oid +oid_section = new_oids + +# To use this configuration file with the "-extfile" option of the +# "openssl x509" utility, name here the section containing the +# X.509v3 extensions to use: +# extensions = +# (Alternatively, use a configuration file that has only +# X.509v3 extensions in its main [= default] section.) + +[ new_oids ] + +# We can add new OIDs in here for use by 'ca', 'req' and 'ts'. +# Add a simple OID like this: +# testoid1=1.2.3.4 +# Or use config file substitution like this: +# testoid2=${testoid1}.5.6 + +# Policies used by the TSA examples. +tsa_policy1 = 1.2.3.4.1 +tsa_policy2 = 1.2.3.4.5.6 +tsa_policy3 = 1.2.3.4.5.7 + +#################################################################### +[ ca ] +default_ca = CA_default # The default ca section + +#################################################################### +[ CA_default ] + +dir = ./demoCA # Where everything is kept +certs = $dir/certs # Where the issued certs are kept +crl_dir = $dir/crl # Where the issued crl are kept +database = $dir/index.txt # database index file. +#unique_subject = no # Set to 'no' to allow creation of + # several certs with same subject. +new_certs_dir = $dir/newcerts # default place for new certs. + +certificate = $dir/cacert.pem # The CA certificate +serial = $dir/serial # The current serial number +crlnumber = $dir/crlnumber # the current crl number + # must be commented out to leave a V1 CRL +crl = $dir/crl.pem # The current CRL +private_key = $dir/private/cakey.pem# The private key +RANDFILE = $dir/private/.rand # private random number file + +x509_extensions = usr_cert # The extensions to add to the cert + +# Comment out the following two lines for the "traditional" +# (and highly broken) format. +name_opt = ca_default # Subject Name options +cert_opt = ca_default # Certificate field options + +# Extension copying option: use with caution. +copy_extensions = copy + +# Extensions to add to a CRL. Note: Netscape communicator chokes on V2 CRLs +# so this is commented out by default to leave a V1 CRL. +# crlnumber must also be commented out to leave a V1 CRL. +# crl_extensions = crl_ext + +default_days = 365 # how long to certify for +default_crl_days= 30 # how long before next CRL +default_md = default # use public key default MD +preserve = no # keep passed DN ordering + +# A few difference way of specifying how similar the request should look +# For type CA, the listed attributes must be the same, and the optional +# and supplied fields are just that :-) +policy = policy_match + +# For the CA policy +[ policy_match ] +countryName = optional +stateOrProvinceName = optional +organizationName = optional +organizationalUnitName = optional +commonName = supplied +emailAddress = optional + +# For the 'anything' policy +# At this point in time, you must list all acceptable 'object' +# types. +[ policy_anything ] +countryName = optional +stateOrProvinceName = optional +localityName = optional +organizationName = optional +organizationalUnitName = optional +commonName = supplied +emailAddress = optional + +#################################################################### +[ req ] +default_bits = 2048 +default_keyfile = privkey.pem +distinguished_name = req_distinguished_name +attributes = req_attributes +x509_extensions = v3_ca # The extensions to add to the self signed cert + +# Passwords for private keys if not present they will be prompted for +# input_password = secret +# output_password = secret + +# This sets a mask for permitted string types. There are several options. +# default: PrintableString, T61String, BMPString. +# pkix : PrintableString, BMPString (PKIX recommendation before 2004) +# utf8only: only UTF8Strings (PKIX recommendation after 2004). +# nombstr : PrintableString, T61String (no BMPStrings or UTF8Strings). +# MASK:XXXX a literal mask value. +# WARNING: ancient versions of Netscape crash on BMPStrings or UTF8Strings. +string_mask = utf8only + +# req_extensions = v3_req # The extensions to add to a certificate request + +[ req_distinguished_name ] +countryName = Country Name (2 letter code) +countryName_default = CZ +countryName_min = 2 +countryName_max = 2 + +stateOrProvinceName = State or Province Name (full name) +stateOrProvinceName_default = PRAGUE + +localityName = Locality Name (eg, city) + +0.organizationName = Organization Name (eg, company) +0.organizationName_default = + +# we can do this but it is not needed normally :-) +#1.organizationName = Second Organization Name (eg, company) +#1.organizationName_default = World Wide Web Pty Ltd + +organizationalUnitName = Organizational Unit Name (eg, section) +#organizationalUnitName_default = + +commonName = Common Name (e.g. server FQDN or YOUR name) +commonName_max = 64 +commonName_default = transport-test-server.com + +emailAddress = Email Address +emailAddress_max = 64 + +# SET-ex3 = SET extension number 3 + +[ req_attributes ] +challengePassword = A challenge password +challengePassword_min = 4 +challengePassword_max = 20 + +unstructuredName = An optional company name + +[ usr_cert ] + +# These extensions are added when 'ca' signs a request. + +# This goes against PKIX guidelines but some CAs do it and some software +# requires this to avoid interpreting an end user certificate as a CA. + +basicConstraints=CA:FALSE + +# Here are some examples of the usage of nsCertType. If it is omitted +# the certificate can be used for anything *except* object signing. + +# This is OK for an SSL server. +# nsCertType = server + +# For an object signing certificate this would be used. +# nsCertType = objsign + +# For normal client use this is typical +# nsCertType = client, email + +# and for everything including object signing: +# nsCertType = client, email, objsign + +# This is typical in keyUsage for a client certificate. +# keyUsage = nonRepudiation, digitalSignature, keyEncipherment + +# This will be displayed in Netscape's comment listbox. +nsComment = "OpenSSL Generated Certificate" + +# PKIX recommendations harmless if included in all certificates. +subjectKeyIdentifier=hash +authorityKeyIdentifier=keyid,issuer + +# This stuff is for subjectAltName and issuerAltname. +# Import the email address. +# subjectAltName=email:copy +# An alternative to produce certificates that aren't +# deprecated according to PKIX. +# subjectAltName=email:move + +# Copy subject details +# issuerAltName=issuer:copy + +#nsCaRevocationUrl = http://www.domain.dom/ca-crl.pem +#nsBaseUrl +#nsRevocationUrl +#nsRenewalUrl +#nsCaPolicyUrl +#nsSslServerName + +# This is required for TSA certificates. +# extendedKeyUsage = critical,timeStamping + +[ v3_req ] + +# Extensions to add to a certificate request + +basicConstraints = CA:FALSE +keyUsage = nonRepudiation, digitalSignature, keyEncipherment + +[ v3_ca ] + + +# Extensions for a typical CA + + +# PKIX recommendation. + +subjectKeyIdentifier=hash + +authorityKeyIdentifier=keyid:always,issuer + +basicConstraints = critical,CA:true + +subjectAltName = @alternate_names + +# Key usage: this is typical for a CA certificate. However since it will +# prevent it being used as an test self-signed certificate it is best +# left out by default. +keyUsage = digitalSignature, keyEncipherment, cRLSign, keyCertSign + +# Some might want this also +# nsCertType = sslCA, emailCA + +# Include email address in subject alt name: another PKIX recommendation +# subjectAltName=email:copy +# Copy issuer details +# issuerAltName=issuer:copy + +# DER hex encoding of an extension: beware experts only! +# obj=DER:02:03 +# Where 'obj' is a standard or added object +# You can even override a supported extension: +# basicConstraints= critical, DER:30:03:01:01:FF + +[ crl_ext ] + +# CRL extensions. +# Only issuerAltName and authorityKeyIdentifier make any sense in a CRL. + +# issuerAltName=issuer:copy +authorityKeyIdentifier=keyid:always + +[ proxy_cert_ext ] +# These extensions should be added when creating a proxy certificate + +# This goes against PKIX guidelines but some CAs do it and some software +# requires this to avoid interpreting an end user certificate as a CA. + +basicConstraints=CA:FALSE + +# Here are some examples of the usage of nsCertType. If it is omitted +# the certificate can be used for anything *except* object signing. + +# This is OK for an SSL server. +# nsCertType = server + +# For an object signing certificate this would be used. +# nsCertType = objsign + +# For normal client use this is typical +# nsCertType = client, email + +# and for everything including object signing: +# nsCertType = client, email, objsign + +# This is typical in keyUsage for a client certificate. +# keyUsage = nonRepudiation, digitalSignature, keyEncipherment + +# This will be displayed in Netscape's comment listbox. +nsComment = "OpenSSL Generated Certificate" + +# PKIX recommendations harmless if included in all certificates. +subjectKeyIdentifier=hash +authorityKeyIdentifier=keyid,issuer + +# This stuff is for subjectAltName and issuerAltname. +# Import the email address. +# subjectAltName=email:copy +# An alternative to produce certificates that aren't +# deprecated according to PKIX. +# subjectAltName=email:move + +# Copy subject details +# issuerAltName=issuer:copy + +#nsCaRevocationUrl = http://www.domain.dom/ca-crl.pem +#nsBaseUrl +#nsRevocationUrl +#nsRenewalUrl +#nsCaPolicyUrl +#nsSslServerName + +# This really needs to be in place for it to be a proxy certificate. +proxyCertInfo=critical,language:id-ppl-anyLanguage,pathlen:3,policy:foo + +#################################################################### +[ tsa ] + +default_tsa = tsa_config1 # the default TSA section + +[ tsa_config1 ] + +# These are used by the TSA reply generation only. +dir = ./demoCA # TSA root directory +serial = $dir/tsaserial # The current serial number (mandatory) +crypto_device = builtin # OpenSSL engine to use for signing +signer_cert = $dir/tsacert.pem # The TSA signing certificate + # (optional) +certs = $dir/cacert.pem # Certificate chain to include in reply + # (optional) +signer_key = $dir/private/tsakey.pem # The TSA private key (optional) +signer_digest = sha256 # Signing digest to use. (Optional) +default_policy = tsa_policy1 # Policy if request did not specify it + # (optional) +other_policies = tsa_policy2, tsa_policy3 # acceptable policies (optional) +digests = sha1, sha256, sha384, sha512 # Acceptable message digests (mandatory) +accuracy = secs:1, millisecs:500, microsecs:100 # (optional) +clock_precision_digits = 0 # number of digits after dot. (optional) +ordering = yes # Is ordering defined for timestamps? + # (optional, default: no) +tsa_name = yes # Must the TSA name be included in the reply? + # (optional, default: no) +ess_cert_id_chain = no # Must the ESS cert id chain be included? + # (optional, default: no) + +[ alternate_names ] + +DNS.1 = transport-test-server.com diff --git a/tests/pytests/certs/tt.key.pem b/tests/pytests/certs/tt.key.pem new file mode 100644 index 0000000..1974be7 --- /dev/null +++ b/tests/pytests/certs/tt.key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDRAQDAX6+lFKur +vm7fgQqm8WyYzT/wxfPJjsVQGe87OlH1KFzVfzYzgEt0RMlMeZgipREBZB2zK+WF +M5RBHWYAwlI5PKt7EAGn8q1Zm4z+M9Uom3/Hy3bZ9q+AJwjkodpHYuFyWJqHIQBq +aQ3SFyJwdZ/GsuzEUfWuIl74oyyMAeykTKFGdaVuIlLC3fKm8UCnfk99i/LEXUwR +cmOV0uaG7deN5ITDDCFdb615yVjLkMhGY/jHK7uuxATOopEk4vThQ1aQjSkHwlua +qFUW6Zl4QF8WOAufoWQPFZ8XxmUYEIG/sMvLv6dol7ltjEbCbfyzlS+9Qbnq6Mfh +TZF/4jAPAgMBAAECggEBALSs10d18FMW0WjAUPxpgxnaLnTRSesMVLjy8ONT6Bkd +S2hRIh91vxc6WwABzrqLita4N0EqmPoggmNpuUmo7lrNoWLVbbAOoD/da7nA3FuL +10MpWYcP/ohh1klEdU2gFSAM/LNqoPsbrk5OzqHFWgI5zItqdX8pEucb01nBRWsp +VMY2vzVuFB2jweZQ5+LCpfSMcRIzlxQa9CG4Peu6YW1Z4b3aUcS63/829JN/ZOGd +uoRqR+gP71yNIt6i7wA5cot5FRmzlFEGhb1XzBOB1FFHOiknOZzbBtDsGUUmVtfA +6mXcTumhdHbC0bXnHei/s2s9X0EeyQFYPkoS4NUQ2dECgYEA/7lhgn89K8rpUPnS +eccTpKVPWp8luQei98Hi/F94kwP32l7Zl7Bmu2nltUoB1GBRXoXY6KzTphmT6ioA +8joLCKIii5/nOdZAdHbIN2tkXS56h524q5I2jKogjfRrpCaAJE8x99f8L9uTBfZb +/7BBQDHai1/S6LcpIRf/4g1/xBsCgYEA0Tq4V5hR9mGDUFir1FDGhA3ijDkIE/sO +3QGTU7W90BL27te98FuQtWOPqfd1fi26WypNpNQUZb3V4x5tmDcpWscfj6I10432 +4zECPlDgaevucJjj245U7WjUhdAvlRy6K8H/8MgRBAjw9h8dwIGIx9gmOqKdA+/h +ve3xyjKQex0CgYAz0XzQ1LewiA1/OyBLTOvOETFjS5x5QfLkAYXdXfswzz0KIu40 +rqoij/LcKYL1Zg8W+Ehb3amFnuk6KgjHDLvvo+scH+ra7W9iKi+oCzrrJt/tWyhw +m9Ax8Mdn/H9TY/nTYbjeYAXaLMQ+EQ3TYgPW3kNKusAiJ/tNmW9gfxvEwQKBgGSJ +Rbj5fTDZjGKYKQDdS3Z6wYhFg0culObHcgaARtPruPHtgtwy82blj0vJl5Bo4qoZ +urNgIOj+ff8jSOAiaWGwWs8Gz7x289IZY42UCTF8Z9d878g5LT/i5nPiJGsPIboS +/yuwxtRcg4SQURiGZbY5e60jJDWXF67O3icdguVVAoGARXLufXvZ/9Xf1DmFFxjq +PJMCa1sfofqjB4KqYbt17vFtTsddCiyqsbpx36oY6nIdm9yUiGo10YaSEJtDEGLS +L3TPZ4s8M8dcjOfj8Kk75pKbJ7NY4qA64dtbxcZbrFp3/mGZkDing94y+Zc/aFqa +xQsA/yhmYV9r+FHDL54Cn6I= +-----END PRIVATE KEY----- diff --git a/tests/pytests/conftest.py b/tests/pytests/conftest.py new file mode 100644 index 0000000..e10ba3a --- /dev/null +++ b/tests/pytests/conftest.py @@ -0,0 +1,78 @@ +import socket + +import pytest + +from kresd import init_portdir, make_kresd + + +@pytest.fixture +def kresd(tmpdir): + with make_kresd(tmpdir) as kresd: + yield kresd + + +@pytest.fixture +def kresd_tt(tmpdir): + with make_kresd(tmpdir, 'tt') as kresd: + yield kresd + + +@pytest.fixture(params=[ + 'ip_tcp_socket', + 'ip6_tcp_socket', + 'ip_tls_socket', + 'ip6_tls_socket', +]) +def make_kresd_sock(request, kresd): + """Factory function to create sockets of the same kind.""" + sock_func = getattr(kresd, request.param) + + def _make_kresd_sock(): + return sock_func() + + return _make_kresd_sock + + +@pytest.fixture +def kresd_sock(make_kresd_sock): + return make_kresd_sock() + + +@pytest.fixture(params=[ + socket.AF_INET, + socket.AF_INET6, +]) +def sock_family(request): + return request.param + + +@pytest.fixture(params=[ + True, + False +]) +def single_buffer(request): # whether to send all data in a single buffer + return request.param + + +@pytest.fixture(params=[ + True, + False +]) +def query_before(request): # whether to send an initial query + return request.param + + +@pytest.mark.optionalhook +def pytest_metadata(metadata): # filter potentially sensitive data from GitLab CI + keys_to_delete = [] + for key in metadata.keys(): + key_lower = key.lower() + if 'password' in key_lower or 'token' in key_lower or \ + key_lower.startswith('ci') or key_lower.startswith('gitlab'): + keys_to_delete.append(key) + for key in keys_to_delete: + del metadata[key] + + +def pytest_sessionstart(session): # pylint: disable=unused-argument + init_portdir() diff --git a/tests/pytests/conn_flood.py b/tests/pytests/conn_flood.py new file mode 100644 index 0000000..8c2625a --- /dev/null +++ b/tests/pytests/conn_flood.py @@ -0,0 +1,83 @@ +"""Test opening as many connections as possible. + +Due to resource-intensity of this test, it's filename doesn't contain +"test" on purpose, so it doesn't automatically get picked up by pytest +(to allow easy parallel testing). + +To execute this test, pass the filename of this file to pytest directly. +Also, make sure not to use parallel execution (-n). +""" + +import resource +import time + +import pytest + +from kresd import Kresd +import utils + + +MAX_SOCKETS = 10000 # upper bound of how many connections to open +MAX_ITERATIONS = 10 # number of iterations to run the test + +# we can't use softlimit ifself since kresd already has open sockets, +# so use lesser value +RESERVED_NOFILE = 40 # 40 is empirical value + + +@pytest.mark.parametrize('sock_func_name', [ + 'ip_tcp_socket', + 'ip6_tcp_socket', + 'ip_tls_socket', + 'ip6_tls_socket', +]) +def test_conn_flood(tmpdir, sock_func_name): + def create_sockets(make_sock, nsockets): + sockets = [] + next_ping = time.time() + 4 # less than tcp idle timeout / 2 + while True: + additional_sockets = 0 + while time.time() < next_ping: + nsock_to_init = min(100, nsockets - len(sockets)) + if not nsock_to_init: + return sockets + sockets.extend([make_sock() for _ in range(nsock_to_init)]) + additional_sockets += nsock_to_init + + # large number of connections can take a lot of time to open + # send some valid data to avoid TCP idle timeout for already open sockets + next_ping = time.time() + 4 + for s in sockets: + utils.ping_alive(s) + + # break when no more than 20% additional sockets are created + if additional_sockets / len(sockets) < 0.2: + return sockets + + max_num_of_open_files = resource.getrlimit(resource.RLIMIT_NOFILE)[0] - RESERVED_NOFILE + nsockets = min(max_num_of_open_files, MAX_SOCKETS) + + # create kresd instance with verbose=False + ip = '127.0.0.1' + ip6 = '::1' + with Kresd(tmpdir, ip=ip, ip6=ip6, verbose=False) as kresd: + make_sock = getattr(kresd, sock_func_name) # function for creating sockets + sockets = create_sockets(make_sock, nsockets) + print("\nEstablished {} connections".format(len(sockets))) + + print("Start sending data") + for i in range(MAX_ITERATIONS): + for s in sockets: + utils.ping_alive(s) + print("Iteration {} done...".format(i)) + + print("Close connections") + for s in sockets: + s.close() + + # check in kresd is alive + print("Check upstream is still alive") + sock = make_sock() + utils.ping_alive(sock) + + print("OK!") diff --git a/tests/pytests/kresd.py b/tests/pytests/kresd.py new file mode 100644 index 0000000..4b122f6 --- /dev/null +++ b/tests/pytests/kresd.py @@ -0,0 +1,305 @@ +from collections import namedtuple +from contextlib import ContextDecorator, contextmanager +import os +from pathlib import Path +import random +import re +import shutil +import socket +import subprocess +import time + +import jinja2 + +import utils + + +PYTESTS_DIR = os.path.dirname(os.path.realpath(__file__)) +CERTS_DIR = os.path.join(PYTESTS_DIR, 'certs') +TEMPLATES_DIR = os.path.join(PYTESTS_DIR, 'templates') +KRESD_CONF_TEMPLATE = 'kresd.conf.j2' +KRESD_STARTUP_MSGID = 10005 # special unique ID at the start of the "test" log +KRESD_PORTDIR = '/tmp/pytest-kresd-portdir' +KRESD_TESTPORT_MIN = 1024 +KRESD_TESTPORT_MAX = 32768 # avoid overlap with docker ephemeral port range + + +def init_portdir(): + try: + shutil.rmtree(KRESD_PORTDIR) + except FileNotFoundError: + pass + os.makedirs(KRESD_PORTDIR) + + +def create_file_from_template(template_path, dest, data): + env = jinja2.Environment( + loader=jinja2.FileSystemLoader(TEMPLATES_DIR)) + template = env.get_template(template_path) + rendered_template = template.render(**data) + + with open(dest, "w") as fh: + fh.write(rendered_template) + + +Forward = namedtuple('Forward', ['proto', 'ip', 'port', 'hostname', 'ca_file']) + + +class Kresd(ContextDecorator): + def __init__( + self, workdir, port=None, tls_port=None, ip=None, ip6=None, certname=None, + verbose=True, hints=None, forward=None): + if ip is None and ip6 is None: + raise ValueError("IPv4 or IPv6 must be specified!") + self.workdir = str(workdir) + self.port = port + self.tls_port = tls_port + self.ip = ip + self.ip6 = ip6 + self.process = None + self.sockets = [] + self.logfile = None + self.verbose = verbose + self.hints = {} if hints is None else hints + self.forward = forward + + if certname: + self.tls_cert_path = os.path.join(CERTS_DIR, certname + '.cert.pem') + self.tls_key_path = os.path.join(CERTS_DIR, certname + '.key.pem') + else: + self.tls_cert_path = None + self.tls_key_path = None + + @property + def config_path(self): + return str(os.path.join(self.workdir, 'kresd.conf')) + + @property + def logfile_path(self): + return str(os.path.join(self.workdir, 'kresd.log')) + + def __enter__(self): + if self.port is not None: + take_port(self.port, self.ip, self.ip6, timeout=120) + else: + self.port = make_port(self.ip, self.ip6) + if self.tls_port is not None: + take_port(self.tls_port, self.ip, self.ip6, timeout=120) + else: + self.tls_port = make_port(self.ip, self.ip6) + + create_file_from_template(KRESD_CONF_TEMPLATE, self.config_path, {'kresd': self}) + self.logfile = open(self.logfile_path, 'w') + self.process = subprocess.Popen( + ['kresd', '-c', self.config_path, '-f', '1', self.workdir], + stdout=self.logfile, env=os.environ.copy()) + + try: + self._wait_for_tcp_port() # wait for ports to be up and responding + if not self.all_ports_alive(msgid=10001): + raise RuntimeError("Kresd not listening on all ports") + + # issue special msgid to mark start of test log + sock = self.ip_tcp_socket() if self.ip else self.ip6_tcp_socket() + assert utils.try_ping_alive(sock, close=True, msgid=KRESD_STARTUP_MSGID) + + # sanity check - kresd didn't crash + self.process.poll() + if self.process.returncode is not None: + raise RuntimeError("Kresd crashed with returncode: {}".format( + self.process.returncode)) + except (RuntimeError, ConnectionError): # pylint: disable=try-except-raise + with open(self.logfile_path) as log: # print log for debugging + print(log.read()) + raise + + return self + + def __exit__(self, exc_type, exc_value, traceback): + try: + if not self.all_ports_alive(msgid=1006): + raise RuntimeError("Kresd crashed") + finally: + for sock in self.sockets: + sock.close() + self.process.terminate() + self.logfile.close() + Path(KRESD_PORTDIR, str(self.port)).unlink() + + def all_ports_alive(self, msgid=10001): + alive = True + if self.ip: + alive &= utils.try_ping_alive(self.ip_tcp_socket(), close=True, msgid=msgid) + alive &= utils.try_ping_alive(self.ip_tls_socket(), close=True, msgid=msgid + 1) + if self.ip6: + alive &= utils.try_ping_alive(self.ip6_tcp_socket(), close=True, msgid=msgid + 2) + alive &= utils.try_ping_alive(self.ip6_tls_socket(), close=True, msgid=msgid + 3) + return alive + + def _wait_for_tcp_port(self, max_delay=10, delay_step=0.2): + family = socket.AF_INET if self.ip else socket.AF_INET6 + i = 0 + end_time = time.time() + max_delay + + while time.time() < end_time: + i += 1 + + # use exponential backoff algorhitm to choose next delay + rand_delay = random.randrange(0, i) + time.sleep(rand_delay * delay_step) + + try: + sock, dest = self.stream_socket(family, timeout=5) + sock.connect(dest) + except ConnectionRefusedError: + continue + else: + try: + return utils.try_ping_alive(sock, close=True, msgid=10000) + except socket.timeout: + continue + finally: + sock.close() + raise RuntimeError("Kresd didn't start in time") + + def socket_dest(self, family, tls=False): + port = self.tls_port if tls else self.port + if family == socket.AF_INET: + return self.ip, port + elif family == socket.AF_INET6: + return self.ip6, port, 0, 0 + raise RuntimeError("Unsupported socket family: {}".format(family)) + + def stream_socket(self, family, tls=False, timeout=20): + """Initialize a socket and return it along with the destination without connecting.""" + sock = socket.socket(family, socket.SOCK_STREAM) + sock.settimeout(timeout) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + dest = self.socket_dest(family, tls) + self.sockets.append(sock) + return sock, dest + + def _tcp_socket(self, family): + sock, dest = self.stream_socket(family) + sock.connect(dest) + return sock + + def ip_tcp_socket(self): + return self._tcp_socket(socket.AF_INET) + + def ip6_tcp_socket(self): + return self._tcp_socket(socket.AF_INET6) + + def _tls_socket(self, family): + sock, dest = self.stream_socket(family, tls=True) + ctx = utils.make_ssl_context(insecure=True) + ssock = ctx.wrap_socket(sock) + try: + ssock.connect(dest) + except OSError as exc: + if exc.errno == 0: # sometimes happens shortly after startup + return None + return ssock + + def _tls_socket_with_retry(self, family): + sock = self._tls_socket(family) + if sock is None: + time.sleep(0.1) + sock = self._tls_socket(family) + if sock is None: + raise RuntimeError("Failed to create TLS socket!") + return sock + + def ip_tls_socket(self): + return self._tls_socket_with_retry(socket.AF_INET) + + def ip6_tls_socket(self): + return self._tls_socket_with_retry(socket.AF_INET6) + + def partial_log(self): + partial_log = '\n (... ommiting log start)\n' + with open(self.logfile_path) as log: # display partial log for debugging + past_startup_msgid = False + past_startup = False + for line in log: + if past_startup: + partial_log += line + else: # find real start of test log (after initial alive-pings) + if not past_startup_msgid: + if re.match(KRESD_LOG_STARTUP_MSGID, line) is not None: + past_startup_msgid = True + else: + if re.match(KRESD_LOG_IO_CLOSE, line) is not None: + past_startup = True + return partial_log + + +def is_port_free(port, ip=None, ip6=None): + def check(family, type_, dest): + sock = socket.socket(family, type_) + sock.bind(dest) + sock.close() + + try: + if ip is not None: + check(socket.AF_INET, socket.SOCK_STREAM, (ip, port)) + check(socket.AF_INET, socket.SOCK_DGRAM, (ip, port)) + if ip6 is not None: + check(socket.AF_INET6, socket.SOCK_STREAM, (ip6, port, 0, 0)) + check(socket.AF_INET6, socket.SOCK_DGRAM, (ip6, port, 0, 0)) + except OSError as exc: + if exc.errno == 98: # address alrady in use + return False + else: + raise + return True + + +def take_port(port, ip=None, ip6=None, timeout=0): + port_path = Path(KRESD_PORTDIR, str(port)) + end_time = time.time() + timeout + try: + port_path.touch(exist_ok=False) + except FileExistsError: + raise ValueError( + "Port {} already reserved by system or another kresd instance!".format(port)) + + while True: + if is_port_free(port, ip, ip6): + # NOTE: The port_path isn't removed, so other instances don't have to attempt to + # take the same port again. This has the side effect of leaving many of these + # files behind, because when another kresd shuts down and removes its file, the + # port still can't be reserved for a while. This shouldn't become an issue unless + # we have thousands of tests (and run out of the port range). + break + + if time.time() < end_time: + time.sleep(5) + else: + raise ValueError( + "Port {} is reserved by system!".format(port)) + return port + + +def make_port(ip=None, ip6=None): + for _ in range(10): # max attempts + port = random.randint(KRESD_TESTPORT_MIN, KRESD_TESTPORT_MAX) + try: + take_port(port, ip, ip6) + except ValueError: + continue # port reserved by system / another kresd instance + return port + raise RuntimeError("No available port found!") + + +KRESD_LOG_STARTUP_MSGID = re.compile(r'^\[{}.*'.format(KRESD_STARTUP_MSGID)) +KRESD_LOG_IO_CLOSE = re.compile(r'^\[io\].*closed by peer.*') + + +@contextmanager +def make_kresd( + workdir, certname=None, ip='127.0.0.1', ip6='::1', forward=None, hints=None, + port=None, tls_port=None): + with Kresd(workdir, port, tls_port, ip, ip6, certname, forward=forward, hints=hints) as kresd: + yield kresd + print(kresd.partial_log()) diff --git a/tests/pytests/pylintrc b/tests/pytests/pylintrc new file mode 100644 index 0000000..d9cf1a1 --- /dev/null +++ b/tests/pytests/pylintrc @@ -0,0 +1,29 @@ +[MESSAGES CONTROL] + +disable= + missing-docstring, + too-many-arguments, + too-many-instance-attributes, + fixme, + unused-import, # checked by flake8 + line-too-long, # checked by flake8 + invalid-name, + broad-except, + bad-continuation, + global-statement, + no-else-return, + redefined-outer-name, # commonly used with pytest fixtures + + +[SIMILARITIES] +min-similarity-lines=6 +ignore-comments=yes +ignore-docstrings=yes +ignore-imports=no + +[DESIGN] +max-parents=10 +max-locals=20 + +[TYPECHECK] +ignored-modules=ssl diff --git a/tests/pytests/rehandshake/Makefile b/tests/pytests/rehandshake/Makefile new file mode 100644 index 0000000..170b89e --- /dev/null +++ b/tests/pytests/rehandshake/Makefile @@ -0,0 +1,28 @@ +CC=gcc +CFLAGS_TLS=-DDEBUG -ggdb3 -O0 -lgnutls -luv +CFLAGS_TCP=-DDEBUG -ggdb3 -O0 -luv + +all: tcproxy tlsproxy + +tlsproxy: tls-proxy.o tlsproxy.o + $(CC) tls-proxy.o tlsproxy.o -o tlsproxy $(CFLAGS_TLS) + +tls-proxy.o: tls-proxy.c tls-proxy.h array.h + $(CC) -c -o $@ $< $(CFLAGS_TLS) + +tlsproxy.o: tlsproxy.c tls-proxy.h + $(CC) -c -o $@ $< $(CFLAGS_TLS) + +tcproxy: tcp-proxy.o tcproxy.o + $(CC) tcp-proxy.o tcproxy.o -o tcproxy $(CFLAGS_TCP) + +tcp-proxy.o: tcp-proxy.c tcp-proxy.h array.h + $(CC) -c -o $@ $< $(CFLAGS_TCP) + +tcproxy.o: tcproxy.c tcp-proxy.h + $(CC) -c -o $@ $< $(CFLAGS_TCP) + +clean: + rm -f tcp-proxy.o tcproxy.o tcproxy tls-proxy.o tlsproxy.o tlsproxy + +.PHONY: all clean diff --git a/tests/pytests/rehandshake/array.h b/tests/pytests/rehandshake/array.h new file mode 100644 index 0000000..ece4dd1 --- /dev/null +++ b/tests/pytests/rehandshake/array.h @@ -0,0 +1,166 @@ +/* Copyright (C) 2015-2017 CZ.NIC, z.s.p.o. + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + */ + +/** + * + * @file array.h + * @brief A set of simple macros to make working with dynamic arrays easier. + * + * @note The C has no generics, so it is implemented mostly using macros. + * Be aware of that, as direct usage of the macros in the evaluating macros + * may lead to different expectations: + * + * @code{.c} + * MIN(array_push(arr, val), other) + * @endcode + * + * May evaluate the code twice, leading to unexpected behaviour. + * This is a price to pay for the absence of proper generics. + * + * # Example usage: + * + * @code{.c} + * array_t(const char*) arr; + * array_init(arr); + * + * // Reserve memory in advance + * if (array_reserve(arr, 2) < 0) { + * return ENOMEM; + * } + * + * // Already reserved, cannot fail + * array_push(arr, "princess"); + * array_push(arr, "leia"); + * + * // Not reserved, may fail + * if (array_push(arr, "han") < 0) { + * return ENOMEM; + * } + * + * // It does not hide what it really is + * for (size_t i = 0; i < arr.len; ++i) { + * printf("%s\n", arr.at[i]); + * } + * + * // Random delete + * array_del(arr, 0); + * @endcode + * \addtogroup generics + * @{ + */ + +#pragma once +#include + +/** Simplified Qt containers growth strategy. */ +static inline size_t array_next_count(size_t want) +{ + if (want < 2048) { + return (want < 20) ? want + 4 : want * 2; + } else { + return want + 2048; + } +} + +/** @internal Incremental memory reservation */ +static inline int array_std_reserve(void *baton, char **mem, size_t elm_size, size_t want, size_t *have) +{ + if (*have >= want) { + return 0; + } + /* Simplified Qt containers growth strategy */ + size_t next_size = array_next_count(want); + void *mem_new = realloc(*mem, next_size * elm_size); + if (mem_new != NULL) { + *mem = mem_new; + *have = next_size; + return 0; + } + return -1; +} + +/** @internal Wrapper for stdlib free. */ +static inline void array_std_free(void *baton, void *p) +{ + free(p); +} + +/** Declare an array structure. */ +#define array_t(type) struct {type * at; size_t len; size_t cap; } + +/** Zero-initialize the array. */ +#define array_init(array) ((array).at = NULL, (array).len = (array).cap = 0) + +/** Free and zero-initialize the array (plain malloc/free). */ +#define array_clear(array) \ + array_clear_mm(array, array_std_free, NULL) + +/** Make the array empty and free pointed-to memory. + * Mempool usage: pass mm_free and a knot_mm_t* . */ +#define array_clear_mm(array, free, baton) \ + (free)((baton), (array).at), array_init(array) + +/** Reserve capacity for at least n elements. + * @return 0 if success, <0 on failure */ +#define array_reserve(array, n) \ + array_reserve_mm(array, n, array_std_reserve, NULL) + +/** Reserve capacity for at least n elements. + * Mempool usage: pass kr_memreserve and a knot_mm_t* . + * @return 0 if success, <0 on failure */ +#define array_reserve_mm(array, n, reserve, baton) \ + (reserve)((baton), (char **) &(array).at, sizeof((array).at[0]), (n), &(array).cap) + +/** + * Push value at the end of the array, resize it if necessary. + * Mempool usage: pass kr_memreserve and a knot_mm_t* . + * @note May fail if the capacity is not reserved. + * @return element index on success, <0 on failure + */ +#define array_push_mm(array, val, reserve, baton) \ + (int)((array).len < (array).cap ? ((array).at[(array).len] = val, (array).len++) \ + : (array_reserve_mm(array, ((array).cap + 1), reserve, baton) < 0 ? -1 \ + : ((array).at[(array).len] = val, (array).len++))) + +/** + * Push value at the end of the array, resize it if necessary (plain malloc/free). + * @note May fail if the capacity is not reserved. + * @return element index on success, <0 on failure + */ +#define array_push(array, val) \ + array_push_mm(array, val, array_std_reserve, NULL) + +/** + * Pop value from the end of the array. + */ +#define array_pop(array) \ + (array).len -= 1 + +/** + * Remove value at given index. + * @return 0 on success, <0 on failure + */ +#define array_del(array, i) \ + (int)((i) < (array).len ? ((array).len -= 1,(array).at[i] = (array).at[(array).len], 0) : -1) + +/** + * Return last element of the array. + * @warning Undefined if the array is empty. + */ +#define array_tail(array) \ + (array).at[(array).len - 1] + +/** @} */ diff --git a/tests/pytests/rehandshake/tcp-proxy.c b/tests/pytests/rehandshake/tcp-proxy.c new file mode 100644 index 0000000..ba7198b --- /dev/null +++ b/tests/pytests/rehandshake/tcp-proxy.c @@ -0,0 +1,336 @@ +#include +#include +#include +#include +#include +#include +#include +#include "array.h" + +struct buf { + char buf[16 * 1024]; + size_t size; +}; + +enum peer_state { + STATE_NOT_CONNECTED, + STATE_LISTENING, + STATE_CONNECTED, + STATE_CONNECT_IN_PROGRESS, + STATE_CLOSING_IN_PROGRESS +}; + +struct proxy_ctx { + uv_loop_t *loop; + uv_tcp_t server; + uv_tcp_t client; + uv_tcp_t upstream; + struct sockaddr_storage server_addr; + struct sockaddr_storage upstream_addr; + + int server_state; + int client_state; + int upstream_state; + + array_t(struct buf *) buffer_pool; + array_t(struct buf *) upstream_pending; +}; + +static void read_from_upstream_cb(uv_stream_t *upstream, ssize_t nread, const uv_buf_t *buf); +static void read_from_client_cb(uv_stream_t *client, ssize_t nread, const uv_buf_t *buf); + +static struct buf *borrow_io_buffer(struct proxy_ctx *proxy) +{ + struct buf *buf = NULL; + if (proxy->buffer_pool.len > 0) { + buf = array_tail(proxy->buffer_pool); + array_pop(proxy->buffer_pool); + } else { + buf = calloc(1, sizeof (struct buf)); + } + return buf; +} + +static void release_io_buffer(struct proxy_ctx *proxy, struct buf *buf) +{ + if (!buf) { + return; + } + + if (proxy->buffer_pool.len < 1000) { + buf->size = 0; + array_push(proxy->buffer_pool, buf); + } else { + free(buf); + } +} + +static void push_to_upstream_pending(struct proxy_ctx *proxy, const char *buf, size_t size) +{ + while (size > 0) { + struct buf *b = borrow_io_buffer(proxy); + b->size = size <= sizeof(b->buf) ? size : sizeof(b->buf); + memcpy(b->buf, buf, b->size); + array_push(proxy->upstream_pending, b); + size -= b->size; + } +} + +static struct buf *get_first_upstream_pending(struct proxy_ctx *proxy) +{ + struct buf *buf = NULL; + if (proxy->upstream_pending.len > 0) { + buf = proxy->upstream_pending.at[0]; + } + return buf; +} + +static void remove_first_upstream_pending(struct proxy_ctx *proxy) +{ + for (int i = 1; i < proxy->upstream_pending.len; ++i) { + proxy->upstream_pending.at[i - 1] = proxy->upstream_pending.at[i]; + } + if (proxy->upstream_pending.len > 0) { + proxy->upstream_pending.len -= 1; + } +} + +static void clear_upstream_pending(struct proxy_ctx *proxy) +{ + for (int i = 1; i < proxy->upstream_pending.len; ++i) { + struct buf *b = proxy->upstream_pending.at[i]; + release_io_buffer(proxy, b); + } + proxy->upstream_pending.len = 0; +} + +static void clear_buffer_pool(struct proxy_ctx *proxy) +{ + for (int i = 1; i < proxy->buffer_pool.len; ++i) { + struct buf *b = proxy->buffer_pool.at[i]; + free(b); + } + proxy->buffer_pool.len = 0; +} + +static void alloc_uv_buffer(uv_handle_t *handle, size_t suggested_size, uv_buf_t *buf) +{ + buf->base = (char*)malloc(suggested_size); + buf->len = suggested_size; +} + +static void on_client_close(uv_handle_t *handle) +{ + struct proxy_ctx *proxy = (struct proxy_ctx *)handle->loop->data; + proxy->client_state = STATE_NOT_CONNECTED; +} + +static void on_upstream_close(uv_handle_t *handle) +{ + struct proxy_ctx *proxy = (struct proxy_ctx *)handle->loop->data; + proxy->upstream_state = STATE_NOT_CONNECTED; +} + +static void write_to_client_cb(uv_write_t *req, int status) +{ + struct proxy_ctx *proxy = (struct proxy_ctx *)req->handle->loop->data; + free(req); + if (status) { + fprintf(stderr, "error writing to client: %s\n", uv_strerror(status)); + clear_upstream_pending(proxy); + proxy->client_state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&proxy->client, on_client_close); + } +} + +static void write_to_upstream_cb(uv_write_t *req, int status) +{ + struct proxy_ctx *proxy = (struct proxy_ctx *)req->handle->loop->data; + free(req); + if (status) { + fprintf(stderr, "error writing to upstream: %s\n", uv_strerror(status)); + clear_upstream_pending(proxy); + proxy->upstream_state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&proxy->upstream, on_upstream_close); + return; + } + if (proxy->upstream_pending.len > 0) { + struct buf *buf = get_first_upstream_pending(proxy); + remove_first_upstream_pending(proxy); + release_io_buffer(proxy, buf); + if (proxy->upstream_state == STATE_CONNECTED && + proxy->upstream_pending.len > 0) { + buf = get_first_upstream_pending(proxy); + /* TODO avoid allocation */ + uv_write_t *req = (uv_write_t *) malloc(sizeof(uv_write_t)); + uv_buf_t wrbuf = uv_buf_init(buf->buf, buf->size); + uv_write(req, (uv_stream_t *)&proxy->upstream, &wrbuf, 1, write_to_upstream_cb); + } + } +} + +static void on_client_connection(uv_stream_t *server, int status) +{ + if (status < 0) { + fprintf(stderr, "incoming connection error: %s\n", uv_strerror(status)); + return; + } + + fprintf(stdout, "incoming connection\n"); + + struct proxy_ctx *proxy = (struct proxy_ctx *)server->loop->data; + if (proxy->client_state != STATE_NOT_CONNECTED) { + fprintf(stderr, "client already connected, ignoring\n"); + return; + } + + uv_tcp_init(proxy->loop, &proxy->client); + proxy->client_state = STATE_CONNECTED; + if (uv_accept(server, (uv_stream_t*)&proxy->client) == 0) { + uv_read_start((uv_stream_t*)&proxy->client, alloc_uv_buffer, read_from_client_cb); + } else { + proxy->client_state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&proxy->client, on_client_close); + } +} + +static void on_connect_to_upstream(uv_connect_t *req, int status) +{ + struct proxy_ctx *proxy = (struct proxy_ctx *)req->handle->loop->data; + free(req); + if (status < 0) { + fprintf(stderr, "error connecting to upstream: %s\n", uv_strerror(status)); + clear_upstream_pending(proxy); + proxy->upstream_state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&proxy->upstream, on_upstream_close); + return; + } + + proxy->upstream_state = STATE_CONNECTED; + uv_read_start((uv_stream_t*)&proxy->upstream, alloc_uv_buffer, read_from_upstream_cb); + if (proxy->upstream_pending.len > 0) { + struct buf *buf = get_first_upstream_pending(proxy); + /* TODO avoid allocation */ + uv_write_t *wreq = (uv_write_t *) malloc(sizeof(uv_write_t)); + uv_buf_t wrbuf = uv_buf_init(buf->buf, buf->size); + uv_write(wreq, (uv_stream_t *)&proxy->upstream, &wrbuf, 1, write_to_upstream_cb); + } +} + +static void read_from_client_cb(uv_stream_t *client, ssize_t nread, const uv_buf_t *buf) +{ + if (nread == 0) { + return; + } + struct proxy_ctx *proxy = (struct proxy_ctx *)client->loop->data; + if (nread < 0) { + if (nread != UV_EOF) { + fprintf(stderr, "error reading from client: %s\n", uv_err_name(nread)); + } + if (proxy->client_state == STATE_CONNECTED) { + proxy->client_state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*) client, on_client_close); + } + return; + } + if (proxy->upstream_state == STATE_CONNECTED) { + if (proxy->upstream_pending.len > 0) { + push_to_upstream_pending(proxy, buf->base, nread); + } else { + /* TODO avoid allocation */ + uv_write_t *req = (uv_write_t *) malloc(sizeof(uv_write_t)); + uv_buf_t wrbuf = uv_buf_init(buf->base, nread); + uv_write(req, (uv_stream_t *)&proxy->upstream, &wrbuf, 1, write_to_upstream_cb); + } + } else if (proxy->upstream_state == STATE_NOT_CONNECTED) { + /* TODO avoid allocation */ + uv_tcp_init(proxy->loop, &proxy->upstream); + uv_connect_t *conn = (uv_connect_t *) malloc(sizeof(uv_connect_t)); + proxy->upstream_state = STATE_CONNECT_IN_PROGRESS; + uv_tcp_connect(conn, &proxy->upstream, (struct sockaddr *)&proxy->upstream_addr, + on_connect_to_upstream); + push_to_upstream_pending(proxy, buf->base, nread); + } else if (proxy->upstream_state == STATE_CONNECT_IN_PROGRESS) { + push_to_upstream_pending(proxy, buf->base, nread); + } +} + +static void read_from_upstream_cb(uv_stream_t *upstream, ssize_t nread, const uv_buf_t *buf) +{ + if (nread == 0) { + return; + } + struct proxy_ctx *proxy = (struct proxy_ctx *)upstream->loop->data; + if (nread < 0) { + if (nread != UV_EOF) { + fprintf(stderr, "error reading from upstream: %s\n", uv_err_name(nread)); + } + clear_upstream_pending(proxy); + if (proxy->upstream_state == STATE_CONNECTED) { + proxy->upstream_state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&proxy->upstream, on_upstream_close); + } + return; + } + if (proxy->client_state == STATE_CONNECTED) { + /* TODO Avoid allocation */ + uv_write_t *req = (uv_write_t *) malloc(sizeof(uv_write_t)); + uv_buf_t wrbuf = uv_buf_init(buf->base, nread); + uv_write(req, (uv_stream_t *)&proxy->client, &wrbuf, 1, write_to_client_cb); + } +} + +struct proxy_ctx *proxy_allocate() +{ + return malloc(sizeof(struct proxy_ctx)); +} + +int proxy_init(struct proxy_ctx *proxy, + const char *server_addr, int server_port, + const char *upstream_addr, int upstream_port) +{ + proxy->loop = uv_default_loop(); + uv_tcp_init(proxy->loop, &proxy->server); + int res = uv_ip4_addr(server_addr, server_port, (struct sockaddr_in *)&proxy->server_addr); + if (res != 0) { + return res; + } + res = uv_ip4_addr(upstream_addr, upstream_port, (struct sockaddr_in *)&proxy->upstream_addr); + if (res != 0) { + return res; + } + array_init(proxy->buffer_pool); + array_init(proxy->upstream_pending); + proxy->server_state = STATE_NOT_CONNECTED; + proxy->client_state = STATE_NOT_CONNECTED; + proxy->upstream_state = STATE_NOT_CONNECTED; + + proxy->loop->data = proxy; + return 0; +} + +void proxy_free(struct proxy_ctx *proxy) +{ + if (!proxy) { + return; + } + clear_upstream_pending(proxy); + clear_buffer_pool(proxy); + /* TODO correctly close all the uv_tcp_t */ + free(proxy); +} + +int proxy_start_listen(struct proxy_ctx *proxy) +{ + uv_tcp_bind(&proxy->server, (const struct sockaddr*)&proxy->server_addr, 0); + int ret = uv_listen((uv_stream_t*)&proxy->server, 128, on_client_connection); + if (ret == 0) { + proxy->server_state = STATE_LISTENING; + } + return ret; +} + +int proxy_run(struct proxy_ctx *proxy) +{ + return uv_run(proxy->loop, UV_RUN_DEFAULT); +} diff --git a/tests/pytests/rehandshake/tcp-proxy.h b/tests/pytests/rehandshake/tcp-proxy.h new file mode 100644 index 0000000..668a65f --- /dev/null +++ b/tests/pytests/rehandshake/tcp-proxy.h @@ -0,0 +1,12 @@ +#pragma once + +struct proxy_ctx; + +struct proxy_ctx *proxy_allocate(); +void proxy_free(struct proxy_ctx *proxy); +int proxy_init(struct proxy_ctx *proxy, + const char *server_addr, int server_port, + const char *upstream_addr, int upstream_port); +int proxy_start_listen(struct proxy_ctx *proxy); +int proxy_run(struct proxy_ctx *proxy); + diff --git a/tests/pytests/rehandshake/tcproxy.c b/tests/pytests/rehandshake/tcproxy.c new file mode 100644 index 0000000..87a6b4c --- /dev/null +++ b/tests/pytests/rehandshake/tcproxy.c @@ -0,0 +1,25 @@ +#include +#include "tcp-proxy.h" + +int main() +{ + struct proxy_ctx *proxy = proxy_allocate(); + if (!proxy) { + fprintf(stderr, "can't allocate proxy structure\n"); + return 1; + } + int res = proxy_init(proxy, "127.0.0.1", 54000, "127.0.0.1", 53001); + if (res) { + fprintf(stderr, "can't initialize proxy by given addresses\n"); + return res; + } + res = proxy_start_listen(proxy); + if (res) { + fprintf(stderr, "error starting listen, error code: %i\n", res); + return res; + } + res = proxy_run(proxy); + proxy_free(proxy); + return res; +} + diff --git a/tests/pytests/rehandshake/tls-proxy.c b/tests/pytests/rehandshake/tls-proxy.c new file mode 100644 index 0000000..bf6cc0d --- /dev/null +++ b/tests/pytests/rehandshake/tls-proxy.c @@ -0,0 +1,848 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include "array.h" + +#define TLS_MAX_SEND_RETRIES 100 +#define CLIENT_ANSWER_CHUNK_SIZE 8 +struct buf { + char buf[16 * 1024]; + size_t size; +}; + +enum peer_state { + STATE_NOT_CONNECTED, + STATE_LISTENING, + STATE_CONNECTED, + STATE_CONNECT_IN_PROGRESS, + STATE_CLOSING_IN_PROGRESS +}; + +enum handshake_state { + TLS_HS_NOT_STARTED = 0, + TLS_HS_EXPECTED, + TLS_HS_IN_PROGRESS, + TLS_HS_DONE, + TLS_HS_CLOSING, + TLS_HS_LAST +}; + +struct tls_ctx { + gnutls_session_t session; + int handshake_state; + gnutls_certificate_credentials_t credentials; + gnutls_priority_t priority_cache; + /* for reading from the network */ + const uint8_t *buf; + ssize_t nread; + ssize_t consumed; + uint8_t recv_buf[4096]; +}; + +struct tls_proxy_ctx { + uv_loop_t *loop; + uv_tcp_t server; + uv_tcp_t client; + uv_tcp_t upstream; + struct sockaddr_storage server_addr; + struct sockaddr_storage upstream_addr; + struct sockaddr_storage client_addr; + + int server_state; + int client_state; + int upstream_state; + + array_t(struct buf *) buffer_pool; + array_t(struct buf *) upstream_pending; + array_t(struct buf *) client_pending; + + char io_buf[0xFFFF]; + struct tls_ctx tls; +}; + +static void read_from_upstream_cb(uv_stream_t *upstream, ssize_t nread, const uv_buf_t *buf); +static void read_from_client_cb(uv_stream_t *client, ssize_t nread, const uv_buf_t *buf); +static ssize_t proxy_gnutls_pull(gnutls_transport_ptr_t h, void *buf, size_t len); +static ssize_t proxy_gnutls_push(gnutls_transport_ptr_t h, const void *buf, size_t len); +static int tls_process_from_upstream(struct tls_proxy_ctx *proxy, const uint8_t *buf, ssize_t nread); +static int tls_process_from_client(struct tls_proxy_ctx *proxy, const uint8_t *buf, ssize_t nread); +static int write_to_upstream_pending(struct tls_proxy_ctx *proxy); +static int write_to_client_pending(struct tls_proxy_ctx *proxy); + + +static int gnutls_references = 0; + +const void *ip_addr(const struct sockaddr *addr) +{ + if (!addr) { + return NULL; + } + switch (addr->sa_family) { + case AF_INET: return (const void *)&(((const struct sockaddr_in *)addr)->sin_addr); + case AF_INET6: return (const void *)&(((const struct sockaddr_in6 *)addr)->sin6_addr); + default: return NULL; + } +} + +uint16_t ip_addr_port(const struct sockaddr *addr) +{ + if (!addr) { + return 0; + } + switch (addr->sa_family) { + case AF_INET: return ntohs(((const struct sockaddr_in *)addr)->sin_port); + case AF_INET6: return ntohs(((const struct sockaddr_in6 *)addr)->sin6_port); + default: return 0; + } +} + +static int ip_addr_str(const struct sockaddr *addr, char *buf, size_t *buflen) +{ + int ret = 0; + if (!addr || !buf || !buflen) { + return EINVAL; + } + + char str[INET6_ADDRSTRLEN + 6]; + if (!inet_ntop(addr->sa_family, ip_addr(addr), str, sizeof(str))) { + return errno; + } + int len = strlen(str); + str[len] = '#'; + snprintf(&str[len + 1], 6, "%uh", ip_addr_port(addr)); + len += 6; + str[len] = 0; + if (len >= *buflen) { + ret = ENOSPC; + } else { + memcpy(buf, str, len + 1); + } + *buflen = len; + return ret; +} + +static inline char *ip_straddr(const struct sockaddr *addr) +{ + assert(addr != NULL); + /* We are the sinle-threaded application */ + static char str[INET6_ADDRSTRLEN + 6]; + size_t len = sizeof(str); + int ret = ip_addr_str(addr, str, &len); + return ret != 0 || len == 0 ? NULL : str; +} + +static struct buf *borrow_io_buffer(struct tls_proxy_ctx *proxy) +{ + struct buf *buf = NULL; + if (proxy->buffer_pool.len > 0) { + buf = array_tail(proxy->buffer_pool); + array_pop(proxy->buffer_pool); + } else { + buf = calloc(1, sizeof (struct buf)); + } + return buf; +} + +static void release_io_buffer(struct tls_proxy_ctx *proxy, struct buf *buf) +{ + if (!buf) { + return; + } + + if (proxy->buffer_pool.len < 1000) { + buf->size = 0; + array_push(proxy->buffer_pool, buf); + } else { + free(buf); + } +} + +static struct buf *get_first_upstream_pending(struct tls_proxy_ctx *proxy) +{ + struct buf *buf = NULL; + if (proxy->upstream_pending.len > 0) { + buf = proxy->upstream_pending.at[0]; + } + return buf; +} + +static struct buf *get_first_client_pending(struct tls_proxy_ctx *proxy) +{ + struct buf *buf = NULL; + if (proxy->client_pending.len > 0) { + buf = proxy->client_pending.at[0]; + } + return buf; +} + +static void remove_first_upstream_pending(struct tls_proxy_ctx *proxy) +{ + for (int i = 1; i < proxy->upstream_pending.len; ++i) { + proxy->upstream_pending.at[i - 1] = proxy->upstream_pending.at[i]; + } + if (proxy->upstream_pending.len > 0) { + proxy->upstream_pending.len -= 1; + } +} + +static void remove_first_client_pending(struct tls_proxy_ctx *proxy) +{ + for (int i = 1; i < proxy->client_pending.len; ++i) { + proxy->client_pending.at[i - 1] = proxy->client_pending.at[i]; + } + if (proxy->client_pending.len > 0) { + proxy->client_pending.len -= 1; + } +} + +static void clear_upstream_pending(struct tls_proxy_ctx *proxy) +{ + for (int i = 0; i < proxy->upstream_pending.len; ++i) { + struct buf *b = proxy->upstream_pending.at[i]; + release_io_buffer(proxy, b); + } + proxy->upstream_pending.len = 0; +} + +static void clear_client_pending(struct tls_proxy_ctx *proxy) +{ + for (int i = 0; i < proxy->client_pending.len; ++i) { + struct buf *b = proxy->client_pending.at[i]; + release_io_buffer(proxy, b); + } + proxy->client_pending.len = 0; +} + +static void clear_buffer_pool(struct tls_proxy_ctx *proxy) +{ + for (int i = 0; i < proxy->buffer_pool.len; ++i) { + struct buf *b = proxy->buffer_pool.at[i]; + free(b); + } + proxy->buffer_pool.len = 0; +} + +static void alloc_uv_buffer(uv_handle_t *handle, size_t suggested_size, uv_buf_t *buf) +{ + struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)handle->loop->data; + buf->base = proxy->io_buf; + buf->len = sizeof(proxy->io_buf); +} + +static void on_client_close(uv_handle_t *handle) +{ + struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)handle->loop->data; + gnutls_deinit(proxy->tls.session); + proxy->tls.handshake_state = TLS_HS_NOT_STARTED; + proxy->client_state = STATE_NOT_CONNECTED; +} + +static void on_dummmy_client_close(uv_handle_t *handle) +{ + free(handle); +} + +static void on_upstream_close(uv_handle_t *handle) +{ + struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)handle->loop->data; + proxy->upstream_state = STATE_NOT_CONNECTED; +} + +static void write_to_client_cb(uv_write_t *req, int status) +{ + struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)req->handle->loop->data; + free(req); + if (status) { + fprintf(stderr, "error writing to client: %s\n", uv_strerror(status)); + clear_client_pending(proxy); + clear_upstream_pending(proxy); + if (proxy->client_state == STATE_CONNECTED) { + proxy->client_state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&proxy->client, on_client_close); + return; + } + } + fprintf(stdout, "successfully wrote to client, pending len is %zd\n", + proxy->client_pending.len); + if (proxy->client_state == STATE_CONNECTED && + proxy->tls.handshake_state == TLS_HS_DONE) { + write_to_client_pending(proxy); + } +} + +static void write_to_upstream_cb(uv_write_t *req, int status) +{ + struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)req->handle->loop->data; + if (status) { + free(req); + fprintf(stderr, "error writing to upstream: %s\n", uv_strerror(status)); + clear_upstream_pending(proxy); + proxy->upstream_state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&proxy->upstream, on_upstream_close); + return; + } + if (req->data != NULL) { + assert(proxy->upstream_pending.len > 0); + struct buf *buf = get_first_upstream_pending(proxy); + assert(req->data == (void *)buf->buf); + fprintf(stdout, "successfully wrote %zi bytes to upstream, pending len is %zd\n", + buf->size, proxy->upstream_pending.len); + remove_first_upstream_pending(proxy); + release_io_buffer(proxy, buf); + } else { + fprintf(stdout, "successfully wrote bytes to upstream, pending len is %zd\n", + proxy->upstream_pending.len); + } + if (proxy->upstream_state == STATE_CONNECTED && + proxy->upstream_pending.len > 0) { + write_to_upstream_pending(proxy); + } + free(req); +} + +static void on_client_connection(uv_stream_t *server, int status) +{ + if (status < 0) { + fprintf(stderr, "incoming connection error: %s\n", uv_strerror(status)); + return; + } + + int err = 0; + int ret = 0; + struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)server->loop->data; + if (proxy->client_state != STATE_NOT_CONNECTED) { + fprintf(stderr, "incoming connection"); + uv_tcp_t *dummy_client = malloc(sizeof(uv_tcp_t)); + uv_tcp_init(proxy->loop, dummy_client); + err = uv_accept(server, (uv_stream_t*)dummy_client); + if (err == 0) { + struct sockaddr dummy_addr; + int dummy_addr_len = sizeof(dummy_addr); + ret = uv_tcp_getpeername(dummy_client, + &dummy_addr, + &dummy_addr_len); + if (ret == 0) { + fprintf(stderr, " from %s", ip_straddr(&dummy_addr)); + } + uv_close((uv_handle_t *)dummy_client, on_dummmy_client_close); + } else { + on_dummmy_client_close((uv_handle_t *)dummy_client); + } + fprintf(stderr, " - client already connected, rejecting\n"); + return; + } + + uv_tcp_init(proxy->loop, &proxy->client); + uv_tcp_nodelay((uv_tcp_t *)&proxy->client, 1); + proxy->client_state = STATE_CONNECTED; + err = uv_accept(server, (uv_stream_t*)&proxy->client); + if (err != 0) { + fprintf(stderr, "incoming connection - uv_accept() failed: (%d) %s\n", + err, uv_strerror(err)); + return; + } + + struct sockaddr *addr = (struct sockaddr *)&(proxy->client_addr); + int addr_len = sizeof(proxy->client_addr); + ret = uv_tcp_getpeername(&proxy->client, addr, &addr_len); + if (ret || addr->sa_family == AF_UNSPEC) { + proxy->client_state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&proxy->client, on_client_close); + fprintf(stderr, "incoming connection - uv_tcp_getpeername() failed: (%d) %s\n", + err, uv_strerror(err)); + return; + } + + fprintf(stdout, "incoming connection from %s\n", ip_straddr(addr)); + + uv_read_start((uv_stream_t*)&proxy->client, alloc_uv_buffer, read_from_client_cb); + + const char *errpos = NULL; + struct tls_ctx *tls = &proxy->tls; + assert (tls->handshake_state == TLS_HS_NOT_STARTED); + err = gnutls_init(&tls->session, GNUTLS_SERVER | GNUTLS_NONBLOCK); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stderr, "gnutls_init() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + } + err = gnutls_priority_set(tls->session, tls->priority_cache); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stderr, "gnutls_priority_set() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + } + err = gnutls_credentials_set(tls->session, GNUTLS_CRD_CERTIFICATE, tls->credentials); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stderr, "gnutls_credentials_set() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + } + gnutls_certificate_server_set_request(tls->session, GNUTLS_CERT_IGNORE); + gnutls_handshake_set_timeout(tls->session, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT); + + gnutls_transport_set_pull_function(tls->session, proxy_gnutls_pull); + gnutls_transport_set_push_function(tls->session, proxy_gnutls_push); + gnutls_transport_set_ptr(tls->session, proxy); + + tls->handshake_state = TLS_HS_IN_PROGRESS; +} + +static void on_connect_to_upstream(uv_connect_t *req, int status) +{ + struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)req->handle->loop->data; + free(req); + if (status < 0) { + fprintf(stderr, "error connecting to upstream (%s): %s\n", + ip_straddr((struct sockaddr *)&proxy->upstream_addr), + uv_strerror(status)); + clear_upstream_pending(proxy); + proxy->upstream_state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&proxy->upstream, on_upstream_close); + return; + } + fprintf(stdout, "connected to %s\n", ip_straddr((struct sockaddr *)&proxy->upstream_addr)); + + proxy->upstream_state = STATE_CONNECTED; + uv_read_start((uv_stream_t*)&proxy->upstream, alloc_uv_buffer, read_from_upstream_cb); + if (proxy->upstream_pending.len > 0) { + write_to_upstream_pending(proxy); + } +} + +static void read_from_client_cb(uv_stream_t *client, ssize_t nread, const uv_buf_t *buf) +{ + fprintf(stdout, "reading %zd bytes from client\n", nread); + if (nread == 0) { + return; + } + struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)client->loop->data; + if (nread < 0) { + if (nread != UV_EOF) { + fprintf(stderr, "error reading from client: %s\n", uv_err_name(nread)); + } else { + fprintf(stdout, "client has closed the connection\n"); + } + if (proxy->client_state == STATE_CONNECTED) { + proxy->client_state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*) client, on_client_close); + } + return; + } + + int res = tls_process_from_client(proxy, buf->base, nread); + if (res < 0) { + if (proxy->client_state == STATE_CONNECTED) { + proxy->client_state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*) client, on_client_close); + } + } +} + +static void read_from_upstream_cb(uv_stream_t *upstream, ssize_t nread, const uv_buf_t *buf) +{ + fprintf(stdout, "reading %zd bytes from upstream\n", nread); + if (nread == 0) { + return; + } + struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)upstream->loop->data; + if (nread < 0) { + if (nread != UV_EOF) { + fprintf(stderr, "error reading from upstream: %s\n", uv_err_name(nread)); + } else { + fprintf(stdout, "upstream has closed the connection\n"); + } + clear_upstream_pending(proxy); + if (proxy->upstream_state == STATE_CONNECTED) { + proxy->upstream_state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&proxy->upstream, on_upstream_close); + } + return; + } + int res = tls_process_from_upstream(proxy, buf->base, nread); + if (res < 0) { + fprintf(stderr, "error sending tls data to client\n"); + if (proxy->client_state == STATE_CONNECTED) { + proxy->client_state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&proxy->client, on_client_close); + } + } +} + +static void push_to_upstream_pending(struct tls_proxy_ctx *proxy, const char *buf, size_t size) +{ + while (size > 0) { + struct buf *b = borrow_io_buffer(proxy); + b->size = size <= sizeof(b->buf) ? size : sizeof(b->buf); + memcpy(b->buf, buf, b->size); + array_push(proxy->upstream_pending, b); + size -= b->size; + buf += b->size; + } +} + +static void push_to_client_pending(struct tls_proxy_ctx *proxy, const char *buf, size_t size) +{ + while (size > 0) { + struct buf *b = borrow_io_buffer(proxy); + b->size = size <= sizeof(b->buf) ? size : sizeof(b->buf); + if (b->size > CLIENT_ANSWER_CHUNK_SIZE) { + b->size = CLIENT_ANSWER_CHUNK_SIZE; + } + memcpy(b->buf, buf, b->size); + array_push(proxy->client_pending, b); + size -= b->size; + buf += b->size; + } +} + +static int write_to_upstream_pending(struct tls_proxy_ctx *proxy) +{ + struct buf *buf = get_first_upstream_pending(proxy); + /* TODO avoid allocation */ + uv_write_t *req = (uv_write_t *) malloc(sizeof(uv_write_t)); + uv_buf_t wrbuf = uv_buf_init(buf->buf, buf->size); + req->data = buf->buf; + fprintf(stdout, "writing %zd bytes to upstream\n", buf->size); + return uv_write(req, (uv_stream_t *)&proxy->upstream, &wrbuf, 1, write_to_upstream_cb); +} + +static ssize_t proxy_gnutls_pull(gnutls_transport_ptr_t h, void *buf, size_t len) +{ + struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)h; + struct tls_ctx *t = &proxy->tls; + + fprintf(stdout, "\t gnutls: pulling %zd bytes from client\n", len); + + if (t->nread <= t->consumed) { + errno = EAGAIN; + fprintf(stdout, "\t gnutls: return EAGAIN\n"); + return -1; + } + + ssize_t avail = t->nread - t->consumed; + ssize_t transfer = (avail <= len ? avail : len); + memcpy(buf, t->buf + t->consumed, transfer); + t->consumed += transfer; + return transfer; +} + +ssize_t proxy_gnutls_push(gnutls_transport_ptr_t h, const void *buf, size_t len) +{ + struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)h; + struct tls_ctx *t = &proxy->tls; + fprintf(stdout, "\t gnutls: writing %zd bytes to client\n", len); + + ssize_t ret = -1; + const size_t req_size_aligned = ((sizeof(uv_write_t) / 16) + 1) * 16; + char *common_buf = malloc(req_size_aligned + len); + uv_write_t *req = (uv_write_t *) common_buf; + char *data = common_buf + req_size_aligned; + const uv_buf_t uv_buf[1] = { + { data, len } + }; + memcpy(data, buf, len); + req->data = data; + int res = uv_write(req, (uv_stream_t *)&proxy->client, uv_buf, 1, write_to_client_cb); + if (res == 0) { + ret = len; + } else { + free(common_buf); + errno = EIO; + } + return ret; +} + +static int write_to_client_pending(struct tls_proxy_ctx *proxy) +{ + if (proxy->client_pending.len == 0) { + return 0; + } + + struct buf *buf = get_first_client_pending(proxy); + uv_buf_t wrbuf = uv_buf_init(buf->buf, buf->size); + fprintf(stdout, "writing %zd bytes to client\n", buf->size); + + gnutls_session_t tls_session = proxy->tls.session; + assert(proxy->tls.handshake_state != TLS_HS_IN_PROGRESS); + assert(gnutls_record_check_corked(tls_session) == 0); + + char *data = buf->buf; + size_t len = buf->size; + + ssize_t count = 0; + ssize_t submitted = len; + ssize_t retries = 0; + do { + count = gnutls_record_send(tls_session, data, len); + if (count < 0) { + if (gnutls_error_is_fatal(count)) { + fprintf(stderr, "gnutls_record_send failed: %s (%zd)\n", + gnutls_strerror_name(count), count); + return -1; + } + if (++retries > TLS_MAX_SEND_RETRIES) { + fprintf(stderr, "gnutls_record_send: too many sequential non-fatal errors (%zd), last error is: %s (%zd)\n", + retries, gnutls_strerror_name(count), count); + return -1; + } + } else if (count != 0) { + data += count; + len -= count; + retries = 0; + } else { + if (++retries < TLS_MAX_SEND_RETRIES) { + continue; + } + fprintf(stderr, "gnutls_record_send: too many retries (%zd)\n", + retries); + fprintf(stderr, "tls_push_to_client didn't send all data(%zd of %zd)\n", + len, submitted); + return -1; + } + } while (len > 0); + + remove_first_client_pending(proxy); + release_io_buffer(proxy, buf); + + fprintf(stdout, "submitted %zd bytes to client\n", submitted); + assert (gnutls_safe_renegotiation_status(tls_session) != 0); + assert (gnutls_rehandshake(tls_session) == GNUTLS_E_SUCCESS); + /* Prevent write-to-client callback from sending next pending chunk. + * At the same time tls_process_from_client() must not call gnutls_handshake() + * as there can be application data in this direction. */ + proxy->tls.handshake_state = TLS_HS_EXPECTED; + fprintf(stdout, "rehandshake started\n"); + return submitted; +} + +static int tls_process_from_upstream(struct tls_proxy_ctx *proxy, const uint8_t *buf, ssize_t len) +{ + gnutls_session_t tls_session = proxy->tls.session; + + fprintf(stdout, "pushing %zd bytes to client\n", len); + + assert(gnutls_record_check_corked(tls_session) == 0); + ssize_t submitted = 0; + if (proxy->client_state != STATE_CONNECTED) { + return submitted; + } + + bool list_was_empty = (proxy->client_pending.len == 0); + push_to_client_pending(proxy, buf, len); + submitted = len; + if (proxy->tls.handshake_state == TLS_HS_DONE) { + if (list_was_empty && proxy->client_pending.len > 0) { + int ret = write_to_client_pending(proxy); + if (ret < 0) { + submitted = -1; + } + } + } + + return submitted; +} + +int tls_process_handshake(struct tls_proxy_ctx *proxy) +{ + struct tls_ctx *tls = &proxy->tls; + int ret = 1; + while (tls->handshake_state == TLS_HS_IN_PROGRESS) { + fprintf(stdout, "TLS handshake in progress...\n"); + int err = gnutls_handshake(tls->session); + if (err == GNUTLS_E_SUCCESS) { + tls->handshake_state = TLS_HS_DONE; + fprintf(stdout, "TLS handshake has completed\n"); + ret = 1; + if (proxy->client_pending.len != 0) { + write_to_client_pending(proxy); + } + } else if (gnutls_error_is_fatal(err)) { + fprintf(stderr, "gnutls_handshake failed: %s (%d)\n", + gnutls_strerror_name(err), err); + ret = -1; + break; + } else { + fprintf(stderr, "gnutls_handshake nonfatal error: %s (%d)\n", + gnutls_strerror_name(err), err); + ret = 0; + break; + } + } + return ret; +} + +int tls_process_from_client(struct tls_proxy_ctx *proxy, const uint8_t *buf, ssize_t nread) +{ + struct tls_ctx *tls = &proxy->tls; + + tls->buf = buf; + tls->nread = nread >= 0 ? nread : 0; + tls->consumed = 0; + + fprintf(stdout, "tls_process: reading %zd bytes from client\n", nread); + + int ret = tls_process_handshake(proxy); + if (ret <= 0) { + return ret; + } + + int submitted = 0; + while (true) { + ssize_t count = 0; + count = gnutls_record_recv(tls->session, tls->recv_buf, sizeof(tls->recv_buf)); + if (count == GNUTLS_E_AGAIN) { + break; /* No data available */ + } else if (count == GNUTLS_E_INTERRUPTED) { + continue; /* Try reading again */ + } else if (count == GNUTLS_E_REHANDSHAKE) { + tls->handshake_state = TLS_HS_IN_PROGRESS; + ret = tls_process_handshake(proxy); + if (ret <= 0) { + return ret; + } + continue; + } else if (count < 0) { + fprintf(stderr, "gnutls_record_recv failed: %s (%zd)\n", + gnutls_strerror_name(count), count); + return -1; + } else if (count == 0) { + break; + } + if (proxy->upstream_state == STATE_CONNECTED) { + bool upstream_pending_is_empty = (proxy->upstream_pending.len == 0); + push_to_upstream_pending(proxy, tls->recv_buf, count); + if (upstream_pending_is_empty) { + write_to_upstream_pending(proxy); + } + } else if (proxy->upstream_state == STATE_NOT_CONNECTED) { + /* TODO avoid allocation */ + uv_tcp_init(proxy->loop, &proxy->upstream); + uv_connect_t *conn = (uv_connect_t *) malloc(sizeof(uv_connect_t)); + proxy->upstream_state = STATE_CONNECT_IN_PROGRESS; + fprintf(stdout, "connecting to %s\n", + ip_straddr((struct sockaddr *)&proxy->upstream_addr)); + uv_tcp_connect(conn, &proxy->upstream, (struct sockaddr *)&proxy->upstream_addr, + on_connect_to_upstream); + push_to_upstream_pending(proxy, tls->recv_buf, count); + } else if (proxy->upstream_state == STATE_CONNECT_IN_PROGRESS) { + push_to_upstream_pending(proxy, tls->recv_buf, count); + } + submitted += count; + } + return submitted; +} + +struct tls_proxy_ctx *tls_proxy_allocate() +{ + return malloc(sizeof(struct tls_proxy_ctx)); +} + +int tls_proxy_init(struct tls_proxy_ctx *proxy, + const char *server_addr, int server_port, + const char *upstream_addr, int upstream_port, + const char *cert_file, const char *key_file) +{ + proxy->loop = uv_default_loop(); + uv_tcp_init(proxy->loop, &proxy->server); + int res = uv_ip4_addr(server_addr, server_port, (struct sockaddr_in *)&proxy->server_addr); + if (res != 0) { + fprintf(stderr, "uv_ip4_addr failed with string '%s'\n", server_addr); + return -1; + } + res = uv_ip4_addr(upstream_addr, upstream_port, (struct sockaddr_in *)&proxy->upstream_addr); + if (res != 0) { + fprintf(stderr, "uv_ip4_addr failed with string '%s'\n", upstream_addr); + return -1; + } + array_init(proxy->buffer_pool); + array_init(proxy->upstream_pending); + array_init(proxy->client_pending); + proxy->server_state = STATE_NOT_CONNECTED; + proxy->client_state = STATE_NOT_CONNECTED; + proxy->upstream_state = STATE_NOT_CONNECTED; + + proxy->loop->data = proxy; + + int err = 0; + if (gnutls_references == 0) { + err = gnutls_global_init(); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stderr, "gnutls_global_init() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + return -1; + } + } + gnutls_references += 1; + + err = gnutls_certificate_allocate_credentials(&proxy->tls.credentials); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stderr, "gnutls_certificate_allocate_credentials() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + return -1; + } + + err = gnutls_certificate_set_x509_system_trust(proxy->tls.credentials); + if (err <= 0) { + fprintf(stderr, "gnutls_certificate_set_x509_system_trust() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + return -1; + } + + if (cert_file && key_file) { + err = gnutls_certificate_set_x509_key_file(proxy->tls.credentials, + cert_file, key_file, GNUTLS_X509_FMT_PEM); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stderr, "gnutls_certificate_set_x509_key_file() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + return -1; + } + } + + err = gnutls_priority_init(&proxy->tls.priority_cache, NULL, NULL); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stderr, "gnutls_priority_init() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + return -1; + } + + + proxy->tls.handshake_state = TLS_HS_NOT_STARTED; + return 0; +} + +void tls_proxy_free(struct tls_proxy_ctx *proxy) +{ + if (!proxy) { + return; + } + clear_upstream_pending(proxy); + clear_client_pending(proxy); + clear_buffer_pool(proxy); + gnutls_certificate_free_credentials(proxy->tls.credentials); + gnutls_priority_deinit(proxy->tls.priority_cache); + /* TODO correctly close all the uv_tcp_t */ + free(proxy); + + gnutls_references -= 1; + if (gnutls_references == 0) { + gnutls_global_deinit(); + } +} + +int tls_proxy_start_listen(struct tls_proxy_ctx *proxy) +{ + uv_tcp_bind(&proxy->server, (const struct sockaddr*)&proxy->server_addr, 0); + int ret = uv_listen((uv_stream_t*)&proxy->server, 128, on_client_connection); + if (ret == 0) { + proxy->server_state = STATE_LISTENING; + } + return ret; +} + +int tls_proxy_run(struct tls_proxy_ctx *proxy) +{ + return uv_run(proxy->loop, UV_RUN_DEFAULT); +} diff --git a/tests/pytests/rehandshake/tls-proxy.h b/tests/pytests/rehandshake/tls-proxy.h new file mode 100644 index 0000000..1204eda --- /dev/null +++ b/tests/pytests/rehandshake/tls-proxy.h @@ -0,0 +1,14 @@ +#pragma once + +struct tls_proxy_ctx; + +struct tls_proxy_ctx *tls_proxy_allocate(); +void tls_proxy_free(struct tls_proxy_ctx *proxy); +int tls_proxy_init(struct tls_proxy_ctx *proxy, + const char *server_addr, int server_port, + const char *upstream_addr, int upstream_port, + const char *cert_file, const char *key_file); +int tls_proxy_start_listen(struct tls_proxy_ctx *proxy); +int tls_proxy_run(struct tls_proxy_ctx *proxy); + + diff --git a/tests/pytests/rehandshake/tlsproxy.c b/tests/pytests/rehandshake/tlsproxy.c new file mode 100644 index 0000000..0c074f1 --- /dev/null +++ b/tests/pytests/rehandshake/tlsproxy.c @@ -0,0 +1,31 @@ +#include +#include "tls-proxy.h" +#include + +int main() +{ + struct tls_proxy_ctx *proxy = tls_proxy_allocate(); + if (!proxy) { + fprintf(stderr, "can't allocate tls_proxy structure\n"); + return 1; + } + int res = tls_proxy_init(proxy, + "127.0.0.1", 53921, /* Address to listen */ + "127.0.0.1", 53910, /* Upstream address */ + "../certs/tt.cert.pem", + "../certs/tt.key.pem"); + if (res) { + fprintf(stderr, "can't initialize tls_proxy structure\n"); + return res; + } + res = tls_proxy_start_listen(proxy); + if (res) { + fprintf(stderr, "error starting listen, error code: %i\n", res); + return res; + } + fprintf(stdout, "started...\n"); + res = tls_proxy_run(proxy); + tls_proxy_free(proxy); + return res; +} + diff --git a/tests/pytests/requirements.txt b/tests/pytests/requirements.txt new file mode 100644 index 0000000..6e2e4d2 --- /dev/null +++ b/tests/pytests/requirements.txt @@ -0,0 +1,5 @@ +dnspython +jinja2 +pytest +pytest-html +pytest-xdist diff --git a/tests/pytests/templates/kresd.conf.j2 b/tests/pytests/templates/kresd.conf.j2 new file mode 100644 index 0000000..4d95521 --- /dev/null +++ b/tests/pytests/templates/kresd.conf.j2 @@ -0,0 +1,42 @@ +modules = { + 'policy', + 'hints > iterate', +} + +verbose({{ 'true' if kresd.verbose else 'false' }}) + +{% if kresd.ip %} +net.listen('{{ kresd.ip }}', {{ kresd.port }}) +net.listen('{{ kresd.ip }}', {{ kresd.tls_port }}, {tls = true}) +{% endif %} + +{% if kresd.ip6 %} +net.listen('{{ kresd.ip6 }}', {{ kresd.port }}) +net.listen('{{ kresd.ip6 }}', {{ kresd.tls_port }}, {tls = true}) +{% endif %} + +net.ipv4=true +net.ipv6=true + +{% if kresd.tls_key_path and kresd.tls_cert_path %} +net.tls("{{ kresd.tls_cert_path }}", "{{ kresd.tls_key_path }}") +{% endif %} + +{% for name, ip in kresd.hints.items() %} +hints['{{ name }}'] = '{{ ip }}' +{% endfor %} + +policy.add(policy.all(policy.QTRACE)) + +{% if kresd.forward %} +policy.add(policy.all( + {% if kresd.forward.proto == 'tls' %} + policy.TLS_FORWARD({ + {"{{ kresd.forward.ip }}@{{ kresd.forward.port }}", hostname='{{ kresd.forward.hostname}}', ca_file='{{ kresd.forward.ca_file }}'}}) + {% endif %} +)) +{% endif %} + +modules.unload("ta_signal_query") +modules.unload("priming") +modules.unload("detect_time_skew") diff --git a/tests/pytests/test_conn_mgmt.py b/tests/pytests/test_conn_mgmt.py new file mode 100644 index 0000000..c4b1cba --- /dev/null +++ b/tests/pytests/test_conn_mgmt.py @@ -0,0 +1,213 @@ +"""TCP Connection Management tests""" + +import socket +import struct +import time + +import pytest + +import utils + + +@pytest.mark.parametrize('garbage_lengths', [ + (1,), + (1024,), + (65533,), # max size garbage + (65533, 65533), + (1024, 1024, 1024), + # (0,), # currently kresd uses this as a heuristic of "lost in bytestream" + # (0, 1024), # and closes the connection +]) +def test_ignore_garbage(kresd_sock, garbage_lengths, single_buffer, query_before): + """Send chunk of garbage, prefixed by garbage length. It should be ignored.""" + buff = b'' + if query_before: # optionally send initial query + msg_buff_before, msgid_before = utils.get_msgbuff() + if single_buffer: + buff += msg_buff_before + else: + kresd_sock.sendall(msg_buff_before) + + for glength in garbage_lengths: # prepare garbage data + if glength is None: + continue + garbage_buff = utils.get_prefixed_garbage(glength) + if single_buffer: + buff += garbage_buff + else: + kresd_sock.sendall(garbage_buff) + + msg_buff, msgid = utils.get_msgbuff() # final query + buff += msg_buff + kresd_sock.sendall(buff) + + if query_before: + answer_before = utils.receive_parse_answer(kresd_sock) + assert answer_before.id == msgid_before + answer = utils.receive_parse_answer(kresd_sock) + assert answer.id == msgid + + +def test_pipelining(kresd_sock): + """ + First query takes longer to resolve - answer to second query should arrive sooner. + + This test requires internet connection. + """ + # initialization (to avoid issues with net.ipv6=true) + buff_pre, msgid_pre = utils.get_msgbuff('0.delay.getdnsapi.net.') + kresd_sock.sendall(buff_pre) + msg_answer = utils.receive_parse_answer(kresd_sock) + assert msg_answer.id == msgid_pre + + # test + buff1, msgid1 = utils.get_msgbuff('1500.delay.getdnsapi.net.', msgid=1) + buff2, msgid2 = utils.get_msgbuff('1.delay.getdnsapi.net.', msgid=2) + buff = buff1 + buff2 + kresd_sock.sendall(buff) + + msg_answer = utils.receive_parse_answer(kresd_sock) + assert msg_answer.id == msgid2 + + msg_answer = utils.receive_parse_answer(kresd_sock) + assert msg_answer.id == msgid1 + + +@pytest.mark.parametrize('duration, delay', [ + (utils.MAX_TIMEOUT, 0.1), + (utils.MAX_TIMEOUT, 3), + (utils.MAX_TIMEOUT, 7), + (utils.MAX_TIMEOUT + 10, 3), +]) +def test_long_lived(kresd_sock, duration, delay): + """Establish and keep connection alive for longer than maximum timeout.""" + utils.ping_alive(kresd_sock) + end_time = time.time() + duration + + while time.time() < end_time: + time.sleep(delay) + utils.ping_alive(kresd_sock) + + +def test_close(kresd_sock, query_before): + """Establish a connection and wait for timeout from kresd.""" + if query_before: + utils.ping_alive(kresd_sock) + time.sleep(utils.MAX_TIMEOUT) + + with utils.expect_kresd_close(): + utils.ping_alive(kresd_sock) + + +def test_slow_lorris(kresd_sock, query_before): + """Simulate slow-lorris attack by sending byte after byte with delays in between.""" + if query_before: + utils.ping_alive(kresd_sock) + + buff, _ = utils.get_msgbuff() + end_time = time.time() + utils.MAX_TIMEOUT + + with utils.expect_kresd_close(): + for i in range(len(buff)): + b = buff[i:i+1] + kresd_sock.send(b) + if time.time() > end_time: + break + time.sleep(1) + + +@pytest.mark.parametrize('sock_func_name', [ + 'ip_tcp_socket', + 'ip6_tcp_socket', +]) +def test_oob(kresd, sock_func_name): + """TCP out-of-band (urgent) data must not crash resolver.""" + make_sock = getattr(kresd, sock_func_name) + sock = make_sock() + msg_buff, msgid = utils.get_msgbuff() + sock.sendall(msg_buff, socket.MSG_OOB) + + try: + msg_answer = utils.receive_parse_answer(sock) + assert msg_answer.id == msgid + except ConnectionError: + pass # TODO kresd responds with TCP RST, this should be fixed + + # check kresd is alive + sock2 = make_sock() + utils.ping_alive(sock2) + + +def flood_buffer(msgcount): + flood_buff = bytes() + msgbuff, _ = utils.get_msgbuff() + noid_msgbuff = msgbuff[2:] + + def gen_msg(msgid): + return struct.pack("!H", len(msgbuff)) + struct.pack("!H", msgid) + noid_msgbuff + + for i in range(msgcount): + flood_buff += gen_msg(i) + return flood_buff + + +def test_query_flood_close(make_kresd_sock): + """Flood resolver with queries and close the connection.""" + buff = flood_buffer(10000) + sock1 = make_kresd_sock() + sock1.sendall(buff) + sock1.close() + + sock2 = make_kresd_sock() + utils.ping_alive(sock2) + + +def test_query_flood_no_recv(make_kresd_sock): + """Flood resolver with queries but don't read any data.""" + # A use-case for TCP_USER_TIMEOUT socket option? See RFC 793 and RFC 5482 + + # It seems it doesn't works as expected. libuv doesn't return any error + # (neither on uv_write() call, not in the callback) when kresd sends answers, + # so kresd can't recognize that client didn't read any answers. At a certain + # point, kresd stops receiving queries from the client (whilst client keep + # sending) and closes connection due to timeout. + + buff = flood_buffer(10000) + sock1 = make_kresd_sock() + end_time = time.time() + utils.MAX_TIMEOUT + + with utils.expect_kresd_close(rst_ok=True): # connection must be closed + while time.time() < end_time: + sock1.sendall(buff) + time.sleep(0.5) + + sock2 = make_kresd_sock() + utils.ping_alive(sock2) # resolver must stay alive + + +@pytest.mark.parametrize('glength, gcount, delay', [ + (65533, 100, 0.5), + (0, 100000, 0.5), + (1024, 1000, 0.5), + (65533, 1, 0), + (0, 1, 0), + (1024, 1, 0), +]) +def test_query_flood_garbage(make_kresd_sock, glength, gcount, delay, query_before): + """Flood resolver with prefixed garbage.""" + sock1 = make_kresd_sock() + if query_before: + utils.ping_alive(sock1) + + gbuff = utils.get_prefixed_garbage(glength) + buff = gbuff * gcount + + end_time = time.time() + utils.MAX_TIMEOUT + + with utils.expect_kresd_close(rst_ok=True): # connection must be closed + while time.time() < end_time: + sock1.sendall(buff) + time.sleep(delay) + + sock2 = make_kresd_sock() + utils.ping_alive(sock2) # resolver must stay alive diff --git a/tests/pytests/test_prefix.py b/tests/pytests/test_prefix.py new file mode 100644 index 0000000..5d8ef16 --- /dev/null +++ b/tests/pytests/test_prefix.py @@ -0,0 +1,113 @@ +"""TCP Connection Management tests - prefix length + +RFC1035 +4.2.2. TCP usage +The message is prefixed with a two byte length field which gives the message +length, excluding the two byte length field. + +The following test suite focuses on edge cases for the prefix - when it +is either too short or too long, instead of matching the length of DNS +message exactly. +""" + +import time + +import pytest + +import utils + + +@pytest.fixture(params=[ + 'no_query_before', + 'query_before', + 'query_before_in_single_buffer', +]) +def send_query(request): + """Function sends a buffer, either by itself, or with a valid query before. + If a valid query is sent before, it can be sent either in a separate buffer, or + along with the provided buffer.""" + + # pylint: disable=possibly-unused-variable + + def no_query_before(sock, buff): # pylint: disable=unused-argument + sock.sendall(buff) + + def query_before(sock, buff, single_buffer=False): + """Send an initial query and expect a response.""" + msg_buff, msgid = utils.get_msgbuff() + + if single_buffer: + sock.sendall(msg_buff + buff) + else: + sock.sendall(msg_buff) + sock.sendall(buff) + + answer = utils.receive_parse_answer(sock) + assert answer.id == msgid + + def query_before_in_single_buffer(sock, buff): + return query_before(sock, buff, single_buffer=True) + + return locals()[request.param] + + +@pytest.mark.parametrize('datalen', [ + 1, # just one byte of DNS header + 11, # DNS header size minus 1 + 14, # DNS Header size plus 2 +]) +def test_prefix_cuts_message(kresd_sock, datalen, send_query): + """Prefix is shorter than the DNS message.""" + wire, _ = utils.prepare_wire() + assert datalen < len(wire) + invalid_buff = utils.prepare_buffer(wire, datalen) + + send_query(kresd_sock, invalid_buff) # buffer breaks parsing of TCP stream + + with utils.expect_kresd_close(): + utils.ping_alive(kresd_sock) + + +def test_prefix_greater_than_message(kresd_sock, send_query): + """Prefix is greater than the length of the entire DNS message.""" + wire, invalid_msgid = utils.prepare_wire() + datalen = len(wire) + 16 + invalid_buff = utils.prepare_buffer(wire, datalen) + + send_query(kresd_sock, invalid_buff) + + valid_buff, _ = utils.get_msgbuff() + kresd_sock.sendall(valid_buff) + + # invalid_buff is answered (treats additional data as trailing garbage) + answer = utils.receive_parse_answer(kresd_sock) + assert answer.id == invalid_msgid + + # parsing stream is broken by the invalid_buff, valid query is never answered + with utils.expect_kresd_close(): + utils.receive_parse_answer(kresd_sock) + + +@pytest.mark.parametrize('glength', [ + 0, + 1, + 8, + 1024, + 4096, + 20000, +]) +def test_prefix_trailing_garbage(kresd_sock, glength, query_before): + """Send messages with trailing garbage (its length included in prefix).""" + if query_before: + utils.ping_alive(kresd_sock) + + for _ in range(10): + wire, msgid = utils.prepare_wire() + wire += utils.get_garbage(glength) + buff = utils.prepare_buffer(wire) + + kresd_sock.sendall(buff) + answer = utils.receive_parse_answer(kresd_sock) + assert answer.id == msgid + + time.sleep(0.1) diff --git a/tests/pytests/test_rehandshake.py b/tests/pytests/test_rehandshake.py new file mode 100644 index 0000000..ffbc10b --- /dev/null +++ b/tests/pytests/test_rehandshake.py @@ -0,0 +1,87 @@ +"""TLS rehandshake test + +Test utilizes rehandshake/tls-proxy, which forwards queries to configured +resolver, but when it sends the response back to the query source, it +performs a rehandshake after every 8 bytes sent. + +It is expected the answer will be received by the source kresd instance +and sent back to the client (this test). + +Make sure to run `make all` in `rehandshake/` to compile the proxy. +""" + +import os +import re +import subprocess +import time + +import dns +import dns.rcode +import pytest + +from kresd import CERTS_DIR, Forward, make_kresd, PYTESTS_DIR +import utils + + +REHANDSHAKE_PROXY = os.path.join(PYTESTS_DIR, 'rehandshake', 'tlsproxy') + + +@pytest.mark.skipif(not os.path.exists(REHANDSHAKE_PROXY), + reason="tlsproxy not found (did you compile it?)") +def test_rehandshake(tmpdir): + def resolve_hint(sock, qname): + buff, msgid = utils.get_msgbuff(qname) + sock.sendall(buff) + answer = utils.receive_parse_answer(sock) + assert answer.id == msgid + assert answer.rcode() == dns.rcode.NOERROR + assert answer.answer[0][0].address == '127.0.0.1' + + hints = { + '0.foo.': '127.0.0.1', + '1.foo.': '127.0.0.1', + '2.foo.': '127.0.0.1', + '3.foo.': '127.0.0.1', + } + # run forward target instance + workdir = os.path.join(str(tmpdir), 'kresd_fwd_target') + os.makedirs(workdir) + + with make_kresd(workdir, hints=hints, port=53910) as kresd_fwd_target: + sock = kresd_fwd_target.ip_tls_socket() + resolve_hint(sock, '0.foo.') + + # run proxy + cwd, cmd = os.path.split(REHANDSHAKE_PROXY) + cmd = './' + cmd + ca_file = os.path.join(CERTS_DIR, 'tt.cert.pem') + try: + proxy = subprocess.Popen( + [cmd], cwd=cwd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + + # run test kresd instance + workdir2 = os.path.join(str(tmpdir), 'kresd') + os.makedirs(workdir2) + forward = Forward(proto='tls', ip='127.0.0.1', port=53921, + hostname='transport-test-server.com', ca_file=ca_file) + with make_kresd(workdir2, forward=forward) as kresd: + sock2 = kresd.ip_tcp_socket() + try: + for hint in hints: + resolve_hint(sock2, hint) + time.sleep(0.1) + finally: + # verify log + n_connecting_to = 0 + n_rehandshake = 0 + partial_log = kresd.partial_log() + print(partial_log) + for line in partial_log.splitlines(): + if re.search(r"connecting to: .*", line) is not None: + n_connecting_to += 1 + elif re.search(r"TLS rehandshake .* has started", line) is not None: + n_rehandshake += 1 + assert n_connecting_to == 0 # shouldn't be present in partial log + assert n_rehandshake > 0 + finally: + proxy.terminate() diff --git a/tests/pytests/test_tls.py b/tests/pytests/test_tls.py new file mode 100644 index 0000000..361741d --- /dev/null +++ b/tests/pytests/test_tls.py @@ -0,0 +1,77 @@ +"""TLS-specific tests""" + +import itertools +import os +from socket import AF_INET, AF_INET6 +import ssl +import sys + +import pytest + +from kresd import make_kresd +import utils + + +def test_tls_no_cert(kresd, sock_family): + """Use TLS without certificates.""" + sock, dest = kresd.stream_socket(sock_family, tls=True) + ctx = utils.make_ssl_context(insecure=True) + ssock = ctx.wrap_socket(sock) + ssock.connect(dest) + + utils.ping_alive(ssock) + + +def test_tls_selfsigned_cert(kresd_tt, sock_family): + """Use TLS with a self signed certificate.""" + sock, dest = kresd_tt.stream_socket(sock_family, tls=True) + ctx = utils.make_ssl_context(verify_location=kresd_tt.tls_cert_path) + ssock = ctx.wrap_socket(sock, server_hostname='transport-test-server.com') + ssock.connect(dest) + + utils.ping_alive(ssock) + + +def test_tls_cert_hostname_mismatch(kresd_tt, sock_family): + """Attempt to use self signed certificate and incorrect hostname.""" + sock, dest = kresd_tt.stream_socket(sock_family, tls=True) + ctx = utils.make_ssl_context(verify_location=kresd_tt.tls_cert_path) + ssock = ctx.wrap_socket(sock, server_hostname='wrong-host-name') + + with pytest.raises(ssl.CertificateError): + ssock.connect(dest) + + +@pytest.mark.skipif(sys.version_info < (3, 6), + reason="requires python3.6 or higher") +@pytest.mark.parametrize('sf1, sf2, sf3', itertools.product( + [AF_INET, AF_INET6], [AF_INET, AF_INET6], [AF_INET, AF_INET6])) +def test_tls_session_resumption(tmpdir, sf1, sf2, sf3): + """Attempt TLS session resumption against the same kresd instance and a different one.""" + # TODO ensure that session can't be resumed after session ticket key regeneration + # at the first kresd instance + + def connect(kresd, ctx, sf, session=None): + sock, dest = kresd.stream_socket(sf, tls=True) + ssock = ctx.wrap_socket( + sock, server_hostname='transport-test-server.com', session=session) + ssock.connect(dest) + new_session = ssock.session + assert new_session.has_ticket + assert ssock.session_reused == (session is not None) + utils.ping_alive(ssock) + ssock.close() + return new_session + + workdir = os.path.join(str(tmpdir), 'kresd') + os.makedirs(workdir) + + with make_kresd(workdir, 'tt') as kresd: + ctx = utils.make_ssl_context(verify_location=kresd.tls_cert_path) + session = connect(kresd, ctx, sf1) # initial conn + connect(kresd, ctx, sf2, session) # resume session on the same instance + + workdir2 = os.path.join(str(tmpdir), 'kresd2') + os.makedirs(workdir2) + with make_kresd(workdir2, 'tt') as kresd2: + connect(kresd2, ctx, sf3, session) # resume session on a different instance diff --git a/tests/pytests/utils.py b/tests/pytests/utils.py new file mode 100644 index 0000000..dcdc14c --- /dev/null +++ b/tests/pytests/utils.py @@ -0,0 +1,131 @@ +from contextlib import contextmanager +import random +import ssl +import struct +import time + +import dns +import dns.message +import pytest + + +# default net.tcp_in_idle is 10s, TCP_DEFER_ACCEPT 3s, some extra for +# Python handling / edge cases +MAX_TIMEOUT = 16 + + +def receive_answer(sock): + answer_total_len = 0 + data = sock.recv(2) + if not data: + return None + answer_total_len = struct.unpack_from("!H", data)[0] + + answer_received_len = 0 + data_answer = b'' + while answer_received_len < answer_total_len: + data_chunk = sock.recv(answer_total_len - answer_received_len) + if not data_chunk: + return None + data_answer += data_chunk + answer_received_len += len(data_answer) + + return data_answer + + +def receive_parse_answer(sock): + data_answer = receive_answer(sock) + + if data_answer is None: + raise BrokenPipeError("kresd closed connection") + + msg_answer = dns.message.from_wire(data_answer, one_rr_per_rrset=True) + return msg_answer + + +def prepare_wire( + qname='localhost.', + qtype=dns.rdatatype.A, + qclass=dns.rdataclass.IN, + msgid=None): + """Utility function to generate DNS wire format message""" + msg = dns.message.make_query(qname, qtype, qclass) + if msgid is not None: + msg.id = msgid + return msg.to_wire(), msg.id + + +def prepare_buffer(wire, datalen=None): + """Utility function to prepare TCP buffer from DNS message in wire format""" + assert isinstance(wire, bytes) + if datalen is None: + datalen = len(wire) + return struct.pack("!H", datalen) + wire + + +def get_msgbuff(qname='localhost.', qtype=dns.rdatatype.A, msgid=None): + wire, msgid = prepare_wire(qname, qtype, msgid=msgid) + buff = prepare_buffer(wire) + return buff, msgid + + +def get_garbage(length): + return bytes(random.getrandbits(8) for _ in range(length)) + + +def get_prefixed_garbage(length): + data = get_garbage(length) + return prepare_buffer(data) + + +def try_ping_alive(sock, msgid=None, close=False): + try: + ping_alive(sock, msgid) + except AssertionError: + return False + finally: + if close: + sock.close() + return True + + +def ping_alive(sock, msgid=None): + buff, msgid = get_msgbuff(msgid=msgid) + sock.sendall(buff) + answer = receive_parse_answer(sock) + assert answer.id == msgid + + +@contextmanager +def expect_kresd_close(rst_ok=False): + with pytest.raises(BrokenPipeError, message="kresd didn't close the connection"): + try: + time.sleep(0.2) # give kresd time to close connection with TCP FIN + yield + except ConnectionResetError: + if rst_ok: + raise BrokenPipeError + else: + pytest.skip("kresd closed connection with TCP RST") + + +def make_ssl_context(insecure=False, verify_location=None): + # set TLS v1.2+ + context = ssl.SSLContext(ssl.PROTOCOL_TLS) + context.options |= ssl.OP_NO_SSLv2 + context.options |= ssl.OP_NO_SSLv3 + context.options |= ssl.OP_NO_TLSv1 + context.options |= ssl.OP_NO_TLSv1_1 + + if insecure: + # turn off certificate verification + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + else: + context.verify_mode = ssl.CERT_REQUIRED + context.check_hostname = True + + if verify_location is not None: + context.load_verify_locations(verify_location) + + return context -- cgit v1.2.3