summaryrefslogtreecommitdiffstats
path: root/compiler/rustc_mir_transform/src/normalize_array_len.rs
blob: a159e61717823e5847d01641fc9f9762bae81139 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
//! This pass eliminates casting of arrays into slices when their length
//! is taken using `.len()` method. Handy to preserve information in MIR for const prop

use crate::MirPass;
use rustc_data_structures::fx::FxIndexMap;
use rustc_data_structures::intern::Interned;
use rustc_index::bit_set::BitSet;
use rustc_index::vec::IndexVec;
use rustc_middle::mir::*;
use rustc_middle::ty::{self, ReErased, Region, TyCtxt};

const MAX_NUM_BLOCKS: usize = 800;
const MAX_NUM_LOCALS: usize = 3000;

pub struct NormalizeArrayLen;

impl<'tcx> MirPass<'tcx> for NormalizeArrayLen {
    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
        sess.mir_opt_level() >= 4
    }

    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
        // early returns for edge cases of highly unrolled functions
        if body.basic_blocks.len() > MAX_NUM_BLOCKS {
            return;
        }
        if body.local_decls.len() > MAX_NUM_LOCALS {
            return;
        }
        normalize_array_len_calls(tcx, body)
    }
}

pub fn normalize_array_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
    // We don't ever touch terminators, so no need to invalidate the CFG cache
    let basic_blocks = body.basic_blocks.as_mut_preserves_cfg();
    let local_decls = &mut body.local_decls;

    // do a preliminary analysis to see if we ever have locals of type `[T;N]` or `&[T;N]`
    let mut interesting_locals = BitSet::new_empty(local_decls.len());
    for (local, decl) in local_decls.iter_enumerated() {
        match decl.ty.kind() {
            ty::Array(..) => {
                interesting_locals.insert(local);
            }
            ty::Ref(.., ty, Mutability::Not) => match ty.kind() {
                ty::Array(..) => {
                    interesting_locals.insert(local);
                }
                _ => {}
            },
            _ => {}
        }
    }
    if interesting_locals.is_empty() {
        // we have found nothing to analyze
        return;
    }
    let num_intesting_locals = interesting_locals.count();
    let mut state = FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
    let mut patches_scratchpad =
        FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
    let mut replacements_scratchpad =
        FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
    for block in basic_blocks {
        // make length calls for arrays [T; N] not to decay into length calls for &[T]
        // that forbids constant propagation
        normalize_array_len_call(
            tcx,
            block,
            local_decls,
            &interesting_locals,
            &mut state,
            &mut patches_scratchpad,
            &mut replacements_scratchpad,
        );
        state.clear();
        patches_scratchpad.clear();
        replacements_scratchpad.clear();
    }
}

struct Patcher<'a, 'tcx> {
    tcx: TyCtxt<'tcx>,
    patches_scratchpad: &'a FxIndexMap<usize, usize>,
    replacements_scratchpad: &'a mut FxIndexMap<usize, Local>,
    local_decls: &'a mut IndexVec<Local, LocalDecl<'tcx>>,
    statement_idx: usize,
}

