use proc_macro2::TokenStream; use quote::{quote, quote_spanned}; use syn::spanned::Spanned; use crate::attrs::*; use crate::utils::*; type Variants = syn::punctuated::Punctuated; /// Defines and implements `config_type` enum. pub fn define_config_type_on_enum(em: &syn::ItemEnum) -> syn::Result { let syn::ItemEnum { vis, enum_token, ident, generics, variants, .. } = em; let mod_name_str = format!("__define_config_type_on_enum_{}", ident); let mod_name = syn::Ident::new(&mod_name_str, ident.span()); let variants = fold_quote(variants.iter().map(process_variant), |meta| quote!(#meta,)); let impl_doc_hint = impl_doc_hint(&em.ident, &em.variants); let impl_from_str = impl_from_str(&em.ident, &em.variants); let impl_display = impl_display(&em.ident, &em.variants); let impl_serde = impl_serde(&em.ident, &em.variants); let impl_deserialize = impl_deserialize(&em.ident, &em.variants); Ok(quote! { #[allow(non_snake_case)] mod #mod_name { #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub #enum_token #ident #generics { #variants } #impl_display #impl_doc_hint #impl_from_str #impl_serde #impl_deserialize } #vis use #mod_name::#ident; }) } /// Remove attributes specific to `config_proc_macro` from enum variant fields. fn process_variant(variant: &syn::Variant) -> TokenStream { let metas = variant .attrs .iter() .filter(|attr| !is_doc_hint(attr) && !is_config_value(attr) && !is_unstable_variant(attr)); let attrs = fold_quote(metas, |meta| quote!(#meta)); let syn::Variant { ident, fields, .. } = variant; quote!(#attrs #ident #fields) } /// Return the correct syntax to pattern match on the enum variant, discarding all /// internal field data. fn fields_in_variant(variant: &syn::Variant) -> TokenStream { // With thanks to https://stackoverflow.com/a/65182902 match &variant.fields { syn::Fields::Unnamed(_) => quote_spanned! { variant.span() => (..) }, syn::Fields::Unit => quote_spanned! { variant.span() => }, syn::Fields::Named(_) => quote_spanned! { variant.span() => {..} }, } } fn impl_doc_hint(ident: &syn::Ident, variants: &Variants) -> TokenStream { let doc_hint = variants .iter() .map(doc_hint_of_variant) .collect::>() .join("|"); let doc_hint = format!("[{}]", doc_hint); let variant_stables = variants .iter() .map(|v| (&v.ident, fields_in_variant(&v), !unstable_of_variant(v))); let match_patterns = fold_quote(variant_stables, |(v, fields, stable)| { quote! { #ident::#v #fields => #stable, } }); quote! { use crate::config::ConfigType; impl ConfigType for #ident { fn doc_hint() -> String { #doc_hint.to_owned() } fn stable_variant(&self) -> bool { match self { #match_patterns } } } } } fn impl_display(ident: &syn::Ident, variants: &Variants) -> TokenStream { let vs = variants .iter() .filter(|v| is_unit(v)) .map(|v| (config_value_of_variant(v), &v.ident)); let match_patterns = fold_quote(vs, |(s, v)| { quote! { #ident::#v => write!(f, "{}", #s), } }); quote! { use std::fmt; impl fmt::Display for #ident { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { #match_patterns _ => unimplemented!(), } } } } } fn impl_from_str(ident: &syn::Ident, variants: &Variants) -> TokenStream { let vs = variants .iter() .filter(|v| is_unit(v)) .map(|v| (config_value_of_variant(v), &v.ident)); let if_patterns = fold_quote(vs, |(s, v)| { quote! { if #s.eq_ignore_ascii_case(s) { return Ok(#ident::#v); } } }); let mut err_msg = String::from("Bad variant, expected one of:"); for v in variants.iter().filter(|v| is_unit(v)) { err_msg.push_str(&format!(" `{}`", v.ident)); } quote! { impl ::std::str::FromStr for #ident { type Err = &'static str; fn from_str(s: &str) -> Result { #if_patterns return Err(#err_msg); } } } } fn doc_hint_of_variant(variant: &syn::Variant) -> String { let mut text = find_doc_hint(&variant.attrs).unwrap_or(variant.ident.to_string()); if unstable_of_variant(&variant) { text.push_str(" (unstable)") }; text } fn config_value_of_variant(variant: &syn::Variant) -> String { find_config_value(&variant.attrs).unwrap_or(variant.ident.to_string()) } fn unstable_of_variant(variant: &syn::Variant) -> bool { any_unstable_variant(&variant.attrs) } fn impl_serde(ident: &syn::Ident, variants: &Variants) -> TokenStream { let arms = fold_quote(variants.iter(), |v| { let v_ident = &v.ident; let pattern = match v.fields { syn::Fields::Named(..) => quote!(#ident::v_ident{..}), syn::Fields::Unnamed(..) => quote!(#ident::#v_ident(..)), syn::Fields::Unit => quote!(#ident::#v_ident), }; let option_value = config_value_of_variant(v); quote! { #pattern => serializer.serialize_str(&#option_value), } }); quote! { impl ::serde::ser::Serialize for #ident { fn serialize(&self, serializer: S) -> Result where S: ::serde::ser::Serializer, { use serde::ser::Error; match self { #arms _ => Err(S::Error::custom(format!("Cannot serialize {:?}", self))), } } } } } // Currently only unit variants are supported. fn impl_deserialize(ident: &syn::Ident, variants: &Variants) -> TokenStream { let supported_vs = variants.iter().filter(|v| is_unit(v)); let if_patterns = fold_quote(supported_vs, |v| { let config_value = config_value_of_variant(v); let variant_ident = &v.ident; quote! { if #config_value.eq_ignore_ascii_case(s) { return Ok(#ident::#variant_ident); } } }); let supported_vs = variants.iter().filter(|v| is_unit(v)); let allowed = fold_quote(supported_vs.map(config_value_of_variant), |s| quote!(#s,)); quote! { impl<'de> serde::de::Deserialize<'de> for #ident { fn deserialize(d: D) -> Result where D: serde::Deserializer<'de>, { use serde::de::{Error, Visitor}; use std::marker::PhantomData; use std::fmt; struct StringOnly(PhantomData); impl<'de, T> Visitor<'de> for StringOnly where T: serde::Deserializer<'de> { type Value = String; fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { formatter.write_str("string") } fn visit_str(self, value: &str) -> Result { Ok(String::from(value)) } } let s = &d.deserialize_string(StringOnly::(PhantomData))?; #if_patterns static ALLOWED: &'static[&str] = &[#allowed]; Err(D::Error::unknown_variant(&s, ALLOWED)) } } } }