summaryrefslogtreecommitdiffstats
path: root/compiler/rustc_mir_transform/src/early_otherwise_branch.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src/early_otherwise_branch.rs')
-rw-r--r--compiler/rustc_mir_transform/src/early_otherwise_branch.rs429
1 files changed, 429 insertions, 0 deletions
diff --git a/compiler/rustc_mir_transform/src/early_otherwise_branch.rs b/compiler/rustc_mir_transform/src/early_otherwise_branch.rs
new file mode 100644
index 000000000..dba42f7af
--- /dev/null
+++ b/compiler/rustc_mir_transform/src/early_otherwise_branch.rs
@@ -0,0 +1,429 @@
+use rustc_middle::mir::patch::MirPatch;
+use rustc_middle::mir::*;
+use rustc_middle::ty::{self, Ty, TyCtxt};
+use std::fmt::Debug;
+
+use super::simplify::simplify_cfg;
+
+/// This pass optimizes something like
+/// ```ignore (syntax-highlighting-only)
+/// let x: Option<()>;
+/// let y: Option<()>;
+/// match (x,y) {
+/// (Some(_), Some(_)) => {0},
+/// _ => {1}
+/// }
+/// ```
+/// into something like
+/// ```ignore (syntax-highlighting-only)
+/// let x: Option<()>;
+/// let y: Option<()>;
+/// let discriminant_x = std::mem::discriminant(x);
+/// let discriminant_y = std::mem::discriminant(y);
+/// if discriminant_x == discriminant_y {
+/// match x {
+/// Some(_) => 0,
+/// _ => 1, // <----
+/// } // | Actually the same bb
+/// } else { // |
+/// 1 // <--------------
+/// }
+/// ```
+///
+/// Specifically, it looks for instances of control flow like this:
+/// ```text
+///
+/// =================
+/// | BB1 |
+/// |---------------| ============================
+/// | ... | /------> | BBC |
+/// |---------------| | |--------------------------|
+/// | switchInt(Q) | | | _cl = discriminant(P) |
+/// | c | --------/ |--------------------------|
+/// | d | -------\ | switchInt(_cl) |
+/// | ... | | | c | ---> BBC.2
+/// | otherwise | --\ | /--- | otherwise |
+/// ================= | | | ============================
+/// | | |
+/// ================= | | |
+/// | BBU | <-| | | ============================
+/// |---------------| | \-------> | BBD |
+/// |---------------| | | |--------------------------|
+/// | unreachable | | | | _dl = discriminant(P) |
+/// ================= | | |--------------------------|
+/// | | | switchInt(_dl) |
+/// ================= | | | d | ---> BBD.2
+/// | BB9 | <--------------- | otherwise |
+/// |---------------| ============================
+/// | ... |
+/// =================
+/// ```
+/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU` or to `BB9`. In the
+/// code:
+/// - `BB1` is `parent` and `BBC, BBD` are children
+/// - `P` is `child_place`
+/// - `child_ty` is the type of `_cl`.
+/// - `Q` is `parent_op`.
+/// - `parent_ty` is the type of `Q`.
+/// - `BB9` is `destination`
+/// All this is then transformed into:
+/// ```text
+///
+/// =======================
+/// | BB1 |
+/// |---------------------| ============================
+/// | ... | /------> | BBEq |
+/// | _s = discriminant(P)| | |--------------------------|
+/// | _t = Ne(Q, _s) | | |--------------------------|
+/// |---------------------| | | switchInt(Q) |
+/// | switchInt(_t) | | | c | ---> BBC.2
+/// | false | --------/ | d | ---> BBD.2
+/// | otherwise | ---------------- | otherwise |
+/// ======================= | ============================
+/// |
+/// ================= |
+/// | BB9 | <-----------/
+/// |---------------|
+/// | ... |
+/// =================
+/// ```
+///
+/// This is only correct for some `P`, since `P` is now computed outside the original `switchInt`.
+/// The filter on which `P` are allowed (together with discussion of its correctness) is found in
+/// `may_hoist`.
+pub struct EarlyOtherwiseBranch;
+
+impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
+ fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
+ sess.mir_opt_level() >= 3 && sess.opts.unstable_opts.unsound_mir_opts
+ }
+
+ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+ trace!("running EarlyOtherwiseBranch on {:?}", body.source);
+
+ let mut should_cleanup = false;
+
+ // Also consider newly generated bbs in the same pass
+ for i in 0..body.basic_blocks().len() {
+ let bbs = body.basic_blocks();
+ let parent = BasicBlock::from_usize(i);
+ let Some(opt_data) = evaluate_candidate(tcx, body, parent) else {
+ continue
+ };
+
+ if !tcx.consider_optimizing(|| format!("EarlyOtherwiseBranch {:?}", &opt_data)) {
+ break;
+ }
+
+ trace!("SUCCESS: found optimization possibility to apply: {:?}", &opt_data);
+
+ should_cleanup = true;
+
+ let TerminatorKind::SwitchInt {
+ discr: parent_op,
+ switch_ty: parent_ty,
+ targets: parent_targets
+ } = &bbs[parent].terminator().kind else {
+ unreachable!()
+ };
+ // Always correct since we can only switch on `Copy` types
+ let parent_op = match parent_op {
+ Operand::Move(x) => Operand::Copy(*x),
+ Operand::Copy(x) => Operand::Copy(*x),
+ Operand::Constant(x) => Operand::Constant(x.clone()),
+ };
+ let statements_before = bbs[parent].statements.len();
+ let parent_end = Location { block: parent, statement_index: statements_before };
+
+ let mut patch = MirPatch::new(body);
+
+ // create temp to store second discriminant in, `_s` in example above
+ let second_discriminant_temp =
+ patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
+
+ patch.add_statement(parent_end, StatementKind::StorageLive(second_discriminant_temp));
+
+ // create assignment of discriminant
+ patch.add_assign(
+ parent_end,
+ Place::from(second_discriminant_temp),
+ Rvalue::Discriminant(opt_data.child_place),
+ );
+
+ // create temp to store inequality comparison between the two discriminants, `_t` in
+ // example above
+ let nequal = BinOp::Ne;
+ let comp_res_type = nequal.ty(tcx, *parent_ty, opt_data.child_ty);
+ let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
+ patch.add_statement(parent_end, StatementKind::StorageLive(comp_temp));
+
+ // create inequality comparison between the two discriminants
+ let comp_rvalue = Rvalue::BinaryOp(
+ nequal,
+ Box::new((parent_op.clone(), Operand::Move(Place::from(second_discriminant_temp)))),
+ );
+ patch.add_statement(
+ parent_end,
+ StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
+ );
+
+ let eq_new_targets = parent_targets.iter().map(|(value, child)| {
+ let TerminatorKind::SwitchInt{ targets, .. } = &bbs[child].terminator().kind else {
+ unreachable!()
+ };
+ (value, targets.target_for_value(value))
+ });
+ let eq_targets = SwitchTargets::new(eq_new_targets, opt_data.destination);
+
+ // Create `bbEq` in example above
+ let eq_switch = BasicBlockData::new(Some(Terminator {
+ source_info: bbs[parent].terminator().source_info,
+ kind: TerminatorKind::SwitchInt {
+ // switch on the first discriminant, so we can mark the second one as dead
+ discr: parent_op,
+ switch_ty: opt_data.child_ty,
+ targets: eq_targets,
+ },
+ }));
+
+ let eq_bb = patch.new_block(eq_switch);
+
+ // Jump to it on the basis of the inequality comparison
+ let true_case = opt_data.destination;
+ let false_case = eq_bb;
+ patch.patch_terminator(
+ parent,
+ TerminatorKind::if_(
+ tcx,
+ Operand::Move(Place::from(comp_temp)),
+ true_case,
+ false_case,
+ ),
+ );
+
+ // generate StorageDead for the second_discriminant_temp not in use anymore
+ patch.add_statement(parent_end, StatementKind::StorageDead(second_discriminant_temp));
+
+ // Generate a StorageDead for comp_temp in each of the targets, since we moved it into
+ // the switch
+ for bb in [false_case, true_case].iter() {
+ patch.add_statement(
+ Location { block: *bb, statement_index: 0 },
+ StatementKind::StorageDead(comp_temp),
+ );
+ }
+
+ patch.apply(body);
+ }
+
+ // Since this optimization adds new basic blocks and invalidates others,
+ // clean up the cfg to make it nicer for other passes
+ if should_cleanup {
+ simplify_cfg(tcx, body);
+ }
+ }
+}
+
+/// Returns true if computing the discriminant of `place` may be hoisted out of the branch
+fn may_hoist<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, place: Place<'tcx>) -> bool {
+ // FIXME(JakobDegen): This is unsound. Someone could write code like this:
+ // ```rust
+ // let Q = val;
+ // if discriminant(P) == otherwise {
+ // let ptr = &mut Q as *mut _ as *mut u8;
+ // unsafe { *ptr = 10; } // Any invalid value for the type
+ // }
+ //
+ // match P {
+ // A => match Q {
+ // A => {
+ // // code
+ // }
+ // _ => {
+ // // don't use Q
+ // }
+ // }
+ // _ => {
+ // // don't use Q
+ // }
+ // };
+ // ```
+ //
+ // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
+ // invalid value, which is UB.
+ //
+ // In order to fix this, we would either need to show that the discriminant computation of
+ // `place` is computed in all branches, including the `otherwise` branch, or we would need
+ // another analysis pass to determine that the place is fully initialized. It might even be best
+ // to have the hoisting be performed in a different pass and just do the CFG changing in this
+ // pass.
+ for (place, proj) in place.iter_projections() {
+ match proj {
+ // Dereferencing in the computation of `place` might cause issues from one of two
+ // categories. First, the referent might be invalid. We protect against this by
+ // dereferencing references only (not pointers). Second, the use of a reference may
+ // invalidate other references that are used later (for aliasing reasons). Consider
+ // where such an invalidated reference may appear:
+ // - In `Q`: Not possible since `Q` is used as the operand of a `SwitchInt` and so
+ // cannot contain referenced data.
+ // - In `BBU`: Not possible since that block contains only the `unreachable` terminator
+ // - In `BBC.2, BBD.2`: Not possible, since `discriminant(P)` was computed prior to
+ // reaching that block in the input to our transformation, and so any data
+ // invalidated by that computation could not have been used there.
+ // - In `BB9`: Not possible since control flow might have reached `BB9` via the
+ // `otherwise` branch in `BBC, BBD` in the input to our transformation, which would
+ // have invalidated the data when computing `discriminant(P)`
+ // So dereferencing here is correct.
+ ProjectionElem::Deref => match place.ty(body.local_decls(), tcx).ty.kind() {
+ ty::Ref(..) => {}
+ _ => return false,
+ },
+ // Field projections are always valid
+ ProjectionElem::Field(..) => {}
+ // We cannot allow
+ // downcasts either, since the correctness of the downcast may depend on the parent
+ // branch being taken. An easy example of this is
+ // ```
+ // Q = discriminant(_3)
+ // P = (_3 as Variant)
+ // ```
+ // However, checking if the child and parent place are the same and only erroring then
+ // is not sufficient either, since the `discriminant(_3) == 1` (or whatever) check may
+ // be replaced by another optimization pass with any other condition that can be proven
+ // equivalent.
+ ProjectionElem::Downcast(..) => {
+ return false;
+ }
+ // We cannot allow indexing since the index may be out of bounds.
+ _ => {
+ return false;
+ }
+ }
+ }
+ true
+}
+
+#[derive(Debug)]
+struct OptimizationData<'tcx> {
+ destination: BasicBlock,
+ child_place: Place<'tcx>,
+ child_ty: Ty<'tcx>,
+ child_source: SourceInfo,
+}
+
+fn evaluate_candidate<'tcx>(
+ tcx: TyCtxt<'tcx>,
+ body: &Body<'tcx>,
+ parent: BasicBlock,
+) -> Option<OptimizationData<'tcx>> {
+ let bbs = body.basic_blocks();
+ let TerminatorKind::SwitchInt {
+ targets,
+ switch_ty: parent_ty,
+ ..
+ } = &bbs[parent].terminator().kind else {
+ return None
+ };
+ let parent_dest = {
+ let poss = targets.otherwise();
+ // If the fallthrough on the parent is trivially unreachable, we can let the
+ // children choose the destination
+ if bbs[poss].statements.len() == 0
+ && bbs[poss].terminator().kind == TerminatorKind::Unreachable
+ {
+ None
+ } else {
+ Some(poss)
+ }
+ };
+ let (_, child) = targets.iter().next()?;
+ let child_terminator = &bbs[child].terminator();
+ let TerminatorKind::SwitchInt {
+ switch_ty: child_ty,
+ targets: child_targets,
+ ..
+ } = &child_terminator.kind else {
+ return None
+ };
+ if child_ty != parent_ty {
+ return None;
+ }
+ let Some(StatementKind::Assign(boxed))
+ = &bbs[child].statements.first().map(|x| &x.kind) else {
+ return None;
+ };
+ let (_, Rvalue::Discriminant(child_place)) = &**boxed else {
+ return None;
+ };
+ let destination = parent_dest.unwrap_or(child_targets.otherwise());
+
+ // Verify that the optimization is legal in general
+ // We can hoist evaluating the child discriminant out of the branch
+ if !may_hoist(tcx, body, *child_place) {
+ return None;
+ }
+
+ // Verify that the optimization is legal for each branch
+ for (value, child) in targets.iter() {
+ if !verify_candidate_branch(&bbs[child], value, *child_place, destination) {
+ return None;
+ }
+ }
+ Some(OptimizationData {
+ destination,
+ child_place: *child_place,
+ child_ty: *child_ty,
+ child_source: child_terminator.source_info,
+ })
+}
+
+fn verify_candidate_branch<'tcx>(
+ branch: &BasicBlockData<'tcx>,
+ value: u128,
+ place: Place<'tcx>,
+ destination: BasicBlock,
+) -> bool {
+ // In order for the optimization to be correct, the branch must...
+ // ...have exactly one statement
+ if branch.statements.len() != 1 {
+ return false;
+ }
+ // ...assign the discriminant of `place` in that statement
+ let StatementKind::Assign(boxed) = &branch.statements[0].kind else {
+ return false
+ };
+ let (discr_place, Rvalue::Discriminant(from_place)) = &**boxed else {
+ return false
+ };
+ if *from_place != place {
+ return false;
+ }
+ // ...make that assignment to a local
+ if discr_place.projection.len() != 0 {
+ return false;
+ }
+ // ...terminate on a `SwitchInt` that invalidates that local
+ let TerminatorKind::SwitchInt{ discr: switch_op, targets, .. } = &branch.terminator().kind else {
+ return false
+ };
+ if *switch_op != Operand::Move(*discr_place) {
+ return false;
+ }
+ // ...fall through to `destination` if the switch misses
+ if destination != targets.otherwise() {
+ return false;
+ }
+ // ...have a branch for value `value`
+ let mut iter = targets.iter();
+ let Some((target_value, _)) = iter.next() else {
+ return false;
+ };
+ if target_value != value {
+ return false;
+ }
+ // ...and have no more branches
+ if let Some(_) = iter.next() {
+ return false;
+ }
+ return true;
+}