diff options
Diffstat (limited to 'third_party/rust/prost-derive/src/field')
-rw-r--r-- | third_party/rust/prost-derive/src/field/group.rs | 134 | ||||
-rw-r--r-- | third_party/rust/prost-derive/src/field/map.rs | 394 | ||||
-rw-r--r-- | third_party/rust/prost-derive/src/field/message.rs | 134 | ||||
-rw-r--r-- | third_party/rust/prost-derive/src/field/mod.rs | 366 | ||||
-rw-r--r-- | third_party/rust/prost-derive/src/field/oneof.rs | 99 | ||||
-rw-r--r-- | third_party/rust/prost-derive/src/field/scalar.rs | 810 |
6 files changed, 1937 insertions, 0 deletions
diff --git a/third_party/rust/prost-derive/src/field/group.rs b/third_party/rust/prost-derive/src/field/group.rs new file mode 100644 index 0000000000..6396128802 --- /dev/null +++ b/third_party/rust/prost-derive/src/field/group.rs @@ -0,0 +1,134 @@ +use anyhow::{bail, Error}; +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; +use syn::Meta; + +use crate::field::{set_bool, set_option, tag_attr, word_attr, Label}; + +#[derive(Clone)] +pub struct Field { + pub label: Label, + pub tag: u32, +} + +impl Field { + pub fn new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error> { + let mut group = false; + let mut label = None; + let mut tag = None; + let mut boxed = false; + + let mut unknown_attrs = Vec::new(); + + for attr in attrs { + if word_attr("group", attr) { + set_bool(&mut group, "duplicate group attributes")?; + } else if word_attr("boxed", attr) { + set_bool(&mut boxed, "duplicate boxed attributes")?; + } else if let Some(t) = tag_attr(attr)? { + set_option(&mut tag, t, "duplicate tag attributes")?; + } else if let Some(l) = Label::from_attr(attr) { + set_option(&mut label, l, "duplicate label attributes")?; + } else { + unknown_attrs.push(attr); + } + } + + if !group { + return Ok(None); + } + + match unknown_attrs.len() { + 0 => (), + 1 => bail!("unknown attribute for group field: {:?}", unknown_attrs[0]), + _ => bail!("unknown attributes for group field: {:?}", unknown_attrs), + } + + let tag = match tag.or(inferred_tag) { + Some(tag) => tag, + None => bail!("group field is missing a tag attribute"), + }; + + Ok(Some(Field { + label: label.unwrap_or(Label::Optional), + tag, + })) + } + + pub fn new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error> { + if let Some(mut field) = Field::new(attrs, None)? { + if let Some(attr) = attrs.iter().find(|attr| Label::from_attr(attr).is_some()) { + bail!( + "invalid attribute for oneof field: {}", + attr.path().into_token_stream() + ); + } + field.label = Label::Required; + Ok(Some(field)) + } else { + Ok(None) + } + } + + pub fn encode(&self, ident: TokenStream) -> TokenStream { + let tag = self.tag; + match self.label { + Label::Optional => quote! { + if let Some(ref msg) = #ident { + ::prost::encoding::group::encode(#tag, msg, buf); + } + }, + Label::Required => quote! { + ::prost::encoding::group::encode(#tag, &#ident, buf); + }, + Label::Repeated => quote! { + for msg in &#ident { + ::prost::encoding::group::encode(#tag, msg, buf); + } + }, + } + } + + pub fn merge(&self, ident: TokenStream) -> TokenStream { + match self.label { + Label::Optional => quote! { + ::prost::encoding::group::merge( + tag, + wire_type, + #ident.get_or_insert_with(Default::default), + buf, + ctx, + ) + }, + Label::Required => quote! { + ::prost::encoding::group::merge(tag, wire_type, #ident, buf, ctx) + }, + Label::Repeated => quote! { + ::prost::encoding::group::merge_repeated(tag, wire_type, #ident, buf, ctx) + }, + } + } + + pub fn encoded_len(&self, ident: TokenStream) -> TokenStream { + let tag = self.tag; + match self.label { + Label::Optional => quote! { + #ident.as_ref().map_or(0, |msg| ::prost::encoding::group::encoded_len(#tag, msg)) + }, + Label::Required => quote! { + ::prost::encoding::group::encoded_len(#tag, &#ident) + }, + Label::Repeated => quote! { + ::prost::encoding::group::encoded_len_repeated(#tag, &#ident) + }, + } + } + + pub fn clear(&self, ident: TokenStream) -> TokenStream { + match self.label { + Label::Optional => quote!(#ident = ::core::option::Option::None), + Label::Required => quote!(#ident.clear()), + Label::Repeated => quote!(#ident.clear()), + } + } +} diff --git a/third_party/rust/prost-derive/src/field/map.rs b/third_party/rust/prost-derive/src/field/map.rs new file mode 100644 index 0000000000..1228a6fae2 --- /dev/null +++ b/third_party/rust/prost-derive/src/field/map.rs @@ -0,0 +1,394 @@ +use anyhow::{bail, Error}; +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use syn::{Ident, Lit, Meta, MetaNameValue, NestedMeta}; + +use crate::field::{scalar, set_option, tag_attr}; + +#[derive(Clone, Debug)] +pub enum MapTy { + HashMap, + BTreeMap, +} + +impl MapTy { + fn from_str(s: &str) -> Option<MapTy> { + match s { + "map" | "hash_map" => Some(MapTy::HashMap), + "btree_map" => Some(MapTy::BTreeMap), + _ => None, + } + } + + fn module(&self) -> Ident { + match *self { + MapTy::HashMap => Ident::new("hash_map", Span::call_site()), + MapTy::BTreeMap => Ident::new("btree_map", Span::call_site()), + } + } + + fn lib(&self) -> TokenStream { + match self { + MapTy::HashMap => quote! { std }, + MapTy::BTreeMap => quote! { prost::alloc }, + } + } +} + +fn fake_scalar(ty: scalar::Ty) -> scalar::Field { + let kind = scalar::Kind::Plain(scalar::DefaultValue::new(&ty)); + scalar::Field { + ty, + kind, + tag: 0, // Not used here + } +} + +#[derive(Clone)] +pub struct Field { + pub map_ty: MapTy, + pub key_ty: scalar::Ty, + pub value_ty: ValueTy, + pub tag: u32, +} + +impl Field { + pub fn new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error> { + let mut types = None; + let mut tag = None; + + for attr in attrs { + if let Some(t) = tag_attr(attr)? { + set_option(&mut tag, t, "duplicate tag attributes")?; + } else if let Some(map_ty) = attr + .path() + .get_ident() + .and_then(|i| MapTy::from_str(&i.to_string())) + { + let (k, v): (String, String) = match &*attr { + Meta::NameValue(MetaNameValue { + lit: Lit::Str(lit), .. + }) => { + let items = lit.value(); + let mut items = items.split(',').map(ToString::to_string); + let k = items.next().unwrap(); + let v = match items.next() { + Some(k) => k, + None => bail!("invalid map attribute: must have key and value types"), + }; + if items.next().is_some() { + bail!("invalid map attribute: {:?}", attr); + } + (k, v) + } + Meta::List(meta_list) => { + // TODO(rustlang/rust#23121): slice pattern matching would make this much nicer. + if meta_list.nested.len() != 2 { + bail!("invalid map attribute: must contain key and value types"); + } + let k = match &meta_list.nested[0] { + NestedMeta::Meta(Meta::Path(k)) if k.get_ident().is_some() => { + k.get_ident().unwrap().to_string() + } + _ => bail!("invalid map attribute: key must be an identifier"), + }; + let v = match &meta_list.nested[1] { + NestedMeta::Meta(Meta::Path(v)) if v.get_ident().is_some() => { + v.get_ident().unwrap().to_string() + } + _ => bail!("invalid map attribute: value must be an identifier"), + }; + (k, v) + } + _ => return Ok(None), + }; + set_option( + &mut types, + (map_ty, key_ty_from_str(&k)?, ValueTy::from_str(&v)?), + "duplicate map type attribute", + )?; + } else { + return Ok(None); + } + } + + Ok(match (types, tag.or(inferred_tag)) { + (Some((map_ty, key_ty, value_ty)), Some(tag)) => Some(Field { + map_ty, + key_ty, + value_ty, + tag, + }), + _ => None, + }) + } + + pub fn new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error> { + Field::new(attrs, None) + } + + /// Returns a statement which encodes the map field. + pub fn encode(&self, ident: TokenStream) -> TokenStream { + let tag = self.tag; + let key_mod = self.key_ty.module(); + let ke = quote!(::prost::encoding::#key_mod::encode); + let kl = quote!(::prost::encoding::#key_mod::encoded_len); + let module = self.map_ty.module(); + match &self.value_ty { + ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { + let default = quote!(#ty::default() as i32); + quote! { + ::prost::encoding::#module::encode_with_default( + #ke, + #kl, + ::prost::encoding::int32::encode, + ::prost::encoding::int32::encoded_len, + &(#default), + #tag, + &#ident, + buf, + ); + } + } + ValueTy::Scalar(value_ty) => { + let val_mod = value_ty.module(); + let ve = quote!(::prost::encoding::#val_mod::encode); + let vl = quote!(::prost::encoding::#val_mod::encoded_len); + quote! { + ::prost::encoding::#module::encode( + #ke, + #kl, + #ve, + #vl, + #tag, + &#ident, + buf, + ); + } + } + ValueTy::Message => quote! { + ::prost::encoding::#module::encode( + #ke, + #kl, + ::prost::encoding::message::encode, + ::prost::encoding::message::encoded_len, + #tag, + &#ident, + buf, + ); + }, + } + } + + /// Returns an expression which evaluates to the result of merging a decoded key value pair + /// into the map. + pub fn merge(&self, ident: TokenStream) -> TokenStream { + let key_mod = self.key_ty.module(); + let km = quote!(::prost::encoding::#key_mod::merge); + let module = self.map_ty.module(); + match &self.value_ty { + ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { + let default = quote!(#ty::default() as i32); + quote! { + ::prost::encoding::#module::merge_with_default( + #km, + ::prost::encoding::int32::merge, + #default, + &mut #ident, + buf, + ctx, + ) + } + } + ValueTy::Scalar(value_ty) => { + let val_mod = value_ty.module(); + let vm = quote!(::prost::encoding::#val_mod::merge); + quote!(::prost::encoding::#module::merge(#km, #vm, &mut #ident, buf, ctx)) + } + ValueTy::Message => quote! { + ::prost::encoding::#module::merge( + #km, + ::prost::encoding::message::merge, + &mut #ident, + buf, + ctx, + ) + }, + } + } + + /// Returns an expression which evaluates to the encoded length of the map. + pub fn encoded_len(&self, ident: TokenStream) -> TokenStream { + let tag = self.tag; + let key_mod = self.key_ty.module(); + let kl = quote!(::prost::encoding::#key_mod::encoded_len); + let module = self.map_ty.module(); + match &self.value_ty { + ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => { + let default = quote!(#ty::default() as i32); + quote! { + ::prost::encoding::#module::encoded_len_with_default( + #kl, + ::prost::encoding::int32::encoded_len, + &(#default), + #tag, + &#ident, + ) + } + } + ValueTy::Scalar(value_ty) => { + let val_mod = value_ty.module(); + let vl = quote!(::prost::encoding::#val_mod::encoded_len); + quote!(::prost::encoding::#module::encoded_len(#kl, #vl, #tag, &#ident)) + } + ValueTy::Message => quote! { + ::prost::encoding::#module::encoded_len( + #kl, + ::prost::encoding::message::encoded_len, + #tag, + &#ident, + ) + }, + } + } + + pub fn clear(&self, ident: TokenStream) -> TokenStream { + quote!(#ident.clear()) + } + + /// Returns methods to embed in the message. + pub fn methods(&self, ident: &Ident) -> Option<TokenStream> { + if let ValueTy::Scalar(scalar::Ty::Enumeration(ty)) = &self.value_ty { + let key_ty = self.key_ty.rust_type(); + let key_ref_ty = self.key_ty.rust_ref_type(); + + let get = Ident::new(&format!("get_{}", ident), Span::call_site()); + let insert = Ident::new(&format!("insert_{}", ident), Span::call_site()); + let take_ref = if self.key_ty.is_numeric() { + quote!(&) + } else { + quote!() + }; + + let get_doc = format!( + "Returns the enum value for the corresponding key in `{}`, \ + or `None` if the entry does not exist or it is not a valid enum value.", + ident, + ); + let insert_doc = format!("Inserts a key value pair into `{}`.", ident); + Some(quote! { + #[doc=#get_doc] + pub fn #get(&self, key: #key_ref_ty) -> ::core::option::Option<#ty> { + self.#ident.get(#take_ref key).cloned().and_then(#ty::from_i32) + } + #[doc=#insert_doc] + pub fn #insert(&mut self, key: #key_ty, value: #ty) -> ::core::option::Option<#ty> { + self.#ident.insert(key, value as i32).and_then(#ty::from_i32) + } + }) + } else { + None + } + } + + /// Returns a newtype wrapper around the map, implementing nicer Debug + /// + /// The Debug tries to convert any enumerations met into the variants if possible, instead of + /// outputting the raw numbers. + pub fn debug(&self, wrapper_name: TokenStream) -> TokenStream { + let type_name = match self.map_ty { + MapTy::HashMap => Ident::new("HashMap", Span::call_site()), + MapTy::BTreeMap => Ident::new("BTreeMap", Span::call_site()), + }; + + // A fake field for generating the debug wrapper + let key_wrapper = fake_scalar(self.key_ty.clone()).debug(quote!(KeyWrapper)); + let key = self.key_ty.rust_type(); + let value_wrapper = self.value_ty.debug(); + let libname = self.map_ty.lib(); + let fmt = quote! { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + #key_wrapper + #value_wrapper + let mut builder = f.debug_map(); + for (k, v) in self.0 { + builder.entry(&KeyWrapper(k), &ValueWrapper(v)); + } + builder.finish() + } + }; + match &self.value_ty { + ValueTy::Scalar(ty) => { + let value = ty.rust_type(); + quote! { + struct #wrapper_name<'a>(&'a ::#libname::collections::#type_name<#key, #value>); + impl<'a> ::core::fmt::Debug for #wrapper_name<'a> { + #fmt + } + } + } + ValueTy::Message => quote! { + struct #wrapper_name<'a, V: 'a>(&'a ::#libname::collections::#type_name<#key, V>); + impl<'a, V> ::core::fmt::Debug for #wrapper_name<'a, V> + where + V: ::core::fmt::Debug + 'a, + { + #fmt + } + }, + } + } +} + +fn key_ty_from_str(s: &str) -> Result<scalar::Ty, Error> { + let ty = scalar::Ty::from_str(s)?; + match ty { + scalar::Ty::Int32 + | scalar::Ty::Int64 + | scalar::Ty::Uint32 + | scalar::Ty::Uint64 + | scalar::Ty::Sint32 + | scalar::Ty::Sint64 + | scalar::Ty::Fixed32 + | scalar::Ty::Fixed64 + | scalar::Ty::Sfixed32 + | scalar::Ty::Sfixed64 + | scalar::Ty::Bool + | scalar::Ty::String => Ok(ty), + _ => bail!("invalid map key type: {}", s), + } +} + +/// A map value type. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ValueTy { + Scalar(scalar::Ty), + Message, +} + +impl ValueTy { + fn from_str(s: &str) -> Result<ValueTy, Error> { + if let Ok(ty) = scalar::Ty::from_str(s) { + Ok(ValueTy::Scalar(ty)) + } else if s.trim() == "message" { + Ok(ValueTy::Message) + } else { + bail!("invalid map value type: {}", s); + } + } + + /// Returns a newtype wrapper around the ValueTy for nicer debug. + /// + /// If the contained value is enumeration, it tries to convert it to the variant. If not, it + /// just forwards the implementation. + fn debug(&self) -> TokenStream { + match self { + ValueTy::Scalar(ty) => fake_scalar(ty.clone()).debug(quote!(ValueWrapper)), + ValueTy::Message => quote!( + fn ValueWrapper<T>(v: T) -> T { + v + } + ), + } + } +} diff --git a/third_party/rust/prost-derive/src/field/message.rs b/third_party/rust/prost-derive/src/field/message.rs new file mode 100644 index 0000000000..1ff7c37983 --- /dev/null +++ b/third_party/rust/prost-derive/src/field/message.rs @@ -0,0 +1,134 @@ +use anyhow::{bail, Error}; +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; +use syn::Meta; + +use crate::field::{set_bool, set_option, tag_attr, word_attr, Label}; + +#[derive(Clone)] +pub struct Field { + pub label: Label, + pub tag: u32, +} + +impl Field { + pub fn new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error> { + let mut message = false; + let mut label = None; + let mut tag = None; + let mut boxed = false; + + let mut unknown_attrs = Vec::new(); + + for attr in attrs { + if word_attr("message", attr) { + set_bool(&mut message, "duplicate message attribute")?; + } else if word_attr("boxed", attr) { + set_bool(&mut boxed, "duplicate boxed attribute")?; + } else if let Some(t) = tag_attr(attr)? { + set_option(&mut tag, t, "duplicate tag attributes")?; + } else if let Some(l) = Label::from_attr(attr) { + set_option(&mut label, l, "duplicate label attributes")?; + } else { + unknown_attrs.push(attr); + } + } + + if !message { + return Ok(None); + } + + match unknown_attrs.len() { + 0 => (), + 1 => bail!( + "unknown attribute for message field: {:?}", + unknown_attrs[0] + ), + _ => bail!("unknown attributes for message field: {:?}", unknown_attrs), + } + + let tag = match tag.or(inferred_tag) { + Some(tag) => tag, + None => bail!("message field is missing a tag attribute"), + }; + + Ok(Some(Field { + label: label.unwrap_or(Label::Optional), + tag, + })) + } + + pub fn new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error> { + if let Some(mut field) = Field::new(attrs, None)? { + if let Some(attr) = attrs.iter().find(|attr| Label::from_attr(attr).is_some()) { + bail!( + "invalid attribute for oneof field: {}", + attr.path().into_token_stream() + ); + } + field.label = Label::Required; + Ok(Some(field)) + } else { + Ok(None) + } + } + + pub fn encode(&self, ident: TokenStream) -> TokenStream { + let tag = self.tag; + match self.label { + Label::Optional => quote! { + if let Some(ref msg) = #ident { + ::prost::encoding::message::encode(#tag, msg, buf); + } + }, + Label::Required => quote! { + ::prost::encoding::message::encode(#tag, &#ident, buf); + }, + Label::Repeated => quote! { + for msg in &#ident { + ::prost::encoding::message::encode(#tag, msg, buf); + } + }, + } + } + + pub fn merge(&self, ident: TokenStream) -> TokenStream { + match self.label { + Label::Optional => quote! { + ::prost::encoding::message::merge(wire_type, + #ident.get_or_insert_with(Default::default), + buf, + ctx) + }, + Label::Required => quote! { + ::prost::encoding::message::merge(wire_type, #ident, buf, ctx) + }, + Label::Repeated => quote! { + ::prost::encoding::message::merge_repeated(wire_type, #ident, buf, ctx) + }, + } + } + + pub fn encoded_len(&self, ident: TokenStream) -> TokenStream { + let tag = self.tag; + match self.label { + Label::Optional => quote! { + #ident.as_ref().map_or(0, |msg| ::prost::encoding::message::encoded_len(#tag, msg)) + }, + Label::Required => quote! { + ::prost::encoding::message::encoded_len(#tag, &#ident) + }, + Label::Repeated => quote! { + ::prost::encoding::message::encoded_len_repeated(#tag, &#ident) + }, + } + } + + pub fn clear(&self, ident: TokenStream) -> TokenStream { + match self.label { + Label::Optional => quote!(#ident = ::core::option::Option::None), + Label::Required => quote!(#ident.clear()), + Label::Repeated => quote!(#ident.clear()), + } + } +} diff --git a/third_party/rust/prost-derive/src/field/mod.rs b/third_party/rust/prost-derive/src/field/mod.rs new file mode 100644 index 0000000000..09fef830ef --- /dev/null +++ b/third_party/rust/prost-derive/src/field/mod.rs @@ -0,0 +1,366 @@ +mod group; +mod map; +mod message; +mod oneof; +mod scalar; + +use std::fmt; +use std::slice; + +use anyhow::{bail, Error}; +use proc_macro2::TokenStream; +use quote::quote; +use syn::{Attribute, Ident, Lit, LitBool, Meta, MetaList, MetaNameValue, NestedMeta}; + +#[derive(Clone)] +pub enum Field { + /// A scalar field. + Scalar(scalar::Field), + /// A message field. + Message(message::Field), + /// A map field. + Map(map::Field), + /// A oneof field. + Oneof(oneof::Field), + /// A group field. + Group(group::Field), +} + +impl Field { + /// Creates a new `Field` from an iterator of field attributes. + /// + /// If the meta items are invalid, an error will be returned. + /// If the field should be ignored, `None` is returned. + pub fn new(attrs: Vec<Attribute>, inferred_tag: Option<u32>) -> Result<Option<Field>, Error> { + let attrs = prost_attrs(attrs); + + // TODO: check for ignore attribute. + + let field = if let Some(field) = scalar::Field::new(&attrs, inferred_tag)? { + Field::Scalar(field) + } else if let Some(field) = message::Field::new(&attrs, inferred_tag)? { + Field::Message(field) + } else if let Some(field) = map::Field::new(&attrs, inferred_tag)? { + Field::Map(field) + } else if let Some(field) = oneof::Field::new(&attrs)? { + Field::Oneof(field) + } else if let Some(field) = group::Field::new(&attrs, inferred_tag)? { + Field::Group(field) + } else { + bail!("no type attribute"); + }; + + Ok(Some(field)) + } + + /// Creates a new oneof `Field` from an iterator of field attributes. + /// + /// If the meta items are invalid, an error will be returned. + /// If the field should be ignored, `None` is returned. + pub fn new_oneof(attrs: Vec<Attribute>) -> Result<Option<Field>, Error> { + let attrs = prost_attrs(attrs); + + // TODO: check for ignore attribute. + + let field = if let Some(field) = scalar::Field::new_oneof(&attrs)? { + Field::Scalar(field) + } else if let Some(field) = message::Field::new_oneof(&attrs)? { + Field::Message(field) + } else if let Some(field) = map::Field::new_oneof(&attrs)? { + Field::Map(field) + } else if let Some(field) = group::Field::new_oneof(&attrs)? { + Field::Group(field) + } else { + bail!("no type attribute for oneof field"); + }; + + Ok(Some(field)) + } + + pub fn tags(&self) -> Vec<u32> { + match *self { + Field::Scalar(ref scalar) => vec![scalar.tag], + Field::Message(ref message) => vec![message.tag], + Field::Map(ref map) => vec![map.tag], + Field::Oneof(ref oneof) => oneof.tags.clone(), + Field::Group(ref group) => vec![group.tag], + } + } + + /// Returns a statement which encodes the field. + pub fn encode(&self, ident: TokenStream) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => scalar.encode(ident), + Field::Message(ref message) => message.encode(ident), + Field::Map(ref map) => map.encode(ident), + Field::Oneof(ref oneof) => oneof.encode(ident), + Field::Group(ref group) => group.encode(ident), + } + } + + /// Returns an expression which evaluates to the result of merging a decoded + /// value into the field. + pub fn merge(&self, ident: TokenStream) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => scalar.merge(ident), + Field::Message(ref message) => message.merge(ident), + Field::Map(ref map) => map.merge(ident), + Field::Oneof(ref oneof) => oneof.merge(ident), + Field::Group(ref group) => group.merge(ident), + } + } + + /// Returns an expression which evaluates to the encoded length of the field. + pub fn encoded_len(&self, ident: TokenStream) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => scalar.encoded_len(ident), + Field::Map(ref map) => map.encoded_len(ident), + Field::Message(ref msg) => msg.encoded_len(ident), + Field::Oneof(ref oneof) => oneof.encoded_len(ident), + Field::Group(ref group) => group.encoded_len(ident), + } + } + + /// Returns a statement which clears the field. + pub fn clear(&self, ident: TokenStream) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => scalar.clear(ident), + Field::Message(ref message) => message.clear(ident), + Field::Map(ref map) => map.clear(ident), + Field::Oneof(ref oneof) => oneof.clear(ident), + Field::Group(ref group) => group.clear(ident), + } + } + + pub fn default(&self) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => scalar.default(), + _ => quote!(::core::default::Default::default()), + } + } + + /// Produces the fragment implementing debug for the given field. + pub fn debug(&self, ident: TokenStream) -> TokenStream { + match *self { + Field::Scalar(ref scalar) => { + let wrapper = scalar.debug(quote!(ScalarWrapper)); + quote! { + { + #wrapper + ScalarWrapper(&#ident) + } + } + } + Field::Map(ref map) => { + let wrapper = map.debug(quote!(MapWrapper)); + quote! { + { + #wrapper + MapWrapper(&#ident) + } + } + } + _ => quote!(&#ident), + } + } + + pub fn methods(&self, ident: &Ident) -> Option<TokenStream> { + match *self { + Field::Scalar(ref scalar) => scalar.methods(ident), + Field::Map(ref map) => map.methods(ident), + _ => None, + } + } +} + +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum Label { + /// An optional field. + Optional, + /// A required field. + Required, + /// A repeated field. + Repeated, +} + +impl Label { + fn as_str(self) -> &'static str { + match self { + Label::Optional => "optional", + Label::Required => "required", + Label::Repeated => "repeated", + } + } + + fn variants() -> slice::Iter<'static, Label> { + const VARIANTS: &[Label] = &[Label::Optional, Label::Required, Label::Repeated]; + VARIANTS.iter() + } + + /// Parses a string into a field label. + /// If the string doesn't match a field label, `None` is returned. + fn from_attr(attr: &Meta) -> Option<Label> { + if let Meta::Path(ref path) = *attr { + for &label in Label::variants() { + if path.is_ident(label.as_str()) { + return Some(label); + } + } + } + None + } +} + +impl fmt::Debug for Label { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +impl fmt::Display for Label { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +/// Get the items belonging to the 'prost' list attribute, e.g. `#[prost(foo, bar="baz")]`. +fn prost_attrs(attrs: Vec<Attribute>) -> Vec<Meta> { + attrs + .iter() + .flat_map(Attribute::parse_meta) + .flat_map(|meta| match meta { + Meta::List(MetaList { path, nested, .. }) => { + if path.is_ident("prost") { + nested.into_iter().collect() + } else { + Vec::new() + } + } + _ => Vec::new(), + }) + .flat_map(|attr| -> Result<_, _> { + match attr { + NestedMeta::Meta(attr) => Ok(attr), + NestedMeta::Lit(lit) => bail!("invalid prost attribute: {:?}", lit), + } + }) + .collect() +} + +pub fn set_option<T>(option: &mut Option<T>, value: T, message: &str) -> Result<(), Error> +where + T: fmt::Debug, +{ + if let Some(ref existing) = *option { + bail!("{}: {:?} and {:?}", message, existing, value); + } + *option = Some(value); + Ok(()) +} + +pub fn set_bool(b: &mut bool, message: &str) -> Result<(), Error> { + if *b { + bail!("{}", message); + } else { + *b = true; + Ok(()) + } +} + +/// Unpacks an attribute into a (key, boolean) pair, returning the boolean value. +/// If the key doesn't match the attribute, `None` is returned. +fn bool_attr(key: &str, attr: &Meta) -> Result<Option<bool>, Error> { + if !attr.path().is_ident(key) { + return Ok(None); + } + match *attr { + Meta::Path(..) => Ok(Some(true)), + Meta::List(ref meta_list) => { + // TODO(rustlang/rust#23121): slice pattern matching would make this much nicer. + if meta_list.nested.len() == 1 { + if let NestedMeta::Lit(Lit::Bool(LitBool { value, .. })) = meta_list.nested[0] { + return Ok(Some(value)); + } + } + bail!("invalid {} attribute", key); + } + Meta::NameValue(MetaNameValue { + lit: Lit::Str(ref lit), + .. + }) => lit + .value() + .parse::<bool>() + .map_err(Error::from) + .map(Option::Some), + Meta::NameValue(MetaNameValue { + lit: Lit::Bool(LitBool { value, .. }), + .. + }) => Ok(Some(value)), + _ => bail!("invalid {} attribute", key), + } +} + +/// Checks if an attribute matches a word. +fn word_attr(key: &str, attr: &Meta) -> bool { + if let Meta::Path(ref path) = *attr { + path.is_ident(key) + } else { + false + } +} + +pub(super) fn tag_attr(attr: &Meta) -> Result<Option<u32>, Error> { + if !attr.path().is_ident("tag") { + return Ok(None); + } + match *attr { + Meta::List(ref meta_list) => { + // TODO(rustlang/rust#23121): slice pattern matching would make this much nicer. + if meta_list.nested.len() == 1 { + if let NestedMeta::Lit(Lit::Int(ref lit)) = meta_list.nested[0] { + return Ok(Some(lit.base10_parse()?)); + } + } + bail!("invalid tag attribute: {:?}", attr); + } + Meta::NameValue(ref meta_name_value) => match meta_name_value.lit { + Lit::Str(ref lit) => lit + .value() + .parse::<u32>() + .map_err(Error::from) + .map(Option::Some), + Lit::Int(ref lit) => Ok(Some(lit.base10_parse()?)), + _ => bail!("invalid tag attribute: {:?}", attr), + }, + _ => bail!("invalid tag attribute: {:?}", attr), + } +} + +fn tags_attr(attr: &Meta) -> Result<Option<Vec<u32>>, Error> { + if !attr.path().is_ident("tags") { + return Ok(None); + } + match *attr { + Meta::List(ref meta_list) => { + let mut tags = Vec::with_capacity(meta_list.nested.len()); + for item in &meta_list.nested { + if let NestedMeta::Lit(Lit::Int(ref lit)) = *item { + tags.push(lit.base10_parse()?); + } else { + bail!("invalid tag attribute: {:?}", attr); + } + } + Ok(Some(tags)) + } + Meta::NameValue(MetaNameValue { + lit: Lit::Str(ref lit), + .. + }) => lit + .value() + .split(',') + .map(|s| s.trim().parse::<u32>().map_err(Error::from)) + .collect::<Result<Vec<u32>, _>>() + .map(Some), + _ => bail!("invalid tag attribute: {:?}", attr), + } +} diff --git a/third_party/rust/prost-derive/src/field/oneof.rs b/third_party/rust/prost-derive/src/field/oneof.rs new file mode 100644 index 0000000000..7e7f08671c --- /dev/null +++ b/third_party/rust/prost-derive/src/field/oneof.rs @@ -0,0 +1,99 @@ +use anyhow::{bail, Error}; +use proc_macro2::TokenStream; +use quote::quote; +use syn::{parse_str, Lit, Meta, MetaNameValue, NestedMeta, Path}; + +use crate::field::{set_option, tags_attr}; + +#[derive(Clone)] +pub struct Field { + pub ty: Path, + pub tags: Vec<u32>, +} + +impl Field { + pub fn new(attrs: &[Meta]) -> Result<Option<Field>, Error> { + let mut ty = None; + let mut tags = None; + let mut unknown_attrs = Vec::new(); + + for attr in attrs { + if attr.path().is_ident("oneof") { + let t = match *attr { + Meta::NameValue(MetaNameValue { + lit: Lit::Str(ref lit), + .. + }) => parse_str::<Path>(&lit.value())?, + Meta::List(ref list) if list.nested.len() == 1 => { + // TODO(rustlang/rust#23121): slice pattern matching would make this much nicer. + if let NestedMeta::Meta(Meta::Path(ref path)) = list.nested[0] { + if let Some(ident) = path.get_ident() { + Path::from(ident.clone()) + } else { + bail!("invalid oneof attribute: item must be an identifier"); + } + } else { + bail!("invalid oneof attribute: item must be an identifier"); + } + } + _ => bail!("invalid oneof attribute: {:?}", attr), + }; + set_option(&mut ty, t, "duplicate oneof attribute")?; + } else if let Some(t) = tags_attr(attr)? { + set_option(&mut tags, t, "duplicate tags attributes")?; + } else { + unknown_attrs.push(attr); + } + } + + let ty = match ty { + Some(ty) => ty, + None => return Ok(None), + }; + + match unknown_attrs.len() { + 0 => (), + 1 => bail!( + "unknown attribute for message field: {:?}", + unknown_attrs[0] + ), + _ => bail!("unknown attributes for message field: {:?}", unknown_attrs), + } + + let tags = match tags { + Some(tags) => tags, + None => bail!("oneof field is missing a tags attribute"), + }; + + Ok(Some(Field { ty, tags })) + } + + /// Returns a statement which encodes the oneof field. + pub fn encode(&self, ident: TokenStream) -> TokenStream { + quote! { + if let Some(ref oneof) = #ident { + oneof.encode(buf) + } + } + } + + /// Returns an expression which evaluates to the result of decoding the oneof field. + pub fn merge(&self, ident: TokenStream) -> TokenStream { + let ty = &self.ty; + quote! { + #ty::merge(#ident, tag, wire_type, buf, ctx) + } + } + + /// Returns an expression which evaluates to the encoded length of the oneof field. + pub fn encoded_len(&self, ident: TokenStream) -> TokenStream { + let ty = &self.ty; + quote! { + #ident.as_ref().map_or(0, #ty::encoded_len) + } + } + + pub fn clear(&self, ident: TokenStream) -> TokenStream { + quote!(#ident = ::core::option::Option::None) + } +} diff --git a/third_party/rust/prost-derive/src/field/scalar.rs b/third_party/rust/prost-derive/src/field/scalar.rs new file mode 100644 index 0000000000..75a4fe3a86 --- /dev/null +++ b/third_party/rust/prost-derive/src/field/scalar.rs @@ -0,0 +1,810 @@ +use std::convert::TryFrom; +use std::fmt; + +use anyhow::{anyhow, bail, Error}; +use proc_macro2::{Span, TokenStream}; +use quote::{quote, ToTokens, TokenStreamExt}; +use syn::{parse_str, Ident, Lit, LitByteStr, Meta, MetaList, MetaNameValue, NestedMeta, Path}; + +use crate::field::{bool_attr, set_option, tag_attr, Label}; + +/// A scalar protobuf field. +#[derive(Clone)] +pub struct Field { + pub ty: Ty, + pub kind: Kind, + pub tag: u32, +} + +impl Field { + pub fn new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error> { + let mut ty = None; + let mut label = None; + let mut packed = None; + let mut default = None; + let mut tag = None; + + let mut unknown_attrs = Vec::new(); + + for attr in attrs { + if let Some(t) = Ty::from_attr(attr)? { + set_option(&mut ty, t, "duplicate type attributes")?; + } else if let Some(p) = bool_attr("packed", attr)? { + set_option(&mut packed, p, "duplicate packed attributes")?; + } else if let Some(t) = tag_attr(attr)? { + set_option(&mut tag, t, "duplicate tag attributes")?; + } else if let Some(l) = Label::from_attr(attr) { + set_option(&mut label, l, "duplicate label attributes")?; + } else if let Some(d) = DefaultValue::from_attr(attr)? { + set_option(&mut default, d, "duplicate default attributes")?; + } else { + unknown_attrs.push(attr); + } + } + + let ty = match ty { + Some(ty) => ty, + None => return Ok(None), + }; + + match unknown_attrs.len() { + 0 => (), + 1 => bail!("unknown attribute: {:?}", unknown_attrs[0]), + _ => bail!("unknown attributes: {:?}", unknown_attrs), + } + + let tag = match tag.or(inferred_tag) { + Some(tag) => tag, + None => bail!("missing tag attribute"), + }; + + let has_default = default.is_some(); + let default = default.map_or_else( + || Ok(DefaultValue::new(&ty)), + |lit| DefaultValue::from_lit(&ty, lit), + )?; + + let kind = match (label, packed, has_default) { + (None, Some(true), _) + | (Some(Label::Optional), Some(true), _) + | (Some(Label::Required), Some(true), _) => { + bail!("packed attribute may only be applied to repeated fields"); + } + (Some(Label::Repeated), Some(true), _) if !ty.is_numeric() => { + bail!("packed attribute may only be applied to numeric types"); + } + (Some(Label::Repeated), _, true) => { + bail!("repeated fields may not have a default value"); + } + + (None, _, _) => Kind::Plain(default), + (Some(Label::Optional), _, _) => Kind::Optional(default), + (Some(Label::Required), _, _) => Kind::Required(default), + (Some(Label::Repeated), packed, false) if packed.unwrap_or_else(|| ty.is_numeric()) => { + Kind::Packed + } + (Some(Label::Repeated), _, false) => Kind::Repeated, + }; + + Ok(Some(Field { ty, kind, tag })) + } + + pub fn new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error> { + if let Some(mut field) = Field::new(attrs, None)? { + match field.kind { + Kind::Plain(default) => { + field.kind = Kind::Required(default); + Ok(Some(field)) + } + Kind::Optional(..) => bail!("invalid optional attribute on oneof field"), + Kind::Required(..) => bail!("invalid required attribute on oneof field"), + Kind::Packed | Kind::Repeated => bail!("invalid repeated attribute on oneof field"), + } + } else { + Ok(None) + } + } + + pub fn encode(&self, ident: TokenStream) -> TokenStream { + let module = self.ty.module(); + let encode_fn = match self.kind { + Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(encode), + Kind::Repeated => quote!(encode_repeated), + Kind::Packed => quote!(encode_packed), + }; + let encode_fn = quote!(::prost::encoding::#module::#encode_fn); + let tag = self.tag; + + match self.kind { + Kind::Plain(ref default) => { + let default = default.typed(); + quote! { + if #ident != #default { + #encode_fn(#tag, &#ident, buf); + } + } + } + Kind::Optional(..) => quote! { + if let ::core::option::Option::Some(ref value) = #ident { + #encode_fn(#tag, value, buf); + } + }, + Kind::Required(..) | Kind::Repeated | Kind::Packed => quote! { + #encode_fn(#tag, &#ident, buf); + }, + } + } + + /// Returns an expression which evaluates to the result of merging a decoded + /// scalar value into the field. + pub fn merge(&self, ident: TokenStream) -> TokenStream { + let module = self.ty.module(); + let merge_fn = match self.kind { + Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(merge), + Kind::Repeated | Kind::Packed => quote!(merge_repeated), + }; + let merge_fn = quote!(::prost::encoding::#module::#merge_fn); + + match self.kind { + Kind::Plain(..) | Kind::Required(..) | Kind::Repeated | Kind::Packed => quote! { + #merge_fn(wire_type, #ident, buf, ctx) + }, + Kind::Optional(..) => quote! { + #merge_fn(wire_type, + #ident.get_or_insert_with(Default::default), + buf, + ctx) + }, + } + } + + /// Returns an expression which evaluates to the encoded length of the field. + pub fn encoded_len(&self, ident: TokenStream) -> TokenStream { + let module = self.ty.module(); + let encoded_len_fn = match self.kind { + Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(encoded_len), + Kind::Repeated => quote!(encoded_len_repeated), + Kind::Packed => quote!(encoded_len_packed), + }; + let encoded_len_fn = quote!(::prost::encoding::#module::#encoded_len_fn); + let tag = self.tag; + + match self.kind { + Kind::Plain(ref default) => { + let default = default.typed(); + quote! { + if #ident != #default { + #encoded_len_fn(#tag, &#ident) + } else { + 0 + } + } + } + Kind::Optional(..) => quote! { + #ident.as_ref().map_or(0, |value| #encoded_len_fn(#tag, value)) + }, + Kind::Required(..) | Kind::Repeated | Kind::Packed => quote! { + #encoded_len_fn(#tag, &#ident) + }, + } + } + + pub fn clear(&self, ident: TokenStream) -> TokenStream { + match self.kind { + Kind::Plain(ref default) | Kind::Required(ref default) => { + let default = default.typed(); + match self.ty { + Ty::String | Ty::Bytes(..) => quote!(#ident.clear()), + _ => quote!(#ident = #default), + } + } + Kind::Optional(_) => quote!(#ident = ::core::option::Option::None), + Kind::Repeated | Kind::Packed => quote!(#ident.clear()), + } + } + + /// Returns an expression which evaluates to the default value of the field. + pub fn default(&self) -> TokenStream { + match self.kind { + Kind::Plain(ref value) | Kind::Required(ref value) => value.owned(), + Kind::Optional(_) => quote!(::core::option::Option::None), + Kind::Repeated | Kind::Packed => quote!(::prost::alloc::vec::Vec::new()), + } + } + + /// An inner debug wrapper, around the base type. + fn debug_inner(&self, wrap_name: TokenStream) -> TokenStream { + if let Ty::Enumeration(ref ty) = self.ty { + quote! { + struct #wrap_name<'a>(&'a i32); + impl<'a> ::core::fmt::Debug for #wrap_name<'a> { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + match #ty::from_i32(*self.0) { + None => ::core::fmt::Debug::fmt(&self.0, f), + Some(en) => ::core::fmt::Debug::fmt(&en, f), + } + } + } + } + } else { + quote! { + fn #wrap_name<T>(v: T) -> T { v } + } + } + } + + /// Returns a fragment for formatting the field `ident` in `Debug`. + pub fn debug(&self, wrapper_name: TokenStream) -> TokenStream { + let wrapper = self.debug_inner(quote!(Inner)); + let inner_ty = self.ty.rust_type(); + match self.kind { + Kind::Plain(_) | Kind::Required(_) => self.debug_inner(wrapper_name), + Kind::Optional(_) => quote! { + struct #wrapper_name<'a>(&'a ::core::option::Option<#inner_ty>); + impl<'a> ::core::fmt::Debug for #wrapper_name<'a> { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + #wrapper + ::core::fmt::Debug::fmt(&self.0.as_ref().map(Inner), f) + } + } + }, + Kind::Repeated | Kind::Packed => { + quote! { + struct #wrapper_name<'a>(&'a ::prost::alloc::vec::Vec<#inner_ty>); + impl<'a> ::core::fmt::Debug for #wrapper_name<'a> { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result { + let mut vec_builder = f.debug_list(); + for v in self.0 { + #wrapper + vec_builder.entry(&Inner(v)); + } + vec_builder.finish() + } + } + } + } + } + } + + /// Returns methods to embed in the message. + pub fn methods(&self, ident: &Ident) -> Option<TokenStream> { + let mut ident_str = ident.to_string(); + if ident_str.starts_with("r#") { + ident_str = ident_str[2..].to_owned(); + } + + if let Ty::Enumeration(ref ty) = self.ty { + let set = Ident::new(&format!("set_{}", ident_str), Span::call_site()); + let set_doc = format!("Sets `{}` to the provided enum value.", ident_str); + Some(match self.kind { + Kind::Plain(ref default) | Kind::Required(ref default) => { + let get_doc = format!( + "Returns the enum value of `{}`, \ + or the default if the field is set to an invalid enum value.", + ident_str, + ); + quote! { + #[doc=#get_doc] + pub fn #ident(&self) -> #ty { + #ty::from_i32(self.#ident).unwrap_or(#default) + } + + #[doc=#set_doc] + pub fn #set(&mut self, value: #ty) { + self.#ident = value as i32; + } + } + } + Kind::Optional(ref default) => { + let get_doc = format!( + "Returns the enum value of `{}`, \ + or the default if the field is unset or set to an invalid enum value.", + ident_str, + ); + quote! { + #[doc=#get_doc] + pub fn #ident(&self) -> #ty { + self.#ident.and_then(#ty::from_i32).unwrap_or(#default) + } + + #[doc=#set_doc] + pub fn #set(&mut self, value: #ty) { + self.#ident = ::core::option::Option::Some(value as i32); + } + } + } + Kind::Repeated | Kind::Packed => { + let iter_doc = format!( + "Returns an iterator which yields the valid enum values contained in `{}`.", + ident_str, + ); + let push = Ident::new(&format!("push_{}", ident_str), Span::call_site()); + let push_doc = format!("Appends the provided enum value to `{}`.", ident_str); + quote! { + #[doc=#iter_doc] + pub fn #ident(&self) -> ::core::iter::FilterMap< + ::core::iter::Cloned<::core::slice::Iter<i32>>, + fn(i32) -> ::core::option::Option<#ty>, + > { + self.#ident.iter().cloned().filter_map(#ty::from_i32) + } + #[doc=#push_doc] + pub fn #push(&mut self, value: #ty) { + self.#ident.push(value as i32); + } + } + } + }) + } else if let Kind::Optional(ref default) = self.kind { + let ty = self.ty.rust_ref_type(); + + let match_some = if self.ty.is_numeric() { + quote!(::core::option::Option::Some(val) => val,) + } else { + quote!(::core::option::Option::Some(ref val) => &val[..],) + }; + + let get_doc = format!( + "Returns the value of `{0}`, or the default value if `{0}` is unset.", + ident_str, + ); + + Some(quote! { + #[doc=#get_doc] + pub fn #ident(&self) -> #ty { + match self.#ident { + #match_some + ::core::option::Option::None => #default, + } + } + }) + } else { + None + } + } +} + +/// A scalar protobuf field type. +#[derive(Clone, PartialEq, Eq)] +pub enum Ty { + Double, + Float, + Int32, + Int64, + Uint32, + Uint64, + Sint32, + Sint64, + Fixed32, + Fixed64, + Sfixed32, + Sfixed64, + Bool, + String, + Bytes(BytesTy), + Enumeration(Path), +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum BytesTy { + Vec, + Bytes, +} + +impl BytesTy { + fn try_from_str(s: &str) -> Result<Self, Error> { + match s { + "vec" => Ok(BytesTy::Vec), + "bytes" => Ok(BytesTy::Bytes), + _ => bail!("Invalid bytes type: {}", s), + } + } + + fn rust_type(&self) -> TokenStream { + match self { + BytesTy::Vec => quote! { ::prost::alloc::vec::Vec<u8> }, + BytesTy::Bytes => quote! { ::prost::bytes::Bytes }, + } + } +} + +impl Ty { + pub fn from_attr(attr: &Meta) -> Result<Option<Ty>, Error> { + let ty = match *attr { + Meta::Path(ref name) if name.is_ident("float") => Ty::Float, + Meta::Path(ref name) if name.is_ident("double") => Ty::Double, + Meta::Path(ref name) if name.is_ident("int32") => Ty::Int32, + Meta::Path(ref name) if name.is_ident("int64") => Ty::Int64, + Meta::Path(ref name) if name.is_ident("uint32") => Ty::Uint32, + Meta::Path(ref name) if name.is_ident("uint64") => Ty::Uint64, + Meta::Path(ref name) if name.is_ident("sint32") => Ty::Sint32, + Meta::Path(ref name) if name.is_ident("sint64") => Ty::Sint64, + Meta::Path(ref name) if name.is_ident("fixed32") => Ty::Fixed32, + Meta::Path(ref name) if name.is_ident("fixed64") => Ty::Fixed64, + Meta::Path(ref name) if name.is_ident("sfixed32") => Ty::Sfixed32, + Meta::Path(ref name) if name.is_ident("sfixed64") => Ty::Sfixed64, + Meta::Path(ref name) if name.is_ident("bool") => Ty::Bool, + Meta::Path(ref name) if name.is_ident("string") => Ty::String, + Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes(BytesTy::Vec), + Meta::NameValue(MetaNameValue { + ref path, + lit: Lit::Str(ref l), + .. + }) if path.is_ident("bytes") => Ty::Bytes(BytesTy::try_from_str(&l.value())?), + Meta::NameValue(MetaNameValue { + ref path, + lit: Lit::Str(ref l), + .. + }) if path.is_ident("enumeration") => Ty::Enumeration(parse_str::<Path>(&l.value())?), + Meta::List(MetaList { + ref path, + ref nested, + .. + }) if path.is_ident("enumeration") => { + // TODO(rustlang/rust#23121): slice pattern matching would make this much nicer. + if nested.len() == 1 { + if let NestedMeta::Meta(Meta::Path(ref path)) = nested[0] { + Ty::Enumeration(path.clone()) + } else { + bail!("invalid enumeration attribute: item must be an identifier"); + } + } else { + bail!("invalid enumeration attribute: only a single identifier is supported"); + } + } + _ => return Ok(None), + }; + Ok(Some(ty)) + } + + pub fn from_str(s: &str) -> Result<Ty, Error> { + let enumeration_len = "enumeration".len(); + let error = Err(anyhow!("invalid type: {}", s)); + let ty = match s.trim() { + "float" => Ty::Float, + "double" => Ty::Double, + "int32" => Ty::Int32, + "int64" => Ty::Int64, + "uint32" => Ty::Uint32, + "uint64" => Ty::Uint64, + "sint32" => Ty::Sint32, + "sint64" => Ty::Sint64, + "fixed32" => Ty::Fixed32, + "fixed64" => Ty::Fixed64, + "sfixed32" => Ty::Sfixed32, + "sfixed64" => Ty::Sfixed64, + "bool" => Ty::Bool, + "string" => Ty::String, + "bytes" => Ty::Bytes(BytesTy::Vec), + s if s.len() > enumeration_len && &s[..enumeration_len] == "enumeration" => { + let s = &s[enumeration_len..].trim(); + match s.chars().next() { + Some('<') | Some('(') => (), + _ => return error, + } + match s.chars().next_back() { + Some('>') | Some(')') => (), + _ => return error, + } + + Ty::Enumeration(parse_str::<Path>(s[1..s.len() - 1].trim())?) + } + _ => return error, + }; + Ok(ty) + } + + /// Returns the type as it appears in protobuf field declarations. + pub fn as_str(&self) -> &'static str { + match *self { + Ty::Double => "double", + Ty::Float => "float", + Ty::Int32 => "int32", + Ty::Int64 => "int64", + Ty::Uint32 => "uint32", + Ty::Uint64 => "uint64", + Ty::Sint32 => "sint32", + Ty::Sint64 => "sint64", + Ty::Fixed32 => "fixed32", + Ty::Fixed64 => "fixed64", + Ty::Sfixed32 => "sfixed32", + Ty::Sfixed64 => "sfixed64", + Ty::Bool => "bool", + Ty::String => "string", + Ty::Bytes(..) => "bytes", + Ty::Enumeration(..) => "enum", + } + } + + // TODO: rename to 'owned_type'. + pub fn rust_type(&self) -> TokenStream { + match self { + Ty::String => quote!(::prost::alloc::string::String), + Ty::Bytes(ty) => ty.rust_type(), + _ => self.rust_ref_type(), + } + } + + // TODO: rename to 'ref_type' + pub fn rust_ref_type(&self) -> TokenStream { + match *self { + Ty::Double => quote!(f64), + Ty::Float => quote!(f32), + Ty::Int32 => quote!(i32), + Ty::Int64 => quote!(i64), + Ty::Uint32 => quote!(u32), + Ty::Uint64 => quote!(u64), + Ty::Sint32 => quote!(i32), + Ty::Sint64 => quote!(i64), + Ty::Fixed32 => quote!(u32), + Ty::Fixed64 => quote!(u64), + Ty::Sfixed32 => quote!(i32), + Ty::Sfixed64 => quote!(i64), + Ty::Bool => quote!(bool), + Ty::String => quote!(&str), + Ty::Bytes(..) => quote!(&[u8]), + Ty::Enumeration(..) => quote!(i32), + } + } + + pub fn module(&self) -> Ident { + match *self { + Ty::Enumeration(..) => Ident::new("int32", Span::call_site()), + _ => Ident::new(self.as_str(), Span::call_site()), + } + } + + /// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`). + pub fn is_numeric(&self) -> bool { + !matches!(self, Ty::String | Ty::Bytes(..)) + } +} + +impl fmt::Debug for Ty { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +impl fmt::Display for Ty { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +/// Scalar Protobuf field types. +#[derive(Clone)] +pub enum Kind { + /// A plain proto3 scalar field. + Plain(DefaultValue), + /// An optional scalar field. + Optional(DefaultValue), + /// A required proto2 scalar field. + Required(DefaultValue), + /// A repeated scalar field. + Repeated, + /// A packed repeated scalar field. + Packed, +} + +/// Scalar Protobuf field default value. +#[derive(Clone, Debug)] +pub enum DefaultValue { + F64(f64), + F32(f32), + I32(i32), + I64(i64), + U32(u32), + U64(u64), + Bool(bool), + String(String), + Bytes(Vec<u8>), + Enumeration(TokenStream), + Path(Path), +} + +impl DefaultValue { + pub fn from_attr(attr: &Meta) -> Result<Option<Lit>, Error> { + if !attr.path().is_ident("default") { + Ok(None) + } else if let Meta::NameValue(ref name_value) = *attr { + Ok(Some(name_value.lit.clone())) + } else { + bail!("invalid default value attribute: {:?}", attr) + } + } + + pub fn from_lit(ty: &Ty, lit: Lit) -> Result<DefaultValue, Error> { + let is_i32 = *ty == Ty::Int32 || *ty == Ty::Sint32 || *ty == Ty::Sfixed32; + let is_i64 = *ty == Ty::Int64 || *ty == Ty::Sint64 || *ty == Ty::Sfixed64; + + let is_u32 = *ty == Ty::Uint32 || *ty == Ty::Fixed32; + let is_u64 = *ty == Ty::Uint64 || *ty == Ty::Fixed64; + + let empty_or_is = |expected, actual: &str| expected == actual || actual.is_empty(); + + let default = match lit { + Lit::Int(ref lit) if is_i32 && empty_or_is("i32", lit.suffix()) => { + DefaultValue::I32(lit.base10_parse()?) + } + Lit::Int(ref lit) if is_i64 && empty_or_is("i64", lit.suffix()) => { + DefaultValue::I64(lit.base10_parse()?) + } + Lit::Int(ref lit) if is_u32 && empty_or_is("u32", lit.suffix()) => { + DefaultValue::U32(lit.base10_parse()?) + } + Lit::Int(ref lit) if is_u64 && empty_or_is("u64", lit.suffix()) => { + DefaultValue::U64(lit.base10_parse()?) + } + + Lit::Float(ref lit) if *ty == Ty::Float && empty_or_is("f32", lit.suffix()) => { + DefaultValue::F32(lit.base10_parse()?) + } + Lit::Int(ref lit) if *ty == Ty::Float => DefaultValue::F32(lit.base10_parse()?), + + Lit::Float(ref lit) if *ty == Ty::Double && empty_or_is("f64", lit.suffix()) => { + DefaultValue::F64(lit.base10_parse()?) + } + Lit::Int(ref lit) if *ty == Ty::Double => DefaultValue::F64(lit.base10_parse()?), + + Lit::Bool(ref lit) if *ty == Ty::Bool => DefaultValue::Bool(lit.value), + Lit::Str(ref lit) if *ty == Ty::String => DefaultValue::String(lit.value()), + Lit::ByteStr(ref lit) + if *ty == Ty::Bytes(BytesTy::Bytes) || *ty == Ty::Bytes(BytesTy::Vec) => + { + DefaultValue::Bytes(lit.value()) + } + + Lit::Str(ref lit) => { + let value = lit.value(); + let value = value.trim(); + + if let Ty::Enumeration(ref path) = *ty { + let variant = Ident::new(value, Span::call_site()); + return Ok(DefaultValue::Enumeration(quote!(#path::#variant))); + } + + // Parse special floating point values. + if *ty == Ty::Float { + match value { + "inf" => { + return Ok(DefaultValue::Path(parse_str::<Path>( + "::core::f32::INFINITY", + )?)); + } + "-inf" => { + return Ok(DefaultValue::Path(parse_str::<Path>( + "::core::f32::NEG_INFINITY", + )?)); + } + "nan" => { + return Ok(DefaultValue::Path(parse_str::<Path>("::core::f32::NAN")?)); + } + _ => (), + } + } + if *ty == Ty::Double { + match value { + "inf" => { + return Ok(DefaultValue::Path(parse_str::<Path>( + "::core::f64::INFINITY", + )?)); + } + "-inf" => { + return Ok(DefaultValue::Path(parse_str::<Path>( + "::core::f64::NEG_INFINITY", + )?)); + } + "nan" => { + return Ok(DefaultValue::Path(parse_str::<Path>("::core::f64::NAN")?)); + } + _ => (), + } + } + + // Rust doesn't have a negative literals, so they have to be parsed specially. + if let Some(Ok(lit)) = value.strip_prefix('-').map(syn::parse_str::<Lit>) { + match lit { + Lit::Int(ref lit) if is_i32 && empty_or_is("i32", lit.suffix()) => { + // Initially parse into an i64, so that i32::MIN does not overflow. + let value: i64 = -lit.base10_parse()?; + return Ok(i32::try_from(value).map(DefaultValue::I32)?); + } + Lit::Int(ref lit) if is_i64 && empty_or_is("i64", lit.suffix()) => { + // Initially parse into an i128, so that i64::MIN does not overflow. + let value: i128 = -lit.base10_parse()?; + return Ok(i64::try_from(value).map(DefaultValue::I64)?); + } + Lit::Float(ref lit) + if *ty == Ty::Float && empty_or_is("f32", lit.suffix()) => + { + return Ok(DefaultValue::F32(-lit.base10_parse()?)); + } + Lit::Float(ref lit) + if *ty == Ty::Double && empty_or_is("f64", lit.suffix()) => + { + return Ok(DefaultValue::F64(-lit.base10_parse()?)); + } + Lit::Int(ref lit) if *ty == Ty::Float && lit.suffix().is_empty() => { + return Ok(DefaultValue::F32(-lit.base10_parse()?)); + } + Lit::Int(ref lit) if *ty == Ty::Double && lit.suffix().is_empty() => { + return Ok(DefaultValue::F64(-lit.base10_parse()?)); + } + _ => (), + } + } + match syn::parse_str::<Lit>(&value) { + Ok(Lit::Str(_)) => (), + Ok(lit) => return DefaultValue::from_lit(ty, lit), + _ => (), + } + bail!("invalid default value: {}", quote!(#value)); + } + _ => bail!("invalid default value: {}", quote!(#lit)), + }; + + Ok(default) + } + + pub fn new(ty: &Ty) -> DefaultValue { + match *ty { + Ty::Float => DefaultValue::F32(0.0), + Ty::Double => DefaultValue::F64(0.0), + Ty::Int32 | Ty::Sint32 | Ty::Sfixed32 => DefaultValue::I32(0), + Ty::Int64 | Ty::Sint64 | Ty::Sfixed64 => DefaultValue::I64(0), + Ty::Uint32 | Ty::Fixed32 => DefaultValue::U32(0), + Ty::Uint64 | Ty::Fixed64 => DefaultValue::U64(0), + + Ty::Bool => DefaultValue::Bool(false), + Ty::String => DefaultValue::String(String::new()), + Ty::Bytes(..) => DefaultValue::Bytes(Vec::new()), + Ty::Enumeration(ref path) => DefaultValue::Enumeration(quote!(#path::default())), + } + } + + pub fn owned(&self) -> TokenStream { + match *self { + DefaultValue::String(ref value) if value.is_empty() => { + quote!(::prost::alloc::string::String::new()) + } + DefaultValue::String(ref value) => quote!(#value.into()), + DefaultValue::Bytes(ref value) if value.is_empty() => quote!(Default::default()), + DefaultValue::Bytes(ref value) => { + let lit = LitByteStr::new(value, Span::call_site()); + quote!(#lit.as_ref().into()) + } + + ref other => other.typed(), + } + } + + pub fn typed(&self) -> TokenStream { + if let DefaultValue::Enumeration(_) = *self { + quote!(#self as i32) + } else { + quote!(#self) + } + } +} + +impl ToTokens for DefaultValue { + fn to_tokens(&self, tokens: &mut TokenStream) { + match *self { + DefaultValue::F64(value) => value.to_tokens(tokens), + DefaultValue::F32(value) => value.to_tokens(tokens), + DefaultValue::I32(value) => value.to_tokens(tokens), + DefaultValue::I64(value) => value.to_tokens(tokens), + DefaultValue::U32(value) => value.to_tokens(tokens), + DefaultValue::U64(value) => value.to_tokens(tokens), + DefaultValue::Bool(value) => value.to_tokens(tokens), + DefaultValue::String(ref value) => value.to_tokens(tokens), + DefaultValue::Bytes(ref value) => { + let byte_str = LitByteStr::new(value, Span::call_site()); + tokens.append_all(quote!(#byte_str as &[u8])); + } + DefaultValue::Enumeration(ref value) => value.to_tokens(tokens), + DefaultValue::Path(ref value) => value.to_tokens(tokens), + } + } +} |