From 19fcec84d8d7d21e796c7624e521b60d28ee21ed Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 7 Apr 2024 20:45:59 +0200 Subject: Adding upstream version 16.2.11+ds. Signed-off-by: Daniel Baumann --- src/jaegertracing/thrift/lib/py/CMakeLists.txt | 31 + src/jaegertracing/thrift/lib/py/MANIFEST.in | 1 + src/jaegertracing/thrift/lib/py/Makefile.am | 65 ++ src/jaegertracing/thrift/lib/py/README.md | 35 + .../thrift/lib/py/coding_standards.md | 7 + .../thrift/lib/py/compat/win32/stdint.h | 247 ++++++ src/jaegertracing/thrift/lib/py/setup.cfg | 6 + src/jaegertracing/thrift/lib/py/setup.py | 139 ++++ .../thrift/lib/py/src/TMultiplexedProcessor.py | 82 ++ src/jaegertracing/thrift/lib/py/src/TRecursive.py | 83 ++ src/jaegertracing/thrift/lib/py/src/TSCons.py | 36 + .../thrift/lib/py/src/TSerialization.py | 38 + src/jaegertracing/thrift/lib/py/src/TTornado.py | 188 +++++ src/jaegertracing/thrift/lib/py/src/Thrift.py | 204 +++++ src/jaegertracing/thrift/lib/py/src/__init__.py | 20 + src/jaegertracing/thrift/lib/py/src/compat.py | 46 ++ src/jaegertracing/thrift/lib/py/src/ext/binary.cpp | 38 + src/jaegertracing/thrift/lib/py/src/ext/binary.h | 217 +++++ .../thrift/lib/py/src/ext/compact.cpp | 107 +++ src/jaegertracing/thrift/lib/py/src/ext/compact.h | 368 +++++++++ src/jaegertracing/thrift/lib/py/src/ext/endian.h | 96 +++ src/jaegertracing/thrift/lib/py/src/ext/module.cpp | 203 +++++ src/jaegertracing/thrift/lib/py/src/ext/protocol.h | 96 +++ .../thrift/lib/py/src/ext/protocol.tcc | 913 +++++++++++++++++++++ src/jaegertracing/thrift/lib/py/src/ext/types.cpp | 113 +++ src/jaegertracing/thrift/lib/py/src/ext/types.h | 192 +++++ .../thrift/lib/py/src/protocol/TBase.py | 82 ++ .../thrift/lib/py/src/protocol/TBinaryProtocol.py | 301 +++++++ .../thrift/lib/py/src/protocol/TCompactProtocol.py | 487 +++++++++++ .../thrift/lib/py/src/protocol/THeaderProtocol.py | 225 +++++ .../thrift/lib/py/src/protocol/TJSONProtocol.py | 677 +++++++++++++++ .../lib/py/src/protocol/TMultiplexedProtocol.py | 39 + .../thrift/lib/py/src/protocol/TProtocol.py | 422 ++++++++++ .../lib/py/src/protocol/TProtocolDecorator.py | 26 + .../thrift/lib/py/src/protocol/__init__.py | 21 + .../thrift/lib/py/src/server/THttpServer.py | 131 +++ .../thrift/lib/py/src/server/TNonblockingServer.py | 370 +++++++++ .../thrift/lib/py/src/server/TProcessPoolServer.py | 123 +++ .../thrift/lib/py/src/server/TServer.py | 323 ++++++++ .../thrift/lib/py/src/server/__init__.py | 20 + .../lib/py/src/transport/THeaderTransport.py | 352 ++++++++ .../thrift/lib/py/src/transport/THttpClient.py | 187 +++++ .../thrift/lib/py/src/transport/TSSLSocket.py | 408 +++++++++ .../thrift/lib/py/src/transport/TSocket.py | 215 +++++ .../thrift/lib/py/src/transport/TTransport.py | 456 ++++++++++ .../thrift/lib/py/src/transport/TTwisted.py | 329 ++++++++ .../thrift/lib/py/src/transport/TZlibTransport.py | 248 ++++++ .../thrift/lib/py/src/transport/__init__.py | 20 + .../thrift/lib/py/src/transport/sslcompat.py | 100 +++ .../thrift/lib/py/test/_import_local_thrift.py | 30 + .../thrift/lib/py/test/test_sslsocket.py | 353 ++++++++ .../thrift/lib/py/test/thrift_json.py | 51 ++ 52 files changed, 9567 insertions(+) create mode 100644 src/jaegertracing/thrift/lib/py/CMakeLists.txt create mode 100644 src/jaegertracing/thrift/lib/py/MANIFEST.in create mode 100644 src/jaegertracing/thrift/lib/py/Makefile.am create mode 100644 src/jaegertracing/thrift/lib/py/README.md create mode 100644 src/jaegertracing/thrift/lib/py/coding_standards.md create mode 100644 src/jaegertracing/thrift/lib/py/compat/win32/stdint.h create mode 100644 src/jaegertracing/thrift/lib/py/setup.cfg create mode 100644 src/jaegertracing/thrift/lib/py/setup.py create mode 100644 src/jaegertracing/thrift/lib/py/src/TMultiplexedProcessor.py create mode 100644 src/jaegertracing/thrift/lib/py/src/TRecursive.py create mode 100644 src/jaegertracing/thrift/lib/py/src/TSCons.py create mode 100644 src/jaegertracing/thrift/lib/py/src/TSerialization.py create mode 100644 src/jaegertracing/thrift/lib/py/src/TTornado.py create mode 100644 src/jaegertracing/thrift/lib/py/src/Thrift.py create mode 100644 src/jaegertracing/thrift/lib/py/src/__init__.py create mode 100644 src/jaegertracing/thrift/lib/py/src/compat.py create mode 100644 src/jaegertracing/thrift/lib/py/src/ext/binary.cpp create mode 100644 src/jaegertracing/thrift/lib/py/src/ext/binary.h create mode 100644 src/jaegertracing/thrift/lib/py/src/ext/compact.cpp create mode 100644 src/jaegertracing/thrift/lib/py/src/ext/compact.h create mode 100644 src/jaegertracing/thrift/lib/py/src/ext/endian.h create mode 100644 src/jaegertracing/thrift/lib/py/src/ext/module.cpp create mode 100644 src/jaegertracing/thrift/lib/py/src/ext/protocol.h create mode 100644 src/jaegertracing/thrift/lib/py/src/ext/protocol.tcc create mode 100644 src/jaegertracing/thrift/lib/py/src/ext/types.cpp create mode 100644 src/jaegertracing/thrift/lib/py/src/ext/types.h create mode 100644 src/jaegertracing/thrift/lib/py/src/protocol/TBase.py create mode 100644 src/jaegertracing/thrift/lib/py/src/protocol/TBinaryProtocol.py create mode 100644 src/jaegertracing/thrift/lib/py/src/protocol/TCompactProtocol.py create mode 100644 src/jaegertracing/thrift/lib/py/src/protocol/THeaderProtocol.py create mode 100644 src/jaegertracing/thrift/lib/py/src/protocol/TJSONProtocol.py create mode 100644 src/jaegertracing/thrift/lib/py/src/protocol/TMultiplexedProtocol.py create mode 100644 src/jaegertracing/thrift/lib/py/src/protocol/TProtocol.py create mode 100644 src/jaegertracing/thrift/lib/py/src/protocol/TProtocolDecorator.py create mode 100644 src/jaegertracing/thrift/lib/py/src/protocol/__init__.py create mode 100644 src/jaegertracing/thrift/lib/py/src/server/THttpServer.py create mode 100644 src/jaegertracing/thrift/lib/py/src/server/TNonblockingServer.py create mode 100644 src/jaegertracing/thrift/lib/py/src/server/TProcessPoolServer.py create mode 100644 src/jaegertracing/thrift/lib/py/src/server/TServer.py create mode 100644 src/jaegertracing/thrift/lib/py/src/server/__init__.py create mode 100644 src/jaegertracing/thrift/lib/py/src/transport/THeaderTransport.py create mode 100644 src/jaegertracing/thrift/lib/py/src/transport/THttpClient.py create mode 100644 src/jaegertracing/thrift/lib/py/src/transport/TSSLSocket.py create mode 100644 src/jaegertracing/thrift/lib/py/src/transport/TSocket.py create mode 100644 src/jaegertracing/thrift/lib/py/src/transport/TTransport.py create mode 100644 src/jaegertracing/thrift/lib/py/src/transport/TTwisted.py create mode 100644 src/jaegertracing/thrift/lib/py/src/transport/TZlibTransport.py create mode 100644 src/jaegertracing/thrift/lib/py/src/transport/__init__.py create mode 100644 src/jaegertracing/thrift/lib/py/src/transport/sslcompat.py create mode 100644 src/jaegertracing/thrift/lib/py/test/_import_local_thrift.py create mode 100644 src/jaegertracing/thrift/lib/py/test/test_sslsocket.py create mode 100644 src/jaegertracing/thrift/lib/py/test/thrift_json.py (limited to 'src/jaegertracing/thrift/lib/py') diff --git a/src/jaegertracing/thrift/lib/py/CMakeLists.txt b/src/jaegertracing/thrift/lib/py/CMakeLists.txt new file mode 100644 index 000000000..7bb91fe67 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/CMakeLists.txt @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +include_directories(${PYTHON_INCLUDE_DIRS}) + +add_custom_target(python_build ALL + COMMAND ${PYTHON_EXECUTABLE} setup.py build + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + COMMENT "Building Python library" +) + +if(BUILD_TESTING) + add_test(PythonTestSSLSocket ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test/test_sslsocket.py) + add_test(PythonThriftJson ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test/thrift_json.py) +endif() diff --git a/src/jaegertracing/thrift/lib/py/MANIFEST.in b/src/jaegertracing/thrift/lib/py/MANIFEST.in new file mode 100644 index 000000000..af54e29dc --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/MANIFEST.in @@ -0,0 +1 @@ +include src/ext/* diff --git a/src/jaegertracing/thrift/lib/py/Makefile.am b/src/jaegertracing/thrift/lib/py/Makefile.am new file mode 100644 index 000000000..46e44054b --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/Makefile.am @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +AUTOMAKE_OPTIONS = serial-tests +DESTDIR ?= / + +if WITH_PY3 +py3-build: + $(PYTHON3) setup.py build +py3-test: py3-build + $(PYTHON3) test/thrift_json.py + $(PYTHON3) test/test_sslsocket.py +else +py3-build: +py3-test: +endif + +all-local: py3-build + $(PYTHON) setup.py build + +# We're ignoring prefix here because site-packages seems to be +# the equivalent of /usr/local/lib in Python land. +# Old version (can't put inline because it's not portable). +#$(PYTHON) setup.py install --prefix=$(prefix) --root=$(DESTDIR) $(PYTHON_SETUPUTIL_ARGS) +install-exec-hook: + $(PYTHON) setup.py install --root=$(DESTDIR) --prefix=$(PY_PREFIX) $(PYTHON_SETUPUTIL_ARGS) + +check-local: all py3-test + $(PYTHON) test/thrift_json.py + $(PYTHON) test/test_sslsocket.py + +clean-local: + $(RM) -r build + find . -type f \( -iname "*.pyc" \) | xargs rm -f + find . -type d \( -iname "__pycache__" -or -iname "_trial_temp" \) | xargs rm -rf + +dist-hook: + find $(distdir) -type f \( -iname "*.pyc" \) | xargs rm -f + find $(distdir) -type d \( -iname "__pycache__" -or -iname "_trial_temp" \) | xargs rm -rf + +EXTRA_DIST = \ + CMakeLists.txt \ + MANIFEST.in \ + coding_standards.md \ + compat \ + setup.py \ + setup.cfg \ + src \ + test \ + README.md diff --git a/src/jaegertracing/thrift/lib/py/README.md b/src/jaegertracing/thrift/lib/py/README.md new file mode 100644 index 000000000..29b8c73c4 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/README.md @@ -0,0 +1,35 @@ +Thrift Python Software Library + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Using Thrift with Python +======================== + +Thrift is provided as a set of Python packages. The top level package is +thrift, and there are subpackages for the protocol, transport, and server +code. Each package contains modules using standard Thrift naming conventions +(i.e. TProtocol, TTransport) and implementations in corresponding modules +(i.e. TSocket). There is also a subpackage reflection, which contains +the generated code for the reflection structures. + +The Python libraries can be installed manually using the provided setup.py +file, or automatically using the install hook provided via autoconf/automake. +To use the latter, become superuser and do make install. diff --git a/src/jaegertracing/thrift/lib/py/coding_standards.md b/src/jaegertracing/thrift/lib/py/coding_standards.md new file mode 100644 index 000000000..4c560b524 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/coding_standards.md @@ -0,0 +1,7 @@ +## Python Coding Standards + +Please follow: + * [Thrift General Coding Standards](/doc/coding_standards.md) + * Code Style for Python Code [PEP8](http://legacy.python.org/dev/peps/pep-0008/) + +When in doubt - check with or online with . diff --git a/src/jaegertracing/thrift/lib/py/compat/win32/stdint.h b/src/jaegertracing/thrift/lib/py/compat/win32/stdint.h new file mode 100644 index 000000000..d02608a59 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/compat/win32/stdint.h @@ -0,0 +1,247 @@ +// ISO C9x compliant stdint.h for Microsoft Visual Studio +// Based on ISO/IEC 9899:TC2 Committee draft (May 6, 2005) WG14/N1124 +// +// Copyright (c) 2006-2008 Alexander Chemeris +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. The name of the author may be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED +// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO +// EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +// OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +// ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +/////////////////////////////////////////////////////////////////////////////// + +#ifndef _MSC_VER // [ +#error "Use this header only with Microsoft Visual C++ compilers!" +#endif // _MSC_VER ] + +#ifndef _MSC_STDINT_H_ // [ +#define _MSC_STDINT_H_ + +#if _MSC_VER > 1000 +#pragma once +#endif + +#include + +// For Visual Studio 6 in C++ mode and for many Visual Studio versions when +// compiling for ARM we should wrap include with 'extern "C++" {}' +// or compiler give many errors like this: +// error C2733: second C linkage of overloaded function 'wmemchr' not allowed +#ifdef __cplusplus +extern "C" { +#endif +# include +#ifdef __cplusplus +} +#endif + +// Define _W64 macros to mark types changing their size, like intptr_t. +#ifndef _W64 +# if !defined(__midl) && (defined(_X86_) || defined(_M_IX86)) && _MSC_VER >= 1300 +# define _W64 __w64 +# else +# define _W64 +# endif +#endif + + +// 7.18.1 Integer types + +// 7.18.1.1 Exact-width integer types + +// Visual Studio 6 and Embedded Visual C++ 4 doesn't +// realize that, e.g. char has the same size as __int8 +// so we give up on __intX for them. +#if (_MSC_VER < 1300) + typedef signed char int8_t; + typedef signed short int16_t; + typedef signed int int32_t; + typedef unsigned char uint8_t; + typedef unsigned short uint16_t; + typedef unsigned int uint32_t; +#else + typedef signed __int8 int8_t; + typedef signed __int16 int16_t; + typedef signed __int32 int32_t; + typedef unsigned __int8 uint8_t; + typedef unsigned __int16 uint16_t; + typedef unsigned __int32 uint32_t; +#endif +typedef signed __int64 int64_t; +typedef unsigned __int64 uint64_t; + + +// 7.18.1.2 Minimum-width integer types +typedef int8_t int_least8_t; +typedef int16_t int_least16_t; +typedef int32_t int_least32_t; +typedef int64_t int_least64_t; +typedef uint8_t uint_least8_t; +typedef uint16_t uint_least16_t; +typedef uint32_t uint_least32_t; +typedef uint64_t uint_least64_t; + +// 7.18.1.3 Fastest minimum-width integer types +typedef int8_t int_fast8_t; +typedef int16_t int_fast16_t; +typedef int32_t int_fast32_t; +typedef int64_t int_fast64_t; +typedef uint8_t uint_fast8_t; +typedef uint16_t uint_fast16_t; +typedef uint32_t uint_fast32_t; +typedef uint64_t uint_fast64_t; + +// 7.18.1.4 Integer types capable of holding object pointers +#ifdef _WIN64 // [ + typedef signed __int64 intptr_t; + typedef unsigned __int64 uintptr_t; +#else // _WIN64 ][ + typedef _W64 signed int intptr_t; + typedef _W64 unsigned int uintptr_t; +#endif // _WIN64 ] + +// 7.18.1.5 Greatest-width integer types +typedef int64_t intmax_t; +typedef uint64_t uintmax_t; + + +// 7.18.2 Limits of specified-width integer types + +#if !defined(__cplusplus) || defined(__STDC_LIMIT_MACROS) // [ See footnote 220 at page 257 and footnote 221 at page 259 + +// 7.18.2.1 Limits of exact-width integer types +#define INT8_MIN ((int8_t)_I8_MIN) +#define INT8_MAX _I8_MAX +#define INT16_MIN ((int16_t)_I16_MIN) +#define INT16_MAX _I16_MAX +#define INT32_MIN ((int32_t)_I32_MIN) +#define INT32_MAX _I32_MAX +#define INT64_MIN ((int64_t)_I64_MIN) +#define INT64_MAX _I64_MAX +#define UINT8_MAX _UI8_MAX +#define UINT16_MAX _UI16_MAX +#define UINT32_MAX _UI32_MAX +#define UINT64_MAX _UI64_MAX + +// 7.18.2.2 Limits of minimum-width integer types +#define INT_LEAST8_MIN INT8_MIN +#define INT_LEAST8_MAX INT8_MAX +#define INT_LEAST16_MIN INT16_MIN +#define INT_LEAST16_MAX INT16_MAX +#define INT_LEAST32_MIN INT32_MIN +#define INT_LEAST32_MAX INT32_MAX +#define INT_LEAST64_MIN INT64_MIN +#define INT_LEAST64_MAX INT64_MAX +#define UINT_LEAST8_MAX UINT8_MAX +#define UINT_LEAST16_MAX UINT16_MAX +#define UINT_LEAST32_MAX UINT32_MAX +#define UINT_LEAST64_MAX UINT64_MAX + +// 7.18.2.3 Limits of fastest minimum-width integer types +#define INT_FAST8_MIN INT8_MIN +#define INT_FAST8_MAX INT8_MAX +#define INT_FAST16_MIN INT16_MIN +#define INT_FAST16_MAX INT16_MAX +#define INT_FAST32_MIN INT32_MIN +#define INT_FAST32_MAX INT32_MAX +#define INT_FAST64_MIN INT64_MIN +#define INT_FAST64_MAX INT64_MAX +#define UINT_FAST8_MAX UINT8_MAX +#define UINT_FAST16_MAX UINT16_MAX +#define UINT_FAST32_MAX UINT32_MAX +#define UINT_FAST64_MAX UINT64_MAX + +// 7.18.2.4 Limits of integer types capable of holding object pointers +#ifdef _WIN64 // [ +# define INTPTR_MIN INT64_MIN +# define INTPTR_MAX INT64_MAX +# define UINTPTR_MAX UINT64_MAX +#else // _WIN64 ][ +# define INTPTR_MIN INT32_MIN +# define INTPTR_MAX INT32_MAX +# define UINTPTR_MAX UINT32_MAX +#endif // _WIN64 ] + +// 7.18.2.5 Limits of greatest-width integer types +#define INTMAX_MIN INT64_MIN +#define INTMAX_MAX INT64_MAX +#define UINTMAX_MAX UINT64_MAX + +// 7.18.3 Limits of other integer types + +#ifdef _WIN64 // [ +# define PTRDIFF_MIN _I64_MIN +# define PTRDIFF_MAX _I64_MAX +#else // _WIN64 ][ +# define PTRDIFF_MIN _I32_MIN +# define PTRDIFF_MAX _I32_MAX +#endif // _WIN64 ] + +#define SIG_ATOMIC_MIN INT_MIN +#define SIG_ATOMIC_MAX INT_MAX + +#ifndef SIZE_MAX // [ +# ifdef _WIN64 // [ +# define SIZE_MAX _UI64_MAX +# else // _WIN64 ][ +# define SIZE_MAX _UI32_MAX +# endif // _WIN64 ] +#endif // SIZE_MAX ] + +// WCHAR_MIN and WCHAR_MAX are also defined in +#ifndef WCHAR_MIN // [ +# define WCHAR_MIN 0 +#endif // WCHAR_MIN ] +#ifndef WCHAR_MAX // [ +# define WCHAR_MAX _UI16_MAX +#endif // WCHAR_MAX ] + +#define WINT_MIN 0 +#define WINT_MAX _UI16_MAX + +#endif // __STDC_LIMIT_MACROS ] + + +// 7.18.4 Limits of other integer types + +#if !defined(__cplusplus) || defined(__STDC_CONSTANT_MACROS) // [ See footnote 224 at page 260 + +// 7.18.4.1 Macros for minimum-width integer constants + +#define INT8_C(val) val##i8 +#define INT16_C(val) val##i16 +#define INT32_C(val) val##i32 +#define INT64_C(val) val##i64 + +#define UINT8_C(val) val##ui8 +#define UINT16_C(val) val##ui16 +#define UINT32_C(val) val##ui32 +#define UINT64_C(val) val##ui64 + +// 7.18.4.2 Macros for greatest-width integer constants +#define INTMAX_C INT64_C +#define UINTMAX_C UINT64_C + +#endif // __STDC_CONSTANT_MACROS ] + + +#endif // _MSC_STDINT_H_ ] diff --git a/src/jaegertracing/thrift/lib/py/setup.cfg b/src/jaegertracing/thrift/lib/py/setup.cfg new file mode 100644 index 000000000..c9ed0aec5 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/setup.cfg @@ -0,0 +1,6 @@ +[install] +optimize = 1 +[metadata] +description-file = README.md +[flake8] +max-line-length = 100 diff --git a/src/jaegertracing/thrift/lib/py/setup.py b/src/jaegertracing/thrift/lib/py/setup.py new file mode 100644 index 000000000..2ba269159 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/setup.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sys +try: + from setuptools import setup, Extension +except Exception: + from distutils.core import setup, Extension + +from distutils.command.build_ext import build_ext +from distutils.errors import CCompilerError, DistutilsExecError, DistutilsPlatformError + +# Fix to build sdist under vagrant +import os +if 'vagrant' in str(os.environ): + try: + del os.link + except AttributeError: + pass + +include_dirs = ['src'] +if sys.platform == 'win32': + include_dirs.append('compat/win32') + ext_errors = (CCompilerError, DistutilsExecError, DistutilsPlatformError, IOError) +else: + ext_errors = (CCompilerError, DistutilsExecError, DistutilsPlatformError) + + +class BuildFailed(Exception): + pass + + +class ve_build_ext(build_ext): + def run(self): + try: + build_ext.run(self) + except DistutilsPlatformError: + raise BuildFailed() + + def build_extension(self, ext): + try: + build_ext.build_extension(self, ext) + except ext_errors: + raise BuildFailed() + + +def run_setup(with_binary): + if with_binary: + extensions = dict( + ext_modules=[ + Extension('thrift.protocol.fastbinary', + sources=[ + 'src/ext/module.cpp', + 'src/ext/types.cpp', + 'src/ext/binary.cpp', + 'src/ext/compact.cpp', + ], + include_dirs=include_dirs, + ) + ], + cmdclass=dict(build_ext=ve_build_ext) + ) + else: + extensions = dict() + + ssl_deps = [] + if sys.version_info[0] == 2: + ssl_deps.append('ipaddress') + if sys.hexversion < 0x03050000: + ssl_deps.append('backports.ssl_match_hostname>=3.5') + tornado_deps = ['tornado>=4.0'] + twisted_deps = ['twisted'] + + setup(name='thrift', + version='0.13.0', + description='Python bindings for the Apache Thrift RPC system', + author='Apache Thrift Developers', + author_email='dev@thrift.apache.org', + url='http://thrift.apache.org', + license='Apache License 2.0', + install_requires=['six>=1.7.2'], + extras_require={ + 'ssl': ssl_deps, + 'tornado': tornado_deps, + 'twisted': twisted_deps, + 'all': ssl_deps + tornado_deps + twisted_deps, + }, + packages=[ + 'thrift', + 'thrift.protocol', + 'thrift.transport', + 'thrift.server', + ], + package_dir={'thrift': 'src'}, + classifiers=[ + 'Development Status :: 5 - Production/Stable', + 'Environment :: Console', + 'Intended Audience :: Developers', + 'Programming Language :: Python', + 'Programming Language :: Python :: 2', + 'Programming Language :: Python :: 3', + 'Topic :: Software Development :: Libraries', + 'Topic :: System :: Networking' + ], + zip_safe=False, + **extensions + ) + + +try: + with_binary = True + run_setup(with_binary) +except BuildFailed: + print() + print('*' * 80) + print("An error occurred while trying to compile with the C extension enabled") + print("Attempting to build without the extension now") + print('*' * 80) + print() + + run_setup(False) diff --git a/src/jaegertracing/thrift/lib/py/src/TMultiplexedProcessor.py b/src/jaegertracing/thrift/lib/py/src/TMultiplexedProcessor.py new file mode 100644 index 000000000..ff88430bd --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/TMultiplexedProcessor.py @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from thrift.Thrift import TProcessor, TMessageType +from thrift.protocol import TProtocolDecorator, TMultiplexedProtocol +from thrift.protocol.TProtocol import TProtocolException + + +class TMultiplexedProcessor(TProcessor): + def __init__(self): + self.defaultProcessor = None + self.services = {} + + def registerDefault(self, processor): + """ + If a non-multiplexed processor connects to the server and wants to + communicate, use the given processor to handle it. This mechanism + allows servers to upgrade from non-multiplexed to multiplexed in a + backwards-compatible way and still handle old clients. + """ + self.defaultProcessor = processor + + def registerProcessor(self, serviceName, processor): + self.services[serviceName] = processor + + def on_message_begin(self, func): + for key in self.services.keys(): + self.services[key].on_message_begin(func) + + def process(self, iprot, oprot): + (name, type, seqid) = iprot.readMessageBegin() + if type != TMessageType.CALL and type != TMessageType.ONEWAY: + raise TProtocolException( + TProtocolException.NOT_IMPLEMENTED, + "TMultiplexedProtocol only supports CALL & ONEWAY") + + index = name.find(TMultiplexedProtocol.SEPARATOR) + if index < 0: + if self.defaultProcessor: + return self.defaultProcessor.process( + StoredMessageProtocol(iprot, (name, type, seqid)), oprot) + else: + raise TProtocolException( + TProtocolException.NOT_IMPLEMENTED, + "Service name not found in message name: " + name + ". " + + "Did you forget to use TMultiplexedProtocol in your client?") + + serviceName = name[0:index] + call = name[index + len(TMultiplexedProtocol.SEPARATOR):] + if serviceName not in self.services: + raise TProtocolException( + TProtocolException.NOT_IMPLEMENTED, + "Service name not found: " + serviceName + ". " + + "Did you forget to call registerProcessor()?") + + standardMessage = (call, type, seqid) + return self.services[serviceName].process( + StoredMessageProtocol(iprot, standardMessage), oprot) + + +class StoredMessageProtocol(TProtocolDecorator.TProtocolDecorator): + def __init__(self, protocol, messageBegin): + self.messageBegin = messageBegin + + def readMessageBegin(self): + return self.messageBegin diff --git a/src/jaegertracing/thrift/lib/py/src/TRecursive.py b/src/jaegertracing/thrift/lib/py/src/TRecursive.py new file mode 100644 index 000000000..abf202cb1 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/TRecursive.py @@ -0,0 +1,83 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from thrift.Thrift import TType + +TYPE_IDX = 1 +SPEC_ARGS_IDX = 3 +SPEC_ARGS_CLASS_REF_IDX = 0 +SPEC_ARGS_THRIFT_SPEC_IDX = 1 + + +def fix_spec(all_structs): + """Wire up recursive references for all TStruct definitions inside of each thrift_spec.""" + for struc in all_structs: + spec = struc.thrift_spec + for thrift_spec in spec: + if thrift_spec is None: + continue + elif thrift_spec[TYPE_IDX] == TType.STRUCT: + other = thrift_spec[SPEC_ARGS_IDX][SPEC_ARGS_CLASS_REF_IDX].thrift_spec + thrift_spec[SPEC_ARGS_IDX][SPEC_ARGS_THRIFT_SPEC_IDX] = other + elif thrift_spec[TYPE_IDX] in (TType.LIST, TType.SET): + _fix_list_or_set(thrift_spec[SPEC_ARGS_IDX]) + elif thrift_spec[TYPE_IDX] == TType.MAP: + _fix_map(thrift_spec[SPEC_ARGS_IDX]) + + +def _fix_list_or_set(element_type): + # For a list or set, the thrift_spec entry looks like, + # (1, TType.LIST, 'lister', (TType.STRUCT, [RecList, None], False), None, ), # 1 + # so ``element_type`` will be, + # (TType.STRUCT, [RecList, None], False) + if element_type[0] == TType.STRUCT: + element_type[1][1] = element_type[1][0].thrift_spec + elif element_type[0] in (TType.LIST, TType.SET): + _fix_list_or_set(element_type[1]) + elif element_type[0] == TType.MAP: + _fix_map(element_type[1]) + + +def _fix_map(element_type): + # For a map of key -> value type, ``element_type`` will be, + # (TType.I16, None, TType.STRUCT, [RecMapBasic, None], False), None, ) + # which is just a normal struct definition. + # + # For a map of key -> list / set, ``element_type`` will be, + # (TType.I16, None, TType.LIST, (TType.STRUCT, [RecMapList, None], False), False) + # and we need to process the 3rd element as a list. + # + # For a map of key -> map, ``element_type`` will be, + # (TType.I16, None, TType.MAP, (TType.I16, None, TType.STRUCT, + # [RecMapMap, None], False), False) + # and need to process 3rd element as a map. + + # Is the map key a struct? + if element_type[0] == TType.STRUCT: + element_type[1][1] = element_type[1][0].thrift_spec + elif element_type[0] in (TType.LIST, TType.SET): + _fix_list_or_set(element_type[1]) + elif element_type[0] == TType.MAP: + _fix_map(element_type[1]) + + # Is the map value a struct? + if element_type[2] == TType.STRUCT: + element_type[3][1] = element_type[3][0].thrift_spec + elif element_type[2] in (TType.LIST, TType.SET): + _fix_list_or_set(element_type[3]) + elif element_type[2] == TType.MAP: + _fix_map(element_type[3]) diff --git a/src/jaegertracing/thrift/lib/py/src/TSCons.py b/src/jaegertracing/thrift/lib/py/src/TSCons.py new file mode 100644 index 000000000..bc67d7069 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/TSCons.py @@ -0,0 +1,36 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from os import path +from SCons.Builder import Builder +from six.moves import map + + +def scons_env(env, add=''): + opath = path.dirname(path.abspath('$TARGET')) + lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE' + cppbuild = Builder(action=lstr) + env.Append(BUILDERS={'ThriftCpp': cppbuild}) + + +def gen_cpp(env, dir, file): + scons_env(env) + suffixes = ['_types.h', '_types.cpp'] + targets = map(lambda s: 'gen-cpp/' + file + s, suffixes) + return env.ThriftCpp(targets, dir + file + '.thrift') diff --git a/src/jaegertracing/thrift/lib/py/src/TSerialization.py b/src/jaegertracing/thrift/lib/py/src/TSerialization.py new file mode 100644 index 000000000..fbbe76807 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/TSerialization.py @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from .protocol import TBinaryProtocol +from .transport import TTransport + + +def serialize(thrift_object, + protocol_factory=TBinaryProtocol.TBinaryProtocolFactory()): + transport = TTransport.TMemoryBuffer() + protocol = protocol_factory.getProtocol(transport) + thrift_object.write(protocol) + return transport.getvalue() + + +def deserialize(base, + buf, + protocol_factory=TBinaryProtocol.TBinaryProtocolFactory()): + transport = TTransport.TMemoryBuffer(buf) + protocol = protocol_factory.getProtocol(transport) + base.read(protocol) + return base diff --git a/src/jaegertracing/thrift/lib/py/src/TTornado.py b/src/jaegertracing/thrift/lib/py/src/TTornado.py new file mode 100644 index 000000000..5eff11d2d --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/TTornado.py @@ -0,0 +1,188 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from __future__ import absolute_import +import logging +import socket +import struct + +from .transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer + +from io import BytesIO +from collections import deque +from contextlib import contextmanager +from tornado import gen, iostream, ioloop, tcpserver, concurrent + +__all__ = ['TTornadoServer', 'TTornadoStreamTransport'] + +logger = logging.getLogger(__name__) + + +class _Lock(object): + def __init__(self): + self._waiters = deque() + + def acquired(self): + return len(self._waiters) > 0 + + @gen.coroutine + def acquire(self): + blocker = self._waiters[-1] if self.acquired() else None + future = concurrent.Future() + self._waiters.append(future) + if blocker: + yield blocker + + raise gen.Return(self._lock_context()) + + def release(self): + assert self.acquired(), 'Lock not aquired' + future = self._waiters.popleft() + future.set_result(None) + + @contextmanager + def _lock_context(self): + try: + yield + finally: + self.release() + + +class TTornadoStreamTransport(TTransportBase): + """a framed, buffered transport over a Tornado stream""" + def __init__(self, host, port, stream=None, io_loop=None): + self.host = host + self.port = port + self.io_loop = io_loop or ioloop.IOLoop.current() + self.__wbuf = BytesIO() + self._read_lock = _Lock() + + # servers provide a ready-to-go stream + self.stream = stream + + def with_timeout(self, timeout, future): + return gen.with_timeout(timeout, future, self.io_loop) + + @gen.coroutine + def open(self, timeout=None): + logger.debug('socket connecting') + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) + self.stream = iostream.IOStream(sock) + + try: + connect = self.stream.connect((self.host, self.port)) + if timeout is not None: + yield self.with_timeout(timeout, connect) + else: + yield connect + except (socket.error, IOError, ioloop.TimeoutError) as e: + message = 'could not connect to {}:{} ({})'.format(self.host, self.port, e) + raise TTransportException( + type=TTransportException.NOT_OPEN, + message=message) + + raise gen.Return(self) + + def set_close_callback(self, callback): + """ + Should be called only after open() returns + """ + self.stream.set_close_callback(callback) + + def close(self): + # don't raise if we intend to close + self.stream.set_close_callback(None) + self.stream.close() + + def read(self, _): + # The generated code for Tornado shouldn't do individual reads -- only + # frames at a time + assert False, "you're doing it wrong" + + @contextmanager + def io_exception_context(self): + try: + yield + except (socket.error, IOError) as e: + raise TTransportException( + type=TTransportException.END_OF_FILE, + message=str(e)) + except iostream.StreamBufferFullError as e: + raise TTransportException( + type=TTransportException.UNKNOWN, + message=str(e)) + + @gen.coroutine + def readFrame(self): + # IOStream processes reads one at a time + with (yield self._read_lock.acquire()): + with self.io_exception_context(): + frame_header = yield self.stream.read_bytes(4) + if len(frame_header) == 0: + raise iostream.StreamClosedError('Read zero bytes from stream') + frame_length, = struct.unpack('!i', frame_header) + frame = yield self.stream.read_bytes(frame_length) + raise gen.Return(frame) + + def write(self, buf): + self.__wbuf.write(buf) + + def flush(self): + frame = self.__wbuf.getvalue() + # reset wbuf before write/flush to preserve state on underlying failure + frame_length = struct.pack('!i', len(frame)) + self.__wbuf = BytesIO() + with self.io_exception_context(): + return self.stream.write(frame_length + frame) + + +class TTornadoServer(tcpserver.TCPServer): + def __init__(self, processor, iprot_factory, oprot_factory=None, + *args, **kwargs): + super(TTornadoServer, self).__init__(*args, **kwargs) + + self._processor = processor + self._iprot_factory = iprot_factory + self._oprot_factory = (oprot_factory if oprot_factory is not None + else iprot_factory) + + @gen.coroutine + def handle_stream(self, stream, address): + host, port = address[:2] + trans = TTornadoStreamTransport(host=host, port=port, stream=stream, + io_loop=self.io_loop) + oprot = self._oprot_factory.getProtocol(trans) + + try: + while not trans.stream.closed(): + try: + frame = yield trans.readFrame() + except TTransportException as e: + if e.type == TTransportException.END_OF_FILE: + break + else: + raise + tr = TMemoryBuffer(frame) + iprot = self._iprot_factory.getProtocol(tr) + yield self._processor.process(iprot, oprot) + except Exception: + logger.exception('thrift exception in handle_stream') + trans.close() + + logger.info('client disconnected %s:%d', host, port) diff --git a/src/jaegertracing/thrift/lib/py/src/Thrift.py b/src/jaegertracing/thrift/lib/py/src/Thrift.py new file mode 100644 index 000000000..c390cbb54 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/Thrift.py @@ -0,0 +1,204 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sys + + +class TType(object): + STOP = 0 + VOID = 1 + BOOL = 2 + BYTE = 3 + I08 = 3 + DOUBLE = 4 + I16 = 6 + I32 = 8 + I64 = 10 + STRING = 11 + UTF7 = 11 + STRUCT = 12 + MAP = 13 + SET = 14 + LIST = 15 + UTF8 = 16 + UTF16 = 17 + + _VALUES_TO_NAMES = ( + 'STOP', + 'VOID', + 'BOOL', + 'BYTE', + 'DOUBLE', + None, + 'I16', + None, + 'I32', + None, + 'I64', + 'STRING', + 'STRUCT', + 'MAP', + 'SET', + 'LIST', + 'UTF8', + 'UTF16', + ) + + +class TMessageType(object): + CALL = 1 + REPLY = 2 + EXCEPTION = 3 + ONEWAY = 4 + + +class TProcessor(object): + """Base class for processor, which works on two streams.""" + + def process(self, iprot, oprot): + """ + Process a request. The normal behvaior is to have the + processor invoke the correct handler and then it is the + server's responsibility to write the response to oprot. + """ + pass + + def on_message_begin(self, func): + """ + Install a callback that receives (name, type, seqid) + after the message header is read. + """ + pass + + +class TException(Exception): + """Base class for all thrift exceptions.""" + + # BaseException.message is deprecated in Python v[2.6,3.0) + if (2, 6, 0) <= sys.version_info < (3, 0): + def _get_message(self): + return self._message + + def _set_message(self, message): + self._message = message + message = property(_get_message, _set_message) + + def __init__(self, message=None): + Exception.__init__(self, message) + self.message = message + + +class TApplicationException(TException): + """Application level thrift exceptions.""" + + UNKNOWN = 0 + UNKNOWN_METHOD = 1 + INVALID_MESSAGE_TYPE = 2 + WRONG_METHOD_NAME = 3 + BAD_SEQUENCE_ID = 4 + MISSING_RESULT = 5 + INTERNAL_ERROR = 6 + PROTOCOL_ERROR = 7 + INVALID_TRANSFORM = 8 + INVALID_PROTOCOL = 9 + UNSUPPORTED_CLIENT_TYPE = 10 + + def __init__(self, type=UNKNOWN, message=None): + TException.__init__(self, message) + self.type = type + + def __str__(self): + if self.message: + return self.message + elif self.type == self.UNKNOWN_METHOD: + return 'Unknown method' + elif self.type == self.INVALID_MESSAGE_TYPE: + return 'Invalid message type' + elif self.type == self.WRONG_METHOD_NAME: + return 'Wrong method name' + elif self.type == self.BAD_SEQUENCE_ID: + return 'Bad sequence ID' + elif self.type == self.MISSING_RESULT: + return 'Missing result' + elif self.type == self.INTERNAL_ERROR: + return 'Internal error' + elif self.type == self.PROTOCOL_ERROR: + return 'Protocol error' + elif self.type == self.INVALID_TRANSFORM: + return 'Invalid transform' + elif self.type == self.INVALID_PROTOCOL: + return 'Invalid protocol' + elif self.type == self.UNSUPPORTED_CLIENT_TYPE: + return 'Unsupported client type' + else: + return 'Default (unknown) TApplicationException' + + def read(self, iprot): + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + if fid == 1: + if ftype == TType.STRING: + self.message = iprot.readString() + else: + iprot.skip(ftype) + elif fid == 2: + if ftype == TType.I32: + self.type = iprot.readI32() + else: + iprot.skip(ftype) + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + oprot.writeStructBegin('TApplicationException') + if self.message is not None: + oprot.writeFieldBegin('message', TType.STRING, 1) + oprot.writeString(self.message) + oprot.writeFieldEnd() + if self.type is not None: + oprot.writeFieldBegin('type', TType.I32, 2) + oprot.writeI32(self.type) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() + + +class TFrozenDict(dict): + """A dictionary that is "frozen" like a frozenset""" + + def __init__(self, *args, **kwargs): + super(TFrozenDict, self).__init__(*args, **kwargs) + # Sort the items so they will be in a consistent order. + # XOR in the hash of the class so we don't collide with + # the hash of a list of tuples. + self.__hashval = hash(TFrozenDict) ^ hash(tuple(sorted(self.items()))) + + def __setitem__(self, *args): + raise TypeError("Can't modify frozen TFreezableDict") + + def __delitem__(self, *args): + raise TypeError("Can't modify frozen TFreezableDict") + + def __hash__(self): + return self.__hashval diff --git a/src/jaegertracing/thrift/lib/py/src/__init__.py b/src/jaegertracing/thrift/lib/py/src/__init__.py new file mode 100644 index 000000000..48d659c40 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['Thrift', 'TSCons'] diff --git a/src/jaegertracing/thrift/lib/py/src/compat.py b/src/jaegertracing/thrift/lib/py/src/compat.py new file mode 100644 index 000000000..0e8271dc1 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/compat.py @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sys + +if sys.version_info[0] == 2: + + from cStringIO import StringIO as BufferIO + + def binary_to_str(bin_val): + return bin_val + + def str_to_binary(str_val): + return str_val + + def byte_index(bytes_val, i): + return ord(bytes_val[i]) + +else: + + from io import BytesIO as BufferIO # noqa + + def binary_to_str(bin_val): + return bin_val.decode('utf8') + + def str_to_binary(str_val): + return bytes(str_val, 'utf8') + + def byte_index(bytes_val, i): + return bytes_val[i] diff --git a/src/jaegertracing/thrift/lib/py/src/ext/binary.cpp b/src/jaegertracing/thrift/lib/py/src/ext/binary.cpp new file mode 100644 index 000000000..85d8d922e --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/ext/binary.cpp @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "ext/binary.h" +namespace apache { +namespace thrift { +namespace py { + +bool BinaryProtocol::readFieldBegin(TType& type, int16_t& tag) { + uint8_t b = 0; + if (!readByte(b)) { + return false; + } + type = static_cast(b); + if (type == T_STOP) { + return true; + } + return readI16(tag); +} +} +} +} diff --git a/src/jaegertracing/thrift/lib/py/src/ext/binary.h b/src/jaegertracing/thrift/lib/py/src/ext/binary.h new file mode 100644 index 000000000..960b0d003 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/ext/binary.h @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef THRIFT_PY_BINARY_H +#define THRIFT_PY_BINARY_H + +#include +#include "ext/protocol.h" +#include "ext/endian.h" +#include + +namespace apache { +namespace thrift { +namespace py { + +class BinaryProtocol : public ProtocolBase { +public: + virtual ~BinaryProtocol() {} + + void writeI8(int8_t val) { writeBuffer(reinterpret_cast(&val), sizeof(int8_t)); } + + void writeI16(int16_t val) { + int16_t net = static_cast(htons(val)); + writeBuffer(reinterpret_cast(&net), sizeof(int16_t)); + } + + void writeI32(int32_t val) { + int32_t net = static_cast(htonl(val)); + writeBuffer(reinterpret_cast(&net), sizeof(int32_t)); + } + + void writeI64(int64_t val) { + int64_t net = static_cast(htonll(val)); + writeBuffer(reinterpret_cast(&net), sizeof(int64_t)); + } + + void writeDouble(double dub) { + // Unfortunately, bitwise_cast doesn't work in C. Bad C! + union { + double f; + int64_t t; + } transfer; + transfer.f = dub; + writeI64(transfer.t); + } + + void writeBool(int v) { writeByte(static_cast(v)); } + + void writeString(PyObject* value, int32_t len) { + writeI32(len); + writeBuffer(PyBytes_AS_STRING(value), len); + } + + bool writeListBegin(PyObject* value, const SetListTypeArgs& parsedargs, int32_t len) { + writeByte(parsedargs.element_type); + writeI32(len); + return true; + } + + bool writeMapBegin(PyObject* value, const MapTypeArgs& parsedargs, int32_t len) { + writeByte(parsedargs.ktag); + writeByte(parsedargs.vtag); + writeI32(len); + return true; + } + + bool writeStructBegin() { return true; } + bool writeStructEnd() { return true; } + bool writeField(PyObject* value, const StructItemSpec& parsedspec) { + writeByte(static_cast(parsedspec.type)); + writeI16(parsedspec.tag); + return encodeValue(value, parsedspec.type, parsedspec.typeargs); + } + + void writeFieldStop() { writeByte(static_cast(T_STOP)); } + + bool readBool(bool& val) { + char* buf; + if (!readBytes(&buf, 1)) { + return false; + } + val = buf[0] == 1; + return true; + } + + bool readI8(int8_t& val) { + char* buf; + if (!readBytes(&buf, 1)) { + return false; + } + val = buf[0]; + return true; + } + + bool readI16(int16_t& val) { + char* buf; + if (!readBytes(&buf, sizeof(int16_t))) { + return false; + } + memcpy(&val, buf, sizeof(int16_t)); + val = ntohs(val); + return true; + } + + bool readI32(int32_t& val) { + char* buf; + if (!readBytes(&buf, sizeof(int32_t))) { + return false; + } + memcpy(&val, buf, sizeof(int32_t)); + val = ntohl(val); + return true; + } + + bool readI64(int64_t& val) { + char* buf; + if (!readBytes(&buf, sizeof(int64_t))) { + return false; + } + memcpy(&val, buf, sizeof(int64_t)); + val = ntohll(val); + return true; + } + + bool readDouble(double& val) { + union { + int64_t f; + double t; + } transfer; + + if (!readI64(transfer.f)) { + return false; + } + val = transfer.t; + return true; + } + + int32_t readString(char** buf) { + int32_t len = 0; + if (!readI32(len) || !checkLengthLimit(len, stringLimit()) || !readBytes(buf, len)) { + return -1; + } + return len; + } + + int32_t readListBegin(TType& etype) { + int32_t len; + uint8_t b = 0; + if (!readByte(b) || !readI32(len) || !checkLengthLimit(len, containerLimit())) { + return -1; + } + etype = static_cast(b); + return len; + } + + int32_t readMapBegin(TType& ktype, TType& vtype) { + int32_t len; + uint8_t k, v; + if (!readByte(k) || !readByte(v) || !readI32(len) || !checkLengthLimit(len, containerLimit())) { + return -1; + } + ktype = static_cast(k); + vtype = static_cast(v); + return len; + } + + bool readStructBegin() { return true; } + bool readStructEnd() { return true; } + + bool readFieldBegin(TType& type, int16_t& tag); + +#define SKIPBYTES(n) \ + do { \ + if (!readBytes(&dummy_buf_, (n))) { \ + return false; \ + } \ + return true; \ + } while (0) + + bool skipBool() { SKIPBYTES(1); } + bool skipByte() { SKIPBYTES(1); } + bool skipI16() { SKIPBYTES(2); } + bool skipI32() { SKIPBYTES(4); } + bool skipI64() { SKIPBYTES(8); } + bool skipDouble() { SKIPBYTES(8); } + bool skipString() { + int32_t len; + if (!readI32(len)) { + return false; + } + SKIPBYTES(len); + } +#undef SKIPBYTES + +private: + char* dummy_buf_; +}; +} +} +} +#endif // THRIFT_PY_BINARY_H diff --git a/src/jaegertracing/thrift/lib/py/src/ext/compact.cpp b/src/jaegertracing/thrift/lib/py/src/ext/compact.cpp new file mode 100644 index 000000000..15a99a077 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/ext/compact.cpp @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "ext/compact.h" + +namespace apache { +namespace thrift { +namespace py { + +const uint8_t CompactProtocol::TTypeToCType[] = { + CT_STOP, // T_STOP + 0, // unused + CT_BOOLEAN_TRUE, // T_BOOL + CT_BYTE, // T_BYTE + CT_DOUBLE, // T_DOUBLE + 0, // unused + CT_I16, // T_I16 + 0, // unused + CT_I32, // T_I32 + 0, // unused + CT_I64, // T_I64 + CT_BINARY, // T_STRING + CT_STRUCT, // T_STRUCT + CT_MAP, // T_MAP + CT_SET, // T_SET + CT_LIST, // T_LIST +}; + +bool CompactProtocol::readFieldBegin(TType& type, int16_t& tag) { + uint8_t b; + if (!readByte(b)) { + return false; + } + uint8_t ctype = b & 0xf; + type = getTType(ctype); + if (type == -1) { + return false; + } else if (type == T_STOP) { + tag = 0; + return true; + } + uint8_t diff = (b & 0xf0) >> 4; + if (diff) { + tag = readTags_.top() + diff; + } else if (!readI16(tag)) { + readTags_.top() = -1; + return false; + } + if (ctype == CT_BOOLEAN_FALSE || ctype == CT_BOOLEAN_TRUE) { + readBool_.exists = true; + readBool_.value = ctype == CT_BOOLEAN_TRUE; + } + readTags_.top() = tag; + return true; +} + +TType CompactProtocol::getTType(uint8_t type) { + switch (type) { + case T_STOP: + return T_STOP; + case CT_BOOLEAN_FALSE: + case CT_BOOLEAN_TRUE: + return T_BOOL; + case CT_BYTE: + return T_BYTE; + case CT_I16: + return T_I16; + case CT_I32: + return T_I32; + case CT_I64: + return T_I64; + case CT_DOUBLE: + return T_DOUBLE; + case CT_BINARY: + return T_STRING; + case CT_LIST: + return T_LIST; + case CT_SET: + return T_SET; + case CT_MAP: + return T_MAP; + case CT_STRUCT: + return T_STRUCT; + default: + PyErr_Format(PyExc_TypeError, "don't know what type: %d", type); + return static_cast(-1); + } +} +} +} +} diff --git a/src/jaegertracing/thrift/lib/py/src/ext/compact.h b/src/jaegertracing/thrift/lib/py/src/ext/compact.h new file mode 100644 index 000000000..a78d7a703 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/ext/compact.h @@ -0,0 +1,368 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef THRIFT_PY_COMPACT_H +#define THRIFT_PY_COMPACT_H + +#include +#include "ext/protocol.h" +#include "ext/endian.h" +#include +#include + +namespace apache { +namespace thrift { +namespace py { + +class CompactProtocol : public ProtocolBase { +public: + CompactProtocol() { readBool_.exists = false; } + + virtual ~CompactProtocol() {} + + void writeI8(int8_t val) { writeBuffer(reinterpret_cast(&val), 1); } + + void writeI16(int16_t val) { writeVarint(toZigZag(val)); } + + int writeI32(int32_t val) { return writeVarint(toZigZag(val)); } + + void writeI64(int64_t val) { writeVarint64(toZigZag64(val)); } + + void writeDouble(double dub) { + union { + double f; + int64_t t; + } transfer; + transfer.f = htolell(dub); + writeBuffer(reinterpret_cast(&transfer.t), sizeof(int64_t)); + } + + void writeBool(int v) { writeByte(static_cast(v ? CT_BOOLEAN_TRUE : CT_BOOLEAN_FALSE)); } + + void writeString(PyObject* value, int32_t len) { + writeVarint(len); + writeBuffer(PyBytes_AS_STRING(value), len); + } + + bool writeListBegin(PyObject* value, const SetListTypeArgs& args, int32_t len) { + int ctype = toCompactType(args.element_type); + if (len <= 14) { + writeByte(static_cast(len << 4 | ctype)); + } else { + writeByte(0xf0 | ctype); + writeVarint(len); + } + return true; + } + + bool writeMapBegin(PyObject* value, const MapTypeArgs& args, int32_t len) { + if (len == 0) { + writeByte(0); + return true; + } + int ctype = toCompactType(args.ktag) << 4 | toCompactType(args.vtag); + writeVarint(len); + writeByte(ctype); + return true; + } + + bool writeStructBegin() { + writeTags_.push(0); + return true; + } + bool writeStructEnd() { + writeTags_.pop(); + return true; + } + + bool writeField(PyObject* value, const StructItemSpec& spec) { + if (spec.type == T_BOOL) { + doWriteFieldBegin(spec, PyObject_IsTrue(value) ? CT_BOOLEAN_TRUE : CT_BOOLEAN_FALSE); + return true; + } else { + doWriteFieldBegin(spec, toCompactType(spec.type)); + return encodeValue(value, spec.type, spec.typeargs); + } + } + + void writeFieldStop() { writeByte(0); } + + bool readBool(bool& val) { + if (readBool_.exists) { + readBool_.exists = false; + val = readBool_.value; + return true; + } + char* buf; + if (!readBytes(&buf, 1)) { + return false; + } + val = buf[0] == CT_BOOLEAN_TRUE; + return true; + } + bool readI8(int8_t& val) { + char* buf; + if (!readBytes(&buf, 1)) { + return false; + } + val = buf[0]; + return true; + } + + bool readI16(int16_t& val) { + uint16_t uval; + if (readVarint(uval)) { + val = fromZigZag(uval); + return true; + } + return false; + } + + bool readI32(int32_t& val) { + uint32_t uval; + if (readVarint(uval)) { + val = fromZigZag(uval); + return true; + } + return false; + } + + bool readI64(int64_t& val) { + uint64_t uval; + if (readVarint(uval)) { + val = fromZigZag(uval); + return true; + } + return false; + } + + bool readDouble(double& val) { + union { + int64_t f; + double t; + } transfer; + + char* buf; + if (!readBytes(&buf, 8)) { + return false; + } + memcpy(&transfer.f, buf, sizeof(int64_t)); + transfer.f = letohll(transfer.f); + val = transfer.t; + return true; + } + + int32_t readString(char** buf) { + uint32_t len; + if (!readVarint(len) || !checkLengthLimit(len, stringLimit())) { + return -1; + } + if (len == 0) { + return 0; + } + if (!readBytes(buf, len)) { + return -1; + } + return len; + } + + int32_t readListBegin(TType& etype) { + uint8_t b; + if (!readByte(b)) { + return -1; + } + etype = getTType(b & 0xf); + if (etype == -1) { + return -1; + } + uint32_t len = (b >> 4) & 0xf; + if (len == 15 && !readVarint(len)) { + return -1; + } + if (!checkLengthLimit(len, containerLimit())) { + return -1; + } + return len; + } + + int32_t readMapBegin(TType& ktype, TType& vtype) { + uint32_t len; + if (!readVarint(len) || !checkLengthLimit(len, containerLimit())) { + return -1; + } + if (len != 0) { + uint8_t kvType; + if (!readByte(kvType)) { + return -1; + } + ktype = getTType(kvType >> 4); + vtype = getTType(kvType & 0xf); + if (ktype == -1 || vtype == -1) { + return -1; + } + } + return len; + } + + bool readStructBegin() { + readTags_.push(0); + return true; + } + bool readStructEnd() { + readTags_.pop(); + return true; + } + bool readFieldBegin(TType& type, int16_t& tag); + + bool skipBool() { + bool val; + return readBool(val); + } +#define SKIPBYTES(n) \ + do { \ + if (!readBytes(&dummy_buf_, (n))) { \ + return false; \ + } \ + return true; \ + } while (0) + bool skipByte() { SKIPBYTES(1); } + bool skipDouble() { SKIPBYTES(8); } + bool skipI16() { + int16_t val; + return readI16(val); + } + bool skipI32() { + int32_t val; + return readI32(val); + } + bool skipI64() { + int64_t val; + return readI64(val); + } + bool skipString() { + uint32_t len; + if (!readVarint(len)) { + return false; + } + SKIPBYTES(len); + } +#undef SKIPBYTES + +private: + enum Types { + CT_STOP = 0x00, + CT_BOOLEAN_TRUE = 0x01, + CT_BOOLEAN_FALSE = 0x02, + CT_BYTE = 0x03, + CT_I16 = 0x04, + CT_I32 = 0x05, + CT_I64 = 0x06, + CT_DOUBLE = 0x07, + CT_BINARY = 0x08, + CT_LIST = 0x09, + CT_SET = 0x0A, + CT_MAP = 0x0B, + CT_STRUCT = 0x0C + }; + + static const uint8_t TTypeToCType[]; + + TType getTType(uint8_t type); + + int toCompactType(TType type) { + int i = static_cast(type); + return i < 16 ? TTypeToCType[i] : -1; + } + + uint32_t toZigZag(int32_t val) { return (val >> 31) ^ (val << 1); } + + uint64_t toZigZag64(int64_t val) { return (val >> 63) ^ (val << 1); } + + int writeVarint(uint32_t val) { + int cnt = 1; + while (val & ~0x7fU) { + writeByte(static_cast((val & 0x7fU) | 0x80U)); + val >>= 7; + ++cnt; + } + writeByte(static_cast(val)); + return cnt; + } + + int writeVarint64(uint64_t val) { + int cnt = 1; + while (val & ~0x7fULL) { + writeByte(static_cast((val & 0x7fULL) | 0x80ULL)); + val >>= 7; + ++cnt; + } + writeByte(static_cast(val)); + return cnt; + } + + template + bool readVarint(T& result) { + uint8_t b; + T val = 0; + int shift = 0; + for (int i = 0; i < Max; ++i) { + if (!readByte(b)) { + return false; + } + if (b & 0x80) { + val |= static_cast(b & 0x7f) << shift; + } else { + val |= static_cast(b) << shift; + result = val; + return true; + } + shift += 7; + } + PyErr_Format(PyExc_OverflowError, "varint exceeded %d bytes", Max); + return false; + } + + template + S fromZigZag(U val) { + return (val >> 1) ^ static_cast(-static_cast(val & 1)); + } + + void doWriteFieldBegin(const StructItemSpec& spec, int ctype) { + int diff = spec.tag - writeTags_.top(); + if (diff > 0 && diff <= 15) { + writeByte(static_cast(diff << 4 | ctype)); + } else { + writeByte(static_cast(ctype)); + writeI16(spec.tag); + } + writeTags_.top() = spec.tag; + } + + std::stack writeTags_; + std::stack readTags_; + struct { + bool exists; + bool value; + } readBool_; + char* dummy_buf_; +}; +} +} +} +#endif // THRIFT_PY_COMPACT_H diff --git a/src/jaegertracing/thrift/lib/py/src/ext/endian.h b/src/jaegertracing/thrift/lib/py/src/ext/endian.h new file mode 100644 index 000000000..1660cbd98 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/ext/endian.h @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef THRIFT_PY_ENDIAN_H +#define THRIFT_PY_ENDIAN_H + +#include + +#ifndef _WIN32 +#include +#else +#include +#pragma comment(lib, "ws2_32.lib") +#define BIG_ENDIAN (4321) +#define LITTLE_ENDIAN (1234) +#define BYTE_ORDER LITTLE_ENDIAN +#define inline __inline +#endif + +/* Fix endianness issues on Solaris */ +#if defined(__SVR4) && defined(__sun) +#if defined(__i386) && !defined(__i386__) +#define __i386__ +#endif + +#ifndef BIG_ENDIAN +#define BIG_ENDIAN (4321) +#endif +#ifndef LITTLE_ENDIAN +#define LITTLE_ENDIAN (1234) +#endif + +/* I386 is LE, even on Solaris */ +#if !defined(BYTE_ORDER) && defined(__i386__) +#define BYTE_ORDER LITTLE_ENDIAN +#endif +#endif + +#ifndef __BYTE_ORDER +#if defined(BYTE_ORDER) && defined(LITTLE_ENDIAN) && defined(BIG_ENDIAN) +#define __BYTE_ORDER BYTE_ORDER +#define __LITTLE_ENDIAN LITTLE_ENDIAN +#define __BIG_ENDIAN BIG_ENDIAN +#else +#error "Cannot determine endianness" +#endif +#endif + +// Same comment as the enum. Sorry. +#if __BYTE_ORDER == __BIG_ENDIAN +#define ntohll(n) (n) +#define htonll(n) (n) +#if defined(__GNUC__) && defined(__GLIBC__) +#include +#define letohll(n) bswap_64(n) +#define htolell(n) bswap_64(n) +#else /* GNUC & GLIBC */ +#define letohll(n) ((((unsigned long long)ntohl(n)) << 32) + ntohl(n >> 32)) +#define htolell(n) ((((unsigned long long)htonl(n)) << 32) + htonl(n >> 32)) +#endif +#elif __BYTE_ORDER == __LITTLE_ENDIAN +#if defined(__GNUC__) && defined(__GLIBC__) +#include +#define ntohll(n) bswap_64(n) +#define htonll(n) bswap_64(n) +#elif defined(_MSC_VER) +#include +#define ntohll(n) _byteswap_uint64(n) +#define htonll(n) _byteswap_uint64(n) +#else /* GNUC & GLIBC */ +#define ntohll(n) ((((unsigned long long)ntohl(n)) << 32) + ntohl(n >> 32)) +#define htonll(n) ((((unsigned long long)htonl(n)) << 32) + htonl(n >> 32)) +#endif /* GNUC & GLIBC */ +#define letohll(n) (n) +#define htolell(n) (n) +#else /* __BYTE_ORDER */ +#error "Can't define htonll or ntohll!" +#endif + +#endif // THRIFT_PY_ENDIAN_H diff --git a/src/jaegertracing/thrift/lib/py/src/ext/module.cpp b/src/jaegertracing/thrift/lib/py/src/ext/module.cpp new file mode 100644 index 000000000..7158b8fdf --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/ext/module.cpp @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include "types.h" +#include "binary.h" +#include "compact.h" +#include +#include + +// TODO(dreiss): defval appears to be unused. Look into removing it. +// TODO(dreiss): Make parse_spec_args recursive, and cache the output +// permanently in the object. (Malloc and orphan.) +// TODO(dreiss): Why do we need cStringIO for reading, why not just char*? +// Can cStringIO let us work with a BufferedTransport? +// TODO(dreiss): Don't ignore the rv from cwrite (maybe). + +// Doing a benchmark shows that interning actually makes a difference, amazingly. + +/** Pointer to interned string to speed up attribute lookup. */ +PyObject* INTERN_STRING(TFrozenDict); +PyObject* INTERN_STRING(cstringio_buf); +PyObject* INTERN_STRING(cstringio_refill); +static PyObject* INTERN_STRING(string_length_limit); +static PyObject* INTERN_STRING(container_length_limit); +static PyObject* INTERN_STRING(trans); + +namespace apache { +namespace thrift { +namespace py { + +template +static PyObject* encode_impl(PyObject* args) { + if (!args) + return NULL; + + PyObject* enc_obj = NULL; + PyObject* type_args = NULL; + if (!PyArg_ParseTuple(args, "OO", &enc_obj, &type_args)) { + return NULL; + } + if (!enc_obj || !type_args) { + return NULL; + } + + T protocol; + if (!protocol.prepareEncodeBuffer() || !protocol.encodeValue(enc_obj, T_STRUCT, type_args)) { + return NULL; + } + + return protocol.getEncodedValue(); +} + +static inline long as_long_then_delete(PyObject* value, long default_value) { + ScopedPyObject scope(value); + long v = PyInt_AsLong(value); + if (INT_CONV_ERROR_OCCURRED(v)) { + PyErr_Clear(); + return default_value; + } + return v; +} + +template +static PyObject* decode_impl(PyObject* args) { + PyObject* output_obj = NULL; + PyObject* oprot = NULL; + PyObject* typeargs = NULL; + if (!PyArg_ParseTuple(args, "OOO", &output_obj, &oprot, &typeargs)) { + return NULL; + } + + T protocol; + int32_t default_limit = (std::numeric_limits::max)(); + protocol.setStringLengthLimit( + as_long_then_delete(PyObject_GetAttr(oprot, INTERN_STRING(string_length_limit)), + default_limit)); + protocol.setContainerLengthLimit( + as_long_then_delete(PyObject_GetAttr(oprot, INTERN_STRING(container_length_limit)), + default_limit)); + ScopedPyObject transport(PyObject_GetAttr(oprot, INTERN_STRING(trans))); + if (!transport) { + return NULL; + } + + StructTypeArgs parsedargs; + if (!parse_struct_args(&parsedargs, typeargs)) { + return NULL; + } + + if (!protocol.prepareDecodeBufferFromTransport(transport.get())) { + return NULL; + } + + return protocol.readStruct(output_obj, parsedargs.klass, parsedargs.spec); +} +} +} +} + +using namespace apache::thrift::py; + +/* -- PYTHON MODULE SETUP STUFF --- */ + +extern "C" { + +static PyObject* encode_binary(PyObject*, PyObject* args) { + return encode_impl(args); +} + +static PyObject* decode_binary(PyObject*, PyObject* args) { + return decode_impl(args); +} + +static PyObject* encode_compact(PyObject*, PyObject* args) { + return encode_impl(args); +} + +static PyObject* decode_compact(PyObject*, PyObject* args) { + return decode_impl(args); +} + +static PyMethodDef ThriftFastBinaryMethods[] = { + {"encode_binary", encode_binary, METH_VARARGS, ""}, + {"decode_binary", decode_binary, METH_VARARGS, ""}, + {"encode_compact", encode_compact, METH_VARARGS, ""}, + {"decode_compact", decode_compact, METH_VARARGS, ""}, + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + +#if PY_MAJOR_VERSION >= 3 + +static struct PyModuleDef ThriftFastBinaryDef = {PyModuleDef_HEAD_INIT, + "thrift.protocol.fastbinary", + NULL, + 0, + ThriftFastBinaryMethods, + NULL, + NULL, + NULL, + NULL}; + +#define INITERROR return NULL; + +PyObject* PyInit_fastbinary() { + +#else + +#define INITERROR return; + +void initfastbinary() { + + PycString_IMPORT; + if (PycStringIO == NULL) + INITERROR + +#endif + +#define INIT_INTERN_STRING(value) \ + do { \ + INTERN_STRING(value) = PyString_InternFromString(#value); \ + if (!INTERN_STRING(value)) \ + INITERROR \ + } while (0) + + INIT_INTERN_STRING(TFrozenDict); + INIT_INTERN_STRING(cstringio_buf); + INIT_INTERN_STRING(cstringio_refill); + INIT_INTERN_STRING(string_length_limit); + INIT_INTERN_STRING(container_length_limit); + INIT_INTERN_STRING(trans); +#undef INIT_INTERN_STRING + + PyObject* module = +#if PY_MAJOR_VERSION >= 3 + PyModule_Create(&ThriftFastBinaryDef); +#else + Py_InitModule("thrift.protocol.fastbinary", ThriftFastBinaryMethods); +#endif + if (module == NULL) + INITERROR; + +#if PY_MAJOR_VERSION >= 3 + return module; +#endif +} +} diff --git a/src/jaegertracing/thrift/lib/py/src/ext/protocol.h b/src/jaegertracing/thrift/lib/py/src/ext/protocol.h new file mode 100644 index 000000000..521b7ee92 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/ext/protocol.h @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef THRIFT_PY_PROTOCOL_H +#define THRIFT_PY_PROTOCOL_H + +#include "ext/types.h" +#include +#include + +namespace apache { +namespace thrift { +namespace py { + +template +class ProtocolBase { + +public: + ProtocolBase() + : stringLimit_((std::numeric_limits::max)()), + containerLimit_((std::numeric_limits::max)()), + output_(NULL) {} + inline virtual ~ProtocolBase(); + + bool prepareDecodeBufferFromTransport(PyObject* trans); + + PyObject* readStruct(PyObject* output, PyObject* klass, PyObject* spec_seq); + + bool prepareEncodeBuffer(); + + bool encodeValue(PyObject* value, TType type, PyObject* typeargs); + + PyObject* getEncodedValue(); + + long stringLimit() const { return stringLimit_; } + void setStringLengthLimit(long limit) { stringLimit_ = limit; } + + long containerLimit() const { return containerLimit_; } + void setContainerLengthLimit(long limit) { containerLimit_ = limit; } + +protected: + bool readBytes(char** output, int len); + + bool readByte(uint8_t& val) { + char* buf; + if (!readBytes(&buf, 1)) { + return false; + } + val = static_cast(buf[0]); + return true; + } + + bool writeBuffer(char* data, size_t len); + + void writeByte(uint8_t val) { writeBuffer(reinterpret_cast(&val), 1); } + + PyObject* decodeValue(TType type, PyObject* typeargs); + + bool skip(TType type); + + inline bool checkType(TType got, TType expected); + inline bool checkLengthLimit(int32_t len, long limit); + + inline bool isUtf8(PyObject* typeargs); + +private: + Impl* impl() { return static_cast(this); } + + long stringLimit_; + long containerLimit_; + EncodeBuffer* output_; + DecodeBuffer input_; +}; +} +} +} + +#include "ext/protocol.tcc" + +#endif // THRIFT_PY_PROTOCOL_H diff --git a/src/jaegertracing/thrift/lib/py/src/ext/protocol.tcc b/src/jaegertracing/thrift/lib/py/src/ext/protocol.tcc new file mode 100644 index 000000000..e15df7ea0 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/ext/protocol.tcc @@ -0,0 +1,913 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef THRIFT_PY_PROTOCOL_TCC +#define THRIFT_PY_PROTOCOL_TCC + +#include + +#define CHECK_RANGE(v, min, max) (((v) <= (max)) && ((v) >= (min))) +#define INIT_OUTBUF_SIZE 128 + +#if PY_MAJOR_VERSION < 3 +#include +#else +#include +#endif + +namespace apache { +namespace thrift { +namespace py { + +#if PY_MAJOR_VERSION < 3 + +namespace detail { + +inline bool input_check(PyObject* input) { + return PycStringIO_InputCheck(input); +} + +inline EncodeBuffer* new_encode_buffer(size_t size) { + if (!PycStringIO) { + PycString_IMPORT; + } + if (!PycStringIO) { + return NULL; + } + return PycStringIO->NewOutput(size); +} + +inline int read_buffer(PyObject* buf, char** output, int len) { + if (!PycStringIO) { + PycString_IMPORT; + } + if (!PycStringIO) { + PyErr_SetString(PyExc_ImportError, "failed to import native cStringIO"); + return -1; + } + return PycStringIO->cread(buf, output, len); +} +} + +template +inline ProtocolBase::~ProtocolBase() { + if (output_) { + Py_CLEAR(output_); + } +} + +template +inline bool ProtocolBase::isUtf8(PyObject* typeargs) { + return PyString_Check(typeargs) && !strncmp(PyString_AS_STRING(typeargs), "UTF8", 4); +} + +template +PyObject* ProtocolBase::getEncodedValue() { + if (!PycStringIO) { + PycString_IMPORT; + } + if (!PycStringIO) { + return NULL; + } + return PycStringIO->cgetvalue(output_); +} + +template +inline bool ProtocolBase::writeBuffer(char* data, size_t size) { + if (!PycStringIO) { + PycString_IMPORT; + } + if (!PycStringIO) { + PyErr_SetString(PyExc_ImportError, "failed to import native cStringIO"); + return false; + } + int len = PycStringIO->cwrite(output_, data, size); + if (len < 0) { + PyErr_SetString(PyExc_IOError, "failed to write to cStringIO object"); + return false; + } + if (static_cast(len) != size) { + PyErr_Format(PyExc_EOFError, "write length mismatch: expected %lu got %d", size, len); + return false; + } + return true; +} + +#else + +namespace detail { + +inline bool input_check(PyObject* input) { + // TODO: Check for BytesIO type + return true; +} + +inline EncodeBuffer* new_encode_buffer(size_t size) { + EncodeBuffer* buffer = new EncodeBuffer; + buffer->buf.reserve(size); + buffer->pos = 0; + return buffer; +} + +struct bytesio { + PyObject_HEAD +#if PY_MINOR_VERSION < 5 + char* buf; +#else + PyObject* buf; +#endif + Py_ssize_t pos; + Py_ssize_t string_size; +}; + +inline int read_buffer(PyObject* buf, char** output, int len) { + bytesio* buf2 = reinterpret_cast(buf); +#if PY_MINOR_VERSION < 5 + *output = buf2->buf + buf2->pos; +#else + *output = PyBytes_AS_STRING(buf2->buf) + buf2->pos; +#endif + Py_ssize_t pos0 = buf2->pos; + buf2->pos = (std::min)(buf2->pos + static_cast(len), buf2->string_size); + return static_cast(buf2->pos - pos0); +} +} + +template +inline ProtocolBase::~ProtocolBase() { + if (output_) { + delete output_; + } +} + +template +inline bool ProtocolBase::isUtf8(PyObject* typeargs) { + // while condition for py2 is "arg == 'UTF8'", it should be "arg != 'BINARY'" for py3. + // HACK: check the length and don't bother reading the value + return !PyUnicode_Check(typeargs) || PyUnicode_GET_LENGTH(typeargs) != 6; +} + +template +PyObject* ProtocolBase::getEncodedValue() { + return PyBytes_FromStringAndSize(output_->buf.data(), output_->buf.size()); +} + +template +inline bool ProtocolBase::writeBuffer(char* data, size_t size) { + size_t need = size + output_->pos; + if (output_->buf.capacity() < need) { + try { + output_->buf.reserve(need); + } catch (std::bad_alloc& ex) { + PyErr_SetString(PyExc_MemoryError, "Failed to allocate write buffer"); + return false; + } + } + std::copy(data, data + size, std::back_inserter(output_->buf)); + return true; +} + +#endif + +namespace detail { + +#define DECLARE_OP_SCOPE(name, op) \ + template \ + struct name##Scope { \ + Impl* impl; \ + bool valid; \ + name##Scope(Impl* thiz) : impl(thiz), valid(impl->op##Begin()) {} \ + ~name##Scope() { \ + if (valid) \ + impl->op##End(); \ + } \ + operator bool() { return valid; } \ + }; \ + template class T> \ + name##Scope op##Scope(T* thiz) { \ + return name##Scope(static_cast(thiz)); \ + } +DECLARE_OP_SCOPE(WriteStruct, writeStruct) +DECLARE_OP_SCOPE(ReadStruct, readStruct) +#undef DECLARE_OP_SCOPE + +inline bool check_ssize_t_32(Py_ssize_t len) { + // error from getting the int + if (INT_CONV_ERROR_OCCURRED(len)) { + return false; + } + if (!CHECK_RANGE(len, 0, (std::numeric_limits::max)())) { + PyErr_SetString(PyExc_OverflowError, "size out of range: exceeded INT32_MAX"); + return false; + } + return true; +} +} + +template +bool parse_pyint(PyObject* o, T* ret, int32_t min, int32_t max) { + long val = PyInt_AsLong(o); + + if (INT_CONV_ERROR_OCCURRED(val)) { + return false; + } + if (!CHECK_RANGE(val, min, max)) { + PyErr_SetString(PyExc_OverflowError, "int out of range"); + return false; + } + + *ret = static_cast(val); + return true; +} + +template +inline bool ProtocolBase::checkType(TType got, TType expected) { + if (expected != got) { + PyErr_SetString(PyExc_TypeError, "got wrong ttype while reading field"); + return false; + } + return true; +} + +template +bool ProtocolBase::checkLengthLimit(int32_t len, long limit) { + if (len < 0) { + PyErr_Format(PyExc_OverflowError, "negative length: %ld", limit); + return false; + } + if (len > limit) { + PyErr_Format(PyExc_OverflowError, "size exceeded specified limit: %ld", limit); + return false; + } + return true; +} + +template +bool ProtocolBase::readBytes(char** output, int len) { + if (len < 0) { + PyErr_Format(PyExc_ValueError, "attempted to read negative length: %d", len); + return false; + } + // TODO(dreiss): Don't fear the malloc. Think about taking a copy of + // the partial read instead of forcing the transport + // to prepend it to its buffer. + + int rlen = detail::read_buffer(input_.stringiobuf.get(), output, len); + + if (rlen == len) { + return true; + } else if (rlen == -1) { + return false; + } else { + // using building functions as this is a rare codepath + ScopedPyObject newiobuf(PyObject_CallFunction(input_.refill_callable.get(), refill_signature, + *output, rlen, len, NULL)); + if (!newiobuf) { + return false; + } + + // must do this *AFTER* the call so that we don't deref the io buffer + input_.stringiobuf.reset(newiobuf.release()); + + rlen = detail::read_buffer(input_.stringiobuf.get(), output, len); + + if (rlen == len) { + return true; + } else if (rlen == -1) { + return false; + } else { + // TODO(dreiss): This could be a valid code path for big binary blobs. + PyErr_SetString(PyExc_TypeError, "refill claimed to have refilled the buffer, but didn't!!"); + return false; + } + } +} + +template +bool ProtocolBase::prepareDecodeBufferFromTransport(PyObject* trans) { + if (input_.stringiobuf) { + PyErr_SetString(PyExc_ValueError, "decode buffer is already initialized"); + return false; + } + + ScopedPyObject stringiobuf(PyObject_GetAttr(trans, INTERN_STRING(cstringio_buf))); + if (!stringiobuf) { + return false; + } + if (!detail::input_check(stringiobuf.get())) { + PyErr_SetString(PyExc_TypeError, "expecting stringio input_"); + return false; + } + + ScopedPyObject refill_callable(PyObject_GetAttr(trans, INTERN_STRING(cstringio_refill))); + if (!refill_callable) { + return false; + } + if (!PyCallable_Check(refill_callable.get())) { + PyErr_SetString(PyExc_TypeError, "expecting callable"); + return false; + } + + input_.stringiobuf.swap(stringiobuf); + input_.refill_callable.swap(refill_callable); + return true; +} + +template +bool ProtocolBase::prepareEncodeBuffer() { + output_ = detail::new_encode_buffer(INIT_OUTBUF_SIZE); + return output_ != NULL; +} + +template +bool ProtocolBase::encodeValue(PyObject* value, TType type, PyObject* typeargs) { + /* + * Refcounting Strategy: + * + * We assume that elements of the thrift_spec tuple are not going to be + * mutated, so we don't ref count those at all. Other than that, we try to + * keep a reference to all the user-created objects while we work with them. + * encodeValue assumes that a reference is already held. The *caller* is + * responsible for handling references + */ + + switch (type) { + + case T_BOOL: { + int v = PyObject_IsTrue(value); + if (v == -1) { + return false; + } + impl()->writeBool(v); + return true; + } + case T_I08: { + int8_t val; + + if (!parse_pyint(value, &val, (std::numeric_limits::min)(), + (std::numeric_limits::max)())) { + return false; + } + + impl()->writeI8(val); + return true; + } + case T_I16: { + int16_t val; + + if (!parse_pyint(value, &val, (std::numeric_limits::min)(), + (std::numeric_limits::max)())) { + return false; + } + + impl()->writeI16(val); + return true; + } + case T_I32: { + int32_t val; + + if (!parse_pyint(value, &val, (std::numeric_limits::min)(), + (std::numeric_limits::max)())) { + return false; + } + + impl()->writeI32(val); + return true; + } + case T_I64: { + int64_t nval = PyLong_AsLongLong(value); + + if (INT_CONV_ERROR_OCCURRED(nval)) { + return false; + } + + if (!CHECK_RANGE(nval, (std::numeric_limits::min)(), + (std::numeric_limits::max)())) { + PyErr_SetString(PyExc_OverflowError, "int out of range"); + return false; + } + + impl()->writeI64(nval); + return true; + } + + case T_DOUBLE: { + double nval = PyFloat_AsDouble(value); + if (nval == -1.0 && PyErr_Occurred()) { + return false; + } + + impl()->writeDouble(nval); + return true; + } + + case T_STRING: { + ScopedPyObject nval; + + if (PyUnicode_Check(value)) { + nval.reset(PyUnicode_AsUTF8String(value)); + if (!nval) { + return false; + } + } else { + Py_INCREF(value); + nval.reset(value); + } + + Py_ssize_t len = PyBytes_Size(nval.get()); + if (!detail::check_ssize_t_32(len)) { + return false; + } + + impl()->writeString(nval.get(), static_cast(len)); + return true; + } + + case T_LIST: + case T_SET: { + SetListTypeArgs parsedargs; + if (!parse_set_list_args(&parsedargs, typeargs)) { + return false; + } + + Py_ssize_t len = PyObject_Length(value); + if (!detail::check_ssize_t_32(len)) { + return false; + } + + if (!impl()->writeListBegin(value, parsedargs, static_cast(len)) || PyErr_Occurred()) { + return false; + } + ScopedPyObject iterator(PyObject_GetIter(value)); + if (!iterator) { + return false; + } + + while (PyObject* rawItem = PyIter_Next(iterator.get())) { + ScopedPyObject item(rawItem); + if (!encodeValue(item.get(), parsedargs.element_type, parsedargs.typeargs)) { + return false; + } + } + + return true; + } + + case T_MAP: { + Py_ssize_t len = PyDict_Size(value); + if (!detail::check_ssize_t_32(len)) { + return false; + } + + MapTypeArgs parsedargs; + if (!parse_map_args(&parsedargs, typeargs)) { + return false; + } + + if (!impl()->writeMapBegin(value, parsedargs, static_cast(len)) || PyErr_Occurred()) { + return false; + } + Py_ssize_t pos = 0; + PyObject* k = NULL; + PyObject* v = NULL; + // TODO(bmaurer): should support any mapping, not just dicts + while (PyDict_Next(value, &pos, &k, &v)) { + if (!encodeValue(k, parsedargs.ktag, parsedargs.ktypeargs) + || !encodeValue(v, parsedargs.vtag, parsedargs.vtypeargs)) { + return false; + } + } + return true; + } + + case T_STRUCT: { + StructTypeArgs parsedargs; + if (!parse_struct_args(&parsedargs, typeargs)) { + return false; + } + + Py_ssize_t nspec = PyTuple_Size(parsedargs.spec); + if (nspec == -1) { + PyErr_SetString(PyExc_TypeError, "spec is not a tuple"); + return false; + } + + detail::WriteStructScope scope = detail::writeStructScope(this); + if (!scope) { + return false; + } + for (Py_ssize_t i = 0; i < nspec; i++) { + PyObject* spec_tuple = PyTuple_GET_ITEM(parsedargs.spec, i); + if (spec_tuple == Py_None) { + continue; + } + + StructItemSpec parsedspec; + if (!parse_struct_item_spec(&parsedspec, spec_tuple)) { + return false; + } + + ScopedPyObject instval(PyObject_GetAttr(value, parsedspec.attrname)); + + if (!instval) { + return false; + } + + if (instval.get() == Py_None) { + continue; + } + + bool res = impl()->writeField(instval.get(), parsedspec); + if (!res) { + return false; + } + } + impl()->writeFieldStop(); + return true; + } + + case T_STOP: + case T_VOID: + case T_UTF16: + case T_UTF8: + case T_U64: + default: + PyErr_Format(PyExc_TypeError, "Unexpected TType for encodeValue: %d", type); + return false; + } + + return true; +} + +template +bool ProtocolBase::skip(TType type) { + switch (type) { + case T_BOOL: + return impl()->skipBool(); + case T_I08: + return impl()->skipByte(); + case T_I16: + return impl()->skipI16(); + case T_I32: + return impl()->skipI32(); + case T_I64: + return impl()->skipI64(); + case T_DOUBLE: + return impl()->skipDouble(); + + case T_STRING: { + return impl()->skipString(); + } + + case T_LIST: + case T_SET: { + TType etype = T_STOP; + int32_t len = impl()->readListBegin(etype); + if (len < 0) { + return false; + } + for (int32_t i = 0; i < len; i++) { + if (!skip(etype)) { + return false; + } + } + return true; + } + + case T_MAP: { + TType ktype = T_STOP; + TType vtype = T_STOP; + int32_t len = impl()->readMapBegin(ktype, vtype); + if (len < 0) { + return false; + } + for (int32_t i = 0; i < len; i++) { + if (!skip(ktype) || !skip(vtype)) { + return false; + } + } + return true; + } + + case T_STRUCT: { + detail::ReadStructScope scope = detail::readStructScope(this); + if (!scope) { + return false; + } + while (true) { + TType type = T_STOP; + int16_t tag; + if (!impl()->readFieldBegin(type, tag)) { + return false; + } + if (type == T_STOP) { + return true; + } + if (!skip(type)) { + return false; + } + } + return true; + } + + case T_STOP: + case T_VOID: + case T_UTF16: + case T_UTF8: + case T_U64: + default: + PyErr_Format(PyExc_TypeError, "Unexpected TType for skip: %d", type); + return false; + } + + return true; +} + +// Returns a new reference. +template +PyObject* ProtocolBase::decodeValue(TType type, PyObject* typeargs) { + switch (type) { + + case T_BOOL: { + bool v = 0; + if (!impl()->readBool(v)) { + return NULL; + } + if (v) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + } + case T_I08: { + int8_t v = 0; + if (!impl()->readI8(v)) { + return NULL; + } + return PyInt_FromLong(v); + } + case T_I16: { + int16_t v = 0; + if (!impl()->readI16(v)) { + return NULL; + } + return PyInt_FromLong(v); + } + case T_I32: { + int32_t v = 0; + if (!impl()->readI32(v)) { + return NULL; + } + return PyInt_FromLong(v); + } + + case T_I64: { + int64_t v = 0; + if (!impl()->readI64(v)) { + return NULL; + } + // TODO(dreiss): Find out if we can take this fastpath always when + // sizeof(long) == sizeof(long long). + if (CHECK_RANGE(v, LONG_MIN, LONG_MAX)) { + return PyInt_FromLong((long)v); + } + return PyLong_FromLongLong(v); + } + + case T_DOUBLE: { + double v = 0.0; + if (!impl()->readDouble(v)) { + return NULL; + } + return PyFloat_FromDouble(v); + } + + case T_STRING: { + char* buf = NULL; + int len = impl()->readString(&buf); + if (len < 0) { + return NULL; + } + if (isUtf8(typeargs)) { + return PyUnicode_DecodeUTF8(buf, len, 0); + } else { + return PyBytes_FromStringAndSize(buf, len); + } + } + + case T_LIST: + case T_SET: { + SetListTypeArgs parsedargs; + if (!parse_set_list_args(&parsedargs, typeargs)) { + return NULL; + } + + TType etype = T_STOP; + int32_t len = impl()->readListBegin(etype); + if (len < 0) { + return NULL; + } + if (len > 0 && !checkType(etype, parsedargs.element_type)) { + return NULL; + } + + bool use_tuple = type == T_LIST && parsedargs.immutable; + ScopedPyObject ret(use_tuple ? PyTuple_New(len) : PyList_New(len)); + if (!ret) { + return NULL; + } + + for (int i = 0; i < len; i++) { + PyObject* item = decodeValue(etype, parsedargs.typeargs); + if (!item) { + return NULL; + } + if (use_tuple) { + PyTuple_SET_ITEM(ret.get(), i, item); + } else { + PyList_SET_ITEM(ret.get(), i, item); + } + } + + // TODO(dreiss): Consider biting the bullet and making two separate cases + // for list and set, avoiding this post facto conversion. + if (type == T_SET) { + PyObject* setret; + setret = parsedargs.immutable ? PyFrozenSet_New(ret.get()) : PySet_New(ret.get()); + return setret; + } + return ret.release(); + } + + case T_MAP: { + MapTypeArgs parsedargs; + if (!parse_map_args(&parsedargs, typeargs)) { + return NULL; + } + + TType ktype = T_STOP; + TType vtype = T_STOP; + uint32_t len = impl()->readMapBegin(ktype, vtype); + if (len > 0 && (!checkType(ktype, parsedargs.ktag) || !checkType(vtype, parsedargs.vtag))) { + return NULL; + } + + ScopedPyObject ret(PyDict_New()); + if (!ret) { + return NULL; + } + + for (uint32_t i = 0; i < len; i++) { + ScopedPyObject k(decodeValue(ktype, parsedargs.ktypeargs)); + if (!k) { + return NULL; + } + ScopedPyObject v(decodeValue(vtype, parsedargs.vtypeargs)); + if (!v) { + return NULL; + } + if (PyDict_SetItem(ret.get(), k.get(), v.get()) == -1) { + return NULL; + } + } + + if (parsedargs.immutable) { + if (!ThriftModule) { + ThriftModule = PyImport_ImportModule("thrift.Thrift"); + } + if (!ThriftModule) { + return NULL; + } + + ScopedPyObject cls(PyObject_GetAttr(ThriftModule, INTERN_STRING(TFrozenDict))); + if (!cls) { + return NULL; + } + + ScopedPyObject arg(PyTuple_New(1)); + PyTuple_SET_ITEM(arg.get(), 0, ret.release()); + ret.reset(PyObject_CallObject(cls.get(), arg.get())); + } + + return ret.release(); + } + + case T_STRUCT: { + StructTypeArgs parsedargs; + if (!parse_struct_args(&parsedargs, typeargs)) { + return NULL; + } + return readStruct(Py_None, parsedargs.klass, parsedargs.spec); + } + + case T_STOP: + case T_VOID: + case T_UTF16: + case T_UTF8: + case T_U64: + default: + PyErr_Format(PyExc_TypeError, "Unexpected TType for decodeValue: %d", type); + return NULL; + } +} + +template +PyObject* ProtocolBase::readStruct(PyObject* output, PyObject* klass, PyObject* spec_seq) { + int spec_seq_len = PyTuple_Size(spec_seq); + bool immutable = output == Py_None; + ScopedPyObject kwargs; + if (spec_seq_len == -1) { + return NULL; + } + + if (immutable) { + kwargs.reset(PyDict_New()); + if (!kwargs) { + PyErr_SetString(PyExc_TypeError, "failed to prepare kwargument storage"); + return NULL; + } + } + + detail::ReadStructScope scope = detail::readStructScope(this); + if (!scope) { + return NULL; + } + while (true) { + TType type = T_STOP; + int16_t tag; + if (!impl()->readFieldBegin(type, tag)) { + return NULL; + } + if (type == T_STOP) { + break; + } + if (tag < 0 || tag >= spec_seq_len) { + if (!skip(type)) { + PyErr_SetString(PyExc_TypeError, "Error while skipping unknown field"); + return NULL; + } + continue; + } + + PyObject* item_spec = PyTuple_GET_ITEM(spec_seq, tag); + if (item_spec == Py_None) { + if (!skip(type)) { + PyErr_SetString(PyExc_TypeError, "Error while skipping unknown field"); + return NULL; + } + continue; + } + StructItemSpec parsedspec; + if (!parse_struct_item_spec(&parsedspec, item_spec)) { + return NULL; + } + if (parsedspec.type != type) { + if (!skip(type)) { + PyErr_Format(PyExc_TypeError, "struct field had wrong type: expected %d but got %d", + parsedspec.type, type); + return NULL; + } + continue; + } + + ScopedPyObject fieldval(decodeValue(parsedspec.type, parsedspec.typeargs)); + if (!fieldval) { + return NULL; + } + + if ((immutable && PyDict_SetItem(kwargs.get(), parsedspec.attrname, fieldval.get()) == -1) + || (!immutable && PyObject_SetAttr(output, parsedspec.attrname, fieldval.get()) == -1)) { + return NULL; + } + } + if (immutable) { + ScopedPyObject args(PyTuple_New(0)); + if (!args) { + PyErr_SetString(PyExc_TypeError, "failed to prepare argument storage"); + return NULL; + } + return PyObject_Call(klass, args.get(), kwargs.get()); + } + Py_INCREF(output); + return output; +} +} +} +} +#endif // THRIFT_PY_PROTOCOL_H diff --git a/src/jaegertracing/thrift/lib/py/src/ext/types.cpp b/src/jaegertracing/thrift/lib/py/src/ext/types.cpp new file mode 100644 index 000000000..68443fbe8 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/ext/types.cpp @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "ext/types.h" +#include "ext/protocol.h" + +namespace apache { +namespace thrift { +namespace py { + +PyObject* ThriftModule = NULL; + +#if PY_MAJOR_VERSION < 3 +char refill_signature[] = {'s', '#', 'i'}; +#else +const char* refill_signature = "y#i"; +#endif + +bool parse_struct_item_spec(StructItemSpec* dest, PyObject* spec_tuple) { + // i'd like to use ParseArgs here, but it seems to be a bottleneck. + if (PyTuple_Size(spec_tuple) != 5) { + PyErr_Format(PyExc_TypeError, "expecting 5 arguments for spec tuple but got %d", + static_cast(PyTuple_Size(spec_tuple))); + return false; + } + + dest->tag = static_cast(PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 0))); + if (INT_CONV_ERROR_OCCURRED(dest->tag)) { + return false; + } + + dest->type = static_cast(PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 1))); + if (INT_CONV_ERROR_OCCURRED(dest->type)) { + return false; + } + + dest->attrname = PyTuple_GET_ITEM(spec_tuple, 2); + dest->typeargs = PyTuple_GET_ITEM(spec_tuple, 3); + dest->defval = PyTuple_GET_ITEM(spec_tuple, 4); + return true; +} + +bool parse_set_list_args(SetListTypeArgs* dest, PyObject* typeargs) { + if (PyTuple_Size(typeargs) != 3) { + PyErr_SetString(PyExc_TypeError, "expecting tuple of size 3 for list/set type args"); + return false; + } + + dest->element_type = static_cast(PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0))); + if (INT_CONV_ERROR_OCCURRED(dest->element_type)) { + return false; + } + + dest->typeargs = PyTuple_GET_ITEM(typeargs, 1); + + dest->immutable = Py_True == PyTuple_GET_ITEM(typeargs, 2); + + return true; +} + +bool parse_map_args(MapTypeArgs* dest, PyObject* typeargs) { + if (PyTuple_Size(typeargs) != 5) { + PyErr_SetString(PyExc_TypeError, "expecting 5 arguments for typeargs to map"); + return false; + } + + dest->ktag = static_cast(PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0))); + if (INT_CONV_ERROR_OCCURRED(dest->ktag)) { + return false; + } + + dest->vtag = static_cast(PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 2))); + if (INT_CONV_ERROR_OCCURRED(dest->vtag)) { + return false; + } + + dest->ktypeargs = PyTuple_GET_ITEM(typeargs, 1); + dest->vtypeargs = PyTuple_GET_ITEM(typeargs, 3); + dest->immutable = Py_True == PyTuple_GET_ITEM(typeargs, 4); + + return true; +} + +bool parse_struct_args(StructTypeArgs* dest, PyObject* typeargs) { + if (PyList_Size(typeargs) != 2) { + PyErr_SetString(PyExc_TypeError, "expecting list of size 2 for struct args"); + return false; + } + + dest->klass = PyList_GET_ITEM(typeargs, 0); + dest->spec = PyList_GET_ITEM(typeargs, 1); + + return true; +} +} +} +} diff --git a/src/jaegertracing/thrift/lib/py/src/ext/types.h b/src/jaegertracing/thrift/lib/py/src/ext/types.h new file mode 100644 index 000000000..5cd8dda9e --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/ext/types.h @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef THRIFT_PY_TYPES_H +#define THRIFT_PY_TYPES_H + +#include + +#ifdef _MSC_VER +#define __STDC_FORMAT_MACROS +#define __STDC_LIMIT_MACROS +#endif +#include + +#if PY_MAJOR_VERSION >= 3 + +#include + +// TODO: better macros +#define PyInt_AsLong(v) PyLong_AsLong(v) +#define PyInt_FromLong(v) PyLong_FromLong(v) + +#define PyString_InternFromString(v) PyUnicode_InternFromString(v) + +#endif + +#define INTERN_STRING(value) _intern_##value + +#define INT_CONV_ERROR_OCCURRED(v) (((v) == -1) && PyErr_Occurred()) + +extern "C" { +extern PyObject* INTERN_STRING(TFrozenDict); +extern PyObject* INTERN_STRING(cstringio_buf); +extern PyObject* INTERN_STRING(cstringio_refill); +} + +namespace apache { +namespace thrift { +namespace py { + +extern PyObject* ThriftModule; + +// Stolen out of TProtocol.h. +// It would be a huge pain to have both get this from one place. +enum TType { + T_INVALID = -1, + T_STOP = 0, + T_VOID = 1, + T_BOOL = 2, + T_BYTE = 3, + T_I08 = 3, + T_I16 = 6, + T_I32 = 8, + T_U64 = 9, + T_I64 = 10, + T_DOUBLE = 4, + T_STRING = 11, + T_UTF7 = 11, + T_STRUCT = 12, + T_MAP = 13, + T_SET = 14, + T_LIST = 15, + T_UTF8 = 16, + T_UTF16 = 17 +}; + +// replace with unique_ptr when we're OK with C++11 +class ScopedPyObject { +public: + ScopedPyObject() : obj_(NULL) {} + explicit ScopedPyObject(PyObject* py_object) : obj_(py_object) {} + ~ScopedPyObject() { + if (obj_) + Py_DECREF(obj_); + } + PyObject* get() throw() { return obj_; } + operator bool() { return obj_; } + void reset(PyObject* py_object) throw() { + if (obj_) + Py_DECREF(obj_); + obj_ = py_object; + } + PyObject* release() throw() { + PyObject* tmp = obj_; + obj_ = NULL; + return tmp; + } + void swap(ScopedPyObject& other) throw() { + ScopedPyObject tmp(other.release()); + other.reset(release()); + reset(tmp.release()); + } + +private: + ScopedPyObject(const ScopedPyObject&) {} + ScopedPyObject& operator=(const ScopedPyObject&) { return *this; } + + PyObject* obj_; +}; + +/** + * A cache of the two key attributes of a CReadableTransport, + * so we don't have to keep calling PyObject_GetAttr. + */ +struct DecodeBuffer { + ScopedPyObject stringiobuf; + ScopedPyObject refill_callable; +}; + +#if PY_MAJOR_VERSION < 3 +extern char refill_signature[3]; +typedef PyObject EncodeBuffer; +#else +extern const char* refill_signature; +struct EncodeBuffer { + std::vector buf; + size_t pos; +}; +#endif + +/** + * A cache of the spec_args for a set or list, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +struct SetListTypeArgs { + TType element_type; + PyObject* typeargs; + bool immutable; +}; + +/** + * A cache of the spec_args for a map, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +struct MapTypeArgs { + TType ktag; + TType vtag; + PyObject* ktypeargs; + PyObject* vtypeargs; + bool immutable; +}; + +/** + * A cache of the spec_args for a struct, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +struct StructTypeArgs { + PyObject* klass; + PyObject* spec; + bool immutable; +}; + +/** + * A cache of the item spec from a struct specification, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +struct StructItemSpec { + int tag; + TType type; + PyObject* attrname; + PyObject* typeargs; + PyObject* defval; +}; + +bool parse_set_list_args(SetListTypeArgs* dest, PyObject* typeargs); + +bool parse_map_args(MapTypeArgs* dest, PyObject* typeargs); + +bool parse_struct_args(StructTypeArgs* dest, PyObject* typeargs); + +bool parse_struct_item_spec(StructItemSpec* dest, PyObject* spec_tuple); +} +} +} + +#endif // THRIFT_PY_TYPES_H diff --git a/src/jaegertracing/thrift/lib/py/src/protocol/TBase.py b/src/jaegertracing/thrift/lib/py/src/protocol/TBase.py new file mode 100644 index 000000000..9ae1b1182 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/protocol/TBase.py @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from thrift.transport import TTransport + + +class TBase(object): + __slots__ = () + + def __repr__(self): + L = ['%s=%r' % (key, getattr(self, key)) for key in self.__slots__] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + for attr in self.__slots__: + my_val = getattr(self, attr) + other_val = getattr(other, attr) + if my_val != other_val: + return False + return True + + def __ne__(self, other): + return not (self == other) + + def read(self, iprot): + if (iprot._fast_decode is not None and + isinstance(iprot.trans, TTransport.CReadableTransport) and + self.thrift_spec is not None): + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + else: + iprot.readStruct(self, self.thrift_spec) + + def write(self, oprot): + if (oprot._fast_encode is not None and self.thrift_spec is not None): + oprot.trans.write( + oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + else: + oprot.writeStruct(self, self.thrift_spec) + + +class TExceptionBase(TBase, Exception): + pass + + +class TFrozenBase(TBase): + def __setitem__(self, *args): + raise TypeError("Can't modify frozen struct") + + def __delitem__(self, *args): + raise TypeError("Can't modify frozen struct") + + def __hash__(self, *args): + return hash(self.__class__) ^ hash(self.__slots__) + + @classmethod + def read(cls, iprot): + if (iprot._fast_decode is not None and + isinstance(iprot.trans, TTransport.CReadableTransport) and + cls.thrift_spec is not None): + self = cls() + return iprot._fast_decode(None, iprot, + [self.__class__, self.thrift_spec]) + else: + return iprot.readStruct(cls, cls.thrift_spec, True) diff --git a/src/jaegertracing/thrift/lib/py/src/protocol/TBinaryProtocol.py b/src/jaegertracing/thrift/lib/py/src/protocol/TBinaryProtocol.py new file mode 100644 index 000000000..6b2facc4f --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/protocol/TBinaryProtocol.py @@ -0,0 +1,301 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from .TProtocol import TType, TProtocolBase, TProtocolException, TProtocolFactory +from struct import pack, unpack + + +class TBinaryProtocol(TProtocolBase): + """Binary implementation of the Thrift protocol driver.""" + + # NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be + # positive, converting this into a long. If we hardcode the int value + # instead it'll stay in 32 bit-land. + + # VERSION_MASK = 0xffff0000 + VERSION_MASK = -65536 + + # VERSION_1 = 0x80010000 + VERSION_1 = -2147418112 + + TYPE_MASK = 0x000000ff + + def __init__(self, trans, strictRead=False, strictWrite=True, **kwargs): + TProtocolBase.__init__(self, trans) + self.strictRead = strictRead + self.strictWrite = strictWrite + self.string_length_limit = kwargs.get('string_length_limit', None) + self.container_length_limit = kwargs.get('container_length_limit', None) + + def _check_string_length(self, length): + self._check_length(self.string_length_limit, length) + + def _check_container_length(self, length): + self._check_length(self.container_length_limit, length) + + def writeMessageBegin(self, name, type, seqid): + if self.strictWrite: + self.writeI32(TBinaryProtocol.VERSION_1 | type) + self.writeString(name) + self.writeI32(seqid) + else: + self.writeString(name) + self.writeByte(type) + self.writeI32(seqid) + + def writeMessageEnd(self): + pass + + def writeStructBegin(self, name): + pass + + def writeStructEnd(self): + pass + + def writeFieldBegin(self, name, type, id): + self.writeByte(type) + self.writeI16(id) + + def writeFieldEnd(self): + pass + + def writeFieldStop(self): + self.writeByte(TType.STOP) + + def writeMapBegin(self, ktype, vtype, size): + self.writeByte(ktype) + self.writeByte(vtype) + self.writeI32(size) + + def writeMapEnd(self): + pass + + def writeListBegin(self, etype, size): + self.writeByte(etype) + self.writeI32(size) + + def writeListEnd(self): + pass + + def writeSetBegin(self, etype, size): + self.writeByte(etype) + self.writeI32(size) + + def writeSetEnd(self): + pass + + def writeBool(self, bool): + if bool: + self.writeByte(1) + else: + self.writeByte(0) + + def writeByte(self, byte): + buff = pack("!b", byte) + self.trans.write(buff) + + def writeI16(self, i16): + buff = pack("!h", i16) + self.trans.write(buff) + + def writeI32(self, i32): + buff = pack("!i", i32) + self.trans.write(buff) + + def writeI64(self, i64): + buff = pack("!q", i64) + self.trans.write(buff) + + def writeDouble(self, dub): + buff = pack("!d", dub) + self.trans.write(buff) + + def writeBinary(self, str): + self.writeI32(len(str)) + self.trans.write(str) + + def readMessageBegin(self): + sz = self.readI32() + if sz < 0: + version = sz & TBinaryProtocol.VERSION_MASK + if version != TBinaryProtocol.VERSION_1: + raise TProtocolException( + type=TProtocolException.BAD_VERSION, + message='Bad version in readMessageBegin: %d' % (sz)) + type = sz & TBinaryProtocol.TYPE_MASK + name = self.readString() + seqid = self.readI32() + else: + if self.strictRead: + raise TProtocolException(type=TProtocolException.BAD_VERSION, + message='No protocol version header') + name = self.trans.readAll(sz) + type = self.readByte() + seqid = self.readI32() + return (name, type, seqid) + + def readMessageEnd(self): + pass + + def readStructBegin(self): + pass + + def readStructEnd(self): + pass + + def readFieldBegin(self): + type = self.readByte() + if type == TType.STOP: + return (None, type, 0) + id = self.readI16() + return (None, type, id) + + def readFieldEnd(self): + pass + + def readMapBegin(self): + ktype = self.readByte() + vtype = self.readByte() + size = self.readI32() + self._check_container_length(size) + return (ktype, vtype, size) + + def readMapEnd(self): + pass + + def readListBegin(self): + etype = self.readByte() + size = self.readI32() + self._check_container_length(size) + return (etype, size) + + def readListEnd(self): + pass + + def readSetBegin(self): + etype = self.readByte() + size = self.readI32() + self._check_container_length(size) + return (etype, size) + + def readSetEnd(self): + pass + + def readBool(self): + byte = self.readByte() + if byte == 0: + return False + return True + + def readByte(self): + buff = self.trans.readAll(1) + val, = unpack('!b', buff) + return val + + def readI16(self): + buff = self.trans.readAll(2) + val, = unpack('!h', buff) + return val + + def readI32(self): + buff = self.trans.readAll(4) + val, = unpack('!i', buff) + return val + + def readI64(self): + buff = self.trans.readAll(8) + val, = unpack('!q', buff) + return val + + def readDouble(self): + buff = self.trans.readAll(8) + val, = unpack('!d', buff) + return val + + def readBinary(self): + size = self.readI32() + self._check_string_length(size) + s = self.trans.readAll(size) + return s + + +class TBinaryProtocolFactory(TProtocolFactory): + def __init__(self, strictRead=False, strictWrite=True, **kwargs): + self.strictRead = strictRead + self.strictWrite = strictWrite + self.string_length_limit = kwargs.get('string_length_limit', None) + self.container_length_limit = kwargs.get('container_length_limit', None) + + def getProtocol(self, trans): + prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite, + string_length_limit=self.string_length_limit, + container_length_limit=self.container_length_limit) + return prot + + +class TBinaryProtocolAccelerated(TBinaryProtocol): + """C-Accelerated version of TBinaryProtocol. + + This class does not override any of TBinaryProtocol's methods, + but the generated code recognizes it directly and will call into + our C module to do the encoding, bypassing this object entirely. + We inherit from TBinaryProtocol so that the normal TBinaryProtocol + encoding can happen if the fastbinary module doesn't work for some + reason. (TODO(dreiss): Make this happen sanely in more cases.) + To disable this behavior, pass fallback=False constructor argument. + + In order to take advantage of the C module, just use + TBinaryProtocolAccelerated instead of TBinaryProtocol. + + NOTE: This code was contributed by an external developer. + The internal Thrift team has reviewed and tested it, + but we cannot guarantee that it is production-ready. + Please feel free to report bugs and/or success stories + to the public mailing list. + """ + pass + + def __init__(self, *args, **kwargs): + fallback = kwargs.pop('fallback', True) + super(TBinaryProtocolAccelerated, self).__init__(*args, **kwargs) + try: + from thrift.protocol import fastbinary + except ImportError: + if not fallback: + raise + else: + self._fast_decode = fastbinary.decode_binary + self._fast_encode = fastbinary.encode_binary + + +class TBinaryProtocolAcceleratedFactory(TProtocolFactory): + def __init__(self, + string_length_limit=None, + container_length_limit=None, + fallback=True): + self.string_length_limit = string_length_limit + self.container_length_limit = container_length_limit + self._fallback = fallback + + def getProtocol(self, trans): + return TBinaryProtocolAccelerated( + trans, + string_length_limit=self.string_length_limit, + container_length_limit=self.container_length_limit, + fallback=self._fallback) diff --git a/src/jaegertracing/thrift/lib/py/src/protocol/TCompactProtocol.py b/src/jaegertracing/thrift/lib/py/src/protocol/TCompactProtocol.py new file mode 100644 index 000000000..700e792f7 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/protocol/TCompactProtocol.py @@ -0,0 +1,487 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from .TProtocol import TType, TProtocolBase, TProtocolException, TProtocolFactory, checkIntegerLimits +from struct import pack, unpack + +from ..compat import binary_to_str, str_to_binary + +__all__ = ['TCompactProtocol', 'TCompactProtocolFactory'] + +CLEAR = 0 +FIELD_WRITE = 1 +VALUE_WRITE = 2 +CONTAINER_WRITE = 3 +BOOL_WRITE = 4 +FIELD_READ = 5 +CONTAINER_READ = 6 +VALUE_READ = 7 +BOOL_READ = 8 + + +def make_helper(v_from, container): + def helper(func): + def nested(self, *args, **kwargs): + assert self.state in (v_from, container), (self.state, v_from, container) + return func(self, *args, **kwargs) + return nested + return helper + + +writer = make_helper(VALUE_WRITE, CONTAINER_WRITE) +reader = make_helper(VALUE_READ, CONTAINER_READ) + + +def makeZigZag(n, bits): + checkIntegerLimits(n, bits) + return (n << 1) ^ (n >> (bits - 1)) + + +def fromZigZag(n): + return (n >> 1) ^ -(n & 1) + + +def writeVarint(trans, n): + assert n >= 0, "Input to TCompactProtocol writeVarint cannot be negative!" + out = bytearray() + while True: + if n & ~0x7f == 0: + out.append(n) + break + else: + out.append((n & 0xff) | 0x80) + n = n >> 7 + trans.write(bytes(out)) + + +def readVarint(trans): + result = 0 + shift = 0 + while True: + x = trans.readAll(1) + byte = ord(x) + result |= (byte & 0x7f) << shift + if byte >> 7 == 0: + return result + shift += 7 + + +class CompactType(object): + STOP = 0x00 + TRUE = 0x01 + FALSE = 0x02 + BYTE = 0x03 + I16 = 0x04 + I32 = 0x05 + I64 = 0x06 + DOUBLE = 0x07 + BINARY = 0x08 + LIST = 0x09 + SET = 0x0A + MAP = 0x0B + STRUCT = 0x0C + + +CTYPES = { + TType.STOP: CompactType.STOP, + TType.BOOL: CompactType.TRUE, # used for collection + TType.BYTE: CompactType.BYTE, + TType.I16: CompactType.I16, + TType.I32: CompactType.I32, + TType.I64: CompactType.I64, + TType.DOUBLE: CompactType.DOUBLE, + TType.STRING: CompactType.BINARY, + TType.STRUCT: CompactType.STRUCT, + TType.LIST: CompactType.LIST, + TType.SET: CompactType.SET, + TType.MAP: CompactType.MAP, +} + +TTYPES = {} +for k, v in CTYPES.items(): + TTYPES[v] = k +TTYPES[CompactType.FALSE] = TType.BOOL +del k +del v + + +class TCompactProtocol(TProtocolBase): + """Compact implementation of the Thrift protocol driver.""" + + PROTOCOL_ID = 0x82 + VERSION = 1 + VERSION_MASK = 0x1f + TYPE_MASK = 0xe0 + TYPE_BITS = 0x07 + TYPE_SHIFT_AMOUNT = 5 + + def __init__(self, trans, + string_length_limit=None, + container_length_limit=None): + TProtocolBase.__init__(self, trans) + self.state = CLEAR + self.__last_fid = 0 + self.__bool_fid = None + self.__bool_value = None + self.__structs = [] + self.__containers = [] + self.string_length_limit = string_length_limit + self.container_length_limit = container_length_limit + + def _check_string_length(self, length): + self._check_length(self.string_length_limit, length) + + def _check_container_length(self, length): + self._check_length(self.container_length_limit, length) + + def __writeVarint(self, n): + writeVarint(self.trans, n) + + def writeMessageBegin(self, name, type, seqid): + assert self.state == CLEAR + self.__writeUByte(self.PROTOCOL_ID) + self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT)) + # The sequence id is a signed 32-bit integer but the compact protocol + # writes this out as a "var int" which is always positive, and attempting + # to write a negative number results in an infinite loop, so we may + # need to do some conversion here... + tseqid = seqid + if tseqid < 0: + tseqid = 2147483648 + (2147483648 + tseqid) + self.__writeVarint(tseqid) + self.__writeBinary(str_to_binary(name)) + self.state = VALUE_WRITE + + def writeMessageEnd(self): + assert self.state == VALUE_WRITE + self.state = CLEAR + + def writeStructBegin(self, name): + assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state + self.__structs.append((self.state, self.__last_fid)) + self.state = FIELD_WRITE + self.__last_fid = 0 + + def writeStructEnd(self): + assert self.state == FIELD_WRITE + self.state, self.__last_fid = self.__structs.pop() + + def writeFieldStop(self): + self.__writeByte(0) + + def __writeFieldHeader(self, type, fid): + delta = fid - self.__last_fid + if 0 < delta <= 15: + self.__writeUByte(delta << 4 | type) + else: + self.__writeByte(type) + self.__writeI16(fid) + self.__last_fid = fid + + def writeFieldBegin(self, name, type, fid): + assert self.state == FIELD_WRITE, self.state + if type == TType.BOOL: + self.state = BOOL_WRITE + self.__bool_fid = fid + else: + self.state = VALUE_WRITE + self.__writeFieldHeader(CTYPES[type], fid) + + def writeFieldEnd(self): + assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state + self.state = FIELD_WRITE + + def __writeUByte(self, byte): + self.trans.write(pack('!B', byte)) + + def __writeByte(self, byte): + self.trans.write(pack('!b', byte)) + + def __writeI16(self, i16): + self.__writeVarint(makeZigZag(i16, 16)) + + def __writeSize(self, i32): + self.__writeVarint(i32) + + def writeCollectionBegin(self, etype, size): + assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state + if size <= 14: + self.__writeUByte(size << 4 | CTYPES[etype]) + else: + self.__writeUByte(0xf0 | CTYPES[etype]) + self.__writeSize(size) + self.__containers.append(self.state) + self.state = CONTAINER_WRITE + writeSetBegin = writeCollectionBegin + writeListBegin = writeCollectionBegin + + def writeMapBegin(self, ktype, vtype, size): + assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state + if size == 0: + self.__writeByte(0) + else: + self.__writeSize(size) + self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype]) + self.__containers.append(self.state) + self.state = CONTAINER_WRITE + + def writeCollectionEnd(self): + assert self.state == CONTAINER_WRITE, self.state + self.state = self.__containers.pop() + writeMapEnd = writeCollectionEnd + writeSetEnd = writeCollectionEnd + writeListEnd = writeCollectionEnd + + def writeBool(self, bool): + if self.state == BOOL_WRITE: + if bool: + ctype = CompactType.TRUE + else: + ctype = CompactType.FALSE + self.__writeFieldHeader(ctype, self.__bool_fid) + elif self.state == CONTAINER_WRITE: + if bool: + self.__writeByte(CompactType.TRUE) + else: + self.__writeByte(CompactType.FALSE) + else: + raise AssertionError("Invalid state in compact protocol") + + writeByte = writer(__writeByte) + writeI16 = writer(__writeI16) + + @writer + def writeI32(self, i32): + self.__writeVarint(makeZigZag(i32, 32)) + + @writer + def writeI64(self, i64): + self.__writeVarint(makeZigZag(i64, 64)) + + @writer + def writeDouble(self, dub): + self.trans.write(pack('> 4 + if delta == 0: + fid = self.__readI16() + else: + fid = self.__last_fid + delta + self.__last_fid = fid + type = type & 0x0f + if type == CompactType.TRUE: + self.state = BOOL_READ + self.__bool_value = True + elif type == CompactType.FALSE: + self.state = BOOL_READ + self.__bool_value = False + else: + self.state = VALUE_READ + return (None, self.__getTType(type), fid) + + def readFieldEnd(self): + assert self.state in (VALUE_READ, BOOL_READ), self.state + self.state = FIELD_READ + + def __readUByte(self): + result, = unpack('!B', self.trans.readAll(1)) + return result + + def __readByte(self): + result, = unpack('!b', self.trans.readAll(1)) + return result + + def __readVarint(self): + return readVarint(self.trans) + + def __readZigZag(self): + return fromZigZag(self.__readVarint()) + + def __readSize(self): + result = self.__readVarint() + if result < 0: + raise TProtocolException("Length < 0") + return result + + def readMessageBegin(self): + assert self.state == CLEAR + proto_id = self.__readUByte() + if proto_id != self.PROTOCOL_ID: + raise TProtocolException(TProtocolException.BAD_VERSION, + 'Bad protocol id in the message: %d' % proto_id) + ver_type = self.__readUByte() + type = (ver_type >> self.TYPE_SHIFT_AMOUNT) & self.TYPE_BITS + version = ver_type & self.VERSION_MASK + if version != self.VERSION: + raise TProtocolException(TProtocolException.BAD_VERSION, + 'Bad version: %d (expect %d)' % (version, self.VERSION)) + seqid = self.__readVarint() + # the sequence is a compact "var int" which is treaded as unsigned, + # however the sequence is actually signed... + if seqid > 2147483647: + seqid = -2147483648 - (2147483648 - seqid) + name = binary_to_str(self.__readBinary()) + return (name, type, seqid) + + def readMessageEnd(self): + assert self.state == CLEAR + assert len(self.__structs) == 0 + + def readStructBegin(self): + assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state + self.__structs.append((self.state, self.__last_fid)) + self.state = FIELD_READ + self.__last_fid = 0 + + def readStructEnd(self): + assert self.state == FIELD_READ + self.state, self.__last_fid = self.__structs.pop() + + def readCollectionBegin(self): + assert self.state in (VALUE_READ, CONTAINER_READ), self.state + size_type = self.__readUByte() + size = size_type >> 4 + type = self.__getTType(size_type) + if size == 15: + size = self.__readSize() + self._check_container_length(size) + self.__containers.append(self.state) + self.state = CONTAINER_READ + return type, size + readSetBegin = readCollectionBegin + readListBegin = readCollectionBegin + + def readMapBegin(self): + assert self.state in (VALUE_READ, CONTAINER_READ), self.state + size = self.__readSize() + self._check_container_length(size) + types = 0 + if size > 0: + types = self.__readUByte() + vtype = self.__getTType(types) + ktype = self.__getTType(types >> 4) + self.__containers.append(self.state) + self.state = CONTAINER_READ + return (ktype, vtype, size) + + def readCollectionEnd(self): + assert self.state == CONTAINER_READ, self.state + self.state = self.__containers.pop() + readSetEnd = readCollectionEnd + readListEnd = readCollectionEnd + readMapEnd = readCollectionEnd + + def readBool(self): + if self.state == BOOL_READ: + return self.__bool_value == CompactType.TRUE + elif self.state == CONTAINER_READ: + return self.__readByte() == CompactType.TRUE + else: + raise AssertionError("Invalid state in compact protocol: %d" % + self.state) + + readByte = reader(__readByte) + __readI16 = __readZigZag + readI16 = reader(__readZigZag) + readI32 = reader(__readZigZag) + readI64 = reader(__readZigZag) + + @reader + def readDouble(self): + buff = self.trans.readAll(8) + val, = unpack('= 0xd800 and codeunit <= 0xdbff + + def _isLowSurrogate(self, codeunit): + return codeunit >= 0xdc00 and codeunit <= 0xdfff + + def _toChar(self, high, low=None): + if not low: + if sys.version_info[0] == 2: + return ("\\u%04x" % high).decode('unicode-escape') \ + .encode('utf-8') + else: + return chr(high) + else: + codepoint = (1 << 16) + ((high & 0x3ff) << 10) + codepoint += low & 0x3ff + if sys.version_info[0] == 2: + s = "\\U%08x" % codepoint + return s.decode('unicode-escape').encode('utf-8') + else: + return chr(codepoint) + + def readJSONString(self, skipContext): + highSurrogate = None + string = [] + if skipContext is False: + self.context.read() + self.readJSONSyntaxChar(QUOTE) + while True: + character = self.reader.read() + if character == QUOTE: + break + if ord(character) == ESCSEQ0: + character = self.reader.read() + if ord(character) == ESCSEQ1: + character = self.trans.read(4).decode('ascii') + codeunit = int(character, 16) + if self._isHighSurrogate(codeunit): + if highSurrogate: + raise TProtocolException( + TProtocolException.INVALID_DATA, + "Expected low surrogate char") + highSurrogate = codeunit + continue + elif self._isLowSurrogate(codeunit): + if not highSurrogate: + raise TProtocolException( + TProtocolException.INVALID_DATA, + "Expected high surrogate char") + character = self._toChar(highSurrogate, codeunit) + highSurrogate = None + else: + character = self._toChar(codeunit) + else: + if character not in ESCAPE_CHARS: + raise TProtocolException( + TProtocolException.INVALID_DATA, + "Expected control char") + character = ESCAPE_CHARS[character] + elif character in ESCAPE_CHAR_VALS: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Unescaped control char") + elif sys.version_info[0] > 2: + utf8_bytes = bytearray([ord(character)]) + while ord(self.reader.peek()) >= 0x80: + utf8_bytes.append(ord(self.reader.read())) + character = utf8_bytes.decode('utf8') + string.append(character) + + if highSurrogate: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Expected low surrogate char") + return ''.join(string) + + def isJSONNumeric(self, character): + return (True if NUMERIC_CHAR.find(character) != - 1 else False) + + def readJSONQuotes(self): + if (self.context.escapeNum()): + self.readJSONSyntaxChar(QUOTE) + + def readJSONNumericChars(self): + numeric = [] + while True: + character = self.reader.peek() + if self.isJSONNumeric(character) is False: + break + numeric.append(self.reader.read()) + return b''.join(numeric).decode('ascii') + + def readJSONInteger(self): + self.context.read() + self.readJSONQuotes() + numeric = self.readJSONNumericChars() + self.readJSONQuotes() + try: + return int(numeric) + except ValueError: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Bad data encounted in numeric data") + + def readJSONDouble(self): + self.context.read() + if self.reader.peek() == QUOTE: + string = self.readJSONString(True) + try: + double = float(string) + if (self.context.escapeNum is False and + not math.isinf(double) and + not math.isnan(double)): + raise TProtocolException( + TProtocolException.INVALID_DATA, + "Numeric data unexpectedly quoted") + return double + except ValueError: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Bad data encounted in numeric data") + else: + if self.context.escapeNum() is True: + self.readJSONSyntaxChar(QUOTE) + try: + return float(self.readJSONNumericChars()) + except ValueError: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Bad data encounted in numeric data") + + def readJSONBase64(self): + string = self.readJSONString(False) + size = len(string) + m = size % 4 + # Force padding since b64encode method does not allow it + if m != 0: + for i in range(4 - m): + string += '=' + return base64.b64decode(string) + + def readJSONObjectStart(self): + self.context.read() + self.readJSONSyntaxChar(LBRACE) + self.pushContext(JSONPairContext(self)) + + def readJSONObjectEnd(self): + self.readJSONSyntaxChar(RBRACE) + self.popContext() + + def readJSONArrayStart(self): + self.context.read() + self.readJSONSyntaxChar(LBRACKET) + self.pushContext(JSONListContext(self)) + + def readJSONArrayEnd(self): + self.readJSONSyntaxChar(RBRACKET) + self.popContext() + + +class TJSONProtocol(TJSONProtocolBase): + + def readMessageBegin(self): + self.resetReadContext() + self.readJSONArrayStart() + if self.readJSONInteger() != VERSION: + raise TProtocolException(TProtocolException.BAD_VERSION, + "Message contained bad version.") + name = self.readJSONString(False) + typen = self.readJSONInteger() + seqid = self.readJSONInteger() + return (name, typen, seqid) + + def readMessageEnd(self): + self.readJSONArrayEnd() + + def readStructBegin(self): + self.readJSONObjectStart() + + def readStructEnd(self): + self.readJSONObjectEnd() + + def readFieldBegin(self): + character = self.reader.peek() + ttype = 0 + id = 0 + if character == RBRACE: + ttype = TType.STOP + else: + id = self.readJSONInteger() + self.readJSONObjectStart() + ttype = JTYPES[self.readJSONString(False)] + return (None, ttype, id) + + def readFieldEnd(self): + self.readJSONObjectEnd() + + def readMapBegin(self): + self.readJSONArrayStart() + keyType = JTYPES[self.readJSONString(False)] + valueType = JTYPES[self.readJSONString(False)] + size = self.readJSONInteger() + self.readJSONObjectStart() + return (keyType, valueType, size) + + def readMapEnd(self): + self.readJSONObjectEnd() + self.readJSONArrayEnd() + + def readCollectionBegin(self): + self.readJSONArrayStart() + elemType = JTYPES[self.readJSONString(False)] + size = self.readJSONInteger() + return (elemType, size) + readListBegin = readCollectionBegin + readSetBegin = readCollectionBegin + + def readCollectionEnd(self): + self.readJSONArrayEnd() + readSetEnd = readCollectionEnd + readListEnd = readCollectionEnd + + def readBool(self): + return (False if self.readJSONInteger() == 0 else True) + + def readNumber(self): + return self.readJSONInteger() + readByte = readNumber + readI16 = readNumber + readI32 = readNumber + readI64 = readNumber + + def readDouble(self): + return self.readJSONDouble() + + def readString(self): + return self.readJSONString(False) + + def readBinary(self): + return self.readJSONBase64() + + def writeMessageBegin(self, name, request_type, seqid): + self.resetWriteContext() + self.writeJSONArrayStart() + self.writeJSONNumber(VERSION) + self.writeJSONString(name) + self.writeJSONNumber(request_type) + self.writeJSONNumber(seqid) + + def writeMessageEnd(self): + self.writeJSONArrayEnd() + + def writeStructBegin(self, name): + self.writeJSONObjectStart() + + def writeStructEnd(self): + self.writeJSONObjectEnd() + + def writeFieldBegin(self, name, ttype, id): + self.writeJSONNumber(id) + self.writeJSONObjectStart() + self.writeJSONString(CTYPES[ttype]) + + def writeFieldEnd(self): + self.writeJSONObjectEnd() + + def writeFieldStop(self): + pass + + def writeMapBegin(self, ktype, vtype, size): + self.writeJSONArrayStart() + self.writeJSONString(CTYPES[ktype]) + self.writeJSONString(CTYPES[vtype]) + self.writeJSONNumber(size) + self.writeJSONObjectStart() + + def writeMapEnd(self): + self.writeJSONObjectEnd() + self.writeJSONArrayEnd() + + def writeListBegin(self, etype, size): + self.writeJSONArrayStart() + self.writeJSONString(CTYPES[etype]) + self.writeJSONNumber(size) + + def writeListEnd(self): + self.writeJSONArrayEnd() + + def writeSetBegin(self, etype, size): + self.writeJSONArrayStart() + self.writeJSONString(CTYPES[etype]) + self.writeJSONNumber(size) + + def writeSetEnd(self): + self.writeJSONArrayEnd() + + def writeBool(self, boolean): + self.writeJSONNumber(1 if boolean is True else 0) + + def writeByte(self, byte): + checkIntegerLimits(byte, 8) + self.writeJSONNumber(byte) + + def writeI16(self, i16): + checkIntegerLimits(i16, 16) + self.writeJSONNumber(i16) + + def writeI32(self, i32): + checkIntegerLimits(i32, 32) + self.writeJSONNumber(i32) + + def writeI64(self, i64): + checkIntegerLimits(i64, 64) + self.writeJSONNumber(i64) + + def writeDouble(self, dbl): + # 17 significant digits should be just enough for any double precision + # value. + self.writeJSONNumber(dbl, '{0:.17g}') + + def writeString(self, string): + self.writeJSONString(string) + + def writeBinary(self, binary): + self.writeJSONBase64(binary) + + +class TJSONProtocolFactory(TProtocolFactory): + def getProtocol(self, trans): + return TJSONProtocol(trans) + + @property + def string_length_limit(senf): + return None + + @property + def container_length_limit(senf): + return None + + +class TSimpleJSONProtocol(TJSONProtocolBase): + """Simple, readable, write-only JSON protocol. + + Useful for interacting with scripting languages. + """ + + def readMessageBegin(self): + raise NotImplementedError() + + def readMessageEnd(self): + raise NotImplementedError() + + def readStructBegin(self): + raise NotImplementedError() + + def readStructEnd(self): + raise NotImplementedError() + + def writeMessageBegin(self, name, request_type, seqid): + self.resetWriteContext() + + def writeMessageEnd(self): + pass + + def writeStructBegin(self, name): + self.writeJSONObjectStart() + + def writeStructEnd(self): + self.writeJSONObjectEnd() + + def writeFieldBegin(self, name, ttype, fid): + self.writeJSONString(name) + + def writeFieldEnd(self): + pass + + def writeMapBegin(self, ktype, vtype, size): + self.writeJSONObjectStart() + + def writeMapEnd(self): + self.writeJSONObjectEnd() + + def _writeCollectionBegin(self, etype, size): + self.writeJSONArrayStart() + + def _writeCollectionEnd(self): + self.writeJSONArrayEnd() + writeListBegin = _writeCollectionBegin + writeListEnd = _writeCollectionEnd + writeSetBegin = _writeCollectionBegin + writeSetEnd = _writeCollectionEnd + + def writeByte(self, byte): + checkIntegerLimits(byte, 8) + self.writeJSONNumber(byte) + + def writeI16(self, i16): + checkIntegerLimits(i16, 16) + self.writeJSONNumber(i16) + + def writeI32(self, i32): + checkIntegerLimits(i32, 32) + self.writeJSONNumber(i32) + + def writeI64(self, i64): + checkIntegerLimits(i64, 64) + self.writeJSONNumber(i64) + + def writeBool(self, boolean): + self.writeJSONNumber(1 if boolean is True else 0) + + def writeDouble(self, dbl): + self.writeJSONNumber(dbl) + + def writeString(self, string): + self.writeJSONString(string) + + def writeBinary(self, binary): + self.writeJSONBase64(binary) + + +class TSimpleJSONProtocolFactory(TProtocolFactory): + + def getProtocol(self, trans): + return TSimpleJSONProtocol(trans) diff --git a/src/jaegertracing/thrift/lib/py/src/protocol/TMultiplexedProtocol.py b/src/jaegertracing/thrift/lib/py/src/protocol/TMultiplexedProtocol.py new file mode 100644 index 000000000..0f8390fdb --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/protocol/TMultiplexedProtocol.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from thrift.Thrift import TMessageType +from thrift.protocol import TProtocolDecorator + +SEPARATOR = ":" + + +class TMultiplexedProtocol(TProtocolDecorator.TProtocolDecorator): + def __init__(self, protocol, serviceName): + self.serviceName = serviceName + + def writeMessageBegin(self, name, type, seqid): + if (type == TMessageType.CALL or + type == TMessageType.ONEWAY): + super(TMultiplexedProtocol, self).writeMessageBegin( + self.serviceName + SEPARATOR + name, + type, + seqid + ) + else: + super(TMultiplexedProtocol, self).writeMessageBegin(name, type, seqid) diff --git a/src/jaegertracing/thrift/lib/py/src/protocol/TProtocol.py b/src/jaegertracing/thrift/lib/py/src/protocol/TProtocol.py new file mode 100644 index 000000000..3456e8f0e --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/protocol/TProtocol.py @@ -0,0 +1,422 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from thrift.Thrift import TException, TType, TFrozenDict +from thrift.transport.TTransport import TTransportException +from ..compat import binary_to_str, str_to_binary + +import six +import sys +from itertools import islice +from six.moves import zip + + +class TProtocolException(TException): + """Custom Protocol Exception class""" + + UNKNOWN = 0 + INVALID_DATA = 1 + NEGATIVE_SIZE = 2 + SIZE_LIMIT = 3 + BAD_VERSION = 4 + NOT_IMPLEMENTED = 5 + DEPTH_LIMIT = 6 + INVALID_PROTOCOL = 7 + + def __init__(self, type=UNKNOWN, message=None): + TException.__init__(self, message) + self.type = type + + +class TProtocolBase(object): + """Base class for Thrift protocol driver.""" + + def __init__(self, trans): + self.trans = trans + self._fast_decode = None + self._fast_encode = None + + @staticmethod + def _check_length(limit, length): + if length < 0: + raise TTransportException(TTransportException.NEGATIVE_SIZE, + 'Negative length: %d' % length) + if limit is not None and length > limit: + raise TTransportException(TTransportException.SIZE_LIMIT, + 'Length exceeded max allowed: %d' % limit) + + def writeMessageBegin(self, name, ttype, seqid): + pass + + def writeMessageEnd(self): + pass + + def writeStructBegin(self, name): + pass + + def writeStructEnd(self): + pass + + def writeFieldBegin(self, name, ttype, fid): + pass + + def writeFieldEnd(self): + pass + + def writeFieldStop(self): + pass + + def writeMapBegin(self, ktype, vtype, size): + pass + + def writeMapEnd(self): + pass + + def writeListBegin(self, etype, size): + pass + + def writeListEnd(self): + pass + + def writeSetBegin(self, etype, size): + pass + + def writeSetEnd(self): + pass + + def writeBool(self, bool_val): + pass + + def writeByte(self, byte): + pass + + def writeI16(self, i16): + pass + + def writeI32(self, i32): + pass + + def writeI64(self, i64): + pass + + def writeDouble(self, dub): + pass + + def writeString(self, str_val): + self.writeBinary(str_to_binary(str_val)) + + def writeBinary(self, str_val): + pass + + def writeUtf8(self, str_val): + self.writeString(str_val.encode('utf8')) + + def readMessageBegin(self): + pass + + def readMessageEnd(self): + pass + + def readStructBegin(self): + pass + + def readStructEnd(self): + pass + + def readFieldBegin(self): + pass + + def readFieldEnd(self): + pass + + def readMapBegin(self): + pass + + def readMapEnd(self): + pass + + def readListBegin(self): + pass + + def readListEnd(self): + pass + + def readSetBegin(self): + pass + + def readSetEnd(self): + pass + + def readBool(self): + pass + + def readByte(self): + pass + + def readI16(self): + pass + + def readI32(self): + pass + + def readI64(self): + pass + + def readDouble(self): + pass + + def readString(self): + return binary_to_str(self.readBinary()) + + def readBinary(self): + pass + + def readUtf8(self): + return self.readString().decode('utf8') + + def skip(self, ttype): + if ttype == TType.BOOL: + self.readBool() + elif ttype == TType.BYTE: + self.readByte() + elif ttype == TType.I16: + self.readI16() + elif ttype == TType.I32: + self.readI32() + elif ttype == TType.I64: + self.readI64() + elif ttype == TType.DOUBLE: + self.readDouble() + elif ttype == TType.STRING: + self.readString() + elif ttype == TType.STRUCT: + name = self.readStructBegin() + while True: + (name, ttype, id) = self.readFieldBegin() + if ttype == TType.STOP: + break + self.skip(ttype) + self.readFieldEnd() + self.readStructEnd() + elif ttype == TType.MAP: + (ktype, vtype, size) = self.readMapBegin() + for i in range(size): + self.skip(ktype) + self.skip(vtype) + self.readMapEnd() + elif ttype == TType.SET: + (etype, size) = self.readSetBegin() + for i in range(size): + self.skip(etype) + self.readSetEnd() + elif ttype == TType.LIST: + (etype, size) = self.readListBegin() + for i in range(size): + self.skip(etype) + self.readListEnd() + else: + raise TProtocolException( + TProtocolException.INVALID_DATA, + "invalid TType") + + # tuple of: ( 'reader method' name, is_container bool, 'writer_method' name ) + _TTYPE_HANDLERS = ( + (None, None, False), # 0 TType.STOP + (None, None, False), # 1 TType.VOID # TODO: handle void? + ('readBool', 'writeBool', False), # 2 TType.BOOL + ('readByte', 'writeByte', False), # 3 TType.BYTE and I08 + ('readDouble', 'writeDouble', False), # 4 TType.DOUBLE + (None, None, False), # 5 undefined + ('readI16', 'writeI16', False), # 6 TType.I16 + (None, None, False), # 7 undefined + ('readI32', 'writeI32', False), # 8 TType.I32 + (None, None, False), # 9 undefined + ('readI64', 'writeI64', False), # 10 TType.I64 + ('readString', 'writeString', False), # 11 TType.STRING and UTF7 + ('readContainerStruct', 'writeContainerStruct', True), # 12 *.STRUCT + ('readContainerMap', 'writeContainerMap', True), # 13 TType.MAP + ('readContainerSet', 'writeContainerSet', True), # 14 TType.SET + ('readContainerList', 'writeContainerList', True), # 15 TType.LIST + (None, None, False), # 16 TType.UTF8 # TODO: handle utf8 types? + (None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types? + ) + + def _ttype_handlers(self, ttype, spec): + if spec == 'BINARY': + if ttype != TType.STRING: + raise TProtocolException(type=TProtocolException.INVALID_DATA, + message='Invalid binary field type %d' % ttype) + return ('readBinary', 'writeBinary', False) + if sys.version_info[0] == 2 and spec == 'UTF8': + if ttype != TType.STRING: + raise TProtocolException(type=TProtocolException.INVALID_DATA, + message='Invalid string field type %d' % ttype) + return ('readUtf8', 'writeUtf8', False) + return self._TTYPE_HANDLERS[ttype] if ttype < len(self._TTYPE_HANDLERS) else (None, None, False) + + def _read_by_ttype(self, ttype, spec, espec): + reader_name, _, is_container = self._ttype_handlers(ttype, espec) + if reader_name is None: + raise TProtocolException(type=TProtocolException.INVALID_DATA, + message='Invalid type %d' % (ttype)) + reader_func = getattr(self, reader_name) + read = (lambda: reader_func(espec)) if is_container else reader_func + while True: + yield read() + + def readFieldByTType(self, ttype, spec): + return next(self._read_by_ttype(ttype, spec, spec)) + + def readContainerList(self, spec): + ttype, tspec, is_immutable = spec + (list_type, list_len) = self.readListBegin() + # TODO: compare types we just decoded with thrift_spec + elems = islice(self._read_by_ttype(ttype, spec, tspec), list_len) + results = (tuple if is_immutable else list)(elems) + self.readListEnd() + return results + + def readContainerSet(self, spec): + ttype, tspec, is_immutable = spec + (set_type, set_len) = self.readSetBegin() + # TODO: compare types we just decoded with thrift_spec + elems = islice(self._read_by_ttype(ttype, spec, tspec), set_len) + results = (frozenset if is_immutable else set)(elems) + self.readSetEnd() + return results + + def readContainerStruct(self, spec): + (obj_class, obj_spec) = spec + obj = obj_class() + obj.read(self) + return obj + + def readContainerMap(self, spec): + ktype, kspec, vtype, vspec, is_immutable = spec + (map_ktype, map_vtype, map_len) = self.readMapBegin() + # TODO: compare types we just decoded with thrift_spec and + # abort/skip if types disagree + keys = self._read_by_ttype(ktype, spec, kspec) + vals = self._read_by_ttype(vtype, spec, vspec) + keyvals = islice(zip(keys, vals), map_len) + results = (TFrozenDict if is_immutable else dict)(keyvals) + self.readMapEnd() + return results + + def readStruct(self, obj, thrift_spec, is_immutable=False): + if is_immutable: + fields = {} + self.readStructBegin() + while True: + (fname, ftype, fid) = self.readFieldBegin() + if ftype == TType.STOP: + break + try: + field = thrift_spec[fid] + except IndexError: + self.skip(ftype) + else: + if field is not None and ftype == field[1]: + fname = field[2] + fspec = field[3] + val = self.readFieldByTType(ftype, fspec) + if is_immutable: + fields[fname] = val + else: + setattr(obj, fname, val) + else: + self.skip(ftype) + self.readFieldEnd() + self.readStructEnd() + if is_immutable: + return obj(**fields) + + def writeContainerStruct(self, val, spec): + val.write(self) + + def writeContainerList(self, val, spec): + ttype, tspec, _ = spec + self.writeListBegin(ttype, len(val)) + for _ in self._write_by_ttype(ttype, val, spec, tspec): + pass + self.writeListEnd() + + def writeContainerSet(self, val, spec): + ttype, tspec, _ = spec + self.writeSetBegin(ttype, len(val)) + for _ in self._write_by_ttype(ttype, val, spec, tspec): + pass + self.writeSetEnd() + + def writeContainerMap(self, val, spec): + ktype, kspec, vtype, vspec, _ = spec + self.writeMapBegin(ktype, vtype, len(val)) + for _ in zip(self._write_by_ttype(ktype, six.iterkeys(val), spec, kspec), + self._write_by_ttype(vtype, six.itervalues(val), spec, vspec)): + pass + self.writeMapEnd() + + def writeStruct(self, obj, thrift_spec): + self.writeStructBegin(obj.__class__.__name__) + for field in thrift_spec: + if field is None: + continue + fname = field[2] + val = getattr(obj, fname) + if val is None: + # skip writing out unset fields + continue + fid = field[0] + ftype = field[1] + fspec = field[3] + self.writeFieldBegin(fname, ftype, fid) + self.writeFieldByTType(ftype, val, fspec) + self.writeFieldEnd() + self.writeFieldStop() + self.writeStructEnd() + + def _write_by_ttype(self, ttype, vals, spec, espec): + _, writer_name, is_container = self._ttype_handlers(ttype, espec) + writer_func = getattr(self, writer_name) + write = (lambda v: writer_func(v, espec)) if is_container else writer_func + for v in vals: + yield write(v) + + def writeFieldByTType(self, ttype, val, spec): + next(self._write_by_ttype(ttype, [val], spec, spec)) + + +def checkIntegerLimits(i, bits): + if bits == 8 and (i < -128 or i > 127): + raise TProtocolException(TProtocolException.INVALID_DATA, + "i8 requires -128 <= number <= 127") + elif bits == 16 and (i < -32768 or i > 32767): + raise TProtocolException(TProtocolException.INVALID_DATA, + "i16 requires -32768 <= number <= 32767") + elif bits == 32 and (i < -2147483648 or i > 2147483647): + raise TProtocolException(TProtocolException.INVALID_DATA, + "i32 requires -2147483648 <= number <= 2147483647") + elif bits == 64 and (i < -9223372036854775808 or i > 9223372036854775807): + raise TProtocolException(TProtocolException.INVALID_DATA, + "i64 requires -9223372036854775808 <= number <= 9223372036854775807") + + +class TProtocolFactory(object): + def getProtocol(self, trans): + pass diff --git a/src/jaegertracing/thrift/lib/py/src/protocol/TProtocolDecorator.py b/src/jaegertracing/thrift/lib/py/src/protocol/TProtocolDecorator.py new file mode 100644 index 000000000..f5546c736 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/protocol/TProtocolDecorator.py @@ -0,0 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + + +class TProtocolDecorator(object): + def __new__(cls, protocol, *args, **kwargs): + decorated_cls = type(''.join(['Decorated', protocol.__class__.__name__]), + (cls, protocol.__class__), + protocol.__dict__) + return object.__new__(decorated_cls) diff --git a/src/jaegertracing/thrift/lib/py/src/protocol/__init__.py b/src/jaegertracing/thrift/lib/py/src/protocol/__init__.py new file mode 100644 index 000000000..06647a24b --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/protocol/__init__.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['fastbinary', 'TBase', 'TBinaryProtocol', 'TCompactProtocol', + 'TJSONProtocol', 'TProtocol', 'TProtocolDecorator'] diff --git a/src/jaegertracing/thrift/lib/py/src/server/THttpServer.py b/src/jaegertracing/thrift/lib/py/src/server/THttpServer.py new file mode 100644 index 000000000..47e817df7 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/server/THttpServer.py @@ -0,0 +1,131 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import ssl + +from six.moves import BaseHTTPServer + +from thrift.Thrift import TMessageType +from thrift.server import TServer +from thrift.transport import TTransport + + +class ResponseException(Exception): + """Allows handlers to override the HTTP response + + Normally, THttpServer always sends a 200 response. If a handler wants + to override this behavior (e.g., to simulate a misconfigured or + overloaded web server during testing), it can raise a ResponseException. + The function passed to the constructor will be called with the + RequestHandler as its only argument. Note that this is irrelevant + for ONEWAY requests, as the HTTP response must be sent before the + RPC is processed. + """ + def __init__(self, handler): + self.handler = handler + + +class THttpServer(TServer.TServer): + """A simple HTTP-based Thrift server + + This class is not very performant, but it is useful (for example) for + acting as a mock version of an Apache-based PHP Thrift endpoint. + Also important to note the HTTP implementation pretty much violates the + transport/protocol/processor/server layering, by performing the transport + functions here. This means things like oneway handling are oddly exposed. + """ + def __init__(self, + processor, + server_address, + inputProtocolFactory, + outputProtocolFactory=None, + server_class=BaseHTTPServer.HTTPServer, + **kwargs): + """Set up protocol factories and HTTP (or HTTPS) server. + + See BaseHTTPServer for server_address. + See TServer for protocol factories. + + To make a secure server, provide the named arguments: + * cafile - to validate clients [optional] + * cert_file - the server cert + * key_file - the server's key + """ + if outputProtocolFactory is None: + outputProtocolFactory = inputProtocolFactory + + TServer.TServer.__init__(self, processor, None, None, None, + inputProtocolFactory, outputProtocolFactory) + + thttpserver = self + self._replied = None + + class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler): + def do_POST(self): + # Don't care about the request path. + thttpserver._replied = False + iftrans = TTransport.TFileObjectTransport(self.rfile) + itrans = TTransport.TBufferedTransport( + iftrans, int(self.headers['Content-Length'])) + otrans = TTransport.TMemoryBuffer() + iprot = thttpserver.inputProtocolFactory.getProtocol(itrans) + oprot = thttpserver.outputProtocolFactory.getProtocol(otrans) + try: + thttpserver.processor.on_message_begin(self.on_begin) + thttpserver.processor.process(iprot, oprot) + except ResponseException as exn: + exn.handler(self) + else: + if not thttpserver._replied: + # If the request was ONEWAY we would have replied already + data = otrans.getvalue() + self.send_response(200) + self.send_header("Content-Length", len(data)) + self.send_header("Content-Type", "application/x-thrift") + self.end_headers() + self.wfile.write(data) + + def on_begin(self, name, type, seqid): + """ + Inspect the message header. + + This allows us to post an immediate transport response + if the request is a ONEWAY message type. + """ + if type == TMessageType.ONEWAY: + self.send_response(200) + self.send_header("Content-Type", "application/x-thrift") + self.end_headers() + thttpserver._replied = True + + self.httpd = server_class(server_address, RequestHander) + + if (kwargs.get('cafile') or kwargs.get('cert_file') or kwargs.get('key_file')): + context = ssl.create_default_context(cafile=kwargs.get('cafile')) + context.check_hostname = False + context.load_cert_chain(kwargs.get('cert_file'), kwargs.get('key_file')) + context.verify_mode = ssl.CERT_REQUIRED if kwargs.get('cafile') else ssl.CERT_NONE + self.httpd.socket = context.wrap_socket(self.httpd.socket, server_side=True) + + def serve(self): + self.httpd.serve_forever() + + def shutdown(self): + self.httpd.socket.close() + # self.httpd.shutdown() # hangs forever, python doesn't handle POLLNVAL properly! diff --git a/src/jaegertracing/thrift/lib/py/src/server/TNonblockingServer.py b/src/jaegertracing/thrift/lib/py/src/server/TNonblockingServer.py new file mode 100644 index 000000000..f62d486eb --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/server/TNonblockingServer.py @@ -0,0 +1,370 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""Implementation of non-blocking server. + +The main idea of the server is to receive and send requests +only from the main thread. + +The thread poool should be sized for concurrent tasks, not +maximum connections +""" + +import logging +import select +import socket +import struct +import threading + +from collections import deque +from six.moves import queue + +from thrift.transport import TTransport +from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory + +__all__ = ['TNonblockingServer'] + +logger = logging.getLogger(__name__) + + +class Worker(threading.Thread): + """Worker is a small helper to process incoming connection.""" + + def __init__(self, queue): + threading.Thread.__init__(self) + self.queue = queue + + def run(self): + """Process queries from task queue, stop if processor is None.""" + while True: + try: + processor, iprot, oprot, otrans, callback = self.queue.get() + if processor is None: + break + processor.process(iprot, oprot) + callback(True, otrans.getvalue()) + except Exception: + logger.exception("Exception while processing request", exc_info=True) + callback(False, b'') + + +WAIT_LEN = 0 +WAIT_MESSAGE = 1 +WAIT_PROCESS = 2 +SEND_ANSWER = 3 +CLOSED = 4 + + +def locked(func): + """Decorator which locks self.lock.""" + def nested(self, *args, **kwargs): + self.lock.acquire() + try: + return func(self, *args, **kwargs) + finally: + self.lock.release() + return nested + + +def socket_exception(func): + """Decorator close object on socket.error.""" + def read(self, *args, **kwargs): + try: + return func(self, *args, **kwargs) + except socket.error: + logger.debug('ignoring socket exception', exc_info=True) + self.close() + return read + + +class Message(object): + def __init__(self, offset, len_, header): + self.offset = offset + self.len = len_ + self.buffer = None + self.is_header = header + + @property + def end(self): + return self.offset + self.len + + +class Connection(object): + """Basic class is represented connection. + + It can be in state: + WAIT_LEN --- connection is reading request len. + WAIT_MESSAGE --- connection is reading request. + WAIT_PROCESS --- connection has just read whole request and + waits for call ready routine. + SEND_ANSWER --- connection is sending answer string (including length + of answer). + CLOSED --- socket was closed and connection should be deleted. + """ + def __init__(self, new_socket, wake_up): + self.socket = new_socket + self.socket.setblocking(False) + self.status = WAIT_LEN + self.len = 0 + self.received = deque() + self._reading = Message(0, 4, True) + self._rbuf = b'' + self._wbuf = b'' + self.lock = threading.Lock() + self.wake_up = wake_up + self.remaining = False + + @socket_exception + def read(self): + """Reads data from stream and switch state.""" + assert self.status in (WAIT_LEN, WAIT_MESSAGE) + assert not self.received + buf_size = 8192 + first = True + done = False + while not done: + read = self.socket.recv(buf_size) + rlen = len(read) + done = rlen < buf_size + self._rbuf += read + if first and rlen == 0: + if self.status != WAIT_LEN or self._rbuf: + logger.error('could not read frame from socket') + else: + logger.debug('read zero length. client might have disconnected') + self.close() + while len(self._rbuf) >= self._reading.end: + if self._reading.is_header: + mlen, = struct.unpack('!i', self._rbuf[:4]) + self._reading = Message(self._reading.end, mlen, False) + self.status = WAIT_MESSAGE + else: + self._reading.buffer = self._rbuf + self.received.append(self._reading) + self._rbuf = self._rbuf[self._reading.end:] + self._reading = Message(0, 4, True) + first = False + if self.received: + self.status = WAIT_PROCESS + break + self.remaining = not done + + @socket_exception + def write(self): + """Writes data from socket and switch state.""" + assert self.status == SEND_ANSWER + sent = self.socket.send(self._wbuf) + if sent == len(self._wbuf): + self.status = WAIT_LEN + self._wbuf = b'' + self.len = 0 + else: + self._wbuf = self._wbuf[sent:] + + @locked + def ready(self, all_ok, message): + """Callback function for switching state and waking up main thread. + + This function is the only function witch can be called asynchronous. + + The ready can switch Connection to three states: + WAIT_LEN if request was oneway. + SEND_ANSWER if request was processed in normal way. + CLOSED if request throws unexpected exception. + + The one wakes up main thread. + """ + assert self.status == WAIT_PROCESS + if not all_ok: + self.close() + self.wake_up() + return + self.len = 0 + if len(message) == 0: + # it was a oneway request, do not write answer + self._wbuf = b'' + self.status = WAIT_LEN + else: + self._wbuf = struct.pack('!i', len(message)) + message + self.status = SEND_ANSWER + self.wake_up() + + @locked + def is_writeable(self): + """Return True if connection should be added to write list of select""" + return self.status == SEND_ANSWER + + # it's not necessary, but... + @locked + def is_readable(self): + """Return True if connection should be added to read list of select""" + return self.status in (WAIT_LEN, WAIT_MESSAGE) + + @locked + def is_closed(self): + """Returns True if connection is closed.""" + return self.status == CLOSED + + def fileno(self): + """Returns the file descriptor of the associated socket.""" + return self.socket.fileno() + + def close(self): + """Closes connection""" + self.status = CLOSED + self.socket.close() + + +class TNonblockingServer(object): + """Non-blocking server.""" + + def __init__(self, + processor, + lsocket, + inputProtocolFactory=None, + outputProtocolFactory=None, + threads=10): + self.processor = processor + self.socket = lsocket + self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory() + self.out_protocol = outputProtocolFactory or self.in_protocol + self.threads = int(threads) + self.clients = {} + self.tasks = queue.Queue() + self._read, self._write = socket.socketpair() + self.prepared = False + self._stop = False + + def setNumThreads(self, num): + """Set the number of worker threads that should be created.""" + # implement ThreadPool interface + assert not self.prepared, "Can't change number of threads after start" + self.threads = num + + def prepare(self): + """Prepares server for serve requests.""" + if self.prepared: + return + self.socket.listen() + for _ in range(self.threads): + thread = Worker(self.tasks) + thread.setDaemon(True) + thread.start() + self.prepared = True + + def wake_up(self): + """Wake up main thread. + + The server usually waits in select call in we should terminate one. + The simplest way is using socketpair. + + Select always wait to read from the first socket of socketpair. + + In this case, we can just write anything to the second socket from + socketpair. + """ + self._write.send(b'1') + + def stop(self): + """Stop the server. + + This method causes the serve() method to return. stop() may be invoked + from within your handler, or from another thread. + + After stop() is called, serve() will return but the server will still + be listening on the socket. serve() may then be called again to resume + processing requests. Alternatively, close() may be called after + serve() returns to close the server socket and shutdown all worker + threads. + """ + self._stop = True + self.wake_up() + + def _select(self): + """Does select on open connections.""" + readable = [self.socket.handle.fileno(), self._read.fileno()] + writable = [] + remaining = [] + for i, connection in list(self.clients.items()): + if connection.is_readable(): + readable.append(connection.fileno()) + if connection.remaining or connection.received: + remaining.append(connection.fileno()) + if connection.is_writeable(): + writable.append(connection.fileno()) + if connection.is_closed(): + del self.clients[i] + if remaining: + return remaining, [], [], False + else: + return select.select(readable, writable, readable) + (True,) + + def handle(self): + """Handle requests. + + WARNING! You must call prepare() BEFORE calling handle() + """ + assert self.prepared, "You have to call prepare before handle" + rset, wset, xset, selected = self._select() + for readable in rset: + if readable == self._read.fileno(): + # don't care i just need to clean readable flag + self._read.recv(1024) + elif readable == self.socket.handle.fileno(): + try: + client = self.socket.accept() + if client: + self.clients[client.handle.fileno()] = Connection(client.handle, + self.wake_up) + except socket.error: + logger.debug('error while accepting', exc_info=True) + else: + connection = self.clients[readable] + if selected: + connection.read() + if connection.received: + connection.status = WAIT_PROCESS + msg = connection.received.popleft() + itransport = TTransport.TMemoryBuffer(msg.buffer, msg.offset) + otransport = TTransport.TMemoryBuffer() + iprot = self.in_protocol.getProtocol(itransport) + oprot = self.out_protocol.getProtocol(otransport) + self.tasks.put([self.processor, iprot, oprot, + otransport, connection.ready]) + for writeable in wset: + self.clients[writeable].write() + for oob in xset: + self.clients[oob].close() + del self.clients[oob] + + def close(self): + """Closes the server.""" + for _ in range(self.threads): + self.tasks.put([None, None, None, None, None]) + self.socket.close() + self.prepared = False + + def serve(self): + """Serve requests. + + Serve requests forever, or until stop() is called. + """ + self._stop = False + self.prepare() + while not self._stop: + self.handle() diff --git a/src/jaegertracing/thrift/lib/py/src/server/TProcessPoolServer.py b/src/jaegertracing/thrift/lib/py/src/server/TProcessPoolServer.py new file mode 100644 index 000000000..fe6dc8162 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/server/TProcessPoolServer.py @@ -0,0 +1,123 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + + +import logging + +from multiprocessing import Process, Value, Condition + +from .TServer import TServer +from thrift.transport.TTransport import TTransportException + +logger = logging.getLogger(__name__) + + +class TProcessPoolServer(TServer): + """Server with a fixed size pool of worker subprocesses to service requests + + Note that if you need shared state between the handlers - it's up to you! + Written by Dvir Volk, doat.com + """ + def __init__(self, *args): + TServer.__init__(self, *args) + self.numWorkers = 10 + self.workers = [] + self.isRunning = Value('b', False) + self.stopCondition = Condition() + self.postForkCallback = None + + def setPostForkCallback(self, callback): + if not callable(callback): + raise TypeError("This is not a callback!") + self.postForkCallback = callback + + def setNumWorkers(self, num): + """Set the number of worker threads that should be created""" + self.numWorkers = num + + def workerProcess(self): + """Loop getting clients from the shared queue and process them""" + if self.postForkCallback: + self.postForkCallback() + + while self.isRunning.value: + try: + client = self.serverTransport.accept() + if not client: + continue + self.serveClient(client) + except (KeyboardInterrupt, SystemExit): + return 0 + except Exception as x: + logger.exception(x) + + def serveClient(self, client): + """Process input/output from a client for as long as possible""" + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) + oprot = self.outputProtocolFactory.getProtocol(otrans) + + try: + while True: + self.processor.process(iprot, oprot) + except TTransportException: + pass + except Exception as x: + logger.exception(x) + + itrans.close() + otrans.close() + + def serve(self): + """Start workers and put into queue""" + # this is a shared state that can tell the workers to exit when False + self.isRunning.value = True + + # first bind and listen to the port + self.serverTransport.listen() + + # fork the children + for i in range(self.numWorkers): + try: + w = Process(target=self.workerProcess) + w.daemon = True + w.start() + self.workers.append(w) + except Exception as x: + logger.exception(x) + + # wait until the condition is set by stop() + while True: + self.stopCondition.acquire() + try: + self.stopCondition.wait() + break + except (SystemExit, KeyboardInterrupt): + break + except Exception as x: + logger.exception(x) + + self.isRunning.value = False + + def stop(self): + self.isRunning.value = False + self.stopCondition.acquire() + self.stopCondition.notify() + self.stopCondition.release() diff --git a/src/jaegertracing/thrift/lib/py/src/server/TServer.py b/src/jaegertracing/thrift/lib/py/src/server/TServer.py new file mode 100644 index 000000000..df2a7bb93 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/server/TServer.py @@ -0,0 +1,323 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from six.moves import queue +import logging +import os +import threading + +from thrift.protocol import TBinaryProtocol +from thrift.protocol.THeaderProtocol import THeaderProtocolFactory +from thrift.transport import TTransport + +logger = logging.getLogger(__name__) + + +class TServer(object): + """Base interface for a server, which must have a serve() method. + + Three constructors for all servers: + 1) (processor, serverTransport) + 2) (processor, serverTransport, transportFactory, protocolFactory) + 3) (processor, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory) + """ + def __init__(self, *args): + if (len(args) == 2): + self.__initArgs__(args[0], args[1], + TTransport.TTransportFactoryBase(), + TTransport.TTransportFactoryBase(), + TBinaryProtocol.TBinaryProtocolFactory(), + TBinaryProtocol.TBinaryProtocolFactory()) + elif (len(args) == 4): + self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3]) + elif (len(args) == 6): + self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5]) + + def __initArgs__(self, processor, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory): + self.processor = processor + self.serverTransport = serverTransport + self.inputTransportFactory = inputTransportFactory + self.outputTransportFactory = outputTransportFactory + self.inputProtocolFactory = inputProtocolFactory + self.outputProtocolFactory = outputProtocolFactory + + input_is_header = isinstance(self.inputProtocolFactory, THeaderProtocolFactory) + output_is_header = isinstance(self.outputProtocolFactory, THeaderProtocolFactory) + if any((input_is_header, output_is_header)) and input_is_header != output_is_header: + raise ValueError("THeaderProtocol servers require that both the input and " + "output protocols are THeaderProtocol.") + + def serve(self): + pass + + +class TSimpleServer(TServer): + """Simple single-threaded server that just pumps around one transport.""" + + def __init__(self, *args): + TServer.__init__(self, *args) + + def serve(self): + self.serverTransport.listen() + while True: + client = self.serverTransport.accept() + if not client: + continue + + itrans = self.inputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) + + # for THeaderProtocol, we must use the same protocol instance for + # input and output so that the response is in the same dialect that + # the server detected the request was in. + if isinstance(self.inputProtocolFactory, THeaderProtocolFactory): + otrans = None + oprot = iprot + else: + otrans = self.outputTransportFactory.getTransport(client) + oprot = self.outputProtocolFactory.getProtocol(otrans) + + try: + while True: + self.processor.process(iprot, oprot) + except TTransport.TTransportException: + pass + except Exception as x: + logger.exception(x) + + itrans.close() + if otrans: + otrans.close() + + +class TThreadedServer(TServer): + """Threaded server that spawns a new thread per each connection.""" + + def __init__(self, *args, **kwargs): + TServer.__init__(self, *args) + self.daemon = kwargs.get("daemon", False) + + def serve(self): + self.serverTransport.listen() + while True: + try: + client = self.serverTransport.accept() + if not client: + continue + t = threading.Thread(target=self.handle, args=(client,)) + t.setDaemon(self.daemon) + t.start() + except KeyboardInterrupt: + raise + except Exception as x: + logger.exception(x) + + def handle(self, client): + itrans = self.inputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) + + # for THeaderProtocol, we must use the same protocol instance for input + # and output so that the response is in the same dialect that the + # server detected the request was in. + if isinstance(self.inputProtocolFactory, THeaderProtocolFactory): + otrans = None + oprot = iprot + else: + otrans = self.outputTransportFactory.getTransport(client) + oprot = self.outputProtocolFactory.getProtocol(otrans) + + try: + while True: + self.processor.process(iprot, oprot) + except TTransport.TTransportException: + pass + except Exception as x: + logger.exception(x) + + itrans.close() + if otrans: + otrans.close() + + +class TThreadPoolServer(TServer): + """Server with a fixed size pool of threads which service requests.""" + + def __init__(self, *args, **kwargs): + TServer.__init__(self, *args) + self.clients = queue.Queue() + self.threads = 10 + self.daemon = kwargs.get("daemon", False) + + def setNumThreads(self, num): + """Set the number of worker threads that should be created""" + self.threads = num + + def serveThread(self): + """Loop around getting clients from the shared queue and process them.""" + while True: + try: + client = self.clients.get() + self.serveClient(client) + except Exception as x: + logger.exception(x) + + def serveClient(self, client): + """Process input/output from a client for as long as possible""" + itrans = self.inputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) + + # for THeaderProtocol, we must use the same protocol instance for input + # and output so that the response is in the same dialect that the + # server detected the request was in. + if isinstance(self.inputProtocolFactory, THeaderProtocolFactory): + otrans = None + oprot = iprot + else: + otrans = self.outputTransportFactory.getTransport(client) + oprot = self.outputProtocolFactory.getProtocol(otrans) + + try: + while True: + self.processor.process(iprot, oprot) + except TTransport.TTransportException: + pass + except Exception as x: + logger.exception(x) + + itrans.close() + if otrans: + otrans.close() + + def serve(self): + """Start a fixed number of worker threads and put client into a queue""" + for i in range(self.threads): + try: + t = threading.Thread(target=self.serveThread) + t.setDaemon(self.daemon) + t.start() + except Exception as x: + logger.exception(x) + + # Pump the socket for clients + self.serverTransport.listen() + while True: + try: + client = self.serverTransport.accept() + if not client: + continue + self.clients.put(client) + except Exception as x: + logger.exception(x) + + +class TForkingServer(TServer): + """A Thrift server that forks a new process for each request + + This is more scalable than the threaded server as it does not cause + GIL contention. + + Note that this has different semantics from the threading server. + Specifically, updates to shared variables will no longer be shared. + It will also not work on windows. + + This code is heavily inspired by SocketServer.ForkingMixIn in the + Python stdlib. + """ + def __init__(self, *args): + TServer.__init__(self, *args) + self.children = [] + + def serve(self): + def try_close(file): + try: + file.close() + except IOError as e: + logger.warning(e, exc_info=True) + + self.serverTransport.listen() + while True: + client = self.serverTransport.accept() + if not client: + continue + try: + pid = os.fork() + + if pid: # parent + # add before collect, otherwise you race w/ waitpid + self.children.append(pid) + self.collect_children() + + # Parent must close socket or the connection may not get + # closed promptly + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + try_close(itrans) + try_close(otrans) + else: + itrans = self.inputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) + + # for THeaderProtocol, we must use the same protocol + # instance for input and output so that the response is in + # the same dialect that the server detected the request was + # in. + if isinstance(self.inputProtocolFactory, THeaderProtocolFactory): + otrans = None + oprot = iprot + else: + otrans = self.outputTransportFactory.getTransport(client) + oprot = self.outputProtocolFactory.getProtocol(otrans) + + ecode = 0 + try: + try: + while True: + self.processor.process(iprot, oprot) + except TTransport.TTransportException: + pass + except Exception as e: + logger.exception(e) + ecode = 1 + finally: + try_close(itrans) + if otrans: + try_close(otrans) + + os._exit(ecode) + + except TTransport.TTransportException: + pass + except Exception as x: + logger.exception(x) + + def collect_children(self): + while self.children: + try: + pid, status = os.waitpid(0, os.WNOHANG) + except os.error: + pid = None + + if pid: + self.children.remove(pid) + else: + break diff --git a/src/jaegertracing/thrift/lib/py/src/server/__init__.py b/src/jaegertracing/thrift/lib/py/src/server/__init__.py new file mode 100644 index 000000000..1bf6e254e --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/server/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['TServer', 'TNonblockingServer'] diff --git a/src/jaegertracing/thrift/lib/py/src/transport/THeaderTransport.py b/src/jaegertracing/thrift/lib/py/src/transport/THeaderTransport.py new file mode 100644 index 000000000..c0d564012 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/transport/THeaderTransport.py @@ -0,0 +1,352 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import struct +import zlib + +from thrift.compat import BufferIO, byte_index +from thrift.protocol.TBinaryProtocol import TBinaryProtocol +from thrift.protocol.TCompactProtocol import TCompactProtocol, readVarint, writeVarint +from thrift.Thrift import TApplicationException +from thrift.transport.TTransport import ( + CReadableTransport, + TMemoryBuffer, + TTransportBase, + TTransportException, +) + + +U16 = struct.Struct("!H") +I32 = struct.Struct("!i") +HEADER_MAGIC = 0x0FFF +HARD_MAX_FRAME_SIZE = 0x3FFFFFFF + + +class THeaderClientType(object): + HEADERS = 0x00 + + FRAMED_BINARY = 0x01 + UNFRAMED_BINARY = 0x02 + + FRAMED_COMPACT = 0x03 + UNFRAMED_COMPACT = 0x04 + + +class THeaderSubprotocolID(object): + BINARY = 0x00 + COMPACT = 0x02 + + +class TInfoHeaderType(object): + KEY_VALUE = 0x01 + + +class THeaderTransformID(object): + ZLIB = 0x01 + + +READ_TRANSFORMS_BY_ID = { + THeaderTransformID.ZLIB: zlib.decompress, +} + + +WRITE_TRANSFORMS_BY_ID = { + THeaderTransformID.ZLIB: zlib.compress, +} + + +def _readString(trans): + size = readVarint(trans) + if size < 0: + raise TTransportException( + TTransportException.NEGATIVE_SIZE, + "Negative length" + ) + return trans.read(size) + + +def _writeString(trans, value): + writeVarint(trans, len(value)) + trans.write(value) + + +class THeaderTransport(TTransportBase, CReadableTransport): + def __init__(self, transport, allowed_client_types): + self._transport = transport + self._client_type = THeaderClientType.HEADERS + self._allowed_client_types = allowed_client_types + + self._read_buffer = BufferIO(b"") + self._read_headers = {} + + self._write_buffer = BufferIO() + self._write_headers = {} + self._write_transforms = [] + + self.flags = 0 + self.sequence_id = 0 + self._protocol_id = THeaderSubprotocolID.BINARY + self._max_frame_size = HARD_MAX_FRAME_SIZE + + def isOpen(self): + return self._transport.isOpen() + + def open(self): + return self._transport.open() + + def close(self): + return self._transport.close() + + def get_headers(self): + return self._read_headers + + def set_header(self, key, value): + if not isinstance(key, bytes): + raise ValueError("header names must be bytes") + if not isinstance(value, bytes): + raise ValueError("header values must be bytes") + self._write_headers[key] = value + + def clear_headers(self): + self._write_headers.clear() + + def add_transform(self, transform_id): + if transform_id not in WRITE_TRANSFORMS_BY_ID: + raise ValueError("unknown transform") + self._write_transforms.append(transform_id) + + def set_max_frame_size(self, size): + if not 0 < size < HARD_MAX_FRAME_SIZE: + raise ValueError("maximum frame size should be < %d and > 0" % HARD_MAX_FRAME_SIZE) + self._max_frame_size = size + + @property + def protocol_id(self): + if self._client_type == THeaderClientType.HEADERS: + return self._protocol_id + elif self._client_type in (THeaderClientType.FRAMED_BINARY, THeaderClientType.UNFRAMED_BINARY): + return THeaderSubprotocolID.BINARY + elif self._client_type in (THeaderClientType.FRAMED_COMPACT, THeaderClientType.UNFRAMED_COMPACT): + return THeaderSubprotocolID.COMPACT + else: + raise TTransportException( + TTransportException.INVALID_CLIENT_TYPE, + "Protocol ID not know for client type %d" % self._client_type, + ) + + def read(self, sz): + # if there are bytes left in the buffer, produce those first. + bytes_read = self._read_buffer.read(sz) + bytes_left_to_read = sz - len(bytes_read) + if bytes_left_to_read == 0: + return bytes_read + + # if we've determined this is an unframed client, just pass the read + # through to the underlying transport until we're reset again at the + # beginning of the next message. + if self._client_type in (THeaderClientType.UNFRAMED_BINARY, THeaderClientType.UNFRAMED_COMPACT): + return bytes_read + self._transport.read(bytes_left_to_read) + + # we're empty and (maybe) framed. fill the buffers with the next frame. + self.readFrame(bytes_left_to_read) + return bytes_read + self._read_buffer.read(bytes_left_to_read) + + def _set_client_type(self, client_type): + if client_type not in self._allowed_client_types: + raise TTransportException( + TTransportException.INVALID_CLIENT_TYPE, + "Client type %d not allowed by server." % client_type, + ) + self._client_type = client_type + + def readFrame(self, req_sz): + # the first word could either be the length field of a framed message + # or the first bytes of an unframed message. + first_word = self._transport.readAll(I32.size) + frame_size, = I32.unpack(first_word) + is_unframed = False + if frame_size & TBinaryProtocol.VERSION_MASK == TBinaryProtocol.VERSION_1: + self._set_client_type(THeaderClientType.UNFRAMED_BINARY) + is_unframed = True + elif (byte_index(first_word, 0) == TCompactProtocol.PROTOCOL_ID and + byte_index(first_word, 1) & TCompactProtocol.VERSION_MASK == TCompactProtocol.VERSION): + self._set_client_type(THeaderClientType.UNFRAMED_COMPACT) + is_unframed = True + + if is_unframed: + bytes_left_to_read = req_sz - I32.size + if bytes_left_to_read > 0: + rest = self._transport.read(bytes_left_to_read) + else: + rest = b"" + self._read_buffer = BufferIO(first_word + rest) + return + + # ok, we're still here so we're framed. + if frame_size > self._max_frame_size: + raise TTransportException( + TTransportException.SIZE_LIMIT, + "Frame was too large.", + ) + read_buffer = BufferIO(self._transport.readAll(frame_size)) + + # the next word is either going to be the version field of a + # binary/compact protocol message or the magic value + flags of a + # header protocol message. + second_word = read_buffer.read(I32.size) + version, = I32.unpack(second_word) + read_buffer.seek(0) + if version >> 16 == HEADER_MAGIC: + self._set_client_type(THeaderClientType.HEADERS) + self._read_buffer = self._parse_header_format(read_buffer) + elif version & TBinaryProtocol.VERSION_MASK == TBinaryProtocol.VERSION_1: + self._set_client_type(THeaderClientType.FRAMED_BINARY) + self._read_buffer = read_buffer + elif (byte_index(second_word, 0) == TCompactProtocol.PROTOCOL_ID and + byte_index(second_word, 1) & TCompactProtocol.VERSION_MASK == TCompactProtocol.VERSION): + self._set_client_type(THeaderClientType.FRAMED_COMPACT) + self._read_buffer = read_buffer + else: + raise TTransportException( + TTransportException.INVALID_CLIENT_TYPE, + "Could not detect client transport type.", + ) + + def _parse_header_format(self, buffer): + # make BufferIO look like TTransport for varint helpers + buffer_transport = TMemoryBuffer() + buffer_transport._buffer = buffer + + buffer.read(2) # discard the magic bytes + self.flags, = U16.unpack(buffer.read(U16.size)) + self.sequence_id, = I32.unpack(buffer.read(I32.size)) + + header_length = U16.unpack(buffer.read(U16.size))[0] * 4 + end_of_headers = buffer.tell() + header_length + if end_of_headers > len(buffer.getvalue()): + raise TTransportException( + TTransportException.SIZE_LIMIT, + "Header size is larger than whole frame.", + ) + + self._protocol_id = readVarint(buffer_transport) + + transforms = [] + transform_count = readVarint(buffer_transport) + for _ in range(transform_count): + transform_id = readVarint(buffer_transport) + if transform_id not in READ_TRANSFORMS_BY_ID: + raise TApplicationException( + TApplicationException.INVALID_TRANSFORM, + "Unknown transform: %d" % transform_id, + ) + transforms.append(transform_id) + transforms.reverse() + + headers = {} + while buffer.tell() < end_of_headers: + header_type = readVarint(buffer_transport) + if header_type == TInfoHeaderType.KEY_VALUE: + count = readVarint(buffer_transport) + for _ in range(count): + key = _readString(buffer_transport) + value = _readString(buffer_transport) + headers[key] = value + else: + break # ignore unknown headers + self._read_headers = headers + + # skip padding / anything we didn't understand + buffer.seek(end_of_headers) + + payload = buffer.read() + for transform_id in transforms: + transform_fn = READ_TRANSFORMS_BY_ID[transform_id] + payload = transform_fn(payload) + return BufferIO(payload) + + def write(self, buf): + self._write_buffer.write(buf) + + def flush(self): + payload = self._write_buffer.getvalue() + self._write_buffer = BufferIO() + + buffer = BufferIO() + if self._client_type == THeaderClientType.HEADERS: + for transform_id in self._write_transforms: + transform_fn = WRITE_TRANSFORMS_BY_ID[transform_id] + payload = transform_fn(payload) + + headers = BufferIO() + writeVarint(headers, self._protocol_id) + writeVarint(headers, len(self._write_transforms)) + for transform_id in self._write_transforms: + writeVarint(headers, transform_id) + if self._write_headers: + writeVarint(headers, TInfoHeaderType.KEY_VALUE) + writeVarint(headers, len(self._write_headers)) + for key, value in self._write_headers.items(): + _writeString(headers, key) + _writeString(headers, value) + self._write_headers = {} + padding_needed = (4 - (len(headers.getvalue()) % 4)) % 4 + headers.write(b"\x00" * padding_needed) + header_bytes = headers.getvalue() + + buffer.write(I32.pack(10 + len(header_bytes) + len(payload))) + buffer.write(U16.pack(HEADER_MAGIC)) + buffer.write(U16.pack(self.flags)) + buffer.write(I32.pack(self.sequence_id)) + buffer.write(U16.pack(len(header_bytes) // 4)) + buffer.write(header_bytes) + buffer.write(payload) + elif self._client_type in (THeaderClientType.FRAMED_BINARY, THeaderClientType.FRAMED_COMPACT): + buffer.write(I32.pack(len(payload))) + buffer.write(payload) + elif self._client_type in (THeaderClientType.UNFRAMED_BINARY, THeaderClientType.UNFRAMED_COMPACT): + buffer.write(payload) + else: + raise TTransportException( + TTransportException.INVALID_CLIENT_TYPE, + "Unknown client type.", + ) + + # the frame length field doesn't count towards the frame payload size + frame_bytes = buffer.getvalue() + frame_payload_size = len(frame_bytes) - 4 + if frame_payload_size > self._max_frame_size: + raise TTransportException( + TTransportException.SIZE_LIMIT, + "Attempting to send frame that is too large.", + ) + + self._transport.write(frame_bytes) + self._transport.flush() + + @property + def cstringio_buf(self): + return self._read_buffer + + def cstringio_refill(self, partialread, reqlen): + result = bytearray(partialread) + while len(result) < reqlen: + result += self.read(reqlen - len(result)) + self._read_buffer = BufferIO(result) + return self._read_buffer diff --git a/src/jaegertracing/thrift/lib/py/src/transport/THttpClient.py b/src/jaegertracing/thrift/lib/py/src/transport/THttpClient.py new file mode 100644 index 000000000..37b0a4d8d --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/transport/THttpClient.py @@ -0,0 +1,187 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from io import BytesIO +import os +import ssl +import sys +import warnings +import base64 + +from six.moves import urllib +from six.moves import http_client + +from .TTransport import TTransportBase +import six + + +class THttpClient(TTransportBase): + """Http implementation of TTransport base.""" + + def __init__(self, uri_or_host, port=None, path=None, cafile=None, cert_file=None, key_file=None, ssl_context=None): + """THttpClient supports two different types of construction: + + THttpClient(host, port, path) - deprecated + THttpClient(uri, [port=, path=, cafile=, cert_file=, key_file=, ssl_context=]) + + Only the second supports https. To properly authenticate against the server, + provide the client's identity by specifying cert_file and key_file. To properly + authenticate the server, specify either cafile or ssl_context with a CA defined. + NOTE: if both cafile and ssl_context are defined, ssl_context will override cafile. + """ + if port is not None: + warnings.warn( + "Please use the THttpClient('http{s}://host:port/path') constructor", + DeprecationWarning, + stacklevel=2) + self.host = uri_or_host + self.port = port + assert path + self.path = path + self.scheme = 'http' + else: + parsed = urllib.parse.urlparse(uri_or_host) + self.scheme = parsed.scheme + assert self.scheme in ('http', 'https') + if self.scheme == 'http': + self.port = parsed.port or http_client.HTTP_PORT + elif self.scheme == 'https': + self.port = parsed.port or http_client.HTTPS_PORT + self.certfile = cert_file + self.keyfile = key_file + self.context = ssl.create_default_context(cafile=cafile) if (cafile and not ssl_context) else ssl_context + self.host = parsed.hostname + self.path = parsed.path + if parsed.query: + self.path += '?%s' % parsed.query + try: + proxy = urllib.request.getproxies()[self.scheme] + except KeyError: + proxy = None + else: + if urllib.request.proxy_bypass(self.host): + proxy = None + if proxy: + parsed = urllib.parse.urlparse(proxy) + self.realhost = self.host + self.realport = self.port + self.host = parsed.hostname + self.port = parsed.port + self.proxy_auth = self.basic_proxy_auth_header(parsed) + else: + self.realhost = self.realport = self.proxy_auth = None + self.__wbuf = BytesIO() + self.__http = None + self.__http_response = None + self.__timeout = None + self.__custom_headers = None + + @staticmethod + def basic_proxy_auth_header(proxy): + if proxy is None or not proxy.username: + return None + ap = "%s:%s" % (urllib.parse.unquote(proxy.username), + urllib.parse.unquote(proxy.password)) + cr = base64.b64encode(ap).strip() + return "Basic " + cr + + def using_proxy(self): + return self.realhost is not None + + def open(self): + if self.scheme == 'http': + self.__http = http_client.HTTPConnection(self.host, self.port, + timeout=self.__timeout) + elif self.scheme == 'https': + self.__http = http_client.HTTPSConnection(self.host, self.port, + key_file=self.keyfile, + cert_file=self.certfile, + timeout=self.__timeout, + context=self.context) + if self.using_proxy(): + self.__http.set_tunnel(self.realhost, self.realport, + {"Proxy-Authorization": self.proxy_auth}) + + def close(self): + self.__http.close() + self.__http = None + self.__http_response = None + + def isOpen(self): + return self.__http is not None + + def setTimeout(self, ms): + if ms is None: + self.__timeout = None + else: + self.__timeout = ms / 1000.0 + + def setCustomHeaders(self, headers): + self.__custom_headers = headers + + def read(self, sz): + return self.__http_response.read(sz) + + def write(self, buf): + self.__wbuf.write(buf) + + def flush(self): + if self.isOpen(): + self.close() + self.open() + + # Pull data out of buffer + data = self.__wbuf.getvalue() + self.__wbuf = BytesIO() + + # HTTP request + if self.using_proxy() and self.scheme == "http": + # need full URL of real host for HTTP proxy here (HTTPS uses CONNECT tunnel) + self.__http.putrequest('POST', "http://%s:%s%s" % + (self.realhost, self.realport, self.path)) + else: + self.__http.putrequest('POST', self.path) + + # Write headers + self.__http.putheader('Content-Type', 'application/x-thrift') + self.__http.putheader('Content-Length', str(len(data))) + if self.using_proxy() and self.scheme == "http" and self.proxy_auth is not None: + self.__http.putheader("Proxy-Authorization", self.proxy_auth) + + if not self.__custom_headers or 'User-Agent' not in self.__custom_headers: + user_agent = 'Python/THttpClient' + script = os.path.basename(sys.argv[0]) + if script: + user_agent = '%s (%s)' % (user_agent, urllib.parse.quote(script)) + self.__http.putheader('User-Agent', user_agent) + + if self.__custom_headers: + for key, val in six.iteritems(self.__custom_headers): + self.__http.putheader(key, val) + + self.__http.endheaders() + + # Write payload + self.__http.send(data) + + # Get reply to flush the request + self.__http_response = self.__http.getresponse() + self.code = self.__http_response.status + self.message = self.__http_response.reason + self.headers = self.__http_response.msg diff --git a/src/jaegertracing/thrift/lib/py/src/transport/TSSLSocket.py b/src/jaegertracing/thrift/lib/py/src/transport/TSSLSocket.py new file mode 100644 index 000000000..5b3ae5991 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/transport/TSSLSocket.py @@ -0,0 +1,408 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import logging +import os +import socket +import ssl +import sys +import warnings + +from .sslcompat import _match_hostname, _match_has_ipaddress +from thrift.transport import TSocket +from thrift.transport.TTransport import TTransportException + +logger = logging.getLogger(__name__) +warnings.filterwarnings( + 'default', category=DeprecationWarning, module=__name__) + + +class TSSLBase(object): + # SSLContext is not available for Python < 2.7.9 + _has_ssl_context = sys.hexversion >= 0x020709F0 + + # ciphers argument is not available for Python < 2.7.0 + _has_ciphers = sys.hexversion >= 0x020700F0 + + # For python >= 2.7.9, use latest TLS that both client and server + # supports. + # SSL 2.0 and 3.0 are disabled via ssl.OP_NO_SSLv2 and ssl.OP_NO_SSLv3. + # For python < 2.7.9, use TLS 1.0 since TLSv1_X nor OP_NO_SSLvX is + # unavailable. + _default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else \ + ssl.PROTOCOL_TLSv1 + + def _init_context(self, ssl_version): + if self._has_ssl_context: + self._context = ssl.SSLContext(ssl_version) + if self._context.protocol == ssl.PROTOCOL_SSLv23: + self._context.options |= ssl.OP_NO_SSLv2 + self._context.options |= ssl.OP_NO_SSLv3 + else: + self._context = None + self._ssl_version = ssl_version + + @property + def _should_verify(self): + if self._has_ssl_context: + return self._context.verify_mode != ssl.CERT_NONE + else: + return self.cert_reqs != ssl.CERT_NONE + + @property + def ssl_version(self): + if self._has_ssl_context: + return self.ssl_context.protocol + else: + return self._ssl_version + + @property + def ssl_context(self): + return self._context + + SSL_VERSION = _default_protocol + """ + Default SSL version. + For backwards compatibility, it can be modified. + Use __init__ keyword argument "ssl_version" instead. + """ + + def _deprecated_arg(self, args, kwargs, pos, key): + if len(args) <= pos: + return + real_pos = pos + 3 + warnings.warn( + '%dth positional argument is deprecated.' + 'please use keyword argument instead.' + % real_pos, DeprecationWarning, stacklevel=3) + + if key in kwargs: + raise TypeError( + 'Duplicate argument: %dth argument and %s keyword argument.' + % (real_pos, key)) + kwargs[key] = args[pos] + + def _unix_socket_arg(self, host, port, args, kwargs): + key = 'unix_socket' + if host is None and port is None and len(args) == 1 and key not in kwargs: + kwargs[key] = args[0] + return True + return False + + def __getattr__(self, key): + if key == 'SSL_VERSION': + warnings.warn( + 'SSL_VERSION is deprecated.' + 'please use ssl_version attribute instead.', + DeprecationWarning, stacklevel=2) + return self.ssl_version + + def __init__(self, server_side, host, ssl_opts): + self._server_side = server_side + if TSSLBase.SSL_VERSION != self._default_protocol: + warnings.warn( + 'SSL_VERSION is deprecated.' + 'please use ssl_version keyword argument instead.', + DeprecationWarning, stacklevel=2) + self._context = ssl_opts.pop('ssl_context', None) + self._server_hostname = None + if not self._server_side: + self._server_hostname = ssl_opts.pop('server_hostname', host) + if self._context: + self._custom_context = True + if ssl_opts: + raise ValueError( + 'Incompatible arguments: ssl_context and %s' + % ' '.join(ssl_opts.keys())) + if not self._has_ssl_context: + raise ValueError( + 'ssl_context is not available for this version of Python') + else: + self._custom_context = False + ssl_version = ssl_opts.pop('ssl_version', TSSLBase.SSL_VERSION) + self._init_context(ssl_version) + self.cert_reqs = ssl_opts.pop('cert_reqs', ssl.CERT_REQUIRED) + self.ca_certs = ssl_opts.pop('ca_certs', None) + self.keyfile = ssl_opts.pop('keyfile', None) + self.certfile = ssl_opts.pop('certfile', None) + self.ciphers = ssl_opts.pop('ciphers', None) + + if ssl_opts: + raise ValueError( + 'Unknown keyword arguments: ', ' '.join(ssl_opts.keys())) + + if self._should_verify: + if not self.ca_certs: + raise ValueError( + 'ca_certs is needed when cert_reqs is not ssl.CERT_NONE') + if not os.access(self.ca_certs, os.R_OK): + raise IOError('Certificate Authority ca_certs file "%s" ' + 'is not readable, cannot validate SSL ' + 'certificates.' % (self.ca_certs)) + + @property + def certfile(self): + return self._certfile + + @certfile.setter + def certfile(self, certfile): + if self._server_side and not certfile: + raise ValueError('certfile is needed for server-side') + if certfile and not os.access(certfile, os.R_OK): + raise IOError('No such certfile found: %s' % (certfile)) + self._certfile = certfile + + def _wrap_socket(self, sock): + if self._has_ssl_context: + if not self._custom_context: + self.ssl_context.verify_mode = self.cert_reqs + if self.certfile: + self.ssl_context.load_cert_chain(self.certfile, + self.keyfile) + if self.ciphers: + self.ssl_context.set_ciphers(self.ciphers) + if self.ca_certs: + self.ssl_context.load_verify_locations(self.ca_certs) + return self.ssl_context.wrap_socket( + sock, server_side=self._server_side, + server_hostname=self._server_hostname) + else: + ssl_opts = { + 'ssl_version': self._ssl_version, + 'server_side': self._server_side, + 'ca_certs': self.ca_certs, + 'keyfile': self.keyfile, + 'certfile': self.certfile, + 'cert_reqs': self.cert_reqs, + } + if self.ciphers: + if self._has_ciphers: + ssl_opts['ciphers'] = self.ciphers + else: + logger.warning( + 'ciphers is specified but ignored due to old Python version') + return ssl.wrap_socket(sock, **ssl_opts) + + +class TSSLSocket(TSocket.TSocket, TSSLBase): + """ + SSL implementation of TSocket + + This class creates outbound sockets wrapped using the + python standard ssl module for encrypted connections. + """ + + # New signature + # def __init__(self, host='localhost', port=9090, unix_socket=None, + # **ssl_args): + # Deprecated signature + # def __init__(self, host='localhost', port=9090, validate=True, + # ca_certs=None, keyfile=None, certfile=None, + # unix_socket=None, ciphers=None): + def __init__(self, host='localhost', port=9090, *args, **kwargs): + """Positional arguments: ``host``, ``port``, ``unix_socket`` + + Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, + ``ssl_version``, ``ca_certs``, + ``ciphers`` (Python 2.7.0 or later), + ``server_hostname`` (Python 2.7.9 or later) + Passed to ssl.wrap_socket. See ssl.wrap_socket documentation. + + Alternative keyword arguments: (Python 2.7.9 or later) + ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket + ``server_hostname``: Passed to SSLContext.wrap_socket + + Common keyword argument: + ``validate_callback`` (cert, hostname) -> None: + Called after SSL handshake. Can raise when hostname does not + match the cert. + ``socket_keepalive`` enable TCP keepalive, default off. + """ + self.is_valid = False + self.peercert = None + + if args: + if len(args) > 6: + raise TypeError('Too many positional argument') + if not self._unix_socket_arg(host, port, args, kwargs): + self._deprecated_arg(args, kwargs, 0, 'validate') + self._deprecated_arg(args, kwargs, 1, 'ca_certs') + self._deprecated_arg(args, kwargs, 2, 'keyfile') + self._deprecated_arg(args, kwargs, 3, 'certfile') + self._deprecated_arg(args, kwargs, 4, 'unix_socket') + self._deprecated_arg(args, kwargs, 5, 'ciphers') + + validate = kwargs.pop('validate', None) + if validate is not None: + cert_reqs_name = 'CERT_REQUIRED' if validate else 'CERT_NONE' + warnings.warn( + 'validate is deprecated. please use cert_reqs=ssl.%s instead' + % cert_reqs_name, + DeprecationWarning, stacklevel=2) + if 'cert_reqs' in kwargs: + raise TypeError('Cannot specify both validate and cert_reqs') + kwargs['cert_reqs'] = ssl.CERT_REQUIRED if validate else ssl.CERT_NONE + + unix_socket = kwargs.pop('unix_socket', None) + socket_keepalive = kwargs.pop('socket_keepalive', False) + self._validate_callback = kwargs.pop('validate_callback', _match_hostname) + TSSLBase.__init__(self, False, host, kwargs) + TSocket.TSocket.__init__(self, host, port, unix_socket, + socket_keepalive=socket_keepalive) + + def close(self): + try: + self.handle.settimeout(0.001) + self.handle = self.handle.unwrap() + except (ssl.SSLError, socket.error, OSError): + # could not complete shutdown in a reasonable amount of time. bail. + pass + TSocket.TSocket.close(self) + + @property + def validate(self): + warnings.warn('validate is deprecated. please use cert_reqs instead', + DeprecationWarning, stacklevel=2) + return self.cert_reqs != ssl.CERT_NONE + + @validate.setter + def validate(self, value): + warnings.warn('validate is deprecated. please use cert_reqs instead', + DeprecationWarning, stacklevel=2) + self.cert_reqs = ssl.CERT_REQUIRED if value else ssl.CERT_NONE + + def _do_open(self, family, socktype): + plain_sock = socket.socket(family, socktype) + try: + return self._wrap_socket(plain_sock) + except Exception as ex: + plain_sock.close() + msg = 'failed to initialize SSL' + logger.exception(msg) + raise TTransportException(type=TTransportException.NOT_OPEN, message=msg, inner=ex) + + def open(self): + super(TSSLSocket, self).open() + if self._should_verify: + self.peercert = self.handle.getpeercert() + try: + self._validate_callback(self.peercert, self._server_hostname) + self.is_valid = True + except TTransportException: + raise + except Exception as ex: + raise TTransportException(message=str(ex), inner=ex) + + +class TSSLServerSocket(TSocket.TServerSocket, TSSLBase): + """SSL implementation of TServerSocket + + This uses the ssl module's wrap_socket() method to provide SSL + negotiated encryption. + """ + + # New signature + # def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args): + # Deprecated signature + # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None): + def __init__(self, host=None, port=9090, *args, **kwargs): + """Positional arguments: ``host``, ``port``, ``unix_socket`` + + Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``, + ``ca_certs``, ``ciphers`` (Python 2.7.0 or later) + See ssl.wrap_socket documentation. + + Alternative keyword arguments: (Python 2.7.9 or later) + ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket + ``server_hostname``: Passed to SSLContext.wrap_socket + + Common keyword argument: + ``validate_callback`` (cert, hostname) -> None: + Called after SSL handshake. Can raise when hostname does not + match the cert. + """ + if args: + if len(args) > 3: + raise TypeError('Too many positional argument') + if not self._unix_socket_arg(host, port, args, kwargs): + self._deprecated_arg(args, kwargs, 0, 'certfile') + self._deprecated_arg(args, kwargs, 1, 'unix_socket') + self._deprecated_arg(args, kwargs, 2, 'ciphers') + + if 'ssl_context' not in kwargs: + # Preserve existing behaviors for default values + if 'cert_reqs' not in kwargs: + kwargs['cert_reqs'] = ssl.CERT_NONE + if'certfile' not in kwargs: + kwargs['certfile'] = 'cert.pem' + + unix_socket = kwargs.pop('unix_socket', None) + self._validate_callback = \ + kwargs.pop('validate_callback', _match_hostname) + TSSLBase.__init__(self, True, None, kwargs) + TSocket.TServerSocket.__init__(self, host, port, unix_socket) + if self._should_verify and not _match_has_ipaddress: + raise ValueError('Need ipaddress and backports.ssl_match_hostname ' + 'module to verify client certificate') + + def setCertfile(self, certfile): + """Set or change the server certificate file used to wrap new + connections. + + @param certfile: The filename of the server certificate, + i.e. '/etc/certs/server.pem' + @type certfile: str + + Raises an IOError exception if the certfile is not present or unreadable. + """ + warnings.warn( + 'setCertfile is deprecated. please use certfile property instead.', + DeprecationWarning, stacklevel=2) + self.certfile = certfile + + def accept(self): + plain_client, addr = self.handle.accept() + try: + client = self._wrap_socket(plain_client) + except (ssl.SSLError, socket.error, OSError): + logger.exception('Error while accepting from %s', addr) + # failed handshake/ssl wrap, close socket to client + plain_client.close() + # raise + # We can't raise the exception, because it kills most TServer derived + # serve() methods. + # Instead, return None, and let the TServer instance deal with it in + # other exception handling. (but TSimpleServer dies anyway) + return None + + if self._should_verify: + client.peercert = client.getpeercert() + try: + self._validate_callback(client.peercert, addr[0]) + client.is_valid = True + except Exception: + logger.warn('Failed to validate client certificate address: %s', + addr[0], exc_info=True) + client.close() + plain_client.close() + return None + + result = TSocket.TSocket() + result.handle = client + return result diff --git a/src/jaegertracing/thrift/lib/py/src/transport/TSocket.py b/src/jaegertracing/thrift/lib/py/src/transport/TSocket.py new file mode 100644 index 000000000..df25d42db --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/transport/TSocket.py @@ -0,0 +1,215 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import errno +import logging +import os +import socket +import sys + +from .TTransport import TTransportBase, TTransportException, TServerTransportBase + +logger = logging.getLogger(__name__) + + +class TSocketBase(TTransportBase): + def _resolveAddr(self): + if self._unix_socket is not None: + return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, + self._unix_socket)] + else: + return socket.getaddrinfo(self.host, + self.port, + self._socket_family, + socket.SOCK_STREAM, + 0, + socket.AI_PASSIVE | socket.AI_ADDRCONFIG) + + def close(self): + if self.handle: + self.handle.close() + self.handle = None + + +class TSocket(TSocketBase): + """Socket implementation of TTransport base.""" + + def __init__(self, host='localhost', port=9090, unix_socket=None, + socket_family=socket.AF_UNSPEC, + socket_keepalive=False): + """Initialize a TSocket + + @param host(str) The host to connect to. + @param port(int) The (TCP) port to connect to. + @param unix_socket(str) The filename of a unix socket to connect to. + (host and port will be ignored.) + @param socket_family(int) The socket family to use with this socket. + @param socket_keepalive(bool) enable TCP keepalive, default off. + """ + self.host = host + self.port = port + self.handle = None + self._unix_socket = unix_socket + self._timeout = None + self._socket_family = socket_family + self._socket_keepalive = socket_keepalive + + def setHandle(self, h): + self.handle = h + + def isOpen(self): + return self.handle is not None + + def setTimeout(self, ms): + if ms is None: + self._timeout = None + else: + self._timeout = ms / 1000.0 + + if self.handle is not None: + self.handle.settimeout(self._timeout) + + def _do_open(self, family, socktype): + return socket.socket(family, socktype) + + @property + def _address(self): + return self._unix_socket if self._unix_socket else '%s:%d' % (self.host, self.port) + + def open(self): + if self.handle: + raise TTransportException(type=TTransportException.ALREADY_OPEN, message="already open") + try: + addrs = self._resolveAddr() + except socket.gaierror as gai: + msg = 'failed to resolve sockaddr for ' + str(self._address) + logger.exception(msg) + raise TTransportException(type=TTransportException.NOT_OPEN, message=msg, inner=gai) + for family, socktype, _, _, sockaddr in addrs: + handle = self._do_open(family, socktype) + + # TCP_KEEPALIVE + if self._socket_keepalive: + handle.setsockopt(socket.IPPROTO_TCP, socket.SO_KEEPALIVE, 1) + + handle.settimeout(self._timeout) + try: + handle.connect(sockaddr) + self.handle = handle + return + except socket.error: + handle.close() + logger.info('Could not connect to %s', sockaddr, exc_info=True) + msg = 'Could not connect to any of %s' % list(map(lambda a: a[4], + addrs)) + logger.error(msg) + raise TTransportException(type=TTransportException.NOT_OPEN, message=msg) + + def read(self, sz): + try: + buff = self.handle.recv(sz) + except socket.error as e: + if (e.args[0] == errno.ECONNRESET and + (sys.platform == 'darwin' or sys.platform.startswith('freebsd'))): + # freebsd and Mach don't follow POSIX semantic of recv + # and fail with ECONNRESET if peer performed shutdown. + # See corresponding comment and code in TSocket::read() + # in lib/cpp/src/transport/TSocket.cpp. + self.close() + # Trigger the check to raise the END_OF_FILE exception below. + buff = '' + elif e.args[0] == errno.ETIMEDOUT: + raise TTransportException(type=TTransportException.TIMED_OUT, message="read timeout", inner=e) + else: + raise TTransportException(message="unexpected exception", inner=e) + if len(buff) == 0: + raise TTransportException(type=TTransportException.END_OF_FILE, + message='TSocket read 0 bytes') + return buff + + def write(self, buff): + if not self.handle: + raise TTransportException(type=TTransportException.NOT_OPEN, + message='Transport not open') + sent = 0 + have = len(buff) + while sent < have: + try: + plus = self.handle.send(buff) + if plus == 0: + raise TTransportException(type=TTransportException.END_OF_FILE, + message='TSocket sent 0 bytes') + sent += plus + buff = buff[plus:] + except socket.error as e: + raise TTransportException(message="unexpected exception", inner=e) + + def flush(self): + pass + + +class TServerSocket(TSocketBase, TServerTransportBase): + """Socket implementation of TServerTransport base.""" + + def __init__(self, host=None, port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC): + self.host = host + self.port = port + self._unix_socket = unix_socket + self._socket_family = socket_family + self.handle = None + self._backlog = 128 + + def setBacklog(self, backlog=None): + if not self.handle: + self._backlog = backlog + else: + # We cann't update backlog when it is already listening, since the + # handle has been created. + logger.warn('You have to set backlog before listen.') + + def listen(self): + res0 = self._resolveAddr() + socket_family = self._socket_family == socket.AF_UNSPEC and socket.AF_INET6 or self._socket_family + for res in res0: + if res[0] is socket_family or res is res0[-1]: + break + + # We need remove the old unix socket if the file exists and + # nobody is listening on it. + if self._unix_socket: + tmp = socket.socket(res[0], res[1]) + try: + tmp.connect(res[4]) + except socket.error as err: + eno, message = err.args + if eno == errno.ECONNREFUSED: + os.unlink(res[4]) + + self.handle = socket.socket(res[0], res[1]) + self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if hasattr(self.handle, 'settimeout'): + self.handle.settimeout(None) + self.handle.bind(res[4]) + self.handle.listen(self._backlog) + + def accept(self): + client, addr = self.handle.accept() + result = TSocket() + result.setHandle(client) + return result diff --git a/src/jaegertracing/thrift/lib/py/src/transport/TTransport.py b/src/jaegertracing/thrift/lib/py/src/transport/TTransport.py new file mode 100644 index 000000000..9dbe95df4 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/transport/TTransport.py @@ -0,0 +1,456 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from struct import pack, unpack +from thrift.Thrift import TException +from ..compat import BufferIO + + +class TTransportException(TException): + """Custom Transport Exception class""" + + UNKNOWN = 0 + NOT_OPEN = 1 + ALREADY_OPEN = 2 + TIMED_OUT = 3 + END_OF_FILE = 4 + NEGATIVE_SIZE = 5 + SIZE_LIMIT = 6 + INVALID_CLIENT_TYPE = 7 + + def __init__(self, type=UNKNOWN, message=None, inner=None): + TException.__init__(self, message) + self.type = type + self.inner = inner + + +class TTransportBase(object): + """Base class for Thrift transport layer.""" + + def isOpen(self): + pass + + def open(self): + pass + + def close(self): + pass + + def read(self, sz): + pass + + def readAll(self, sz): + buff = b'' + have = 0 + while (have < sz): + chunk = self.read(sz - have) + chunkLen = len(chunk) + have += chunkLen + buff += chunk + + if chunkLen == 0: + raise EOFError() + + return buff + + def write(self, buf): + pass + + def flush(self): + pass + + +# This class should be thought of as an interface. +class CReadableTransport(object): + """base class for transports that are readable from C""" + + # TODO(dreiss): Think about changing this interface to allow us to use + # a (Python, not c) StringIO instead, because it allows + # you to write after reading. + + # NOTE: This is a classic class, so properties will NOT work + # correctly for setting. + @property + def cstringio_buf(self): + """A cStringIO buffer that contains the current chunk we are reading.""" + pass + + def cstringio_refill(self, partialread, reqlen): + """Refills cstringio_buf. + + Returns the currently used buffer (which can but need not be the same as + the old cstringio_buf). partialread is what the C code has read from the + buffer, and should be inserted into the buffer before any more reads. The + return value must be a new, not borrowed reference. Something along the + lines of self._buf should be fine. + + If reqlen bytes can't be read, throw EOFError. + """ + pass + + +class TServerTransportBase(object): + """Base class for Thrift server transports.""" + + def listen(self): + pass + + def accept(self): + pass + + def close(self): + pass + + +class TTransportFactoryBase(object): + """Base class for a Transport Factory""" + + def getTransport(self, trans): + return trans + + +class TBufferedTransportFactory(object): + """Factory transport that builds buffered transports""" + + def getTransport(self, trans): + buffered = TBufferedTransport(trans) + return buffered + + +class TBufferedTransport(TTransportBase, CReadableTransport): + """Class that wraps another transport and buffers its I/O. + + The implementation uses a (configurable) fixed-size read buffer + but buffers all writes until a flush is performed. + """ + DEFAULT_BUFFER = 4096 + + def __init__(self, trans, rbuf_size=DEFAULT_BUFFER): + self.__trans = trans + self.__wbuf = BufferIO() + # Pass string argument to initialize read buffer as cStringIO.InputType + self.__rbuf = BufferIO(b'') + self.__rbuf_size = rbuf_size + + def isOpen(self): + return self.__trans.isOpen() + + def open(self): + return self.__trans.open() + + def close(self): + return self.__trans.close() + + def read(self, sz): + ret = self.__rbuf.read(sz) + if len(ret) != 0: + return ret + self.__rbuf = BufferIO(self.__trans.read(max(sz, self.__rbuf_size))) + return self.__rbuf.read(sz) + + def write(self, buf): + try: + self.__wbuf.write(buf) + except Exception as e: + # on exception reset wbuf so it doesn't contain a partial function call + self.__wbuf = BufferIO() + raise e + + def flush(self): + out = self.__wbuf.getvalue() + # reset wbuf before write/flush to preserve state on underlying failure + self.__wbuf = BufferIO() + self.__trans.write(out) + self.__trans.flush() + + # Implement the CReadableTransport interface. + @property + def cstringio_buf(self): + return self.__rbuf + + def cstringio_refill(self, partialread, reqlen): + retstring = partialread + if reqlen < self.__rbuf_size: + # try to make a read of as much as we can. + retstring += self.__trans.read(self.__rbuf_size) + + # but make sure we do read reqlen bytes. + if len(retstring) < reqlen: + retstring += self.__trans.readAll(reqlen - len(retstring)) + + self.__rbuf = BufferIO(retstring) + return self.__rbuf + + +class TMemoryBuffer(TTransportBase, CReadableTransport): + """Wraps a cBytesIO object as a TTransport. + + NOTE: Unlike the C++ version of this class, you cannot write to it + then immediately read from it. If you want to read from a + TMemoryBuffer, you must either pass a string to the constructor. + TODO(dreiss): Make this work like the C++ version. + """ + + def __init__(self, value=None, offset=0): + """value -- a value to read from for stringio + + If value is set, this will be a transport for reading, + otherwise, it is for writing""" + if value is not None: + self._buffer = BufferIO(value) + else: + self._buffer = BufferIO() + if offset: + self._buffer.seek(offset) + + def isOpen(self): + return not self._buffer.closed + + def open(self): + pass + + def close(self): + self._buffer.close() + + def read(self, sz): + return self._buffer.read(sz) + + def write(self, buf): + self._buffer.write(buf) + + def flush(self): + pass + + def getvalue(self): + return self._buffer.getvalue() + + # Implement the CReadableTransport interface. + @property + def cstringio_buf(self): + return self._buffer + + def cstringio_refill(self, partialread, reqlen): + # only one shot at reading... + raise EOFError() + + +class TFramedTransportFactory(object): + """Factory transport that builds framed transports""" + + def getTransport(self, trans): + framed = TFramedTransport(trans) + return framed + + +class TFramedTransport(TTransportBase, CReadableTransport): + """Class that wraps another transport and frames its I/O when writing.""" + + def __init__(self, trans,): + self.__trans = trans + self.__rbuf = BufferIO(b'') + self.__wbuf = BufferIO() + + def isOpen(self): + return self.__trans.isOpen() + + def open(self): + return self.__trans.open() + + def close(self): + return self.__trans.close() + + def read(self, sz): + ret = self.__rbuf.read(sz) + if len(ret) != 0: + return ret + + self.readFrame() + return self.__rbuf.read(sz) + + def readFrame(self): + buff = self.__trans.readAll(4) + sz, = unpack('!i', buff) + self.__rbuf = BufferIO(self.__trans.readAll(sz)) + + def write(self, buf): + self.__wbuf.write(buf) + + def flush(self): + wout = self.__wbuf.getvalue() + wsz = len(wout) + # reset wbuf before write/flush to preserve state on underlying failure + self.__wbuf = BufferIO() + # N.B.: Doing this string concatenation is WAY cheaper than making + # two separate calls to the underlying socket object. Socket writes in + # Python turn out to be REALLY expensive, but it seems to do a pretty + # good job of managing string buffer operations without excessive copies + buf = pack("!i", wsz) + wout + self.__trans.write(buf) + self.__trans.flush() + + # Implement the CReadableTransport interface. + @property + def cstringio_buf(self): + return self.__rbuf + + def cstringio_refill(self, prefix, reqlen): + # self.__rbuf will already be empty here because fastbinary doesn't + # ask for a refill until the previous buffer is empty. Therefore, + # we can start reading new frames immediately. + while len(prefix) < reqlen: + self.readFrame() + prefix += self.__rbuf.getvalue() + self.__rbuf = BufferIO(prefix) + return self.__rbuf + + +class TFileObjectTransport(TTransportBase): + """Wraps a file-like object to make it work as a Thrift transport.""" + + def __init__(self, fileobj): + self.fileobj = fileobj + + def isOpen(self): + return True + + def close(self): + self.fileobj.close() + + def read(self, sz): + return self.fileobj.read(sz) + + def write(self, buf): + self.fileobj.write(buf) + + def flush(self): + self.fileobj.flush() + + +class TSaslClientTransport(TTransportBase, CReadableTransport): + """ + SASL transport + """ + + START = 1 + OK = 2 + BAD = 3 + ERROR = 4 + COMPLETE = 5 + + def __init__(self, transport, host, service, mechanism='GSSAPI', + **sasl_kwargs): + """ + transport: an underlying transport to use, typically just a TSocket + host: the name of the server, from a SASL perspective + service: the name of the server's service, from a SASL perspective + mechanism: the name of the preferred mechanism to use + + All other kwargs will be passed to the puresasl.client.SASLClient + constructor. + """ + + from puresasl.client import SASLClient + + self.transport = transport + self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) + + self.__wbuf = BufferIO() + self.__rbuf = BufferIO(b'') + + def open(self): + if not self.transport.isOpen(): + self.transport.open() + + self.send_sasl_msg(self.START, bytes(self.sasl.mechanism, 'ascii')) + self.send_sasl_msg(self.OK, self.sasl.process()) + + while True: + status, challenge = self.recv_sasl_msg() + if status == self.OK: + self.send_sasl_msg(self.OK, self.sasl.process(challenge)) + elif status == self.COMPLETE: + if not self.sasl.complete: + raise TTransportException( + TTransportException.NOT_OPEN, + "The server erroneously indicated " + "that SASL negotiation was complete") + else: + break + else: + raise TTransportException( + TTransportException.NOT_OPEN, + "Bad SASL negotiation status: %d (%s)" + % (status, challenge)) + + def send_sasl_msg(self, status, body): + header = pack(">BI", status, len(body)) + self.transport.write(header + body) + self.transport.flush() + + def recv_sasl_msg(self): + header = self.transport.readAll(5) + status, length = unpack(">BI", header) + if length > 0: + payload = self.transport.readAll(length) + else: + payload = "" + return status, payload + + def write(self, data): + self.__wbuf.write(data) + + def flush(self): + data = self.__wbuf.getvalue() + encoded = self.sasl.wrap(data) + self.transport.write(pack("!i", len(encoded)) + encoded) + self.transport.flush() + self.__wbuf = BufferIO() + + def read(self, sz): + ret = self.__rbuf.read(sz) + if len(ret) != 0: + return ret + + self._read_frame() + return self.__rbuf.read(sz) + + def _read_frame(self): + header = self.transport.readAll(4) + length, = unpack('!i', header) + encoded = self.transport.readAll(length) + self.__rbuf = BufferIO(self.sasl.unwrap(encoded)) + + def close(self): + self.sasl.dispose() + self.transport.close() + + # based on TFramedTransport + @property + def cstringio_buf(self): + return self.__rbuf + + def cstringio_refill(self, prefix, reqlen): + # self.__rbuf will already be empty here because fastbinary doesn't + # ask for a refill until the previous buffer is empty. Therefore, + # we can start reading new frames immediately. + while len(prefix) < reqlen: + self._read_frame() + prefix += self.__rbuf.getvalue() + self.__rbuf = BufferIO(prefix) + return self.__rbuf diff --git a/src/jaegertracing/thrift/lib/py/src/transport/TTwisted.py b/src/jaegertracing/thrift/lib/py/src/transport/TTwisted.py new file mode 100644 index 000000000..a27f0adad --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/transport/TTwisted.py @@ -0,0 +1,329 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from io import BytesIO +import struct + +from zope.interface import implementer, Interface, Attribute +from twisted.internet.protocol import ServerFactory, ClientFactory, \ + connectionDone +from twisted.internet import defer +from twisted.internet.threads import deferToThread +from twisted.protocols import basic +from twisted.web import server, resource, http + +from thrift.transport import TTransport + + +class TMessageSenderTransport(TTransport.TTransportBase): + + def __init__(self): + self.__wbuf = BytesIO() + + def write(self, buf): + self.__wbuf.write(buf) + + def flush(self): + msg = self.__wbuf.getvalue() + self.__wbuf = BytesIO() + return self.sendMessage(msg) + + def sendMessage(self, message): + raise NotImplementedError + + +class TCallbackTransport(TMessageSenderTransport): + + def __init__(self, func): + TMessageSenderTransport.__init__(self) + self.func = func + + def sendMessage(self, message): + return self.func(message) + + +class ThriftClientProtocol(basic.Int32StringReceiver): + + MAX_LENGTH = 2 ** 31 - 1 + + def __init__(self, client_class, iprot_factory, oprot_factory=None): + self._client_class = client_class + self._iprot_factory = iprot_factory + if oprot_factory is None: + self._oprot_factory = iprot_factory + else: + self._oprot_factory = oprot_factory + + self.recv_map = {} + self.started = defer.Deferred() + + def dispatch(self, msg): + self.sendString(msg) + + def connectionMade(self): + tmo = TCallbackTransport(self.dispatch) + self.client = self._client_class(tmo, self._oprot_factory) + self.started.callback(self.client) + + def connectionLost(self, reason=connectionDone): + # the called errbacks can add items to our client's _reqs, + # so we need to use a tmp, and iterate until no more requests + # are added during errbacks + if self.client: + tex = TTransport.TTransportException( + type=TTransport.TTransportException.END_OF_FILE, + message='Connection closed (%s)' % reason) + while self.client._reqs: + _, v = self.client._reqs.popitem() + v.errback(tex) + del self.client._reqs + self.client = None + + def stringReceived(self, frame): + tr = TTransport.TMemoryBuffer(frame) + iprot = self._iprot_factory.getProtocol(tr) + (fname, mtype, rseqid) = iprot.readMessageBegin() + + try: + method = self.recv_map[fname] + except KeyError: + method = getattr(self.client, 'recv_' + fname) + self.recv_map[fname] = method + + method(iprot, mtype, rseqid) + + +class ThriftSASLClientProtocol(ThriftClientProtocol): + + START = 1 + OK = 2 + BAD = 3 + ERROR = 4 + COMPLETE = 5 + + MAX_LENGTH = 2 ** 31 - 1 + + def __init__(self, client_class, iprot_factory, oprot_factory=None, + host=None, service=None, mechanism='GSSAPI', **sasl_kwargs): + """ + host: the name of the server, from a SASL perspective + service: the name of the server's service, from a SASL perspective + mechanism: the name of the preferred mechanism to use + + All other kwargs will be passed to the puresasl.client.SASLClient + constructor. + """ + + from puresasl.client import SASLClient + self.SASLCLient = SASLClient + + ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory) + + self._sasl_negotiation_deferred = None + self._sasl_negotiation_status = None + self.client = None + + if host is not None: + self.createSASLClient(host, service, mechanism, **sasl_kwargs) + + def createSASLClient(self, host, service, mechanism, **kwargs): + self.sasl = self.SASLClient(host, service, mechanism, **kwargs) + + def dispatch(self, msg): + encoded = self.sasl.wrap(msg) + len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded)) + ThriftClientProtocol.dispatch(self, len_and_encoded) + + @defer.inlineCallbacks + def connectionMade(self): + self._sendSASLMessage(self.START, self.sasl.mechanism) + initial_message = yield deferToThread(self.sasl.process) + self._sendSASLMessage(self.OK, initial_message) + + while True: + status, challenge = yield self._receiveSASLMessage() + if status == self.OK: + response = yield deferToThread(self.sasl.process, challenge) + self._sendSASLMessage(self.OK, response) + elif status == self.COMPLETE: + if not self.sasl.complete: + msg = "The server erroneously indicated that SASL " \ + "negotiation was complete" + raise TTransport.TTransportException(msg, message=msg) + else: + break + else: + msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge) + raise TTransport.TTransportException(msg, message=msg) + + self._sasl_negotiation_deferred = None + ThriftClientProtocol.connectionMade(self) + + def _sendSASLMessage(self, status, body): + if body is None: + body = "" + header = struct.pack(">BI", status, len(body)) + self.transport.write(header + body) + + def _receiveSASLMessage(self): + self._sasl_negotiation_deferred = defer.Deferred() + self._sasl_negotiation_status = None + return self._sasl_negotiation_deferred + + def connectionLost(self, reason=connectionDone): + if self.client: + ThriftClientProtocol.connectionLost(self, reason) + + def dataReceived(self, data): + if self._sasl_negotiation_deferred: + # we got a sasl challenge in the format (status, length, challenge) + # save the status, let IntNStringReceiver piece the challenge data together + self._sasl_negotiation_status, = struct.unpack("B", data[0]) + ThriftClientProtocol.dataReceived(self, data[1:]) + else: + # normal frame, let IntNStringReceiver piece it together + ThriftClientProtocol.dataReceived(self, data) + + def stringReceived(self, frame): + if self._sasl_negotiation_deferred: + # the frame is just a SASL challenge + response = (self._sasl_negotiation_status, frame) + self._sasl_negotiation_deferred.callback(response) + else: + # there's a second 4 byte length prefix inside the frame + decoded_frame = self.sasl.unwrap(frame[4:]) + ThriftClientProtocol.stringReceived(self, decoded_frame) + + +class ThriftServerProtocol(basic.Int32StringReceiver): + + MAX_LENGTH = 2 ** 31 - 1 + + def dispatch(self, msg): + self.sendString(msg) + + def processError(self, error): + self.transport.loseConnection() + + def processOk(self, _, tmo): + msg = tmo.getvalue() + + if len(msg) > 0: + self.dispatch(msg) + + def stringReceived(self, frame): + tmi = TTransport.TMemoryBuffer(frame) + tmo = TTransport.TMemoryBuffer() + + iprot = self.factory.iprot_factory.getProtocol(tmi) + oprot = self.factory.oprot_factory.getProtocol(tmo) + + d = self.factory.processor.process(iprot, oprot) + d.addCallbacks(self.processOk, self.processError, + callbackArgs=(tmo,)) + + +class IThriftServerFactory(Interface): + + processor = Attribute("Thrift processor") + + iprot_factory = Attribute("Input protocol factory") + + oprot_factory = Attribute("Output protocol factory") + + +class IThriftClientFactory(Interface): + + client_class = Attribute("Thrift client class") + + iprot_factory = Attribute("Input protocol factory") + + oprot_factory = Attribute("Output protocol factory") + + +@implementer(IThriftServerFactory) +class ThriftServerFactory(ServerFactory): + + protocol = ThriftServerProtocol + + def __init__(self, processor, iprot_factory, oprot_factory=None): + self.processor = processor + self.iprot_factory = iprot_factory + if oprot_factory is None: + self.oprot_factory = iprot_factory + else: + self.oprot_factory = oprot_factory + + +@implementer(IThriftClientFactory) +class ThriftClientFactory(ClientFactory): + + protocol = ThriftClientProtocol + + def __init__(self, client_class, iprot_factory, oprot_factory=None): + self.client_class = client_class + self.iprot_factory = iprot_factory + if oprot_factory is None: + self.oprot_factory = iprot_factory + else: + self.oprot_factory = oprot_factory + + def buildProtocol(self, addr): + p = self.protocol(self.client_class, self.iprot_factory, + self.oprot_factory) + p.factory = self + return p + + +class ThriftResource(resource.Resource): + + allowedMethods = ('POST',) + + def __init__(self, processor, inputProtocolFactory, + outputProtocolFactory=None): + resource.Resource.__init__(self) + self.inputProtocolFactory = inputProtocolFactory + if outputProtocolFactory is None: + self.outputProtocolFactory = inputProtocolFactory + else: + self.outputProtocolFactory = outputProtocolFactory + self.processor = processor + + def getChild(self, path, request): + return self + + def _cbProcess(self, _, request, tmo): + msg = tmo.getvalue() + request.setResponseCode(http.OK) + request.setHeader("content-type", "application/x-thrift") + request.write(msg) + request.finish() + + def render_POST(self, request): + request.content.seek(0, 0) + data = request.content.read() + tmi = TTransport.TMemoryBuffer(data) + tmo = TTransport.TMemoryBuffer() + + iprot = self.inputProtocolFactory.getProtocol(tmi) + oprot = self.outputProtocolFactory.getProtocol(tmo) + + d = self.processor.process(iprot, oprot) + d.addCallback(self._cbProcess, request, tmo) + return server.NOT_DONE_YET diff --git a/src/jaegertracing/thrift/lib/py/src/transport/TZlibTransport.py b/src/jaegertracing/thrift/lib/py/src/transport/TZlibTransport.py new file mode 100644 index 000000000..e84857924 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/transport/TZlibTransport.py @@ -0,0 +1,248 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +"""TZlibTransport provides a compressed transport and transport factory +class, using the python standard library zlib module to implement +data compression. +""" + +from __future__ import division +import zlib +from .TTransport import TTransportBase, CReadableTransport +from ..compat import BufferIO + + +class TZlibTransportFactory(object): + """Factory transport that builds zlib compressed transports. + + This factory caches the last single client/transport that it was passed + and returns the same TZlibTransport object that was created. + + This caching means the TServer class will get the _same_ transport + object for both input and output transports from this factory. + (For non-threaded scenarios only, since the cache only holds one object) + + The purpose of this caching is to allocate only one TZlibTransport where + only one is really needed (since it must have separate read/write buffers), + and makes the statistics from getCompSavings() and getCompRatio() + easier to understand. + """ + # class scoped cache of last transport given and zlibtransport returned + _last_trans = None + _last_z = None + + def getTransport(self, trans, compresslevel=9): + """Wrap a transport, trans, with the TZlibTransport + compressed transport class, returning a new + transport to the caller. + + @param compresslevel: The zlib compression level, ranging + from 0 (no compression) to 9 (best compression). Defaults to 9. + @type compresslevel: int + + This method returns a TZlibTransport which wraps the + passed C{trans} TTransport derived instance. + """ + if trans == self._last_trans: + return self._last_z + ztrans = TZlibTransport(trans, compresslevel) + self._last_trans = trans + self._last_z = ztrans + return ztrans + + +class TZlibTransport(TTransportBase, CReadableTransport): + """Class that wraps a transport with zlib, compressing writes + and decompresses reads, using the python standard + library zlib module. + """ + # Read buffer size for the python fastbinary C extension, + # the TBinaryProtocolAccelerated class. + DEFAULT_BUFFSIZE = 4096 + + def __init__(self, trans, compresslevel=9): + """Create a new TZlibTransport, wrapping C{trans}, another + TTransport derived object. + + @param trans: A thrift transport object, i.e. a TSocket() object. + @type trans: TTransport + @param compresslevel: The zlib compression level, ranging + from 0 (no compression) to 9 (best compression). Default is 9. + @type compresslevel: int + """ + self.__trans = trans + self.compresslevel = compresslevel + self.__rbuf = BufferIO() + self.__wbuf = BufferIO() + self._init_zlib() + self._init_stats() + + def _reinit_buffers(self): + """Internal method to initialize/reset the internal StringIO objects + for read and write buffers. + """ + self.__rbuf = BufferIO() + self.__wbuf = BufferIO() + + def _init_stats(self): + """Internal method to reset the internal statistics counters + for compression ratios and bandwidth savings. + """ + self.bytes_in = 0 + self.bytes_out = 0 + self.bytes_in_comp = 0 + self.bytes_out_comp = 0 + + def _init_zlib(self): + """Internal method for setting up the zlib compression and + decompression objects. + """ + self._zcomp_read = zlib.decompressobj() + self._zcomp_write = zlib.compressobj(self.compresslevel) + + def getCompRatio(self): + """Get the current measured compression ratios (in,out) from + this transport. + + Returns a tuple of: + (inbound_compression_ratio, outbound_compression_ratio) + + The compression ratios are computed as: + compressed / uncompressed + + E.g., data that compresses by 10x will have a ratio of: 0.10 + and data that compresses to half of ts original size will + have a ratio of 0.5 + + None is returned if no bytes have yet been processed in + a particular direction. + """ + r_percent, w_percent = (None, None) + if self.bytes_in > 0: + r_percent = self.bytes_in_comp / self.bytes_in + if self.bytes_out > 0: + w_percent = self.bytes_out_comp / self.bytes_out + return (r_percent, w_percent) + + def getCompSavings(self): + """Get the current count of saved bytes due to data + compression. + + Returns a tuple of: + (inbound_saved_bytes, outbound_saved_bytes) + + Note: if compression is actually expanding your + data (only likely with very tiny thrift objects), then + the values returned will be negative. + """ + r_saved = self.bytes_in - self.bytes_in_comp + w_saved = self.bytes_out - self.bytes_out_comp + return (r_saved, w_saved) + + def isOpen(self): + """Return the underlying transport's open status""" + return self.__trans.isOpen() + + def open(self): + """Open the underlying transport""" + self._init_stats() + return self.__trans.open() + + def listen(self): + """Invoke the underlying transport's listen() method""" + self.__trans.listen() + + def accept(self): + """Accept connections on the underlying transport""" + return self.__trans.accept() + + def close(self): + """Close the underlying transport,""" + self._reinit_buffers() + self._init_zlib() + return self.__trans.close() + + def read(self, sz): + """Read up to sz bytes from the decompressed bytes buffer, and + read from the underlying transport if the decompression + buffer is empty. + """ + ret = self.__rbuf.read(sz) + if len(ret) > 0: + return ret + # keep reading from transport until something comes back + while True: + if self.readComp(sz): + break + ret = self.__rbuf.read(sz) + return ret + + def readComp(self, sz): + """Read compressed data from the underlying transport, then + decompress it and append it to the internal StringIO read buffer + """ + zbuf = self.__trans.read(sz) + zbuf = self._zcomp_read.unconsumed_tail + zbuf + buf = self._zcomp_read.decompress(zbuf) + self.bytes_in += len(zbuf) + self.bytes_in_comp += len(buf) + old = self.__rbuf.read() + self.__rbuf = BufferIO(old + buf) + if len(old) + len(buf) == 0: + return False + return True + + def write(self, buf): + """Write some bytes, putting them into the internal write + buffer for eventual compression. + """ + self.__wbuf.write(buf) + + def flush(self): + """Flush any queued up data in the write buffer and ensure the + compression buffer is flushed out to the underlying transport + """ + wout = self.__wbuf.getvalue() + if len(wout) > 0: + zbuf = self._zcomp_write.compress(wout) + self.bytes_out += len(wout) + self.bytes_out_comp += len(zbuf) + else: + zbuf = '' + ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH) + self.bytes_out_comp += len(ztail) + if (len(zbuf) + len(ztail)) > 0: + self.__wbuf = BufferIO() + self.__trans.write(zbuf + ztail) + self.__trans.flush() + + @property + def cstringio_buf(self): + """Implement the CReadableTransport interface""" + return self.__rbuf + + def cstringio_refill(self, partialread, reqlen): + """Implement the CReadableTransport interface for refill""" + retstring = partialread + if reqlen < self.DEFAULT_BUFFSIZE: + retstring += self.read(self.DEFAULT_BUFFSIZE) + while len(retstring) < reqlen: + retstring += self.read(reqlen - len(retstring)) + self.__rbuf = BufferIO(retstring) + return self.__rbuf diff --git a/src/jaegertracing/thrift/lib/py/src/transport/__init__.py b/src/jaegertracing/thrift/lib/py/src/transport/__init__.py new file mode 100644 index 000000000..c9596d9a6 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/transport/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['TTransport', 'TSocket', 'THttpClient', 'TZlibTransport'] diff --git a/src/jaegertracing/thrift/lib/py/src/transport/sslcompat.py b/src/jaegertracing/thrift/lib/py/src/transport/sslcompat.py new file mode 100644 index 000000000..ab00cb2a8 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/src/transport/sslcompat.py @@ -0,0 +1,100 @@ +# +# licensed to the apache software foundation (asf) under one +# or more contributor license agreements. see the notice file +# distributed with this work for additional information +# regarding copyright ownership. the asf licenses this file +# to you under the apache license, version 2.0 (the +# "license"); you may not use this file except in compliance +# with the license. you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, +# software distributed under the license is distributed on an +# "as is" basis, without warranties or conditions of any +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import logging +import sys + +from thrift.transport.TTransport import TTransportException + +logger = logging.getLogger(__name__) + + +def legacy_validate_callback(cert, hostname): + """legacy method to validate the peer's SSL certificate, and to check + the commonName of the certificate to ensure it matches the hostname we + used to make this connection. Does not support subjectAltName records + in certificates. + + raises TTransportException if the certificate fails validation. + """ + if 'subject' not in cert: + raise TTransportException( + TTransportException.NOT_OPEN, + 'No SSL certificate found from %s' % hostname) + fields = cert['subject'] + for field in fields: + # ensure structure we get back is what we expect + if not isinstance(field, tuple): + continue + cert_pair = field[0] + if len(cert_pair) < 2: + continue + cert_key, cert_value = cert_pair[0:2] + if cert_key != 'commonName': + continue + certhost = cert_value + # this check should be performed by some sort of Access Manager + if certhost == hostname: + # success, cert commonName matches desired hostname + return + else: + raise TTransportException( + TTransportException.UNKNOWN, + 'Hostname we connected to "%s" doesn\'t match certificate ' + 'provided commonName "%s"' % (hostname, certhost)) + raise TTransportException( + TTransportException.UNKNOWN, + 'Could not validate SSL certificate from host "%s". Cert=%s' + % (hostname, cert)) + + +def _optional_dependencies(): + try: + import ipaddress # noqa + logger.debug('ipaddress module is available') + ipaddr = True + except ImportError: + logger.warn('ipaddress module is unavailable') + ipaddr = False + + if sys.hexversion < 0x030500F0: + try: + from backports.ssl_match_hostname import match_hostname, __version__ as ver + ver = list(map(int, ver.split('.'))) + logger.debug('backports.ssl_match_hostname module is available') + match = match_hostname + if ver[0] * 10 + ver[1] >= 35: + return ipaddr, match + else: + logger.warn('backports.ssl_match_hostname module is too old') + ipaddr = False + except ImportError: + logger.warn('backports.ssl_match_hostname is unavailable') + ipaddr = False + try: + from ssl import match_hostname + logger.debug('ssl.match_hostname is available') + match = match_hostname + except ImportError: + logger.warn('using legacy validation callback') + match = legacy_validate_callback + return ipaddr, match + + +_match_has_ipaddress, _match_hostname = _optional_dependencies() diff --git a/src/jaegertracing/thrift/lib/py/test/_import_local_thrift.py b/src/jaegertracing/thrift/lib/py/test/_import_local_thrift.py new file mode 100644 index 000000000..d22312298 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/test/_import_local_thrift.py @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import glob +import os +import sys + +SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__)) +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR))) + +for libpath in glob.glob(os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*')): + if libpath.endswith('-%d.%d' % (sys.version_info[0], sys.version_info[1])): + sys.path.insert(0, libpath) + break diff --git a/src/jaegertracing/thrift/lib/py/test/test_sslsocket.py b/src/jaegertracing/thrift/lib/py/test/test_sslsocket.py new file mode 100644 index 000000000..f4c87f195 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/test/test_sslsocket.py @@ -0,0 +1,353 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import inspect +import logging +import os +import platform +import ssl +import sys +import tempfile +import threading +import unittest +import warnings +from contextlib import contextmanager + +import _import_local_thrift # noqa + +SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__)) +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR))) +SERVER_PEM = os.path.join(ROOT_DIR, 'test', 'keys', 'server.pem') +SERVER_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'server.crt') +SERVER_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'server.key') +CLIENT_CERT_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.crt') +CLIENT_KEY_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.key') +CLIENT_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.crt') +CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.key') +CLIENT_CA = os.path.join(ROOT_DIR, 'test', 'keys', 'CA.pem') + +TEST_CIPHERS = 'DES-CBC3-SHA:ECDHE-RSA-AES128-GCM-SHA256' + + +class ServerAcceptor(threading.Thread): + def __init__(self, server, expect_failure=False): + super(ServerAcceptor, self).__init__() + self.daemon = True + self._server = server + self._listening = threading.Event() + self._port = None + self._port_bound = threading.Event() + self._client = None + self._client_accepted = threading.Event() + self._expect_failure = expect_failure + frame = inspect.stack(3)[2] + self.name = frame[3] + del frame + + def run(self): + self._server.listen() + self._listening.set() + + try: + address = self._server.handle.getsockname() + if len(address) > 1: + # AF_INET addresses are 2-tuples (host, port) and AF_INET6 are + # 4-tuples (host, port, ...), but in each case port is in the second slot. + self._port = address[1] + finally: + self._port_bound.set() + + try: + self._client = self._server.accept() + if self._client: + self._client.read(5) # hello + self._client.write(b"there") + except Exception: + logging.exception('error on server side (%s):' % self.name) + if not self._expect_failure: + raise + finally: + self._client_accepted.set() + + def await_listening(self): + self._listening.wait() + + @property + def port(self): + self._port_bound.wait() + return self._port + + @property + def client(self): + self._client_accepted.wait() + return self._client + + def close(self): + if self._client: + self._client.close() + self._server.close() + + +# Python 2.6 compat +class AssertRaises(object): + def __init__(self, expected): + self._expected = expected + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_value, traceback): + if not exc_type or not issubclass(exc_type, self._expected): + raise Exception('fail') + return True + + +class TSSLSocketTest(unittest.TestCase): + def _server_socket(self, **kwargs): + return TSSLServerSocket(port=0, **kwargs) + + @contextmanager + def _connectable_client(self, server, expect_failure=False, path=None, **client_kwargs): + acc = ServerAcceptor(server, expect_failure) + try: + acc.start() + acc.await_listening() + + host, port = ('localhost', acc.port) if path is None else (None, None) + client = TSSLSocket(host, port, unix_socket=path, **client_kwargs) + yield acc, client + finally: + acc.close() + + def _assert_connection_failure(self, server, path=None, **client_args): + logging.disable(logging.CRITICAL) + try: + with self._connectable_client(server, True, path=path, **client_args) as (acc, client): + # We need to wait for a connection failure, but not too long. 20ms is a tunable + # compromise between test speed and stability + client.setTimeout(20) + with self._assert_raises(TTransportException): + client.open() + client.write(b"hello") + client.read(5) # b"there" + finally: + logging.disable(logging.NOTSET) + + def _assert_raises(self, exc): + if sys.hexversion >= 0x020700F0: + return self.assertRaises(exc) + else: + return AssertRaises(exc) + + def _assert_connection_success(self, server, path=None, **client_args): + with self._connectable_client(server, path=path, **client_args) as (acc, client): + try: + client.open() + client.write(b"hello") + self.assertEqual(client.read(5), b"there") + self.assertTrue(acc.client is not None) + finally: + client.close() + + # deprecated feature + def test_deprecation(self): + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) + TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT) + self.assertEqual(len(w), 1) + + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) + # Deprecated signature + # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None): + TSSLSocket('localhost', 0, True, SERVER_CERT, CLIENT_KEY, CLIENT_CERT, None, TEST_CIPHERS) + self.assertEqual(len(w), 7) + + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) + # Deprecated signature + # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None): + TSSLServerSocket(None, 0, SERVER_PEM, None, TEST_CIPHERS) + self.assertEqual(len(w), 3) + + # deprecated feature + def test_set_cert_reqs_by_validate(self): + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) + c1 = TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT) + self.assertEqual(c1.cert_reqs, ssl.CERT_REQUIRED) + + c1 = TSSLSocket('localhost', 0, validate=False) + self.assertEqual(c1.cert_reqs, ssl.CERT_NONE) + + self.assertEqual(len(w), 2) + + # deprecated feature + def test_set_validate_by_cert_reqs(self): + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) + c1 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_NONE) + self.assertFalse(c1.validate) + + c2 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) + self.assertTrue(c2.validate) + + c3 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_OPTIONAL, ca_certs=SERVER_CERT) + self.assertTrue(c3.validate) + + self.assertEqual(len(w), 3) + + def test_unix_domain_socket(self): + if platform.system() == 'Windows': + print('skipping test_unix_domain_socket') + return + fd, path = tempfile.mkstemp() + os.close(fd) + os.unlink(path) + try: + server = self._server_socket(unix_socket=path, keyfile=SERVER_KEY, certfile=SERVER_CERT) + self._assert_connection_success(server, path=path, cert_reqs=ssl.CERT_NONE) + finally: + os.unlink(path) + + def test_server_cert(self): + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) + self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) + + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) + # server cert not in ca_certs + self._assert_connection_failure(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=CLIENT_CERT) + + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) + self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE) + + def test_set_server_cert(self): + server = self._server_socket(keyfile=SERVER_KEY, certfile=CLIENT_CERT) + with self._assert_raises(Exception): + server.certfile = 'foo' + with self._assert_raises(Exception): + server.certfile = None + server.certfile = SERVER_CERT + self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) + + def test_client_cert(self): + if not _match_has_ipaddress: + print('skipping test_client_cert') + return + server = self._server_socket( + cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY, + certfile=SERVER_CERT, ca_certs=CLIENT_CERT) + self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=SERVER_CERT, keyfile=SERVER_KEY) + + server = self._server_socket( + cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY, + certfile=SERVER_CERT, ca_certs=CLIENT_CA) + self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT_NO_IP, keyfile=CLIENT_KEY_NO_IP) + + server = self._server_socket( + cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY, + certfile=SERVER_CERT, ca_certs=CLIENT_CA) + self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY) + + server = self._server_socket( + cert_reqs=ssl.CERT_OPTIONAL, keyfile=SERVER_KEY, + certfile=SERVER_CERT, ca_certs=CLIENT_CA) + self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY) + + def test_ciphers(self): + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS) + self._assert_connection_success(server, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS) + + if not TSSLSocket._has_ciphers: + # unittest.skip is not available for Python 2.6 + print('skipping test_ciphers') + return + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) + self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL') + + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS) + self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL') + + def test_ssl2_and_ssl3_disabled(self): + if not hasattr(ssl, 'PROTOCOL_SSLv3'): + print('PROTOCOL_SSLv3 is not available') + else: + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) + self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3) + + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3) + self._assert_connection_failure(server, ca_certs=SERVER_CERT) + + if not hasattr(ssl, 'PROTOCOL_SSLv2'): + print('PROTOCOL_SSLv2 is not available') + else: + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) + self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2) + + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2) + self._assert_connection_failure(server, ca_certs=SERVER_CERT) + + def test_newer_tls(self): + if not TSSLSocket._has_ssl_context: + # unittest.skip is not available for Python 2.6 + print('skipping test_newer_tls') + return + if not hasattr(ssl, 'PROTOCOL_TLSv1_2'): + print('PROTOCOL_TLSv1_2 is not available') + else: + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) + self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) + + if not hasattr(ssl, 'PROTOCOL_TLSv1_1'): + print('PROTOCOL_TLSv1_1 is not available') + else: + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) + self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) + + if not hasattr(ssl, 'PROTOCOL_TLSv1_1') or not hasattr(ssl, 'PROTOCOL_TLSv1_2'): + print('PROTOCOL_TLSv1_1 and/or PROTOCOL_TLSv1_2 is not available') + else: + server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) + self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) + + def test_ssl_context(self): + if not TSSLSocket._has_ssl_context: + # unittest.skip is not available for Python 2.6 + print('skipping test_ssl_context') + return + server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + server_context.load_cert_chain(SERVER_CERT, SERVER_KEY) + server_context.load_verify_locations(CLIENT_CA) + server_context.verify_mode = ssl.CERT_REQUIRED + server = self._server_socket(ssl_context=server_context) + + client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY) + client_context.load_verify_locations(SERVER_CERT) + client_context.verify_mode = ssl.CERT_REQUIRED + + self._assert_connection_success(server, ssl_context=client_context) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.WARN) + from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket, _match_has_ipaddress + from thrift.transport.TTransport import TTransportException + + unittest.main() diff --git a/src/jaegertracing/thrift/lib/py/test/thrift_json.py b/src/jaegertracing/thrift/lib/py/test/thrift_json.py new file mode 100644 index 000000000..40e7a47e3 --- /dev/null +++ b/src/jaegertracing/thrift/lib/py/test/thrift_json.py @@ -0,0 +1,51 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sys +import unittest + +import _import_local_thrift # noqa +from thrift.protocol.TJSONProtocol import TJSONProtocol +from thrift.transport import TTransport + +# +# In order to run the test under Windows. We need to create symbolic link +# name 'thrift' to '../src' folder by using: +# +# mklink /D thrift ..\src +# + + +class TestJSONString(unittest.TestCase): + + def test_escaped_unicode_string(self): + unicode_json = b'"hello \\u0e01\\u0e02\\u0e03\\ud835\\udcab\\udb40\\udc70 unicode"' + unicode_text = u'hello \u0e01\u0e02\u0e03\U0001D4AB\U000E0070 unicode' + + buf = TTransport.TMemoryBuffer(unicode_json) + transport = TTransport.TBufferedTransportFactory().getTransport(buf) + protocol = TJSONProtocol(transport) + + if sys.version_info[0] == 2: + unicode_text = unicode_text.encode('utf8') + self.assertEqual(protocol.readString(), unicode_text) + + +if __name__ == '__main__': + unittest.main() -- cgit v1.2.3