diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
commit | 36d22d82aa202bb199967e9512281e9a53db42c9 (patch) | |
tree | 105e8c98ddea1c1e4784a60a5a6410fa416be2de /third_party/rust/prost | |
parent | Initial commit. (diff) | |
download | firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip |
Adding upstream version 115.7.0esr.upstream/115.7.0esrupstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
24 files changed, 6006 insertions, 0 deletions
diff --git a/third_party/rust/prost-derive/.cargo-checksum.json b/third_party/rust/prost-derive/.cargo-checksum.json new file mode 100644 index 0000000000..e64d4cd29b --- /dev/null +++ b/third_party/rust/prost-derive/.cargo-checksum.json @@ -0,0 +1 @@ +{"files":{"Cargo.toml":"4af10c29150183b223cd13b9eb0375fdbef6daba35e43a654c55c0491ca78ba2","README.md":"6c67fa1e48f14adfaf834f520f798ddfb79f90804f46cc215ee391a7d57913a4","src/field/group.rs":"dfd31b34008741abc857dc11362c11a06afd8eddceb62924443bf2bea117460a","src/field/map.rs":"69e1fa14cfeb4d4407b9c343a2a72b4ba68c3ecb7af912965a63b50e91dd61fb","src/field/message.rs":"fb910a563eba0b0c68a7a3855f7877535931faef4d788b05a3402f0cdd037e3e","src/field/mod.rs":"ca917a3f673623f0946e9e37ebad2916f3c1f8d163a1844143ef7d21d245394c","src/field/oneof.rs":"4fc488445b05e464070fadd8799cafb806db5c23d1494c4300cb293394863012","src/field/scalar.rs":"03d69f74daa2037a0210a086bd873ae9beca6b615e2973c96c43a6d0d30aa525","src/lib.rs":"8cfa6f1fc707e8931df82f9df67fa07e1b822b7581e4c9733e3ce24febabf708"},"package":"600d2f334aa05acb02a755e217ef1ab6dea4d51b58b7846588b747edec04efba"}
\ No newline at end of file diff --git a/third_party/rust/prost-derive/Cargo.toml b/third_party/rust/prost-derive/Cargo.toml new file mode 100644 index 0000000000..d6c2d3e74c --- /dev/null +++ b/third_party/rust/prost-derive/Cargo.toml @@ -0,0 +1,40 @@ +# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO +# +# When uploading crates to the registry Cargo will automatically +# "normalize" Cargo.toml files for maximal compatibility +# with all versions of Cargo and also rewrite `path` dependencies +# to registry (e.g., crates.io) dependencies +# +# If you believe there's an error in this file please file an +# issue against the rust-lang/cargo repository. If you're +# editing this file be aware that the upstream Cargo.toml +# will likely look very different (and much more reasonable) + +[package] +edition = "2018" +name = "prost-derive" +version = "0.8.0" +authors = ["Dan Burkert <dan@danburkert.com>", "Tokio Contributors <team@tokio.rs>"] +description = "A Protocol Buffers implementation for the Rust Language." +documentation = "https://docs.rs/prost-derive" +readme = "README.md" +license = "Apache-2.0" +repository = "https://github.com/tokio-rs/prost" + +[lib] +proc_macro = true +[dependencies.anyhow] +version = "1" + +[dependencies.itertools] +version = "0.10" + +[dependencies.proc-macro2] +version = "1" + +[dependencies.quote] +version = "1" + +[dependencies.syn] +version = "1" +features = ["extra-traits"] diff --git a/third_party/rust/prost-derive/README.md b/third_party/rust/prost-derive/README.md new file mode 100644 index 0000000000..a51050e7e6 --- /dev/null +++ b/third_party/rust/prost-derive/README.md @@ -0,0 +1,16 @@ +[![Documentation](https://docs.rs/prost-derive/badge.svg)](https://docs.rs/prost-derive/) +[![Crate](https://img.shields.io/crates/v/prost-derive.svg)](https://crates.io/crates/prost-derive) + +# prost-derive + +`prost-derive` handles generating encoding and decoding implementations for Rust +types annotated with `prost` annotation. For the most part, users of `prost` +shouldn't need to interact with `prost-derive` directly. + +## License + +`prost-derive` is distributed under the terms of the Apache License (Version 2.0). + +See [LICENSE](../LICENSE) for details. + +Copyright 2017 Dan Burkert diff --git a/third_party/rust/prost-derive/src/field/group.rs b/third_party/rust/prost-derive/src/field/group.rs new file mode 100644 index 0000000000..6396128802 --- /dev/null +++ b/third_party/rust/prost-derive/src/field/group.rs @@ -0,0 +1,134 @@ +use anyhow::{bail, Error}; +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; +use syn::Meta; + +use crate::field::{set_bool, set_option, tag_attr, word_attr, Label}; + +#[derive(Clone)] +pub struct Field { + pub label: Label, + pub tag: u32, +} + +impl Field { + pub fn new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error> { + let mut group = false; + let mut label = None; + let mut tag = None; + let mut boxed = false; + + let mut unknown_attrs = Vec::new(); + + for attr in attrs { + if word_attr("group", attr) { + set_bool(&mut group, "duplicate group attributes")?; + } else if word_attr("boxed", attr) { + set_bool(&mut boxed, "duplicate boxed attributes")?; + } else if let Some(t) = tag_attr(attr)? { + set_option(&mut tag, t, "duplicate tag attributes")?; + } else if let Some(l) = Label::from_attr(attr) { + set_option(&mut label, l, "duplicate label attributes")?; + } else { + unknown_attrs.push(attr); + } + } + + if !group { + return Ok(None); + } + + match unknown_attrs.len() { + 0 => (), + 1 => bail!("unknown attribute for group field: {:?}", unknown_attrs[0]), + _ => bail!("unknown attributes for group field: {:?}", unknown_attrs), + } + + let tag = match tag.or(inferred_tag) { + Some(tag) => tag, + None => bail!("group field is missing a tag attribute"), + }; + + Ok(Some(Field { + label: label.unwrap_or(Label::Optional), + tag, + })) + } + + pub fn new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error> { + if let Some(mut field) = Field::new(attrs, None)? { + if let Some(attr) = attrs.iter().find(|attr| Label::from_attr(attr).is_some()) { + bail!( + "invalid attribute for oneof field: {}", + attr.path().into_token_stream() + ); + } + field.label = Label::Required; + Ok(Some(field)) + } else { + Ok(None) + } + } + + pub fn encode(&self, ident: TokenStream) -> TokenStream { + let tag = self.tag; + match self.label { + Label::Optional => quote! { + if let Some(ref msg) = #ident { + ::prost::encoding::group::encode(#tag, msg, buf); + } + }, + Label::Required => quote! { + ::prost::encoding::group::encode(#tag, &#ident, buf); + }, + Label::Repeated => quote! { + for msg in &#ident { + ::prost::encoding::group::encode(#tag, msg, buf); + } + }, + } + } + + pub fn merge(&self, ident: TokenStream) -> TokenStream { + match self.label { + Label::Optional => quote! { + ::prost::encoding::group::merge( + tag, + wire_type, + #ident.get_or_insert_with(Default::default), + buf, + ctx, + ) + }, + Label::Required => quote! { + ::prost::encoding::group::merge(tag, wire_type, #ident, buf, ctx) + }, + Label::Repeated => quote! { + ::prost::encoding::group::merge_repeated(tag, wire_type, #ident, buf, ctx) + }, + } + } + + pub fn encoded_len(&self, ident: TokenStream) -> TokenStream { + let tag = self.tag; + match self.label { + Label::Optional => quote! { + #ident.as_ref().map_or(0, |msg| ::prost::encoding::group::encoded_len(#tag, msg)) + }, + Label::Required => quote! { + ::prost::encoding::group::encoded_len(#tag, &#ident) + }, + Label::Repeated => quote! { + ::prost::encoding::group::encoded_len_repeated(#tag, &#ident) + }, + } + } + + pub fn clear(&self, ident: TokenStream) -> TokenStream { + match self.label { + Label::Optional => quote!(#ident = ::core::option::Option::None), + Label::Required => quote!(#ident.clear()), + Label::Repeated => quote!(#ident.clear()), + } + } +} diff --git a/third_party/rust/prost-derive/src/field/map.rs b/third_party/rust/prost-derive/src/field/map.rs new file mode 100644 index 0000000000..1228a6fae2 --- /dev/null +++ b/third_party/rust/prost-derive/src/field/map.rs @@ -0,0 +1,394 @@ +use anyhow::{bail, Error}; +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use syn::{Ident, Lit, Meta, MetaNameValue, NestedMeta}; + +use crate::field::{scalar, set_option, tag_attr}; + +#[derive(Clone, Debug)] +pub enum MapTy { + HashMap, + BTreeMap, +} + +impl MapTy { + fn from_str(s: &str) -> Option<MapTy> { + match s { + "map" | "hash_map" => Some(MapTy::HashMap), + "btree_map" => Some(MapTy::BTreeMap), + _ => None, + } + } + + fn module(&self) -> Ident { + match *self { + MapTy::HashMap => Ident::new("hash_map", Span::call_site()), + MapTy::BTreeMap => Ident::new("btree_map", Span::call_site()), + } + } + + fn lib(&self) -> TokenStream { + match self { + MapTy::HashMap => quote! { std }, + MapTy::BTreeMap => quote! { prost::alloc }, + } + } +} + +fn fake_scalar(ty: scalar::Ty) -> scalar::Field { + let kind = scalar::Kind::Plain(scalar::DefaultValue::new(&ty)); + scalar::Field { + ty, + kind, + tag: 0, // Not used here + } +} + +#[derive(Clone)] +pub struct Field { + pub map_ty: MapTy, + pub key_ty: scalar::Ty, + pub value_ty: ValueTy, + pub tag: u32, +} + +impl Field { + pub fn new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error> { + let mut types = None; + let mut tag = None; + + for attr in attrs { + if let Some(t) = tag_attr(attr)? { + set_option(&mut tag, t, "duplicate tag attributes")?; + } else if let Some(map_ty) = attr + .path() + .get_ident() + .and_then(|i| MapTy::from_str(&i.to_string())) + { + let (k, v): (String, String) = match &*attr { + Meta::NameValue(MetaNameValue { + lit: Lit::Str(lit), .. + }) => { + let items = lit.value(); + let mut items = items.split(',').map(ToString::to_string); + let k = items.next().unwrap(); + let v = match items.next() { + Some(k) => k, + None => bail!("invalid map attribute: must have key and value types"), + }; + if items.next().is_some() { + bail!("invalid map attribute: {:?}", attr); + } + (k, v) + } + Meta::List(meta_list) => { + // TODO(rustlang/rust#23121): slice pattern matching would make this much nicer. + if meta_list.nested.len() != 2 { + bail!("invalid map attribute: must contain key and value types"); + } + let k = match &meta_list.nested[0] { + NestedMeta::Meta(Meta::Path(k)) if k.get_ident().is_some() => { + k.get_ident().unwrap().to_string() + } + _ => bail!("invalid map attribute: key must be an identifier"), + }; + let v = match &meta_list.nested[1] { + NestedMeta::Meta(Meta::Path(v)) if v.get_ident().is_some() => { + v.get_ident().unwrap().to_string() + } + _ => bail!("invalid map attribute: value must be an identifier"), + }; + (k, v) + } + _ => return Ok(None), + }; + set_option( + &mut types, + (map_ty, key_ty_from_str(&k)?, ValueTy::from_str(&v)?), + "duplicate map type attribute", + )?; + } else { + return Ok(None); + } + } + + Ok(match (types, tag.or(inferred_tag)) { + (Some((map_ty, key_ty, value_ty)), Some(tag)) => Some(Field { + map_ty, + key_ty, + value_ty, + tag, + }), + _ => None, + }) + } + + pub fn new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error> { + Field::new(attrs, None) + } + + /// Returns a statement which encodes the map field. + pub fn encode(&self, ident: TokenStream) -> TokenStream { + let tag = self.tag; + let key_mod = self.key_ty.module(); + let ke = quote!(::prost::encoding::#key_mod::encode); + let kl = quote!(::prost::encoding::#key_mod::encoded_len); + let module = self.map_ty.module(); + match &self.value_ty { + ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { + let default = quote!(#ty::default() as i32); + quote! { + ::prost::encoding::#module::encode_with_default( + #ke, + #kl, + ::prost::encoding::int32::encode, + ::prost::encoding::int32::encoded_len, + &(#default), + #tag, + &#ident, + buf, + ); + } + } + ValueTy::Scalar(value_ty) => { + let val_mod = value_ty.module(); + let ve = quote!(::prost::encoding::#val_mod::encode); + let vl = quote!(::prost::encoding::#val_mod::encoded_len); + quote! { + ::prost::encoding::#module::encode( + #ke, + #kl, + #ve, + #vl, + #tag, + &#ident, + buf, + ); + } + } + ValueTy::Message => quote! { + ::prost::encoding::#module::encode( + #ke, + #kl, + ::prost::encoding::message::encode, + ::prost::encoding::message::encoded_len, + #tag, + &#ident, + buf, + ); + }, + } + } + + /// Returns an expression which evaluates to the result of merging a decoded key value pair + /// into the map. + pub fn merge(&self, ident: TokenStream) -> TokenStream { + let key_mod = self.key_ty.module(); + let km = quote!(::prost::encoding::#key_mod::merge); + let module = self.map_ty.module(); + match &self.value_ty { + ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { + let default = quote!(#ty::default() as i32); + quote! { + ::prost::encoding::#module::merge_with_default( + #km, + ::prost::encoding::int32::merge, + #default, + &mut #ident, + buf, + ctx, + ) + } + } + ValueTy::Scalar(value_ty) => { + let val_mod = value_ty.module(); + let vm = quote!(::prost::encoding::#val_mod::merge); + quote!(::prost::encoding::#module::merge(#km, #vm, &mut #ident, buf, ctx)) + } + ValueTy::Message => quote! { + ::prost::encoding::#module::merge( + #km, + ::prost::encoding::message::merge, + &mut #ident, + buf, + ctx, + ) + }, + } + } + + /// Returns an expression which evaluates to the encoded length of the map. + pub fn encoded_len(&self, ident: TokenStream) -> TokenStream { + let tag = self.tag; + let key_mod = self.key_ty.module(); + let kl = quote!(::prost::encoding::#key_mod::encoded_len); + let module = self.map_ty.module(); + match &self.value_ty { + ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { + let default = quote!(#ty::default() as i32); + quote! { + ::prost::encoding::#module::encoded_len_with_default( + #kl, + ::prost::encoding::int32::encoded_len, + &(#default), + #tag, + &#ident, + ) + } + } + ValueTy::Scalar(value_ty) => { + let val_mod = value_ty.module(); + let vl = quote!(::prost::encoding::#val_mod::encoded_len); + quote!(::prost::encoding::#module::encoded_len(#kl, #vl, #tag, &#ident)) + } + ValueTy::Message => quote! { + ::prost::encoding::#module::encoded_len( + #kl, + ::prost::encoding::message::encoded_len, + #tag, + &#ident, + ) + }, + } + } + + pub fn clear(&self, ident: TokenStream) -> TokenStream { + quote!(#ident.clear()) + } + + /// Returns methods to embed in the message. + pub fn methods(&self, ident: &Ident) -> Option<TokenStream> { + if let ValueTy::Scalar(scalar::Ty::Enumeration(ty)) = &self.value_ty { + let key_ty = self.key_ty.rust_type(); + let key_ref_ty = self.key_ty.rust_ref_type(); + + let get = Ident::new(&format!("get_{}", ident), Span::call_site()); + let insert = Ident::new(&format!("insert_{}", ident), Span::call_site()); + let take_ref = if self.key_ty.is_numeric() { + quote!(&) + } else { + quote!() + }; + + let get_doc = format!( + "Returns the enum value for the corresponding key in `{}`, \ + or `None` if the entry does not exist or it is not a valid enum value.", + ident, + ); + let insert_doc = format!("Inserts a key value pair into `{}`.", ident); + Some(quote! { + #[doc=#get_doc] + pub fn #get(&self, key: #key_ref_ty) -> ::core::option::Option<#ty> { + self.#ident.get(#take_ref key).cloned().and_then(#ty::from_i32) + } + #[doc=#insert_doc] + pub fn #insert(&mut self, key: #key_ty, value: #ty) -> ::core::option::Option<#ty> { + self.#ident.insert(key, value as i32).and_then(#ty::from_i32) + } + }) + } else { + None + } + } + + /// Returns a newtype wrapper around the map, implementing nicer Debug + /// + /// The Debug tries to convert any enumerations met into the variants if possible, instead of + /// outputting the raw numbers. + pub fn debug(&self, wrapper_name: TokenStream) -> TokenStream { + let type_name = match self.map_ty { + MapTy::HashMap => Ident::new("HashMap", Span::call_site()), + MapTy::BTreeMap => Ident::new("BTreeMap", Span::call_site()), + }; + + // A fake field for generating the debug wrapper + let key_wrapper = fake_scalar(self.key_ty.clone()).debug(quote!(KeyWrapper)); + let key = self.key_ty.rust_type(); + let value_wrapper = self.value_ty.debug(); + let libname = self.map_ty.lib(); + let fmt = quote! { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + #key_wrapper + #value_wrapper + let mut builder = f.debug_map(); + for (k, v) in self.0 { + builder.entry(&KeyWrapper(k), &ValueWrapper(v)); + } + builder.finish() + } + }; + match &self.value_ty { + ValueTy::Scalar(ty) => { + let value = ty.rust_type(); + quote! { + struct #wrapper_name<'a>(&'a ::#libname::collections::#type_name<#key, #value>); + impl<'a> ::core::fmt::Debug for #wrapper_name<'a> { + #fmt + } + } + } + ValueTy::Message => quote! { + struct #wrapper_name<'a, V: 'a>(&'a ::#libname::collections::#type_name<#key, V>); + impl<'a, V> ::core::fmt::Debug for #wrapper_name<'a, V> + where + V: ::core::fmt::Debug + 'a, + { + #fmt + } + }, + } + } +} + +fn key_ty_from_str(s: &str) -> Result<scalar::Ty, Error> { + let ty = scalar::Ty::from_str(s)?; + match ty { + scalar::Ty::Int32 + | scalar::Ty::Int64 + | scalar::Ty::Uint32 + | scalar::Ty::Uint64 + | scalar::Ty::Sint32 + | scalar::Ty::Sint64 + | scalar::Ty::Fixed32 + | scalar::Ty::Fixed64 + | scalar::Ty::Sfixed32 + | scalar::Ty::Sfixed64 + | scalar::Ty::Bool + | scalar::Ty::String => Ok(ty), + _ => bail!("invalid map key type: {}", s), + } +} + +/// A map value type. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ValueTy { + Scalar(scalar::Ty), + Message, +} + +impl ValueTy { + fn from_str(s: &str) -> Result<ValueTy, Error> { + if let Ok(ty) = scalar::Ty::from_str(s) { + Ok(ValueTy::Scalar(ty)) + } else if s.trim() == "message" { + Ok(ValueTy::Message) + } else { + bail!("invalid map value type: {}", s); + } + } + + /// Returns a newtype wrapper around the ValueTy for nicer debug. + /// + /// If the contained value is enumeration, it tries to convert it to the variant. If not, it + /// just forwards the implementation. + fn debug(&self) -> TokenStream { + match self { + ValueTy::Scalar(ty) => fake_scalar(ty.clone()).debug(quote!(ValueWrapper)), + ValueTy::Message => quote!( + fn ValueWrapper<T>(v: T) -> T { + v + } + ), + } + } +} diff --git a/third_party/rust/prost-derive/src/field/message.rs b/third_party/rust/prost-derive/src/field/message.rs new file mode 100644 index 0000000000..1ff7c37983 --- /dev/null +++ b/third_party/rust/prost-derive/src/field/message.rs @@ -0,0 +1,134 @@ +use anyhow::{bail, Error}; +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; +use syn::Meta; + +use crate::field::{set_bool, set_option, tag_attr, word_attr, Label}; + +#[derive(Clone)] +pub struct Field { + pub label: Label, + pub tag: u32, +} + +impl Field { + pub fn new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error> { + let mut message = false; + let mut label = None; + let mut tag = None; + let mut boxed = false; + + let mut unknown_attrs = Vec::new(); + + for attr in attrs { + if word_attr("message", attr) { + set_bool(&mut message, "duplicate message attribute")?; + } else if word_attr("boxed", attr) { + set_bool(&mut boxed, "duplicate boxed attribute")?; + } else if let Some(t) = tag_attr(attr)? { + set_option(&mut tag, t, "duplicate tag attributes")?; + } else if let Some(l) = Label::from_attr(attr) { + set_option(&mut label, l, "duplicate label attributes")?; + } else { + unknown_attrs.push(attr); + } + } + + if !message { + return Ok(None); + } + + match unknown_attrs.len() { + 0 => (), + 1 => bail!( + "unknown attribute for message field: {:?}", + unknown_attrs[0] + ), + _ => bail!("unknown attributes for message field: {:?}", unknown_attrs), + } + + let tag = match tag.or(inferred_tag) { + Some(tag) => tag, + None => bail!("message field is missing a tag attribute"), + }; + + Ok(Some(Field { + label: label.unwrap_or(Label::Optional), + tag, + })) + } + + pub fn new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error> { + if let Some(mut field) = Field::new(attrs, None)? { + if let Some(attr) = attrs.iter().find(|attr| Label::from_attr(attr).is_some()) { + bail!( + "invalid attribute for oneof field: {}", + attr.path().into_token_stream() + ); + } + field.label = Label::Required; + Ok(Some(field)) + } else { + Ok(None) + } + } + + pub fn encode(&self, ident: TokenStream) -> TokenStream { + let tag = self.tag; + match self.label { + Label::Optional => quote! { + if let Some(ref msg) = #ident { + ::prost::encoding::message::encode(#tag, msg, buf); + } + }, + Label::Required => quote! { + ::prost::encoding::message::encode(#tag, &#ident, buf); + }, + Label::Repeated => quote! { + for msg in &#ident { + ::prost::encoding::message::encode(#tag, msg, buf); + } + }, + } + } + + pub fn merge(&self, ident: TokenStream) -> TokenStream { + match self.label { + Label::Optional => quote! { + ::prost::encoding::message::merge(wire_type, + #ident.get_or_insert_with(Default::default), + buf, + ctx) + }, + Label::Required => quote! { + ::prost::encoding::message::merge(wire_type, #ident, buf, ctx) + }, + Label::Repeated => quote! { + ::prost::encoding::message::merge_repeated(wire_type, #ident, buf, ctx) + }, + } + } + + pub fn encoded_len(&self, ident: TokenStream) -> TokenStream { + let tag = self.tag; + match self.label { + Label::Optional => quote! { + #ident.as_ref().map_or(0, |msg| ::prost::encoding::message::encoded_len(#tag, msg)) + }, + Label::Required => quote! { + ::prost::encoding::message::encoded_len(#tag, &#ident) + }, + Label::Repeated => quote! { + ::prost::encoding::message::encoded_len_repeated(#tag, &#ident) + }, + } + } + + pub fn clear(&self, ident: TokenStream) -> TokenStream { + match self.label { + Label::Optional => quote!(#ident = ::core::option::Option::None), + Label::Required => quote!(#ident.clear()), + Label::Repeated => quote!(#ident.clear()), + } + } +} diff --git a/third_party/rust/prost-derive/src/field/mod.rs b/third_party/rust/prost-derive/src/field/mod.rs new file mode 100644 index 0000000000..09fef830ef --- /dev/null +++ b/third_party/rust/prost-derive/src/field/mod.rs @@ -0,0 +1,366 @@ +mod group; +mod map; +mod message; +mod oneof; +mod scalar; + +use std::fmt; +use std::slice; + +use anyhow::{bail, Error}; +use proc_macro2::TokenStream; +use quote::quote; +use syn::{Attribute, Ident, Lit, LitBool, Meta, MetaList, MetaNameValue, NestedMeta}; + +#[derive(Clone)] +pub enum Field { + /// A scalar field. + Scalar(scalar::Field), + /// A message field. + Message(message::Field), + /// A map field. + Map(map::Field), + /// A oneof field. + Oneof(oneof::Field), + /// A group field. + Group(group::Field), +} + +impl Field { + /// Creates a new `Field` from an iterator of field attributes. + /// + /// If the meta items are invalid, an error will be returned. + /// If the field should be ignored, `None` is returned. + pub fn new(attrs: Vec<Attribute>, inferred_tag: Option<u32>) -> Result<Option<Field>, Error> { + let attrs = prost_attrs(attrs); + + // TODO: check for ignore attribute. + + let field = if let Some(field) = scalar::Field::new(&attrs, inferred_tag)? { + Field::Scalar(field) + } else if let Some(field) = message::Field::new(&attrs, inferred_tag)? { + Field::Message(field) + } else if let Some(field) = map::Field::new(&attrs, inferred_tag)? { + Field::Map(field) + } else if let Some(field) = oneof::Field::new(&attrs)? { + Field::Oneof(field) + } else if let Some(field) = group::Field::new(&attrs, inferred_tag)? { + Field::Group(field) + } else { + bail!("no type attribute"); + }; + + Ok(Some(field)) + } + + /// Creates a new oneof `Field` from an iterator of field attributes. + /// + /// If the meta items are invalid, an error will be returned. + /// If the field should be ignored, `None` is returned. + pub fn new_oneof(attrs: Vec<Attribute>) -> Result<Option<Field>, Error> { + let attrs = prost_attrs(attrs); + + // TODO: check for ignore attribute. + + let field = if let Some(field) = scalar::Field::new_oneof(&attrs)? { + Field::Scalar(field) + } else if let Some(field) = message::Field::new_oneof(&attrs)? { + Field::Message(field) + } else if let Some(field) = map::Field::new_oneof(&attrs)? { + Field::Map(field) + } else if let Some(field) = group::Field::new_oneof(&attrs)? { + Field::Group(field) + } else { + bail!("no type attribute for oneof field"); + }; + + Ok(Some(field)) + } + + pub fn tags(&self) -> Vec<u32> { + match *self { + Field::Scalar(ref scalar) => vec![scalar.tag], + Field::Message(ref message) => vec![message.tag], + Field::Map(ref map) => vec![map.tag], + Field::Oneof(ref oneof) => oneof.tags.clone(), + Field::Group(ref group) => vec![group.tag], + } + } + + /// Returns a statement which encodes the field. + pub fn encode(&self, ident: TokenStream) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => scalar.encode(ident), + Field::Message(ref message) => message.encode(ident), + Field::Map(ref map) => map.encode(ident), + Field::Oneof(ref oneof) => oneof.encode(ident), + Field::Group(ref group) => group.encode(ident), + } + } + + /// Returns an expression which evaluates to the result of merging a decoded + /// value into the field. + pub fn merge(&self, ident: TokenStream) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => scalar.merge(ident), + Field::Message(ref message) => message.merge(ident), + Field::Map(ref map) => map.merge(ident), + Field::Oneof(ref oneof) => oneof.merge(ident), + Field::Group(ref group) => group.merge(ident), + } + } + + /// Returns an expression which evaluates to the encoded length of the field. + pub fn encoded_len(&self, ident: TokenStream) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => scalar.encoded_len(ident), + Field::Map(ref map) => map.encoded_len(ident), + Field::Message(ref msg) => msg.encoded_len(ident), + Field::Oneof(ref oneof) => oneof.encoded_len(ident), + Field::Group(ref group) => group.encoded_len(ident), + } + } + + /// Returns a statement which clears the field. + pub fn clear(&self, ident: TokenStream) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => scalar.clear(ident), + Field::Message(ref message) => message.clear(ident), + Field::Map(ref map) => map.clear(ident), + Field::Oneof(ref oneof) => oneof.clear(ident), + Field::Group(ref group) => group.clear(ident), + } + } + + pub fn default(&self) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => scalar.default(), + _ => quote!(::core::default::Default::default()), + } + } + + /// Produces the fragment implementing debug for the given field. + pub fn debug(&self, ident: TokenStream) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => { + let wrapper = scalar.debug(quote!(ScalarWrapper)); + quote! { + { + #wrapper + ScalarWrapper(&#ident) + } + } + } + Field::Map(ref map) => { + let wrapper = map.debug(quote!(MapWrapper)); + quote! { + { + #wrapper + MapWrapper(&#ident) + } + } + } + _ => quote!(&#ident), + } + } + + pub fn methods(&self, ident: &Ident) -> Option<TokenStream> { + match *self { + Field::Scalar(ref scalar) => scalar.methods(ident), + Field::Map(ref map) => map.methods(ident), + _ => None, + } + } +} + +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum Label { + /// An optional field. + Optional, + /// A required field. + Required, + /// A repeated field. + Repeated, +} + +impl Label { + fn as_str(self) -> &'static str { + match self { + Label::Optional => "optional", + Label::Required => "required", + Label::Repeated => "repeated", + } + } + + fn variants() -> slice::Iter<'static, Label> { + const VARIANTS: &[Label] = &[Label::Optional, Label::Required, Label::Repeated]; + VARIANTS.iter() + } + + /// Parses a string into a field label. + /// If the string doesn't match a field label, `None` is returned. + fn from_attr(attr: &Meta) -> Option<Label> { + if let Meta::Path(ref path) = *attr { + for &label in Label::variants() { + if path.is_ident(label.as_str()) { + return Some(label); + } + } + } + None + } +} + +impl fmt::Debug for Label { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +impl fmt::Display for Label { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +/// Get the items belonging to the 'prost' list attribute, e.g. `#[prost(foo, bar="baz")]`. +fn prost_attrs(attrs: Vec<Attribute>) -> Vec<Meta> { + attrs + .iter() + .flat_map(Attribute::parse_meta) + .flat_map(|meta| match meta { + Meta::List(MetaList { path, nested, .. }) => { + if path.is_ident("prost") { + nested.into_iter().collect() + } else { + Vec::new() + } + } + _ => Vec::new(), + }) + .flat_map(|attr| -> Result<_, _> { + match attr { + NestedMeta::Meta(attr) => Ok(attr), + NestedMeta::Lit(lit) => bail!("invalid prost attribute: {:?}", lit), + } + }) + .collect() +} + +pub fn set_option<T>(option: &mut Option<T>, value: T, message: &str) -> Result<(), Error> +where + T: fmt::Debug, +{ + if let Some(ref existing) = *option { + bail!("{}: {:?} and {:?}", message, existing, value); + } + *option = Some(value); + Ok(()) +} + +pub fn set_bool(b: &mut bool, message: &str) -> Result<(), Error> { + if *b { + bail!("{}", message); + } else { + *b = true; + Ok(()) + } +} + +/// Unpacks an attribute into a (key, boolean) pair, returning the boolean value. +/// If the key doesn't match the attribute, `None` is returned. +fn bool_attr(key: &str, attr: &Meta) -> Result<Option<bool>, Error> { + if !attr.path().is_ident(key) { + return Ok(None); + } + match *attr { + Meta::Path(..) => Ok(Some(true)), + Meta::List(ref meta_list) => { + // TODO(rustlang/rust#23121): slice pattern matching would make this much nicer. + if meta_list.nested.len() == 1 { + if let NestedMeta::Lit(Lit::Bool(LitBool { value, .. })) = meta_list.nested[0] { + return Ok(Some(value)); + } + } + bail!("invalid {} attribute", key); + } + Meta::NameValue(MetaNameValue { + lit: Lit::Str(ref lit), + .. + }) => lit + .value() + .parse::<bool>() + .map_err(Error::from) + .map(Option::Some), + Meta::NameValue(MetaNameValue { + lit: Lit::Bool(LitBool { value, .. }), + .. + }) => Ok(Some(value)), + _ => bail!("invalid {} attribute", key), + } +} + +/// Checks if an attribute matches a word. +fn word_attr(key: &str, attr: &Meta) -> bool { + if let Meta::Path(ref path) = *attr { + path.is_ident(key) + } else { + false + } +} + +pub(super) fn tag_attr(attr: &Meta) -> Result<Option<u32>, Error> { + if !attr.path().is_ident("tag") { + return Ok(None); + } + match *attr { + Meta::List(ref meta_list) => { + // TODO(rustlang/rust#23121): slice pattern matching would make this much nicer. + if meta_list.nested.len() == 1 { + if let NestedMeta::Lit(Lit::Int(ref lit)) = meta_list.nested[0] { + return Ok(Some(lit.base10_parse()?)); + } + } + bail!("invalid tag attribute: {:?}", attr); + } + Meta::NameValue(ref meta_name_value) => match meta_name_value.lit { + Lit::Str(ref lit) => lit + .value() + .parse::<u32>() + .map_err(Error::from) + .map(Option::Some), + Lit::Int(ref lit) => Ok(Some(lit.base10_parse()?)), + _ => bail!("invalid tag attribute: {:?}", attr), + }, + _ => bail!("invalid tag attribute: {:?}", attr), + } +} + +fn tags_attr(attr: &Meta) -> Result<Option<Vec<u32>>, Error> { + if !attr.path().is_ident("tags") { + return Ok(None); + } + match *attr { + Meta::List(ref meta_list) => { + let mut tags = Vec::with_capacity(meta_list.nested.len()); + for item in &meta_list.nested { + if let NestedMeta::Lit(Lit::Int(ref lit)) = *item { + tags.push(lit.base10_parse()?); + } else { + bail!("invalid tag attribute: {:?}", attr); + } + } + Ok(Some(tags)) + } + Meta::NameValue(MetaNameValue { + lit: Lit::Str(ref lit), + .. + }) => lit + .value() + .split(',') + .map(|s| s.trim().parse::<u32>().map_err(Error::from)) + .collect::<Result<Vec<u32>, _>>() + .map(Some), + _ => bail!("invalid tag attribute: {:?}", attr), + } +} diff --git a/third_party/rust/prost-derive/src/field/oneof.rs b/third_party/rust/prost-derive/src/field/oneof.rs new file mode 100644 index 0000000000..7e7f08671c --- /dev/null +++ b/third_party/rust/prost-derive/src/field/oneof.rs @@ -0,0 +1,99 @@ +use anyhow::{bail, Error}; +use proc_macro2::TokenStream; +use quote::quote; +use syn::{parse_str, Lit, Meta, MetaNameValue, NestedMeta, Path}; + +use crate::field::{set_option, tags_attr}; + +#[derive(Clone)] +pub struct Field { + pub ty: Path, + pub tags: Vec<u32>, +} + +impl Field { + pub fn new(attrs: &[Meta]) -> Result<Option<Field>, Error> { + let mut ty = None; + let mut tags = None; + let mut unknown_attrs = Vec::new(); + + for attr in attrs { + if attr.path().is_ident("oneof") { + let t = match *attr { + Meta::NameValue(MetaNameValue { + lit: Lit::Str(ref lit), + .. + }) => parse_str::<Path>(&lit.value())?, + Meta::List(ref list) if list.nested.len() == 1 => { + // TODO(rustlang/rust#23121): slice pattern matching would make this much nicer. + if let NestedMeta::Meta(Meta::Path(ref path)) = list.nested[0] { + if let Some(ident) = path.get_ident() { + Path::from(ident.clone()) + } else { + bail!("invalid oneof attribute: item must be an identifier"); + } + } else { + bail!("invalid oneof attribute: item must be an identifier"); + } + } + _ => bail!("invalid oneof attribute: {:?}", attr), + }; + set_option(&mut ty, t, "duplicate oneof attribute")?; + } else if let Some(t) = tags_attr(attr)? { + set_option(&mut tags, t, "duplicate tags attributes")?; + } else { + unknown_attrs.push(attr); + } + } + + let ty = match ty { + Some(ty) => ty, + None => return Ok(None), + }; + + match unknown_attrs.len() { + 0 => (), + 1 => bail!( + "unknown attribute for message field: {:?}", + unknown_attrs[0] + ), + _ => bail!("unknown attributes for message field: {:?}", unknown_attrs), + } + + let tags = match tags { + Some(tags) => tags, + None => bail!("oneof field is missing a tags attribute"), + }; + + Ok(Some(Field { ty, tags })) + } + + /// Returns a statement which encodes the oneof field. + pub fn encode(&self, ident: TokenStream) -> TokenStream { + quote! { + if let Some(ref oneof) = #ident { + oneof.encode(buf) + } + } + } + + /// Returns an expression which evaluates to the result of decoding the oneof field. + pub fn merge(&self, ident: TokenStream) -> TokenStream { + let ty = &self.ty; + quote! { + #ty::merge(#ident, tag, wire_type, buf, ctx) + } + } + + /// Returns an expression which evaluates to the encoded length of the oneof field. + pub fn encoded_len(&self, ident: TokenStream) -> TokenStream { + let ty = &self.ty; + quote! { + #ident.as_ref().map_or(0, #ty::encoded_len) + } + } + + pub fn clear(&self, ident: TokenStream) -> TokenStream { + quote!(#ident = ::core::option::Option::None) + } +} diff --git a/third_party/rust/prost-derive/src/field/scalar.rs b/third_party/rust/prost-derive/src/field/scalar.rs new file mode 100644 index 0000000000..75a4fe3a86 --- /dev/null +++ b/third_party/rust/prost-derive/src/field/scalar.rs @@ -0,0 +1,810 @@ +use std::convert::TryFrom; +use std::fmt; + +use anyhow::{anyhow, bail, Error}; +use proc_macro2::{Span, TokenStream}; +use quote::{quote, ToTokens, TokenStreamExt}; +use syn::{parse_str, Ident, Lit, LitByteStr, Meta, MetaList, MetaNameValue, NestedMeta, Path}; + +use crate::field::{bool_attr, set_option, tag_attr, Label}; + +/// A scalar protobuf field. +#[derive(Clone)] +pub struct Field { + pub ty: Ty, + pub kind: Kind, + pub tag: u32, +} + +impl Field { + pub fn new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error> { + let mut ty = None; + let mut label = None; + let mut packed = None; + let mut default = None; + let mut tag = None; + + let mut unknown_attrs = Vec::new(); + + for attr in attrs { + if let Some(t) = Ty::from_attr(attr)? { + set_option(&mut ty, t, "duplicate type attributes")?; + } else if let Some(p) = bool_attr("packed", attr)? { + set_option(&mut packed, p, "duplicate packed attributes")?; + } else if let Some(t) = tag_attr(attr)? { + set_option(&mut tag, t, "duplicate tag attributes")?; + } else if let Some(l) = Label::from_attr(attr) { + set_option(&mut label, l, "duplicate label attributes")?; + } else if let Some(d) = DefaultValue::from_attr(attr)? { + set_option(&mut default, d, "duplicate default attributes")?; + } else { + unknown_attrs.push(attr); + } + } + + let ty = match ty { + Some(ty) => ty, + None => return Ok(None), + }; + + match unknown_attrs.len() { + 0 => (), + 1 => bail!("unknown attribute: {:?}", unknown_attrs[0]), + _ => bail!("unknown attributes: {:?}", unknown_attrs), + } + + let tag = match tag.or(inferred_tag) { + Some(tag) => tag, + None => bail!("missing tag attribute"), + }; + + let has_default = default.is_some(); + let default = default.map_or_else( + || Ok(DefaultValue::new(&ty)), + |lit| DefaultValue::from_lit(&ty, lit), + )?; + + let kind = match (label, packed, has_default) { + (None, Some(true), _) + | (Some(Label::Optional), Some(true), _) + | (Some(Label::Required), Some(true), _) => { + bail!("packed attribute may only be applied to repeated fields"); + } + (Some(Label::Repeated), Some(true), _) if !ty.is_numeric() => { + bail!("packed attribute may only be applied to numeric types"); + } + (Some(Label::Repeated), _, true) => { + bail!("repeated fields may not have a default value"); + } + + (None, _, _) => Kind::Plain(default), + (Some(Label::Optional), _, _) => Kind::Optional(default), + (Some(Label::Required), _, _) => Kind::Required(default), + (Some(Label::Repeated), packed, false) if packed.unwrap_or_else(|| ty.is_numeric()) => { + Kind::Packed + } + (Some(Label::Repeated), _, false) => Kind::Repeated, + }; + + Ok(Some(Field { ty, kind, tag })) + } + + pub fn new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error> { + if let Some(mut field) = Field::new(attrs, None)? { + match field.kind { + Kind::Plain(default) => { + field.kind = Kind::Required(default); + Ok(Some(field)) + } + Kind::Optional(..) => bail!("invalid optional attribute on oneof field"), + Kind::Required(..) => bail!("invalid required attribute on oneof field"), + Kind::Packed | Kind::Repeated => bail!("invalid repeated attribute on oneof field"), + } + } else { + Ok(None) + } + } + + pub fn encode(&self, ident: TokenStream) -> TokenStream { + let module = self.ty.module(); + let encode_fn = match self.kind { + Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(encode), + Kind::Repeated => quote!(encode_repeated), + Kind::Packed => quote!(encode_packed), + }; + let encode_fn = quote!(::prost::encoding::#module::#encode_fn); + let tag = self.tag; + + match self.kind { + Kind::Plain(ref default) => { + let default = default.typed(); + quote! { + if #ident != #default { + #encode_fn(#tag, &#ident, buf); + } + } + } + Kind::Optional(..) => quote! { + if let ::core::option::Option::Some(ref value) = #ident { + #encode_fn(#tag, value, buf); + } + }, + Kind::Required(..) | Kind::Repeated | Kind::Packed => quote! { + #encode_fn(#tag, &#ident, buf); + }, + } + } + + /// Returns an expression which evaluates to the result of merging a decoded + /// scalar value into the field. + pub fn merge(&self, ident: TokenStream) -> TokenStream { + let module = self.ty.module(); + let merge_fn = match self.kind { + Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(merge), + Kind::Repeated | Kind::Packed => quote!(merge_repeated), + }; + let merge_fn = quote!(::prost::encoding::#module::#merge_fn); + + match self.kind { + Kind::Plain(..) | Kind::Required(..) | Kind::Repeated | Kind::Packed => quote! { + #merge_fn(wire_type, #ident, buf, ctx) + }, + Kind::Optional(..) => quote! { + #merge_fn(wire_type, + #ident.get_or_insert_with(Default::default), + buf, + ctx) + }, + } + } + + /// Returns an expression which evaluates to the encoded length of the field. + pub fn encoded_len(&self, ident: TokenStream) -> TokenStream { + let module = self.ty.module(); + let encoded_len_fn = match self.kind { + Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(encoded_len), + Kind::Repeated => quote!(encoded_len_repeated), + Kind::Packed => quote!(encoded_len_packed), + }; + let encoded_len_fn = quote!(::prost::encoding::#module::#encoded_len_fn); + let tag = self.tag; + + match self.kind { + Kind::Plain(ref default) => { + let default = default.typed(); + quote! { + if #ident != #default { + #encoded_len_fn(#tag, &#ident) + } else { + 0 + } + } + } + Kind::Optional(..) => quote! { + #ident.as_ref().map_or(0, |value| #encoded_len_fn(#tag, value)) + }, + Kind::Required(..) | Kind::Repeated | Kind::Packed => quote! { + #encoded_len_fn(#tag, &#ident) + }, + } + } + + pub fn clear(&self, ident: TokenStream) -> TokenStream { + match self.kind { + Kind::Plain(ref default) | Kind::Required(ref default) => { + let default = default.typed(); + match self.ty { + Ty::String | Ty::Bytes(..) => quote!(#ident.clear()), + _ => quote!(#ident = #default), + } + } + Kind::Optional(_) => quote!(#ident = ::core::option::Option::None), + Kind::Repeated | Kind::Packed => quote!(#ident.clear()), + } + } + + /// Returns an expression which evaluates to the default value of the field. + pub fn default(&self) -> TokenStream { + match self.kind { + Kind::Plain(ref value) | Kind::Required(ref value) => value.owned(), + Kind::Optional(_) => quote!(::core::option::Option::None), + Kind::Repeated | Kind::Packed => quote!(::prost::alloc::vec::Vec::new()), + } + } + + /// An inner debug wrapper, around the base type. + fn debug_inner(&self, wrap_name: TokenStream) -> TokenStream { + if let Ty::Enumeration(ref ty) = self.ty { + quote! { + struct #wrap_name<'a>(&'a i32); + impl<'a> ::core::fmt::Debug for #wrap_name<'a> { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + match #ty::from_i32(*self.0) { + None => ::core::fmt::Debug::fmt(&self.0, f), + Some(en) => ::core::fmt::Debug::fmt(&en, f), + } + } + } + } + } else { + quote! { + fn #wrap_name<T>(v: T) -> T { v } + } + } + } + + /// Returns a fragment for formatting the field `ident` in `Debug`. + pub fn debug(&self, wrapper_name: TokenStream) -> TokenStream { + let wrapper = self.debug_inner(quote!(Inner)); + let inner_ty = self.ty.rust_type(); + match self.kind { + Kind::Plain(_) | Kind::Required(_) => self.debug_inner(wrapper_name), + Kind::Optional(_) => quote! { + struct #wrapper_name<'a>(&'a ::core::option::Option<#inner_ty>); + impl<'a> ::core::fmt::Debug for #wrapper_name<'a> { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + #wrapper + ::core::fmt::Debug::fmt(&self.0.as_ref().map(Inner), f) + } + } + }, + Kind::Repeated | Kind::Packed => { + quote! { + struct #wrapper_name<'a>(&'a ::prost::alloc::vec::Vec<#inner_ty>); + impl<'a> ::core::fmt::Debug for #wrapper_name<'a> { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + let mut vec_builder = f.debug_list(); + for v in self.0 { + #wrapper + vec_builder.entry(&Inner(v)); + } + vec_builder.finish() + } + } + } + } + } + } + + /// Returns methods to embed in the message. + pub fn methods(&self, ident: &Ident) -> Option<TokenStream> { + let mut ident_str = ident.to_string(); + if ident_str.starts_with("r#") { + ident_str = ident_str[2..].to_owned(); + } + + if let Ty::Enumeration(ref ty) = self.ty { + let set = Ident::new(&format!("set_{}", ident_str), Span::call_site()); + let set_doc = format!("Sets `{}` to the provided enum value.", ident_str); + Some(match self.kind { + Kind::Plain(ref default) | Kind::Required(ref default) => { + let get_doc = format!( + "Returns the enum value of `{}`, \ + or the default if the field is set to an invalid enum value.", + ident_str, + ); + quote! { + #[doc=#get_doc] + pub fn #ident(&self) -> #ty { + #ty::from_i32(self.#ident).unwrap_or(#default) + } + + #[doc=#set_doc] + pub fn #set(&mut self, value: #ty) { + self.#ident = value as i32; + } + } + } + Kind::Optional(ref default) => { + let get_doc = format!( + "Returns the enum value of `{}`, \ + or the default if the field is unset or set to an invalid enum value.", + ident_str, + ); + quote! { + #[doc=#get_doc] + pub fn #ident(&self) -> #ty { + self.#ident.and_then(#ty::from_i32).unwrap_or(#default) + } + + #[doc=#set_doc] + pub fn #set(&mut self, value: #ty) { + self.#ident = ::core::option::Option::Some(value as i32); + } + } + } + Kind::Repeated | Kind::Packed => { + let iter_doc = format!( + "Returns an iterator which yields the valid enum values contained in `{}`.", + ident_str, + ); + let push = Ident::new(&format!("push_{}", ident_str), Span::call_site()); + let push_doc = format!("Appends the provided enum value to `{}`.", ident_str); + quote! { + #[doc=#iter_doc] + pub fn #ident(&self) -> ::core::iter::FilterMap< + ::core::iter::Cloned<::core::slice::Iter<i32>>, + fn(i32) -> ::core::option::Option<#ty>, + > { + self.#ident.iter().cloned().filter_map(#ty::from_i32) + } + #[doc=#push_doc] + pub fn #push(&mut self, value: #ty) { + self.#ident.push(value as i32); + } + } + } + }) + } else if let Kind::Optional(ref default) = self.kind { + let ty = self.ty.rust_ref_type(); + + let match_some = if self.ty.is_numeric() { + quote!(::core::option::Option::Some(val) => val,) + } else { + quote!(::core::option::Option::Some(ref val) => &val[..],) + }; + + let get_doc = format!( + "Returns the value of `{0}`, or the default value if `{0}` is unset.", + ident_str, + ); + + Some(quote! { + #[doc=#get_doc] + pub fn #ident(&self) -> #ty { + match self.#ident { + #match_some + ::core::option::Option::None => #default, + } + } + }) + } else { + None + } + } +} + +/// A scalar protobuf field type. +#[derive(Clone, PartialEq, Eq)] +pub enum Ty { + Double, + Float, + Int32, + Int64, + Uint32, + Uint64, + Sint32, + Sint64, + Fixed32, + Fixed64, + Sfixed32, + Sfixed64, + Bool, + String, + Bytes(BytesTy), + Enumeration(Path), +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum BytesTy { + Vec, + Bytes, +} + +impl BytesTy { + fn try_from_str(s: &str) -> Result<Self, Error> { + match s { + "vec" => Ok(BytesTy::Vec), + "bytes" => Ok(BytesTy::Bytes), + _ => bail!("Invalid bytes type: {}", s), + } + } + + fn rust_type(&self) -> TokenStream { + match self { + BytesTy::Vec => quote! { ::prost::alloc::vec::Vec<u8> }, + BytesTy::Bytes => quote! { ::prost::bytes::Bytes }, + } + } +} + +impl Ty { + pub fn from_attr(attr: &Meta) -> Result<Option<Ty>, Error> { + let ty = match *attr { + Meta::Path(ref name) if name.is_ident("float") => Ty::Float, + Meta::Path(ref name) if name.is_ident("double") => Ty::Double, + Meta::Path(ref name) if name.is_ident("int32") => Ty::Int32, + Meta::Path(ref name) if name.is_ident("int64") => Ty::Int64, + Meta::Path(ref name) if name.is_ident("uint32") => Ty::Uint32, + Meta::Path(ref name) if name.is_ident("uint64") => Ty::Uint64, + Meta::Path(ref name) if name.is_ident("sint32") => Ty::Sint32, + Meta::Path(ref name) if name.is_ident("sint64") => Ty::Sint64, + Meta::Path(ref name) if name.is_ident("fixed32") => Ty::Fixed32, + Meta::Path(ref name) if name.is_ident("fixed64") => Ty::Fixed64, + Meta::Path(ref name) if name.is_ident("sfixed32") => Ty::Sfixed32, + Meta::Path(ref name) if name.is_ident("sfixed64") => Ty::Sfixed64, + Meta::Path(ref name) if name.is_ident("bool") => Ty::Bool, + Meta::Path(ref name) if name.is_ident("string") => Ty::String, + Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes(BytesTy::Vec), + Meta::NameValue(MetaNameValue { + ref path, + lit: Lit::Str(ref l), + .. + }) if path.is_ident("bytes") => Ty::Bytes(BytesTy::try_from_str(&l.value())?), + Meta::NameValue(MetaNameValue { + ref path, + lit: Lit::Str(ref l), + .. + }) if path.is_ident("enumeration") => Ty::Enumeration(parse_str::<Path>(&l.value())?), + Meta::List(MetaList { + ref path, + ref nested, + .. + }) if path.is_ident("enumeration") => { + // TODO(rustlang/rust#23121): slice pattern matching would make this much nicer. + if nested.len() == 1 { + if let NestedMeta::Meta(Meta::Path(ref path)) = nested[0] { + Ty::Enumeration(path.clone()) + } else { + bail!("invalid enumeration attribute: item must be an identifier"); + } + } else { + bail!("invalid enumeration attribute: only a single identifier is supported"); + } + } + _ => return Ok(None), + }; + Ok(Some(ty)) + } + + pub fn from_str(s: &str) -> Result<Ty, Error> { + let enumeration_len = "enumeration".len(); + let error = Err(anyhow!("invalid type: {}", s)); + let ty = match s.trim() { + "float" => Ty::Float, + "double" => Ty::Double, + "int32" => Ty::Int32, + "int64" => Ty::Int64, + "uint32" => Ty::Uint32, + "uint64" => Ty::Uint64, + "sint32" => Ty::Sint32, + "sint64" => Ty::Sint64, + "fixed32" => Ty::Fixed32, + "fixed64" => Ty::Fixed64, + "sfixed32" => Ty::Sfixed32, + "sfixed64" => Ty::Sfixed64, + "bool" => Ty::Bool, + "string" => Ty::String, + "bytes" => Ty::Bytes(BytesTy::Vec), + s if s.len() > enumeration_len && &s[..enumeration_len] == "enumeration" => { + let s = &s[enumeration_len..].trim(); + match s.chars().next() { + Some('<') | Some('(') => (), + _ => return error, + } + match s.chars().next_back() { + Some('>') | Some(')') => (), + _ => return error, + } + + Ty::Enumeration(parse_str::<Path>(s[1..s.len() - 1].trim())?) + } + _ => return error, + }; + Ok(ty) + } + + /// Returns the type as it appears in protobuf field declarations. + pub fn as_str(&self) -> &'static str { + match *self { + Ty::Double => "double", + Ty::Float => "float", + Ty::Int32 => "int32", + Ty::Int64 => "int64", + Ty::Uint32 => "uint32", + Ty::Uint64 => "uint64", + Ty::Sint32 => "sint32", + Ty::Sint64 => "sint64", + Ty::Fixed32 => "fixed32", + Ty::Fixed64 => "fixed64", + Ty::Sfixed32 => "sfixed32", + Ty::Sfixed64 => "sfixed64", + Ty::Bool => "bool", + Ty::String => "string", + Ty::Bytes(..) => "bytes", + Ty::Enumeration(..) => "enum", + } + } + + // TODO: rename to 'owned_type'. + pub fn rust_type(&self) -> TokenStream { + match self { + Ty::String => quote!(::prost::alloc::string::String), + Ty::Bytes(ty) => ty.rust_type(), + _ => self.rust_ref_type(), + } + } + + // TODO: rename to 'ref_type' + pub fn rust_ref_type(&self) -> TokenStream { + match *self { + Ty::Double => quote!(f64), + Ty::Float => quote!(f32), + Ty::Int32 => quote!(i32), + Ty::Int64 => quote!(i64), + Ty::Uint32 => quote!(u32), + Ty::Uint64 => quote!(u64), + Ty::Sint32 => quote!(i32), + Ty::Sint64 => quote!(i64), + Ty::Fixed32 => quote!(u32), + Ty::Fixed64 => quote!(u64), + Ty::Sfixed32 => quote!(i32), + Ty::Sfixed64 => quote!(i64), + Ty::Bool => quote!(bool), + Ty::String => quote!(&str), + Ty::Bytes(..) => quote!(&[u8]), + Ty::Enumeration(..) => quote!(i32), + } + } + + pub fn module(&self) -> Ident { + match *self { + Ty::Enumeration(..) => Ident::new("int32", Span::call_site()), + _ => Ident::new(self.as_str(), Span::call_site()), + } + } + + /// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`). + pub fn is_numeric(&self) -> bool { + !matches!(self, Ty::String | Ty::Bytes(..)) + } +} + +impl fmt::Debug for Ty { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +impl fmt::Display for Ty { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +/// Scalar Protobuf field types. +#[derive(Clone)] +pub enum Kind { + /// A plain proto3 scalar field. + Plain(DefaultValue), + /// An optional scalar field. + Optional(DefaultValue), + /// A required proto2 scalar field. + Required(DefaultValue), + /// A repeated scalar field. + Repeated, + /// A packed repeated scalar field. + Packed, +} + +/// Scalar Protobuf field default value. +#[derive(Clone, Debug)] +pub enum DefaultValue { + F64(f64), + F32(f32), + I32(i32), + I64(i64), + U32(u32), + U64(u64), + Bool(bool), + String(String), + Bytes(Vec<u8>), + Enumeration(TokenStream), + Path(Path), +} + +impl DefaultValue { + pub fn from_attr(attr: &Meta) -> Result<Option<Lit>, Error> { + if !attr.path().is_ident("default") { + Ok(None) + } else if let Meta::NameValue(ref name_value) = *attr { + Ok(Some(name_value.lit.clone())) + } else { + bail!("invalid default value attribute: {:?}", attr) + } + } + + pub fn from_lit(ty: &Ty, lit: Lit) -> Result<DefaultValue, Error> { + let is_i32 = *ty == Ty::Int32 || *ty == Ty::Sint32 || *ty == Ty::Sfixed32; + let is_i64 = *ty == Ty::Int64 || *ty == Ty::Sint64 || *ty == Ty::Sfixed64; + + let is_u32 = *ty == Ty::Uint32 || *ty == Ty::Fixed32; + let is_u64 = *ty == Ty::Uint64 || *ty == Ty::Fixed64; + + let empty_or_is = |expected, actual: &str| expected == actual || actual.is_empty(); + + let default = match lit { + Lit::Int(ref lit) if is_i32 && empty_or_is("i32", lit.suffix()) => { + DefaultValue::I32(lit.base10_parse()?) + } + Lit::Int(ref lit) if is_i64 && empty_or_is("i64", lit.suffix()) => { + DefaultValue::I64(lit.base10_parse()?) + } + Lit::Int(ref lit) if is_u32 && empty_or_is("u32", lit.suffix()) => { + DefaultValue::U32(lit.base10_parse()?) + } + Lit::Int(ref lit) if is_u64 && empty_or_is("u64", lit.suffix()) => { + DefaultValue::U64(lit.base10_parse()?) + } + + Lit::Float(ref lit) if *ty == Ty::Float && empty_or_is("f32", lit.suffix()) => { + DefaultValue::F32(lit.base10_parse()?) + } + Lit::Int(ref lit) if *ty == Ty::Float => DefaultValue::F32(lit.base10_parse()?), + + Lit::Float(ref lit) if *ty == Ty::Double && empty_or_is("f64", lit.suffix()) => { + DefaultValue::F64(lit.base10_parse()?) + } + Lit::Int(ref lit) if *ty == Ty::Double => DefaultValue::F64(lit.base10_parse()?), + + Lit::Bool(ref lit) if *ty == Ty::Bool => DefaultValue::Bool(lit.value), + Lit::Str(ref lit) if *ty == Ty::String => DefaultValue::String(lit.value()), + Lit::ByteStr(ref lit) + if *ty == Ty::Bytes(BytesTy::Bytes) || *ty == Ty::Bytes(BytesTy::Vec) => + { + DefaultValue::Bytes(lit.value()) + } + + Lit::Str(ref lit) => { + let value = lit.value(); + let value = value.trim(); + + if let Ty::Enumeration(ref path) = *ty { + let variant = Ident::new(value, Span::call_site()); + return Ok(DefaultValue::Enumeration(quote!(#path::#variant))); + } + + // Parse special floating point values. + if *ty == Ty::Float { + match value { + "inf" => { + return Ok(DefaultValue::Path(parse_str::<Path>( + "::core::f32::INFINITY", + )?)); + } + "-inf" => { + return Ok(DefaultValue::Path(parse_str::<Path>( + "::core::f32::NEG_INFINITY", + )?)); + } + "nan" => { + return Ok(DefaultValue::Path(parse_str::<Path>("::core::f32::NAN")?)); + } + _ => (), + } + } + if *ty == Ty::Double { + match value { + "inf" => { + return Ok(DefaultValue::Path(parse_str::<Path>( + "::core::f64::INFINITY", + )?)); + } + "-inf" => { + return Ok(DefaultValue::Path(parse_str::<Path>( + "::core::f64::NEG_INFINITY", + )?)); + } + "nan" => { + return Ok(DefaultValue::Path(parse_str::<Path>("::core::f64::NAN")?)); + } + _ => (), + } + } + + // Rust doesn't have a negative literals, so they have to be parsed specially. + if let Some(Ok(lit)) = value.strip_prefix('-').map(syn::parse_str::<Lit>) { + match lit { + Lit::Int(ref lit) if is_i32 && empty_or_is("i32", lit.suffix()) => { + // Initially parse into an i64, so that i32::MIN does not overflow. + let value: i64 = -lit.base10_parse()?; + return Ok(i32::try_from(value).map(DefaultValue::I32)?); + } + Lit::Int(ref lit) if is_i64 && empty_or_is("i64", lit.suffix()) => { + // Initially parse into an i128, so that i64::MIN does not overflow. + let value: i128 = -lit.base10_parse()?; + return Ok(i64::try_from(value).map(DefaultValue::I64)?); + } + Lit::Float(ref lit) + if *ty == Ty::Float && empty_or_is("f32", lit.suffix()) => + { + return Ok(DefaultValue::F32(-lit.base10_parse()?)); + } + Lit::Float(ref lit) + if *ty == Ty::Double && empty_or_is("f64", lit.suffix()) => + { + return Ok(DefaultValue::F64(-lit.base10_parse()?)); + } + Lit::Int(ref lit) if *ty == Ty::Float && lit.suffix().is_empty() => { + return Ok(DefaultValue::F32(-lit.base10_parse()?)); + } + Lit::Int(ref lit) if *ty == Ty::Double && lit.suffix().is_empty() => { + return Ok(DefaultValue::F64(-lit.base10_parse()?)); + } + _ => (), + } + } + match syn::parse_str::<Lit>(&value) { + Ok(Lit::Str(_)) => (), + Ok(lit) => return DefaultValue::from_lit(ty, lit), + _ => (), + } + bail!("invalid default value: {}", quote!(#value)); + } + _ => bail!("invalid default value: {}", quote!(#lit)), + }; + + Ok(default) + } + + pub fn new(ty: &Ty) -> DefaultValue { + match *ty { + Ty::Float => DefaultValue::F32(0.0), + Ty::Double => DefaultValue::F64(0.0), + Ty::Int32 | Ty::Sint32 | Ty::Sfixed32 => DefaultValue::I32(0), + Ty::Int64 | Ty::Sint64 | Ty::Sfixed64 => DefaultValue::I64(0), + Ty::Uint32 | Ty::Fixed32 => DefaultValue::U32(0), + Ty::Uint64 | Ty::Fixed64 => DefaultValue::U64(0), + + Ty::Bool => DefaultValue::Bool(false), + Ty::String => DefaultValue::String(String::new()), + Ty::Bytes(..) => DefaultValue::Bytes(Vec::new()), + Ty::Enumeration(ref path) => DefaultValue::Enumeration(quote!(#path::default())), + } + } + + pub fn owned(&self) -> TokenStream { + match *self { + DefaultValue::String(ref value) if value.is_empty() => { + quote!(::prost::alloc::string::String::new()) + } + DefaultValue::String(ref value) => quote!(#value.into()), + DefaultValue::Bytes(ref value) if value.is_empty() => quote!(Default::default()), + DefaultValue::Bytes(ref value) => { + let lit = LitByteStr::new(value, Span::call_site()); + quote!(#lit.as_ref().into()) + } + + ref other => other.typed(), + } + } + + pub fn typed(&self) -> TokenStream { + if let DefaultValue::Enumeration(_) = *self { + quote!(#self as i32) + } else { + quote!(#self) + } + } +} + +impl ToTokens for DefaultValue { + fn to_tokens(&self, tokens: &mut TokenStream) { + match *self { + DefaultValue::F64(value) => value.to_tokens(tokens), + DefaultValue::F32(value) => value.to_tokens(tokens), + DefaultValue::I32(value) => value.to_tokens(tokens), + DefaultValue::I64(value) => value.to_tokens(tokens), + DefaultValue::U32(value) => value.to_tokens(tokens), + DefaultValue::U64(value) => value.to_tokens(tokens), + DefaultValue::Bool(value) => value.to_tokens(tokens), + DefaultValue::String(ref value) => value.to_tokens(tokens), + DefaultValue::Bytes(ref value) => { + let byte_str = LitByteStr::new(value, Span::call_site()); + tokens.append_all(quote!(#byte_str as &[u8])); + } + DefaultValue::Enumeration(ref value) => value.to_tokens(tokens), + DefaultValue::Path(ref value) => value.to_tokens(tokens), + } + } +} diff --git a/third_party/rust/prost-derive/src/lib.rs b/third_party/rust/prost-derive/src/lib.rs new file mode 100644 index 0000000000..a517c41e80 --- /dev/null +++ b/third_party/rust/prost-derive/src/lib.rs @@ -0,0 +1,470 @@ +#![doc(html_root_url = "https://docs.rs/prost-derive/0.8.0")] +// The `quote!` macro requires deep recursion. +#![recursion_limit = "4096"] + +extern crate alloc; +extern crate proc_macro; + +use anyhow::{bail, Error}; +use itertools::Itertools; +use proc_macro::TokenStream; +use proc_macro2::Span; +use quote::quote; +use syn::{ + punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed, + FieldsUnnamed, Ident, Variant, +}; + +mod field; +use crate::field::Field; + +fn try_message(input: TokenStream) -> Result<TokenStream, Error> { + let input: DeriveInput = syn::parse(input)?; + + let ident = input.ident; + + let variant_data = match input.data { + Data::Struct(variant_data) => variant_data, + Data::Enum(..) => bail!("Message can not be derived for an enum"), + Data::Union(..) => bail!("Message can not be derived for a union"), + }; + + let generics = &input.generics; + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let fields = match variant_data { + DataStruct { + fields: Fields::Named(FieldsNamed { named: fields, .. }), + .. + } + | DataStruct { + fields: + Fields::Unnamed(FieldsUnnamed { + unnamed: fields, .. + }), + .. + } => fields.into_iter().collect(), + DataStruct { + fields: Fields::Unit, + .. + } => Vec::new(), + }; + + let mut next_tag: u32 = 1; + let mut fields = fields + .into_iter() + .enumerate() + .flat_map(|(idx, field)| { + let field_ident = field + .ident + .unwrap_or_else(|| Ident::new(&idx.to_string(), Span::call_site())); + match Field::new(field.attrs, Some(next_tag)) { + Ok(Some(field)) => { + next_tag = field.tags().iter().max().map(|t| t + 1).unwrap_or(next_tag); + Some(Ok((field_ident, field))) + } + Ok(None) => None, + Err(err) => Some(Err( + err.context(format!("invalid message field {}.{}", ident, field_ident)) + )), + } + }) + .collect::<Result<Vec<_>, _>>()?; + + // We want Debug to be in declaration order + let unsorted_fields = fields.clone(); + + // Sort the fields by tag number so that fields will be encoded in tag order. + // TODO: This encodes oneof fields in the position of their lowest tag, + // regardless of the currently occupied variant, is that consequential? + // See: https://developers.google.com/protocol-buffers/docs/encoding#order + fields.sort_by_key(|&(_, ref field)| field.tags().into_iter().min().unwrap()); + let fields = fields; + + let mut tags = fields + .iter() + .flat_map(|&(_, ref field)| field.tags()) + .collect::<Vec<_>>(); + let num_tags = tags.len(); + tags.sort_unstable(); + tags.dedup(); + if tags.len() != num_tags { + bail!("message {} has fields with duplicate tags", ident); + } + + let encoded_len = fields + .iter() + .map(|&(ref field_ident, ref field)| field.encoded_len(quote!(self.#field_ident))); + + let encode = fields + .iter() + .map(|&(ref field_ident, ref field)| field.encode(quote!(self.#field_ident))); + + let merge = fields.iter().map(|&(ref field_ident, ref field)| { + let merge = field.merge(quote!(value)); + let tags = field.tags().into_iter().map(|tag| quote!(#tag)); + let tags = Itertools::intersperse(tags, quote!(|)); + + quote! { + #(#tags)* => { + let mut value = &mut self.#field_ident; + #merge.map_err(|mut error| { + error.push(STRUCT_NAME, stringify!(#field_ident)); + error + }) + }, + } + }); + + let struct_name = if fields.is_empty() { + quote!() + } else { + quote!( + const STRUCT_NAME: &'static str = stringify!(#ident); + ) + }; + + // TODO + let is_struct = true; + + let clear = fields + .iter() + .map(|&(ref field_ident, ref field)| field.clear(quote!(self.#field_ident))); + + let default = fields.iter().map(|&(ref field_ident, ref field)| { + let value = field.default(); + quote!(#field_ident: #value,) + }); + + let methods = fields + .iter() + .flat_map(|&(ref field_ident, ref field)| field.methods(field_ident)) + .collect::<Vec<_>>(); + let methods = if methods.is_empty() { + quote!() + } else { + quote! { + #[allow(dead_code)] + impl #impl_generics #ident #ty_generics #where_clause { + #(#methods)* + } + } + }; + + let debugs = unsorted_fields.iter().map(|&(ref field_ident, ref field)| { + let wrapper = field.debug(quote!(self.#field_ident)); + let call = if is_struct { + quote!(builder.field(stringify!(#field_ident), &wrapper)) + } else { + quote!(builder.field(&wrapper)) + }; + quote! { + let builder = { + let wrapper = #wrapper; + #call + }; + } + }); + let debug_builder = if is_struct { + quote!(f.debug_struct(stringify!(#ident))) + } else { + quote!(f.debug_tuple(stringify!(#ident))) + }; + + let expanded = quote! { + impl #impl_generics ::prost::Message for #ident #ty_generics #where_clause { + #[allow(unused_variables)] + fn encode_raw<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut { + #(#encode)* + } + + #[allow(unused_variables)] + fn merge_field<B>( + &mut self, + tag: u32, + wire_type: ::prost::encoding::WireType, + buf: &mut B, + ctx: ::prost::encoding::DecodeContext, + ) -> ::core::result::Result<(), ::prost::DecodeError> + where B: ::prost::bytes::Buf { + #struct_name + match tag { + #(#merge)* + _ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx), + } + } + + #[inline] + fn encoded_len(&self) -> usize { + 0 #(+ #encoded_len)* + } + + fn clear(&mut self) { + #(#clear;)* + } + } + + impl #impl_generics Default for #ident #ty_generics #where_clause { + fn default() -> Self { + #ident { + #(#default)* + } + } + } + + impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + let mut builder = #debug_builder; + #(#debugs;)* + builder.finish() + } + } + + #methods + }; + + Ok(expanded.into()) +} + +#[proc_macro_derive(Message, attributes(prost))] +pub fn message(input: TokenStream) -> TokenStream { + try_message(input).unwrap() +} + +fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> { + let input: DeriveInput = syn::parse(input)?; + let ident = input.ident; + + let generics = &input.generics; + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let punctuated_variants = match input.data { + Data::Enum(DataEnum { variants, .. }) => variants, + Data::Struct(_) => bail!("Enumeration can not be derived for a struct"), + Data::Union(..) => bail!("Enumeration can not be derived for a union"), + }; + + // Map the variants into 'fields'. + let mut variants: Vec<(Ident, Expr)> = Vec::new(); + for Variant { + ident, + fields, + discriminant, + .. + } in punctuated_variants + { + match fields { + Fields::Unit => (), + Fields::Named(_) | Fields::Unnamed(_) => { + bail!("Enumeration variants may not have fields") + } + } + + match discriminant { + Some((_, expr)) => variants.push((ident, expr)), + None => bail!("Enumeration variants must have a disriminant"), + } + } + + if variants.is_empty() { + panic!("Enumeration must have at least one variant"); + } + + let default = variants[0].0.clone(); + + let is_valid = variants + .iter() + .map(|&(_, ref value)| quote!(#value => true)); + let from = variants.iter().map( + |&(ref variant, ref value)| quote!(#value => ::core::option::Option::Some(#ident::#variant)), + ); + + let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident); + let from_i32_doc = format!( + "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.", + ident + ); + + let expanded = quote! { + impl #impl_generics #ident #ty_generics #where_clause { + #[doc=#is_valid_doc] + pub fn is_valid(value: i32) -> bool { + match value { + #(#is_valid,)* + _ => false, + } + } + + #[doc=#from_i32_doc] + pub fn from_i32(value: i32) -> ::core::option::Option<#ident> { + match value { + #(#from,)* + _ => ::core::option::Option::None, + } + } + } + + impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause { + fn default() -> #ident { + #ident::#default + } + } + + impl #impl_generics ::core::convert::From::<#ident> for i32 #ty_generics #where_clause { + fn from(value: #ident) -> i32 { + value as i32 + } + } + }; + + Ok(expanded.into()) +} + +#[proc_macro_derive(Enumeration, attributes(prost))] +pub fn enumeration(input: TokenStream) -> TokenStream { + try_enumeration(input).unwrap() +} + +fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> { + let input: DeriveInput = syn::parse(input)?; + + let ident = input.ident; + + let variants = match input.data { + Data::Enum(DataEnum { variants, .. }) => variants, + Data::Struct(..) => bail!("Oneof can not be derived for a struct"), + Data::Union(..) => bail!("Oneof can not be derived for a union"), + }; + + let generics = &input.generics; + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + // Map the variants into 'fields'. + let mut fields: Vec<(Ident, Field)> = Vec::new(); + for Variant { + attrs, + ident: variant_ident, + fields: variant_fields, + .. + } in variants + { + let variant_fields = match variant_fields { + Fields::Unit => Punctuated::new(), + Fields::Named(FieldsNamed { named: fields, .. }) + | Fields::Unnamed(FieldsUnnamed { + unnamed: fields, .. + }) => fields, + }; + if variant_fields.len() != 1 { + bail!("Oneof enum variants must have a single field"); + } + match Field::new_oneof(attrs)? { + Some(field) => fields.push((variant_ident, field)), + None => bail!("invalid oneof variant: oneof variants may not be ignored"), + } + } + + let mut tags = fields + .iter() + .flat_map(|&(ref variant_ident, ref field)| -> Result<u32, Error> { + if field.tags().len() > 1 { + bail!( + "invalid oneof variant {}::{}: oneof variants may only have a single tag", + ident, + variant_ident + ); + } + Ok(field.tags()[0]) + }) + .collect::<Vec<_>>(); + tags.sort_unstable(); + tags.dedup(); + if tags.len() != fields.len() { + panic!("invalid oneof {}: variants have duplicate tags", ident); + } + + let encode = fields.iter().map(|&(ref variant_ident, ref field)| { + let encode = field.encode(quote!(*value)); + quote!(#ident::#variant_ident(ref value) => { #encode }) + }); + + let merge = fields.iter().map(|&(ref variant_ident, ref field)| { + let tag = field.tags()[0]; + let merge = field.merge(quote!(value)); + quote! { + #tag => { + match field { + ::core::option::Option::Some(#ident::#variant_ident(ref mut value)) => { + #merge + }, + _ => { + let mut owned_value = ::core::default::Default::default(); + let value = &mut owned_value; + #merge.map(|_| *field = ::core::option::Option::Some(#ident::#variant_ident(owned_value))) + }, + } + } + } + }); + + let encoded_len = fields.iter().map(|&(ref variant_ident, ref field)| { + let encoded_len = field.encoded_len(quote!(*value)); + quote!(#ident::#variant_ident(ref value) => #encoded_len) + }); + + let debug = fields.iter().map(|&(ref variant_ident, ref field)| { + let wrapper = field.debug(quote!(*value)); + quote!(#ident::#variant_ident(ref value) => { + let wrapper = #wrapper; + f.debug_tuple(stringify!(#variant_ident)) + .field(&wrapper) + .finish() + }) + }); + + let expanded = quote! { + impl #impl_generics #ident #ty_generics #where_clause { + pub fn encode<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut { + match *self { + #(#encode,)* + } + } + + pub fn merge<B>( + field: &mut ::core::option::Option<#ident #ty_generics>, + tag: u32, + wire_type: ::prost::encoding::WireType, + buf: &mut B, + ctx: ::prost::encoding::DecodeContext, + ) -> ::core::result::Result<(), ::prost::DecodeError> + where B: ::prost::bytes::Buf { + match tag { + #(#merge,)* + _ => unreachable!(concat!("invalid ", stringify!(#ident), " tag: {}"), tag), + } + } + + #[inline] + pub fn encoded_len(&self) -> usize { + match *self { + #(#encoded_len,)* + } + } + } + + impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + match *self { + #(#debug,)* + } + } + } + }; + + Ok(expanded.into()) +} + +#[proc_macro_derive(Oneof, attributes(prost))] +pub fn oneof(input: TokenStream) -> TokenStream { + try_oneof(input).unwrap() +} diff --git a/third_party/rust/prost/.cargo-checksum.json b/third_party/rust/prost/.cargo-checksum.json new file mode 100644 index 0000000000..addf3e74e4 --- /dev/null +++ b/third_party/rust/prost/.cargo-checksum.json @@ -0,0 +1 @@ +{"files":{"Cargo.toml":"181dc7aae816183fc30486a41d4efcffc9dff51604b8a285b5259193103283a0","FUZZING.md":"0223441ab805f2ccfda9f172955446c3287e1e125cc68814f77a46573935feba","LICENSE":"a60eea817514531668d7e00765731449fe14d059d3249e0bc93b36de45f759f2","README.md":"582696a33d4cd5da2797a0928d35689bbb0a51317ca03c94be235224cbadd6cc","benches/varint.rs":"25e28eadeb5092882281eaa16307ae72cbdabe04a308c9cfe15d3a62c8ead40b","clippy.toml":"10eea08f9e26e0dc498e431ac3a62b861bd74029d1cad8a394284af4cbc90532","prepare-release.sh":"67f42e0649d33269c88272e69a9bf48d02de42126a9942ac6298b6995adea8df","publish-release.sh":"a9ff9a5a65a6772fbe115b64051b1284b0b81825f839a65594d6834c53d7a78f","src/encoding.rs":"20dd077efd9f12e45657b802cfe57de0739de4b9b81b6e420a569b7459150ff2","src/error.rs":"60194cd97e5a6b0f985e74630d3f41c14e3ce1bd009a802c1540462e9a671bc5","src/lib.rs":"7bbc5d3d941ab687dc2bed0982289d28e5a4e1f66805cde3ce09dd97d39eb2df","src/message.rs":"19527eb5efa2a959d555c45f9391022a8f380984fc6bbf35ab23069a0a6091ce","src/types.rs":"edfefaf56ab4bc12c98cbcd9d0b3ca18c10352217556c193be6314990ecffd9c"},"package":"de5e2533f59d08fcf364fd374ebda0692a70bd6d7e66ef97f306f45c6c5d8020"}
\ No newline at end of file diff --git a/third_party/rust/prost/Cargo.toml b/third_party/rust/prost/Cargo.toml new file mode 100644 index 0000000000..c1112fdc20 --- /dev/null +++ b/third_party/rust/prost/Cargo.toml @@ -0,0 +1,60 @@ +# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO +# +# When uploading crates to the registry Cargo will automatically +# "normalize" Cargo.toml files for maximal compatibility +# with all versions of Cargo and also rewrite `path` dependencies +# to registry (e.g., crates.io) dependencies +# +# If you believe there's an error in this file please file an +# issue against the rust-lang/cargo repository. If you're +# editing this file be aware that the upstream Cargo.toml +# will likely look very different (and much more reasonable) + +[package] +edition = "2018" +name = "prost" +version = "0.8.0" +authors = ["Dan Burkert <dan@danburkert.com>", "Tokio Contributors <team@tokio.rs>"] +description = "A Protocol Buffers implementation for the Rust Language." +documentation = "https://docs.rs/prost" +readme = "README.md" +keywords = ["protobuf", "serialization"] +categories = ["encoding"] +license = "Apache-2.0" +repository = "https://github.com/tokio-rs/prost" +[profile.bench] +debug = true + +[lib] +bench = false + +[[bench]] +name = "varint" +harness = false +[dependencies.bytes] +version = "1" +default-features = false + +[dependencies.prost-derive] +version = "0.8.0" +optional = true +[dev-dependencies.criterion] +version = "0.3" + +[dev-dependencies.env_logger] +version = "0.8" +default-features = false + +[dev-dependencies.log] +version = "0.4" + +[dev-dependencies.proptest] +version = "1" + +[dev-dependencies.rand] +version = "0.8" + +[features] +default = ["prost-derive", "std"] +no-recursion-limit = [] +std = [] diff --git a/third_party/rust/prost/FUZZING.md b/third_party/rust/prost/FUZZING.md new file mode 100644 index 0000000000..d47268d699 --- /dev/null +++ b/third_party/rust/prost/FUZZING.md @@ -0,0 +1,27 @@ +# Fuzzing + +Prost ships a few fuzz tests, using both libfuzzer and aflfuzz. + + +## afl + +To run the afl fuzz tests, first install cargo-afl: + + cargo install -f afl + +Then build a fuzz target and run afl on it: + + cd afl/<target>/ + cargo afl build --bin fuzz-target + cargo afl fuzz -i in -o out target/debug/fuzz-target + +To reproduce a crash: + + cd afl/<target>/ + cargo build --bin reproduce + cargo run --bin reproduce -- out/crashes/<crashfile> + + +## libfuzzer + +TODO diff --git a/third_party/rust/prost/LICENSE b/third_party/rust/prost/LICENSE new file mode 100644 index 0000000000..16fe87b06e --- /dev/null +++ b/third_party/rust/prost/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/third_party/rust/prost/README.md b/third_party/rust/prost/README.md new file mode 100644 index 0000000000..fc5d87b80d --- /dev/null +++ b/third_party/rust/prost/README.md @@ -0,0 +1,462 @@ +![continuous integration](https://github.com/tokio-rs/prost/workflows/continuous%20integration/badge.svg) +[![Documentation](https://docs.rs/prost/badge.svg)](https://docs.rs/prost/) +[![Crate](https://img.shields.io/crates/v/prost.svg)](https://crates.io/crates/prost) +[![Dependency Status](https://deps.rs/repo/github/tokio-rs/prost/status.svg)](https://deps.rs/repo/github/tokio-rs/prost) + +# *PROST!* + +`prost` is a [Protocol Buffers](https://developers.google.com/protocol-buffers/) +implementation for the [Rust Language](https://www.rust-lang.org/). `prost` +generates simple, idiomatic Rust code from `proto2` and `proto3` files. + +Compared to other Protocol Buffers implementations, `prost` + +* Generates simple, idiomatic, and readable Rust types by taking advantage of + Rust `derive` attributes. +* Retains comments from `.proto` files in generated Rust code. +* Allows existing Rust types (not generated from a `.proto`) to be serialized + and deserialized by adding attributes. +* Uses the [`bytes::{Buf, BufMut}`](https://github.com/carllerche/bytes) + abstractions for serialization instead of `std::io::{Read, Write}`. +* Respects the Protobuf `package` specifier when organizing generated code + into Rust modules. +* Preserves unknown enum values during deserialization. +* Does not include support for runtime reflection or message descriptors. + +## Using `prost` in a Cargo Project + +First, add `prost` and its public dependencies to your `Cargo.toml`: + +``` +[dependencies] +prost = "0.8" +# Only necessary if using Protobuf well-known types: +prost-types = "0.8" +``` + +The recommended way to add `.proto` compilation to a Cargo project is to use the +`prost-build` library. See the [`prost-build` documentation](prost-build) for +more details and examples. + +## Generated Code + +`prost` generates Rust code from source `.proto` files using the `proto2` or +`proto3` syntax. `prost`'s goal is to make the generated code as simple as +possible. + +### Packages + +All `.proto` files used with `prost` must contain a +[`package` specifier][package]. `prost` will translate the Protobuf package into +a Rust module. For example, given the `package` specifier: + +[package]: https://developers.google.com/protocol-buffers/docs/proto#packages + +```proto +package foo.bar; +``` + +All Rust types generated from the file will be in the `foo::bar` module. + +### Messages + +Given a simple message declaration: + +```proto +// Sample message. +message Foo { +} +``` + +`prost` will generate the following Rust struct: + +```rust +/// Sample message. +#[derive(Clone, Debug, PartialEq, Message)] +pub struct Foo { +} +``` + +### Fields + +Fields in Protobuf messages are translated into Rust as public struct fields of the +corresponding type. + +#### Scalar Values + +Scalar value types are converted as follows: + +| Protobuf Type | Rust Type | +| --- | --- | +| `double` | `f64` | +| `float` | `f32` | +| `int32` | `i32` | +| `int64` | `i64` | +| `uint32` | `u32` | +| `uint64` | `u64` | +| `sint32` | `i32` | +| `sint64` | `i64` | +| `fixed32` | `u32` | +| `fixed64` | `u64` | +| `sfixed32` | `i32` | +| `sfixed64` | `i64` | +| `bool` | `bool` | +| `string` | `String` | +| `bytes` | `Vec<u8>` | + +#### Enumerations + +All `.proto` enumeration types convert to the Rust `i32` type. Additionally, +each enumeration type gets a corresponding Rust `enum` type. For example, this +`proto` enum: + +```proto +enum PhoneType { + MOBILE = 0; + HOME = 1; + WORK = 2; +} +``` + +gets this corresponding Rust enum [1]: + +```rust +pub enum PhoneType { + Mobile = 0, + Home = 1, + Work = 2, +} +``` + +You can convert a `PhoneType` value to an `i32` by doing: + +```rust +PhoneType::Mobile as i32 +``` + +The `#[derive(::prost::Enumeration)]` annotation added to the generated +`PhoneType` adds these associated functions to the type: + +```rust +impl PhoneType { + pub fn is_valid(value: i32) -> bool { ... } + pub fn from_i32(value: i32) -> Option<PhoneType> { ... } +} +``` + +so you can convert an `i32` to its corresponding `PhoneType` value by doing, +for example: + +```rust +let phone_type = 2i32; + +match PhoneType::from_i32(phone_type) { + Some(PhoneType::Mobile) => ..., + Some(PhoneType::Home) => ..., + Some(PhoneType::Work) => ..., + None => ..., +} +``` + +Additionally, wherever a `proto` enum is used as a field in a `Message`, the +message will have 'accessor' methods to get/set the value of the field as the +Rust enum type. For instance, this proto `PhoneNumber` message that has a field +named `type` of type `PhoneType`: + +```proto +message PhoneNumber { + string number = 1; + PhoneType type = 2; +} +``` + +will become the following Rust type [1] with methods `type` and `set_type`: + +```rust +pub struct PhoneNumber { + pub number: String, + pub r#type: i32, // the `r#` is needed because `type` is a Rust keyword +} + +impl PhoneNumber { + pub fn r#type(&self) -> PhoneType { ... } + pub fn set_type(&mut self, value: PhoneType) { ... } +} +``` + +Note that the getter methods will return the Rust enum's default value if the +field has an invalid `i32` value. + +The `enum` type isn't used directly as a field, because the Protobuf spec +mandates that enumerations values are 'open', and decoding unrecognized +enumeration values must be possible. + +[1] Annotations have been elided for clarity. See below for a full example. + +#### Field Modifiers + +Protobuf scalar value and enumeration message fields can have a modifier +depending on the Protobuf version. Modifiers change the corresponding type of +the Rust field: + +| `.proto` Version | Modifier | Rust Type | +| --- | --- | --- | +| `proto2` | `optional` | `Option<T>` | +| `proto2` | `required` | `T` | +| `proto3` | default | `T` | +| `proto2`/`proto3` | repeated | `Vec<T>` | + +#### Map Fields + +Map fields are converted to a Rust `HashMap` with key and value type converted +from the Protobuf key and value types. + +#### Message Fields + +Message fields are converted to the corresponding struct type. The table of +field modifiers above applies to message fields, except that `proto3` message +fields without a modifier (the default) will be wrapped in an `Option`. +Typically message fields are unboxed. `prost` will automatically box a message +field if the field type and the parent type are recursively nested in order to +avoid an infinite sized struct. + +#### Oneof Fields + +Oneof fields convert to a Rust enum. Protobuf `oneof`s types are not named, so +`prost` uses the name of the `oneof` field for the resulting Rust enum, and +defines the enum in a module under the struct. For example, a `proto3` message +such as: + +```proto +message Foo { + oneof widget { + int32 quux = 1; + string bar = 2; + } +} +``` + +generates the following Rust[1]: + +```rust +pub struct Foo { + pub widget: Option<foo::Widget>, +} +pub mod foo { + pub enum Widget { + Quux(i32), + Bar(String), + } +} +``` + +`oneof` fields are always wrapped in an `Option`. + +[1] Annotations have been elided for clarity. See below for a full example. + +### Services + +`prost-build` allows a custom code-generator to be used for processing `service` +definitions. This can be used to output Rust traits according to an +application's specific needs. + +### Generated Code Example + +Example `.proto` file: + +```proto +syntax = "proto3"; +package tutorial; + +message Person { + string name = 1; + int32 id = 2; // Unique ID number for this person. + string email = 3; + + enum PhoneType { + MOBILE = 0; + HOME = 1; + WORK = 2; + } + + message PhoneNumber { + string number = 1; + PhoneType type = 2; + } + + repeated PhoneNumber phones = 4; +} + +// Our address book file is just one of these. +message AddressBook { + repeated Person people = 1; +} +``` + +and the generated Rust code (`tutorial.rs`): + +```rust +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Person { + #[prost(string, tag="1")] + pub name: ::prost::alloc::string::String, + /// Unique ID number for this person. + #[prost(int32, tag="2")] + pub id: i32, + #[prost(string, tag="3")] + pub email: ::prost::alloc::string::String, + #[prost(message, repeated, tag="4")] + pub phones: ::prost::alloc::vec::Vec<person::PhoneNumber>, +} +/// Nested message and enum types in `Person`. +pub mod person { + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct PhoneNumber { + #[prost(string, tag="1")] + pub number: ::prost::alloc::string::String, + #[prost(enumeration="PhoneType", tag="2")] + pub r#type: i32, + } + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] + #[repr(i32)] + pub enum PhoneType { + Mobile = 0, + Home = 1, + Work = 2, + } +} +/// Our address book file is just one of these. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct AddressBook { + #[prost(message, repeated, tag="1")] + pub people: ::prost::alloc::vec::Vec<Person>, +} +``` + +## Accessing the `protoc` `FileDescriptorSet` + +The `prost_build::Config::file_descriptor_set_path` option can be used to emit a file descriptor set +during the build & code generation step. When used in conjunction with the `std::include_bytes` +macro and the `prost_types::FileDescriptorSet` type, applications and libraries using Prost can +implement introspection capabilities requiring details from the original `.proto` files. + +## Using `prost` in a `no_std` Crate + +`prost` is compatible with `no_std` crates. To enable `no_std` support, disable +the `std` features in `prost` and `prost-types`: + +``` +[dependencies] +prost = { version = "0.6", default-features = false, features = ["prost-derive"] } +# Only necessary if using Protobuf well-known types: +prost-types = { version = "0.6", default-features = false } +``` + +Additionally, configure `prost-build` to output `BTreeMap`s instead of `HashMap`s +for all Protobuf `map` fields in your `build.rs`: + +```rust +let mut config = prost_build::Config::new(); +config.btree_map(&["."]); +``` + +When using edition 2015, it may be necessary to add an `extern crate core;` +directive to the crate which includes `prost`-generated code. + +## Serializing Existing Types + +`prost` uses a custom derive macro to handle encoding and decoding types, which +means that if your existing Rust type is compatible with Protobuf types, you can +serialize and deserialize it by adding the appropriate derive and field +annotations. + +Currently the best documentation on adding annotations is to look at the +generated code examples above. + +### Tag Inference for Existing Types + +Prost automatically infers tags for the struct. + +Fields are tagged sequentially in the order they +are specified, starting with `1`. + +You may skip tags which have been reserved, or where there are gaps between +sequentially occurring tag values by specifying the tag number to skip to with +the `tag` attribute on the first field after the gap. The following fields will +be tagged sequentially starting from the next number. + +```rust +use prost; +use prost::{Enumeration, Message}; + +#[derive(Clone, PartialEq, Message)] +struct Person { + #[prost(string, tag = "1")] + pub id: String, // tag=1 + // NOTE: Old "name" field has been removed + // pub name: String, // tag=2 (Removed) + #[prost(string, tag = "6")] + pub given_name: String, // tag=6 + #[prost(string)] + pub family_name: String, // tag=7 + #[prost(string)] + pub formatted_name: String, // tag=8 + #[prost(uint32, tag = "3")] + pub age: u32, // tag=3 + #[prost(uint32)] + pub height: u32, // tag=4 + #[prost(enumeration = "Gender")] + pub gender: i32, // tag=5 + // NOTE: Skip to less commonly occurring fields + #[prost(string, tag = "16")] + pub name_prefix: String, // tag=16 (eg. mr/mrs/ms) + #[prost(string)] + pub name_suffix: String, // tag=17 (eg. jr/esq) + #[prost(string)] + pub maiden_name: String, // tag=18 +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Enumeration)] +pub enum Gender { + Unknown = 0, + Female = 1, + Male = 2, +} +``` + +## FAQ + +1. **Could `prost` be implemented as a serializer for [Serde](https://serde.rs/)?** + + Probably not, however I would like to hear from a Serde expert on the matter. + There are two complications with trying to serialize Protobuf messages with + Serde: + + - Protobuf fields require a numbered tag, and currently there appears to be no + mechanism suitable for this in `serde`. + - The mapping of Protobuf type to Rust type is not 1-to-1. As a result, + trait-based approaches to dispatching don't work very well. Example: six + different Protobuf field types correspond to a Rust `Vec<i32>`: `repeated + int32`, `repeated sint32`, `repeated sfixed32`, and their packed + counterparts. + + But it is possible to place `serde` derive tags onto the generated types, so + the same structure can support both `prost` and `Serde`. + +2. **I get errors when trying to run `cargo test` on MacOS** + + If the errors are about missing `autoreconf` or similar, you can probably fix + them by running + + ``` + brew install automake + brew install libtool + ``` + +## License + +`prost` is distributed under the terms of the Apache License (Version 2.0). + +See [LICENSE](LICENSE) for details. + +Copyright 2017 Dan Burkert diff --git a/third_party/rust/prost/benches/varint.rs b/third_party/rust/prost/benches/varint.rs new file mode 100644 index 0000000000..34951f5eaf --- /dev/null +++ b/third_party/rust/prost/benches/varint.rs @@ -0,0 +1,99 @@ +use std::mem; + +use bytes::Buf; +use criterion::{Criterion, Throughput}; +use prost::encoding::{decode_varint, encode_varint, encoded_len_varint}; +use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng}; + +fn benchmark_varint(criterion: &mut Criterion, name: &str, mut values: Vec<u64>) { + // Shuffle the values in a stable order. + values.shuffle(&mut StdRng::seed_from_u64(0)); + let name = format!("varint/{}", name); + + let encoded_len = values + .iter() + .cloned() + .map(encoded_len_varint) + .sum::<usize>() as u64; + let decoded_len = (values.len() * mem::size_of::<u64>()) as u64; + + criterion + .benchmark_group(&name) + .bench_function("encode", { + let encode_values = values.clone(); + move |b| { + let mut buf = Vec::<u8>::with_capacity(encode_values.len() * 10); + b.iter(|| { + buf.clear(); + for &value in &encode_values { + encode_varint(value, &mut buf); + } + criterion::black_box(&buf); + }) + } + }) + .throughput(Throughput::Bytes(encoded_len)); + + criterion + .benchmark_group(&name) + .bench_function("decode", { + let decode_values = values.clone(); + + move |b| { + let mut buf = Vec::with_capacity(decode_values.len() * 10); + for &value in &decode_values { + encode_varint(value, &mut buf); + } + + b.iter(|| { + let mut buf = &mut buf.as_slice(); + while buf.has_remaining() { + let result = decode_varint(&mut buf); + debug_assert!(result.is_ok()); + criterion::black_box(&result); + } + }) + } + }) + .throughput(Throughput::Bytes(decoded_len)); + + criterion + .benchmark_group(&name) + .bench_function("encoded_len", move |b| { + b.iter(|| { + let mut sum = 0; + for &value in &values { + sum += encoded_len_varint(value); + } + criterion::black_box(sum); + }) + }) + .throughput(Throughput::Bytes(decoded_len)); +} + +fn main() { + let mut criterion = Criterion::default().configure_from_args(); + + // Benchmark encoding and decoding 100 small (1 byte) varints. + benchmark_varint(&mut criterion, "small", (0..100).collect()); + + // Benchmark encoding and decoding 100 medium (5 byte) varints. + benchmark_varint(&mut criterion, "medium", (1 << 28..).take(100).collect()); + + // Benchmark encoding and decoding 100 large (10 byte) varints. + benchmark_varint(&mut criterion, "large", (1 << 63..).take(100).collect()); + + // Benchmark encoding and decoding 100 varints of mixed width (average 5.5 bytes). + benchmark_varint( + &mut criterion, + "mixed", + (0..10) + .flat_map(move |width| { + let exponent = width * 7; + (0..10).map(move |offset| offset + (1 << exponent)) + }) + .collect(), + ); + + criterion.final_summary(); +} diff --git a/third_party/rust/prost/clippy.toml b/third_party/rust/prost/clippy.toml new file mode 100644 index 0000000000..5988e12d8f --- /dev/null +++ b/third_party/rust/prost/clippy.toml @@ -0,0 +1 @@ +too-many-arguments-threshold=8 diff --git a/third_party/rust/prost/prepare-release.sh b/third_party/rust/prost/prepare-release.sh new file mode 100755 index 0000000000..34e5202c24 --- /dev/null +++ b/third_party/rust/prost/prepare-release.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +# Script which automates modifying source version fields, and creating a release +# commit and tag. The commit and tag are not automatically pushed, nor are the +# crates published (see publish-release.sh). + +set -ex + +if [ "$#" -ne 1 ] +then + echo "Usage: $0 <version>" + exit 1 +fi + +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +VERSION="$1" +MINOR="$( echo ${VERSION} | cut -d\. -f1-2 )" + +VERSION_MATCHER="([a-z0-9\\.-]+)" +PROST_CRATE_MATCHER="(prost|prost-[a-z]+)" + +# Update the README.md. +sed -i -E "s/${PROST_CRATE_MATCHER} = \"${VERSION_MATCHER}\"/\1 = \"${MINOR}\"/" "$DIR/README.md" + +# Update html_root_url attributes. +sed -i -E "s~html_root_url = \"https://docs\.rs/${PROST_CRATE_MATCHER}/$VERSION_MATCHER\"~html_root_url = \"https://docs.rs/\1/${VERSION}\"~" \ + "$DIR/src/lib.rs" \ + "$DIR/prost-derive/src/lib.rs" \ + "$DIR/prost-build/src/lib.rs" \ + "$DIR/prost-types/src/lib.rs" + +# Update Cargo.toml version fields. +sed -i -E "s/^version = \"${VERSION_MATCHER}\"$/version = \"${VERSION}\"/" \ + "$DIR/Cargo.toml" \ + "$DIR/prost-derive/Cargo.toml" \ + "$DIR/prost-build/Cargo.toml" \ + "$DIR/prost-types/Cargo.toml" + +# Update Cargo.toml dependency versions. +sed -i -E "s/^${PROST_CRATE_MATCHER} = \{ version = \"${VERSION_MATCHER}\"/\1 = { version = \"${VERSION}\"/" \ + "$DIR/Cargo.toml" \ + "$DIR/prost-derive/Cargo.toml" \ + "$DIR/prost-build/Cargo.toml" \ + "$DIR/prost-types/Cargo.toml" + +git commit -a -m "release ${VERSION}" +git tag -a "v${VERSION}" -m "release ${VERSION}" diff --git a/third_party/rust/prost/publish-release.sh b/third_party/rust/prost/publish-release.sh new file mode 100755 index 0000000000..b2be598298 --- /dev/null +++ b/third_party/rust/prost/publish-release.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +# Script which automates publishing a crates.io release of the prost crates. + +set -ex + +if [ "$#" -ne 0 ] +then + echo "Usage: $0" + exit 1 +fi + +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +CRATES=( \ + "prost-derive" \ + "." \ + "prost-types" \ + "prost-build" \ +) + +for CRATE in "${CRATES[@]}"; do + pushd "$DIR/$CRATE" + cargo publish + popd +done diff --git a/third_party/rust/prost/src/encoding.rs b/third_party/rust/prost/src/encoding.rs new file mode 100644 index 0000000000..252358685c --- /dev/null +++ b/third_party/rust/prost/src/encoding.rs @@ -0,0 +1,1770 @@ +//! Utility functions and types for encoding and decoding Protobuf types. +//! +//! Meant to be used only from `Message` implementations. + +#![allow(clippy::implicit_hasher, clippy::ptr_arg)] + +use alloc::collections::BTreeMap; +use alloc::format; +use alloc::string::String; +use alloc::vec::Vec; +use core::cmp::min; +use core::convert::TryFrom; +use core::mem; +use core::str; +use core::u32; +use core::usize; + +use ::bytes::{Buf, BufMut, Bytes}; + +use crate::DecodeError; +use crate::Message; + +/// Encodes an integer value into LEB128 variable length format, and writes it to the buffer. +/// The buffer must have enough remaining space (maximum 10 bytes). +#[inline] +pub fn encode_varint<B>(mut value: u64, buf: &mut B) +where + B: BufMut, +{ + // Safety notes: + // + // - ptr::write is an unsafe raw pointer write. The use here is safe since the length of the + // uninit slice is checked. + // - advance_mut is unsafe because it could cause uninitialized memory to be advanced over. The + // use here is safe since each byte which is advanced over has been written to in the + // previous loop iteration. + unsafe { + let mut i; + 'outer: loop { + i = 0; + + let uninit_slice = buf.chunk_mut(); + for offset in 0..uninit_slice.len() { + i += 1; + let ptr = uninit_slice.as_mut_ptr().add(offset); + if value < 0x80 { + ptr.write(value as u8); + break 'outer; + } else { + ptr.write(((value & 0x7F) | 0x80) as u8); + value >>= 7; + } + } + + buf.advance_mut(i); + debug_assert!(buf.has_remaining_mut()); + } + + buf.advance_mut(i); + } +} + +/// Decodes a LEB128-encoded variable length integer from the buffer. +pub fn decode_varint<B>(buf: &mut B) -> Result<u64, DecodeError> +where + B: Buf, +{ + let bytes = buf.chunk(); + let len = bytes.len(); + if len == 0 { + return Err(DecodeError::new("invalid varint")); + } + + let byte = unsafe { *bytes.get_unchecked(0) }; + if byte < 0x80 { + buf.advance(1); + Ok(u64::from(byte)) + } else if len > 10 || bytes[len - 1] < 0x80 { + let (value, advance) = unsafe { decode_varint_slice(bytes) }?; + buf.advance(advance); + Ok(value) + } else { + decode_varint_slow(buf) + } +} + +/// Decodes a LEB128-encoded variable length integer from the slice, returning the value and the +/// number of bytes read. +/// +/// Based loosely on [`ReadVarint64FromArray`][1]. +/// +/// ## Safety +/// +/// The caller must ensure that `bytes` is non-empty and either `bytes.len() >= 10` or the last +/// element in bytes is < `0x80`. +/// +/// [1]: https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.cc#L365-L406 +#[inline] +unsafe fn decode_varint_slice(bytes: &[u8]) -> Result<(u64, usize), DecodeError> { + // Fully unrolled varint decoding loop. Splitting into 32-bit pieces gives better performance. + + let mut b: u8; + let mut part0: u32; + b = *bytes.get_unchecked(0); + part0 = u32::from(b); + if b < 0x80 { + return Ok((u64::from(part0), 1)); + }; + part0 -= 0x80; + b = *bytes.get_unchecked(1); + part0 += u32::from(b) << 7; + if b < 0x80 { + return Ok((u64::from(part0), 2)); + }; + part0 -= 0x80 << 7; + b = *bytes.get_unchecked(2); + part0 += u32::from(b) << 14; + if b < 0x80 { + return Ok((u64::from(part0), 3)); + }; + part0 -= 0x80 << 14; + b = *bytes.get_unchecked(3); + part0 += u32::from(b) << 21; + if b < 0x80 { + return Ok((u64::from(part0), 4)); + }; + part0 -= 0x80 << 21; + let value = u64::from(part0); + + let mut part1: u32; + b = *bytes.get_unchecked(4); + part1 = u32::from(b); + if b < 0x80 { + return Ok((value + (u64::from(part1) << 28), 5)); + }; + part1 -= 0x80; + b = *bytes.get_unchecked(5); + part1 += u32::from(b) << 7; + if b < 0x80 { + return Ok((value + (u64::from(part1) << 28), 6)); + }; + part1 -= 0x80 << 7; + b = *bytes.get_unchecked(6); + part1 += u32::from(b) << 14; + if b < 0x80 { + return Ok((value + (u64::from(part1) << 28), 7)); + }; + part1 -= 0x80 << 14; + b = *bytes.get_unchecked(7); + part1 += u32::from(b) << 21; + if b < 0x80 { + return Ok((value + (u64::from(part1) << 28), 8)); + }; + part1 -= 0x80 << 21; + let value = value + ((u64::from(part1)) << 28); + + let mut part2: u32; + b = *bytes.get_unchecked(8); + part2 = u32::from(b); + if b < 0x80 { + return Ok((value + (u64::from(part2) << 56), 9)); + }; + part2 -= 0x80; + b = *bytes.get_unchecked(9); + part2 += u32::from(b) << 7; + if b < 0x80 { + return Ok((value + (u64::from(part2) << 56), 10)); + }; + + // We have overrun the maximum size of a varint (10 bytes). Assume the data is corrupt. + Err(DecodeError::new("invalid varint")) +} + +/// Decodes a LEB128-encoded variable length integer from the buffer, advancing the buffer as +/// necessary. +#[inline(never)] +fn decode_varint_slow<B>(buf: &mut B) -> Result<u64, DecodeError> +where + B: Buf, +{ + let mut value = 0; + for count in 0..min(10, buf.remaining()) { + let byte = buf.get_u8(); + value |= u64::from(byte & 0x7F) << (count * 7); + if byte <= 0x7F { + return Ok(value); + } + } + + Err(DecodeError::new("invalid varint")) +} + +/// Additional information passed to every decode/merge function. +/// +/// The context should be passed by value and can be freely cloned. When passing +/// to a function which is decoding a nested object, then use `enter_recursion`. +#[derive(Clone, Debug)] +pub struct DecodeContext { + /// How many times we can recurse in the current decode stack before we hit + /// the recursion limit. + /// + /// The recursion limit is defined by `RECURSION_LIMIT` and cannot be + /// customized. The recursion limit can be ignored by building the Prost + /// crate with the `no-recursion-limit` feature. + #[cfg(not(feature = "no-recursion-limit"))] + recurse_count: u32, +} + +impl Default for DecodeContext { + #[cfg(not(feature = "no-recursion-limit"))] + #[inline] + fn default() -> DecodeContext { + DecodeContext { + recurse_count: crate::RECURSION_LIMIT, + } + } + + #[cfg(feature = "no-recursion-limit")] + #[inline] + fn default() -> DecodeContext { + DecodeContext {} + } +} + +impl DecodeContext { + /// Call this function before recursively decoding. + /// + /// There is no `exit` function since this function creates a new `DecodeContext` + /// to be used at the next level of recursion. Continue to use the old context + // at the previous level of recursion. + #[cfg(not(feature = "no-recursion-limit"))] + #[inline] + pub(crate) fn enter_recursion(&self) -> DecodeContext { + DecodeContext { + recurse_count: self.recurse_count - 1, + } + } + + #[cfg(feature = "no-recursion-limit")] + #[inline] + pub(crate) fn enter_recursion(&self) -> DecodeContext { + DecodeContext {} + } + + /// Checks whether the recursion limit has been reached in the stack of + /// decodes described by the `DecodeContext` at `self.ctx`. + /// + /// Returns `Ok<()>` if it is ok to continue recursing. + /// Returns `Err<DecodeError>` if the recursion limit has been reached. + #[cfg(not(feature = "no-recursion-limit"))] + #[inline] + pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> { + if self.recurse_count == 0 { + Err(DecodeError::new("recursion limit reached")) + } else { + Ok(()) + } + } + + #[cfg(feature = "no-recursion-limit")] + #[inline] + #[allow(clippy::unnecessary_wraps)] // needed in other features + pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> { + Ok(()) + } +} + +/// Returns the encoded length of the value in LEB128 variable length format. +/// The returned value will be between 1 and 10, inclusive. +#[inline] +pub fn encoded_len_varint(value: u64) -> usize { + // Based on [VarintSize64][1]. + // [1]: https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.h#L1301-L1309 + ((((value | 1).leading_zeros() ^ 63) * 9 + 73) / 64) as usize +} + +#[derive(Clone, Copy, Debug, PartialEq)] +#[repr(u8)] +pub enum WireType { + Varint = 0, + SixtyFourBit = 1, + LengthDelimited = 2, + StartGroup = 3, + EndGroup = 4, + ThirtyTwoBit = 5, +} + +pub const MIN_TAG: u32 = 1; +pub const MAX_TAG: u32 = (1 << 29) - 1; + +impl TryFrom<u64> for WireType { + type Error = DecodeError; + + #[inline] + fn try_from(value: u64) -> Result<Self, Self::Error> { + match value { + 0 => Ok(WireType::Varint), + 1 => Ok(WireType::SixtyFourBit), + 2 => Ok(WireType::LengthDelimited), + 3 => Ok(WireType::StartGroup), + 4 => Ok(WireType::EndGroup), + 5 => Ok(WireType::ThirtyTwoBit), + _ => Err(DecodeError::new(format!( + "invalid wire type value: {}", + value + ))), + } + } +} + +/// Encodes a Protobuf field key, which consists of a wire type designator and +/// the field tag. +#[inline] +pub fn encode_key<B>(tag: u32, wire_type: WireType, buf: &mut B) +where + B: BufMut, +{ + debug_assert!((MIN_TAG..=MAX_TAG).contains(&tag)); + let key = (tag << 3) | wire_type as u32; + encode_varint(u64::from(key), buf); +} + +/// Decodes a Protobuf field key, which consists of a wire type designator and +/// the field tag. +#[inline(always)] +pub fn decode_key<B>(buf: &mut B) -> Result<(u32, WireType), DecodeError> +where + B: Buf, +{ + let key = decode_varint(buf)?; + if key > u64::from(u32::MAX) { + return Err(DecodeError::new(format!("invalid key value: {}", key))); + } + let wire_type = WireType::try_from(key & 0x07)?; + let tag = key as u32 >> 3; + + if tag < MIN_TAG { + return Err(DecodeError::new("invalid tag value: 0")); + } + + Ok((tag, wire_type)) +} + +/// Returns the width of an encoded Protobuf field key with the given tag. +/// The returned width will be between 1 and 5 bytes (inclusive). +#[inline] +pub fn key_len(tag: u32) -> usize { + encoded_len_varint(u64::from(tag << 3)) +} + +/// Checks that the expected wire type matches the actual wire type, +/// or returns an error result. +#[inline] +pub fn check_wire_type(expected: WireType, actual: WireType) -> Result<(), DecodeError> { + if expected != actual { + return Err(DecodeError::new(format!( + "invalid wire type: {:?} (expected {:?})", + actual, expected + ))); + } + Ok(()) +} + +/// Helper function which abstracts reading a length delimiter prefix followed +/// by decoding values until the length of bytes is exhausted. +pub fn merge_loop<T, M, B>( + value: &mut T, + buf: &mut B, + ctx: DecodeContext, + mut merge: M, +) -> Result<(), DecodeError> +where + M: FnMut(&mut T, &mut B, DecodeContext) -> Result<(), DecodeError>, + B: Buf, +{ + let len = decode_varint(buf)?; + let remaining = buf.remaining(); + if len > remaining as u64 { + return Err(DecodeError::new("buffer underflow")); + } + + let limit = remaining - len as usize; + while buf.remaining() > limit { + merge(value, buf, ctx.clone())?; + } + + if buf.remaining() != limit { + return Err(DecodeError::new("delimited length exceeded")); + } + Ok(()) +} + +pub fn skip_field<B>( + wire_type: WireType, + tag: u32, + buf: &mut B, + ctx: DecodeContext, +) -> Result<(), DecodeError> +where + B: Buf, +{ + ctx.limit_reached()?; + let len = match wire_type { + WireType::Varint => decode_varint(buf).map(|_| 0)?, + WireType::ThirtyTwoBit => 4, + WireType::SixtyFourBit => 8, + WireType::LengthDelimited => decode_varint(buf)?, + WireType::StartGroup => loop { + let (inner_tag, inner_wire_type) = decode_key(buf)?; + match inner_wire_type { + WireType::EndGroup => { + if inner_tag != tag { + return Err(DecodeError::new("unexpected end group tag")); + } + break 0; + } + _ => skip_field(inner_wire_type, inner_tag, buf, ctx.enter_recursion())?, + } + }, + WireType::EndGroup => return Err(DecodeError::new("unexpected end group tag")), + }; + + if len > buf.remaining() as u64 { + return Err(DecodeError::new("buffer underflow")); + } + + buf.advance(len as usize); + Ok(()) +} + +/// Helper macro which emits an `encode_repeated` function for the type. +macro_rules! encode_repeated { + ($ty:ty) => { + pub fn encode_repeated<B>(tag: u32, values: &[$ty], buf: &mut B) + where + B: BufMut, + { + for value in values { + encode(tag, value, buf); + } + } + }; +} + +/// Helper macro which emits a `merge_repeated` function for the numeric type. +macro_rules! merge_repeated_numeric { + ($ty:ty, + $wire_type:expr, + $merge:ident, + $merge_repeated:ident) => { + pub fn $merge_repeated<B>( + wire_type: WireType, + values: &mut Vec<$ty>, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + if wire_type == WireType::LengthDelimited { + // Packed. + merge_loop(values, buf, ctx, |values, buf, ctx| { + let mut value = Default::default(); + $merge($wire_type, &mut value, buf, ctx)?; + values.push(value); + Ok(()) + }) + } else { + // Unpacked. + check_wire_type($wire_type, wire_type)?; + let mut value = Default::default(); + $merge(wire_type, &mut value, buf, ctx)?; + values.push(value); + Ok(()) + } + } + }; +} + +/// Macro which emits a module containing a set of encoding functions for a +/// variable width numeric type. +macro_rules! varint { + ($ty:ty, + $proto_ty:ident) => ( + varint!($ty, + $proto_ty, + to_uint64(value) { *value as u64 }, + from_uint64(value) { value as $ty }); + ); + + ($ty:ty, + $proto_ty:ident, + to_uint64($to_uint64_value:ident) $to_uint64:expr, + from_uint64($from_uint64_value:ident) $from_uint64:expr) => ( + + pub mod $proto_ty { + use crate::encoding::*; + + pub fn encode<B>(tag: u32, $to_uint64_value: &$ty, buf: &mut B) where B: BufMut { + encode_key(tag, WireType::Varint, buf); + encode_varint($to_uint64, buf); + } + + pub fn merge<B>(wire_type: WireType, value: &mut $ty, buf: &mut B, _ctx: DecodeContext) -> Result<(), DecodeError> where B: Buf { + check_wire_type(WireType::Varint, wire_type)?; + let $from_uint64_value = decode_varint(buf)?; + *value = $from_uint64; + Ok(()) + } + + encode_repeated!($ty); + + pub fn encode_packed<B>(tag: u32, values: &[$ty], buf: &mut B) where B: BufMut { + if values.is_empty() { return; } + + encode_key(tag, WireType::LengthDelimited, buf); + let len: usize = values.iter().map(|$to_uint64_value| { + encoded_len_varint($to_uint64) + }).sum(); + encode_varint(len as u64, buf); + + for $to_uint64_value in values { + encode_varint($to_uint64, buf); + } + } + + merge_repeated_numeric!($ty, WireType::Varint, merge, merge_repeated); + + #[inline] + pub fn encoded_len(tag: u32, $to_uint64_value: &$ty) -> usize { + key_len(tag) + encoded_len_varint($to_uint64) + } + + #[inline] + pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize { + key_len(tag) * values.len() + values.iter().map(|$to_uint64_value| { + encoded_len_varint($to_uint64) + }).sum::<usize>() + } + + #[inline] + pub fn encoded_len_packed(tag: u32, values: &[$ty]) -> usize { + if values.is_empty() { + 0 + } else { + let len = values.iter() + .map(|$to_uint64_value| encoded_len_varint($to_uint64)) + .sum::<usize>(); + key_len(tag) + encoded_len_varint(len as u64) + len + } + } + + #[cfg(test)] + mod test { + use proptest::prelude::*; + + use crate::encoding::$proto_ty::*; + use crate::encoding::test::{ + check_collection_type, + check_type, + }; + + proptest! { + #[test] + fn check(value: $ty, tag in MIN_TAG..=MAX_TAG) { + check_type(value, tag, WireType::Varint, + encode, merge, encoded_len)?; + } + #[test] + fn check_repeated(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) { + check_collection_type(value, tag, WireType::Varint, + encode_repeated, merge_repeated, + encoded_len_repeated)?; + } + #[test] + fn check_packed(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) { + check_type(value, tag, WireType::LengthDelimited, + encode_packed, merge_repeated, + encoded_len_packed)?; + } + } + } + } + + ); +} +varint!(bool, bool, + to_uint64(value) if *value { 1u64 } else { 0u64 }, + from_uint64(value) value != 0); +varint!(i32, int32); +varint!(i64, int64); +varint!(u32, uint32); +varint!(u64, uint64); +varint!(i32, sint32, +to_uint64(value) { + ((value << 1) ^ (value >> 31)) as u32 as u64 +}, +from_uint64(value) { + let value = value as u32; + ((value >> 1) as i32) ^ (-((value & 1) as i32)) +}); +varint!(i64, sint64, +to_uint64(value) { + ((value << 1) ^ (value >> 63)) as u64 +}, +from_uint64(value) { + ((value >> 1) as i64) ^ (-((value & 1) as i64)) +}); + +/// Macro which emits a module containing a set of encoding functions for a +/// fixed width numeric type. +macro_rules! fixed_width { + ($ty:ty, + $width:expr, + $wire_type:expr, + $proto_ty:ident, + $put:ident, + $get:ident) => { + pub mod $proto_ty { + use crate::encoding::*; + + pub fn encode<B>(tag: u32, value: &$ty, buf: &mut B) + where + B: BufMut, + { + encode_key(tag, $wire_type, buf); + buf.$put(*value); + } + + pub fn merge<B>( + wire_type: WireType, + value: &mut $ty, + buf: &mut B, + _ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + check_wire_type($wire_type, wire_type)?; + if buf.remaining() < $width { + return Err(DecodeError::new("buffer underflow")); + } + *value = buf.$get(); + Ok(()) + } + + encode_repeated!($ty); + + pub fn encode_packed<B>(tag: u32, values: &[$ty], buf: &mut B) + where + B: BufMut, + { + if values.is_empty() { + return; + } + + encode_key(tag, WireType::LengthDelimited, buf); + let len = values.len() as u64 * $width; + encode_varint(len as u64, buf); + + for value in values { + buf.$put(*value); + } + } + + merge_repeated_numeric!($ty, $wire_type, merge, merge_repeated); + + #[inline] + pub fn encoded_len(tag: u32, _: &$ty) -> usize { + key_len(tag) + $width + } + + #[inline] + pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize { + (key_len(tag) + $width) * values.len() + } + + #[inline] + pub fn encoded_len_packed(tag: u32, values: &[$ty]) -> usize { + if values.is_empty() { + 0 + } else { + let len = $width * values.len(); + key_len(tag) + encoded_len_varint(len as u64) + len + } + } + + #[cfg(test)] + mod test { + use proptest::prelude::*; + + use super::super::test::{check_collection_type, check_type}; + use super::*; + + proptest! { + #[test] + fn check(value: $ty, tag in MIN_TAG..=MAX_TAG) { + check_type(value, tag, $wire_type, + encode, merge, encoded_len)?; + } + #[test] + fn check_repeated(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) { + check_collection_type(value, tag, $wire_type, + encode_repeated, merge_repeated, + encoded_len_repeated)?; + } + #[test] + fn check_packed(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) { + check_type(value, tag, WireType::LengthDelimited, + encode_packed, merge_repeated, + encoded_len_packed)?; + } + } + } + } + }; +} +fixed_width!( + f32, + 4, + WireType::ThirtyTwoBit, + float, + put_f32_le, + get_f32_le +); +fixed_width!( + f64, + 8, + WireType::SixtyFourBit, + double, + put_f64_le, + get_f64_le +); +fixed_width!( + u32, + 4, + WireType::ThirtyTwoBit, + fixed32, + put_u32_le, + get_u32_le +); +fixed_width!( + u64, + 8, + WireType::SixtyFourBit, + fixed64, + put_u64_le, + get_u64_le +); +fixed_width!( + i32, + 4, + WireType::ThirtyTwoBit, + sfixed32, + put_i32_le, + get_i32_le +); +fixed_width!( + i64, + 8, + WireType::SixtyFourBit, + sfixed64, + put_i64_le, + get_i64_le +); + +/// Macro which emits encoding functions for a length-delimited type. +macro_rules! length_delimited { + ($ty:ty) => { + encode_repeated!($ty); + + pub fn merge_repeated<B>( + wire_type: WireType, + values: &mut Vec<$ty>, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + check_wire_type(WireType::LengthDelimited, wire_type)?; + let mut value = Default::default(); + merge(wire_type, &mut value, buf, ctx)?; + values.push(value); + Ok(()) + } + + #[inline] + pub fn encoded_len(tag: u32, value: &$ty) -> usize { + key_len(tag) + encoded_len_varint(value.len() as u64) + value.len() + } + + #[inline] + pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize { + key_len(tag) * values.len() + + values + .iter() + .map(|value| encoded_len_varint(value.len() as u64) + value.len()) + .sum::<usize>() + } + }; +} + +pub mod string { + use super::*; + + pub fn encode<B>(tag: u32, value: &String, buf: &mut B) + where + B: BufMut, + { + encode_key(tag, WireType::LengthDelimited, buf); + encode_varint(value.len() as u64, buf); + buf.put_slice(value.as_bytes()); + } + pub fn merge<B>( + wire_type: WireType, + value: &mut String, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + // ## Unsafety + // + // `string::merge` reuses `bytes::merge`, with an additional check of utf-8 + // well-formedness. If the utf-8 is not well-formed, or if any other error occurs, then the + // string is cleared, so as to avoid leaking a string field with invalid data. + // + // This implementation uses the unsafe `String::as_mut_vec` method instead of the safe + // alternative of temporarily swapping an empty `String` into the field, because it results + // in up to 10% better performance on the protobuf message decoding benchmarks. + // + // It's required when using `String::as_mut_vec` that invalid utf-8 data not be leaked into + // the backing `String`. To enforce this, even in the event of a panic in `bytes::merge` or + // in the buf implementation, a drop guard is used. + unsafe { + struct DropGuard<'a>(&'a mut Vec<u8>); + impl<'a> Drop for DropGuard<'a> { + #[inline] + fn drop(&mut self) { + self.0.clear(); + } + } + + let drop_guard = DropGuard(value.as_mut_vec()); + bytes::merge(wire_type, drop_guard.0, buf, ctx)?; + match str::from_utf8(drop_guard.0) { + Ok(_) => { + // Success; do not clear the bytes. + mem::forget(drop_guard); + Ok(()) + } + Err(_) => Err(DecodeError::new( + "invalid string value: data is not UTF-8 encoded", + )), + } + } + } + + length_delimited!(String); + + #[cfg(test)] + mod test { + use proptest::prelude::*; + + use super::super::test::{check_collection_type, check_type}; + use super::*; + + proptest! { + #[test] + fn check(value: String, tag in MIN_TAG..=MAX_TAG) { + super::test::check_type(value, tag, WireType::LengthDelimited, + encode, merge, encoded_len)?; + } + #[test] + fn check_repeated(value: Vec<String>, tag in MIN_TAG..=MAX_TAG) { + super::test::check_collection_type(value, tag, WireType::LengthDelimited, + encode_repeated, merge_repeated, + encoded_len_repeated)?; + } + } + } +} + +pub trait BytesAdapter: sealed::BytesAdapter {} + +mod sealed { + use super::{Buf, BufMut}; + + pub trait BytesAdapter: Default + Sized + 'static { + fn len(&self) -> usize; + + /// Replace contents of this buffer with the contents of another buffer. + fn replace_with<B>(&mut self, buf: B) + where + B: Buf; + + /// Appends this buffer to the (contents of) other buffer. + fn append_to<B>(&self, buf: &mut B) + where + B: BufMut; + + fn is_empty(&self) -> bool { + self.len() == 0 + } + } +} + +impl BytesAdapter for Bytes {} + +impl sealed::BytesAdapter for Bytes { + fn len(&self) -> usize { + Buf::remaining(self) + } + + fn replace_with<B>(&mut self, mut buf: B) + where + B: Buf, + { + *self = buf.copy_to_bytes(buf.remaining()); + } + + fn append_to<B>(&self, buf: &mut B) + where + B: BufMut, + { + buf.put(self.clone()) + } +} + +impl BytesAdapter for Vec<u8> {} + +impl sealed::BytesAdapter for Vec<u8> { + fn len(&self) -> usize { + Vec::len(self) + } + + fn replace_with<B>(&mut self, buf: B) + where + B: Buf, + { + self.clear(); + self.reserve(buf.remaining()); + self.put(buf); + } + + fn append_to<B>(&self, buf: &mut B) + where + B: BufMut, + { + buf.put(self.as_slice()) + } +} + +pub mod bytes { + use super::*; + + pub fn encode<A, B>(tag: u32, value: &A, buf: &mut B) + where + A: BytesAdapter, + B: BufMut, + { + encode_key(tag, WireType::LengthDelimited, buf); + encode_varint(value.len() as u64, buf); + value.append_to(buf); + } + + pub fn merge<A, B>( + wire_type: WireType, + value: &mut A, + buf: &mut B, + _ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + A: BytesAdapter, + B: Buf, + { + check_wire_type(WireType::LengthDelimited, wire_type)?; + let len = decode_varint(buf)?; + if len > buf.remaining() as u64 { + return Err(DecodeError::new("buffer underflow")); + } + let len = len as usize; + + // Clear the existing value. This follows from the following rule in the encoding guide[1]: + // + // > Normally, an encoded message would never have more than one instance of a non-repeated + // > field. However, parsers are expected to handle the case in which they do. For numeric + // > types and strings, if the same field appears multiple times, the parser accepts the + // > last value it sees. + // + // [1]: https://developers.google.com/protocol-buffers/docs/encoding#optional + value.replace_with(buf.copy_to_bytes(len)); + Ok(()) + } + + length_delimited!(impl BytesAdapter); + + #[cfg(test)] + mod test { + use proptest::prelude::*; + + use super::super::test::{check_collection_type, check_type}; + use super::*; + + proptest! { + #[test] + fn check_vec(value: Vec<u8>, tag in MIN_TAG..=MAX_TAG) { + super::test::check_type::<Vec<u8>, Vec<u8>>(value, tag, WireType::LengthDelimited, + encode, merge, encoded_len)?; + } + + #[test] + fn check_bytes(value: Vec<u8>, tag in MIN_TAG..=MAX_TAG) { + let value = Bytes::from(value); + super::test::check_type::<Bytes, Bytes>(value, tag, WireType::LengthDelimited, + encode, merge, encoded_len)?; + } + + #[test] + fn check_repeated_vec(value: Vec<Vec<u8>>, tag in MIN_TAG..=MAX_TAG) { + super::test::check_collection_type(value, tag, WireType::LengthDelimited, + encode_repeated, merge_repeated, + encoded_len_repeated)?; + } + + #[test] + fn check_repeated_bytes(value: Vec<Vec<u8>>, tag in MIN_TAG..=MAX_TAG) { + let value = value.into_iter().map(Bytes::from).collect(); + super::test::check_collection_type(value, tag, WireType::LengthDelimited, + encode_repeated, merge_repeated, + encoded_len_repeated)?; + } + } + } +} + +pub mod message { + use super::*; + + pub fn encode<M, B>(tag: u32, msg: &M, buf: &mut B) + where + M: Message, + B: BufMut, + { + encode_key(tag, WireType::LengthDelimited, buf); + encode_varint(msg.encoded_len() as u64, buf); + msg.encode_raw(buf); + } + + pub fn merge<M, B>( + wire_type: WireType, + msg: &mut M, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + M: Message, + B: Buf, + { + check_wire_type(WireType::LengthDelimited, wire_type)?; + ctx.limit_reached()?; + merge_loop( + msg, + buf, + ctx.enter_recursion(), + |msg: &mut M, buf: &mut B, ctx| { + let (tag, wire_type) = decode_key(buf)?; + msg.merge_field(tag, wire_type, buf, ctx) + }, + ) + } + + pub fn encode_repeated<M, B>(tag: u32, messages: &[M], buf: &mut B) + where + M: Message, + B: BufMut, + { + for msg in messages { + encode(tag, msg, buf); + } + } + + pub fn merge_repeated<M, B>( + wire_type: WireType, + messages: &mut Vec<M>, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + M: Message + Default, + B: Buf, + { + check_wire_type(WireType::LengthDelimited, wire_type)?; + let mut msg = M::default(); + merge(WireType::LengthDelimited, &mut msg, buf, ctx)?; + messages.push(msg); + Ok(()) + } + + #[inline] + pub fn encoded_len<M>(tag: u32, msg: &M) -> usize + where + M: Message, + { + let len = msg.encoded_len(); + key_len(tag) + encoded_len_varint(len as u64) + len + } + + #[inline] + pub fn encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize + where + M: Message, + { + key_len(tag) * messages.len() + + messages + .iter() + .map(Message::encoded_len) + .map(|len| len + encoded_len_varint(len as u64)) + .sum::<usize>() + } +} + +pub mod group { + use super::*; + + pub fn encode<M, B>(tag: u32, msg: &M, buf: &mut B) + where + M: Message, + B: BufMut, + { + encode_key(tag, WireType::StartGroup, buf); + msg.encode_raw(buf); + encode_key(tag, WireType::EndGroup, buf); + } + + pub fn merge<M, B>( + tag: u32, + wire_type: WireType, + msg: &mut M, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + M: Message, + B: Buf, + { + check_wire_type(WireType::StartGroup, wire_type)?; + + ctx.limit_reached()?; + loop { + let (field_tag, field_wire_type) = decode_key(buf)?; + if field_wire_type == WireType::EndGroup { + if field_tag != tag { + return Err(DecodeError::new("unexpected end group tag")); + } + return Ok(()); + } + + M::merge_field(msg, field_tag, field_wire_type, buf, ctx.enter_recursion())?; + } + } + + pub fn encode_repeated<M, B>(tag: u32, messages: &[M], buf: &mut B) + where + M: Message, + B: BufMut, + { + for msg in messages { + encode(tag, msg, buf); + } + } + + pub fn merge_repeated<M, B>( + tag: u32, + wire_type: WireType, + messages: &mut Vec<M>, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + M: Message + Default, + B: Buf, + { + check_wire_type(WireType::StartGroup, wire_type)?; + let mut msg = M::default(); + merge(tag, WireType::StartGroup, &mut msg, buf, ctx)?; + messages.push(msg); + Ok(()) + } + + #[inline] + pub fn encoded_len<M>(tag: u32, msg: &M) -> usize + where + M: Message, + { + 2 * key_len(tag) + msg.encoded_len() + } + + #[inline] + pub fn encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize + where + M: Message, + { + 2 * key_len(tag) * messages.len() + messages.iter().map(Message::encoded_len).sum::<usize>() + } +} + +/// Rust doesn't have a `Map` trait, so macros are currently the best way to be +/// generic over `HashMap` and `BTreeMap`. +macro_rules! map { + ($map_ty:ident) => { + use crate::encoding::*; + use core::hash::Hash; + + /// Generic protobuf map encode function. + pub fn encode<K, V, B, KE, KL, VE, VL>( + key_encode: KE, + key_encoded_len: KL, + val_encode: VE, + val_encoded_len: VL, + tag: u32, + values: &$map_ty<K, V>, + buf: &mut B, + ) where + K: Default + Eq + Hash + Ord, + V: Default + PartialEq, + B: BufMut, + KE: Fn(u32, &K, &mut B), + KL: Fn(u32, &K) -> usize, + VE: Fn(u32, &V, &mut B), + VL: Fn(u32, &V) -> usize, + { + encode_with_default( + key_encode, + key_encoded_len, + val_encode, + val_encoded_len, + &V::default(), + tag, + values, + buf, + ) + } + + /// Generic protobuf map merge function. + pub fn merge<K, V, B, KM, VM>( + key_merge: KM, + val_merge: VM, + values: &mut $map_ty<K, V>, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + K: Default + Eq + Hash + Ord, + V: Default, + B: Buf, + KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>, + VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>, + { + merge_with_default(key_merge, val_merge, V::default(), values, buf, ctx) + } + + /// Generic protobuf map encode function. + pub fn encoded_len<K, V, KL, VL>( + key_encoded_len: KL, + val_encoded_len: VL, + tag: u32, + values: &$map_ty<K, V>, + ) -> usize + where + K: Default + Eq + Hash + Ord, + V: Default + PartialEq, + KL: Fn(u32, &K) -> usize, + VL: Fn(u32, &V) -> usize, + { + encoded_len_with_default(key_encoded_len, val_encoded_len, &V::default(), tag, values) + } + + /// Generic protobuf map encode function with an overriden value default. + /// + /// This is necessary because enumeration values can have a default value other + /// than 0 in proto2. + pub fn encode_with_default<K, V, B, KE, KL, VE, VL>( + key_encode: KE, + key_encoded_len: KL, + val_encode: VE, + val_encoded_len: VL, + val_default: &V, + tag: u32, + values: &$map_ty<K, V>, + buf: &mut B, + ) where + K: Default + Eq + Hash + Ord, + V: PartialEq, + B: BufMut, + KE: Fn(u32, &K, &mut B), + KL: Fn(u32, &K) -> usize, + VE: Fn(u32, &V, &mut B), + VL: Fn(u32, &V) -> usize, + { + for (key, val) in values.iter() { + let skip_key = key == &K::default(); + let skip_val = val == val_default; + + let len = (if skip_key { 0 } else { key_encoded_len(1, key) }) + + (if skip_val { 0 } else { val_encoded_len(2, val) }); + + encode_key(tag, WireType::LengthDelimited, buf); + encode_varint(len as u64, buf); + if !skip_key { + key_encode(1, key, buf); + } + if !skip_val { + val_encode(2, val, buf); + } + } + } + + /// Generic protobuf map merge function with an overriden value default. + /// + /// This is necessary because enumeration values can have a default value other + /// than 0 in proto2. + pub fn merge_with_default<K, V, B, KM, VM>( + key_merge: KM, + val_merge: VM, + val_default: V, + values: &mut $map_ty<K, V>, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + K: Default + Eq + Hash + Ord, + B: Buf, + KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>, + VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>, + { + let mut key = Default::default(); + let mut val = val_default; + ctx.limit_reached()?; + merge_loop( + &mut (&mut key, &mut val), + buf, + ctx.enter_recursion(), + |&mut (ref mut key, ref mut val), buf, ctx| { + let (tag, wire_type) = decode_key(buf)?; + match tag { + 1 => key_merge(wire_type, key, buf, ctx), + 2 => val_merge(wire_type, val, buf, ctx), + _ => skip_field(wire_type, tag, buf, ctx), + } + }, + )?; + values.insert(key, val); + + Ok(()) + } + + /// Generic protobuf map encode function with an overriden value default. + /// + /// This is necessary because enumeration values can have a default value other + /// than 0 in proto2. + pub fn encoded_len_with_default<K, V, KL, VL>( + key_encoded_len: KL, + val_encoded_len: VL, + val_default: &V, + tag: u32, + values: &$map_ty<K, V>, + ) -> usize + where + K: Default + Eq + Hash + Ord, + V: PartialEq, + KL: Fn(u32, &K) -> usize, + VL: Fn(u32, &V) -> usize, + { + key_len(tag) * values.len() + + values + .iter() + .map(|(key, val)| { + let len = (if key == &K::default() { + 0 + } else { + key_encoded_len(1, key) + }) + (if val == val_default { + 0 + } else { + val_encoded_len(2, val) + }); + encoded_len_varint(len as u64) + len + }) + .sum::<usize>() + } + }; +} + +#[cfg(feature = "std")] +pub mod hash_map { + use std::collections::HashMap; + map!(HashMap); +} + +pub mod btree_map { + map!(BTreeMap); +} + +#[cfg(test)] +mod test { + use alloc::string::ToString; + use core::borrow::Borrow; + use core::fmt::Debug; + use core::u64; + + use ::bytes::{Bytes, BytesMut}; + use proptest::{prelude::*, test_runner::TestCaseResult}; + + use crate::encoding::*; + + pub fn check_type<T, B>( + value: T, + tag: u32, + wire_type: WireType, + encode: fn(u32, &B, &mut BytesMut), + merge: fn(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>, + encoded_len: fn(u32, &B) -> usize, + ) -> TestCaseResult + where + T: Debug + Default + PartialEq + Borrow<B>, + B: ?Sized, + { + prop_assume!(MIN_TAG <= tag && tag <= MAX_TAG); + + let expected_len = encoded_len(tag, value.borrow()); + + let mut buf = BytesMut::with_capacity(expected_len); + encode(tag, value.borrow(), &mut buf); + + let mut buf = buf.freeze(); + + prop_assert_eq!( + buf.remaining(), + expected_len, + "encoded_len wrong; expected: {}, actual: {}", + expected_len, + buf.remaining() + ); + + if !buf.has_remaining() { + // Short circuit for empty packed values. + return Ok(()); + } + + let (decoded_tag, decoded_wire_type) = + decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?; + prop_assert_eq!( + tag, + decoded_tag, + "decoded tag does not match; expected: {}, actual: {}", + tag, + decoded_tag + ); + + prop_assert_eq!( + wire_type, + decoded_wire_type, + "decoded wire type does not match; expected: {:?}, actual: {:?}", + wire_type, + decoded_wire_type, + ); + + match wire_type { + WireType::SixtyFourBit if buf.remaining() != 8 => Err(TestCaseError::fail(format!( + "64bit wire type illegal remaining: {}, tag: {}", + buf.remaining(), + tag + ))), + WireType::ThirtyTwoBit if buf.remaining() != 4 => Err(TestCaseError::fail(format!( + "32bit wire type illegal remaining: {}, tag: {}", + buf.remaining(), + tag + ))), + _ => Ok(()), + }?; + + let mut roundtrip_value = T::default(); + merge( + wire_type, + &mut roundtrip_value, + &mut buf, + DecodeContext::default(), + ) + .map_err(|error| TestCaseError::fail(error.to_string()))?; + + prop_assert!( + !buf.has_remaining(), + "expected buffer to be empty, remaining: {}", + buf.remaining() + ); + + prop_assert_eq!(value, roundtrip_value); + + Ok(()) + } + + pub fn check_collection_type<T, B, E, M, L>( + value: T, + tag: u32, + wire_type: WireType, + encode: E, + mut merge: M, + encoded_len: L, + ) -> TestCaseResult + where + T: Debug + Default + PartialEq + Borrow<B>, + B: ?Sized, + E: FnOnce(u32, &B, &mut BytesMut), + M: FnMut(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>, + L: FnOnce(u32, &B) -> usize, + { + prop_assume!(MIN_TAG <= tag && tag <= MAX_TAG); + + let expected_len = encoded_len(tag, value.borrow()); + + let mut buf = BytesMut::with_capacity(expected_len); + encode(tag, value.borrow(), &mut buf); + + let mut buf = buf.freeze(); + + prop_assert_eq!( + buf.remaining(), + expected_len, + "encoded_len wrong; expected: {}, actual: {}", + expected_len, + buf.remaining() + ); + + let mut roundtrip_value = Default::default(); + while buf.has_remaining() { + let (decoded_tag, decoded_wire_type) = + decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?; + + prop_assert_eq!( + tag, + decoded_tag, + "decoded tag does not match; expected: {}, actual: {}", + tag, + decoded_tag + ); + + prop_assert_eq!( + wire_type, + decoded_wire_type, + "decoded wire type does not match; expected: {:?}, actual: {:?}", + wire_type, + decoded_wire_type + ); + + merge( + wire_type, + &mut roundtrip_value, + &mut buf, + DecodeContext::default(), + ) + .map_err(|error| TestCaseError::fail(error.to_string()))?; + } + + prop_assert_eq!(value, roundtrip_value); + + Ok(()) + } + + #[test] + fn string_merge_invalid_utf8() { + let mut s = String::new(); + let buf = b"\x02\x80\x80"; + + let r = string::merge( + WireType::LengthDelimited, + &mut s, + &mut &buf[..], + DecodeContext::default(), + ); + r.expect_err("must be an error"); + assert!(s.is_empty()); + } + + #[test] + fn varint() { + fn check(value: u64, mut encoded: &[u8]) { + // TODO(rust-lang/rust-clippy#5494) + #![allow(clippy::clone_double_ref)] + + // Small buffer. + let mut buf = Vec::with_capacity(1); + encode_varint(value, &mut buf); + assert_eq!(buf, encoded); + + // Large buffer. + let mut buf = Vec::with_capacity(100); + encode_varint(value, &mut buf); + assert_eq!(buf, encoded); + + assert_eq!(encoded_len_varint(value), encoded.len()); + + let roundtrip_value = decode_varint(&mut encoded.clone()).expect("decoding failed"); + assert_eq!(value, roundtrip_value); + + let roundtrip_value = decode_varint_slow(&mut encoded).expect("slow decoding failed"); + assert_eq!(value, roundtrip_value); + } + + check(2u64.pow(0) - 1, &[0x00]); + check(2u64.pow(0), &[0x01]); + + check(2u64.pow(7) - 1, &[0x7F]); + check(2u64.pow(7), &[0x80, 0x01]); + check(300, &[0xAC, 0x02]); + + check(2u64.pow(14) - 1, &[0xFF, 0x7F]); + check(2u64.pow(14), &[0x80, 0x80, 0x01]); + + check(2u64.pow(21) - 1, &[0xFF, 0xFF, 0x7F]); + check(2u64.pow(21), &[0x80, 0x80, 0x80, 0x01]); + + check(2u64.pow(28) - 1, &[0xFF, 0xFF, 0xFF, 0x7F]); + check(2u64.pow(28), &[0x80, 0x80, 0x80, 0x80, 0x01]); + + check(2u64.pow(35) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0x7F]); + check(2u64.pow(35), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x01]); + + check(2u64.pow(42) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]); + check(2u64.pow(42), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]); + + check( + 2u64.pow(49) - 1, + &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F], + ); + check( + 2u64.pow(49), + &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01], + ); + + check( + 2u64.pow(56) - 1, + &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F], + ); + check( + 2u64.pow(56), + &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01], + ); + + check( + 2u64.pow(63) - 1, + &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F], + ); + check( + 2u64.pow(63), + &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01], + ); + + check( + u64::MAX, + &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01], + ); + } + + /// This big bowl o' macro soup generates an encoding property test for each combination of map + /// type, scalar map key, and value type. + /// TODO: these tests take a long time to compile, can this be improved? + #[cfg(feature = "std")] + macro_rules! map_tests { + (keys: $keys:tt, + vals: $vals:tt) => { + mod hash_map { + map_tests!(@private HashMap, hash_map, $keys, $vals); + } + mod btree_map { + map_tests!(@private BTreeMap, btree_map, $keys, $vals); + } + }; + + (@private $map_type:ident, + $mod_name:ident, + [$(($key_ty:ty, $key_proto:ident)),*], + $vals:tt) => { + $( + mod $key_proto { + use std::collections::$map_type; + + use proptest::prelude::*; + + use crate::encoding::*; + use crate::encoding::test::check_collection_type; + + map_tests!(@private $map_type, $mod_name, ($key_ty, $key_proto), $vals); + } + )* + }; + + (@private $map_type:ident, + $mod_name:ident, + ($key_ty:ty, $key_proto:ident), + [$(($val_ty:ty, $val_proto:ident)),*]) => { + $( + proptest! { + #[test] + fn $val_proto(values: $map_type<$key_ty, $val_ty>, tag in MIN_TAG..=MAX_TAG) { + check_collection_type(values, tag, WireType::LengthDelimited, + |tag, values, buf| { + $mod_name::encode($key_proto::encode, + $key_proto::encoded_len, + $val_proto::encode, + $val_proto::encoded_len, + tag, + values, + buf) + }, + |wire_type, values, buf, ctx| { + check_wire_type(WireType::LengthDelimited, wire_type)?; + $mod_name::merge($key_proto::merge, + $val_proto::merge, + values, + buf, + ctx) + }, + |tag, values| { + $mod_name::encoded_len($key_proto::encoded_len, + $val_proto::encoded_len, + tag, + values) + })?; + } + } + )* + }; + } + + #[cfg(feature = "std")] + map_tests!(keys: [ + (i32, int32), + (i64, int64), + (u32, uint32), + (u64, uint64), + (i32, sint32), + (i64, sint64), + (u32, fixed32), + (u64, fixed64), + (i32, sfixed32), + (i64, sfixed64), + (bool, bool), + (String, string) + ], + vals: [ + (f32, float), + (f64, double), + (i32, int32), + (i64, int64), + (u32, uint32), + (u64, uint64), + (i32, sint32), + (i64, sint64), + (u32, fixed32), + (u64, fixed64), + (i32, sfixed32), + (i64, sfixed64), + (bool, bool), + (String, string), + (Vec<u8>, bytes) + ]); +} diff --git a/third_party/rust/prost/src/error.rs b/third_party/rust/prost/src/error.rs new file mode 100644 index 0000000000..fc098299c8 --- /dev/null +++ b/third_party/rust/prost/src/error.rs @@ -0,0 +1,131 @@ +//! Protobuf encoding and decoding errors. + +use alloc::borrow::Cow; +use alloc::boxed::Box; +use alloc::vec::Vec; + +use core::fmt; + +/// A Protobuf message decoding error. +/// +/// `DecodeError` indicates that the input buffer does not caontain a valid +/// Protobuf message. The error details should be considered 'best effort': in +/// general it is not possible to exactly pinpoint why data is malformed. +#[derive(Clone, PartialEq, Eq)] +pub struct DecodeError { + inner: Box<Inner>, +} + +#[derive(Clone, PartialEq, Eq)] +struct Inner { + /// A 'best effort' root cause description. + description: Cow<'static, str>, + /// A stack of (message, field) name pairs, which identify the specific + /// message type and field where decoding failed. The stack contains an + /// entry per level of nesting. + stack: Vec<(&'static str, &'static str)>, +} + +impl DecodeError { + /// Creates a new `DecodeError` with a 'best effort' root cause description. + /// + /// Meant to be used only by `Message` implementations. + #[doc(hidden)] + #[cold] + pub fn new(description: impl Into<Cow<'static, str>>) -> DecodeError { + DecodeError { + inner: Box::new(Inner { + description: description.into(), + stack: Vec::new(), + }), + } + } + + /// Pushes a (message, field) name location pair on to the location stack. + /// + /// Meant to be used only by `Message` implementations. + #[doc(hidden)] + pub fn push(&mut self, message: &'static str, field: &'static str) { + self.inner.stack.push((message, field)); + } +} + +impl fmt::Debug for DecodeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("DecodeError") + .field("description", &self.inner.description) + .field("stack", &self.inner.stack) + .finish() + } +} + +impl fmt::Display for DecodeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("failed to decode Protobuf message: ")?; + for &(message, field) in &self.inner.stack { + write!(f, "{}.{}: ", message, field)?; + } + f.write_str(&self.inner.description) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for DecodeError {} + +#[cfg(feature = "std")] +impl From<DecodeError> for std::io::Error { + fn from(error: DecodeError) -> std::io::Error { + std::io::Error::new(std::io::ErrorKind::InvalidData, error) + } +} + +/// A Protobuf message encoding error. +/// +/// `EncodeError` always indicates that a message failed to encode because the +/// provided buffer had insufficient capacity. Message encoding is otherwise +/// infallible. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct EncodeError { + required: usize, + remaining: usize, +} + +impl EncodeError { + /// Creates a new `EncodeError`. + pub(crate) fn new(required: usize, remaining: usize) -> EncodeError { + EncodeError { + required, + remaining, + } + } + + /// Returns the required buffer capacity to encode the message. + pub fn required_capacity(&self) -> usize { + self.required + } + + /// Returns the remaining length in the provided buffer at the time of encoding. + pub fn remaining(&self) -> usize { + self.remaining + } +} + +impl fmt::Display for EncodeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "failed to encode Protobuf messsage; insufficient buffer capacity (required: {}, remaining: {})", + self.required, self.remaining + ) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for EncodeError {} + +#[cfg(feature = "std")] +impl From<EncodeError> for std::io::Error { + fn from(error: EncodeError) -> std::io::Error { + std::io::Error::new(std::io::ErrorKind::InvalidInput, error) + } +} diff --git a/third_party/rust/prost/src/lib.rs b/third_party/rust/prost/src/lib.rs new file mode 100644 index 0000000000..9d4069e76e --- /dev/null +++ b/third_party/rust/prost/src/lib.rs @@ -0,0 +1,93 @@ +#![doc(html_root_url = "https://docs.rs/prost/0.8.0")] +#![cfg_attr(not(feature = "std"), no_std)] + +// Re-export the alloc crate for use within derived code. +#[doc(hidden)] +pub extern crate alloc; + +// Re-export the bytes crate for use within derived code. +#[doc(hidden)] +pub use bytes; + +mod error; +mod message; +mod types; + +#[doc(hidden)] +pub mod encoding; + +pub use crate::error::{DecodeError, EncodeError}; +pub use crate::message::Message; + +use bytes::{Buf, BufMut}; + +use crate::encoding::{decode_varint, encode_varint, encoded_len_varint}; + +// See `encoding::DecodeContext` for more info. +// 100 is the default recursion limit in the C++ implementation. +#[cfg(not(feature = "no-recursion-limit"))] +const RECURSION_LIMIT: u32 = 100; + +/// Encodes a length delimiter to the buffer. +/// +/// See [Message.encode_length_delimited] for more info. +/// +/// An error will be returned if the buffer does not have sufficient capacity to encode the +/// delimiter. +pub fn encode_length_delimiter<B>(length: usize, buf: &mut B) -> Result<(), EncodeError> +where + B: BufMut, +{ + let length = length as u64; + let required = encoded_len_varint(length); + let remaining = buf.remaining_mut(); + if required > remaining { + return Err(EncodeError::new(required, remaining)); + } + encode_varint(length, buf); + Ok(()) +} + +/// Returns the encoded length of a length delimiter. +/// +/// Applications may use this method to ensure sufficient buffer capacity before calling +/// `encode_length_delimiter`. The returned size will be between 1 and 10, inclusive. +pub fn length_delimiter_len(length: usize) -> usize { + encoded_len_varint(length as u64) +} + +/// Decodes a length delimiter from the buffer. +/// +/// This method allows the length delimiter to be decoded independently of the message, when the +/// message is encoded with [Message.encode_length_delimited]. +/// +/// An error may be returned in two cases: +/// +/// * If the supplied buffer contains fewer than 10 bytes, then an error indicates that more +/// input is required to decode the full delimiter. +/// * If the supplied buffer contains more than 10 bytes, then the buffer contains an invalid +/// delimiter, and typically the buffer should be considered corrupt. +pub fn decode_length_delimiter<B>(mut buf: B) -> Result<usize, DecodeError> +where + B: Buf, +{ + let length = decode_varint(&mut buf)?; + if length > usize::max_value() as u64 { + return Err(DecodeError::new( + "length delimiter exceeds maximum usize value", + )); + } + Ok(length as usize) +} + +// Re-export #[derive(Message, Enumeration, Oneof)]. +// Based on serde's equivalent re-export [1], but enabled by default. +// +// [1]: https://github.com/serde-rs/serde/blob/v1.0.89/serde/src/lib.rs#L245-L256 +#[cfg(feature = "prost-derive")] +#[allow(unused_imports)] +#[macro_use] +extern crate prost_derive; +#[cfg(feature = "prost-derive")] +#[doc(hidden)] +pub use prost_derive::*; diff --git a/third_party/rust/prost/src/message.rs b/third_party/rust/prost/src/message.rs new file mode 100644 index 0000000000..112c7b89f1 --- /dev/null +++ b/third_party/rust/prost/src/message.rs @@ -0,0 +1,200 @@ +use alloc::boxed::Box; +use core::fmt::Debug; +use core::usize; + +use bytes::{Buf, BufMut}; + +use crate::encoding::{ + decode_key, encode_varint, encoded_len_varint, message, DecodeContext, WireType, +}; +use crate::DecodeError; +use crate::EncodeError; + +/// A Protocol Buffers message. +pub trait Message: Debug + Send + Sync { + /// Encodes the message to a buffer. + /// + /// This method will panic if the buffer has insufficient capacity. + /// + /// Meant to be used only by `Message` implementations. + #[doc(hidden)] + fn encode_raw<B>(&self, buf: &mut B) + where + B: BufMut, + Self: Sized; + + /// Decodes a field from a buffer, and merges it into `self`. + /// + /// Meant to be used only by `Message` implementations. + #[doc(hidden)] + fn merge_field<B>( + &mut self, + tag: u32, + wire_type: WireType, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + Self: Sized; + + /// Returns the encoded length of the message without a length delimiter. + fn encoded_len(&self) -> usize; + + /// Encodes the message to a buffer. + /// + /// An error will be returned if the buffer does not have sufficient capacity. + fn encode<B>(&self, buf: &mut B) -> Result<(), EncodeError> + where + B: BufMut, + Self: Sized, + { + let required = self.encoded_len(); + let remaining = buf.remaining_mut(); + if required > buf.remaining_mut() { + return Err(EncodeError::new(required, remaining)); + } + + self.encode_raw(buf); + Ok(()) + } + + #[cfg(feature = "std")] + /// Encodes the message to a newly allocated buffer. + fn encode_to_vec(&self) -> Vec<u8> + where + Self: Sized, + { + let mut buf = Vec::with_capacity(self.encoded_len()); + + self.encode_raw(&mut buf); + buf + } + + /// Encodes the message with a length-delimiter to a buffer. + /// + /// An error will be returned if the buffer does not have sufficient capacity. + fn encode_length_delimited<B>(&self, buf: &mut B) -> Result<(), EncodeError> + where + B: BufMut, + Self: Sized, + { + let len = self.encoded_len(); + let required = len + encoded_len_varint(len as u64); + let remaining = buf.remaining_mut(); + if required > remaining { + return Err(EncodeError::new(required, remaining)); + } + encode_varint(len as u64, buf); + self.encode_raw(buf); + Ok(()) + } + + #[cfg(feature = "std")] + /// Encodes the message with a length-delimiter to a newly allocated buffer. + fn encode_length_delimited_to_vec(&self) -> Vec<u8> + where + Self: Sized, + { + let len = self.encoded_len(); + let mut buf = Vec::with_capacity(len + encoded_len_varint(len as u64)); + + encode_varint(len as u64, &mut buf); + self.encode_raw(&mut buf); + buf + } + + /// Decodes an instance of the message from a buffer. + /// + /// The entire buffer will be consumed. + fn decode<B>(mut buf: B) -> Result<Self, DecodeError> + where + B: Buf, + Self: Default, + { + let mut message = Self::default(); + Self::merge(&mut message, &mut buf).map(|_| message) + } + + /// Decodes a length-delimited instance of the message from the buffer. + fn decode_length_delimited<B>(buf: B) -> Result<Self, DecodeError> + where + B: Buf, + Self: Default, + { + let mut message = Self::default(); + message.merge_length_delimited(buf)?; + Ok(message) + } + + /// Decodes an instance of the message from a buffer, and merges it into `self`. + /// + /// The entire buffer will be consumed. + fn merge<B>(&mut self, mut buf: B) -> Result<(), DecodeError> + where + B: Buf, + Self: Sized, + { + let ctx = DecodeContext::default(); + while buf.has_remaining() { + let (tag, wire_type) = decode_key(&mut buf)?; + self.merge_field(tag, wire_type, &mut buf, ctx.clone())?; + } + Ok(()) + } + + /// Decodes a length-delimited instance of the message from buffer, and + /// merges it into `self`. + fn merge_length_delimited<B>(&mut self, mut buf: B) -> Result<(), DecodeError> + where + B: Buf, + Self: Sized, + { + message::merge( + WireType::LengthDelimited, + self, + &mut buf, + DecodeContext::default(), + ) + } + + /// Clears the message, resetting all fields to their default. + fn clear(&mut self); +} + +impl<M> Message for Box<M> +where + M: Message, +{ + fn encode_raw<B>(&self, buf: &mut B) + where + B: BufMut, + { + (**self).encode_raw(buf) + } + fn merge_field<B>( + &mut self, + tag: u32, + wire_type: WireType, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + (**self).merge_field(tag, wire_type, buf, ctx) + } + fn encoded_len(&self) -> usize { + (**self).encoded_len() + } + fn clear(&mut self) { + (**self).clear() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const _MESSAGE_IS_OBJECT_SAFE: Option<&dyn Message> = None; +} diff --git a/third_party/rust/prost/src/types.rs b/third_party/rust/prost/src/types.rs new file mode 100644 index 0000000000..864a2adda1 --- /dev/null +++ b/third_party/rust/prost/src/types.rs @@ -0,0 +1,424 @@ +//! Protocol Buffers well-known wrapper types. +//! +//! This module provides implementations of `Message` for Rust standard library types which +//! correspond to a Protobuf well-known wrapper type. The remaining well-known types are defined in +//! the `prost-types` crate in order to avoid a cyclic dependency between `prost` and +//! `prost-build`. + +use alloc::string::String; +use alloc::vec::Vec; + +use ::bytes::{Buf, BufMut, Bytes}; + +use crate::{ + encoding::{ + bool, bytes, double, float, int32, int64, skip_field, string, uint32, uint64, + DecodeContext, WireType, + }, + DecodeError, Message, +}; + +/// `google.protobuf.BoolValue` +impl Message for bool { + fn encode_raw<B>(&self, buf: &mut B) + where + B: BufMut, + { + if *self { + bool::encode(1, self, buf) + } + } + fn merge_field<B>( + &mut self, + tag: u32, + wire_type: WireType, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + if tag == 1 { + bool::merge(wire_type, self, buf, ctx) + } else { + skip_field(wire_type, tag, buf, ctx) + } + } + fn encoded_len(&self) -> usize { + if *self { + 2 + } else { + 0 + } + } + fn clear(&mut self) { + *self = false; + } +} + +/// `google.protobuf.UInt32Value` +impl Message for u32 { + fn encode_raw<B>(&self, buf: &mut B) + where + B: BufMut, + { + if *self != 0 { + uint32::encode(1, self, buf) + } + } + fn merge_field<B>( + &mut self, + tag: u32, + wire_type: WireType, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + if tag == 1 { + uint32::merge(wire_type, self, buf, ctx) + } else { + skip_field(wire_type, tag, buf, ctx) + } + } + fn encoded_len(&self) -> usize { + if *self != 0 { + uint32::encoded_len(1, self) + } else { + 0 + } + } + fn clear(&mut self) { + *self = 0; + } +} + +/// `google.protobuf.UInt64Value` +impl Message for u64 { + fn encode_raw<B>(&self, buf: &mut B) + where + B: BufMut, + { + if *self != 0 { + uint64::encode(1, self, buf) + } + } + fn merge_field<B>( + &mut self, + tag: u32, + wire_type: WireType, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + if tag == 1 { + uint64::merge(wire_type, self, buf, ctx) + } else { + skip_field(wire_type, tag, buf, ctx) + } + } + fn encoded_len(&self) -> usize { + if *self != 0 { + uint64::encoded_len(1, self) + } else { + 0 + } + } + fn clear(&mut self) { + *self = 0; + } +} + +/// `google.protobuf.Int32Value` +impl Message for i32 { + fn encode_raw<B>(&self, buf: &mut B) + where + B: BufMut, + { + if *self != 0 { + int32::encode(1, self, buf) + } + } + fn merge_field<B>( + &mut self, + tag: u32, + wire_type: WireType, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + if tag == 1 { + int32::merge(wire_type, self, buf, ctx) + } else { + skip_field(wire_type, tag, buf, ctx) + } + } + fn encoded_len(&self) -> usize { + if *self != 0 { + int32::encoded_len(1, self) + } else { + 0 + } + } + fn clear(&mut self) { + *self = 0; + } +} + +/// `google.protobuf.Int64Value` +impl Message for i64 { + fn encode_raw<B>(&self, buf: &mut B) + where + B: BufMut, + { + if *self != 0 { + int64::encode(1, self, buf) + } + } + fn merge_field<B>( + &mut self, + tag: u32, + wire_type: WireType, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + if tag == 1 { + int64::merge(wire_type, self, buf, ctx) + } else { + skip_field(wire_type, tag, buf, ctx) + } + } + fn encoded_len(&self) -> usize { + if *self != 0 { + int64::encoded_len(1, self) + } else { + 0 + } + } + fn clear(&mut self) { + *self = 0; + } +} + +/// `google.protobuf.FloatValue` +impl Message for f32 { + fn encode_raw<B>(&self, buf: &mut B) + where + B: BufMut, + { + if *self != 0.0 { + float::encode(1, self, buf) + } + } + fn merge_field<B>( + &mut self, + tag: u32, + wire_type: WireType, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + if tag == 1 { + float::merge(wire_type, self, buf, ctx) + } else { + skip_field(wire_type, tag, buf, ctx) + } + } + fn encoded_len(&self) -> usize { + if *self != 0.0 { + float::encoded_len(1, self) + } else { + 0 + } + } + fn clear(&mut self) { + *self = 0.0; + } +} + +/// `google.protobuf.DoubleValue` +impl Message for f64 { + fn encode_raw<B>(&self, buf: &mut B) + where + B: BufMut, + { + if *self != 0.0 { + double::encode(1, self, buf) + } + } + fn merge_field<B>( + &mut self, + tag: u32, + wire_type: WireType, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + if tag == 1 { + double::merge(wire_type, self, buf, ctx) + } else { + skip_field(wire_type, tag, buf, ctx) + } + } + fn encoded_len(&self) -> usize { + if *self != 0.0 { + double::encoded_len(1, self) + } else { + 0 + } + } + fn clear(&mut self) { + *self = 0.0; + } +} + +/// `google.protobuf.StringValue` +impl Message for String { + fn encode_raw<B>(&self, buf: &mut B) + where + B: BufMut, + { + if !self.is_empty() { + string::encode(1, self, buf) + } + } + fn merge_field<B>( + &mut self, + tag: u32, + wire_type: WireType, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + if tag == 1 { + string::merge(wire_type, self, buf, ctx) + } else { + skip_field(wire_type, tag, buf, ctx) + } + } + fn encoded_len(&self) -> usize { + if !self.is_empty() { + string::encoded_len(1, self) + } else { + 0 + } + } + fn clear(&mut self) { + self.clear(); + } +} + +/// `google.protobuf.BytesValue` +impl Message for Vec<u8> { + fn encode_raw<B>(&self, buf: &mut B) + where + B: BufMut, + { + if !self.is_empty() { + bytes::encode(1, self, buf) + } + } + fn merge_field<B>( + &mut self, + tag: u32, + wire_type: WireType, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + if tag == 1 { + bytes::merge(wire_type, self, buf, ctx) + } else { + skip_field(wire_type, tag, buf, ctx) + } + } + fn encoded_len(&self) -> usize { + if !self.is_empty() { + bytes::encoded_len(1, self) + } else { + 0 + } + } + fn clear(&mut self) { + self.clear(); + } +} + +/// `google.protobuf.BytesValue` +impl Message for Bytes { + fn encode_raw<B>(&self, buf: &mut B) + where + B: BufMut, + { + if !self.is_empty() { + bytes::encode(1, self, buf) + } + } + fn merge_field<B>( + &mut self, + tag: u32, + wire_type: WireType, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + if tag == 1 { + bytes::merge(wire_type, self, buf, ctx) + } else { + skip_field(wire_type, tag, buf, ctx) + } + } + fn encoded_len(&self) -> usize { + if !self.is_empty() { + bytes::encoded_len(1, self) + } else { + 0 + } + } + fn clear(&mut self) { + self.clear(); + } +} + +/// `google.protobuf.Empty` +impl Message for () { + fn encode_raw<B>(&self, _buf: &mut B) + where + B: BufMut, + { + } + fn merge_field<B>( + &mut self, + tag: u32, + wire_type: WireType, + buf: &mut B, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + B: Buf, + { + skip_field(wire_type, tag, buf, ctx) + } + fn encoded_len(&self) -> usize { + 0 + } + fn clear(&mut self) {} +} |