summaryrefslogtreecommitdiffstats
path: root/tests/pytests
diff options
context:
space:
mode:
Diffstat (limited to 'tests/pytests')
-rw-r--r--tests/pytests/README.rst56
-rwxr-xr-xtests/pytests/certs/tt-certgen-expired.sh19
-rwxr-xr-xtests/pytests/certs/tt-certgen.sh5
-rw-r--r--tests/pytests/certs/tt-expired.cert.pem80
-rw-r--r--tests/pytests/certs/tt-expired.key.pem27
-rw-r--r--tests/pytests/certs/tt.cert.pem22
-rw-r--r--tests/pytests/certs/tt.conf353
-rw-r--r--tests/pytests/certs/tt.key.pem28
-rw-r--r--tests/pytests/conftest.py102
-rw-r--r--tests/pytests/conn_flood.py85
-rw-r--r--tests/pytests/kresd.py306
-rw-r--r--tests/pytests/meson.build77
-rw-r--r--tests/pytests/proxy.py161
-rw-r--r--tests/pytests/proxy/tls-proxy.c1038
-rw-r--r--tests/pytests/proxy/tls-proxy.h34
-rw-r--r--tests/pytests/proxy/tlsproxy.c198
-rw-r--r--tests/pytests/pylintrc33
-rw-r--r--tests/pytests/requirements.txt5
-rw-r--r--tests/pytests/templates/kresd.conf.j262
-rw-r--r--tests/pytests/test_conn_mgmt.py214
-rw-r--r--tests/pytests/test_edns.py22
-rw-r--r--tests/pytests/test_prefix.py114
-rw-r--r--tests/pytests/test_random_close.py54
-rw-r--r--tests/pytests/test_rehandshake.py52
-rw-r--r--tests/pytests/test_tls.py83
-rw-r--r--tests/pytests/utils.py136
26 files changed, 3366 insertions, 0 deletions
diff --git a/tests/pytests/README.rst b/tests/pytests/README.rst
new file mode 100644
index 0000000..173dc40
--- /dev/null
+++ b/tests/pytests/README.rst
@@ -0,0 +1,56 @@
+.. SPDX-License-Identifier: GPL-3.0-or-later
+
+Python client tests for kresd
+=============================
+
+The tests run `/usr/bin/env kresd` (can be modified with `$PATH`) with custom config
+and execute client-side testing, such as TCP / TLS connection management.
+
+Requirements
+------------
+
+- pip3 install -r requirements.txt
+
+Executing tests
+---------------
+
+Tests can be executed with the pytest framework.
+
+.. code-block:: bash
+
+ $ pytest-3 # sequential, all tests (with exception of few special tests)
+ $ pytest-3 test_conn_mgmt.py::test_ignore_garbage # specific test only
+ $ pytest-3 --html pytests.html --self-contained-html # html report
+
+It's highly recommended to run these tests in parallel, since lot of them
+wait for kresd timeout. This can be done with `python-xdist`:
+
+.. code-block:: bash
+
+ $ pytest-3 -n 24 # parallel with 24 jobs
+
+Each test spawns an independent kresd instance, so test failures shouldn't affect
+each other.
+
+Some tests are omitted from automatic test collection by default, due to their
+resource constraints. These typically have to be executed separately by providing
+the path to test file directly.
+
+.. code-block:: bash
+
+ $ pytest-3 conn_flood.py
+
+Note: some tests may fail without an internet connection.
+
+Developer notes
+---------------
+
+Typically, each test requires a setup of kresd, and a connected socket to run tests on.
+The framework provides a few useful pytest fixtures to simplify this process:
+
+- `kresd_sock` provides a connected socket to a test-specific, running kresd instance.
+ It expands to 4 values (tests) - IPv4 TCP, IPv6 TCP, IPv4 TLS, IPv6 TLS sockets
+- `make_kresd_sock` is similar to `kresd_sock`, except it's a factory function that
+ produces a new connected socket (of the same type) on each call
+- `kresd`, `kresd_tt` are all Kresd instances, already running
+ and initialized with config (with no / valid TLS certificates)
diff --git a/tests/pytests/certs/tt-certgen-expired.sh b/tests/pytests/certs/tt-certgen-expired.sh
new file mode 100755
index 0000000..23a6978
--- /dev/null
+++ b/tests/pytests/certs/tt-certgen-expired.sh
@@ -0,0 +1,19 @@
+# !/bin/bash
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+if [ ! -d ./demoCA ]; then
+ mkdir ./demoCA
+fi
+if [ ! -d ./demoCA/newcerts ]; then
+ mkdir ./demoCA/newcerts
+fi
+touch ./demoCA/index.txt
+touch ./demoCA/index.txt.attr
+if [ ! -f ./demoCA/serial ]; then
+ echo 01 > ./demoCA/serial
+fi
+
+openssl genrsa -out tt-expired.key.pem 2048
+openssl req -config tt.conf -new -key tt-expired.key.pem -out tt-expired.csr.pem
+openssl ca -config tt.conf -selfsign -keyfile tt-expired.key.pem -out tt-expired.cert.pem -in tt-expired.csr.pem -startdate 19700101000000Z -enddate 19700101000000Z
+
diff --git a/tests/pytests/certs/tt-certgen.sh b/tests/pytests/certs/tt-certgen.sh
new file mode 100755
index 0000000..9414475
--- /dev/null
+++ b/tests/pytests/certs/tt-certgen.sh
@@ -0,0 +1,5 @@
+# !/bin/sh
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+openssl req -config tt.conf -new -x509 -newkey rsa:2048 -nodes -keyout tt.key.pem -sha256 -out tt.cert.pem -days 20000
+
diff --git a/tests/pytests/certs/tt-expired.cert.pem b/tests/pytests/certs/tt-expired.cert.pem
new file mode 100644
index 0000000..c9f8c09
--- /dev/null
+++ b/tests/pytests/certs/tt-expired.cert.pem
@@ -0,0 +1,80 @@
+Certificate:
+ Data:
+ Version: 3 (0x2)
+ Serial Number: 1 (0x1)
+ Signature Algorithm: sha256WithRSAEncryption
+ Issuer: C=CZ, ST=PRAGUE, CN=transport-test-server.com
+ Validity
+ Not Before: Jan 1 00:00:00 1970 GMT
+ Not After : Jan 1 00:00:00 1970 GMT
+ Subject: C=CZ, ST=PRAGUE, CN=transport-test-server.com
+ Subject Public Key Info:
+ Public Key Algorithm: rsaEncryption
+ Public-Key: (2048 bit)
+ Modulus:
+ 00:bf:6b:1a:11:47:01:ac:eb:5c:2d:cf:ce:6a:a4:
+ 00:ce:2f:d1:25:03:5f:06:38:02:92:24:18:92:2a:
+ 69:19:b2:2b:a3:4f:f7:79:de:35:c3:f5:72:37:83:
+ 44:93:f9:76:fc:89:29:32:9c:0d:4b:95:7d:d1:5d:
+ 40:e9:ba:49:50:7d:c6:0a:c8:1e:e7:90:1e:37:7c:
+ 0b:23:a3:e3:bc:c9:53:81:de:d6:5f:cb:b2:3d:36:
+ ac:59:b0:33:91:8f:0c:5f:10:20:70:bf:a3:22:b3:
+ 98:ac:d4:7a:ea:67:b8:b1:8c:cf:e5:fe:8f:a0:a5:
+ 02:ad:6d:ce:f1:62:ab:dc:5d:96:9c:4f:95:47:d5:
+ 82:b7:b3:e3:87:4c:8d:38:85:2a:24:9d:7f:c7:a4:
+ 0e:bd:8a:2d:6b:d2:d4:e8:78:62:1b:aa:25:5f:5a:
+ 64:e5:76:23:ae:11:03:9a:5c:ed:a2:ba:51:ec:b1:
+ f3:ae:ba:5c:eb:dd:49:63:ca:c7:af:0c:16:1d:94:
+ 95:3a:ce:2c:8f:e2:94:7f:1f:a1:76:e2:9f:d1:41:
+ 31:f0:68:e5:ae:df:d0:75:a0:34:f5:25:93:85:b3:
+ 25:50:42:6c:00:c0:fe:3b:e0:fb:00:de:75:33:86:
+ 6a:21:35:14:9d:7f:4a:af:f7:15:f2:d7:bb:2f:de:
+ df:ab
+ Exponent: 65537 (0x10001)
+ X509v3 extensions:
+ X509v3 Basic Constraints:
+ CA:FALSE
+ Netscape Comment:
+ OpenSSL Generated Certificate
+ X509v3 Subject Key Identifier:
+ B3:42:0A:9A:00:19:CB:CB:24:A0:02:45:1E:8A:B0:54:CB:9F:55:FE
+ X509v3 Authority Key Identifier:
+ keyid:B3:42:0A:9A:00:19:CB:CB:24:A0:02:45:1E:8A:B0:54:CB:9F:55:FE
+
+ Signature Algorithm: sha256WithRSAEncryption
+ 32:9a:05:e3:6f:ae:ee:b1:a2:12:0a:9f:0a:e7:78:26:df:90:
+ fb:84:60:ae:13:fc:ff:fd:42:84:23:14:c3:2e:e2:a9:df:4b:
+ 5c:2f:5b:0e:3d:f9:5a:56:50:13:bc:89:1a:08:70:dd:6c:6c:
+ e8:ae:cf:22:39:92:f2:3b:40:03:8f:4e:bc:54:88:6b:fd:8c:
+ b6:eb:30:90:21:db:fc:4e:5c:7e:12:75:e2:52:76:df:19:0f:
+ 30:49:1e:15:bc:ba:6a:e6:f7:af:93:ad:e4:36:da:47:47:a6:
+ 88:b0:ae:46:1e:91:e1:d6:b1:5e:a4:f0:68:02:81:57:86:5d:
+ 17:d1:6c:7e:7a:9f:5e:0d:fc:10:e7:7a:1a:b5:f9:4b:1d:78:
+ a4:9a:9d:d7:c2:64:c3:52:28:7f:a1:b7:25:d7:13:3f:09:7f:
+ f2:fd:dd:c6:91:eb:9b:51:80:e2:36:cb:9f:5b:4e:47:eb:77:
+ d3:cc:8b:18:b5:0b:97:a2:53:8e:fb:9b:94:7d:57:21:32:c6:
+ f3:67:93:a4:9b:eb:46:b7:cd:08:43:99:dd:c1:c3:51:b9:19:
+ ef:92:77:1c:84:67:80:67:95:ba:00:75:3d:7b:8b:ff:24:30:
+ f1:fa:6d:da:31:9d:cf:06:da:5d:04:07:14:45:8c:6b:e7:21:
+ 31:ec:7b:23
+-----BEGIN CERTIFICATE-----
+MIIDfjCCAmagAwIBAgIBATANBgkqhkiG9w0BAQsFADBCMQswCQYDVQQGEwJDWjEP
+MA0GA1UECAwGUFJBR1VFMSIwIAYDVQQDDBl0cmFuc3BvcnQtdGVzdC1zZXJ2ZXIu
+Y29tMCIYDzE5NzAwMTAxMDAwMDAwWhgPMTk3MDAxMDEwMDAwMDBaMEIxCzAJBgNV
+BAYTAkNaMQ8wDQYDVQQIDAZQUkFHVUUxIjAgBgNVBAMMGXRyYW5zcG9ydC10ZXN0
+LXNlcnZlci5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC/axoR
+RwGs61wtz85qpADOL9ElA18GOAKSJBiSKmkZsiujT/d53jXD9XI3g0ST+Xb8iSky
+nA1LlX3RXUDpuklQfcYKyB7nkB43fAsjo+O8yVOB3tZfy7I9NqxZsDORjwxfECBw
+v6Mis5is1HrqZ7ixjM/l/o+gpQKtbc7xYqvcXZacT5VH1YK3s+OHTI04hSoknX/H
+pA69ii1r0tToeGIbqiVfWmTldiOuEQOaXO2iulHssfOuulzr3UljysevDBYdlJU6
+ziyP4pR/H6F24p/RQTHwaOWu39B1oDT1JZOFsyVQQmwAwP474PsA3nUzhmohNRSd
+f0qv9xXy17sv3t+rAgMBAAGjezB5MAkGA1UdEwQCMAAwLAYJYIZIAYb4QgENBB8W
+HU9wZW5TU0wgR2VuZXJhdGVkIENlcnRpZmljYXRlMB0GA1UdDgQWBBSzQgqaABnL
+yySgAkUeirBUy59V/jAfBgNVHSMEGDAWgBSzQgqaABnLyySgAkUeirBUy59V/jAN
+BgkqhkiG9w0BAQsFAAOCAQEAMpoF42+u7rGiEgqfCud4Jt+Q+4RgrhP8//1ChCMU
+wy7iqd9LXC9bDj35WlZQE7yJGghw3Wxs6K7PIjmS8jtAA49OvFSIa/2MtuswkCHb
+/E5cfhJ14lJ23xkPMEkeFby6aub3r5Ot5DbaR0emiLCuRh6R4daxXqTwaAKBV4Zd
+F9FsfnqfXg38EOd6GrX5Sx14pJqd18Jkw1Iof6G3JdcTPwl/8v3dxpHrm1GA4jbL
+n1tOR+t308yLGLULl6JTjvublH1XITLG82eTpJvrRrfNCEOZ3cHDUbkZ75J3HIRn
+gGeVugB1PXuL/yQw8fpt2jGdzwbaXQQHFEWMa+chMex7Iw==
+-----END CERTIFICATE-----
diff --git a/tests/pytests/certs/tt-expired.key.pem b/tests/pytests/certs/tt-expired.key.pem
new file mode 100644
index 0000000..ca2988c
--- /dev/null
+++ b/tests/pytests/certs/tt-expired.key.pem
@@ -0,0 +1,27 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIIEogIBAAKCAQEAv2saEUcBrOtcLc/OaqQAzi/RJQNfBjgCkiQYkippGbIro0/3
+ed41w/VyN4NEk/l2/IkpMpwNS5V90V1A6bpJUH3GCsge55AeN3wLI6PjvMlTgd7W
+X8uyPTasWbAzkY8MXxAgcL+jIrOYrNR66me4sYzP5f6PoKUCrW3O8WKr3F2WnE+V
+R9WCt7Pjh0yNOIUqJJ1/x6QOvYota9LU6HhiG6olX1pk5XYjrhEDmlztorpR7LHz
+rrpc691JY8rHrwwWHZSVOs4sj+KUfx+hduKf0UEx8Gjlrt/QdaA09SWThbMlUEJs
+AMD+O+D7AN51M4ZqITUUnX9Kr/cV8te7L97fqwIDAQABAoIBAEA4ytIpJKLDhHXK
+VtLom2ySFnV4oBUSDarCeYvwtrpsUL/GQJ2etCM+4kdFv2h2NjmcOzpDqSJG0aPA
+ydqhKZ/b0uojIltGuxyafZJDllDsqxvTi9EwImjvQvwEZgjcGaZ7Xqb1ZOJrpzm1
+QFgM3KaVO9tKgR3Avxk40kmidU7FctFi5IELwnH/RR1OHvJbxOE4+i0LlDx0QzhX
+QHtnvHLqLLdqsFk8KvuVuVj1FwqJ6cSL0JrAdt7dnGmXBo4PDqT8Hj0AjM+CcNrV
+1D6Li9xr4y55EZUK2qU/FVDC3LqlYQy5mBfasJAXPQG4RgSVFxJ929HC7gi8vMCO
+UMeLniECgYEA6gBoRwzQ5pJUXfZGW41lJt08utfycGZm7VrA81r0x0F+DcuZ2t6J
+kB9Wnp/MNpB4DJLbl7oM2OlFOO3cw0n3VaFpNMPHVHzNbyi1hp94AIIeDz/sxfUI
+Lx7ynAQSPPQzDRfVJesT8waBdweA71TBOlrFQ2Cp7O4Qf+p0akQSv3cCgYEA0Wnd
+1Gbierv2m6Jnblg+brTMQwbRsOAM2n0V4Gd2kRaLSYd23ebshvx8xTWipRlrb5vP
+UEh+LkfuscqaJDCrikasht9z5FJtfIzHKgTrLSoR3MJRjrnuLJWTQUwSqzd0UNN6
+HigV6p+CqesNnELErak53IMfmkHAhTSkII8R9m0CgYBRY+DhTaDfgegcYouoTm7v
+bKYx6uillciZKCbSvkFDiREaJUYXba31ViEfvT8ff3JyFSaSCKFtVP3BxmIx/ukr
+fKAGPU54oYwm7Mbu00q/CoMAFOD7HbZCBYanI3dggiO7mx2FOdXPguTHDPIYzKcE
+8AuK2vVftpJAm8DwMUtAEwKBgH/eRc5ZGDdbKGS10LQm+9A7Y3IV6to2pIKQ2FfS
+tSo4espmBeXPCGQQLdt5OZvYHqril77s1OdLkutKy74HXecr6lLchHZJAoOHrmDw
+6e0FAC0tFgGxdEYS+vxnCAs17DciOjHJxkAiL/WzCfd9KXzklOkZw6U8OuLbVtBu
+q8gtAoGAbl03XZm+SHrO7XjHK/Fe5YD13cOirg48htvjbpqEDZNQr3l0eVnEj074
+IopDa/wUFlaaqPZ/DVFctqSocyskWIP4u9HfmsNBHjK5zQlge7B1fNVao++YKund
+qnVnXjWQuF2aL8k2geFxdSKmHTF4/N1qEyeyR+tMaFpGfMZOuM8=
+-----END RSA PRIVATE KEY-----
diff --git a/tests/pytests/certs/tt.cert.pem b/tests/pytests/certs/tt.cert.pem
new file mode 100644
index 0000000..2ea4898
--- /dev/null
+++ b/tests/pytests/certs/tt.cert.pem
@@ -0,0 +1,22 @@
+-----BEGIN CERTIFICATE-----
+MIIDkTCCAnmgAwIBAgIJAP/KybHquuyUMA0GCSqGSIb3DQEBCwUAMEIxCzAJBgNV
+BAYTAkNaMQ8wDQYDVQQIDAZQUkFHVUUxIjAgBgNVBAMMGXRyYW5zcG9ydC10ZXN0
+LXNlcnZlci5jb20wIBcNMTgwNzE4MTAwODIzWhgPMjA3MzA0MjAxMDA4MjNaMEIx
+CzAJBgNVBAYTAkNaMQ8wDQYDVQQIDAZQUkFHVUUxIjAgBgNVBAMMGXRyYW5zcG9y
+dC10ZXN0LXNlcnZlci5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIB
+AQDRAQDAX6+lFKurvm7fgQqm8WyYzT/wxfPJjsVQGe87OlH1KFzVfzYzgEt0RMlM
+eZgipREBZB2zK+WFM5RBHWYAwlI5PKt7EAGn8q1Zm4z+M9Uom3/Hy3bZ9q+AJwjk
+odpHYuFyWJqHIQBqaQ3SFyJwdZ/GsuzEUfWuIl74oyyMAeykTKFGdaVuIlLC3fKm
+8UCnfk99i/LEXUwRcmOV0uaG7deN5ITDDCFdb615yVjLkMhGY/jHK7uuxATOopEk
+4vThQ1aQjSkHwluaqFUW6Zl4QF8WOAufoWQPFZ8XxmUYEIG/sMvLv6dol7ltjEbC
+bfyzlS+9Qbnq6MfhTZF/4jAPAgMBAAGjgYcwgYQwHQYDVR0OBBYEFBNiUgCiKw4b
+CFNKaEkqhkNSer7wMB8GA1UdIwQYMBaAFBNiUgCiKw4bCFNKaEkqhkNSer7wMA8G
+A1UdEwEB/wQFMAMBAf8wJAYDVR0RBB0wG4IZdHJhbnNwb3J0LXRlc3Qtc2VydmVy
+LmNvbTALBgNVHQ8EBAMCAaYwDQYJKoZIhvcNAQELBQADggEBACz1ZQ8XkGobhTcA
+hkSTSw0ko6qwVuJJD5ue3SUcWLATsskohTJmN6bde3IMDRyQvLJAlMdG2p1qMbtA
+OTbnQJTT7oDLaW8w2D+eO5oWTJvxLpl6TxbIfJN/8ITB1ltOCxTU9cVNbd2eh8sj
+l3R4etg9djYRrqtNxCQZOYSwvhHw2MefnwjGVuJEu6JYOn3IE8Jqsh/LI59C87nE
+MetZrXlzC6kSAFfRYgQET9RhBobMU9yFR8zGVHDFoxqNQs2lYKPz/3rFPetL2rjT
+cFwzxkxDdwn+RNisBc1LMfIg7pvSMFR6sAnpjeRHN0Uoem1jj2qtzjbFENDuyQ4/
+HSi4UcE=
+-----END CERTIFICATE-----
diff --git a/tests/pytests/certs/tt.conf b/tests/pytests/certs/tt.conf
new file mode 100644
index 0000000..f011e5a
--- /dev/null
+++ b/tests/pytests/certs/tt.conf
@@ -0,0 +1,353 @@
+# SPDX-License-Identifier: CC0-1.0
+# OpenSSL example configuration file.
+# This is mostly being used for generation of certificate requests.
+#
+
+# This definition stops the following lines choking if HOME isn't
+# defined.
+HOME = .
+RANDFILE = $ENV::HOME/.rnd
+
+# Extra OBJECT IDENTIFIER info:
+#oid_file = $ENV::HOME/.oid
+oid_section = new_oids
+
+# To use this configuration file with the "-extfile" option of the
+# "openssl x509" utility, name here the section containing the
+# X.509v3 extensions to use:
+# extensions =
+# (Alternatively, use a configuration file that has only
+# X.509v3 extensions in its main [= default] section.)
+
+[ new_oids ]
+
+# We can add new OIDs in here for use by 'ca', 'req' and 'ts'.
+# Add a simple OID like this:
+# testoid1=1.2.3.4
+# Or use config file substitution like this:
+# testoid2=${testoid1}.5.6
+
+# Policies used by the TSA examples.
+tsa_policy1 = 1.2.3.4.1
+tsa_policy2 = 1.2.3.4.5.6
+tsa_policy3 = 1.2.3.4.5.7
+
+####################################################################
+[ ca ]
+default_ca = CA_default # The default ca section
+
+####################################################################
+[ CA_default ]
+
+dir = ./demoCA # Where everything is kept
+certs = $dir/certs # Where the issued certs are kept
+crl_dir = $dir/crl # Where the issued crl are kept
+database = $dir/index.txt # database index file.
+#unique_subject = no # Set to 'no' to allow creation of
+ # several certs with same subject.
+new_certs_dir = $dir/newcerts # default place for new certs.
+
+certificate = $dir/cacert.pem # The CA certificate
+serial = $dir/serial # The current serial number
+crlnumber = $dir/crlnumber # the current crl number
+ # must be commented out to leave a V1 CRL
+crl = $dir/crl.pem # The current CRL
+private_key = $dir/private/cakey.pem# The private key
+RANDFILE = $dir/private/.rand # private random number file
+
+x509_extensions = usr_cert # The extensions to add to the cert
+
+# Comment out the following two lines for the "traditional"
+# (and highly broken) format.
+name_opt = ca_default # Subject Name options
+cert_opt = ca_default # Certificate field options
+
+# Extension copying option: use with caution.
+copy_extensions = copy
+
+# Extensions to add to a CRL. Note: Netscape communicator chokes on V2 CRLs
+# so this is commented out by default to leave a V1 CRL.
+# crlnumber must also be commented out to leave a V1 CRL.
+# crl_extensions = crl_ext
+
+default_days = 365 # how long to certify for
+default_crl_days= 30 # how long before next CRL
+default_md = default # use public key default MD
+preserve = no # keep passed DN ordering
+
+# A few difference way of specifying how similar the request should look
+# For type CA, the listed attributes must be the same, and the optional
+# and supplied fields are just that :-)
+policy = policy_match
+
+# For the CA policy
+[ policy_match ]
+countryName = optional
+stateOrProvinceName = optional
+organizationName = optional
+organizationalUnitName = optional
+commonName = supplied
+emailAddress = optional
+
+# For the 'anything' policy
+# At this point in time, you must list all acceptable 'object'
+# types.
+[ policy_anything ]
+countryName = optional
+stateOrProvinceName = optional
+localityName = optional
+organizationName = optional
+organizationalUnitName = optional
+commonName = supplied
+emailAddress = optional
+
+####################################################################
+[ req ]
+default_bits = 2048
+default_keyfile = privkey.pem
+distinguished_name = req_distinguished_name
+attributes = req_attributes
+x509_extensions = v3_ca # The extensions to add to the self signed cert
+
+# Passwords for private keys if not present they will be prompted for
+# input_password = secret
+# output_password = secret
+
+# This sets a mask for permitted string types. There are several options.
+# default: PrintableString, T61String, BMPString.
+# pkix : PrintableString, BMPString (PKIX recommendation before 2004)
+# utf8only: only UTF8Strings (PKIX recommendation after 2004).
+# nombstr : PrintableString, T61String (no BMPStrings or UTF8Strings).
+# MASK:XXXX a literal mask value.
+# WARNING: ancient versions of Netscape crash on BMPStrings or UTF8Strings.
+string_mask = utf8only
+
+# req_extensions = v3_req # The extensions to add to a certificate request
+
+[ req_distinguished_name ]
+countryName = Country Name (2 letter code)
+countryName_default = CZ
+countryName_min = 2
+countryName_max = 2
+
+stateOrProvinceName = State or Province Name (full name)
+stateOrProvinceName_default = PRAGUE
+
+localityName = Locality Name (eg, city)
+
+0.organizationName = Organization Name (eg, company)
+0.organizationName_default =
+
+# we can do this but it is not needed normally :-)
+#1.organizationName = Second Organization Name (eg, company)
+#1.organizationName_default = World Wide Web Pty Ltd
+
+organizationalUnitName = Organizational Unit Name (eg, section)
+#organizationalUnitName_default =
+
+commonName = Common Name (e.g. server FQDN or YOUR name)
+commonName_max = 64
+commonName_default = transport-test-server.com
+
+emailAddress = Email Address
+emailAddress_max = 64
+
+# SET-ex3 = SET extension number 3
+
+[ req_attributes ]
+challengePassword = A challenge password
+challengePassword_min = 4
+challengePassword_max = 20
+
+unstructuredName = An optional company name
+
+[ usr_cert ]
+
+# These extensions are added when 'ca' signs a request.
+
+# This goes against PKIX guidelines but some CAs do it and some software
+# requires this to avoid interpreting an end user certificate as a CA.
+
+basicConstraints=CA:FALSE
+
+# Here are some examples of the usage of nsCertType. If it is omitted
+# the certificate can be used for anything *except* object signing.
+
+# This is OK for an SSL server.
+# nsCertType = server
+
+# For an object signing certificate this would be used.
+# nsCertType = objsign
+
+# For normal client use this is typical
+# nsCertType = client, email
+
+# and for everything including object signing:
+# nsCertType = client, email, objsign
+
+# This is typical in keyUsage for a client certificate.
+# keyUsage = nonRepudiation, digitalSignature, keyEncipherment
+
+# This will be displayed in Netscape's comment listbox.
+nsComment = "OpenSSL Generated Certificate"
+
+# PKIX recommendations harmless if included in all certificates.
+subjectKeyIdentifier=hash
+authorityKeyIdentifier=keyid,issuer
+
+# This stuff is for subjectAltName and issuerAltname.
+# Import the email address.
+# subjectAltName=email:copy
+# An alternative to produce certificates that aren't
+# deprecated according to PKIX.
+# subjectAltName=email:move
+
+# Copy subject details
+# issuerAltName=issuer:copy
+
+#nsCaRevocationUrl = http://www.domain.dom/ca-crl.pem
+#nsBaseUrl
+#nsRevocationUrl
+#nsRenewalUrl
+#nsCaPolicyUrl
+#nsSslServerName
+
+# This is required for TSA certificates.
+# extendedKeyUsage = critical,timeStamping
+
+[ v3_req ]
+
+# Extensions to add to a certificate request
+
+basicConstraints = CA:FALSE
+keyUsage = nonRepudiation, digitalSignature, keyEncipherment
+
+[ v3_ca ]
+
+
+# Extensions for a typical CA
+
+
+# PKIX recommendation.
+
+subjectKeyIdentifier=hash
+
+authorityKeyIdentifier=keyid:always,issuer
+
+basicConstraints = critical,CA:true
+
+subjectAltName = @alternate_names
+
+# Key usage: this is typical for a CA certificate. However since it will
+# prevent it being used as an test self-signed certificate it is best
+# left out by default.
+keyUsage = digitalSignature, keyEncipherment, cRLSign, keyCertSign
+
+# Some might want this also
+# nsCertType = sslCA, emailCA
+
+# Include email address in subject alt name: another PKIX recommendation
+# subjectAltName=email:copy
+# Copy issuer details
+# issuerAltName=issuer:copy
+
+# DER hex encoding of an extension: beware experts only!
+# obj=DER:02:03
+# Where 'obj' is a standard or added object
+# You can even override a supported extension:
+# basicConstraints= critical, DER:30:03:01:01:FF
+
+[ crl_ext ]
+
+# CRL extensions.
+# Only issuerAltName and authorityKeyIdentifier make any sense in a CRL.
+
+# issuerAltName=issuer:copy
+authorityKeyIdentifier=keyid:always
+
+[ proxy_cert_ext ]
+# These extensions should be added when creating a proxy certificate
+
+# This goes against PKIX guidelines but some CAs do it and some software
+# requires this to avoid interpreting an end user certificate as a CA.
+
+basicConstraints=CA:FALSE
+
+# Here are some examples of the usage of nsCertType. If it is omitted
+# the certificate can be used for anything *except* object signing.
+
+# This is OK for an SSL server.
+# nsCertType = server
+
+# For an object signing certificate this would be used.
+# nsCertType = objsign
+
+# For normal client use this is typical
+# nsCertType = client, email
+
+# and for everything including object signing:
+# nsCertType = client, email, objsign
+
+# This is typical in keyUsage for a client certificate.
+# keyUsage = nonRepudiation, digitalSignature, keyEncipherment
+
+# This will be displayed in Netscape's comment listbox.
+nsComment = "OpenSSL Generated Certificate"
+
+# PKIX recommendations harmless if included in all certificates.
+subjectKeyIdentifier=hash
+authorityKeyIdentifier=keyid,issuer
+
+# This stuff is for subjectAltName and issuerAltname.
+# Import the email address.
+# subjectAltName=email:copy
+# An alternative to produce certificates that aren't
+# deprecated according to PKIX.
+# subjectAltName=email:move
+
+# Copy subject details
+# issuerAltName=issuer:copy
+
+#nsCaRevocationUrl = http://www.domain.dom/ca-crl.pem
+#nsBaseUrl
+#nsRevocationUrl
+#nsRenewalUrl
+#nsCaPolicyUrl
+#nsSslServerName
+
+# This really needs to be in place for it to be a proxy certificate.
+proxyCertInfo=critical,language:id-ppl-anyLanguage,pathlen:3,policy:foo
+
+####################################################################
+[ tsa ]
+
+default_tsa = tsa_config1 # the default TSA section
+
+[ tsa_config1 ]
+
+# These are used by the TSA reply generation only.
+dir = ./demoCA # TSA root directory
+serial = $dir/tsaserial # The current serial number (mandatory)
+crypto_device = builtin # OpenSSL engine to use for signing
+signer_cert = $dir/tsacert.pem # The TSA signing certificate
+ # (optional)
+certs = $dir/cacert.pem # Certificate chain to include in reply
+ # (optional)
+signer_key = $dir/private/tsakey.pem # The TSA private key (optional)
+signer_digest = sha256 # Signing digest to use. (Optional)
+default_policy = tsa_policy1 # Policy if request did not specify it
+ # (optional)
+other_policies = tsa_policy2, tsa_policy3 # acceptable policies (optional)
+digests = sha1, sha256, sha384, sha512 # Acceptable message digests (mandatory)
+accuracy = secs:1, millisecs:500, microsecs:100 # (optional)
+clock_precision_digits = 0 # number of digits after dot. (optional)
+ordering = yes # Is ordering defined for timestamps?
+ # (optional, default: no)
+tsa_name = yes # Must the TSA name be included in the reply?
+ # (optional, default: no)
+ess_cert_id_chain = no # Must the ESS cert id chain be included?
+ # (optional, default: no)
+
+[ alternate_names ]
+
+DNS.1 = transport-test-server.com
diff --git a/tests/pytests/certs/tt.key.pem b/tests/pytests/certs/tt.key.pem
new file mode 100644
index 0000000..1974be7
--- /dev/null
+++ b/tests/pytests/certs/tt.key.pem
@@ -0,0 +1,28 @@
+-----BEGIN PRIVATE KEY-----
+MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDRAQDAX6+lFKur
+vm7fgQqm8WyYzT/wxfPJjsVQGe87OlH1KFzVfzYzgEt0RMlMeZgipREBZB2zK+WF
+M5RBHWYAwlI5PKt7EAGn8q1Zm4z+M9Uom3/Hy3bZ9q+AJwjkodpHYuFyWJqHIQBq
+aQ3SFyJwdZ/GsuzEUfWuIl74oyyMAeykTKFGdaVuIlLC3fKm8UCnfk99i/LEXUwR
+cmOV0uaG7deN5ITDDCFdb615yVjLkMhGY/jHK7uuxATOopEk4vThQ1aQjSkHwlua
+qFUW6Zl4QF8WOAufoWQPFZ8XxmUYEIG/sMvLv6dol7ltjEbCbfyzlS+9Qbnq6Mfh
+TZF/4jAPAgMBAAECggEBALSs10d18FMW0WjAUPxpgxnaLnTRSesMVLjy8ONT6Bkd
+S2hRIh91vxc6WwABzrqLita4N0EqmPoggmNpuUmo7lrNoWLVbbAOoD/da7nA3FuL
+10MpWYcP/ohh1klEdU2gFSAM/LNqoPsbrk5OzqHFWgI5zItqdX8pEucb01nBRWsp
+VMY2vzVuFB2jweZQ5+LCpfSMcRIzlxQa9CG4Peu6YW1Z4b3aUcS63/829JN/ZOGd
+uoRqR+gP71yNIt6i7wA5cot5FRmzlFEGhb1XzBOB1FFHOiknOZzbBtDsGUUmVtfA
+6mXcTumhdHbC0bXnHei/s2s9X0EeyQFYPkoS4NUQ2dECgYEA/7lhgn89K8rpUPnS
+eccTpKVPWp8luQei98Hi/F94kwP32l7Zl7Bmu2nltUoB1GBRXoXY6KzTphmT6ioA
+8joLCKIii5/nOdZAdHbIN2tkXS56h524q5I2jKogjfRrpCaAJE8x99f8L9uTBfZb
+/7BBQDHai1/S6LcpIRf/4g1/xBsCgYEA0Tq4V5hR9mGDUFir1FDGhA3ijDkIE/sO
+3QGTU7W90BL27te98FuQtWOPqfd1fi26WypNpNQUZb3V4x5tmDcpWscfj6I10432
+4zECPlDgaevucJjj245U7WjUhdAvlRy6K8H/8MgRBAjw9h8dwIGIx9gmOqKdA+/h
+ve3xyjKQex0CgYAz0XzQ1LewiA1/OyBLTOvOETFjS5x5QfLkAYXdXfswzz0KIu40
+rqoij/LcKYL1Zg8W+Ehb3amFnuk6KgjHDLvvo+scH+ra7W9iKi+oCzrrJt/tWyhw
+m9Ax8Mdn/H9TY/nTYbjeYAXaLMQ+EQ3TYgPW3kNKusAiJ/tNmW9gfxvEwQKBgGSJ
+Rbj5fTDZjGKYKQDdS3Z6wYhFg0culObHcgaARtPruPHtgtwy82blj0vJl5Bo4qoZ
+urNgIOj+ff8jSOAiaWGwWs8Gz7x289IZY42UCTF8Z9d878g5LT/i5nPiJGsPIboS
+/yuwxtRcg4SQURiGZbY5e60jJDWXF67O3icdguVVAoGARXLufXvZ/9Xf1DmFFxjq
+PJMCa1sfofqjB4KqYbt17vFtTsddCiyqsbpx36oY6nIdm9yUiGo10YaSEJtDEGLS
+L3TPZ4s8M8dcjOfj8Kk75pKbJ7NY4qA64dtbxcZbrFp3/mGZkDing94y+Zc/aFqa
+xQsA/yhmYV9r+FHDL54Cn6I=
+-----END PRIVATE KEY-----
diff --git a/tests/pytests/conftest.py b/tests/pytests/conftest.py
new file mode 100644
index 0000000..4c711f8
--- /dev/null
+++ b/tests/pytests/conftest.py
@@ -0,0 +1,102 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import socket
+
+import pytest
+
+from kresd import init_portdir, make_kresd
+
+
+@pytest.fixture
+def kresd(tmpdir):
+ with make_kresd(tmpdir) as kresd:
+ yield kresd
+
+
+@pytest.fixture
+def kresd_silent(tmpdir):
+ with make_kresd(tmpdir, verbose=False) as kresd:
+ yield kresd
+
+
+@pytest.fixture
+def kresd_tt(tmpdir):
+ with make_kresd(tmpdir, 'tt') as kresd:
+ yield kresd
+
+
+@pytest.fixture(params=[
+ 'ip_tcp_socket',
+ 'ip6_tcp_socket',
+ 'ip_tls_socket',
+ 'ip6_tls_socket',
+])
+def make_kresd_sock(request, kresd):
+ """Factory function to create sockets of the same kind."""
+ sock_func = getattr(kresd, request.param)
+
+ def _make_kresd_sock():
+ return sock_func()
+
+ return _make_kresd_sock
+
+
+@pytest.fixture(params=[
+ 'ip_tcp_socket',
+ 'ip6_tcp_socket',
+ 'ip_tls_socket',
+ 'ip6_tls_socket',
+])
+def make_kresd_silent_sock(request, kresd_silent):
+ """Factory function to create sockets of the same kind (no verbose)."""
+ sock_func = getattr(kresd_silent, request.param)
+
+ def _make_kresd_sock():
+ return sock_func()
+
+ return _make_kresd_sock
+
+
+@pytest.fixture
+def kresd_sock(make_kresd_sock):
+ return make_kresd_sock()
+
+
+@pytest.fixture(params=[
+ socket.AF_INET,
+ socket.AF_INET6,
+])
+def sock_family(request):
+ return request.param
+
+
+@pytest.fixture(params=[
+ True,
+ False
+])
+def single_buffer(request): # whether to send all data in a single buffer
+ return request.param
+
+
+@pytest.fixture(params=[
+ True,
+ False
+])
+def query_before(request): # whether to send an initial query
+ return request.param
+
+
+@pytest.mark.optionalhook
+def pytest_metadata(metadata): # filter potentially sensitive data from GitLab CI
+ keys_to_delete = []
+ for key in metadata.keys():
+ key_lower = key.lower()
+ if 'password' in key_lower or 'token' in key_lower or \
+ key_lower.startswith('ci') or key_lower.startswith('gitlab'):
+ keys_to_delete.append(key)
+ for key in keys_to_delete:
+ del metadata[key]
+
+
+def pytest_sessionstart(session): # pylint: disable=unused-argument
+ init_portdir()
diff --git a/tests/pytests/conn_flood.py b/tests/pytests/conn_flood.py
new file mode 100644
index 0000000..88bbfa5
--- /dev/null
+++ b/tests/pytests/conn_flood.py
@@ -0,0 +1,85 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+"""Test opening as many connections as possible.
+
+Due to resource-intensity of this test, it's filename doesn't contain
+"test" on purpose, so it doesn't automatically get picked up by pytest
+(to allow easy parallel testing).
+
+To execute this test, pass the filename of this file to pytest directly.
+Also, make sure not to use parallel execution (-n).
+"""
+
+import resource
+import time
+
+import pytest
+
+from kresd import Kresd
+import utils
+
+
+MAX_SOCKETS = 10000 # upper bound of how many connections to open
+MAX_ITERATIONS = 10 # number of iterations to run the test
+
+# we can't use softlimit itself since kresd already has open sockets,
+# so use lesser value
+RESERVED_NOFILE = 40 # 40 is empirical value
+
+
+@pytest.mark.parametrize('sock_func_name', [
+ 'ip_tcp_socket',
+ 'ip6_tcp_socket',
+ 'ip_tls_socket',
+ 'ip6_tls_socket',
+])
+def test_conn_flood(tmpdir, sock_func_name):
+ def create_sockets(make_sock, nsockets):
+ sockets = []
+ next_ping = time.time() + 4 # less than tcp idle timeout / 2
+ while True:
+ additional_sockets = 0
+ while time.time() < next_ping:
+ nsock_to_init = min(100, nsockets - len(sockets))
+ if not nsock_to_init:
+ return sockets
+ sockets.extend([make_sock() for _ in range(nsock_to_init)])
+ additional_sockets += nsock_to_init
+
+ # large number of connections can take a lot of time to open
+ # send some valid data to avoid TCP idle timeout for already open sockets
+ next_ping = time.time() + 4
+ for s in sockets:
+ utils.ping_alive(s)
+
+ # break when no more than 20% additional sockets are created
+ if additional_sockets / len(sockets) < 0.2:
+ return sockets
+
+ max_num_of_open_files = resource.getrlimit(resource.RLIMIT_NOFILE)[0] - RESERVED_NOFILE
+ nsockets = min(max_num_of_open_files, MAX_SOCKETS)
+
+ # create kresd instance with verbose=False
+ ip = '127.0.0.1'
+ ip6 = '::1'
+ with Kresd(tmpdir, ip=ip, ip6=ip6, verbose=False) as kresd:
+ make_sock = getattr(kresd, sock_func_name) # function for creating sockets
+ sockets = create_sockets(make_sock, nsockets)
+ print("\nEstablished {} connections".format(len(sockets)))
+
+ print("Start sending data")
+ for i in range(MAX_ITERATIONS):
+ for s in sockets:
+ utils.ping_alive(s)
+ print("Iteration {} done...".format(i))
+
+ print("Close connections")
+ for s in sockets:
+ s.close()
+
+ # check in kresd is alive
+ print("Check upstream is still alive")
+ sock = make_sock()
+ utils.ping_alive(sock)
+
+ print("OK!")
diff --git a/tests/pytests/kresd.py b/tests/pytests/kresd.py
new file mode 100644
index 0000000..149efe9
--- /dev/null
+++ b/tests/pytests/kresd.py
@@ -0,0 +1,306 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from collections import namedtuple
+from contextlib import ContextDecorator, contextmanager
+import os
+from pathlib import Path
+import random
+import re
+import shutil
+import socket
+import subprocess
+import time
+
+import jinja2
+
+import utils
+
+
+PYTESTS_DIR = os.path.dirname(os.path.realpath(__file__))
+CERTS_DIR = os.path.join(PYTESTS_DIR, 'certs')
+TEMPLATES_DIR = os.path.join(PYTESTS_DIR, 'templates')
+KRESD_CONF_TEMPLATE = 'kresd.conf.j2'
+KRESD_STARTUP_MSGID = 10005 # special unique ID at the start of the "test" log
+KRESD_PORTDIR = '/tmp/pytest-kresd-portdir'
+KRESD_TESTPORT_MIN = 1024
+KRESD_TESTPORT_MAX = 32768 # avoid overlap with docker ephemeral port range
+
+
+def init_portdir():
+ try:
+ shutil.rmtree(KRESD_PORTDIR)
+ except FileNotFoundError:
+ pass
+ os.makedirs(KRESD_PORTDIR)
+
+
+def create_file_from_template(template_path, dest, data):
+ env = jinja2.Environment(
+ loader=jinja2.FileSystemLoader(TEMPLATES_DIR))
+ template = env.get_template(template_path)
+ rendered_template = template.render(**data)
+
+ with open(dest, "w", encoding='UTF-8') as fh:
+ fh.write(rendered_template)
+
+
+Forward = namedtuple('Forward', ['proto', 'ip', 'port', 'hostname', 'ca_file'])
+
+
+class Kresd(ContextDecorator):
+ def __init__(
+ self, workdir, port=None, tls_port=None, ip=None, ip6=None, certname=None,
+ verbose=True, hints=None, forward=None, policy_test_pass=False):
+ if ip is None and ip6 is None:
+ raise ValueError("IPv4 or IPv6 must be specified!")
+ self.workdir = str(workdir)
+ self.port = port
+ self.tls_port = tls_port
+ self.ip = ip
+ self.ip6 = ip6
+ self.process = None
+ self.sockets = []
+ self.logfile = None
+ self.verbose = verbose
+ self.hints = {} if hints is None else hints
+ self.forward = forward
+ self.policy_test_pass = policy_test_pass
+
+ if certname:
+ self.tls_cert_path = os.path.join(CERTS_DIR, certname + '.cert.pem')
+ self.tls_key_path = os.path.join(CERTS_DIR, certname + '.key.pem')
+ else:
+ self.tls_cert_path = None
+ self.tls_key_path = None
+
+ @property
+ def config_path(self):
+ return str(os.path.join(self.workdir, 'kresd.conf'))
+
+ @property
+ def logfile_path(self):
+ return str(os.path.join(self.workdir, 'kresd.log'))
+
+ def __enter__(self):
+ if self.port is not None:
+ take_port(self.port, self.ip, self.ip6, timeout=120)
+ else:
+ self.port = make_port(self.ip, self.ip6)
+ if self.tls_port is not None:
+ take_port(self.tls_port, self.ip, self.ip6, timeout=120)
+ else:
+ self.tls_port = make_port(self.ip, self.ip6)
+
+ create_file_from_template(KRESD_CONF_TEMPLATE, self.config_path, {'kresd': self})
+ self.logfile = open(self.logfile_path, 'w', encoding='UTF-8')
+ self.process = subprocess.Popen(
+ ['kresd', '-c', self.config_path, '-n', self.workdir],
+ stderr=self.logfile, env=os.environ.copy())
+
+ try:
+ self._wait_for_tcp_port() # wait for ports to be up and responding
+ if not self.all_ports_alive(msgid=10001):
+ raise RuntimeError("Kresd not listening on all ports")
+
+ # issue special msgid to mark start of test log
+ sock = self.ip_tcp_socket() if self.ip else self.ip6_tcp_socket()
+ assert utils.try_ping_alive(sock, close=True, msgid=KRESD_STARTUP_MSGID)
+
+ # sanity check - kresd didn't crash
+ self.process.poll()
+ if self.process.returncode is not None:
+ raise RuntimeError("Kresd crashed with returncode: {}".format(
+ self.process.returncode))
+ except (RuntimeError, ConnectionError): # pylint: disable=try-except-raise
+ with open(self.logfile_path, encoding='UTF-8') as log: # print log for debugging
+ print(log.read())
+ raise
+
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ try:
+ if not self.all_ports_alive(msgid=1006):
+ raise RuntimeError("Kresd crashed")
+ finally:
+ for sock in self.sockets:
+ sock.close()
+ self.process.terminate()
+ self.logfile.close()
+ Path(KRESD_PORTDIR, str(self.port)).unlink()
+
+ def all_ports_alive(self, msgid=10001):
+ alive = True
+ if self.ip:
+ alive &= utils.try_ping_alive(self.ip_tcp_socket(), close=True, msgid=msgid)
+ alive &= utils.try_ping_alive(self.ip_tls_socket(), close=True, msgid=msgid + 1)
+ if self.ip6:
+ alive &= utils.try_ping_alive(self.ip6_tcp_socket(), close=True, msgid=msgid + 2)
+ alive &= utils.try_ping_alive(self.ip6_tls_socket(), close=True, msgid=msgid + 3)
+ return alive
+
+ def _wait_for_tcp_port(self, max_delay=10, delay_step=0.2):
+ family = socket.AF_INET if self.ip else socket.AF_INET6
+ i = 0
+ end_time = time.time() + max_delay
+
+ while time.time() < end_time:
+ i += 1
+
+ # use exponential backoff algorithm to choose next delay
+ rand_delay = random.randrange(0, i)
+ time.sleep(rand_delay * delay_step)
+
+ try:
+ sock, dest = self.stream_socket(family, timeout=5)
+ sock.connect(dest)
+ except ConnectionRefusedError:
+ continue
+ else:
+ try:
+ return utils.try_ping_alive(sock, close=True, msgid=10000)
+ except socket.timeout:
+ continue
+ finally:
+ sock.close()
+ raise RuntimeError("Kresd didn't start in time {}".format(dest))
+
+ def socket_dest(self, family, tls=False):
+ port = self.tls_port if tls else self.port
+ if family == socket.AF_INET:
+ return self.ip, port
+ elif family == socket.AF_INET6:
+ return self.ip6, port, 0, 0
+ raise RuntimeError("Unsupported socket family: {}".format(family))
+
+ def stream_socket(self, family, tls=False, timeout=20):
+ """Initialize a socket and return it along with the destination without connecting."""
+ sock = socket.socket(family, socket.SOCK_STREAM)
+ sock.settimeout(timeout)
+ sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+ dest = self.socket_dest(family, tls)
+ self.sockets.append(sock)
+ return sock, dest
+
+ def _tcp_socket(self, family):
+ sock, dest = self.stream_socket(family)
+ sock.connect(dest)
+ return sock
+
+ def ip_tcp_socket(self):
+ return self._tcp_socket(socket.AF_INET)
+
+ def ip6_tcp_socket(self):
+ return self._tcp_socket(socket.AF_INET6)
+
+ def _tls_socket(self, family):
+ sock, dest = self.stream_socket(family, tls=True)
+ ctx = utils.make_ssl_context(insecure=True)
+ ssock = ctx.wrap_socket(sock)
+ try:
+ ssock.connect(dest)
+ except OSError as exc:
+ if exc.errno == 0: # sometimes happens shortly after startup
+ return None
+ return ssock
+
+ def _tls_socket_with_retry(self, family):
+ sock = self._tls_socket(family)
+ if sock is None:
+ time.sleep(0.1)
+ sock = self._tls_socket(family)
+ if sock is None:
+ raise RuntimeError("Failed to create TLS socket!")
+ return sock
+
+ def ip_tls_socket(self):
+ return self._tls_socket_with_retry(socket.AF_INET)
+
+ def ip6_tls_socket(self):
+ return self._tls_socket_with_retry(socket.AF_INET6)
+
+ def partial_log(self):
+ partial_log = '\n (... omitting log start)\n'
+ with open(self.logfile_path, encoding='UTF-8') as log: # display partial log for debugging
+ past_startup_msgid = False
+ past_startup = False
+ for line in log:
+ if past_startup:
+ partial_log += line
+ else: # find real start of test log (after initial alive-pings)
+ if not past_startup_msgid:
+ if re.match(KRESD_LOG_STARTUP_MSGID, line) is not None:
+ past_startup_msgid = True
+ else:
+ if re.match(KRESD_LOG_IO_CLOSE, line) is not None:
+ past_startup = True
+ return partial_log
+
+
+def is_port_free(port, ip=None, ip6=None):
+ def check(family, type_, dest):
+ sock = socket.socket(family, type_)
+ sock.bind(dest)
+ sock.close()
+
+ try:
+ if ip is not None:
+ check(socket.AF_INET, socket.SOCK_STREAM, (ip, port))
+ check(socket.AF_INET, socket.SOCK_DGRAM, (ip, port))
+ if ip6 is not None:
+ check(socket.AF_INET6, socket.SOCK_STREAM, (ip6, port, 0, 0))
+ check(socket.AF_INET6, socket.SOCK_DGRAM, (ip6, port, 0, 0))
+ except OSError as exc:
+ if exc.errno == 98: # address already in use
+ return False
+ else:
+ raise
+ return True
+
+
+def take_port(port, ip=None, ip6=None, timeout=0):
+ port_path = Path(KRESD_PORTDIR, str(port))
+ end_time = time.time() + timeout
+ try:
+ port_path.touch(exist_ok=False)
+ except FileExistsError as ex:
+ raise ValueError(
+ "Port {} already reserved by system or another kresd instance!".format(port)) from ex
+
+ while True:
+ if is_port_free(port, ip, ip6):
+ # NOTE: The port_path isn't removed, so other instances don't have to attempt to
+ # take the same port again. This has the side effect of leaving many of these
+ # files behind, because when another kresd shuts down and removes its file, the
+ # port still can't be reserved for a while. This shouldn't become an issue unless
+ # we have thousands of tests (and run out of the port range).
+ break
+
+ if time.time() < end_time:
+ time.sleep(5)
+ else:
+ raise ValueError(
+ "Port {} is reserved by system!".format(port))
+ return port
+
+
+def make_port(ip=None, ip6=None):
+ for _ in range(10): # max attempts
+ port = random.randint(KRESD_TESTPORT_MIN, KRESD_TESTPORT_MAX)
+ try:
+ take_port(port, ip, ip6)
+ except ValueError:
+ continue # port reserved by system / another kresd instance
+ return port
+ raise RuntimeError("No available port found!")
+
+
+KRESD_LOG_STARTUP_MSGID = re.compile(r'^\[[^]]+\]\[{}.*'.format(KRESD_STARTUP_MSGID))
+KRESD_LOG_IO_CLOSE = re.compile(r'^\[io \].*closed by peer.*')
+
+
+@contextmanager
+def make_kresd(workdir, certname=None, ip='127.0.0.1', ip6='::1', **kwargs):
+ with Kresd(workdir, ip=ip, ip6=ip6, certname=certname, **kwargs) as kresd:
+ yield kresd
+ print(kresd.partial_log())
diff --git a/tests/pytests/meson.build b/tests/pytests/meson.build
new file mode 100644
index 0000000..d717dc2
--- /dev/null
+++ b/tests/pytests/meson.build
@@ -0,0 +1,77 @@
+# tests: pytests
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+# python 3 dependencies
+py3_deps += [
+ ['jinja2', 'jinja2 (for pytests)'],
+ ['dns', 'dnspython (for pytests)'],
+ ['pytest', 'pytest (for pytests)'],
+ ['pytest_html', 'pytest-html (for pytests)'],
+ ['xdist', 'pytest-xdist (for pytests)'],
+]
+
+if gnutls.version().version_compare('<3.6.4')
+ error('pytests require GnuTLS >= 3.6.4')
+endif
+
+# compile tlsproxy
+tlsproxy_src = files([
+ 'proxy/tlsproxy.c',
+ 'proxy/tls-proxy.c',
+])
+tlsproxy = executable(
+ 'tlsproxy',
+ tlsproxy_src,
+ dependencies: [
+ libkres_dep,
+ libuv,
+ gnutls,
+ ],
+)
+
+# path to kresd and tlsproxy
+pytests_env = environment()
+pytests_env.prepend('PATH', sbin_dir, meson.current_build_dir())
+
+test(
+ 'pytests.parallel',
+ python3,
+ args: [
+ '-m', 'pytest',
+ '-d',
+ '--html', 'pytests.parallel.html',
+ '--self-contained-html',
+ '--junitxml=pytests.parallel.junit.xml',
+ '-n', '24',
+ '-v',
+ ],
+ env: pytests_env,
+ suite: [
+ 'postinstall',
+ 'pytests',
+ ],
+ workdir: meson.current_source_dir(),
+ is_parallel: false,
+ timeout: 180,
+ depends: tlsproxy,
+)
+
+test(
+ 'pytests.single',
+ python3,
+ args: [
+ '-m', 'pytest',
+ '--junitxml=pytests.single.junit.xml',
+ '-ra',
+ '--capture=no',
+ 'conn_flood.py',
+ ],
+ env: pytests_env,
+ suite: [
+ 'postinstall',
+ 'pytests',
+ ],
+ workdir: meson.current_source_dir(),
+ is_parallel: false,
+ timeout: 240,
+)
diff --git a/tests/pytests/proxy.py b/tests/pytests/proxy.py
new file mode 100644
index 0000000..b8a53cd
--- /dev/null
+++ b/tests/pytests/proxy.py
@@ -0,0 +1,161 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from contextlib import contextmanager, ContextDecorator
+import os
+import subprocess
+from typing import Any, Dict, Optional
+
+import dns
+import dns.rcode
+import pytest
+
+from kresd import CERTS_DIR, Forward, Kresd, make_kresd, make_port
+import utils
+
+
+HINTS = {
+ '0.foo.': '127.0.0.1',
+ '1.foo.': '127.0.0.1',
+ '2.foo.': '127.0.0.1',
+ '3.foo.': '127.0.0.1',
+}
+
+
+def resolve_hint(sock, qname):
+ buff, msgid = utils.get_msgbuff(qname)
+ sock.sendall(buff)
+ answer = utils.receive_parse_answer(sock)
+ assert answer.id == msgid
+ assert answer.rcode() == dns.rcode.NOERROR
+ assert answer.answer[0][0].address == HINTS[qname]
+
+
+class Proxy(ContextDecorator):
+ EXECUTABLE = ''
+
+ def __init__(
+ self,
+ local_ip: str = '127.0.0.1',
+ local_port: Optional[int] = None,
+ upstream_ip: str = '127.0.0.1',
+ upstream_port: Optional[int] = None
+ ) -> None:
+ self.local_ip = local_ip
+ self.local_port = local_port
+ self.upstream_ip = upstream_ip
+ self.upstream_port = upstream_port
+ self.proxy = None
+
+ def get_args(self):
+ args = []
+ args.append('--local')
+ args.append(self.local_ip)
+ if self.local_port is not None:
+ args.append('--lport')
+ args.append(str(self.local_port))
+ args.append('--upstream')
+ args.append(self.upstream_ip)
+ if self.upstream_port is not None:
+ args.append('--uport')
+ args.append(str(self.upstream_port))
+ return args
+
+ def __enter__(self):
+ args = [self.EXECUTABLE] + self.get_args()
+ print(' '.join(args))
+
+ try:
+ self.proxy = subprocess.Popen(
+ args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
+ except subprocess.CalledProcessError:
+ pytest.skip("proxy '{}' failed to run (did you compile it?)"
+ .format(self.EXECUTABLE))
+
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ if self.proxy is not None:
+ self.proxy.terminate()
+ self.proxy = None
+
+
+class TLSProxy(Proxy):
+ EXECUTABLE = 'tlsproxy'
+
+ def __init__(
+ self,
+ local_ip: str = '127.0.0.1',
+ local_port: Optional[int] = None,
+ upstream_ip: str = '127.0.0.1',
+ upstream_port: Optional[int] = None,
+ certname: Optional[str] = 'tt',
+ close: Optional[int] = None,
+ rehandshake: bool = False,
+ force_tls13: bool = False
+ ) -> None:
+ super().__init__(local_ip, local_port, upstream_ip, upstream_port)
+ if certname is not None:
+ self.cert_path = os.path.join(CERTS_DIR, certname + '.cert.pem')
+ self.key_path = os.path.join(CERTS_DIR, certname + '.key.pem')
+ else:
+ self.cert_path = None
+ self.key_path = None
+ self.close = close
+ self.rehandshake = rehandshake
+ self.force_tls13 = force_tls13
+
+ def get_args(self):
+ args = super().get_args()
+ if self.cert_path is not None:
+ args.append('--cert')
+ args.append(self.cert_path)
+ if self.key_path is not None:
+ args.append('--key')
+ args.append(self.key_path)
+ if self.close is not None:
+ args.append('--close')
+ args.append(str(self.close))
+ if self.rehandshake:
+ args.append('--rehandshake')
+ if self.force_tls13:
+ args.append('--tls13')
+ return args
+
+
+@contextmanager
+def kresd_tls_client(
+ workdir: str,
+ proxy: TLSProxy,
+ kresd_tls_client_kwargs: Optional[Dict[Any, Any]] = None,
+ kresd_fwd_target_kwargs: Optional[Dict[Any, Any]] = None
+ ) -> Kresd:
+ """kresd_tls_client --(tls)--> tlsproxy --(tcp)--> kresd_fwd_target"""
+ ALLOWED_IPS = {'127.0.0.1', '::1'}
+ assert proxy.local_ip in ALLOWED_IPS, "only localhost IPs supported for proxy"
+ assert proxy.upstream_ip in ALLOWED_IPS, "only localhost IPs are supported for proxy"
+
+ if kresd_tls_client_kwargs is None:
+ kresd_tls_client_kwargs = {}
+ if kresd_fwd_target_kwargs is None:
+ kresd_fwd_target_kwargs = {}
+
+ # run forward target instance
+ dir1 = os.path.join(workdir, 'kresd_fwd_target')
+ os.makedirs(dir1)
+
+ with make_kresd(dir1, hints=HINTS, **kresd_fwd_target_kwargs) as kresd_fwd_target:
+ sock = kresd_fwd_target.ip_tcp_socket()
+ resolve_hint(sock, list(HINTS.keys())[0])
+
+ proxy.local_port = make_port('127.0.0.1', '::1')
+ proxy.upstream_port = kresd_fwd_target.port
+
+ with proxy:
+ # run test kresd instance
+ dir2 = os.path.join(workdir, 'kresd_tls_client')
+ os.makedirs(dir2)
+ forward = Forward(
+ proto='tls', ip=proxy.local_ip, port=proxy.local_port,
+ hostname='transport-test-server.com', ca_file=proxy.cert_path)
+ with make_kresd(dir2, forward=forward, **kresd_tls_client_kwargs) as kresd:
+ yield kresd
diff --git a/tests/pytests/proxy/tls-proxy.c b/tests/pytests/proxy/tls-proxy.c
new file mode 100644
index 0000000..986716c
--- /dev/null
+++ b/tests/pytests/proxy/tls-proxy.c
@@ -0,0 +1,1038 @@
+/* SPDX-License-Identifier: GPL-3.0-or-later */
+
+#include <assert.h>
+#include <stdio.h>
+#include <unistd.h>
+#include <string.h>
+#include <stdlib.h>
+#include <stdbool.h>
+#include <gnutls/gnutls.h>
+#include <uv.h>
+#include "lib/generic/array.h"
+#include "tls-proxy.h"
+
+#define TLS_MAX_SEND_RETRIES 100
+#define CLIENT_ANSWER_CHUNK_SIZE 8
+
+#define MAX_CLIENT_PENDING_SIZE 4096
+
+struct buf {
+ size_t size;
+ char buf[];
+};
+
+enum peer_state {
+ STATE_NOT_CONNECTED,
+ STATE_LISTENING,
+ STATE_CONNECTED,
+ STATE_CONNECT_IN_PROGRESS,
+ STATE_CLOSING_IN_PROGRESS
+};
+
+enum handshake_state {
+ TLS_HS_NOT_STARTED = 0,
+ TLS_HS_EXPECTED,
+ TLS_HS_REAUTH_EXPECTED,
+ TLS_HS_IN_PROGRESS,
+ TLS_HS_DONE,
+ TLS_HS_CLOSING,
+ TLS_HS_LAST
+};
+
+struct tls_ctx {
+ gnutls_session_t session;
+ enum handshake_state handshake_state;
+ /* for reading from the network */
+ const uint8_t *buf;
+ ssize_t nread;
+ ssize_t consumed;
+ uint8_t recv_buf[4096];
+};
+
+struct peer {
+ uv_tcp_t handle;
+ enum peer_state state;
+ struct sockaddr_storage addr;
+ array_t(struct buf *) pending_buf;
+ uint64_t connection_timestamp;
+ struct tls_ctx *tls;
+ struct peer *peer;
+ int active_requests;
+};
+
+struct tls_proxy_ctx {
+ const struct args *a;
+ uv_loop_t *loop;
+ gnutls_certificate_credentials_t tls_credentials;
+ gnutls_priority_t tls_priority_cache;
+ struct {
+ uv_tcp_t handle;
+ struct sockaddr_storage addr;
+ } server;
+ struct sockaddr_storage upstream_addr;
+ array_t(struct peer *) client_list;
+ char uv_wire_buf[65535 * 2];
+ int conn_sequence;
+};
+
+static void read_from_upstream_cb(uv_stream_t *upstream, ssize_t nread, const uv_buf_t *buf);
+static void read_from_client_cb(uv_stream_t *client, ssize_t nread, const uv_buf_t *buf);
+static ssize_t proxy_gnutls_pull(gnutls_transport_ptr_t h, void *buf, size_t len);
+static ssize_t proxy_gnutls_push(gnutls_transport_ptr_t h, const void *buf, size_t len);
+static int tls_process_from_upstream(struct peer *upstream, const uint8_t *buf, ssize_t nread);
+static int tls_process_from_client(struct peer *client, const uint8_t *buf, ssize_t nread);
+static int write_to_upstream_pending(struct peer *peer);
+static int write_to_client_pending(struct peer *peer);
+static void on_client_close(uv_handle_t *handle);
+static void on_upstream_close(uv_handle_t *handle);
+
+static int gnutls_references = 0;
+
+static const char * const tlsv12_priorities =
+ "NORMAL:" /* GnuTLS defaults */
+ "-VERS-TLS1.0:-VERS-TLS1.1:+VERS-TLS1.2:-VERS-TLS1.3:" /* TLS 1.2 only */
+ "-VERS-SSL3.0:-ARCFOUR-128:-COMP-ALL:+COMP-NULL";
+
+static const char * const tlsv13_priorities =
+ "NORMAL:" /* GnuTLS defaults */
+ "-VERS-TLS1.0:-VERS-TLS1.1:-VERS-TLS1.2:+VERS-TLS1.3:" /* TLS 1.3 only */
+ "-VERS-SSL3.0:-ARCFOUR-128:-COMP-ALL:+COMP-NULL";
+
+static struct tls_proxy_ctx *get_proxy(struct peer *peer)
+{
+ return (struct tls_proxy_ctx *)peer->handle.loop->data;
+}
+
+const void *ip_addr(const struct sockaddr *addr)
+{
+ if (!addr) {
+ return NULL;
+ }
+ switch (addr->sa_family) {
+ case AF_INET: return (const void *)&(((const struct sockaddr_in *)addr)->sin_addr);
+ case AF_INET6: return (const void *)&(((const struct sockaddr_in6 *)addr)->sin6_addr);
+ default: return NULL;
+ }
+}
+
+uint16_t ip_addr_port(const struct sockaddr *addr)
+{
+ if (!addr) {
+ return 0;
+ }
+ switch (addr->sa_family) {
+ case AF_INET: return ntohs(((const struct sockaddr_in *)addr)->sin_port);
+ case AF_INET6: return ntohs(((const struct sockaddr_in6 *)addr)->sin6_port);
+ default: return 0;
+ }
+}
+
+static int ip_addr_str(const struct sockaddr *addr, char *buf, size_t *buflen)
+{
+ int ret = 0;
+ if (!addr || !buf || !buflen) {
+ return EINVAL;
+ }
+
+ char str[INET6_ADDRSTRLEN + 6];
+ if (!inet_ntop(addr->sa_family, ip_addr(addr), str, sizeof(str))) {
+ return errno;
+ }
+ int len = strlen(str);
+ str[len] = '#';
+ snprintf(&str[len + 1], 6, "%hu", ip_addr_port(addr));
+ len += 6;
+ str[len] = 0;
+ if (len >= *buflen) {
+ ret = ENOSPC;
+ } else {
+ memcpy(buf, str, len + 1);
+ }
+ *buflen = len;
+ return ret;
+}
+
+static inline char *ip_straddr(const struct sockaddr_storage *saddr_storage)
+{
+ assert(saddr_storage != NULL);
+ const struct sockaddr *addr = (const struct sockaddr *)saddr_storage;
+ /* We are the single-threaded application */
+ static char str[INET6_ADDRSTRLEN + 6];
+ size_t len = sizeof(str);
+ int ret = ip_addr_str(addr, str, &len);
+ return ret != 0 || len == 0 ? NULL : str;
+}
+
+static struct buf *alloc_io_buffer(size_t size)
+{
+ struct buf *buf = calloc(1, sizeof (struct buf) + size);
+ buf->size = size;
+ return buf;
+}
+
+static void free_io_buffer(struct buf *buf)
+{
+ if (!buf) {
+ return;
+ }
+ free(buf);
+}
+
+static struct buf *get_first_pending_buf(struct peer *peer)
+{
+ struct buf *buf = NULL;
+ if (peer->pending_buf.len > 0) {
+ buf = peer->pending_buf.at[0];
+ }
+ return buf;
+}
+
+static struct buf *remove_first_pending_buf(struct peer *peer)
+{
+ if (peer->pending_buf.len == 0) {
+ return NULL;
+ }
+ struct buf * buf = peer->pending_buf.at[0];
+ for (int i = 1; i < peer->pending_buf.len; ++i) {
+ peer->pending_buf.at[i - 1] = peer->pending_buf.at[i];
+ }
+ peer->pending_buf.len -= 1;
+ return buf;
+}
+
+static void clear_pending_bufs(struct peer *peer)
+{
+ for (int i = 0; i < peer->pending_buf.len; ++i) {
+ struct buf *b = peer->pending_buf.at[i];
+ free_io_buffer(b);
+ }
+ peer->pending_buf.len = 0;
+}
+
+static void alloc_uv_buffer(uv_handle_t *handle, size_t suggested_size, uv_buf_t *buf)
+{
+ struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)handle->loop->data;
+ buf->base = proxy->uv_wire_buf;
+ buf->len = sizeof(proxy->uv_wire_buf);
+}
+
+static void on_client_close(uv_handle_t *handle)
+{
+ struct peer *client = (struct peer *)handle->data;
+ struct peer *upstream = client->peer;
+ fprintf(stdout, "[client] connection with '%s' closed\n", ip_straddr(&client->addr));
+ assert(client->tls);
+ gnutls_deinit(client->tls->session);
+ client->tls->handshake_state = TLS_HS_NOT_STARTED;
+ client->state = STATE_NOT_CONNECTED;
+ if (upstream->state != STATE_NOT_CONNECTED) {
+ if (upstream->state == STATE_CONNECTED) {
+ fprintf(stdout, "[client] closing connection with upstream for '%s'\n",
+ ip_straddr(&client->addr));
+ upstream->state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&upstream->handle, on_upstream_close);
+ }
+ return;
+ }
+ struct tls_proxy_ctx *proxy = get_proxy(client);
+ for (size_t i = 0; i < proxy->client_list.len; ++i) {
+ struct peer *client_i = proxy->client_list.at[i];
+ if (client_i == client) {
+ fprintf(stdout, "[client] connection structures deallocated for '%s'\n",
+ ip_straddr(&client->addr));
+ array_del(proxy->client_list, i);
+ free(client->tls);
+ free(client);
+ break;
+ }
+ }
+}
+
+static void on_upstream_close(uv_handle_t *handle)
+{
+ struct peer *upstream = (struct peer *)handle->data;
+ struct peer *client = upstream->peer;
+ assert(upstream->tls == NULL);
+ upstream->state = STATE_NOT_CONNECTED;
+ fprintf(stdout, "[upstream] connection with upstream closed for client '%s'\n", ip_straddr(&client->addr));
+ if (client->state != STATE_NOT_CONNECTED) {
+ if (client->state == STATE_CONNECTED) {
+ fprintf(stdout, "[upstream] closing connection to client '%s'\n",
+ ip_straddr(&client->addr));
+ client->state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&client->handle, on_client_close);
+ }
+ return;
+ }
+ struct tls_proxy_ctx *proxy = get_proxy(upstream);
+ for (size_t i = 0; i < proxy->client_list.len; ++i) {
+ struct peer *client_i = proxy->client_list.at[i];
+ if (client_i == client) {
+ fprintf(stdout, "[upstream] connection structures deallocated for '%s'\n",
+ ip_straddr(&client->addr));
+ array_del(proxy->client_list, i);
+ free(upstream);
+ free(client->tls);
+ free(client);
+ break;
+ }
+ }
+}
+
+static void write_to_client_cb(uv_write_t *req, int status)
+{
+ struct peer *client = (struct peer *)req->handle->data;
+ free(req);
+ client->active_requests -= 1;
+ if (status) {
+ fprintf(stdout, "[client] error writing to client '%s': %s\n",
+ ip_straddr(&client->addr), uv_strerror(status));
+ clear_pending_bufs(client);
+ if (client->state == STATE_CONNECTED) {
+ client->state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&client->handle, on_client_close);
+ return;
+ }
+ }
+ fprintf(stdout, "[client] successfully wrote to client '%s', pending len is %zd, active requests %i\n",
+ ip_straddr(&client->addr), client->pending_buf.len, client->active_requests);
+ if (client->state == STATE_CONNECTED &&
+ client->tls->handshake_state == TLS_HS_DONE) {
+ struct tls_proxy_ctx *proxy = get_proxy(client);
+ uint64_t elapsed = uv_now(proxy->loop) - client->connection_timestamp;
+ if (!proxy->a->close_connection || elapsed < proxy->a->close_timeout) {
+ write_to_client_pending(client);
+ } else {
+ clear_pending_bufs(client);
+ client->state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&client->handle, on_client_close);
+ fprintf(stdout, "[client] closing connection to client '%s'\n", ip_straddr(&client->addr));
+ }
+ }
+}
+
+static void write_to_upstream_cb(uv_write_t *req, int status)
+{
+ struct peer *upstream = (struct peer *)req->handle->data;
+ void *data = req->data;
+ free(req);
+ if (status) {
+ fprintf(stdout, "[upstream] error writing to upstream: %s\n", uv_strerror(status));
+ clear_pending_bufs(upstream);
+ upstream->state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&upstream->handle, on_upstream_close);
+ return;
+ }
+ if (data != NULL) {
+ assert(upstream->pending_buf.len > 0);
+ struct buf *buf = get_first_pending_buf(upstream);
+ assert(data == (void *)buf->buf);
+ fprintf(stdout, "[upstream] successfully wrote %zi bytes to upstream, pending len is %zd\n",
+ buf->size, upstream->pending_buf.len);
+ remove_first_pending_buf(upstream);
+ free_io_buffer(buf);
+ } else {
+ fprintf(stdout, "[upstream] successfully wrote to upstream, pending len is %zd\n",
+ upstream->pending_buf.len);
+ }
+ if (upstream->peer == NULL || upstream->peer->state != STATE_CONNECTED) {
+ clear_pending_bufs(upstream);
+ } else if (upstream->state == STATE_CONNECTED && upstream->pending_buf.len > 0) {
+ write_to_upstream_pending(upstream);
+ }
+}
+
+static void accept_connection_from_client(uv_stream_t *server)
+{
+ struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)server->loop->data;
+ struct peer *client = calloc(1, sizeof(struct peer));
+ uv_tcp_init(proxy->loop, &client->handle);
+ uv_tcp_nodelay((uv_tcp_t *)&client->handle, 1);
+
+ int err = uv_accept(server, (uv_stream_t*)&client->handle);
+ if (err != 0) {
+ fprintf(stdout, "[client] incoming connection - uv_accept() failed: (%d) %s\n",
+ err, uv_strerror(err));
+ proxy->conn_sequence = 0;
+ return;
+ }
+
+ client->state = STATE_CONNECTED;
+ array_init(client->pending_buf);
+ client->handle.data = client;
+
+ struct peer *upstream = calloc(1, sizeof(struct peer));
+ uv_tcp_init(proxy->loop, &upstream->handle);
+ uv_tcp_nodelay((uv_tcp_t *)&upstream->handle, 1);
+
+ client->peer = upstream;
+
+ array_init(upstream->pending_buf);
+ upstream->state = STATE_NOT_CONNECTED;
+ upstream->peer = client;
+ upstream->handle.data = upstream;
+
+ struct sockaddr *addr = (struct sockaddr *)&(client->addr);
+ int addr_len = sizeof(client->addr);
+ int ret = uv_tcp_getpeername(&client->handle, addr, &addr_len);
+ if (ret || addr->sa_family == AF_UNSPEC) {
+ client->state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&client->handle, on_client_close);
+ fprintf(stdout, "[client] incoming connection - uv_tcp_getpeername() failed: (%d) %s\n",
+ err, uv_strerror(err));
+ proxy->conn_sequence = 0;
+ return;
+ }
+ memcpy(&upstream->addr, &proxy->upstream_addr, sizeof(struct sockaddr_storage));
+
+ struct tls_ctx *tls = calloc(1, sizeof(struct tls_ctx));
+ tls->handshake_state = TLS_HS_NOT_STARTED;
+
+ client->tls = tls;
+ const char *errpos = NULL;
+ unsigned int gnutls_flags = GNUTLS_SERVER | GNUTLS_NONBLOCK;
+#if GNUTLS_VERSION_NUMBER >= 0x030604
+ if (proxy->a->tls_13) {
+ gnutls_flags |= GNUTLS_POST_HANDSHAKE_AUTH;
+ }
+#endif
+ err = gnutls_init(&tls->session, gnutls_flags);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stdout, "[client] gnutls_init() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ }
+ err = gnutls_priority_set(tls->session, proxy->tls_priority_cache);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stdout, "[client] gnutls_priority_set() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ }
+
+ const char *direct_priorities = proxy->a->tls_13 ? tlsv13_priorities : tlsv12_priorities;
+ err = gnutls_priority_set_direct(tls->session, direct_priorities, &errpos);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stdout, "[client] setting priority '%s' failed at character %zd (...'%s') with %s (%d)\n",
+ direct_priorities, errpos - direct_priorities, errpos,
+ gnutls_strerror_name(err), err);
+ }
+ err = gnutls_credentials_set(tls->session, GNUTLS_CRD_CERTIFICATE, proxy->tls_credentials);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stdout, "[client] gnutls_credentials_set() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ }
+ if (proxy->a->tls_13) {
+ gnutls_certificate_server_set_request(tls->session, GNUTLS_CERT_REQUEST);
+ } else {
+ gnutls_certificate_server_set_request(tls->session, GNUTLS_CERT_IGNORE);
+ }
+ gnutls_handshake_set_timeout(tls->session, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT);
+ gnutls_transport_set_pull_function(tls->session, proxy_gnutls_pull);
+ gnutls_transport_set_push_function(tls->session, proxy_gnutls_push);
+ gnutls_transport_set_ptr(tls->session, client);
+
+ tls->handshake_state = TLS_HS_IN_PROGRESS;
+
+ client->connection_timestamp = uv_now(proxy->loop);
+ proxy->conn_sequence += 1;
+ array_push(proxy->client_list, client);
+
+ fprintf(stdout, "[client] incoming connection from '%s'\n", ip_straddr(&client->addr));
+ uv_read_start((uv_stream_t*)&client->handle, alloc_uv_buffer, read_from_client_cb);
+}
+
+static void dynamic_handle_close_cb(uv_handle_t *handle)
+{
+ free(handle);
+}
+
+static void delayed_accept_timer_cb(uv_timer_t *timer)
+{
+ uv_stream_t *server = (uv_stream_t *)timer->data;
+ fprintf(stdout, "[client] delayed connection processing\n");
+ accept_connection_from_client(server);
+ uv_close((uv_handle_t *)timer, dynamic_handle_close_cb);
+}
+
+static void on_client_connection(uv_stream_t *server, int status)
+{
+ if (status < 0) {
+ fprintf(stdout, "[client] incoming connection error: %s\n", uv_strerror(status));
+ return;
+ }
+ struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)server->loop->data;
+ proxy->conn_sequence += 1;
+ if (proxy->a->max_conn_sequence > 0 &&
+ proxy->conn_sequence > proxy->a->max_conn_sequence) {
+ fprintf(stdout, "[client] incoming connection, delaying\n");
+ uv_timer_t *timer = (uv_timer_t*)malloc(sizeof *timer);
+ uv_timer_init(uv_default_loop(), timer);
+ timer->data = server;
+ uv_timer_start(timer, delayed_accept_timer_cb, 10000, 0);
+ proxy->conn_sequence = 0;
+ } else {
+ accept_connection_from_client(server);
+ }
+}
+
+static void on_connect_to_upstream(uv_connect_t *req, int status)
+{
+ struct peer *upstream = (struct peer *)req->handle->data;
+ free(req);
+ if (status < 0) {
+ fprintf(stdout, "[upstream] error connecting to upstream (%s): %s\n",
+ ip_straddr(&upstream->addr),
+ uv_strerror(status));
+ clear_pending_bufs(upstream);
+ upstream->state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&upstream->handle, on_upstream_close);
+ return;
+ }
+ fprintf(stdout, "[upstream] connected to %s\n", ip_straddr(&upstream->addr));
+
+ upstream->state = STATE_CONNECTED;
+ uv_read_start((uv_stream_t*)&upstream->handle, alloc_uv_buffer, read_from_upstream_cb);
+ if (upstream->pending_buf.len > 0) {
+ write_to_upstream_pending(upstream);
+ }
+}
+
+static void read_from_client_cb(uv_stream_t *handle, ssize_t nread, const uv_buf_t *buf)
+{
+ if (nread == 0) {
+ fprintf(stdout, "[client] reading %zd bytes\n", nread);
+ return;
+ }
+ struct peer *client = (struct peer *)handle->data;
+ if (nread < 0) {
+ if (nread != UV_EOF) {
+ fprintf(stdout, "[client] error reading from '%s': %s\n",
+ ip_straddr(&client->addr),
+ uv_err_name(nread));
+ } else {
+ fprintf(stdout, "[client] closing connection with '%s'\n",
+ ip_straddr(&client->addr));
+ }
+ if (client->state == STATE_CONNECTED) {
+ client->state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)handle, on_client_close);
+ }
+ return;
+ }
+
+ struct tls_proxy_ctx *proxy = get_proxy(client);
+ if (proxy->a->accept_only) {
+ fprintf(stdout, "[client] ignoring %zd bytes from '%s'\n", nread, ip_straddr(&client->addr));
+ return;
+ }
+ fprintf(stdout, "[client] reading %zd bytes from '%s'\n", nread, ip_straddr(&client->addr));
+
+ int res = tls_process_from_client(client, (const uint8_t *)buf->base, nread);
+ if (res < 0) {
+ if (client->state == STATE_CONNECTED) {
+ client->state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&client->handle, on_client_close);
+ }
+ }
+}
+
+static void read_from_upstream_cb(uv_stream_t *handle, ssize_t nread, const uv_buf_t *buf)
+{
+ fprintf(stdout, "[upstream] reading %zd bytes\n", nread);
+ if (nread == 0) {
+ return;
+ }
+ struct peer *upstream = (struct peer *)handle->data;
+ if (nread < 0) {
+ if (nread != UV_EOF) {
+ fprintf(stdout, "[upstream] error reading from upstream: %s\n", uv_err_name(nread));
+ } else {
+ fprintf(stdout, "[upstream] closing connection\n");
+ }
+ clear_pending_bufs(upstream);
+ if (upstream->state == STATE_CONNECTED) {
+ upstream->state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&upstream->handle, on_upstream_close);
+ }
+ return;
+ }
+ int res = tls_process_from_upstream(upstream, (const uint8_t *)buf->base, nread);
+ if (res < 0) {
+ fprintf(stdout, "[upstream] error processing tls data to client\n");
+ if (upstream->peer->state == STATE_CONNECTED) {
+ upstream->peer->state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&upstream->peer->handle, on_client_close);
+ }
+ }
+}
+
+static void push_to_upstream_pending(struct peer *upstream, const char *buf, size_t size)
+{
+ struct buf *b = alloc_io_buffer(size);
+ memcpy(b->buf, buf, b->size);
+ array_push(upstream->pending_buf, b);
+}
+
+static void push_to_client_pending(struct peer *client, const char *buf, size_t size)
+{
+ struct tls_proxy_ctx *proxy = get_proxy(client);
+ while (size > 0) {
+ int temp_size = size;
+ if (proxy->a->rehandshake && temp_size > CLIENT_ANSWER_CHUNK_SIZE) {
+ temp_size = CLIENT_ANSWER_CHUNK_SIZE;
+ }
+ struct buf *b = alloc_io_buffer(temp_size);
+ memcpy(b->buf, buf, b->size);
+ array_push(client->pending_buf, b);
+ size -= temp_size;
+ buf += temp_size;
+ }
+}
+
+static int write_to_upstream_pending(struct peer *upstream)
+{
+ struct buf *buf = get_first_pending_buf(upstream);
+ uv_write_t *req = (uv_write_t *) malloc(sizeof(uv_write_t));
+ uv_buf_t wrbuf = uv_buf_init(buf->buf, buf->size);
+ req->data = buf->buf;
+ fprintf(stdout, "[upstream] writing %zd bytes\n", buf->size);
+ return uv_write(req, (uv_stream_t *)&upstream->handle, &wrbuf, 1, write_to_upstream_cb);
+}
+
+static ssize_t proxy_gnutls_pull(gnutls_transport_ptr_t h, void *buf, size_t len)
+{
+ struct peer *peer = (struct peer *)h;
+ struct tls_ctx *t = peer->tls;
+
+ fprintf(stdout, "[gnutls_pull] pulling %zd bytes\n", len);
+
+ if (t->nread <= t->consumed) {
+ errno = EAGAIN;
+ fprintf(stdout, "[gnutls_pull] return EAGAIN\n");
+ return -1;
+ }
+
+ ssize_t avail = t->nread - t->consumed;
+ ssize_t transfer = (avail <= len ? avail : len);
+ memcpy(buf, t->buf + t->consumed, transfer);
+ t->consumed += transfer;
+ return transfer;
+}
+
+ssize_t proxy_gnutls_push(gnutls_transport_ptr_t h, const void *buf, size_t len)
+{
+ struct peer *client = (struct peer *)h;
+ fprintf(stdout, "[gnutls_push] writing %zd bytes\n", len);
+
+ ssize_t ret = -1;
+ const size_t req_size_aligned = ((sizeof(uv_write_t) / 16) + 1) * 16;
+ char *common_buf = malloc(req_size_aligned + len);
+ uv_write_t *req = (uv_write_t *) common_buf;
+ char *data = common_buf + req_size_aligned;
+ const uv_buf_t uv_buf[1] = {
+ { data, len }
+ };
+ memcpy(data, buf, len);
+ int res = uv_write(req, (uv_stream_t *)&client->handle, uv_buf, 1, write_to_client_cb);
+ if (res == 0) {
+ ret = len;
+ client->active_requests += 1;
+ } else {
+ free(common_buf);
+ errno = EIO;
+ }
+ return ret;
+}
+
+static int write_to_client_pending(struct peer *client)
+{
+ if (client->pending_buf.len == 0) {
+ return 0;
+ }
+
+ struct tls_proxy_ctx *proxy = get_proxy(client);
+ struct buf *buf = get_first_pending_buf(client);
+ fprintf(stdout, "[client] writing %zd bytes\n", buf->size);
+
+ gnutls_session_t tls_session = client->tls->session;
+ assert(client->tls->handshake_state != TLS_HS_IN_PROGRESS);
+
+ char *data = buf->buf;
+ size_t len = buf->size;
+
+ ssize_t count = 0;
+ ssize_t submitted = len;
+ ssize_t retries = 0;
+ do {
+ count = gnutls_record_send(tls_session, data, len);
+ if (count < 0) {
+ if (gnutls_error_is_fatal(count)) {
+ fprintf(stdout, "[client] gnutls_record_send failed: %s (%zd)\n",
+ gnutls_strerror_name(count), count);
+ return -1;
+ }
+ if (++retries > TLS_MAX_SEND_RETRIES) {
+ fprintf(stdout, "[client] gnutls_record_send: too many sequential non-fatal errors (%zd), last error is: %s (%zd)\n",
+ retries, gnutls_strerror_name(count), count);
+ return -1;
+ }
+ } else if (count != 0) {
+ data += count;
+ len -= count;
+ retries = 0;
+ } else {
+ if (++retries < TLS_MAX_SEND_RETRIES) {
+ continue;
+ }
+ fprintf(stdout, "[client] gnutls_record_send: too many retries (%zd)\n",
+ retries);
+ fprintf(stdout, "[client] tls_push_to_client didn't send all data(%zd of %zd)\n",
+ len, submitted);
+ return -1;
+ }
+ } while (len > 0);
+
+ remove_first_pending_buf(client);
+ free_io_buffer(buf);
+
+ fprintf(stdout, "[client] submitted %zd bytes\n", submitted);
+ if (proxy->a->rehandshake) {
+ int err = GNUTLS_E_SUCCESS;
+#if GNUTLS_VERSION_NUMBER >= 0x030604
+ if (proxy->a->tls_13) {
+ int flags = gnutls_session_get_flags(tls_session);
+ if ((flags & GNUTLS_SFLAGS_POST_HANDSHAKE_AUTH) == 0) {
+ /* Client doesn't support post-handshake re-authentication,
+ * nothing to test here */
+ fprintf(stdout, "[client] GNUTLS_SFLAGS_POST_HANDSHAKE_AUTH flag not detected\n");
+ assert(false);
+ }
+ err = gnutls_reauth(tls_session, 0);
+ if (err != GNUTLS_E_INTERRUPTED &&
+ err != GNUTLS_E_AGAIN &&
+ err != GNUTLS_E_GOT_APPLICATION_DATA) {
+ fprintf(stdout, "[client] gnutls_reauth() failed: %s (%i)\n",
+ gnutls_strerror_name(err), err);
+ } else {
+ fprintf(stdout, "[client] post-handshake authentication initiated\n");
+ }
+ client->tls->handshake_state = TLS_HS_REAUTH_EXPECTED;
+ } else {
+ assert (gnutls_safe_renegotiation_status(tls_session) != 0);
+ err = gnutls_rehandshake(tls_session);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stdout, "[client] gnutls_rehandshake() failed: %s (%i)\n",
+ gnutls_strerror_name(err), err);
+ assert(false);
+ } else {
+ fprintf(stdout, "[client] rehandshake started\n");
+ }
+ client->tls->handshake_state = TLS_HS_EXPECTED;
+ }
+#else
+ assert (gnutls_safe_renegotiation_status(tls_session) != 0);
+ err = gnutls_rehandshake(tls_session);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stdout, "[client] gnutls_rehandshake() failed: %s (%i)\n",
+ gnutls_strerror_name(err), err);
+ assert(false);
+ } else {
+ fprintf(stdout, "[client] rehandshake started\n");
+ }
+ /* Prevent write-to-client callback from sending next pending chunk.
+ * At the same time tls_process_from_client() must not call gnutls_handshake()
+ * as there can be application data in this direction. */
+ client->tls->handshake_state = TLS_HS_EXPECTED;
+#endif
+ }
+ return submitted;
+}
+
+static int tls_process_from_upstream(struct peer *upstream, const uint8_t *buf, ssize_t len)
+{
+ struct peer *client = upstream->peer;
+
+ fprintf(stdout, "[upstream] pushing %zd bytes to client\n", len);
+
+ ssize_t submitted = 0;
+ if (client->state != STATE_CONNECTED) {
+ return submitted;
+ }
+
+ bool list_was_empty = (client->pending_buf.len == 0);
+ push_to_client_pending(client, (const char *)buf, len);
+ submitted = len;
+ if (client->tls->handshake_state == TLS_HS_DONE) {
+ if (list_was_empty && client->pending_buf.len > 0) {
+ int ret = write_to_client_pending(client);
+ if (ret < 0) {
+ submitted = -1;
+ }
+ }
+ }
+
+ return submitted;
+}
+
+int tls_process_handshake(struct peer *peer)
+{
+ struct tls_ctx *tls = peer->tls;
+ int ret = 1;
+ while (tls->handshake_state == TLS_HS_IN_PROGRESS) {
+ fprintf(stdout, "[tls] TLS handshake in progress...\n");
+ int err = gnutls_handshake(tls->session);
+ if (err == GNUTLS_E_SUCCESS) {
+ tls->handshake_state = TLS_HS_DONE;
+ fprintf(stdout, "[tls] TLS handshake has completed\n");
+ ret = 1;
+ if (peer->pending_buf.len != 0) {
+ write_to_client_pending(peer);
+ }
+ } else if (gnutls_error_is_fatal(err)) {
+ fprintf(stdout, "[tls] gnutls_handshake failed: %s (%d)\n",
+ gnutls_strerror_name(err), err);
+ ret = -1;
+ break;
+ } else {
+ fprintf(stdout, "[tls] gnutls_handshake nonfatal error: %s (%d)\n",
+ gnutls_strerror_name(err), err);
+ ret = 0;
+ break;
+ }
+ }
+ return ret;
+}
+
+#if GNUTLS_VERSION_NUMBER >= 0x030604
+int tls_process_reauth(struct peer *peer)
+{
+ struct tls_ctx *tls = peer->tls;
+ int ret = 1;
+ while (tls->handshake_state == TLS_HS_REAUTH_EXPECTED) {
+ fprintf(stdout, "[tls] TLS re-authentication in progress...\n");
+ int err = gnutls_reauth(tls->session, 0);
+ if (err == GNUTLS_E_SUCCESS) {
+ tls->handshake_state = TLS_HS_DONE;
+ fprintf(stdout, "[tls] TLS re-authentication has completed\n");
+ ret = 1;
+ if (peer->pending_buf.len != 0) {
+ write_to_client_pending(peer);
+ }
+ } else if (err != GNUTLS_E_INTERRUPTED &&
+ err != GNUTLS_E_AGAIN &&
+ err != GNUTLS_E_GOT_APPLICATION_DATA) {
+ /* these are listed as nonfatal errors there
+ * https://www.gnutls.org/manual/gnutls.html#gnutls_005freauth */
+ fprintf(stdout, "[tls] gnutls_reauth failed: %s (%d)\n",
+ gnutls_strerror_name(err), err);
+ ret = -1;
+ break;
+ } else {
+ fprintf(stdout, "[tls] gnutls_reauth nonfatal error: %s (%d)\n",
+ gnutls_strerror_name(err), err);
+ ret = 0;
+ break;
+ }
+ }
+ return ret;
+}
+#endif
+
+int tls_process_from_client(struct peer *client, const uint8_t *buf, ssize_t nread)
+{
+ struct tls_ctx *tls = client->tls;
+
+ tls->buf = buf;
+ tls->nread = nread >= 0 ? nread : 0;
+ tls->consumed = 0;
+
+ fprintf(stdout, "[client] tls_process: reading %zd bytes from client\n", nread);
+
+ int ret = 0;
+ if (tls->handshake_state == TLS_HS_REAUTH_EXPECTED) {
+ ret = tls_process_reauth(client);
+ } else {
+ ret = tls_process_handshake(client);
+ }
+ if (ret <= 0) {
+ return ret;
+ }
+
+ int submitted = 0;
+ while (true) {
+ ssize_t count = gnutls_record_recv(tls->session, tls->recv_buf, sizeof(tls->recv_buf));
+ if (count == GNUTLS_E_AGAIN) {
+ break; /* No data available */
+ } else if (count == GNUTLS_E_INTERRUPTED) {
+ continue; /* Try reading again */
+ } else if (count == GNUTLS_E_REHANDSHAKE) {
+ tls->handshake_state = TLS_HS_IN_PROGRESS;
+ ret = tls_process_handshake(client);
+ if (ret < 0) { /* Critical error */
+ return ret;
+ }
+ if (ret == 0) { /* Non fatal, most likely GNUTLS_E_AGAIN */
+ break;
+ }
+ continue;
+ }
+#if GNUTLS_VERSION_NUMBER >= 0x030604
+ else if (count == GNUTLS_E_REAUTH_REQUEST) {
+ assert(false);
+ tls->handshake_state = TLS_HS_IN_PROGRESS;
+ ret = tls_process_reauth(client);
+ if (ret < 0) { /* Critical error */
+ return ret;
+ }
+ if (ret == 0) { /* Non fatal, most likely GNUTLS_E_AGAIN */
+ break;
+ }
+ continue;
+ }
+#endif
+ else if (count < 0) {
+ fprintf(stdout, "[client] gnutls_record_recv failed: %s (%zd)\n",
+ gnutls_strerror_name(count), count);
+ assert(false);
+ return -1;
+ } else if (count == 0) {
+ break;
+ }
+ struct peer *upstream = client->peer;
+ if (upstream->state == STATE_CONNECTED) {
+ bool upstream_pending_is_empty = (upstream->pending_buf.len == 0);
+ push_to_upstream_pending(upstream, (const char *)tls->recv_buf, count);
+ if (upstream_pending_is_empty) {
+ write_to_upstream_pending(upstream);
+ }
+ } else if (upstream->state == STATE_NOT_CONNECTED) {
+ uv_connect_t *conn = (uv_connect_t *) malloc(sizeof(uv_connect_t));
+ upstream->state = STATE_CONNECT_IN_PROGRESS;
+ fprintf(stdout, "[client] connecting to upstream '%s'\n", ip_straddr(&upstream->addr));
+ uv_tcp_connect(conn, &upstream->handle, (struct sockaddr *)&upstream->addr,
+ on_connect_to_upstream);
+ push_to_upstream_pending(upstream, (const char *)tls->recv_buf, count);
+ } else if (upstream->state == STATE_CONNECT_IN_PROGRESS) {
+ push_to_upstream_pending(upstream, (const char *)tls->recv_buf, count);
+ }
+ submitted += count;
+ }
+ return submitted;
+}
+
+struct tls_proxy_ctx *tls_proxy_allocate()
+{
+ return malloc(sizeof(struct tls_proxy_ctx));
+}
+
+int tls_proxy_init(struct tls_proxy_ctx *proxy, const struct args *a)
+{
+ const char *server_addr = a->local_addr;
+ int server_port = a->local_port;
+ const char *upstream_addr = a->upstream;
+ int upstream_port = a->upstream_port;
+ const char *cert_file = a->cert_file;
+ const char *key_file = a->key_file;
+ proxy->a = a;
+ proxy->loop = uv_default_loop();
+ uv_tcp_init(proxy->loop, &proxy->server.handle);
+ int res = uv_ip4_addr(server_addr, server_port, (struct sockaddr_in *)&proxy->server.addr);
+ if (res != 0) {
+ res = uv_ip6_addr(server_addr, server_port, (struct sockaddr_in6 *)&proxy->server.addr);
+ if (res != 0) {
+ fprintf(stdout, "[proxy] tls_proxy_init: can't parse local address '%s'\n", server_addr);
+ return -1;
+ }
+ }
+ res = uv_ip4_addr(upstream_addr, upstream_port, (struct sockaddr_in *)&proxy->upstream_addr);
+ if (res != 0) {
+ res = uv_ip6_addr(upstream_addr, upstream_port, (struct sockaddr_in6 *)&proxy->upstream_addr);
+ if (res != 0) {
+ fprintf(stdout, "[proxy] tls_proxy_init: can't parse upstream address '%s'\n", upstream_addr);
+ return -1;
+ }
+ }
+ array_init(proxy->client_list);
+ proxy->conn_sequence = 0;
+
+ proxy->loop->data = proxy;
+
+ int err = 0;
+ if (gnutls_references == 0) {
+ err = gnutls_global_init();
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stdout, "[proxy] gnutls_global_init() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ return -1;
+ }
+ }
+ gnutls_references += 1;
+
+ err = gnutls_certificate_allocate_credentials(&proxy->tls_credentials);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stdout, "[proxy] gnutls_certificate_allocate_credentials() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ return -1;
+ }
+
+ err = gnutls_certificate_set_x509_system_trust(proxy->tls_credentials);
+ if (err <= 0) {
+ fprintf(stdout, "[proxy] gnutls_certificate_set_x509_system_trust() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ return -1;
+ }
+
+ if (cert_file && key_file) {
+ err = gnutls_certificate_set_x509_key_file(proxy->tls_credentials,
+ cert_file, key_file, GNUTLS_X509_FMT_PEM);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stdout, "[proxy] gnutls_certificate_set_x509_key_file() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ return -1;
+ }
+ }
+
+ err = gnutls_priority_init(&proxy->tls_priority_cache, NULL, NULL);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stdout, "[proxy] gnutls_priority_init() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ return -1;
+ }
+
+ return 0;
+}
+
+void tls_proxy_free(struct tls_proxy_ctx *proxy)
+{
+ if (!proxy) {
+ return;
+ }
+ while (proxy->client_list.len > 0) {
+ size_t last_index = proxy->client_list.len - 1;
+ struct peer *client = proxy->client_list.at[last_index];
+ clear_pending_bufs(client);
+ clear_pending_bufs(client->peer);
+ /* TODO correctly close all the uv_tcp_t */
+ free(client->peer);
+ free(client);
+ array_del(proxy->client_list, last_index);
+ }
+ gnutls_certificate_free_credentials(proxy->tls_credentials);
+ gnutls_priority_deinit(proxy->tls_priority_cache);
+ free(proxy);
+
+ gnutls_references -= 1;
+ if (gnutls_references == 0) {
+ gnutls_global_deinit();
+ }
+}
+
+int tls_proxy_start_listen(struct tls_proxy_ctx *proxy)
+{
+ uv_tcp_bind(&proxy->server.handle, (const struct sockaddr*)&proxy->server.addr, 0);
+ int ret = uv_listen((uv_stream_t*)&proxy->server.handle, 128, on_client_connection);
+ return ret;
+}
+
+int tls_proxy_run(struct tls_proxy_ctx *proxy)
+{
+ return uv_run(proxy->loop, UV_RUN_DEFAULT);
+}
diff --git a/tests/pytests/proxy/tls-proxy.h b/tests/pytests/proxy/tls-proxy.h
new file mode 100644
index 0000000..2dbd9fd
--- /dev/null
+++ b/tests/pytests/proxy/tls-proxy.h
@@ -0,0 +1,34 @@
+#pragma once
+/* SPDX-License-Identifier: GPL-3.0-or-later */
+
+#define __STDC_FORMAT_MACROS
+#include <inttypes.h>
+#include <stdint.h>
+#include <stdbool.h>
+#include <netinet/in.h>
+
+struct args {
+ const char *local_addr;
+ uint16_t local_port;
+ const char *upstream;
+ uint16_t upstream_port;
+
+ bool rehandshake;
+ bool close_connection;
+ bool accept_only;
+ bool tls_13;
+
+ uint64_t close_timeout;
+ uint32_t max_conn_sequence;
+
+ const char *cert_file;
+ const char *key_file;
+};
+
+struct tls_proxy_ctx;
+
+struct tls_proxy_ctx *tls_proxy_allocate();
+void tls_proxy_free(struct tls_proxy_ctx *proxy);
+int tls_proxy_init(struct tls_proxy_ctx *proxy, const struct args *a);
+int tls_proxy_start_listen(struct tls_proxy_ctx *proxy);
+int tls_proxy_run(struct tls_proxy_ctx *proxy);
diff --git a/tests/pytests/proxy/tlsproxy.c b/tests/pytests/proxy/tlsproxy.c
new file mode 100644
index 0000000..bcdffb0
--- /dev/null
+++ b/tests/pytests/proxy/tlsproxy.c
@@ -0,0 +1,198 @@
+/* SPDX-License-Identifier: GPL-3.0-or-later */
+
+#include <stdio.h>
+#include <getopt.h>
+#include <stdlib.h>
+#include <signal.h>
+#include <errno.h>
+#include <string.h>
+#include <gnutls/gnutls.h>
+#include "tls-proxy.h"
+
+static char default_local_addr[] = "127.0.0.1";
+static char default_upstream_addr[] = "127.0.0.1";
+static char default_cert_path[] = "../certs/tt.cert.pem";
+static char default_key_path[] = "../certs/tt.key.pem";
+
+void help(char *argv[], struct args *a)
+{
+ printf("Usage: %s [parameters] [rundir]\n", argv[0]);
+ printf("\nParameters:\n"
+ " -l, --local=[addr] Server address to bind to (default: %s).\n"
+ " -p, --lport=[port] Server port to bind to (default: %u).\n"
+ " -u, --upstream=[addr] Upstream address (default: %s).\n"
+ " -d, --uport=[port] Upstream port (default: %u).\n"
+ " -t, --cert=[path] Path to certificate file (default: %s).\n"
+ " -k, --key=[path] Path to key file (default: %s).\n"
+ " -c, --close=[N] Close connection to client after\n"
+ " every N ms (default: %" PRIu64 ").\n"
+ " -f, --fail=[N] Delay every Nth incoming connection by 10 sec,\n"
+ " 0 disables delaying (default: 0).\n"
+ " -r, --rehandshake Do TLS rehandshake after every 8 bytes\n"
+ " sent to the client (default: no).\n"
+ " -a, --acceptonly Accept incoming connections, but don't\n"
+ " connect to upstream (default: no).\n"
+ " -v, --tls13 Force use of TLSv1.3. If not turned on,\n"
+ " TLSv1.2 will be used (default: no).\n"
+ ,
+ a->local_addr, a->local_port,
+ a->upstream, a->upstream_port,
+ a->cert_file, a->key_file,
+ a->close_timeout);
+}
+
+void init_args(struct args *a)
+{
+ a->local_addr = default_local_addr;
+ a->local_port = 54000;
+ a->upstream = default_upstream_addr;
+ a->upstream_port = 53000;
+ a->cert_file = default_cert_path;
+ a->key_file = default_key_path;
+ a->rehandshake = false;
+ a->accept_only = false;
+ a->tls_13 = false;
+ a->close_connection = false;
+ a->close_timeout = 1000;
+ a->max_conn_sequence = 0; /* disabled */
+}
+
+int main(int argc, char **argv)
+{
+ long int li_value = 0;
+ int c = 0, li = 0;
+ struct option opts[] = {
+ {"local", required_argument, 0, 'l'},
+ {"lport", required_argument, 0, 'p'},
+ {"upstream", required_argument, 0, 'u'},
+ {"uport", required_argument, 0, 'd'},
+ {"cert", required_argument, 0, 't'},
+ {"key", required_argument, 0, 'k'},
+ {"close", required_argument, 0, 'c'},
+ {"fail", required_argument, 0, 'f'},
+ {"rehandshake", no_argument, 0, 'r'},
+ {"acceptonly", no_argument, 0, 'a'},
+#if GNUTLS_VERSION_NUMBER >= 0x030604
+ {"tls13", no_argument, 0, 'v'},
+#endif
+ {0, 0, 0, 0}
+ };
+ struct args args;
+ init_args(&args);
+ while ((c = getopt_long(argc, argv, "l:p:u:d:t:k:c:f:rav", opts, &li)) != -1) {
+ switch (c)
+ {
+ case 'l':
+ args.local_addr = optarg;
+ break;
+ case 'u':
+ args.upstream = optarg;
+ break;
+ case 't':
+ args.cert_file = optarg;
+ break;
+ case 'k':
+ args.key_file = optarg;
+ break;
+ case 'p':
+ li_value = strtol(optarg, NULL, 10);
+ if (li_value <= 0 || li_value > UINT16_MAX) {
+ printf("error: '-p' requires a positive"
+ " number less or equal to 65535, not '%s'\n", optarg);
+ return -1;
+ }
+ args.local_port = (uint16_t)li_value;
+ break;
+ case 'd':
+ li_value = strtol(optarg, NULL, 10);
+ if (li_value <= 0 || li_value > UINT16_MAX) {
+ printf("error: '-d' requires a positive"
+ " number less or equal to 65535, not '%s'\n", optarg);
+ return -1;
+ }
+ args.upstream_port = (uint16_t)li_value;
+ break;
+ case 'c':
+ li_value = strtol(optarg, NULL, 10);
+ if (li_value <= 0) {
+ printf("[system] error '-c' requires a positive"
+ " number, not '%s'\n", optarg);
+ return -1;
+ }
+ args.close_connection = true;
+ args.close_timeout = li_value;
+ break;
+ case 'f':
+ li_value = strtol(optarg, NULL, 10);
+ if (li_value <= 0 || li_value > UINT32_MAX) {
+ printf("error: '-f' requires a positive"
+ " number less or equal to %i, not '%s'\n",
+ UINT32_MAX, optarg);
+ return -1;
+ }
+ args.max_conn_sequence = (uint32_t)li_value;
+ break;
+ case 'r':
+ args.rehandshake = true;
+ break;
+ case 'a':
+ args.accept_only = true;
+ break;
+ case 'v':
+#if GNUTLS_VERSION_NUMBER >= 0x030604
+ args.tls_13 = true;
+#endif
+ break;
+ default:
+ init_args(&args);
+ help(argv, &args);
+ return -1;
+ }
+ }
+ if (signal(SIGPIPE, SIG_IGN) == SIG_ERR) {
+ fprintf(stderr, "failed to set up SIGPIPE handler to ignore(%s)\n",
+ strerror(errno));
+ }
+ struct tls_proxy_ctx *proxy = tls_proxy_allocate();
+ if (!proxy) {
+ fprintf(stderr, "can't allocate tls_proxy structure\n");
+ return 1;
+ }
+ int res = tls_proxy_init(proxy, &args);
+ if (res) {
+ fprintf(stderr, "can't initialize tls_proxy structure\n");
+ return res;
+ }
+ res = tls_proxy_start_listen(proxy);
+ if (res) {
+ fprintf(stderr, "error starting listen, error code: %i\n", res);
+ return res;
+ }
+ fprintf(stdout, "Listen on %s#%u\n"
+ "Upstream is expected on %s#%u\n"
+ "Certificate file %s\n"
+ "Key file %s\n"
+ "Rehandshake %s\n"
+ "Close %s\n"
+ "Refuse incoming connections every %ith%s\n"
+ "Only accept, don't forward %s\n"
+ "Force TLSv1.3 %s\n"
+ ,
+ args.local_addr, args.local_port,
+ args.upstream, args.upstream_port,
+ args.cert_file, args.key_file,
+ args.rehandshake ? "yes" : "no",
+ args.close_connection ? "yes" : "no",
+ args.max_conn_sequence, args.max_conn_sequence ? "" : " (disabled)",
+ args.accept_only ? "yes" : "no",
+#if GNUTLS_VERSION_NUMBER >= 0x030604
+ args.tls_13 ? "yes" : "no"
+#else
+ "Not supported"
+#endif
+ );
+ res = tls_proxy_run(proxy);
+ tls_proxy_free(proxy);
+ return res;
+}
+
diff --git a/tests/pytests/pylintrc b/tests/pytests/pylintrc
new file mode 100644
index 0000000..2c406be
--- /dev/null
+++ b/tests/pytests/pylintrc
@@ -0,0 +1,33 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+[MESSAGES CONTROL]
+
+disable=
+ missing-docstring,
+ too-few-public-methods,
+ too-many-arguments,
+ too-many-instance-attributes,
+ fixme,
+ unused-import, # checked by flake8
+ line-too-long, # checked by flake8
+ invalid-name,
+ broad-except,
+ bad-continuation,
+ global-statement,
+ no-else-return,
+ redefined-outer-name, # commonly used with pytest fixtures
+ consider-using-with,
+ consider-using-f-string,
+
+
+[SIMILARITIES]
+min-similarity-lines=6
+ignore-comments=yes
+ignore-docstrings=yes
+ignore-imports=no
+
+[DESIGN]
+max-parents=10
+max-locals=20
+
+[TYPECHECK]
+ignored-modules=ssl
diff --git a/tests/pytests/requirements.txt b/tests/pytests/requirements.txt
new file mode 100644
index 0000000..6e2e4d2
--- /dev/null
+++ b/tests/pytests/requirements.txt
@@ -0,0 +1,5 @@
+dnspython
+jinja2
+pytest
+pytest-html
+pytest-xdist
diff --git a/tests/pytests/templates/kresd.conf.j2 b/tests/pytests/templates/kresd.conf.j2
new file mode 100644
index 0000000..b87515c
--- /dev/null
+++ b/tests/pytests/templates/kresd.conf.j2
@@ -0,0 +1,62 @@
+-- SPDX-License-Identifier: GPL-3.0-or-later
+modules = {
+ 'hints > policy',
+ 'policy > iterate',
+}
+
+{% if kresd.verbose %}
+log_level('debug')
+{% endif %}
+
+{% if kresd.ip %}
+net.listen('{{ kresd.ip }}', {{ kresd.port }})
+net.listen('{{ kresd.ip }}', {{ kresd.tls_port }}, {tls = true})
+{% endif %}
+
+{% if kresd.ip6 %}
+net.listen('{{ kresd.ip6 }}', {{ kresd.port }})
+net.listen('{{ kresd.ip6 }}', {{ kresd.tls_port }}, {tls = true})
+{% endif %}
+
+net.ipv4=true
+net.ipv6=true
+
+{% if kresd.tls_key_path and kresd.tls_cert_path %}
+net.tls("{{ kresd.tls_cert_path }}", "{{ kresd.tls_key_path }}")
+net.tls_sticket_secret('0123456789ABCDEF0123456789ABCDEF')
+{% endif %}
+
+hints['localhost.'] = '127.0.0.1'
+{% for name, ip in kresd.hints.items() %}
+hints['{{ name }}'] = '{{ ip }}'
+{% endfor %}
+
+policy.add(policy.all(policy.QTRACE))
+
+{% if kresd.forward %}
+policy.add(policy.all(
+ {% if kresd.forward.proto == 'tls' %}
+ policy.TLS_FORWARD({
+ {"{{ kresd.forward.ip }}@{{ kresd.forward.port }}", hostname='{{ kresd.forward.hostname}}', ca_file='{{ kresd.forward.ca_file }}'}})
+ {% endif %}
+))
+{% endif %}
+
+{% if kresd.policy_test_pass %}
+policy.add(policy.suffix(policy.PASS, {todname('test.')}))
+{% endif %}
+
+-- EDNS EDE tests
+policy.add(policy.suffix(policy.DENY, {todname('deny.test.')}))
+policy.add(policy.suffix(policy.REFUSE, {todname('refuse.test.')}))
+policy.add(policy.suffix(policy.ANSWER({ [kres.type.A] = { rdata=kres.str2ip('192.0.2.7'), ttl=300 } }), {todname('forge.test.')}))
+
+-- make sure DNSSEC is turned off for tests
+trust_anchors.remove('.')
+modules.unload("ta_update")
+modules.unload("ta_signal_query")
+modules.unload("priming")
+modules.unload("detect_time_skew")
+
+-- choose a small cache, since it is preallocated
+cache.size = 1 * MB
diff --git a/tests/pytests/test_conn_mgmt.py b/tests/pytests/test_conn_mgmt.py
new file mode 100644
index 0000000..1d15091
--- /dev/null
+++ b/tests/pytests/test_conn_mgmt.py
@@ -0,0 +1,214 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+"""TCP Connection Management tests"""
+
+import socket
+import struct
+import time
+
+import pytest
+
+import utils
+
+
+@pytest.mark.parametrize('garbage_lengths', [
+ (1,),
+ (1024,),
+ (65533,), # max size garbage
+ (65533, 65533),
+ (1024, 1024, 1024),
+ # (0,), # currently kresd uses this as a heuristic of "lost in bytestream"
+ # (0, 1024), # and closes the connection
+])
+def test_ignore_garbage(kresd_sock, garbage_lengths, single_buffer, query_before):
+ """Send chunk of garbage, prefixed by garbage length. It should be ignored."""
+ buff = b''
+ if query_before: # optionally send initial query
+ msg_buff_before, msgid_before = utils.get_msgbuff()
+ if single_buffer:
+ buff += msg_buff_before
+ else:
+ kresd_sock.sendall(msg_buff_before)
+
+ for glength in garbage_lengths: # prepare garbage data
+ if glength is None:
+ continue
+ garbage_buff = utils.get_prefixed_garbage(glength)
+ if single_buffer:
+ buff += garbage_buff
+ else:
+ kresd_sock.sendall(garbage_buff)
+
+ msg_buff, msgid = utils.get_msgbuff() # final query
+ buff += msg_buff
+ kresd_sock.sendall(buff)
+
+ if query_before:
+ answer_before = utils.receive_parse_answer(kresd_sock)
+ assert answer_before.id == msgid_before
+ answer = utils.receive_parse_answer(kresd_sock)
+ assert answer.id == msgid
+
+
+def test_pipelining(kresd_sock):
+ """
+ First query takes longer to resolve - answer to second query should arrive sooner.
+
+ This test requires internet connection.
+ """
+ # initialization (to avoid issues with net.ipv6=true)
+ buff_pre, msgid_pre = utils.get_msgbuff('0.delay.getdnsapi.net.')
+ kresd_sock.sendall(buff_pre)
+ msg_answer = utils.receive_parse_answer(kresd_sock)
+ assert msg_answer.id == msgid_pre
+
+ # test
+ buff1, msgid1 = utils.get_msgbuff('1500.delay.getdnsapi.net.', msgid=1)
+ buff2, msgid2 = utils.get_msgbuff('1.delay.getdnsapi.net.', msgid=2)
+ buff = buff1 + buff2
+ kresd_sock.sendall(buff)
+
+ msg_answer = utils.receive_parse_answer(kresd_sock)
+ assert msg_answer.id == msgid2
+
+ msg_answer = utils.receive_parse_answer(kresd_sock)
+ assert msg_answer.id == msgid1
+
+
+@pytest.mark.parametrize('duration, delay', [
+ (utils.MAX_TIMEOUT, 0.1),
+ (utils.MAX_TIMEOUT, 3),
+ (utils.MAX_TIMEOUT, 7),
+ (utils.MAX_TIMEOUT + 10, 3),
+])
+def test_long_lived(kresd_sock, duration, delay):
+ """Establish and keep connection alive for longer than maximum timeout."""
+ utils.ping_alive(kresd_sock)
+ end_time = time.time() + duration
+
+ while time.time() < end_time:
+ time.sleep(delay)
+ utils.ping_alive(kresd_sock)
+
+
+def test_close(kresd_sock, query_before):
+ """Establish a connection and wait for timeout from kresd."""
+ if query_before:
+ utils.ping_alive(kresd_sock)
+ time.sleep(utils.MAX_TIMEOUT)
+
+ with utils.expect_kresd_close():
+ utils.ping_alive(kresd_sock)
+
+
+def test_slow_lorris(kresd_sock, query_before):
+ """Simulate slow-lorris attack by sending byte after byte with delays in between."""
+ if query_before:
+ utils.ping_alive(kresd_sock)
+
+ buff, _ = utils.get_msgbuff()
+ end_time = time.time() + utils.MAX_TIMEOUT
+
+ with utils.expect_kresd_close():
+ for i in range(len(buff)):
+ b = buff[i:i+1]
+ kresd_sock.send(b)
+ if time.time() > end_time:
+ break
+ time.sleep(1)
+
+
+@pytest.mark.parametrize('sock_func_name', [
+ 'ip_tcp_socket',
+ 'ip6_tcp_socket',
+])
+def test_oob(kresd, sock_func_name):
+ """TCP out-of-band (urgent) data must not crash resolver."""
+ make_sock = getattr(kresd, sock_func_name)
+ sock = make_sock()
+ msg_buff, msgid = utils.get_msgbuff()
+ sock.sendall(msg_buff, socket.MSG_OOB)
+
+ try:
+ msg_answer = utils.receive_parse_answer(sock)
+ assert msg_answer.id == msgid
+ except ConnectionError:
+ pass # TODO kresd responds with TCP RST, this should be fixed
+
+ # check kresd is alive
+ sock2 = make_sock()
+ utils.ping_alive(sock2)
+
+
+def flood_buffer(msgcount):
+ flood_buff = bytes()
+ msgbuff, _ = utils.get_msgbuff()
+ noid_msgbuff = msgbuff[2:]
+
+ def gen_msg(msgid):
+ return struct.pack("!H", len(msgbuff)) + struct.pack("!H", msgid) + noid_msgbuff
+
+ for i in range(msgcount):
+ flood_buff += gen_msg(i)
+ return flood_buff
+
+
+def test_query_flood_close(make_kresd_sock):
+ """Flood resolver with queries and close the connection."""
+ buff = flood_buffer(10000)
+ sock1 = make_kresd_sock()
+ sock1.sendall(buff)
+ sock1.close()
+
+ sock2 = make_kresd_sock()
+ utils.ping_alive(sock2)
+
+
+def test_query_flood_no_recv(make_kresd_sock):
+ """Flood resolver with queries but don't read any data."""
+ # A use-case for TCP_USER_TIMEOUT socket option? See RFC 793 and RFC 5482
+
+ # It seems it doesn't works as expected. libuv doesn't return any error
+ # (neither on uv_write() call, not in the callback) when kresd sends answers,
+ # so kresd can't recognize that client didn't read any answers. At a certain
+ # point, kresd stops receiving queries from the client (whilst client keep
+ # sending) and closes connection due to timeout.
+
+ buff = flood_buffer(10000)
+ sock1 = make_kresd_sock()
+ end_time = time.time() + utils.MAX_TIMEOUT
+
+ with utils.expect_kresd_close(rst_ok=True): # connection must be closed
+ while time.time() < end_time:
+ sock1.sendall(buff)
+ time.sleep(0.5)
+
+ sock2 = make_kresd_sock()
+ utils.ping_alive(sock2) # resolver must stay alive
+
+
+@pytest.mark.parametrize('glength, gcount, delay', [
+ (65533, 100, 0.5),
+ (0, 100000, 0.5),
+ (1024, 1000, 0.5),
+ (65533, 1, 0),
+ (0, 1, 0),
+ (1024, 1, 0),
+])
+def test_query_flood_garbage(make_kresd_silent_sock, glength, gcount, delay, query_before):
+ """Flood resolver with prefixed garbage."""
+ sock1 = make_kresd_silent_sock()
+ if query_before:
+ utils.ping_alive(sock1)
+
+ gbuff = utils.get_prefixed_garbage(glength)
+ buff = gbuff * gcount
+
+ end_time = time.time() + utils.MAX_TIMEOUT
+
+ with utils.expect_kresd_close(rst_ok=True): # connection must be closed
+ while time.time() < end_time:
+ sock1.sendall(buff)
+ time.sleep(delay)
+
+ sock2 = make_kresd_silent_sock()
+ utils.ping_alive(sock2) # resolver must stay alive
diff --git a/tests/pytests/test_edns.py b/tests/pytests/test_edns.py
new file mode 100644
index 0000000..b19af37
--- /dev/null
+++ b/tests/pytests/test_edns.py
@@ -0,0 +1,22 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+"""EDNS tests"""
+
+import dns
+import pytest
+
+import utils
+
+
+@pytest.mark.parametrize('dname, code, text', [
+ ('deny.test.', dns.edns.EDECode.BLOCKED, 'CR36'),
+ ('refuse.test.', dns.edns.EDECode.PROHIBITED, 'EIM4'),
+ ('forge.test.', dns.edns.EDECode.FORGED_ANSWER, '5DO5'),
+])
+def test_edns_ede(kresd_sock, dname, code, text):
+ """Check that kresd responds with EDNS EDE codes in selected cases."""
+ buff, msgid = utils.get_msgbuff(dname)
+ kresd_sock.sendall(buff)
+ answer = utils.receive_parse_answer(kresd_sock)
+ assert answer.id == msgid
+ assert answer.options[0].code == code
+ assert answer.options[0].text == text
diff --git a/tests/pytests/test_prefix.py b/tests/pytests/test_prefix.py
new file mode 100644
index 0000000..2b576b6
--- /dev/null
+++ b/tests/pytests/test_prefix.py
@@ -0,0 +1,114 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+"""TCP Connection Management tests - prefix length
+
+RFC1035
+4.2.2. TCP usage
+The message is prefixed with a two byte length field which gives the message
+length, excluding the two byte length field.
+
+The following test suite focuses on edge cases for the prefix - when it
+is either too short or too long, instead of matching the length of DNS
+message exactly.
+"""
+
+import time
+
+import pytest
+
+import utils
+
+
+@pytest.fixture(params=[
+ 'no_query_before',
+ 'query_before',
+ 'query_before_in_single_buffer',
+])
+def send_query(request):
+ """Function sends a buffer, either by itself, or with a valid query before.
+ If a valid query is sent before, it can be sent either in a separate buffer, or
+ along with the provided buffer."""
+
+ # pylint: disable=possibly-unused-variable
+
+ def no_query_before(sock, buff): # pylint: disable=unused-argument
+ sock.sendall(buff)
+
+ def query_before(sock, buff, single_buffer=False):
+ """Send an initial query and expect a response."""
+ msg_buff, msgid = utils.get_msgbuff()
+
+ if single_buffer:
+ sock.sendall(msg_buff + buff)
+ else:
+ sock.sendall(msg_buff)
+ sock.sendall(buff)
+
+ answer = utils.receive_parse_answer(sock)
+ assert answer.id == msgid
+
+ def query_before_in_single_buffer(sock, buff):
+ return query_before(sock, buff, single_buffer=True)
+
+ return locals()[request.param]
+
+
+@pytest.mark.parametrize('datalen', [
+ 1, # just one byte of DNS header
+ 11, # DNS header size minus 1
+ 14, # DNS Header size plus 2
+])
+def test_prefix_cuts_message(kresd_sock, datalen, send_query):
+ """Prefix is shorter than the DNS message."""
+ wire, _ = utils.prepare_wire()
+ assert datalen < len(wire)
+ invalid_buff = utils.prepare_buffer(wire, datalen)
+
+ send_query(kresd_sock, invalid_buff) # buffer breaks parsing of TCP stream
+
+ with utils.expect_kresd_close():
+ utils.ping_alive(kresd_sock)
+
+
+def test_prefix_greater_than_message(kresd_sock, send_query):
+ """Prefix is greater than the length of the entire DNS message."""
+ wire, invalid_msgid = utils.prepare_wire()
+ datalen = len(wire) + 16
+ invalid_buff = utils.prepare_buffer(wire, datalen)
+
+ send_query(kresd_sock, invalid_buff)
+
+ valid_buff, _ = utils.get_msgbuff()
+ kresd_sock.sendall(valid_buff)
+
+ # invalid_buff is answered (treats additional data as trailing garbage)
+ answer = utils.receive_parse_answer(kresd_sock)
+ assert answer.id == invalid_msgid
+
+ # parsing stream is broken by the invalid_buff, valid query is never answered
+ with utils.expect_kresd_close():
+ utils.receive_parse_answer(kresd_sock)
+
+
+@pytest.mark.parametrize('glength', [
+ 0,
+ 1,
+ 8,
+ 1024,
+ 4096,
+ 20000,
+])
+def test_prefix_trailing_garbage(kresd_sock, glength, query_before):
+ """Send messages with trailing garbage (its length included in prefix)."""
+ if query_before:
+ utils.ping_alive(kresd_sock)
+
+ for _ in range(10):
+ wire, msgid = utils.prepare_wire()
+ wire += utils.get_garbage(glength)
+ buff = utils.prepare_buffer(wire)
+
+ kresd_sock.sendall(buff)
+ answer = utils.receive_parse_answer(kresd_sock)
+ assert answer.id == msgid
+
+ time.sleep(0.1)
diff --git a/tests/pytests/test_random_close.py b/tests/pytests/test_random_close.py
new file mode 100644
index 0000000..a7cc877
--- /dev/null
+++ b/tests/pytests/test_random_close.py
@@ -0,0 +1,54 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+"""TLS test when forward target closes connection after one second
+
+Test utilizes TLS proxy, which forwards queries to configured
+resolver, but closes the connection 1s after establishing.
+
+Kresd must stay alive and be able to answer queries.
+"""
+
+import random
+import string
+import time
+
+from proxy import HINTS, kresd_tls_client, resolve_hint, TLSProxy
+import utils
+
+
+QPS = 500
+
+
+def random_string(size=32, chars=(string.ascii_lowercase + string.digits)):
+ return ''.join(random.choice(chars) for x in range(size))
+
+
+def rsa_cannon(sock, duration, domain='test.', qps=QPS):
+ end_time = time.time() + duration
+
+ while time.time() < end_time:
+ next_time = time.time() + 1/qps
+ buff, _ = utils.get_msgbuff('{}.{}'.format(random_string(), domain))
+ sock.sendall(buff)
+ time_left = next_time - time.time()
+ if time_left > 0:
+ time.sleep(time_left)
+
+
+def test_proxy_random_close(tmpdir):
+ proxy = TLSProxy(close=1000)
+
+ kresd_tls_client_kwargs = {
+ 'verbose': False,
+ 'policy_test_pass': True
+ }
+ kresd_fwd_target_kwargs = {
+ 'verbose': False
+ }
+ with kresd_tls_client(str(tmpdir), proxy, kresd_tls_client_kwargs, kresd_fwd_target_kwargs) \
+ as kresd:
+ sock2 = kresd.ip_tcp_socket()
+ rsa_cannon(sock2, 20)
+ sock3 = kresd.ip_tcp_socket()
+ for hint in HINTS:
+ resolve_hint(sock3, hint)
+ time.sleep(0.1)
diff --git a/tests/pytests/test_rehandshake.py b/tests/pytests/test_rehandshake.py
new file mode 100644
index 0000000..f07ba58
--- /dev/null
+++ b/tests/pytests/test_rehandshake.py
@@ -0,0 +1,52 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+"""TLS rehandshake test
+
+Test is using TLS proxy with rehandshake. When queries are sent, they are
+simply forwarded. When the responses are sent back, a rehandshake is performed
+after every 8 bytes.
+
+It is expected the answer will be received by the source kresd instance
+and sent back to the client (this test).
+"""
+
+import re
+import time
+
+import pytest
+
+from proxy import HINTS, kresd_tls_client, resolve_hint, TLSProxy
+
+
+def verify_rehandshake(tmpdir, proxy):
+ with kresd_tls_client(str(tmpdir), proxy) as kresd:
+ sock2 = kresd.ip_tcp_socket()
+ try:
+ for hint in HINTS:
+ resolve_hint(sock2, hint)
+ time.sleep(0.1)
+ finally:
+ # verify log
+ n_connecting_to = 0
+ n_rehandshake = 0
+ partial_log = kresd.partial_log()
+ print(partial_log)
+ for line in partial_log.splitlines():
+ if re.search(r"connecting to: .*", line) is not None:
+ n_connecting_to += 1
+ elif re.search(r"TLS rehandshake .* has started", line) is not None:
+ n_rehandshake += 1
+ assert n_connecting_to == 1 # should connect exactly once
+ assert n_rehandshake > 0
+
+
+def test_proxy_rehandshake_tls12(tmpdir):
+ proxy = TLSProxy(rehandshake=True)
+ verify_rehandshake(tmpdir, proxy)
+
+
+# TODO fix TLS v1.3 proxy / kresd rehandshake
+@pytest.mark.xfail(
+ reason="TLS 1.3 rehandshake isn't properly supported either in tlsproxy or in kresd")
+def test_proxy_rehandshake_tls13(tmpdir):
+ proxy = TLSProxy(rehandshake=True, force_tls13=True)
+ verify_rehandshake(tmpdir, proxy)
diff --git a/tests/pytests/test_tls.py b/tests/pytests/test_tls.py
new file mode 100644
index 0000000..3e1328a
--- /dev/null
+++ b/tests/pytests/test_tls.py
@@ -0,0 +1,83 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+"""TLS-specific tests"""
+
+import itertools
+import os
+from socket import AF_INET, AF_INET6
+import ssl
+import sys
+
+import pytest
+
+from kresd import make_kresd
+import utils
+
+
+def test_tls_no_cert(kresd, sock_family):
+ """Use TLS without certificates."""
+ sock, dest = kresd.stream_socket(sock_family, tls=True)
+ ctx = utils.make_ssl_context(insecure=True)
+ ssock = ctx.wrap_socket(sock)
+ ssock.connect(dest)
+
+ utils.ping_alive(ssock)
+
+
+def test_tls_selfsigned_cert(kresd_tt, sock_family):
+ """Use TLS with a self signed certificate."""
+ sock, dest = kresd_tt.stream_socket(sock_family, tls=True)
+ ctx = utils.make_ssl_context(verify_location=kresd_tt.tls_cert_path)
+ ssock = ctx.wrap_socket(sock, server_hostname='transport-test-server.com')
+ ssock.connect(dest)
+
+ utils.ping_alive(ssock)
+
+
+def test_tls_cert_hostname_mismatch(kresd_tt, sock_family):
+ """Attempt to use self signed certificate and incorrect hostname."""
+ sock, dest = kresd_tt.stream_socket(sock_family, tls=True)
+ ctx = utils.make_ssl_context(verify_location=kresd_tt.tls_cert_path)
+ ssock = ctx.wrap_socket(sock, server_hostname='wrong-host-name')
+
+ with pytest.raises(ssl.CertificateError):
+ ssock.connect(dest)
+
+
+@pytest.mark.skipif(sys.version_info < (3, 6),
+ reason="requires python3.6 or higher")
+@pytest.mark.parametrize('sf1, sf2, sf3', itertools.product(
+ [AF_INET, AF_INET6], [AF_INET, AF_INET6], [AF_INET, AF_INET6]))
+def test_tls_session_resumption(tmpdir, sf1, sf2, sf3):
+ """Attempt TLS session resumption against the same kresd instance and a different one."""
+ # TODO ensure that session can't be resumed after session ticket key regeneration
+ # at the first kresd instance
+
+ # NOTE TLS 1.3 is intentionally disabled for session resumption tests,
+ # because python's SSLSocket.session isn't compatible with TLS 1.3
+ # https://docs.python.org/3/library/ssl.html?highlight=ssl%20ticket#tls-1-3
+
+ def connect(kresd, ctx, sf, session=None):
+ sock, dest = kresd.stream_socket(sf, tls=True)
+ ssock = ctx.wrap_socket(
+ sock, server_hostname='transport-test-server.com', session=session)
+ ssock.connect(dest)
+ new_session = ssock.session
+ assert new_session.has_ticket
+ assert ssock.session_reused == (session is not None)
+ utils.ping_alive(ssock)
+ ssock.close()
+ return new_session
+
+ workdir = os.path.join(str(tmpdir), 'kresd')
+ os.makedirs(workdir)
+
+ with make_kresd(workdir, 'tt') as kresd:
+ ctx = utils.make_ssl_context(
+ verify_location=kresd.tls_cert_path, extra_options=[ssl.OP_NO_TLSv1_3])
+ session = connect(kresd, ctx, sf1) # initial conn
+ connect(kresd, ctx, sf2, session) # resume session on the same instance
+
+ workdir2 = os.path.join(str(tmpdir), 'kresd2')
+ os.makedirs(workdir2)
+ with make_kresd(workdir2, 'tt') as kresd2:
+ connect(kresd2, ctx, sf3, session) # resume session on a different instance
diff --git a/tests/pytests/utils.py b/tests/pytests/utils.py
new file mode 100644
index 0000000..4b995d4
--- /dev/null
+++ b/tests/pytests/utils.py
@@ -0,0 +1,136 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+from contextlib import contextmanager
+import random
+import ssl
+import struct
+import time
+
+import dns
+import dns.message
+import pytest
+
+
+# default net.tcp_in_idle is 10s, TCP_DEFER_ACCEPT 3s, some extra for
+# Python handling / edge cases
+MAX_TIMEOUT = 16
+
+
+def receive_answer(sock):
+ answer_total_len = 0
+ data = sock.recv(2)
+ if not data:
+ return None
+ answer_total_len = struct.unpack_from("!H", data)[0]
+
+ answer_received_len = 0
+ data_answer = b''
+ while answer_received_len < answer_total_len:
+ data_chunk = sock.recv(answer_total_len - answer_received_len)
+ if not data_chunk:
+ return None
+ data_answer += data_chunk
+ answer_received_len += len(data_answer)
+
+ return data_answer
+
+
+def receive_parse_answer(sock):
+ data_answer = receive_answer(sock)
+
+ if data_answer is None:
+ raise BrokenPipeError("kresd closed connection")
+
+ msg_answer = dns.message.from_wire(data_answer, one_rr_per_rrset=True)
+ return msg_answer
+
+
+def prepare_wire(
+ qname='localhost.',
+ qtype=dns.rdatatype.A,
+ qclass=dns.rdataclass.IN,
+ msgid=None):
+ """Utility function to generate DNS wire format message"""
+ msg = dns.message.make_query(qname, qtype, qclass, use_edns=True)
+ if msgid is not None:
+ msg.id = msgid
+ return msg.to_wire(), msg.id
+
+
+def prepare_buffer(wire, datalen=None):
+ """Utility function to prepare TCP buffer from DNS message in wire format"""
+ assert isinstance(wire, bytes)
+ if datalen is None:
+ datalen = len(wire)
+ return struct.pack("!H", datalen) + wire
+
+
+def get_msgbuff(qname='localhost.', qtype=dns.rdatatype.A, msgid=None):
+ wire, msgid = prepare_wire(qname, qtype, msgid=msgid)
+ buff = prepare_buffer(wire)
+ return buff, msgid
+
+
+def get_garbage(length):
+ return bytes(random.getrandbits(8) for _ in range(length))
+
+
+def get_prefixed_garbage(length):
+ data = get_garbage(length)
+ return prepare_buffer(data)
+
+
+def try_ping_alive(sock, msgid=None, close=False):
+ try:
+ ping_alive(sock, msgid)
+ except AssertionError:
+ return False
+ finally:
+ if close:
+ sock.close()
+ return True
+
+
+def ping_alive(sock, msgid=None):
+ buff, msgid = get_msgbuff(msgid=msgid)
+ sock.sendall(buff)
+ answer = receive_parse_answer(sock)
+ assert answer.id == msgid
+
+
+@contextmanager
+def expect_kresd_close(rst_ok=False):
+ with pytest.raises(BrokenPipeError):
+ try:
+ time.sleep(0.2) # give kresd time to close connection with TCP FIN
+ yield
+ except ConnectionResetError as ex:
+ if rst_ok:
+ raise BrokenPipeError from ex
+ pytest.skip("kresd closed connection with TCP RST")
+ pytest.fail("kresd didn't close the connection")
+
+
+def make_ssl_context(insecure=False, verify_location=None, extra_options=None):
+ # set TLS v1.2+
+ context = ssl.SSLContext(ssl.PROTOCOL_TLS)
+ context.options |= ssl.OP_NO_SSLv2
+ context.options |= ssl.OP_NO_SSLv3
+ context.options |= ssl.OP_NO_TLSv1
+ context.options |= ssl.OP_NO_TLSv1_1
+
+ if extra_options is not None:
+ for option in extra_options:
+ context.options |= option
+
+ if insecure:
+ # turn off certificate verification
+ context.check_hostname = False
+ context.verify_mode = ssl.CERT_NONE
+ else:
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.check_hostname = True
+
+ if verify_location is not None:
+ context.load_verify_locations(verify_location)
+
+ return context