summaryrefslogtreecommitdiffstats
path: root/src/tools/clippy/clippy_lints/src/loops/manual_find.rs
blob: 09b2376d5c04a6d68ad26d1d3abaaa8002f7d4db (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
use super::utils::make_iterator_snippet;
use super::MANUAL_FIND;
use clippy_utils::{
    diagnostics::span_lint_and_then, higher, is_lang_ctor, path_res, peel_blocks_with_stmt,
    source::snippet_with_applicability, ty::implements_trait,
};
use if_chain::if_chain;
use rustc_errors::Applicability;
use rustc_hir::{
    def::Res, lang_items::LangItem, BindingAnnotation, Block, Expr, ExprKind, HirId, Node, Pat, PatKind, Stmt, StmtKind,
};
use rustc_lint::LateContext;
use rustc_span::source_map::Span;

pub(super) fn check<'tcx>(
    cx: &LateContext<'tcx>,
    pat: &'tcx Pat<'_>,
    arg: &'tcx Expr<'_>,
    body: &'tcx Expr<'_>,
    span: Span,
    expr: &'tcx Expr<'_>,
) {
    let inner_expr = peel_blocks_with_stmt(body);
    // Check for the specific case that the result is returned and optimize suggestion for that (more
    // cases can be added later)
    if_chain! {
        if let Some(higher::If { cond, then, r#else: None, }) = higher::If::hir(inner_expr);
        if let Some(binding_id) = get_binding(pat);
        if let ExprKind::Block(block, _) = then.kind;
        if let [stmt] = block.stmts;
        if let StmtKind::Semi(semi) = stmt.kind;
        if let ExprKind::Ret(Some(ret_value)) = semi.kind;
        if let ExprKind::Call(Expr { kind: ExprKind::Path(ctor), .. }, [inner_ret]) = ret_value.kind;
        if is_lang_ctor(cx, ctor, LangItem::OptionSome);
        if path_res(cx, inner_ret) == Res::Local(binding_id);
        if let Some((last_stmt, last_ret)) = last_stmt_and_ret(cx, expr);
        then {
            let mut applicability = Applicability::MachineApplicable;
            let mut snippet = make_iterator_snippet(cx, arg, &mut applicability);
            // Checks if `pat` is a single reference to a binding (`&x`)
            let is_ref_to_binding =
                matches!(pat.kind, PatKind::Ref(inner, _) if matches!(inner.kind, PatKind::Binding(..)));
            // If `pat` is not a binding or a reference to a binding (`x` or `&x`)
            // we need to map it to the binding returned by the function (i.e. `.map(|(x, _)| x)`)
            if !(matches!(pat.kind, PatKind::Binding(..)) || is_ref_to_binding) {
                snippet.push_str(
                    &format!(
                        ".map(|{}| {})",
                        snippet_with_applicability(cx, pat.span, "..", &mut applicability),
                        snippet_with_applicability(cx, inner_ret.span, "..", &mut applicability),
                    )[..],
                );
            }
            let ty = cx.typeck_results().expr_ty(inner_ret);
            if cx.tcx.lang_items().copy_trait().map_or(false, |id| implements_trait(cx, ty, id, &[])) {
                snippet.push_str(
                    &format!(
                        ".find(|{}{}| {})",
                        "&".repeat(1 + usize::from(is_ref_to_binding)),
                        snippet_with_applicability(cx, inner_ret.span, "..", &mut applicability),
                        snippet_with_applicability(cx, cond.span, "..", &mut applicability),
                    )[..],
                );
                if is_ref_to_binding {
                    snippet.push_str(".copied()");
                }
            } else {
                applicability = Applicability::MaybeIncorrect;
                snippet.push_str(
                    &format!(
                        ".find(|{}| {})",
                        snippet_with_applicability(cx, inner_ret.span, "..", &mut applicability),
                        snippet_with_applicability(cx, cond.span, "..", &mut applicability),
                    )[..],
                );
            }
            // Extends to `last_stmt` to include semicolon in case of `return None;`
            let lint_span = span.to(last_stmt.span).to(last_ret.span);
            span_lint_and_then(
                cx,
                MANUAL_FIND,
                lint_span,
                "manual implementation of `Iterator::find`",
                |diag| {
                    if applicability == Applicability::MaybeIncorrect {
                        diag.note("you may need to dereference some variables");
                    }
                    diag.span_suggestion(
                        lint_span,
                        "replace with an iterator",
                        snippet,
                        applicability,
                    );
                },
            );
        }
    }
}

fn get_binding(pat: &Pat<'_>) -> Option<HirId> {
    let mut hir_id = None;
    let mut count = 0;
    pat.each_binding(|annotation, id, _, _| {
        count += 1;
        if count > 1 {
            hir_id = None;
            return;
        }
        if let BindingAnnotation::NONE = annotation {
            hir_id = Some(id);
        }
    });
    hir_id
}

// Returns the last statement and last return if function fits format for lint
fn last_stmt_and_ret<'tcx>(
    cx: &LateContext<'tcx>,
    expr: &'tcx Expr<'_>,
) -> Option<(&'tcx Stmt<'tcx>, &'tcx Expr<'tcx>)> {
    // Returns last non-return statement and the last return
    fn extract<'tcx>(block: &Block<'tcx>) -> Option<(&'tcx Stmt<'tcx>, &'tcx Expr<'tcx>)> {
        if let [.., last_stmt] = block.stmts {
            if let Some(ret) = block.expr {
                return Some((last_stmt, ret));
            }
            if_chain! {
                if let [.., snd_last, _] = block.stmts;
                if let StmtKind::Semi(last_expr) = last_stmt.kind;
                if let ExprKind::Ret(Some(ret)) = last_expr.kind;
                then {
                    return Some((snd_last, ret));
                }
            }
        }
        None
    }
    let mut parent_iter = cx.tcx.hir().parent_iter(expr.hir_id);
    if_chain! {
        // This should be the loop
        if let Some((node_hir, Node::Stmt(..))) = parent_iter.next();
        // This should be the function body
        if let Some((_, Node::Block(block))) = parent_iter.next();
        if let Some((last_stmt, last_ret)) = extract(block);
        if last_stmt.hir_id == node_hir;
        if let ExprKind::Path(path) = &last_ret.kind;
        if is_lang_ctor(cx, path, LangItem::OptionNone);
        if let Some((_, Node::Expr(_block))) = parent_iter.next();
        // This includes the function header
        if let Some((_, func)) = parent_iter.next();
        if func.fn_kind().is_some();
        then {
            Some((block.stmts.last().unwrap(), last_ret))
        } else {
            None
        }
    }
}