summaryrefslogtreecommitdiffstats
path: root/third_party/rust/derive_more-impl/src/unwrap.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/derive_more-impl/src/unwrap.rs')
-rw-r--r--third_party/rust/derive_more-impl/src/unwrap.rs169
1 files changed, 169 insertions, 0 deletions
diff --git a/third_party/rust/derive_more-impl/src/unwrap.rs b/third_party/rust/derive_more-impl/src/unwrap.rs
new file mode 100644
index 0000000000..02a9010c4c
--- /dev/null
+++ b/third_party/rust/derive_more-impl/src/unwrap.rs
@@ -0,0 +1,169 @@
+use crate::utils::{AttrParams, DeriveType, State};
+use convert_case::{Case, Casing};
+use proc_macro2::TokenStream;
+use quote::{format_ident, quote};
+use syn::{DeriveInput, Fields, Ident, Result, Type};
+
+pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> {
+ let state = State::with_attr_params(
+ input,
+ trait_name,
+ "unwrap".into(),
+ AttrParams {
+ enum_: vec!["ignore", "owned", "ref", "ref_mut"],
+ variant: vec!["ignore", "owned", "ref", "ref_mut"],
+ struct_: vec!["ignore"],
+ field: vec!["ignore"],
+ },
+ )?;
+ assert!(
+ state.derive_type == DeriveType::Enum,
+ "Unwrap can only be derived for enums",
+ );
+
+ let enum_name = &input.ident;
+ let (imp_generics, type_generics, where_clause) = input.generics.split_for_impl();
+
+ let variant_data = state.enabled_variant_data();
+
+ let mut funcs = vec![];
+ for (variant_state, info) in
+ Iterator::zip(variant_data.variant_states.iter(), variant_data.infos)
+ {
+ let variant = variant_state.variant.unwrap();
+ let fn_name = format_ident!(
+ "unwrap_{ident}",
+ ident = variant.ident.to_string().to_case(Case::Snake),
+ span = variant.ident.span(),
+ );
+ let ref_fn_name = format_ident!(
+ "unwrap_{ident}_ref",
+ ident = variant.ident.to_string().to_case(Case::Snake),
+ span = variant.ident.span(),
+ );
+ let mut_fn_name = format_ident!(
+ "unwrap_{ident}_mut",
+ ident = variant.ident.to_string().to_case(Case::Snake),
+ span = variant.ident.span(),
+ );
+ let variant_ident = &variant.ident;
+ let (data_pattern, ret_value, data_types) = get_field_info(&variant.fields);
+ let pattern = quote! { #enum_name :: #variant_ident #data_pattern };
+
+ let (failed_block, failed_block_ref, failed_block_mut) = (
+ failed_block(&state, enum_name, &fn_name),
+ failed_block(&state, enum_name, &ref_fn_name),
+ failed_block(&state, enum_name, &mut_fn_name),
+ );
+
+ let doc_owned = format!(
+ "Unwraps this value to the `{enum_name}::{variant_ident}` variant.\n",
+ );
+ let doc_ref = format!(
+ "Unwraps this reference to the `{enum_name}::{variant_ident}` variant.\n",
+ );
+ let doc_mut = format!(
+ "Unwraps this mutable reference to the `{enum_name}::{variant_ident}` variant.\n",
+ );
+ let doc_else = "Panics if this value is of any other type.";
+
+ let func = quote! {
+ #[inline]
+ #[track_caller]
+ #[doc = #doc_owned]
+ #[doc = #doc_else]
+ pub fn #fn_name(self) -> (#(#data_types),*) {
+ match self {
+ #pattern => #ret_value,
+ val @ _ => #failed_block,
+ }
+ }
+ };
+
+ let ref_func = quote! {
+ #[inline]
+ #[track_caller]
+ #[doc = #doc_ref]
+ #[doc = #doc_else]
+ pub fn #ref_fn_name(&self) -> (#(&#data_types),*) {
+ match self {
+ #pattern => #ret_value,
+ val @ _ => #failed_block_ref,
+ }
+ }
+ };
+
+ let mut_func = quote! {
+ #[inline]
+ #[track_caller]
+ #[doc = #doc_mut]
+ #[doc = #doc_else]
+ pub fn #mut_fn_name(&mut self) -> (#(&mut #data_types),*) {
+ match self {
+ #pattern => #ret_value,
+ val @ _ => #failed_block_mut,
+ }
+ }
+ };
+
+ if info.owned && state.default_info.owned {
+ funcs.push(func);
+ }
+ if info.ref_ && state.default_info.ref_ {
+ funcs.push(ref_func);
+ }
+ if info.ref_mut && state.default_info.ref_mut {
+ funcs.push(mut_func);
+ }
+ }
+
+ let imp = quote! {
+ #[automatically_derived]
+ impl #imp_generics #enum_name #type_generics #where_clause {
+ #(#funcs)*
+ }
+ };
+
+ Ok(imp)
+}
+
+fn get_field_info(fields: &Fields) -> (TokenStream, TokenStream, Vec<&Type>) {
+ match fields {
+ Fields::Named(_) => panic!("cannot unwrap anonymous records"),
+ Fields::Unnamed(ref fields) => {
+ let (idents, types) = fields
+ .unnamed
+ .iter()
+ .enumerate()
+ .map(|(n, it)| (format_ident!("field_{n}"), &it.ty))
+ .unzip::<_, _, Vec<_>, Vec<_>>();
+ (quote! { (#(#idents),*) }, quote! { (#(#idents),*) }, types)
+ }
+ Fields::Unit => (quote! {}, quote! { () }, vec![]),
+ }
+}
+
+fn failed_block(state: &State, enum_name: &Ident, fn_name: &Ident) -> TokenStream {
+ let arms = state
+ .variant_states
+ .iter()
+ .map(|it| it.variant.unwrap())
+ .map(|variant| {
+ let data_pattern = match variant.fields {
+ Fields::Named(_) => quote! { {..} },
+ Fields::Unnamed(_) => quote! { (..) },
+ Fields::Unit => quote! {},
+ };
+ let variant_ident = &variant.ident;
+ let panic_msg = format!(
+ "called `{enum_name}::{fn_name}()` on a `{enum_name}::{variant_ident}` value"
+ );
+ quote! { #enum_name :: #variant_ident #data_pattern => panic!(#panic_msg) }
+ });
+
+ quote! {
+ match val {
+ #(#arms),*
+ }
+ }
+}