summaryrefslogtreecommitdiffstats
path: root/compiler/rustc_mir_transform/src/match_branches.rs
blob: a0ba69c89b048491f9840a9cd6806f91bf2fd46f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
use crate::MirPass;
use rustc_middle::mir::*;
use rustc_middle::ty::TyCtxt;
use std::iter;

use super::simplify::simplify_cfg;

pub struct MatchBranchSimplification;

/// If a source block is found that switches between two blocks that are exactly
/// the same modulo const bool assignments (e.g., one assigns true another false
/// to the same place), merge a target block statements into the source block,
/// using Eq / Ne comparison with switch value where const bools value differ.
///
/// For example:
///
/// ```ignore (MIR)
/// bb0: {
///     switchInt(move _3) -> [42_isize: bb1, otherwise: bb2];
/// }
///
/// bb1: {
///     _2 = const true;
///     goto -> bb3;
/// }
///
/// bb2: {
///     _2 = const false;
///     goto -> bb3;
/// }
/// ```
///
/// into:
///
/// ```ignore (MIR)
/// bb0: {
///    _2 = Eq(move _3, const 42_isize);
///    goto -> bb3;
/// }
/// ```

impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
        sess.mir_opt_level() >= 3
    }

    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
        let def_id = body.source.def_id();
        let param_env = tcx.param_env(def_id);

        let bbs = body.basic_blocks.as_mut();
        let mut should_cleanup = false;
        'outer: for bb_idx in bbs.indices() {
            if !tcx.consider_optimizing(|| format!("MatchBranchSimplification {:?} ", def_id)) {
                continue;
            }

            let (discr, val, switch_ty, first, second) = match bbs[bb_idx].terminator().kind {
                TerminatorKind::SwitchInt {
                    discr: ref discr @ (Operand::Copy(_) | Operand::Move(_)),
                    switch_ty,
                    ref targets,
                    ..
                } if targets.iter().len() == 1 => {
                    let (value, target) = targets.iter().next().unwrap();
                    if target == targets.otherwise() {
                        continue;
                    }
                    (discr, value, switch_ty, target, targets.otherwise())
                }
                // Only optimize switch int statements
                _ => continue,
            };

            // Check that destinations are identical, and if not, then don't optimize this block
            if bbs[first].terminator().kind != bbs[second].terminator().kind {
                continue;
            }

            // Check that blocks are assignments of consts to the same place or same statement,
            // and match up 1-1, if not don't optimize this block.
            let first_stmts = &bbs[first].statements;
            let scnd_stmts = &bbs[second].statements;
            if first_stmts.len() != scnd_stmts.len() {
                continue;
            }
            for (f, s) in iter::zip(first_stmts, scnd_stmts) {
                match (&f.kind, &s.kind) {
                    // If two statements are exactly the same, we can optimize.
                    (f_s, s_s) if f_s == s_s => {}

                    // If two statements are const bool assignments to the same place, we can optimize.
                    (
                        StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))),
                        StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
                    ) if lhs_f == lhs_s
                        && f_c.literal.ty().is_bool()
                        && s_c.literal.ty().is_bool()
                        && f_c.literal.try_eval_bool(tcx, param_env).is_some()
                        && s_c.literal.try_eval_bool(tcx, param_env).is_some() => {}

                    // Otherwise we cannot optimize. Try another block.
                    _ => continue 'outer,
                }
            }
            // Take ownership of items now that we know we can optimize.
            let discr = discr.clone();

            // Introduce a temporary for the discriminant value.
            let source_info = bbs[bb_idx].terminator().source_info;
            let discr_local = body.local_decls.push(LocalDecl::new(switch_ty, source_info.span));

            // We already checked that first and second are different blocks,
            // and bb_idx has a different terminator from both of them.
            let (from, first, second) = bbs.pick3_mut(bb_idx, first, second);

            let new_stmts = iter::zip(&first.statements, &second.statements).map(|(f, s)| {
                match (&f.kind, &s.kind) {
                    (f_s, s_s) if f_s == s_s => (*f).clone(),

                    (
                        StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))),
                        StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(s_c)))),
                    ) => {
                        // From earlier loop we know that we are dealing with bool constants only:
                        let f_b = f_c.literal.try_eval_bool(tcx, param_env).unwrap();
                        let s_b = s_c.literal.try_eval_bool(tcx, param_env).unwrap();
                        if f_b == s_b {
                            // Same value in both blocks. Use statement as is.
                            (*f).clone()
                        } else {
                            // Different value between blocks. Make value conditional on switch condition.
                            let size = tcx.layout_of(param_env.and(switch_ty)).unwrap().size;
                            let const_cmp = Operand::const_from_scalar(
                                tcx,
                                switch_ty,
                                rustc_const_eval::interpret::Scalar::from_uint(val, size),
                                rustc_span::DUMMY_SP,
                            );
                            let op = if f_b { BinOp::Eq } else { BinOp::Ne };
                            let rhs = Rvalue::BinaryOp(
                                op,
                                Box::new((Operand::Copy(Place::from(discr_local)), const_cmp)),
                            );
                            Statement {
                                source_info: f.source_info,
                                kind: StatementKind::Assign(Box::new((*lhs, rhs))),
                            }
                        }
                    }

                    _ => unreachable!(),
                }
            });

            from.statements
                .push(Statement { source_info, kind: StatementKind::StorageLive(discr_local) });
            from.statements.push(Statement {
                source_info,
                kind: StatementKind::Assign(Box::new((
                    Place::from(discr_local),
                    Rvalue::Use(discr),
                ))),
            });
            from.statements.extend(new_stmts);
            from.statements
                .push(Statement { source_info, kind: StatementKind::StorageDead(discr_local) });
            from.terminator_mut().kind = first.terminator().kind.clone();
            should_cleanup = true;
        }

        if should_cleanup {
            simplify_cfg(tcx, body);
        }
    }
}