diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
commit | 26a029d407be480d791972afb5975cf62c9360a6 (patch) | |
tree | f435a8308119effd964b339f76abb83a57c29483 /third_party/rust/prost-derive/src | |
parent | Initial commit. (diff) | |
download | firefox-26a029d407be480d791972afb5975cf62c9360a6.tar.xz firefox-26a029d407be480d791972afb5975cf62c9360a6.zip |
Adding upstream version 124.0.1.upstream/124.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/prost-derive/src')
-rw-r--r-- | third_party/rust/prost-derive/src/field/group.rs | 134 | ||||
-rw-r--r-- | third_party/rust/prost-derive/src/field/map.rs | 407 | ||||
-rw-r--r-- | third_party/rust/prost-derive/src/field/message.rs | 134 | ||||
-rw-r--r-- | third_party/rust/prost-derive/src/field/mod.rs | 359 | ||||
-rw-r--r-- | third_party/rust/prost-derive/src/field/oneof.rs | 92 | ||||
-rw-r--r-- | third_party/rust/prost-derive/src/field/scalar.rs | 828 | ||||
-rw-r--r-- | third_party/rust/prost-derive/src/lib.rs | 530 |
7 files changed, 2484 insertions, 0 deletions
diff --git a/third_party/rust/prost-derive/src/field/group.rs b/third_party/rust/prost-derive/src/field/group.rs new file mode 100644 index 0000000000..076b577d73 --- /dev/null +++ b/third_party/rust/prost-derive/src/field/group.rs @@ -0,0 +1,134 @@ +use anyhow::{bail, Error}; +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; +use syn::Meta; + +use crate::field::{set_bool, set_option, tag_attr, word_attr, Label}; + +#[derive(Clone)] +pub struct Field { + pub label: Label, + pub tag: u32, +} + +impl Field { + pub fn new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error> { + let mut group = false; + let mut label = None; + let mut tag = None; + let mut boxed = false; + + let mut unknown_attrs = Vec::new(); + + for attr in attrs { + if word_attr("group", attr) { + set_bool(&mut group, "duplicate group attributes")?; + } else if word_attr("boxed", attr) { + set_bool(&mut boxed, "duplicate boxed attributes")?; + } else if let Some(t) = tag_attr(attr)? { + set_option(&mut tag, t, "duplicate tag attributes")?; + } else if let Some(l) = Label::from_attr(attr) { + set_option(&mut label, l, "duplicate label attributes")?; + } else { + unknown_attrs.push(attr); + } + } + + if !group { + return Ok(None); + } + + match unknown_attrs.len() { + 0 => (), + 1 => bail!("unknown attribute for group field: {:?}", unknown_attrs[0]), + _ => bail!("unknown attributes for group field: {:?}", unknown_attrs), + } + + let tag = match tag.or(inferred_tag) { + Some(tag) => tag, + None => bail!("group field is missing a tag attribute"), + }; + + Ok(Some(Field { + label: label.unwrap_or(Label::Optional), + tag, + })) + } + + pub fn new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error> { + if let Some(mut field) = Field::new(attrs, None)? { + if let Some(attr) = attrs.iter().find(|attr| Label::from_attr(attr).is_some()) { + bail!( + "invalid attribute for oneof field: {}", + attr.path().into_token_stream() + ); + } + field.label = Label::Required; + Ok(Some(field)) + } else { + Ok(None) + } + } + + pub fn encode(&self, ident: TokenStream) -> TokenStream { + let tag = self.tag; + match self.label { + Label::Optional => quote! { + if let Some(ref msg) = #ident { + ::prost::encoding::group::encode(#tag, msg, buf); + } + }, + Label::Required => quote! { + ::prost::encoding::group::encode(#tag, &#ident, buf); + }, + Label::Repeated => quote! { + for msg in &#ident { + ::prost::encoding::group::encode(#tag, msg, buf); + } + }, + } + } + + pub fn merge(&self, ident: TokenStream) -> TokenStream { + match self.label { + Label::Optional => quote! { + ::prost::encoding::group::merge( + tag, + wire_type, + #ident.get_or_insert_with(::core::default::Default::default), + buf, + ctx, + ) + }, + Label::Required => quote! { + ::prost::encoding::group::merge(tag, wire_type, #ident, buf, ctx) + }, + Label::Repeated => quote! { + ::prost::encoding::group::merge_repeated(tag, wire_type, #ident, buf, ctx) + }, + } + } + + pub fn encoded_len(&self, ident: TokenStream) -> TokenStream { + let tag = self.tag; + match self.label { + Label::Optional => quote! { + #ident.as_ref().map_or(0, |msg| ::prost::encoding::group::encoded_len(#tag, msg)) + }, + Label::Required => quote! { + ::prost::encoding::group::encoded_len(#tag, &#ident) + }, + Label::Repeated => quote! { + ::prost::encoding::group::encoded_len_repeated(#tag, &#ident) + }, + } + } + + pub fn clear(&self, ident: TokenStream) -> TokenStream { + match self.label { + Label::Optional => quote!(#ident = ::core::option::Option::None), + Label::Required => quote!(#ident.clear()), + Label::Repeated => quote!(#ident.clear()), + } + } +} diff --git a/third_party/rust/prost-derive/src/field/map.rs b/third_party/rust/prost-derive/src/field/map.rs new file mode 100644 index 0000000000..4855cc5c67 --- /dev/null +++ b/third_party/rust/prost-derive/src/field/map.rs @@ -0,0 +1,407 @@ +use anyhow::{bail, Error}; +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::{Expr, ExprLit, Ident, Lit, Meta, MetaNameValue, Token}; + +use crate::field::{scalar, set_option, tag_attr}; + +#[derive(Clone, Debug)] +pub enum MapTy { + HashMap, + BTreeMap, +} + +impl MapTy { + fn from_str(s: &str) -> Option<MapTy> { + match s { + "map" | "hash_map" => Some(MapTy::HashMap), + "btree_map" => Some(MapTy::BTreeMap), + _ => None, + } + } + + fn module(&self) -> Ident { + match *self { + MapTy::HashMap => Ident::new("hash_map", Span::call_site()), + MapTy::BTreeMap => Ident::new("btree_map", Span::call_site()), + } + } + + fn lib(&self) -> TokenStream { + match self { + MapTy::HashMap => quote! { std }, + MapTy::BTreeMap => quote! { prost::alloc }, + } + } +} + +fn fake_scalar(ty: scalar::Ty) -> scalar::Field { + let kind = scalar::Kind::Plain(scalar::DefaultValue::new(&ty)); + scalar::Field { + ty, + kind, + tag: 0, // Not used here + } +} + +#[derive(Clone)] +pub struct Field { + pub map_ty: MapTy, + pub key_ty: scalar::Ty, + pub value_ty: ValueTy, + pub tag: u32, +} + +impl Field { + pub fn new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error> { + let mut types = None; + let mut tag = None; + + for attr in attrs { + if let Some(t) = tag_attr(attr)? { + set_option(&mut tag, t, "duplicate tag attributes")?; + } else if let Some(map_ty) = attr + .path() + .get_ident() + .and_then(|i| MapTy::from_str(&i.to_string())) + { + let (k, v): (String, String) = match attr { + Meta::NameValue(MetaNameValue { + value: + Expr::Lit(ExprLit { + lit: Lit::Str(lit), .. + }), + .. + }) => { + let items = lit.value(); + let mut items = items.split(',').map(ToString::to_string); + let k = items.next().unwrap(); + let v = match items.next() { + Some(k) => k, + None => bail!("invalid map attribute: must have key and value types"), + }; + if items.next().is_some() { + bail!("invalid map attribute: {:?}", attr); + } + (k, v) + } + Meta::List(meta_list) => { + let nested = meta_list + .parse_args_with(Punctuated::<Ident, Token![,]>::parse_terminated)? + .into_iter() + .collect::<Vec<_>>(); + if nested.len() != 2 { + bail!("invalid map attribute: must contain key and value types"); + } + (nested[0].to_string(), nested[1].to_string()) + } + _ => return Ok(None), + }; + set_option( + &mut types, + (map_ty, key_ty_from_str(&k)?, ValueTy::from_str(&v)?), + "duplicate map type attribute", + )?; + } else { + return Ok(None); + } + } + + Ok(match (types, tag.or(inferred_tag)) { + (Some((map_ty, key_ty, value_ty)), Some(tag)) => Some(Field { + map_ty, + key_ty, + value_ty, + tag, + }), + _ => None, + }) + } + + pub fn new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error> { + Field::new(attrs, None) + } + + /// Returns a statement which encodes the map field. + pub fn encode(&self, ident: TokenStream) -> TokenStream { + let tag = self.tag; + let key_mod = self.key_ty.module(); + let ke = quote!(::prost::encoding::#key_mod::encode); + let kl = quote!(::prost::encoding::#key_mod::encoded_len); + let module = self.map_ty.module(); + match &self.value_ty { + ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { + let default = quote!(#ty::default() as i32); + quote! { + ::prost::encoding::#module::encode_with_default( + #ke, + #kl, + ::prost::encoding::int32::encode, + ::prost::encoding::int32::encoded_len, + &(#default), + #tag, + &#ident, + buf, + ); + } + } + ValueTy::Scalar(value_ty) => { + let val_mod = value_ty.module(); + let ve = quote!(::prost::encoding::#val_mod::encode); + let vl = quote!(::prost::encoding::#val_mod::encoded_len); + quote! { + ::prost::encoding::#module::encode( + #ke, + #kl, + #ve, + #vl, + #tag, + &#ident, + buf, + ); + } + } + ValueTy::Message => quote! { + ::prost::encoding::#module::encode( + #ke, + #kl, + ::prost::encoding::message::encode, + ::prost::encoding::message::encoded_len, + #tag, + &#ident, + buf, + ); + }, + } + } + + /// Returns an expression which evaluates to the result of merging a decoded key value pair + /// into the map. + pub fn merge(&self, ident: TokenStream) -> TokenStream { + let key_mod = self.key_ty.module(); + let km = quote!(::prost::encoding::#key_mod::merge); + let module = self.map_ty.module(); + match &self.value_ty { + ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { + let default = quote!(#ty::default() as i32); + quote! { + ::prost::encoding::#module::merge_with_default( + #km, + ::prost::encoding::int32::merge, + #default, + &mut #ident, + buf, + ctx, + ) + } + } + ValueTy::Scalar(value_ty) => { + let val_mod = value_ty.module(); + let vm = quote!(::prost::encoding::#val_mod::merge); + quote!(::prost::encoding::#module::merge(#km, #vm, &mut #ident, buf, ctx)) + } + ValueTy::Message => quote! { + ::prost::encoding::#module::merge( + #km, + ::prost::encoding::message::merge, + &mut #ident, + buf, + ctx, + ) + }, + } + } + + /// Returns an expression which evaluates to the encoded length of the map. + pub fn encoded_len(&self, ident: TokenStream) -> TokenStream { + let tag = self.tag; + let key_mod = self.key_ty.module(); + let kl = quote!(::prost::encoding::#key_mod::encoded_len); + let module = self.map_ty.module(); + match &self.value_ty { + ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { + let default = quote!(#ty::default() as i32); + quote! { + ::prost::encoding::#module::encoded_len_with_default( + #kl, + ::prost::encoding::int32::encoded_len, + &(#default), + #tag, + &#ident, + ) + } + } + ValueTy::Scalar(value_ty) => { + let val_mod = value_ty.module(); + let vl = quote!(::prost::encoding::#val_mod::encoded_len); + quote!(::prost::encoding::#module::encoded_len(#kl, #vl, #tag, &#ident)) + } + ValueTy::Message => quote! { + ::prost::encoding::#module::encoded_len( + #kl, + ::prost::encoding::message::encoded_len, + #tag, + &#ident, + ) + }, + } + } + + pub fn clear(&self, ident: TokenStream) -> TokenStream { + quote!(#ident.clear()) + } + + /// Returns methods to embed in the message. + pub fn methods(&self, ident: &TokenStream) -> Option<TokenStream> { + if let ValueTy::Scalar(scalar::Ty::Enumeration(ty)) = &self.value_ty { + let key_ty = self.key_ty.rust_type(); + let key_ref_ty = self.key_ty.rust_ref_type(); + + let get = Ident::new(&format!("get_{}", ident), Span::call_site()); + let insert = Ident::new(&format!("insert_{}", ident), Span::call_site()); + let take_ref = if self.key_ty.is_numeric() { + quote!(&) + } else { + quote!() + }; + + let get_doc = format!( + "Returns the enum value for the corresponding key in `{}`, \ + or `None` if the entry does not exist or it is not a valid enum value.", + ident, + ); + let insert_doc = format!("Inserts a key value pair into `{}`.", ident); + Some(quote! { + #[doc=#get_doc] + pub fn #get(&self, key: #key_ref_ty) -> ::core::option::Option<#ty> { + self.#ident.get(#take_ref key).cloned().and_then(|x| { + let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x); + result.ok() + }) + } + #[doc=#insert_doc] + pub fn #insert(&mut self, key: #key_ty, value: #ty) -> ::core::option::Option<#ty> { + self.#ident.insert(key, value as i32).and_then(|x| { + let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x); + result.ok() + }) + } + }) + } else { + None + } + } + + /// Returns a newtype wrapper around the map, implementing nicer Debug + /// + /// The Debug tries to convert any enumerations met into the variants if possible, instead of + /// outputting the raw numbers. + pub fn debug(&self, wrapper_name: TokenStream) -> TokenStream { + let type_name = match self.map_ty { + MapTy::HashMap => Ident::new("HashMap", Span::call_site()), + MapTy::BTreeMap => Ident::new("BTreeMap", Span::call_site()), + }; + + // A fake field for generating the debug wrapper + let key_wrapper = fake_scalar(self.key_ty.clone()).debug(quote!(KeyWrapper)); + let key = self.key_ty.rust_type(); + let value_wrapper = self.value_ty.debug(); + let libname = self.map_ty.lib(); + let fmt = quote! { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + #key_wrapper + #value_wrapper + let mut builder = f.debug_map(); + for (k, v) in self.0 { + builder.entry(&KeyWrapper(k), &ValueWrapper(v)); + } + builder.finish() + } + }; + match &self.value_ty { + ValueTy::Scalar(ty) => { + if let scalar::Ty::Bytes(_) = *ty { + return quote! { + struct #wrapper_name<'a>(&'a dyn ::core::fmt::Debug); + impl<'a> ::core::fmt::Debug for #wrapper_name<'a> { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + self.0.fmt(f) + } + } + }; + } + + let value = ty.rust_type(); + quote! { + struct #wrapper_name<'a>(&'a ::#libname::collections::#type_name<#key, #value>); + impl<'a> ::core::fmt::Debug for #wrapper_name<'a> { + #fmt + } + } + } + ValueTy::Message => quote! { + struct #wrapper_name<'a, V: 'a>(&'a ::#libname::collections::#type_name<#key, V>); + impl<'a, V> ::core::fmt::Debug for #wrapper_name<'a, V> + where + V: ::core::fmt::Debug + 'a, + { + #fmt + } + }, + } + } +} + +fn key_ty_from_str(s: &str) -> Result<scalar::Ty, Error> { + let ty = scalar::Ty::from_str(s)?; + match ty { + scalar::Ty::Int32 + | scalar::Ty::Int64 + | scalar::Ty::Uint32 + | scalar::Ty::Uint64 + | scalar::Ty::Sint32 + | scalar::Ty::Sint64 + | scalar::Ty::Fixed32 + | scalar::Ty::Fixed64 + | scalar::Ty::Sfixed32 + | scalar::Ty::Sfixed64 + | scalar::Ty::Bool + | scalar::Ty::String => Ok(ty), + _ => bail!("invalid map key type: {}", s), + } +} + +/// A map value type. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ValueTy { + Scalar(scalar::Ty), + Message, +} + +impl ValueTy { + fn from_str(s: &str) -> Result<ValueTy, Error> { + if let Ok(ty) = scalar::Ty::from_str(s) { + Ok(ValueTy::Scalar(ty)) + } else if s.trim() == "message" { + Ok(ValueTy::Message) + } else { + bail!("invalid map value type: {}", s); + } + } + + /// Returns a newtype wrapper around the ValueTy for nicer debug. + /// + /// If the contained value is enumeration, it tries to convert it to the variant. If not, it + /// just forwards the implementation. + fn debug(&self) -> TokenStream { + match self { + ValueTy::Scalar(ty) => fake_scalar(ty.clone()).debug(quote!(ValueWrapper)), + ValueTy::Message => quote!( + fn ValueWrapper<T>(v: T) -> T { + v + } + ), + } + } +} diff --git a/third_party/rust/prost-derive/src/field/message.rs b/third_party/rust/prost-derive/src/field/message.rs new file mode 100644 index 0000000000..3bcdddfb16 --- /dev/null +++ b/third_party/rust/prost-derive/src/field/message.rs @@ -0,0 +1,134 @@ +use anyhow::{bail, Error}; +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; +use syn::Meta; + +use crate::field::{set_bool, set_option, tag_attr, word_attr, Label}; + +#[derive(Clone)] +pub struct Field { + pub label: Label, + pub tag: u32, +} + +impl Field { + pub fn new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error> { + let mut message = false; + let mut label = None; + let mut tag = None; + let mut boxed = false; + + let mut unknown_attrs = Vec::new(); + + for attr in attrs { + if word_attr("message", attr) { + set_bool(&mut message, "duplicate message attribute")?; + } else if word_attr("boxed", attr) { + set_bool(&mut boxed, "duplicate boxed attribute")?; + } else if let Some(t) = tag_attr(attr)? { + set_option(&mut tag, t, "duplicate tag attributes")?; + } else if let Some(l) = Label::from_attr(attr) { + set_option(&mut label, l, "duplicate label attributes")?; + } else { + unknown_attrs.push(attr); + } + } + + if !message { + return Ok(None); + } + + match unknown_attrs.len() { + 0 => (), + 1 => bail!( + "unknown attribute for message field: {:?}", + unknown_attrs[0] + ), + _ => bail!("unknown attributes for message field: {:?}", unknown_attrs), + } + + let tag = match tag.or(inferred_tag) { + Some(tag) => tag, + None => bail!("message field is missing a tag attribute"), + }; + + Ok(Some(Field { + label: label.unwrap_or(Label::Optional), + tag, + })) + } + + pub fn new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error> { + if let Some(mut field) = Field::new(attrs, None)? { + if let Some(attr) = attrs.iter().find(|attr| Label::from_attr(attr).is_some()) { + bail!( + "invalid attribute for oneof field: {}", + attr.path().into_token_stream() + ); + } + field.label = Label::Required; + Ok(Some(field)) + } else { + Ok(None) + } + } + + pub fn encode(&self, ident: TokenStream) -> TokenStream { + let tag = self.tag; + match self.label { + Label::Optional => quote! { + if let Some(ref msg) = #ident { + ::prost::encoding::message::encode(#tag, msg, buf); + } + }, + Label::Required => quote! { + ::prost::encoding::message::encode(#tag, &#ident, buf); + }, + Label::Repeated => quote! { + for msg in &#ident { + ::prost::encoding::message::encode(#tag, msg, buf); + } + }, + } + } + + pub fn merge(&self, ident: TokenStream) -> TokenStream { + match self.label { + Label::Optional => quote! { + ::prost::encoding::message::merge(wire_type, + #ident.get_or_insert_with(::core::default::Default::default), + buf, + ctx) + }, + Label::Required => quote! { + ::prost::encoding::message::merge(wire_type, #ident, buf, ctx) + }, + Label::Repeated => quote! { + ::prost::encoding::message::merge_repeated(wire_type, #ident, buf, ctx) + }, + } + } + + pub fn encoded_len(&self, ident: TokenStream) -> TokenStream { + let tag = self.tag; + match self.label { + Label::Optional => quote! { + #ident.as_ref().map_or(0, |msg| ::prost::encoding::message::encoded_len(#tag, msg)) + }, + Label::Required => quote! { + ::prost::encoding::message::encoded_len(#tag, &#ident) + }, + Label::Repeated => quote! { + ::prost::encoding::message::encoded_len_repeated(#tag, &#ident) + }, + } + } + + pub fn clear(&self, ident: TokenStream) -> TokenStream { + match self.label { + Label::Optional => quote!(#ident = ::core::option::Option::None), + Label::Required => quote!(#ident.clear()), + Label::Repeated => quote!(#ident.clear()), + } + } +} diff --git a/third_party/rust/prost-derive/src/field/mod.rs b/third_party/rust/prost-derive/src/field/mod.rs new file mode 100644 index 0000000000..4bec5617c2 --- /dev/null +++ b/third_party/rust/prost-derive/src/field/mod.rs @@ -0,0 +1,359 @@ +mod group; +mod map; +mod message; +mod oneof; +mod scalar; + +use std::fmt; +use std::slice; + +use anyhow::{bail, Error}; +use proc_macro2::TokenStream; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::{Attribute, Expr, ExprLit, Lit, LitBool, LitInt, Meta, MetaNameValue, Token}; + +#[derive(Clone)] +pub enum Field { + /// A scalar field. + Scalar(scalar::Field), + /// A message field. + Message(message::Field), + /// A map field. + Map(map::Field), + /// A oneof field. + Oneof(oneof::Field), + /// A group field. + Group(group::Field), +} + +impl Field { + /// Creates a new `Field` from an iterator of field attributes. + /// + /// If the meta items are invalid, an error will be returned. + /// If the field should be ignored, `None` is returned. + pub fn new(attrs: Vec<Attribute>, inferred_tag: Option<u32>) -> Result<Option<Field>, Error> { + let attrs = prost_attrs(attrs)?; + + // TODO: check for ignore attribute. + + let field = if let Some(field) = scalar::Field::new(&attrs, inferred_tag)? { + Field::Scalar(field) + } else if let Some(field) = message::Field::new(&attrs, inferred_tag)? { + Field::Message(field) + } else if let Some(field) = map::Field::new(&attrs, inferred_tag)? { + Field::Map(field) + } else if let Some(field) = oneof::Field::new(&attrs)? { + Field::Oneof(field) + } else if let Some(field) = group::Field::new(&attrs, inferred_tag)? { + Field::Group(field) + } else { + bail!("no type attribute"); + }; + + Ok(Some(field)) + } + + /// Creates a new oneof `Field` from an iterator of field attributes. + /// + /// If the meta items are invalid, an error will be returned. + /// If the field should be ignored, `None` is returned. + pub fn new_oneof(attrs: Vec<Attribute>) -> Result<Option<Field>, Error> { + let attrs = prost_attrs(attrs)?; + + // TODO: check for ignore attribute. + + let field = if let Some(field) = scalar::Field::new_oneof(&attrs)? { + Field::Scalar(field) + } else if let Some(field) = message::Field::new_oneof(&attrs)? { + Field::Message(field) + } else if let Some(field) = map::Field::new_oneof(&attrs)? { + Field::Map(field) + } else if let Some(field) = group::Field::new_oneof(&attrs)? { + Field::Group(field) + } else { + bail!("no type attribute for oneof field"); + }; + + Ok(Some(field)) + } + + pub fn tags(&self) -> Vec<u32> { + match *self { + Field::Scalar(ref scalar) => vec![scalar.tag], + Field::Message(ref message) => vec![message.tag], + Field::Map(ref map) => vec![map.tag], + Field::Oneof(ref oneof) => oneof.tags.clone(), + Field::Group(ref group) => vec![group.tag], + } + } + + /// Returns a statement which encodes the field. + pub fn encode(&self, ident: TokenStream) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => scalar.encode(ident), + Field::Message(ref message) => message.encode(ident), + Field::Map(ref map) => map.encode(ident), + Field::Oneof(ref oneof) => oneof.encode(ident), + Field::Group(ref group) => group.encode(ident), + } + } + + /// Returns an expression which evaluates to the result of merging a decoded + /// value into the field. + pub fn merge(&self, ident: TokenStream) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => scalar.merge(ident), + Field::Message(ref message) => message.merge(ident), + Field::Map(ref map) => map.merge(ident), + Field::Oneof(ref oneof) => oneof.merge(ident), + Field::Group(ref group) => group.merge(ident), + } + } + + /// Returns an expression which evaluates to the encoded length of the field. + pub fn encoded_len(&self, ident: TokenStream) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => scalar.encoded_len(ident), + Field::Map(ref map) => map.encoded_len(ident), + Field::Message(ref msg) => msg.encoded_len(ident), + Field::Oneof(ref oneof) => oneof.encoded_len(ident), + Field::Group(ref group) => group.encoded_len(ident), + } + } + + /// Returns a statement which clears the field. + pub fn clear(&self, ident: TokenStream) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => scalar.clear(ident), + Field::Message(ref message) => message.clear(ident), + Field::Map(ref map) => map.clear(ident), + Field::Oneof(ref oneof) => oneof.clear(ident), + Field::Group(ref group) => group.clear(ident), + } + } + + pub fn default(&self) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => scalar.default(), + _ => quote!(::core::default::Default::default()), + } + } + + /// Produces the fragment implementing debug for the given field. + pub fn debug(&self, ident: TokenStream) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => { + let wrapper = scalar.debug(quote!(ScalarWrapper)); + quote! { + { + #wrapper + ScalarWrapper(&#ident) + } + } + } + Field::Map(ref map) => { + let wrapper = map.debug(quote!(MapWrapper)); + quote! { + { + #wrapper + MapWrapper(&#ident) + } + } + } + _ => quote!(&#ident), + } + } + + pub fn methods(&self, ident: &TokenStream) -> Option<TokenStream> { + match *self { + Field::Scalar(ref scalar) => scalar.methods(ident), + Field::Map(ref map) => map.methods(ident), + _ => None, + } + } +} + +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum Label { + /// An optional field. + Optional, + /// A required field. + Required, + /// A repeated field. + Repeated, +} + +impl Label { + fn as_str(self) -> &'static str { + match self { + Label::Optional => "optional", + Label::Required => "required", + Label::Repeated => "repeated", + } + } + + fn variants() -> slice::Iter<'static, Label> { + const VARIANTS: &[Label] = &[Label::Optional, Label::Required, Label::Repeated]; + VARIANTS.iter() + } + + /// Parses a string into a field label. + /// If the string doesn't match a field label, `None` is returned. + fn from_attr(attr: &Meta) -> Option<Label> { + if let Meta::Path(ref path) = *attr { + for &label in Label::variants() { + if path.is_ident(label.as_str()) { + return Some(label); + } + } + } + None + } +} + +impl fmt::Debug for Label { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +impl fmt::Display for Label { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +/// Get the items belonging to the 'prost' list attribute, e.g. `#[prost(foo, bar="baz")]`. +fn prost_attrs(attrs: Vec<Attribute>) -> Result<Vec<Meta>, Error> { + let mut result = Vec::new(); + for attr in attrs.iter() { + if let Meta::List(meta_list) = &attr.meta { + if meta_list.path.is_ident("prost") { + result.extend( + meta_list + .parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)? + .into_iter(), + ) + } + } + } + Ok(result) +} + +pub fn set_option<T>(option: &mut Option<T>, value: T, message: &str) -> Result<(), Error> +where + T: fmt::Debug, +{ + if let Some(ref existing) = *option { + bail!("{}: {:?} and {:?}", message, existing, value); + } + *option = Some(value); + Ok(()) +} + +pub fn set_bool(b: &mut bool, message: &str) -> Result<(), Error> { + if *b { + bail!("{}", message); + } else { + *b = true; + Ok(()) + } +} + +/// Unpacks an attribute into a (key, boolean) pair, returning the boolean value. +/// If the key doesn't match the attribute, `None` is returned. +fn bool_attr(key: &str, attr: &Meta) -> Result<Option<bool>, Error> { + if !attr.path().is_ident(key) { + return Ok(None); + } + match *attr { + Meta::Path(..) => Ok(Some(true)), + Meta::List(ref meta_list) => { + return Ok(Some(meta_list.parse_args::<LitBool>()?.value())); + } + Meta::NameValue(MetaNameValue { + value: + Expr::Lit(ExprLit { + lit: Lit::Str(ref lit), + .. + }), + .. + }) => lit + .value() + .parse::<bool>() + .map_err(Error::from) + .map(Option::Some), + Meta::NameValue(MetaNameValue { + value: + Expr::Lit(ExprLit { + lit: Lit::Bool(LitBool { value, .. }), + .. + }), + .. + }) => Ok(Some(value)), + _ => bail!("invalid {} attribute", key), + } +} + +/// Checks if an attribute matches a word. +fn word_attr(key: &str, attr: &Meta) -> bool { + if let Meta::Path(ref path) = *attr { + path.is_ident(key) + } else { + false + } +} + +pub(super) fn tag_attr(attr: &Meta) -> Result<Option<u32>, Error> { + if !attr.path().is_ident("tag") { + return Ok(None); + } + match *attr { + Meta::List(ref meta_list) => { + return Ok(Some(meta_list.parse_args::<LitInt>()?.base10_parse()?)); + } + Meta::NameValue(MetaNameValue { + value: Expr::Lit(ref expr), + .. + }) => match expr.lit { + Lit::Str(ref lit) => lit + .value() + .parse::<u32>() + .map_err(Error::from) + .map(Option::Some), + Lit::Int(ref lit) => Ok(Some(lit.base10_parse()?)), + _ => bail!("invalid tag attribute: {:?}", attr), + }, + _ => bail!("invalid tag attribute: {:?}", attr), + } +} + +fn tags_attr(attr: &Meta) -> Result<Option<Vec<u32>>, Error> { + if !attr.path().is_ident("tags") { + return Ok(None); + } + match *attr { + Meta::List(ref meta_list) => Ok(Some( + meta_list + .parse_args_with(Punctuated::<LitInt, Token![,]>::parse_terminated)? + .iter() + .map(LitInt::base10_parse) + .collect::<Result<Vec<_>, _>>()?, + )), + Meta::NameValue(MetaNameValue { + value: + Expr::Lit(ExprLit { + lit: Lit::Str(ref lit), + .. + }), + .. + }) => lit + .value() + .split(',') + .map(|s| s.trim().parse::<u32>().map_err(Error::from)) + .collect::<Result<Vec<u32>, _>>() + .map(Some), + _ => bail!("invalid tag attribute: {:?}", attr), + } +} diff --git a/third_party/rust/prost-derive/src/field/oneof.rs b/third_party/rust/prost-derive/src/field/oneof.rs new file mode 100644 index 0000000000..78c77eeb13 --- /dev/null +++ b/third_party/rust/prost-derive/src/field/oneof.rs @@ -0,0 +1,92 @@ +use anyhow::{bail, Error}; +use proc_macro2::TokenStream; +use quote::quote; +use syn::{parse_str, Expr, ExprLit, Ident, Lit, Meta, MetaNameValue, Path}; + +use crate::field::{set_option, tags_attr}; + +#[derive(Clone)] +pub struct Field { + pub ty: Path, + pub tags: Vec<u32>, +} + +impl Field { + pub fn new(attrs: &[Meta]) -> Result<Option<Field>, Error> { + let mut ty = None; + let mut tags = None; + let mut unknown_attrs = Vec::new(); + + for attr in attrs { + if attr.path().is_ident("oneof") { + let t = match *attr { + Meta::NameValue(MetaNameValue { + value: + Expr::Lit(ExprLit { + lit: Lit::Str(ref lit), + .. + }), + .. + }) => parse_str::<Path>(&lit.value())?, + Meta::List(ref list) => list.parse_args::<Ident>()?.into(), + _ => bail!("invalid oneof attribute: {:?}", attr), + }; + set_option(&mut ty, t, "duplicate oneof attribute")?; + } else if let Some(t) = tags_attr(attr)? { + set_option(&mut tags, t, "duplicate tags attributes")?; + } else { + unknown_attrs.push(attr); + } + } + + let ty = match ty { + Some(ty) => ty, + None => return Ok(None), + }; + + match unknown_attrs.len() { + 0 => (), + 1 => bail!( + "unknown attribute for message field: {:?}", + unknown_attrs[0] + ), + _ => bail!("unknown attributes for message field: {:?}", unknown_attrs), + } + + let tags = match tags { + Some(tags) => tags, + None => bail!("oneof field is missing a tags attribute"), + }; + + Ok(Some(Field { ty, tags })) + } + + /// Returns a statement which encodes the oneof field. + pub fn encode(&self, ident: TokenStream) -> TokenStream { + quote! { + if let Some(ref oneof) = #ident { + oneof.encode(buf) + } + } + } + + /// Returns an expression which evaluates to the result of decoding the oneof field. + pub fn merge(&self, ident: TokenStream) -> TokenStream { + let ty = &self.ty; + quote! { + #ty::merge(#ident, tag, wire_type, buf, ctx) + } + } + + /// Returns an expression which evaluates to the encoded length of the oneof field. + pub fn encoded_len(&self, ident: TokenStream) -> TokenStream { + let ty = &self.ty; + quote! { + #ident.as_ref().map_or(0, #ty::encoded_len) + } + } + + pub fn clear(&self, ident: TokenStream) -> TokenStream { + quote!(#ident = ::core::option::Option::None) + } +} diff --git a/third_party/rust/prost-derive/src/field/scalar.rs b/third_party/rust/prost-derive/src/field/scalar.rs new file mode 100644 index 0000000000..5a3dfb2ec3 --- /dev/null +++ b/third_party/rust/prost-derive/src/field/scalar.rs @@ -0,0 +1,828 @@ +use std::convert::TryFrom; +use std::fmt; + +use anyhow::{anyhow, bail, Error}; +use proc_macro2::{Span, TokenStream}; +use quote::{quote, ToTokens, TokenStreamExt}; +use syn::{parse_str, Expr, ExprLit, Ident, Index, Lit, LitByteStr, Meta, MetaNameValue, Path}; + +use crate::field::{bool_attr, set_option, tag_attr, Label}; + +/// A scalar protobuf field. +#[derive(Clone)] +pub struct Field { + pub ty: Ty, + pub kind: Kind, + pub tag: u32, +} + +impl Field { + pub fn new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error> { + let mut ty = None; + let mut label = None; + let mut packed = None; + let mut default = None; + let mut tag = None; + + let mut unknown_attrs = Vec::new(); + + for attr in attrs { + if let Some(t) = Ty::from_attr(attr)? { + set_option(&mut ty, t, "duplicate type attributes")?; + } else if let Some(p) = bool_attr("packed", attr)? { + set_option(&mut packed, p, "duplicate packed attributes")?; + } else if let Some(t) = tag_attr(attr)? { + set_option(&mut tag, t, "duplicate tag attributes")?; + } else if let Some(l) = Label::from_attr(attr) { + set_option(&mut label, l, "duplicate label attributes")?; + } else if let Some(d) = DefaultValue::from_attr(attr)? { + set_option(&mut default, d, "duplicate default attributes")?; + } else { + unknown_attrs.push(attr); + } + } + + let ty = match ty { + Some(ty) => ty, + None => return Ok(None), + }; + + match unknown_attrs.len() { + 0 => (), + 1 => bail!("unknown attribute: {:?}", unknown_attrs[0]), + _ => bail!("unknown attributes: {:?}", unknown_attrs), + } + + let tag = match tag.or(inferred_tag) { + Some(tag) => tag, + None => bail!("missing tag attribute"), + }; + + let has_default = default.is_some(); + let default = default.map_or_else( + || Ok(DefaultValue::new(&ty)), + |lit| DefaultValue::from_lit(&ty, lit), + )?; + + let kind = match (label, packed, has_default) { + (None, Some(true), _) + | (Some(Label::Optional), Some(true), _) + | (Some(Label::Required), Some(true), _) => { + bail!("packed attribute may only be applied to repeated fields"); + } + (Some(Label::Repeated), Some(true), _) if !ty.is_numeric() => { + bail!("packed attribute may only be applied to numeric types"); + } + (Some(Label::Repeated), _, true) => { + bail!("repeated fields may not have a default value"); + } + + (None, _, _) => Kind::Plain(default), + (Some(Label::Optional), _, _) => Kind::Optional(default), + (Some(Label::Required), _, _) => Kind::Required(default), + (Some(Label::Repeated), packed, false) if packed.unwrap_or_else(|| ty.is_numeric()) => { + Kind::Packed + } + (Some(Label::Repeated), _, false) => Kind::Repeated, + }; + + Ok(Some(Field { ty, kind, tag })) + } + + pub fn new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error> { + if let Some(mut field) = Field::new(attrs, None)? { + match field.kind { + Kind::Plain(default) => { + field.kind = Kind::Required(default); + Ok(Some(field)) + } + Kind::Optional(..) => bail!("invalid optional attribute on oneof field"), + Kind::Required(..) => bail!("invalid required attribute on oneof field"), + Kind::Packed | Kind::Repeated => bail!("invalid repeated attribute on oneof field"), + } + } else { + Ok(None) + } + } + + pub fn encode(&self, ident: TokenStream) -> TokenStream { + let module = self.ty.module(); + let encode_fn = match self.kind { + Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(encode), + Kind::Repeated => quote!(encode_repeated), + Kind::Packed => quote!(encode_packed), + }; + let encode_fn = quote!(::prost::encoding::#module::#encode_fn); + let tag = self.tag; + + match self.kind { + Kind::Plain(ref default) => { + let default = default.typed(); + quote! { + if #ident != #default { + #encode_fn(#tag, &#ident, buf); + } + } + } + Kind::Optional(..) => quote! { + if let ::core::option::Option::Some(ref value) = #ident { + #encode_fn(#tag, value, buf); + } + }, + Kind::Required(..) | Kind::Repeated | Kind::Packed => quote! { + #encode_fn(#tag, &#ident, buf); + }, + } + } + + /// Returns an expression which evaluates to the result of merging a decoded + /// scalar value into the field. + pub fn merge(&self, ident: TokenStream) -> TokenStream { + let module = self.ty.module(); + let merge_fn = match self.kind { + Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(merge), + Kind::Repeated | Kind::Packed => quote!(merge_repeated), + }; + let merge_fn = quote!(::prost::encoding::#module::#merge_fn); + + match self.kind { + Kind::Plain(..) | Kind::Required(..) | Kind::Repeated | Kind::Packed => quote! { + #merge_fn(wire_type, #ident, buf, ctx) + }, + Kind::Optional(..) => quote! { + #merge_fn(wire_type, + #ident.get_or_insert_with(::core::default::Default::default), + buf, + ctx) + }, + } + } + + /// Returns an expression which evaluates to the encoded length of the field. + pub fn encoded_len(&self, ident: TokenStream) -> TokenStream { + let module = self.ty.module(); + let encoded_len_fn = match self.kind { + Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(encoded_len), + Kind::Repeated => quote!(encoded_len_repeated), + Kind::Packed => quote!(encoded_len_packed), + }; + let encoded_len_fn = quote!(::prost::encoding::#module::#encoded_len_fn); + let tag = self.tag; + + match self.kind { + Kind::Plain(ref default) => { + let default = default.typed(); + quote! { + if #ident != #default { + #encoded_len_fn(#tag, &#ident) + } else { + 0 + } + } + } + Kind::Optional(..) => quote! { + #ident.as_ref().map_or(0, |value| #encoded_len_fn(#tag, value)) + }, + Kind::Required(..) | Kind::Repeated | Kind::Packed => quote! { + #encoded_len_fn(#tag, &#ident) + }, + } + } + + pub fn clear(&self, ident: TokenStream) -> TokenStream { + match self.kind { + Kind::Plain(ref default) | Kind::Required(ref default) => { + let default = default.typed(); + match self.ty { + Ty::String | Ty::Bytes(..) => quote!(#ident.clear()), + _ => quote!(#ident = #default), + } + } + Kind::Optional(_) => quote!(#ident = ::core::option::Option::None), + Kind::Repeated | Kind::Packed => quote!(#ident.clear()), + } + } + + /// Returns an expression which evaluates to the default value of the field. + pub fn default(&self) -> TokenStream { + match self.kind { + Kind::Plain(ref value) | Kind::Required(ref value) => value.owned(), + Kind::Optional(_) => quote!(::core::option::Option::None), + Kind::Repeated | Kind::Packed => quote!(::prost::alloc::vec::Vec::new()), + } + } + + /// An inner debug wrapper, around the base type. + fn debug_inner(&self, wrap_name: TokenStream) -> TokenStream { + if let Ty::Enumeration(ref ty) = self.ty { + quote! { + struct #wrap_name<'a>(&'a i32); + impl<'a> ::core::fmt::Debug for #wrap_name<'a> { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + let res: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(*self.0); + match res { + Err(_) => ::core::fmt::Debug::fmt(&self.0, f), + Ok(en) => ::core::fmt::Debug::fmt(&en, f), + } + } + } + } + } else { + quote! { + #[allow(non_snake_case)] + fn #wrap_name<T>(v: T) -> T { v } + } + } + } + + /// Returns a fragment for formatting the field `ident` in `Debug`. + pub fn debug(&self, wrapper_name: TokenStream) -> TokenStream { + let wrapper = self.debug_inner(quote!(Inner)); + let inner_ty = self.ty.rust_type(); + match self.kind { + Kind::Plain(_) | Kind::Required(_) => self.debug_inner(wrapper_name), + Kind::Optional(_) => quote! { + struct #wrapper_name<'a>(&'a ::core::option::Option<#inner_ty>); + impl<'a> ::core::fmt::Debug for #wrapper_name<'a> { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + #wrapper + ::core::fmt::Debug::fmt(&self.0.as_ref().map(Inner), f) + } + } + }, + Kind::Repeated | Kind::Packed => { + quote! { + struct #wrapper_name<'a>(&'a ::prost::alloc::vec::Vec<#inner_ty>); + impl<'a> ::core::fmt::Debug for #wrapper_name<'a> { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + let mut vec_builder = f.debug_list(); + for v in self.0 { + #wrapper + vec_builder.entry(&Inner(v)); + } + vec_builder.finish() + } + } + } + } + } + } + + /// Returns methods to embed in the message. + pub fn methods(&self, ident: &TokenStream) -> Option<TokenStream> { + let mut ident_str = ident.to_string(); + if ident_str.starts_with("r#") { + ident_str = ident_str[2..].to_owned(); + } + + // Prepend `get_` for getter methods of tuple structs. + let get = match syn::parse_str::<Index>(&ident_str) { + Ok(index) => { + let get = Ident::new(&format!("get_{}", index.index), Span::call_site()); + quote!(#get) + } + Err(_) => quote!(#ident), + }; + + if let Ty::Enumeration(ref ty) = self.ty { + let set = Ident::new(&format!("set_{}", ident_str), Span::call_site()); + let set_doc = format!("Sets `{}` to the provided enum value.", ident_str); + Some(match self.kind { + Kind::Plain(ref default) | Kind::Required(ref default) => { + let get_doc = format!( + "Returns the enum value of `{}`, \ + or the default if the field is set to an invalid enum value.", + ident_str, + ); + quote! { + #[doc=#get_doc] + pub fn #get(&self) -> #ty { + ::core::convert::TryFrom::try_from(self.#ident).unwrap_or(#default) + } + + #[doc=#set_doc] + pub fn #set(&mut self, value: #ty) { + self.#ident = value as i32; + } + } + } + Kind::Optional(ref default) => { + let get_doc = format!( + "Returns the enum value of `{}`, \ + or the default if the field is unset or set to an invalid enum value.", + ident_str, + ); + quote! { + #[doc=#get_doc] + pub fn #get(&self) -> #ty { + self.#ident.and_then(|x| { + let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x); + result.ok() + }).unwrap_or(#default) + } + + #[doc=#set_doc] + pub fn #set(&mut self, value: #ty) { + self.#ident = ::core::option::Option::Some(value as i32); + } + } + } + Kind::Repeated | Kind::Packed => { + let iter_doc = format!( + "Returns an iterator which yields the valid enum values contained in `{}`.", + ident_str, + ); + let push = Ident::new(&format!("push_{}", ident_str), Span::call_site()); + let push_doc = format!("Appends the provided enum value to `{}`.", ident_str); + quote! { + #[doc=#iter_doc] + pub fn #get(&self) -> ::core::iter::FilterMap< + ::core::iter::Cloned<::core::slice::Iter<i32>>, + fn(i32) -> ::core::option::Option<#ty>, + > { + self.#ident.iter().cloned().filter_map(|x| { + let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x); + result.ok() + }) + } + #[doc=#push_doc] + pub fn #push(&mut self, value: #ty) { + self.#ident.push(value as i32); + } + } + } + }) + } else if let Kind::Optional(ref default) = self.kind { + let ty = self.ty.rust_ref_type(); + + let match_some = if self.ty.is_numeric() { + quote!(::core::option::Option::Some(val) => val,) + } else { + quote!(::core::option::Option::Some(ref val) => &val[..],) + }; + + let get_doc = format!( + "Returns the value of `{0}`, or the default value if `{0}` is unset.", + ident_str, + ); + + Some(quote! { + #[doc=#get_doc] + pub fn #get(&self) -> #ty { + match self.#ident { + #match_some + ::core::option::Option::None => #default, + } + } + }) + } else { + None + } + } +} + +/// A scalar protobuf field type. +#[derive(Clone, PartialEq, Eq)] +pub enum Ty { + Double, + Float, + Int32, + Int64, + Uint32, + Uint64, + Sint32, + Sint64, + Fixed32, + Fixed64, + Sfixed32, + Sfixed64, + Bool, + String, + Bytes(BytesTy), + Enumeration(Path), +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum BytesTy { + Vec, + Bytes, +} + +impl BytesTy { + fn try_from_str(s: &str) -> Result<Self, Error> { + match s { + "vec" => Ok(BytesTy::Vec), + "bytes" => Ok(BytesTy::Bytes), + _ => bail!("Invalid bytes type: {}", s), + } + } + + fn rust_type(&self) -> TokenStream { + match self { + BytesTy::Vec => quote! { ::prost::alloc::vec::Vec<u8> }, + BytesTy::Bytes => quote! { ::prost::bytes::Bytes }, + } + } +} + +impl Ty { + pub fn from_attr(attr: &Meta) -> Result<Option<Ty>, Error> { + let ty = match *attr { + Meta::Path(ref name) if name.is_ident("float") => Ty::Float, + Meta::Path(ref name) if name.is_ident("double") => Ty::Double, + Meta::Path(ref name) if name.is_ident("int32") => Ty::Int32, + Meta::Path(ref name) if name.is_ident("int64") => Ty::Int64, + Meta::Path(ref name) if name.is_ident("uint32") => Ty::Uint32, + Meta::Path(ref name) if name.is_ident("uint64") => Ty::Uint64, + Meta::Path(ref name) if name.is_ident("sint32") => Ty::Sint32, + Meta::Path(ref name) if name.is_ident("sint64") => Ty::Sint64, + Meta::Path(ref name) if name.is_ident("fixed32") => Ty::Fixed32, + Meta::Path(ref name) if name.is_ident("fixed64") => Ty::Fixed64, + Meta::Path(ref name) if name.is_ident("sfixed32") => Ty::Sfixed32, + Meta::Path(ref name) if name.is_ident("sfixed64") => Ty::Sfixed64, + Meta::Path(ref name) if name.is_ident("bool") => Ty::Bool, + Meta::Path(ref name) if name.is_ident("string") => Ty::String, + Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes(BytesTy::Vec), + Meta::NameValue(MetaNameValue { + ref path, + value: + Expr::Lit(ExprLit { + lit: Lit::Str(ref l), + .. + }), + .. + }) if path.is_ident("bytes") => Ty::Bytes(BytesTy::try_from_str(&l.value())?), + Meta::NameValue(MetaNameValue { + ref path, + value: + Expr::Lit(ExprLit { + lit: Lit::Str(ref l), + .. + }), + .. + }) if path.is_ident("enumeration") => Ty::Enumeration(parse_str::<Path>(&l.value())?), + Meta::List(ref meta_list) if meta_list.path.is_ident("enumeration") => { + Ty::Enumeration(meta_list.parse_args::<Path>()?) + } + _ => return Ok(None), + }; + Ok(Some(ty)) + } + + pub fn from_str(s: &str) -> Result<Ty, Error> { + let enumeration_len = "enumeration".len(); + let error = Err(anyhow!("invalid type: {}", s)); + let ty = match s.trim() { + "float" => Ty::Float, + "double" => Ty::Double, + "int32" => Ty::Int32, + "int64" => Ty::Int64, + "uint32" => Ty::Uint32, + "uint64" => Ty::Uint64, + "sint32" => Ty::Sint32, + "sint64" => Ty::Sint64, + "fixed32" => Ty::Fixed32, + "fixed64" => Ty::Fixed64, + "sfixed32" => Ty::Sfixed32, + "sfixed64" => Ty::Sfixed64, + "bool" => Ty::Bool, + "string" => Ty::String, + "bytes" => Ty::Bytes(BytesTy::Vec), + s if s.len() > enumeration_len && &s[..enumeration_len] == "enumeration" => { + let s = &s[enumeration_len..].trim(); + match s.chars().next() { + Some('<') | Some('(') => (), + _ => return error, + } + match s.chars().next_back() { + Some('>') | Some(')') => (), + _ => return error, + } + + Ty::Enumeration(parse_str::<Path>(s[1..s.len() - 1].trim())?) + } + _ => return error, + }; + Ok(ty) + } + + /// Returns the type as it appears in protobuf field declarations. + pub fn as_str(&self) -> &'static str { + match *self { + Ty::Double => "double", + Ty::Float => "float", + Ty::Int32 => "int32", + Ty::Int64 => "int64", + Ty::Uint32 => "uint32", + Ty::Uint64 => "uint64", + Ty::Sint32 => "sint32", + Ty::Sint64 => "sint64", + Ty::Fixed32 => "fixed32", + Ty::Fixed64 => "fixed64", + Ty::Sfixed32 => "sfixed32", + Ty::Sfixed64 => "sfixed64", + Ty::Bool => "bool", + Ty::String => "string", + Ty::Bytes(..) => "bytes", + Ty::Enumeration(..) => "enum", + } + } + + // TODO: rename to 'owned_type'. + pub fn rust_type(&self) -> TokenStream { + match self { + Ty::String => quote!(::prost::alloc::string::String), + Ty::Bytes(ty) => ty.rust_type(), + _ => self.rust_ref_type(), + } + } + + // TODO: rename to 'ref_type' + pub fn rust_ref_type(&self) -> TokenStream { + match *self { + Ty::Double => quote!(f64), + Ty::Float => quote!(f32), + Ty::Int32 => quote!(i32), + Ty::Int64 => quote!(i64), + Ty::Uint32 => quote!(u32), + Ty::Uint64 => quote!(u64), + Ty::Sint32 => quote!(i32), + Ty::Sint64 => quote!(i64), + Ty::Fixed32 => quote!(u32), + Ty::Fixed64 => quote!(u64), + Ty::Sfixed32 => quote!(i32), + Ty::Sfixed64 => quote!(i64), + Ty::Bool => quote!(bool), + Ty::String => quote!(&str), + Ty::Bytes(..) => quote!(&[u8]), + Ty::Enumeration(..) => quote!(i32), + } + } + + pub fn module(&self) -> Ident { + match *self { + Ty::Enumeration(..) => Ident::new("int32", Span::call_site()), + _ => Ident::new(self.as_str(), Span::call_site()), + } + } + + /// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`). + pub fn is_numeric(&self) -> bool { + !matches!(self, Ty::String | Ty::Bytes(..)) + } +} + +impl fmt::Debug for Ty { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +impl fmt::Display for Ty { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +/// Scalar Protobuf field types. +#[derive(Clone)] +pub enum Kind { + /// A plain proto3 scalar field. + Plain(DefaultValue), + /// An optional scalar field. + Optional(DefaultValue), + /// A required proto2 scalar field. + Required(DefaultValue), + /// A repeated scalar field. + Repeated, + /// A packed repeated scalar field. + Packed, +} + +/// Scalar Protobuf field default value. +#[derive(Clone, Debug)] +pub enum DefaultValue { + F64(f64), + F32(f32), + I32(i32), + I64(i64), + U32(u32), + U64(u64), + Bool(bool), + String(String), + Bytes(Vec<u8>), + Enumeration(TokenStream), + Path(Path), +} + +impl DefaultValue { + pub fn from_attr(attr: &Meta) -> Result<Option<Lit>, Error> { + if !attr.path().is_ident("default") { + Ok(None) + } else if let Meta::NameValue(MetaNameValue { + value: Expr::Lit(ExprLit { ref lit, .. }), + .. + }) = *attr + { + Ok(Some(lit.clone())) + } else { + bail!("invalid default value attribute: {:?}", attr) + } + } + + pub fn from_lit(ty: &Ty, lit: Lit) -> Result<DefaultValue, Error> { + let is_i32 = *ty == Ty::Int32 || *ty == Ty::Sint32 || *ty == Ty::Sfixed32; + let is_i64 = *ty == Ty::Int64 || *ty == Ty::Sint64 || *ty == Ty::Sfixed64; + + let is_u32 = *ty == Ty::Uint32 || *ty == Ty::Fixed32; + let is_u64 = *ty == Ty::Uint64 || *ty == Ty::Fixed64; + + let empty_or_is = |expected, actual: &str| expected == actual || actual.is_empty(); + + let default = match lit { + Lit::Int(ref lit) if is_i32 && empty_or_is("i32", lit.suffix()) => { + DefaultValue::I32(lit.base10_parse()?) + } + Lit::Int(ref lit) if is_i64 && empty_or_is("i64", lit.suffix()) => { + DefaultValue::I64(lit.base10_parse()?) + } + Lit::Int(ref lit) if is_u32 && empty_or_is("u32", lit.suffix()) => { + DefaultValue::U32(lit.base10_parse()?) + } + Lit::Int(ref lit) if is_u64 && empty_or_is("u64", lit.suffix()) => { + DefaultValue::U64(lit.base10_parse()?) + } + + Lit::Float(ref lit) if *ty == Ty::Float && empty_or_is("f32", lit.suffix()) => { + DefaultValue::F32(lit.base10_parse()?) + } + Lit::Int(ref lit) if *ty == Ty::Float => DefaultValue::F32(lit.base10_parse()?), + + Lit::Float(ref lit) if *ty == Ty::Double && empty_or_is("f64", lit.suffix()) => { + DefaultValue::F64(lit.base10_parse()?) + } + Lit::Int(ref lit) if *ty == Ty::Double => DefaultValue::F64(lit.base10_parse()?), + + Lit::Bool(ref lit) if *ty == Ty::Bool => DefaultValue::Bool(lit.value), + Lit::Str(ref lit) if *ty == Ty::String => DefaultValue::String(lit.value()), + Lit::ByteStr(ref lit) + if *ty == Ty::Bytes(BytesTy::Bytes) || *ty == Ty::Bytes(BytesTy::Vec) => + { + DefaultValue::Bytes(lit.value()) + } + + Lit::Str(ref lit) => { + let value = lit.value(); + let value = value.trim(); + + if let Ty::Enumeration(ref path) = *ty { + let variant = Ident::new(value, Span::call_site()); + return Ok(DefaultValue::Enumeration(quote!(#path::#variant))); + } + + // Parse special floating point values. + if *ty == Ty::Float { + match value { + "inf" => { + return Ok(DefaultValue::Path(parse_str::<Path>( + "::core::f32::INFINITY", + )?)); + } + "-inf" => { + return Ok(DefaultValue::Path(parse_str::<Path>( + "::core::f32::NEG_INFINITY", + )?)); + } + "nan" => { + return Ok(DefaultValue::Path(parse_str::<Path>("::core::f32::NAN")?)); + } + _ => (), + } + } + if *ty == Ty::Double { + match value { + "inf" => { + return Ok(DefaultValue::Path(parse_str::<Path>( + "::core::f64::INFINITY", + )?)); + } + "-inf" => { + return Ok(DefaultValue::Path(parse_str::<Path>( + "::core::f64::NEG_INFINITY", + )?)); + } + "nan" => { + return Ok(DefaultValue::Path(parse_str::<Path>("::core::f64::NAN")?)); + } + _ => (), + } + } + + // Rust doesn't have a negative literals, so they have to be parsed specially. + if let Some(Ok(lit)) = value.strip_prefix('-').map(syn::parse_str::<Lit>) { + match lit { + Lit::Int(ref lit) if is_i32 && empty_or_is("i32", lit.suffix()) => { + // Initially parse into an i64, so that i32::MIN does not overflow. + let value: i64 = -lit.base10_parse()?; + return Ok(i32::try_from(value).map(DefaultValue::I32)?); + } + Lit::Int(ref lit) if is_i64 && empty_or_is("i64", lit.suffix()) => { + // Initially parse into an i128, so that i64::MIN does not overflow. + let value: i128 = -lit.base10_parse()?; + return Ok(i64::try_from(value).map(DefaultValue::I64)?); + } + Lit::Float(ref lit) + if *ty == Ty::Float && empty_or_is("f32", lit.suffix()) => + { + return Ok(DefaultValue::F32(-lit.base10_parse()?)); + } + Lit::Float(ref lit) + if *ty == Ty::Double && empty_or_is("f64", lit.suffix()) => + { + return Ok(DefaultValue::F64(-lit.base10_parse()?)); + } + Lit::Int(ref lit) if *ty == Ty::Float && lit.suffix().is_empty() => { + return Ok(DefaultValue::F32(-lit.base10_parse()?)); + } + Lit::Int(ref lit) if *ty == Ty::Double && lit.suffix().is_empty() => { + return Ok(DefaultValue::F64(-lit.base10_parse()?)); + } + _ => (), + } + } + match syn::parse_str::<Lit>(value) { + Ok(Lit::Str(_)) => (), + Ok(lit) => return DefaultValue::from_lit(ty, lit), + _ => (), + } + bail!("invalid default value: {}", quote!(#value)); + } + _ => bail!("invalid default value: {}", quote!(#lit)), + }; + + Ok(default) + } + + pub fn new(ty: &Ty) -> DefaultValue { + match *ty { + Ty::Float => DefaultValue::F32(0.0), + Ty::Double => DefaultValue::F64(0.0), + Ty::Int32 | Ty::Sint32 | Ty::Sfixed32 => DefaultValue::I32(0), + Ty::Int64 | Ty::Sint64 | Ty::Sfixed64 => DefaultValue::I64(0), + Ty::Uint32 | Ty::Fixed32 => DefaultValue::U32(0), + Ty::Uint64 | Ty::Fixed64 => DefaultValue::U64(0), + + Ty::Bool => DefaultValue::Bool(false), + Ty::String => DefaultValue::String(String::new()), + Ty::Bytes(..) => DefaultValue::Bytes(Vec::new()), + Ty::Enumeration(ref path) => DefaultValue::Enumeration(quote!(#path::default())), + } + } + + pub fn owned(&self) -> TokenStream { + match *self { + DefaultValue::String(ref value) if value.is_empty() => { + quote!(::prost::alloc::string::String::new()) + } + DefaultValue::String(ref value) => quote!(#value.into()), + DefaultValue::Bytes(ref value) if value.is_empty() => { + quote!(::core::default::Default::default()) + } + DefaultValue::Bytes(ref value) => { + let lit = LitByteStr::new(value, Span::call_site()); + quote!(#lit.as_ref().into()) + } + + ref other => other.typed(), + } + } + + pub fn typed(&self) -> TokenStream { + if let DefaultValue::Enumeration(_) = *self { + quote!(#self as i32) + } else { + quote!(#self) + } + } +} + +impl ToTokens for DefaultValue { + fn to_tokens(&self, tokens: &mut TokenStream) { + match *self { + DefaultValue::F64(value) => value.to_tokens(tokens), + DefaultValue::F32(value) => value.to_tokens(tokens), + DefaultValue::I32(value) => value.to_tokens(tokens), + DefaultValue::I64(value) => value.to_tokens(tokens), + DefaultValue::U32(value) => value.to_tokens(tokens), + DefaultValue::U64(value) => value.to_tokens(tokens), + DefaultValue::Bool(value) => value.to_tokens(tokens), + DefaultValue::String(ref value) => value.to_tokens(tokens), + DefaultValue::Bytes(ref value) => { + let byte_str = LitByteStr::new(value, Span::call_site()); + tokens.append_all(quote!(#byte_str as &[u8])); + } + DefaultValue::Enumeration(ref value) => value.to_tokens(tokens), + DefaultValue::Path(ref value) => value.to_tokens(tokens), + } + } +} diff --git a/third_party/rust/prost-derive/src/lib.rs b/third_party/rust/prost-derive/src/lib.rs new file mode 100644 index 0000000000..8bc99c5eda --- /dev/null +++ b/third_party/rust/prost-derive/src/lib.rs @@ -0,0 +1,530 @@ +#![doc(html_root_url = "https://docs.rs/prost-derive/0.12.1")] +// The `quote!` macro requires deep recursion. +#![recursion_limit = "4096"] + +extern crate alloc; +extern crate proc_macro; + +use anyhow::{bail, Error}; +use itertools::Itertools; +use proc_macro::TokenStream; +use proc_macro2::Span; +use quote::quote; +use syn::{ + punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed, + FieldsUnnamed, Ident, Index, Variant, +}; + +mod field; +use crate::field::Field; + +fn try_message(input: TokenStream) -> Result<TokenStream, Error> { + let input: DeriveInput = syn::parse(input)?; + + let ident = input.ident; + + syn::custom_keyword!(skip_debug); + let skip_debug = input + .attrs + .into_iter() + .any(|a| a.path().is_ident("prost") && a.parse_args::<skip_debug>().is_ok()); + + let variant_data = match input.data { + Data::Struct(variant_data) => variant_data, + Data::Enum(..) => bail!("Message can not be derived for an enum"), + Data::Union(..) => bail!("Message can not be derived for a union"), + }; + + let generics = &input.generics; + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let (is_struct, fields) = match variant_data { + DataStruct { + fields: Fields::Named(FieldsNamed { named: fields, .. }), + .. + } => (true, fields.into_iter().collect()), + DataStruct { + fields: + Fields::Unnamed(FieldsUnnamed { + unnamed: fields, .. + }), + .. + } => (false, fields.into_iter().collect()), + DataStruct { + fields: Fields::Unit, + .. + } => (false, Vec::new()), + }; + + let mut next_tag: u32 = 1; + let mut fields = fields + .into_iter() + .enumerate() + .flat_map(|(i, field)| { + let field_ident = field.ident.map(|x| quote!(#x)).unwrap_or_else(|| { + let index = Index { + index: i as u32, + span: Span::call_site(), + }; + quote!(#index) + }); + match Field::new(field.attrs, Some(next_tag)) { + Ok(Some(field)) => { + next_tag = field.tags().iter().max().map(|t| t + 1).unwrap_or(next_tag); + Some(Ok((field_ident, field))) + } + Ok(None) => None, + Err(err) => Some(Err( + err.context(format!("invalid message field {}.{}", ident, field_ident)) + )), + } + }) + .collect::<Result<Vec<_>, _>>()?; + + // We want Debug to be in declaration order + let unsorted_fields = fields.clone(); + + // Sort the fields by tag number so that fields will be encoded in tag order. + // TODO: This encodes oneof fields in the position of their lowest tag, + // regardless of the currently occupied variant, is that consequential? + // See: https://developers.google.com/protocol-buffers/docs/encoding#order + fields.sort_by_key(|&(_, ref field)| field.tags().into_iter().min().unwrap()); + let fields = fields; + + let mut tags = fields + .iter() + .flat_map(|&(_, ref field)| field.tags()) + .collect::<Vec<_>>(); + let num_tags = tags.len(); + tags.sort_unstable(); + tags.dedup(); + if tags.len() != num_tags { + bail!("message {} has fields with duplicate tags", ident); + } + + let encoded_len = fields + .iter() + .map(|&(ref field_ident, ref field)| field.encoded_len(quote!(self.#field_ident))); + + let encode = fields + .iter() + .map(|&(ref field_ident, ref field)| field.encode(quote!(self.#field_ident))); + + let merge = fields.iter().map(|&(ref field_ident, ref field)| { + let merge = field.merge(quote!(value)); + let tags = field.tags().into_iter().map(|tag| quote!(#tag)); + let tags = Itertools::intersperse(tags, quote!(|)); + + quote! { + #(#tags)* => { + let mut value = &mut self.#field_ident; + #merge.map_err(|mut error| { + error.push(STRUCT_NAME, stringify!(#field_ident)); + error + }) + }, + } + }); + + let struct_name = if fields.is_empty() { + quote!() + } else { + quote!( + const STRUCT_NAME: &'static str = stringify!(#ident); + ) + }; + + let clear = fields + .iter() + .map(|&(ref field_ident, ref field)| field.clear(quote!(self.#field_ident))); + + let default = if is_struct { + let default = fields.iter().map(|(field_ident, field)| { + let value = field.default(); + quote!(#field_ident: #value,) + }); + quote! {#ident { + #(#default)* + }} + } else { + let default = fields.iter().map(|(_, field)| { + let value = field.default(); + quote!(#value,) + }); + quote! {#ident ( + #(#default)* + )} + }; + + let methods = fields + .iter() + .flat_map(|&(ref field_ident, ref field)| field.methods(field_ident)) + .collect::<Vec<_>>(); + let methods = if methods.is_empty() { + quote!() + } else { + quote! { + #[allow(dead_code)] + impl #impl_generics #ident #ty_generics #where_clause { + #(#methods)* + } + } + }; + + let expanded = quote! { + impl #impl_generics ::prost::Message for #ident #ty_generics #where_clause { + #[allow(unused_variables)] + fn encode_raw<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut { + #(#encode)* + } + + #[allow(unused_variables)] + fn merge_field<B>( + &mut self, + tag: u32, + wire_type: ::prost::encoding::WireType, + buf: &mut B, + ctx: ::prost::encoding::DecodeContext, + ) -> ::core::result::Result<(), ::prost::DecodeError> + where B: ::prost::bytes::Buf { + #struct_name + match tag { + #(#merge)* + _ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx), + } + } + + #[inline] + fn encoded_len(&self) -> usize { + 0 #(+ #encoded_len)* + } + + fn clear(&mut self) { + #(#clear;)* + } + } + + impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause { + fn default() -> Self { + #default + } + } + }; + let expanded = if skip_debug { + expanded + } else { + let debugs = unsorted_fields.iter().map(|&(ref field_ident, ref field)| { + let wrapper = field.debug(quote!(self.#field_ident)); + let call = if is_struct { + quote!(builder.field(stringify!(#field_ident), &wrapper)) + } else { + quote!(builder.field(&wrapper)) + }; + quote! { + let builder = { + let wrapper = #wrapper; + #call + }; + } + }); + let debug_builder = if is_struct { + quote!(f.debug_struct(stringify!(#ident))) + } else { + quote!(f.debug_tuple(stringify!(#ident))) + }; + quote! { + #expanded + + impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + let mut builder = #debug_builder; + #(#debugs;)* + builder.finish() + } + } + } + }; + + let expanded = quote! { + #expanded + + #methods + }; + + Ok(expanded.into()) +} + +#[proc_macro_derive(Message, attributes(prost))] +pub fn message(input: TokenStream) -> TokenStream { + try_message(input).unwrap() +} + +fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> { + let input: DeriveInput = syn::parse(input)?; + let ident = input.ident; + + let generics = &input.generics; + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let punctuated_variants = match input.data { + Data::Enum(DataEnum { variants, .. }) => variants, + Data::Struct(_) => bail!("Enumeration can not be derived for a struct"), + Data::Union(..) => bail!("Enumeration can not be derived for a union"), + }; + + // Map the variants into 'fields'. + let mut variants: Vec<(Ident, Expr)> = Vec::new(); + for Variant { + ident, + fields, + discriminant, + .. + } in punctuated_variants + { + match fields { + Fields::Unit => (), + Fields::Named(_) | Fields::Unnamed(_) => { + bail!("Enumeration variants may not have fields") + } + } + + match discriminant { + Some((_, expr)) => variants.push((ident, expr)), + None => bail!("Enumeration variants must have a discriminant"), + } + } + + if variants.is_empty() { + panic!("Enumeration must have at least one variant"); + } + + let default = variants[0].0.clone(); + + let is_valid = variants + .iter() + .map(|&(_, ref value)| quote!(#value => true)); + let from = variants.iter().map( + |&(ref variant, ref value)| quote!(#value => ::core::option::Option::Some(#ident::#variant)), + ); + + let try_from = variants.iter().map( + |&(ref variant, ref value)| quote!(#value => ::core::result::Result::Ok(#ident::#variant)), + ); + + let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident); + let from_i32_doc = format!( + "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.", + ident + ); + + let expanded = quote! { + impl #impl_generics #ident #ty_generics #where_clause { + #[doc=#is_valid_doc] + pub fn is_valid(value: i32) -> bool { + match value { + #(#is_valid,)* + _ => false, + } + } + + #[deprecated = "Use the TryFrom<i32> implementation instead"] + #[doc=#from_i32_doc] + pub fn from_i32(value: i32) -> ::core::option::Option<#ident> { + match value { + #(#from,)* + _ => ::core::option::Option::None, + } + } + } + + impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause { + fn default() -> #ident { + #ident::#default + } + } + + impl #impl_generics ::core::convert::From::<#ident> for i32 #ty_generics #where_clause { + fn from(value: #ident) -> i32 { + value as i32 + } + } + + impl #impl_generics ::core::convert::TryFrom::<i32> for #ident #ty_generics #where_clause { + type Error = ::prost::DecodeError; + + fn try_from(value: i32) -> ::core::result::Result<#ident, ::prost::DecodeError> { + match value { + #(#try_from,)* + _ => ::core::result::Result::Err(::prost::DecodeError::new("invalid enumeration value")), + } + } + } + }; + + Ok(expanded.into()) +} + +#[proc_macro_derive(Enumeration, attributes(prost))] +pub fn enumeration(input: TokenStream) -> TokenStream { + try_enumeration(input).unwrap() +} + +fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> { + let input: DeriveInput = syn::parse(input)?; + + let ident = input.ident; + + syn::custom_keyword!(skip_debug); + let skip_debug = input + .attrs + .into_iter() + .any(|a| a.path().is_ident("prost") && a.parse_args::<skip_debug>().is_ok()); + + let variants = match input.data { + Data::Enum(DataEnum { variants, .. }) => variants, + Data::Struct(..) => bail!("Oneof can not be derived for a struct"), + Data::Union(..) => bail!("Oneof can not be derived for a union"), + }; + + let generics = &input.generics; + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + // Map the variants into 'fields'. + let mut fields: Vec<(Ident, Field)> = Vec::new(); + for Variant { + attrs, + ident: variant_ident, + fields: variant_fields, + .. + } in variants + { + let variant_fields = match variant_fields { + Fields::Unit => Punctuated::new(), + Fields::Named(FieldsNamed { named: fields, .. }) + | Fields::Unnamed(FieldsUnnamed { + unnamed: fields, .. + }) => fields, + }; + if variant_fields.len() != 1 { + bail!("Oneof enum variants must have a single field"); + } + match Field::new_oneof(attrs)? { + Some(field) => fields.push((variant_ident, field)), + None => bail!("invalid oneof variant: oneof variants may not be ignored"), + } + } + + let mut tags = fields + .iter() + .flat_map(|&(ref variant_ident, ref field)| -> Result<u32, Error> { + if field.tags().len() > 1 { + bail!( + "invalid oneof variant {}::{}: oneof variants may only have a single tag", + ident, + variant_ident + ); + } + Ok(field.tags()[0]) + }) + .collect::<Vec<_>>(); + tags.sort_unstable(); + tags.dedup(); + if tags.len() != fields.len() { + panic!("invalid oneof {}: variants have duplicate tags", ident); + } + + let encode = fields.iter().map(|&(ref variant_ident, ref field)| { + let encode = field.encode(quote!(*value)); + quote!(#ident::#variant_ident(ref value) => { #encode }) + }); + + let merge = fields.iter().map(|&(ref variant_ident, ref field)| { + let tag = field.tags()[0]; + let merge = field.merge(quote!(value)); + quote! { + #tag => { + match field { + ::core::option::Option::Some(#ident::#variant_ident(ref mut value)) => { + #merge + }, + _ => { + let mut owned_value = ::core::default::Default::default(); + let value = &mut owned_value; + #merge.map(|_| *field = ::core::option::Option::Some(#ident::#variant_ident(owned_value))) + }, + } + } + } + }); + + let encoded_len = fields.iter().map(|&(ref variant_ident, ref field)| { + let encoded_len = field.encoded_len(quote!(*value)); + quote!(#ident::#variant_ident(ref value) => #encoded_len) + }); + + let expanded = quote! { + impl #impl_generics #ident #ty_generics #where_clause { + /// Encodes the message to a buffer. + pub fn encode<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut { + match *self { + #(#encode,)* + } + } + + /// Decodes an instance of the message from a buffer, and merges it into self. + pub fn merge<B>( + field: &mut ::core::option::Option<#ident #ty_generics>, + tag: u32, + wire_type: ::prost::encoding::WireType, + buf: &mut B, + ctx: ::prost::encoding::DecodeContext, + ) -> ::core::result::Result<(), ::prost::DecodeError> + where B: ::prost::bytes::Buf { + match tag { + #(#merge,)* + _ => unreachable!(concat!("invalid ", stringify!(#ident), " tag: {}"), tag), + } + } + + /// Returns the encoded length of the message without a length delimiter. + #[inline] + pub fn encoded_len(&self) -> usize { + match *self { + #(#encoded_len,)* + } + } + } + + }; + let expanded = if skip_debug { + expanded + } else { + let debug = fields.iter().map(|&(ref variant_ident, ref field)| { + let wrapper = field.debug(quote!(*value)); + quote!(#ident::#variant_ident(ref value) => { + let wrapper = #wrapper; + f.debug_tuple(stringify!(#variant_ident)) + .field(&wrapper) + .finish() + }) + }); + quote! { + #expanded + + impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + match *self { + #(#debug,)* + } + } + } + } + }; + + Ok(expanded.into()) +} + +#[proc_macro_derive(Oneof, attributes(prost))] +pub fn oneof(input: TokenStream) -> TokenStream { + try_oneof(input).unwrap() +} |