summaryrefslogtreecommitdiffstats
path: root/tests/pytests
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/pytests/README.rst54
-rwxr-xr-xtests/pytests/certs/tt-certgen-expired.sh18
-rwxr-xr-xtests/pytests/certs/tt-certgen.sh4
-rw-r--r--tests/pytests/certs/tt-expired.cert.pem80
-rw-r--r--tests/pytests/certs/tt-expired.key.pem27
-rw-r--r--tests/pytests/certs/tt.cert.pem22
-rw-r--r--tests/pytests/certs/tt.conf353
-rw-r--r--tests/pytests/certs/tt.key.pem28
-rw-r--r--tests/pytests/conftest.py78
-rw-r--r--tests/pytests/conn_flood.py83
-rw-r--r--tests/pytests/kresd.py305
-rw-r--r--tests/pytests/pylintrc29
-rw-r--r--tests/pytests/rehandshake/Makefile28
-rw-r--r--tests/pytests/rehandshake/array.h166
-rw-r--r--tests/pytests/rehandshake/tcp-proxy.c336
-rw-r--r--tests/pytests/rehandshake/tcp-proxy.h12
-rw-r--r--tests/pytests/rehandshake/tcproxy.c25
-rw-r--r--tests/pytests/rehandshake/tls-proxy.c848
-rw-r--r--tests/pytests/rehandshake/tls-proxy.h14
-rw-r--r--tests/pytests/rehandshake/tlsproxy.c31
-rw-r--r--tests/pytests/requirements.txt5
-rw-r--r--tests/pytests/templates/kresd.conf.j242
-rw-r--r--tests/pytests/test_conn_mgmt.py213
-rw-r--r--tests/pytests/test_prefix.py113
-rw-r--r--tests/pytests/test_rehandshake.py87
-rw-r--r--tests/pytests/test_tls.py77
-rw-r--r--tests/pytests/utils.py131
27 files changed, 3209 insertions, 0 deletions
diff --git a/tests/pytests/README.rst b/tests/pytests/README.rst
new file mode 100644
index 0000000..9a11ccd
--- /dev/null
+++ b/tests/pytests/README.rst
@@ -0,0 +1,54 @@
+Python client tests for kresd
+=============================
+
+The tests run `/usr/bin/env kresd` (can be modified with `$PATH`) with custom config
+and execute client-side testing, such as TCP / TLS connection management.
+
+Requirements
+------------
+
+- pip3 install -r requirements.txt
+
+Executing tests
+---------------
+
+Tests can be executed with the pytest framework.
+
+.. code-block:: bash
+
+ $ pytest-3 # sequential, all tests (with exception of few special tests)
+ $ pytest-3 test_conn_mgmt.py::test_ignore_garbage # specific test only
+ $ pytest-3 --html pytests.html --self-contained-html # html report
+
+It's highly recommended to run these tests in parallel, since lot of them
+wait for kresd timeout. This can be done with `python-xdist`:
+
+.. code-block:: bash
+
+ $ pytest-3 -n 24 # parallel with 24 jobs
+
+Each test spawns an independent kresd instance, so test failures shouldn't affect
+each other.
+
+Some tests are omitted from automatic test collection by default, due to their
+resource contraints. These typicially have to be executed separately by providing
+the path to test file directly.
+
+.. code-block:: bash
+
+ $ pytest-3 conn_flood.py
+
+Note: some tests may fail without an internet connection.
+
+Developer notes
+---------------
+
+Typically, each test requires a setup of kresd, and a connected socket to run tests on.
+The framework provides a few useful pytest fixtures to simplify this process:
+
+- `kresd_sock` provides a connected socket to a test-specific, running kresd instance.
+ It expands to 4 values (tests) - IPv4 TCP, IPv6 TCP, IPv4 TLS, IPv6 TLS sockets
+- `make_kresd_sock` is similar to `kresd_sock`, except it's a factory function that
+ produces a new connected socket (of the same type) on each call
+- `kresd`, `kresd_tt` are all Kresd instances, already running
+ and initialized with config (with no / valid TLS certificates)
diff --git a/tests/pytests/certs/tt-certgen-expired.sh b/tests/pytests/certs/tt-certgen-expired.sh
new file mode 100755
index 0000000..7750291
--- /dev/null
+++ b/tests/pytests/certs/tt-certgen-expired.sh
@@ -0,0 +1,18 @@
+# !/bin/bash
+
+if [ ! -d ./demoCA ]; then
+ mkdir ./demoCA
+fi
+if [ ! -d ./demoCA/newcerts ]; then
+ mkdir ./demoCA/newcerts
+fi
+touch ./demoCA/index.txt
+touch ./demoCA/index.txt.attr
+if [ ! -f ./demoCA/serial ]; then
+ echo 01 > ./demoCA/serial
+fi
+
+openssl genrsa -out tt-expired.key.pem 2048
+openssl req -config tt.conf -new -key tt-expired.key.pem -out tt-expired.csr.pem
+openssl ca -config tt.conf -selfsign -keyfile tt-expired.key.pem -out tt-expired.cert.pem -in tt-expired.csr.pem -startdate 19700101000000Z -enddate 19700101000000Z
+
diff --git a/tests/pytests/certs/tt-certgen.sh b/tests/pytests/certs/tt-certgen.sh
new file mode 100755
index 0000000..b6b3d7f
--- /dev/null
+++ b/tests/pytests/certs/tt-certgen.sh
@@ -0,0 +1,4 @@
+# !/bin/sh
+
+openssl req -config tt.conf -new -x509 -newkey rsa:2048 -nodes -keyout tt.key.pem -sha256 -out tt.cert.pem -days 20000
+
diff --git a/tests/pytests/certs/tt-expired.cert.pem b/tests/pytests/certs/tt-expired.cert.pem
new file mode 100644
index 0000000..c9f8c09
--- /dev/null
+++ b/tests/pytests/certs/tt-expired.cert.pem
@@ -0,0 +1,80 @@
+Certificate:
+ Data:
+ Version: 3 (0x2)
+ Serial Number: 1 (0x1)
+ Signature Algorithm: sha256WithRSAEncryption
+ Issuer: C=CZ, ST=PRAGUE, CN=transport-test-server.com
+ Validity
+ Not Before: Jan 1 00:00:00 1970 GMT
+ Not After : Jan 1 00:00:00 1970 GMT
+ Subject: C=CZ, ST=PRAGUE, CN=transport-test-server.com
+ Subject Public Key Info:
+ Public Key Algorithm: rsaEncryption
+ Public-Key: (2048 bit)
+ Modulus:
+ 00:bf:6b:1a:11:47:01:ac:eb:5c:2d:cf:ce:6a:a4:
+ 00:ce:2f:d1:25:03:5f:06:38:02:92:24:18:92:2a:
+ 69:19:b2:2b:a3:4f:f7:79:de:35:c3:f5:72:37:83:
+ 44:93:f9:76:fc:89:29:32:9c:0d:4b:95:7d:d1:5d:
+ 40:e9:ba:49:50:7d:c6:0a:c8:1e:e7:90:1e:37:7c:
+ 0b:23:a3:e3:bc:c9:53:81:de:d6:5f:cb:b2:3d:36:
+ ac:59:b0:33:91:8f:0c:5f:10:20:70:bf:a3:22:b3:
+ 98:ac:d4:7a:ea:67:b8:b1:8c:cf:e5:fe:8f:a0:a5:
+ 02:ad:6d:ce:f1:62:ab:dc:5d:96:9c:4f:95:47:d5:
+ 82:b7:b3:e3:87:4c:8d:38:85:2a:24:9d:7f:c7:a4:
+ 0e:bd:8a:2d:6b:d2:d4:e8:78:62:1b:aa:25:5f:5a:
+ 64:e5:76:23:ae:11:03:9a:5c:ed:a2:ba:51:ec:b1:
+ f3:ae:ba:5c:eb:dd:49:63:ca:c7:af:0c:16:1d:94:
+ 95:3a:ce:2c:8f:e2:94:7f:1f:a1:76:e2:9f:d1:41:
+ 31:f0:68:e5:ae:df:d0:75:a0:34:f5:25:93:85:b3:
+ 25:50:42:6c:00:c0:fe:3b:e0:fb:00:de:75:33:86:
+ 6a:21:35:14:9d:7f:4a:af:f7:15:f2:d7:bb:2f:de:
+ df:ab
+ Exponent: 65537 (0x10001)
+ X509v3 extensions:
+ X509v3 Basic Constraints:
+ CA:FALSE
+ Netscape Comment:
+ OpenSSL Generated Certificate
+ X509v3 Subject Key Identifier:
+ B3:42:0A:9A:00:19:CB:CB:24:A0:02:45:1E:8A:B0:54:CB:9F:55:FE
+ X509v3 Authority Key Identifier:
+ keyid:B3:42:0A:9A:00:19:CB:CB:24:A0:02:45:1E:8A:B0:54:CB:9F:55:FE
+
+ Signature Algorithm: sha256WithRSAEncryption
+ 32:9a:05:e3:6f:ae:ee:b1:a2:12:0a:9f:0a:e7:78:26:df:90:
+ fb:84:60:ae:13:fc:ff:fd:42:84:23:14:c3:2e:e2:a9:df:4b:
+ 5c:2f:5b:0e:3d:f9:5a:56:50:13:bc:89:1a:08:70:dd:6c:6c:
+ e8:ae:cf:22:39:92:f2:3b:40:03:8f:4e:bc:54:88:6b:fd:8c:
+ b6:eb:30:90:21:db:fc:4e:5c:7e:12:75:e2:52:76:df:19:0f:
+ 30:49:1e:15:bc:ba:6a:e6:f7:af:93:ad:e4:36:da:47:47:a6:
+ 88:b0:ae:46:1e:91:e1:d6:b1:5e:a4:f0:68:02:81:57:86:5d:
+ 17:d1:6c:7e:7a:9f:5e:0d:fc:10:e7:7a:1a:b5:f9:4b:1d:78:
+ a4:9a:9d:d7:c2:64:c3:52:28:7f:a1:b7:25:d7:13:3f:09:7f:
+ f2:fd:dd:c6:91:eb:9b:51:80:e2:36:cb:9f:5b:4e:47:eb:77:
+ d3:cc:8b:18:b5:0b:97:a2:53:8e:fb:9b:94:7d:57:21:32:c6:
+ f3:67:93:a4:9b:eb:46:b7:cd:08:43:99:dd:c1:c3:51:b9:19:
+ ef:92:77:1c:84:67:80:67:95:ba:00:75:3d:7b:8b:ff:24:30:
+ f1:fa:6d:da:31:9d:cf:06:da:5d:04:07:14:45:8c:6b:e7:21:
+ 31:ec:7b:23
+-----BEGIN CERTIFICATE-----
+MIIDfjCCAmagAwIBAgIBATANBgkqhkiG9w0BAQsFADBCMQswCQYDVQQGEwJDWjEP
+MA0GA1UECAwGUFJBR1VFMSIwIAYDVQQDDBl0cmFuc3BvcnQtdGVzdC1zZXJ2ZXIu
+Y29tMCIYDzE5NzAwMTAxMDAwMDAwWhgPMTk3MDAxMDEwMDAwMDBaMEIxCzAJBgNV
+BAYTAkNaMQ8wDQYDVQQIDAZQUkFHVUUxIjAgBgNVBAMMGXRyYW5zcG9ydC10ZXN0
+LXNlcnZlci5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC/axoR
+RwGs61wtz85qpADOL9ElA18GOAKSJBiSKmkZsiujT/d53jXD9XI3g0ST+Xb8iSky
+nA1LlX3RXUDpuklQfcYKyB7nkB43fAsjo+O8yVOB3tZfy7I9NqxZsDORjwxfECBw
+v6Mis5is1HrqZ7ixjM/l/o+gpQKtbc7xYqvcXZacT5VH1YK3s+OHTI04hSoknX/H
+pA69ii1r0tToeGIbqiVfWmTldiOuEQOaXO2iulHssfOuulzr3UljysevDBYdlJU6
+ziyP4pR/H6F24p/RQTHwaOWu39B1oDT1JZOFsyVQQmwAwP474PsA3nUzhmohNRSd
+f0qv9xXy17sv3t+rAgMBAAGjezB5MAkGA1UdEwQCMAAwLAYJYIZIAYb4QgENBB8W
+HU9wZW5TU0wgR2VuZXJhdGVkIENlcnRpZmljYXRlMB0GA1UdDgQWBBSzQgqaABnL
+yySgAkUeirBUy59V/jAfBgNVHSMEGDAWgBSzQgqaABnLyySgAkUeirBUy59V/jAN
+BgkqhkiG9w0BAQsFAAOCAQEAMpoF42+u7rGiEgqfCud4Jt+Q+4RgrhP8//1ChCMU
+wy7iqd9LXC9bDj35WlZQE7yJGghw3Wxs6K7PIjmS8jtAA49OvFSIa/2MtuswkCHb
+/E5cfhJ14lJ23xkPMEkeFby6aub3r5Ot5DbaR0emiLCuRh6R4daxXqTwaAKBV4Zd
+F9FsfnqfXg38EOd6GrX5Sx14pJqd18Jkw1Iof6G3JdcTPwl/8v3dxpHrm1GA4jbL
+n1tOR+t308yLGLULl6JTjvublH1XITLG82eTpJvrRrfNCEOZ3cHDUbkZ75J3HIRn
+gGeVugB1PXuL/yQw8fpt2jGdzwbaXQQHFEWMa+chMex7Iw==
+-----END CERTIFICATE-----
diff --git a/tests/pytests/certs/tt-expired.key.pem b/tests/pytests/certs/tt-expired.key.pem
new file mode 100644
index 0000000..ca2988c
--- /dev/null
+++ b/tests/pytests/certs/tt-expired.key.pem
@@ -0,0 +1,27 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIIEogIBAAKCAQEAv2saEUcBrOtcLc/OaqQAzi/RJQNfBjgCkiQYkippGbIro0/3
+ed41w/VyN4NEk/l2/IkpMpwNS5V90V1A6bpJUH3GCsge55AeN3wLI6PjvMlTgd7W
+X8uyPTasWbAzkY8MXxAgcL+jIrOYrNR66me4sYzP5f6PoKUCrW3O8WKr3F2WnE+V
+R9WCt7Pjh0yNOIUqJJ1/x6QOvYota9LU6HhiG6olX1pk5XYjrhEDmlztorpR7LHz
+rrpc691JY8rHrwwWHZSVOs4sj+KUfx+hduKf0UEx8Gjlrt/QdaA09SWThbMlUEJs
+AMD+O+D7AN51M4ZqITUUnX9Kr/cV8te7L97fqwIDAQABAoIBAEA4ytIpJKLDhHXK
+VtLom2ySFnV4oBUSDarCeYvwtrpsUL/GQJ2etCM+4kdFv2h2NjmcOzpDqSJG0aPA
+ydqhKZ/b0uojIltGuxyafZJDllDsqxvTi9EwImjvQvwEZgjcGaZ7Xqb1ZOJrpzm1
+QFgM3KaVO9tKgR3Avxk40kmidU7FctFi5IELwnH/RR1OHvJbxOE4+i0LlDx0QzhX
+QHtnvHLqLLdqsFk8KvuVuVj1FwqJ6cSL0JrAdt7dnGmXBo4PDqT8Hj0AjM+CcNrV
+1D6Li9xr4y55EZUK2qU/FVDC3LqlYQy5mBfasJAXPQG4RgSVFxJ929HC7gi8vMCO
+UMeLniECgYEA6gBoRwzQ5pJUXfZGW41lJt08utfycGZm7VrA81r0x0F+DcuZ2t6J
+kB9Wnp/MNpB4DJLbl7oM2OlFOO3cw0n3VaFpNMPHVHzNbyi1hp94AIIeDz/sxfUI
+Lx7ynAQSPPQzDRfVJesT8waBdweA71TBOlrFQ2Cp7O4Qf+p0akQSv3cCgYEA0Wnd
+1Gbierv2m6Jnblg+brTMQwbRsOAM2n0V4Gd2kRaLSYd23ebshvx8xTWipRlrb5vP
+UEh+LkfuscqaJDCrikasht9z5FJtfIzHKgTrLSoR3MJRjrnuLJWTQUwSqzd0UNN6
+HigV6p+CqesNnELErak53IMfmkHAhTSkII8R9m0CgYBRY+DhTaDfgegcYouoTm7v
+bKYx6uillciZKCbSvkFDiREaJUYXba31ViEfvT8ff3JyFSaSCKFtVP3BxmIx/ukr
+fKAGPU54oYwm7Mbu00q/CoMAFOD7HbZCBYanI3dggiO7mx2FOdXPguTHDPIYzKcE
+8AuK2vVftpJAm8DwMUtAEwKBgH/eRc5ZGDdbKGS10LQm+9A7Y3IV6to2pIKQ2FfS
+tSo4espmBeXPCGQQLdt5OZvYHqril77s1OdLkutKy74HXecr6lLchHZJAoOHrmDw
+6e0FAC0tFgGxdEYS+vxnCAs17DciOjHJxkAiL/WzCfd9KXzklOkZw6U8OuLbVtBu
+q8gtAoGAbl03XZm+SHrO7XjHK/Fe5YD13cOirg48htvjbpqEDZNQr3l0eVnEj074
+IopDa/wUFlaaqPZ/DVFctqSocyskWIP4u9HfmsNBHjK5zQlge7B1fNVao++YKund
+qnVnXjWQuF2aL8k2geFxdSKmHTF4/N1qEyeyR+tMaFpGfMZOuM8=
+-----END RSA PRIVATE KEY-----
diff --git a/tests/pytests/certs/tt.cert.pem b/tests/pytests/certs/tt.cert.pem
new file mode 100644
index 0000000..2ea4898
--- /dev/null
+++ b/tests/pytests/certs/tt.cert.pem
@@ -0,0 +1,22 @@
+-----BEGIN CERTIFICATE-----
+MIIDkTCCAnmgAwIBAgIJAP/KybHquuyUMA0GCSqGSIb3DQEBCwUAMEIxCzAJBgNV
+BAYTAkNaMQ8wDQYDVQQIDAZQUkFHVUUxIjAgBgNVBAMMGXRyYW5zcG9ydC10ZXN0
+LXNlcnZlci5jb20wIBcNMTgwNzE4MTAwODIzWhgPMjA3MzA0MjAxMDA4MjNaMEIx
+CzAJBgNVBAYTAkNaMQ8wDQYDVQQIDAZQUkFHVUUxIjAgBgNVBAMMGXRyYW5zcG9y
+dC10ZXN0LXNlcnZlci5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIB
+AQDRAQDAX6+lFKurvm7fgQqm8WyYzT/wxfPJjsVQGe87OlH1KFzVfzYzgEt0RMlM
+eZgipREBZB2zK+WFM5RBHWYAwlI5PKt7EAGn8q1Zm4z+M9Uom3/Hy3bZ9q+AJwjk
+odpHYuFyWJqHIQBqaQ3SFyJwdZ/GsuzEUfWuIl74oyyMAeykTKFGdaVuIlLC3fKm
+8UCnfk99i/LEXUwRcmOV0uaG7deN5ITDDCFdb615yVjLkMhGY/jHK7uuxATOopEk
+4vThQ1aQjSkHwluaqFUW6Zl4QF8WOAufoWQPFZ8XxmUYEIG/sMvLv6dol7ltjEbC
+bfyzlS+9Qbnq6MfhTZF/4jAPAgMBAAGjgYcwgYQwHQYDVR0OBBYEFBNiUgCiKw4b
+CFNKaEkqhkNSer7wMB8GA1UdIwQYMBaAFBNiUgCiKw4bCFNKaEkqhkNSer7wMA8G
+A1UdEwEB/wQFMAMBAf8wJAYDVR0RBB0wG4IZdHJhbnNwb3J0LXRlc3Qtc2VydmVy
+LmNvbTALBgNVHQ8EBAMCAaYwDQYJKoZIhvcNAQELBQADggEBACz1ZQ8XkGobhTcA
+hkSTSw0ko6qwVuJJD5ue3SUcWLATsskohTJmN6bde3IMDRyQvLJAlMdG2p1qMbtA
+OTbnQJTT7oDLaW8w2D+eO5oWTJvxLpl6TxbIfJN/8ITB1ltOCxTU9cVNbd2eh8sj
+l3R4etg9djYRrqtNxCQZOYSwvhHw2MefnwjGVuJEu6JYOn3IE8Jqsh/LI59C87nE
+MetZrXlzC6kSAFfRYgQET9RhBobMU9yFR8zGVHDFoxqNQs2lYKPz/3rFPetL2rjT
+cFwzxkxDdwn+RNisBc1LMfIg7pvSMFR6sAnpjeRHN0Uoem1jj2qtzjbFENDuyQ4/
+HSi4UcE=
+-----END CERTIFICATE-----
diff --git a/tests/pytests/certs/tt.conf b/tests/pytests/certs/tt.conf
new file mode 100644
index 0000000..5ac7737
--- /dev/null
+++ b/tests/pytests/certs/tt.conf
@@ -0,0 +1,353 @@
+#
+# OpenSSL example configuration file.
+# This is mostly being used for generation of certificate requests.
+#
+
+# This definition stops the following lines choking if HOME isn't
+# defined.
+HOME = .
+RANDFILE = $ENV::HOME/.rnd
+
+# Extra OBJECT IDENTIFIER info:
+#oid_file = $ENV::HOME/.oid
+oid_section = new_oids
+
+# To use this configuration file with the "-extfile" option of the
+# "openssl x509" utility, name here the section containing the
+# X.509v3 extensions to use:
+# extensions =
+# (Alternatively, use a configuration file that has only
+# X.509v3 extensions in its main [= default] section.)
+
+[ new_oids ]
+
+# We can add new OIDs in here for use by 'ca', 'req' and 'ts'.
+# Add a simple OID like this:
+# testoid1=1.2.3.4
+# Or use config file substitution like this:
+# testoid2=${testoid1}.5.6
+
+# Policies used by the TSA examples.
+tsa_policy1 = 1.2.3.4.1
+tsa_policy2 = 1.2.3.4.5.6
+tsa_policy3 = 1.2.3.4.5.7
+
+####################################################################
+[ ca ]
+default_ca = CA_default # The default ca section
+
+####################################################################
+[ CA_default ]
+
+dir = ./demoCA # Where everything is kept
+certs = $dir/certs # Where the issued certs are kept
+crl_dir = $dir/crl # Where the issued crl are kept
+database = $dir/index.txt # database index file.
+#unique_subject = no # Set to 'no' to allow creation of
+ # several certs with same subject.
+new_certs_dir = $dir/newcerts # default place for new certs.
+
+certificate = $dir/cacert.pem # The CA certificate
+serial = $dir/serial # The current serial number
+crlnumber = $dir/crlnumber # the current crl number
+ # must be commented out to leave a V1 CRL
+crl = $dir/crl.pem # The current CRL
+private_key = $dir/private/cakey.pem# The private key
+RANDFILE = $dir/private/.rand # private random number file
+
+x509_extensions = usr_cert # The extensions to add to the cert
+
+# Comment out the following two lines for the "traditional"
+# (and highly broken) format.
+name_opt = ca_default # Subject Name options
+cert_opt = ca_default # Certificate field options
+
+# Extension copying option: use with caution.
+copy_extensions = copy
+
+# Extensions to add to a CRL. Note: Netscape communicator chokes on V2 CRLs
+# so this is commented out by default to leave a V1 CRL.
+# crlnumber must also be commented out to leave a V1 CRL.
+# crl_extensions = crl_ext
+
+default_days = 365 # how long to certify for
+default_crl_days= 30 # how long before next CRL
+default_md = default # use public key default MD
+preserve = no # keep passed DN ordering
+
+# A few difference way of specifying how similar the request should look
+# For type CA, the listed attributes must be the same, and the optional
+# and supplied fields are just that :-)
+policy = policy_match
+
+# For the CA policy
+[ policy_match ]
+countryName = optional
+stateOrProvinceName = optional
+organizationName = optional
+organizationalUnitName = optional
+commonName = supplied
+emailAddress = optional
+
+# For the 'anything' policy
+# At this point in time, you must list all acceptable 'object'
+# types.
+[ policy_anything ]
+countryName = optional
+stateOrProvinceName = optional
+localityName = optional
+organizationName = optional
+organizationalUnitName = optional
+commonName = supplied
+emailAddress = optional
+
+####################################################################
+[ req ]
+default_bits = 2048
+default_keyfile = privkey.pem
+distinguished_name = req_distinguished_name
+attributes = req_attributes
+x509_extensions = v3_ca # The extensions to add to the self signed cert
+
+# Passwords for private keys if not present they will be prompted for
+# input_password = secret
+# output_password = secret
+
+# This sets a mask for permitted string types. There are several options.
+# default: PrintableString, T61String, BMPString.
+# pkix : PrintableString, BMPString (PKIX recommendation before 2004)
+# utf8only: only UTF8Strings (PKIX recommendation after 2004).
+# nombstr : PrintableString, T61String (no BMPStrings or UTF8Strings).
+# MASK:XXXX a literal mask value.
+# WARNING: ancient versions of Netscape crash on BMPStrings or UTF8Strings.
+string_mask = utf8only
+
+# req_extensions = v3_req # The extensions to add to a certificate request
+
+[ req_distinguished_name ]
+countryName = Country Name (2 letter code)
+countryName_default = CZ
+countryName_min = 2
+countryName_max = 2
+
+stateOrProvinceName = State or Province Name (full name)
+stateOrProvinceName_default = PRAGUE
+
+localityName = Locality Name (eg, city)
+
+0.organizationName = Organization Name (eg, company)
+0.organizationName_default =
+
+# we can do this but it is not needed normally :-)
+#1.organizationName = Second Organization Name (eg, company)
+#1.organizationName_default = World Wide Web Pty Ltd
+
+organizationalUnitName = Organizational Unit Name (eg, section)
+#organizationalUnitName_default =
+
+commonName = Common Name (e.g. server FQDN or YOUR name)
+commonName_max = 64
+commonName_default = transport-test-server.com
+
+emailAddress = Email Address
+emailAddress_max = 64
+
+# SET-ex3 = SET extension number 3
+
+[ req_attributes ]
+challengePassword = A challenge password
+challengePassword_min = 4
+challengePassword_max = 20
+
+unstructuredName = An optional company name
+
+[ usr_cert ]
+
+# These extensions are added when 'ca' signs a request.
+
+# This goes against PKIX guidelines but some CAs do it and some software
+# requires this to avoid interpreting an end user certificate as a CA.
+
+basicConstraints=CA:FALSE
+
+# Here are some examples of the usage of nsCertType. If it is omitted
+# the certificate can be used for anything *except* object signing.
+
+# This is OK for an SSL server.
+# nsCertType = server
+
+# For an object signing certificate this would be used.
+# nsCertType = objsign
+
+# For normal client use this is typical
+# nsCertType = client, email
+
+# and for everything including object signing:
+# nsCertType = client, email, objsign
+
+# This is typical in keyUsage for a client certificate.
+# keyUsage = nonRepudiation, digitalSignature, keyEncipherment
+
+# This will be displayed in Netscape's comment listbox.
+nsComment = "OpenSSL Generated Certificate"
+
+# PKIX recommendations harmless if included in all certificates.
+subjectKeyIdentifier=hash
+authorityKeyIdentifier=keyid,issuer
+
+# This stuff is for subjectAltName and issuerAltname.
+# Import the email address.
+# subjectAltName=email:copy
+# An alternative to produce certificates that aren't
+# deprecated according to PKIX.
+# subjectAltName=email:move
+
+# Copy subject details
+# issuerAltName=issuer:copy
+
+#nsCaRevocationUrl = http://www.domain.dom/ca-crl.pem
+#nsBaseUrl
+#nsRevocationUrl
+#nsRenewalUrl
+#nsCaPolicyUrl
+#nsSslServerName
+
+# This is required for TSA certificates.
+# extendedKeyUsage = critical,timeStamping
+
+[ v3_req ]
+
+# Extensions to add to a certificate request
+
+basicConstraints = CA:FALSE
+keyUsage = nonRepudiation, digitalSignature, keyEncipherment
+
+[ v3_ca ]
+
+
+# Extensions for a typical CA
+
+
+# PKIX recommendation.
+
+subjectKeyIdentifier=hash
+
+authorityKeyIdentifier=keyid:always,issuer
+
+basicConstraints = critical,CA:true
+
+subjectAltName = @alternate_names
+
+# Key usage: this is typical for a CA certificate. However since it will
+# prevent it being used as an test self-signed certificate it is best
+# left out by default.
+keyUsage = digitalSignature, keyEncipherment, cRLSign, keyCertSign
+
+# Some might want this also
+# nsCertType = sslCA, emailCA
+
+# Include email address in subject alt name: another PKIX recommendation
+# subjectAltName=email:copy
+# Copy issuer details
+# issuerAltName=issuer:copy
+
+# DER hex encoding of an extension: beware experts only!
+# obj=DER:02:03
+# Where 'obj' is a standard or added object
+# You can even override a supported extension:
+# basicConstraints= critical, DER:30:03:01:01:FF
+
+[ crl_ext ]
+
+# CRL extensions.
+# Only issuerAltName and authorityKeyIdentifier make any sense in a CRL.
+
+# issuerAltName=issuer:copy
+authorityKeyIdentifier=keyid:always
+
+[ proxy_cert_ext ]
+# These extensions should be added when creating a proxy certificate
+
+# This goes against PKIX guidelines but some CAs do it and some software
+# requires this to avoid interpreting an end user certificate as a CA.
+
+basicConstraints=CA:FALSE
+
+# Here are some examples of the usage of nsCertType. If it is omitted
+# the certificate can be used for anything *except* object signing.
+
+# This is OK for an SSL server.
+# nsCertType = server
+
+# For an object signing certificate this would be used.
+# nsCertType = objsign
+
+# For normal client use this is typical
+# nsCertType = client, email
+
+# and for everything including object signing:
+# nsCertType = client, email, objsign
+
+# This is typical in keyUsage for a client certificate.
+# keyUsage = nonRepudiation, digitalSignature, keyEncipherment
+
+# This will be displayed in Netscape's comment listbox.
+nsComment = "OpenSSL Generated Certificate"
+
+# PKIX recommendations harmless if included in all certificates.
+subjectKeyIdentifier=hash
+authorityKeyIdentifier=keyid,issuer
+
+# This stuff is for subjectAltName and issuerAltname.
+# Import the email address.
+# subjectAltName=email:copy
+# An alternative to produce certificates that aren't
+# deprecated according to PKIX.
+# subjectAltName=email:move
+
+# Copy subject details
+# issuerAltName=issuer:copy
+
+#nsCaRevocationUrl = http://www.domain.dom/ca-crl.pem
+#nsBaseUrl
+#nsRevocationUrl
+#nsRenewalUrl
+#nsCaPolicyUrl
+#nsSslServerName
+
+# This really needs to be in place for it to be a proxy certificate.
+proxyCertInfo=critical,language:id-ppl-anyLanguage,pathlen:3,policy:foo
+
+####################################################################
+[ tsa ]
+
+default_tsa = tsa_config1 # the default TSA section
+
+[ tsa_config1 ]
+
+# These are used by the TSA reply generation only.
+dir = ./demoCA # TSA root directory
+serial = $dir/tsaserial # The current serial number (mandatory)
+crypto_device = builtin # OpenSSL engine to use for signing
+signer_cert = $dir/tsacert.pem # The TSA signing certificate
+ # (optional)
+certs = $dir/cacert.pem # Certificate chain to include in reply
+ # (optional)
+signer_key = $dir/private/tsakey.pem # The TSA private key (optional)
+signer_digest = sha256 # Signing digest to use. (Optional)
+default_policy = tsa_policy1 # Policy if request did not specify it
+ # (optional)
+other_policies = tsa_policy2, tsa_policy3 # acceptable policies (optional)
+digests = sha1, sha256, sha384, sha512 # Acceptable message digests (mandatory)
+accuracy = secs:1, millisecs:500, microsecs:100 # (optional)
+clock_precision_digits = 0 # number of digits after dot. (optional)
+ordering = yes # Is ordering defined for timestamps?
+ # (optional, default: no)
+tsa_name = yes # Must the TSA name be included in the reply?
+ # (optional, default: no)
+ess_cert_id_chain = no # Must the ESS cert id chain be included?
+ # (optional, default: no)
+
+[ alternate_names ]
+
+DNS.1 = transport-test-server.com
diff --git a/tests/pytests/certs/tt.key.pem b/tests/pytests/certs/tt.key.pem
new file mode 100644
index 0000000..1974be7
--- /dev/null
+++ b/tests/pytests/certs/tt.key.pem
@@ -0,0 +1,28 @@
+-----BEGIN PRIVATE KEY-----
+MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDRAQDAX6+lFKur
+vm7fgQqm8WyYzT/wxfPJjsVQGe87OlH1KFzVfzYzgEt0RMlMeZgipREBZB2zK+WF
+M5RBHWYAwlI5PKt7EAGn8q1Zm4z+M9Uom3/Hy3bZ9q+AJwjkodpHYuFyWJqHIQBq
+aQ3SFyJwdZ/GsuzEUfWuIl74oyyMAeykTKFGdaVuIlLC3fKm8UCnfk99i/LEXUwR
+cmOV0uaG7deN5ITDDCFdb615yVjLkMhGY/jHK7uuxATOopEk4vThQ1aQjSkHwlua
+qFUW6Zl4QF8WOAufoWQPFZ8XxmUYEIG/sMvLv6dol7ltjEbCbfyzlS+9Qbnq6Mfh
+TZF/4jAPAgMBAAECggEBALSs10d18FMW0WjAUPxpgxnaLnTRSesMVLjy8ONT6Bkd
+S2hRIh91vxc6WwABzrqLita4N0EqmPoggmNpuUmo7lrNoWLVbbAOoD/da7nA3FuL
+10MpWYcP/ohh1klEdU2gFSAM/LNqoPsbrk5OzqHFWgI5zItqdX8pEucb01nBRWsp
+VMY2vzVuFB2jweZQ5+LCpfSMcRIzlxQa9CG4Peu6YW1Z4b3aUcS63/829JN/ZOGd
+uoRqR+gP71yNIt6i7wA5cot5FRmzlFEGhb1XzBOB1FFHOiknOZzbBtDsGUUmVtfA
+6mXcTumhdHbC0bXnHei/s2s9X0EeyQFYPkoS4NUQ2dECgYEA/7lhgn89K8rpUPnS
+eccTpKVPWp8luQei98Hi/F94kwP32l7Zl7Bmu2nltUoB1GBRXoXY6KzTphmT6ioA
+8joLCKIii5/nOdZAdHbIN2tkXS56h524q5I2jKogjfRrpCaAJE8x99f8L9uTBfZb
+/7BBQDHai1/S6LcpIRf/4g1/xBsCgYEA0Tq4V5hR9mGDUFir1FDGhA3ijDkIE/sO
+3QGTU7W90BL27te98FuQtWOPqfd1fi26WypNpNQUZb3V4x5tmDcpWscfj6I10432
+4zECPlDgaevucJjj245U7WjUhdAvlRy6K8H/8MgRBAjw9h8dwIGIx9gmOqKdA+/h
+ve3xyjKQex0CgYAz0XzQ1LewiA1/OyBLTOvOETFjS5x5QfLkAYXdXfswzz0KIu40
+rqoij/LcKYL1Zg8W+Ehb3amFnuk6KgjHDLvvo+scH+ra7W9iKi+oCzrrJt/tWyhw
+m9Ax8Mdn/H9TY/nTYbjeYAXaLMQ+EQ3TYgPW3kNKusAiJ/tNmW9gfxvEwQKBgGSJ
+Rbj5fTDZjGKYKQDdS3Z6wYhFg0culObHcgaARtPruPHtgtwy82blj0vJl5Bo4qoZ
+urNgIOj+ff8jSOAiaWGwWs8Gz7x289IZY42UCTF8Z9d878g5LT/i5nPiJGsPIboS
+/yuwxtRcg4SQURiGZbY5e60jJDWXF67O3icdguVVAoGARXLufXvZ/9Xf1DmFFxjq
+PJMCa1sfofqjB4KqYbt17vFtTsddCiyqsbpx36oY6nIdm9yUiGo10YaSEJtDEGLS
+L3TPZ4s8M8dcjOfj8Kk75pKbJ7NY4qA64dtbxcZbrFp3/mGZkDing94y+Zc/aFqa
+xQsA/yhmYV9r+FHDL54Cn6I=
+-----END PRIVATE KEY-----
diff --git a/tests/pytests/conftest.py b/tests/pytests/conftest.py
new file mode 100644
index 0000000..e10ba3a
--- /dev/null
+++ b/tests/pytests/conftest.py
@@ -0,0 +1,78 @@
+import socket
+
+import pytest
+
+from kresd import init_portdir, make_kresd
+
+
+@pytest.fixture
+def kresd(tmpdir):
+ with make_kresd(tmpdir) as kresd:
+ yield kresd
+
+
+@pytest.fixture
+def kresd_tt(tmpdir):
+ with make_kresd(tmpdir, 'tt') as kresd:
+ yield kresd
+
+
+@pytest.fixture(params=[
+ 'ip_tcp_socket',
+ 'ip6_tcp_socket',
+ 'ip_tls_socket',
+ 'ip6_tls_socket',
+])
+def make_kresd_sock(request, kresd):
+ """Factory function to create sockets of the same kind."""
+ sock_func = getattr(kresd, request.param)
+
+ def _make_kresd_sock():
+ return sock_func()
+
+ return _make_kresd_sock
+
+
+@pytest.fixture
+def kresd_sock(make_kresd_sock):
+ return make_kresd_sock()
+
+
+@pytest.fixture(params=[
+ socket.AF_INET,
+ socket.AF_INET6,
+])
+def sock_family(request):
+ return request.param
+
+
+@pytest.fixture(params=[
+ True,
+ False
+])
+def single_buffer(request): # whether to send all data in a single buffer
+ return request.param
+
+
+@pytest.fixture(params=[
+ True,
+ False
+])
+def query_before(request): # whether to send an initial query
+ return request.param
+
+
+@pytest.mark.optionalhook
+def pytest_metadata(metadata): # filter potentially sensitive data from GitLab CI
+ keys_to_delete = []
+ for key in metadata.keys():
+ key_lower = key.lower()
+ if 'password' in key_lower or 'token' in key_lower or \
+ key_lower.startswith('ci') or key_lower.startswith('gitlab'):
+ keys_to_delete.append(key)
+ for key in keys_to_delete:
+ del metadata[key]
+
+
+def pytest_sessionstart(session): # pylint: disable=unused-argument
+ init_portdir()
diff --git a/tests/pytests/conn_flood.py b/tests/pytests/conn_flood.py
new file mode 100644
index 0000000..8c2625a
--- /dev/null
+++ b/tests/pytests/conn_flood.py
@@ -0,0 +1,83 @@
+"""Test opening as many connections as possible.
+
+Due to resource-intensity of this test, it's filename doesn't contain
+"test" on purpose, so it doesn't automatically get picked up by pytest
+(to allow easy parallel testing).
+
+To execute this test, pass the filename of this file to pytest directly.
+Also, make sure not to use parallel execution (-n).
+"""
+
+import resource
+import time
+
+import pytest
+
+from kresd import Kresd
+import utils
+
+
+MAX_SOCKETS = 10000 # upper bound of how many connections to open
+MAX_ITERATIONS = 10 # number of iterations to run the test
+
+# we can't use softlimit ifself since kresd already has open sockets,
+# so use lesser value
+RESERVED_NOFILE = 40 # 40 is empirical value
+
+
+@pytest.mark.parametrize('sock_func_name', [
+ 'ip_tcp_socket',
+ 'ip6_tcp_socket',
+ 'ip_tls_socket',
+ 'ip6_tls_socket',
+])
+def test_conn_flood(tmpdir, sock_func_name):
+ def create_sockets(make_sock, nsockets):
+ sockets = []
+ next_ping = time.time() + 4 # less than tcp idle timeout / 2
+ while True:
+ additional_sockets = 0
+ while time.time() < next_ping:
+ nsock_to_init = min(100, nsockets - len(sockets))
+ if not nsock_to_init:
+ return sockets
+ sockets.extend([make_sock() for _ in range(nsock_to_init)])
+ additional_sockets += nsock_to_init
+
+ # large number of connections can take a lot of time to open
+ # send some valid data to avoid TCP idle timeout for already open sockets
+ next_ping = time.time() + 4
+ for s in sockets:
+ utils.ping_alive(s)
+
+ # break when no more than 20% additional sockets are created
+ if additional_sockets / len(sockets) < 0.2:
+ return sockets
+
+ max_num_of_open_files = resource.getrlimit(resource.RLIMIT_NOFILE)[0] - RESERVED_NOFILE
+ nsockets = min(max_num_of_open_files, MAX_SOCKETS)
+
+ # create kresd instance with verbose=False
+ ip = '127.0.0.1'
+ ip6 = '::1'
+ with Kresd(tmpdir, ip=ip, ip6=ip6, verbose=False) as kresd:
+ make_sock = getattr(kresd, sock_func_name) # function for creating sockets
+ sockets = create_sockets(make_sock, nsockets)
+ print("\nEstablished {} connections".format(len(sockets)))
+
+ print("Start sending data")
+ for i in range(MAX_ITERATIONS):
+ for s in sockets:
+ utils.ping_alive(s)
+ print("Iteration {} done...".format(i))
+
+ print("Close connections")
+ for s in sockets:
+ s.close()
+
+ # check in kresd is alive
+ print("Check upstream is still alive")
+ sock = make_sock()
+ utils.ping_alive(sock)
+
+ print("OK!")
diff --git a/tests/pytests/kresd.py b/tests/pytests/kresd.py
new file mode 100644
index 0000000..4b122f6
--- /dev/null
+++ b/tests/pytests/kresd.py
@@ -0,0 +1,305 @@
+from collections import namedtuple
+from contextlib import ContextDecorator, contextmanager
+import os
+from pathlib import Path
+import random
+import re
+import shutil
+import socket
+import subprocess
+import time
+
+import jinja2
+
+import utils
+
+
+PYTESTS_DIR = os.path.dirname(os.path.realpath(__file__))
+CERTS_DIR = os.path.join(PYTESTS_DIR, 'certs')
+TEMPLATES_DIR = os.path.join(PYTESTS_DIR, 'templates')
+KRESD_CONF_TEMPLATE = 'kresd.conf.j2'
+KRESD_STARTUP_MSGID = 10005 # special unique ID at the start of the "test" log
+KRESD_PORTDIR = '/tmp/pytest-kresd-portdir'
+KRESD_TESTPORT_MIN = 1024
+KRESD_TESTPORT_MAX = 32768 # avoid overlap with docker ephemeral port range
+
+
+def init_portdir():
+ try:
+ shutil.rmtree(KRESD_PORTDIR)
+ except FileNotFoundError:
+ pass
+ os.makedirs(KRESD_PORTDIR)
+
+
+def create_file_from_template(template_path, dest, data):
+ env = jinja2.Environment(
+ loader=jinja2.FileSystemLoader(TEMPLATES_DIR))
+ template = env.get_template(template_path)
+ rendered_template = template.render(**data)
+
+ with open(dest, "w") as fh:
+ fh.write(rendered_template)
+
+
+Forward = namedtuple('Forward', ['proto', 'ip', 'port', 'hostname', 'ca_file'])
+
+
+class Kresd(ContextDecorator):
+ def __init__(
+ self, workdir, port=None, tls_port=None, ip=None, ip6=None, certname=None,
+ verbose=True, hints=None, forward=None):
+ if ip is None and ip6 is None:
+ raise ValueError("IPv4 or IPv6 must be specified!")
+ self.workdir = str(workdir)
+ self.port = port
+ self.tls_port = tls_port
+ self.ip = ip
+ self.ip6 = ip6
+ self.process = None
+ self.sockets = []
+ self.logfile = None
+ self.verbose = verbose
+ self.hints = {} if hints is None else hints
+ self.forward = forward
+
+ if certname:
+ self.tls_cert_path = os.path.join(CERTS_DIR, certname + '.cert.pem')
+ self.tls_key_path = os.path.join(CERTS_DIR, certname + '.key.pem')
+ else:
+ self.tls_cert_path = None
+ self.tls_key_path = None
+
+ @property
+ def config_path(self):
+ return str(os.path.join(self.workdir, 'kresd.conf'))
+
+ @property
+ def logfile_path(self):
+ return str(os.path.join(self.workdir, 'kresd.log'))
+
+ def __enter__(self):
+ if self.port is not None:
+ take_port(self.port, self.ip, self.ip6, timeout=120)
+ else:
+ self.port = make_port(self.ip, self.ip6)
+ if self.tls_port is not None:
+ take_port(self.tls_port, self.ip, self.ip6, timeout=120)
+ else:
+ self.tls_port = make_port(self.ip, self.ip6)
+
+ create_file_from_template(KRESD_CONF_TEMPLATE, self.config_path, {'kresd': self})
+ self.logfile = open(self.logfile_path, 'w')
+ self.process = subprocess.Popen(
+ ['kresd', '-c', self.config_path, '-f', '1', self.workdir],
+ stdout=self.logfile, env=os.environ.copy())
+
+ try:
+ self._wait_for_tcp_port() # wait for ports to be up and responding
+ if not self.all_ports_alive(msgid=10001):
+ raise RuntimeError("Kresd not listening on all ports")
+
+ # issue special msgid to mark start of test log
+ sock = self.ip_tcp_socket() if self.ip else self.ip6_tcp_socket()
+ assert utils.try_ping_alive(sock, close=True, msgid=KRESD_STARTUP_MSGID)
+
+ # sanity check - kresd didn't crash
+ self.process.poll()
+ if self.process.returncode is not None:
+ raise RuntimeError("Kresd crashed with returncode: {}".format(
+ self.process.returncode))
+ except (RuntimeError, ConnectionError): # pylint: disable=try-except-raise
+ with open(self.logfile_path) as log: # print log for debugging
+ print(log.read())
+ raise
+
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ try:
+ if not self.all_ports_alive(msgid=1006):
+ raise RuntimeError("Kresd crashed")
+ finally:
+ for sock in self.sockets:
+ sock.close()
+ self.process.terminate()
+ self.logfile.close()
+ Path(KRESD_PORTDIR, str(self.port)).unlink()
+
+ def all_ports_alive(self, msgid=10001):
+ alive = True
+ if self.ip:
+ alive &= utils.try_ping_alive(self.ip_tcp_socket(), close=True, msgid=msgid)
+ alive &= utils.try_ping_alive(self.ip_tls_socket(), close=True, msgid=msgid + 1)
+ if self.ip6:
+ alive &= utils.try_ping_alive(self.ip6_tcp_socket(), close=True, msgid=msgid + 2)
+ alive &= utils.try_ping_alive(self.ip6_tls_socket(), close=True, msgid=msgid + 3)
+ return alive
+
+ def _wait_for_tcp_port(self, max_delay=10, delay_step=0.2):
+ family = socket.AF_INET if self.ip else socket.AF_INET6
+ i = 0
+ end_time = time.time() + max_delay
+
+ while time.time() < end_time:
+ i += 1
+
+ # use exponential backoff algorhitm to choose next delay
+ rand_delay = random.randrange(0, i)
+ time.sleep(rand_delay * delay_step)
+
+ try:
+ sock, dest = self.stream_socket(family, timeout=5)
+ sock.connect(dest)
+ except ConnectionRefusedError:
+ continue
+ else:
+ try:
+ return utils.try_ping_alive(sock, close=True, msgid=10000)
+ except socket.timeout:
+ continue
+ finally:
+ sock.close()
+ raise RuntimeError("Kresd didn't start in time")
+
+ def socket_dest(self, family, tls=False):
+ port = self.tls_port if tls else self.port
+ if family == socket.AF_INET:
+ return self.ip, port
+ elif family == socket.AF_INET6:
+ return self.ip6, port, 0, 0
+ raise RuntimeError("Unsupported socket family: {}".format(family))
+
+ def stream_socket(self, family, tls=False, timeout=20):
+ """Initialize a socket and return it along with the destination without connecting."""
+ sock = socket.socket(family, socket.SOCK_STREAM)
+ sock.settimeout(timeout)
+ sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+ dest = self.socket_dest(family, tls)
+ self.sockets.append(sock)
+ return sock, dest
+
+ def _tcp_socket(self, family):
+ sock, dest = self.stream_socket(family)
+ sock.connect(dest)
+ return sock
+
+ def ip_tcp_socket(self):
+ return self._tcp_socket(socket.AF_INET)
+
+ def ip6_tcp_socket(self):
+ return self._tcp_socket(socket.AF_INET6)
+
+ def _tls_socket(self, family):
+ sock, dest = self.stream_socket(family, tls=True)
+ ctx = utils.make_ssl_context(insecure=True)
+ ssock = ctx.wrap_socket(sock)
+ try:
+ ssock.connect(dest)
+ except OSError as exc:
+ if exc.errno == 0: # sometimes happens shortly after startup
+ return None
+ return ssock
+
+ def _tls_socket_with_retry(self, family):
+ sock = self._tls_socket(family)
+ if sock is None:
+ time.sleep(0.1)
+ sock = self._tls_socket(family)
+ if sock is None:
+ raise RuntimeError("Failed to create TLS socket!")
+ return sock
+
+ def ip_tls_socket(self):
+ return self._tls_socket_with_retry(socket.AF_INET)
+
+ def ip6_tls_socket(self):
+ return self._tls_socket_with_retry(socket.AF_INET6)
+
+ def partial_log(self):
+ partial_log = '\n (... ommiting log start)\n'
+ with open(self.logfile_path) as log: # display partial log for debugging
+ past_startup_msgid = False
+ past_startup = False
+ for line in log:
+ if past_startup:
+ partial_log += line
+ else: # find real start of test log (after initial alive-pings)
+ if not past_startup_msgid:
+ if re.match(KRESD_LOG_STARTUP_MSGID, line) is not None:
+ past_startup_msgid = True
+ else:
+ if re.match(KRESD_LOG_IO_CLOSE, line) is not None:
+ past_startup = True
+ return partial_log
+
+
+def is_port_free(port, ip=None, ip6=None):
+ def check(family, type_, dest):
+ sock = socket.socket(family, type_)
+ sock.bind(dest)
+ sock.close()
+
+ try:
+ if ip is not None:
+ check(socket.AF_INET, socket.SOCK_STREAM, (ip, port))
+ check(socket.AF_INET, socket.SOCK_DGRAM, (ip, port))
+ if ip6 is not None:
+ check(socket.AF_INET6, socket.SOCK_STREAM, (ip6, port, 0, 0))
+ check(socket.AF_INET6, socket.SOCK_DGRAM, (ip6, port, 0, 0))
+ except OSError as exc:
+ if exc.errno == 98: # address alrady in use
+ return False
+ else:
+ raise
+ return True
+
+
+def take_port(port, ip=None, ip6=None, timeout=0):
+ port_path = Path(KRESD_PORTDIR, str(port))
+ end_time = time.time() + timeout
+ try:
+ port_path.touch(exist_ok=False)
+ except FileExistsError:
+ raise ValueError(
+ "Port {} already reserved by system or another kresd instance!".format(port))
+
+ while True:
+ if is_port_free(port, ip, ip6):
+ # NOTE: The port_path isn't removed, so other instances don't have to attempt to
+ # take the same port again. This has the side effect of leaving many of these
+ # files behind, because when another kresd shuts down and removes its file, the
+ # port still can't be reserved for a while. This shouldn't become an issue unless
+ # we have thousands of tests (and run out of the port range).
+ break
+
+ if time.time() < end_time:
+ time.sleep(5)
+ else:
+ raise ValueError(
+ "Port {} is reserved by system!".format(port))
+ return port
+
+
+def make_port(ip=None, ip6=None):
+ for _ in range(10): # max attempts
+ port = random.randint(KRESD_TESTPORT_MIN, KRESD_TESTPORT_MAX)
+ try:
+ take_port(port, ip, ip6)
+ except ValueError:
+ continue # port reserved by system / another kresd instance
+ return port
+ raise RuntimeError("No available port found!")
+
+
+KRESD_LOG_STARTUP_MSGID = re.compile(r'^\[{}.*'.format(KRESD_STARTUP_MSGID))
+KRESD_LOG_IO_CLOSE = re.compile(r'^\[io\].*closed by peer.*')
+
+
+@contextmanager
+def make_kresd(
+ workdir, certname=None, ip='127.0.0.1', ip6='::1', forward=None, hints=None,
+ port=None, tls_port=None):
+ with Kresd(workdir, port, tls_port, ip, ip6, certname, forward=forward, hints=hints) as kresd:
+ yield kresd
+ print(kresd.partial_log())
diff --git a/tests/pytests/pylintrc b/tests/pytests/pylintrc
new file mode 100644
index 0000000..d9cf1a1
--- /dev/null
+++ b/tests/pytests/pylintrc
@@ -0,0 +1,29 @@
+[MESSAGES CONTROL]
+
+disable=
+ missing-docstring,
+ too-many-arguments,
+ too-many-instance-attributes,
+ fixme,
+ unused-import, # checked by flake8
+ line-too-long, # checked by flake8
+ invalid-name,
+ broad-except,
+ bad-continuation,
+ global-statement,
+ no-else-return,
+ redefined-outer-name, # commonly used with pytest fixtures
+
+
+[SIMILARITIES]
+min-similarity-lines=6
+ignore-comments=yes
+ignore-docstrings=yes
+ignore-imports=no
+
+[DESIGN]
+max-parents=10
+max-locals=20
+
+[TYPECHECK]
+ignored-modules=ssl
diff --git a/tests/pytests/rehandshake/Makefile b/tests/pytests/rehandshake/Makefile
new file mode 100644
index 0000000..170b89e
--- /dev/null
+++ b/tests/pytests/rehandshake/Makefile
@@ -0,0 +1,28 @@
+CC=gcc
+CFLAGS_TLS=-DDEBUG -ggdb3 -O0 -lgnutls -luv
+CFLAGS_TCP=-DDEBUG -ggdb3 -O0 -luv
+
+all: tcproxy tlsproxy
+
+tlsproxy: tls-proxy.o tlsproxy.o
+ $(CC) tls-proxy.o tlsproxy.o -o tlsproxy $(CFLAGS_TLS)
+
+tls-proxy.o: tls-proxy.c tls-proxy.h array.h
+ $(CC) -c -o $@ $< $(CFLAGS_TLS)
+
+tlsproxy.o: tlsproxy.c tls-proxy.h
+ $(CC) -c -o $@ $< $(CFLAGS_TLS)
+
+tcproxy: tcp-proxy.o tcproxy.o
+ $(CC) tcp-proxy.o tcproxy.o -o tcproxy $(CFLAGS_TCP)
+
+tcp-proxy.o: tcp-proxy.c tcp-proxy.h array.h
+ $(CC) -c -o $@ $< $(CFLAGS_TCP)
+
+tcproxy.o: tcproxy.c tcp-proxy.h
+ $(CC) -c -o $@ $< $(CFLAGS_TCP)
+
+clean:
+ rm -f tcp-proxy.o tcproxy.o tcproxy tls-proxy.o tlsproxy.o tlsproxy
+
+.PHONY: all clean
diff --git a/tests/pytests/rehandshake/array.h b/tests/pytests/rehandshake/array.h
new file mode 100644
index 0000000..ece4dd1
--- /dev/null
+++ b/tests/pytests/rehandshake/array.h
@@ -0,0 +1,166 @@
+/* Copyright (C) 2015-2017 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+/**
+ *
+ * @file array.h
+ * @brief A set of simple macros to make working with dynamic arrays easier.
+ *
+ * @note The C has no generics, so it is implemented mostly using macros.
+ * Be aware of that, as direct usage of the macros in the evaluating macros
+ * may lead to different expectations:
+ *
+ * @code{.c}
+ * MIN(array_push(arr, val), other)
+ * @endcode
+ *
+ * May evaluate the code twice, leading to unexpected behaviour.
+ * This is a price to pay for the absence of proper generics.
+ *
+ * # Example usage:
+ *
+ * @code{.c}
+ * array_t(const char*) arr;
+ * array_init(arr);
+ *
+ * // Reserve memory in advance
+ * if (array_reserve(arr, 2) < 0) {
+ * return ENOMEM;
+ * }
+ *
+ * // Already reserved, cannot fail
+ * array_push(arr, "princess");
+ * array_push(arr, "leia");
+ *
+ * // Not reserved, may fail
+ * if (array_push(arr, "han") < 0) {
+ * return ENOMEM;
+ * }
+ *
+ * // It does not hide what it really is
+ * for (size_t i = 0; i < arr.len; ++i) {
+ * printf("%s\n", arr.at[i]);
+ * }
+ *
+ * // Random delete
+ * array_del(arr, 0);
+ * @endcode
+ * \addtogroup generics
+ * @{
+ */
+
+#pragma once
+#include <stdlib.h>
+
+/** Simplified Qt containers growth strategy. */
+static inline size_t array_next_count(size_t want)
+{
+ if (want < 2048) {
+ return (want < 20) ? want + 4 : want * 2;
+ } else {
+ return want + 2048;
+ }
+}
+
+/** @internal Incremental memory reservation */
+static inline int array_std_reserve(void *baton, char **mem, size_t elm_size, size_t want, size_t *have)
+{
+ if (*have >= want) {
+ return 0;
+ }
+ /* Simplified Qt containers growth strategy */
+ size_t next_size = array_next_count(want);
+ void *mem_new = realloc(*mem, next_size * elm_size);
+ if (mem_new != NULL) {
+ *mem = mem_new;
+ *have = next_size;
+ return 0;
+ }
+ return -1;
+}
+
+/** @internal Wrapper for stdlib free. */
+static inline void array_std_free(void *baton, void *p)
+{
+ free(p);
+}
+
+/** Declare an array structure. */
+#define array_t(type) struct {type * at; size_t len; size_t cap; }
+
+/** Zero-initialize the array. */
+#define array_init(array) ((array).at = NULL, (array).len = (array).cap = 0)
+
+/** Free and zero-initialize the array (plain malloc/free). */
+#define array_clear(array) \
+ array_clear_mm(array, array_std_free, NULL)
+
+/** Make the array empty and free pointed-to memory.
+ * Mempool usage: pass mm_free and a knot_mm_t* . */
+#define array_clear_mm(array, free, baton) \
+ (free)((baton), (array).at), array_init(array)
+
+/** Reserve capacity for at least n elements.
+ * @return 0 if success, <0 on failure */
+#define array_reserve(array, n) \
+ array_reserve_mm(array, n, array_std_reserve, NULL)
+
+/** Reserve capacity for at least n elements.
+ * Mempool usage: pass kr_memreserve and a knot_mm_t* .
+ * @return 0 if success, <0 on failure */
+#define array_reserve_mm(array, n, reserve, baton) \
+ (reserve)((baton), (char **) &(array).at, sizeof((array).at[0]), (n), &(array).cap)
+
+/**
+ * Push value at the end of the array, resize it if necessary.
+ * Mempool usage: pass kr_memreserve and a knot_mm_t* .
+ * @note May fail if the capacity is not reserved.
+ * @return element index on success, <0 on failure
+ */
+#define array_push_mm(array, val, reserve, baton) \
+ (int)((array).len < (array).cap ? ((array).at[(array).len] = val, (array).len++) \
+ : (array_reserve_mm(array, ((array).cap + 1), reserve, baton) < 0 ? -1 \
+ : ((array).at[(array).len] = val, (array).len++)))
+
+/**
+ * Push value at the end of the array, resize it if necessary (plain malloc/free).
+ * @note May fail if the capacity is not reserved.
+ * @return element index on success, <0 on failure
+ */
+#define array_push(array, val) \
+ array_push_mm(array, val, array_std_reserve, NULL)
+
+/**
+ * Pop value from the end of the array.
+ */
+#define array_pop(array) \
+ (array).len -= 1
+
+/**
+ * Remove value at given index.
+ * @return 0 on success, <0 on failure
+ */
+#define array_del(array, i) \
+ (int)((i) < (array).len ? ((array).len -= 1,(array).at[i] = (array).at[(array).len], 0) : -1)
+
+/**
+ * Return last element of the array.
+ * @warning Undefined if the array is empty.
+ */
+#define array_tail(array) \
+ (array).at[(array).len - 1]
+
+/** @} */
diff --git a/tests/pytests/rehandshake/tcp-proxy.c b/tests/pytests/rehandshake/tcp-proxy.c
new file mode 100644
index 0000000..ba7198b
--- /dev/null
+++ b/tests/pytests/rehandshake/tcp-proxy.c
@@ -0,0 +1,336 @@
+#include <assert.h>
+#include <stdio.h>
+#include <unistd.h>
+#include <string.h>
+#include <stdlib.h>
+#include <stdbool.h>
+#include <uv.h>
+#include "array.h"
+
+struct buf {
+ char buf[16 * 1024];
+ size_t size;
+};
+
+enum peer_state {
+ STATE_NOT_CONNECTED,
+ STATE_LISTENING,
+ STATE_CONNECTED,
+ STATE_CONNECT_IN_PROGRESS,
+ STATE_CLOSING_IN_PROGRESS
+};
+
+struct proxy_ctx {
+ uv_loop_t *loop;
+ uv_tcp_t server;
+ uv_tcp_t client;
+ uv_tcp_t upstream;
+ struct sockaddr_storage server_addr;
+ struct sockaddr_storage upstream_addr;
+
+ int server_state;
+ int client_state;
+ int upstream_state;
+
+ array_t(struct buf *) buffer_pool;
+ array_t(struct buf *) upstream_pending;
+};
+
+static void read_from_upstream_cb(uv_stream_t *upstream, ssize_t nread, const uv_buf_t *buf);
+static void read_from_client_cb(uv_stream_t *client, ssize_t nread, const uv_buf_t *buf);
+
+static struct buf *borrow_io_buffer(struct proxy_ctx *proxy)
+{
+ struct buf *buf = NULL;
+ if (proxy->buffer_pool.len > 0) {
+ buf = array_tail(proxy->buffer_pool);
+ array_pop(proxy->buffer_pool);
+ } else {
+ buf = calloc(1, sizeof (struct buf));
+ }
+ return buf;
+}
+
+static void release_io_buffer(struct proxy_ctx *proxy, struct buf *buf)
+{
+ if (!buf) {
+ return;
+ }
+
+ if (proxy->buffer_pool.len < 1000) {
+ buf->size = 0;
+ array_push(proxy->buffer_pool, buf);
+ } else {
+ free(buf);
+ }
+}
+
+static void push_to_upstream_pending(struct proxy_ctx *proxy, const char *buf, size_t size)
+{
+ while (size > 0) {
+ struct buf *b = borrow_io_buffer(proxy);
+ b->size = size <= sizeof(b->buf) ? size : sizeof(b->buf);
+ memcpy(b->buf, buf, b->size);
+ array_push(proxy->upstream_pending, b);
+ size -= b->size;
+ }
+}
+
+static struct buf *get_first_upstream_pending(struct proxy_ctx *proxy)
+{
+ struct buf *buf = NULL;
+ if (proxy->upstream_pending.len > 0) {
+ buf = proxy->upstream_pending.at[0];
+ }
+ return buf;
+}
+
+static void remove_first_upstream_pending(struct proxy_ctx *proxy)
+{
+ for (int i = 1; i < proxy->upstream_pending.len; ++i) {
+ proxy->upstream_pending.at[i - 1] = proxy->upstream_pending.at[i];
+ }
+ if (proxy->upstream_pending.len > 0) {
+ proxy->upstream_pending.len -= 1;
+ }
+}
+
+static void clear_upstream_pending(struct proxy_ctx *proxy)
+{
+ for (int i = 1; i < proxy->upstream_pending.len; ++i) {
+ struct buf *b = proxy->upstream_pending.at[i];
+ release_io_buffer(proxy, b);
+ }
+ proxy->upstream_pending.len = 0;
+}
+
+static void clear_buffer_pool(struct proxy_ctx *proxy)
+{
+ for (int i = 1; i < proxy->buffer_pool.len; ++i) {
+ struct buf *b = proxy->buffer_pool.at[i];
+ free(b);
+ }
+ proxy->buffer_pool.len = 0;
+}
+
+static void alloc_uv_buffer(uv_handle_t *handle, size_t suggested_size, uv_buf_t *buf)
+{
+ buf->base = (char*)malloc(suggested_size);
+ buf->len = suggested_size;
+}
+
+static void on_client_close(uv_handle_t *handle)
+{
+ struct proxy_ctx *proxy = (struct proxy_ctx *)handle->loop->data;
+ proxy->client_state = STATE_NOT_CONNECTED;
+}
+
+static void on_upstream_close(uv_handle_t *handle)
+{
+ struct proxy_ctx *proxy = (struct proxy_ctx *)handle->loop->data;
+ proxy->upstream_state = STATE_NOT_CONNECTED;
+}
+
+static void write_to_client_cb(uv_write_t *req, int status)
+{
+ struct proxy_ctx *proxy = (struct proxy_ctx *)req->handle->loop->data;
+ free(req);
+ if (status) {
+ fprintf(stderr, "error writing to client: %s\n", uv_strerror(status));
+ clear_upstream_pending(proxy);
+ proxy->client_state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&proxy->client, on_client_close);
+ }
+}
+
+static void write_to_upstream_cb(uv_write_t *req, int status)
+{
+ struct proxy_ctx *proxy = (struct proxy_ctx *)req->handle->loop->data;
+ free(req);
+ if (status) {
+ fprintf(stderr, "error writing to upstream: %s\n", uv_strerror(status));
+ clear_upstream_pending(proxy);
+ proxy->upstream_state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&proxy->upstream, on_upstream_close);
+ return;
+ }
+ if (proxy->upstream_pending.len > 0) {
+ struct buf *buf = get_first_upstream_pending(proxy);
+ remove_first_upstream_pending(proxy);
+ release_io_buffer(proxy, buf);
+ if (proxy->upstream_state == STATE_CONNECTED &&
+ proxy->upstream_pending.len > 0) {
+ buf = get_first_upstream_pending(proxy);
+ /* TODO avoid allocation */
+ uv_write_t *req = (uv_write_t *) malloc(sizeof(uv_write_t));
+ uv_buf_t wrbuf = uv_buf_init(buf->buf, buf->size);
+ uv_write(req, (uv_stream_t *)&proxy->upstream, &wrbuf, 1, write_to_upstream_cb);
+ }
+ }
+}
+
+static void on_client_connection(uv_stream_t *server, int status)
+{
+ if (status < 0) {
+ fprintf(stderr, "incoming connection error: %s\n", uv_strerror(status));
+ return;
+ }
+
+ fprintf(stdout, "incoming connection\n");
+
+ struct proxy_ctx *proxy = (struct proxy_ctx *)server->loop->data;
+ if (proxy->client_state != STATE_NOT_CONNECTED) {
+ fprintf(stderr, "client already connected, ignoring\n");
+ return;
+ }
+
+ uv_tcp_init(proxy->loop, &proxy->client);
+ proxy->client_state = STATE_CONNECTED;
+ if (uv_accept(server, (uv_stream_t*)&proxy->client) == 0) {
+ uv_read_start((uv_stream_t*)&proxy->client, alloc_uv_buffer, read_from_client_cb);
+ } else {
+ proxy->client_state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&proxy->client, on_client_close);
+ }
+}
+
+static void on_connect_to_upstream(uv_connect_t *req, int status)
+{
+ struct proxy_ctx *proxy = (struct proxy_ctx *)req->handle->loop->data;
+ free(req);
+ if (status < 0) {
+ fprintf(stderr, "error connecting to upstream: %s\n", uv_strerror(status));
+ clear_upstream_pending(proxy);
+ proxy->upstream_state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&proxy->upstream, on_upstream_close);
+ return;
+ }
+
+ proxy->upstream_state = STATE_CONNECTED;
+ uv_read_start((uv_stream_t*)&proxy->upstream, alloc_uv_buffer, read_from_upstream_cb);
+ if (proxy->upstream_pending.len > 0) {
+ struct buf *buf = get_first_upstream_pending(proxy);
+ /* TODO avoid allocation */
+ uv_write_t *wreq = (uv_write_t *) malloc(sizeof(uv_write_t));
+ uv_buf_t wrbuf = uv_buf_init(buf->buf, buf->size);
+ uv_write(wreq, (uv_stream_t *)&proxy->upstream, &wrbuf, 1, write_to_upstream_cb);
+ }
+}
+
+static void read_from_client_cb(uv_stream_t *client, ssize_t nread, const uv_buf_t *buf)
+{
+ if (nread == 0) {
+ return;
+ }
+ struct proxy_ctx *proxy = (struct proxy_ctx *)client->loop->data;
+ if (nread < 0) {
+ if (nread != UV_EOF) {
+ fprintf(stderr, "error reading from client: %s\n", uv_err_name(nread));
+ }
+ if (proxy->client_state == STATE_CONNECTED) {
+ proxy->client_state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*) client, on_client_close);
+ }
+ return;
+ }
+ if (proxy->upstream_state == STATE_CONNECTED) {
+ if (proxy->upstream_pending.len > 0) {
+ push_to_upstream_pending(proxy, buf->base, nread);
+ } else {
+ /* TODO avoid allocation */
+ uv_write_t *req = (uv_write_t *) malloc(sizeof(uv_write_t));
+ uv_buf_t wrbuf = uv_buf_init(buf->base, nread);
+ uv_write(req, (uv_stream_t *)&proxy->upstream, &wrbuf, 1, write_to_upstream_cb);
+ }
+ } else if (proxy->upstream_state == STATE_NOT_CONNECTED) {
+ /* TODO avoid allocation */
+ uv_tcp_init(proxy->loop, &proxy->upstream);
+ uv_connect_t *conn = (uv_connect_t *) malloc(sizeof(uv_connect_t));
+ proxy->upstream_state = STATE_CONNECT_IN_PROGRESS;
+ uv_tcp_connect(conn, &proxy->upstream, (struct sockaddr *)&proxy->upstream_addr,
+ on_connect_to_upstream);
+ push_to_upstream_pending(proxy, buf->base, nread);
+ } else if (proxy->upstream_state == STATE_CONNECT_IN_PROGRESS) {
+ push_to_upstream_pending(proxy, buf->base, nread);
+ }
+}
+
+static void read_from_upstream_cb(uv_stream_t *upstream, ssize_t nread, const uv_buf_t *buf)
+{
+ if (nread == 0) {
+ return;
+ }
+ struct proxy_ctx *proxy = (struct proxy_ctx *)upstream->loop->data;
+ if (nread < 0) {
+ if (nread != UV_EOF) {
+ fprintf(stderr, "error reading from upstream: %s\n", uv_err_name(nread));
+ }
+ clear_upstream_pending(proxy);
+ if (proxy->upstream_state == STATE_CONNECTED) {
+ proxy->upstream_state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&proxy->upstream, on_upstream_close);
+ }
+ return;
+ }
+ if (proxy->client_state == STATE_CONNECTED) {
+ /* TODO Avoid allocation */
+ uv_write_t *req = (uv_write_t *) malloc(sizeof(uv_write_t));
+ uv_buf_t wrbuf = uv_buf_init(buf->base, nread);
+ uv_write(req, (uv_stream_t *)&proxy->client, &wrbuf, 1, write_to_client_cb);
+ }
+}
+
+struct proxy_ctx *proxy_allocate()
+{
+ return malloc(sizeof(struct proxy_ctx));
+}
+
+int proxy_init(struct proxy_ctx *proxy,
+ const char *server_addr, int server_port,
+ const char *upstream_addr, int upstream_port)
+{
+ proxy->loop = uv_default_loop();
+ uv_tcp_init(proxy->loop, &proxy->server);
+ int res = uv_ip4_addr(server_addr, server_port, (struct sockaddr_in *)&proxy->server_addr);
+ if (res != 0) {
+ return res;
+ }
+ res = uv_ip4_addr(upstream_addr, upstream_port, (struct sockaddr_in *)&proxy->upstream_addr);
+ if (res != 0) {
+ return res;
+ }
+ array_init(proxy->buffer_pool);
+ array_init(proxy->upstream_pending);
+ proxy->server_state = STATE_NOT_CONNECTED;
+ proxy->client_state = STATE_NOT_CONNECTED;
+ proxy->upstream_state = STATE_NOT_CONNECTED;
+
+ proxy->loop->data = proxy;
+ return 0;
+}
+
+void proxy_free(struct proxy_ctx *proxy)
+{
+ if (!proxy) {
+ return;
+ }
+ clear_upstream_pending(proxy);
+ clear_buffer_pool(proxy);
+ /* TODO correctly close all the uv_tcp_t */
+ free(proxy);
+}
+
+int proxy_start_listen(struct proxy_ctx *proxy)
+{
+ uv_tcp_bind(&proxy->server, (const struct sockaddr*)&proxy->server_addr, 0);
+ int ret = uv_listen((uv_stream_t*)&proxy->server, 128, on_client_connection);
+ if (ret == 0) {
+ proxy->server_state = STATE_LISTENING;
+ }
+ return ret;
+}
+
+int proxy_run(struct proxy_ctx *proxy)
+{
+ return uv_run(proxy->loop, UV_RUN_DEFAULT);
+}
diff --git a/tests/pytests/rehandshake/tcp-proxy.h b/tests/pytests/rehandshake/tcp-proxy.h
new file mode 100644
index 0000000..668a65f
--- /dev/null
+++ b/tests/pytests/rehandshake/tcp-proxy.h
@@ -0,0 +1,12 @@
+#pragma once
+
+struct proxy_ctx;
+
+struct proxy_ctx *proxy_allocate();
+void proxy_free(struct proxy_ctx *proxy);
+int proxy_init(struct proxy_ctx *proxy,
+ const char *server_addr, int server_port,
+ const char *upstream_addr, int upstream_port);
+int proxy_start_listen(struct proxy_ctx *proxy);
+int proxy_run(struct proxy_ctx *proxy);
+
diff --git a/tests/pytests/rehandshake/tcproxy.c b/tests/pytests/rehandshake/tcproxy.c
new file mode 100644
index 0000000..87a6b4c
--- /dev/null
+++ b/tests/pytests/rehandshake/tcproxy.c
@@ -0,0 +1,25 @@
+#include <stdio.h>
+#include "tcp-proxy.h"
+
+int main()
+{
+ struct proxy_ctx *proxy = proxy_allocate();
+ if (!proxy) {
+ fprintf(stderr, "can't allocate proxy structure\n");
+ return 1;
+ }
+ int res = proxy_init(proxy, "127.0.0.1", 54000, "127.0.0.1", 53001);
+ if (res) {
+ fprintf(stderr, "can't initialize proxy by given addresses\n");
+ return res;
+ }
+ res = proxy_start_listen(proxy);
+ if (res) {
+ fprintf(stderr, "error starting listen, error code: %i\n", res);
+ return res;
+ }
+ res = proxy_run(proxy);
+ proxy_free(proxy);
+ return res;
+}
+
diff --git a/tests/pytests/rehandshake/tls-proxy.c b/tests/pytests/rehandshake/tls-proxy.c
new file mode 100644
index 0000000..bf6cc0d
--- /dev/null
+++ b/tests/pytests/rehandshake/tls-proxy.c
@@ -0,0 +1,848 @@
+#include <assert.h>
+#include <stdio.h>
+#include <unistd.h>
+#include <string.h>
+#include <stdlib.h>
+#include <stdbool.h>
+#include <gnutls/gnutls.h>
+#include <uv.h>
+#include "array.h"
+
+#define TLS_MAX_SEND_RETRIES 100
+#define CLIENT_ANSWER_CHUNK_SIZE 8
+struct buf {
+ char buf[16 * 1024];
+ size_t size;
+};
+
+enum peer_state {
+ STATE_NOT_CONNECTED,
+ STATE_LISTENING,
+ STATE_CONNECTED,
+ STATE_CONNECT_IN_PROGRESS,
+ STATE_CLOSING_IN_PROGRESS
+};
+
+enum handshake_state {
+ TLS_HS_NOT_STARTED = 0,
+ TLS_HS_EXPECTED,
+ TLS_HS_IN_PROGRESS,
+ TLS_HS_DONE,
+ TLS_HS_CLOSING,
+ TLS_HS_LAST
+};
+
+struct tls_ctx {
+ gnutls_session_t session;
+ int handshake_state;
+ gnutls_certificate_credentials_t credentials;
+ gnutls_priority_t priority_cache;
+ /* for reading from the network */
+ const uint8_t *buf;
+ ssize_t nread;
+ ssize_t consumed;
+ uint8_t recv_buf[4096];
+};
+
+struct tls_proxy_ctx {
+ uv_loop_t *loop;
+ uv_tcp_t server;
+ uv_tcp_t client;
+ uv_tcp_t upstream;
+ struct sockaddr_storage server_addr;
+ struct sockaddr_storage upstream_addr;
+ struct sockaddr_storage client_addr;
+
+ int server_state;
+ int client_state;
+ int upstream_state;
+
+ array_t(struct buf *) buffer_pool;
+ array_t(struct buf *) upstream_pending;
+ array_t(struct buf *) client_pending;
+
+ char io_buf[0xFFFF];
+ struct tls_ctx tls;
+};
+
+static void read_from_upstream_cb(uv_stream_t *upstream, ssize_t nread, const uv_buf_t *buf);
+static void read_from_client_cb(uv_stream_t *client, ssize_t nread, const uv_buf_t *buf);
+static ssize_t proxy_gnutls_pull(gnutls_transport_ptr_t h, void *buf, size_t len);
+static ssize_t proxy_gnutls_push(gnutls_transport_ptr_t h, const void *buf, size_t len);
+static int tls_process_from_upstream(struct tls_proxy_ctx *proxy, const uint8_t *buf, ssize_t nread);
+static int tls_process_from_client(struct tls_proxy_ctx *proxy, const uint8_t *buf, ssize_t nread);
+static int write_to_upstream_pending(struct tls_proxy_ctx *proxy);
+static int write_to_client_pending(struct tls_proxy_ctx *proxy);
+
+
+static int gnutls_references = 0;
+
+const void *ip_addr(const struct sockaddr *addr)
+{
+ if (!addr) {
+ return NULL;
+ }
+ switch (addr->sa_family) {
+ case AF_INET: return (const void *)&(((const struct sockaddr_in *)addr)->sin_addr);
+ case AF_INET6: return (const void *)&(((const struct sockaddr_in6 *)addr)->sin6_addr);
+ default: return NULL;
+ }
+}
+
+uint16_t ip_addr_port(const struct sockaddr *addr)
+{
+ if (!addr) {
+ return 0;
+ }
+ switch (addr->sa_family) {
+ case AF_INET: return ntohs(((const struct sockaddr_in *)addr)->sin_port);
+ case AF_INET6: return ntohs(((const struct sockaddr_in6 *)addr)->sin6_port);
+ default: return 0;
+ }
+}
+
+static int ip_addr_str(const struct sockaddr *addr, char *buf, size_t *buflen)
+{
+ int ret = 0;
+ if (!addr || !buf || !buflen) {
+ return EINVAL;
+ }
+
+ char str[INET6_ADDRSTRLEN + 6];
+ if (!inet_ntop(addr->sa_family, ip_addr(addr), str, sizeof(str))) {
+ return errno;
+ }
+ int len = strlen(str);
+ str[len] = '#';
+ snprintf(&str[len + 1], 6, "%uh", ip_addr_port(addr));
+ len += 6;
+ str[len] = 0;
+ if (len >= *buflen) {
+ ret = ENOSPC;
+ } else {
+ memcpy(buf, str, len + 1);
+ }
+ *buflen = len;
+ return ret;
+}
+
+static inline char *ip_straddr(const struct sockaddr *addr)
+{
+ assert(addr != NULL);
+ /* We are the sinle-threaded application */
+ static char str[INET6_ADDRSTRLEN + 6];
+ size_t len = sizeof(str);
+ int ret = ip_addr_str(addr, str, &len);
+ return ret != 0 || len == 0 ? NULL : str;
+}
+
+static struct buf *borrow_io_buffer(struct tls_proxy_ctx *proxy)
+{
+ struct buf *buf = NULL;
+ if (proxy->buffer_pool.len > 0) {
+ buf = array_tail(proxy->buffer_pool);
+ array_pop(proxy->buffer_pool);
+ } else {
+ buf = calloc(1, sizeof (struct buf));
+ }
+ return buf;
+}
+
+static void release_io_buffer(struct tls_proxy_ctx *proxy, struct buf *buf)
+{
+ if (!buf) {
+ return;
+ }
+
+ if (proxy->buffer_pool.len < 1000) {
+ buf->size = 0;
+ array_push(proxy->buffer_pool, buf);
+ } else {
+ free(buf);
+ }
+}
+
+static struct buf *get_first_upstream_pending(struct tls_proxy_ctx *proxy)
+{
+ struct buf *buf = NULL;
+ if (proxy->upstream_pending.len > 0) {
+ buf = proxy->upstream_pending.at[0];
+ }
+ return buf;
+}
+
+static struct buf *get_first_client_pending(struct tls_proxy_ctx *proxy)
+{
+ struct buf *buf = NULL;
+ if (proxy->client_pending.len > 0) {
+ buf = proxy->client_pending.at[0];
+ }
+ return buf;
+}
+
+static void remove_first_upstream_pending(struct tls_proxy_ctx *proxy)
+{
+ for (int i = 1; i < proxy->upstream_pending.len; ++i) {
+ proxy->upstream_pending.at[i - 1] = proxy->upstream_pending.at[i];
+ }
+ if (proxy->upstream_pending.len > 0) {
+ proxy->upstream_pending.len -= 1;
+ }
+}
+
+static void remove_first_client_pending(struct tls_proxy_ctx *proxy)
+{
+ for (int i = 1; i < proxy->client_pending.len; ++i) {
+ proxy->client_pending.at[i - 1] = proxy->client_pending.at[i];
+ }
+ if (proxy->client_pending.len > 0) {
+ proxy->client_pending.len -= 1;
+ }
+}
+
+static void clear_upstream_pending(struct tls_proxy_ctx *proxy)
+{
+ for (int i = 0; i < proxy->upstream_pending.len; ++i) {
+ struct buf *b = proxy->upstream_pending.at[i];
+ release_io_buffer(proxy, b);
+ }
+ proxy->upstream_pending.len = 0;
+}
+
+static void clear_client_pending(struct tls_proxy_ctx *proxy)
+{
+ for (int i = 0; i < proxy->client_pending.len; ++i) {
+ struct buf *b = proxy->client_pending.at[i];
+ release_io_buffer(proxy, b);
+ }
+ proxy->client_pending.len = 0;
+}
+
+static void clear_buffer_pool(struct tls_proxy_ctx *proxy)
+{
+ for (int i = 0; i < proxy->buffer_pool.len; ++i) {
+ struct buf *b = proxy->buffer_pool.at[i];
+ free(b);
+ }
+ proxy->buffer_pool.len = 0;
+}
+
+static void alloc_uv_buffer(uv_handle_t *handle, size_t suggested_size, uv_buf_t *buf)
+{
+ struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)handle->loop->data;
+ buf->base = proxy->io_buf;
+ buf->len = sizeof(proxy->io_buf);
+}
+
+static void on_client_close(uv_handle_t *handle)
+{
+ struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)handle->loop->data;
+ gnutls_deinit(proxy->tls.session);
+ proxy->tls.handshake_state = TLS_HS_NOT_STARTED;
+ proxy->client_state = STATE_NOT_CONNECTED;
+}
+
+static void on_dummmy_client_close(uv_handle_t *handle)
+{
+ free(handle);
+}
+
+static void on_upstream_close(uv_handle_t *handle)
+{
+ struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)handle->loop->data;
+ proxy->upstream_state = STATE_NOT_CONNECTED;
+}
+
+static void write_to_client_cb(uv_write_t *req, int status)
+{
+ struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)req->handle->loop->data;
+ free(req);
+ if (status) {
+ fprintf(stderr, "error writing to client: %s\n", uv_strerror(status));
+ clear_client_pending(proxy);
+ clear_upstream_pending(proxy);
+ if (proxy->client_state == STATE_CONNECTED) {
+ proxy->client_state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&proxy->client, on_client_close);
+ return;
+ }
+ }
+ fprintf(stdout, "successfully wrote to client, pending len is %zd\n",
+ proxy->client_pending.len);
+ if (proxy->client_state == STATE_CONNECTED &&
+ proxy->tls.handshake_state == TLS_HS_DONE) {
+ write_to_client_pending(proxy);
+ }
+}
+
+static void write_to_upstream_cb(uv_write_t *req, int status)
+{
+ struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)req->handle->loop->data;
+ if (status) {
+ free(req);
+ fprintf(stderr, "error writing to upstream: %s\n", uv_strerror(status));
+ clear_upstream_pending(proxy);
+ proxy->upstream_state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&proxy->upstream, on_upstream_close);
+ return;
+ }
+ if (req->data != NULL) {
+ assert(proxy->upstream_pending.len > 0);
+ struct buf *buf = get_first_upstream_pending(proxy);
+ assert(req->data == (void *)buf->buf);
+ fprintf(stdout, "successfully wrote %zi bytes to upstream, pending len is %zd\n",
+ buf->size, proxy->upstream_pending.len);
+ remove_first_upstream_pending(proxy);
+ release_io_buffer(proxy, buf);
+ } else {
+ fprintf(stdout, "successfully wrote bytes to upstream, pending len is %zd\n",
+ proxy->upstream_pending.len);
+ }
+ if (proxy->upstream_state == STATE_CONNECTED &&
+ proxy->upstream_pending.len > 0) {
+ write_to_upstream_pending(proxy);
+ }
+ free(req);
+}
+
+static void on_client_connection(uv_stream_t *server, int status)
+{
+ if (status < 0) {
+ fprintf(stderr, "incoming connection error: %s\n", uv_strerror(status));
+ return;
+ }
+
+ int err = 0;
+ int ret = 0;
+ struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)server->loop->data;
+ if (proxy->client_state != STATE_NOT_CONNECTED) {
+ fprintf(stderr, "incoming connection");
+ uv_tcp_t *dummy_client = malloc(sizeof(uv_tcp_t));
+ uv_tcp_init(proxy->loop, dummy_client);
+ err = uv_accept(server, (uv_stream_t*)dummy_client);
+ if (err == 0) {
+ struct sockaddr dummy_addr;
+ int dummy_addr_len = sizeof(dummy_addr);
+ ret = uv_tcp_getpeername(dummy_client,
+ &dummy_addr,
+ &dummy_addr_len);
+ if (ret == 0) {
+ fprintf(stderr, " from %s", ip_straddr(&dummy_addr));
+ }
+ uv_close((uv_handle_t *)dummy_client, on_dummmy_client_close);
+ } else {
+ on_dummmy_client_close((uv_handle_t *)dummy_client);
+ }
+ fprintf(stderr, " - client already connected, rejecting\n");
+ return;
+ }
+
+ uv_tcp_init(proxy->loop, &proxy->client);
+ uv_tcp_nodelay((uv_tcp_t *)&proxy->client, 1);
+ proxy->client_state = STATE_CONNECTED;
+ err = uv_accept(server, (uv_stream_t*)&proxy->client);
+ if (err != 0) {
+ fprintf(stderr, "incoming connection - uv_accept() failed: (%d) %s\n",
+ err, uv_strerror(err));
+ return;
+ }
+
+ struct sockaddr *addr = (struct sockaddr *)&(proxy->client_addr);
+ int addr_len = sizeof(proxy->client_addr);
+ ret = uv_tcp_getpeername(&proxy->client, addr, &addr_len);
+ if (ret || addr->sa_family == AF_UNSPEC) {
+ proxy->client_state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&proxy->client, on_client_close);
+ fprintf(stderr, "incoming connection - uv_tcp_getpeername() failed: (%d) %s\n",
+ err, uv_strerror(err));
+ return;
+ }
+
+ fprintf(stdout, "incoming connection from %s\n", ip_straddr(addr));
+
+ uv_read_start((uv_stream_t*)&proxy->client, alloc_uv_buffer, read_from_client_cb);
+
+ const char *errpos = NULL;
+ struct tls_ctx *tls = &proxy->tls;
+ assert (tls->handshake_state == TLS_HS_NOT_STARTED);
+ err = gnutls_init(&tls->session, GNUTLS_SERVER | GNUTLS_NONBLOCK);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stderr, "gnutls_init() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ }
+ err = gnutls_priority_set(tls->session, tls->priority_cache);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stderr, "gnutls_priority_set() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ }
+ err = gnutls_credentials_set(tls->session, GNUTLS_CRD_CERTIFICATE, tls->credentials);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stderr, "gnutls_credentials_set() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ }
+ gnutls_certificate_server_set_request(tls->session, GNUTLS_CERT_IGNORE);
+ gnutls_handshake_set_timeout(tls->session, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT);
+
+ gnutls_transport_set_pull_function(tls->session, proxy_gnutls_pull);
+ gnutls_transport_set_push_function(tls->session, proxy_gnutls_push);
+ gnutls_transport_set_ptr(tls->session, proxy);
+
+ tls->handshake_state = TLS_HS_IN_PROGRESS;
+}
+
+static void on_connect_to_upstream(uv_connect_t *req, int status)
+{
+ struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)req->handle->loop->data;
+ free(req);
+ if (status < 0) {
+ fprintf(stderr, "error connecting to upstream (%s): %s\n",
+ ip_straddr((struct sockaddr *)&proxy->upstream_addr),
+ uv_strerror(status));
+ clear_upstream_pending(proxy);
+ proxy->upstream_state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&proxy->upstream, on_upstream_close);
+ return;
+ }
+ fprintf(stdout, "connected to %s\n", ip_straddr((struct sockaddr *)&proxy->upstream_addr));
+
+ proxy->upstream_state = STATE_CONNECTED;
+ uv_read_start((uv_stream_t*)&proxy->upstream, alloc_uv_buffer, read_from_upstream_cb);
+ if (proxy->upstream_pending.len > 0) {
+ write_to_upstream_pending(proxy);
+ }
+}
+
+static void read_from_client_cb(uv_stream_t *client, ssize_t nread, const uv_buf_t *buf)
+{
+ fprintf(stdout, "reading %zd bytes from client\n", nread);
+ if (nread == 0) {
+ return;
+ }
+ struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)client->loop->data;
+ if (nread < 0) {
+ if (nread != UV_EOF) {
+ fprintf(stderr, "error reading from client: %s\n", uv_err_name(nread));
+ } else {
+ fprintf(stdout, "client has closed the connection\n");
+ }
+ if (proxy->client_state == STATE_CONNECTED) {
+ proxy->client_state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*) client, on_client_close);
+ }
+ return;
+ }
+
+ int res = tls_process_from_client(proxy, buf->base, nread);
+ if (res < 0) {
+ if (proxy->client_state == STATE_CONNECTED) {
+ proxy->client_state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*) client, on_client_close);
+ }
+ }
+}
+
+static void read_from_upstream_cb(uv_stream_t *upstream, ssize_t nread, const uv_buf_t *buf)
+{
+ fprintf(stdout, "reading %zd bytes from upstream\n", nread);
+ if (nread == 0) {
+ return;
+ }
+ struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)upstream->loop->data;
+ if (nread < 0) {
+ if (nread != UV_EOF) {
+ fprintf(stderr, "error reading from upstream: %s\n", uv_err_name(nread));
+ } else {
+ fprintf(stdout, "upstream has closed the connection\n");
+ }
+ clear_upstream_pending(proxy);
+ if (proxy->upstream_state == STATE_CONNECTED) {
+ proxy->upstream_state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&proxy->upstream, on_upstream_close);
+ }
+ return;
+ }
+ int res = tls_process_from_upstream(proxy, buf->base, nread);
+ if (res < 0) {
+ fprintf(stderr, "error sending tls data to client\n");
+ if (proxy->client_state == STATE_CONNECTED) {
+ proxy->client_state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&proxy->client, on_client_close);
+ }
+ }
+}
+
+static void push_to_upstream_pending(struct tls_proxy_ctx *proxy, const char *buf, size_t size)
+{
+ while (size > 0) {
+ struct buf *b = borrow_io_buffer(proxy);
+ b->size = size <= sizeof(b->buf) ? size : sizeof(b->buf);
+ memcpy(b->buf, buf, b->size);
+ array_push(proxy->upstream_pending, b);
+ size -= b->size;
+ buf += b->size;
+ }
+}
+
+static void push_to_client_pending(struct tls_proxy_ctx *proxy, const char *buf, size_t size)
+{
+ while (size > 0) {
+ struct buf *b = borrow_io_buffer(proxy);
+ b->size = size <= sizeof(b->buf) ? size : sizeof(b->buf);
+ if (b->size > CLIENT_ANSWER_CHUNK_SIZE) {
+ b->size = CLIENT_ANSWER_CHUNK_SIZE;
+ }
+ memcpy(b->buf, buf, b->size);
+ array_push(proxy->client_pending, b);
+ size -= b->size;
+ buf += b->size;
+ }
+}
+
+static int write_to_upstream_pending(struct tls_proxy_ctx *proxy)
+{
+ struct buf *buf = get_first_upstream_pending(proxy);
+ /* TODO avoid allocation */
+ uv_write_t *req = (uv_write_t *) malloc(sizeof(uv_write_t));
+ uv_buf_t wrbuf = uv_buf_init(buf->buf, buf->size);
+ req->data = buf->buf;
+ fprintf(stdout, "writing %zd bytes to upstream\n", buf->size);
+ return uv_write(req, (uv_stream_t *)&proxy->upstream, &wrbuf, 1, write_to_upstream_cb);
+}
+
+static ssize_t proxy_gnutls_pull(gnutls_transport_ptr_t h, void *buf, size_t len)
+{
+ struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)h;
+ struct tls_ctx *t = &proxy->tls;
+
+ fprintf(stdout, "\t gnutls: pulling %zd bytes from client\n", len);
+
+ if (t->nread <= t->consumed) {
+ errno = EAGAIN;
+ fprintf(stdout, "\t gnutls: return EAGAIN\n");
+ return -1;
+ }
+
+ ssize_t avail = t->nread - t->consumed;
+ ssize_t transfer = (avail <= len ? avail : len);
+ memcpy(buf, t->buf + t->consumed, transfer);
+ t->consumed += transfer;
+ return transfer;
+}
+
+ssize_t proxy_gnutls_push(gnutls_transport_ptr_t h, const void *buf, size_t len)
+{
+ struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)h;
+ struct tls_ctx *t = &proxy->tls;
+ fprintf(stdout, "\t gnutls: writing %zd bytes to client\n", len);
+
+ ssize_t ret = -1;
+ const size_t req_size_aligned = ((sizeof(uv_write_t) / 16) + 1) * 16;
+ char *common_buf = malloc(req_size_aligned + len);
+ uv_write_t *req = (uv_write_t *) common_buf;
+ char *data = common_buf + req_size_aligned;
+ const uv_buf_t uv_buf[1] = {
+ { data, len }
+ };
+ memcpy(data, buf, len);
+ req->data = data;
+ int res = uv_write(req, (uv_stream_t *)&proxy->client, uv_buf, 1, write_to_client_cb);
+ if (res == 0) {
+ ret = len;
+ } else {
+ free(common_buf);
+ errno = EIO;
+ }
+ return ret;
+}
+
+static int write_to_client_pending(struct tls_proxy_ctx *proxy)
+{
+ if (proxy->client_pending.len == 0) {
+ return 0;
+ }
+
+ struct buf *buf = get_first_client_pending(proxy);
+ uv_buf_t wrbuf = uv_buf_init(buf->buf, buf->size);
+ fprintf(stdout, "writing %zd bytes to client\n", buf->size);
+
+ gnutls_session_t tls_session = proxy->tls.session;
+ assert(proxy->tls.handshake_state != TLS_HS_IN_PROGRESS);
+ assert(gnutls_record_check_corked(tls_session) == 0);
+
+ char *data = buf->buf;
+ size_t len = buf->size;
+
+ ssize_t count = 0;
+ ssize_t submitted = len;
+ ssize_t retries = 0;
+ do {
+ count = gnutls_record_send(tls_session, data, len);
+ if (count < 0) {
+ if (gnutls_error_is_fatal(count)) {
+ fprintf(stderr, "gnutls_record_send failed: %s (%zd)\n",
+ gnutls_strerror_name(count), count);
+ return -1;
+ }
+ if (++retries > TLS_MAX_SEND_RETRIES) {
+ fprintf(stderr, "gnutls_record_send: too many sequential non-fatal errors (%zd), last error is: %s (%zd)\n",
+ retries, gnutls_strerror_name(count), count);
+ return -1;
+ }
+ } else if (count != 0) {
+ data += count;
+ len -= count;
+ retries = 0;
+ } else {
+ if (++retries < TLS_MAX_SEND_RETRIES) {
+ continue;
+ }
+ fprintf(stderr, "gnutls_record_send: too many retries (%zd)\n",
+ retries);
+ fprintf(stderr, "tls_push_to_client didn't send all data(%zd of %zd)\n",
+ len, submitted);
+ return -1;
+ }
+ } while (len > 0);
+
+ remove_first_client_pending(proxy);
+ release_io_buffer(proxy, buf);
+
+ fprintf(stdout, "submitted %zd bytes to client\n", submitted);
+ assert (gnutls_safe_renegotiation_status(tls_session) != 0);
+ assert (gnutls_rehandshake(tls_session) == GNUTLS_E_SUCCESS);
+ /* Prevent write-to-client callback from sending next pending chunk.
+ * At the same time tls_process_from_client() must not call gnutls_handshake()
+ * as there can be application data in this direction. */
+ proxy->tls.handshake_state = TLS_HS_EXPECTED;
+ fprintf(stdout, "rehandshake started\n");
+ return submitted;
+}
+
+static int tls_process_from_upstream(struct tls_proxy_ctx *proxy, const uint8_t *buf, ssize_t len)
+{
+ gnutls_session_t tls_session = proxy->tls.session;
+
+ fprintf(stdout, "pushing %zd bytes to client\n", len);
+
+ assert(gnutls_record_check_corked(tls_session) == 0);
+ ssize_t submitted = 0;
+ if (proxy->client_state != STATE_CONNECTED) {
+ return submitted;
+ }
+
+ bool list_was_empty = (proxy->client_pending.len == 0);
+ push_to_client_pending(proxy, buf, len);
+ submitted = len;
+ if (proxy->tls.handshake_state == TLS_HS_DONE) {
+ if (list_was_empty && proxy->client_pending.len > 0) {
+ int ret = write_to_client_pending(proxy);
+ if (ret < 0) {
+ submitted = -1;
+ }
+ }
+ }
+
+ return submitted;
+}
+
+int tls_process_handshake(struct tls_proxy_ctx *proxy)
+{
+ struct tls_ctx *tls = &proxy->tls;
+ int ret = 1;
+ while (tls->handshake_state == TLS_HS_IN_PROGRESS) {
+ fprintf(stdout, "TLS handshake in progress...\n");
+ int err = gnutls_handshake(tls->session);
+ if (err == GNUTLS_E_SUCCESS) {
+ tls->handshake_state = TLS_HS_DONE;
+ fprintf(stdout, "TLS handshake has completed\n");
+ ret = 1;
+ if (proxy->client_pending.len != 0) {
+ write_to_client_pending(proxy);
+ }
+ } else if (gnutls_error_is_fatal(err)) {
+ fprintf(stderr, "gnutls_handshake failed: %s (%d)\n",
+ gnutls_strerror_name(err), err);
+ ret = -1;
+ break;
+ } else {
+ fprintf(stderr, "gnutls_handshake nonfatal error: %s (%d)\n",
+ gnutls_strerror_name(err), err);
+ ret = 0;
+ break;
+ }
+ }
+ return ret;
+}
+
+int tls_process_from_client(struct tls_proxy_ctx *proxy, const uint8_t *buf, ssize_t nread)
+{
+ struct tls_ctx *tls = &proxy->tls;
+
+ tls->buf = buf;
+ tls->nread = nread >= 0 ? nread : 0;
+ tls->consumed = 0;
+
+ fprintf(stdout, "tls_process: reading %zd bytes from client\n", nread);
+
+ int ret = tls_process_handshake(proxy);
+ if (ret <= 0) {
+ return ret;
+ }
+
+ int submitted = 0;
+ while (true) {
+ ssize_t count = 0;
+ count = gnutls_record_recv(tls->session, tls->recv_buf, sizeof(tls->recv_buf));
+ if (count == GNUTLS_E_AGAIN) {
+ break; /* No data available */
+ } else if (count == GNUTLS_E_INTERRUPTED) {
+ continue; /* Try reading again */
+ } else if (count == GNUTLS_E_REHANDSHAKE) {
+ tls->handshake_state = TLS_HS_IN_PROGRESS;
+ ret = tls_process_handshake(proxy);
+ if (ret <= 0) {
+ return ret;
+ }
+ continue;
+ } else if (count < 0) {
+ fprintf(stderr, "gnutls_record_recv failed: %s (%zd)\n",
+ gnutls_strerror_name(count), count);
+ return -1;
+ } else if (count == 0) {
+ break;
+ }
+ if (proxy->upstream_state == STATE_CONNECTED) {
+ bool upstream_pending_is_empty = (proxy->upstream_pending.len == 0);
+ push_to_upstream_pending(proxy, tls->recv_buf, count);
+ if (upstream_pending_is_empty) {
+ write_to_upstream_pending(proxy);
+ }
+ } else if (proxy->upstream_state == STATE_NOT_CONNECTED) {
+ /* TODO avoid allocation */
+ uv_tcp_init(proxy->loop, &proxy->upstream);
+ uv_connect_t *conn = (uv_connect_t *) malloc(sizeof(uv_connect_t));
+ proxy->upstream_state = STATE_CONNECT_IN_PROGRESS;
+ fprintf(stdout, "connecting to %s\n",
+ ip_straddr((struct sockaddr *)&proxy->upstream_addr));
+ uv_tcp_connect(conn, &proxy->upstream, (struct sockaddr *)&proxy->upstream_addr,
+ on_connect_to_upstream);
+ push_to_upstream_pending(proxy, tls->recv_buf, count);
+ } else if (proxy->upstream_state == STATE_CONNECT_IN_PROGRESS) {
+ push_to_upstream_pending(proxy, tls->recv_buf, count);
+ }
+ submitted += count;
+ }
+ return submitted;
+}
+
+struct tls_proxy_ctx *tls_proxy_allocate()
+{
+ return malloc(sizeof(struct tls_proxy_ctx));
+}
+
+int tls_proxy_init(struct tls_proxy_ctx *proxy,
+ const char *server_addr, int server_port,
+ const char *upstream_addr, int upstream_port,
+ const char *cert_file, const char *key_file)
+{
+ proxy->loop = uv_default_loop();
+ uv_tcp_init(proxy->loop, &proxy->server);
+ int res = uv_ip4_addr(server_addr, server_port, (struct sockaddr_in *)&proxy->server_addr);
+ if (res != 0) {
+ fprintf(stderr, "uv_ip4_addr failed with string '%s'\n", server_addr);
+ return -1;
+ }
+ res = uv_ip4_addr(upstream_addr, upstream_port, (struct sockaddr_in *)&proxy->upstream_addr);
+ if (res != 0) {
+ fprintf(stderr, "uv_ip4_addr failed with string '%s'\n", upstream_addr);
+ return -1;
+ }
+ array_init(proxy->buffer_pool);
+ array_init(proxy->upstream_pending);
+ array_init(proxy->client_pending);
+ proxy->server_state = STATE_NOT_CONNECTED;
+ proxy->client_state = STATE_NOT_CONNECTED;
+ proxy->upstream_state = STATE_NOT_CONNECTED;
+
+ proxy->loop->data = proxy;
+
+ int err = 0;
+ if (gnutls_references == 0) {
+ err = gnutls_global_init();
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stderr, "gnutls_global_init() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ return -1;
+ }
+ }
+ gnutls_references += 1;
+
+ err = gnutls_certificate_allocate_credentials(&proxy->tls.credentials);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stderr, "gnutls_certificate_allocate_credentials() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ return -1;
+ }
+
+ err = gnutls_certificate_set_x509_system_trust(proxy->tls.credentials);
+ if (err <= 0) {
+ fprintf(stderr, "gnutls_certificate_set_x509_system_trust() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ return -1;
+ }
+
+ if (cert_file && key_file) {
+ err = gnutls_certificate_set_x509_key_file(proxy->tls.credentials,
+ cert_file, key_file, GNUTLS_X509_FMT_PEM);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stderr, "gnutls_certificate_set_x509_key_file() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ return -1;
+ }
+ }
+
+ err = gnutls_priority_init(&proxy->tls.priority_cache, NULL, NULL);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stderr, "gnutls_priority_init() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ return -1;
+ }
+
+
+ proxy->tls.handshake_state = TLS_HS_NOT_STARTED;
+ return 0;
+}
+
+void tls_proxy_free(struct tls_proxy_ctx *proxy)
+{
+ if (!proxy) {
+ return;
+ }
+ clear_upstream_pending(proxy);
+ clear_client_pending(proxy);
+ clear_buffer_pool(proxy);
+ gnutls_certificate_free_credentials(proxy->tls.credentials);
+ gnutls_priority_deinit(proxy->tls.priority_cache);
+ /* TODO correctly close all the uv_tcp_t */
+ free(proxy);
+
+ gnutls_references -= 1;
+ if (gnutls_references == 0) {
+ gnutls_global_deinit();
+ }
+}
+
+int tls_proxy_start_listen(struct tls_proxy_ctx *proxy)
+{
+ uv_tcp_bind(&proxy->server, (const struct sockaddr*)&proxy->server_addr, 0);
+ int ret = uv_listen((uv_stream_t*)&proxy->server, 128, on_client_connection);
+ if (ret == 0) {
+ proxy->server_state = STATE_LISTENING;
+ }
+ return ret;
+}
+
+int tls_proxy_run(struct tls_proxy_ctx *proxy)
+{
+ return uv_run(proxy->loop, UV_RUN_DEFAULT);
+}
diff --git a/tests/pytests/rehandshake/tls-proxy.h b/tests/pytests/rehandshake/tls-proxy.h
new file mode 100644
index 0000000..1204eda
--- /dev/null
+++ b/tests/pytests/rehandshake/tls-proxy.h
@@ -0,0 +1,14 @@
+#pragma once
+
+struct tls_proxy_ctx;
+
+struct tls_proxy_ctx *tls_proxy_allocate();
+void tls_proxy_free(struct tls_proxy_ctx *proxy);
+int tls_proxy_init(struct tls_proxy_ctx *proxy,
+ const char *server_addr, int server_port,
+ const char *upstream_addr, int upstream_port,
+ const char *cert_file, const char *key_file);
+int tls_proxy_start_listen(struct tls_proxy_ctx *proxy);
+int tls_proxy_run(struct tls_proxy_ctx *proxy);
+
+
diff --git a/tests/pytests/rehandshake/tlsproxy.c b/tests/pytests/rehandshake/tlsproxy.c
new file mode 100644
index 0000000..0c074f1
--- /dev/null
+++ b/tests/pytests/rehandshake/tlsproxy.c
@@ -0,0 +1,31 @@
+#include <stdio.h>
+#include "tls-proxy.h"
+#include <gnutls/gnutls.h>
+
+int main()
+{
+ struct tls_proxy_ctx *proxy = tls_proxy_allocate();
+ if (!proxy) {
+ fprintf(stderr, "can't allocate tls_proxy structure\n");
+ return 1;
+ }
+ int res = tls_proxy_init(proxy,
+ "127.0.0.1", 53921, /* Address to listen */
+ "127.0.0.1", 53910, /* Upstream address */
+ "../certs/tt.cert.pem",
+ "../certs/tt.key.pem");
+ if (res) {
+ fprintf(stderr, "can't initialize tls_proxy structure\n");
+ return res;
+ }
+ res = tls_proxy_start_listen(proxy);
+ if (res) {
+ fprintf(stderr, "error starting listen, error code: %i\n", res);
+ return res;
+ }
+ fprintf(stdout, "started...\n");
+ res = tls_proxy_run(proxy);
+ tls_proxy_free(proxy);
+ return res;
+}
+
diff --git a/tests/pytests/requirements.txt b/tests/pytests/requirements.txt
new file mode 100644
index 0000000..6e2e4d2
--- /dev/null
+++ b/tests/pytests/requirements.txt
@@ -0,0 +1,5 @@
+dnspython
+jinja2
+pytest
+pytest-html
+pytest-xdist
diff --git a/tests/pytests/templates/kresd.conf.j2 b/tests/pytests/templates/kresd.conf.j2
new file mode 100644
index 0000000..4d95521
--- /dev/null
+++ b/tests/pytests/templates/kresd.conf.j2
@@ -0,0 +1,42 @@
+modules = {
+ 'policy',
+ 'hints > iterate',
+}
+
+verbose({{ 'true' if kresd.verbose else 'false' }})
+
+{% if kresd.ip %}
+net.listen('{{ kresd.ip }}', {{ kresd.port }})
+net.listen('{{ kresd.ip }}', {{ kresd.tls_port }}, {tls = true})
+{% endif %}
+
+{% if kresd.ip6 %}
+net.listen('{{ kresd.ip6 }}', {{ kresd.port }})
+net.listen('{{ kresd.ip6 }}', {{ kresd.tls_port }}, {tls = true})
+{% endif %}
+
+net.ipv4=true
+net.ipv6=true
+
+{% if kresd.tls_key_path and kresd.tls_cert_path %}
+net.tls("{{ kresd.tls_cert_path }}", "{{ kresd.tls_key_path }}")
+{% endif %}
+
+{% for name, ip in kresd.hints.items() %}
+hints['{{ name }}'] = '{{ ip }}'
+{% endfor %}
+
+policy.add(policy.all(policy.QTRACE))
+
+{% if kresd.forward %}
+policy.add(policy.all(
+ {% if kresd.forward.proto == 'tls' %}
+ policy.TLS_FORWARD({
+ {"{{ kresd.forward.ip }}@{{ kresd.forward.port }}", hostname='{{ kresd.forward.hostname}}', ca_file='{{ kresd.forward.ca_file }}'}})
+ {% endif %}
+))
+{% endif %}
+
+modules.unload("ta_signal_query")
+modules.unload("priming")
+modules.unload("detect_time_skew")
diff --git a/tests/pytests/test_conn_mgmt.py b/tests/pytests/test_conn_mgmt.py
new file mode 100644
index 0000000..c4b1cba
--- /dev/null
+++ b/tests/pytests/test_conn_mgmt.py
@@ -0,0 +1,213 @@
+"""TCP Connection Management tests"""
+
+import socket
+import struct
+import time
+
+import pytest
+
+import utils
+
+
+@pytest.mark.parametrize('garbage_lengths', [
+ (1,),
+ (1024,),
+ (65533,), # max size garbage
+ (65533, 65533),
+ (1024, 1024, 1024),
+ # (0,), # currently kresd uses this as a heuristic of "lost in bytestream"
+ # (0, 1024), # and closes the connection
+])
+def test_ignore_garbage(kresd_sock, garbage_lengths, single_buffer, query_before):
+ """Send chunk of garbage, prefixed by garbage length. It should be ignored."""
+ buff = b''
+ if query_before: # optionally send initial query
+ msg_buff_before, msgid_before = utils.get_msgbuff()
+ if single_buffer:
+ buff += msg_buff_before
+ else:
+ kresd_sock.sendall(msg_buff_before)
+
+ for glength in garbage_lengths: # prepare garbage data
+ if glength is None:
+ continue
+ garbage_buff = utils.get_prefixed_garbage(glength)
+ if single_buffer:
+ buff += garbage_buff
+ else:
+ kresd_sock.sendall(garbage_buff)
+
+ msg_buff, msgid = utils.get_msgbuff() # final query
+ buff += msg_buff
+ kresd_sock.sendall(buff)
+
+ if query_before:
+ answer_before = utils.receive_parse_answer(kresd_sock)
+ assert answer_before.id == msgid_before
+ answer = utils.receive_parse_answer(kresd_sock)
+ assert answer.id == msgid
+
+
+def test_pipelining(kresd_sock):
+ """
+ First query takes longer to resolve - answer to second query should arrive sooner.
+
+ This test requires internet connection.
+ """
+ # initialization (to avoid issues with net.ipv6=true)
+ buff_pre, msgid_pre = utils.get_msgbuff('0.delay.getdnsapi.net.')
+ kresd_sock.sendall(buff_pre)
+ msg_answer = utils.receive_parse_answer(kresd_sock)
+ assert msg_answer.id == msgid_pre
+
+ # test
+ buff1, msgid1 = utils.get_msgbuff('1500.delay.getdnsapi.net.', msgid=1)
+ buff2, msgid2 = utils.get_msgbuff('1.delay.getdnsapi.net.', msgid=2)
+ buff = buff1 + buff2
+ kresd_sock.sendall(buff)
+
+ msg_answer = utils.receive_parse_answer(kresd_sock)
+ assert msg_answer.id == msgid2
+
+ msg_answer = utils.receive_parse_answer(kresd_sock)
+ assert msg_answer.id == msgid1
+
+
+@pytest.mark.parametrize('duration, delay', [
+ (utils.MAX_TIMEOUT, 0.1),
+ (utils.MAX_TIMEOUT, 3),
+ (utils.MAX_TIMEOUT, 7),
+ (utils.MAX_TIMEOUT + 10, 3),
+])
+def test_long_lived(kresd_sock, duration, delay):
+ """Establish and keep connection alive for longer than maximum timeout."""
+ utils.ping_alive(kresd_sock)
+ end_time = time.time() + duration
+
+ while time.time() < end_time:
+ time.sleep(delay)
+ utils.ping_alive(kresd_sock)
+
+
+def test_close(kresd_sock, query_before):
+ """Establish a connection and wait for timeout from kresd."""
+ if query_before:
+ utils.ping_alive(kresd_sock)
+ time.sleep(utils.MAX_TIMEOUT)
+
+ with utils.expect_kresd_close():
+ utils.ping_alive(kresd_sock)
+
+
+def test_slow_lorris(kresd_sock, query_before):
+ """Simulate slow-lorris attack by sending byte after byte with delays in between."""
+ if query_before:
+ utils.ping_alive(kresd_sock)
+
+ buff, _ = utils.get_msgbuff()
+ end_time = time.time() + utils.MAX_TIMEOUT
+
+ with utils.expect_kresd_close():
+ for i in range(len(buff)):
+ b = buff[i:i+1]
+ kresd_sock.send(b)
+ if time.time() > end_time:
+ break
+ time.sleep(1)
+
+
+@pytest.mark.parametrize('sock_func_name', [
+ 'ip_tcp_socket',
+ 'ip6_tcp_socket',
+])
+def test_oob(kresd, sock_func_name):
+ """TCP out-of-band (urgent) data must not crash resolver."""
+ make_sock = getattr(kresd, sock_func_name)
+ sock = make_sock()
+ msg_buff, msgid = utils.get_msgbuff()
+ sock.sendall(msg_buff, socket.MSG_OOB)
+
+ try:
+ msg_answer = utils.receive_parse_answer(sock)
+ assert msg_answer.id == msgid
+ except ConnectionError:
+ pass # TODO kresd responds with TCP RST, this should be fixed
+
+ # check kresd is alive
+ sock2 = make_sock()
+ utils.ping_alive(sock2)
+
+
+def flood_buffer(msgcount):
+ flood_buff = bytes()
+ msgbuff, _ = utils.get_msgbuff()
+ noid_msgbuff = msgbuff[2:]
+
+ def gen_msg(msgid):
+ return struct.pack("!H", len(msgbuff)) + struct.pack("!H", msgid) + noid_msgbuff
+
+ for i in range(msgcount):
+ flood_buff += gen_msg(i)
+ return flood_buff
+
+
+def test_query_flood_close(make_kresd_sock):
+ """Flood resolver with queries and close the connection."""
+ buff = flood_buffer(10000)
+ sock1 = make_kresd_sock()
+ sock1.sendall(buff)
+ sock1.close()
+
+ sock2 = make_kresd_sock()
+ utils.ping_alive(sock2)
+
+
+def test_query_flood_no_recv(make_kresd_sock):
+ """Flood resolver with queries but don't read any data."""
+ # A use-case for TCP_USER_TIMEOUT socket option? See RFC 793 and RFC 5482
+
+ # It seems it doesn't works as expected. libuv doesn't return any error
+ # (neither on uv_write() call, not in the callback) when kresd sends answers,
+ # so kresd can't recognize that client didn't read any answers. At a certain
+ # point, kresd stops receiving queries from the client (whilst client keep
+ # sending) and closes connection due to timeout.
+
+ buff = flood_buffer(10000)
+ sock1 = make_kresd_sock()
+ end_time = time.time() + utils.MAX_TIMEOUT
+
+ with utils.expect_kresd_close(rst_ok=True): # connection must be closed
+ while time.time() < end_time:
+ sock1.sendall(buff)
+ time.sleep(0.5)
+
+ sock2 = make_kresd_sock()
+ utils.ping_alive(sock2) # resolver must stay alive
+
+
+@pytest.mark.parametrize('glength, gcount, delay', [
+ (65533, 100, 0.5),
+ (0, 100000, 0.5),
+ (1024, 1000, 0.5),
+ (65533, 1, 0),
+ (0, 1, 0),
+ (1024, 1, 0),
+])
+def test_query_flood_garbage(make_kresd_sock, glength, gcount, delay, query_before):
+ """Flood resolver with prefixed garbage."""
+ sock1 = make_kresd_sock()
+ if query_before:
+ utils.ping_alive(sock1)
+
+ gbuff = utils.get_prefixed_garbage(glength)
+ buff = gbuff * gcount
+
+ end_time = time.time() + utils.MAX_TIMEOUT
+
+ with utils.expect_kresd_close(rst_ok=True): # connection must be closed
+ while time.time() < end_time:
+ sock1.sendall(buff)
+ time.sleep(delay)
+
+ sock2 = make_kresd_sock()
+ utils.ping_alive(sock2) # resolver must stay alive
diff --git a/tests/pytests/test_prefix.py b/tests/pytests/test_prefix.py
new file mode 100644
index 0000000..5d8ef16
--- /dev/null
+++ b/tests/pytests/test_prefix.py
@@ -0,0 +1,113 @@
+"""TCP Connection Management tests - prefix length
+
+RFC1035
+4.2.2. TCP usage
+The message is prefixed with a two byte length field which gives the message
+length, excluding the two byte length field.
+
+The following test suite focuses on edge cases for the prefix - when it
+is either too short or too long, instead of matching the length of DNS
+message exactly.
+"""
+
+import time
+
+import pytest
+
+import utils
+
+
+@pytest.fixture(params=[
+ 'no_query_before',
+ 'query_before',
+ 'query_before_in_single_buffer',
+])
+def send_query(request):
+ """Function sends a buffer, either by itself, or with a valid query before.
+ If a valid query is sent before, it can be sent either in a separate buffer, or
+ along with the provided buffer."""
+
+ # pylint: disable=possibly-unused-variable
+
+ def no_query_before(sock, buff): # pylint: disable=unused-argument
+ sock.sendall(buff)
+
+ def query_before(sock, buff, single_buffer=False):
+ """Send an initial query and expect a response."""
+ msg_buff, msgid = utils.get_msgbuff()
+
+ if single_buffer:
+ sock.sendall(msg_buff + buff)
+ else:
+ sock.sendall(msg_buff)
+ sock.sendall(buff)
+
+ answer = utils.receive_parse_answer(sock)
+ assert answer.id == msgid
+
+ def query_before_in_single_buffer(sock, buff):
+ return query_before(sock, buff, single_buffer=True)
+
+ return locals()[request.param]
+
+
+@pytest.mark.parametrize('datalen', [
+ 1, # just one byte of DNS header
+ 11, # DNS header size minus 1
+ 14, # DNS Header size plus 2
+])
+def test_prefix_cuts_message(kresd_sock, datalen, send_query):
+ """Prefix is shorter than the DNS message."""
+ wire, _ = utils.prepare_wire()
+ assert datalen < len(wire)
+ invalid_buff = utils.prepare_buffer(wire, datalen)
+
+ send_query(kresd_sock, invalid_buff) # buffer breaks parsing of TCP stream
+
+ with utils.expect_kresd_close():
+ utils.ping_alive(kresd_sock)
+
+
+def test_prefix_greater_than_message(kresd_sock, send_query):
+ """Prefix is greater than the length of the entire DNS message."""
+ wire, invalid_msgid = utils.prepare_wire()
+ datalen = len(wire) + 16
+ invalid_buff = utils.prepare_buffer(wire, datalen)
+
+ send_query(kresd_sock, invalid_buff)
+
+ valid_buff, _ = utils.get_msgbuff()
+ kresd_sock.sendall(valid_buff)
+
+ # invalid_buff is answered (treats additional data as trailing garbage)
+ answer = utils.receive_parse_answer(kresd_sock)
+ assert answer.id == invalid_msgid
+
+ # parsing stream is broken by the invalid_buff, valid query is never answered
+ with utils.expect_kresd_close():
+ utils.receive_parse_answer(kresd_sock)
+
+
+@pytest.mark.parametrize('glength', [
+ 0,
+ 1,
+ 8,
+ 1024,
+ 4096,
+ 20000,
+])
+def test_prefix_trailing_garbage(kresd_sock, glength, query_before):
+ """Send messages with trailing garbage (its length included in prefix)."""
+ if query_before:
+ utils.ping_alive(kresd_sock)
+
+ for _ in range(10):
+ wire, msgid = utils.prepare_wire()
+ wire += utils.get_garbage(glength)
+ buff = utils.prepare_buffer(wire)
+
+ kresd_sock.sendall(buff)
+ answer = utils.receive_parse_answer(kresd_sock)
+ assert answer.id == msgid
+
+ time.sleep(0.1)
diff --git a/tests/pytests/test_rehandshake.py b/tests/pytests/test_rehandshake.py
new file mode 100644
index 0000000..ffbc10b
--- /dev/null
+++ b/tests/pytests/test_rehandshake.py
@@ -0,0 +1,87 @@
+"""TLS rehandshake test
+
+Test utilizes rehandshake/tls-proxy, which forwards queries to configured
+resolver, but when it sends the response back to the query source, it
+performs a rehandshake after every 8 bytes sent.
+
+It is expected the answer will be received by the source kresd instance
+and sent back to the client (this test).
+
+Make sure to run `make all` in `rehandshake/` to compile the proxy.
+"""
+
+import os
+import re
+import subprocess
+import time
+
+import dns
+import dns.rcode
+import pytest
+
+from kresd import CERTS_DIR, Forward, make_kresd, PYTESTS_DIR
+import utils
+
+
+REHANDSHAKE_PROXY = os.path.join(PYTESTS_DIR, 'rehandshake', 'tlsproxy')
+
+
+@pytest.mark.skipif(not os.path.exists(REHANDSHAKE_PROXY),
+ reason="tlsproxy not found (did you compile it?)")
+def test_rehandshake(tmpdir):
+ def resolve_hint(sock, qname):
+ buff, msgid = utils.get_msgbuff(qname)
+ sock.sendall(buff)
+ answer = utils.receive_parse_answer(sock)
+ assert answer.id == msgid
+ assert answer.rcode() == dns.rcode.NOERROR
+ assert answer.answer[0][0].address == '127.0.0.1'
+
+ hints = {
+ '0.foo.': '127.0.0.1',
+ '1.foo.': '127.0.0.1',
+ '2.foo.': '127.0.0.1',
+ '3.foo.': '127.0.0.1',
+ }
+ # run forward target instance
+ workdir = os.path.join(str(tmpdir), 'kresd_fwd_target')
+ os.makedirs(workdir)
+
+ with make_kresd(workdir, hints=hints, port=53910) as kresd_fwd_target:
+ sock = kresd_fwd_target.ip_tls_socket()
+ resolve_hint(sock, '0.foo.')
+
+ # run proxy
+ cwd, cmd = os.path.split(REHANDSHAKE_PROXY)
+ cmd = './' + cmd
+ ca_file = os.path.join(CERTS_DIR, 'tt.cert.pem')
+ try:
+ proxy = subprocess.Popen(
+ [cmd], cwd=cwd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
+
+ # run test kresd instance
+ workdir2 = os.path.join(str(tmpdir), 'kresd')
+ os.makedirs(workdir2)
+ forward = Forward(proto='tls', ip='127.0.0.1', port=53921,
+ hostname='transport-test-server.com', ca_file=ca_file)
+ with make_kresd(workdir2, forward=forward) as kresd:
+ sock2 = kresd.ip_tcp_socket()
+ try:
+ for hint in hints:
+ resolve_hint(sock2, hint)
+ time.sleep(0.1)
+ finally:
+ # verify log
+ n_connecting_to = 0
+ n_rehandshake = 0
+ partial_log = kresd.partial_log()
+ print(partial_log)
+ for line in partial_log.splitlines():
+ if re.search(r"connecting to: .*", line) is not None:
+ n_connecting_to += 1
+ elif re.search(r"TLS rehandshake .* has started", line) is not None:
+ n_rehandshake += 1
+ assert n_connecting_to == 0 # shouldn't be present in partial log
+ assert n_rehandshake > 0
+ finally:
+ proxy.terminate()
diff --git a/tests/pytests/test_tls.py b/tests/pytests/test_tls.py
new file mode 100644
index 0000000..361741d
--- /dev/null
+++ b/tests/pytests/test_tls.py
@@ -0,0 +1,77 @@
+"""TLS-specific tests"""
+
+import itertools
+import os
+from socket import AF_INET, AF_INET6
+import ssl
+import sys
+
+import pytest
+
+from kresd import make_kresd
+import utils
+
+
+def test_tls_no_cert(kresd, sock_family):
+ """Use TLS without certificates."""
+ sock, dest = kresd.stream_socket(sock_family, tls=True)
+ ctx = utils.make_ssl_context(insecure=True)
+ ssock = ctx.wrap_socket(sock)
+ ssock.connect(dest)
+
+ utils.ping_alive(ssock)
+
+
+def test_tls_selfsigned_cert(kresd_tt, sock_family):
+ """Use TLS with a self signed certificate."""
+ sock, dest = kresd_tt.stream_socket(sock_family, tls=True)
+ ctx = utils.make_ssl_context(verify_location=kresd_tt.tls_cert_path)
+ ssock = ctx.wrap_socket(sock, server_hostname='transport-test-server.com')
+ ssock.connect(dest)
+
+ utils.ping_alive(ssock)
+
+
+def test_tls_cert_hostname_mismatch(kresd_tt, sock_family):
+ """Attempt to use self signed certificate and incorrect hostname."""
+ sock, dest = kresd_tt.stream_socket(sock_family, tls=True)
+ ctx = utils.make_ssl_context(verify_location=kresd_tt.tls_cert_path)
+ ssock = ctx.wrap_socket(sock, server_hostname='wrong-host-name')
+
+ with pytest.raises(ssl.CertificateError):
+ ssock.connect(dest)
+
+
+@pytest.mark.skipif(sys.version_info < (3, 6),
+ reason="requires python3.6 or higher")
+@pytest.mark.parametrize('sf1, sf2, sf3', itertools.product(
+ [AF_INET, AF_INET6], [AF_INET, AF_INET6], [AF_INET, AF_INET6]))
+def test_tls_session_resumption(tmpdir, sf1, sf2, sf3):
+ """Attempt TLS session resumption against the same kresd instance and a different one."""
+ # TODO ensure that session can't be resumed after session ticket key regeneration
+ # at the first kresd instance
+
+ def connect(kresd, ctx, sf, session=None):
+ sock, dest = kresd.stream_socket(sf, tls=True)
+ ssock = ctx.wrap_socket(
+ sock, server_hostname='transport-test-server.com', session=session)
+ ssock.connect(dest)
+ new_session = ssock.session
+ assert new_session.has_ticket
+ assert ssock.session_reused == (session is not None)
+ utils.ping_alive(ssock)
+ ssock.close()
+ return new_session
+
+ workdir = os.path.join(str(tmpdir), 'kresd')
+ os.makedirs(workdir)
+
+ with make_kresd(workdir, 'tt') as kresd:
+ ctx = utils.make_ssl_context(verify_location=kresd.tls_cert_path)
+ session = connect(kresd, ctx, sf1) # initial conn
+ connect(kresd, ctx, sf2, session) # resume session on the same instance
+
+ workdir2 = os.path.join(str(tmpdir), 'kresd2')
+ os.makedirs(workdir2)
+ with make_kresd(workdir2, 'tt') as kresd2:
+ connect(kresd2, ctx, sf3, session) # resume session on a different instance
diff --git a/tests/pytests/utils.py b/tests/pytests/utils.py
new file mode 100644
index 0000000..dcdc14c
--- /dev/null
+++ b/tests/pytests/utils.py
@@ -0,0 +1,131 @@
+from contextlib import contextmanager
+import random
+import ssl
+import struct
+import time
+
+import dns
+import dns.message
+import pytest
+
+
+# default net.tcp_in_idle is 10s, TCP_DEFER_ACCEPT 3s, some extra for
+# Python handling / edge cases
+MAX_TIMEOUT = 16
+
+
+def receive_answer(sock):
+ answer_total_len = 0
+ data = sock.recv(2)
+ if not data:
+ return None
+ answer_total_len = struct.unpack_from("!H", data)[0]
+
+ answer_received_len = 0
+ data_answer = b''
+ while answer_received_len < answer_total_len:
+ data_chunk = sock.recv(answer_total_len - answer_received_len)
+ if not data_chunk:
+ return None
+ data_answer += data_chunk
+ answer_received_len += len(data_answer)
+
+ return data_answer
+
+
+def receive_parse_answer(sock):
+ data_answer = receive_answer(sock)
+
+ if data_answer is None:
+ raise BrokenPipeError("kresd closed connection")
+
+ msg_answer = dns.message.from_wire(data_answer, one_rr_per_rrset=True)
+ return msg_answer
+
+
+def prepare_wire(
+ qname='localhost.',
+ qtype=dns.rdatatype.A,
+ qclass=dns.rdataclass.IN,
+ msgid=None):
+ """Utility function to generate DNS wire format message"""
+ msg = dns.message.make_query(qname, qtype, qclass)
+ if msgid is not None:
+ msg.id = msgid
+ return msg.to_wire(), msg.id
+
+
+def prepare_buffer(wire, datalen=None):
+ """Utility function to prepare TCP buffer from DNS message in wire format"""
+ assert isinstance(wire, bytes)
+ if datalen is None:
+ datalen = len(wire)
+ return struct.pack("!H", datalen) + wire
+
+
+def get_msgbuff(qname='localhost.', qtype=dns.rdatatype.A, msgid=None):
+ wire, msgid = prepare_wire(qname, qtype, msgid=msgid)
+ buff = prepare_buffer(wire)
+ return buff, msgid
+
+
+def get_garbage(length):
+ return bytes(random.getrandbits(8) for _ in range(length))
+
+
+def get_prefixed_garbage(length):
+ data = get_garbage(length)
+ return prepare_buffer(data)
+
+
+def try_ping_alive(sock, msgid=None, close=False):
+ try:
+ ping_alive(sock, msgid)
+ except AssertionError:
+ return False
+ finally:
+ if close:
+ sock.close()
+ return True
+
+
+def ping_alive(sock, msgid=None):
+ buff, msgid = get_msgbuff(msgid=msgid)
+ sock.sendall(buff)
+ answer = receive_parse_answer(sock)
+ assert answer.id == msgid
+
+
+@contextmanager
+def expect_kresd_close(rst_ok=False):
+ with pytest.raises(BrokenPipeError, message="kresd didn't close the connection"):
+ try:
+ time.sleep(0.2) # give kresd time to close connection with TCP FIN
+ yield
+ except ConnectionResetError:
+ if rst_ok:
+ raise BrokenPipeError
+ else:
+ pytest.skip("kresd closed connection with TCP RST")
+
+
+def make_ssl_context(insecure=False, verify_location=None):
+ # set TLS v1.2+
+ context = ssl.SSLContext(ssl.PROTOCOL_TLS)
+ context.options |= ssl.OP_NO_SSLv2
+ context.options |= ssl.OP_NO_SSLv3
+ context.options |= ssl.OP_NO_TLSv1
+ context.options |= ssl.OP_NO_TLSv1_1
+
+ if insecure:
+ # turn off certificate verification
+ context.check_hostname = False
+ context.verify_mode = ssl.CERT_NONE
+ else:
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.check_hostname = True
+
+ if verify_location is not None:
+ context.load_verify_locations(verify_location)
+
+ return context