summaryrefslogtreecommitdiffstats
path: root/third_party/rust/derive_more-impl/src/mul_like.rs
blob: d06372f38acae571411f81ee91c6ef5a6b20a530 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
use crate::add_like;
use crate::mul_helpers::generics_and_exprs;
use crate::utils::{AttrParams, HashSet, MultiFieldData, RefType, State};
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use std::iter;
use syn::{DeriveInput, Result};

pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> {
    let mut state = State::with_attr_params(
        input,
        trait_name,
        trait_name.to_lowercase(),
        AttrParams::struct_(vec!["forward"]),
    )?;
    if state.default_info.forward {
        return Ok(add_like::expand(input, trait_name));
    }

    let scalar_ident = format_ident!("__RhsT");
    state.add_trait_path_type_param(quote! { #scalar_ident });
    let multi_field_data = state.enabled_fields_data();
    let MultiFieldData {
        input_type,
        field_types,
        ty_generics,
        trait_path,
        trait_path_with_params,
        method_ident,
        ..
    } = multi_field_data.clone();

    let tys = field_types.iter().collect::<HashSet<_>>();
    let tys = tys.iter();
    let scalar_iter = iter::repeat(&scalar_ident);
    let trait_path_iter = iter::repeat(trait_path);

    let type_where_clauses = quote! {
        where #(#tys: #trait_path_iter<#scalar_iter, Output=#tys>),*
    };

    let (generics, initializers) = generics_and_exprs(
        multi_field_data.clone(),
        &scalar_ident,
        type_where_clauses,
        RefType::No,
    );
    let body = multi_field_data.initializer(&initializers);
    let (impl_generics, _, where_clause) = generics.split_for_impl();
    Ok(quote! {
        #[automatically_derived]
        impl #impl_generics  #trait_path_with_params for #input_type #ty_generics #where_clause {
            type Output = #input_type #ty_generics;

            #[inline]
            fn #method_ident(self, rhs: #scalar_ident) -> #input_type #ty_generics {
                #body
            }
        }
    })
}