impl<'tcx> Patcher<'_, 'tcx> {
    fn patch_expand_statement(
        &mut self,
        statement: &mut Statement<'tcx>,
    ) -> Option<std::vec::IntoIter<Statement<'tcx>>> {
        let idx = self.statement_idx;
        if let Some(len_statemnt_idx) = self.patches_scratchpad.get(&idx).copied() {
            let mut statements = Vec::with_capacity(2);

            // we are at statement that performs a cast. The only sound way is
            // to create another local that performs a similar copy without a cast and then
            // use this copy in the Len operation

            match &statement.kind {
                StatementKind::Assign(box (
                    ..,
                    Rvalue::Cast(
                        CastKind::Pointer(ty::adjustment::PointerCast::Unsize),
                        operand,
                        _,
                    ),
                )) => {
                    match operand {
                        Operand::Copy(place) | Operand::Move(place) => {
                            // create new local
                            let ty = operand.ty(self.local_decls, self.tcx);
                            let local_decl = LocalDecl::with_source_info(ty, statement.source_info);
                            let local = self.local_decls.push(local_decl);
                            // make it live
                            let mut make_live_statement = statement.clone();
                            make_live_statement.kind = StatementKind::StorageLive(local);
                            statements.push(make_live_statement);
                            // copy into it

                            let operand = Operand::Copy(*place);
                            let mut make_copy_statement = statement.clone();
                            let assign_to = Place::from(local);
                            let rvalue = Rvalue::Use(operand);
                            make_copy_statement.kind =
                                StatementKind::Assign(Box::new((assign_to, rvalue)));
                            statements.push(make_copy_statement);

                            // to reorder we have to copy and make NOP
                            statements.push(statement.clone());
                            statement.make_nop();

                            self.replacements_scratchpad.insert(len_statemnt_idx, local);
                        }
                        _ => {
                            unreachable!("it's a bug in the implementation")
                        }
                    }
                }
                _ => {
                    unreachable!("it's a bug in the implementation")
                }
            }

            self.statement_idx += 1;

            Some(statements.into_iter())
        } else if let Some(local) = self.replacements_scratchpad.get(&idx).copied() {
            let mut statements = Vec::with_capacity(2);

            match &statement.kind {
                StatementKind::Assign(box (into, Rvalue::Len(place))) => {
                    let add_deref = if let Some(..) = place.as_local() {
                        false
                    } else if let Some(..) = place.local_or_deref_local() {
                        true
                    } else {
                        unreachable!("it's a bug in the implementation")
                    };
                    // replace len statement
                    let mut len_statement = statement.clone();
                    let mut place = Place::from(local);
                    if add_deref {
                        place = self.tcx.mk_place_deref(place);
                    }
                    len_statement.kind =
                        StatementKind::Assign(Box::new((*into, Rvalue::Len(place))));
                    statements.push(len_statement);

                    // make temporary dead
                    let mut make_dead_statement = statement.clone();
                    make_dead_statement.kind = StatementKind::StorageDead(local);
                    statements.push(make_dead_statement);

                    // make original statement NOP
                    statement.make_nop();
                }
                _ => {
                    unreachable!("it's a bug in the implementation")
                }
            }

            self.statement_idx += 1;

            Some(statements.into_iter())
        } else {
            self.statement_idx += 1;
            None
        }
    }
}

fn normalize_array_len_call<'tcx>(
    tcx: TyCtxt<'tcx>,
    block: &mut BasicBlockData<'tcx>,
    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
    interesting_locals: &BitSet<Local>,
    state: &mut FxIndexMap<Local, usize>,
    patches_scratchpad: &mut FxIndexMap<usize, usize>,
    replacements_scratchpad: &mut FxIndexMap<usize, Local>,
) {
    for (statement_idx, statement) in block.statements.iter_mut().enumerate() {
        match &mut statement.kind {
            StatementKind::Assign(box (place, rvalue)) => {
                match rvalue {
                    Rvalue::Cast(
                        CastKind::Pointer(ty::adjustment::PointerCast::Unsize),
                        operand,
                        cast_ty,
                    ) => {
                        let Some(local) = place.as_local() else { return };
                        match operand {
                            Operand::Copy(place) | Operand::Move(place) => {
                                let Some(operand_local) = place.local_or_deref_local() else { return; };
                                if !interesting_locals.contains(operand_local) {
                                    return;
                                }
                                let operand_ty = local_decls[operand_local].ty;
                                match (operand_ty.kind(), cast_ty.kind()) {
                                    (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => {
                                        if of_ty_src == of_ty_dst {
                                            // this is a cast from [T; N] into [T], so we are good
                                            state.insert(local, statement_idx);
                                        }
                                    }
                                    // current way of patching doesn't allow to work with `mut`
                                    (
                                        ty::Ref(
                                            Region(Interned(ReErased, _)),
                                            operand_ty,
                                            Mutability::Not,
                                        ),
                                        ty::Ref(
                                            Region(Interned(ReErased, _)),
                                            cast_ty,
                                            Mutability::Not,
                                        ),
                                    ) => {
                                        match (operand_ty.kind(), cast_ty.kind()) {
                                            // current way of patching doesn't allow to work with `mut`
                                            (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => {
                                                if of_ty_src == of_ty_dst {
                                                    // this is a cast from [T; N] into [T], so we are good
                                                    state.insert(local, statement_idx);
                                                }
                                            }
                                            _ => {}
                                        }
                                    }
                                    _ => {}
                                }
                            }
                            _ => {}
                        }
                    }
                    Rvalue::Len(place) => {
                        let Some(local) = place.local_or_deref_local() else {
                            return;
                        };
                        if let Some(cast_statement_idx) = state.get(&local).copied() {
                            patches_scratchpad.insert(cast_statement_idx, statement_idx);
                        }
                    }
                    _ => {
                        // invalidate
                        state.remove(&place.local);
                    }
                }
            }
            _ => {}
        }
    }

    let mut patcher = Patcher {
        tcx,
        patches_scratchpad: &*patches_scratchpad,
        replacements_scratchpad,
        local_decls,
        statement_idx: 0,
    };

    block.expand_statements(|st| patcher.patch_expand_statement(st));
}