summaryrefslogtreecommitdiffstats
path: root/vendor/yoke-derive/src/visitor.rs
blob: daca1da13fdef636b176f8f9f9999fded17ebe5d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
// This file is part of ICU4X. For terms of use, please see the file
// called LICENSE at the top level of the ICU4X source tree
// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).

//! Visitor for determining whether a type has type and non-static lifetime parameters

use std::collections::HashSet;
use syn::visit::{visit_lifetime, visit_type, visit_type_path, Visit};
use syn::{Ident, Lifetime, Type, TypePath};

struct TypeVisitor<'a> {
    /// The type parameters in scope
    typarams: &'a HashSet<Ident>,
    /// Whether we found a type parameter
    found_typarams: bool,
    /// Whether we found a non-'static lifetime parameter
    found_lifetimes: bool,
}

impl<'a, 'ast> Visit<'ast> for TypeVisitor<'a> {
    fn visit_lifetime(&mut self, lt: &'ast Lifetime) {
        if lt.ident != "static" {
            self.found_lifetimes = true;
        }
        visit_lifetime(self, lt)
    }
    fn visit_type_path(&mut self, ty: &'ast TypePath) {
        // We only need to check ty.path.get_ident() and not ty.qself or any
        // generics in ty.path because the visitor will eventually visit those
        // types on its own
        if let Some(ident) = ty.path.get_ident() {
            if self.typarams.contains(ident) {
                self.found_typarams = true;
            }
        }

        visit_type_path(self, ty)
    }
}

/// Checks if a type has type or lifetime parameters, given the local context of
/// named type parameters. Returns (has_type_params, has_lifetime_params)
pub fn check_type_for_parameters(ty: &Type, typarams: &HashSet<Ident>) -> (bool, bool) {
    let mut visit = TypeVisitor {
        typarams,
        found_typarams: false,
        found_lifetimes: false,
    };
    visit_type(&mut visit, ty);

    (visit.found_typarams, visit.found_lifetimes)
}

#[cfg(test)]
mod tests {
    use proc_macro2::Span;
    use std::collections::HashSet;
    use syn::{parse_quote, Ident};

    use super::check_type_for_parameters;
    fn make_typarams(params: &[&str]) -> HashSet<Ident> {
        params
            .iter()
            .map(|x| Ident::new(x, Span::call_site()))
            .collect()
    }

    #[test]
    fn test_simple_type() {
        let environment = make_typarams(&["T", "U", "V"]);

        let ty = parse_quote!(Foo<'a, T>);
        let check = check_type_for_parameters(&ty, &environment);
        assert_eq!(check, (true, true));

        let ty = parse_quote!(Foo<T>);
        let check = check_type_for_parameters(&ty, &environment);
        assert_eq!(check, (true, false));

        let ty = parse_quote!(Foo<'static, T>);
        let check = check_type_for_parameters(&ty, &environment);
        assert_eq!(check, (true, false));

        let ty = parse_quote!(Foo<'a>);
        let check = check_type_for_parameters(&ty, &environment);
        assert_eq!(check, (false, true));

        let ty = parse_quote!(Foo<'a, Bar<U>, Baz<(V, u8)>>);
        let check = check_type_for_parameters(&ty, &environment);
        assert_eq!(check, (true, true));

        let ty = parse_quote!(Foo<'a, W>);
        let check = check_type_for_parameters(&ty, &environment);
        assert_eq!(check, (false, true));
    }

    #[test]
    fn test_assoc_types() {
        let environment = make_typarams(&["T"]);

        let ty = parse_quote!(<Foo as SomeTrait<'a, T>>::Output);
        let check = check_type_for_parameters(&ty, &environment);
        assert_eq!(check, (true, true));

        let ty = parse_quote!(<Foo as SomeTrait<'static, T>>::Output);
        let check = check_type_for_parameters(&ty, &environment);
        assert_eq!(check, (true, false));

        let ty = parse_quote!(<T as SomeTrait<'static, Foo>>::Output);
        let check = check_type_for_parameters(&ty, &environment);
        assert_eq!(check, (true, false));
    }
}