diff options
Diffstat (limited to 'third_party/rust/serde_repr/src')
-rw-r--r-- | third_party/rust/serde_repr/src/lib.rs | 140 | ||||
-rw-r--r-- | third_party/rust/serde_repr/src/parse.rs | 114 |
2 files changed, 254 insertions, 0 deletions
diff --git a/third_party/rust/serde_repr/src/lib.rs b/third_party/rust/serde_repr/src/lib.rs new file mode 100644 index 0000000000..2ad67df065 --- /dev/null +++ b/third_party/rust/serde_repr/src/lib.rs @@ -0,0 +1,140 @@ +//! [![github]](https://github.com/dtolnay/serde-repr) [![crates-io]](https://crates.io/crates/serde_repr) [![docs-rs]](https://docs.rs/serde_repr) +//! +//! [github]: https://img.shields.io/badge/github-8da0cb?style=for-the-badge&labelColor=555555&logo=github +//! [crates-io]: https://img.shields.io/badge/crates.io-fc8d62?style=for-the-badge&labelColor=555555&logo=rust +//! [docs-rs]: https://img.shields.io/badge/docs.rs-66c2a5?style=for-the-badge&labelColor=555555&logo=docs.rs +//! +//! <br> +//! +//! Derive `Serialize` and `Deserialize` that delegates to the underlying repr +//! of a C-like enum. +//! +//! # Examples +//! +//! ``` +//! use serde_repr::{Serialize_repr, Deserialize_repr}; +//! +//! #[derive(Serialize_repr, Deserialize_repr, PartialEq, Debug)] +//! #[repr(u8)] +//! enum SmallPrime { +//! Two = 2, +//! Three = 3, +//! Five = 5, +//! Seven = 7, +//! } +//! +//! fn main() -> serde_json::Result<()> { +//! let j = serde_json::to_string(&SmallPrime::Seven)?; +//! assert_eq!(j, "7"); +//! +//! let p: SmallPrime = serde_json::from_str("2")?; +//! assert_eq!(p, SmallPrime::Two); +//! +//! Ok(()) +//! } +//! ``` + +#![allow(clippy::single_match_else)] + +extern crate proc_macro; + +mod parse; + +use proc_macro::TokenStream; +use quote::quote; +use syn::parse_macro_input; + +use crate::parse::Input; + +#[proc_macro_derive(Serialize_repr)] +pub fn derive_serialize(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as Input); + let ident = input.ident; + let repr = input.repr; + + let match_variants = input.variants.iter().map(|variant| { + let variant = &variant.ident; + quote! { + #ident::#variant => #ident::#variant as #repr, + } + }); + + TokenStream::from(quote! { + impl serde::Serialize for #ident { + #[allow(clippy::use_self)] + fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error> + where + S: serde::Serializer + { + let value: #repr = match *self { + #(#match_variants)* + }; + serde::Serialize::serialize(&value, serializer) + } + } + }) +} + +#[proc_macro_derive(Deserialize_repr, attributes(serde))] +pub fn derive_deserialize(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as Input); + let ident = input.ident; + let repr = input.repr; + let variants = input.variants.iter().map(|variant| &variant.ident); + + let declare_discriminants = input.variants.iter().map(|variant| { + let variant = &variant.ident; + quote! { + const #variant: #repr = #ident::#variant as #repr; + } + }); + + let match_discriminants = input.variants.iter().map(|variant| { + let variant = &variant.ident; + quote! { + discriminant::#variant => core::result::Result::Ok(#ident::#variant), + } + }); + + let error_format = match input.variants.len() { + 1 => "invalid value: {}, expected {}".to_owned(), + 2 => "invalid value: {}, expected {} or {}".to_owned(), + n => "invalid value: {}, expected one of: {}".to_owned() + &", {}".repeat(n - 1), + }; + + let other_arm = match input.default_variant { + Some(variant) => { + let variant = &variant.ident; + quote! { + core::result::Result::Ok(#ident::#variant) + } + } + None => quote! { + core::result::Result::Err(serde::de::Error::custom( + format_args!(#error_format, other #(, discriminant::#variants)*) + )) + }, + }; + + TokenStream::from(quote! { + impl<'de> serde::Deserialize<'de> for #ident { + #[allow(clippy::use_self)] + fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + struct discriminant; + + #[allow(non_upper_case_globals)] + impl discriminant { + #(#declare_discriminants)* + } + + match <#repr as serde::Deserialize>::deserialize(deserializer)? { + #(#match_discriminants)* + other => #other_arm, + } + } + } + }) +} diff --git a/third_party/rust/serde_repr/src/parse.rs b/third_party/rust/serde_repr/src/parse.rs new file mode 100644 index 0000000000..2c0ad1e7dc --- /dev/null +++ b/third_party/rust/serde_repr/src/parse.rs @@ -0,0 +1,114 @@ +use proc_macro2::Span; +use syn::parse::{Error, Parse, ParseStream, Parser, Result}; +use syn::{parenthesized, Data, DeriveInput, Fields, Ident, Meta, NestedMeta}; + +pub struct Input { + pub ident: Ident, + pub repr: Ident, + pub variants: Vec<Variant>, + pub default_variant: Option<Variant>, +} + +#[derive(Clone)] +pub struct Variant { + pub ident: Ident, + pub attrs: VariantAttrs, +} + +#[derive(Clone)] +pub struct VariantAttrs { + pub is_default: bool, +} + +fn parse_meta(attrs: &mut VariantAttrs, meta: &Meta) { + if let Meta::List(value) = meta { + for meta in &value.nested { + if let NestedMeta::Meta(Meta::Path(path)) = meta { + if path.is_ident("other") { + attrs.is_default = true; + } + } + } + } +} + +fn parse_attrs(variant: &syn::Variant) -> Result<VariantAttrs> { + let mut attrs = VariantAttrs { is_default: false }; + for attr in &variant.attrs { + if attr.path.is_ident("serde") { + parse_meta(&mut attrs, &attr.parse_meta()?); + } + } + Ok(attrs) +} + +impl Parse for Input { + fn parse(input: ParseStream) -> Result<Self> { + let call_site = Span::call_site(); + let derive_input = DeriveInput::parse(input)?; + + let data = match derive_input.data { + Data::Enum(data) => data, + _ => { + return Err(Error::new(call_site, "input must be an enum")); + } + }; + + let variants = data + .variants + .into_iter() + .map(|variant| match variant.fields { + Fields::Unit => { + let attrs = parse_attrs(&variant)?; + Ok(Variant { + ident: variant.ident, + attrs, + }) + } + Fields::Named(_) | Fields::Unnamed(_) => { + Err(Error::new(variant.ident.span(), "must be a unit variant")) + } + }) + .collect::<Result<Vec<Variant>>>()?; + + if variants.is_empty() { + return Err(Error::new(call_site, "there must be at least one variant")); + } + + let generics = derive_input.generics; + if !generics.params.is_empty() || generics.where_clause.is_some() { + return Err(Error::new(call_site, "generic enum is not supported")); + } + + let mut repr = None; + for attr in derive_input.attrs { + if attr.path.is_ident("repr") { + fn repr_arg(input: ParseStream) -> Result<Ident> { + let content; + parenthesized!(content in input); + content.parse() + } + let ty = repr_arg.parse2(attr.tokens)?; + repr = Some(ty); + break; + } + } + let repr = repr.ok_or_else(|| Error::new(call_site, "missing #[repr(...)] attribute"))?; + + let mut default_variants = variants.iter().filter(|x| x.attrs.is_default); + let default_variant = default_variants.next().cloned(); + if default_variants.next().is_some() { + return Err(Error::new( + call_site, + "only one variant can be #[serde(other)]", + )); + } + + Ok(Input { + ident: derive_input.ident, + repr, + variants, + default_variant, + }) + } +} |