diff options
Diffstat (limited to 'src/jaegertracing/thrift/lib/rs')
30 files changed, 9254 insertions, 0 deletions
diff --git a/src/jaegertracing/thrift/lib/rs/Cargo.toml b/src/jaegertracing/thrift/lib/rs/Cargo.toml new file mode 100644 index 000000000..69da0f399 --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "thrift" +description = "Rust bindings for the Apache Thrift RPC system" +version = "0.13.0" +license = "Apache-2.0" +authors = ["Apache Thrift Developers <dev@thrift.apache.org>"] +homepage = "http://thrift.apache.org" +documentation = "https://thrift.apache.org" +readme = "README.md" +exclude = ["Makefile*", "test/**", "*.iml"] +keywords = ["thrift"] + +[dependencies] +ordered-float = "1.0" +byteorder = "1.3" +integer-encoding = "1.0" +log = "0.4" +threadpool = "1.7" diff --git a/src/jaegertracing/thrift/lib/rs/Makefile.am b/src/jaegertracing/thrift/lib/rs/Makefile.am new file mode 100644 index 000000000..0a34120a3 --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/Makefile.am @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +SUBDIRS = . + +if WITH_TESTS +SUBDIRS += test +endif + +install: + @echo '##############################################################' + @echo '##############################################################' + @echo 'The Rust client library should be installed via a Cargo.toml dependency - please see /lib/rs/README.md' + @echo '##############################################################' + @echo '##############################################################' + +check-local: + $(CARGO) test + +all-local: + $(CARGO) build + +clean-local: + $(CARGO) clean + -$(RM) Cargo.lock + +EXTRA_DIST = \ + src \ + Cargo.toml \ + README.md diff --git a/src/jaegertracing/thrift/lib/rs/README.md b/src/jaegertracing/thrift/lib/rs/README.md new file mode 100644 index 000000000..f518f4eb6 --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/README.md @@ -0,0 +1,120 @@ +# Rust Thrift library + +## Overview + +This crate implements the components required to build a working Thrift server +and client. It is divided into the following modules: + + 1. errors + 2. protocol + 3. transport + 4. server + 5. autogen + +The modules are layered as shown. The `generated` layer is code generated by the +Thrift compiler's Rust plugin. It uses the components defined in this crate to +serialize and deserialize types and implement RPC. Users interact with these +types and services by writing their own code on top. + + ```text + +-----------+ + | app dev | + +-----------+ + | generated | <-> errors/results + +-----------+ + | protocol | + +-----------+ + | transport | + +-----------+ + ``` + +## Using this crate + +Add `thrift = "x.y.z"` to your `Cargo.toml`, where `x.y.z` is the version of the +Thrift compiler you're using. + +## API Documentation + +Full [Rustdoc](https://docs.rs/thrift/) + +## Compatibility + +The Rust library and auto-generated code targets Rust versions 1.28+. +It does not currently use any Rust 2018 features. + +### Breaking Changes + +Breaking changes are minimized. When they are made they will be outlined below with transition guidelines. + +##### Thrift 0.13.0 + +* **[THRIFT-4536]** - Use TryFrom from std, required rust 1.34.0 or higher + + Previously TryFrom was from try_from crate, it is now from the std library, + but this functionality is only available in rust 1.34.0. Additionally, + ordered-float is now re-exported under the thrift module to reduce + possible dependency mismatches. + +##### Thrift 0.12.0 + +* **[THRIFT-4529]** - Rust enum variants are now camel-cased instead of uppercased to conform to Rust naming conventions + + Previously, enum variants were uppercased in the auto-generated code. + For example, the following thrift enum: + + ```thrift + // THRIFT + enum Operation { + ADD, + SUBTRACT, + MULTIPLY, + DIVIDE, + } + ``` + + used to generate: + + ```rust + // OLD AUTO-GENERATED RUST + pub enum Operation { + ADD, + SUBTRACT, + MULTIPLY, + DIVIDE, + } + ``` + It *now* generates: + ```rust + // NEW AUTO-GENERATED RUST + pub enum Operation { + Add, + Subtract, + Multiply, + Divide, + } + ``` + + You will have to change all enum variants in your code to use camel-cased names. + This should be a search and replace. + +## Contributing + +Bug reports and PRs are always welcome! Please see the +[Thrift website](https://thrift.apache.org/) for more details. + +Thrift Rust support requires code in several directories: + +* `compiler/cpp/src/thrift/generate/t_rs_generator.cc`: binding code generator +* `lib/rs`: runtime library +* `lib/rs/test`: supplemental tests +* `tutorial/rs`: tutorial client and server +* `test/rs`: cross-language test client and server + +All library code, test code and auto-generated code compiles and passes clippy +without warnings. All new code must do the same! When making changes ensure that: + +* `rustc` does does output any warnings +* `clippy` with default settings does not output any warnings (includes auto-generated code) +* `cargo test` is successful +* `make precross` and `make check` are successful +* `tutorial/bin/tutorial_client` and `tutorial/bin/tutorial_server` communicate diff --git a/src/jaegertracing/thrift/lib/rs/RELEASING.md b/src/jaegertracing/thrift/lib/rs/RELEASING.md new file mode 100644 index 000000000..073d7a02a --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/RELEASING.md @@ -0,0 +1,57 @@ +# Publishing the thrift crate + +Publishing the Rust thrift crate is straightforward, and involves two major steps: + +1. Setting up your [crates.io](https://www.crates.io) account _(one-time)_ + +2. Packaging/publishing the Rust thrift crate itself + +## Set up your crates.io account (one-time) + +1. Go to [crates.io](https://www.crates.io) and click the `Log In` button at the top right. + + Log in **as the Github user with write permissions to the thrift repo!** + +2. Click your user icon button at the top right and select `Account Settings`. + +3. Click `New Token` next to `API Access`. + + This generates a new API key that cargo uses to publish packages to crates.io. + Store this API key somewhere safe. If you will only use this Github account to + publish crates to crates.io you can follow the instructions to save the + generated key to `~/.cargo/credentials`. + +## Package and Publish + +You can use the automated script or run the release steps manually. + +**Important**: `cargo` expects that version numbers follow the semantic versioning format. +This means that `THRIFT_RELEASE_VERSION` must have a major, minor and patch number, i.e., must +be in the form `#.##.##`. + +#### Automated + +Run `./release.sh [THRIFT_RELEASE_VERSION]`. + +_Requires you to have stored your credentials in `~/.cargo/credentials`._ + +#### Manual + +1. Edit `Cargo.toml` and update the `version = 1.0` key to `version = [THRIFT_RELEASE_VERSION]` + +2. `git add Cargo.toml` + +3. `git commit -m "Update thrift crate version to [THRIFT_RELEASE_VERSION]" -m "Client: rs"` + +4. `cargo login` + + _(not required if you have stored your credentials in `~/.cargo/credentials`)_ + +5. `cargo clean` + +6. `cargo package` + + This step fails if there are any uncommitted or ignored files. Do **not** use the `--allow-dirty` + flag! Instead, add the highlighted files as entries in the `Cargo.toml` `exclude` key. + +7. `cargo publish` diff --git a/src/jaegertracing/thrift/lib/rs/release.sh b/src/jaegertracing/thrift/lib/rs/release.sh new file mode 100755 index 000000000..c4e5b4892 --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/release.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +set -o errexit +set -o pipefail +set -o nounset + +if ! [[ $# -eq 1 && $1 =~ ^[0-9](\.[0-9][0-9]*){2}$ ]]; then + (>&2 echo "Usage: ./publish-crate.sh [THRIFT_RELEASE_VERSION] ") + (>&2 echo " THRIFT_RELEASE_VERSION is in semantic versioning format, i.e. #.##.##") + exit 1 +fi + +THRIFT_RELEASE_VERSION=${1:-} + +echo "Updating Cargo.toml to ${THRIFT_RELEASE_VERSION}" +sed -i.old -e "s/^version = .*$/version = \"${THRIFT_RELEASE_VERSION}\"/g" Cargo.toml +rm Cargo.toml.old + +echo "Committing updated Cargo.toml" +git add Cargo.toml +git commit -m "Update thrift crate version to ${THRIFT_RELEASE_VERSION}" -m "Client: rs" + +echo "Packaging and releasing rust thrift crate with version ${THRIFT_RELEASE_VERSION}" +cargo clean +cargo package +cargo publish diff --git a/src/jaegertracing/thrift/lib/rs/src/autogen.rs b/src/jaegertracing/thrift/lib/rs/src/autogen.rs new file mode 100644 index 000000000..6806a08ce --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/src/autogen.rs @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Thrift compiler auto-generated support. +//! +//! +//! Types and functions used internally by the Thrift compiler's Rust plugin +//! to implement required functionality. Users should never have to use code +//! in this module directly. + +use protocol::{TInputProtocol, TOutputProtocol}; + +/// Specifies the minimum functionality an auto-generated client should provide +/// to communicate with a Thrift server. +pub trait TThriftClient { + /// Returns the input protocol used to read serialized Thrift messages + /// from the Thrift server. + fn i_prot_mut(&mut self) -> &mut dyn TInputProtocol; + /// Returns the output protocol used to write serialized Thrift messages + /// to the Thrift server. + fn o_prot_mut(&mut self) -> &mut dyn TOutputProtocol; + /// Returns the sequence number of the last message written to the Thrift + /// server. Returns `0` if no messages have been written. Sequence + /// numbers should *never* be negative, and this method returns an `i32` + /// simply because the Thrift protocol encodes sequence numbers as `i32` on + /// the wire. + fn sequence_number(&self) -> i32; // FIXME: consider returning a u32 + /// Increments the sequence number, indicating that a message with that + /// number has been sent to the Thrift server. + fn increment_sequence_number(&mut self) -> i32; +} diff --git a/src/jaegertracing/thrift/lib/rs/src/errors.rs b/src/jaegertracing/thrift/lib/rs/src/errors.rs new file mode 100644 index 000000000..68cdc9c17 --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/src/errors.rs @@ -0,0 +1,667 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::convert::{From, Into}; +use std::error::Error as StdError; +use std::fmt::{Debug, Display, Formatter}; +use std::{error, fmt, io, string}; +use std::convert::TryFrom; + +use protocol::{TFieldIdentifier, TInputProtocol, TOutputProtocol, TStructIdentifier, TType}; + +// FIXME: should all my error structs impl error::Error as well? +// FIXME: should all fields in TransportError, ProtocolError and ApplicationError be optional? + +/// Error type returned by all runtime library functions. +/// +/// `thrift::Error` is used throughout this crate as well as in auto-generated +/// Rust code. It consists of four variants defined by convention across Thrift +/// implementations: +/// +/// 1. `Transport`: errors encountered while operating on I/O channels +/// 2. `Protocol`: errors encountered during runtime-library processing +/// 3. `Application`: errors encountered within auto-generated code +/// 4. `User`: IDL-defined exception structs +/// +/// The `Application` variant also functions as a catch-all: all handler errors +/// are automatically turned into application errors. +/// +/// All error variants except `Error::User` take an eponymous struct with two +/// required fields: +/// +/// 1. `kind`: variant-specific enum identifying the error sub-type +/// 2. `message`: human-readable error info string +/// +/// `kind` is defined by convention while `message` is freeform. If none of the +/// enumerated kinds are suitable use `Unknown`. +/// +/// To simplify error creation convenience constructors are defined for all +/// variants, and conversions from their structs (`thrift::TransportError`, +/// `thrift::ProtocolError` and `thrift::ApplicationError` into `thrift::Error`. +/// +/// # Examples +/// +/// Create a `TransportError`. +/// +/// ``` +/// use thrift::{TransportError, TransportErrorKind}; +/// +/// // explicit +/// let err0: thrift::Result<()> = Err( +/// thrift::Error::Transport( +/// TransportError { +/// kind: TransportErrorKind::TimedOut, +/// message: format!("connection to server timed out") +/// } +/// ) +/// ); +/// +/// // use conversion +/// let err1: thrift::Result<()> = Err( +/// thrift::Error::from( +/// TransportError { +/// kind: TransportErrorKind::TimedOut, +/// message: format!("connection to server timed out") +/// } +/// ) +/// ); +/// +/// // use struct constructor +/// let err2: thrift::Result<()> = Err( +/// thrift::Error::Transport( +/// TransportError::new( +/// TransportErrorKind::TimedOut, +/// "connection to server timed out" +/// ) +/// ) +/// ); +/// +/// +/// // use error variant constructor +/// let err3: thrift::Result<()> = Err( +/// thrift::new_transport_error( +/// TransportErrorKind::TimedOut, +/// "connection to server timed out" +/// ) +/// ); +/// ``` +/// +/// Create an error from a string. +/// +/// ``` +/// use thrift::{ApplicationError, ApplicationErrorKind}; +/// +/// // we just use `From::from` to convert a `String` into a `thrift::Error` +/// let err0: thrift::Result<()> = Err( +/// thrift::Error::from("This is an error") +/// ); +/// +/// // err0 is equivalent to... +/// let err1: thrift::Result<()> = Err( +/// thrift::Error::Application( +/// ApplicationError { +/// kind: ApplicationErrorKind::Unknown, +/// message: format!("This is an error") +/// } +/// ) +/// ); +/// ``` +/// +/// Return an IDL-defined exception. +/// +/// ```text +/// // Thrift IDL exception definition. +/// exception Xception { +/// 1: i32 errorCode, +/// 2: string message +/// } +/// ``` +/// +/// ``` +/// use std::error::Error; +/// use std::fmt; +/// use std::fmt::{Display, Formatter}; +/// +/// // auto-generated by the Thrift compiler +/// #[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)] +/// pub struct Xception { +/// pub error_code: Option<i32>, +/// pub message: Option<String>, +/// } +/// +/// // auto-generated by the Thrift compiler +/// impl Error for Xception { +/// fn description(&self) -> &str { +/// "remote service threw Xception" +/// } +/// } +/// +/// // auto-generated by the Thrift compiler +/// impl From<Xception> for thrift::Error { +/// fn from(e: Xception) -> Self { +/// thrift::Error::User(Box::new(e)) +/// } +/// } +/// +/// // auto-generated by the Thrift compiler +/// impl Display for Xception { +/// fn fmt(&self, f: &mut Formatter) -> fmt::Result { +/// self.description().fmt(f) +/// } +/// } +/// +/// // in user code... +/// let err: thrift::Result<()> = Err( +/// thrift::Error::from(Xception { error_code: Some(1), message: None }) +/// ); +/// ``` +pub enum Error { + /// Errors encountered while operating on I/O channels. + /// + /// These include *connection closed* and *bind failure*. + Transport(TransportError), + /// Errors encountered during runtime-library processing. + /// + /// These include *message too large* and *unsupported protocol version*. + Protocol(ProtocolError), + /// Errors encountered within auto-generated code, or when incoming + /// or outgoing messages violate the Thrift spec. + /// + /// These include *out-of-order messages* and *missing required struct + /// fields*. + /// + /// This variant also functions as a catch-all: errors from handler + /// functions are automatically returned as an `ApplicationError`. + Application(ApplicationError), + /// IDL-defined exception structs. + User(Box<dyn error::Error + Sync + Send>), +} + +impl Error { + /// Create an `ApplicationError` from its wire representation. + /// + /// Application code **should never** call this method directly. + pub fn read_application_error_from_in_protocol( + i: &mut dyn TInputProtocol, + ) -> ::Result<ApplicationError> { + let mut message = "general remote error".to_owned(); + let mut kind = ApplicationErrorKind::Unknown; + + i.read_struct_begin()?; + + loop { + let field_ident = i.read_field_begin()?; + + if field_ident.field_type == TType::Stop { + break; + } + + let id = field_ident + .id + .expect("sender should always specify id for non-STOP field"); + + match id { + 1 => { + let remote_message = i.read_string()?; + i.read_field_end()?; + message = remote_message; + } + 2 => { + let remote_type_as_int = i.read_i32()?; + let remote_kind: ApplicationErrorKind = TryFrom::try_from(remote_type_as_int) + .unwrap_or(ApplicationErrorKind::Unknown); + i.read_field_end()?; + kind = remote_kind; + } + _ => { + i.skip(field_ident.field_type)?; + } + } + } + + i.read_struct_end()?; + + Ok(ApplicationError { + kind: kind, + message: message, + }) + } + + /// Convert an `ApplicationError` into its wire representation and write + /// it to the remote. + /// + /// Application code **should never** call this method directly. + pub fn write_application_error_to_out_protocol( + e: &ApplicationError, + o: &mut dyn TOutputProtocol, + ) -> ::Result<()> { + o.write_struct_begin(&TStructIdentifier { + name: "TApplicationException".to_owned(), + })?; + + let message_field = TFieldIdentifier::new("message", TType::String, 1); + let type_field = TFieldIdentifier::new("type", TType::I32, 2); + + o.write_field_begin(&message_field)?; + o.write_string(&e.message)?; + o.write_field_end()?; + + o.write_field_begin(&type_field)?; + o.write_i32(e.kind as i32)?; + o.write_field_end()?; + + o.write_field_stop()?; + o.write_struct_end()?; + + o.flush() + } +} + +impl error::Error for Error { + fn description(&self) -> &str { + match *self { + Error::Transport(ref e) => TransportError::description(e), + Error::Protocol(ref e) => ProtocolError::description(e), + Error::Application(ref e) => ApplicationError::description(e), + Error::User(ref e) => e.description(), + } + } +} + +impl Debug for Error { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match *self { + Error::Transport(ref e) => Debug::fmt(e, f), + Error::Protocol(ref e) => Debug::fmt(e, f), + Error::Application(ref e) => Debug::fmt(e, f), + Error::User(ref e) => Debug::fmt(e, f), + } + } +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match *self { + Error::Transport(ref e) => Display::fmt(e, f), + Error::Protocol(ref e) => Display::fmt(e, f), + Error::Application(ref e) => Display::fmt(e, f), + Error::User(ref e) => Display::fmt(e, f), + } + } +} + +impl From<String> for Error { + fn from(s: String) -> Self { + Error::Application(ApplicationError { + kind: ApplicationErrorKind::Unknown, + message: s, + }) + } +} + +impl<'a> From<&'a str> for Error { + fn from(s: &'a str) -> Self { + Error::Application(ApplicationError { + kind: ApplicationErrorKind::Unknown, + message: String::from(s), + }) + } +} + +impl From<TransportError> for Error { + fn from(e: TransportError) -> Self { + Error::Transport(e) + } +} + +impl From<ProtocolError> for Error { + fn from(e: ProtocolError) -> Self { + Error::Protocol(e) + } +} + +impl From<ApplicationError> for Error { + fn from(e: ApplicationError) -> Self { + Error::Application(e) + } +} + +/// Create a new `Error` instance of type `Transport` that wraps a +/// `TransportError`. +pub fn new_transport_error<S: Into<String>>(kind: TransportErrorKind, message: S) -> Error { + Error::Transport(TransportError::new(kind, message)) +} + +/// Information about I/O errors. +#[derive(Debug, Eq, PartialEq)] +pub struct TransportError { + /// I/O error variant. + /// + /// If a specific `TransportErrorKind` does not apply use + /// `TransportErrorKind::Unknown`. + pub kind: TransportErrorKind, + /// Human-readable error message. + pub message: String, +} + +impl TransportError { + /// Create a new `TransportError`. + pub fn new<S: Into<String>>(kind: TransportErrorKind, message: S) -> TransportError { + TransportError { + kind: kind, + message: message.into(), + } + } +} + +/// I/O error categories. +/// +/// This list may grow, and it is not recommended to match against it. +#[derive(Clone, Copy, Eq, Debug, PartialEq)] +pub enum TransportErrorKind { + /// Catch-all I/O error. + Unknown = 0, + /// An I/O operation was attempted when the transport channel was not open. + NotOpen = 1, + /// The transport channel cannot be opened because it was opened previously. + AlreadyOpen = 2, + /// An I/O operation timed out. + TimedOut = 3, + /// A read could not complete because no bytes were available. + EndOfFile = 4, + /// An invalid (buffer/message) size was requested or received. + NegativeSize = 5, + /// Too large a buffer or message size was requested or received. + SizeLimit = 6, +} + +impl TransportError { + fn description(&self) -> &str { + match self.kind { + TransportErrorKind::Unknown => "transport error", + TransportErrorKind::NotOpen => "not open", + TransportErrorKind::AlreadyOpen => "already open", + TransportErrorKind::TimedOut => "timed out", + TransportErrorKind::EndOfFile => "end of file", + TransportErrorKind::NegativeSize => "negative size message", + TransportErrorKind::SizeLimit => "message too long", + } + } +} + +impl Display for TransportError { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}", self.description()) + } +} + +impl TryFrom<i32> for TransportErrorKind { + type Error = Error; + fn try_from(from: i32) -> Result<Self, Self::Error> { + match from { + 0 => Ok(TransportErrorKind::Unknown), + 1 => Ok(TransportErrorKind::NotOpen), + 2 => Ok(TransportErrorKind::AlreadyOpen), + 3 => Ok(TransportErrorKind::TimedOut), + 4 => Ok(TransportErrorKind::EndOfFile), + 5 => Ok(TransportErrorKind::NegativeSize), + 6 => Ok(TransportErrorKind::SizeLimit), + _ => Err(Error::Protocol(ProtocolError { + kind: ProtocolErrorKind::Unknown, + message: format!("cannot convert {} to TransportErrorKind", from), + })), + } + } +} + +impl From<io::Error> for Error { + fn from(err: io::Error) -> Self { + match err.kind() { + io::ErrorKind::ConnectionReset + | io::ErrorKind::ConnectionRefused + | io::ErrorKind::NotConnected => Error::Transport(TransportError { + kind: TransportErrorKind::NotOpen, + message: err.description().to_owned(), + }), + io::ErrorKind::AlreadyExists => Error::Transport(TransportError { + kind: TransportErrorKind::AlreadyOpen, + message: err.description().to_owned(), + }), + io::ErrorKind::TimedOut => Error::Transport(TransportError { + kind: TransportErrorKind::TimedOut, + message: err.description().to_owned(), + }), + io::ErrorKind::UnexpectedEof => Error::Transport(TransportError { + kind: TransportErrorKind::EndOfFile, + message: err.description().to_owned(), + }), + _ => { + Error::Transport(TransportError { + kind: TransportErrorKind::Unknown, + message: err.description().to_owned(), // FIXME: use io error's debug string + }) + } + } + } +} + +impl From<string::FromUtf8Error> for Error { + fn from(err: string::FromUtf8Error) -> Self { + Error::Protocol(ProtocolError { + kind: ProtocolErrorKind::InvalidData, + message: err.description().to_owned(), // FIXME: use fmt::Error's debug string + }) + } +} + +/// Create a new `Error` instance of type `Protocol` that wraps a +/// `ProtocolError`. +pub fn new_protocol_error<S: Into<String>>(kind: ProtocolErrorKind, message: S) -> Error { + Error::Protocol(ProtocolError::new(kind, message)) +} + +/// Information about errors that occur in the runtime library. +#[derive(Debug, Eq, PartialEq)] +pub struct ProtocolError { + /// Protocol error variant. + /// + /// If a specific `ProtocolErrorKind` does not apply use + /// `ProtocolErrorKind::Unknown`. + pub kind: ProtocolErrorKind, + /// Human-readable error message. + pub message: String, +} + +impl ProtocolError { + /// Create a new `ProtocolError`. + pub fn new<S: Into<String>>(kind: ProtocolErrorKind, message: S) -> ProtocolError { + ProtocolError { + kind: kind, + message: message.into(), + } + } +} + +/// Runtime library error categories. +/// +/// This list may grow, and it is not recommended to match against it. +#[derive(Clone, Copy, Eq, Debug, PartialEq)] +pub enum ProtocolErrorKind { + /// Catch-all runtime-library error. + Unknown = 0, + /// An invalid argument was supplied to a library function, or invalid data + /// was received from a Thrift endpoint. + InvalidData = 1, + /// An invalid size was received in an encoded field. + NegativeSize = 2, + /// Thrift message or field was too long. + SizeLimit = 3, + /// Unsupported or unknown Thrift protocol version. + BadVersion = 4, + /// Unsupported Thrift protocol, server or field type. + NotImplemented = 5, + /// Reached the maximum nested depth to which an encoded Thrift field could + /// be skipped. + DepthLimit = 6, +} + +impl ProtocolError { + fn description(&self) -> &str { + match self.kind { + ProtocolErrorKind::Unknown => "protocol error", + ProtocolErrorKind::InvalidData => "bad data", + ProtocolErrorKind::NegativeSize => "negative message size", + ProtocolErrorKind::SizeLimit => "message too long", + ProtocolErrorKind::BadVersion => "invalid thrift version", + ProtocolErrorKind::NotImplemented => "not implemented", + ProtocolErrorKind::DepthLimit => "maximum skip depth reached", + } + } +} + +impl Display for ProtocolError { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}", self.description()) + } +} + +impl TryFrom<i32> for ProtocolErrorKind { + type Error = Error; + fn try_from(from: i32) -> Result<Self, Self::Error> { + match from { + 0 => Ok(ProtocolErrorKind::Unknown), + 1 => Ok(ProtocolErrorKind::InvalidData), + 2 => Ok(ProtocolErrorKind::NegativeSize), + 3 => Ok(ProtocolErrorKind::SizeLimit), + 4 => Ok(ProtocolErrorKind::BadVersion), + 5 => Ok(ProtocolErrorKind::NotImplemented), + 6 => Ok(ProtocolErrorKind::DepthLimit), + _ => Err(Error::Protocol(ProtocolError { + kind: ProtocolErrorKind::Unknown, + message: format!("cannot convert {} to ProtocolErrorKind", from), + })), + } + } +} + +/// Create a new `Error` instance of type `Application` that wraps an +/// `ApplicationError`. +pub fn new_application_error<S: Into<String>>(kind: ApplicationErrorKind, message: S) -> Error { + Error::Application(ApplicationError::new(kind, message)) +} + +/// Information about errors in auto-generated code or in user-implemented +/// service handlers. +#[derive(Debug, Eq, PartialEq)] +pub struct ApplicationError { + /// Application error variant. + /// + /// If a specific `ApplicationErrorKind` does not apply use + /// `ApplicationErrorKind::Unknown`. + pub kind: ApplicationErrorKind, + /// Human-readable error message. + pub message: String, +} + +impl ApplicationError { + /// Create a new `ApplicationError`. + pub fn new<S: Into<String>>(kind: ApplicationErrorKind, message: S) -> ApplicationError { + ApplicationError { + kind: kind, + message: message.into(), + } + } +} + +/// Auto-generated or user-implemented code error categories. +/// +/// This list may grow, and it is not recommended to match against it. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum ApplicationErrorKind { + /// Catch-all application error. + Unknown = 0, + /// Made service call to an unknown service method. + UnknownMethod = 1, + /// Received an unknown Thrift message type. That is, not one of the + /// `thrift::protocol::TMessageType` variants. + InvalidMessageType = 2, + /// Method name in a service reply does not match the name of the + /// receiving service method. + WrongMethodName = 3, + /// Received an out-of-order Thrift message. + BadSequenceId = 4, + /// Service reply is missing required fields. + MissingResult = 5, + /// Auto-generated code failed unexpectedly. + InternalError = 6, + /// Thrift protocol error. When possible use `Error::ProtocolError` with a + /// specific `ProtocolErrorKind` instead. + ProtocolError = 7, + /// *Unknown*. Included only for compatibility with existing Thrift implementations. + InvalidTransform = 8, // ?? + /// Thrift endpoint requested, or is using, an unsupported encoding. + InvalidProtocol = 9, // ?? + /// Thrift endpoint requested, or is using, an unsupported auto-generated client type. + UnsupportedClientType = 10, // ?? +} + +impl ApplicationError { + fn description(&self) -> &str { + match self.kind { + ApplicationErrorKind::Unknown => "service error", + ApplicationErrorKind::UnknownMethod => "unknown service method", + ApplicationErrorKind::InvalidMessageType => "wrong message type received", + ApplicationErrorKind::WrongMethodName => "unknown method reply received", + ApplicationErrorKind::BadSequenceId => "out of order sequence id", + ApplicationErrorKind::MissingResult => "missing method result", + ApplicationErrorKind::InternalError => "remote service threw exception", + ApplicationErrorKind::ProtocolError => "protocol error", + ApplicationErrorKind::InvalidTransform => "invalid transform", + ApplicationErrorKind::InvalidProtocol => "invalid protocol requested", + ApplicationErrorKind::UnsupportedClientType => "unsupported protocol client", + } + } +} + +impl Display for ApplicationError { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}", self.description()) + } +} + +impl TryFrom<i32> for ApplicationErrorKind { + type Error = Error; + fn try_from(from: i32) -> Result<Self, Self::Error> { + match from { + 0 => Ok(ApplicationErrorKind::Unknown), + 1 => Ok(ApplicationErrorKind::UnknownMethod), + 2 => Ok(ApplicationErrorKind::InvalidMessageType), + 3 => Ok(ApplicationErrorKind::WrongMethodName), + 4 => Ok(ApplicationErrorKind::BadSequenceId), + 5 => Ok(ApplicationErrorKind::MissingResult), + 6 => Ok(ApplicationErrorKind::InternalError), + 7 => Ok(ApplicationErrorKind::ProtocolError), + 8 => Ok(ApplicationErrorKind::InvalidTransform), + 9 => Ok(ApplicationErrorKind::InvalidProtocol), + 10 => Ok(ApplicationErrorKind::UnsupportedClientType), + _ => Err(Error::Application(ApplicationError { + kind: ApplicationErrorKind::Unknown, + message: format!("cannot convert {} to ApplicationErrorKind", from), + })), + } + } +} diff --git a/src/jaegertracing/thrift/lib/rs/src/lib.rs b/src/jaegertracing/thrift/lib/rs/src/lib.rs new file mode 100644 index 000000000..cdd60f0a9 --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/src/lib.rs @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Rust runtime library for the Apache Thrift RPC system. +//! +//! This crate implements the components required to build a working +//! Thrift server and client. It is divided into the following modules: +//! +//! 1. errors +//! 2. protocol +//! 3. transport +//! 4. server +//! 5. autogen +//! +//! The modules are layered as shown in the diagram below. The `autogen'd` +//! layer is generated by the Thrift compiler's Rust plugin. It uses the +//! types and functions defined in this crate to serialize and deserialize +//! messages and implement RPC. Users interact with these types and services +//! by writing their own code that uses the auto-generated clients and +//! servers. +//! +//! ```text +//! +-----------+ +//! | user app | +//! +-----------+ +//! | autogen'd | (uses errors, autogen) +//! +-----------+ +//! | protocol | +//! +-----------+ +//! | transport | +//! +-----------+ +//! ``` + +#![crate_type = "lib"] +#![doc(test(attr(allow(unused_variables), deny(warnings))))] +#![deny(bare_trait_objects)] + +extern crate byteorder; +extern crate ordered_float; +extern crate integer_encoding; +extern crate threadpool; + +#[macro_use] +extern crate log; + +// NOTE: this macro has to be defined before any modules. See: +// https://danielkeep.github.io/quick-intro-to-macros.html#some-more-gotchas + +/// Assert that an expression returning a `Result` is a success. If it is, +/// return the value contained in the result, i.e. `expr.unwrap()`. +#[cfg(test)] +macro_rules! assert_success { + ($e: expr) => {{ + let res = $e; + assert!(res.is_ok()); + res.unwrap() + }}; +} + +pub mod protocol; +pub mod server; +pub mod transport; + +mod errors; +pub use errors::*; + +mod autogen; +pub use autogen::*; + +/// Result type returned by all runtime library functions. +/// +/// As is convention this is a typedef of `std::result::Result` +/// with `E` defined as the `thrift::Error` type. +pub type Result<T> = std::result::Result<T, self::Error>; + +// Re-export ordered-float, since it is used by the generator +pub use ordered_float::OrderedFloat as OrderedFloat;
\ No newline at end of file diff --git a/src/jaegertracing/thrift/lib/rs/src/protocol/binary.rs b/src/jaegertracing/thrift/lib/rs/src/protocol/binary.rs new file mode 100644 index 000000000..2069cf9fd --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/src/protocol/binary.rs @@ -0,0 +1,956 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use byteorder::{BigEndian, ByteOrder, ReadBytesExt, WriteBytesExt}; +use std::convert::{From, TryFrom}; + +use super::{ + TFieldIdentifier, TInputProtocol, TInputProtocolFactory, TListIdentifier, TMapIdentifier, + TMessageIdentifier, TMessageType, +}; +use super::{TOutputProtocol, TOutputProtocolFactory, TSetIdentifier, TStructIdentifier, TType}; +use transport::{TReadTransport, TWriteTransport}; +use {ProtocolError, ProtocolErrorKind}; + +const BINARY_PROTOCOL_VERSION_1: u32 = 0x80010000; + +/// Read messages encoded in the Thrift simple binary encoding. +/// +/// There are two available modes: `strict` and `non-strict`, where the +/// `non-strict` version does not check for the protocol version in the +/// received message header. +/// +/// # Examples +/// +/// Create and use a `TBinaryInputProtocol`. +/// +/// ```no_run +/// use thrift::protocol::{TBinaryInputProtocol, TInputProtocol}; +/// use thrift::transport::TTcpChannel; +/// +/// let mut channel = TTcpChannel::new(); +/// channel.open("localhost:9090").unwrap(); +/// +/// let mut protocol = TBinaryInputProtocol::new(channel, true); +/// +/// let recvd_bool = protocol.read_bool().unwrap(); +/// let recvd_string = protocol.read_string().unwrap(); +/// ``` +#[derive(Debug)] +pub struct TBinaryInputProtocol<T> +where + T: TReadTransport, +{ + strict: bool, + pub transport: T, // FIXME: shouldn't be public +} + +impl<'a, T> TBinaryInputProtocol<T> +where + T: TReadTransport, +{ + /// Create a `TBinaryInputProtocol` that reads bytes from `transport`. + /// + /// Set `strict` to `true` if all incoming messages contain the protocol + /// version number in the protocol header. + pub fn new(transport: T, strict: bool) -> TBinaryInputProtocol<T> { + TBinaryInputProtocol { + strict: strict, + transport: transport, + } + } +} + +impl<T> TInputProtocol for TBinaryInputProtocol<T> +where + T: TReadTransport, +{ + #[cfg_attr(feature = "cargo-clippy", allow(collapsible_if))] + fn read_message_begin(&mut self) -> ::Result<TMessageIdentifier> { + let mut first_bytes = vec![0; 4]; + self.transport.read_exact(&mut first_bytes[..])?; + + // the thrift version header is intentionally negative + // so the first check we'll do is see if the sign bit is set + // and if so - assume it's the protocol-version header + if first_bytes[0] >= 8 { + // apparently we got a protocol-version header - check + // it, and if it matches, read the rest of the fields + if first_bytes[0..2] != [0x80, 0x01] { + Err(::Error::Protocol(ProtocolError { + kind: ProtocolErrorKind::BadVersion, + message: format!("received bad version: {:?}", &first_bytes[0..2]), + })) + } else { + let message_type: TMessageType = TryFrom::try_from(first_bytes[3])?; + let name = self.read_string()?; + let sequence_number = self.read_i32()?; + Ok(TMessageIdentifier::new(name, message_type, sequence_number)) + } + } else { + // apparently we didn't get a protocol-version header, + // which happens if the sender is not using the strict protocol + if self.strict { + // we're in strict mode however, and that always + // requires the protocol-version header to be written first + Err(::Error::Protocol(ProtocolError { + kind: ProtocolErrorKind::BadVersion, + message: format!("received bad version: {:?}", &first_bytes[0..2]), + })) + } else { + // in the non-strict version the first message field + // is the message name. strings (byte arrays) are length-prefixed, + // so we've just read the length in the first 4 bytes + let name_size = BigEndian::read_i32(&first_bytes) as usize; + let mut name_buf: Vec<u8> = vec![0; name_size]; + self.transport.read_exact(&mut name_buf)?; + let name = String::from_utf8(name_buf)?; + + // read the rest of the fields + let message_type: TMessageType = self.read_byte().and_then(TryFrom::try_from)?; + let sequence_number = self.read_i32()?; + Ok(TMessageIdentifier::new(name, message_type, sequence_number)) + } + } + } + + fn read_message_end(&mut self) -> ::Result<()> { + Ok(()) + } + + fn read_struct_begin(&mut self) -> ::Result<Option<TStructIdentifier>> { + Ok(None) + } + + fn read_struct_end(&mut self) -> ::Result<()> { + Ok(()) + } + + fn read_field_begin(&mut self) -> ::Result<TFieldIdentifier> { + let field_type_byte = self.read_byte()?; + let field_type = field_type_from_u8(field_type_byte)?; + let id = match field_type { + TType::Stop => Ok(0), + _ => self.read_i16(), + }?; + Ok(TFieldIdentifier::new::<Option<String>, String, i16>( + None, field_type, id, + )) + } + + fn read_field_end(&mut self) -> ::Result<()> { + Ok(()) + } + + fn read_bytes(&mut self) -> ::Result<Vec<u8>> { + let num_bytes = self.transport.read_i32::<BigEndian>()? as usize; + let mut buf = vec![0u8; num_bytes]; + self.transport + .read_exact(&mut buf) + .map(|_| buf) + .map_err(From::from) + } + + fn read_bool(&mut self) -> ::Result<bool> { + let b = self.read_i8()?; + match b { + 0 => Ok(false), + _ => Ok(true), + } + } + + fn read_i8(&mut self) -> ::Result<i8> { + self.transport.read_i8().map_err(From::from) + } + + fn read_i16(&mut self) -> ::Result<i16> { + self.transport.read_i16::<BigEndian>().map_err(From::from) + } + + fn read_i32(&mut self) -> ::Result<i32> { + self.transport.read_i32::<BigEndian>().map_err(From::from) + } + + fn read_i64(&mut self) -> ::Result<i64> { + self.transport.read_i64::<BigEndian>().map_err(From::from) + } + + fn read_double(&mut self) -> ::Result<f64> { + self.transport.read_f64::<BigEndian>().map_err(From::from) + } + + fn read_string(&mut self) -> ::Result<String> { + let bytes = self.read_bytes()?; + String::from_utf8(bytes).map_err(From::from) + } + + fn read_list_begin(&mut self) -> ::Result<TListIdentifier> { + let element_type: TType = self.read_byte().and_then(field_type_from_u8)?; + let size = self.read_i32()?; + Ok(TListIdentifier::new(element_type, size)) + } + + fn read_list_end(&mut self) -> ::Result<()> { + Ok(()) + } + + fn read_set_begin(&mut self) -> ::Result<TSetIdentifier> { + let element_type: TType = self.read_byte().and_then(field_type_from_u8)?; + let size = self.read_i32()?; + Ok(TSetIdentifier::new(element_type, size)) + } + + fn read_set_end(&mut self) -> ::Result<()> { + Ok(()) + } + + fn read_map_begin(&mut self) -> ::Result<TMapIdentifier> { + let key_type: TType = self.read_byte().and_then(field_type_from_u8)?; + let value_type: TType = self.read_byte().and_then(field_type_from_u8)?; + let size = self.read_i32()?; + Ok(TMapIdentifier::new(key_type, value_type, size)) + } + + fn read_map_end(&mut self) -> ::Result<()> { + Ok(()) + } + + // utility + // + + fn read_byte(&mut self) -> ::Result<u8> { + self.transport.read_u8().map_err(From::from) + } +} + +/// Factory for creating instances of `TBinaryInputProtocol`. +#[derive(Default)] +pub struct TBinaryInputProtocolFactory; + +impl TBinaryInputProtocolFactory { + /// Create a `TBinaryInputProtocolFactory`. + pub fn new() -> TBinaryInputProtocolFactory { + TBinaryInputProtocolFactory {} + } +} + +impl TInputProtocolFactory for TBinaryInputProtocolFactory { + fn create(&self, transport: Box<dyn TReadTransport + Send>) -> Box<dyn TInputProtocol + Send> { + Box::new(TBinaryInputProtocol::new(transport, true)) + } +} + +/// Write messages using the Thrift simple binary encoding. +/// +/// There are two available modes: `strict` and `non-strict`, where the +/// `strict` version writes the protocol version number in the outgoing message +/// header and the `non-strict` version does not. +/// +/// # Examples +/// +/// Create and use a `TBinaryOutputProtocol`. +/// +/// ```no_run +/// use thrift::protocol::{TBinaryOutputProtocol, TOutputProtocol}; +/// use thrift::transport::TTcpChannel; +/// +/// let mut channel = TTcpChannel::new(); +/// channel.open("localhost:9090").unwrap(); +/// +/// let mut protocol = TBinaryOutputProtocol::new(channel, true); +/// +/// protocol.write_bool(true).unwrap(); +/// protocol.write_string("test_string").unwrap(); +/// ``` +#[derive(Debug)] +pub struct TBinaryOutputProtocol<T> +where + T: TWriteTransport, +{ + strict: bool, + pub transport: T, // FIXME: do not make public; only public for testing! +} + +impl<T> TBinaryOutputProtocol<T> +where + T: TWriteTransport, +{ + /// Create a `TBinaryOutputProtocol` that writes bytes to `transport`. + /// + /// Set `strict` to `true` if all outgoing messages should contain the + /// protocol version number in the protocol header. + pub fn new(transport: T, strict: bool) -> TBinaryOutputProtocol<T> { + TBinaryOutputProtocol { + strict: strict, + transport: transport, + } + } +} + +impl<T> TOutputProtocol for TBinaryOutputProtocol<T> +where + T: TWriteTransport, +{ + fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> ::Result<()> { + if self.strict { + let message_type: u8 = identifier.message_type.into(); + let header = BINARY_PROTOCOL_VERSION_1 | (message_type as u32); + self.transport.write_u32::<BigEndian>(header)?; + self.write_string(&identifier.name)?; + self.write_i32(identifier.sequence_number) + } else { + self.write_string(&identifier.name)?; + self.write_byte(identifier.message_type.into())?; + self.write_i32(identifier.sequence_number) + } + } + + fn write_message_end(&mut self) -> ::Result<()> { + Ok(()) + } + + fn write_struct_begin(&mut self, _: &TStructIdentifier) -> ::Result<()> { + Ok(()) + } + + fn write_struct_end(&mut self) -> ::Result<()> { + Ok(()) + } + + fn write_field_begin(&mut self, identifier: &TFieldIdentifier) -> ::Result<()> { + if identifier.id.is_none() && identifier.field_type != TType::Stop { + return Err(::Error::Protocol(ProtocolError { + kind: ProtocolErrorKind::Unknown, + message: format!( + "cannot write identifier {:?} without sequence number", + &identifier + ), + })); + } + + self.write_byte(field_type_to_u8(identifier.field_type))?; + if let Some(id) = identifier.id { + self.write_i16(id) + } else { + Ok(()) + } + } + + fn write_field_end(&mut self) -> ::Result<()> { + Ok(()) + } + + fn write_field_stop(&mut self) -> ::Result<()> { + self.write_byte(field_type_to_u8(TType::Stop)) + } + + fn write_bytes(&mut self, b: &[u8]) -> ::Result<()> { + self.write_i32(b.len() as i32)?; + self.transport.write_all(b).map_err(From::from) + } + + fn write_bool(&mut self, b: bool) -> ::Result<()> { + if b { + self.write_i8(1) + } else { + self.write_i8(0) + } + } + + fn write_i8(&mut self, i: i8) -> ::Result<()> { + self.transport.write_i8(i).map_err(From::from) + } + + fn write_i16(&mut self, i: i16) -> ::Result<()> { + self.transport.write_i16::<BigEndian>(i).map_err(From::from) + } + + fn write_i32(&mut self, i: i32) -> ::Result<()> { + self.transport.write_i32::<BigEndian>(i).map_err(From::from) + } + + fn write_i64(&mut self, i: i64) -> ::Result<()> { + self.transport.write_i64::<BigEndian>(i).map_err(From::from) + } + + fn write_double(&mut self, d: f64) -> ::Result<()> { + self.transport.write_f64::<BigEndian>(d).map_err(From::from) + } + + fn write_string(&mut self, s: &str) -> ::Result<()> { + self.write_bytes(s.as_bytes()) + } + + fn write_list_begin(&mut self, identifier: &TListIdentifier) -> ::Result<()> { + self.write_byte(field_type_to_u8(identifier.element_type))?; + self.write_i32(identifier.size) + } + + fn write_list_end(&mut self) -> ::Result<()> { + Ok(()) + } + + fn write_set_begin(&mut self, identifier: &TSetIdentifier) -> ::Result<()> { + self.write_byte(field_type_to_u8(identifier.element_type))?; + self.write_i32(identifier.size) + } + + fn write_set_end(&mut self) -> ::Result<()> { + Ok(()) + } + + fn write_map_begin(&mut self, identifier: &TMapIdentifier) -> ::Result<()> { + let key_type = identifier + .key_type + .expect("map identifier to write should contain key type"); + self.write_byte(field_type_to_u8(key_type))?; + let val_type = identifier + .value_type + .expect("map identifier to write should contain value type"); + self.write_byte(field_type_to_u8(val_type))?; + self.write_i32(identifier.size) + } + + fn write_map_end(&mut self) -> ::Result<()> { + Ok(()) + } + + fn flush(&mut self) -> ::Result<()> { + self.transport.flush().map_err(From::from) + } + + // utility + // + + fn write_byte(&mut self, b: u8) -> ::Result<()> { + self.transport.write_u8(b).map_err(From::from) + } +} + +/// Factory for creating instances of `TBinaryOutputProtocol`. +#[derive(Default)] +pub struct TBinaryOutputProtocolFactory; + +impl TBinaryOutputProtocolFactory { + /// Create a `TBinaryOutputProtocolFactory`. + pub fn new() -> TBinaryOutputProtocolFactory { + TBinaryOutputProtocolFactory {} + } +} + +impl TOutputProtocolFactory for TBinaryOutputProtocolFactory { + fn create(&self, transport: Box<dyn TWriteTransport + Send>) -> Box<dyn TOutputProtocol + Send> { + Box::new(TBinaryOutputProtocol::new(transport, true)) + } +} + +fn field_type_to_u8(field_type: TType) -> u8 { + match field_type { + TType::Stop => 0x00, + TType::Void => 0x01, + TType::Bool => 0x02, + TType::I08 => 0x03, // equivalent to TType::Byte + TType::Double => 0x04, + TType::I16 => 0x06, + TType::I32 => 0x08, + TType::I64 => 0x0A, + TType::String | TType::Utf7 => 0x0B, + TType::Struct => 0x0C, + TType::Map => 0x0D, + TType::Set => 0x0E, + TType::List => 0x0F, + TType::Utf8 => 0x10, + TType::Utf16 => 0x11, + } +} + +fn field_type_from_u8(b: u8) -> ::Result<TType> { + match b { + 0x00 => Ok(TType::Stop), + 0x01 => Ok(TType::Void), + 0x02 => Ok(TType::Bool), + 0x03 => Ok(TType::I08), // Equivalent to TType::Byte + 0x04 => Ok(TType::Double), + 0x06 => Ok(TType::I16), + 0x08 => Ok(TType::I32), + 0x0A => Ok(TType::I64), + 0x0B => Ok(TType::String), // technically, also a UTF7, but we'll treat it as string + 0x0C => Ok(TType::Struct), + 0x0D => Ok(TType::Map), + 0x0E => Ok(TType::Set), + 0x0F => Ok(TType::List), + 0x10 => Ok(TType::Utf8), + 0x11 => Ok(TType::Utf16), + unkn => Err(::Error::Protocol(ProtocolError { + kind: ProtocolErrorKind::InvalidData, + message: format!("cannot convert {} to TType", unkn), + })), + } +} + +#[cfg(test)] +mod tests { + + use protocol::{ + TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, TMessageIdentifier, + TMessageType, TOutputProtocol, TSetIdentifier, TStructIdentifier, TType, + }; + use transport::{ReadHalf, TBufferChannel, TIoChannel, WriteHalf}; + + use super::*; + + #[test] + fn must_write_strict_message_call_begin() { + let (_, mut o_prot) = test_objects(true); + + let ident = TMessageIdentifier::new("test", TMessageType::Call, 1); + assert!(o_prot.write_message_begin(&ident).is_ok()); + + #[cfg_attr(rustfmt, rustfmt::skip)] + let expected: [u8; 16] = [ + 0x80, + 0x01, + 0x00, + 0x01, + 0x00, + 0x00, + 0x00, + 0x04, + 0x74, + 0x65, + 0x73, + 0x74, + 0x00, + 0x00, + 0x00, + 0x01, + ]; + + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_write_non_strict_message_call_begin() { + let (_, mut o_prot) = test_objects(false); + + let ident = TMessageIdentifier::new("test", TMessageType::Call, 1); + assert!(o_prot.write_message_begin(&ident).is_ok()); + + #[cfg_attr(rustfmt, rustfmt::skip)] + let expected: [u8; 13] = [ + 0x00, + 0x00, + 0x00, + 0x04, + 0x74, + 0x65, + 0x73, + 0x74, + 0x01, + 0x00, + 0x00, + 0x00, + 0x01, + ]; + + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_write_strict_message_reply_begin() { + let (_, mut o_prot) = test_objects(true); + + let ident = TMessageIdentifier::new("test", TMessageType::Reply, 10); + assert!(o_prot.write_message_begin(&ident).is_ok()); + + #[cfg_attr(rustfmt, rustfmt::skip)] + let expected: [u8; 16] = [ + 0x80, + 0x01, + 0x00, + 0x02, + 0x00, + 0x00, + 0x00, + 0x04, + 0x74, + 0x65, + 0x73, + 0x74, + 0x00, + 0x00, + 0x00, + 0x0A, + ]; + + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_write_non_strict_message_reply_begin() { + let (_, mut o_prot) = test_objects(false); + + let ident = TMessageIdentifier::new("test", TMessageType::Reply, 10); + assert!(o_prot.write_message_begin(&ident).is_ok()); + + #[cfg_attr(rustfmt, rustfmt::skip)] + let expected: [u8; 13] = [ + 0x00, + 0x00, + 0x00, + 0x04, + 0x74, + 0x65, + 0x73, + 0x74, + 0x02, + 0x00, + 0x00, + 0x00, + 0x0A, + ]; + + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_round_trip_strict_message_begin() { + let (mut i_prot, mut o_prot) = test_objects(true); + + let sent_ident = TMessageIdentifier::new("test", TMessageType::Call, 1); + assert!(o_prot.write_message_begin(&sent_ident).is_ok()); + + copy_write_buffer_to_read_buffer!(o_prot); + + let received_ident = assert_success!(i_prot.read_message_begin()); + assert_eq!(&received_ident, &sent_ident); + } + + #[test] + fn must_round_trip_non_strict_message_begin() { + let (mut i_prot, mut o_prot) = test_objects(false); + + let sent_ident = TMessageIdentifier::new("test", TMessageType::Call, 1); + assert!(o_prot.write_message_begin(&sent_ident).is_ok()); + + copy_write_buffer_to_read_buffer!(o_prot); + + let received_ident = assert_success!(i_prot.read_message_begin()); + assert_eq!(&received_ident, &sent_ident); + } + + #[test] + fn must_write_message_end() { + assert_no_write(|o| o.write_message_end(), true); + } + + #[test] + fn must_write_struct_begin() { + assert_no_write( + |o| o.write_struct_begin(&TStructIdentifier::new("foo")), + true, + ); + } + + #[test] + fn must_write_struct_end() { + assert_no_write(|o| o.write_struct_end(), true); + } + + #[test] + fn must_write_field_begin() { + let (_, mut o_prot) = test_objects(true); + + assert!(o_prot + .write_field_begin(&TFieldIdentifier::new("some_field", TType::String, 22)) + .is_ok()); + + let expected: [u8; 3] = [0x0B, 0x00, 0x16]; + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_round_trip_field_begin() { + let (mut i_prot, mut o_prot) = test_objects(true); + + let sent_field_ident = TFieldIdentifier::new("foo", TType::I64, 20); + assert!(o_prot.write_field_begin(&sent_field_ident).is_ok()); + + copy_write_buffer_to_read_buffer!(o_prot); + + let expected_ident = TFieldIdentifier { + name: None, + field_type: TType::I64, + id: Some(20), + }; // no name + let received_ident = assert_success!(i_prot.read_field_begin()); + assert_eq!(&received_ident, &expected_ident); + } + + #[test] + fn must_write_stop_field() { + let (_, mut o_prot) = test_objects(true); + + assert!(o_prot.write_field_stop().is_ok()); + + let expected: [u8; 1] = [0x00]; + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_round_trip_field_stop() { + let (mut i_prot, mut o_prot) = test_objects(true); + + assert!(o_prot.write_field_stop().is_ok()); + + copy_write_buffer_to_read_buffer!(o_prot); + + let expected_ident = TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: Some(0), + }; // we get id 0 + + let received_ident = assert_success!(i_prot.read_field_begin()); + assert_eq!(&received_ident, &expected_ident); + } + + #[test] + fn must_write_field_end() { + assert_no_write(|o| o.write_field_end(), true); + } + + #[test] + fn must_write_list_begin() { + let (_, mut o_prot) = test_objects(true); + + assert!(o_prot + .write_list_begin(&TListIdentifier::new(TType::Bool, 5)) + .is_ok()); + + let expected: [u8; 5] = [0x02, 0x00, 0x00, 0x00, 0x05]; + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_round_trip_list_begin() { + let (mut i_prot, mut o_prot) = test_objects(true); + + let ident = TListIdentifier::new(TType::List, 900); + assert!(o_prot.write_list_begin(&ident).is_ok()); + + copy_write_buffer_to_read_buffer!(o_prot); + + let received_ident = assert_success!(i_prot.read_list_begin()); + assert_eq!(&received_ident, &ident); + } + + #[test] + fn must_write_list_end() { + assert_no_write(|o| o.write_list_end(), true); + } + + #[test] + fn must_write_set_begin() { + let (_, mut o_prot) = test_objects(true); + + assert!(o_prot + .write_set_begin(&TSetIdentifier::new(TType::I16, 7)) + .is_ok()); + + let expected: [u8; 5] = [0x06, 0x00, 0x00, 0x00, 0x07]; + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_round_trip_set_begin() { + let (mut i_prot, mut o_prot) = test_objects(true); + + let ident = TSetIdentifier::new(TType::I64, 2000); + assert!(o_prot.write_set_begin(&ident).is_ok()); + + copy_write_buffer_to_read_buffer!(o_prot); + + let received_ident_result = i_prot.read_set_begin(); + assert!(received_ident_result.is_ok()); + assert_eq!(&received_ident_result.unwrap(), &ident); + } + + #[test] + fn must_write_set_end() { + assert_no_write(|o| o.write_set_end(), true); + } + + #[test] + fn must_write_map_begin() { + let (_, mut o_prot) = test_objects(true); + + assert!(o_prot + .write_map_begin(&TMapIdentifier::new(TType::I64, TType::Struct, 32)) + .is_ok()); + + let expected: [u8; 6] = [0x0A, 0x0C, 0x00, 0x00, 0x00, 0x20]; + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_round_trip_map_begin() { + let (mut i_prot, mut o_prot) = test_objects(true); + + let ident = TMapIdentifier::new(TType::Map, TType::Set, 100); + assert!(o_prot.write_map_begin(&ident).is_ok()); + + copy_write_buffer_to_read_buffer!(o_prot); + + let received_ident = assert_success!(i_prot.read_map_begin()); + assert_eq!(&received_ident, &ident); + } + + #[test] + fn must_write_map_end() { + assert_no_write(|o| o.write_map_end(), true); + } + + #[test] + fn must_write_bool_true() { + let (_, mut o_prot) = test_objects(true); + + assert!(o_prot.write_bool(true).is_ok()); + + let expected: [u8; 1] = [0x01]; + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_write_bool_false() { + let (_, mut o_prot) = test_objects(true); + + assert!(o_prot.write_bool(false).is_ok()); + + let expected: [u8; 1] = [0x00]; + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_read_bool_true() { + let (mut i_prot, _) = test_objects(true); + + set_readable_bytes!(i_prot, &[0x01]); + + let read_bool = assert_success!(i_prot.read_bool()); + assert_eq!(read_bool, true); + } + + #[test] + fn must_read_bool_false() { + let (mut i_prot, _) = test_objects(true); + + set_readable_bytes!(i_prot, &[0x00]); + + let read_bool = assert_success!(i_prot.read_bool()); + assert_eq!(read_bool, false); + } + + #[test] + fn must_allow_any_non_zero_value_to_be_interpreted_as_bool_true() { + let (mut i_prot, _) = test_objects(true); + + set_readable_bytes!(i_prot, &[0xAC]); + + let read_bool = assert_success!(i_prot.read_bool()); + assert_eq!(read_bool, true); + } + + #[test] + fn must_write_bytes() { + let (_, mut o_prot) = test_objects(true); + + let bytes: [u8; 10] = [0x0A, 0xCC, 0xD1, 0x84, 0x99, 0x12, 0xAB, 0xBB, 0x45, 0xDF]; + + assert!(o_prot.write_bytes(&bytes).is_ok()); + + let buf = o_prot.transport.write_bytes(); + assert_eq!(&buf[0..4], [0x00, 0x00, 0x00, 0x0A]); // length + assert_eq!(&buf[4..], bytes); // actual bytes + } + + #[test] + fn must_round_trip_bytes() { + let (mut i_prot, mut o_prot) = test_objects(true); + + #[cfg_attr(rustfmt, rustfmt::skip)] + let bytes: [u8; 25] = [ + 0x20, + 0xFD, + 0x18, + 0x84, + 0x99, + 0x12, + 0xAB, + 0xBB, + 0x45, + 0xDF, + 0x34, + 0xDC, + 0x98, + 0xA4, + 0x6D, + 0xF3, + 0x99, + 0xB4, + 0xB7, + 0xD4, + 0x9C, + 0xA5, + 0xB3, + 0xC9, + 0x88, + ]; + + assert!(o_prot.write_bytes(&bytes).is_ok()); + + copy_write_buffer_to_read_buffer!(o_prot); + + let received_bytes = assert_success!(i_prot.read_bytes()); + assert_eq!(&received_bytes, &bytes); + } + + fn test_objects( + strict: bool, + ) -> ( + TBinaryInputProtocol<ReadHalf<TBufferChannel>>, + TBinaryOutputProtocol<WriteHalf<TBufferChannel>>, + ) { + let mem = TBufferChannel::with_capacity(40, 40); + + let (r_mem, w_mem) = mem.split().unwrap(); + + let i_prot = TBinaryInputProtocol::new(r_mem, strict); + let o_prot = TBinaryOutputProtocol::new(w_mem, strict); + + (i_prot, o_prot) + } + + fn assert_no_write<F>(mut write_fn: F, strict: bool) + where + F: FnMut(&mut TBinaryOutputProtocol<WriteHalf<TBufferChannel>>) -> ::Result<()>, + { + let (_, mut o_prot) = test_objects(strict); + assert!(write_fn(&mut o_prot).is_ok()); + assert_eq!(o_prot.transport.write_bytes().len(), 0); + } +} diff --git a/src/jaegertracing/thrift/lib/rs/src/protocol/compact.rs b/src/jaegertracing/thrift/lib/rs/src/protocol/compact.rs new file mode 100644 index 000000000..1750bc42e --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/src/protocol/compact.rs @@ -0,0 +1,2385 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; +use integer_encoding::{VarIntReader, VarIntWriter}; +use std::convert::{From, TryFrom}; +use std::io; + +use super::{ + TFieldIdentifier, TInputProtocol, TInputProtocolFactory, TListIdentifier, TMapIdentifier, + TMessageIdentifier, TMessageType, +}; +use super::{TOutputProtocol, TOutputProtocolFactory, TSetIdentifier, TStructIdentifier, TType}; +use transport::{TReadTransport, TWriteTransport}; + +const COMPACT_PROTOCOL_ID: u8 = 0x82; +const COMPACT_VERSION: u8 = 0x01; +const COMPACT_VERSION_MASK: u8 = 0x1F; + +/// Read messages encoded in the Thrift compact protocol. +/// +/// # Examples +/// +/// Create and use a `TCompactInputProtocol`. +/// +/// ```no_run +/// use thrift::protocol::{TCompactInputProtocol, TInputProtocol}; +/// use thrift::transport::TTcpChannel; +/// +/// let mut channel = TTcpChannel::new(); +/// channel.open("localhost:9090").unwrap(); +/// +/// let mut protocol = TCompactInputProtocol::new(channel); +/// +/// let recvd_bool = protocol.read_bool().unwrap(); +/// let recvd_string = protocol.read_string().unwrap(); +/// ``` +#[derive(Debug)] +pub struct TCompactInputProtocol<T> +where + T: TReadTransport, +{ + // Identifier of the last field deserialized for a struct. + last_read_field_id: i16, + // Stack of the last read field ids (a new entry is added each time a nested struct is read). + read_field_id_stack: Vec<i16>, + // Boolean value for a field. + // Saved because boolean fields and their value are encoded in a single byte, + // and reading the field only occurs after the field id is read. + pending_read_bool_value: Option<bool>, + // Underlying transport used for byte-level operations. + transport: T, +} + +impl<T> TCompactInputProtocol<T> +where + T: TReadTransport, +{ + /// Create a `TCompactInputProtocol` that reads bytes from `transport`. + pub fn new(transport: T) -> TCompactInputProtocol<T> { + TCompactInputProtocol { + last_read_field_id: 0, + read_field_id_stack: Vec::new(), + pending_read_bool_value: None, + transport: transport, + } + } + + fn read_list_set_begin(&mut self) -> ::Result<(TType, i32)> { + let header = self.read_byte()?; + let element_type = collection_u8_to_type(header & 0x0F)?; + + let element_count; + let possible_element_count = (header & 0xF0) >> 4; + if possible_element_count != 15 { + // high bits set high if count and type encoded separately + element_count = possible_element_count as i32; + } else { + element_count = self.transport.read_varint::<u32>()? as i32; + } + + Ok((element_type, element_count)) + } +} + +impl<T> TInputProtocol for TCompactInputProtocol<T> +where + T: TReadTransport, +{ + fn read_message_begin(&mut self) -> ::Result<TMessageIdentifier> { + let compact_id = self.read_byte()?; + if compact_id != COMPACT_PROTOCOL_ID { + Err(::Error::Protocol(::ProtocolError { + kind: ::ProtocolErrorKind::BadVersion, + message: format!("invalid compact protocol header {:?}", compact_id), + })) + } else { + Ok(()) + }?; + + let type_and_byte = self.read_byte()?; + let received_version = type_and_byte & COMPACT_VERSION_MASK; + if received_version != COMPACT_VERSION { + Err(::Error::Protocol(::ProtocolError { + kind: ::ProtocolErrorKind::BadVersion, + message: format!( + "cannot process compact protocol version {:?}", + received_version + ), + })) + } else { + Ok(()) + }?; + + // NOTE: unsigned right shift will pad with 0s + let message_type: TMessageType = TMessageType::try_from(type_and_byte >> 5)?; + let sequence_number = self.read_i32()?; + let service_call_name = self.read_string()?; + + self.last_read_field_id = 0; + + Ok(TMessageIdentifier::new( + service_call_name, + message_type, + sequence_number, + )) + } + + fn read_message_end(&mut self) -> ::Result<()> { + Ok(()) + } + + fn read_struct_begin(&mut self) -> ::Result<Option<TStructIdentifier>> { + self.read_field_id_stack.push(self.last_read_field_id); + self.last_read_field_id = 0; + Ok(None) + } + + fn read_struct_end(&mut self) -> ::Result<()> { + self.last_read_field_id = self + .read_field_id_stack + .pop() + .expect("should have previous field ids"); + Ok(()) + } + + fn read_field_begin(&mut self) -> ::Result<TFieldIdentifier> { + // we can read at least one byte, which is: + // - the type + // - the field delta and the type + let field_type = self.read_byte()?; + let field_delta = (field_type & 0xF0) >> 4; + let field_type = match field_type & 0x0F { + 0x01 => { + self.pending_read_bool_value = Some(true); + Ok(TType::Bool) + } + 0x02 => { + self.pending_read_bool_value = Some(false); + Ok(TType::Bool) + } + ttu8 => u8_to_type(ttu8), + }?; + + match field_type { + TType::Stop => Ok( + TFieldIdentifier::new::<Option<String>, String, Option<i16>>( + None, + TType::Stop, + None, + ), + ), + _ => { + if field_delta != 0 { + self.last_read_field_id += field_delta as i16; + } else { + self.last_read_field_id = self.read_i16()?; + }; + + Ok(TFieldIdentifier { + name: None, + field_type: field_type, + id: Some(self.last_read_field_id), + }) + } + } + } + + fn read_field_end(&mut self) -> ::Result<()> { + Ok(()) + } + + fn read_bool(&mut self) -> ::Result<bool> { + match self.pending_read_bool_value.take() { + Some(b) => Ok(b), + None => { + let b = self.read_byte()?; + match b { + 0x01 => Ok(true), + 0x02 => Ok(false), + unkn => Err(::Error::Protocol(::ProtocolError { + kind: ::ProtocolErrorKind::InvalidData, + message: format!("cannot convert {} into bool", unkn), + })), + } + } + } + } + + fn read_bytes(&mut self) -> ::Result<Vec<u8>> { + let len = self.transport.read_varint::<u32>()?; + let mut buf = vec![0u8; len as usize]; + self.transport + .read_exact(&mut buf) + .map_err(From::from) + .map(|_| buf) + } + + fn read_i8(&mut self) -> ::Result<i8> { + self.read_byte().map(|i| i as i8) + } + + fn read_i16(&mut self) -> ::Result<i16> { + self.transport.read_varint::<i16>().map_err(From::from) + } + + fn read_i32(&mut self) -> ::Result<i32> { + self.transport.read_varint::<i32>().map_err(From::from) + } + + fn read_i64(&mut self) -> ::Result<i64> { + self.transport.read_varint::<i64>().map_err(From::from) + } + + fn read_double(&mut self) -> ::Result<f64> { + self.transport.read_f64::<BigEndian>().map_err(From::from) + } + + fn read_string(&mut self) -> ::Result<String> { + let bytes = self.read_bytes()?; + String::from_utf8(bytes).map_err(From::from) + } + + fn read_list_begin(&mut self) -> ::Result<TListIdentifier> { + let (element_type, element_count) = self.read_list_set_begin()?; + Ok(TListIdentifier::new(element_type, element_count)) + } + + fn read_list_end(&mut self) -> ::Result<()> { + Ok(()) + } + + fn read_set_begin(&mut self) -> ::Result<TSetIdentifier> { + let (element_type, element_count) = self.read_list_set_begin()?; + Ok(TSetIdentifier::new(element_type, element_count)) + } + + fn read_set_end(&mut self) -> ::Result<()> { + Ok(()) + } + + fn read_map_begin(&mut self) -> ::Result<TMapIdentifier> { + let element_count = self.transport.read_varint::<u32>()? as i32; + if element_count == 0 { + Ok(TMapIdentifier::new(None, None, 0)) + } else { + let type_header = self.read_byte()?; + let key_type = collection_u8_to_type((type_header & 0xF0) >> 4)?; + let val_type = collection_u8_to_type(type_header & 0x0F)?; + Ok(TMapIdentifier::new(key_type, val_type, element_count)) + } + } + + fn read_map_end(&mut self) -> ::Result<()> { + Ok(()) + } + + // utility + // + + fn read_byte(&mut self) -> ::Result<u8> { + let mut buf = [0u8; 1]; + self.transport + .read_exact(&mut buf) + .map_err(From::from) + .map(|_| buf[0]) + } +} + +impl<T> io::Seek for TCompactInputProtocol<T> +where + T: io::Seek + TReadTransport, +{ + fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> { + self.transport.seek(pos) + } +} + +/// Factory for creating instances of `TCompactInputProtocol`. +#[derive(Default)] +pub struct TCompactInputProtocolFactory; + +impl TCompactInputProtocolFactory { + /// Create a `TCompactInputProtocolFactory`. + pub fn new() -> TCompactInputProtocolFactory { + TCompactInputProtocolFactory {} + } +} + +impl TInputProtocolFactory for TCompactInputProtocolFactory { + fn create(&self, transport: Box<dyn TReadTransport + Send>) -> Box<dyn TInputProtocol + Send> { + Box::new(TCompactInputProtocol::new(transport)) + } +} + +/// Write messages using the Thrift compact protocol. +/// +/// # Examples +/// +/// Create and use a `TCompactOutputProtocol`. +/// +/// ```no_run +/// use thrift::protocol::{TCompactOutputProtocol, TOutputProtocol}; +/// use thrift::transport::TTcpChannel; +/// +/// let mut channel = TTcpChannel::new(); +/// channel.open("localhost:9090").unwrap(); +/// +/// let mut protocol = TCompactOutputProtocol::new(channel); +/// +/// protocol.write_bool(true).unwrap(); +/// protocol.write_string("test_string").unwrap(); +/// ``` +#[derive(Debug)] +pub struct TCompactOutputProtocol<T> +where + T: TWriteTransport, +{ + // Identifier of the last field serialized for a struct. + last_write_field_id: i16, + // Stack of the last written field ids (new entry added each time a nested struct is written). + write_field_id_stack: Vec<i16>, + // Field identifier of the boolean field to be written. + // Saved because boolean fields and their value are encoded in a single byte + pending_write_bool_field_identifier: Option<TFieldIdentifier>, + // Underlying transport used for byte-level operations. + transport: T, +} + +impl<T> TCompactOutputProtocol<T> +where + T: TWriteTransport, +{ + /// Create a `TCompactOutputProtocol` that writes bytes to `transport`. + pub fn new(transport: T) -> TCompactOutputProtocol<T> { + TCompactOutputProtocol { + last_write_field_id: 0, + write_field_id_stack: Vec::new(), + pending_write_bool_field_identifier: None, + transport: transport, + } + } + + // FIXME: field_type as unconstrained u8 is bad + fn write_field_header(&mut self, field_type: u8, field_id: i16) -> ::Result<()> { + let field_delta = field_id - self.last_write_field_id; + if field_delta > 0 && field_delta < 15 { + self.write_byte(((field_delta as u8) << 4) | field_type)?; + } else { + self.write_byte(field_type)?; + self.write_i16(field_id)?; + } + self.last_write_field_id = field_id; + Ok(()) + } + + fn write_list_set_begin(&mut self, element_type: TType, element_count: i32) -> ::Result<()> { + let elem_identifier = collection_type_to_u8(element_type); + if element_count <= 14 { + let header = (element_count as u8) << 4 | elem_identifier; + self.write_byte(header) + } else { + let header = 0xF0 | elem_identifier; + self.write_byte(header)?; + self.transport + .write_varint(element_count as u32) + .map_err(From::from) + .map(|_| ()) + } + } + + fn assert_no_pending_bool_write(&self) { + if let Some(ref f) = self.pending_write_bool_field_identifier { + panic!("pending bool field {:?} not written", f) + } + } +} + +impl<T> TOutputProtocol for TCompactOutputProtocol<T> +where + T: TWriteTransport, +{ + fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> ::Result<()> { + self.write_byte(COMPACT_PROTOCOL_ID)?; + self.write_byte((u8::from(identifier.message_type) << 5) | COMPACT_VERSION)?; + self.write_i32(identifier.sequence_number)?; + self.write_string(&identifier.name)?; + Ok(()) + } + + fn write_message_end(&mut self) -> ::Result<()> { + self.assert_no_pending_bool_write(); + Ok(()) + } + + fn write_struct_begin(&mut self, _: &TStructIdentifier) -> ::Result<()> { + self.write_field_id_stack.push(self.last_write_field_id); + self.last_write_field_id = 0; + Ok(()) + } + + fn write_struct_end(&mut self) -> ::Result<()> { + self.assert_no_pending_bool_write(); + self.last_write_field_id = self + .write_field_id_stack + .pop() + .expect("should have previous field ids"); + Ok(()) + } + + fn write_field_begin(&mut self, identifier: &TFieldIdentifier) -> ::Result<()> { + match identifier.field_type { + TType::Bool => { + if self.pending_write_bool_field_identifier.is_some() { + panic!( + "should not have a pending bool while writing another bool with id: \ + {:?}", + identifier + ) + } + self.pending_write_bool_field_identifier = Some(identifier.clone()); + Ok(()) + } + _ => { + let field_type = type_to_u8(identifier.field_type); + let field_id = identifier.id.expect("non-stop field should have field id"); + self.write_field_header(field_type, field_id) + } + } + } + + fn write_field_end(&mut self) -> ::Result<()> { + self.assert_no_pending_bool_write(); + Ok(()) + } + + fn write_field_stop(&mut self) -> ::Result<()> { + self.assert_no_pending_bool_write(); + self.write_byte(type_to_u8(TType::Stop)) + } + + fn write_bool(&mut self, b: bool) -> ::Result<()> { + match self.pending_write_bool_field_identifier.take() { + Some(pending) => { + let field_id = pending.id.expect("bool field should have a field id"); + let field_type_as_u8 = if b { 0x01 } else { 0x02 }; + self.write_field_header(field_type_as_u8, field_id) + } + None => { + if b { + self.write_byte(0x01) + } else { + self.write_byte(0x02) + } + } + } + } + + fn write_bytes(&mut self, b: &[u8]) -> ::Result<()> { + self.transport.write_varint(b.len() as u32)?; + self.transport.write_all(b).map_err(From::from) + } + + fn write_i8(&mut self, i: i8) -> ::Result<()> { + self.write_byte(i as u8) + } + + fn write_i16(&mut self, i: i16) -> ::Result<()> { + self.transport + .write_varint(i) + .map_err(From::from) + .map(|_| ()) + } + + fn write_i32(&mut self, i: i32) -> ::Result<()> { + self.transport + .write_varint(i) + .map_err(From::from) + .map(|_| ()) + } + + fn write_i64(&mut self, i: i64) -> ::Result<()> { + self.transport + .write_varint(i) + .map_err(From::from) + .map(|_| ()) + } + + fn write_double(&mut self, d: f64) -> ::Result<()> { + self.transport.write_f64::<BigEndian>(d).map_err(From::from) + } + + fn write_string(&mut self, s: &str) -> ::Result<()> { + self.write_bytes(s.as_bytes()) + } + + fn write_list_begin(&mut self, identifier: &TListIdentifier) -> ::Result<()> { + self.write_list_set_begin(identifier.element_type, identifier.size) + } + + fn write_list_end(&mut self) -> ::Result<()> { + Ok(()) + } + + fn write_set_begin(&mut self, identifier: &TSetIdentifier) -> ::Result<()> { + self.write_list_set_begin(identifier.element_type, identifier.size) + } + + fn write_set_end(&mut self) -> ::Result<()> { + Ok(()) + } + + fn write_map_begin(&mut self, identifier: &TMapIdentifier) -> ::Result<()> { + if identifier.size == 0 { + self.write_byte(0) + } else { + self.transport.write_varint(identifier.size as u32)?; + + let key_type = identifier + .key_type + .expect("map identifier to write should contain key type"); + let key_type_byte = collection_type_to_u8(key_type) << 4; + + let val_type = identifier + .value_type + .expect("map identifier to write should contain value type"); + let val_type_byte = collection_type_to_u8(val_type); + + let map_type_header = key_type_byte | val_type_byte; + self.write_byte(map_type_header) + } + } + + fn write_map_end(&mut self) -> ::Result<()> { + Ok(()) + } + + fn flush(&mut self) -> ::Result<()> { + self.transport.flush().map_err(From::from) + } + + // utility + // + + fn write_byte(&mut self, b: u8) -> ::Result<()> { + self.transport.write(&[b]).map_err(From::from).map(|_| ()) + } +} + +/// Factory for creating instances of `TCompactOutputProtocol`. +#[derive(Default)] +pub struct TCompactOutputProtocolFactory; + +impl TCompactOutputProtocolFactory { + /// Create a `TCompactOutputProtocolFactory`. + pub fn new() -> TCompactOutputProtocolFactory { + TCompactOutputProtocolFactory {} + } +} + +impl TOutputProtocolFactory for TCompactOutputProtocolFactory { + fn create(&self, transport: Box<dyn TWriteTransport + Send>) -> Box<dyn TOutputProtocol + Send> { + Box::new(TCompactOutputProtocol::new(transport)) + } +} + +fn collection_type_to_u8(field_type: TType) -> u8 { + match field_type { + TType::Bool => 0x01, + f => type_to_u8(f), + } +} + +fn type_to_u8(field_type: TType) -> u8 { + match field_type { + TType::Stop => 0x00, + TType::I08 => 0x03, // equivalent to TType::Byte + TType::I16 => 0x04, + TType::I32 => 0x05, + TType::I64 => 0x06, + TType::Double => 0x07, + TType::String => 0x08, + TType::List => 0x09, + TType::Set => 0x0A, + TType::Map => 0x0B, + TType::Struct => 0x0C, + _ => panic!(format!( + "should not have attempted to convert {} to u8", + field_type + )), + } +} + +fn collection_u8_to_type(b: u8) -> ::Result<TType> { + match b { + 0x01 => Ok(TType::Bool), + o => u8_to_type(o), + } +} + +fn u8_to_type(b: u8) -> ::Result<TType> { + match b { + 0x00 => Ok(TType::Stop), + 0x03 => Ok(TType::I08), // equivalent to TType::Byte + 0x04 => Ok(TType::I16), + 0x05 => Ok(TType::I32), + 0x06 => Ok(TType::I64), + 0x07 => Ok(TType::Double), + 0x08 => Ok(TType::String), + 0x09 => Ok(TType::List), + 0x0A => Ok(TType::Set), + 0x0B => Ok(TType::Map), + 0x0C => Ok(TType::Struct), + unkn => Err(::Error::Protocol(::ProtocolError { + kind: ::ProtocolErrorKind::InvalidData, + message: format!("cannot convert {} into TType", unkn), + })), + } +} + +#[cfg(test)] +mod tests { + + use protocol::{ + TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, TMessageIdentifier, + TMessageType, TOutputProtocol, TSetIdentifier, TStructIdentifier, TType, + }; + use transport::{ReadHalf, TBufferChannel, TIoChannel, WriteHalf}; + + use super::*; + + #[test] + fn must_write_message_begin_0() { + let (_, mut o_prot) = test_objects(); + + assert_success!(o_prot.write_message_begin(&TMessageIdentifier::new( + "foo", + TMessageType::Call, + 431 + ))); + + #[cfg_attr(rustfmt, rustfmt::skip)] + let expected: [u8; 8] = [ + 0x82, /* protocol ID */ + 0x21, /* message type | protocol version */ + 0xDE, + 0x06, /* zig-zag varint sequence number */ + 0x03, /* message-name length */ + 0x66, + 0x6F, + 0x6F /* "foo" */, + ]; + + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_write_message_begin_1() { + let (_, mut o_prot) = test_objects(); + + assert_success!(o_prot.write_message_begin(&TMessageIdentifier::new( + "bar", + TMessageType::Reply, + 991828 + ))); + + #[cfg_attr(rustfmt, rustfmt::skip)] + let expected: [u8; 9] = [ + 0x82, /* protocol ID */ + 0x41, /* message type | protocol version */ + 0xA8, + 0x89, + 0x79, /* zig-zag varint sequence number */ + 0x03, /* message-name length */ + 0x62, + 0x61, + 0x72 /* "bar" */, + ]; + + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_round_trip_message_begin() { + let (mut i_prot, mut o_prot) = test_objects(); + + let ident = TMessageIdentifier::new("service_call", TMessageType::Call, 1283948); + + assert_success!(o_prot.write_message_begin(&ident)); + + copy_write_buffer_to_read_buffer!(o_prot); + + let res = assert_success!(i_prot.read_message_begin()); + assert_eq!(&res, &ident); + } + + #[test] + fn must_write_message_end() { + assert_no_write(|o| o.write_message_end()); + } + + // NOTE: structs and fields are tested together + // + + #[test] + fn must_write_struct_with_delta_fields() { + let (_, mut o_prot) = test_objects(); + + // no bytes should be written however + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // write three fields with tiny field ids + // since they're small the field ids will be encoded as deltas + + // since this is the first field (and it's zero) it gets the full varint write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I08, 0))); + assert_success!(o_prot.write_field_end()); + + // since this delta > 0 and < 15 it can be encoded as a delta + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I16, 4))); + assert_success!(o_prot.write_field_end()); + + // since this delta > 0 and < 15 it can be encoded as a delta + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::List, 9))); + assert_success!(o_prot.write_field_end()); + + // now, finish the struct off + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + #[cfg_attr(rustfmt, rustfmt::skip)] + let expected: [u8; 5] = [ + 0x03, /* field type */ + 0x00, /* first field id */ + 0x44, /* field delta (4) | field type */ + 0x59, /* field delta (5) | field type */ + 0x00 /* field stop */, + ]; + + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_round_trip_struct_with_delta_fields() { + let (mut i_prot, mut o_prot) = test_objects(); + + // no bytes should be written however + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // write three fields with tiny field ids + // since they're small the field ids will be encoded as deltas + + // since this is the first field (and it's zero) it gets the full varint write + let field_ident_1 = TFieldIdentifier::new("foo", TType::I08, 0); + assert_success!(o_prot.write_field_begin(&field_ident_1)); + assert_success!(o_prot.write_field_end()); + + // since this delta > 0 and < 15 it can be encoded as a delta + let field_ident_2 = TFieldIdentifier::new("foo", TType::I16, 4); + assert_success!(o_prot.write_field_begin(&field_ident_2)); + assert_success!(o_prot.write_field_end()); + + // since this delta > 0 and < 15 it can be encoded as a delta + let field_ident_3 = TFieldIdentifier::new("foo", TType::List, 9); + assert_success!(o_prot.write_field_begin(&field_ident_3)); + assert_success!(o_prot.write_field_end()); + + // now, finish the struct off + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + copy_write_buffer_to_read_buffer!(o_prot); + + // read the struct back + assert_success!(i_prot.read_struct_begin()); + + let read_ident_1 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_1, + TFieldIdentifier { + name: None, + ..field_ident_1 + } + ); + assert_success!(i_prot.read_field_end()); + + let read_ident_2 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_2, + TFieldIdentifier { + name: None, + ..field_ident_2 + } + ); + assert_success!(i_prot.read_field_end()); + + let read_ident_3 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_3, + TFieldIdentifier { + name: None, + ..field_ident_3 + } + ); + assert_success!(i_prot.read_field_end()); + + let read_ident_4 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_4, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); + + assert_success!(i_prot.read_struct_end()); + } + + #[test] + fn must_write_struct_with_non_zero_initial_field_and_delta_fields() { + let (_, mut o_prot) = test_objects(); + + // no bytes should be written however + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // write three fields with tiny field ids + // since they're small the field ids will be encoded as deltas + + // gets a delta write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I32, 1))); + assert_success!(o_prot.write_field_end()); + + // since this delta > 0 and < 15 it can be encoded as a delta + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Set, 2))); + assert_success!(o_prot.write_field_end()); + + // since this delta > 0 and < 15 it can be encoded as a delta + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::String, 6))); + assert_success!(o_prot.write_field_end()); + + // now, finish the struct off + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + #[cfg_attr(rustfmt, rustfmt::skip)] + let expected: [u8; 4] = [ + 0x15, /* field delta (1) | field type */ + 0x1A, /* field delta (1) | field type */ + 0x48, /* field delta (4) | field type */ + 0x00 /* field stop */, + ]; + + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_round_trip_struct_with_non_zero_initial_field_and_delta_fields() { + let (mut i_prot, mut o_prot) = test_objects(); + + // no bytes should be written however + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // write three fields with tiny field ids + // since they're small the field ids will be encoded as deltas + + // gets a delta write + let field_ident_1 = TFieldIdentifier::new("foo", TType::I32, 1); + assert_success!(o_prot.write_field_begin(&field_ident_1)); + assert_success!(o_prot.write_field_end()); + + // since this delta > 0 and < 15 it can be encoded as a delta + let field_ident_2 = TFieldIdentifier::new("foo", TType::Set, 2); + assert_success!(o_prot.write_field_begin(&field_ident_2)); + assert_success!(o_prot.write_field_end()); + + // since this delta > 0 and < 15 it can be encoded as a delta + let field_ident_3 = TFieldIdentifier::new("foo", TType::String, 6); + assert_success!(o_prot.write_field_begin(&field_ident_3)); + assert_success!(o_prot.write_field_end()); + + // now, finish the struct off + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + copy_write_buffer_to_read_buffer!(o_prot); + + // read the struct back + assert_success!(i_prot.read_struct_begin()); + + let read_ident_1 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_1, + TFieldIdentifier { + name: None, + ..field_ident_1 + } + ); + assert_success!(i_prot.read_field_end()); + + let read_ident_2 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_2, + TFieldIdentifier { + name: None, + ..field_ident_2 + } + ); + assert_success!(i_prot.read_field_end()); + + let read_ident_3 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_3, + TFieldIdentifier { + name: None, + ..field_ident_3 + } + ); + assert_success!(i_prot.read_field_end()); + + let read_ident_4 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_4, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); + + assert_success!(i_prot.read_struct_end()); + } + + #[test] + fn must_write_struct_with_long_fields() { + let (_, mut o_prot) = test_objects(); + + // no bytes should be written however + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // write three fields with field ids that cannot be encoded as deltas + + // since this is the first field (and it's zero) it gets the full varint write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I32, 0))); + assert_success!(o_prot.write_field_end()); + + // since this delta is > 15 it is encoded as a zig-zag varint + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I64, 16))); + assert_success!(o_prot.write_field_end()); + + // since this delta is > 15 it is encoded as a zig-zag varint + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Set, 99))); + assert_success!(o_prot.write_field_end()); + + // now, finish the struct off + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + #[cfg_attr(rustfmt, rustfmt::skip)] + let expected: [u8; 8] = [ + 0x05, /* field type */ + 0x00, /* first field id */ + 0x06, /* field type */ + 0x20, /* zig-zag varint field id */ + 0x0A, /* field type */ + 0xC6, + 0x01, /* zig-zag varint field id */ + 0x00 /* field stop */, + ]; + + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_round_trip_struct_with_long_fields() { + let (mut i_prot, mut o_prot) = test_objects(); + + // no bytes should be written however + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // write three fields with field ids that cannot be encoded as deltas + + // since this is the first field (and it's zero) it gets the full varint write + let field_ident_1 = TFieldIdentifier::new("foo", TType::I32, 0); + assert_success!(o_prot.write_field_begin(&field_ident_1)); + assert_success!(o_prot.write_field_end()); + + // since this delta is > 15 it is encoded as a zig-zag varint + let field_ident_2 = TFieldIdentifier::new("foo", TType::I64, 16); + assert_success!(o_prot.write_field_begin(&field_ident_2)); + assert_success!(o_prot.write_field_end()); + + // since this delta is > 15 it is encoded as a zig-zag varint + let field_ident_3 = TFieldIdentifier::new("foo", TType::Set, 99); + assert_success!(o_prot.write_field_begin(&field_ident_3)); + assert_success!(o_prot.write_field_end()); + + // now, finish the struct off + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + copy_write_buffer_to_read_buffer!(o_prot); + + // read the struct back + assert_success!(i_prot.read_struct_begin()); + + let read_ident_1 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_1, + TFieldIdentifier { + name: None, + ..field_ident_1 + } + ); + assert_success!(i_prot.read_field_end()); + + let read_ident_2 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_2, + TFieldIdentifier { + name: None, + ..field_ident_2 + } + ); + assert_success!(i_prot.read_field_end()); + + let read_ident_3 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_3, + TFieldIdentifier { + name: None, + ..field_ident_3 + } + ); + assert_success!(i_prot.read_field_end()); + + let read_ident_4 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_4, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); + + assert_success!(i_prot.read_struct_end()); + } + + #[test] + fn must_write_struct_with_mix_of_long_and_delta_fields() { + let (_, mut o_prot) = test_objects(); + + // no bytes should be written however + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // write three fields with field ids that cannot be encoded as deltas + + // since the delta is > 0 and < 15 it gets a delta write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I64, 1))); + assert_success!(o_prot.write_field_end()); + + // since this delta > 0 and < 15 it gets a delta write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I32, 9))); + assert_success!(o_prot.write_field_end()); + + // since this delta is > 15 it is encoded as a zig-zag varint + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Set, 1000))); + assert_success!(o_prot.write_field_end()); + + // since this delta is > 15 it is encoded as a zig-zag varint + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Set, 2001))); + assert_success!(o_prot.write_field_end()); + + // since this is only 3 up from the previous it is recorded as a delta + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Set, 2004))); + assert_success!(o_prot.write_field_end()); + + // now, finish the struct off + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + #[cfg_attr(rustfmt, rustfmt::skip)] + let expected: [u8; 10] = [ + 0x16, /* field delta (1) | field type */ + 0x85, /* field delta (8) | field type */ + 0x0A, /* field type */ + 0xD0, + 0x0F, /* zig-zag varint field id */ + 0x0A, /* field type */ + 0xA2, + 0x1F, /* zig-zag varint field id */ + 0x3A, /* field delta (3) | field type */ + 0x00 /* field stop */, + ]; + + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_round_trip_struct_with_mix_of_long_and_delta_fields() { + let (mut i_prot, mut o_prot) = test_objects(); + + // no bytes should be written however + let struct_ident = TStructIdentifier::new("foo"); + assert_success!(o_prot.write_struct_begin(&struct_ident)); + + // write three fields with field ids that cannot be encoded as deltas + + // since the delta is > 0 and < 15 it gets a delta write + let field_ident_1 = TFieldIdentifier::new("foo", TType::I64, 1); + assert_success!(o_prot.write_field_begin(&field_ident_1)); + assert_success!(o_prot.write_field_end()); + + // since this delta > 0 and < 15 it gets a delta write + let field_ident_2 = TFieldIdentifier::new("foo", TType::I32, 9); + assert_success!(o_prot.write_field_begin(&field_ident_2)); + assert_success!(o_prot.write_field_end()); + + // since this delta is > 15 it is encoded as a zig-zag varint + let field_ident_3 = TFieldIdentifier::new("foo", TType::Set, 1000); + assert_success!(o_prot.write_field_begin(&field_ident_3)); + assert_success!(o_prot.write_field_end()); + + // since this delta is > 15 it is encoded as a zig-zag varint + let field_ident_4 = TFieldIdentifier::new("foo", TType::Set, 2001); + assert_success!(o_prot.write_field_begin(&field_ident_4)); + assert_success!(o_prot.write_field_end()); + + // since this is only 3 up from the previous it is recorded as a delta + let field_ident_5 = TFieldIdentifier::new("foo", TType::Set, 2004); + assert_success!(o_prot.write_field_begin(&field_ident_5)); + assert_success!(o_prot.write_field_end()); + + // now, finish the struct off + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + copy_write_buffer_to_read_buffer!(o_prot); + + // read the struct back + assert_success!(i_prot.read_struct_begin()); + + let read_ident_1 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_1, + TFieldIdentifier { + name: None, + ..field_ident_1 + } + ); + assert_success!(i_prot.read_field_end()); + + let read_ident_2 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_2, + TFieldIdentifier { + name: None, + ..field_ident_2 + } + ); + assert_success!(i_prot.read_field_end()); + + let read_ident_3 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_3, + TFieldIdentifier { + name: None, + ..field_ident_3 + } + ); + assert_success!(i_prot.read_field_end()); + + let read_ident_4 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_4, + TFieldIdentifier { + name: None, + ..field_ident_4 + } + ); + assert_success!(i_prot.read_field_end()); + + let read_ident_5 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_5, + TFieldIdentifier { + name: None, + ..field_ident_5 + } + ); + assert_success!(i_prot.read_field_end()); + + let read_ident_6 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_6, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); + + assert_success!(i_prot.read_struct_end()); + } + + #[test] + fn must_write_nested_structs_0() { + // last field of the containing struct is a delta + // first field of the the contained struct is a delta + + let (_, mut o_prot) = test_objects(); + + // start containing struct + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // containing struct + // since the delta is > 0 and < 15 it gets a delta write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I64, 1))); + assert_success!(o_prot.write_field_end()); + + // containing struct + // since this delta > 0 and < 15 it gets a delta write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I32, 9))); + assert_success!(o_prot.write_field_end()); + + // start contained struct + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // contained struct + // since the delta is > 0 and < 15 it gets a delta write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I08, 7))); + assert_success!(o_prot.write_field_end()); + + // contained struct + // since this delta > 15 it gets a full write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Double, 24))); + assert_success!(o_prot.write_field_end()); + + // end contained struct + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + // end containing struct + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + #[cfg_attr(rustfmt, rustfmt::skip)] + let expected: [u8; 7] = [ + 0x16, /* field delta (1) | field type */ + 0x85, /* field delta (8) | field type */ + 0x73, /* field delta (7) | field type */ + 0x07, /* field type */ + 0x30, /* zig-zag varint field id */ + 0x00, /* field stop - contained */ + 0x00 /* field stop - containing */, + ]; + + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_round_trip_nested_structs_0() { + // last field of the containing struct is a delta + // first field of the the contained struct is a delta + + let (mut i_prot, mut o_prot) = test_objects(); + + // start containing struct + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // containing struct + // since the delta is > 0 and < 15 it gets a delta write + let field_ident_1 = TFieldIdentifier::new("foo", TType::I64, 1); + assert_success!(o_prot.write_field_begin(&field_ident_1)); + assert_success!(o_prot.write_field_end()); + + // containing struct + // since this delta > 0 and < 15 it gets a delta write + let field_ident_2 = TFieldIdentifier::new("foo", TType::I32, 9); + assert_success!(o_prot.write_field_begin(&field_ident_2)); + assert_success!(o_prot.write_field_end()); + + // start contained struct + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // contained struct + // since the delta is > 0 and < 15 it gets a delta write + let field_ident_3 = TFieldIdentifier::new("foo", TType::I08, 7); + assert_success!(o_prot.write_field_begin(&field_ident_3)); + assert_success!(o_prot.write_field_end()); + + // contained struct + // since this delta > 15 it gets a full write + let field_ident_4 = TFieldIdentifier::new("foo", TType::Double, 24); + assert_success!(o_prot.write_field_begin(&field_ident_4)); + assert_success!(o_prot.write_field_end()); + + // end contained struct + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + // end containing struct + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + copy_write_buffer_to_read_buffer!(o_prot); + + // read containing struct back + assert_success!(i_prot.read_struct_begin()); + + let read_ident_1 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_1, + TFieldIdentifier { + name: None, + ..field_ident_1 + } + ); + assert_success!(i_prot.read_field_end()); + + let read_ident_2 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_2, + TFieldIdentifier { + name: None, + ..field_ident_2 + } + ); + assert_success!(i_prot.read_field_end()); + + // read contained struct back + assert_success!(i_prot.read_struct_begin()); + + let read_ident_3 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_3, + TFieldIdentifier { + name: None, + ..field_ident_3 + } + ); + assert_success!(i_prot.read_field_end()); + + let read_ident_4 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_4, + TFieldIdentifier { + name: None, + ..field_ident_4 + } + ); + assert_success!(i_prot.read_field_end()); + + // end contained struct + let read_ident_6 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_6, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); + assert_success!(i_prot.read_struct_end()); + + // end containing struct + let read_ident_7 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_7, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); + assert_success!(i_prot.read_struct_end()); + } + + #[test] + fn must_write_nested_structs_1() { + // last field of the containing struct is a delta + // first field of the the contained struct is a full write + + let (_, mut o_prot) = test_objects(); + + // start containing struct + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // containing struct + // since the delta is > 0 and < 15 it gets a delta write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I64, 1))); + assert_success!(o_prot.write_field_end()); + + // containing struct + // since this delta > 0 and < 15 it gets a delta write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I32, 9))); + assert_success!(o_prot.write_field_end()); + + // start contained struct + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // contained struct + // since this delta > 15 it gets a full write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Double, 24))); + assert_success!(o_prot.write_field_end()); + + // contained struct + // since the delta is > 0 and < 15 it gets a delta write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I08, 27))); + assert_success!(o_prot.write_field_end()); + + // end contained struct + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + // end containing struct + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + #[cfg_attr(rustfmt, rustfmt::skip)] + let expected: [u8; 7] = [ + 0x16, /* field delta (1) | field type */ + 0x85, /* field delta (8) | field type */ + 0x07, /* field type */ + 0x30, /* zig-zag varint field id */ + 0x33, /* field delta (3) | field type */ + 0x00, /* field stop - contained */ + 0x00 /* field stop - containing */, + ]; + + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_round_trip_nested_structs_1() { + // last field of the containing struct is a delta + // first field of the the contained struct is a full write + + let (mut i_prot, mut o_prot) = test_objects(); + + // start containing struct + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // containing struct + // since the delta is > 0 and < 15 it gets a delta write + let field_ident_1 = TFieldIdentifier::new("foo", TType::I64, 1); + assert_success!(o_prot.write_field_begin(&field_ident_1)); + assert_success!(o_prot.write_field_end()); + + // containing struct + // since this delta > 0 and < 15 it gets a delta write + let field_ident_2 = TFieldIdentifier::new("foo", TType::I32, 9); + assert_success!(o_prot.write_field_begin(&field_ident_2)); + assert_success!(o_prot.write_field_end()); + + // start contained struct + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // contained struct + // since this delta > 15 it gets a full write + let field_ident_3 = TFieldIdentifier::new("foo", TType::Double, 24); + assert_success!(o_prot.write_field_begin(&field_ident_3)); + assert_success!(o_prot.write_field_end()); + + // contained struct + // since the delta is > 0 and < 15 it gets a delta write + let field_ident_4 = TFieldIdentifier::new("foo", TType::I08, 27); + assert_success!(o_prot.write_field_begin(&field_ident_4)); + assert_success!(o_prot.write_field_end()); + + // end contained struct + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + // end containing struct + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + copy_write_buffer_to_read_buffer!(o_prot); + + // read containing struct back + assert_success!(i_prot.read_struct_begin()); + + let read_ident_1 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_1, + TFieldIdentifier { + name: None, + ..field_ident_1 + } + ); + assert_success!(i_prot.read_field_end()); + + let read_ident_2 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_2, + TFieldIdentifier { + name: None, + ..field_ident_2 + } + ); + assert_success!(i_prot.read_field_end()); + + // read contained struct back + assert_success!(i_prot.read_struct_begin()); + + let read_ident_3 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_3, + TFieldIdentifier { + name: None, + ..field_ident_3 + } + ); + assert_success!(i_prot.read_field_end()); + + let read_ident_4 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_4, + TFieldIdentifier { + name: None, + ..field_ident_4 + } + ); + assert_success!(i_prot.read_field_end()); + + // end contained struct + let read_ident_6 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_6, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); + assert_success!(i_prot.read_struct_end()); + + // end containing struct + let read_ident_7 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_7, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); + assert_success!(i_prot.read_struct_end()); + } + + #[test] + fn must_write_nested_structs_2() { + // last field of the containing struct is a full write + // first field of the the contained struct is a delta write + + let (_, mut o_prot) = test_objects(); + + // start containing struct + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // containing struct + // since the delta is > 0 and < 15 it gets a delta write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I64, 1))); + assert_success!(o_prot.write_field_end()); + + // containing struct + // since this delta > 15 it gets a full write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::String, 21))); + assert_success!(o_prot.write_field_end()); + + // start contained struct + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // contained struct + // since this delta > 0 and < 15 it gets a delta write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Double, 7))); + assert_success!(o_prot.write_field_end()); + + // contained struct + // since the delta is > 0 and < 15 it gets a delta write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I08, 10))); + assert_success!(o_prot.write_field_end()); + + // end contained struct + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + // end containing struct + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + #[cfg_attr(rustfmt, rustfmt::skip)] + let expected: [u8; 7] = [ + 0x16, /* field delta (1) | field type */ + 0x08, /* field type */ + 0x2A, /* zig-zag varint field id */ + 0x77, /* field delta(7) | field type */ + 0x33, /* field delta (3) | field type */ + 0x00, /* field stop - contained */ + 0x00 /* field stop - containing */, + ]; + + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_round_trip_nested_structs_2() { + let (mut i_prot, mut o_prot) = test_objects(); + + // start containing struct + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // containing struct + // since the delta is > 0 and < 15 it gets a delta write + let field_ident_1 = TFieldIdentifier::new("foo", TType::I64, 1); + assert_success!(o_prot.write_field_begin(&field_ident_1)); + assert_success!(o_prot.write_field_end()); + + // containing struct + // since this delta > 15 it gets a full write + let field_ident_2 = TFieldIdentifier::new("foo", TType::String, 21); + assert_success!(o_prot.write_field_begin(&field_ident_2)); + assert_success!(o_prot.write_field_end()); + + // start contained struct + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // contained struct + // since this delta > 0 and < 15 it gets a delta write + let field_ident_3 = TFieldIdentifier::new("foo", TType::Double, 7); + assert_success!(o_prot.write_field_begin(&field_ident_3)); + assert_success!(o_prot.write_field_end()); + + // contained struct + // since the delta is > 0 and < 15 it gets a delta write + let field_ident_4 = TFieldIdentifier::new("foo", TType::I08, 10); + assert_success!(o_prot.write_field_begin(&field_ident_4)); + assert_success!(o_prot.write_field_end()); + + // end contained struct + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + // end containing struct + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + copy_write_buffer_to_read_buffer!(o_prot); + + // read containing struct back + assert_success!(i_prot.read_struct_begin()); + + let read_ident_1 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_1, + TFieldIdentifier { + name: None, + ..field_ident_1 + } + ); + assert_success!(i_prot.read_field_end()); + + let read_ident_2 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_2, + TFieldIdentifier { + name: None, + ..field_ident_2 + } + ); + assert_success!(i_prot.read_field_end()); + + // read contained struct back + assert_success!(i_prot.read_struct_begin()); + + let read_ident_3 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_3, + TFieldIdentifier { + name: None, + ..field_ident_3 + } + ); + assert_success!(i_prot.read_field_end()); + + let read_ident_4 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_4, + TFieldIdentifier { + name: None, + ..field_ident_4 + } + ); + assert_success!(i_prot.read_field_end()); + + // end contained struct + let read_ident_6 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_6, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); + assert_success!(i_prot.read_struct_end()); + + // end containing struct + let read_ident_7 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_7, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); + assert_success!(i_prot.read_struct_end()); + } + + #[test] + fn must_write_nested_structs_3() { + // last field of the containing struct is a full write + // first field of the the contained struct is a full write + + let (_, mut o_prot) = test_objects(); + + // start containing struct + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // containing struct + // since the delta is > 0 and < 15 it gets a delta write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I64, 1))); + assert_success!(o_prot.write_field_end()); + + // containing struct + // since this delta > 15 it gets a full write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::String, 21))); + assert_success!(o_prot.write_field_end()); + + // start contained struct + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // contained struct + // since this delta > 15 it gets a full write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Double, 21))); + assert_success!(o_prot.write_field_end()); + + // contained struct + // since the delta is > 0 and < 15 it gets a delta write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I08, 27))); + assert_success!(o_prot.write_field_end()); + + // end contained struct + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + // end containing struct + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + #[cfg_attr(rustfmt, rustfmt::skip)] + let expected: [u8; 8] = [ + 0x16, /* field delta (1) | field type */ + 0x08, /* field type */ + 0x2A, /* zig-zag varint field id */ + 0x07, /* field type */ + 0x2A, /* zig-zag varint field id */ + 0x63, /* field delta (6) | field type */ + 0x00, /* field stop - contained */ + 0x00 /* field stop - containing */, + ]; + + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_round_trip_nested_structs_3() { + // last field of the containing struct is a full write + // first field of the the contained struct is a full write + + let (mut i_prot, mut o_prot) = test_objects(); + + // start containing struct + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // containing struct + // since the delta is > 0 and < 15 it gets a delta write + let field_ident_1 = TFieldIdentifier::new("foo", TType::I64, 1); + assert_success!(o_prot.write_field_begin(&field_ident_1)); + assert_success!(o_prot.write_field_end()); + + // containing struct + // since this delta > 15 it gets a full write + let field_ident_2 = TFieldIdentifier::new("foo", TType::String, 21); + assert_success!(o_prot.write_field_begin(&field_ident_2)); + assert_success!(o_prot.write_field_end()); + + // start contained struct + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // contained struct + // since this delta > 15 it gets a full write + let field_ident_3 = TFieldIdentifier::new("foo", TType::Double, 21); + assert_success!(o_prot.write_field_begin(&field_ident_3)); + assert_success!(o_prot.write_field_end()); + + // contained struct + // since the delta is > 0 and < 15 it gets a delta write + let field_ident_4 = TFieldIdentifier::new("foo", TType::I08, 27); + assert_success!(o_prot.write_field_begin(&field_ident_4)); + assert_success!(o_prot.write_field_end()); + + // end contained struct + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + // end containing struct + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + copy_write_buffer_to_read_buffer!(o_prot); + + // read containing struct back + assert_success!(i_prot.read_struct_begin()); + + let read_ident_1 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_1, + TFieldIdentifier { + name: None, + ..field_ident_1 + } + ); + assert_success!(i_prot.read_field_end()); + + let read_ident_2 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_2, + TFieldIdentifier { + name: None, + ..field_ident_2 + } + ); + assert_success!(i_prot.read_field_end()); + + // read contained struct back + assert_success!(i_prot.read_struct_begin()); + + let read_ident_3 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_3, + TFieldIdentifier { + name: None, + ..field_ident_3 + } + ); + assert_success!(i_prot.read_field_end()); + + let read_ident_4 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_4, + TFieldIdentifier { + name: None, + ..field_ident_4 + } + ); + assert_success!(i_prot.read_field_end()); + + // end contained struct + let read_ident_6 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_6, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); + assert_success!(i_prot.read_struct_end()); + + // end containing struct + let read_ident_7 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_7, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); + assert_success!(i_prot.read_struct_end()); + } + + #[test] + fn must_write_bool_field() { + let (_, mut o_prot) = test_objects(); + + // no bytes should be written however + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + + // write three fields with field ids that cannot be encoded as deltas + + // since the delta is > 0 and < 16 it gets a delta write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Bool, 1))); + assert_success!(o_prot.write_bool(true)); + assert_success!(o_prot.write_field_end()); + + // since this delta > 0 and < 15 it gets a delta write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Bool, 9))); + assert_success!(o_prot.write_bool(false)); + assert_success!(o_prot.write_field_end()); + + // since this delta > 15 it gets a full write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Bool, 26))); + assert_success!(o_prot.write_bool(true)); + assert_success!(o_prot.write_field_end()); + + // since this delta > 15 it gets a full write + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Bool, 45))); + assert_success!(o_prot.write_bool(false)); + assert_success!(o_prot.write_field_end()); + + // now, finish the struct off + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + #[cfg_attr(rustfmt, rustfmt::skip)] + let expected: [u8; 7] = [ + 0x11, /* field delta (1) | true */ + 0x82, /* field delta (8) | false */ + 0x01, /* true */ + 0x34, /* field id */ + 0x02, /* false */ + 0x5A, /* field id */ + 0x00 /* stop field */, + ]; + + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_round_trip_bool_field() { + let (mut i_prot, mut o_prot) = test_objects(); + + // no bytes should be written however + let struct_ident = TStructIdentifier::new("foo"); + assert_success!(o_prot.write_struct_begin(&struct_ident)); + + // write two fields + + // since the delta is > 0 and < 16 it gets a delta write + let field_ident_1 = TFieldIdentifier::new("foo", TType::Bool, 1); + assert_success!(o_prot.write_field_begin(&field_ident_1)); + assert_success!(o_prot.write_bool(true)); + assert_success!(o_prot.write_field_end()); + + // since this delta > 0 and < 15 it gets a delta write + let field_ident_2 = TFieldIdentifier::new("foo", TType::Bool, 9); + assert_success!(o_prot.write_field_begin(&field_ident_2)); + assert_success!(o_prot.write_bool(false)); + assert_success!(o_prot.write_field_end()); + + // since this delta > 15 it gets a full write + let field_ident_3 = TFieldIdentifier::new("foo", TType::Bool, 26); + assert_success!(o_prot.write_field_begin(&field_ident_3)); + assert_success!(o_prot.write_bool(true)); + assert_success!(o_prot.write_field_end()); + + // since this delta > 15 it gets a full write + let field_ident_4 = TFieldIdentifier::new("foo", TType::Bool, 45); + assert_success!(o_prot.write_field_begin(&field_ident_4)); + assert_success!(o_prot.write_bool(false)); + assert_success!(o_prot.write_field_end()); + + // now, finish the struct off + assert_success!(o_prot.write_field_stop()); + assert_success!(o_prot.write_struct_end()); + + copy_write_buffer_to_read_buffer!(o_prot); + + // read the struct back + assert_success!(i_prot.read_struct_begin()); + + let read_ident_1 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_1, + TFieldIdentifier { + name: None, + ..field_ident_1 + } + ); + let read_value_1 = assert_success!(i_prot.read_bool()); + assert_eq!(read_value_1, true); + assert_success!(i_prot.read_field_end()); + + let read_ident_2 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_2, + TFieldIdentifier { + name: None, + ..field_ident_2 + } + ); + let read_value_2 = assert_success!(i_prot.read_bool()); + assert_eq!(read_value_2, false); + assert_success!(i_prot.read_field_end()); + + let read_ident_3 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_3, + TFieldIdentifier { + name: None, + ..field_ident_3 + } + ); + let read_value_3 = assert_success!(i_prot.read_bool()); + assert_eq!(read_value_3, true); + assert_success!(i_prot.read_field_end()); + + let read_ident_4 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_4, + TFieldIdentifier { + name: None, + ..field_ident_4 + } + ); + let read_value_4 = assert_success!(i_prot.read_bool()); + assert_eq!(read_value_4, false); + assert_success!(i_prot.read_field_end()); + + let read_ident_5 = assert_success!(i_prot.read_field_begin()); + assert_eq!( + read_ident_5, + TFieldIdentifier { + name: None, + field_type: TType::Stop, + id: None, + } + ); + + assert_success!(i_prot.read_struct_end()); + } + + #[test] + #[should_panic] + fn must_fail_if_write_field_end_without_writing_bool_value() { + let (_, mut o_prot) = test_objects(); + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Bool, 1))); + o_prot.write_field_end().unwrap(); + } + + #[test] + #[should_panic] + fn must_fail_if_write_stop_field_without_writing_bool_value() { + let (_, mut o_prot) = test_objects(); + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Bool, 1))); + o_prot.write_field_stop().unwrap(); + } + + #[test] + #[should_panic] + fn must_fail_if_write_struct_end_without_writing_bool_value() { + let (_, mut o_prot) = test_objects(); + assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo"))); + assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Bool, 1))); + o_prot.write_struct_end().unwrap(); + } + + #[test] + #[should_panic] + fn must_fail_if_write_struct_end_without_any_fields() { + let (_, mut o_prot) = test_objects(); + o_prot.write_struct_end().unwrap(); + } + + #[test] + fn must_write_field_end() { + assert_no_write(|o| o.write_field_end()); + } + + #[test] + fn must_write_small_sized_list_begin() { + let (_, mut o_prot) = test_objects(); + + assert_success!(o_prot.write_list_begin(&TListIdentifier::new(TType::I64, 4))); + + let expected: [u8; 1] = [0x46 /* size | elem_type */]; + + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_round_trip_small_sized_list_begin() { + let (mut i_prot, mut o_prot) = test_objects(); + + let ident = TListIdentifier::new(TType::I08, 10); + + assert_success!(o_prot.write_list_begin(&ident)); + + copy_write_buffer_to_read_buffer!(o_prot); + + let res = assert_success!(i_prot.read_list_begin()); + assert_eq!(&res, &ident); + } + + #[test] + fn must_write_large_sized_list_begin() { + let (_, mut o_prot) = test_objects(); + + let res = o_prot.write_list_begin(&TListIdentifier::new(TType::List, 9999)); + assert!(res.is_ok()); + + let expected: [u8; 3] = [ + 0xF9, /* 0xF0 | elem_type */ + 0x8F, 0x4E, /* size as varint */ + ]; + + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_round_trip_large_sized_list_begin() { + let (mut i_prot, mut o_prot) = test_objects(); + + let ident = TListIdentifier::new(TType::Set, 47381); + + assert_success!(o_prot.write_list_begin(&ident)); + + copy_write_buffer_to_read_buffer!(o_prot); + + let res = assert_success!(i_prot.read_list_begin()); + assert_eq!(&res, &ident); + } + + #[test] + fn must_write_list_end() { + assert_no_write(|o| o.write_list_end()); + } + + #[test] + fn must_write_small_sized_set_begin() { + let (_, mut o_prot) = test_objects(); + + assert_success!(o_prot.write_set_begin(&TSetIdentifier::new(TType::Struct, 2))); + + let expected: [u8; 1] = [0x2C /* size | elem_type */]; + + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_round_trip_small_sized_set_begin() { + let (mut i_prot, mut o_prot) = test_objects(); + + let ident = TSetIdentifier::new(TType::I16, 7); + + assert_success!(o_prot.write_set_begin(&ident)); + + copy_write_buffer_to_read_buffer!(o_prot); + + let res = assert_success!(i_prot.read_set_begin()); + assert_eq!(&res, &ident); + } + + #[test] + fn must_write_large_sized_set_begin() { + let (_, mut o_prot) = test_objects(); + + assert_success!(o_prot.write_set_begin(&TSetIdentifier::new(TType::Double, 23891))); + + let expected: [u8; 4] = [ + 0xF7, /* 0xF0 | elem_type */ + 0xD3, 0xBA, 0x01, /* size as varint */ + ]; + + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_round_trip_large_sized_set_begin() { + let (mut i_prot, mut o_prot) = test_objects(); + + let ident = TSetIdentifier::new(TType::Map, 3928429); + + assert_success!(o_prot.write_set_begin(&ident)); + + copy_write_buffer_to_read_buffer!(o_prot); + + let res = assert_success!(i_prot.read_set_begin()); + assert_eq!(&res, &ident); + } + + #[test] + fn must_write_set_end() { + assert_no_write(|o| o.write_set_end()); + } + + #[test] + fn must_write_zero_sized_map_begin() { + let (_, mut o_prot) = test_objects(); + + assert_success!(o_prot.write_map_begin(&TMapIdentifier::new(TType::String, TType::I32, 0))); + + let expected: [u8; 1] = [0x00]; // since size is zero we don't write anything + + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_read_zero_sized_map_begin() { + let (mut i_prot, mut o_prot) = test_objects(); + + assert_success!(o_prot.write_map_begin(&TMapIdentifier::new(TType::Double, TType::I32, 0))); + + copy_write_buffer_to_read_buffer!(o_prot); + + let res = assert_success!(i_prot.read_map_begin()); + assert_eq!( + &res, + &TMapIdentifier { + key_type: None, + value_type: None, + size: 0, + } + ); + } + + #[test] + fn must_write_map_begin() { + let (_, mut o_prot) = test_objects(); + + assert_success!(o_prot.write_map_begin(&TMapIdentifier::new( + TType::Double, + TType::String, + 238 + ))); + + let expected: [u8; 3] = [ + 0xEE, 0x01, /* size as varint */ + 0x78, /* key type | val type */ + ]; + + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_round_trip_map_begin() { + let (mut i_prot, mut o_prot) = test_objects(); + + let ident = TMapIdentifier::new(TType::Map, TType::List, 1928349); + + assert_success!(o_prot.write_map_begin(&ident)); + + copy_write_buffer_to_read_buffer!(o_prot); + + let res = assert_success!(i_prot.read_map_begin()); + assert_eq!(&res, &ident); + } + + #[test] + fn must_write_map_end() { + assert_no_write(|o| o.write_map_end()); + } + + #[test] + fn must_write_map_with_bool_key_and_value() { + let (_, mut o_prot) = test_objects(); + + assert_success!(o_prot.write_map_begin(&TMapIdentifier::new(TType::Bool, TType::Bool, 1))); + assert_success!(o_prot.write_bool(true)); + assert_success!(o_prot.write_bool(false)); + assert_success!(o_prot.write_map_end()); + + let expected: [u8; 4] = [ + 0x01, /* size as varint */ + 0x11, /* key type | val type */ + 0x01, /* key: true */ + 0x02, /* val: false */ + ]; + + assert_eq_written_bytes!(o_prot, expected); + } + + #[test] + fn must_round_trip_map_with_bool_value() { + let (mut i_prot, mut o_prot) = test_objects(); + + let map_ident = TMapIdentifier::new(TType::Bool, TType::Bool, 2); + assert_success!(o_prot.write_map_begin(&map_ident)); + assert_success!(o_prot.write_bool(true)); + assert_success!(o_prot.write_bool(false)); + assert_success!(o_prot.write_bool(false)); + assert_success!(o_prot.write_bool(true)); + assert_success!(o_prot.write_map_end()); + + copy_write_buffer_to_read_buffer!(o_prot); + + // map header + let rcvd_ident = assert_success!(i_prot.read_map_begin()); + assert_eq!(&rcvd_ident, &map_ident); + // key 1 + let b = assert_success!(i_prot.read_bool()); + assert_eq!(b, true); + // val 1 + let b = assert_success!(i_prot.read_bool()); + assert_eq!(b, false); + // key 2 + let b = assert_success!(i_prot.read_bool()); + assert_eq!(b, false); + // val 2 + let b = assert_success!(i_prot.read_bool()); + assert_eq!(b, true); + // map end + assert_success!(i_prot.read_map_end()); + } + + #[test] + fn must_read_map_end() { + let (mut i_prot, _) = test_objects(); + assert!(i_prot.read_map_end().is_ok()); // will blow up if we try to read from empty buffer + } + + fn test_objects() -> ( + TCompactInputProtocol<ReadHalf<TBufferChannel>>, + TCompactOutputProtocol<WriteHalf<TBufferChannel>>, + ) { + let mem = TBufferChannel::with_capacity(80, 80); + + let (r_mem, w_mem) = mem.split().unwrap(); + + let i_prot = TCompactInputProtocol::new(r_mem); + let o_prot = TCompactOutputProtocol::new(w_mem); + + (i_prot, o_prot) + } + + fn assert_no_write<F>(mut write_fn: F) + where + F: FnMut(&mut TCompactOutputProtocol<WriteHalf<TBufferChannel>>) -> ::Result<()>, + { + let (_, mut o_prot) = test_objects(); + assert!(write_fn(&mut o_prot).is_ok()); + assert_eq!(o_prot.transport.write_bytes().len(), 0); + } +} diff --git a/src/jaegertracing/thrift/lib/rs/src/protocol/mod.rs b/src/jaegertracing/thrift/lib/rs/src/protocol/mod.rs new file mode 100644 index 000000000..2d8513f2c --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/src/protocol/mod.rs @@ -0,0 +1,968 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Types used to send and receive primitives between a Thrift client and server. +//! +//! # Examples +//! +//! Create and use a `TInputProtocol`. +//! +//! ```no_run +//! use thrift::protocol::{TBinaryInputProtocol, TInputProtocol}; +//! use thrift::transport::TTcpChannel; +//! +//! // create the I/O channel +//! let mut channel = TTcpChannel::new(); +//! channel.open("127.0.0.1:9090").unwrap(); +//! +//! // create the protocol to decode bytes into types +//! let mut protocol = TBinaryInputProtocol::new(channel, true); +//! +//! // read types from the wire +//! let field_identifier = protocol.read_field_begin().unwrap(); +//! let field_contents = protocol.read_string().unwrap(); +//! let field_end = protocol.read_field_end().unwrap(); +//! ``` +//! +//! Create and use a `TOutputProtocol`. +//! +//! ```no_run +//! use thrift::protocol::{TBinaryOutputProtocol, TFieldIdentifier, TOutputProtocol, TType}; +//! use thrift::transport::TTcpChannel; +//! +//! // create the I/O channel +//! let mut channel = TTcpChannel::new(); +//! channel.open("127.0.0.1:9090").unwrap(); +//! +//! // create the protocol to encode types into bytes +//! let mut protocol = TBinaryOutputProtocol::new(channel, true); +//! +//! // write types +//! protocol.write_field_begin(&TFieldIdentifier::new("string_thing", TType::String, 1)).unwrap(); +//! protocol.write_string("foo").unwrap(); +//! protocol.write_field_end().unwrap(); +//! ``` + +use std::convert::{From, TryFrom}; +use std::fmt; +use std::fmt::{Display, Formatter}; + +use transport::{TReadTransport, TWriteTransport}; +use {ProtocolError, ProtocolErrorKind}; + +#[cfg(test)] +macro_rules! assert_eq_written_bytes { + ($o_prot:ident, $expected_bytes:ident) => {{ + assert_eq!($o_prot.transport.write_bytes(), &$expected_bytes); + }}; +} + +// FIXME: should take both read and write +#[cfg(test)] +macro_rules! copy_write_buffer_to_read_buffer { + ($o_prot:ident) => {{ + $o_prot.transport.copy_write_buffer_to_read_buffer(); + }}; +} + +#[cfg(test)] +macro_rules! set_readable_bytes { + ($i_prot:ident, $bytes:expr) => { + $i_prot.transport.set_readable_bytes($bytes); + }; +} + +mod binary; +mod compact; +mod multiplexed; +mod stored; + +pub use self::binary::{ + TBinaryInputProtocol, TBinaryInputProtocolFactory, TBinaryOutputProtocol, + TBinaryOutputProtocolFactory, +}; +pub use self::compact::{ + TCompactInputProtocol, TCompactInputProtocolFactory, TCompactOutputProtocol, + TCompactOutputProtocolFactory, +}; +pub use self::multiplexed::TMultiplexedOutputProtocol; +pub use self::stored::TStoredInputProtocol; + +// Default maximum depth to which `TInputProtocol::skip` will skip a Thrift +// field. A default is necessary because Thrift structs or collections may +// contain nested structs and collections, which could result in indefinite +// recursion. +const MAXIMUM_SKIP_DEPTH: i8 = 64; + +/// Converts a stream of bytes into Thrift identifiers, primitives, +/// containers, or structs. +/// +/// This trait does not deal with higher-level Thrift concepts like structs or +/// exceptions - only with primitives and message or container boundaries. Once +/// bytes are read they are deserialized and an identifier (for example +/// `TMessageIdentifier`) or a primitive is returned. +/// +/// All methods return a `thrift::Result`. If an `Err` is returned the protocol +/// instance and its underlying transport should be terminated. +/// +/// # Examples +/// +/// Create and use a `TInputProtocol` +/// +/// ```no_run +/// use thrift::protocol::{TBinaryInputProtocol, TInputProtocol}; +/// use thrift::transport::TTcpChannel; +/// +/// let mut channel = TTcpChannel::new(); +/// channel.open("127.0.0.1:9090").unwrap(); +/// +/// let mut protocol = TBinaryInputProtocol::new(channel, true); +/// +/// let field_identifier = protocol.read_field_begin().unwrap(); +/// let field_contents = protocol.read_string().unwrap(); +/// let field_end = protocol.read_field_end().unwrap(); +/// ``` +pub trait TInputProtocol { + /// Read the beginning of a Thrift message. + fn read_message_begin(&mut self) -> ::Result<TMessageIdentifier>; + /// Read the end of a Thrift message. + fn read_message_end(&mut self) -> ::Result<()>; + /// Read the beginning of a Thrift struct. + fn read_struct_begin(&mut self) -> ::Result<Option<TStructIdentifier>>; + /// Read the end of a Thrift struct. + fn read_struct_end(&mut self) -> ::Result<()>; + /// Read the beginning of a Thrift struct field. + fn read_field_begin(&mut self) -> ::Result<TFieldIdentifier>; + /// Read the end of a Thrift struct field. + fn read_field_end(&mut self) -> ::Result<()>; + /// Read a bool. + fn read_bool(&mut self) -> ::Result<bool>; + /// Read a fixed-length byte array. + fn read_bytes(&mut self) -> ::Result<Vec<u8>>; + /// Read a word. + fn read_i8(&mut self) -> ::Result<i8>; + /// Read a 16-bit signed integer. + fn read_i16(&mut self) -> ::Result<i16>; + /// Read a 32-bit signed integer. + fn read_i32(&mut self) -> ::Result<i32>; + /// Read a 64-bit signed integer. + fn read_i64(&mut self) -> ::Result<i64>; + /// Read a 64-bit float. + fn read_double(&mut self) -> ::Result<f64>; + /// Read a fixed-length string (not null terminated). + fn read_string(&mut self) -> ::Result<String>; + /// Read the beginning of a list. + fn read_list_begin(&mut self) -> ::Result<TListIdentifier>; + /// Read the end of a list. + fn read_list_end(&mut self) -> ::Result<()>; + /// Read the beginning of a set. + fn read_set_begin(&mut self) -> ::Result<TSetIdentifier>; + /// Read the end of a set. + fn read_set_end(&mut self) -> ::Result<()>; + /// Read the beginning of a map. + fn read_map_begin(&mut self) -> ::Result<TMapIdentifier>; + /// Read the end of a map. + fn read_map_end(&mut self) -> ::Result<()>; + /// Skip a field with type `field_type` recursively until the default + /// maximum skip depth is reached. + fn skip(&mut self, field_type: TType) -> ::Result<()> { + self.skip_till_depth(field_type, MAXIMUM_SKIP_DEPTH) + } + /// Skip a field with type `field_type` recursively up to `depth` levels. + fn skip_till_depth(&mut self, field_type: TType, depth: i8) -> ::Result<()> { + if depth == 0 { + return Err(::Error::Protocol(ProtocolError { + kind: ProtocolErrorKind::DepthLimit, + message: format!("cannot parse past {:?}", field_type), + })); + } + + match field_type { + TType::Bool => self.read_bool().map(|_| ()), + TType::I08 => self.read_i8().map(|_| ()), + TType::I16 => self.read_i16().map(|_| ()), + TType::I32 => self.read_i32().map(|_| ()), + TType::I64 => self.read_i64().map(|_| ()), + TType::Double => self.read_double().map(|_| ()), + TType::String => self.read_string().map(|_| ()), + TType::Struct => { + self.read_struct_begin()?; + loop { + let field_ident = self.read_field_begin()?; + if field_ident.field_type == TType::Stop { + break; + } + self.skip_till_depth(field_ident.field_type, depth - 1)?; + } + self.read_struct_end() + } + TType::List => { + let list_ident = self.read_list_begin()?; + for _ in 0..list_ident.size { + self.skip_till_depth(list_ident.element_type, depth - 1)?; + } + self.read_list_end() + } + TType::Set => { + let set_ident = self.read_set_begin()?; + for _ in 0..set_ident.size { + self.skip_till_depth(set_ident.element_type, depth - 1)?; + } + self.read_set_end() + } + TType::Map => { + let map_ident = self.read_map_begin()?; + for _ in 0..map_ident.size { + let key_type = map_ident + .key_type + .expect("non-zero sized map should contain key type"); + let val_type = map_ident + .value_type + .expect("non-zero sized map should contain value type"); + self.skip_till_depth(key_type, depth - 1)?; + self.skip_till_depth(val_type, depth - 1)?; + } + self.read_map_end() + } + u => Err(::Error::Protocol(ProtocolError { + kind: ProtocolErrorKind::Unknown, + message: format!("cannot skip field type {:?}", &u), + })), + } + } + + // utility (DO NOT USE IN GENERATED CODE!!!!) + // + + /// Read an unsigned byte. + /// + /// This method should **never** be used in generated code. + fn read_byte(&mut self) -> ::Result<u8>; +} + +/// Converts Thrift identifiers, primitives, containers or structs into a +/// stream of bytes. +/// +/// This trait does not deal with higher-level Thrift concepts like structs or +/// exceptions - only with primitives and message or container boundaries. +/// Write methods take an identifier (for example, `TMessageIdentifier`) or a +/// primitive. Any or all of the fields in an identifier may be omitted when +/// writing to the transport. Write methods may even be noops. All of this is +/// transparent to the caller; as long as a matching `TInputProtocol` +/// implementation is used, received messages will be decoded correctly. +/// +/// All methods return a `thrift::Result`. If an `Err` is returned the protocol +/// instance and its underlying transport should be terminated. +/// +/// # Examples +/// +/// Create and use a `TOutputProtocol` +/// +/// ```no_run +/// use thrift::protocol::{TBinaryOutputProtocol, TFieldIdentifier, TOutputProtocol, TType}; +/// use thrift::transport::TTcpChannel; +/// +/// let mut channel = TTcpChannel::new(); +/// channel.open("127.0.0.1:9090").unwrap(); +/// +/// let mut protocol = TBinaryOutputProtocol::new(channel, true); +/// +/// protocol.write_field_begin(&TFieldIdentifier::new("string_thing", TType::String, 1)).unwrap(); +/// protocol.write_string("foo").unwrap(); +/// protocol.write_field_end().unwrap(); +/// ``` +pub trait TOutputProtocol { + /// Write the beginning of a Thrift message. + fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> ::Result<()>; + /// Write the end of a Thrift message. + fn write_message_end(&mut self) -> ::Result<()>; + /// Write the beginning of a Thrift struct. + fn write_struct_begin(&mut self, identifier: &TStructIdentifier) -> ::Result<()>; + /// Write the end of a Thrift struct. + fn write_struct_end(&mut self) -> ::Result<()>; + /// Write the beginning of a Thrift field. + fn write_field_begin(&mut self, identifier: &TFieldIdentifier) -> ::Result<()>; + /// Write the end of a Thrift field. + fn write_field_end(&mut self) -> ::Result<()>; + /// Write a STOP field indicating that all the fields in a struct have been + /// written. + fn write_field_stop(&mut self) -> ::Result<()>; + /// Write a bool. + fn write_bool(&mut self, b: bool) -> ::Result<()>; + /// Write a fixed-length byte array. + fn write_bytes(&mut self, b: &[u8]) -> ::Result<()>; + /// Write an 8-bit signed integer. + fn write_i8(&mut self, i: i8) -> ::Result<()>; + /// Write a 16-bit signed integer. + fn write_i16(&mut self, i: i16) -> ::Result<()>; + /// Write a 32-bit signed integer. + fn write_i32(&mut self, i: i32) -> ::Result<()>; + /// Write a 64-bit signed integer. + fn write_i64(&mut self, i: i64) -> ::Result<()>; + /// Write a 64-bit float. + fn write_double(&mut self, d: f64) -> ::Result<()>; + /// Write a fixed-length string. + fn write_string(&mut self, s: &str) -> ::Result<()>; + /// Write the beginning of a list. + fn write_list_begin(&mut self, identifier: &TListIdentifier) -> ::Result<()>; + /// Write the end of a list. + fn write_list_end(&mut self) -> ::Result<()>; + /// Write the beginning of a set. + fn write_set_begin(&mut self, identifier: &TSetIdentifier) -> ::Result<()>; + /// Write the end of a set. + fn write_set_end(&mut self) -> ::Result<()>; + /// Write the beginning of a map. + fn write_map_begin(&mut self, identifier: &TMapIdentifier) -> ::Result<()>; + /// Write the end of a map. + fn write_map_end(&mut self) -> ::Result<()>; + /// Flush buffered bytes to the underlying transport. + fn flush(&mut self) -> ::Result<()>; + + // utility (DO NOT USE IN GENERATED CODE!!!!) + // + + /// Write an unsigned byte. + /// + /// This method should **never** be used in generated code. + fn write_byte(&mut self, b: u8) -> ::Result<()>; // FIXME: REMOVE +} + +impl<P> TInputProtocol for Box<P> +where + P: TInputProtocol + ?Sized, +{ + fn read_message_begin(&mut self) -> ::Result<TMessageIdentifier> { + (**self).read_message_begin() + } + + fn read_message_end(&mut self) -> ::Result<()> { + (**self).read_message_end() + } + + fn read_struct_begin(&mut self) -> ::Result<Option<TStructIdentifier>> { + (**self).read_struct_begin() + } + + fn read_struct_end(&mut self) -> ::Result<()> { + (**self).read_struct_end() + } + + fn read_field_begin(&mut self) -> ::Result<TFieldIdentifier> { + (**self).read_field_begin() + } + + fn read_field_end(&mut self) -> ::Result<()> { + (**self).read_field_end() + } + + fn read_bool(&mut self) -> ::Result<bool> { + (**self).read_bool() + } + + fn read_bytes(&mut self) -> ::Result<Vec<u8>> { + (**self).read_bytes() + } + + fn read_i8(&mut self) -> ::Result<i8> { + (**self).read_i8() + } + + fn read_i16(&mut self) -> ::Result<i16> { + (**self).read_i16() + } + + fn read_i32(&mut self) -> ::Result<i32> { + (**self).read_i32() + } + + fn read_i64(&mut self) -> ::Result<i64> { + (**self).read_i64() + } + + fn read_double(&mut self) -> ::Result<f64> { + (**self).read_double() + } + + fn read_string(&mut self) -> ::Result<String> { + (**self).read_string() + } + + fn read_list_begin(&mut self) -> ::Result<TListIdentifier> { + (**self).read_list_begin() + } + + fn read_list_end(&mut self) -> ::Result<()> { + (**self).read_list_end() + } + + fn read_set_begin(&mut self) -> ::Result<TSetIdentifier> { + (**self).read_set_begin() + } + + fn read_set_end(&mut self) -> ::Result<()> { + (**self).read_set_end() + } + + fn read_map_begin(&mut self) -> ::Result<TMapIdentifier> { + (**self).read_map_begin() + } + + fn read_map_end(&mut self) -> ::Result<()> { + (**self).read_map_end() + } + + fn read_byte(&mut self) -> ::Result<u8> { + (**self).read_byte() + } +} + +impl<P> TOutputProtocol for Box<P> +where + P: TOutputProtocol + ?Sized, +{ + fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> ::Result<()> { + (**self).write_message_begin(identifier) + } + + fn write_message_end(&mut self) -> ::Result<()> { + (**self).write_message_end() + } + + fn write_struct_begin(&mut self, identifier: &TStructIdentifier) -> ::Result<()> { + (**self).write_struct_begin(identifier) + } + + fn write_struct_end(&mut self) -> ::Result<()> { + (**self).write_struct_end() + } + + fn write_field_begin(&mut self, identifier: &TFieldIdentifier) -> ::Result<()> { + (**self).write_field_begin(identifier) + } + + fn write_field_end(&mut self) -> ::Result<()> { + (**self).write_field_end() + } + + fn write_field_stop(&mut self) -> ::Result<()> { + (**self).write_field_stop() + } + + fn write_bool(&mut self, b: bool) -> ::Result<()> { + (**self).write_bool(b) + } + + fn write_bytes(&mut self, b: &[u8]) -> ::Result<()> { + (**self).write_bytes(b) + } + + fn write_i8(&mut self, i: i8) -> ::Result<()> { + (**self).write_i8(i) + } + + fn write_i16(&mut self, i: i16) -> ::Result<()> { + (**self).write_i16(i) + } + + fn write_i32(&mut self, i: i32) -> ::Result<()> { + (**self).write_i32(i) + } + + fn write_i64(&mut self, i: i64) -> ::Result<()> { + (**self).write_i64(i) + } + + fn write_double(&mut self, d: f64) -> ::Result<()> { + (**self).write_double(d) + } + + fn write_string(&mut self, s: &str) -> ::Result<()> { + (**self).write_string(s) + } + + fn write_list_begin(&mut self, identifier: &TListIdentifier) -> ::Result<()> { + (**self).write_list_begin(identifier) + } + + fn write_list_end(&mut self) -> ::Result<()> { + (**self).write_list_end() + } + + fn write_set_begin(&mut self, identifier: &TSetIdentifier) -> ::Result<()> { + (**self).write_set_begin(identifier) + } + + fn write_set_end(&mut self) -> ::Result<()> { + (**self).write_set_end() + } + + fn write_map_begin(&mut self, identifier: &TMapIdentifier) -> ::Result<()> { + (**self).write_map_begin(identifier) + } + + fn write_map_end(&mut self) -> ::Result<()> { + (**self).write_map_end() + } + + fn flush(&mut self) -> ::Result<()> { + (**self).flush() + } + + fn write_byte(&mut self, b: u8) -> ::Result<()> { + (**self).write_byte(b) + } +} + +/// Helper type used by servers to create `TInputProtocol` instances for +/// accepted client connections. +/// +/// # Examples +/// +/// Create a `TInputProtocolFactory` and use it to create a `TInputProtocol`. +/// +/// ```no_run +/// use thrift::protocol::{TBinaryInputProtocolFactory, TInputProtocolFactory}; +/// use thrift::transport::TTcpChannel; +/// +/// let mut channel = TTcpChannel::new(); +/// channel.open("127.0.0.1:9090").unwrap(); +/// +/// let factory = TBinaryInputProtocolFactory::new(); +/// let protocol = factory.create(Box::new(channel)); +/// ``` +pub trait TInputProtocolFactory { + // Create a `TInputProtocol` that reads bytes from `transport`. + fn create(&self, transport: Box<dyn TReadTransport + Send>) -> Box<dyn TInputProtocol + Send>; +} + +impl<T> TInputProtocolFactory for Box<T> +where + T: TInputProtocolFactory + ?Sized, +{ + fn create(&self, transport: Box<dyn TReadTransport + Send>) -> Box<dyn TInputProtocol + Send> { + (**self).create(transport) + } +} + +/// Helper type used by servers to create `TOutputProtocol` instances for +/// accepted client connections. +/// +/// # Examples +/// +/// Create a `TOutputProtocolFactory` and use it to create a `TOutputProtocol`. +/// +/// ```no_run +/// use thrift::protocol::{TBinaryOutputProtocolFactory, TOutputProtocolFactory}; +/// use thrift::transport::TTcpChannel; +/// +/// let mut channel = TTcpChannel::new(); +/// channel.open("127.0.0.1:9090").unwrap(); +/// +/// let factory = TBinaryOutputProtocolFactory::new(); +/// let protocol = factory.create(Box::new(channel)); +/// ``` +pub trait TOutputProtocolFactory { + /// Create a `TOutputProtocol` that writes bytes to `transport`. + fn create(&self, transport: Box<dyn TWriteTransport + Send>) -> Box<dyn TOutputProtocol + Send>; +} + +impl<T> TOutputProtocolFactory for Box<T> +where + T: TOutputProtocolFactory + ?Sized, +{ + fn create(&self, transport: Box<dyn TWriteTransport + Send>) -> Box<dyn TOutputProtocol + Send> { + (**self).create(transport) + } +} + +/// Thrift message identifier. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct TMessageIdentifier { + /// Service call the message is associated with. + pub name: String, + /// Message type. + pub message_type: TMessageType, + /// Ordered sequence number identifying the message. + pub sequence_number: i32, +} + +impl TMessageIdentifier { + /// Create a `TMessageIdentifier` for a Thrift service-call named `name` + /// with message type `message_type` and sequence number `sequence_number`. + pub fn new<S: Into<String>>( + name: S, + message_type: TMessageType, + sequence_number: i32, + ) -> TMessageIdentifier { + TMessageIdentifier { + name: name.into(), + message_type: message_type, + sequence_number: sequence_number, + } + } +} + +/// Thrift struct identifier. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct TStructIdentifier { + /// Name of the encoded Thrift struct. + pub name: String, +} + +impl TStructIdentifier { + /// Create a `TStructIdentifier` for a struct named `name`. + pub fn new<S: Into<String>>(name: S) -> TStructIdentifier { + TStructIdentifier { name: name.into() } + } +} + +/// Thrift field identifier. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct TFieldIdentifier { + /// Name of the Thrift field. + /// + /// `None` if it's not sent over the wire. + pub name: Option<String>, + /// Field type. + /// + /// This may be a primitive, container, or a struct. + pub field_type: TType, + /// Thrift field id. + /// + /// `None` only if `field_type` is `TType::Stop`. + pub id: Option<i16>, +} + +impl TFieldIdentifier { + /// Create a `TFieldIdentifier` for a field named `name` with type + /// `field_type` and field id `id`. + /// + /// `id` should be `None` if `field_type` is `TType::Stop`. + pub fn new<N, S, I>(name: N, field_type: TType, id: I) -> TFieldIdentifier + where + N: Into<Option<S>>, + S: Into<String>, + I: Into<Option<i16>>, + { + TFieldIdentifier { + name: name.into().map(|n| n.into()), + field_type: field_type, + id: id.into(), + } + } +} + +/// Thrift list identifier. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct TListIdentifier { + /// Type of the elements in the list. + pub element_type: TType, + /// Number of elements in the list. + pub size: i32, +} + +impl TListIdentifier { + /// Create a `TListIdentifier` for a list with `size` elements of type + /// `element_type`. + pub fn new(element_type: TType, size: i32) -> TListIdentifier { + TListIdentifier { + element_type: element_type, + size: size, + } + } +} + +/// Thrift set identifier. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct TSetIdentifier { + /// Type of the elements in the set. + pub element_type: TType, + /// Number of elements in the set. + pub size: i32, +} + +impl TSetIdentifier { + /// Create a `TSetIdentifier` for a set with `size` elements of type + /// `element_type`. + pub fn new(element_type: TType, size: i32) -> TSetIdentifier { + TSetIdentifier { + element_type: element_type, + size: size, + } + } +} + +/// Thrift map identifier. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct TMapIdentifier { + /// Map key type. + pub key_type: Option<TType>, + /// Map value type. + pub value_type: Option<TType>, + /// Number of entries in the map. + pub size: i32, +} + +impl TMapIdentifier { + /// Create a `TMapIdentifier` for a map with `size` entries of type + /// `key_type -> value_type`. + pub fn new<K, V>(key_type: K, value_type: V, size: i32) -> TMapIdentifier + where + K: Into<Option<TType>>, + V: Into<Option<TType>>, + { + TMapIdentifier { + key_type: key_type.into(), + value_type: value_type.into(), + size: size, + } + } +} + +/// Thrift message types. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum TMessageType { + /// Service-call request. + Call, + /// Service-call response. + Reply, + /// Unexpected error in the remote service. + Exception, + /// One-way service-call request (no response is expected). + OneWay, +} + +impl Display for TMessageType { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match *self { + TMessageType::Call => write!(f, "Call"), + TMessageType::Reply => write!(f, "Reply"), + TMessageType::Exception => write!(f, "Exception"), + TMessageType::OneWay => write!(f, "OneWay"), + } + } +} + +impl From<TMessageType> for u8 { + fn from(message_type: TMessageType) -> Self { + match message_type { + TMessageType::Call => 0x01, + TMessageType::Reply => 0x02, + TMessageType::Exception => 0x03, + TMessageType::OneWay => 0x04, + } + } +} + +impl TryFrom<u8> for TMessageType { + type Error = ::Error; + fn try_from(b: u8) -> Result<Self, Self::Error> { + match b { + 0x01 => Ok(TMessageType::Call), + 0x02 => Ok(TMessageType::Reply), + 0x03 => Ok(TMessageType::Exception), + 0x04 => Ok(TMessageType::OneWay), + unkn => Err(::Error::Protocol(ProtocolError { + kind: ProtocolErrorKind::InvalidData, + message: format!("cannot convert {} to TMessageType", unkn), + })), + } + } +} + +/// Thrift struct-field types. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum TType { + /// Indicates that there are no more serialized fields in this Thrift struct. + Stop, + /// Void (`()`) field. + Void, + /// Boolean. + Bool, + /// Signed 8-bit int. + I08, + /// Double-precision number. + Double, + /// Signed 16-bit int. + I16, + /// Signed 32-bit int. + I32, + /// Signed 64-bit int. + I64, + /// UTF-8 string. + String, + /// UTF-7 string. *Unsupported*. + Utf7, + /// Thrift struct. + Struct, + /// Map. + Map, + /// Set. + Set, + /// List. + List, + /// UTF-8 string. + Utf8, + /// UTF-16 string. *Unsupported*. + Utf16, +} + +impl Display for TType { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match *self { + TType::Stop => write!(f, "STOP"), + TType::Void => write!(f, "void"), + TType::Bool => write!(f, "bool"), + TType::I08 => write!(f, "i08"), + TType::Double => write!(f, "double"), + TType::I16 => write!(f, "i16"), + TType::I32 => write!(f, "i32"), + TType::I64 => write!(f, "i64"), + TType::String => write!(f, "string"), + TType::Utf7 => write!(f, "UTF7"), + TType::Struct => write!(f, "struct"), + TType::Map => write!(f, "map"), + TType::Set => write!(f, "set"), + TType::List => write!(f, "list"), + TType::Utf8 => write!(f, "UTF8"), + TType::Utf16 => write!(f, "UTF16"), + } + } +} + +/// Compare the expected message sequence number `expected` with the received +/// message sequence number `actual`. +/// +/// Return `()` if `actual == expected`, `Err` otherwise. +pub fn verify_expected_sequence_number(expected: i32, actual: i32) -> ::Result<()> { + if expected == actual { + Ok(()) + } else { + Err(::Error::Application(::ApplicationError { + kind: ::ApplicationErrorKind::BadSequenceId, + message: format!("expected {} got {}", expected, actual), + })) + } +} + +/// Compare the expected service-call name `expected` with the received +/// service-call name `actual`. +/// +/// Return `()` if `actual == expected`, `Err` otherwise. +pub fn verify_expected_service_call(expected: &str, actual: &str) -> ::Result<()> { + if expected == actual { + Ok(()) + } else { + Err(::Error::Application(::ApplicationError { + kind: ::ApplicationErrorKind::WrongMethodName, + message: format!("expected {} got {}", expected, actual), + })) + } +} + +/// Compare the expected message type `expected` with the received message type +/// `actual`. +/// +/// Return `()` if `actual == expected`, `Err` otherwise. +pub fn verify_expected_message_type(expected: TMessageType, actual: TMessageType) -> ::Result<()> { + if expected == actual { + Ok(()) + } else { + Err(::Error::Application(::ApplicationError { + kind: ::ApplicationErrorKind::InvalidMessageType, + message: format!("expected {} got {}", expected, actual), + })) + } +} + +/// Check if a required Thrift struct field exists. +/// +/// Return `()` if it does, `Err` otherwise. +pub fn verify_required_field_exists<T>(field_name: &str, field: &Option<T>) -> ::Result<()> { + match *field { + Some(_) => Ok(()), + None => Err(::Error::Protocol(::ProtocolError { + kind: ::ProtocolErrorKind::Unknown, + message: format!("missing required field {}", field_name), + })), + } +} + +/// Extract the field id from a Thrift field identifier. +/// +/// `field_ident` must *not* have `TFieldIdentifier.field_type` of type `TType::Stop`. +/// +/// Return `TFieldIdentifier.id` if an id exists, `Err` otherwise. +pub fn field_id(field_ident: &TFieldIdentifier) -> ::Result<i16> { + field_ident.id.ok_or_else(|| { + ::Error::Protocol(::ProtocolError { + kind: ::ProtocolErrorKind::Unknown, + message: format!("missing field in in {:?}", field_ident), + }) + }) +} + +#[cfg(test)] +mod tests { + + use std::io::Cursor; + + use super::*; + use transport::{TReadTransport, TWriteTransport}; + + #[test] + fn must_create_usable_input_protocol_from_concrete_input_protocol() { + let r: Box<dyn TReadTransport> = Box::new(Cursor::new([0, 1, 2])); + let mut t = TCompactInputProtocol::new(r); + takes_input_protocol(&mut t) + } + + #[test] + fn must_create_usable_input_protocol_from_boxed_input() { + let r: Box<dyn TReadTransport> = Box::new(Cursor::new([0, 1, 2])); + let mut t: Box<dyn TInputProtocol> = Box::new(TCompactInputProtocol::new(r)); + takes_input_protocol(&mut t) + } + + #[test] + fn must_create_usable_output_protocol_from_concrete_output_protocol() { + let w: Box<dyn TWriteTransport> = Box::new(vec![0u8; 10]); + let mut t = TCompactOutputProtocol::new(w); + takes_output_protocol(&mut t) + } + + #[test] + fn must_create_usable_output_protocol_from_boxed_output() { + let w: Box<dyn TWriteTransport> = Box::new(vec![0u8; 10]); + let mut t: Box<dyn TOutputProtocol> = Box::new(TCompactOutputProtocol::new(w)); + takes_output_protocol(&mut t) + } + + fn takes_input_protocol<R>(t: &mut R) + where + R: TInputProtocol, + { + t.read_byte().unwrap(); + } + + fn takes_output_protocol<W>(t: &mut W) + where + W: TOutputProtocol, + { + t.flush().unwrap(); + } +} diff --git a/src/jaegertracing/thrift/lib/rs/src/protocol/multiplexed.rs b/src/jaegertracing/thrift/lib/rs/src/protocol/multiplexed.rs new file mode 100644 index 000000000..aaee44f73 --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/src/protocol/multiplexed.rs @@ -0,0 +1,239 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::{ + TFieldIdentifier, TListIdentifier, TMapIdentifier, TMessageIdentifier, TMessageType, + TOutputProtocol, TSetIdentifier, TStructIdentifier, +}; + +/// `TOutputProtocol` that prefixes the service name to all outgoing Thrift +/// messages. +/// +/// A `TMultiplexedOutputProtocol` should be used when multiple Thrift services +/// send messages over a single I/O channel. By prefixing service identifiers +/// to outgoing messages receivers are able to demux them and route them to the +/// appropriate service processor. Rust receivers must use a `TMultiplexedProcessor` +/// to process incoming messages, while other languages must use their +/// corresponding multiplexed processor implementations. +/// +/// For example, given a service `TestService` and a service call `test_call`, +/// this implementation would identify messages as originating from +/// `TestService:test_call`. +/// +/// # Examples +/// +/// Create and use a `TMultiplexedOutputProtocol`. +/// +/// ```no_run +/// use thrift::protocol::{TMessageIdentifier, TMessageType, TOutputProtocol}; +/// use thrift::protocol::{TBinaryOutputProtocol, TMultiplexedOutputProtocol}; +/// use thrift::transport::TTcpChannel; +/// +/// let mut channel = TTcpChannel::new(); +/// channel.open("localhost:9090").unwrap(); +/// +/// let protocol = TBinaryOutputProtocol::new(channel, true); +/// let mut protocol = TMultiplexedOutputProtocol::new("service_name", protocol); +/// +/// let ident = TMessageIdentifier::new("svc_call", TMessageType::Call, 1); +/// protocol.write_message_begin(&ident).unwrap(); +/// ``` +#[derive(Debug)] +pub struct TMultiplexedOutputProtocol<P> +where + P: TOutputProtocol, +{ + service_name: String, + inner: P, +} + +impl<P> TMultiplexedOutputProtocol<P> +where + P: TOutputProtocol, +{ + /// Create a `TMultiplexedOutputProtocol` that identifies outgoing messages + /// as originating from a service named `service_name` and sends them over + /// the `wrapped` `TOutputProtocol`. Outgoing messages are encoded and sent + /// by `wrapped`, not by this instance. + pub fn new(service_name: &str, wrapped: P) -> TMultiplexedOutputProtocol<P> { + TMultiplexedOutputProtocol { + service_name: service_name.to_owned(), + inner: wrapped, + } + } +} + +// FIXME: avoid passthrough methods +impl<P> TOutputProtocol for TMultiplexedOutputProtocol<P> +where + P: TOutputProtocol, +{ + fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> ::Result<()> { + match identifier.message_type { + // FIXME: is there a better way to override identifier here? + TMessageType::Call | TMessageType::OneWay => { + let identifier = TMessageIdentifier { + name: format!("{}:{}", self.service_name, identifier.name), + ..*identifier + }; + self.inner.write_message_begin(&identifier) + } + _ => self.inner.write_message_begin(identifier), + } + } + + fn write_message_end(&mut self) -> ::Result<()> { + self.inner.write_message_end() + } + + fn write_struct_begin(&mut self, identifier: &TStructIdentifier) -> ::Result<()> { + self.inner.write_struct_begin(identifier) + } + + fn write_struct_end(&mut self) -> ::Result<()> { + self.inner.write_struct_end() + } + + fn write_field_begin(&mut self, identifier: &TFieldIdentifier) -> ::Result<()> { + self.inner.write_field_begin(identifier) + } + + fn write_field_end(&mut self) -> ::Result<()> { + self.inner.write_field_end() + } + + fn write_field_stop(&mut self) -> ::Result<()> { + self.inner.write_field_stop() + } + + fn write_bytes(&mut self, b: &[u8]) -> ::Result<()> { + self.inner.write_bytes(b) + } + + fn write_bool(&mut self, b: bool) -> ::Result<()> { + self.inner.write_bool(b) + } + + fn write_i8(&mut self, i: i8) -> ::Result<()> { + self.inner.write_i8(i) + } + + fn write_i16(&mut self, i: i16) -> ::Result<()> { + self.inner.write_i16(i) + } + + fn write_i32(&mut self, i: i32) -> ::Result<()> { + self.inner.write_i32(i) + } + + fn write_i64(&mut self, i: i64) -> ::Result<()> { + self.inner.write_i64(i) + } + + fn write_double(&mut self, d: f64) -> ::Result<()> { + self.inner.write_double(d) + } + + fn write_string(&mut self, s: &str) -> ::Result<()> { + self.inner.write_string(s) + } + + fn write_list_begin(&mut self, identifier: &TListIdentifier) -> ::Result<()> { + self.inner.write_list_begin(identifier) + } + + fn write_list_end(&mut self) -> ::Result<()> { + self.inner.write_list_end() + } + + fn write_set_begin(&mut self, identifier: &TSetIdentifier) -> ::Result<()> { + self.inner.write_set_begin(identifier) + } + + fn write_set_end(&mut self) -> ::Result<()> { + self.inner.write_set_end() + } + + fn write_map_begin(&mut self, identifier: &TMapIdentifier) -> ::Result<()> { + self.inner.write_map_begin(identifier) + } + + fn write_map_end(&mut self) -> ::Result<()> { + self.inner.write_map_end() + } + + fn flush(&mut self) -> ::Result<()> { + self.inner.flush() + } + + // utility + // + + fn write_byte(&mut self, b: u8) -> ::Result<()> { + self.inner.write_byte(b) + } +} + +#[cfg(test)] +mod tests { + + use protocol::{TBinaryOutputProtocol, TMessageIdentifier, TMessageType, TOutputProtocol}; + use transport::{TBufferChannel, TIoChannel, WriteHalf}; + + use super::*; + + #[test] + fn must_write_message_begin_with_prefixed_service_name() { + let mut o_prot = test_objects(); + + let ident = TMessageIdentifier::new("bar", TMessageType::Call, 2); + assert_success!(o_prot.write_message_begin(&ident)); + + #[cfg_attr(rustfmt, rustfmt::skip)] + let expected: [u8; 19] = [ + 0x80, + 0x01, /* protocol identifier */ + 0x00, + 0x01, /* message type */ + 0x00, + 0x00, + 0x00, + 0x07, + 0x66, + 0x6F, + 0x6F, /* "foo" */ + 0x3A, /* ":" */ + 0x62, + 0x61, + 0x72, /* "bar" */ + 0x00, + 0x00, + 0x00, + 0x02 /* sequence number */, + ]; + + assert_eq!(o_prot.inner.transport.write_bytes(), expected); + } + + fn test_objects() -> TMultiplexedOutputProtocol<TBinaryOutputProtocol<WriteHalf<TBufferChannel>>> + { + let c = TBufferChannel::with_capacity(40, 40); + let (_, w_chan) = c.split().unwrap(); + let prot = TBinaryOutputProtocol::new(w_chan, true); + TMultiplexedOutputProtocol::new("foo", prot) + } +} diff --git a/src/jaegertracing/thrift/lib/rs/src/protocol/stored.rs b/src/jaegertracing/thrift/lib/rs/src/protocol/stored.rs new file mode 100644 index 000000000..faa51288e --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/src/protocol/stored.rs @@ -0,0 +1,195 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::convert::Into; + +use super::{ + TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, TMessageIdentifier, + TSetIdentifier, TStructIdentifier, +}; +use ProtocolErrorKind; + +/// `TInputProtocol` required to use a `TMultiplexedProcessor`. +/// +/// A `TMultiplexedProcessor` reads incoming message identifiers to determine to +/// which `TProcessor` requests should be forwarded. However, once read, those +/// message identifier bytes are no longer on the wire. Since downstream +/// processors expect to read message identifiers from the given input protocol +/// we need some way of supplying a `TMessageIdentifier` with the service-name +/// stripped. This implementation stores the received `TMessageIdentifier` +/// (without the service name) and passes it to the wrapped `TInputProtocol` +/// when `TInputProtocol::read_message_begin(...)` is called. It delegates all +/// other calls directly to the wrapped `TInputProtocol`. +/// +/// This type **should not** be used by application code. +/// +/// # Examples +/// +/// Create and use a `TStoredInputProtocol`. +/// +/// ```no_run +/// use thrift::protocol::{TInputProtocol, TMessageIdentifier, TMessageType, TOutputProtocol}; +/// use thrift::protocol::{TBinaryInputProtocol, TBinaryOutputProtocol, TStoredInputProtocol}; +/// use thrift::server::TProcessor; +/// use thrift::transport::{TIoChannel, TTcpChannel}; +/// +/// // sample processor +/// struct ActualProcessor; +/// impl TProcessor for ActualProcessor { +/// fn process( +/// &self, +/// _: &mut TInputProtocol, +/// _: &mut TOutputProtocol +/// ) -> thrift::Result<()> { +/// unimplemented!() +/// } +/// } +/// let processor = ActualProcessor {}; +/// +/// // construct the shared transport +/// let mut channel = TTcpChannel::new(); +/// channel.open("localhost:9090").unwrap(); +/// +/// let (i_chan, o_chan) = channel.split().unwrap(); +/// +/// // construct the actual input and output protocols +/// let mut i_prot = TBinaryInputProtocol::new(i_chan, true); +/// let mut o_prot = TBinaryOutputProtocol::new(o_chan, true); +/// +/// // message identifier received from remote and modified to remove the service name +/// let new_msg_ident = TMessageIdentifier::new("service_call", TMessageType::Call, 1); +/// +/// // construct the proxy input protocol +/// let mut proxy_i_prot = TStoredInputProtocol::new(&mut i_prot, new_msg_ident); +/// let res = processor.process(&mut proxy_i_prot, &mut o_prot); +/// ``` +// FIXME: implement Debug +pub struct TStoredInputProtocol<'a> { + inner: &'a mut dyn TInputProtocol, + message_ident: Option<TMessageIdentifier>, +} + +impl<'a> TStoredInputProtocol<'a> { + /// Create a `TStoredInputProtocol` that delegates all calls other than + /// `TInputProtocol::read_message_begin(...)` to a `wrapped` + /// `TInputProtocol`. `message_ident` is the modified message identifier - + /// with service name stripped - that will be passed to + /// `wrapped.read_message_begin(...)`. + pub fn new( + wrapped: &mut dyn TInputProtocol, + message_ident: TMessageIdentifier, + ) -> TStoredInputProtocol { + TStoredInputProtocol { + inner: wrapped, + message_ident: message_ident.into(), + } + } +} + +impl<'a> TInputProtocol for TStoredInputProtocol<'a> { + fn read_message_begin(&mut self) -> ::Result<TMessageIdentifier> { + self.message_ident.take().ok_or_else(|| { + ::errors::new_protocol_error( + ProtocolErrorKind::Unknown, + "message identifier already read", + ) + }) + } + + fn read_message_end(&mut self) -> ::Result<()> { + self.inner.read_message_end() + } + + fn read_struct_begin(&mut self) -> ::Result<Option<TStructIdentifier>> { + self.inner.read_struct_begin() + } + + fn read_struct_end(&mut self) -> ::Result<()> { + self.inner.read_struct_end() + } + + fn read_field_begin(&mut self) -> ::Result<TFieldIdentifier> { + self.inner.read_field_begin() + } + + fn read_field_end(&mut self) -> ::Result<()> { + self.inner.read_field_end() + } + + fn read_bytes(&mut self) -> ::Result<Vec<u8>> { + self.inner.read_bytes() + } + + fn read_bool(&mut self) -> ::Result<bool> { + self.inner.read_bool() + } + + fn read_i8(&mut self) -> ::Result<i8> { + self.inner.read_i8() + } + + fn read_i16(&mut self) -> ::Result<i16> { + self.inner.read_i16() + } + + fn read_i32(&mut self) -> ::Result<i32> { + self.inner.read_i32() + } + + fn read_i64(&mut self) -> ::Result<i64> { + self.inner.read_i64() + } + + fn read_double(&mut self) -> ::Result<f64> { + self.inner.read_double() + } + + fn read_string(&mut self) -> ::Result<String> { + self.inner.read_string() + } + + fn read_list_begin(&mut self) -> ::Result<TListIdentifier> { + self.inner.read_list_begin() + } + + fn read_list_end(&mut self) -> ::Result<()> { + self.inner.read_list_end() + } + + fn read_set_begin(&mut self) -> ::Result<TSetIdentifier> { + self.inner.read_set_begin() + } + + fn read_set_end(&mut self) -> ::Result<()> { + self.inner.read_set_end() + } + + fn read_map_begin(&mut self) -> ::Result<TMapIdentifier> { + self.inner.read_map_begin() + } + + fn read_map_end(&mut self) -> ::Result<()> { + self.inner.read_map_end() + } + + // utility + // + + fn read_byte(&mut self) -> ::Result<u8> { + self.inner.read_byte() + } +} diff --git a/src/jaegertracing/thrift/lib/rs/src/server/mod.rs b/src/jaegertracing/thrift/lib/rs/src/server/mod.rs new file mode 100644 index 000000000..b719d1ba8 --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/src/server/mod.rs @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Types used to implement a Thrift server. + +use protocol::{TInputProtocol, TMessageIdentifier, TMessageType, TOutputProtocol}; +use {ApplicationError, ApplicationErrorKind}; + +mod multiplexed; +mod threaded; + +pub use self::multiplexed::TMultiplexedProcessor; +pub use self::threaded::TServer; + +/// Handles incoming Thrift messages and dispatches them to the user-defined +/// handler functions. +/// +/// An implementation is auto-generated for each Thrift service. When used by a +/// server (for example, a `TSimpleServer`), it will demux incoming service +/// calls and invoke the corresponding user-defined handler function. +/// +/// # Examples +/// +/// Create and start a server using the auto-generated `TProcessor` for +/// a Thrift service `SimpleService`. +/// +/// ```no_run +/// use thrift::protocol::{TInputProtocol, TOutputProtocol}; +/// use thrift::server::TProcessor; +/// +/// // +/// // auto-generated +/// // +/// +/// // processor for `SimpleService` +/// struct SimpleServiceSyncProcessor; +/// impl SimpleServiceSyncProcessor { +/// fn new<H: SimpleServiceSyncHandler>(processor: H) -> SimpleServiceSyncProcessor { +/// unimplemented!(); +/// } +/// } +/// +/// // `TProcessor` implementation for `SimpleService` +/// impl TProcessor for SimpleServiceSyncProcessor { +/// fn process(&self, i: &mut TInputProtocol, o: &mut TOutputProtocol) -> thrift::Result<()> { +/// unimplemented!(); +/// } +/// } +/// +/// // service functions for SimpleService +/// trait SimpleServiceSyncHandler { +/// fn service_call(&self) -> thrift::Result<()>; +/// } +/// +/// // +/// // user-code follows +/// // +/// +/// // define a handler that will be invoked when `service_call` is received +/// struct SimpleServiceHandlerImpl; +/// impl SimpleServiceSyncHandler for SimpleServiceHandlerImpl { +/// fn service_call(&self) -> thrift::Result<()> { +/// unimplemented!(); +/// } +/// } +/// +/// // instantiate the processor +/// let processor = SimpleServiceSyncProcessor::new(SimpleServiceHandlerImpl {}); +/// +/// // at this point you can pass the processor to the server +/// // let server = TServer::new(..., processor); +/// ``` +pub trait TProcessor { + /// Process a Thrift service call. + /// + /// Reads arguments from `i`, executes the user's handler code, and writes + /// the response to `o`. + /// + /// Returns `()` if the handler was executed; `Err` otherwise. + fn process(&self, i: &mut dyn TInputProtocol, o: &mut dyn TOutputProtocol) -> ::Result<()>; +} + +/// Convenience function used in generated `TProcessor` implementations to +/// return an `ApplicationError` if thrift message processing failed. +pub fn handle_process_result( + msg_ident: &TMessageIdentifier, + res: ::Result<()>, + o_prot: &mut dyn TOutputProtocol, +) -> ::Result<()> { + if let Err(e) = res { + let e = match e { + ::Error::Application(a) => a, + _ => ApplicationError::new(ApplicationErrorKind::Unknown, format!("{:?}", e)), + }; + + let ident = TMessageIdentifier::new( + msg_ident.name.clone(), + TMessageType::Exception, + msg_ident.sequence_number, + ); + + o_prot.write_message_begin(&ident)?; + ::Error::write_application_error_to_out_protocol(&e, o_prot)?; + o_prot.write_message_end()?; + o_prot.flush() + } else { + Ok(()) + } +} diff --git a/src/jaegertracing/thrift/lib/rs/src/server/multiplexed.rs b/src/jaegertracing/thrift/lib/rs/src/server/multiplexed.rs new file mode 100644 index 000000000..3f9bc78e4 --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/src/server/multiplexed.rs @@ -0,0 +1,351 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::convert::Into; +use std::fmt; +use std::fmt::{Debug, Formatter}; +use std::sync::{Arc, Mutex}; + +use protocol::{TInputProtocol, TMessageIdentifier, TOutputProtocol, TStoredInputProtocol}; + +use super::{handle_process_result, TProcessor}; + +const MISSING_SEPARATOR_AND_NO_DEFAULT: &'static str = + "missing service separator and no default processor set"; +type ThreadSafeProcessor = Box<dyn TProcessor + Send + Sync>; + +/// A `TProcessor` that can demux service calls to multiple underlying +/// Thrift services. +/// +/// Users register service-specific `TProcessor` instances with a +/// `TMultiplexedProcessor`, and then register that processor with a server +/// implementation. Following that, all incoming service calls are automatically +/// routed to the service-specific `TProcessor`. +/// +/// A `TMultiplexedProcessor` can only handle messages sent by a +/// `TMultiplexedOutputProtocol`. +#[derive(Default)] +pub struct TMultiplexedProcessor { + stored: Mutex<StoredProcessors>, +} + +#[derive(Default)] +struct StoredProcessors { + processors: HashMap<String, Arc<ThreadSafeProcessor>>, + default_processor: Option<Arc<ThreadSafeProcessor>>, +} + +impl TMultiplexedProcessor { + /// Create a new `TMultiplexedProcessor` with no registered service-specific + /// processors. + pub fn new() -> TMultiplexedProcessor { + TMultiplexedProcessor { + stored: Mutex::new(StoredProcessors { + processors: HashMap::new(), + default_processor: None, + }), + } + } + + /// Register a service-specific `processor` for the service named + /// `service_name`. This implementation is also backwards-compatible with + /// non-multiplexed clients. Set `as_default` to `true` to allow + /// non-namespaced requests to be dispatched to a default processor. + /// + /// Returns success if a new entry was inserted. Returns an error if: + /// * A processor exists for `service_name` + /// * You attempt to register a processor as default, and an existing default exists + #[cfg_attr(feature = "cargo-clippy", allow(map_entry))] + pub fn register<S: Into<String>>( + &mut self, + service_name: S, + processor: Box<dyn TProcessor + Send + Sync>, + as_default: bool, + ) -> ::Result<()> { + let mut stored = self.stored.lock().unwrap(); + + let name = service_name.into(); + if !stored.processors.contains_key(&name) { + let processor = Arc::new(processor); + + if as_default { + if stored.default_processor.is_none() { + stored.processors.insert(name, processor.clone()); + stored.default_processor = Some(processor.clone()); + Ok(()) + } else { + Err("cannot reset default processor".into()) + } + } else { + stored.processors.insert(name, processor); + Ok(()) + } + } else { + Err(format!("cannot overwrite existing processor for service {}", name).into()) + } + } + + fn process_message( + &self, + msg_ident: &TMessageIdentifier, + i_prot: &mut dyn TInputProtocol, + o_prot: &mut dyn TOutputProtocol, + ) -> ::Result<()> { + let (svc_name, svc_call) = split_ident_name(&msg_ident.name); + debug!("routing svc_name {:?} svc_call {}", &svc_name, &svc_call); + + let processor: Option<Arc<ThreadSafeProcessor>> = { + let stored = self.stored.lock().unwrap(); + if let Some(name) = svc_name { + stored.processors.get(name).cloned() + } else { + stored.default_processor.clone() + } + }; + + match processor { + Some(arc) => { + let new_msg_ident = TMessageIdentifier::new( + svc_call, + msg_ident.message_type, + msg_ident.sequence_number, + ); + let mut proxy_i_prot = TStoredInputProtocol::new(i_prot, new_msg_ident); + (*arc).process(&mut proxy_i_prot, o_prot) + } + None => Err(missing_processor_message(svc_name).into()), + } + } +} + +impl TProcessor for TMultiplexedProcessor { + fn process(&self, i_prot: &mut dyn TInputProtocol, o_prot: &mut dyn TOutputProtocol) -> ::Result<()> { + let msg_ident = i_prot.read_message_begin()?; + + debug!("process incoming msg id:{:?}", &msg_ident); + let res = self.process_message(&msg_ident, i_prot, o_prot); + + handle_process_result(&msg_ident, res, o_prot) + } +} + +impl Debug for TMultiplexedProcessor { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + let stored = self.stored.lock().unwrap(); + write!( + f, + "TMultiplexedProcess {{ registered_count: {:?} default: {:?} }}", + stored.processors.keys().len(), + stored.default_processor.is_some() + ) + } +} + +fn split_ident_name(ident_name: &str) -> (Option<&str>, &str) { + ident_name + .find(':') + .map(|pos| { + let (svc_name, svc_call) = ident_name.split_at(pos); + let (_, svc_call) = svc_call.split_at(1); // remove colon from service call name + (Some(svc_name), svc_call) + }) + .or_else(|| Some((None, ident_name))) + .unwrap() +} + +fn missing_processor_message(svc_name: Option<&str>) -> String { + match svc_name { + Some(name) => format!("no processor found for service {}", name), + None => MISSING_SEPARATOR_AND_NO_DEFAULT.to_owned(), + } +} + +#[cfg(test)] +mod tests { + use std::convert::Into; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; + + use protocol::{TBinaryInputProtocol, TBinaryOutputProtocol, TMessageIdentifier, TMessageType}; + use transport::{ReadHalf, TBufferChannel, TIoChannel, WriteHalf}; + use {ApplicationError, ApplicationErrorKind}; + + use super::*; + + #[test] + fn should_split_name_into_proper_separator_and_service_call() { + let ident_name = "foo:bar_call"; + let (serv, call) = split_ident_name(&ident_name); + assert_eq!(serv, Some("foo")); + assert_eq!(call, "bar_call"); + } + + #[test] + fn should_return_full_ident_if_no_separator_exists() { + let ident_name = "bar_call"; + let (serv, call) = split_ident_name(&ident_name); + assert_eq!(serv, None); + assert_eq!(call, "bar_call"); + } + + #[test] + fn should_write_error_if_no_separator_found_and_no_default_processor_exists() { + let (mut i, mut o) = build_objects(); + + let sent_ident = TMessageIdentifier::new("foo", TMessageType::Call, 10); + o.write_message_begin(&sent_ident).unwrap(); + o.flush().unwrap(); + o.transport.copy_write_buffer_to_read_buffer(); + o.transport.empty_write_buffer(); + + let p = TMultiplexedProcessor::new(); + p.process(&mut i, &mut o).unwrap(); // at this point an error should be written out + + i.transport.set_readable_bytes(&o.transport.write_bytes()); + let rcvd_ident = i.read_message_begin().unwrap(); + let expected_ident = TMessageIdentifier::new("foo", TMessageType::Exception, 10); + assert_eq!(rcvd_ident, expected_ident); + let rcvd_err = ::Error::read_application_error_from_in_protocol(&mut i).unwrap(); + let expected_err = ApplicationError::new( + ApplicationErrorKind::Unknown, + MISSING_SEPARATOR_AND_NO_DEFAULT, + ); + assert_eq!(rcvd_err, expected_err); + } + + #[test] + fn should_write_error_if_separator_exists_and_no_processor_found() { + let (mut i, mut o) = build_objects(); + + let sent_ident = TMessageIdentifier::new("missing:call", TMessageType::Call, 10); + o.write_message_begin(&sent_ident).unwrap(); + o.flush().unwrap(); + o.transport.copy_write_buffer_to_read_buffer(); + o.transport.empty_write_buffer(); + + let p = TMultiplexedProcessor::new(); + p.process(&mut i, &mut o).unwrap(); // at this point an error should be written out + + i.transport.set_readable_bytes(&o.transport.write_bytes()); + let rcvd_ident = i.read_message_begin().unwrap(); + let expected_ident = TMessageIdentifier::new("missing:call", TMessageType::Exception, 10); + assert_eq!(rcvd_ident, expected_ident); + let rcvd_err = ::Error::read_application_error_from_in_protocol(&mut i).unwrap(); + let expected_err = ApplicationError::new( + ApplicationErrorKind::Unknown, + missing_processor_message(Some("missing")), + ); + assert_eq!(rcvd_err, expected_err); + } + + #[derive(Default)] + struct Service { + pub invoked: Arc<AtomicBool>, + } + + impl TProcessor for Service { + fn process(&self, _: &mut dyn TInputProtocol, _: &mut dyn TOutputProtocol) -> ::Result<()> { + let res = self + .invoked + .compare_and_swap(false, true, Ordering::Relaxed); + if res { + Ok(()) + } else { + Err("failed swap".into()) + } + } + } + + #[test] + fn should_route_call_to_correct_processor() { + let (mut i, mut o) = build_objects(); + + // build the services + let svc_1 = Service { + invoked: Arc::new(AtomicBool::new(false)), + }; + let atm_1 = svc_1.invoked.clone(); + let svc_2 = Service { + invoked: Arc::new(AtomicBool::new(false)), + }; + let atm_2 = svc_2.invoked.clone(); + + // register them + let mut p = TMultiplexedProcessor::new(); + p.register("service_1", Box::new(svc_1), false).unwrap(); + p.register("service_2", Box::new(svc_2), false).unwrap(); + + // make the service call + let sent_ident = TMessageIdentifier::new("service_1:call", TMessageType::Call, 10); + o.write_message_begin(&sent_ident).unwrap(); + o.flush().unwrap(); + o.transport.copy_write_buffer_to_read_buffer(); + o.transport.empty_write_buffer(); + + p.process(&mut i, &mut o).unwrap(); + + // service 1 should have been invoked, not service 2 + assert_eq!(atm_1.load(Ordering::Relaxed), true); + assert_eq!(atm_2.load(Ordering::Relaxed), false); + } + + #[test] + fn should_route_call_to_correct_processor_if_no_separator_exists_and_default_processor_set() { + let (mut i, mut o) = build_objects(); + + // build the services + let svc_1 = Service { + invoked: Arc::new(AtomicBool::new(false)), + }; + let atm_1 = svc_1.invoked.clone(); + let svc_2 = Service { + invoked: Arc::new(AtomicBool::new(false)), + }; + let atm_2 = svc_2.invoked.clone(); + + // register them + let mut p = TMultiplexedProcessor::new(); + p.register("service_1", Box::new(svc_1), false).unwrap(); + p.register("service_2", Box::new(svc_2), true).unwrap(); // second processor is default + + // make the service call (it's an old client, so we have to be backwards compatible) + let sent_ident = TMessageIdentifier::new("old_call", TMessageType::Call, 10); + o.write_message_begin(&sent_ident).unwrap(); + o.flush().unwrap(); + o.transport.copy_write_buffer_to_read_buffer(); + o.transport.empty_write_buffer(); + + p.process(&mut i, &mut o).unwrap(); + + // service 2 should have been invoked, not service 1 + assert_eq!(atm_1.load(Ordering::Relaxed), false); + assert_eq!(atm_2.load(Ordering::Relaxed), true); + } + + fn build_objects() -> ( + TBinaryInputProtocol<ReadHalf<TBufferChannel>>, + TBinaryOutputProtocol<WriteHalf<TBufferChannel>>, + ) { + let c = TBufferChannel::with_capacity(128, 128); + let (r_c, w_c) = c.split().unwrap(); + ( + TBinaryInputProtocol::new(r_c, true), + TBinaryOutputProtocol::new(w_c, true), + ) + } +} diff --git a/src/jaegertracing/thrift/lib/rs/src/server/threaded.rs b/src/jaegertracing/thrift/lib/rs/src/server/threaded.rs new file mode 100644 index 000000000..8f8c082d6 --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/src/server/threaded.rs @@ -0,0 +1,233 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::net::{TcpListener, TcpStream}; +use std::sync::Arc; +use threadpool::ThreadPool; + +use protocol::{TInputProtocol, TInputProtocolFactory, TOutputProtocol, TOutputProtocolFactory}; +use transport::{TIoChannel, TReadTransportFactory, TTcpChannel, TWriteTransportFactory}; +use {ApplicationError, ApplicationErrorKind}; + +use super::TProcessor; + +/// Fixed-size thread-pool blocking Thrift server. +/// +/// A `TServer` listens on a given address and submits accepted connections +/// to an **unbounded** queue. Connections from this queue are serviced by +/// the first available worker thread from a **fixed-size** thread pool. Each +/// accepted connection is handled by that worker thread, and communication +/// over this thread occurs sequentially and synchronously (i.e. calls block). +/// Accepted connections have an input half and an output half, each of which +/// uses a `TTransport` and `TInputProtocol`/`TOutputProtocol` to translate +/// messages to and from byes. Any combination of `TInputProtocol`, `TOutputProtocol` +/// and `TTransport` may be used. +/// +/// # Examples +/// +/// Creating and running a `TServer` using Thrift-compiler-generated +/// service code. +/// +/// ```no_run +/// use thrift::protocol::{TInputProtocolFactory, TOutputProtocolFactory}; +/// use thrift::protocol::{TBinaryInputProtocolFactory, TBinaryOutputProtocolFactory}; +/// use thrift::protocol::{TInputProtocol, TOutputProtocol}; +/// use thrift::transport::{TBufferedReadTransportFactory, TBufferedWriteTransportFactory, +/// TReadTransportFactory, TWriteTransportFactory}; +/// use thrift::server::{TProcessor, TServer}; +/// +/// // +/// // auto-generated +/// // +/// +/// // processor for `SimpleService` +/// struct SimpleServiceSyncProcessor; +/// impl SimpleServiceSyncProcessor { +/// fn new<H: SimpleServiceSyncHandler>(processor: H) -> SimpleServiceSyncProcessor { +/// unimplemented!(); +/// } +/// } +/// +/// // `TProcessor` implementation for `SimpleService` +/// impl TProcessor for SimpleServiceSyncProcessor { +/// fn process(&self, i: &mut TInputProtocol, o: &mut TOutputProtocol) -> thrift::Result<()> { +/// unimplemented!(); +/// } +/// } +/// +/// // service functions for SimpleService +/// trait SimpleServiceSyncHandler { +/// fn service_call(&self) -> thrift::Result<()>; +/// } +/// +/// // +/// // user-code follows +/// // +/// +/// // define a handler that will be invoked when `service_call` is received +/// struct SimpleServiceHandlerImpl; +/// impl SimpleServiceSyncHandler for SimpleServiceHandlerImpl { +/// fn service_call(&self) -> thrift::Result<()> { +/// unimplemented!(); +/// } +/// } +/// +/// // instantiate the processor +/// let processor = SimpleServiceSyncProcessor::new(SimpleServiceHandlerImpl {}); +/// +/// // instantiate the server +/// let i_tr_fact: Box<TReadTransportFactory> = Box::new(TBufferedReadTransportFactory::new()); +/// let i_pr_fact: Box<TInputProtocolFactory> = Box::new(TBinaryInputProtocolFactory::new()); +/// let o_tr_fact: Box<TWriteTransportFactory> = Box::new(TBufferedWriteTransportFactory::new()); +/// let o_pr_fact: Box<TOutputProtocolFactory> = Box::new(TBinaryOutputProtocolFactory::new()); +/// +/// let mut server = TServer::new( +/// i_tr_fact, +/// i_pr_fact, +/// o_tr_fact, +/// o_pr_fact, +/// processor, +/// 10 +/// ); +/// +/// // start listening for incoming connections +/// match server.listen("127.0.0.1:8080") { +/// Ok(_) => println!("listen completed"), +/// Err(e) => println!("listen failed with error {:?}", e), +/// } +/// ``` +#[derive(Debug)] +pub struct TServer<PRC, RTF, IPF, WTF, OPF> +where + PRC: TProcessor + Send + Sync + 'static, + RTF: TReadTransportFactory + 'static, + IPF: TInputProtocolFactory + 'static, + WTF: TWriteTransportFactory + 'static, + OPF: TOutputProtocolFactory + 'static, +{ + r_trans_factory: RTF, + i_proto_factory: IPF, + w_trans_factory: WTF, + o_proto_factory: OPF, + processor: Arc<PRC>, + worker_pool: ThreadPool, +} + +impl<PRC, RTF, IPF, WTF, OPF> TServer<PRC, RTF, IPF, WTF, OPF> +where + PRC: TProcessor + Send + Sync + 'static, + RTF: TReadTransportFactory + 'static, + IPF: TInputProtocolFactory + 'static, + WTF: TWriteTransportFactory + 'static, + OPF: TOutputProtocolFactory + 'static, +{ + /// Create a `TServer`. + /// + /// Each accepted connection has an input and output half, each of which + /// requires a `TTransport` and `TProtocol`. `TServer` uses + /// `read_transport_factory` and `input_protocol_factory` to create + /// implementations for the input, and `write_transport_factory` and + /// `output_protocol_factory` to create implementations for the output. + pub fn new( + read_transport_factory: RTF, + input_protocol_factory: IPF, + write_transport_factory: WTF, + output_protocol_factory: OPF, + processor: PRC, + num_workers: usize, + ) -> TServer<PRC, RTF, IPF, WTF, OPF> { + TServer { + r_trans_factory: read_transport_factory, + i_proto_factory: input_protocol_factory, + w_trans_factory: write_transport_factory, + o_proto_factory: output_protocol_factory, + processor: Arc::new(processor), + worker_pool: ThreadPool::with_name("Thrift service processor".to_owned(), num_workers), + } + } + + /// Listen for incoming connections on `listen_address`. + /// + /// `listen_address` should be in the form `host:port`, + /// for example: `127.0.0.1:8080`. + /// + /// Return `()` if successful. + /// + /// Return `Err` when the server cannot bind to `listen_address` or there + /// is an unrecoverable error. + pub fn listen(&mut self, listen_address: &str) -> ::Result<()> { + let listener = TcpListener::bind(listen_address)?; + for stream in listener.incoming() { + match stream { + Ok(s) => { + let (i_prot, o_prot) = self.new_protocols_for_connection(s)?; + let processor = self.processor.clone(); + self.worker_pool + .execute(move || handle_incoming_connection(processor, i_prot, o_prot)); + } + Err(e) => { + warn!("failed to accept remote connection with error {:?}", e); + } + } + } + + Err(::Error::Application(ApplicationError { + kind: ApplicationErrorKind::Unknown, + message: "aborted listen loop".into(), + })) + } + + fn new_protocols_for_connection( + &mut self, + stream: TcpStream, + ) -> ::Result<(Box<dyn TInputProtocol + Send>, Box<dyn TOutputProtocol + Send>)> { + // create the shared tcp stream + let channel = TTcpChannel::with_stream(stream); + + // split it into two - one to be owned by the + // input tran/proto and the other by the output + let (r_chan, w_chan) = channel.split()?; + + // input protocol and transport + let r_tran = self.r_trans_factory.create(Box::new(r_chan)); + let i_prot = self.i_proto_factory.create(r_tran); + + // output protocol and transport + let w_tran = self.w_trans_factory.create(Box::new(w_chan)); + let o_prot = self.o_proto_factory.create(w_tran); + + Ok((i_prot, o_prot)) + } +} + +fn handle_incoming_connection<PRC>( + processor: Arc<PRC>, + i_prot: Box<dyn TInputProtocol>, + o_prot: Box<dyn TOutputProtocol>, +) where + PRC: TProcessor, +{ + let mut i_prot = i_prot; + let mut o_prot = o_prot; + loop { + let r = processor.process(&mut *i_prot, &mut *o_prot); + if let Err(e) = r { + warn!("processor completed with error: {:?}", e); + break; + } + } +} diff --git a/src/jaegertracing/thrift/lib/rs/src/transport/buffered.rs b/src/jaegertracing/thrift/lib/rs/src/transport/buffered.rs new file mode 100644 index 000000000..b33eb4f55 --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/src/transport/buffered.rs @@ -0,0 +1,483 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::cmp; +use std::io; +use std::io::{Read, Write}; + +use super::{TReadTransport, TReadTransportFactory, TWriteTransport, TWriteTransportFactory}; + +/// Default capacity of the read buffer in bytes. +const READ_CAPACITY: usize = 4096; + +/// Default capacity of the write buffer in bytes.. +const WRITE_CAPACITY: usize = 4096; + +/// Transport that reads messages via an internal buffer. +/// +/// A `TBufferedReadTransport` maintains a fixed-size internal read buffer. +/// On a call to `TBufferedReadTransport::read(...)` one full message - both +/// fixed-length header and bytes - is read from the wrapped channel and buffered. +/// Subsequent read calls are serviced from the internal buffer until it is +/// exhausted, at which point the next full message is read from the wrapped +/// channel. +/// +/// # Examples +/// +/// Create and use a `TBufferedReadTransport`. +/// +/// ```no_run +/// use std::io::Read; +/// use thrift::transport::{TBufferedReadTransport, TTcpChannel}; +/// +/// let mut c = TTcpChannel::new(); +/// c.open("localhost:9090").unwrap(); +/// +/// let mut t = TBufferedReadTransport::new(c); +/// +/// t.read(&mut vec![0u8; 1]).unwrap(); +/// ``` +#[derive(Debug)] +pub struct TBufferedReadTransport<C> +where + C: Read, +{ + buf: Box<[u8]>, + pos: usize, + cap: usize, + chan: C, +} + +impl<C> TBufferedReadTransport<C> +where + C: Read, +{ + /// Create a `TBufferedTransport` with default-sized internal read and + /// write buffers that wraps the given `TIoChannel`. + pub fn new(channel: C) -> TBufferedReadTransport<C> { + TBufferedReadTransport::with_capacity(READ_CAPACITY, channel) + } + + /// Create a `TBufferedTransport` with an internal read buffer of size + /// `read_capacity` and an internal write buffer of size + /// `write_capacity` that wraps the given `TIoChannel`. + pub fn with_capacity(read_capacity: usize, channel: C) -> TBufferedReadTransport<C> { + TBufferedReadTransport { + buf: vec![0; read_capacity].into_boxed_slice(), + pos: 0, + cap: 0, + chan: channel, + } + } + + fn get_bytes(&mut self) -> io::Result<&[u8]> { + if self.cap - self.pos == 0 { + self.pos = 0; + self.cap = self.chan.read(&mut self.buf)?; + } + + Ok(&self.buf[self.pos..self.cap]) + } + + fn consume(&mut self, consumed: usize) { + // TODO: was a bug here += <-- test somehow + self.pos = cmp::min(self.cap, self.pos + consumed); + } +} + +impl<C> Read for TBufferedReadTransport<C> +where + C: Read, +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + let mut bytes_read = 0; + + loop { + let nread = { + let avail_bytes = self.get_bytes()?; + let avail_space = buf.len() - bytes_read; + let nread = cmp::min(avail_space, avail_bytes.len()); + buf[bytes_read..(bytes_read + nread)].copy_from_slice(&avail_bytes[..nread]); + nread + }; + + self.consume(nread); + bytes_read += nread; + + if bytes_read == buf.len() || nread == 0 { + break; + } + } + + Ok(bytes_read) + } +} + +/// Factory for creating instances of `TBufferedReadTransport`. +#[derive(Default)] +pub struct TBufferedReadTransportFactory; + +impl TBufferedReadTransportFactory { + pub fn new() -> TBufferedReadTransportFactory { + TBufferedReadTransportFactory {} + } +} + +impl TReadTransportFactory for TBufferedReadTransportFactory { + /// Create a `TBufferedReadTransport`. + fn create(&self, channel: Box<dyn Read + Send>) -> Box<dyn TReadTransport + Send> { + Box::new(TBufferedReadTransport::new(channel)) + } +} + +/// Transport that writes messages via an internal buffer. +/// +/// A `TBufferedWriteTransport` maintains a fixed-size internal write buffer. +/// All writes are made to this buffer and are sent to the wrapped channel only +/// when `TBufferedWriteTransport::flush()` is called. On a flush a fixed-length +/// header with a count of the buffered bytes is written, followed by the bytes +/// themselves. +/// +/// # Examples +/// +/// Create and use a `TBufferedWriteTransport`. +/// +/// ```no_run +/// use std::io::Write; +/// use thrift::transport::{TBufferedWriteTransport, TTcpChannel}; +/// +/// let mut c = TTcpChannel::new(); +/// c.open("localhost:9090").unwrap(); +/// +/// let mut t = TBufferedWriteTransport::new(c); +/// +/// t.write(&[0x00]).unwrap(); +/// t.flush().unwrap(); +/// ``` +#[derive(Debug)] +pub struct TBufferedWriteTransport<C> +where + C: Write, +{ + buf: Vec<u8>, + cap: usize, + channel: C, +} + +impl<C> TBufferedWriteTransport<C> +where + C: Write, +{ + /// Create a `TBufferedTransport` with default-sized internal read and + /// write buffers that wraps the given `TIoChannel`. + pub fn new(channel: C) -> TBufferedWriteTransport<C> { + TBufferedWriteTransport::with_capacity(WRITE_CAPACITY, channel) + } + + /// Create a `TBufferedTransport` with an internal read buffer of size + /// `read_capacity` and an internal write buffer of size + /// `write_capacity` that wraps the given `TIoChannel`. + pub fn with_capacity(write_capacity: usize, channel: C) -> TBufferedWriteTransport<C> { + assert!( + write_capacity > 0, + "write buffer size must be a positive integer" + ); + + TBufferedWriteTransport { + buf: Vec::with_capacity(write_capacity), + cap: write_capacity, + channel: channel, + } + } +} + +impl<C> Write for TBufferedWriteTransport<C> +where + C: Write, +{ + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + if !buf.is_empty() { + let mut avail_bytes; + + loop { + avail_bytes = cmp::min(buf.len(), self.cap - self.buf.len()); + + if avail_bytes == 0 { + self.flush()?; + } else { + break; + } + } + + let avail_bytes = avail_bytes; + + self.buf.extend_from_slice(&buf[..avail_bytes]); + assert!(self.buf.len() <= self.cap, "copy overflowed buffer"); + + Ok(avail_bytes) + } else { + Ok(0) + } + } + + fn flush(&mut self) -> io::Result<()> { + self.channel.write_all(&self.buf)?; + self.channel.flush()?; + self.buf.clear(); + Ok(()) + } +} + +/// Factory for creating instances of `TBufferedWriteTransport`. +#[derive(Default)] +pub struct TBufferedWriteTransportFactory; + +impl TBufferedWriteTransportFactory { + pub fn new() -> TBufferedWriteTransportFactory { + TBufferedWriteTransportFactory {} + } +} + +impl TWriteTransportFactory for TBufferedWriteTransportFactory { + /// Create a `TBufferedWriteTransport`. + fn create(&self, channel: Box<dyn Write + Send>) -> Box<dyn TWriteTransport + Send> { + Box::new(TBufferedWriteTransport::new(channel)) + } +} + +#[cfg(test)] +mod tests { + use std::io::{Read, Write}; + + use super::*; + use transport::TBufferChannel; + + #[test] + fn must_return_zero_if_read_buffer_is_empty() { + let mem = TBufferChannel::with_capacity(10, 0); + let mut t = TBufferedReadTransport::with_capacity(10, mem); + + let mut b = vec![0; 10]; + let read_result = t.read(&mut b); + + assert_eq!(read_result.unwrap(), 0); + } + + #[test] + fn must_return_zero_if_caller_reads_into_zero_capacity_buffer() { + let mem = TBufferChannel::with_capacity(10, 0); + let mut t = TBufferedReadTransport::with_capacity(10, mem); + + let read_result = t.read(&mut []); + + assert_eq!(read_result.unwrap(), 0); + } + + #[test] + fn must_return_zero_if_nothing_more_can_be_read() { + let mem = TBufferChannel::with_capacity(4, 0); + let mut t = TBufferedReadTransport::with_capacity(4, mem); + + t.chan.set_readable_bytes(&[0, 1, 2, 3]); + + // read buffer is exactly the same size as bytes available + let mut buf = vec![0u8; 4]; + let read_result = t.read(&mut buf); + + // we've read exactly 4 bytes + assert_eq!(read_result.unwrap(), 4); + assert_eq!(&buf, &[0, 1, 2, 3]); + + // try read again + let buf_again = vec![0u8; 4]; + let read_result = t.read(&mut buf); + + // this time, 0 bytes and we haven't changed the buffer + assert_eq!(read_result.unwrap(), 0); + assert_eq!(&buf_again, &[0, 0, 0, 0]) + } + + #[test] + fn must_fill_user_buffer_with_only_as_many_bytes_as_available() { + let mem = TBufferChannel::with_capacity(4, 0); + let mut t = TBufferedReadTransport::with_capacity(4, mem); + + t.chan.set_readable_bytes(&[0, 1, 2, 3]); + + // read buffer is much larger than the bytes available + let mut buf = vec![0u8; 8]; + let read_result = t.read(&mut buf); + + // we've read exactly 4 bytes + assert_eq!(read_result.unwrap(), 4); + assert_eq!(&buf[..4], &[0, 1, 2, 3]); + + // try read again + let read_result = t.read(&mut buf[4..]); + + // this time, 0 bytes and we haven't changed the buffer + assert_eq!(read_result.unwrap(), 0); + assert_eq!(&buf, &[0, 1, 2, 3, 0, 0, 0, 0]) + } + + #[test] + fn must_read_successfully() { + // this test involves a few loops within the buffered transport + // itself where it has to drain the underlying transport in order + // to service a read + + // we have a much smaller buffer than the + // underlying transport has bytes available + let mem = TBufferChannel::with_capacity(10, 0); + let mut t = TBufferedReadTransport::with_capacity(2, mem); + + // fill the underlying transport's byte buffer + let mut readable_bytes = [0u8; 10]; + for i in 0..10 { + readable_bytes[i] = i as u8; + } + + t.chan.set_readable_bytes(&readable_bytes); + + // we ask to read into a buffer that's much larger + // than the one the buffered transport has; as a result + // it's going to have to keep asking the underlying + // transport for more bytes + let mut buf = [0u8; 8]; + let read_result = t.read(&mut buf); + + // we should have read 8 bytes + assert_eq!(read_result.unwrap(), 8); + assert_eq!(&buf, &[0, 1, 2, 3, 4, 5, 6, 7]); + + // let's clear out the buffer and try read again + for i in 0..8 { + buf[i] = 0; + } + let read_result = t.read(&mut buf); + + // this time we were only able to read 2 bytes + // (all that's remaining from the underlying transport) + // let's also check that the remaining bytes are untouched + assert_eq!(read_result.unwrap(), 2); + assert_eq!(&buf[0..2], &[8, 9]); + assert_eq!(&buf[2..], &[0, 0, 0, 0, 0, 0]); + + // try read again (we should get 0) + // and all the existing bytes were untouched + let read_result = t.read(&mut buf); + assert_eq!(read_result.unwrap(), 0); + assert_eq!(&buf[0..2], &[8, 9]); + assert_eq!(&buf[2..], &[0, 0, 0, 0, 0, 0]); + } + + #[test] + fn must_return_error_when_nothing_can_be_written_to_underlying_channel() { + let mem = TBufferChannel::with_capacity(0, 0); + let mut t = TBufferedWriteTransport::with_capacity(1, mem); + + let b = vec![0; 10]; + let r = t.write(&b); + + // should have written 1 byte + assert_eq!(r.unwrap(), 1); + + // let's try again... + let r = t.write(&b[1..]); + + // this time we'll error out because the auto-flush failed + assert!(r.is_err()); + } + + #[test] + fn must_return_zero_if_caller_calls_write_with_empty_buffer() { + let mem = TBufferChannel::with_capacity(0, 10); + let mut t = TBufferedWriteTransport::with_capacity(10, mem); + + let r = t.write(&[]); + let expected: [u8; 0] = []; + + assert_eq!(r.unwrap(), 0); + assert_eq_transport_written_bytes!(t, expected); + } + + #[test] + fn must_auto_flush_if_write_buffer_full() { + let mem = TBufferChannel::with_capacity(0, 8); + let mut t = TBufferedWriteTransport::with_capacity(4, mem); + + let b0 = [0x00, 0x01, 0x02, 0x03]; + let b1 = [0x04, 0x05, 0x06, 0x07]; + + // write the first 4 bytes; we've now filled the transport's write buffer + let r = t.write(&b0); + assert_eq!(r.unwrap(), 4); + + // try write the next 4 bytes; this causes the transport to auto-flush the first 4 bytes + let r = t.write(&b1); + assert_eq!(r.unwrap(), 4); + + // check that in writing the second 4 bytes we auto-flushed the first 4 bytes + assert_eq_transport_num_written_bytes!(t, 4); + assert_eq_transport_written_bytes!(t, b0); + t.channel.empty_write_buffer(); + + // now flush the transport to push the second 4 bytes to the underlying channel + assert!(t.flush().is_ok()); + + // check that we wrote out the second 4 bytes + assert_eq_transport_written_bytes!(t, b1); + } + + #[test] + fn must_write_to_inner_transport_on_flush() { + let mem = TBufferChannel::with_capacity(10, 10); + let mut t = TBufferedWriteTransport::new(mem); + + let b: [u8; 5] = [0, 1, 2, 3, 4]; + assert_eq!(t.write(&b).unwrap(), 5); + assert_eq_transport_num_written_bytes!(t, 0); + + assert!(t.flush().is_ok()); + + assert_eq_transport_written_bytes!(t, b); + } + + #[test] + fn must_write_successfully_after_flush() { + let mem = TBufferChannel::with_capacity(0, 5); + let mut t = TBufferedWriteTransport::with_capacity(5, mem); + + // write and flush + let b: [u8; 5] = [0, 1, 2, 3, 4]; + assert_eq!(t.write(&b).unwrap(), 5); + assert!(t.flush().is_ok()); + + // check the flushed bytes + assert_eq_transport_written_bytes!(t, b); + + // reset our underlying transport + t.channel.empty_write_buffer(); + + // write and flush again + assert_eq!(t.write(&b).unwrap(), 5); + assert!(t.flush().is_ok()); + + // check the flushed bytes + assert_eq_transport_written_bytes!(t, b); + } +} diff --git a/src/jaegertracing/thrift/lib/rs/src/transport/framed.rs b/src/jaegertracing/thrift/lib/rs/src/transport/framed.rs new file mode 100644 index 000000000..98ad1bb2f --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/src/transport/framed.rs @@ -0,0 +1,459 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; +use std::cmp; +use std::io; +use std::io::{Read, Write}; + +use super::{TReadTransport, TReadTransportFactory, TWriteTransport, TWriteTransportFactory}; + +/// Default capacity of the read buffer in bytes. +const READ_CAPACITY: usize = 4096; + +/// Default capacity of the write buffer in bytes. +const WRITE_CAPACITY: usize = 4096; + +/// Transport that reads framed messages. +/// +/// A `TFramedReadTransport` maintains a fixed-size internal read buffer. +/// On a call to `TFramedReadTransport::read(...)` one full message - both +/// fixed-length header and bytes - is read from the wrapped channel and +/// buffered. Subsequent read calls are serviced from the internal buffer +/// until it is exhausted, at which point the next full message is read +/// from the wrapped channel. +/// +/// # Examples +/// +/// Create and use a `TFramedReadTransport`. +/// +/// ```no_run +/// use std::io::Read; +/// use thrift::transport::{TFramedReadTransport, TTcpChannel}; +/// +/// let mut c = TTcpChannel::new(); +/// c.open("localhost:9090").unwrap(); +/// +/// let mut t = TFramedReadTransport::new(c); +/// +/// t.read(&mut vec![0u8; 1]).unwrap(); +/// ``` +#[derive(Debug)] +pub struct TFramedReadTransport<C> +where + C: Read, +{ + buf: Vec<u8>, + pos: usize, + cap: usize, + chan: C, +} + +impl<C> TFramedReadTransport<C> +where + C: Read, +{ + /// Create a `TFramedReadTransport` with a default-sized + /// internal read buffer that wraps the given `TIoChannel`. + pub fn new(channel: C) -> TFramedReadTransport<C> { + TFramedReadTransport::with_capacity(READ_CAPACITY, channel) + } + + /// Create a `TFramedTransport` with an internal read buffer + /// of size `read_capacity` that wraps the given `TIoChannel`. + pub fn with_capacity(read_capacity: usize, channel: C) -> TFramedReadTransport<C> { + TFramedReadTransport { + buf: vec![0; read_capacity], // FIXME: do I actually have to do this? + pos: 0, + cap: 0, + chan: channel, + } + } +} + +impl<C> Read for TFramedReadTransport<C> +where + C: Read, +{ + fn read(&mut self, b: &mut [u8]) -> io::Result<usize> { + if self.cap - self.pos == 0 { + let message_size = self.chan.read_i32::<BigEndian>()? as usize; + + let buf_capacity = cmp::max(message_size, READ_CAPACITY); + self.buf.resize(buf_capacity, 0); + + self.chan.read_exact(&mut self.buf[..message_size])?; + self.cap = message_size as usize; + self.pos = 0; + } + + let nread = cmp::min(b.len(), self.cap - self.pos); + b[..nread].clone_from_slice(&self.buf[self.pos..self.pos + nread]); + self.pos += nread; + + Ok(nread) + } +} + +/// Factory for creating instances of `TFramedReadTransport`. +#[derive(Default)] +pub struct TFramedReadTransportFactory; + +impl TFramedReadTransportFactory { + pub fn new() -> TFramedReadTransportFactory { + TFramedReadTransportFactory {} + } +} + +impl TReadTransportFactory for TFramedReadTransportFactory { + /// Create a `TFramedReadTransport`. + fn create(&self, channel: Box<dyn Read + Send>) -> Box<dyn TReadTransport + Send> { + Box::new(TFramedReadTransport::new(channel)) + } +} + +/// Transport that writes framed messages. +/// +/// A `TFramedWriteTransport` maintains a fixed-size internal write buffer. All +/// writes are made to this buffer and are sent to the wrapped channel only +/// when `TFramedWriteTransport::flush()` is called. On a flush a fixed-length +/// header with a count of the buffered bytes is written, followed by the bytes +/// themselves. +/// +/// # Examples +/// +/// Create and use a `TFramedWriteTransport`. +/// +/// ```no_run +/// use std::io::Write; +/// use thrift::transport::{TFramedWriteTransport, TTcpChannel}; +/// +/// let mut c = TTcpChannel::new(); +/// c.open("localhost:9090").unwrap(); +/// +/// let mut t = TFramedWriteTransport::new(c); +/// +/// t.write(&[0x00]).unwrap(); +/// t.flush().unwrap(); +/// ``` +#[derive(Debug)] +pub struct TFramedWriteTransport<C> +where + C: Write, +{ + buf: Vec<u8>, + channel: C, +} + +impl<C> TFramedWriteTransport<C> +where + C: Write, +{ + /// Create a `TFramedWriteTransport` with default-sized internal + /// write buffer that wraps the given `TIoChannel`. + pub fn new(channel: C) -> TFramedWriteTransport<C> { + TFramedWriteTransport::with_capacity(WRITE_CAPACITY, channel) + } + + /// Create a `TFramedWriteTransport` with an internal write buffer + /// of size `write_capacity` that wraps the given `TIoChannel`. + pub fn with_capacity(write_capacity: usize, channel: C) -> TFramedWriteTransport<C> { + TFramedWriteTransport { + buf: Vec::with_capacity(write_capacity), + channel, + } + } +} + +impl<C> Write for TFramedWriteTransport<C> +where + C: Write, +{ + fn write(&mut self, b: &[u8]) -> io::Result<usize> { + let current_capacity = self.buf.capacity(); + let available_space = current_capacity - self.buf.len(); + if b.len() > available_space { + let additional_space = cmp::max(b.len() - available_space, current_capacity); + self.buf.reserve(additional_space); + } + + self.buf.extend_from_slice(b); + Ok(b.len()) + } + + fn flush(&mut self) -> io::Result<()> { + let message_size = self.buf.len(); + + if let 0 = message_size { + return Ok(()); + } else { + self.channel.write_i32::<BigEndian>(message_size as i32)?; + } + + // will spin if the underlying channel can't be written to + let mut byte_index = 0; + while byte_index < message_size { + let nwrite = self.channel.write(&self.buf[byte_index..message_size])?; + byte_index = cmp::min(byte_index + nwrite, message_size); + } + + let buf_capacity = cmp::min(self.buf.capacity(), WRITE_CAPACITY); + self.buf.resize(buf_capacity, 0); + self.buf.clear(); + + self.channel.flush() + } +} + +/// Factory for creating instances of `TFramedWriteTransport`. +#[derive(Default)] +pub struct TFramedWriteTransportFactory; + +impl TFramedWriteTransportFactory { + pub fn new() -> TFramedWriteTransportFactory { + TFramedWriteTransportFactory {} + } +} + +impl TWriteTransportFactory for TFramedWriteTransportFactory { + /// Create a `TFramedWriteTransport`. + fn create(&self, channel: Box<dyn Write + Send>) -> Box<dyn TWriteTransport + Send> { + Box::new(TFramedWriteTransport::new(channel)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use transport::mem::TBufferChannel; + + // FIXME: test a forced reserve + + #[test] + fn must_read_message_smaller_than_initial_buffer_size() { + let c = TBufferChannel::with_capacity(10, 10); + let mut t = TFramedReadTransport::with_capacity(8, c); + + t.chan.set_readable_bytes(&[ + 0x00, 0x00, 0x00, 0x04, /* message size */ + 0x00, 0x01, 0x02, 0x03, /* message body */ + ]); + + let mut buf = vec![0; 8]; + + // we've read exactly 4 bytes + assert_eq!(t.read(&mut buf).unwrap(), 4); + assert_eq!(&buf[..4], &[0x00, 0x01, 0x02, 0x03]); + } + + #[test] + fn must_read_message_greater_than_initial_buffer_size() { + let c = TBufferChannel::with_capacity(10, 10); + let mut t = TFramedReadTransport::with_capacity(2, c); + + t.chan.set_readable_bytes(&[ + 0x00, 0x00, 0x00, 0x04, /* message size */ + 0x00, 0x01, 0x02, 0x03, /* message body */ + ]); + + let mut buf = vec![0; 8]; + + // we've read exactly 4 bytes + assert_eq!(t.read(&mut buf).unwrap(), 4); + assert_eq!(&buf[..4], &[0x00, 0x01, 0x02, 0x03]); + } + + #[test] + fn must_read_multiple_messages_in_sequence_correctly() { + let c = TBufferChannel::with_capacity(10, 10); + let mut t = TFramedReadTransport::with_capacity(2, c); + + // + // 1st message + // + + t.chan.set_readable_bytes(&[ + 0x00, 0x00, 0x00, 0x04, /* message size */ + 0x00, 0x01, 0x02, 0x03, /* message body */ + ]); + + let mut buf = vec![0; 8]; + + // we've read exactly 4 bytes + assert_eq!(t.read(&mut buf).unwrap(), 4); + assert_eq!(&buf, &[0x00, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00]); + + // + // 2nd message + // + + t.chan.set_readable_bytes(&[ + 0x00, 0x00, 0x00, 0x01, /* message size */ + 0x04, /* message body */ + ]); + + let mut buf = vec![0; 8]; + + // we've read exactly 1 byte + assert_eq!(t.read(&mut buf).unwrap(), 1); + assert_eq!(&buf, &[0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); + } + + #[test] + fn must_write_message_smaller_than_buffer_size() { + let mem = TBufferChannel::with_capacity(0, 0); + let mut t = TFramedWriteTransport::with_capacity(20, mem); + + let b = vec![0; 10]; + + // should have written 10 bytes + assert_eq!(t.write(&b).unwrap(), 10); + } + + #[test] + fn must_return_zero_if_caller_calls_write_with_empty_buffer() { + let mem = TBufferChannel::with_capacity(0, 10); + let mut t = TFramedWriteTransport::with_capacity(10, mem); + + let expected: [u8; 0] = []; + + assert_eq!(t.write(&[]).unwrap(), 0); + assert_eq_transport_written_bytes!(t, expected); + } + + #[test] + fn must_write_to_inner_transport_on_flush() { + let mem = TBufferChannel::with_capacity(10, 10); + let mut t = TFramedWriteTransport::new(mem); + + let b: [u8; 5] = [0x00, 0x01, 0x02, 0x03, 0x04]; + assert_eq!(t.write(&b).unwrap(), 5); + assert_eq_transport_num_written_bytes!(t, 0); + + assert!(t.flush().is_ok()); + + let expected_bytes = [ + 0x00, 0x00, 0x00, 0x05, /* message size */ + 0x00, 0x01, 0x02, 0x03, 0x04, /* message body */ + ]; + + assert_eq_transport_written_bytes!(t, expected_bytes); + } + + #[test] + fn must_write_message_greater_than_buffer_size_00() { + let mem = TBufferChannel::with_capacity(0, 10); + + // IMPORTANT: DO **NOT** CHANGE THE WRITE_CAPACITY OR THE NUMBER OF BYTES TO BE WRITTEN! + // these lengths were chosen to be just long enough + // that doubling the capacity is a **worse** choice than + // simply resizing the buffer to b.len() + + let mut t = TFramedWriteTransport::with_capacity(1, mem); + let b = [0x00, 0x01, 0x02]; + + // should have written 3 bytes + assert_eq!(t.write(&b).unwrap(), 3); + assert_eq_transport_num_written_bytes!(t, 0); + + assert!(t.flush().is_ok()); + + let expected_bytes = [ + 0x00, 0x00, 0x00, 0x03, /* message size */ + 0x00, 0x01, 0x02, /* message body */ + ]; + + assert_eq_transport_written_bytes!(t, expected_bytes); + } + + #[test] + fn must_write_message_greater_than_buffer_size_01() { + let mem = TBufferChannel::with_capacity(0, 10); + + // IMPORTANT: DO **NOT** CHANGE THE WRITE_CAPACITY OR THE NUMBER OF BYTES TO BE WRITTEN! + // these lengths were chosen to be just long enough + // that doubling the capacity is a **better** choice than + // simply resizing the buffer to b.len() + + let mut t = TFramedWriteTransport::with_capacity(2, mem); + let b = [0x00, 0x01, 0x02]; + + // should have written 3 bytes + assert_eq!(t.write(&b).unwrap(), 3); + assert_eq_transport_num_written_bytes!(t, 0); + + assert!(t.flush().is_ok()); + + let expected_bytes = [ + 0x00, 0x00, 0x00, 0x03, /* message size */ + 0x00, 0x01, 0x02, /* message body */ + ]; + + assert_eq_transport_written_bytes!(t, expected_bytes); + } + + #[test] + fn must_return_error_if_nothing_can_be_written_to_inner_transport_on_flush() { + let mem = TBufferChannel::with_capacity(0, 0); + let mut t = TFramedWriteTransport::with_capacity(1, mem); + + let b = vec![0; 10]; + + // should have written 10 bytes + assert_eq!(t.write(&b).unwrap(), 10); + + // let's flush + let r = t.flush(); + + // this time we'll error out because the flush can't write to the underlying channel + assert!(r.is_err()); + } + + #[test] + fn must_write_successfully_after_flush() { + // IMPORTANT: write capacity *MUST* be greater + // than message sizes used in this test + 4-byte frame header + let mem = TBufferChannel::with_capacity(0, 10); + let mut t = TFramedWriteTransport::with_capacity(5, mem); + + // write and flush + let first_message: [u8; 5] = [0x00, 0x01, 0x02, 0x03, 0x04]; + assert_eq!(t.write(&first_message).unwrap(), 5); + assert!(t.flush().is_ok()); + + let mut expected = Vec::new(); + expected.write_all(&[0x00, 0x00, 0x00, 0x05]).unwrap(); // message size + expected.extend_from_slice(&first_message); + + // check the flushed bytes + assert_eq!(t.channel.write_bytes(), expected); + + // reset our underlying transport + t.channel.empty_write_buffer(); + + let second_message: [u8; 3] = [0x05, 0x06, 0x07]; + assert_eq!(t.write(&second_message).unwrap(), 3); + assert!(t.flush().is_ok()); + + expected.clear(); + expected.write_all(&[0x00, 0x00, 0x00, 0x03]).unwrap(); // message size + expected.extend_from_slice(&second_message); + + // check the flushed bytes + assert_eq!(t.channel.write_bytes(), expected); + } +} diff --git a/src/jaegertracing/thrift/lib/rs/src/transport/mem.rs b/src/jaegertracing/thrift/lib/rs/src/transport/mem.rs new file mode 100644 index 000000000..82c4b579f --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/src/transport/mem.rs @@ -0,0 +1,385 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::cmp; +use std::io; +use std::sync::{Arc, Mutex}; + +use super::{ReadHalf, TIoChannel, WriteHalf}; + +/// In-memory read and write channel with fixed-size read and write buffers. +/// +/// On a `write` bytes are written to the internal write buffer. Writes are no +/// longer accepted once this buffer is full. Callers must `empty_write_buffer()` +/// before subsequent writes are accepted. +/// +/// You can set readable bytes in the internal read buffer by filling it with +/// `set_readable_bytes(...)`. Callers can then read until the buffer is +/// depleted. No further reads are accepted until the internal read buffer is +/// replenished again. +#[derive(Debug)] +pub struct TBufferChannel { + read: Arc<Mutex<ReadData>>, + write: Arc<Mutex<WriteData>>, +} + +#[derive(Debug)] +struct ReadData { + buf: Box<[u8]>, + pos: usize, + idx: usize, + cap: usize, +} + +#[derive(Debug)] +struct WriteData { + buf: Box<[u8]>, + pos: usize, + cap: usize, +} + +impl TBufferChannel { + /// Constructs a new, empty `TBufferChannel` with the given + /// read buffer capacity and write buffer capacity. + pub fn with_capacity(read_capacity: usize, write_capacity: usize) -> TBufferChannel { + TBufferChannel { + read: Arc::new(Mutex::new(ReadData { + buf: vec![0; read_capacity].into_boxed_slice(), + idx: 0, + pos: 0, + cap: read_capacity, + })), + write: Arc::new(Mutex::new(WriteData { + buf: vec![0; write_capacity].into_boxed_slice(), + pos: 0, + cap: write_capacity, + })), + } + } + + /// Return a copy of the bytes held by the internal read buffer. + /// Returns an empty vector if no readable bytes are present. + pub fn read_bytes(&self) -> Vec<u8> { + let rdata = self.read.as_ref().lock().unwrap(); + let mut buf = vec![0u8; rdata.idx]; + buf.copy_from_slice(&rdata.buf[..rdata.idx]); + buf + } + + // FIXME: do I really need this API call? + // FIXME: should this simply reset to the last set of readable bytes? + /// Reset the number of readable bytes to zero. + /// + /// Subsequent calls to `read` will return nothing. + pub fn empty_read_buffer(&mut self) { + let mut rdata = self.read.as_ref().lock().unwrap(); + rdata.pos = 0; + rdata.idx = 0; + } + + /// Copy bytes from the source buffer `buf` into the internal read buffer, + /// overwriting any existing bytes. Returns the number of bytes copied, + /// which is `min(buf.len(), internal_read_buf.len())`. + pub fn set_readable_bytes(&mut self, buf: &[u8]) -> usize { + self.empty_read_buffer(); + let mut rdata = self.read.as_ref().lock().unwrap(); + let max_bytes = cmp::min(rdata.cap, buf.len()); + rdata.buf[..max_bytes].clone_from_slice(&buf[..max_bytes]); + rdata.idx = max_bytes; + max_bytes + } + + /// Return a copy of the bytes held by the internal write buffer. + /// Returns an empty vector if no bytes were written. + pub fn write_bytes(&self) -> Vec<u8> { + let wdata = self.write.as_ref().lock().unwrap(); + let mut buf = vec![0u8; wdata.pos]; + buf.copy_from_slice(&wdata.buf[..wdata.pos]); + buf + } + + /// Resets the internal write buffer, making it seem like no bytes were + /// written. Calling `write_buffer` after this returns an empty vector. + pub fn empty_write_buffer(&mut self) { + let mut wdata = self.write.as_ref().lock().unwrap(); + wdata.pos = 0; + } + + /// Overwrites the contents of the read buffer with the contents of the + /// write buffer. The write buffer is emptied after this operation. + pub fn copy_write_buffer_to_read_buffer(&mut self) { + // FIXME: redo this entire method + let buf = { + let wdata = self.write.as_ref().lock().unwrap(); + let b = &wdata.buf[..wdata.pos]; + let mut b_ret = vec![0; b.len()]; + b_ret.copy_from_slice(b); + b_ret + }; + + let bytes_copied = self.set_readable_bytes(&buf); + assert_eq!(bytes_copied, buf.len()); + + self.empty_write_buffer(); + } +} + +impl TIoChannel for TBufferChannel { + fn split(self) -> ::Result<(ReadHalf<Self>, WriteHalf<Self>)> + where + Self: Sized, + { + Ok(( + ReadHalf { + handle: TBufferChannel { + read: self.read.clone(), + write: self.write.clone(), + }, + }, + WriteHalf { + handle: TBufferChannel { + read: self.read.clone(), + write: self.write.clone(), + }, + }, + )) + } +} + +impl io::Read for TBufferChannel { + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + let mut rdata = self.read.as_ref().lock().unwrap(); + let nread = cmp::min(buf.len(), rdata.idx - rdata.pos); + buf[..nread].clone_from_slice(&rdata.buf[rdata.pos..rdata.pos + nread]); + rdata.pos += nread; + Ok(nread) + } +} + +impl io::Write for TBufferChannel { + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + let mut wdata = self.write.as_ref().lock().unwrap(); + let nwrite = cmp::min(buf.len(), wdata.cap - wdata.pos); + let (start, end) = (wdata.pos, wdata.pos + nwrite); + wdata.buf[start..end].clone_from_slice(&buf[..nwrite]); + wdata.pos += nwrite; + Ok(nwrite) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) // nothing to do on flush + } +} + +#[cfg(test)] +mod tests { + use std::io::{Read, Write}; + + use super::TBufferChannel; + + #[test] + fn must_empty_write_buffer() { + let mut t = TBufferChannel::with_capacity(0, 1); + + let bytes_to_write: [u8; 1] = [0x01]; + let result = t.write(&bytes_to_write); + assert_eq!(result.unwrap(), 1); + assert_eq!(&t.write_bytes(), &bytes_to_write); + + t.empty_write_buffer(); + assert_eq!(t.write_bytes().len(), 0); + } + + #[test] + fn must_accept_writes_after_buffer_emptied() { + let mut t = TBufferChannel::with_capacity(0, 2); + + let bytes_to_write: [u8; 2] = [0x01, 0x02]; + + // first write (all bytes written) + let result = t.write(&bytes_to_write); + assert_eq!(result.unwrap(), 2); + assert_eq!(&t.write_bytes(), &bytes_to_write); + + // try write again (nothing should be written) + let result = t.write(&bytes_to_write); + assert_eq!(result.unwrap(), 0); + assert_eq!(&t.write_bytes(), &bytes_to_write); // still the same as before + + // now reset the buffer + t.empty_write_buffer(); + assert_eq!(t.write_bytes().len(), 0); + + // now try write again - the write should succeed + let result = t.write(&bytes_to_write); + assert_eq!(result.unwrap(), 2); + assert_eq!(&t.write_bytes(), &bytes_to_write); + } + + #[test] + fn must_accept_multiple_writes_until_buffer_is_full() { + let mut t = TBufferChannel::with_capacity(0, 10); + + // first write (all bytes written) + let bytes_to_write_0: [u8; 2] = [0x01, 0x41]; + let write_0_result = t.write(&bytes_to_write_0); + assert_eq!(write_0_result.unwrap(), 2); + assert_eq!(t.write_bytes(), &bytes_to_write_0); + + // second write (all bytes written, starting at index 2) + let bytes_to_write_1: [u8; 7] = [0x24, 0x41, 0x32, 0x33, 0x11, 0x98, 0xAF]; + let write_1_result = t.write(&bytes_to_write_1); + assert_eq!(write_1_result.unwrap(), 7); + assert_eq!(&t.write_bytes()[2..], &bytes_to_write_1); + + // third write (only 1 byte written - that's all we have space for) + let bytes_to_write_2: [u8; 3] = [0xBF, 0xDA, 0x98]; + let write_2_result = t.write(&bytes_to_write_2); + assert_eq!(write_2_result.unwrap(), 1); + assert_eq!(&t.write_bytes()[9..], &bytes_to_write_2[0..1]); // how does this syntax work?! + + // fourth write (no writes are accepted) + let bytes_to_write_3: [u8; 3] = [0xBF, 0xAA, 0xFD]; + let write_3_result = t.write(&bytes_to_write_3); + assert_eq!(write_3_result.unwrap(), 0); + + // check the full write buffer + let mut expected: Vec<u8> = Vec::with_capacity(10); + expected.extend_from_slice(&bytes_to_write_0); + expected.extend_from_slice(&bytes_to_write_1); + expected.extend_from_slice(&bytes_to_write_2[0..1]); + assert_eq!(t.write_bytes(), &expected[..]); + } + + #[test] + fn must_empty_read_buffer() { + let mut t = TBufferChannel::with_capacity(1, 0); + + let bytes_to_read: [u8; 1] = [0x01]; + let result = t.set_readable_bytes(&bytes_to_read); + assert_eq!(result, 1); + assert_eq!(t.read_bytes(), &bytes_to_read); + + t.empty_read_buffer(); + assert_eq!(t.read_bytes().len(), 0); + } + + #[test] + fn must_allow_readable_bytes_to_be_set_after_read_buffer_emptied() { + let mut t = TBufferChannel::with_capacity(1, 0); + + let bytes_to_read_0: [u8; 1] = [0x01]; + let result = t.set_readable_bytes(&bytes_to_read_0); + assert_eq!(result, 1); + assert_eq!(t.read_bytes(), &bytes_to_read_0); + + t.empty_read_buffer(); + assert_eq!(t.read_bytes().len(), 0); + + let bytes_to_read_1: [u8; 1] = [0x02]; + let result = t.set_readable_bytes(&bytes_to_read_1); + assert_eq!(result, 1); + assert_eq!(t.read_bytes(), &bytes_to_read_1); + } + + #[test] + fn must_accept_multiple_reads_until_all_bytes_read() { + let mut t = TBufferChannel::with_capacity(10, 0); + + let readable_bytes: [u8; 10] = [0xFF, 0xEE, 0xDD, 0xCC, 0xBB, 0x00, 0x1A, 0x2B, 0x3C, 0x4D]; + + // check that we're able to set the bytes to be read + let result = t.set_readable_bytes(&readable_bytes); + assert_eq!(result, 10); + assert_eq!(t.read_bytes(), &readable_bytes); + + // first read + let mut read_buf_0 = vec![0; 5]; + let read_result = t.read(&mut read_buf_0); + assert_eq!(read_result.unwrap(), 5); + assert_eq!(read_buf_0.as_slice(), &(readable_bytes[0..5])); + + // second read + let mut read_buf_1 = vec![0; 4]; + let read_result = t.read(&mut read_buf_1); + assert_eq!(read_result.unwrap(), 4); + assert_eq!(read_buf_1.as_slice(), &(readable_bytes[5..9])); + + // third read (only 1 byte remains to be read) + let mut read_buf_2 = vec![0; 3]; + let read_result = t.read(&mut read_buf_2); + assert_eq!(read_result.unwrap(), 1); + read_buf_2.truncate(1); // FIXME: does the caller have to do this? + assert_eq!(read_buf_2.as_slice(), &(readable_bytes[9..])); + + // fourth read (nothing should be readable) + let mut read_buf_3 = vec![0; 10]; + let read_result = t.read(&mut read_buf_3); + assert_eq!(read_result.unwrap(), 0); + read_buf_3.truncate(0); + + // check that all the bytes we received match the original (again!) + let mut bytes_read = Vec::with_capacity(10); + bytes_read.extend_from_slice(&read_buf_0); + bytes_read.extend_from_slice(&read_buf_1); + bytes_read.extend_from_slice(&read_buf_2); + bytes_read.extend_from_slice(&read_buf_3); + assert_eq!(&bytes_read, &readable_bytes); + } + + #[test] + fn must_allow_reads_to_succeed_after_read_buffer_replenished() { + let mut t = TBufferChannel::with_capacity(3, 0); + + let readable_bytes_0: [u8; 3] = [0x02, 0xAB, 0x33]; + + // check that we're able to set the bytes to be read + let result = t.set_readable_bytes(&readable_bytes_0); + assert_eq!(result, 3); + assert_eq!(t.read_bytes(), &readable_bytes_0); + + let mut read_buf = vec![0; 4]; + + // drain the read buffer + let read_result = t.read(&mut read_buf); + assert_eq!(read_result.unwrap(), 3); + assert_eq!(t.read_bytes(), &read_buf[0..3]); + + // check that a subsequent read fails + let read_result = t.read(&mut read_buf); + assert_eq!(read_result.unwrap(), 0); + + // we don't modify the read buffer on failure + let mut expected_bytes = Vec::with_capacity(4); + expected_bytes.extend_from_slice(&readable_bytes_0); + expected_bytes.push(0x00); + assert_eq!(&read_buf, &expected_bytes); + + // replenish the read buffer again + let readable_bytes_1: [u8; 2] = [0x91, 0xAA]; + + // check that we're able to set the bytes to be read + let result = t.set_readable_bytes(&readable_bytes_1); + assert_eq!(result, 2); + assert_eq!(t.read_bytes(), &readable_bytes_1); + + // read again + let read_result = t.read(&mut read_buf); + assert_eq!(read_result.unwrap(), 2); + assert_eq!(t.read_bytes(), &read_buf[0..2]); + } +} diff --git a/src/jaegertracing/thrift/lib/rs/src/transport/mod.rs b/src/jaegertracing/thrift/lib/rs/src/transport/mod.rs new file mode 100644 index 000000000..32c07998a --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/src/transport/mod.rs @@ -0,0 +1,291 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Types used to send and receive bytes over an I/O channel. +//! +//! The core types are the `TReadTransport`, `TWriteTransport` and the +//! `TIoChannel` traits, through which `TInputProtocol` or +//! `TOutputProtocol` can receive and send primitives over the wire. While +//! `TInputProtocol` and `TOutputProtocol` instances deal with language primitives +//! the types in this module understand only bytes. + +use std::io; +use std::io::{Read, Write}; +use std::ops::{Deref, DerefMut}; + +#[cfg(test)] +macro_rules! assert_eq_transport_num_written_bytes { + ($transport:ident, $num_written_bytes:expr) => {{ + assert_eq!($transport.channel.write_bytes().len(), $num_written_bytes); + }}; +} + +#[cfg(test)] +macro_rules! assert_eq_transport_written_bytes { + ($transport:ident, $expected_bytes:ident) => {{ + assert_eq!($transport.channel.write_bytes(), &$expected_bytes); + }}; +} + +mod buffered; +mod framed; +mod mem; +mod socket; + +pub use self::buffered::{ + TBufferedReadTransport, TBufferedReadTransportFactory, TBufferedWriteTransport, + TBufferedWriteTransportFactory, +}; +pub use self::framed::{ + TFramedReadTransport, TFramedReadTransportFactory, TFramedWriteTransport, + TFramedWriteTransportFactory, +}; +pub use self::mem::TBufferChannel; +pub use self::socket::TTcpChannel; + +/// Identifies a transport used by a `TInputProtocol` to receive bytes. +pub trait TReadTransport: Read {} + +/// Helper type used by a server to create `TReadTransport` instances for +/// accepted client connections. +pub trait TReadTransportFactory { + /// Create a `TTransport` that wraps a channel over which bytes are to be read. + fn create(&self, channel: Box<dyn Read + Send>) -> Box<dyn TReadTransport + Send>; +} + +/// Identifies a transport used by `TOutputProtocol` to send bytes. +pub trait TWriteTransport: Write {} + +/// Helper type used by a server to create `TWriteTransport` instances for +/// accepted client connections. +pub trait TWriteTransportFactory { + /// Create a `TTransport` that wraps a channel over which bytes are to be sent. + fn create(&self, channel: Box<dyn Write + Send>) -> Box<dyn TWriteTransport + Send>; +} + +impl<T> TReadTransport for T where T: Read {} + +impl<T> TWriteTransport for T where T: Write {} + +// FIXME: implement the Debug trait for boxed transports + +impl<T> TReadTransportFactory for Box<T> +where + T: TReadTransportFactory + ?Sized, +{ + fn create(&self, channel: Box<dyn Read + Send>) -> Box<dyn TReadTransport + Send> { + (**self).create(channel) + } +} + +impl<T> TWriteTransportFactory for Box<T> +where + T: TWriteTransportFactory + ?Sized, +{ + fn create(&self, channel: Box<dyn Write + Send>) -> Box<dyn TWriteTransport + Send> { + (**self).create(channel) + } +} + +/// Identifies a splittable bidirectional I/O channel used to send and receive bytes. +pub trait TIoChannel: Read + Write { + /// Split the channel into a readable half and a writable half, where the + /// readable half implements `io::Read` and the writable half implements + /// `io::Write`. Returns `None` if the channel was not initialized, or if it + /// cannot be split safely. + /// + /// Returned halves may share the underlying OS channel or buffer resources. + /// Implementations **should ensure** that these two halves can be safely + /// used independently by concurrent threads. + fn split(self) -> ::Result<(::transport::ReadHalf<Self>, ::transport::WriteHalf<Self>)> + where + Self: Sized; +} + +/// The readable half of an object returned from `TIoChannel::split`. +#[derive(Debug)] +pub struct ReadHalf<C> +where + C: Read, +{ + handle: C, +} + +/// The writable half of an object returned from `TIoChannel::split`. +#[derive(Debug)] +pub struct WriteHalf<C> +where + C: Write, +{ + handle: C, +} + +impl<C> ReadHalf<C> +where + C: Read, +{ + /// Create a `ReadHalf` associated with readable `handle` + pub fn new(handle: C) -> ReadHalf<C> { + ReadHalf { handle } + } +} + +impl<C> WriteHalf<C> +where + C: Write, +{ + /// Create a `WriteHalf` associated with writable `handle` + pub fn new(handle: C) -> WriteHalf<C> { + WriteHalf { handle } + } +} + +impl<C> Read for ReadHalf<C> +where + C: Read, +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + self.handle.read(buf) + } +} + +impl<C> Write for WriteHalf<C> +where + C: Write, +{ + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + self.handle.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.handle.flush() + } +} + +impl<C> Deref for ReadHalf<C> +where + C: Read, +{ + type Target = C; + + fn deref(&self) -> &Self::Target { + &self.handle + } +} + +impl<C> DerefMut for ReadHalf<C> +where + C: Read, +{ + fn deref_mut(&mut self) -> &mut C { + &mut self.handle + } +} + +impl<C> Deref for WriteHalf<C> +where + C: Write, +{ + type Target = C; + + fn deref(&self) -> &Self::Target { + &self.handle + } +} + +impl<C> DerefMut for WriteHalf<C> +where + C: Write, +{ + fn deref_mut(&mut self) -> &mut C { + &mut self.handle + } +} + +#[cfg(test)] +mod tests { + + use std::io::Cursor; + + use super::*; + + #[test] + fn must_create_usable_read_channel_from_concrete_read_type() { + let r = Cursor::new([0, 1, 2]); + let _ = TBufferedReadTransport::new(r); + } + + #[test] + fn must_create_usable_read_channel_from_boxed_read() { + let r: Box<dyn Read> = Box::new(Cursor::new([0, 1, 2])); + let _ = TBufferedReadTransport::new(r); + } + + #[test] + fn must_create_usable_write_channel_from_concrete_write_type() { + let w = vec![0u8; 10]; + let _ = TBufferedWriteTransport::new(w); + } + + #[test] + fn must_create_usable_write_channel_from_boxed_write() { + let w: Box<dyn Write> = Box::new(vec![0u8; 10]); + let _ = TBufferedWriteTransport::new(w); + } + + #[test] + fn must_create_usable_read_transport_from_concrete_read_transport() { + let r = Cursor::new([0, 1, 2]); + let mut t = TBufferedReadTransport::new(r); + takes_read_transport(&mut t) + } + + #[test] + fn must_create_usable_read_transport_from_boxed_read() { + let r = Cursor::new([0, 1, 2]); + let mut t: Box<dyn TReadTransport> = Box::new(TBufferedReadTransport::new(r)); + takes_read_transport(&mut t) + } + + #[test] + fn must_create_usable_write_transport_from_concrete_write_transport() { + let w = vec![0u8; 10]; + let mut t = TBufferedWriteTransport::new(w); + takes_write_transport(&mut t) + } + + #[test] + fn must_create_usable_write_transport_from_boxed_write() { + let w = vec![0u8; 10]; + let mut t: Box<dyn TWriteTransport> = Box::new(TBufferedWriteTransport::new(w)); + takes_write_transport(&mut t) + } + + fn takes_read_transport<R>(t: &mut R) + where + R: TReadTransport, + { + t.bytes(); + } + + fn takes_write_transport<W>(t: &mut W) + where + W: TWriteTransport, + { + t.flush().unwrap(); + } +} diff --git a/src/jaegertracing/thrift/lib/rs/src/transport/socket.rs b/src/jaegertracing/thrift/lib/rs/src/transport/socket.rs new file mode 100644 index 000000000..0bef67bed --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/src/transport/socket.rs @@ -0,0 +1,168 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::convert::From; +use std::io; +use std::io::{ErrorKind, Read, Write}; +use std::net::{Shutdown, TcpStream}; + +use super::{ReadHalf, TIoChannel, WriteHalf}; +use {new_transport_error, TransportErrorKind}; + +/// Bidirectional TCP/IP channel. +/// +/// # Examples +/// +/// Create a `TTcpChannel`. +/// +/// ```no_run +/// use std::io::{Read, Write}; +/// use thrift::transport::TTcpChannel; +/// +/// let mut c = TTcpChannel::new(); +/// c.open("localhost:9090").unwrap(); +/// +/// let mut buf = vec![0u8; 4]; +/// c.read(&mut buf).unwrap(); +/// c.write(&vec![0, 1, 2]).unwrap(); +/// ``` +/// +/// Create a `TTcpChannel` by wrapping an existing `TcpStream`. +/// +/// ```no_run +/// use std::io::{Read, Write}; +/// use std::net::TcpStream; +/// use thrift::transport::TTcpChannel; +/// +/// let stream = TcpStream::connect("127.0.0.1:9189").unwrap(); +/// +/// // no need to call c.open() since we've already connected above +/// let mut c = TTcpChannel::with_stream(stream); +/// +/// let mut buf = vec![0u8; 4]; +/// c.read(&mut buf).unwrap(); +/// c.write(&vec![0, 1, 2]).unwrap(); +/// ``` +#[derive(Debug, Default)] +pub struct TTcpChannel { + stream: Option<TcpStream>, +} + +impl TTcpChannel { + /// Create an uninitialized `TTcpChannel`. + /// + /// The returned instance must be opened using `TTcpChannel::open(...)` + /// before it can be used. + pub fn new() -> TTcpChannel { + TTcpChannel { stream: None } + } + + /// Create a `TTcpChannel` that wraps an existing `TcpStream`. + /// + /// The passed-in stream is assumed to have been opened before being wrapped + /// by the created `TTcpChannel` instance. + pub fn with_stream(stream: TcpStream) -> TTcpChannel { + TTcpChannel { + stream: Some(stream), + } + } + + /// Connect to `remote_address`, which should have the form `host:port`. + pub fn open(&mut self, remote_address: &str) -> ::Result<()> { + if self.stream.is_some() { + Err(new_transport_error( + TransportErrorKind::AlreadyOpen, + "tcp connection previously opened", + )) + } else { + match TcpStream::connect(&remote_address) { + Ok(s) => { + self.stream = Some(s); + Ok(()) + } + Err(e) => Err(From::from(e)), + } + } + } + + /// Shut down this channel. + /// + /// Both send and receive halves are closed, and this instance can no + /// longer be used to communicate with another endpoint. + pub fn close(&mut self) -> ::Result<()> { + self.if_set(|s| s.shutdown(Shutdown::Both)) + .map_err(From::from) + } + + fn if_set<F, T>(&mut self, mut stream_operation: F) -> io::Result<T> + where + F: FnMut(&mut TcpStream) -> io::Result<T>, + { + if let Some(ref mut s) = self.stream { + stream_operation(s) + } else { + Err(io::Error::new( + ErrorKind::NotConnected, + "tcp endpoint not connected", + )) + } + } +} + +impl TIoChannel for TTcpChannel { + fn split(self) -> ::Result<(ReadHalf<Self>, WriteHalf<Self>)> + where + Self: Sized, + { + let mut s = self; + + s.stream + .as_mut() + .and_then(|s| s.try_clone().ok()) + .map(|cloned| { + let read_half = ReadHalf::new(TTcpChannel { + stream: s.stream.take(), + }); + let write_half = WriteHalf::new(TTcpChannel { + stream: Some(cloned), + }); + (read_half, write_half) + }) + .ok_or_else(|| { + new_transport_error( + TransportErrorKind::Unknown, + "cannot clone underlying tcp stream", + ) + }) + } +} + +impl Read for TTcpChannel { + fn read(&mut self, b: &mut [u8]) -> io::Result<usize> { + self.if_set(|s| s.read(b)) + } +} + +impl Write for TTcpChannel { + fn write(&mut self, b: &[u8]) -> io::Result<usize> { + self.if_set(|s| s.write(b)) + } + + fn flush(&mut self) -> io::Result<()> { + self.if_set(|s| s.flush()) + } +} diff --git a/src/jaegertracing/thrift/lib/rs/test/Cargo.toml b/src/jaegertracing/thrift/lib/rs/test/Cargo.toml new file mode 100644 index 000000000..dc4ffe32b --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/test/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "kitchen-sink" +version = "0.1.0" +license = "Apache-2.0" +authors = ["Apache Thrift Developers <dev@thrift.apache.org>"] +publish = false + +[dependencies] +clap = "<2.28.0" +ordered-float = "1.0" +try_from = "0.3" + +[dependencies.thrift] +path = "../" + diff --git a/src/jaegertracing/thrift/lib/rs/test/Makefile.am b/src/jaegertracing/thrift/lib/rs/test/Makefile.am new file mode 100644 index 000000000..486188cfe --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/test/Makefile.am @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +THRIFT = $(top_builddir)/compiler/cpp/thrift + +stubs: thrifts/Base_One.thrift thrifts/Base_Two.thrift thrifts/Midlayer.thrift thrifts/Ultimate.thrift $(top_builddir)/test/Recursive.thrift $(THRIFT) + $(THRIFT) -I ./thrifts -out src --gen rs thrifts/Base_One.thrift + $(THRIFT) -I ./thrifts -out src --gen rs thrifts/Base_Two.thrift + $(THRIFT) -I ./thrifts -out src --gen rs thrifts/Midlayer.thrift + $(THRIFT) -I ./thrifts -out src --gen rs thrifts/Ultimate.thrift + $(THRIFT) -out src --gen rs $(top_builddir)/test/Recursive.thrift + $(THRIFT) -out src --gen rs $(top_builddir)/test/Identifiers.thrift #THRIFT-4953 + +check: stubs + $(CARGO) build + $(CARGO) test + [ -d bin ] || mkdir bin + cp target/debug/kitchen_sink_server bin/kitchen_sink_server + cp target/debug/kitchen_sink_client bin/kitchen_sink_client + +clean-local: + $(CARGO) clean + -$(RM) Cargo.lock + -$(RM) src/base_one.rs + -$(RM) src/base_two.rs + -$(RM) src/midlayer.rs + -$(RM) src/ultimate.rs + -$(RM) -r bin + +EXTRA_DIST = \ + Cargo.toml \ + thrifts/Base_One.thrift \ + thrifts/Base_Two.thrift \ + thrifts/Midlayer.thrift \ + thrifts/Ultimate.thrift \ + src/lib.rs \ + src/bin/kitchen_sink_server.rs \ + src/bin/kitchen_sink_client.rs + diff --git a/src/jaegertracing/thrift/lib/rs/test/src/bin/kitchen_sink_client.rs b/src/jaegertracing/thrift/lib/rs/test/src/bin/kitchen_sink_client.rs new file mode 100644 index 000000000..d295c8870 --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/test/src/bin/kitchen_sink_client.rs @@ -0,0 +1,239 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate clap; + +extern crate kitchen_sink; +extern crate thrift; + +use std::convert::Into; + +use kitchen_sink::base_two::{TNapkinServiceSyncClient, TRamenServiceSyncClient}; +use kitchen_sink::midlayer::{MealServiceSyncClient, TMealServiceSyncClient}; +use kitchen_sink::recursive; +use kitchen_sink::recursive::{CoRec, CoRec2, RecList, RecTree, TTestServiceSyncClient}; +use kitchen_sink::ultimate::{FullMealServiceSyncClient, TFullMealServiceSyncClient}; +use thrift::protocol::{ + TBinaryInputProtocol, TBinaryOutputProtocol, TCompactInputProtocol, TCompactOutputProtocol, + TInputProtocol, TOutputProtocol, +}; +use thrift::transport::{ + ReadHalf, TFramedReadTransport, TFramedWriteTransport, TIoChannel, TTcpChannel, WriteHalf, +}; + +fn main() { + match run() { + Ok(()) => println!("kitchen sink client completed successfully"), + Err(e) => { + println!("kitchen sink client failed with error {:?}", e); + std::process::exit(1); + } + } +} + +fn run() -> thrift::Result<()> { + let matches = clap_app!(rust_kitchen_sink_client => + (version: "0.1.0") + (author: "Apache Thrift Developers <dev@thrift.apache.org>") + (about: "Thrift Rust kitchen sink client") + (@arg host: --host +takes_value "Host on which the Thrift test server is located") + (@arg port: --port +takes_value "Port on which the Thrift test server is listening") + (@arg protocol: --protocol +takes_value "Thrift protocol implementation to use (\"binary\", \"compact\")") + (@arg service: --service +takes_value "Service type to contact (\"part\", \"full\", \"recursive\")") + ) + .get_matches(); + + let host = matches.value_of("host").unwrap_or("127.0.0.1"); + let port = value_t!(matches, "port", u16).unwrap_or(9090); + let protocol = matches.value_of("protocol").unwrap_or("compact"); + let service = matches.value_of("service").unwrap_or("part"); + + let (i_chan, o_chan) = tcp_channel(host, port)?; + let (i_tran, o_tran) = ( + TFramedReadTransport::new(i_chan), + TFramedWriteTransport::new(o_chan), + ); + + let (i_prot, o_prot): (Box<TInputProtocol>, Box<TOutputProtocol>) = match protocol { + "binary" => ( + Box::new(TBinaryInputProtocol::new(i_tran, true)), + Box::new(TBinaryOutputProtocol::new(o_tran, true)), + ), + "compact" => ( + Box::new(TCompactInputProtocol::new(i_tran)), + Box::new(TCompactOutputProtocol::new(o_tran)), + ), + unmatched => return Err(format!("unsupported protocol {}", unmatched).into()), + }; + + run_client(service, i_prot, o_prot) +} + +fn run_client( + service: &str, + i_prot: Box<TInputProtocol>, + o_prot: Box<TOutputProtocol>, +) -> thrift::Result<()> { + match service { + "full" => exec_full_meal_client(i_prot, o_prot), + "part" => exec_meal_client(i_prot, o_prot), + "recursive" => exec_recursive_client(i_prot, o_prot), + _ => Err(thrift::Error::from(format!( + "unknown service type {}", + service + ))), + } +} + +fn tcp_channel( + host: &str, + port: u16, +) -> thrift::Result<(ReadHalf<TTcpChannel>, WriteHalf<TTcpChannel>)> { + let mut c = TTcpChannel::new(); + c.open(&format!("{}:{}", host, port))?; + c.split() +} + +fn exec_meal_client( + i_prot: Box<TInputProtocol>, + o_prot: Box<TOutputProtocol>, +) -> thrift::Result<()> { + let mut client = MealServiceSyncClient::new(i_prot, o_prot); + + // client.full_meal(); // <-- IMPORTANT: if you uncomment this, compilation *should* fail + // this is because the MealService struct does not contain the appropriate service marker + + // only the following three calls work + execute_call("part", "ramen", || client.ramen(50)).map(|_| ())?; + execute_call("part", "meal", || client.meal()).map(|_| ())?; + execute_call("part", "napkin", || client.napkin()).map(|_| ())?; + + Ok(()) +} + +fn exec_full_meal_client( + i_prot: Box<TInputProtocol>, + o_prot: Box<TOutputProtocol>, +) -> thrift::Result<()> { + let mut client = FullMealServiceSyncClient::new(i_prot, o_prot); + + execute_call("full", "ramen", || client.ramen(100)).map(|_| ())?; + execute_call("full", "meal", || client.meal()).map(|_| ())?; + execute_call("full", "napkin", || client.napkin()).map(|_| ())?; + execute_call("full", "full meal", || client.full_meal()).map(|_| ())?; + + Ok(()) +} + +fn exec_recursive_client( + i_prot: Box<TInputProtocol>, + o_prot: Box<TOutputProtocol>, +) -> thrift::Result<()> { + let mut client = recursive::TestServiceSyncClient::new(i_prot, o_prot); + + let tree = RecTree { + children: Some(vec![Box::new(RecTree { + children: Some(vec![ + Box::new(RecTree { + children: None, + item: Some(3), + }), + Box::new(RecTree { + children: None, + item: Some(4), + }), + ]), + item: Some(2), + })]), + item: Some(1), + }; + + let expected_tree = RecTree { + children: Some(vec![Box::new(RecTree { + children: Some(vec![ + Box::new(RecTree { + children: Some(Vec::new()), // remote returns an empty list + item: Some(3), + }), + Box::new(RecTree { + children: Some(Vec::new()), // remote returns an empty list + item: Some(4), + }), + ]), + item: Some(2), + })]), + item: Some(1), + }; + + let returned_tree = execute_call("recursive", "echo_tree", || client.echo_tree(tree.clone()))?; + if returned_tree != expected_tree { + return Err(format!( + "mismatched recursive tree {:?} {:?}", + expected_tree, returned_tree + ) + .into()); + } + + let list = RecList { + nextitem: Some(Box::new(RecList { + nextitem: Some(Box::new(RecList { + nextitem: None, + item: Some(3), + })), + item: Some(2), + })), + item: Some(1), + }; + let returned_list = execute_call("recursive", "echo_list", || client.echo_list(list.clone()))?; + if returned_list != list { + return Err(format!("mismatched recursive list {:?} {:?}", list, returned_list).into()); + } + + let co_rec = CoRec { + other: Some(Box::new(CoRec2 { + other: Some(CoRec { + other: Some(Box::new(CoRec2 { other: None })), + }), + })), + }; + let returned_co_rec = execute_call("recursive", "echo_co_rec", || { + client.echo_co_rec(co_rec.clone()) + })?; + if returned_co_rec != co_rec { + return Err(format!("mismatched co_rec {:?} {:?}", co_rec, returned_co_rec).into()); + } + + Ok(()) +} + +fn execute_call<F, R>(service_type: &str, call_name: &str, mut f: F) -> thrift::Result<R> +where + F: FnMut() -> thrift::Result<R>, +{ + let res = f(); + + match res { + Ok(_) => println!("{}: completed {} call", service_type, call_name), + Err(ref e) => println!( + "{}: failed {} call with error {:?}", + service_type, call_name, e + ), + } + + res +} diff --git a/src/jaegertracing/thrift/lib/rs/test/src/bin/kitchen_sink_server.rs b/src/jaegertracing/thrift/lib/rs/test/src/bin/kitchen_sink_server.rs new file mode 100644 index 000000000..73801eaf8 --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/test/src/bin/kitchen_sink_server.rs @@ -0,0 +1,313 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate clap; +extern crate kitchen_sink; +extern crate thrift; + +use thrift::protocol::{ + TBinaryInputProtocolFactory, TBinaryOutputProtocolFactory, TCompactInputProtocolFactory, + TCompactOutputProtocolFactory, TInputProtocolFactory, TOutputProtocolFactory, +}; +use thrift::server::TServer; +use thrift::transport::{ + TFramedReadTransportFactory, TFramedWriteTransportFactory, TReadTransportFactory, + TWriteTransportFactory, +}; + +use kitchen_sink::base_one::Noodle; +use kitchen_sink::base_two::{ + BrothType, Napkin, NapkinServiceSyncHandler, Ramen, RamenServiceSyncHandler, +}; +use kitchen_sink::midlayer::{ + Dessert, Meal, MealServiceSyncHandler, MealServiceSyncProcessor, Pie, +}; +use kitchen_sink::recursive; +use kitchen_sink::ultimate::FullMealAndDrinksServiceSyncHandler; +use kitchen_sink::ultimate::{ + Drink, FullMeal, FullMealAndDrinks, FullMealAndDrinksServiceSyncProcessor, + FullMealServiceSyncHandler, +}; + +fn main() { + match run() { + Ok(()) => println!("kitchen sink server completed successfully"), + Err(e) => { + println!("kitchen sink server failed with error {:?}", e); + std::process::exit(1); + } + } +} + +fn run() -> thrift::Result<()> { + let matches = clap_app!(rust_kitchen_sink_server => + (version: "0.1.0") + (author: "Apache Thrift Developers <dev@thrift.apache.org>") + (about: "Thrift Rust kitchen sink test server") + (@arg port: --port +takes_value "port on which the test server listens") + (@arg protocol: --protocol +takes_value "Thrift protocol implementation to use (\"binary\", \"compact\")") + (@arg service: --service +takes_value "Service type to contact (\"part\", \"full\", \"recursive\")") + ) + .get_matches(); + + let port = value_t!(matches, "port", u16).unwrap_or(9090); + let protocol = matches.value_of("protocol").unwrap_or("compact"); + let service = matches.value_of("service").unwrap_or("part"); + let listen_address = format!("127.0.0.1:{}", port); + + println!("binding to {}", listen_address); + + let r_transport_factory = TFramedReadTransportFactory::new(); + let w_transport_factory = TFramedWriteTransportFactory::new(); + + let (i_protocol_factory, o_protocol_factory): ( + Box<TInputProtocolFactory>, + Box<TOutputProtocolFactory>, + ) = match &*protocol { + "binary" => ( + Box::new(TBinaryInputProtocolFactory::new()), + Box::new(TBinaryOutputProtocolFactory::new()), + ), + "compact" => ( + Box::new(TCompactInputProtocolFactory::new()), + Box::new(TCompactOutputProtocolFactory::new()), + ), + unknown => { + return Err(format!("unsupported transport type {}", unknown).into()); + } + }; + + // FIXME: should processor be boxed as well? + // + // [sigh] I hate Rust generics implementation + // + // I would have preferred to build a server here, return it, and then do + // the common listen-and-handle stuff, but since the server doesn't have a + // common type (because each match arm instantiates a server with a + // different processor) this isn't possible. + // + // Since what I'm doing is uncommon I'm just going to duplicate the code + match &*service { + "part" => run_meal_server( + &listen_address, + r_transport_factory, + i_protocol_factory, + w_transport_factory, + o_protocol_factory, + ), + "full" => run_full_meal_server( + &listen_address, + r_transport_factory, + i_protocol_factory, + w_transport_factory, + o_protocol_factory, + ), + "recursive" => run_recursive_server( + &listen_address, + r_transport_factory, + i_protocol_factory, + w_transport_factory, + o_protocol_factory, + ), + unknown => Err(format!("unsupported service type {}", unknown).into()), + } +} + +fn run_meal_server<RTF, IPF, WTF, OPF>( + listen_address: &str, + r_transport_factory: RTF, + i_protocol_factory: IPF, + w_transport_factory: WTF, + o_protocol_factory: OPF, +) -> thrift::Result<()> +where + RTF: TReadTransportFactory + 'static, + IPF: TInputProtocolFactory + 'static, + WTF: TWriteTransportFactory + 'static, + OPF: TOutputProtocolFactory + 'static, +{ + let processor = MealServiceSyncProcessor::new(PartHandler {}); + let mut server = TServer::new( + r_transport_factory, + i_protocol_factory, + w_transport_factory, + o_protocol_factory, + processor, + 1, + ); + + server.listen(listen_address) +} + +fn run_full_meal_server<RTF, IPF, WTF, OPF>( + listen_address: &str, + r_transport_factory: RTF, + i_protocol_factory: IPF, + w_transport_factory: WTF, + o_protocol_factory: OPF, +) -> thrift::Result<()> +where + RTF: TReadTransportFactory + 'static, + IPF: TInputProtocolFactory + 'static, + WTF: TWriteTransportFactory + 'static, + OPF: TOutputProtocolFactory + 'static, +{ + let processor = FullMealAndDrinksServiceSyncProcessor::new(FullHandler {}); + let mut server = TServer::new( + r_transport_factory, + i_protocol_factory, + w_transport_factory, + o_protocol_factory, + processor, + 1, + ); + + server.listen(listen_address) +} + +struct PartHandler; + +impl MealServiceSyncHandler for PartHandler { + fn handle_meal(&self) -> thrift::Result<Meal> { + println!("part: handling meal call"); + Ok(meal()) + } +} + +impl RamenServiceSyncHandler for PartHandler { + fn handle_ramen(&self, _: i32) -> thrift::Result<Ramen> { + println!("part: handling ramen call"); + Ok(ramen()) + } +} + +impl NapkinServiceSyncHandler for PartHandler { + fn handle_napkin(&self) -> thrift::Result<Napkin> { + println!("part: handling napkin call"); + Ok(napkin()) + } +} + +// full service +// + +struct FullHandler; + +impl FullMealAndDrinksServiceSyncHandler for FullHandler { + fn handle_full_meal_and_drinks(&self) -> thrift::Result<FullMealAndDrinks> { + println!("full_meal_and_drinks: handling full meal and drinks call"); + Ok(FullMealAndDrinks::new(full_meal(), Drink::CanadianWhisky)) + } + + fn handle_best_pie(&self) -> thrift::Result<Pie> { + println!("full_meal_and_drinks: handling pie call"); + Ok(Pie::MississippiMud) // I prefer Pie::Pumpkin, but I have to check that casing works + } +} + +impl FullMealServiceSyncHandler for FullHandler { + fn handle_full_meal(&self) -> thrift::Result<FullMeal> { + println!("full: handling full meal call"); + Ok(full_meal()) + } +} + +impl MealServiceSyncHandler for FullHandler { + fn handle_meal(&self) -> thrift::Result<Meal> { + println!("full: handling meal call"); + Ok(meal()) + } +} + +impl RamenServiceSyncHandler for FullHandler { + fn handle_ramen(&self, _: i32) -> thrift::Result<Ramen> { + println!("full: handling ramen call"); + Ok(ramen()) + } +} + +impl NapkinServiceSyncHandler for FullHandler { + fn handle_napkin(&self) -> thrift::Result<Napkin> { + println!("full: handling napkin call"); + Ok(napkin()) + } +} + +fn full_meal() -> FullMeal { + FullMeal::new(meal(), Dessert::Port("Graham's Tawny".to_owned())) +} + +fn meal() -> Meal { + Meal::new(noodle(), ramen()) +} + +fn noodle() -> Noodle { + Noodle::new("spelt".to_owned(), 100) +} + +fn ramen() -> Ramen { + Ramen::new("Mr Ramen".to_owned(), 72, BrothType::Miso) +} + +fn napkin() -> Napkin { + Napkin {} +} + +fn run_recursive_server<RTF, IPF, WTF, OPF>( + listen_address: &str, + r_transport_factory: RTF, + i_protocol_factory: IPF, + w_transport_factory: WTF, + o_protocol_factory: OPF, +) -> thrift::Result<()> +where + RTF: TReadTransportFactory + 'static, + IPF: TInputProtocolFactory + 'static, + WTF: TWriteTransportFactory + 'static, + OPF: TOutputProtocolFactory + 'static, +{ + let processor = recursive::TestServiceSyncProcessor::new(RecursiveTestServerHandler {}); + let mut server = TServer::new( + r_transport_factory, + i_protocol_factory, + w_transport_factory, + o_protocol_factory, + processor, + 1, + ); + + server.listen(listen_address) +} + +struct RecursiveTestServerHandler; +impl recursive::TestServiceSyncHandler for RecursiveTestServerHandler { + fn handle_echo_tree(&self, tree: recursive::RecTree) -> thrift::Result<recursive::RecTree> { + println!("{:?}", tree); + Ok(tree) + } + + fn handle_echo_list(&self, lst: recursive::RecList) -> thrift::Result<recursive::RecList> { + println!("{:?}", lst); + Ok(lst) + } + + fn handle_echo_co_rec(&self, item: recursive::CoRec) -> thrift::Result<recursive::CoRec> { + println!("{:?}", item); + Ok(item) + } +} diff --git a/src/jaegertracing/thrift/lib/rs/test/src/lib.rs b/src/jaegertracing/thrift/lib/rs/test/src/lib.rs new file mode 100644 index 000000000..9debdca54 --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/test/src/lib.rs @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate thrift; + +pub mod base_one; +pub mod base_two; +pub mod midlayer; +pub mod ultimate; +pub mod recursive; + +#[cfg(test)] +mod tests { + + use std::default::Default; + + use super::*; + + #[test] + fn must_be_able_to_use_constructor() { + let _ = midlayer::Meal::new(Some(base_one::Noodle::default()), None); + } + + #[test] + fn must_be_able_to_use_constructor_with_no_fields() { + let _ = midlayer::Meal::new(None, None); + } + + #[test] + fn must_be_able_to_use_constructor_without_option_wrap() { + let _ = midlayer::Meal::new(base_one::Noodle::default(), None); + } + + #[test] + fn must_be_able_to_use_defaults() { + let _ = midlayer::Meal { + noodle: Some(base_one::Noodle::default()), + ..Default::default() + }; + } +} diff --git a/src/jaegertracing/thrift/lib/rs/test/thrifts/Base_One.thrift b/src/jaegertracing/thrift/lib/rs/test/thrifts/Base_One.thrift new file mode 100644 index 000000000..c5fa6c20d --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/test/thrifts/Base_One.thrift @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * Contains some contributions under the Thrift Software License. + * Please see doc/old-thrift-license.txt in the Thrift distribution for + * details. + */ + +typedef i64 Temperature + +typedef i8 Size + +typedef string Location + +const i32 BoilingPoint = 100 + +const list<Temperature> Temperatures = [10, 11, 22, 33] + +// IMPORTANT: temps should end with ".0" because this tests +// that we don't have a problem with const float list generation +const list<double> CommonTemperatures = [300.0, 450.0] + +const double MealsPerDay = 2.5; + +const string DefaultRecipeName = "Soup-rise of the Day" +const binary DefaultRecipeBinary = "Soup-rise of the 01010101" + +struct Noodle { + 1: string flourType + 2: Temperature cookTemp +} + +struct Spaghetti { + 1: optional list<Noodle> noodles +} + +const Noodle SpeltNoodle = { "flourType": "spelt", "cookTemp": 110 } + +struct MeasuringSpoon { + 1: Size size +} + +struct MeasuringCup { + 1: double millis +} + +union MeasuringAids { + 1: MeasuringSpoon spoon + 2: MeasuringCup cup +} + +struct CookingTemperatures { + 1: set<double> commonTemperatures + 2: list<double> usedTemperatures + 3: map<double, double> fahrenheitToCentigradeConversions +} + +struct Recipe { + 1: string recipeName + 2: string cuisine + 3: i8 page +} + +union CookingTools { + 1: set<MeasuringSpoon> measuringSpoons + 2: map<Size, Location> measuringCups, + 3: list<Recipe> recipes +} + diff --git a/src/jaegertracing/thrift/lib/rs/test/thrifts/Base_Two.thrift b/src/jaegertracing/thrift/lib/rs/test/thrifts/Base_Two.thrift new file mode 100644 index 000000000..caa6acb86 --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/test/thrifts/Base_Two.thrift @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * Contains some contributions under the Thrift Software License. + * Please see doc/old-thrift-license.txt in the Thrift distribution for + * details. + */ + +const i32 WaterWeight = 200 + +enum brothType { + Miso, + shouyu, +} + +struct Ramen { + 1: optional string ramenType + 2: required i32 noodleCount + 3: brothType broth +} + +struct Napkin { + // empty +} + +service NapkinService { + Napkin napkin() +} + +service RamenService extends NapkinService { + Ramen ramen(1: i32 requestedNoodleCount) +} + +/* const struct CookedRamen = { "bar": 10 } */ + diff --git a/src/jaegertracing/thrift/lib/rs/test/thrifts/Midlayer.thrift b/src/jaegertracing/thrift/lib/rs/test/thrifts/Midlayer.thrift new file mode 100644 index 000000000..16ff49b0e --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/test/thrifts/Midlayer.thrift @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * Contains some contributions under the Thrift Software License. + * Please see doc/old-thrift-license.txt in the Thrift distribution for + * details. + */ + +include "Base_One.thrift" +include "Base_Two.thrift" + +const i32 WaterBoilingPoint = Base_One.BoilingPoint + +const map<string, Base_One.Temperature> TemperatureNames = { "freezing": 0, "boiling": 100 } + +const map<set<i32>, map<list<string>, string>> MyConstNestedMap = { + [0, 1, 2, 3]: { ["foo"]: "bar" }, + [20]: { ["nut", "ton"] : "bar" }, + [30, 40]: { ["bouncy", "tinkly"]: "castle" } +} + +const list<list<i32>> MyConstNestedList = [ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8] +] + +const set<set<i32>> MyConstNestedSet = [ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8] +] + +enum Pie { + PUMPKIN, + apple, // intentionally poorly cased + STRAWBERRY_RHUBARB, + Key_Lime, // intentionally poorly cased + coconut_Cream, // intentionally poorly cased + mississippi_mud, // intentionally poorly cased +} + +struct Meal { + 1: Base_One.Noodle noodle + 2: Base_Two.Ramen ramen +} + +union Dessert { + 1: string port + 2: string iceWine +} + +service MealService extends Base_Two.RamenService { + Meal meal() +} + diff --git a/src/jaegertracing/thrift/lib/rs/test/thrifts/Ultimate.thrift b/src/jaegertracing/thrift/lib/rs/test/thrifts/Ultimate.thrift new file mode 100644 index 000000000..72fa100a6 --- /dev/null +++ b/src/jaegertracing/thrift/lib/rs/test/thrifts/Ultimate.thrift @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * Contains some contributions under the Thrift Software License. + * Please see doc/old-thrift-license.txt in the Thrift distribution for + * details. + */ + +include "Midlayer.thrift" + +enum Drink { + WATER, + WHISKEY, + WINE, + scotch, // intentionally poorly cased + LATE_HARVEST_WINE, + India_Pale_Ale, // intentionally poorly cased + apple_cider, // intentially poorly cased + belgian_Ale, // intentionally poorly cased + Canadian_whisky, // intentionally poorly cased +} + +const map<i8, Midlayer.Pie> RankedPies = { + 1: Midlayer.Pie.PUMPKIN, + 2: Midlayer.Pie.STRAWBERRY_RHUBARB, + 3: Midlayer.Pie.apple, + 4: Midlayer.Pie.mississippi_mud, + 5: Midlayer.Pie.coconut_Cream, + 6: Midlayer.Pie.Key_Lime, +} + +struct FullMeal { + 1: required Midlayer.Meal meal + 2: required Midlayer.Dessert dessert +} + +struct FullMealAndDrinks { + 1: required FullMeal fullMeal + 2: optional Drink drink +} + +service FullMealService extends Midlayer.MealService { + FullMeal fullMeal() +} + +service FullMealAndDrinksService extends FullMealService { + FullMealAndDrinks fullMealAndDrinks() + + Midlayer.Pie bestPie() +} + |