summaryrefslogtreecommitdiffstats
path: root/src/tools/clippy/clippy_lints/src/size_of_in_element_count.rs
blob: ac4e29e9dfdfa0f5d5c95b93b2b6883284aeb757 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
//! Lint on use of `size_of` or `size_of_val` of T in an expression
//! expecting a count of T

use clippy_utils::diagnostics::span_lint_and_help;
use clippy_utils::{match_def_path, paths};
use if_chain::if_chain;
use rustc_hir::BinOpKind;
use rustc_hir::{Expr, ExprKind};
use rustc_lint::{LateContext, LateLintPass};
use rustc_middle::ty::{self, Ty, TypeAndMut};
use rustc_session::{declare_lint_pass, declare_tool_lint};
use rustc_span::sym;

declare_clippy_lint! {
    /// ### What it does
    /// Detects expressions where
    /// `size_of::<T>` or `size_of_val::<T>` is used as a
    /// count of elements of type `T`
    ///
    /// ### Why is this bad?
    /// These functions expect a count
    /// of `T` and not a number of bytes
    ///
    /// ### Example
    /// ```rust,no_run
    /// # use std::ptr::copy_nonoverlapping;
    /// # use std::mem::size_of;
    /// const SIZE: usize = 128;
    /// let x = [2u8; SIZE];
    /// let mut y = [2u8; SIZE];
    /// unsafe { copy_nonoverlapping(x.as_ptr(), y.as_mut_ptr(), size_of::<u8>() * SIZE) };
    /// ```
    #[clippy::version = "1.50.0"]
    pub SIZE_OF_IN_ELEMENT_COUNT,
    correctness,
    "using `size_of::<T>` or `size_of_val::<T>` where a count of elements of `T` is expected"
}

declare_lint_pass!(SizeOfInElementCount => [SIZE_OF_IN_ELEMENT_COUNT]);

fn get_size_of_ty<'tcx>(cx: &LateContext<'tcx>, expr: &'tcx Expr<'_>, inverted: bool) -> Option<Ty<'tcx>> {
    match expr.kind {
        ExprKind::Call(count_func, _func_args) => {
            if_chain! {
                if !inverted;
                if let ExprKind::Path(ref count_func_qpath) = count_func.kind;
                if let Some(def_id) = cx.qpath_res(count_func_qpath, count_func.hir_id).opt_def_id();
                if matches!(cx.tcx.get_diagnostic_name(def_id), Some(sym::mem_size_of | sym::mem_size_of_val));
                then {
                    cx.typeck_results().node_substs(count_func.hir_id).types().next()
                } else {
                    None
                }
            }
        },
        ExprKind::Binary(op, left, right) if BinOpKind::Mul == op.node => {
            get_size_of_ty(cx, left, inverted).or_else(|| get_size_of_ty(cx, right, inverted))
        },
        ExprKind::Binary(op, left, right) if BinOpKind::Div == op.node => {
            get_size_of_ty(cx, left, inverted).or_else(|| get_size_of_ty(cx, right, !inverted))
        },
        ExprKind::Cast(expr, _) => get_size_of_ty(cx, expr, inverted),
        _ => None,
    }
}

fn get_pointee_ty_and_count_expr<'tcx>(
    cx: &LateContext<'tcx>,
    expr: &'tcx Expr<'_>,
) -> Option<(Ty<'tcx>, &'tcx Expr<'tcx>)> {
    const FUNCTIONS: [&[&str]; 8] = [
        &paths::PTR_COPY_NONOVERLAPPING,
        &paths::PTR_COPY,
        &paths::PTR_WRITE_BYTES,
        &paths::PTR_SWAP_NONOVERLAPPING,
        &paths::PTR_SLICE_FROM_RAW_PARTS,
        &paths::PTR_SLICE_FROM_RAW_PARTS_MUT,
        &paths::SLICE_FROM_RAW_PARTS,
        &paths::SLICE_FROM_RAW_PARTS_MUT,
    ];
    const METHODS: [&str; 11] = [
        "write_bytes",
        "copy_to",
        "copy_from",
        "copy_to_nonoverlapping",
        "copy_from_nonoverlapping",
        "add",
        "wrapping_add",
        "sub",
        "wrapping_sub",
        "offset",
        "wrapping_offset",
    ];

    if_chain! {
        // Find calls to ptr::{copy, copy_nonoverlapping}
        // and ptr::{swap_nonoverlapping, write_bytes},
        if let ExprKind::Call(func, [.., count]) = expr.kind;
        if let ExprKind::Path(ref func_qpath) = func.kind;
        if let Some(def_id) = cx.qpath_res(func_qpath, func.hir_id).opt_def_id();
        if FUNCTIONS.iter().any(|func_path| match_def_path(cx, def_id, func_path));

        // Get the pointee type
        if let Some(pointee_ty) = cx.typeck_results().node_substs(func.hir_id).types().next();
        then {
            return Some((pointee_ty, count));
        }
    };
    if_chain! {
        // Find calls to copy_{from,to}{,_nonoverlapping} and write_bytes methods
        if let ExprKind::MethodCall(method_path, ptr_self, [.., count], _) = expr.kind;
        let method_ident = method_path.ident.as_str();
        if METHODS.iter().any(|m| *m == method_ident);

        // Get the pointee type
        if let ty::RawPtr(TypeAndMut { ty: pointee_ty, .. }) =
            cx.typeck_results().expr_ty(ptr_self).kind();
        then {
            return Some((*pointee_ty, count));
        }
    };
    None
}

impl<'tcx> LateLintPass<'tcx> for SizeOfInElementCount {
    fn check_expr(&mut self, cx: &LateContext<'tcx>, expr: &'tcx Expr<'_>) {
        const HELP_MSG: &str = "use a count of elements instead of a count of bytes\
            , it already gets multiplied by the size of the type";

        const LINT_MSG: &str = "found a count of bytes \
             instead of a count of elements of `T`";

        if_chain! {
            // Find calls to functions with an element count parameter and get
            // the pointee type and count parameter expression
            if let Some((pointee_ty, count_expr)) = get_pointee_ty_and_count_expr(cx, expr);

            // Find a size_of call in the count parameter expression and
            // check that it's the same type
            if let Some(ty_used_for_size_of) = get_size_of_ty(cx, count_expr, false);
            if pointee_ty == ty_used_for_size_of;
            then {
                span_lint_and_help(
                    cx,
                    SIZE_OF_IN_ELEMENT_COUNT,
                    count_expr.span,
                    LINT_MSG,
                    None,
                    HELP_MSG
                );
            }
        };
    }
}