From 19fcec84d8d7d21e796c7624e521b60d28ee21ed Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 7 Apr 2024 20:45:59 +0200 Subject: Adding upstream version 16.2.11+ds. Signed-off-by: Daniel Baumann --- .../thrift/lib/py/test/_import_local_thrift.py | 30 ++ .../thrift/lib/py/test/test_sslsocket.py | 353 +++++++++++++++++++++ .../thrift/lib/py/test/thrift_json.py | 51 +++ 3 files changed, 434 insertions(+) create mode 100644 src/jaegertracing/thrift/lib/py/test/_import_local_thrift.py create mode 100644 src/jaegertracing/thrift/lib/py/test/test_sslsocket.py create mode 100644 src/jaegertracing/thrift/lib/py/test/thrift_json.py (limited to 'src/jaegertracing/thrift/lib/py/test') diff --git a/src/jaegertracing/thrift/lib/py/test/_import_local_thrift.py b/src/jaegertracing/thrift/lib/py/test/_import_local_thrift.py new file mode 100644 index 000000000..d22312298 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/test/_import_local_thrift.py @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import glob +import os +import sys + +SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__)) +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR))) + +for libpath in glob.glob(os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*')): + if libpath.endswith('-%d.%d' % (sys.version_info[0], sys.version_info[1])): + sys.path.insert(0, libpath) + break diff --git a/src/jaegertracing/thrift/lib/py/test/test_sslsocket.py b/src/jaegertracing/thrift/lib/py/test/test_sslsocket.py new file mode 100644 index 000000000..f4c87f195 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/test/test_sslsocket.py @@ -0,0 +1,353 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import inspect +import logging +import os +import platform +import ssl +import sys +import tempfile +import threading +import unittest +import warnings +from contextlib import contextmanager + +import _import_local_thrift # noqa + +SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__)) +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR))) +SERVER_PEM = os.path.join(ROOT_DIR, 'test', 'keys', 'server.pem') +SERVER_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'server.crt') +SERVER_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'server.key') +CLIENT_CERT_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.crt') +CLIENT_KEY_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.key') +CLIENT_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.crt') +CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.key') +CLIENT_CA = os.path.join(ROOT_DIR, 'test', 'keys', 'CA.pem') + +TEST_CIPHERS = 'DES-CBC3-SHA:ECDHE-RSA-AES128-GCM-SHA256' + + +class ServerAcceptor(threading.Thread): + def __init__(self, server, expect_failure=False): + super(ServerAcceptor, self).__init__() + self.daemon = True + self._server = server + self._listening = threading.Event() + self._port = None + self._port_bound = threading.Event() + self._client = None + self._client_accepted = threading.Event() + self._expect_failure = expect_failure + frame = inspect.stack(3)[2] + self.name = frame[3] + del frame + + def run(self): + self._server.listen() + self._listening.set() + + try: + address = self._server.handle.getsockname() + if len(address) > 1: + # AF_INET addresses are 2-tuples (host, port) and AF_INET6 are + # 4-tuples (host, port, ...), but in each case port is in the second slot. + self._port = address[1] + finally: + self._port_bound.set() + + try: + self._client = self._server.accept() + if self._client: + self._client.read(5) # hello + self._client.write(b"there") + except Exception: + logging.exception('error on server side (%s):' % self.name) + if not self._expect_failure: + raise + finally: + self._client_accepted.set() + + def await_listening(self): + self._listening.wait() + + @property + def port(self): + self._port_bound.wait() + return self._port + + @property + def client(self): + self._client_accepted.wait() + return self._client + + def close(self): + if self._client: + self._client.close() + self._server.close() + + +# Python 2.6 compat +class AssertRaises(object): + def __init__(self, expected): + self._expected = expected + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_value, traceback): + if not exc_type or not issubclass(exc_type, self._expected): + raise Exception('fail') + return True + + +class TSSLSocketTest(unittest.TestCase): + def _server_socket(self, **kwargs): + return TSSLServerSocket(port=0, **kwargs) + + @contextmanager + def _connectable_client(self, server, expect_failure=False, path=None, **client_kwargs): + acc = ServerAcceptor(server, expect_failure) + try: + acc.start() + acc.await_listening() + + host, port = ('localhost', acc.port) if path is None else (None, None) + client = TSSLSocket(host, port, unix_socket=path, **client_kwargs) + yield acc, client + finally: + acc.close() + + def _assert_connection_failure(self, server, path=None, **client_args): + logging.disable(logging.CRITICAL) + try: + with self._connectable_client(server, True, path=path, **client_args) as (acc, client): + # We need to wait for a connection failure, but not too long. 20ms is a tunable + # compromise between test speed and stability + client.setTimeout(20) + with self._assert_raises(TTransportException): + client.open() + client.write(b"hello") + client.read(5) # b"there" + finally: + logging.disable(logging.NOTSET) + + def _assert_raises(self, exc): + if sys.hexversion >= 0x020700F0: + return self.assertRaises(exc) + else: + return AssertRaises(exc) + + def _assert_connection_success(self, server, path=None, **client_args): + with self._connectable_client(server, path=path, **client_args) as (acc, client): + try: + client.open() + client.write(b"hello") + self.assertEqual(client.read(5), b"there") + self.assertTrue(acc.client is not None) + finally: + client.close() + + # deprecated feature + def test_deprecation(self): + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) + TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT) + self.assertEqual(len(w), 1) + + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) + # Deprecated signature + # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None): + TSSLSocket('localhost', 0, True, SERVER_CERT, CLIENT_KEY, CLIENT_CERT, None, TEST_CIPHERS) + self.assertEqual(len(w), 7) + + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) + # Deprecated signature + # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None): + TSSLServerSocket(None, 0, SERVER_PEM, None, TEST_CIPHERS) + self.assertEqual(len(w), 3) + + # deprecated feature + def test_set_cert_reqs_by_validate(self): + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) + c1 = TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT) + self.assertEqual(c1.cert_reqs, ssl.CERT_REQUIRED) + + c1 = TSSLSocket('localhost', 0, validate=False) + self.assertEqual(c1.cert_reqs, ssl.CERT_NONE) + + self.assertEqual(len(w), 2) + + # deprecated feature + def test_set_validate_by_cert_reqs(self): + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) + c1 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_NONE) + self.assertFalse(c1.validate) + + c2 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) + self.assertTrue(c2.validate) + + c3 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_OPTIONAL, ca_certs=SERVER_CERT) + self.assertTrue(c3.validate) + + self.assertEqual(len(w), 3) + + def test_unix_domain_socket(self): + if platform.system() == 'Windows': + print('skipping test_unix_domain_socket') + return + fd, path = tempfile.mkstemp() + os.close(fd) + os.unlink(path) + try: + server = self._server_socket(unix_socket=path, keyfile=SERVER_KEY, certfile=SERVER_CERT) + self._assert_connection_success(server, path=path, cert_reqs=ssl.CERT_NONE) + finally: + os.unlink(path) + + def test_server_cert(self): + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) + self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) + + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) + # server cert not in ca_certs + self._assert_connection_failure(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=CLIENT_CERT) + + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) + self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE) + + def test_set_server_cert(self): + server = self._server_socket(keyfile=SERVER_KEY, certfile=CLIENT_CERT) + with self._assert_raises(Exception): + server.certfile = 'foo' + with self._assert_raises(Exception): + server.certfile = None + server.certfile = SERVER_CERT + self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) + + def test_client_cert(self): + if not _match_has_ipaddress: + print('skipping test_client_cert') + return + server = self._server_socket( + cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY, + certfile=SERVER_CERT, ca_certs=CLIENT_CERT) + self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=SERVER_CERT, keyfile=SERVER_KEY) + + server = self._server_socket( + cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY, + certfile=SERVER_CERT, ca_certs=CLIENT_CA) + self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT_NO_IP, keyfile=CLIENT_KEY_NO_IP) + + server = self._server_socket( + cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY, + certfile=SERVER_CERT, ca_certs=CLIENT_CA) + self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY) + + server = self._server_socket( + cert_reqs=ssl.CERT_OPTIONAL, keyfile=SERVER_KEY, + certfile=SERVER_CERT, ca_certs=CLIENT_CA) + self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY) + + def test_ciphers(self): + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS) + self._assert_connection_success(server, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS) + + if not TSSLSocket._has_ciphers: + # unittest.skip is not available for Python 2.6 + print('skipping test_ciphers') + return + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) + self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL') + + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS) + self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL') + + def test_ssl2_and_ssl3_disabled(self): + if not hasattr(ssl, 'PROTOCOL_SSLv3'): + print('PROTOCOL_SSLv3 is not available') + else: + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) + self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3) + + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3) + self._assert_connection_failure(server, ca_certs=SERVER_CERT) + + if not hasattr(ssl, 'PROTOCOL_SSLv2'): + print('PROTOCOL_SSLv2 is not available') + else: + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) + self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2) + + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2) + self._assert_connection_failure(server, ca_certs=SERVER_CERT) + + def test_newer_tls(self): + if not TSSLSocket._has_ssl_context: + # unittest.skip is not available for Python 2.6 + print('skipping test_newer_tls') + return + if not hasattr(ssl, 'PROTOCOL_TLSv1_2'): + print('PROTOCOL_TLSv1_2 is not available') + else: + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) + self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) + + if not hasattr(ssl, 'PROTOCOL_TLSv1_1'): + print('PROTOCOL_TLSv1_1 is not available') + else: + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) + self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) + + if not hasattr(ssl, 'PROTOCOL_TLSv1_1') or not hasattr(ssl, 'PROTOCOL_TLSv1_2'): + print('PROTOCOL_TLSv1_1 and/or PROTOCOL_TLSv1_2 is not available') + else: + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) + self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) + + def test_ssl_context(self): + if not TSSLSocket._has_ssl_context: + # unittest.skip is not available for Python 2.6 + print('skipping test_ssl_context') + return + server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + server_context.load_cert_chain(SERVER_CERT, SERVER_KEY) + server_context.load_verify_locations(CLIENT_CA) + server_context.verify_mode = ssl.CERT_REQUIRED + server = self._server_socket(ssl_context=server_context) + + client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY) + client_context.load_verify_locations(SERVER_CERT) + client_context.verify_mode = ssl.CERT_REQUIRED + + self._assert_connection_success(server, ssl_context=client_context) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.WARN) + from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket, _match_has_ipaddress + from thrift.transport.TTransport import TTransportException + + unittest.main() diff --git a/src/jaegertracing/thrift/lib/py/test/thrift_json.py b/src/jaegertracing/thrift/lib/py/test/thrift_json.py new file mode 100644 index 000000000..40e7a47e3 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/test/thrift_json.py @@ -0,0 +1,51 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sys +import unittest + +import _import_local_thrift # noqa +from thrift.protocol.TJSONProtocol import TJSONProtocol +from thrift.transport import TTransport + +# +# In order to run the test under Windows. We need to create symbolic link +# name 'thrift' to '../src' folder by using: +# +# mklink /D thrift ..\src +# + + +class TestJSONString(unittest.TestCase): + + def test_escaped_unicode_string(self): + unicode_json = b'"hello \\u0e01\\u0e02\\u0e03\\ud835\\udcab\\udb40\\udc70 unicode"' + unicode_text = u'hello \u0e01\u0e02\u0e03\U0001D4AB\U000E0070 unicode' + + buf = TTransport.TMemoryBuffer(unicode_json) + transport = TTransport.TBufferedTransportFactory().getTransport(buf) + protocol = TJSONProtocol(transport) + + if sys.version_info[0] == 2: + unicode_text = unicode_text.encode('utf8') + self.assertEqual(protocol.readString(), unicode_text) + + +if __name__ == '__main__': + unittest.main() -- cgit v1.2.3