diff options
Diffstat (limited to 'compiler/rustc_mir_transform')
62 files changed, 19470 insertions, 0 deletions
diff --git a/compiler/rustc_mir_transform/Cargo.toml b/compiler/rustc_mir_transform/Cargo.toml new file mode 100644 index 000000000..85b7a4af5 --- /dev/null +++ b/compiler/rustc_mir_transform/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "rustc_mir_transform" +version = "0.0.0" +edition = "2021" + +[lib] +doctest = false + +[dependencies] +itertools = "0.10.1" +smallvec = { version = "1.8.1", features = ["union", "may_dangle"] } +tracing = "0.1" +rustc_ast = { path = "../rustc_ast" } +rustc_attr = { path = "../rustc_attr" } +rustc_data_structures = { path = "../rustc_data_structures" } +rustc_errors = { path = "../rustc_errors" } +rustc_hir = { path = "../rustc_hir" } +rustc_index = { path = "../rustc_index" } +rustc_middle = { path = "../rustc_middle" } +rustc_const_eval = { path = "../rustc_const_eval" } +rustc_mir_dataflow = { path = "../rustc_mir_dataflow" } +rustc_serialize = { path = "../rustc_serialize" } +rustc_session = { path = "../rustc_session" } +rustc_target = { path = "../rustc_target" } +rustc_trait_selection = { path = "../rustc_trait_selection" } +rustc_span = { path = "../rustc_span" } + +[dev-dependencies] +coverage_test_macros = { path = "src/coverage/test_macros" } diff --git a/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs b/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs new file mode 100644 index 000000000..2502e8b60 --- /dev/null +++ b/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs @@ -0,0 +1,140 @@ +use crate::MirPass; +use rustc_ast::InlineAsmOptions; +use rustc_middle::mir::*; +use rustc_middle::ty::layout; +use rustc_middle::ty::{self, TyCtxt}; +use rustc_target::spec::abi::Abi; +use rustc_target::spec::PanicStrategy; + +/// A pass that runs which is targeted at ensuring that codegen guarantees about +/// unwinding are upheld for compilations of panic=abort programs. +/// +/// When compiling with panic=abort codegen backends generally want to assume +/// that all Rust-defined functions do not unwind, and it's UB if they actually +/// do unwind. Foreign functions, however, can be declared as "may unwind" via +/// their ABI (e.g. `extern "C-unwind"`). To uphold the guarantees that +/// Rust-defined functions never unwind a well-behaved Rust program needs to +/// catch unwinding from foreign functions and force them to abort. +/// +/// This pass walks over all functions calls which may possibly unwind, +/// and if any are found sets their cleanup to a block that aborts the process. +/// This forces all unwinds, in panic=abort mode happening in foreign code, to +/// trigger a process abort. +#[derive(PartialEq)] +pub struct AbortUnwindingCalls; + +impl<'tcx> MirPass<'tcx> for AbortUnwindingCalls { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let def_id = body.source.def_id(); + let kind = tcx.def_kind(def_id); + + // We don't simplify the MIR of constants at this time because that + // namely results in a cyclic query when we call `tcx.type_of` below. + if !kind.is_fn_like() { + return; + } + + // This pass only runs on functions which themselves cannot unwind, + // forcibly changing the body of the function to structurally provide + // this guarantee by aborting on an unwind. If this function can unwind, + // then there's nothing to do because it already should work correctly. + // + // Here we test for this function itself whether its ABI allows + // unwinding or not. + let body_ty = tcx.type_of(def_id); + let body_abi = match body_ty.kind() { + ty::FnDef(..) => body_ty.fn_sig(tcx).abi(), + ty::Closure(..) => Abi::RustCall, + ty::Generator(..) => Abi::Rust, + _ => span_bug!(body.span, "unexpected body ty: {:?}", body_ty), + }; + let body_can_unwind = layout::fn_can_unwind(tcx, Some(def_id), body_abi); + + // Look in this function body for any basic blocks which are terminated + // with a function call, and whose function we're calling may unwind. + // This will filter to functions with `extern "C-unwind"` ABIs, for + // example. + let mut calls_to_terminate = Vec::new(); + let mut cleanups_to_remove = Vec::new(); + for (id, block) in body.basic_blocks().iter_enumerated() { + if block.is_cleanup { + continue; + } + let Some(terminator) = &block.terminator else { continue }; + let span = terminator.source_info.span; + + let call_can_unwind = match &terminator.kind { + TerminatorKind::Call { func, .. } => { + let ty = func.ty(body, tcx); + let sig = ty.fn_sig(tcx); + let fn_def_id = match ty.kind() { + ty::FnPtr(_) => None, + &ty::FnDef(def_id, _) => Some(def_id), + _ => span_bug!(span, "invalid callee of type {:?}", ty), + }; + layout::fn_can_unwind(tcx, fn_def_id, sig.abi()) + } + TerminatorKind::Drop { .. } | TerminatorKind::DropAndReplace { .. } => { + tcx.sess.opts.unstable_opts.panic_in_drop == PanicStrategy::Unwind + && layout::fn_can_unwind(tcx, None, Abi::Rust) + } + TerminatorKind::Assert { .. } | TerminatorKind::FalseUnwind { .. } => { + layout::fn_can_unwind(tcx, None, Abi::Rust) + } + TerminatorKind::InlineAsm { options, .. } => { + options.contains(InlineAsmOptions::MAY_UNWIND) + } + _ if terminator.unwind().is_some() => { + span_bug!(span, "unexpected terminator that may unwind {:?}", terminator) + } + _ => continue, + }; + + // If this function call can't unwind, then there's no need for it + // to have a landing pad. This means that we can remove any cleanup + // registered for it. + if !call_can_unwind { + cleanups_to_remove.push(id); + continue; + } + + // Otherwise if this function can unwind, then if the outer function + // can also unwind there's nothing to do. If the outer function + // can't unwind, however, we need to change the landing pad for this + // function call to one that aborts. + if !body_can_unwind { + calls_to_terminate.push(id); + } + } + + // For call instructions which need to be terminated, we insert a + // singular basic block which simply terminates, and then configure the + // `cleanup` attribute for all calls we found to this basic block we + // insert which means that any unwinding that happens in the functions + // will force an abort of the process. + if !calls_to_terminate.is_empty() { + let bb = BasicBlockData { + statements: Vec::new(), + is_cleanup: true, + terminator: Some(Terminator { + source_info: SourceInfo::outermost(body.span), + kind: TerminatorKind::Abort, + }), + }; + let abort_bb = body.basic_blocks_mut().push(bb); + + for bb in calls_to_terminate { + let cleanup = body.basic_blocks_mut()[bb].terminator_mut().unwind_mut().unwrap(); + *cleanup = Some(abort_bb); + } + } + + for id in cleanups_to_remove { + let cleanup = body.basic_blocks_mut()[id].terminator_mut().unwind_mut().unwrap(); + *cleanup = None; + } + + // We may have invalidated some `cleanup` blocks so clean those up now. + super::simplify::remove_dead_blocks(tcx, body); + } +} diff --git a/compiler/rustc_mir_transform/src/add_call_guards.rs b/compiler/rustc_mir_transform/src/add_call_guards.rs new file mode 100644 index 000000000..f12c8560c --- /dev/null +++ b/compiler/rustc_mir_transform/src/add_call_guards.rs @@ -0,0 +1,81 @@ +use crate::MirPass; +use rustc_index::vec::{Idx, IndexVec}; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; + +#[derive(PartialEq)] +pub enum AddCallGuards { + AllCallEdges, + CriticalCallEdges, +} +pub use self::AddCallGuards::*; + +/** + * Breaks outgoing critical edges for call terminators in the MIR. + * + * Critical edges are edges that are neither the only edge leaving a + * block, nor the only edge entering one. + * + * When you want something to happen "along" an edge, you can either + * do at the end of the predecessor block, or at the start of the + * successor block. Critical edges have to be broken in order to prevent + * "edge actions" from affecting other edges. We need this for calls that are + * codegened to LLVM invoke instructions, because invoke is a block terminator + * in LLVM so we can't insert any code to handle the call's result into the + * block that performs the call. + * + * This function will break those edges by inserting new blocks along them. + * + * NOTE: Simplify CFG will happily undo most of the work this pass does. + * + */ + +impl<'tcx> MirPass<'tcx> for AddCallGuards { + fn run_pass(&self, _tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + self.add_call_guards(body); + } +} + +impl AddCallGuards { + pub fn add_call_guards(&self, body: &mut Body<'_>) { + let mut pred_count: IndexVec<_, _> = + body.basic_blocks.predecessors().iter().map(|ps| ps.len()).collect(); + pred_count[START_BLOCK] += 1; + + // We need a place to store the new blocks generated + let mut new_blocks = Vec::new(); + + let cur_len = body.basic_blocks().len(); + + for block in body.basic_blocks_mut() { + match block.terminator { + Some(Terminator { + kind: TerminatorKind::Call { target: Some(ref mut destination), cleanup, .. }, + source_info, + }) if pred_count[*destination] > 1 + && (cleanup.is_some() || self == &AllCallEdges) => + { + // It's a critical edge, break it + let call_guard = BasicBlockData { + statements: vec![], + is_cleanup: block.is_cleanup, + terminator: Some(Terminator { + source_info, + kind: TerminatorKind::Goto { target: *destination }, + }), + }; + + // Get the index it will be when inserted into the MIR + let idx = cur_len + new_blocks.len(); + new_blocks.push(call_guard); + *destination = BasicBlock::new(idx); + } + _ => {} + } + } + + debug!("Broke {} N edges", new_blocks.len()); + + body.basic_blocks_mut().extend(new_blocks); + } +} diff --git a/compiler/rustc_mir_transform/src/add_moves_for_packed_drops.rs b/compiler/rustc_mir_transform/src/add_moves_for_packed_drops.rs new file mode 100644 index 000000000..8de0aad04 --- /dev/null +++ b/compiler/rustc_mir_transform/src/add_moves_for_packed_drops.rs @@ -0,0 +1,107 @@ +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; + +use crate::util; +use crate::MirPass; +use rustc_middle::mir::patch::MirPatch; + +// This pass moves values being dropped that are within a packed +// struct to a separate local before dropping them, to ensure that +// they are dropped from an aligned address. +// +// For example, if we have something like +// ```Rust +// #[repr(packed)] +// struct Foo { +// dealign: u8, +// data: Vec<u8> +// } +// +// let foo = ...; +// ``` +// +// We want to call `drop_in_place::<Vec<u8>>` on `data` from an aligned +// address. This means we can't simply drop `foo.data` directly, because +// its address is not aligned. +// +// Instead, we move `foo.data` to a local and drop that: +// ``` +// storage.live(drop_temp) +// drop_temp = foo.data; +// drop(drop_temp) -> next +// next: +// storage.dead(drop_temp) +// ``` +// +// The storage instructions are required to avoid stack space +// blowup. + +pub struct AddMovesForPackedDrops; + +impl<'tcx> MirPass<'tcx> for AddMovesForPackedDrops { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + debug!("add_moves_for_packed_drops({:?} @ {:?})", body.source, body.span); + add_moves_for_packed_drops(tcx, body); + } +} + +pub fn add_moves_for_packed_drops<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let patch = add_moves_for_packed_drops_patch(tcx, body); + patch.apply(body); +} + +fn add_moves_for_packed_drops_patch<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> MirPatch<'tcx> { + let def_id = body.source.def_id(); + let mut patch = MirPatch::new(body); + let param_env = tcx.param_env(def_id); + + for (bb, data) in body.basic_blocks().iter_enumerated() { + let loc = Location { block: bb, statement_index: data.statements.len() }; + let terminator = data.terminator(); + + match terminator.kind { + TerminatorKind::Drop { place, .. } + if util::is_disaligned(tcx, body, param_env, place) => + { + add_move_for_packed_drop(tcx, body, &mut patch, terminator, loc, data.is_cleanup); + } + TerminatorKind::DropAndReplace { .. } => { + span_bug!(terminator.source_info.span, "replace in AddMovesForPackedDrops"); + } + _ => {} + } + } + + patch +} + +fn add_move_for_packed_drop<'tcx>( + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, + patch: &mut MirPatch<'tcx>, + terminator: &Terminator<'tcx>, + loc: Location, + is_cleanup: bool, +) { + debug!("add_move_for_packed_drop({:?} @ {:?})", terminator, loc); + let TerminatorKind::Drop { ref place, target, unwind } = terminator.kind else { + unreachable!(); + }; + + let source_info = terminator.source_info; + let ty = place.ty(body, tcx).ty; + let temp = patch.new_temp(ty, terminator.source_info.span); + + let storage_dead_block = patch.new_block(BasicBlockData { + statements: vec![Statement { source_info, kind: StatementKind::StorageDead(temp) }], + terminator: Some(Terminator { source_info, kind: TerminatorKind::Goto { target } }), + is_cleanup, + }); + + patch.add_statement(loc, StatementKind::StorageLive(temp)); + patch.add_assign(loc, Place::from(temp), Rvalue::Use(Operand::Move(*place))); + patch.patch_terminator( + loc.block, + TerminatorKind::Drop { place: Place::from(temp), target: storage_dead_block, unwind }, + ); +} diff --git a/compiler/rustc_mir_transform/src/add_retag.rs b/compiler/rustc_mir_transform/src/add_retag.rs new file mode 100644 index 000000000..9c5896c4e --- /dev/null +++ b/compiler/rustc_mir_transform/src/add_retag.rs @@ -0,0 +1,186 @@ +//! This pass adds validation calls (AcquireValid, ReleaseValid) where appropriate. +//! It has to be run really early, before transformations like inlining, because +//! introducing these calls *adds* UB -- so, conceptually, this pass is actually part +//! of MIR building, and only after this pass we think of the program has having the +//! normal MIR semantics. + +use crate::MirPass; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, Ty, TyCtxt}; + +pub struct AddRetag; + +/// Determines whether this place is "stable": Whether, if we evaluate it again +/// after the assignment, we can be sure to obtain the same place value. +/// (Concurrent accesses by other threads are no problem as these are anyway non-atomic +/// copies. Data races are UB.) +fn is_stable(place: PlaceRef<'_>) -> bool { + // Which place this evaluates to can change with any memory write, + // so cannot assume deref to be stable. + !place.has_deref() +} + +/// Determine whether this type may contain a reference (or box), and thus needs retagging. +/// We will only recurse `depth` times into Tuples/ADTs to bound the cost of this. +fn may_contain_reference<'tcx>(ty: Ty<'tcx>, depth: u32, tcx: TyCtxt<'tcx>) -> bool { + match ty.kind() { + // Primitive types that are not references + ty::Bool + | ty::Char + | ty::Float(_) + | ty::Int(_) + | ty::Uint(_) + | ty::RawPtr(..) + | ty::FnPtr(..) + | ty::Str + | ty::FnDef(..) + | ty::Never => false, + // References + ty::Ref(..) => true, + ty::Adt(..) if ty.is_box() => true, + // Compound types: recurse + ty::Array(ty, _) | ty::Slice(ty) => { + // This does not branch so we keep the depth the same. + may_contain_reference(*ty, depth, tcx) + } + ty::Tuple(tys) => { + depth == 0 || tys.iter().any(|ty| may_contain_reference(ty, depth - 1, tcx)) + } + ty::Adt(adt, subst) => { + depth == 0 + || adt.variants().iter().any(|v| { + v.fields.iter().any(|f| may_contain_reference(f.ty(tcx, subst), depth - 1, tcx)) + }) + } + // Conservative fallback + _ => true, + } +} + +impl<'tcx> MirPass<'tcx> for AddRetag { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.opts.unstable_opts.mir_emit_retag + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // We need an `AllCallEdges` pass before we can do any work. + super::add_call_guards::AllCallEdges.run_pass(tcx, body); + + let (span, arg_count) = (body.span, body.arg_count); + let basic_blocks = body.basic_blocks.as_mut(); + let local_decls = &body.local_decls; + let needs_retag = |place: &Place<'tcx>| { + // FIXME: Instead of giving up for unstable places, we should introduce + // a temporary and retag on that. + is_stable(place.as_ref()) + && may_contain_reference(place.ty(&*local_decls, tcx).ty, /*depth*/ 3, tcx) + && !local_decls[place.local].is_deref_temp() + }; + let place_base_raw = |place: &Place<'tcx>| { + // If this is a `Deref`, get the type of what we are deref'ing. + if place.has_deref() { + let ty = &local_decls[place.local].ty; + ty.is_unsafe_ptr() + } else { + // Not a deref, and thus not raw. + false + } + }; + + // PART 1 + // Retag arguments at the beginning of the start block. + { + // FIXME: Consider using just the span covering the function + // argument declaration. + let source_info = SourceInfo::outermost(span); + // Gather all arguments, skip return value. + let places = local_decls + .iter_enumerated() + .skip(1) + .take(arg_count) + .map(|(local, _)| Place::from(local)) + .filter(needs_retag); + // Emit their retags. + basic_blocks[START_BLOCK].statements.splice( + 0..0, + places.map(|place| Statement { + source_info, + kind: StatementKind::Retag(RetagKind::FnEntry, Box::new(place)), + }), + ); + } + + // PART 2 + // Retag return values of functions. Also escape-to-raw the argument of `drop`. + // We collect the return destinations because we cannot mutate while iterating. + let returns = basic_blocks + .iter_mut() + .filter_map(|block_data| { + match block_data.terminator().kind { + TerminatorKind::Call { target: Some(target), destination, .. } + if needs_retag(&destination) => + { + // Remember the return destination for later + Some((block_data.terminator().source_info, destination, target)) + } + + // `Drop` is also a call, but it doesn't return anything so we are good. + TerminatorKind::Drop { .. } | TerminatorKind::DropAndReplace { .. } => None, + // Not a block ending in a Call -> ignore. + _ => None, + } + }) + .collect::<Vec<_>>(); + // Now we go over the returns we collected to retag the return values. + for (source_info, dest_place, dest_block) in returns { + basic_blocks[dest_block].statements.insert( + 0, + Statement { + source_info, + kind: StatementKind::Retag(RetagKind::Default, Box::new(dest_place)), + }, + ); + } + + // PART 3 + // Add retag after assignment. + for block_data in basic_blocks { + // We want to insert statements as we iterate. To this end, we + // iterate backwards using indices. + for i in (0..block_data.statements.len()).rev() { + let (retag_kind, place) = match block_data.statements[i].kind { + // Retag-as-raw after escaping to a raw pointer, if the referent + // is not already a raw pointer. + StatementKind::Assign(box (lplace, Rvalue::AddressOf(_, ref rplace))) + if !place_base_raw(rplace) => + { + (RetagKind::Raw, lplace) + } + // Retag after assignments of reference type. + StatementKind::Assign(box (ref place, ref rvalue)) if needs_retag(place) => { + let kind = match rvalue { + Rvalue::Ref(_, borrow_kind, _) + if borrow_kind.allows_two_phase_borrow() => + { + RetagKind::TwoPhase + } + _ => RetagKind::Default, + }; + (kind, *place) + } + // Do nothing for the rest + _ => continue, + }; + // Insert a retag after the statement. + let source_info = block_data.statements[i].source_info; + block_data.statements.insert( + i + 1, + Statement { + source_info, + kind: StatementKind::Retag(retag_kind, Box::new(place)), + }, + ); + } + } + } +} diff --git a/compiler/rustc_mir_transform/src/check_const_item_mutation.rs b/compiler/rustc_mir_transform/src/check_const_item_mutation.rs new file mode 100644 index 000000000..8838b14c5 --- /dev/null +++ b/compiler/rustc_mir_transform/src/check_const_item_mutation.rs @@ -0,0 +1,156 @@ +use rustc_errors::{DiagnosticBuilder, LintDiagnosticBuilder}; +use rustc_middle::mir::visit::Visitor; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use rustc_session::lint::builtin::CONST_ITEM_MUTATION; +use rustc_span::def_id::DefId; + +use crate::MirLint; + +pub struct CheckConstItemMutation; + +impl<'tcx> MirLint<'tcx> for CheckConstItemMutation { + fn run_lint(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) { + let mut checker = ConstMutationChecker { body, tcx, target_local: None }; + checker.visit_body(&body); + } +} + +struct ConstMutationChecker<'a, 'tcx> { + body: &'a Body<'tcx>, + tcx: TyCtxt<'tcx>, + target_local: Option<Local>, +} + +impl<'tcx> ConstMutationChecker<'_, 'tcx> { + fn is_const_item(&self, local: Local) -> Option<DefId> { + if let Some(box LocalInfo::ConstRef { def_id }) = self.body.local_decls[local].local_info { + Some(def_id) + } else { + None + } + } + + fn is_const_item_without_destructor(&self, local: Local) -> Option<DefId> { + let def_id = self.is_const_item(local)?; + + // We avoid linting mutation of a const item if the const's type has a + // Drop impl. The Drop logic observes the mutation which was performed. + // + // pub struct Log { msg: &'static str } + // pub const LOG: Log = Log { msg: "" }; + // impl Drop for Log { + // fn drop(&mut self) { println!("{}", self.msg); } + // } + // + // LOG.msg = "wow"; // prints "wow" + // + // FIXME(https://github.com/rust-lang/rust/issues/77425): + // Drop this exception once there is a stable attribute to suppress the + // const item mutation lint for a single specific const only. Something + // equivalent to: + // + // #[const_mutation_allowed] + // pub const LOG: Log = Log { msg: "" }; + match self.tcx.calculate_dtor(def_id, |_, _| Ok(())) { + Some(_) => None, + None => Some(def_id), + } + } + + fn lint_const_item_usage( + &self, + place: &Place<'tcx>, + const_item: DefId, + location: Location, + decorate: impl for<'b> FnOnce(LintDiagnosticBuilder<'b, ()>) -> DiagnosticBuilder<'b, ()>, + ) { + // Don't lint on borrowing/assigning when a dereference is involved. + // If we 'leave' the temporary via a dereference, we must + // be modifying something else + // + // `unsafe { *FOO = 0; *BAR.field = 1; }` + // `unsafe { &mut *FOO }` + // `unsafe { (*ARRAY)[0] = val; } + if !place.projection.iter().any(|p| matches!(p, PlaceElem::Deref)) { + let source_info = self.body.source_info(location); + let lint_root = self.body.source_scopes[source_info.scope] + .local_data + .as_ref() + .assert_crate_local() + .lint_root; + + self.tcx.struct_span_lint_hir( + CONST_ITEM_MUTATION, + lint_root, + source_info.span, + |lint| { + decorate(lint) + .span_note(self.tcx.def_span(const_item), "`const` item defined here") + .emit(); + }, + ); + } + } +} + +impl<'tcx> Visitor<'tcx> for ConstMutationChecker<'_, 'tcx> { + fn visit_statement(&mut self, stmt: &Statement<'tcx>, loc: Location) { + if let StatementKind::Assign(box (lhs, _)) = &stmt.kind { + // Check for assignment to fields of a constant + // Assigning directly to a constant (e.g. `FOO = true;`) is a hard error, + // so emitting a lint would be redundant. + if !lhs.projection.is_empty() { + if let Some(def_id) = self.is_const_item_without_destructor(lhs.local) { + self.lint_const_item_usage(&lhs, def_id, loc, |lint| { + let mut lint = lint.build("attempting to modify a `const` item"); + lint.note("each usage of a `const` item creates a new temporary; the original `const` item will not be modified"); + lint + }) + } + } + // We are looking for MIR of the form: + // + // ``` + // _1 = const FOO; + // _2 = &mut _1; + // method_call(_2, ..) + // ``` + // + // Record our current LHS, so that we can detect this + // pattern in `visit_rvalue` + self.target_local = lhs.as_local(); + } + self.super_statement(stmt, loc); + self.target_local = None; + } + fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, loc: Location) { + if let Rvalue::Ref(_, BorrowKind::Mut { .. }, place) = rvalue { + let local = place.local; + if let Some(def_id) = self.is_const_item(local) { + // If this Rvalue is being used as the right-hand side of a + // `StatementKind::Assign`, see if it ends up getting used as + // the `self` parameter of a method call (as the terminator of our current + // BasicBlock). If so, we emit a more specific lint. + let method_did = self.target_local.and_then(|target_local| { + crate::util::find_self_call(self.tcx, &self.body, target_local, loc.block) + }); + let lint_loc = + if method_did.is_some() { self.body.terminator_loc(loc.block) } else { loc }; + self.lint_const_item_usage(place, def_id, lint_loc, |lint| { + let mut lint = lint.build("taking a mutable reference to a `const` item"); + lint + .note("each usage of a `const` item creates a new temporary") + .note("the mutable reference will refer to this temporary, not the original `const` item"); + + if let Some((method_did, _substs)) = method_did { + lint.span_note(self.tcx.def_span(method_did), "mutable reference created due to call to this method"); + } + + lint + }); + } + } + self.super_rvalue(rvalue, loc); + } +} diff --git a/compiler/rustc_mir_transform/src/check_packed_ref.rs b/compiler/rustc_mir_transform/src/check_packed_ref.rs new file mode 100644 index 000000000..3b7ba3f9a --- /dev/null +++ b/compiler/rustc_mir_transform/src/check_packed_ref.rs @@ -0,0 +1,108 @@ +use rustc_hir::def_id::LocalDefId; +use rustc_middle::mir::visit::{PlaceContext, Visitor}; +use rustc_middle::mir::*; +use rustc_middle::ty::query::Providers; +use rustc_middle::ty::{self, TyCtxt}; +use rustc_session::lint::builtin::UNALIGNED_REFERENCES; + +use crate::util; +use crate::MirLint; + +pub(crate) fn provide(providers: &mut Providers) { + *providers = Providers { unsafe_derive_on_repr_packed, ..*providers }; +} + +pub struct CheckPackedRef; + +impl<'tcx> MirLint<'tcx> for CheckPackedRef { + fn run_lint(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) { + let param_env = tcx.param_env(body.source.def_id()); + let source_info = SourceInfo::outermost(body.span); + let mut checker = PackedRefChecker { body, tcx, param_env, source_info }; + checker.visit_body(&body); + } +} + +struct PackedRefChecker<'a, 'tcx> { + body: &'a Body<'tcx>, + tcx: TyCtxt<'tcx>, + param_env: ty::ParamEnv<'tcx>, + source_info: SourceInfo, +} + +fn unsafe_derive_on_repr_packed(tcx: TyCtxt<'_>, def_id: LocalDefId) { + let lint_hir_id = tcx.hir().local_def_id_to_hir_id(def_id); + + tcx.struct_span_lint_hir(UNALIGNED_REFERENCES, lint_hir_id, tcx.def_span(def_id), |lint| { + // FIXME: when we make this a hard error, this should have its + // own error code. + let extra = if tcx.generics_of(def_id).own_requires_monomorphization() { + "with type or const parameters" + } else { + "that does not derive `Copy`" + }; + let message = format!( + "`{}` can't be derived on this `#[repr(packed)]` struct {}", + tcx.item_name(tcx.trait_id_of_impl(def_id.to_def_id()).expect("derived trait name")), + extra + ); + lint.build(message).emit(); + }); +} + +impl<'tcx> Visitor<'tcx> for PackedRefChecker<'_, 'tcx> { + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { + // Make sure we know where in the MIR we are. + self.source_info = terminator.source_info; + self.super_terminator(terminator, location); + } + + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + // Make sure we know where in the MIR we are. + self.source_info = statement.source_info; + self.super_statement(statement, location); + } + + fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) { + if context.is_borrow() { + if util::is_disaligned(self.tcx, self.body, self.param_env, *place) { + let def_id = self.body.source.instance.def_id(); + if let Some(impl_def_id) = self + .tcx + .impl_of_method(def_id) + .filter(|&def_id| self.tcx.is_builtin_derive(def_id)) + { + // If a method is defined in the local crate, + // the impl containing that method should also be. + self.tcx.ensure().unsafe_derive_on_repr_packed(impl_def_id.expect_local()); + } else { + let source_info = self.source_info; + let lint_root = self.body.source_scopes[source_info.scope] + .local_data + .as_ref() + .assert_crate_local() + .lint_root; + self.tcx.struct_span_lint_hir( + UNALIGNED_REFERENCES, + lint_root, + source_info.span, + |lint| { + lint.build("reference to packed field is unaligned") + .note( + "fields of packed structs are not properly aligned, and creating \ + a misaligned reference is undefined behavior (even if that \ + reference is never dereferenced)", + ) + .help( + "copy the field contents to a local variable, or replace the \ + reference with a raw pointer and use `read_unaligned`/`write_unaligned` \ + (loads and stores via `*p` must be properly aligned even when using raw pointers)" + ) + .emit(); + }, + ); + } + } + } + } +} diff --git a/compiler/rustc_mir_transform/src/check_unsafety.rs b/compiler/rustc_mir_transform/src/check_unsafety.rs new file mode 100644 index 000000000..d564f4801 --- /dev/null +++ b/compiler/rustc_mir_transform/src/check_unsafety.rs @@ -0,0 +1,619 @@ +use rustc_data_structures::fx::FxHashMap; +use rustc_errors::struct_span_err; +use rustc_hir as hir; +use rustc_hir::def_id::{DefId, LocalDefId}; +use rustc_hir::hir_id::HirId; +use rustc_hir::intravisit; +use rustc_middle::mir::visit::{MutatingUseContext, PlaceContext, Visitor}; +use rustc_middle::ty::query::Providers; +use rustc_middle::ty::{self, TyCtxt}; +use rustc_middle::{lint, mir::*}; +use rustc_session::lint::builtin::{UNSAFE_OP_IN_UNSAFE_FN, UNUSED_UNSAFE}; +use rustc_session::lint::Level; + +use std::collections::hash_map; +use std::ops::Bound; + +pub struct UnsafetyChecker<'a, 'tcx> { + body: &'a Body<'tcx>, + body_did: LocalDefId, + violations: Vec<UnsafetyViolation>, + source_info: SourceInfo, + tcx: TyCtxt<'tcx>, + param_env: ty::ParamEnv<'tcx>, + + /// Used `unsafe` blocks in this function. This is used for the "unused_unsafe" lint. + /// + /// The keys are the used `unsafe` blocks, the UnusedUnsafeKind indicates whether + /// or not any of the usages happen at a place that doesn't allow `unsafe_op_in_unsafe_fn`. + used_unsafe_blocks: FxHashMap<HirId, UsedUnsafeBlockData>, +} + +impl<'a, 'tcx> UnsafetyChecker<'a, 'tcx> { + fn new( + body: &'a Body<'tcx>, + body_did: LocalDefId, + tcx: TyCtxt<'tcx>, + param_env: ty::ParamEnv<'tcx>, + ) -> Self { + Self { + body, + body_did, + violations: vec![], + source_info: SourceInfo::outermost(body.span), + tcx, + param_env, + used_unsafe_blocks: Default::default(), + } + } +} + +impl<'tcx> Visitor<'tcx> for UnsafetyChecker<'_, 'tcx> { + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { + self.source_info = terminator.source_info; + match terminator.kind { + TerminatorKind::Goto { .. } + | TerminatorKind::SwitchInt { .. } + | TerminatorKind::Drop { .. } + | TerminatorKind::Yield { .. } + | TerminatorKind::Assert { .. } + | TerminatorKind::DropAndReplace { .. } + | TerminatorKind::GeneratorDrop + | TerminatorKind::Resume + | TerminatorKind::Abort + | TerminatorKind::Return + | TerminatorKind::Unreachable + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } => { + // safe (at least as emitted during MIR construction) + } + + TerminatorKind::Call { ref func, .. } => { + let func_ty = func.ty(self.body, self.tcx); + let func_id = + if let ty::FnDef(func_id, _) = func_ty.kind() { Some(func_id) } else { None }; + let sig = func_ty.fn_sig(self.tcx); + if let hir::Unsafety::Unsafe = sig.unsafety() { + self.require_unsafe( + UnsafetyViolationKind::General, + UnsafetyViolationDetails::CallToUnsafeFunction, + ) + } + + if let Some(func_id) = func_id { + self.check_target_features(*func_id); + } + } + + TerminatorKind::InlineAsm { .. } => self.require_unsafe( + UnsafetyViolationKind::General, + UnsafetyViolationDetails::UseOfInlineAssembly, + ), + } + self.super_terminator(terminator, location); + } + + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + self.source_info = statement.source_info; + match statement.kind { + StatementKind::Assign(..) + | StatementKind::FakeRead(..) + | StatementKind::SetDiscriminant { .. } + | StatementKind::Deinit(..) + | StatementKind::StorageLive(..) + | StatementKind::StorageDead(..) + | StatementKind::Retag { .. } + | StatementKind::AscribeUserType(..) + | StatementKind::Coverage(..) + | StatementKind::Nop => { + // safe (at least as emitted during MIR construction) + } + + StatementKind::CopyNonOverlapping(..) => unreachable!(), + } + self.super_statement(statement, location); + } + + fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) { + match rvalue { + Rvalue::Aggregate(box ref aggregate, _) => match aggregate { + &AggregateKind::Array(..) | &AggregateKind::Tuple => {} + &AggregateKind::Adt(adt_did, ..) => { + match self.tcx.layout_scalar_valid_range(adt_did) { + (Bound::Unbounded, Bound::Unbounded) => {} + _ => self.require_unsafe( + UnsafetyViolationKind::General, + UnsafetyViolationDetails::InitializingTypeWith, + ), + } + } + &AggregateKind::Closure(def_id, _) | &AggregateKind::Generator(def_id, _, _) => { + let UnsafetyCheckResult { violations, used_unsafe_blocks, .. } = + self.tcx.unsafety_check_result(def_id); + self.register_violations( + violations, + used_unsafe_blocks.iter().map(|(&h, &d)| (h, d)), + ); + } + }, + _ => {} + } + self.super_rvalue(rvalue, location); + } + + fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) { + // On types with `scalar_valid_range`, prevent + // * `&mut x.field` + // * `x.field = y;` + // * `&x.field` if `field`'s type has interior mutability + // because either of these would allow modifying the layout constrained field and + // insert values that violate the layout constraints. + if context.is_mutating_use() || context.is_borrow() { + self.check_mut_borrowing_layout_constrained_field(*place, context.is_mutating_use()); + } + + // Some checks below need the extra meta info of the local declaration. + let decl = &self.body.local_decls[place.local]; + + // Check the base local: it might be an unsafe-to-access static. We only check derefs of the + // temporary holding the static pointer to avoid duplicate errors + // <https://github.com/rust-lang/rust/pull/78068#issuecomment-731753506>. + if decl.internal && place.projection.first() == Some(&ProjectionElem::Deref) { + // If the projection root is an artificial local that we introduced when + // desugaring `static`, give a more specific error message + // (avoid the general "raw pointer" clause below, that would only be confusing). + if let Some(box LocalInfo::StaticRef { def_id, .. }) = decl.local_info { + if self.tcx.is_mutable_static(def_id) { + self.require_unsafe( + UnsafetyViolationKind::General, + UnsafetyViolationDetails::UseOfMutableStatic, + ); + return; + } else if self.tcx.is_foreign_item(def_id) { + self.require_unsafe( + UnsafetyViolationKind::General, + UnsafetyViolationDetails::UseOfExternStatic, + ); + return; + } + } + } + + // Check for raw pointer `Deref`. + for (base, proj) in place.iter_projections() { + if proj == ProjectionElem::Deref { + let base_ty = base.ty(self.body, self.tcx).ty; + if base_ty.is_unsafe_ptr() { + self.require_unsafe( + UnsafetyViolationKind::General, + UnsafetyViolationDetails::DerefOfRawPointer, + ) + } + } + } + + // Check for union fields. For this we traverse right-to-left, as the last `Deref` changes + // whether we *read* the union field or potentially *write* to it (if this place is being assigned to). + let mut saw_deref = false; + for (base, proj) in place.iter_projections().rev() { + if proj == ProjectionElem::Deref { + saw_deref = true; + continue; + } + + let base_ty = base.ty(self.body, self.tcx).ty; + if base_ty.is_union() { + // If we did not hit a `Deref` yet and the overall place use is an assignment, the + // rules are different. + let assign_to_field = !saw_deref + && matches!( + context, + PlaceContext::MutatingUse( + MutatingUseContext::Store + | MutatingUseContext::Drop + | MutatingUseContext::AsmOutput + ) + ); + // If this is just an assignment, determine if the assigned type needs dropping. + if assign_to_field { + // We have to check the actual type of the assignment, as that determines if the + // old value is being dropped. + let assigned_ty = place.ty(&self.body.local_decls, self.tcx).ty; + if assigned_ty.needs_drop(self.tcx, self.param_env) { + // This would be unsafe, but should be outright impossible since we reject such unions. + self.tcx.sess.delay_span_bug( + self.source_info.span, + format!("union fields that need dropping should be impossible: {assigned_ty}") + ); + } + } else { + self.require_unsafe( + UnsafetyViolationKind::General, + UnsafetyViolationDetails::AccessToUnionField, + ) + } + } + } + } +} + +impl<'tcx> UnsafetyChecker<'_, 'tcx> { + fn require_unsafe(&mut self, kind: UnsafetyViolationKind, details: UnsafetyViolationDetails) { + // Violations can turn out to be `UnsafeFn` during analysis, but they should not start out as such. + assert_ne!(kind, UnsafetyViolationKind::UnsafeFn); + + let source_info = self.source_info; + let lint_root = self.body.source_scopes[self.source_info.scope] + .local_data + .as_ref() + .assert_crate_local() + .lint_root; + self.register_violations( + [&UnsafetyViolation { source_info, lint_root, kind, details }], + [], + ); + } + + fn register_violations<'a>( + &mut self, + violations: impl IntoIterator<Item = &'a UnsafetyViolation>, + new_used_unsafe_blocks: impl IntoIterator<Item = (HirId, UsedUnsafeBlockData)>, + ) { + use UsedUnsafeBlockData::{AllAllowedInUnsafeFn, SomeDisallowedInUnsafeFn}; + + let update_entry = |this: &mut Self, hir_id, new_usage| { + match this.used_unsafe_blocks.entry(hir_id) { + hash_map::Entry::Occupied(mut entry) => { + if new_usage == SomeDisallowedInUnsafeFn { + *entry.get_mut() = SomeDisallowedInUnsafeFn; + } + } + hash_map::Entry::Vacant(entry) => { + entry.insert(new_usage); + } + }; + }; + let safety = self.body.source_scopes[self.source_info.scope] + .local_data + .as_ref() + .assert_crate_local() + .safety; + match safety { + // `unsafe` blocks are required in safe code + Safety::Safe => violations.into_iter().for_each(|&violation| { + match violation.kind { + UnsafetyViolationKind::General => {} + UnsafetyViolationKind::UnsafeFn => { + bug!("`UnsafetyViolationKind::UnsafeFn` in an `Safe` context") + } + } + if !self.violations.contains(&violation) { + self.violations.push(violation) + } + }), + // With the RFC 2585, no longer allow `unsafe` operations in `unsafe fn`s + Safety::FnUnsafe => violations.into_iter().for_each(|&(mut violation)| { + violation.kind = UnsafetyViolationKind::UnsafeFn; + if !self.violations.contains(&violation) { + self.violations.push(violation) + } + }), + Safety::BuiltinUnsafe => {} + Safety::ExplicitUnsafe(hir_id) => violations.into_iter().for_each(|violation| { + update_entry( + self, + hir_id, + match self.tcx.lint_level_at_node(UNSAFE_OP_IN_UNSAFE_FN, violation.lint_root).0 + { + Level::Allow => AllAllowedInUnsafeFn(violation.lint_root), + _ => SomeDisallowedInUnsafeFn, + }, + ) + }), + }; + + new_used_unsafe_blocks + .into_iter() + .for_each(|(hir_id, usage_data)| update_entry(self, hir_id, usage_data)); + } + fn check_mut_borrowing_layout_constrained_field( + &mut self, + place: Place<'tcx>, + is_mut_use: bool, + ) { + for (place_base, elem) in place.iter_projections().rev() { + match elem { + // Modifications behind a dereference don't affect the value of + // the pointer. + ProjectionElem::Deref => return, + ProjectionElem::Field(..) => { + let ty = place_base.ty(&self.body.local_decls, self.tcx).ty; + if let ty::Adt(def, _) = ty.kind() { + if self.tcx.layout_scalar_valid_range(def.did()) + != (Bound::Unbounded, Bound::Unbounded) + { + let details = if is_mut_use { + UnsafetyViolationDetails::MutationOfLayoutConstrainedField + + // Check `is_freeze` as late as possible to avoid cycle errors + // with opaque types. + } else if !place + .ty(self.body, self.tcx) + .ty + .is_freeze(self.tcx.at(self.source_info.span), self.param_env) + { + UnsafetyViolationDetails::BorrowOfLayoutConstrainedField + } else { + continue; + }; + self.require_unsafe(UnsafetyViolationKind::General, details); + } + } + } + _ => {} + } + } + } + + /// Checks whether calling `func_did` needs an `unsafe` context or not, i.e. whether + /// the called function has target features the calling function hasn't. + fn check_target_features(&mut self, func_did: DefId) { + // Unsafety isn't required on wasm targets. For more information see + // the corresponding check in typeck/src/collect.rs + if self.tcx.sess.target.options.is_like_wasm { + return; + } + + let callee_features = &self.tcx.codegen_fn_attrs(func_did).target_features; + // The body might be a constant, so it doesn't have codegen attributes. + let self_features = &self.tcx.body_codegen_attrs(self.body_did.to_def_id()).target_features; + + // Is `callee_features` a subset of `calling_features`? + if !callee_features.iter().all(|feature| self_features.contains(feature)) { + self.require_unsafe( + UnsafetyViolationKind::General, + UnsafetyViolationDetails::CallToFunctionWith, + ) + } + } +} + +pub(crate) fn provide(providers: &mut Providers) { + *providers = Providers { + unsafety_check_result: |tcx, def_id| { + if let Some(def) = ty::WithOptConstParam::try_lookup(def_id, tcx) { + tcx.unsafety_check_result_for_const_arg(def) + } else { + unsafety_check_result(tcx, ty::WithOptConstParam::unknown(def_id)) + } + }, + unsafety_check_result_for_const_arg: |tcx, (did, param_did)| { + unsafety_check_result( + tcx, + ty::WithOptConstParam { did, const_param_did: Some(param_did) }, + ) + }, + ..*providers + }; +} + +/// Context information for [`UnusedUnsafeVisitor`] traversal, +/// saves (innermost) relevant context +#[derive(Copy, Clone, Debug)] +enum Context { + Safe, + /// in an `unsafe fn` + UnsafeFn(HirId), + /// in a *used* `unsafe` block + /// (i.e. a block without unused-unsafe warning) + UnsafeBlock(HirId), +} + +struct UnusedUnsafeVisitor<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + used_unsafe_blocks: &'a FxHashMap<HirId, UsedUnsafeBlockData>, + context: Context, + unused_unsafes: &'a mut Vec<(HirId, UnusedUnsafe)>, +} + +impl<'tcx> intravisit::Visitor<'tcx> for UnusedUnsafeVisitor<'_, 'tcx> { + fn visit_block(&mut self, block: &'tcx hir::Block<'tcx>) { + use UsedUnsafeBlockData::{AllAllowedInUnsafeFn, SomeDisallowedInUnsafeFn}; + + if let hir::BlockCheckMode::UnsafeBlock(hir::UnsafeSource::UserProvided) = block.rules { + let used = match self.tcx.lint_level_at_node(UNUSED_UNSAFE, block.hir_id) { + (Level::Allow, _) => Some(SomeDisallowedInUnsafeFn), + _ => self.used_unsafe_blocks.get(&block.hir_id).copied(), + }; + let unused_unsafe = match (self.context, used) { + (_, None) => UnusedUnsafe::Unused, + (Context::Safe, Some(_)) + | (Context::UnsafeFn(_), Some(SomeDisallowedInUnsafeFn)) => { + let previous_context = self.context; + self.context = Context::UnsafeBlock(block.hir_id); + intravisit::walk_block(self, block); + self.context = previous_context; + return; + } + (Context::UnsafeFn(hir_id), Some(AllAllowedInUnsafeFn(lint_root))) => { + UnusedUnsafe::InUnsafeFn(hir_id, lint_root) + } + (Context::UnsafeBlock(hir_id), Some(_)) => UnusedUnsafe::InUnsafeBlock(hir_id), + }; + self.unused_unsafes.push((block.hir_id, unused_unsafe)); + } + intravisit::walk_block(self, block); + } + + fn visit_fn( + &mut self, + fk: intravisit::FnKind<'tcx>, + _fd: &'tcx hir::FnDecl<'tcx>, + b: hir::BodyId, + _s: rustc_span::Span, + _id: HirId, + ) { + if matches!(fk, intravisit::FnKind::Closure) { + self.visit_body(self.tcx.hir().body(b)) + } + } +} + +fn check_unused_unsafe( + tcx: TyCtxt<'_>, + def_id: LocalDefId, + used_unsafe_blocks: &FxHashMap<HirId, UsedUnsafeBlockData>, +) -> Vec<(HirId, UnusedUnsafe)> { + let body_id = tcx.hir().maybe_body_owned_by(def_id); + + let Some(body_id) = body_id else { + debug!("check_unused_unsafe({:?}) - no body found", def_id); + return vec![]; + }; + + let body = tcx.hir().body(body_id); + let hir_id = tcx.hir().local_def_id_to_hir_id(def_id); + let context = match tcx.hir().fn_sig_by_hir_id(hir_id) { + Some(sig) if sig.header.unsafety == hir::Unsafety::Unsafe => Context::UnsafeFn(hir_id), + _ => Context::Safe, + }; + + debug!( + "check_unused_unsafe({:?}, context={:?}, body={:?}, used_unsafe_blocks={:?})", + def_id, body, context, used_unsafe_blocks + ); + + let mut unused_unsafes = vec![]; + + let mut visitor = UnusedUnsafeVisitor { + tcx, + used_unsafe_blocks, + context, + unused_unsafes: &mut unused_unsafes, + }; + intravisit::Visitor::visit_body(&mut visitor, body); + + unused_unsafes +} + +fn unsafety_check_result<'tcx>( + tcx: TyCtxt<'tcx>, + def: ty::WithOptConstParam<LocalDefId>, +) -> &'tcx UnsafetyCheckResult { + debug!("unsafety_violations({:?})", def); + + // N.B., this borrow is valid because all the consumers of + // `mir_built` force this. + let body = &tcx.mir_built(def).borrow(); + + let param_env = tcx.param_env(def.did); + + let mut checker = UnsafetyChecker::new(body, def.did, tcx, param_env); + checker.visit_body(&body); + + let unused_unsafes = (!tcx.is_closure(def.did.to_def_id())) + .then(|| check_unused_unsafe(tcx, def.did, &checker.used_unsafe_blocks)); + + tcx.arena.alloc(UnsafetyCheckResult { + violations: checker.violations, + used_unsafe_blocks: checker.used_unsafe_blocks, + unused_unsafes, + }) +} + +fn report_unused_unsafe(tcx: TyCtxt<'_>, kind: UnusedUnsafe, id: HirId) { + let span = tcx.sess.source_map().guess_head_span(tcx.hir().span(id)); + tcx.struct_span_lint_hir(UNUSED_UNSAFE, id, span, |lint| { + let msg = "unnecessary `unsafe` block"; + let mut db = lint.build(msg); + db.span_label(span, msg); + match kind { + UnusedUnsafe::Unused => {} + UnusedUnsafe::InUnsafeBlock(id) => { + db.span_label( + tcx.sess.source_map().guess_head_span(tcx.hir().span(id)), + "because it's nested under this `unsafe` block", + ); + } + UnusedUnsafe::InUnsafeFn(id, usage_lint_root) => { + db.span_label( + tcx.sess.source_map().guess_head_span(tcx.hir().span(id)), + "because it's nested under this `unsafe` fn", + ) + .note( + "this `unsafe` block does contain unsafe operations, \ + but those are already allowed in an `unsafe fn`", + ); + let (level, source) = + tcx.lint_level_at_node(UNSAFE_OP_IN_UNSAFE_FN, usage_lint_root); + assert_eq!(level, Level::Allow); + lint::explain_lint_level_source( + UNSAFE_OP_IN_UNSAFE_FN, + Level::Allow, + source, + &mut db, + ); + } + } + + db.emit(); + }); +} + +pub fn check_unsafety(tcx: TyCtxt<'_>, def_id: LocalDefId) { + debug!("check_unsafety({:?})", def_id); + + // closures are handled by their parent fn. + if tcx.is_closure(def_id.to_def_id()) { + return; + } + + let UnsafetyCheckResult { violations, unused_unsafes, .. } = tcx.unsafety_check_result(def_id); + + for &UnsafetyViolation { source_info, lint_root, kind, details } in violations.iter() { + let (description, note) = details.description_and_note(); + + // Report an error. + let unsafe_fn_msg = + if unsafe_op_in_unsafe_fn_allowed(tcx, lint_root) { " function or" } else { "" }; + + match kind { + UnsafetyViolationKind::General => { + // once + struct_span_err!( + tcx.sess, + source_info.span, + E0133, + "{} is unsafe and requires unsafe{} block", + description, + unsafe_fn_msg, + ) + .span_label(source_info.span, description) + .note(note) + .emit(); + } + UnsafetyViolationKind::UnsafeFn => tcx.struct_span_lint_hir( + UNSAFE_OP_IN_UNSAFE_FN, + lint_root, + source_info.span, + |lint| { + lint.build(&format!( + "{} is unsafe and requires unsafe block (error E0133)", + description, + )) + .span_label(source_info.span, description) + .note(note) + .emit(); + }, + ), + } + } + + for &(block_id, kind) in unused_unsafes.as_ref().unwrap() { + report_unused_unsafe(tcx, kind, block_id); + } +} + +fn unsafe_op_in_unsafe_fn_allowed(tcx: TyCtxt<'_>, id: HirId) -> bool { + tcx.lint_level_at_node(UNSAFE_OP_IN_UNSAFE_FN, id).0 == Level::Allow +} diff --git a/compiler/rustc_mir_transform/src/cleanup_post_borrowck.rs b/compiler/rustc_mir_transform/src/cleanup_post_borrowck.rs new file mode 100644 index 000000000..611d29a4e --- /dev/null +++ b/compiler/rustc_mir_transform/src/cleanup_post_borrowck.rs @@ -0,0 +1,59 @@ +//! This module provides a pass to replacing the following statements with +//! [`Nop`]s +//! +//! - [`AscribeUserType`] +//! - [`FakeRead`] +//! - [`Assign`] statements with a [`Shallow`] borrow +//! +//! The `CleanFakeReadsAndBorrows` "pass" is actually implemented as two +//! traversals (aka visits) of the input MIR. The first traversal, +//! `DeleteAndRecordFakeReads`, deletes the fake reads and finds the +//! temporaries read by [`ForMatchGuard`] reads, and `DeleteFakeBorrows` +//! deletes the initialization of those temporaries. +//! +//! [`AscribeUserType`]: rustc_middle::mir::StatementKind::AscribeUserType +//! [`Shallow`]: rustc_middle::mir::BorrowKind::Shallow +//! [`FakeRead`]: rustc_middle::mir::StatementKind::FakeRead +//! [`Assign`]: rustc_middle::mir::StatementKind::Assign +//! [`ForMatchGuard`]: rustc_middle::mir::FakeReadCause::ForMatchGuard +//! [`Nop`]: rustc_middle::mir::StatementKind::Nop + +use crate::MirPass; +use rustc_middle::mir::visit::MutVisitor; +use rustc_middle::mir::{Body, BorrowKind, Location, Rvalue}; +use rustc_middle::mir::{Statement, StatementKind}; +use rustc_middle::ty::TyCtxt; + +pub struct CleanupNonCodegenStatements; + +pub struct DeleteNonCodegenStatements<'tcx> { + tcx: TyCtxt<'tcx>, +} + +impl<'tcx> MirPass<'tcx> for CleanupNonCodegenStatements { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let mut delete = DeleteNonCodegenStatements { tcx }; + delete.visit_body(body); + body.user_type_annotations.raw.clear(); + + for decl in &mut body.local_decls { + decl.user_ty = None; + } + } +} + +impl<'tcx> MutVisitor<'tcx> for DeleteNonCodegenStatements<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { + match statement.kind { + StatementKind::AscribeUserType(..) + | StatementKind::Assign(box (_, Rvalue::Ref(_, BorrowKind::Shallow, _))) + | StatementKind::FakeRead(..) => statement.make_nop(), + _ => (), + } + self.super_statement(statement, location); + } +} diff --git a/compiler/rustc_mir_transform/src/const_debuginfo.rs b/compiler/rustc_mir_transform/src/const_debuginfo.rs new file mode 100644 index 000000000..6f0ae4f07 --- /dev/null +++ b/compiler/rustc_mir_transform/src/const_debuginfo.rs @@ -0,0 +1,100 @@ +//! Finds locals which are assigned once to a const and unused except for debuginfo and converts +//! their debuginfo to use the const directly, allowing the local to be removed. + +use rustc_middle::{ + mir::{ + visit::{PlaceContext, Visitor}, + Body, Constant, Local, Location, Operand, Rvalue, StatementKind, VarDebugInfoContents, + }, + ty::TyCtxt, +}; + +use crate::MirPass; +use rustc_index::{bit_set::BitSet, vec::IndexVec}; + +pub struct ConstDebugInfo; + +impl<'tcx> MirPass<'tcx> for ConstDebugInfo { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.opts.unstable_opts.unsound_mir_opts && sess.mir_opt_level() > 0 + } + + fn run_pass(&self, _tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + trace!("running ConstDebugInfo on {:?}", body.source); + + for (local, constant) in find_optimization_oportunities(body) { + for debuginfo in &mut body.var_debug_info { + if let VarDebugInfoContents::Place(p) = debuginfo.value { + if p.local == local && p.projection.is_empty() { + trace!( + "changing debug info for {:?} from place {:?} to constant {:?}", + debuginfo.name, + p, + constant + ); + debuginfo.value = VarDebugInfoContents::Const(constant); + } + } + } + } + } +} + +struct LocalUseVisitor { + local_mutating_uses: IndexVec<Local, u8>, + local_assignment_locations: IndexVec<Local, Option<Location>>, +} + +fn find_optimization_oportunities<'tcx>(body: &Body<'tcx>) -> Vec<(Local, Constant<'tcx>)> { + let mut visitor = LocalUseVisitor { + local_mutating_uses: IndexVec::from_elem(0, &body.local_decls), + local_assignment_locations: IndexVec::from_elem(None, &body.local_decls), + }; + + visitor.visit_body(body); + + let mut locals_to_debuginfo = BitSet::new_empty(body.local_decls.len()); + for debuginfo in &body.var_debug_info { + if let VarDebugInfoContents::Place(p) = debuginfo.value && let Some(l) = p.as_local() { + locals_to_debuginfo.insert(l); + } + } + + let mut eligible_locals = Vec::new(); + for (local, mutating_uses) in visitor.local_mutating_uses.drain_enumerated(..) { + if mutating_uses != 1 || !locals_to_debuginfo.contains(local) { + continue; + } + + if let Some(location) = visitor.local_assignment_locations[local] { + let bb = &body[location.block]; + + // The value is assigned as the result of a call, not a constant + if bb.statements.len() == location.statement_index { + continue; + } + + if let StatementKind::Assign(box (p, Rvalue::Use(Operand::Constant(box c)))) = + &bb.statements[location.statement_index].kind + { + if let Some(local) = p.as_local() { + eligible_locals.push((local, *c)); + } + } + } + } + + eligible_locals +} + +impl Visitor<'_> for LocalUseVisitor { + fn visit_local(&mut self, local: Local, context: PlaceContext, location: Location) { + if context.is_mutating_use() { + self.local_mutating_uses[local] = self.local_mutating_uses[local].saturating_add(1); + + if context.is_place_assignment() { + self.local_assignment_locations[local] = Some(location); + } + } + } +} diff --git a/compiler/rustc_mir_transform/src/const_goto.rs b/compiler/rustc_mir_transform/src/const_goto.rs new file mode 100644 index 000000000..5acf939f0 --- /dev/null +++ b/compiler/rustc_mir_transform/src/const_goto.rs @@ -0,0 +1,117 @@ +//! This pass optimizes the following sequence +//! ```rust,ignore (example) +//! bb2: { +//! _2 = const true; +//! goto -> bb3; +//! } +//! +//! bb3: { +//! switchInt(_2) -> [false: bb4, otherwise: bb5]; +//! } +//! ``` +//! into +//! ```rust,ignore (example) +//! bb2: { +//! _2 = const true; +//! goto -> bb5; +//! } +//! ``` + +use crate::MirPass; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use rustc_middle::{mir::visit::Visitor, ty::ParamEnv}; + +use super::simplify::{simplify_cfg, simplify_locals}; + +pub struct ConstGoto; + +impl<'tcx> MirPass<'tcx> for ConstGoto { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 4 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + trace!("Running ConstGoto on {:?}", body.source); + let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); + let mut opt_finder = + ConstGotoOptimizationFinder { tcx, body, optimizations: vec![], param_env }; + opt_finder.visit_body(body); + let should_simplify = !opt_finder.optimizations.is_empty(); + for opt in opt_finder.optimizations { + let block = &mut body.basic_blocks_mut()[opt.bb_with_goto]; + block.statements.extend(opt.stmts_move_up); + let terminator = block.terminator_mut(); + let new_goto = TerminatorKind::Goto { target: opt.target_to_use_in_goto }; + debug!("SUCCESS: replacing `{:?}` with `{:?}`", terminator.kind, new_goto); + terminator.kind = new_goto; + } + + // if we applied optimizations, we potentially have some cfg to cleanup to + // make it easier for further passes + if should_simplify { + simplify_cfg(tcx, body); + simplify_locals(body, tcx); + } + } +} + +impl<'tcx> Visitor<'tcx> for ConstGotoOptimizationFinder<'_, 'tcx> { + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { + let _: Option<_> = try { + let target = terminator.kind.as_goto()?; + // We only apply this optimization if the last statement is a const assignment + let last_statement = self.body.basic_blocks()[location.block].statements.last()?; + + if let (place, Rvalue::Use(Operand::Constant(_const))) = + last_statement.kind.as_assign()? + { + // We found a constant being assigned to `place`. + // Now check that the target of this Goto switches on this place. + let target_bb = &self.body.basic_blocks()[target]; + + // The `StorageDead(..)` statement does not affect the functionality of mir. + // We can move this part of the statement up to the predecessor. + let mut stmts_move_up = Vec::new(); + for stmt in &target_bb.statements { + if let StatementKind::StorageDead(..) = stmt.kind { + stmts_move_up.push(stmt.clone()) + } else { + None?; + } + } + + let target_bb_terminator = target_bb.terminator(); + let (discr, switch_ty, targets) = target_bb_terminator.kind.as_switch()?; + if discr.place() == Some(*place) { + // We now know that the Switch matches on the const place, and it is statementless + // Now find which value in the Switch matches the const value. + let const_value = + _const.literal.try_eval_bits(self.tcx, self.param_env, switch_ty)?; + let target_to_use_in_goto = targets.target_for_value(const_value); + self.optimizations.push(OptimizationToApply { + bb_with_goto: location.block, + target_to_use_in_goto, + stmts_move_up, + }); + } + } + Some(()) + }; + + self.super_terminator(terminator, location); + } +} + +struct OptimizationToApply<'tcx> { + bb_with_goto: BasicBlock, + target_to_use_in_goto: BasicBlock, + stmts_move_up: Vec<Statement<'tcx>>, +} + +pub struct ConstGotoOptimizationFinder<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + body: &'a Body<'tcx>, + param_env: ParamEnv<'tcx>, + optimizations: Vec<OptimizationToApply<'tcx>>, +} diff --git a/compiler/rustc_mir_transform/src/const_prop.rs b/compiler/rustc_mir_transform/src/const_prop.rs new file mode 100644 index 000000000..fbc0a767f --- /dev/null +++ b/compiler/rustc_mir_transform/src/const_prop.rs @@ -0,0 +1,1142 @@ +//! Propagates constants for early reporting of statically known +//! assertion failures + +use std::cell::Cell; + +use rustc_ast::Mutability; +use rustc_data_structures::fx::FxHashSet; +use rustc_hir::def::DefKind; +use rustc_index::bit_set::BitSet; +use rustc_index::vec::IndexVec; +use rustc_middle::mir::visit::{ + MutVisitor, MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor, +}; +use rustc_middle::mir::{ + BasicBlock, BinOp, Body, Constant, ConstantKind, Local, LocalDecl, LocalKind, Location, + Operand, Place, Rvalue, SourceInfo, Statement, StatementKind, Terminator, TerminatorKind, UnOp, + RETURN_PLACE, +}; +use rustc_middle::ty::layout::{LayoutError, LayoutOf, LayoutOfHelpers, TyAndLayout}; +use rustc_middle::ty::subst::{InternalSubsts, Subst}; +use rustc_middle::ty::{self, ConstKind, Instance, ParamEnv, Ty, TyCtxt, TypeVisitable}; +use rustc_span::{def_id::DefId, Span}; +use rustc_target::abi::{self, HasDataLayout, Size, TargetDataLayout}; +use rustc_target::spec::abi::Abi as CallAbi; +use rustc_trait_selection::traits; + +use crate::MirPass; +use rustc_const_eval::interpret::{ + self, compile_time_machine, AllocId, ConstAllocation, ConstValue, CtfeValidationMode, Frame, + ImmTy, Immediate, InterpCx, InterpResult, LocalState, LocalValue, MemoryKind, OpTy, PlaceTy, + Pointer, Scalar, ScalarMaybeUninit, StackPopCleanup, StackPopUnwind, +}; + +/// The maximum number of bytes that we'll allocate space for a local or the return value. +/// Needed for #66397, because otherwise we eval into large places and that can cause OOM or just +/// Severely regress performance. +const MAX_ALLOC_LIMIT: u64 = 1024; + +/// Macro for machine-specific `InterpError` without allocation. +/// (These will never be shown to the user, but they help diagnose ICEs.) +macro_rules! throw_machine_stop_str { + ($($tt:tt)*) => {{ + // We make a new local type for it. The type itself does not carry any information, + // but its vtable (for the `MachineStopType` trait) does. + struct Zst; + // Printing this type shows the desired string. + impl std::fmt::Display for Zst { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, $($tt)*) + } + } + impl rustc_middle::mir::interpret::MachineStopType for Zst {} + throw_machine_stop!(Zst) + }}; +} + +pub struct ConstProp; + +impl<'tcx> MirPass<'tcx> for ConstProp { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 1 + } + + #[instrument(skip(self, tcx), level = "debug")] + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // will be evaluated by miri and produce its errors there + if body.source.promoted.is_some() { + return; + } + + let def_id = body.source.def_id().expect_local(); + let def_kind = tcx.def_kind(def_id); + let is_fn_like = def_kind.is_fn_like(); + let is_assoc_const = def_kind == DefKind::AssocConst; + + // Only run const prop on functions, methods, closures and associated constants + if !is_fn_like && !is_assoc_const { + // skip anon_const/statics/consts because they'll be evaluated by miri anyway + trace!("ConstProp skipped for {:?}", def_id); + return; + } + + let is_generator = tcx.type_of(def_id.to_def_id()).is_generator(); + // FIXME(welseywiser) const prop doesn't work on generators because of query cycles + // computing their layout. + if is_generator { + trace!("ConstProp skipped for generator {:?}", def_id); + return; + } + + // Check if it's even possible to satisfy the 'where' clauses + // for this item. + // This branch will never be taken for any normal function. + // However, it's possible to `#!feature(trivial_bounds)]` to write + // a function with impossible to satisfy clauses, e.g.: + // `fn foo() where String: Copy {}` + // + // We don't usually need to worry about this kind of case, + // since we would get a compilation error if the user tried + // to call it. However, since we can do const propagation + // even without any calls to the function, we need to make + // sure that it even makes sense to try to evaluate the body. + // If there are unsatisfiable where clauses, then all bets are + // off, and we just give up. + // + // We manually filter the predicates, skipping anything that's not + // "global". We are in a potentially generic context + // (e.g. we are evaluating a function without substituting generic + // parameters, so this filtering serves two purposes: + // + // 1. We skip evaluating any predicates that we would + // never be able prove are unsatisfiable (e.g. `<T as Foo>` + // 2. We avoid trying to normalize predicates involving generic + // parameters (e.g. `<T as Foo>::MyItem`). This can confuse + // the normalization code (leading to cycle errors), since + // it's usually never invoked in this way. + let predicates = tcx + .predicates_of(def_id.to_def_id()) + .predicates + .iter() + .filter_map(|(p, _)| if p.is_global() { Some(*p) } else { None }); + if traits::impossible_predicates( + tcx, + traits::elaborate_predicates(tcx, predicates).map(|o| o.predicate).collect(), + ) { + trace!("ConstProp skipped for {:?}: found unsatisfiable predicates", def_id); + return; + } + + trace!("ConstProp starting for {:?}", def_id); + + let dummy_body = &Body::new( + body.source, + body.basic_blocks().clone(), + body.source_scopes.clone(), + body.local_decls.clone(), + Default::default(), + body.arg_count, + Default::default(), + body.span, + body.generator_kind(), + body.tainted_by_errors, + ); + + // FIXME(oli-obk, eddyb) Optimize locals (or even local paths) to hold + // constants, instead of just checking for const-folding succeeding. + // That would require a uniform one-def no-mutation analysis + // and RPO (or recursing when needing the value of a local). + let mut optimization_finder = ConstPropagator::new(body, dummy_body, tcx); + optimization_finder.visit_body(body); + + trace!("ConstProp done for {:?}", def_id); + } +} + +pub struct ConstPropMachine<'mir, 'tcx> { + /// The virtual call stack. + stack: Vec<Frame<'mir, 'tcx>>, + /// `OnlyInsideOwnBlock` locals that were written in the current block get erased at the end. + pub written_only_inside_own_block_locals: FxHashSet<Local>, + /// Locals that need to be cleared after every block terminates. + pub only_propagate_inside_block_locals: BitSet<Local>, + pub can_const_prop: IndexVec<Local, ConstPropMode>, +} + +impl ConstPropMachine<'_, '_> { + pub fn new( + only_propagate_inside_block_locals: BitSet<Local>, + can_const_prop: IndexVec<Local, ConstPropMode>, + ) -> Self { + Self { + stack: Vec::new(), + written_only_inside_own_block_locals: Default::default(), + only_propagate_inside_block_locals, + can_const_prop, + } + } +} + +impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx> { + compile_time_machine!(<'mir, 'tcx>); + const PANIC_ON_ALLOC_FAIL: bool = true; // all allocations are small (see `MAX_ALLOC_LIMIT`) + + type MemoryKind = !; + + fn load_mir( + _ecx: &InterpCx<'mir, 'tcx, Self>, + _instance: ty::InstanceDef<'tcx>, + ) -> InterpResult<'tcx, &'tcx Body<'tcx>> { + throw_machine_stop_str!("calling functions isn't supported in ConstProp") + } + + fn find_mir_or_eval_fn( + _ecx: &mut InterpCx<'mir, 'tcx, Self>, + _instance: ty::Instance<'tcx>, + _abi: CallAbi, + _args: &[OpTy<'tcx>], + _destination: &PlaceTy<'tcx>, + _target: Option<BasicBlock>, + _unwind: StackPopUnwind, + ) -> InterpResult<'tcx, Option<(&'mir Body<'tcx>, ty::Instance<'tcx>)>> { + Ok(None) + } + + fn call_intrinsic( + _ecx: &mut InterpCx<'mir, 'tcx, Self>, + _instance: ty::Instance<'tcx>, + _args: &[OpTy<'tcx>], + _destination: &PlaceTy<'tcx>, + _target: Option<BasicBlock>, + _unwind: StackPopUnwind, + ) -> InterpResult<'tcx> { + throw_machine_stop_str!("calling intrinsics isn't supported in ConstProp") + } + + fn assert_panic( + _ecx: &mut InterpCx<'mir, 'tcx, Self>, + _msg: &rustc_middle::mir::AssertMessage<'tcx>, + _unwind: Option<rustc_middle::mir::BasicBlock>, + ) -> InterpResult<'tcx> { + bug!("panics terminators are not evaluated in ConstProp") + } + + fn binary_ptr_op( + _ecx: &InterpCx<'mir, 'tcx, Self>, + _bin_op: BinOp, + _left: &ImmTy<'tcx>, + _right: &ImmTy<'tcx>, + ) -> InterpResult<'tcx, (Scalar, bool, Ty<'tcx>)> { + // We can't do this because aliasing of memory can differ between const eval and llvm + throw_machine_stop_str!("pointer arithmetic or comparisons aren't supported in ConstProp") + } + + fn access_local<'a>( + frame: &'a Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>, + local: Local, + ) -> InterpResult<'tcx, &'a interpret::Operand<Self::Provenance>> { + let l = &frame.locals[local]; + + if matches!( + l.value, + LocalValue::Live(interpret::Operand::Immediate(interpret::Immediate::Uninit)) + ) { + // For us "uninit" means "we don't know its value, might be initiailized or not". + // So stop here. + throw_machine_stop_str!("tried to access alocal with unknown value ") + } + + l.access() + } + + fn access_local_mut<'a>( + ecx: &'a mut InterpCx<'mir, 'tcx, Self>, + frame: usize, + local: Local, + ) -> InterpResult<'tcx, &'a mut interpret::Operand<Self::Provenance>> { + if ecx.machine.can_const_prop[local] == ConstPropMode::NoPropagation { + throw_machine_stop_str!("tried to write to a local that is marked as not propagatable") + } + if frame == 0 && ecx.machine.only_propagate_inside_block_locals.contains(local) { + trace!( + "mutating local {:?} which is restricted to its block. \ + Will remove it from const-prop after block is finished.", + local + ); + ecx.machine.written_only_inside_own_block_locals.insert(local); + } + ecx.machine.stack[frame].locals[local].access_mut() + } + + fn before_access_global( + _tcx: TyCtxt<'tcx>, + _machine: &Self, + _alloc_id: AllocId, + alloc: ConstAllocation<'tcx, Self::Provenance, Self::AllocExtra>, + _static_def_id: Option<DefId>, + is_write: bool, + ) -> InterpResult<'tcx> { + if is_write { + throw_machine_stop_str!("can't write to global"); + } + // If the static allocation is mutable, then we can't const prop it as its content + // might be different at runtime. + if alloc.inner().mutability == Mutability::Mut { + throw_machine_stop_str!("can't access mutable globals in ConstProp"); + } + + Ok(()) + } + + #[inline(always)] + fn expose_ptr( + _ecx: &mut InterpCx<'mir, 'tcx, Self>, + _ptr: Pointer<AllocId>, + ) -> InterpResult<'tcx> { + throw_machine_stop_str!("exposing pointers isn't supported in ConstProp") + } + + #[inline(always)] + fn init_frame_extra( + _ecx: &mut InterpCx<'mir, 'tcx, Self>, + frame: Frame<'mir, 'tcx>, + ) -> InterpResult<'tcx, Frame<'mir, 'tcx>> { + Ok(frame) + } + + #[inline(always)] + fn stack<'a>( + ecx: &'a InterpCx<'mir, 'tcx, Self>, + ) -> &'a [Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>] { + &ecx.machine.stack + } + + #[inline(always)] + fn stack_mut<'a>( + ecx: &'a mut InterpCx<'mir, 'tcx, Self>, + ) -> &'a mut Vec<Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>> { + &mut ecx.machine.stack + } +} + +/// Finds optimization opportunities on the MIR. +struct ConstPropagator<'mir, 'tcx> { + ecx: InterpCx<'mir, 'tcx, ConstPropMachine<'mir, 'tcx>>, + tcx: TyCtxt<'tcx>, + param_env: ParamEnv<'tcx>, + local_decls: &'mir IndexVec<Local, LocalDecl<'tcx>>, + // Because we have `MutVisitor` we can't obtain the `SourceInfo` from a `Location`. So we store + // the last known `SourceInfo` here and just keep revisiting it. + source_info: Option<SourceInfo>, +} + +impl<'tcx> LayoutOfHelpers<'tcx> for ConstPropagator<'_, 'tcx> { + type LayoutOfResult = Result<TyAndLayout<'tcx>, LayoutError<'tcx>>; + + #[inline] + fn handle_layout_err(&self, err: LayoutError<'tcx>, _: Span, _: Ty<'tcx>) -> LayoutError<'tcx> { + err + } +} + +impl HasDataLayout for ConstPropagator<'_, '_> { + #[inline] + fn data_layout(&self) -> &TargetDataLayout { + &self.tcx.data_layout + } +} + +impl<'tcx> ty::layout::HasTyCtxt<'tcx> for ConstPropagator<'_, 'tcx> { + #[inline] + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } +} + +impl<'tcx> ty::layout::HasParamEnv<'tcx> for ConstPropagator<'_, 'tcx> { + #[inline] + fn param_env(&self) -> ty::ParamEnv<'tcx> { + self.param_env + } +} + +impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { + fn new( + body: &Body<'tcx>, + dummy_body: &'mir Body<'tcx>, + tcx: TyCtxt<'tcx>, + ) -> ConstPropagator<'mir, 'tcx> { + let def_id = body.source.def_id(); + let substs = &InternalSubsts::identity_for_item(tcx, def_id); + let param_env = tcx.param_env_reveal_all_normalized(def_id); + + let can_const_prop = CanConstProp::check(tcx, param_env, body); + let mut only_propagate_inside_block_locals = BitSet::new_empty(can_const_prop.len()); + for (l, mode) in can_const_prop.iter_enumerated() { + if *mode == ConstPropMode::OnlyInsideOwnBlock { + only_propagate_inside_block_locals.insert(l); + } + } + let mut ecx = InterpCx::new( + tcx, + tcx.def_span(def_id), + param_env, + ConstPropMachine::new(only_propagate_inside_block_locals, can_const_prop), + ); + + let ret_layout = ecx + .layout_of(body.bound_return_ty().subst(tcx, substs)) + .ok() + // Don't bother allocating memory for large values. + // I don't know how return types can seem to be unsized but this happens in the + // `type/type-unsatisfiable.rs` test. + .filter(|ret_layout| { + !ret_layout.is_unsized() && ret_layout.size < Size::from_bytes(MAX_ALLOC_LIMIT) + }) + .unwrap_or_else(|| ecx.layout_of(tcx.types.unit).unwrap()); + + let ret = ecx + .allocate(ret_layout, MemoryKind::Stack) + .expect("couldn't perform small allocation") + .into(); + + ecx.push_stack_frame( + Instance::new(def_id, substs), + dummy_body, + &ret, + StackPopCleanup::Root { cleanup: false }, + ) + .expect("failed to push initial stack frame"); + + ConstPropagator { + ecx, + tcx, + param_env, + local_decls: &dummy_body.local_decls, + source_info: None, + } + } + + fn get_const(&self, place: Place<'tcx>) -> Option<OpTy<'tcx>> { + let op = match self.ecx.eval_place_to_op(place, None) { + Ok(op) => op, + Err(e) => { + trace!("get_const failed: {}", e); + return None; + } + }; + + // Try to read the local as an immediate so that if it is representable as a scalar, we can + // handle it as such, but otherwise, just return the value as is. + Some(match self.ecx.read_immediate_raw(&op, /*force*/ false) { + Ok(Ok(imm)) => imm.into(), + _ => op, + }) + } + + /// Remove `local` from the pool of `Locals`. Allows writing to them, + /// but not reading from them anymore. + fn remove_const(ecx: &mut InterpCx<'mir, 'tcx, ConstPropMachine<'mir, 'tcx>>, local: Local) { + ecx.frame_mut().locals[local] = LocalState { + value: LocalValue::Live(interpret::Operand::Immediate(interpret::Immediate::Uninit)), + layout: Cell::new(None), + }; + } + + fn use_ecx<F, T>(&mut self, f: F) -> Option<T> + where + F: FnOnce(&mut Self) -> InterpResult<'tcx, T>, + { + match f(self) { + Ok(val) => Some(val), + Err(error) => { + trace!("InterpCx operation failed: {:?}", error); + // Some errors shouldn't come up because creating them causes + // an allocation, which we should avoid. When that happens, + // dedicated error variants should be introduced instead. + assert!( + !error.kind().formatted_string(), + "const-prop encountered formatting error: {}", + error + ); + None + } + } + } + + /// Returns the value, if any, of evaluating `c`. + fn eval_constant(&mut self, c: &Constant<'tcx>) -> Option<OpTy<'tcx>> { + // FIXME we need to revisit this for #67176 + if c.needs_subst() { + return None; + } + + self.ecx.mir_const_to_op(&c.literal, None).ok() + } + + /// Returns the value, if any, of evaluating `place`. + fn eval_place(&mut self, place: Place<'tcx>) -> Option<OpTy<'tcx>> { + trace!("eval_place(place={:?})", place); + self.use_ecx(|this| this.ecx.eval_place_to_op(place, None)) + } + + /// Returns the value, if any, of evaluating `op`. Calls upon `eval_constant` + /// or `eval_place`, depending on the variant of `Operand` used. + fn eval_operand(&mut self, op: &Operand<'tcx>) -> Option<OpTy<'tcx>> { + match *op { + Operand::Constant(ref c) => self.eval_constant(c), + Operand::Move(place) | Operand::Copy(place) => self.eval_place(place), + } + } + + fn check_unary_op(&mut self, op: UnOp, arg: &Operand<'tcx>) -> Option<()> { + if self.use_ecx(|this| { + let val = this.ecx.read_immediate(&this.ecx.eval_operand(arg, None)?)?; + let (_res, overflow, _ty) = this.ecx.overflowing_unary_op(op, &val)?; + Ok(overflow) + })? { + // `AssertKind` only has an `OverflowNeg` variant, so make sure that is + // appropriate to use. + assert_eq!(op, UnOp::Neg, "Neg is the only UnOp that can overflow"); + return None; + } + + Some(()) + } + + fn check_binary_op( + &mut self, + op: BinOp, + left: &Operand<'tcx>, + right: &Operand<'tcx>, + ) -> Option<()> { + let r = self.use_ecx(|this| this.ecx.read_immediate(&this.ecx.eval_operand(right, None)?)); + let l = self.use_ecx(|this| this.ecx.read_immediate(&this.ecx.eval_operand(left, None)?)); + // Check for exceeding shifts *even if* we cannot evaluate the LHS. + if op == BinOp::Shr || op == BinOp::Shl { + let r = r.clone()?; + // We need the type of the LHS. We cannot use `place_layout` as that is the type + // of the result, which for checked binops is not the same! + let left_ty = left.ty(self.local_decls, self.tcx); + let left_size = self.ecx.layout_of(left_ty).ok()?.size; + let right_size = r.layout.size; + let r_bits = r.to_scalar().ok(); + let r_bits = r_bits.and_then(|r| r.to_bits(right_size).ok()); + if r_bits.map_or(false, |b| b >= left_size.bits() as u128) { + return None; + } + } + + if let (Some(l), Some(r)) = (&l, &r) { + // The remaining operators are handled through `overflowing_binary_op`. + if self.use_ecx(|this| { + let (_res, overflow, _ty) = this.ecx.overflowing_binary_op(op, l, r)?; + Ok(overflow) + })? { + return None; + } + } + Some(()) + } + + fn propagate_operand(&mut self, operand: &mut Operand<'tcx>) { + match *operand { + Operand::Copy(l) | Operand::Move(l) => { + if let Some(value) = self.get_const(l) && self.should_const_prop(&value) { + // FIXME(felix91gr): this code only handles `Scalar` cases. + // For now, we're not handling `ScalarPair` cases because + // doing so here would require a lot of code duplication. + // We should hopefully generalize `Operand` handling into a fn, + // and use it to do const-prop here and everywhere else + // where it makes sense. + if let interpret::Operand::Immediate(interpret::Immediate::Scalar( + ScalarMaybeUninit::Scalar(scalar), + )) = *value + { + *operand = self.operand_from_scalar( + scalar, + value.layout.ty, + self.source_info.unwrap().span, + ); + } + } + } + Operand::Constant(_) => (), + } + } + + fn const_prop(&mut self, rvalue: &Rvalue<'tcx>, place: Place<'tcx>) -> Option<()> { + // Perform any special handling for specific Rvalue types. + // Generally, checks here fall into one of two categories: + // 1. Additional checking to provide useful lints to the user + // - In this case, we will do some validation and then fall through to the + // end of the function which evals the assignment. + // 2. Working around bugs in other parts of the compiler + // - In this case, we'll return `None` from this function to stop evaluation. + match rvalue { + // Additional checking: give lints to the user if an overflow would occur. + // We do this here and not in the `Assert` terminator as that terminator is + // only sometimes emitted (overflow checks can be disabled), but we want to always + // lint. + Rvalue::UnaryOp(op, arg) => { + trace!("checking UnaryOp(op = {:?}, arg = {:?})", op, arg); + self.check_unary_op(*op, arg)?; + } + Rvalue::BinaryOp(op, box (left, right)) => { + trace!("checking BinaryOp(op = {:?}, left = {:?}, right = {:?})", op, left, right); + self.check_binary_op(*op, left, right)?; + } + Rvalue::CheckedBinaryOp(op, box (left, right)) => { + trace!( + "checking CheckedBinaryOp(op = {:?}, left = {:?}, right = {:?})", + op, + left, + right + ); + self.check_binary_op(*op, left, right)?; + } + + // Do not try creating references (#67862) + Rvalue::AddressOf(_, place) | Rvalue::Ref(_, _, place) => { + trace!("skipping AddressOf | Ref for {:?}", place); + + // This may be creating mutable references or immutable references to cells. + // If that happens, the pointed to value could be mutated via that reference. + // Since we aren't tracking references, the const propagator loses track of what + // value the local has right now. + // Thus, all locals that have their reference taken + // must not take part in propagation. + Self::remove_const(&mut self.ecx, place.local); + + return None; + } + Rvalue::ThreadLocalRef(def_id) => { + trace!("skipping ThreadLocalRef({:?})", def_id); + + return None; + } + + // There's no other checking to do at this time. + Rvalue::Aggregate(..) + | Rvalue::Use(..) + | Rvalue::CopyForDeref(..) + | Rvalue::Repeat(..) + | Rvalue::Len(..) + | Rvalue::Cast(..) + | Rvalue::ShallowInitBox(..) + | Rvalue::Discriminant(..) + | Rvalue::NullaryOp(..) => {} + } + + // FIXME we need to revisit this for #67176 + if rvalue.needs_subst() { + return None; + } + + if self.tcx.sess.mir_opt_level() >= 4 { + self.eval_rvalue_with_identities(rvalue, place) + } else { + self.use_ecx(|this| this.ecx.eval_rvalue_into_place(rvalue, place)) + } + } + + // Attempt to use algebraic identities to eliminate constant expressions + fn eval_rvalue_with_identities( + &mut self, + rvalue: &Rvalue<'tcx>, + place: Place<'tcx>, + ) -> Option<()> { + self.use_ecx(|this| match rvalue { + Rvalue::BinaryOp(op, box (left, right)) + | Rvalue::CheckedBinaryOp(op, box (left, right)) => { + let l = this.ecx.eval_operand(left, None); + let r = this.ecx.eval_operand(right, None); + + let const_arg = match (l, r) { + (Ok(ref x), Err(_)) | (Err(_), Ok(ref x)) => this.ecx.read_immediate(x)?, + (Err(e), Err(_)) => return Err(e), + (Ok(_), Ok(_)) => return this.ecx.eval_rvalue_into_place(rvalue, place), + }; + + if !matches!(const_arg.layout.abi, abi::Abi::Scalar(..)) { + // We cannot handle Scalar Pair stuff. + return this.ecx.eval_rvalue_into_place(rvalue, place); + } + + let arg_value = const_arg.to_scalar()?.to_bits(const_arg.layout.size)?; + let dest = this.ecx.eval_place(place)?; + + match op { + BinOp::BitAnd if arg_value == 0 => this.ecx.write_immediate(*const_arg, &dest), + BinOp::BitOr + if arg_value == const_arg.layout.size.truncate(u128::MAX) + || (const_arg.layout.ty.is_bool() && arg_value == 1) => + { + this.ecx.write_immediate(*const_arg, &dest) + } + BinOp::Mul if const_arg.layout.ty.is_integral() && arg_value == 0 => { + if let Rvalue::CheckedBinaryOp(_, _) = rvalue { + let val = Immediate::ScalarPair( + const_arg.to_scalar()?.into(), + Scalar::from_bool(false).into(), + ); + this.ecx.write_immediate(val, &dest) + } else { + this.ecx.write_immediate(*const_arg, &dest) + } + } + _ => this.ecx.eval_rvalue_into_place(rvalue, place), + } + } + _ => this.ecx.eval_rvalue_into_place(rvalue, place), + }) + } + + /// Creates a new `Operand::Constant` from a `Scalar` value + fn operand_from_scalar(&self, scalar: Scalar, ty: Ty<'tcx>, span: Span) -> Operand<'tcx> { + Operand::Constant(Box::new(Constant { + span, + user_ty: None, + literal: ConstantKind::from_scalar(self.tcx, scalar, ty), + })) + } + + fn replace_with_const( + &mut self, + rval: &mut Rvalue<'tcx>, + value: &OpTy<'tcx>, + source_info: SourceInfo, + ) { + if let Rvalue::Use(Operand::Constant(c)) = rval { + match c.literal { + ConstantKind::Ty(c) if matches!(c.kind(), ConstKind::Unevaluated(..)) => {} + _ => { + trace!("skipping replace of Rvalue::Use({:?} because it is already a const", c); + return; + } + } + } + + trace!("attempting to replace {:?} with {:?}", rval, value); + if let Err(e) = self.ecx.const_validate_operand( + value, + vec![], + // FIXME: is ref tracking too expensive? + // FIXME: what is the point of ref tracking if we do not even check the tracked refs? + &mut interpret::RefTracking::empty(), + CtfeValidationMode::Regular, + ) { + trace!("validation error, attempt failed: {:?}", e); + return; + } + + // FIXME> figure out what to do when read_immediate_raw fails + let imm = self.use_ecx(|this| this.ecx.read_immediate_raw(value, /*force*/ false)); + + if let Some(Ok(imm)) = imm { + match *imm { + interpret::Immediate::Scalar(ScalarMaybeUninit::Scalar(scalar)) => { + *rval = Rvalue::Use(self.operand_from_scalar( + scalar, + value.layout.ty, + source_info.span, + )); + } + Immediate::ScalarPair( + ScalarMaybeUninit::Scalar(_), + ScalarMaybeUninit::Scalar(_), + ) => { + // Found a value represented as a pair. For now only do const-prop if the type + // of `rvalue` is also a tuple with two scalars. + // FIXME: enable the general case stated above ^. + let ty = value.layout.ty; + // Only do it for tuples + if let ty::Tuple(types) = ty.kind() { + // Only do it if tuple is also a pair with two scalars + if let [ty1, ty2] = types[..] { + let alloc = self.use_ecx(|this| { + let ty_is_scalar = |ty| { + this.ecx.layout_of(ty).ok().map(|layout| layout.abi.is_scalar()) + == Some(true) + }; + if ty_is_scalar(ty1) && ty_is_scalar(ty2) { + let alloc = this + .ecx + .intern_with_temp_alloc(value.layout, |ecx, dest| { + ecx.write_immediate(*imm, dest) + }) + .unwrap(); + Ok(Some(alloc)) + } else { + Ok(None) + } + }); + + if let Some(Some(alloc)) = alloc { + // Assign entire constant in a single statement. + // We can't use aggregates, as we run after the aggregate-lowering `MirPhase`. + let const_val = ConstValue::ByRef { alloc, offset: Size::ZERO }; + let literal = ConstantKind::Val(const_val, ty); + *rval = Rvalue::Use(Operand::Constant(Box::new(Constant { + span: source_info.span, + user_ty: None, + literal, + }))); + } + } + } + } + // Scalars or scalar pairs that contain undef values are assumed to not have + // successfully evaluated and are thus not propagated. + _ => {} + } + } + } + + /// Returns `true` if and only if this `op` should be const-propagated into. + fn should_const_prop(&mut self, op: &OpTy<'tcx>) -> bool { + if !self.tcx.consider_optimizing(|| format!("ConstantPropagation - OpTy: {:?}", op)) { + return false; + } + + match **op { + interpret::Operand::Immediate(Immediate::Scalar(ScalarMaybeUninit::Scalar(s))) => { + s.try_to_int().is_ok() + } + interpret::Operand::Immediate(Immediate::ScalarPair( + ScalarMaybeUninit::Scalar(l), + ScalarMaybeUninit::Scalar(r), + )) => l.try_to_int().is_ok() && r.try_to_int().is_ok(), + _ => false, + } + } +} + +/// The mode that `ConstProp` is allowed to run in for a given `Local`. +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum ConstPropMode { + /// The `Local` can be propagated into and reads of this `Local` can also be propagated. + FullConstProp, + /// The `Local` can only be propagated into and from its own block. + OnlyInsideOwnBlock, + /// The `Local` can be propagated into but reads cannot be propagated. + OnlyPropagateInto, + /// The `Local` cannot be part of propagation at all. Any statement + /// referencing it either for reading or writing will not get propagated. + NoPropagation, +} + +pub struct CanConstProp { + can_const_prop: IndexVec<Local, ConstPropMode>, + // False at the beginning. Once set, no more assignments are allowed to that local. + found_assignment: BitSet<Local>, + // Cache of locals' information + local_kinds: IndexVec<Local, LocalKind>, +} + +impl CanConstProp { + /// Returns true if `local` can be propagated + pub fn check<'tcx>( + tcx: TyCtxt<'tcx>, + param_env: ParamEnv<'tcx>, + body: &Body<'tcx>, + ) -> IndexVec<Local, ConstPropMode> { + let mut cpv = CanConstProp { + can_const_prop: IndexVec::from_elem(ConstPropMode::FullConstProp, &body.local_decls), + found_assignment: BitSet::new_empty(body.local_decls.len()), + local_kinds: IndexVec::from_fn_n( + |local| body.local_kind(local), + body.local_decls.len(), + ), + }; + for (local, val) in cpv.can_const_prop.iter_enumerated_mut() { + let ty = body.local_decls[local].ty; + match tcx.layout_of(param_env.and(ty)) { + Ok(layout) if layout.size < Size::from_bytes(MAX_ALLOC_LIMIT) => {} + // Either the layout fails to compute, then we can't use this local anyway + // or the local is too large, then we don't want to. + _ => { + *val = ConstPropMode::NoPropagation; + continue; + } + } + // Cannot use args at all + // Cannot use locals because if x < y { y - x } else { x - y } would + // lint for x != y + // FIXME(oli-obk): lint variables until they are used in a condition + // FIXME(oli-obk): lint if return value is constant + if cpv.local_kinds[local] == LocalKind::Arg { + *val = ConstPropMode::OnlyPropagateInto; + trace!( + "local {:?} can't be const propagated because it's a function argument", + local + ); + } else if cpv.local_kinds[local] == LocalKind::Var { + *val = ConstPropMode::OnlyInsideOwnBlock; + trace!( + "local {:?} will only be propagated inside its block, because it's a user variable", + local + ); + } + } + cpv.visit_body(&body); + cpv.can_const_prop + } +} + +impl Visitor<'_> for CanConstProp { + fn visit_local(&mut self, local: Local, context: PlaceContext, _: Location) { + use rustc_middle::mir::visit::PlaceContext::*; + match context { + // Projections are fine, because `&mut foo.x` will be caught by + // `MutatingUseContext::Borrow` elsewhere. + MutatingUse(MutatingUseContext::Projection) + // These are just stores, where the storing is not propagatable, but there may be later + // mutations of the same local via `Store` + | MutatingUse(MutatingUseContext::Call) + | MutatingUse(MutatingUseContext::AsmOutput) + | MutatingUse(MutatingUseContext::Deinit) + // Actual store that can possibly even propagate a value + | MutatingUse(MutatingUseContext::Store) + | MutatingUse(MutatingUseContext::SetDiscriminant) => { + if !self.found_assignment.insert(local) { + match &mut self.can_const_prop[local] { + // If the local can only get propagated in its own block, then we don't have + // to worry about multiple assignments, as we'll nuke the const state at the + // end of the block anyway, and inside the block we overwrite previous + // states as applicable. + ConstPropMode::OnlyInsideOwnBlock => {} + ConstPropMode::NoPropagation => {} + ConstPropMode::OnlyPropagateInto => {} + other @ ConstPropMode::FullConstProp => { + trace!( + "local {:?} can't be propagated because of multiple assignments. Previous state: {:?}", + local, other, + ); + *other = ConstPropMode::OnlyInsideOwnBlock; + } + } + } + } + // Reading constants is allowed an arbitrary number of times + NonMutatingUse(NonMutatingUseContext::Copy) + | NonMutatingUse(NonMutatingUseContext::Move) + | NonMutatingUse(NonMutatingUseContext::Inspect) + | NonMutatingUse(NonMutatingUseContext::Projection) + | NonUse(_) => {} + + // These could be propagated with a smarter analysis or just some careful thinking about + // whether they'd be fine right now. + MutatingUse(MutatingUseContext::Yield) + | MutatingUse(MutatingUseContext::Drop) + | MutatingUse(MutatingUseContext::Retag) + // These can't ever be propagated under any scheme, as we can't reason about indirect + // mutation. + | NonMutatingUse(NonMutatingUseContext::SharedBorrow) + | NonMutatingUse(NonMutatingUseContext::ShallowBorrow) + | NonMutatingUse(NonMutatingUseContext::UniqueBorrow) + | NonMutatingUse(NonMutatingUseContext::AddressOf) + | MutatingUse(MutatingUseContext::Borrow) + | MutatingUse(MutatingUseContext::AddressOf) => { + trace!("local {:?} can't be propagaged because it's used: {:?}", local, context); + self.can_const_prop[local] = ConstPropMode::NoPropagation; + } + } + } +} + +impl<'tcx> MutVisitor<'tcx> for ConstPropagator<'_, 'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_body(&mut self, body: &mut Body<'tcx>) { + for (bb, data) in body.basic_blocks_mut().iter_enumerated_mut() { + self.visit_basic_block_data(bb, data); + } + } + + fn visit_operand(&mut self, operand: &mut Operand<'tcx>, location: Location) { + self.super_operand(operand, location); + + // Only const prop copies and moves on `mir_opt_level=3` as doing so + // currently slightly increases compile time in some cases. + if self.tcx.sess.mir_opt_level() >= 3 { + self.propagate_operand(operand) + } + } + + fn visit_constant(&mut self, constant: &mut Constant<'tcx>, location: Location) { + trace!("visit_constant: {:?}", constant); + self.super_constant(constant, location); + self.eval_constant(constant); + } + + fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { + trace!("visit_statement: {:?}", statement); + let source_info = statement.source_info; + self.source_info = Some(source_info); + if let StatementKind::Assign(box (place, ref mut rval)) = statement.kind { + let can_const_prop = self.ecx.machine.can_const_prop[place.local]; + if let Some(()) = self.const_prop(rval, place) { + // This will return None if the above `const_prop` invocation only "wrote" a + // type whose creation requires no write. E.g. a generator whose initial state + // consists solely of uninitialized memory (so it doesn't capture any locals). + if let Some(ref value) = self.get_const(place) && self.should_const_prop(value) { + trace!("replacing {:?} with {:?}", rval, value); + self.replace_with_const(rval, value, source_info); + if can_const_prop == ConstPropMode::FullConstProp + || can_const_prop == ConstPropMode::OnlyInsideOwnBlock + { + trace!("propagated into {:?}", place); + } + } + match can_const_prop { + ConstPropMode::OnlyInsideOwnBlock => { + trace!( + "found local restricted to its block. \ + Will remove it from const-prop after block is finished. Local: {:?}", + place.local + ); + } + ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => { + trace!("can't propagate into {:?}", place); + if place.local != RETURN_PLACE { + Self::remove_const(&mut self.ecx, place.local); + } + } + ConstPropMode::FullConstProp => {} + } + } else { + // Const prop failed, so erase the destination, ensuring that whatever happens + // from here on, does not know about the previous value. + // This is important in case we have + // ```rust + // let mut x = 42; + // x = SOME_MUTABLE_STATIC; + // // x must now be uninit + // ``` + // FIXME: we overzealously erase the entire local, because that's easier to + // implement. + trace!( + "propagation into {:?} failed. + Nuking the entire site from orbit, it's the only way to be sure", + place, + ); + Self::remove_const(&mut self.ecx, place.local); + } + } else { + match statement.kind { + StatementKind::SetDiscriminant { ref place, .. } => { + match self.ecx.machine.can_const_prop[place.local] { + ConstPropMode::FullConstProp | ConstPropMode::OnlyInsideOwnBlock => { + if self.use_ecx(|this| this.ecx.statement(statement)).is_some() { + trace!("propped discriminant into {:?}", place); + } else { + Self::remove_const(&mut self.ecx, place.local); + } + } + ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => { + Self::remove_const(&mut self.ecx, place.local); + } + } + } + StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => { + let frame = self.ecx.frame_mut(); + frame.locals[local].value = + if let StatementKind::StorageLive(_) = statement.kind { + LocalValue::Live(interpret::Operand::Immediate( + interpret::Immediate::Uninit, + )) + } else { + LocalValue::Dead + }; + } + _ => {} + } + } + + self.super_statement(statement, location); + } + + fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) { + let source_info = terminator.source_info; + self.source_info = Some(source_info); + self.super_terminator(terminator, location); + match &mut terminator.kind { + TerminatorKind::Assert { expected, ref mut cond, .. } => { + if let Some(ref value) = self.eval_operand(&cond) { + trace!("assertion on {:?} should be {:?}", value, expected); + let expected = ScalarMaybeUninit::from(Scalar::from_bool(*expected)); + let value_const = self.ecx.read_scalar(&value).unwrap(); + if expected != value_const { + // Poison all places this operand references so that further code + // doesn't use the invalid value + match cond { + Operand::Move(ref place) | Operand::Copy(ref place) => { + Self::remove_const(&mut self.ecx, place.local); + } + Operand::Constant(_) => {} + } + } else { + if self.should_const_prop(value) { + if let ScalarMaybeUninit::Scalar(scalar) = value_const { + *cond = self.operand_from_scalar( + scalar, + self.tcx.types.bool, + source_info.span, + ); + } + } + } + } + } + TerminatorKind::SwitchInt { ref mut discr, .. } => { + // FIXME: This is currently redundant with `visit_operand`, but sadly + // always visiting operands currently causes a perf regression in LLVM codegen, so + // `visit_operand` currently only runs for propagates places for `mir_opt_level=4`. + self.propagate_operand(discr) + } + // None of these have Operands to const-propagate. + TerminatorKind::Goto { .. } + | TerminatorKind::Resume + | TerminatorKind::Abort + | TerminatorKind::Return + | TerminatorKind::Unreachable + | TerminatorKind::Drop { .. } + | TerminatorKind::DropAndReplace { .. } + | TerminatorKind::Yield { .. } + | TerminatorKind::GeneratorDrop + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } + | TerminatorKind::InlineAsm { .. } => {} + // Every argument in our function calls have already been propagated in `visit_operand`. + // + // NOTE: because LLVM codegen gives slight performance regressions with it, so this is + // gated on `mir_opt_level=3`. + TerminatorKind::Call { .. } => {} + } + + // We remove all Locals which are restricted in propagation to their containing blocks and + // which were modified in the current block. + // Take it out of the ecx so we can get a mutable reference to the ecx for `remove_const`. + let mut locals = std::mem::take(&mut self.ecx.machine.written_only_inside_own_block_locals); + for &local in locals.iter() { + Self::remove_const(&mut self.ecx, local); + } + locals.clear(); + // Put it back so we reuse the heap of the storage + self.ecx.machine.written_only_inside_own_block_locals = locals; + if cfg!(debug_assertions) { + // Ensure we are correctly erasing locals with the non-debug-assert logic. + for local in self.ecx.machine.only_propagate_inside_block_locals.iter() { + assert!( + self.get_const(local.into()).is_none() + || self + .layout_of(self.local_decls[local].ty) + .map_or(true, |layout| layout.is_zst()) + ) + } + } + } +} diff --git a/compiler/rustc_mir_transform/src/const_prop_lint.rs b/compiler/rustc_mir_transform/src/const_prop_lint.rs new file mode 100644 index 000000000..c2ea55af4 --- /dev/null +++ b/compiler/rustc_mir_transform/src/const_prop_lint.rs @@ -0,0 +1,734 @@ +//! Propagates constants for early reporting of statically known +//! assertion failures + +use crate::const_prop::CanConstProp; +use crate::const_prop::ConstPropMachine; +use crate::const_prop::ConstPropMode; +use crate::MirLint; +use rustc_const_eval::const_eval::ConstEvalErr; +use rustc_const_eval::interpret::{ + self, InterpCx, InterpResult, LocalState, LocalValue, MemoryKind, OpTy, Scalar, + ScalarMaybeUninit, StackPopCleanup, +}; +use rustc_hir::def::DefKind; +use rustc_hir::HirId; +use rustc_index::bit_set::BitSet; +use rustc_index::vec::IndexVec; +use rustc_middle::mir::visit::Visitor; +use rustc_middle::mir::{ + AssertKind, BinOp, Body, Constant, ConstantKind, Local, LocalDecl, Location, Operand, Place, + Rvalue, SourceInfo, SourceScope, SourceScopeData, Statement, StatementKind, Terminator, + TerminatorKind, UnOp, RETURN_PLACE, +}; +use rustc_middle::ty::layout::{LayoutError, LayoutOf, LayoutOfHelpers, TyAndLayout}; +use rustc_middle::ty::subst::{InternalSubsts, Subst}; +use rustc_middle::ty::{ + self, ConstInt, ConstKind, Instance, ParamEnv, ScalarInt, Ty, TyCtxt, TypeVisitable, +}; +use rustc_session::lint; +use rustc_span::Span; +use rustc_target::abi::{HasDataLayout, Size, TargetDataLayout}; +use rustc_trait_selection::traits; +use std::cell::Cell; + +/// The maximum number of bytes that we'll allocate space for a local or the return value. +/// Needed for #66397, because otherwise we eval into large places and that can cause OOM or just +/// Severely regress performance. +const MAX_ALLOC_LIMIT: u64 = 1024; +pub struct ConstProp; + +impl<'tcx> MirLint<'tcx> for ConstProp { + fn run_lint(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) { + // will be evaluated by miri and produce its errors there + if body.source.promoted.is_some() { + return; + } + + let def_id = body.source.def_id().expect_local(); + let is_fn_like = tcx.def_kind(def_id).is_fn_like(); + let is_assoc_const = tcx.def_kind(def_id) == DefKind::AssocConst; + + // Only run const prop on functions, methods, closures and associated constants + if !is_fn_like && !is_assoc_const { + // skip anon_const/statics/consts because they'll be evaluated by miri anyway + trace!("ConstProp skipped for {:?}", def_id); + return; + } + + let is_generator = tcx.type_of(def_id.to_def_id()).is_generator(); + // FIXME(welseywiser) const prop doesn't work on generators because of query cycles + // computing their layout. + if is_generator { + trace!("ConstProp skipped for generator {:?}", def_id); + return; + } + + // Check if it's even possible to satisfy the 'where' clauses + // for this item. + // This branch will never be taken for any normal function. + // However, it's possible to `#!feature(trivial_bounds)]` to write + // a function with impossible to satisfy clauses, e.g.: + // `fn foo() where String: Copy {}` + // + // We don't usually need to worry about this kind of case, + // since we would get a compilation error if the user tried + // to call it. However, since we can do const propagation + // even without any calls to the function, we need to make + // sure that it even makes sense to try to evaluate the body. + // If there are unsatisfiable where clauses, then all bets are + // off, and we just give up. + // + // We manually filter the predicates, skipping anything that's not + // "global". We are in a potentially generic context + // (e.g. we are evaluating a function without substituting generic + // parameters, so this filtering serves two purposes: + // + // 1. We skip evaluating any predicates that we would + // never be able prove are unsatisfiable (e.g. `<T as Foo>` + // 2. We avoid trying to normalize predicates involving generic + // parameters (e.g. `<T as Foo>::MyItem`). This can confuse + // the normalization code (leading to cycle errors), since + // it's usually never invoked in this way. + let predicates = tcx + .predicates_of(def_id.to_def_id()) + .predicates + .iter() + .filter_map(|(p, _)| if p.is_global() { Some(*p) } else { None }); + if traits::impossible_predicates( + tcx, + traits::elaborate_predicates(tcx, predicates).map(|o| o.predicate).collect(), + ) { + trace!("ConstProp skipped for {:?}: found unsatisfiable predicates", def_id); + return; + } + + trace!("ConstProp starting for {:?}", def_id); + + let dummy_body = &Body::new( + body.source, + body.basic_blocks().clone(), + body.source_scopes.clone(), + body.local_decls.clone(), + Default::default(), + body.arg_count, + Default::default(), + body.span, + body.generator_kind(), + body.tainted_by_errors, + ); + + // FIXME(oli-obk, eddyb) Optimize locals (or even local paths) to hold + // constants, instead of just checking for const-folding succeeding. + // That would require a uniform one-def no-mutation analysis + // and RPO (or recursing when needing the value of a local). + let mut optimization_finder = ConstPropagator::new(body, dummy_body, tcx); + optimization_finder.visit_body(body); + + trace!("ConstProp done for {:?}", def_id); + } +} + +/// Finds optimization opportunities on the MIR. +struct ConstPropagator<'mir, 'tcx> { + ecx: InterpCx<'mir, 'tcx, ConstPropMachine<'mir, 'tcx>>, + tcx: TyCtxt<'tcx>, + param_env: ParamEnv<'tcx>, + source_scopes: &'mir IndexVec<SourceScope, SourceScopeData<'tcx>>, + local_decls: &'mir IndexVec<Local, LocalDecl<'tcx>>, + // Because we have `MutVisitor` we can't obtain the `SourceInfo` from a `Location`. So we store + // the last known `SourceInfo` here and just keep revisiting it. + source_info: Option<SourceInfo>, +} + +impl<'tcx> LayoutOfHelpers<'tcx> for ConstPropagator<'_, 'tcx> { + type LayoutOfResult = Result<TyAndLayout<'tcx>, LayoutError<'tcx>>; + + #[inline] + fn handle_layout_err(&self, err: LayoutError<'tcx>, _: Span, _: Ty<'tcx>) -> LayoutError<'tcx> { + err + } +} + +impl HasDataLayout for ConstPropagator<'_, '_> { + #[inline] + fn data_layout(&self) -> &TargetDataLayout { + &self.tcx.data_layout + } +} + +impl<'tcx> ty::layout::HasTyCtxt<'tcx> for ConstPropagator<'_, 'tcx> { + #[inline] + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } +} + +impl<'tcx> ty::layout::HasParamEnv<'tcx> for ConstPropagator<'_, 'tcx> { + #[inline] + fn param_env(&self) -> ty::ParamEnv<'tcx> { + self.param_env + } +} + +impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { + fn new( + body: &Body<'tcx>, + dummy_body: &'mir Body<'tcx>, + tcx: TyCtxt<'tcx>, + ) -> ConstPropagator<'mir, 'tcx> { + let def_id = body.source.def_id(); + let substs = &InternalSubsts::identity_for_item(tcx, def_id); + let param_env = tcx.param_env_reveal_all_normalized(def_id); + + let can_const_prop = CanConstProp::check(tcx, param_env, body); + let mut only_propagate_inside_block_locals = BitSet::new_empty(can_const_prop.len()); + for (l, mode) in can_const_prop.iter_enumerated() { + if *mode == ConstPropMode::OnlyInsideOwnBlock { + only_propagate_inside_block_locals.insert(l); + } + } + let mut ecx = InterpCx::new( + tcx, + tcx.def_span(def_id), + param_env, + ConstPropMachine::new(only_propagate_inside_block_locals, can_const_prop), + ); + + let ret_layout = ecx + .layout_of(body.bound_return_ty().subst(tcx, substs)) + .ok() + // Don't bother allocating memory for large values. + // I don't know how return types can seem to be unsized but this happens in the + // `type/type-unsatisfiable.rs` test. + .filter(|ret_layout| { + !ret_layout.is_unsized() && ret_layout.size < Size::from_bytes(MAX_ALLOC_LIMIT) + }) + .unwrap_or_else(|| ecx.layout_of(tcx.types.unit).unwrap()); + + let ret = ecx + .allocate(ret_layout, MemoryKind::Stack) + .expect("couldn't perform small allocation") + .into(); + + ecx.push_stack_frame( + Instance::new(def_id, substs), + dummy_body, + &ret, + StackPopCleanup::Root { cleanup: false }, + ) + .expect("failed to push initial stack frame"); + + ConstPropagator { + ecx, + tcx, + param_env, + source_scopes: &dummy_body.source_scopes, + local_decls: &dummy_body.local_decls, + source_info: None, + } + } + + fn get_const(&self, place: Place<'tcx>) -> Option<OpTy<'tcx>> { + let op = match self.ecx.eval_place_to_op(place, None) { + Ok(op) => op, + Err(e) => { + trace!("get_const failed: {}", e); + return None; + } + }; + + // Try to read the local as an immediate so that if it is representable as a scalar, we can + // handle it as such, but otherwise, just return the value as is. + Some(match self.ecx.read_immediate_raw(&op, /*force*/ false) { + Ok(Ok(imm)) => imm.into(), + _ => op, + }) + } + + /// Remove `local` from the pool of `Locals`. Allows writing to them, + /// but not reading from them anymore. + fn remove_const(ecx: &mut InterpCx<'mir, 'tcx, ConstPropMachine<'mir, 'tcx>>, local: Local) { + ecx.frame_mut().locals[local] = LocalState { + value: LocalValue::Live(interpret::Operand::Immediate(interpret::Immediate::Uninit)), + layout: Cell::new(None), + }; + } + + fn lint_root(&self, source_info: SourceInfo) -> Option<HirId> { + source_info.scope.lint_root(self.source_scopes) + } + + fn use_ecx<F, T>(&mut self, source_info: SourceInfo, f: F) -> Option<T> + where + F: FnOnce(&mut Self) -> InterpResult<'tcx, T>, + { + // Overwrite the PC -- whatever the interpreter does to it does not make any sense anyway. + self.ecx.frame_mut().loc = Err(source_info.span); + match f(self) { + Ok(val) => Some(val), + Err(error) => { + trace!("InterpCx operation failed: {:?}", error); + // Some errors shouldn't come up because creating them causes + // an allocation, which we should avoid. When that happens, + // dedicated error variants should be introduced instead. + assert!( + !error.kind().formatted_string(), + "const-prop encountered formatting error: {}", + error + ); + None + } + } + } + + /// Returns the value, if any, of evaluating `c`. + fn eval_constant(&mut self, c: &Constant<'tcx>, source_info: SourceInfo) -> Option<OpTy<'tcx>> { + // FIXME we need to revisit this for #67176 + if c.needs_subst() { + return None; + } + + match self.ecx.mir_const_to_op(&c.literal, None) { + Ok(op) => Some(op), + Err(error) => { + let tcx = self.ecx.tcx.at(c.span); + let err = ConstEvalErr::new(&self.ecx, error, Some(c.span)); + if let Some(lint_root) = self.lint_root(source_info) { + let lint_only = match c.literal { + ConstantKind::Ty(ct) => match ct.kind() { + // Promoteds must lint and not error as the user didn't ask for them + ConstKind::Unevaluated(ty::Unevaluated { + def: _, + substs: _, + promoted: Some(_), + }) => true, + // Out of backwards compatibility we cannot report hard errors in unused + // generic functions using associated constants of the generic parameters. + _ => c.literal.needs_subst(), + }, + ConstantKind::Val(_, ty) => ty.needs_subst(), + }; + if lint_only { + // Out of backwards compatibility we cannot report hard errors in unused + // generic functions using associated constants of the generic parameters. + err.report_as_lint(tcx, "erroneous constant used", lint_root, Some(c.span)); + } else { + err.report_as_error(tcx, "erroneous constant used"); + } + } else { + err.report_as_error(tcx, "erroneous constant used"); + } + None + } + } + } + + /// Returns the value, if any, of evaluating `place`. + fn eval_place(&mut self, place: Place<'tcx>, source_info: SourceInfo) -> Option<OpTy<'tcx>> { + trace!("eval_place(place={:?})", place); + self.use_ecx(source_info, |this| this.ecx.eval_place_to_op(place, None)) + } + + /// Returns the value, if any, of evaluating `op`. Calls upon `eval_constant` + /// or `eval_place`, depending on the variant of `Operand` used. + fn eval_operand(&mut self, op: &Operand<'tcx>, source_info: SourceInfo) -> Option<OpTy<'tcx>> { + match *op { + Operand::Constant(ref c) => self.eval_constant(c, source_info), + Operand::Move(place) | Operand::Copy(place) => self.eval_place(place, source_info), + } + } + + fn report_assert_as_lint( + &self, + lint: &'static lint::Lint, + source_info: SourceInfo, + message: &'static str, + panic: AssertKind<impl std::fmt::Debug>, + ) { + if let Some(lint_root) = self.lint_root(source_info) { + self.tcx.struct_span_lint_hir(lint, lint_root, source_info.span, |lint| { + let mut err = lint.build(message); + err.span_label(source_info.span, format!("{:?}", panic)); + err.emit(); + }); + } + } + + fn check_unary_op( + &mut self, + op: UnOp, + arg: &Operand<'tcx>, + source_info: SourceInfo, + ) -> Option<()> { + if let (val, true) = self.use_ecx(source_info, |this| { + let val = this.ecx.read_immediate(&this.ecx.eval_operand(arg, None)?)?; + let (_res, overflow, _ty) = this.ecx.overflowing_unary_op(op, &val)?; + Ok((val, overflow)) + })? { + // `AssertKind` only has an `OverflowNeg` variant, so make sure that is + // appropriate to use. + assert_eq!(op, UnOp::Neg, "Neg is the only UnOp that can overflow"); + self.report_assert_as_lint( + lint::builtin::ARITHMETIC_OVERFLOW, + source_info, + "this arithmetic operation will overflow", + AssertKind::OverflowNeg(val.to_const_int()), + ); + return None; + } + + Some(()) + } + + fn check_binary_op( + &mut self, + op: BinOp, + left: &Operand<'tcx>, + right: &Operand<'tcx>, + source_info: SourceInfo, + ) -> Option<()> { + let r = self.use_ecx(source_info, |this| { + this.ecx.read_immediate(&this.ecx.eval_operand(right, None)?) + }); + let l = self.use_ecx(source_info, |this| { + this.ecx.read_immediate(&this.ecx.eval_operand(left, None)?) + }); + // Check for exceeding shifts *even if* we cannot evaluate the LHS. + if op == BinOp::Shr || op == BinOp::Shl { + let r = r.clone()?; + // We need the type of the LHS. We cannot use `place_layout` as that is the type + // of the result, which for checked binops is not the same! + let left_ty = left.ty(self.local_decls, self.tcx); + let left_size = self.ecx.layout_of(left_ty).ok()?.size; + let right_size = r.layout.size; + let r_bits = r.to_scalar().ok(); + let r_bits = r_bits.and_then(|r| r.to_bits(right_size).ok()); + if r_bits.map_or(false, |b| b >= left_size.bits() as u128) { + debug!("check_binary_op: reporting assert for {:?}", source_info); + self.report_assert_as_lint( + lint::builtin::ARITHMETIC_OVERFLOW, + source_info, + "this arithmetic operation will overflow", + AssertKind::Overflow( + op, + match l { + Some(l) => l.to_const_int(), + // Invent a dummy value, the diagnostic ignores it anyway + None => ConstInt::new( + ScalarInt::try_from_uint(1_u8, left_size).unwrap(), + left_ty.is_signed(), + left_ty.is_ptr_sized_integral(), + ), + }, + r.to_const_int(), + ), + ); + return None; + } + } + + if let (Some(l), Some(r)) = (l, r) { + // The remaining operators are handled through `overflowing_binary_op`. + if self.use_ecx(source_info, |this| { + let (_res, overflow, _ty) = this.ecx.overflowing_binary_op(op, &l, &r)?; + Ok(overflow) + })? { + self.report_assert_as_lint( + lint::builtin::ARITHMETIC_OVERFLOW, + source_info, + "this arithmetic operation will overflow", + AssertKind::Overflow(op, l.to_const_int(), r.to_const_int()), + ); + return None; + } + } + Some(()) + } + + fn const_prop( + &mut self, + rvalue: &Rvalue<'tcx>, + source_info: SourceInfo, + place: Place<'tcx>, + ) -> Option<()> { + // Perform any special handling for specific Rvalue types. + // Generally, checks here fall into one of two categories: + // 1. Additional checking to provide useful lints to the user + // - In this case, we will do some validation and then fall through to the + // end of the function which evals the assignment. + // 2. Working around bugs in other parts of the compiler + // - In this case, we'll return `None` from this function to stop evaluation. + match rvalue { + // Additional checking: give lints to the user if an overflow would occur. + // We do this here and not in the `Assert` terminator as that terminator is + // only sometimes emitted (overflow checks can be disabled), but we want to always + // lint. + Rvalue::UnaryOp(op, arg) => { + trace!("checking UnaryOp(op = {:?}, arg = {:?})", op, arg); + self.check_unary_op(*op, arg, source_info)?; + } + Rvalue::BinaryOp(op, box (left, right)) => { + trace!("checking BinaryOp(op = {:?}, left = {:?}, right = {:?})", op, left, right); + self.check_binary_op(*op, left, right, source_info)?; + } + Rvalue::CheckedBinaryOp(op, box (left, right)) => { + trace!( + "checking CheckedBinaryOp(op = {:?}, left = {:?}, right = {:?})", + op, + left, + right + ); + self.check_binary_op(*op, left, right, source_info)?; + } + + // Do not try creating references (#67862) + Rvalue::AddressOf(_, place) | Rvalue::Ref(_, _, place) => { + trace!("skipping AddressOf | Ref for {:?}", place); + + // This may be creating mutable references or immutable references to cells. + // If that happens, the pointed to value could be mutated via that reference. + // Since we aren't tracking references, the const propagator loses track of what + // value the local has right now. + // Thus, all locals that have their reference taken + // must not take part in propagation. + Self::remove_const(&mut self.ecx, place.local); + + return None; + } + Rvalue::ThreadLocalRef(def_id) => { + trace!("skipping ThreadLocalRef({:?})", def_id); + + return None; + } + + // There's no other checking to do at this time. + Rvalue::Aggregate(..) + | Rvalue::Use(..) + | Rvalue::CopyForDeref(..) + | Rvalue::Repeat(..) + | Rvalue::Len(..) + | Rvalue::Cast(..) + | Rvalue::ShallowInitBox(..) + | Rvalue::Discriminant(..) + | Rvalue::NullaryOp(..) => {} + } + + // FIXME we need to revisit this for #67176 + if rvalue.needs_subst() { + return None; + } + + self.use_ecx(source_info, |this| this.ecx.eval_rvalue_into_place(rvalue, place)) + } +} + +impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> { + fn visit_body(&mut self, body: &Body<'tcx>) { + for (bb, data) in body.basic_blocks().iter_enumerated() { + self.visit_basic_block_data(bb, data); + } + } + + fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) { + self.super_operand(operand, location); + } + + fn visit_constant(&mut self, constant: &Constant<'tcx>, location: Location) { + trace!("visit_constant: {:?}", constant); + self.super_constant(constant, location); + self.eval_constant(constant, self.source_info.unwrap()); + } + + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + trace!("visit_statement: {:?}", statement); + let source_info = statement.source_info; + self.source_info = Some(source_info); + if let StatementKind::Assign(box (place, ref rval)) = statement.kind { + let can_const_prop = self.ecx.machine.can_const_prop[place.local]; + if let Some(()) = self.const_prop(rval, source_info, place) { + match can_const_prop { + ConstPropMode::OnlyInsideOwnBlock => { + trace!( + "found local restricted to its block. \ + Will remove it from const-prop after block is finished. Local: {:?}", + place.local + ); + } + ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => { + trace!("can't propagate into {:?}", place); + if place.local != RETURN_PLACE { + Self::remove_const(&mut self.ecx, place.local); + } + } + ConstPropMode::FullConstProp => {} + } + } else { + // Const prop failed, so erase the destination, ensuring that whatever happens + // from here on, does not know about the previous value. + // This is important in case we have + // ```rust + // let mut x = 42; + // x = SOME_MUTABLE_STATIC; + // // x must now be uninit + // ``` + // FIXME: we overzealously erase the entire local, because that's easier to + // implement. + trace!( + "propagation into {:?} failed. + Nuking the entire site from orbit, it's the only way to be sure", + place, + ); + Self::remove_const(&mut self.ecx, place.local); + } + } else { + match statement.kind { + StatementKind::SetDiscriminant { ref place, .. } => { + match self.ecx.machine.can_const_prop[place.local] { + ConstPropMode::FullConstProp | ConstPropMode::OnlyInsideOwnBlock => { + if self + .use_ecx(source_info, |this| this.ecx.statement(statement)) + .is_some() + { + trace!("propped discriminant into {:?}", place); + } else { + Self::remove_const(&mut self.ecx, place.local); + } + } + ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => { + Self::remove_const(&mut self.ecx, place.local); + } + } + } + StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => { + let frame = self.ecx.frame_mut(); + frame.locals[local].value = + if let StatementKind::StorageLive(_) = statement.kind { + LocalValue::Live(interpret::Operand::Immediate( + interpret::Immediate::Uninit, + )) + } else { + LocalValue::Dead + }; + } + _ => {} + } + } + + self.super_statement(statement, location); + } + + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { + let source_info = terminator.source_info; + self.source_info = Some(source_info); + self.super_terminator(terminator, location); + match &terminator.kind { + TerminatorKind::Assert { expected, ref msg, ref cond, .. } => { + if let Some(ref value) = self.eval_operand(&cond, source_info) { + trace!("assertion on {:?} should be {:?}", value, expected); + let expected = ScalarMaybeUninit::from(Scalar::from_bool(*expected)); + let value_const = self.ecx.read_scalar(&value).unwrap(); + if expected != value_const { + enum DbgVal<T> { + Val(T), + Underscore, + } + impl<T: std::fmt::Debug> std::fmt::Debug for DbgVal<T> { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Val(val) => val.fmt(fmt), + Self::Underscore => fmt.write_str("_"), + } + } + } + let mut eval_to_int = |op| { + // This can be `None` if the lhs wasn't const propagated and we just + // triggered the assert on the value of the rhs. + self.eval_operand(op, source_info).map_or(DbgVal::Underscore, |op| { + DbgVal::Val(self.ecx.read_immediate(&op).unwrap().to_const_int()) + }) + }; + let msg = match msg { + AssertKind::DivisionByZero(op) => { + Some(AssertKind::DivisionByZero(eval_to_int(op))) + } + AssertKind::RemainderByZero(op) => { + Some(AssertKind::RemainderByZero(eval_to_int(op))) + } + AssertKind::Overflow(bin_op @ (BinOp::Div | BinOp::Rem), op1, op2) => { + // Division overflow is *UB* in the MIR, and different than the + // other overflow checks. + Some(AssertKind::Overflow( + *bin_op, + eval_to_int(op1), + eval_to_int(op2), + )) + } + AssertKind::BoundsCheck { ref len, ref index } => { + let len = eval_to_int(len); + let index = eval_to_int(index); + Some(AssertKind::BoundsCheck { len, index }) + } + // Remaining overflow errors are already covered by checks on the binary operators. + AssertKind::Overflow(..) | AssertKind::OverflowNeg(_) => None, + // Need proper const propagator for these. + _ => None, + }; + // Poison all places this operand references so that further code + // doesn't use the invalid value + match cond { + Operand::Move(ref place) | Operand::Copy(ref place) => { + Self::remove_const(&mut self.ecx, place.local); + } + Operand::Constant(_) => {} + } + if let Some(msg) = msg { + self.report_assert_as_lint( + lint::builtin::UNCONDITIONAL_PANIC, + source_info, + "this operation will panic at runtime", + msg, + ); + } + } + } + } + // None of these have Operands to const-propagate. + TerminatorKind::Goto { .. } + | TerminatorKind::Resume + | TerminatorKind::Abort + | TerminatorKind::Return + | TerminatorKind::Unreachable + | TerminatorKind::Drop { .. } + | TerminatorKind::DropAndReplace { .. } + | TerminatorKind::Yield { .. } + | TerminatorKind::GeneratorDrop + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } + | TerminatorKind::SwitchInt { .. } + | TerminatorKind::Call { .. } + | TerminatorKind::InlineAsm { .. } => {} + } + + // We remove all Locals which are restricted in propagation to their containing blocks and + // which were modified in the current block. + // Take it out of the ecx so we can get a mutable reference to the ecx for `remove_const`. + let mut locals = std::mem::take(&mut self.ecx.machine.written_only_inside_own_block_locals); + for &local in locals.iter() { + Self::remove_const(&mut self.ecx, local); + } + locals.clear(); + // Put it back so we reuse the heap of the storage + self.ecx.machine.written_only_inside_own_block_locals = locals; + if cfg!(debug_assertions) { + // Ensure we are correctly erasing locals with the non-debug-assert logic. + for local in self.ecx.machine.only_propagate_inside_block_locals.iter() { + assert!( + self.get_const(local.into()).is_none() + || self + .layout_of(self.local_decls[local].ty) + .map_or(true, |layout| layout.is_zst()) + ) + } + } + } +} diff --git a/compiler/rustc_mir_transform/src/coverage/counters.rs b/compiler/rustc_mir_transform/src/coverage/counters.rs new file mode 100644 index 000000000..45de0c280 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/counters.rs @@ -0,0 +1,614 @@ +use super::Error; + +use super::debug; +use super::graph; +use super::spans; + +use debug::{DebugCounters, NESTED_INDENT}; +use graph::{BasicCoverageBlock, BcbBranch, CoverageGraph, TraverseCoverageGraphWithLoops}; +use spans::CoverageSpan; + +use rustc_data_structures::graph::WithNumNodes; +use rustc_index::bit_set::BitSet; +use rustc_middle::mir::coverage::*; + +/// Manages the counter and expression indexes/IDs to generate `CoverageKind` components for MIR +/// `Coverage` statements. +pub(super) struct CoverageCounters { + function_source_hash: u64, + next_counter_id: u32, + num_expressions: u32, + pub debug_counters: DebugCounters, +} + +impl CoverageCounters { + pub fn new(function_source_hash: u64) -> Self { + Self { + function_source_hash, + next_counter_id: CounterValueReference::START.as_u32(), + num_expressions: 0, + debug_counters: DebugCounters::new(), + } + } + + /// Activate the `DebugCounters` data structures, to provide additional debug formatting + /// features when formatting `CoverageKind` (counter) values. + pub fn enable_debug(&mut self) { + self.debug_counters.enable(); + } + + /// Makes `CoverageKind` `Counter`s and `Expressions` for the `BasicCoverageBlock`s directly or + /// indirectly associated with `CoverageSpans`, and returns additional `Expression`s + /// representing intermediate values. + pub fn make_bcb_counters( + &mut self, + basic_coverage_blocks: &mut CoverageGraph, + coverage_spans: &[CoverageSpan], + ) -> Result<Vec<CoverageKind>, Error> { + let mut bcb_counters = BcbCounters::new(self, basic_coverage_blocks); + bcb_counters.make_bcb_counters(coverage_spans) + } + + fn make_counter<F>(&mut self, debug_block_label_fn: F) -> CoverageKind + where + F: Fn() -> Option<String>, + { + let counter = CoverageKind::Counter { + function_source_hash: self.function_source_hash, + id: self.next_counter(), + }; + if self.debug_counters.is_enabled() { + self.debug_counters.add_counter(&counter, (debug_block_label_fn)()); + } + counter + } + + fn make_expression<F>( + &mut self, + lhs: ExpressionOperandId, + op: Op, + rhs: ExpressionOperandId, + debug_block_label_fn: F, + ) -> CoverageKind + where + F: Fn() -> Option<String>, + { + let id = self.next_expression(); + let expression = CoverageKind::Expression { id, lhs, op, rhs }; + if self.debug_counters.is_enabled() { + self.debug_counters.add_counter(&expression, (debug_block_label_fn)()); + } + expression + } + + pub fn make_identity_counter(&mut self, counter_operand: ExpressionOperandId) -> CoverageKind { + let some_debug_block_label = if self.debug_counters.is_enabled() { + self.debug_counters.some_block_label(counter_operand).cloned() + } else { + None + }; + self.make_expression(counter_operand, Op::Add, ExpressionOperandId::ZERO, || { + some_debug_block_label.clone() + }) + } + + /// Counter IDs start from one and go up. + fn next_counter(&mut self) -> CounterValueReference { + assert!(self.next_counter_id < u32::MAX - self.num_expressions); + let next = self.next_counter_id; + self.next_counter_id += 1; + CounterValueReference::from(next) + } + + /// Expression IDs start from u32::MAX and go down because an Expression can reference + /// (add or subtract counts) of both Counter regions and Expression regions. The counter + /// expression operand IDs must be unique across both types. + fn next_expression(&mut self) -> InjectedExpressionId { + assert!(self.next_counter_id < u32::MAX - self.num_expressions); + let next = u32::MAX - self.num_expressions; + self.num_expressions += 1; + InjectedExpressionId::from(next) + } +} + +/// Traverse the `CoverageGraph` and add either a `Counter` or `Expression` to every BCB, to be +/// injected with `CoverageSpan`s. `Expressions` have no runtime overhead, so if a viable expression +/// (adding or subtracting two other counters or expressions) can compute the same result as an +/// embedded counter, an `Expression` should be used. +struct BcbCounters<'a> { + coverage_counters: &'a mut CoverageCounters, + basic_coverage_blocks: &'a mut CoverageGraph, +} + +impl<'a> BcbCounters<'a> { + fn new( + coverage_counters: &'a mut CoverageCounters, + basic_coverage_blocks: &'a mut CoverageGraph, + ) -> Self { + Self { coverage_counters, basic_coverage_blocks } + } + + /// If two `BasicCoverageBlock`s branch from another `BasicCoverageBlock`, one of the branches + /// can be counted by `Expression` by subtracting the other branch from the branching + /// block. Otherwise, the `BasicCoverageBlock` executed the least should have the `Counter`. + /// One way to predict which branch executes the least is by considering loops. A loop is exited + /// at a branch, so the branch that jumps to a `BasicCoverageBlock` outside the loop is almost + /// always executed less than the branch that does not exit the loop. + /// + /// Returns any non-code-span expressions created to represent intermediate values (such as to + /// add two counters so the result can be subtracted from another counter), or an Error with + /// message for subsequent debugging. + fn make_bcb_counters( + &mut self, + coverage_spans: &[CoverageSpan], + ) -> Result<Vec<CoverageKind>, Error> { + debug!("make_bcb_counters(): adding a counter or expression to each BasicCoverageBlock"); + let num_bcbs = self.basic_coverage_blocks.num_nodes(); + let mut collect_intermediate_expressions = Vec::with_capacity(num_bcbs); + + let mut bcbs_with_coverage = BitSet::new_empty(num_bcbs); + for covspan in coverage_spans { + bcbs_with_coverage.insert(covspan.bcb); + } + + // Walk the `CoverageGraph`. For each `BasicCoverageBlock` node with an associated + // `CoverageSpan`, add a counter. If the `BasicCoverageBlock` branches, add a counter or + // expression to each branch `BasicCoverageBlock` (if the branch BCB has only one incoming + // edge) or edge from the branching BCB to the branch BCB (if the branch BCB has multiple + // incoming edges). + // + // The `TraverseCoverageGraphWithLoops` traversal ensures that, when a loop is encountered, + // all `BasicCoverageBlock` nodes in the loop are visited before visiting any node outside + // the loop. The `traversal` state includes a `context_stack`, providing a way to know if + // the current BCB is in one or more nested loops or not. + let mut traversal = TraverseCoverageGraphWithLoops::new(&self.basic_coverage_blocks); + while let Some(bcb) = traversal.next(self.basic_coverage_blocks) { + if bcbs_with_coverage.contains(bcb) { + debug!("{:?} has at least one `CoverageSpan`. Get or make its counter", bcb); + let branching_counter_operand = + self.get_or_make_counter_operand(bcb, &mut collect_intermediate_expressions)?; + + if self.bcb_needs_branch_counters(bcb) { + self.make_branch_counters( + &mut traversal, + bcb, + branching_counter_operand, + &mut collect_intermediate_expressions, + )?; + } + } else { + debug!( + "{:?} does not have any `CoverageSpan`s. A counter will only be added if \ + and when a covered BCB has an expression dependency.", + bcb, + ); + } + } + + if traversal.is_complete() { + Ok(collect_intermediate_expressions) + } else { + Error::from_string(format!( + "`TraverseCoverageGraphWithLoops` missed some `BasicCoverageBlock`s: {:?}", + traversal.unvisited(), + )) + } + } + + fn make_branch_counters( + &mut self, + traversal: &mut TraverseCoverageGraphWithLoops, + branching_bcb: BasicCoverageBlock, + branching_counter_operand: ExpressionOperandId, + collect_intermediate_expressions: &mut Vec<CoverageKind>, + ) -> Result<(), Error> { + let branches = self.bcb_branches(branching_bcb); + debug!( + "{:?} has some branch(es) without counters:\n {}", + branching_bcb, + branches + .iter() + .map(|branch| { + format!("{:?}: {:?}", branch, branch.counter(&self.basic_coverage_blocks)) + }) + .collect::<Vec<_>>() + .join("\n "), + ); + + // Use the `traversal` state to decide if a subset of the branches exit a loop, making it + // likely that branch is executed less than branches that do not exit the same loop. In this + // case, any branch that does not exit the loop (and has not already been assigned a + // counter) should be counted by expression, if possible. (If a preferred expression branch + // is not selected based on the loop context, select any branch without an existing + // counter.) + let expression_branch = self.choose_preferred_expression_branch(traversal, &branches); + + // Assign a Counter or Expression to each branch, plus additional `Expression`s, as needed, + // to sum up intermediate results. + let mut some_sumup_counter_operand = None; + for branch in branches { + // Skip the selected `expression_branch`, if any. It's expression will be assigned after + // all others. + if branch != expression_branch { + let branch_counter_operand = if branch.is_only_path_to_target() { + debug!( + " {:?} has only one incoming edge (from {:?}), so adding a \ + counter", + branch, branching_bcb + ); + self.get_or_make_counter_operand( + branch.target_bcb, + collect_intermediate_expressions, + )? + } else { + debug!(" {:?} has multiple incoming edges, so adding an edge counter", branch); + self.get_or_make_edge_counter_operand( + branching_bcb, + branch.target_bcb, + collect_intermediate_expressions, + )? + }; + if let Some(sumup_counter_operand) = + some_sumup_counter_operand.replace(branch_counter_operand) + { + let intermediate_expression = self.coverage_counters.make_expression( + branch_counter_operand, + Op::Add, + sumup_counter_operand, + || None, + ); + debug!( + " [new intermediate expression: {}]", + self.format_counter(&intermediate_expression) + ); + let intermediate_expression_operand = intermediate_expression.as_operand_id(); + collect_intermediate_expressions.push(intermediate_expression); + some_sumup_counter_operand.replace(intermediate_expression_operand); + } + } + } + + // Assign the final expression to the `expression_branch` by subtracting the total of all + // other branches from the counter of the branching BCB. + let sumup_counter_operand = + some_sumup_counter_operand.expect("sumup_counter_operand should have a value"); + debug!( + "Making an expression for the selected expression_branch: {:?} \ + (expression_branch predecessors: {:?})", + expression_branch, + self.bcb_predecessors(expression_branch.target_bcb), + ); + let expression = self.coverage_counters.make_expression( + branching_counter_operand, + Op::Subtract, + sumup_counter_operand, + || Some(format!("{:?}", expression_branch)), + ); + debug!("{:?} gets an expression: {}", expression_branch, self.format_counter(&expression)); + let bcb = expression_branch.target_bcb; + if expression_branch.is_only_path_to_target() { + self.basic_coverage_blocks[bcb].set_counter(expression)?; + } else { + self.basic_coverage_blocks[bcb].set_edge_counter_from(branching_bcb, expression)?; + } + Ok(()) + } + + fn get_or_make_counter_operand( + &mut self, + bcb: BasicCoverageBlock, + collect_intermediate_expressions: &mut Vec<CoverageKind>, + ) -> Result<ExpressionOperandId, Error> { + self.recursive_get_or_make_counter_operand(bcb, collect_intermediate_expressions, 1) + } + + fn recursive_get_or_make_counter_operand( + &mut self, + bcb: BasicCoverageBlock, + collect_intermediate_expressions: &mut Vec<CoverageKind>, + debug_indent_level: usize, + ) -> Result<ExpressionOperandId, Error> { + // If the BCB already has a counter, return it. + if let Some(counter_kind) = self.basic_coverage_blocks[bcb].counter() { + debug!( + "{}{:?} already has a counter: {}", + NESTED_INDENT.repeat(debug_indent_level), + bcb, + self.format_counter(counter_kind), + ); + return Ok(counter_kind.as_operand_id()); + } + + // A BCB with only one incoming edge gets a simple `Counter` (via `make_counter()`). + // Also, a BCB that loops back to itself gets a simple `Counter`. This may indicate the + // program results in a tight infinite loop, but it should still compile. + let one_path_to_target = self.bcb_has_one_path_to_target(bcb); + if one_path_to_target || self.bcb_predecessors(bcb).contains(&bcb) { + let counter_kind = self.coverage_counters.make_counter(|| Some(format!("{:?}", bcb))); + if one_path_to_target { + debug!( + "{}{:?} gets a new counter: {}", + NESTED_INDENT.repeat(debug_indent_level), + bcb, + self.format_counter(&counter_kind), + ); + } else { + debug!( + "{}{:?} has itself as its own predecessor. It can't be part of its own \ + Expression sum, so it will get its own new counter: {}. (Note, the compiled \ + code will generate an infinite loop.)", + NESTED_INDENT.repeat(debug_indent_level), + bcb, + self.format_counter(&counter_kind), + ); + } + return self.basic_coverage_blocks[bcb].set_counter(counter_kind); + } + + // A BCB with multiple incoming edges can compute its count by `Expression`, summing up the + // counters and/or expressions of its incoming edges. This will recursively get or create + // counters for those incoming edges first, then call `make_expression()` to sum them up, + // with additional intermediate expressions as needed. + let mut predecessors = self.bcb_predecessors(bcb).to_owned().into_iter(); + debug!( + "{}{:?} has multiple incoming edges and will get an expression that sums them up...", + NESTED_INDENT.repeat(debug_indent_level), + bcb, + ); + let first_edge_counter_operand = self.recursive_get_or_make_edge_counter_operand( + predecessors.next().unwrap(), + bcb, + collect_intermediate_expressions, + debug_indent_level + 1, + )?; + let mut some_sumup_edge_counter_operand = None; + for predecessor in predecessors { + let edge_counter_operand = self.recursive_get_or_make_edge_counter_operand( + predecessor, + bcb, + collect_intermediate_expressions, + debug_indent_level + 1, + )?; + if let Some(sumup_edge_counter_operand) = + some_sumup_edge_counter_operand.replace(edge_counter_operand) + { + let intermediate_expression = self.coverage_counters.make_expression( + sumup_edge_counter_operand, + Op::Add, + edge_counter_operand, + || None, + ); + debug!( + "{}new intermediate expression: {}", + NESTED_INDENT.repeat(debug_indent_level), + self.format_counter(&intermediate_expression) + ); + let intermediate_expression_operand = intermediate_expression.as_operand_id(); + collect_intermediate_expressions.push(intermediate_expression); + some_sumup_edge_counter_operand.replace(intermediate_expression_operand); + } + } + let counter_kind = self.coverage_counters.make_expression( + first_edge_counter_operand, + Op::Add, + some_sumup_edge_counter_operand.unwrap(), + || Some(format!("{:?}", bcb)), + ); + debug!( + "{}{:?} gets a new counter (sum of predecessor counters): {}", + NESTED_INDENT.repeat(debug_indent_level), + bcb, + self.format_counter(&counter_kind) + ); + self.basic_coverage_blocks[bcb].set_counter(counter_kind) + } + + fn get_or_make_edge_counter_operand( + &mut self, + from_bcb: BasicCoverageBlock, + to_bcb: BasicCoverageBlock, + collect_intermediate_expressions: &mut Vec<CoverageKind>, + ) -> Result<ExpressionOperandId, Error> { + self.recursive_get_or_make_edge_counter_operand( + from_bcb, + to_bcb, + collect_intermediate_expressions, + 1, + ) + } + + fn recursive_get_or_make_edge_counter_operand( + &mut self, + from_bcb: BasicCoverageBlock, + to_bcb: BasicCoverageBlock, + collect_intermediate_expressions: &mut Vec<CoverageKind>, + debug_indent_level: usize, + ) -> Result<ExpressionOperandId, Error> { + // If the source BCB has only one successor (assumed to be the given target), an edge + // counter is unnecessary. Just get or make a counter for the source BCB. + let successors = self.bcb_successors(from_bcb).iter(); + if successors.len() == 1 { + return self.recursive_get_or_make_counter_operand( + from_bcb, + collect_intermediate_expressions, + debug_indent_level + 1, + ); + } + + // If the edge already has a counter, return it. + if let Some(counter_kind) = self.basic_coverage_blocks[to_bcb].edge_counter_from(from_bcb) { + debug!( + "{}Edge {:?}->{:?} already has a counter: {}", + NESTED_INDENT.repeat(debug_indent_level), + from_bcb, + to_bcb, + self.format_counter(counter_kind) + ); + return Ok(counter_kind.as_operand_id()); + } + + // Make a new counter to count this edge. + let counter_kind = + self.coverage_counters.make_counter(|| Some(format!("{:?}->{:?}", from_bcb, to_bcb))); + debug!( + "{}Edge {:?}->{:?} gets a new counter: {}", + NESTED_INDENT.repeat(debug_indent_level), + from_bcb, + to_bcb, + self.format_counter(&counter_kind) + ); + self.basic_coverage_blocks[to_bcb].set_edge_counter_from(from_bcb, counter_kind) + } + + /// Select a branch for the expression, either the recommended `reloop_branch`, or if none was + /// found, select any branch. + fn choose_preferred_expression_branch( + &self, + traversal: &TraverseCoverageGraphWithLoops, + branches: &[BcbBranch], + ) -> BcbBranch { + let branch_needs_a_counter = + |branch: &BcbBranch| branch.counter(&self.basic_coverage_blocks).is_none(); + + let some_reloop_branch = self.find_some_reloop_branch(traversal, &branches); + if let Some(reloop_branch_without_counter) = + some_reloop_branch.filter(branch_needs_a_counter) + { + debug!( + "Selecting reloop_branch={:?} that still needs a counter, to get the \ + `Expression`", + reloop_branch_without_counter + ); + reloop_branch_without_counter + } else { + let &branch_without_counter = branches + .iter() + .find(|&&branch| branch.counter(&self.basic_coverage_blocks).is_none()) + .expect( + "needs_branch_counters was `true` so there should be at least one \ + branch", + ); + debug!( + "Selecting any branch={:?} that still needs a counter, to get the \ + `Expression` because there was no `reloop_branch`, or it already had a \ + counter", + branch_without_counter + ); + branch_without_counter + } + } + + /// At most, one of the branches (or its edge, from the branching_bcb, if the branch has + /// multiple incoming edges) can have a counter computed by expression. + /// + /// If at least one of the branches leads outside of a loop (`found_loop_exit` is + /// true), and at least one other branch does not exit the loop (the first of which + /// is captured in `some_reloop_branch`), it's likely any reloop branch will be + /// executed far more often than loop exit branch, making the reloop branch a better + /// candidate for an expression. + fn find_some_reloop_branch( + &self, + traversal: &TraverseCoverageGraphWithLoops, + branches: &[BcbBranch], + ) -> Option<BcbBranch> { + let branch_needs_a_counter = + |branch: &BcbBranch| branch.counter(&self.basic_coverage_blocks).is_none(); + + let mut some_reloop_branch: Option<BcbBranch> = None; + for context in traversal.context_stack.iter().rev() { + if let Some((backedge_from_bcbs, _)) = &context.loop_backedges { + let mut found_loop_exit = false; + for &branch in branches.iter() { + if backedge_from_bcbs.iter().any(|&backedge_from_bcb| { + self.bcb_is_dominated_by(backedge_from_bcb, branch.target_bcb) + }) { + if let Some(reloop_branch) = some_reloop_branch { + if reloop_branch.counter(&self.basic_coverage_blocks).is_none() { + // we already found a candidate reloop_branch that still + // needs a counter + continue; + } + } + // The path from branch leads back to the top of the loop. Set this + // branch as the `reloop_branch`. If this branch already has a + // counter, and we find another reloop branch that doesn't have a + // counter yet, that branch will be selected as the `reloop_branch` + // instead. + some_reloop_branch = Some(branch); + } else { + // The path from branch leads outside this loop + found_loop_exit = true; + } + if found_loop_exit + && some_reloop_branch.filter(branch_needs_a_counter).is_some() + { + // Found both a branch that exits the loop and a branch that returns + // to the top of the loop (`reloop_branch`), and the `reloop_branch` + // doesn't already have a counter. + break; + } + } + if !found_loop_exit { + debug!( + "No branches exit the loop, so any branch without an existing \ + counter can have the `Expression`." + ); + break; + } + if some_reloop_branch.is_some() { + debug!( + "Found a branch that exits the loop and a branch the loops back to \ + the top of the loop (`reloop_branch`). The `reloop_branch` will \ + get the `Expression`, as long as it still needs a counter." + ); + break; + } + // else all branches exited this loop context, so run the same checks with + // the outer loop(s) + } + } + some_reloop_branch + } + + #[inline] + fn bcb_predecessors(&self, bcb: BasicCoverageBlock) -> &[BasicCoverageBlock] { + &self.basic_coverage_blocks.predecessors[bcb] + } + + #[inline] + fn bcb_successors(&self, bcb: BasicCoverageBlock) -> &[BasicCoverageBlock] { + &self.basic_coverage_blocks.successors[bcb] + } + + #[inline] + fn bcb_branches(&self, from_bcb: BasicCoverageBlock) -> Vec<BcbBranch> { + self.bcb_successors(from_bcb) + .iter() + .map(|&to_bcb| BcbBranch::from_to(from_bcb, to_bcb, &self.basic_coverage_blocks)) + .collect::<Vec<_>>() + } + + fn bcb_needs_branch_counters(&self, bcb: BasicCoverageBlock) -> bool { + let branch_needs_a_counter = + |branch: &BcbBranch| branch.counter(&self.basic_coverage_blocks).is_none(); + let branches = self.bcb_branches(bcb); + branches.len() > 1 && branches.iter().any(branch_needs_a_counter) + } + + /// Returns true if the BasicCoverageBlock has zero or one incoming edge. (If zero, it should be + /// the entry point for the function.) + #[inline] + fn bcb_has_one_path_to_target(&self, bcb: BasicCoverageBlock) -> bool { + self.bcb_predecessors(bcb).len() <= 1 + } + + #[inline] + fn bcb_is_dominated_by(&self, node: BasicCoverageBlock, dom: BasicCoverageBlock) -> bool { + self.basic_coverage_blocks.is_dominated_by(node, dom) + } + + #[inline] + fn format_counter(&self, counter_kind: &CoverageKind) -> String { + self.coverage_counters.debug_counters.format_counter(counter_kind) + } +} diff --git a/compiler/rustc_mir_transform/src/coverage/debug.rs b/compiler/rustc_mir_transform/src/coverage/debug.rs new file mode 100644 index 000000000..0f8679b0b --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/debug.rs @@ -0,0 +1,831 @@ +//! The `InstrumentCoverage` MIR pass implementation includes debugging tools and options +//! to help developers understand and/or improve the analysis and instrumentation of a MIR. +//! +//! To enable coverage, include the rustc command line option: +//! +//! * `-C instrument-coverage` +//! +//! MIR Dump Files, with additional `CoverageGraph` graphviz and `CoverageSpan` spanview +//! ------------------------------------------------------------------------------------ +//! +//! Additional debugging options include: +//! +//! * `-Z dump-mir=InstrumentCoverage` - Generate `.mir` files showing the state of the MIR, +//! before and after the `InstrumentCoverage` pass, for each compiled function. +//! +//! * `-Z dump-mir-graphviz` - If `-Z dump-mir` is also enabled for the current MIR node path, +//! each MIR dump is accompanied by a before-and-after graphical view of the MIR, in Graphviz +//! `.dot` file format (which can be visually rendered as a graph using any of a number of free +//! Graphviz viewers and IDE extensions). +//! +//! For the `InstrumentCoverage` pass, this option also enables generation of an additional +//! Graphviz `.dot` file for each function, rendering the `CoverageGraph`: the control flow +//! graph (CFG) of `BasicCoverageBlocks` (BCBs), as nodes, internally labeled to show the +//! `CoverageSpan`-based MIR elements each BCB represents (`BasicBlock`s, `Statement`s and +//! `Terminator`s), assigned coverage counters and/or expressions, and edge counters, as needed. +//! +//! (Note the additional option, `-Z graphviz-dark-mode`, can be added, to change the rendered +//! output from its default black-on-white background to a dark color theme, if desired.) +//! +//! * `-Z dump-mir-spanview` - If `-Z dump-mir` is also enabled for the current MIR node path, +//! each MIR dump is accompanied by a before-and-after `.html` document showing the function's +//! original source code, highlighted by it's MIR spans, at the `statement`-level (by default), +//! `terminator` only, or encompassing span for the `Terminator` plus all `Statement`s, in each +//! `block` (`BasicBlock`). +//! +//! For the `InstrumentCoverage` pass, this option also enables generation of an additional +//! spanview `.html` file for each function, showing the aggregated `CoverageSpan`s that will +//! require counters (or counter expressions) for accurate coverage analysis. +//! +//! Debug Logging +//! ------------- +//! +//! The `InstrumentCoverage` pass includes debug logging messages at various phases and decision +//! points, which can be enabled via environment variable: +//! +//! ```shell +//! RUSTC_LOG=rustc_mir_transform::transform::coverage=debug +//! ``` +//! +//! Other module paths with coverage-related debug logs may also be of interest, particularly for +//! debugging the coverage map data, injected as global variables in the LLVM IR (during rustc's +//! code generation pass). For example: +//! +//! ```shell +//! RUSTC_LOG=rustc_mir_transform::transform::coverage,rustc_codegen_ssa::coverageinfo,rustc_codegen_llvm::coverageinfo=debug +//! ``` +//! +//! Coverage Debug Options +//! --------------------------------- +//! +//! Additional debugging options can be enabled using the environment variable: +//! +//! ```shell +//! RUSTC_COVERAGE_DEBUG_OPTIONS=<options> +//! ``` +//! +//! These options are comma-separated, and specified in the format `option-name=value`. For example: +//! +//! ```shell +//! $ RUSTC_COVERAGE_DEBUG_OPTIONS=counter-format=id+operation,allow-unused-expressions=yes cargo build +//! ``` +//! +//! Coverage debug options include: +//! +//! * `allow-unused-expressions=yes` or `no` (default: `no`) +//! +//! The `InstrumentCoverage` algorithms _should_ only create and assign expressions to a +//! `BasicCoverageBlock`, or an incoming edge, if that expression is either (a) required to +//! count a `CoverageSpan`, or (b) a dependency of some other required counter expression. +//! +//! If an expression is generated that does not map to a `CoverageSpan` or dependency, this +//! probably indicates there was a bug in the algorithm that creates and assigns counters +//! and expressions. +//! +//! When this kind of bug is encountered, the rustc compiler will panic by default. Setting: +//! `allow-unused-expressions=yes` will log a warning message instead of panicking (effectively +//! ignoring the unused expressions), which may be helpful when debugging the root cause of +//! the problem. +//! +//! * `counter-format=<choices>`, where `<choices>` can be any plus-separated combination of `id`, +//! `block`, and/or `operation` (default: `block+operation`) +//! +//! This option effects both the `CoverageGraph` (graphviz `.dot` files) and debug logging, when +//! generating labels for counters and expressions. +//! +//! Depending on the values and combinations, counters can be labeled by: +//! +//! * `id` - counter or expression ID (ascending counter IDs, starting at 1, or descending +//! expression IDs, starting at `u32:MAX`) +//! * `block` - the `BasicCoverageBlock` label (for example, `bcb0`) or edge label (for +//! example `bcb0->bcb1`), for counters or expressions assigned to count a +//! `BasicCoverageBlock` or edge. Intermediate expressions (not directly associated with +//! a BCB or edge) will be labeled by their expression ID, unless `operation` is also +//! specified. +//! * `operation` - applied to expressions only, labels include the left-hand-side counter +//! or expression label (lhs operand), the operator (`+` or `-`), and the right-hand-side +//! counter or expression (rhs operand). Expression operand labels are generated +//! recursively, generating labels with nested operations, enclosed in parentheses +//! (for example: `bcb2 + (bcb0 - bcb1)`). + +use super::graph::{BasicCoverageBlock, BasicCoverageBlockData, CoverageGraph}; +use super::spans::CoverageSpan; + +use itertools::Itertools; +use rustc_middle::mir::create_dump_file; +use rustc_middle::mir::generic_graphviz::GraphvizWriter; +use rustc_middle::mir::spanview::{self, SpanViewable}; + +use rustc_data_structures::fx::FxHashMap; +use rustc_middle::mir::coverage::*; +use rustc_middle::mir::{self, BasicBlock, TerminatorKind}; +use rustc_middle::ty::TyCtxt; +use rustc_span::Span; + +use std::iter; +use std::ops::Deref; +use std::sync::OnceLock; + +pub const NESTED_INDENT: &str = " "; + +const RUSTC_COVERAGE_DEBUG_OPTIONS: &str = "RUSTC_COVERAGE_DEBUG_OPTIONS"; + +pub(super) fn debug_options<'a>() -> &'a DebugOptions { + static DEBUG_OPTIONS: OnceLock<DebugOptions> = OnceLock::new(); + + &DEBUG_OPTIONS.get_or_init(DebugOptions::from_env) +} + +/// Parses and maintains coverage-specific debug options captured from the environment variable +/// "RUSTC_COVERAGE_DEBUG_OPTIONS", if set. +#[derive(Debug, Clone)] +pub(super) struct DebugOptions { + pub allow_unused_expressions: bool, + counter_format: ExpressionFormat, +} + +impl DebugOptions { + fn from_env() -> Self { + let mut allow_unused_expressions = true; + let mut counter_format = ExpressionFormat::default(); + + if let Ok(env_debug_options) = std::env::var(RUSTC_COVERAGE_DEBUG_OPTIONS) { + for setting_str in env_debug_options.replace(' ', "").replace('-', "_").split(',') { + let (option, value) = match setting_str.split_once('=') { + None => (setting_str, None), + Some((k, v)) => (k, Some(v)), + }; + match option { + "allow_unused_expressions" => { + allow_unused_expressions = bool_option_val(option, value); + debug!( + "{} env option `allow_unused_expressions` is set to {}", + RUSTC_COVERAGE_DEBUG_OPTIONS, allow_unused_expressions + ); + } + "counter_format" => { + match value { + None => { + bug!( + "`{}` option in environment variable {} requires one or more \ + plus-separated choices (a non-empty subset of \ + `id+block+operation`)", + option, + RUSTC_COVERAGE_DEBUG_OPTIONS + ); + } + Some(val) => { + counter_format = counter_format_option_val(val); + debug!( + "{} env option `counter_format` is set to {:?}", + RUSTC_COVERAGE_DEBUG_OPTIONS, counter_format + ); + } + }; + } + _ => bug!( + "Unsupported setting `{}` in environment variable {}", + option, + RUSTC_COVERAGE_DEBUG_OPTIONS + ), + }; + } + } + + Self { allow_unused_expressions, counter_format } + } +} + +fn bool_option_val(option: &str, some_strval: Option<&str>) -> bool { + if let Some(val) = some_strval { + if vec!["yes", "y", "on", "true"].contains(&val) { + true + } else if vec!["no", "n", "off", "false"].contains(&val) { + false + } else { + bug!( + "Unsupported value `{}` for option `{}` in environment variable {}", + option, + val, + RUSTC_COVERAGE_DEBUG_OPTIONS + ) + } + } else { + true + } +} + +fn counter_format_option_val(strval: &str) -> ExpressionFormat { + let mut counter_format = ExpressionFormat { id: false, block: false, operation: false }; + let components = strval.splitn(3, '+'); + for component in components { + match component { + "id" => counter_format.id = true, + "block" => counter_format.block = true, + "operation" => counter_format.operation = true, + _ => bug!( + "Unsupported counter_format choice `{}` in environment variable {}", + component, + RUSTC_COVERAGE_DEBUG_OPTIONS + ), + } + } + counter_format +} + +#[derive(Debug, Clone)] +struct ExpressionFormat { + id: bool, + block: bool, + operation: bool, +} + +impl Default for ExpressionFormat { + fn default() -> Self { + Self { id: false, block: true, operation: true } + } +} + +/// If enabled, this struct maintains a map from `CoverageKind` IDs (as `ExpressionOperandId`) to +/// the `CoverageKind` data and optional label (normally, the counter's associated +/// `BasicCoverageBlock` format string, if any). +/// +/// Use `format_counter` to convert one of these `CoverageKind` counters to a debug output string, +/// as directed by the `DebugOptions`. This allows the format of counter labels in logs and dump +/// files (including the `CoverageGraph` graphviz file) to be changed at runtime, via environment +/// variable. +/// +/// `DebugCounters` supports a recursive rendering of `Expression` counters, so they can be +/// presented as nested expressions such as `(bcb3 - (bcb0 + bcb1))`. +pub(super) struct DebugCounters { + some_counters: Option<FxHashMap<ExpressionOperandId, DebugCounter>>, +} + +impl DebugCounters { + pub fn new() -> Self { + Self { some_counters: None } + } + + pub fn enable(&mut self) { + debug_assert!(!self.is_enabled()); + self.some_counters.replace(FxHashMap::default()); + } + + pub fn is_enabled(&self) -> bool { + self.some_counters.is_some() + } + + pub fn add_counter(&mut self, counter_kind: &CoverageKind, some_block_label: Option<String>) { + if let Some(counters) = &mut self.some_counters { + let id: ExpressionOperandId = match *counter_kind { + CoverageKind::Counter { id, .. } => id.into(), + CoverageKind::Expression { id, .. } => id.into(), + _ => bug!( + "the given `CoverageKind` is not an counter or expression: {:?}", + counter_kind + ), + }; + counters + .try_insert(id, DebugCounter::new(counter_kind.clone(), some_block_label)) + .expect("attempt to add the same counter_kind to DebugCounters more than once"); + } + } + + pub fn some_block_label(&self, operand: ExpressionOperandId) -> Option<&String> { + self.some_counters.as_ref().map_or(None, |counters| { + counters + .get(&operand) + .map_or(None, |debug_counter| debug_counter.some_block_label.as_ref()) + }) + } + + pub fn format_counter(&self, counter_kind: &CoverageKind) -> String { + match *counter_kind { + CoverageKind::Counter { .. } => { + format!("Counter({})", self.format_counter_kind(counter_kind)) + } + CoverageKind::Expression { .. } => { + format!("Expression({})", self.format_counter_kind(counter_kind)) + } + CoverageKind::Unreachable { .. } => "Unreachable".to_owned(), + } + } + + fn format_counter_kind(&self, counter_kind: &CoverageKind) -> String { + let counter_format = &debug_options().counter_format; + if let CoverageKind::Expression { id, lhs, op, rhs } = *counter_kind { + if counter_format.operation { + return format!( + "{}{} {} {}", + if counter_format.id || self.some_counters.is_none() { + format!("#{} = ", id.index()) + } else { + String::new() + }, + self.format_operand(lhs), + if op == Op::Add { "+" } else { "-" }, + self.format_operand(rhs), + ); + } + } + + let id: ExpressionOperandId = match *counter_kind { + CoverageKind::Counter { id, .. } => id.into(), + CoverageKind::Expression { id, .. } => id.into(), + _ => { + bug!("the given `CoverageKind` is not an counter or expression: {:?}", counter_kind) + } + }; + if self.some_counters.is_some() && (counter_format.block || !counter_format.id) { + let counters = self.some_counters.as_ref().unwrap(); + if let Some(DebugCounter { some_block_label: Some(block_label), .. }) = + counters.get(&id) + { + return if counter_format.id { + format!("{}#{}", block_label, id.index()) + } else { + block_label.to_string() + }; + } + } + format!("#{}", id.index()) + } + + fn format_operand(&self, operand: ExpressionOperandId) -> String { + if operand.index() == 0 { + return String::from("0"); + } + if let Some(counters) = &self.some_counters { + if let Some(DebugCounter { counter_kind, some_block_label }) = counters.get(&operand) { + if let CoverageKind::Expression { .. } = counter_kind { + if let Some(label) = some_block_label && debug_options().counter_format.block { + return format!( + "{}:({})", + label, + self.format_counter_kind(counter_kind) + ); + } + return format!("({})", self.format_counter_kind(counter_kind)); + } + return self.format_counter_kind(counter_kind); + } + } + format!("#{}", operand.index()) + } +} + +/// A non-public support class to `DebugCounters`. +#[derive(Debug)] +struct DebugCounter { + counter_kind: CoverageKind, + some_block_label: Option<String>, +} + +impl DebugCounter { + fn new(counter_kind: CoverageKind, some_block_label: Option<String>) -> Self { + Self { counter_kind, some_block_label } + } +} + +/// If enabled, this data structure captures additional debugging information used when generating +/// a Graphviz (.dot file) representation of the `CoverageGraph`, for debugging purposes. +pub(super) struct GraphvizData { + some_bcb_to_coverage_spans_with_counters: + Option<FxHashMap<BasicCoverageBlock, Vec<(CoverageSpan, CoverageKind)>>>, + some_bcb_to_dependency_counters: Option<FxHashMap<BasicCoverageBlock, Vec<CoverageKind>>>, + some_edge_to_counter: Option<FxHashMap<(BasicCoverageBlock, BasicBlock), CoverageKind>>, +} + +impl GraphvizData { + pub fn new() -> Self { + Self { + some_bcb_to_coverage_spans_with_counters: None, + some_bcb_to_dependency_counters: None, + some_edge_to_counter: None, + } + } + + pub fn enable(&mut self) { + debug_assert!(!self.is_enabled()); + self.some_bcb_to_coverage_spans_with_counters = Some(FxHashMap::default()); + self.some_bcb_to_dependency_counters = Some(FxHashMap::default()); + self.some_edge_to_counter = Some(FxHashMap::default()); + } + + pub fn is_enabled(&self) -> bool { + self.some_bcb_to_coverage_spans_with_counters.is_some() + } + + pub fn add_bcb_coverage_span_with_counter( + &mut self, + bcb: BasicCoverageBlock, + coverage_span: &CoverageSpan, + counter_kind: &CoverageKind, + ) { + if let Some(bcb_to_coverage_spans_with_counters) = + self.some_bcb_to_coverage_spans_with_counters.as_mut() + { + bcb_to_coverage_spans_with_counters + .entry(bcb) + .or_insert_with(Vec::new) + .push((coverage_span.clone(), counter_kind.clone())); + } + } + + pub fn get_bcb_coverage_spans_with_counters( + &self, + bcb: BasicCoverageBlock, + ) -> Option<&[(CoverageSpan, CoverageKind)]> { + if let Some(bcb_to_coverage_spans_with_counters) = + self.some_bcb_to_coverage_spans_with_counters.as_ref() + { + bcb_to_coverage_spans_with_counters.get(&bcb).map(Deref::deref) + } else { + None + } + } + + pub fn add_bcb_dependency_counter( + &mut self, + bcb: BasicCoverageBlock, + counter_kind: &CoverageKind, + ) { + if let Some(bcb_to_dependency_counters) = self.some_bcb_to_dependency_counters.as_mut() { + bcb_to_dependency_counters + .entry(bcb) + .or_insert_with(Vec::new) + .push(counter_kind.clone()); + } + } + + pub fn get_bcb_dependency_counters(&self, bcb: BasicCoverageBlock) -> Option<&[CoverageKind]> { + if let Some(bcb_to_dependency_counters) = self.some_bcb_to_dependency_counters.as_ref() { + bcb_to_dependency_counters.get(&bcb).map(Deref::deref) + } else { + None + } + } + + pub fn set_edge_counter( + &mut self, + from_bcb: BasicCoverageBlock, + to_bb: BasicBlock, + counter_kind: &CoverageKind, + ) { + if let Some(edge_to_counter) = self.some_edge_to_counter.as_mut() { + edge_to_counter + .try_insert((from_bcb, to_bb), counter_kind.clone()) + .expect("invalid attempt to insert more than one edge counter for the same edge"); + } + } + + pub fn get_edge_counter( + &self, + from_bcb: BasicCoverageBlock, + to_bb: BasicBlock, + ) -> Option<&CoverageKind> { + if let Some(edge_to_counter) = self.some_edge_to_counter.as_ref() { + edge_to_counter.get(&(from_bcb, to_bb)) + } else { + None + } + } +} + +/// If enabled, this struct captures additional data used to track whether expressions were used, +/// directly or indirectly, to compute the coverage counts for all `CoverageSpan`s, and any that are +/// _not_ used are retained in the `unused_expressions` Vec, to be included in debug output (logs +/// and/or a `CoverageGraph` graphviz output). +pub(super) struct UsedExpressions { + some_used_expression_operands: + Option<FxHashMap<ExpressionOperandId, Vec<InjectedExpressionId>>>, + some_unused_expressions: + Option<Vec<(CoverageKind, Option<BasicCoverageBlock>, BasicCoverageBlock)>>, +} + +impl UsedExpressions { + pub fn new() -> Self { + Self { some_used_expression_operands: None, some_unused_expressions: None } + } + + pub fn enable(&mut self) { + debug_assert!(!self.is_enabled()); + self.some_used_expression_operands = Some(FxHashMap::default()); + self.some_unused_expressions = Some(Vec::new()); + } + + pub fn is_enabled(&self) -> bool { + self.some_used_expression_operands.is_some() + } + + pub fn add_expression_operands(&mut self, expression: &CoverageKind) { + if let Some(used_expression_operands) = self.some_used_expression_operands.as_mut() { + if let CoverageKind::Expression { id, lhs, rhs, .. } = *expression { + used_expression_operands.entry(lhs).or_insert_with(Vec::new).push(id); + used_expression_operands.entry(rhs).or_insert_with(Vec::new).push(id); + } + } + } + + pub fn expression_is_used(&self, expression: &CoverageKind) -> bool { + if let Some(used_expression_operands) = self.some_used_expression_operands.as_ref() { + used_expression_operands.contains_key(&expression.as_operand_id()) + } else { + false + } + } + + pub fn add_unused_expression_if_not_found( + &mut self, + expression: &CoverageKind, + edge_from_bcb: Option<BasicCoverageBlock>, + target_bcb: BasicCoverageBlock, + ) { + if let Some(used_expression_operands) = self.some_used_expression_operands.as_ref() { + if !used_expression_operands.contains_key(&expression.as_operand_id()) { + self.some_unused_expressions.as_mut().unwrap().push(( + expression.clone(), + edge_from_bcb, + target_bcb, + )); + } + } + } + + /// Return the list of unused counters (if any) as a tuple with the counter (`CoverageKind`), + /// optional `from_bcb` (if it was an edge counter), and `target_bcb`. + pub fn get_unused_expressions( + &self, + ) -> Vec<(CoverageKind, Option<BasicCoverageBlock>, BasicCoverageBlock)> { + if let Some(unused_expressions) = self.some_unused_expressions.as_ref() { + unused_expressions.clone() + } else { + Vec::new() + } + } + + /// If enabled, validate that every BCB or edge counter not directly associated with a coverage + /// span is at least indirectly associated (it is a dependency of a BCB counter that _is_ + /// associated with a coverage span). + pub fn validate( + &mut self, + bcb_counters_without_direct_coverage_spans: &[( + Option<BasicCoverageBlock>, + BasicCoverageBlock, + CoverageKind, + )], + ) { + if self.is_enabled() { + let mut not_validated = bcb_counters_without_direct_coverage_spans + .iter() + .map(|(_, _, counter_kind)| counter_kind) + .collect::<Vec<_>>(); + let mut validating_count = 0; + while not_validated.len() != validating_count { + let to_validate = not_validated.split_off(0); + validating_count = to_validate.len(); + for counter_kind in to_validate { + if self.expression_is_used(counter_kind) { + self.add_expression_operands(counter_kind); + } else { + not_validated.push(counter_kind); + } + } + } + } + } + + pub fn alert_on_unused_expressions(&self, debug_counters: &DebugCounters) { + if let Some(unused_expressions) = self.some_unused_expressions.as_ref() { + for (counter_kind, edge_from_bcb, target_bcb) in unused_expressions { + let unused_counter_message = if let Some(from_bcb) = edge_from_bcb.as_ref() { + format!( + "non-coverage edge counter found without a dependent expression, in \ + {:?}->{:?}; counter={}", + from_bcb, + target_bcb, + debug_counters.format_counter(&counter_kind), + ) + } else { + format!( + "non-coverage counter found without a dependent expression, in {:?}; \ + counter={}", + target_bcb, + debug_counters.format_counter(&counter_kind), + ) + }; + + if debug_options().allow_unused_expressions { + debug!("WARNING: {}", unused_counter_message); + } else { + bug!("{}", unused_counter_message); + } + } + } + } +} + +/// Generates the MIR pass `CoverageSpan`-specific spanview dump file. +pub(super) fn dump_coverage_spanview<'tcx>( + tcx: TyCtxt<'tcx>, + mir_body: &mir::Body<'tcx>, + basic_coverage_blocks: &CoverageGraph, + pass_name: &str, + body_span: Span, + coverage_spans: &[CoverageSpan], +) { + let mir_source = mir_body.source; + let def_id = mir_source.def_id(); + + let span_viewables = span_viewables(tcx, mir_body, basic_coverage_blocks, &coverage_spans); + let mut file = create_dump_file(tcx, "html", None, pass_name, &0, mir_source) + .expect("Unexpected error creating MIR spanview HTML file"); + let crate_name = tcx.crate_name(def_id.krate); + let item_name = tcx.def_path(def_id).to_filename_friendly_no_crate(); + let title = format!("{}.{} - Coverage Spans", crate_name, item_name); + spanview::write_document(tcx, body_span, span_viewables, &title, &mut file) + .expect("Unexpected IO error dumping coverage spans as HTML"); +} + +/// Converts the computed `BasicCoverageBlockData`s into `SpanViewable`s. +fn span_viewables<'tcx>( + tcx: TyCtxt<'tcx>, + mir_body: &mir::Body<'tcx>, + basic_coverage_blocks: &CoverageGraph, + coverage_spans: &[CoverageSpan], +) -> Vec<SpanViewable> { + let mut span_viewables = Vec::new(); + for coverage_span in coverage_spans { + let tooltip = coverage_span.format_coverage_statements(tcx, mir_body); + let CoverageSpan { span, bcb, .. } = coverage_span; + let bcb_data = &basic_coverage_blocks[*bcb]; + let id = bcb_data.id(); + let leader_bb = bcb_data.leader_bb(); + span_viewables.push(SpanViewable { bb: leader_bb, span: *span, id, tooltip }); + } + span_viewables +} + +/// Generates the MIR pass coverage-specific graphviz dump file. +pub(super) fn dump_coverage_graphviz<'tcx>( + tcx: TyCtxt<'tcx>, + mir_body: &mir::Body<'tcx>, + pass_name: &str, + basic_coverage_blocks: &CoverageGraph, + debug_counters: &DebugCounters, + graphviz_data: &GraphvizData, + intermediate_expressions: &[CoverageKind], + debug_used_expressions: &UsedExpressions, +) { + let mir_source = mir_body.source; + let def_id = mir_source.def_id(); + let node_content = |bcb| { + bcb_to_string_sections( + tcx, + mir_body, + debug_counters, + &basic_coverage_blocks[bcb], + graphviz_data.get_bcb_coverage_spans_with_counters(bcb), + graphviz_data.get_bcb_dependency_counters(bcb), + // intermediate_expressions are injected into the mir::START_BLOCK, so + // include them in the first BCB. + if bcb.index() == 0 { Some(&intermediate_expressions) } else { None }, + ) + }; + let edge_labels = |from_bcb| { + let from_bcb_data = &basic_coverage_blocks[from_bcb]; + let from_terminator = from_bcb_data.terminator(mir_body); + let mut edge_labels = from_terminator.kind.fmt_successor_labels(); + edge_labels.retain(|label| label != "unreachable"); + let edge_counters = from_terminator + .successors() + .map(|successor_bb| graphviz_data.get_edge_counter(from_bcb, successor_bb)); + iter::zip(&edge_labels, edge_counters) + .map(|(label, some_counter)| { + if let Some(counter) = some_counter { + format!("{}\n{}", label, debug_counters.format_counter(counter)) + } else { + label.to_string() + } + }) + .collect::<Vec<_>>() + }; + let graphviz_name = format!("Cov_{}_{}", def_id.krate.index(), def_id.index.index()); + let mut graphviz_writer = + GraphvizWriter::new(basic_coverage_blocks, &graphviz_name, node_content, edge_labels); + let unused_expressions = debug_used_expressions.get_unused_expressions(); + if unused_expressions.len() > 0 { + graphviz_writer.set_graph_label(&format!( + "Unused expressions:\n {}", + unused_expressions + .as_slice() + .iter() + .map(|(counter_kind, edge_from_bcb, target_bcb)| { + if let Some(from_bcb) = edge_from_bcb.as_ref() { + format!( + "{:?}->{:?}: {}", + from_bcb, + target_bcb, + debug_counters.format_counter(&counter_kind), + ) + } else { + format!( + "{:?}: {}", + target_bcb, + debug_counters.format_counter(&counter_kind), + ) + } + }) + .join("\n ") + )); + } + let mut file = create_dump_file(tcx, "dot", None, pass_name, &0, mir_source) + .expect("Unexpected error creating BasicCoverageBlock graphviz DOT file"); + graphviz_writer + .write_graphviz(tcx, &mut file) + .expect("Unexpected error writing BasicCoverageBlock graphviz DOT file"); +} + +fn bcb_to_string_sections<'tcx>( + tcx: TyCtxt<'tcx>, + mir_body: &mir::Body<'tcx>, + debug_counters: &DebugCounters, + bcb_data: &BasicCoverageBlockData, + some_coverage_spans_with_counters: Option<&[(CoverageSpan, CoverageKind)]>, + some_dependency_counters: Option<&[CoverageKind]>, + some_intermediate_expressions: Option<&[CoverageKind]>, +) -> Vec<String> { + let len = bcb_data.basic_blocks.len(); + let mut sections = Vec::new(); + if let Some(collect_intermediate_expressions) = some_intermediate_expressions { + sections.push( + collect_intermediate_expressions + .iter() + .map(|expression| { + format!("Intermediate {}", debug_counters.format_counter(expression)) + }) + .join("\n"), + ); + } + if let Some(coverage_spans_with_counters) = some_coverage_spans_with_counters { + sections.push( + coverage_spans_with_counters + .iter() + .map(|(covspan, counter)| { + format!( + "{} at {}", + debug_counters.format_counter(counter), + covspan.format(tcx, mir_body) + ) + }) + .join("\n"), + ); + } + if let Some(dependency_counters) = some_dependency_counters { + sections.push(format!( + "Non-coverage counters:\n {}", + dependency_counters + .iter() + .map(|counter| debug_counters.format_counter(counter)) + .join(" \n"), + )); + } + if let Some(counter_kind) = &bcb_data.counter_kind { + sections.push(format!("{:?}", counter_kind)); + } + let non_term_blocks = bcb_data.basic_blocks[0..len - 1] + .iter() + .map(|&bb| format!("{:?}: {}", bb, term_type(&mir_body[bb].terminator().kind))) + .collect::<Vec<_>>(); + if non_term_blocks.len() > 0 { + sections.push(non_term_blocks.join("\n")); + } + sections.push(format!( + "{:?}: {}", + bcb_data.basic_blocks.last().unwrap(), + term_type(&bcb_data.terminator(mir_body).kind) + )); + sections +} + +/// Returns a simple string representation of a `TerminatorKind` variant, independent of any +/// values it might hold. +pub(super) fn term_type(kind: &TerminatorKind<'_>) -> &'static str { + match kind { + TerminatorKind::Goto { .. } => "Goto", + TerminatorKind::SwitchInt { .. } => "SwitchInt", + TerminatorKind::Resume => "Resume", + TerminatorKind::Abort => "Abort", + TerminatorKind::Return => "Return", + TerminatorKind::Unreachable => "Unreachable", + TerminatorKind::Drop { .. } => "Drop", + TerminatorKind::DropAndReplace { .. } => "DropAndReplace", + TerminatorKind::Call { .. } => "Call", + TerminatorKind::Assert { .. } => "Assert", + TerminatorKind::Yield { .. } => "Yield", + TerminatorKind::GeneratorDrop => "GeneratorDrop", + TerminatorKind::FalseEdge { .. } => "FalseEdge", + TerminatorKind::FalseUnwind { .. } => "FalseUnwind", + TerminatorKind::InlineAsm { .. } => "InlineAsm", + } +} diff --git a/compiler/rustc_mir_transform/src/coverage/graph.rs b/compiler/rustc_mir_transform/src/coverage/graph.rs new file mode 100644 index 000000000..759ea7cd3 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/graph.rs @@ -0,0 +1,753 @@ +use super::Error; + +use itertools::Itertools; +use rustc_data_structures::fx::FxHashMap; +use rustc_data_structures::graph::dominators::{self, Dominators}; +use rustc_data_structures::graph::{self, GraphSuccessors, WithNumNodes, WithStartNode}; +use rustc_index::bit_set::BitSet; +use rustc_index::vec::IndexVec; +use rustc_middle::mir::coverage::*; +use rustc_middle::mir::{self, BasicBlock, BasicBlockData, Terminator, TerminatorKind}; + +use std::ops::{Index, IndexMut}; + +const ID_SEPARATOR: &str = ","; + +/// A coverage-specific simplification of the MIR control flow graph (CFG). The `CoverageGraph`s +/// nodes are `BasicCoverageBlock`s, which encompass one or more MIR `BasicBlock`s, plus a +/// `CoverageKind` counter (to be added by `CoverageCounters::make_bcb_counters`), and an optional +/// set of additional counters--if needed--to count incoming edges, if there are more than one. +/// (These "edge counters" are eventually converted into new MIR `BasicBlock`s.) +#[derive(Debug)] +pub(super) struct CoverageGraph { + bcbs: IndexVec<BasicCoverageBlock, BasicCoverageBlockData>, + bb_to_bcb: IndexVec<BasicBlock, Option<BasicCoverageBlock>>, + pub successors: IndexVec<BasicCoverageBlock, Vec<BasicCoverageBlock>>, + pub predecessors: IndexVec<BasicCoverageBlock, Vec<BasicCoverageBlock>>, + dominators: Option<Dominators<BasicCoverageBlock>>, +} + +impl CoverageGraph { + pub fn from_mir(mir_body: &mir::Body<'_>) -> Self { + let (bcbs, bb_to_bcb) = Self::compute_basic_coverage_blocks(mir_body); + + // Pre-transform MIR `BasicBlock` successors and predecessors into the BasicCoverageBlock + // equivalents. Note that since the BasicCoverageBlock graph has been fully simplified, the + // each predecessor of a BCB leader_bb should be in a unique BCB. It is possible for a + // `SwitchInt` to have multiple targets to the same destination `BasicBlock`, so + // de-duplication is required. This is done without reordering the successors. + + let bcbs_len = bcbs.len(); + let mut seen = IndexVec::from_elem_n(false, bcbs_len); + let successors = IndexVec::from_fn_n( + |bcb| { + for b in seen.iter_mut() { + *b = false; + } + let bcb_data = &bcbs[bcb]; + let mut bcb_successors = Vec::new(); + for successor in + bcb_filtered_successors(&mir_body, &bcb_data.terminator(mir_body).kind) + .filter_map(|successor_bb| bb_to_bcb[successor_bb]) + { + if !seen[successor] { + seen[successor] = true; + bcb_successors.push(successor); + } + } + bcb_successors + }, + bcbs.len(), + ); + + let mut predecessors = IndexVec::from_elem_n(Vec::new(), bcbs.len()); + for (bcb, bcb_successors) in successors.iter_enumerated() { + for &successor in bcb_successors { + predecessors[successor].push(bcb); + } + } + + let mut basic_coverage_blocks = + Self { bcbs, bb_to_bcb, successors, predecessors, dominators: None }; + let dominators = dominators::dominators(&basic_coverage_blocks); + basic_coverage_blocks.dominators = Some(dominators); + basic_coverage_blocks + } + + fn compute_basic_coverage_blocks( + mir_body: &mir::Body<'_>, + ) -> ( + IndexVec<BasicCoverageBlock, BasicCoverageBlockData>, + IndexVec<BasicBlock, Option<BasicCoverageBlock>>, + ) { + let num_basic_blocks = mir_body.basic_blocks.len(); + let mut bcbs = IndexVec::with_capacity(num_basic_blocks); + let mut bb_to_bcb = IndexVec::from_elem_n(None, num_basic_blocks); + + // Walk the MIR CFG using a Preorder traversal, which starts from `START_BLOCK` and follows + // each block terminator's `successors()`. Coverage spans must map to actual source code, + // so compiler generated blocks and paths can be ignored. To that end, the CFG traversal + // intentionally omits unwind paths. + // FIXME(#78544): MIR InstrumentCoverage: Improve coverage of `#[should_panic]` tests and + // `catch_unwind()` handlers. + let mir_cfg_without_unwind = ShortCircuitPreorder::new(&mir_body, bcb_filtered_successors); + + let mut basic_blocks = Vec::new(); + for (bb, data) in mir_cfg_without_unwind { + if let Some(last) = basic_blocks.last() { + let predecessors = &mir_body.basic_blocks.predecessors()[bb]; + if predecessors.len() > 1 || !predecessors.contains(last) { + // The `bb` has more than one _incoming_ edge, and should start its own + // `BasicCoverageBlockData`. (Note, the `basic_blocks` vector does not yet + // include `bb`; it contains a sequence of one or more sequential basic_blocks + // with no intermediate branches in or out. Save these as a new + // `BasicCoverageBlockData` before starting the new one.) + Self::add_basic_coverage_block( + &mut bcbs, + &mut bb_to_bcb, + basic_blocks.split_off(0), + ); + debug!( + " because {}", + if predecessors.len() > 1 { + "predecessors.len() > 1".to_owned() + } else { + format!("bb {} is not in precessors: {:?}", bb.index(), predecessors) + } + ); + } + } + basic_blocks.push(bb); + + let term = data.terminator(); + + match term.kind { + TerminatorKind::Return { .. } + | TerminatorKind::Abort + | TerminatorKind::Yield { .. } + | TerminatorKind::SwitchInt { .. } => { + // The `bb` has more than one _outgoing_ edge, or exits the function. Save the + // current sequence of `basic_blocks` gathered to this point, as a new + // `BasicCoverageBlockData`. + Self::add_basic_coverage_block( + &mut bcbs, + &mut bb_to_bcb, + basic_blocks.split_off(0), + ); + debug!(" because term.kind = {:?}", term.kind); + // Note that this condition is based on `TerminatorKind`, even though it + // theoretically boils down to `successors().len() != 1`; that is, either zero + // (e.g., `Return`, `Abort`) or multiple successors (e.g., `SwitchInt`), but + // since the BCB CFG ignores things like unwind branches (which exist in the + // `Terminator`s `successors()` list) checking the number of successors won't + // work. + } + + // The following `TerminatorKind`s are either not expected outside an unwind branch, + // or they should not (under normal circumstances) branch. Coverage graphs are + // simplified by assuring coverage results are accurate for program executions that + // don't panic. + // + // Programs that panic and unwind may record slightly inaccurate coverage results + // for a coverage region containing the `Terminator` that began the panic. This + // is as intended. (See Issue #78544 for a possible future option to support + // coverage in test programs that panic.) + TerminatorKind::Goto { .. } + | TerminatorKind::Resume + | TerminatorKind::Unreachable + | TerminatorKind::Drop { .. } + | TerminatorKind::DropAndReplace { .. } + | TerminatorKind::Call { .. } + | TerminatorKind::GeneratorDrop + | TerminatorKind::Assert { .. } + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } + | TerminatorKind::InlineAsm { .. } => {} + } + } + + if !basic_blocks.is_empty() { + // process any remaining basic_blocks into a final `BasicCoverageBlockData` + Self::add_basic_coverage_block(&mut bcbs, &mut bb_to_bcb, basic_blocks.split_off(0)); + debug!(" because the end of the MIR CFG was reached while traversing"); + } + + (bcbs, bb_to_bcb) + } + + fn add_basic_coverage_block( + bcbs: &mut IndexVec<BasicCoverageBlock, BasicCoverageBlockData>, + bb_to_bcb: &mut IndexVec<BasicBlock, Option<BasicCoverageBlock>>, + basic_blocks: Vec<BasicBlock>, + ) { + let bcb = BasicCoverageBlock::from_usize(bcbs.len()); + for &bb in basic_blocks.iter() { + bb_to_bcb[bb] = Some(bcb); + } + let bcb_data = BasicCoverageBlockData::from(basic_blocks); + debug!("adding bcb{}: {:?}", bcb.index(), bcb_data); + bcbs.push(bcb_data); + } + + #[inline(always)] + pub fn iter_enumerated( + &self, + ) -> impl Iterator<Item = (BasicCoverageBlock, &BasicCoverageBlockData)> { + self.bcbs.iter_enumerated() + } + + #[inline(always)] + pub fn iter_enumerated_mut( + &mut self, + ) -> impl Iterator<Item = (BasicCoverageBlock, &mut BasicCoverageBlockData)> { + self.bcbs.iter_enumerated_mut() + } + + #[inline(always)] + pub fn bcb_from_bb(&self, bb: BasicBlock) -> Option<BasicCoverageBlock> { + if bb.index() < self.bb_to_bcb.len() { self.bb_to_bcb[bb] } else { None } + } + + #[inline(always)] + pub fn is_dominated_by(&self, node: BasicCoverageBlock, dom: BasicCoverageBlock) -> bool { + self.dominators.as_ref().unwrap().is_dominated_by(node, dom) + } + + #[inline(always)] + pub fn dominators(&self) -> &Dominators<BasicCoverageBlock> { + self.dominators.as_ref().unwrap() + } +} + +impl Index<BasicCoverageBlock> for CoverageGraph { + type Output = BasicCoverageBlockData; + + #[inline] + fn index(&self, index: BasicCoverageBlock) -> &BasicCoverageBlockData { + &self.bcbs[index] + } +} + +impl IndexMut<BasicCoverageBlock> for CoverageGraph { + #[inline] + fn index_mut(&mut self, index: BasicCoverageBlock) -> &mut BasicCoverageBlockData { + &mut self.bcbs[index] + } +} + +impl graph::DirectedGraph for CoverageGraph { + type Node = BasicCoverageBlock; +} + +impl graph::WithNumNodes for CoverageGraph { + #[inline] + fn num_nodes(&self) -> usize { + self.bcbs.len() + } +} + +impl graph::WithStartNode for CoverageGraph { + #[inline] + fn start_node(&self) -> Self::Node { + self.bcb_from_bb(mir::START_BLOCK) + .expect("mir::START_BLOCK should be in a BasicCoverageBlock") + } +} + +type BcbSuccessors<'graph> = std::slice::Iter<'graph, BasicCoverageBlock>; + +impl<'graph> graph::GraphSuccessors<'graph> for CoverageGraph { + type Item = BasicCoverageBlock; + type Iter = std::iter::Cloned<BcbSuccessors<'graph>>; +} + +impl graph::WithSuccessors for CoverageGraph { + #[inline] + fn successors(&self, node: Self::Node) -> <Self as GraphSuccessors<'_>>::Iter { + self.successors[node].iter().cloned() + } +} + +impl<'graph> graph::GraphPredecessors<'graph> for CoverageGraph { + type Item = BasicCoverageBlock; + type Iter = std::iter::Copied<std::slice::Iter<'graph, BasicCoverageBlock>>; +} + +impl graph::WithPredecessors for CoverageGraph { + #[inline] + fn predecessors(&self, node: Self::Node) -> <Self as graph::GraphPredecessors<'_>>::Iter { + self.predecessors[node].iter().copied() + } +} + +rustc_index::newtype_index! { + /// A node in the control-flow graph of CoverageGraph. + pub(super) struct BasicCoverageBlock { + DEBUG_FORMAT = "bcb{}", + const START_BCB = 0, + } +} + +/// `BasicCoverageBlockData` holds the data indexed by a `BasicCoverageBlock`. +/// +/// A `BasicCoverageBlock` (BCB) represents the maximal-length sequence of MIR `BasicBlock`s without +/// conditional branches, and form a new, simplified, coverage-specific Control Flow Graph, without +/// altering the original MIR CFG. +/// +/// Note that running the MIR `SimplifyCfg` transform is not sufficient (and therefore not +/// necessary). The BCB-based CFG is a more aggressive simplification. For example: +/// +/// * The BCB CFG ignores (trims) branches not relevant to coverage, such as unwind-related code, +/// that is injected by the Rust compiler but has no physical source code to count. This also +/// means a BasicBlock with a `Call` terminator can be merged into its primary successor target +/// block, in the same BCB. (But, note: Issue #78544: "MIR InstrumentCoverage: Improve coverage +/// of `#[should_panic]` tests and `catch_unwind()` handlers") +/// * Some BasicBlock terminators support Rust-specific concerns--like borrow-checking--that are +/// not relevant to coverage analysis. `FalseUnwind`, for example, can be treated the same as +/// a `Goto`, and merged with its successor into the same BCB. +/// +/// Each BCB with at least one computed `CoverageSpan` will have no more than one `Counter`. +/// In some cases, a BCB's execution count can be computed by `Expression`. Additional +/// disjoint `CoverageSpan`s in a BCB can also be counted by `Expression` (by adding `ZERO` +/// to the BCB's primary counter or expression). +/// +/// The BCB CFG is critical to simplifying the coverage analysis by ensuring graph path-based +/// queries (`is_dominated_by()`, `predecessors`, `successors`, etc.) have branch (control flow) +/// significance. +#[derive(Debug, Clone)] +pub(super) struct BasicCoverageBlockData { + pub basic_blocks: Vec<BasicBlock>, + pub counter_kind: Option<CoverageKind>, + edge_from_bcbs: Option<FxHashMap<BasicCoverageBlock, CoverageKind>>, +} + +impl BasicCoverageBlockData { + pub fn from(basic_blocks: Vec<BasicBlock>) -> Self { + assert!(basic_blocks.len() > 0); + Self { basic_blocks, counter_kind: None, edge_from_bcbs: None } + } + + #[inline(always)] + pub fn leader_bb(&self) -> BasicBlock { + self.basic_blocks[0] + } + + #[inline(always)] + pub fn last_bb(&self) -> BasicBlock { + *self.basic_blocks.last().unwrap() + } + + #[inline(always)] + pub fn terminator<'a, 'tcx>(&self, mir_body: &'a mir::Body<'tcx>) -> &'a Terminator<'tcx> { + &mir_body[self.last_bb()].terminator() + } + + pub fn set_counter( + &mut self, + counter_kind: CoverageKind, + ) -> Result<ExpressionOperandId, Error> { + debug_assert!( + // If the BCB has an edge counter (to be injected into a new `BasicBlock`), it can also + // have an expression (to be injected into an existing `BasicBlock` represented by this + // `BasicCoverageBlock`). + self.edge_from_bcbs.is_none() || counter_kind.is_expression(), + "attempt to add a `Counter` to a BCB target with existing incoming edge counters" + ); + let operand = counter_kind.as_operand_id(); + if let Some(replaced) = self.counter_kind.replace(counter_kind) { + Error::from_string(format!( + "attempt to set a BasicCoverageBlock coverage counter more than once; \ + {:?} already had counter {:?}", + self, replaced, + )) + } else { + Ok(operand) + } + } + + #[inline(always)] + pub fn counter(&self) -> Option<&CoverageKind> { + self.counter_kind.as_ref() + } + + #[inline(always)] + pub fn take_counter(&mut self) -> Option<CoverageKind> { + self.counter_kind.take() + } + + pub fn set_edge_counter_from( + &mut self, + from_bcb: BasicCoverageBlock, + counter_kind: CoverageKind, + ) -> Result<ExpressionOperandId, Error> { + if level_enabled!(tracing::Level::DEBUG) { + // If the BCB has an edge counter (to be injected into a new `BasicBlock`), it can also + // have an expression (to be injected into an existing `BasicBlock` represented by this + // `BasicCoverageBlock`). + if !self.counter_kind.as_ref().map_or(true, |c| c.is_expression()) { + return Error::from_string(format!( + "attempt to add an incoming edge counter from {:?} when the target BCB already \ + has a `Counter`", + from_bcb + )); + } + } + let operand = counter_kind.as_operand_id(); + if let Some(replaced) = + self.edge_from_bcbs.get_or_insert_default().insert(from_bcb, counter_kind) + { + Error::from_string(format!( + "attempt to set an edge counter more than once; from_bcb: \ + {:?} already had counter {:?}", + from_bcb, replaced, + )) + } else { + Ok(operand) + } + } + + #[inline] + pub fn edge_counter_from(&self, from_bcb: BasicCoverageBlock) -> Option<&CoverageKind> { + if let Some(edge_from_bcbs) = &self.edge_from_bcbs { + edge_from_bcbs.get(&from_bcb) + } else { + None + } + } + + #[inline] + pub fn take_edge_counters( + &mut self, + ) -> Option<impl Iterator<Item = (BasicCoverageBlock, CoverageKind)>> { + self.edge_from_bcbs.take().map(|m| m.into_iter()) + } + + pub fn id(&self) -> String { + format!("@{}", self.basic_blocks.iter().map(|bb| bb.index().to_string()).join(ID_SEPARATOR)) + } +} + +/// Represents a successor from a branching BasicCoverageBlock (such as the arms of a `SwitchInt`) +/// as either the successor BCB itself, if it has only one incoming edge, or the successor _plus_ +/// the specific branching BCB, representing the edge between the two. The latter case +/// distinguishes this incoming edge from other incoming edges to the same `target_bcb`. +#[derive(Clone, Copy, PartialEq, Eq)] +pub(super) struct BcbBranch { + pub edge_from_bcb: Option<BasicCoverageBlock>, + pub target_bcb: BasicCoverageBlock, +} + +impl BcbBranch { + pub fn from_to( + from_bcb: BasicCoverageBlock, + to_bcb: BasicCoverageBlock, + basic_coverage_blocks: &CoverageGraph, + ) -> Self { + let edge_from_bcb = if basic_coverage_blocks.predecessors[to_bcb].len() > 1 { + Some(from_bcb) + } else { + None + }; + Self { edge_from_bcb, target_bcb: to_bcb } + } + + pub fn counter<'a>( + &self, + basic_coverage_blocks: &'a CoverageGraph, + ) -> Option<&'a CoverageKind> { + if let Some(from_bcb) = self.edge_from_bcb { + basic_coverage_blocks[self.target_bcb].edge_counter_from(from_bcb) + } else { + basic_coverage_blocks[self.target_bcb].counter() + } + } + + pub fn is_only_path_to_target(&self) -> bool { + self.edge_from_bcb.is_none() + } +} + +impl std::fmt::Debug for BcbBranch { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(from_bcb) = self.edge_from_bcb { + write!(fmt, "{:?}->{:?}", from_bcb, self.target_bcb) + } else { + write!(fmt, "{:?}", self.target_bcb) + } + } +} + +// Returns the `Terminator`s non-unwind successors. +// FIXME(#78544): MIR InstrumentCoverage: Improve coverage of `#[should_panic]` tests and +// `catch_unwind()` handlers. +fn bcb_filtered_successors<'a, 'tcx>( + body: &'a mir::Body<'tcx>, + term_kind: &'a TerminatorKind<'tcx>, +) -> Box<dyn Iterator<Item = BasicBlock> + 'a> { + Box::new( + match &term_kind { + // SwitchInt successors are never unwind, and all of them should be traversed. + TerminatorKind::SwitchInt { ref targets, .. } => { + None.into_iter().chain(targets.all_targets().into_iter().copied()) + } + // For all other kinds, return only the first successor, if any, and ignore unwinds. + // NOTE: `chain(&[])` is required to coerce the `option::iter` (from + // `next().into_iter()`) into the `mir::Successors` aliased type. + _ => term_kind.successors().next().into_iter().chain((&[]).into_iter().copied()), + } + .filter(move |&successor| body[successor].terminator().kind != TerminatorKind::Unreachable), + ) +} + +/// Maintains separate worklists for each loop in the BasicCoverageBlock CFG, plus one for the +/// CoverageGraph outside all loops. This supports traversing the BCB CFG in a way that +/// ensures a loop is completely traversed before processing Blocks after the end of the loop. +#[derive(Debug)] +pub(super) struct TraversalContext { + /// From one or more backedges returning to a loop header. + pub loop_backedges: Option<(Vec<BasicCoverageBlock>, BasicCoverageBlock)>, + + /// worklist, to be traversed, of CoverageGraph in the loop with the given loop + /// backedges, such that the loop is the inner inner-most loop containing these + /// CoverageGraph + pub worklist: Vec<BasicCoverageBlock>, +} + +pub(super) struct TraverseCoverageGraphWithLoops { + pub backedges: IndexVec<BasicCoverageBlock, Vec<BasicCoverageBlock>>, + pub context_stack: Vec<TraversalContext>, + visited: BitSet<BasicCoverageBlock>, +} + +impl TraverseCoverageGraphWithLoops { + pub fn new(basic_coverage_blocks: &CoverageGraph) -> Self { + let start_bcb = basic_coverage_blocks.start_node(); + let backedges = find_loop_backedges(basic_coverage_blocks); + let context_stack = + vec![TraversalContext { loop_backedges: None, worklist: vec![start_bcb] }]; + // `context_stack` starts with a `TraversalContext` for the main function context (beginning + // with the `start` BasicCoverageBlock of the function). New worklists are pushed to the top + // of the stack as loops are entered, and popped off of the stack when a loop's worklist is + // exhausted. + let visited = BitSet::new_empty(basic_coverage_blocks.num_nodes()); + Self { backedges, context_stack, visited } + } + + pub fn next(&mut self, basic_coverage_blocks: &CoverageGraph) -> Option<BasicCoverageBlock> { + debug!( + "TraverseCoverageGraphWithLoops::next - context_stack: {:?}", + self.context_stack.iter().rev().collect::<Vec<_>>() + ); + while let Some(next_bcb) = { + // Strip contexts with empty worklists from the top of the stack + while self.context_stack.last().map_or(false, |context| context.worklist.is_empty()) { + self.context_stack.pop(); + } + // Pop the next bcb off of the current context_stack. If none, all BCBs were visited. + self.context_stack.last_mut().map_or(None, |context| context.worklist.pop()) + } { + if !self.visited.insert(next_bcb) { + debug!("Already visited: {:?}", next_bcb); + continue; + } + debug!("Visiting {:?}", next_bcb); + if self.backedges[next_bcb].len() > 0 { + debug!("{:?} is a loop header! Start a new TraversalContext...", next_bcb); + self.context_stack.push(TraversalContext { + loop_backedges: Some((self.backedges[next_bcb].clone(), next_bcb)), + worklist: Vec::new(), + }); + } + self.extend_worklist(basic_coverage_blocks, next_bcb); + return Some(next_bcb); + } + None + } + + pub fn extend_worklist( + &mut self, + basic_coverage_blocks: &CoverageGraph, + bcb: BasicCoverageBlock, + ) { + let successors = &basic_coverage_blocks.successors[bcb]; + debug!("{:?} has {} successors:", bcb, successors.len()); + for &successor in successors { + if successor == bcb { + debug!( + "{:?} has itself as its own successor. (Note, the compiled code will \ + generate an infinite loop.)", + bcb + ); + // Don't re-add this successor to the worklist. We are already processing it. + break; + } + for context in self.context_stack.iter_mut().rev() { + // Add successors of the current BCB to the appropriate context. Successors that + // stay within a loop are added to the BCBs context worklist. Successors that + // exit the loop (they are not dominated by the loop header) must be reachable + // from other BCBs outside the loop, and they will be added to a different + // worklist. + // + // Branching blocks (with more than one successor) must be processed before + // blocks with only one successor, to prevent unnecessarily complicating + // `Expression`s by creating a Counter in a `BasicCoverageBlock` that the + // branching block would have given an `Expression` (or vice versa). + let (some_successor_to_add, some_loop_header) = + if let Some((_, loop_header)) = context.loop_backedges { + if basic_coverage_blocks.is_dominated_by(successor, loop_header) { + (Some(successor), Some(loop_header)) + } else { + (None, None) + } + } else { + (Some(successor), None) + }; + if let Some(successor_to_add) = some_successor_to_add { + if basic_coverage_blocks.successors[successor_to_add].len() > 1 { + debug!( + "{:?} successor is branching. Prioritize it at the beginning of \ + the {}", + successor_to_add, + if let Some(loop_header) = some_loop_header { + format!("worklist for the loop headed by {:?}", loop_header) + } else { + String::from("non-loop worklist") + }, + ); + context.worklist.insert(0, successor_to_add); + } else { + debug!( + "{:?} successor is non-branching. Defer it to the end of the {}", + successor_to_add, + if let Some(loop_header) = some_loop_header { + format!("worklist for the loop headed by {:?}", loop_header) + } else { + String::from("non-loop worklist") + }, + ); + context.worklist.push(successor_to_add); + } + break; + } + } + } + } + + pub fn is_complete(&self) -> bool { + self.visited.count() == self.visited.domain_size() + } + + pub fn unvisited(&self) -> Vec<BasicCoverageBlock> { + let mut unvisited_set: BitSet<BasicCoverageBlock> = + BitSet::new_filled(self.visited.domain_size()); + unvisited_set.subtract(&self.visited); + unvisited_set.iter().collect::<Vec<_>>() + } +} + +pub(super) fn find_loop_backedges( + basic_coverage_blocks: &CoverageGraph, +) -> IndexVec<BasicCoverageBlock, Vec<BasicCoverageBlock>> { + let num_bcbs = basic_coverage_blocks.num_nodes(); + let mut backedges = IndexVec::from_elem_n(Vec::<BasicCoverageBlock>::new(), num_bcbs); + + // Identify loops by their backedges. + // + // The computational complexity is bounded by: n(s) x d where `n` is the number of + // `BasicCoverageBlock` nodes (the simplified/reduced representation of the CFG derived from the + // MIR); `s` is the average number of successors per node (which is most likely less than 2, and + // independent of the size of the function, so it can be treated as a constant); + // and `d` is the average number of dominators per node. + // + // The average number of dominators depends on the size and complexity of the function, and + // nodes near the start of the function's control flow graph typically have less dominators + // than nodes near the end of the CFG. Without doing a detailed mathematical analysis, I + // think the resulting complexity has the characteristics of O(n log n). + // + // The overall complexity appears to be comparable to many other MIR transform algorithms, and I + // don't expect that this function is creating a performance hot spot, but if this becomes an + // issue, there may be ways to optimize the `is_dominated_by` algorithm (as indicated by an + // existing `FIXME` comment in that code), or possibly ways to optimize it's usage here, perhaps + // by keeping track of results for visited `BasicCoverageBlock`s if they can be used to short + // circuit downstream `is_dominated_by` checks. + // + // For now, that kind of optimization seems unnecessarily complicated. + for (bcb, _) in basic_coverage_blocks.iter_enumerated() { + for &successor in &basic_coverage_blocks.successors[bcb] { + if basic_coverage_blocks.is_dominated_by(bcb, successor) { + let loop_header = successor; + let backedge_from_bcb = bcb; + debug!( + "Found BCB backedge: {:?} -> loop_header: {:?}", + backedge_from_bcb, loop_header + ); + backedges[loop_header].push(backedge_from_bcb); + } + } + } + backedges +} + +pub struct ShortCircuitPreorder< + 'a, + 'tcx, + F: Fn(&'a mir::Body<'tcx>, &'a TerminatorKind<'tcx>) -> Box<dyn Iterator<Item = BasicBlock> + 'a>, +> { + body: &'a mir::Body<'tcx>, + visited: BitSet<BasicBlock>, + worklist: Vec<BasicBlock>, + filtered_successors: F, +} + +impl< + 'a, + 'tcx, + F: Fn(&'a mir::Body<'tcx>, &'a TerminatorKind<'tcx>) -> Box<dyn Iterator<Item = BasicBlock> + 'a>, +> ShortCircuitPreorder<'a, 'tcx, F> +{ + pub fn new( + body: &'a mir::Body<'tcx>, + filtered_successors: F, + ) -> ShortCircuitPreorder<'a, 'tcx, F> { + let worklist = vec![mir::START_BLOCK]; + + ShortCircuitPreorder { + body, + visited: BitSet::new_empty(body.basic_blocks().len()), + worklist, + filtered_successors, + } + } +} + +impl< + 'a, + 'tcx, + F: Fn(&'a mir::Body<'tcx>, &'a TerminatorKind<'tcx>) -> Box<dyn Iterator<Item = BasicBlock> + 'a>, +> Iterator for ShortCircuitPreorder<'a, 'tcx, F> +{ + type Item = (BasicBlock, &'a BasicBlockData<'tcx>); + + fn next(&mut self) -> Option<(BasicBlock, &'a BasicBlockData<'tcx>)> { + while let Some(idx) = self.worklist.pop() { + if !self.visited.insert(idx) { + continue; + } + + let data = &self.body[idx]; + + if let Some(ref term) = data.terminator { + self.worklist.extend((self.filtered_successors)(&self.body, &term.kind)); + } + + return Some((idx, data)); + } + + None + } + + fn size_hint(&self) -> (usize, Option<usize>) { + let size = self.body.basic_blocks().len() - self.visited.count(); + (size, Some(size)) + } +} diff --git a/compiler/rustc_mir_transform/src/coverage/mod.rs b/compiler/rustc_mir_transform/src/coverage/mod.rs new file mode 100644 index 000000000..2619626a5 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/mod.rs @@ -0,0 +1,580 @@ +pub mod query; + +mod counters; +mod debug; +mod graph; +mod spans; + +#[cfg(test)] +mod tests; + +use counters::CoverageCounters; +use graph::{BasicCoverageBlock, BasicCoverageBlockData, CoverageGraph}; +use spans::{CoverageSpan, CoverageSpans}; + +use crate::MirPass; + +use rustc_data_structures::graph::WithNumNodes; +use rustc_data_structures::sync::Lrc; +use rustc_index::vec::IndexVec; +use rustc_middle::hir; +use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrFlags; +use rustc_middle::mir::coverage::*; +use rustc_middle::mir::dump_enabled; +use rustc_middle::mir::{ + self, BasicBlock, BasicBlockData, Coverage, SourceInfo, Statement, StatementKind, Terminator, + TerminatorKind, +}; +use rustc_middle::ty::TyCtxt; +use rustc_span::def_id::DefId; +use rustc_span::source_map::SourceMap; +use rustc_span::{CharPos, ExpnKind, Pos, SourceFile, Span, Symbol}; + +/// A simple error message wrapper for `coverage::Error`s. +#[derive(Debug)] +struct Error { + message: String, +} + +impl Error { + pub fn from_string<T>(message: String) -> Result<T, Error> { + Err(Self { message }) + } +} + +/// Inserts `StatementKind::Coverage` statements that either instrument the binary with injected +/// counters, via intrinsic `llvm.instrprof.increment`, and/or inject metadata used during codegen +/// to construct the coverage map. +pub struct InstrumentCoverage; + +impl<'tcx> MirPass<'tcx> for InstrumentCoverage { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.instrument_coverage() + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, mir_body: &mut mir::Body<'tcx>) { + let mir_source = mir_body.source; + + // If the InstrumentCoverage pass is called on promoted MIRs, skip them. + // See: https://github.com/rust-lang/rust/pull/73011#discussion_r438317601 + if mir_source.promoted.is_some() { + trace!( + "InstrumentCoverage skipped for {:?} (already promoted for Miri evaluation)", + mir_source.def_id() + ); + return; + } + + let is_fn_like = + tcx.hir().get_by_def_id(mir_source.def_id().expect_local()).fn_kind().is_some(); + + // Only instrument functions, methods, and closures (not constants since they are evaluated + // at compile time by Miri). + // FIXME(#73156): Handle source code coverage in const eval, but note, if and when const + // expressions get coverage spans, we will probably have to "carve out" space for const + // expressions from coverage spans in enclosing MIR's, like we do for closures. (That might + // be tricky if const expressions have no corresponding statements in the enclosing MIR. + // Closures are carved out by their initial `Assign` statement.) + if !is_fn_like { + trace!("InstrumentCoverage skipped for {:?} (not an fn-like)", mir_source.def_id()); + return; + } + + match mir_body.basic_blocks()[mir::START_BLOCK].terminator().kind { + TerminatorKind::Unreachable => { + trace!("InstrumentCoverage skipped for unreachable `START_BLOCK`"); + return; + } + _ => {} + } + + let codegen_fn_attrs = tcx.codegen_fn_attrs(mir_source.def_id()); + if codegen_fn_attrs.flags.contains(CodegenFnAttrFlags::NO_COVERAGE) { + return; + } + + trace!("InstrumentCoverage starting for {:?}", mir_source.def_id()); + Instrumentor::new(&self.name(), tcx, mir_body).inject_counters(); + trace!("InstrumentCoverage done for {:?}", mir_source.def_id()); + } +} + +struct Instrumentor<'a, 'tcx> { + pass_name: &'a str, + tcx: TyCtxt<'tcx>, + mir_body: &'a mut mir::Body<'tcx>, + source_file: Lrc<SourceFile>, + fn_sig_span: Span, + body_span: Span, + basic_coverage_blocks: CoverageGraph, + coverage_counters: CoverageCounters, +} + +impl<'a, 'tcx> Instrumentor<'a, 'tcx> { + fn new(pass_name: &'a str, tcx: TyCtxt<'tcx>, mir_body: &'a mut mir::Body<'tcx>) -> Self { + let source_map = tcx.sess.source_map(); + let def_id = mir_body.source.def_id(); + let (some_fn_sig, hir_body) = fn_sig_and_body(tcx, def_id); + + let body_span = get_body_span(tcx, hir_body, mir_body); + + let source_file = source_map.lookup_source_file(body_span.lo()); + let fn_sig_span = match some_fn_sig.filter(|fn_sig| { + fn_sig.span.eq_ctxt(body_span) + && Lrc::ptr_eq(&source_file, &source_map.lookup_source_file(fn_sig.span.lo())) + }) { + Some(fn_sig) => fn_sig.span.with_hi(body_span.lo()), + None => body_span.shrink_to_lo(), + }; + + debug!( + "instrumenting {}: {:?}, fn sig span: {:?}, body span: {:?}", + if tcx.is_closure(def_id) { "closure" } else { "function" }, + def_id, + fn_sig_span, + body_span + ); + + let function_source_hash = hash_mir_source(tcx, hir_body); + let basic_coverage_blocks = CoverageGraph::from_mir(mir_body); + Self { + pass_name, + tcx, + mir_body, + source_file, + fn_sig_span, + body_span, + basic_coverage_blocks, + coverage_counters: CoverageCounters::new(function_source_hash), + } + } + + fn inject_counters(&'a mut self) { + let tcx = self.tcx; + let mir_source = self.mir_body.source; + let def_id = mir_source.def_id(); + let fn_sig_span = self.fn_sig_span; + let body_span = self.body_span; + + let mut graphviz_data = debug::GraphvizData::new(); + let mut debug_used_expressions = debug::UsedExpressions::new(); + + let dump_mir = dump_enabled(tcx, self.pass_name, def_id); + let dump_graphviz = dump_mir && tcx.sess.opts.unstable_opts.dump_mir_graphviz; + let dump_spanview = dump_mir && tcx.sess.opts.unstable_opts.dump_mir_spanview.is_some(); + + if dump_graphviz { + graphviz_data.enable(); + self.coverage_counters.enable_debug(); + } + + if dump_graphviz || level_enabled!(tracing::Level::DEBUG) { + debug_used_expressions.enable(); + } + + //////////////////////////////////////////////////// + // Compute `CoverageSpan`s from the `CoverageGraph`. + let coverage_spans = CoverageSpans::generate_coverage_spans( + &self.mir_body, + fn_sig_span, + body_span, + &self.basic_coverage_blocks, + ); + + if dump_spanview { + debug::dump_coverage_spanview( + tcx, + self.mir_body, + &self.basic_coverage_blocks, + self.pass_name, + body_span, + &coverage_spans, + ); + } + + //////////////////////////////////////////////////// + // Create an optimized mix of `Counter`s and `Expression`s for the `CoverageGraph`. Ensure + // every `CoverageSpan` has a `Counter` or `Expression` assigned to its `BasicCoverageBlock` + // and all `Expression` dependencies (operands) are also generated, for any other + // `BasicCoverageBlock`s not already associated with a `CoverageSpan`. + // + // Intermediate expressions (used to compute other `Expression` values), which have no + // direct associate to any `BasicCoverageBlock`, are returned in the method `Result`. + let intermediate_expressions_or_error = self + .coverage_counters + .make_bcb_counters(&mut self.basic_coverage_blocks, &coverage_spans); + + let (result, intermediate_expressions) = match intermediate_expressions_or_error { + Ok(intermediate_expressions) => { + // If debugging, add any intermediate expressions (which are not associated with any + // BCB) to the `debug_used_expressions` map. + if debug_used_expressions.is_enabled() { + for intermediate_expression in &intermediate_expressions { + debug_used_expressions.add_expression_operands(intermediate_expression); + } + } + + //////////////////////////////////////////////////// + // Remove the counter or edge counter from of each `CoverageSpan`s associated + // `BasicCoverageBlock`, and inject a `Coverage` statement into the MIR. + // + // `Coverage` statements injected from `CoverageSpan`s will include the code regions + // (source code start and end positions) to be counted by the associated counter. + // + // These `CoverageSpan`-associated counters are removed from their associated + // `BasicCoverageBlock`s so that the only remaining counters in the `CoverageGraph` + // are indirect counters (to be injected next, without associated code regions). + self.inject_coverage_span_counters( + coverage_spans, + &mut graphviz_data, + &mut debug_used_expressions, + ); + + //////////////////////////////////////////////////// + // For any remaining `BasicCoverageBlock` counters (that were not associated with + // any `CoverageSpan`), inject `Coverage` statements (_without_ code region `Span`s) + // to ensure `BasicCoverageBlock` counters that other `Expression`s may depend on + // are in fact counted, even though they don't directly contribute to counting + // their own independent code region's coverage. + self.inject_indirect_counters(&mut graphviz_data, &mut debug_used_expressions); + + // Intermediate expressions will be injected as the final step, after generating + // debug output, if any. + //////////////////////////////////////////////////// + + (Ok(()), intermediate_expressions) + } + Err(e) => (Err(e), Vec::new()), + }; + + if graphviz_data.is_enabled() { + // Even if there was an error, a partial CoverageGraph can still generate a useful + // graphviz output. + debug::dump_coverage_graphviz( + tcx, + self.mir_body, + self.pass_name, + &self.basic_coverage_blocks, + &self.coverage_counters.debug_counters, + &graphviz_data, + &intermediate_expressions, + &debug_used_expressions, + ); + } + + if let Err(e) = result { + bug!("Error processing: {:?}: {:?}", self.mir_body.source.def_id(), e.message) + }; + + // Depending on current `debug_options()`, `alert_on_unused_expressions()` could panic, so + // this check is performed as late as possible, to allow other debug output (logs and dump + // files), which might be helpful in analyzing unused expressions, to still be generated. + debug_used_expressions.alert_on_unused_expressions(&self.coverage_counters.debug_counters); + + //////////////////////////////////////////////////// + // Finally, inject the intermediate expressions collected along the way. + for intermediate_expression in intermediate_expressions { + inject_intermediate_expression(self.mir_body, intermediate_expression); + } + } + + /// Inject a counter for each `CoverageSpan`. There can be multiple `CoverageSpan`s for a given + /// BCB, but only one actual counter needs to be incremented per BCB. `bb_counters` maps each + /// `bcb` to its `Counter`, when injected. Subsequent `CoverageSpan`s for a BCB that already has + /// a `Counter` will inject an `Expression` instead, and compute its value by adding `ZERO` to + /// the BCB `Counter` value. + /// + /// If debugging, add every BCB `Expression` associated with a `CoverageSpan`s to the + /// `used_expression_operands` map. + fn inject_coverage_span_counters( + &mut self, + coverage_spans: Vec<CoverageSpan>, + graphviz_data: &mut debug::GraphvizData, + debug_used_expressions: &mut debug::UsedExpressions, + ) { + let tcx = self.tcx; + let source_map = tcx.sess.source_map(); + let body_span = self.body_span; + let file_name = Symbol::intern(&self.source_file.name.prefer_remapped().to_string_lossy()); + + let mut bcb_counters = IndexVec::from_elem_n(None, self.basic_coverage_blocks.num_nodes()); + for covspan in coverage_spans { + let bcb = covspan.bcb; + let span = covspan.span; + let counter_kind = if let Some(&counter_operand) = bcb_counters[bcb].as_ref() { + self.coverage_counters.make_identity_counter(counter_operand) + } else if let Some(counter_kind) = self.bcb_data_mut(bcb).take_counter() { + bcb_counters[bcb] = Some(counter_kind.as_operand_id()); + debug_used_expressions.add_expression_operands(&counter_kind); + counter_kind + } else { + bug!("Every BasicCoverageBlock should have a Counter or Expression"); + }; + graphviz_data.add_bcb_coverage_span_with_counter(bcb, &covspan, &counter_kind); + + debug!( + "Calling make_code_region(file_name={}, source_file={:?}, span={}, body_span={})", + file_name, + self.source_file, + source_map.span_to_diagnostic_string(span), + source_map.span_to_diagnostic_string(body_span) + ); + + inject_statement( + self.mir_body, + counter_kind, + self.bcb_leader_bb(bcb), + Some(make_code_region(source_map, file_name, &self.source_file, span, body_span)), + ); + } + } + + /// `inject_coverage_span_counters()` looped through the `CoverageSpan`s and injected the + /// counter from the `CoverageSpan`s `BasicCoverageBlock`, removing it from the BCB in the + /// process (via `take_counter()`). + /// + /// Any other counter associated with a `BasicCoverageBlock`, or its incoming edge, but not + /// associated with a `CoverageSpan`, should only exist if the counter is an `Expression` + /// dependency (one of the expression operands). Collect them, and inject the additional + /// counters into the MIR, without a reportable coverage span. + fn inject_indirect_counters( + &mut self, + graphviz_data: &mut debug::GraphvizData, + debug_used_expressions: &mut debug::UsedExpressions, + ) { + let mut bcb_counters_without_direct_coverage_spans = Vec::new(); + for (target_bcb, target_bcb_data) in self.basic_coverage_blocks.iter_enumerated_mut() { + if let Some(counter_kind) = target_bcb_data.take_counter() { + bcb_counters_without_direct_coverage_spans.push((None, target_bcb, counter_kind)); + } + if let Some(edge_counters) = target_bcb_data.take_edge_counters() { + for (from_bcb, counter_kind) in edge_counters { + bcb_counters_without_direct_coverage_spans.push(( + Some(from_bcb), + target_bcb, + counter_kind, + )); + } + } + } + + // If debug is enabled, validate that every BCB or edge counter not directly associated + // with a coverage span is at least indirectly associated (it is a dependency of a BCB + // counter that _is_ associated with a coverage span). + debug_used_expressions.validate(&bcb_counters_without_direct_coverage_spans); + + for (edge_from_bcb, target_bcb, counter_kind) in bcb_counters_without_direct_coverage_spans + { + debug_used_expressions.add_unused_expression_if_not_found( + &counter_kind, + edge_from_bcb, + target_bcb, + ); + + match counter_kind { + CoverageKind::Counter { .. } => { + let inject_to_bb = if let Some(from_bcb) = edge_from_bcb { + // The MIR edge starts `from_bb` (the outgoing / last BasicBlock in + // `from_bcb`) and ends at `to_bb` (the incoming / first BasicBlock in the + // `target_bcb`; also called the `leader_bb`). + let from_bb = self.bcb_last_bb(from_bcb); + let to_bb = self.bcb_leader_bb(target_bcb); + + let new_bb = inject_edge_counter_basic_block(self.mir_body, from_bb, to_bb); + graphviz_data.set_edge_counter(from_bcb, new_bb, &counter_kind); + debug!( + "Edge {:?} (last {:?}) -> {:?} (leader {:?}) requires a new MIR \ + BasicBlock {:?}, for unclaimed edge counter {}", + edge_from_bcb, + from_bb, + target_bcb, + to_bb, + new_bb, + self.format_counter(&counter_kind), + ); + new_bb + } else { + let target_bb = self.bcb_last_bb(target_bcb); + graphviz_data.add_bcb_dependency_counter(target_bcb, &counter_kind); + debug!( + "{:?} ({:?}) gets a new Coverage statement for unclaimed counter {}", + target_bcb, + target_bb, + self.format_counter(&counter_kind), + ); + target_bb + }; + + inject_statement(self.mir_body, counter_kind, inject_to_bb, None); + } + CoverageKind::Expression { .. } => { + inject_intermediate_expression(self.mir_body, counter_kind) + } + _ => bug!("CoverageKind should be a counter"), + } + } + } + + #[inline] + fn bcb_leader_bb(&self, bcb: BasicCoverageBlock) -> BasicBlock { + self.bcb_data(bcb).leader_bb() + } + + #[inline] + fn bcb_last_bb(&self, bcb: BasicCoverageBlock) -> BasicBlock { + self.bcb_data(bcb).last_bb() + } + + #[inline] + fn bcb_data(&self, bcb: BasicCoverageBlock) -> &BasicCoverageBlockData { + &self.basic_coverage_blocks[bcb] + } + + #[inline] + fn bcb_data_mut(&mut self, bcb: BasicCoverageBlock) -> &mut BasicCoverageBlockData { + &mut self.basic_coverage_blocks[bcb] + } + + #[inline] + fn format_counter(&self, counter_kind: &CoverageKind) -> String { + self.coverage_counters.debug_counters.format_counter(counter_kind) + } +} + +fn inject_edge_counter_basic_block( + mir_body: &mut mir::Body<'_>, + from_bb: BasicBlock, + to_bb: BasicBlock, +) -> BasicBlock { + let span = mir_body[from_bb].terminator().source_info.span.shrink_to_hi(); + let new_bb = mir_body.basic_blocks_mut().push(BasicBlockData { + statements: vec![], // counter will be injected here + terminator: Some(Terminator { + source_info: SourceInfo::outermost(span), + kind: TerminatorKind::Goto { target: to_bb }, + }), + is_cleanup: false, + }); + let edge_ref = mir_body[from_bb] + .terminator_mut() + .successors_mut() + .find(|successor| **successor == to_bb) + .expect("from_bb should have a successor for to_bb"); + *edge_ref = new_bb; + new_bb +} + +fn inject_statement( + mir_body: &mut mir::Body<'_>, + counter_kind: CoverageKind, + bb: BasicBlock, + some_code_region: Option<CodeRegion>, +) { + debug!( + " injecting statement {:?} for {:?} at code region: {:?}", + counter_kind, bb, some_code_region + ); + let data = &mut mir_body[bb]; + let source_info = data.terminator().source_info; + let statement = Statement { + source_info, + kind: StatementKind::Coverage(Box::new(Coverage { + kind: counter_kind, + code_region: some_code_region, + })), + }; + data.statements.insert(0, statement); +} + +// Non-code expressions are injected into the coverage map, without generating executable code. +fn inject_intermediate_expression(mir_body: &mut mir::Body<'_>, expression: CoverageKind) { + debug_assert!(matches!(expression, CoverageKind::Expression { .. })); + debug!(" injecting non-code expression {:?}", expression); + let inject_in_bb = mir::START_BLOCK; + let data = &mut mir_body[inject_in_bb]; + let source_info = data.terminator().source_info; + let statement = Statement { + source_info, + kind: StatementKind::Coverage(Box::new(Coverage { kind: expression, code_region: None })), + }; + data.statements.push(statement); +} + +/// Convert the Span into its file name, start line and column, and end line and column +fn make_code_region( + source_map: &SourceMap, + file_name: Symbol, + source_file: &Lrc<SourceFile>, + span: Span, + body_span: Span, +) -> CodeRegion { + let (start_line, mut start_col) = source_file.lookup_file_pos(span.lo()); + let (end_line, end_col) = if span.hi() == span.lo() { + let (end_line, mut end_col) = (start_line, start_col); + // Extend an empty span by one character so the region will be counted. + let CharPos(char_pos) = start_col; + if span.hi() == body_span.hi() { + start_col = CharPos(char_pos - 1); + } else { + end_col = CharPos(char_pos + 1); + } + (end_line, end_col) + } else { + source_file.lookup_file_pos(span.hi()) + }; + let start_line = source_map.doctest_offset_line(&source_file.name, start_line); + let end_line = source_map.doctest_offset_line(&source_file.name, end_line); + CodeRegion { + file_name, + start_line: start_line as u32, + start_col: start_col.to_u32() + 1, + end_line: end_line as u32, + end_col: end_col.to_u32() + 1, + } +} + +fn fn_sig_and_body<'tcx>( + tcx: TyCtxt<'tcx>, + def_id: DefId, +) -> (Option<&'tcx rustc_hir::FnSig<'tcx>>, &'tcx rustc_hir::Body<'tcx>) { + // FIXME(#79625): Consider improving MIR to provide the information needed, to avoid going back + // to HIR for it. + let hir_node = tcx.hir().get_if_local(def_id).expect("expected DefId is local"); + let fn_body_id = hir::map::associated_body(hir_node).expect("HIR node is a function with body"); + (hir::map::fn_sig(hir_node), tcx.hir().body(fn_body_id)) +} + +fn get_body_span<'tcx>( + tcx: TyCtxt<'tcx>, + hir_body: &rustc_hir::Body<'tcx>, + mir_body: &mut mir::Body<'tcx>, +) -> Span { + let mut body_span = hir_body.value.span; + let def_id = mir_body.source.def_id(); + + if tcx.is_closure(def_id) { + // If the MIR function is a closure, and if the closure body span + // starts from a macro, but it's content is not in that macro, try + // to find a non-macro callsite, and instrument the spans there + // instead. + loop { + let expn_data = body_span.ctxt().outer_expn_data(); + if expn_data.is_root() { + break; + } + if let ExpnKind::Macro { .. } = expn_data.kind { + body_span = expn_data.call_site; + } else { + break; + } + } + } + + body_span +} + +fn hash_mir_source<'tcx>(tcx: TyCtxt<'tcx>, hir_body: &'tcx rustc_hir::Body<'tcx>) -> u64 { + // FIXME(cjgillot) Stop hashing HIR manually here. + let owner = hir_body.id().hir_id.owner; + tcx.hir_owner_nodes(owner).unwrap().hash_including_bodies.to_smaller_hash() +} diff --git a/compiler/rustc_mir_transform/src/coverage/query.rs b/compiler/rustc_mir_transform/src/coverage/query.rs new file mode 100644 index 000000000..9d02f58ae --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/query.rs @@ -0,0 +1,170 @@ +use super::*; + +use rustc_middle::mir::coverage::*; +use rustc_middle::mir::{self, Body, Coverage, CoverageInfo}; +use rustc_middle::ty::query::Providers; +use rustc_middle::ty::{self, TyCtxt}; +use rustc_span::def_id::DefId; + +/// A `query` provider for retrieving coverage information injected into MIR. +pub(crate) fn provide(providers: &mut Providers) { + providers.coverageinfo = |tcx, def_id| coverageinfo(tcx, def_id); + providers.covered_code_regions = |tcx, def_id| covered_code_regions(tcx, def_id); +} + +/// The `num_counters` argument to `llvm.instrprof.increment` is the max counter_id + 1, or in +/// other words, the number of counter value references injected into the MIR (plus 1 for the +/// reserved `ZERO` counter, which uses counter ID `0` when included in an expression). Injected +/// counters have a counter ID from `1..num_counters-1`. +/// +/// `num_expressions` is the number of counter expressions added to the MIR body. +/// +/// Both `num_counters` and `num_expressions` are used to initialize new vectors, during backend +/// code generate, to lookup counters and expressions by simple u32 indexes. +/// +/// MIR optimization may split and duplicate some BasicBlock sequences, or optimize out some code +/// including injected counters. (It is OK if some counters are optimized out, but those counters +/// are still included in the total `num_counters` or `num_expressions`.) Simply counting the +/// calls may not work; but computing the number of counters or expressions by adding `1` to the +/// highest ID (for a given instrumented function) is valid. +/// +/// This visitor runs twice, first with `add_missing_operands` set to `false`, to find the maximum +/// counter ID and maximum expression ID based on their enum variant `id` fields; then, as a +/// safeguard, with `add_missing_operands` set to `true`, to find any other counter or expression +/// IDs referenced by expression operands, if not already seen. +/// +/// Ideally, each operand ID in a MIR `CoverageKind::Expression` will have a separate MIR `Coverage` +/// statement for the `Counter` or `Expression` with the referenced ID. but since current or future +/// MIR optimizations can theoretically optimize out segments of a MIR, it may not be possible to +/// guarantee this, so the second pass ensures the `CoverageInfo` counts include all referenced IDs. +struct CoverageVisitor { + info: CoverageInfo, + add_missing_operands: bool, +} + +impl CoverageVisitor { + /// Updates `num_counters` to the maximum encountered zero-based counter_id plus 1. Note the + /// final computed number of counters should be the number of all `CoverageKind::Counter` + /// statements in the MIR *plus one* for the implicit `ZERO` counter. + #[inline(always)] + fn update_num_counters(&mut self, counter_id: u32) { + self.info.num_counters = std::cmp::max(self.info.num_counters, counter_id + 1); + } + + /// Computes an expression index for each expression ID, and updates `num_expressions` to the + /// maximum encountered index plus 1. + #[inline(always)] + fn update_num_expressions(&mut self, expression_id: u32) { + let expression_index = u32::MAX - expression_id; + self.info.num_expressions = std::cmp::max(self.info.num_expressions, expression_index + 1); + } + + fn update_from_expression_operand(&mut self, operand_id: u32) { + if operand_id >= self.info.num_counters { + let operand_as_expression_index = u32::MAX - operand_id; + if operand_as_expression_index >= self.info.num_expressions { + // The operand ID is outside the known range of counter IDs and also outside the + // known range of expression IDs. In either case, the result of a missing operand + // (if and when used in an expression) will be zero, so from a computation + // perspective, it doesn't matter whether it is interpreted as a counter or an + // expression. + // + // However, the `num_counters` and `num_expressions` query results are used to + // allocate arrays when generating the coverage map (during codegen), so choose + // the type that grows either `num_counters` or `num_expressions` the least. + if operand_id - self.info.num_counters + < operand_as_expression_index - self.info.num_expressions + { + self.update_num_counters(operand_id) + } else { + self.update_num_expressions(operand_id) + } + } + } + } + + fn visit_body(&mut self, body: &Body<'_>) { + for bb_data in body.basic_blocks().iter() { + for statement in bb_data.statements.iter() { + if let StatementKind::Coverage(box ref coverage) = statement.kind { + if is_inlined(body, statement) { + continue; + } + self.visit_coverage(coverage); + } + } + } + } + + fn visit_coverage(&mut self, coverage: &Coverage) { + if self.add_missing_operands { + match coverage.kind { + CoverageKind::Expression { lhs, rhs, .. } => { + self.update_from_expression_operand(u32::from(lhs)); + self.update_from_expression_operand(u32::from(rhs)); + } + _ => {} + } + } else { + match coverage.kind { + CoverageKind::Counter { id, .. } => { + self.update_num_counters(u32::from(id)); + } + CoverageKind::Expression { id, .. } => { + self.update_num_expressions(u32::from(id)); + } + _ => {} + } + } + } +} + +fn coverageinfo<'tcx>(tcx: TyCtxt<'tcx>, instance_def: ty::InstanceDef<'tcx>) -> CoverageInfo { + let mir_body = tcx.instance_mir(instance_def); + + let mut coverage_visitor = CoverageVisitor { + // num_counters always has at least the `ZERO` counter. + info: CoverageInfo { num_counters: 1, num_expressions: 0 }, + add_missing_operands: false, + }; + + coverage_visitor.visit_body(mir_body); + + coverage_visitor.add_missing_operands = true; + coverage_visitor.visit_body(mir_body); + + coverage_visitor.info +} + +fn covered_code_regions<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Vec<&'tcx CodeRegion> { + let body = mir_body(tcx, def_id); + body.basic_blocks() + .iter() + .flat_map(|data| { + data.statements.iter().filter_map(|statement| match statement.kind { + StatementKind::Coverage(box ref coverage) => { + if is_inlined(body, statement) { + None + } else { + coverage.code_region.as_ref() // may be None + } + } + _ => None, + }) + }) + .collect() +} + +fn is_inlined(body: &Body<'_>, statement: &Statement<'_>) -> bool { + let scope_data = &body.source_scopes[statement.source_info.scope]; + scope_data.inlined.is_some() || scope_data.inlined_parent_scope.is_some() +} + +/// This function ensures we obtain the correct MIR for the given item irrespective of +/// whether that means const mir or runtime mir. For `const fn` this opts for runtime +/// mir. +fn mir_body<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> &'tcx mir::Body<'tcx> { + let id = ty::WithOptConstParam::unknown(def_id); + let def = ty::InstanceDef::Item(id); + tcx.instance_mir(def) +} diff --git a/compiler/rustc_mir_transform/src/coverage/spans.rs b/compiler/rustc_mir_transform/src/coverage/spans.rs new file mode 100644 index 000000000..423e78317 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/spans.rs @@ -0,0 +1,892 @@ +use super::debug::term_type; +use super::graph::{BasicCoverageBlock, BasicCoverageBlockData, CoverageGraph, START_BCB}; + +use itertools::Itertools; +use rustc_data_structures::graph::WithNumNodes; +use rustc_middle::mir::spanview::source_range_no_file; +use rustc_middle::mir::{ + self, AggregateKind, BasicBlock, FakeReadCause, Rvalue, Statement, StatementKind, Terminator, + TerminatorKind, +}; +use rustc_middle::ty::TyCtxt; +use rustc_span::source_map::original_sp; +use rustc_span::{BytePos, ExpnKind, MacroKind, Span, Symbol}; + +use std::cell::RefCell; +use std::cmp::Ordering; + +#[derive(Debug, Copy, Clone)] +pub(super) enum CoverageStatement { + Statement(BasicBlock, Span, usize), + Terminator(BasicBlock, Span), +} + +impl CoverageStatement { + pub fn format<'tcx>(&self, tcx: TyCtxt<'tcx>, mir_body: &mir::Body<'tcx>) -> String { + match *self { + Self::Statement(bb, span, stmt_index) => { + let stmt = &mir_body[bb].statements[stmt_index]; + format!( + "{}: @{}[{}]: {:?}", + source_range_no_file(tcx, span), + bb.index(), + stmt_index, + stmt + ) + } + Self::Terminator(bb, span) => { + let term = mir_body[bb].terminator(); + format!( + "{}: @{}.{}: {:?}", + source_range_no_file(tcx, span), + bb.index(), + term_type(&term.kind), + term.kind + ) + } + } + } + + pub fn span(&self) -> Span { + match self { + Self::Statement(_, span, _) | Self::Terminator(_, span) => *span, + } + } +} + +/// A BCB is deconstructed into one or more `Span`s. Each `Span` maps to a `CoverageSpan` that +/// references the originating BCB and one or more MIR `Statement`s and/or `Terminator`s. +/// Initially, the `Span`s come from the `Statement`s and `Terminator`s, but subsequent +/// transforms can combine adjacent `Span`s and `CoverageSpan` from the same BCB, merging the +/// `CoverageStatement` vectors, and the `Span`s to cover the extent of the combined `Span`s. +/// +/// Note: A `CoverageStatement` merged into another CoverageSpan may come from a `BasicBlock` that +/// is not part of the `CoverageSpan` bcb if the statement was included because it's `Span` matches +/// or is subsumed by the `Span` associated with this `CoverageSpan`, and it's `BasicBlock` +/// `is_dominated_by()` the `BasicBlock`s in this `CoverageSpan`. +#[derive(Debug, Clone)] +pub(super) struct CoverageSpan { + pub span: Span, + pub expn_span: Span, + pub current_macro_or_none: RefCell<Option<Option<Symbol>>>, + pub bcb: BasicCoverageBlock, + pub coverage_statements: Vec<CoverageStatement>, + pub is_closure: bool, +} + +impl CoverageSpan { + pub fn for_fn_sig(fn_sig_span: Span) -> Self { + Self { + span: fn_sig_span, + expn_span: fn_sig_span, + current_macro_or_none: Default::default(), + bcb: START_BCB, + coverage_statements: vec![], + is_closure: false, + } + } + + pub fn for_statement( + statement: &Statement<'_>, + span: Span, + expn_span: Span, + bcb: BasicCoverageBlock, + bb: BasicBlock, + stmt_index: usize, + ) -> Self { + let is_closure = match statement.kind { + StatementKind::Assign(box (_, Rvalue::Aggregate(box ref kind, _))) => { + matches!(kind, AggregateKind::Closure(_, _) | AggregateKind::Generator(_, _, _)) + } + _ => false, + }; + + Self { + span, + expn_span, + current_macro_or_none: Default::default(), + bcb, + coverage_statements: vec![CoverageStatement::Statement(bb, span, stmt_index)], + is_closure, + } + } + + pub fn for_terminator( + span: Span, + expn_span: Span, + bcb: BasicCoverageBlock, + bb: BasicBlock, + ) -> Self { + Self { + span, + expn_span, + current_macro_or_none: Default::default(), + bcb, + coverage_statements: vec![CoverageStatement::Terminator(bb, span)], + is_closure: false, + } + } + + pub fn merge_from(&mut self, mut other: CoverageSpan) { + debug_assert!(self.is_mergeable(&other)); + self.span = self.span.to(other.span); + self.coverage_statements.append(&mut other.coverage_statements); + } + + pub fn cutoff_statements_at(&mut self, cutoff_pos: BytePos) { + self.coverage_statements.retain(|covstmt| covstmt.span().hi() <= cutoff_pos); + if let Some(highest_covstmt) = + self.coverage_statements.iter().max_by_key(|covstmt| covstmt.span().hi()) + { + self.span = self.span.with_hi(highest_covstmt.span().hi()); + } + } + + #[inline] + pub fn is_mergeable(&self, other: &Self) -> bool { + self.is_in_same_bcb(other) && !(self.is_closure || other.is_closure) + } + + #[inline] + pub fn is_in_same_bcb(&self, other: &Self) -> bool { + self.bcb == other.bcb + } + + pub fn format<'tcx>(&self, tcx: TyCtxt<'tcx>, mir_body: &mir::Body<'tcx>) -> String { + format!( + "{}\n {}", + source_range_no_file(tcx, self.span), + self.format_coverage_statements(tcx, mir_body).replace('\n', "\n "), + ) + } + + pub fn format_coverage_statements<'tcx>( + &self, + tcx: TyCtxt<'tcx>, + mir_body: &mir::Body<'tcx>, + ) -> String { + let mut sorted_coverage_statements = self.coverage_statements.clone(); + sorted_coverage_statements.sort_unstable_by_key(|covstmt| match *covstmt { + CoverageStatement::Statement(bb, _, index) => (bb, index), + CoverageStatement::Terminator(bb, _) => (bb, usize::MAX), + }); + sorted_coverage_statements.iter().map(|covstmt| covstmt.format(tcx, mir_body)).join("\n") + } + + /// If the span is part of a macro, returns the macro name symbol. + pub fn current_macro(&self) -> Option<Symbol> { + self.current_macro_or_none + .borrow_mut() + .get_or_insert_with(|| { + if let ExpnKind::Macro(MacroKind::Bang, current_macro) = + self.expn_span.ctxt().outer_expn_data().kind + { + return Some(current_macro); + } + None + }) + .map(|symbol| symbol) + } + + /// If the span is part of a macro, and the macro is visible (expands directly to the given + /// body_span), returns the macro name symbol. + pub fn visible_macro(&self, body_span: Span) -> Option<Symbol> { + if let Some(current_macro) = self.current_macro() && self + .expn_span + .parent_callsite() + .unwrap_or_else(|| bug!("macro must have a parent")) + .eq_ctxt(body_span) + { + return Some(current_macro); + } + None + } + + pub fn is_macro_expansion(&self) -> bool { + self.current_macro().is_some() + } +} + +/// Converts the initial set of `CoverageSpan`s (one per MIR `Statement` or `Terminator`) into a +/// minimal set of `CoverageSpan`s, using the BCB CFG to determine where it is safe and useful to: +/// +/// * Remove duplicate source code coverage regions +/// * Merge spans that represent continuous (both in source code and control flow), non-branching +/// execution +/// * Carve out (leave uncovered) any span that will be counted by another MIR (notably, closures) +pub struct CoverageSpans<'a, 'tcx> { + /// The MIR, used to look up `BasicBlockData`. + mir_body: &'a mir::Body<'tcx>, + + /// A `Span` covering the signature of function for the MIR. + fn_sig_span: Span, + + /// A `Span` covering the function body of the MIR (typically from left curly brace to right + /// curly brace). + body_span: Span, + + /// The BasicCoverageBlock Control Flow Graph (BCB CFG). + basic_coverage_blocks: &'a CoverageGraph, + + /// The initial set of `CoverageSpan`s, sorted by `Span` (`lo` and `hi`) and by relative + /// dominance between the `BasicCoverageBlock`s of equal `Span`s. + sorted_spans_iter: Option<std::vec::IntoIter<CoverageSpan>>, + + /// The current `CoverageSpan` to compare to its `prev`, to possibly merge, discard, force the + /// discard of the `prev` (and or `pending_dups`), or keep both (with `prev` moved to + /// `pending_dups`). If `curr` is not discarded or merged, it becomes `prev` for the next + /// iteration. + some_curr: Option<CoverageSpan>, + + /// The original `span` for `curr`, in case `curr.span()` is modified. The `curr_original_span` + /// **must not be mutated** (except when advancing to the next `curr`), even if `curr.span()` + /// is mutated. + curr_original_span: Span, + + /// The CoverageSpan from a prior iteration; typically assigned from that iteration's `curr`. + /// If that `curr` was discarded, `prev` retains its value from the previous iteration. + some_prev: Option<CoverageSpan>, + + /// Assigned from `curr_original_span` from the previous iteration. The `prev_original_span` + /// **must not be mutated** (except when advancing to the next `prev`), even if `prev.span()` + /// is mutated. + prev_original_span: Span, + + /// A copy of the expn_span from the prior iteration. + prev_expn_span: Option<Span>, + + /// One or more `CoverageSpan`s with the same `Span` but different `BasicCoverageBlock`s, and + /// no `BasicCoverageBlock` in this list dominates another `BasicCoverageBlock` in the list. + /// If a new `curr` span also fits this criteria (compared to an existing list of + /// `pending_dups`), that `curr` `CoverageSpan` moves to `prev` before possibly being added to + /// the `pending_dups` list, on the next iteration. As a result, if `prev` and `pending_dups` + /// have the same `Span`, the criteria for `pending_dups` holds for `prev` as well: a `prev` + /// with a matching `Span` does not dominate any `pending_dup` and no `pending_dup` dominates a + /// `prev` with a matching `Span`) + pending_dups: Vec<CoverageSpan>, + + /// The final `CoverageSpan`s to add to the coverage map. A `Counter` or `Expression` + /// will also be injected into the MIR for each `CoverageSpan`. + refined_spans: Vec<CoverageSpan>, +} + +impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { + /// Generate a minimal set of `CoverageSpan`s, each representing a contiguous code region to be + /// counted. + /// + /// The basic steps are: + /// + /// 1. Extract an initial set of spans from the `Statement`s and `Terminator`s of each + /// `BasicCoverageBlockData`. + /// 2. Sort the spans by span.lo() (starting position). Spans that start at the same position + /// are sorted with longer spans before shorter spans; and equal spans are sorted + /// (deterministically) based on "dominator" relationship (if any). + /// 3. Traverse the spans in sorted order to identify spans that can be dropped (for instance, + /// if another span or spans are already counting the same code region), or should be merged + /// into a broader combined span (because it represents a contiguous, non-branching, and + /// uninterrupted region of source code). + /// + /// Closures are exposed in their enclosing functions as `Assign` `Rvalue`s, and since + /// closures have their own MIR, their `Span` in their enclosing function should be left + /// "uncovered". + /// + /// Note the resulting vector of `CoverageSpan`s may not be fully sorted (and does not need + /// to be). + pub(super) fn generate_coverage_spans( + mir_body: &'a mir::Body<'tcx>, + fn_sig_span: Span, // Ensured to be same SourceFile and SyntaxContext as `body_span` + body_span: Span, + basic_coverage_blocks: &'a CoverageGraph, + ) -> Vec<CoverageSpan> { + let mut coverage_spans = CoverageSpans { + mir_body, + fn_sig_span, + body_span, + basic_coverage_blocks, + sorted_spans_iter: None, + refined_spans: Vec::with_capacity(basic_coverage_blocks.num_nodes() * 2), + some_curr: None, + curr_original_span: Span::with_root_ctxt(BytePos(0), BytePos(0)), + some_prev: None, + prev_original_span: Span::with_root_ctxt(BytePos(0), BytePos(0)), + prev_expn_span: None, + pending_dups: Vec::new(), + }; + + let sorted_spans = coverage_spans.mir_to_initial_sorted_coverage_spans(); + + coverage_spans.sorted_spans_iter = Some(sorted_spans.into_iter()); + + coverage_spans.to_refined_spans() + } + + fn mir_to_initial_sorted_coverage_spans(&self) -> Vec<CoverageSpan> { + let mut initial_spans = + Vec::<CoverageSpan>::with_capacity(self.mir_body.basic_blocks.len() * 2); + for (bcb, bcb_data) in self.basic_coverage_blocks.iter_enumerated() { + initial_spans.extend(self.bcb_to_initial_coverage_spans(bcb, bcb_data)); + } + + if initial_spans.is_empty() { + // This can happen if, for example, the function is unreachable (contains only a + // `BasicBlock`(s) with an `Unreachable` terminator). + return initial_spans; + } + + initial_spans.push(CoverageSpan::for_fn_sig(self.fn_sig_span)); + + initial_spans.sort_unstable_by(|a, b| { + if a.span.lo() == b.span.lo() { + if a.span.hi() == b.span.hi() { + if a.is_in_same_bcb(b) { + Some(Ordering::Equal) + } else { + // Sort equal spans by dominator relationship, in reverse order (so + // dominators always come after the dominated equal spans). When later + // comparing two spans in order, the first will either dominate the second, + // or they will have no dominator relationship. + self.basic_coverage_blocks.dominators().rank_partial_cmp(b.bcb, a.bcb) + } + } else { + // Sort hi() in reverse order so shorter spans are attempted after longer spans. + // This guarantees that, if a `prev` span overlaps, and is not equal to, a + // `curr` span, the prev span either extends further left of the curr span, or + // they start at the same position and the prev span extends further right of + // the end of the curr span. + b.span.hi().partial_cmp(&a.span.hi()) + } + } else { + a.span.lo().partial_cmp(&b.span.lo()) + } + .unwrap() + }); + + initial_spans + } + + /// Iterate through the sorted `CoverageSpan`s, and return the refined list of merged and + /// de-duplicated `CoverageSpan`s. + fn to_refined_spans(mut self) -> Vec<CoverageSpan> { + while self.next_coverage_span() { + if self.some_prev.is_none() { + debug!(" initial span"); + self.check_invoked_macro_name_span(); + } else if self.curr().is_mergeable(self.prev()) { + debug!(" same bcb (and neither is a closure), merge with prev={:?}", self.prev()); + let prev = self.take_prev(); + self.curr_mut().merge_from(prev); + self.check_invoked_macro_name_span(); + // Note that curr.span may now differ from curr_original_span + } else if self.prev_ends_before_curr() { + debug!( + " different bcbs and disjoint spans, so keep curr for next iter, and add \ + prev={:?}", + self.prev() + ); + let prev = self.take_prev(); + self.push_refined_span(prev); + self.check_invoked_macro_name_span(); + } else if self.prev().is_closure { + // drop any equal or overlapping span (`curr`) and keep `prev` to test again in the + // next iter + debug!( + " curr overlaps a closure (prev). Drop curr and keep prev for next iter. \ + prev={:?}", + self.prev() + ); + self.take_curr(); + } else if self.curr().is_closure { + self.carve_out_span_for_closure(); + } else if self.prev_original_span == self.curr().span { + // Note that this compares the new (`curr`) span to `prev_original_span`. + // In this branch, the actual span byte range of `prev_original_span` is not + // important. What is important is knowing whether the new `curr` span was + // **originally** the same as the original span of `prev()`. The original spans + // reflect their original sort order, and for equal spans, conveys a partial + // ordering based on CFG dominator priority. + if self.prev().is_macro_expansion() && self.curr().is_macro_expansion() { + // Macros that expand to include branching (such as + // `assert_eq!()`, `assert_ne!()`, `info!()`, `debug!()`, or + // `trace!()) typically generate callee spans with identical + // ranges (typically the full span of the macro) for all + // `BasicBlocks`. This makes it impossible to distinguish + // the condition (`if val1 != val2`) from the optional + // branched statements (such as the call to `panic!()` on + // assert failure). In this case it is better (or less + // worse) to drop the optional branch bcbs and keep the + // non-conditional statements, to count when reached. + debug!( + " curr and prev are part of a macro expansion, and curr has the same span \ + as prev, but is in a different bcb. Drop curr and keep prev for next iter. \ + prev={:?}", + self.prev() + ); + self.take_curr(); + } else { + self.hold_pending_dups_unless_dominated(); + } + } else { + self.cutoff_prev_at_overlapping_curr(); + self.check_invoked_macro_name_span(); + } + } + + debug!(" AT END, adding last prev={:?}", self.prev()); + let prev = self.take_prev(); + let pending_dups = self.pending_dups.split_off(0); + for dup in pending_dups { + debug!(" ...adding at least one pending dup={:?}", dup); + self.push_refined_span(dup); + } + + // Async functions wrap a closure that implements the body to be executed. The enclosing + // function is called and returns an `impl Future` without initially executing any of the + // body. To avoid showing the return from the enclosing function as a "covered" return from + // the closure, the enclosing function's `TerminatorKind::Return`s `CoverageSpan` is + // excluded. The closure's `Return` is the only one that will be counted. This provides + // adequate coverage, and more intuitive counts. (Avoids double-counting the closing brace + // of the function body.) + let body_ends_with_closure = if let Some(last_covspan) = self.refined_spans.last() { + last_covspan.is_closure && last_covspan.span.hi() == self.body_span.hi() + } else { + false + }; + + if !body_ends_with_closure { + self.push_refined_span(prev); + } + + // Remove `CoverageSpan`s derived from closures, originally added to ensure the coverage + // regions for the current function leave room for the closure's own coverage regions + // (injected separately, from the closure's own MIR). + self.refined_spans.retain(|covspan| !covspan.is_closure); + self.refined_spans + } + + fn push_refined_span(&mut self, covspan: CoverageSpan) { + let len = self.refined_spans.len(); + if len > 0 { + let last = &mut self.refined_spans[len - 1]; + if last.is_mergeable(&covspan) { + debug!( + "merging new refined span with last refined span, last={:?}, covspan={:?}", + last, covspan + ); + last.merge_from(covspan); + return; + } + } + self.refined_spans.push(covspan) + } + + fn check_invoked_macro_name_span(&mut self) { + if let Some(visible_macro) = self.curr().visible_macro(self.body_span) { + if self.prev_expn_span.map_or(true, |prev_expn_span| { + self.curr().expn_span.ctxt() != prev_expn_span.ctxt() + }) { + let merged_prefix_len = self.curr_original_span.lo() - self.curr().span.lo(); + let after_macro_bang = + merged_prefix_len + BytePos(visible_macro.as_str().len() as u32 + 1); + let mut macro_name_cov = self.curr().clone(); + self.curr_mut().span = + self.curr().span.with_lo(self.curr().span.lo() + after_macro_bang); + macro_name_cov.span = + macro_name_cov.span.with_hi(macro_name_cov.span.lo() + after_macro_bang); + debug!( + " and curr starts a new macro expansion, so add a new span just for \ + the macro `{}!`, new span={:?}", + visible_macro, macro_name_cov + ); + self.push_refined_span(macro_name_cov); + } + } + } + + // Generate a set of `CoverageSpan`s from the filtered set of `Statement`s and `Terminator`s of + // the `BasicBlock`(s) in the given `BasicCoverageBlockData`. One `CoverageSpan` is generated + // for each `Statement` and `Terminator`. (Note that subsequent stages of coverage analysis will + // merge some `CoverageSpan`s, at which point a `CoverageSpan` may represent multiple + // `Statement`s and/or `Terminator`s.) + fn bcb_to_initial_coverage_spans( + &self, + bcb: BasicCoverageBlock, + bcb_data: &'a BasicCoverageBlockData, + ) -> Vec<CoverageSpan> { + bcb_data + .basic_blocks + .iter() + .flat_map(|&bb| { + let data = &self.mir_body[bb]; + data.statements + .iter() + .enumerate() + .filter_map(move |(index, statement)| { + filtered_statement_span(statement).map(|span| { + CoverageSpan::for_statement( + statement, + function_source_span(span, self.body_span), + span, + bcb, + bb, + index, + ) + }) + }) + .chain(filtered_terminator_span(data.terminator()).map(|span| { + CoverageSpan::for_terminator( + function_source_span(span, self.body_span), + span, + bcb, + bb, + ) + })) + }) + .collect() + } + + fn curr(&self) -> &CoverageSpan { + self.some_curr + .as_ref() + .unwrap_or_else(|| bug!("invalid attempt to unwrap a None some_curr")) + } + + fn curr_mut(&mut self) -> &mut CoverageSpan { + self.some_curr + .as_mut() + .unwrap_or_else(|| bug!("invalid attempt to unwrap a None some_curr")) + } + + fn prev(&self) -> &CoverageSpan { + self.some_prev + .as_ref() + .unwrap_or_else(|| bug!("invalid attempt to unwrap a None some_prev")) + } + + fn prev_mut(&mut self) -> &mut CoverageSpan { + self.some_prev + .as_mut() + .unwrap_or_else(|| bug!("invalid attempt to unwrap a None some_prev")) + } + + fn take_prev(&mut self) -> CoverageSpan { + self.some_prev.take().unwrap_or_else(|| bug!("invalid attempt to unwrap a None some_prev")) + } + + /// If there are `pending_dups` but `prev` is not a matching dup (`prev.span` doesn't match the + /// `pending_dups` spans), then one of the following two things happened during the previous + /// iteration: + /// * the previous `curr` span (which is now `prev`) was not a duplicate of the pending_dups + /// (in which case there should be at least two spans in `pending_dups`); or + /// * the `span` of `prev` was modified by `curr_mut().merge_from(prev)` (in which case + /// `pending_dups` could have as few as one span) + /// In either case, no more spans will match the span of `pending_dups`, so + /// add the `pending_dups` if they don't overlap `curr`, and clear the list. + fn check_pending_dups(&mut self) { + if let Some(dup) = self.pending_dups.last() && dup.span != self.prev().span { + debug!( + " SAME spans, but pending_dups are NOT THE SAME, so BCBs matched on \ + previous iteration, or prev started a new disjoint span" + ); + if dup.span.hi() <= self.curr().span.lo() { + let pending_dups = self.pending_dups.split_off(0); + for dup in pending_dups.into_iter() { + debug!(" ...adding at least one pending={:?}", dup); + self.push_refined_span(dup); + } + } else { + self.pending_dups.clear(); + } + } + } + + /// Advance `prev` to `curr` (if any), and `curr` to the next `CoverageSpan` in sorted order. + fn next_coverage_span(&mut self) -> bool { + if let Some(curr) = self.some_curr.take() { + self.prev_expn_span = Some(curr.expn_span); + self.some_prev = Some(curr); + self.prev_original_span = self.curr_original_span; + } + while let Some(curr) = self.sorted_spans_iter.as_mut().unwrap().next() { + debug!("FOR curr={:?}", curr); + if self.some_prev.is_some() && self.prev_starts_after_next(&curr) { + debug!( + " prev.span starts after curr.span, so curr will be dropped (skipping past \ + closure?); prev={:?}", + self.prev() + ); + } else { + // Save a copy of the original span for `curr` in case the `CoverageSpan` is changed + // by `self.curr_mut().merge_from(prev)`. + self.curr_original_span = curr.span; + self.some_curr.replace(curr); + self.check_pending_dups(); + return true; + } + } + false + } + + /// If called, then the next call to `next_coverage_span()` will *not* update `prev` with the + /// `curr` coverage span. + fn take_curr(&mut self) -> CoverageSpan { + self.some_curr.take().unwrap_or_else(|| bug!("invalid attempt to unwrap a None some_curr")) + } + + /// Returns true if the curr span should be skipped because prev has already advanced beyond the + /// end of curr. This can only happen if a prior iteration updated `prev` to skip past a region + /// of code, such as skipping past a closure. + fn prev_starts_after_next(&self, next_curr: &CoverageSpan) -> bool { + self.prev().span.lo() > next_curr.span.lo() + } + + /// Returns true if the curr span starts past the end of the prev span, which means they don't + /// overlap, so we now know the prev can be added to the refined coverage spans. + fn prev_ends_before_curr(&self) -> bool { + self.prev().span.hi() <= self.curr().span.lo() + } + + /// If `prev`s span extends left of the closure (`curr`), carve out the closure's span from + /// `prev`'s span. (The closure's coverage counters will be injected when processing the + /// closure's own MIR.) Add the portion of the span to the left of the closure; and if the span + /// extends to the right of the closure, update `prev` to that portion of the span. For any + /// `pending_dups`, repeat the same process. + fn carve_out_span_for_closure(&mut self) { + let curr_span = self.curr().span; + let left_cutoff = curr_span.lo(); + let right_cutoff = curr_span.hi(); + let has_pre_closure_span = self.prev().span.lo() < right_cutoff; + let has_post_closure_span = self.prev().span.hi() > right_cutoff; + let mut pending_dups = self.pending_dups.split_off(0); + if has_pre_closure_span { + let mut pre_closure = self.prev().clone(); + pre_closure.span = pre_closure.span.with_hi(left_cutoff); + debug!(" prev overlaps a closure. Adding span for pre_closure={:?}", pre_closure); + if !pending_dups.is_empty() { + for mut dup in pending_dups.iter().cloned() { + dup.span = dup.span.with_hi(left_cutoff); + debug!(" ...and at least one pre_closure dup={:?}", dup); + self.push_refined_span(dup); + } + } + self.push_refined_span(pre_closure); + } + if has_post_closure_span { + // Mutate `prev.span()` to start after the closure (and discard curr). + // (**NEVER** update `prev_original_span` because it affects the assumptions + // about how the `CoverageSpan`s are ordered.) + self.prev_mut().span = self.prev().span.with_lo(right_cutoff); + debug!(" Mutated prev.span to start after the closure. prev={:?}", self.prev()); + for dup in pending_dups.iter_mut() { + debug!(" ...and at least one overlapping dup={:?}", dup); + dup.span = dup.span.with_lo(right_cutoff); + } + self.pending_dups.append(&mut pending_dups); + let closure_covspan = self.take_curr(); + self.push_refined_span(closure_covspan); // since self.prev() was already updated + } else { + pending_dups.clear(); + } + } + + /// Called if `curr.span` equals `prev_original_span` (and potentially equal to all + /// `pending_dups` spans, if any). Keep in mind, `prev.span()` may have been changed. + /// If prev.span() was merged into other spans (with matching BCB, for instance), + /// `prev.span.hi()` will be greater than (further right of) `prev_original_span.hi()`. + /// If prev.span() was split off to the right of a closure, prev.span().lo() will be + /// greater than prev_original_span.lo(). The actual span of `prev_original_span` is + /// not as important as knowing that `prev()` **used to have the same span** as `curr(), + /// which means their sort order is still meaningful for determining the dominator + /// relationship. + /// + /// When two `CoverageSpan`s have the same `Span`, dominated spans can be discarded; but if + /// neither `CoverageSpan` dominates the other, both (or possibly more than two) are held, + /// until their disposition is determined. In this latter case, the `prev` dup is moved into + /// `pending_dups` so the new `curr` dup can be moved to `prev` for the next iteration. + fn hold_pending_dups_unless_dominated(&mut self) { + // Equal coverage spans are ordered by dominators before dominated (if any), so it should be + // impossible for `curr` to dominate any previous `CoverageSpan`. + debug_assert!(!self.span_bcb_is_dominated_by(self.prev(), self.curr())); + + let initial_pending_count = self.pending_dups.len(); + if initial_pending_count > 0 { + let mut pending_dups = self.pending_dups.split_off(0); + pending_dups.retain(|dup| !self.span_bcb_is_dominated_by(self.curr(), dup)); + self.pending_dups.append(&mut pending_dups); + if self.pending_dups.len() < initial_pending_count { + debug!( + " discarded {} of {} pending_dups that dominated curr", + initial_pending_count - self.pending_dups.len(), + initial_pending_count + ); + } + } + + if self.span_bcb_is_dominated_by(self.curr(), self.prev()) { + debug!( + " different bcbs but SAME spans, and prev dominates curr. Discard prev={:?}", + self.prev() + ); + self.cutoff_prev_at_overlapping_curr(); + // If one span dominates the other, associate the span with the code from the dominated + // block only (`curr`), and discard the overlapping portion of the `prev` span. (Note + // that if `prev.span` is wider than `prev_original_span`, a `CoverageSpan` will still + // be created for `prev`s block, for the non-overlapping portion, left of `curr.span`.) + // + // For example: + // match somenum { + // x if x < 1 => { ... } + // }... + // + // The span for the first `x` is referenced by both the pattern block (every time it is + // evaluated) and the arm code (only when matched). The counter will be applied only to + // the dominated block. This allows coverage to track and highlight things like the + // assignment of `x` above, if the branch is matched, making `x` available to the arm + // code; and to track and highlight the question mark `?` "try" operator at the end of + // a function call returning a `Result`, so the `?` is covered when the function returns + // an `Err`, and not counted as covered if the function always returns `Ok`. + } else { + // Save `prev` in `pending_dups`. (`curr` will become `prev` in the next iteration.) + // If the `curr` CoverageSpan is later discarded, `pending_dups` can be discarded as + // well; but if `curr` is added to refined_spans, the `pending_dups` will also be added. + debug!( + " different bcbs but SAME spans, and neither dominates, so keep curr for \ + next iter, and, pending upcoming spans (unless overlapping) add prev={:?}", + self.prev() + ); + let prev = self.take_prev(); + self.pending_dups.push(prev); + } + } + + /// `curr` overlaps `prev`. If `prev`s span extends left of `curr`s span, keep _only_ + /// statements that end before `curr.lo()` (if any), and add the portion of the + /// combined span for those statements. Any other statements have overlapping spans + /// that can be ignored because `curr` and/or other upcoming statements/spans inside + /// the overlap area will produce their own counters. This disambiguation process + /// avoids injecting multiple counters for overlapping spans, and the potential for + /// double-counting. + fn cutoff_prev_at_overlapping_curr(&mut self) { + debug!( + " different bcbs, overlapping spans, so ignore/drop pending and only add prev \ + if it has statements that end before curr; prev={:?}", + self.prev() + ); + if self.pending_dups.is_empty() { + let curr_span = self.curr().span; + self.prev_mut().cutoff_statements_at(curr_span.lo()); + if self.prev().coverage_statements.is_empty() { + debug!(" ... no non-overlapping statements to add"); + } else { + debug!(" ... adding modified prev={:?}", self.prev()); + let prev = self.take_prev(); + self.push_refined_span(prev); + } + } else { + // with `pending_dups`, `prev` cannot have any statements that don't overlap + self.pending_dups.clear(); + } + } + + fn span_bcb_is_dominated_by(&self, covspan: &CoverageSpan, dom_covspan: &CoverageSpan) -> bool { + self.basic_coverage_blocks.is_dominated_by(covspan.bcb, dom_covspan.bcb) + } +} + +/// If the MIR `Statement` has a span contributive to computing coverage spans, +/// return it; otherwise return `None`. +pub(super) fn filtered_statement_span(statement: &Statement<'_>) -> Option<Span> { + match statement.kind { + // These statements have spans that are often outside the scope of the executed source code + // for their parent `BasicBlock`. + StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + // Coverage should not be encountered, but don't inject coverage coverage + | StatementKind::Coverage(_) + // Ignore `Nop`s + | StatementKind::Nop => None, + + // FIXME(#78546): MIR InstrumentCoverage - Can the source_info.span for `FakeRead` + // statements be more consistent? + // + // FakeReadCause::ForGuardBinding, in this example: + // match somenum { + // x if x < 1 => { ... } + // }... + // The BasicBlock within the match arm code included one of these statements, but the span + // for it covered the `1` in this source. The actual statements have nothing to do with that + // source span: + // FakeRead(ForGuardBinding, _4); + // where `_4` is: + // _4 = &_1; (at the span for the first `x`) + // and `_1` is the `Place` for `somenum`. + // + // If and when the Issue is resolved, remove this special case match pattern: + StatementKind::FakeRead(box (cause, _)) if cause == FakeReadCause::ForGuardBinding => None, + + // Retain spans from all other statements + StatementKind::FakeRead(box (_, _)) // Not including `ForGuardBinding` + | StatementKind::CopyNonOverlapping(..) + | StatementKind::Assign(_) + | StatementKind::SetDiscriminant { .. } + | StatementKind::Deinit(..) + | StatementKind::Retag(_, _) + | StatementKind::AscribeUserType(_, _) => { + Some(statement.source_info.span) + } + } +} + +/// If the MIR `Terminator` has a span contributive to computing coverage spans, +/// return it; otherwise return `None`. +pub(super) fn filtered_terminator_span(terminator: &Terminator<'_>) -> Option<Span> { + match terminator.kind { + // These terminators have spans that don't positively contribute to computing a reasonable + // span of actually executed source code. (For example, SwitchInt terminators extracted from + // an `if condition { block }` has a span that includes the executed block, if true, + // but for coverage, the code region executed, up to *and* through the SwitchInt, + // actually stops before the if's block.) + TerminatorKind::Unreachable // Unreachable blocks are not connected to the MIR CFG + | TerminatorKind::Assert { .. } + | TerminatorKind::Drop { .. } + | TerminatorKind::DropAndReplace { .. } + | TerminatorKind::SwitchInt { .. } + // For `FalseEdge`, only the `real` branch is taken, so it is similar to a `Goto`. + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::Goto { .. } => None, + + // Call `func` operand can have a more specific span when part of a chain of calls + | TerminatorKind::Call { ref func, .. } => { + let mut span = terminator.source_info.span; + if let mir::Operand::Constant(box constant) = func { + if constant.span.lo() > span.lo() { + span = span.with_lo(constant.span.lo()); + } + } + Some(span) + } + + // Retain spans from all other terminators + TerminatorKind::Resume + | TerminatorKind::Abort + | TerminatorKind::Return + | TerminatorKind::Yield { .. } + | TerminatorKind::GeneratorDrop + | TerminatorKind::FalseUnwind { .. } + | TerminatorKind::InlineAsm { .. } => { + Some(terminator.source_info.span) + } + } +} + +/// Returns an extrapolated span (pre-expansion[^1]) corresponding to a range +/// within the function's body source. This span is guaranteed to be contained +/// within, or equal to, the `body_span`. If the extrapolated span is not +/// contained within the `body_span`, the `body_span` is returned. +/// +/// [^1]Expansions result from Rust syntax including macros, syntactic sugar, +/// etc.). +#[inline] +pub(super) fn function_source_span(span: Span, body_span: Span) -> Span { + let original_span = original_sp(span, body_span).with_ctxt(body_span.ctxt()); + if body_span.contains(original_span) { original_span } else { body_span } +} diff --git a/compiler/rustc_mir_transform/src/coverage/test_macros/Cargo.toml b/compiler/rustc_mir_transform/src/coverage/test_macros/Cargo.toml new file mode 100644 index 000000000..f5e8b6565 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/test_macros/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "coverage_test_macros" +version = "0.0.0" +edition = "2021" + +[lib] +proc-macro = true +doctest = false diff --git a/compiler/rustc_mir_transform/src/coverage/test_macros/src/lib.rs b/compiler/rustc_mir_transform/src/coverage/test_macros/src/lib.rs new file mode 100644 index 000000000..3d6095d27 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/test_macros/src/lib.rs @@ -0,0 +1,6 @@ +use proc_macro::TokenStream; + +#[proc_macro] +pub fn let_bcb(item: TokenStream) -> TokenStream { + format!("let bcb{} = graph::BasicCoverageBlock::from_usize({});", item, item).parse().unwrap() +} diff --git a/compiler/rustc_mir_transform/src/coverage/tests.rs b/compiler/rustc_mir_transform/src/coverage/tests.rs new file mode 100644 index 000000000..6380f0352 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/tests.rs @@ -0,0 +1,710 @@ +//! This crate hosts a selection of "unit tests" for components of the `InstrumentCoverage` MIR +//! pass. +//! +//! ```shell +//! ./x.py test --keep-stage 1 compiler/rustc_mir --test-args '--show-output coverage' +//! ``` +//! +//! The tests construct a few "mock" objects, as needed, to support the `InstrumentCoverage` +//! functions and algorithms. Mocked objects include instances of `mir::Body`; including +//! `Terminator`s of various `kind`s, and `Span` objects. Some functions used by or used on +//! real, runtime versions of these mocked-up objects have constraints (such as cross-thread +//! limitations) and deep dependencies on other elements of the full Rust compiler (which is +//! *not* constructed or mocked for these tests). +//! +//! Of particular note, attempting to simply print elements of the `mir::Body` with default +//! `Debug` formatting can fail because some `Debug` format implementations require the +//! `TyCtxt`, obtained via a static global variable that is *not* set for these tests. +//! Initializing the global type context is prohibitively complex for the scope and scale of these +//! tests (essentially requiring initializing the entire compiler). +//! +//! Also note, some basic features of `Span` also rely on the `Span`s own "session globals", which +//! are unrelated to the `TyCtxt` global. Without initializing the `Span` session globals, some +//! basic, coverage-specific features would be impossible to test, but thankfully initializing these +//! globals is comparatively simpler. The easiest way is to wrap the test in a closure argument +//! to: `rustc_span::create_default_session_globals_then(|| { test_here(); })`. + +use super::counters; +use super::debug; +use super::graph; +use super::spans; + +use coverage_test_macros::let_bcb; + +use itertools::Itertools; +use rustc_data_structures::graph::WithNumNodes; +use rustc_data_structures::graph::WithSuccessors; +use rustc_index::vec::{Idx, IndexVec}; +use rustc_middle::mir::coverage::CoverageKind; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, Ty, TyCtxt}; +use rustc_span::{self, BytePos, Pos, Span, DUMMY_SP}; + +// All `TEMP_BLOCK` targets should be replaced before calling `to_body() -> mir::Body`. +const TEMP_BLOCK: BasicBlock = BasicBlock::MAX; + +struct MockBlocks<'tcx> { + blocks: IndexVec<BasicBlock, BasicBlockData<'tcx>>, + dummy_place: Place<'tcx>, + next_local: usize, + bool_ty: Ty<'tcx>, +} + +impl<'tcx> MockBlocks<'tcx> { + fn new() -> Self { + Self { + blocks: IndexVec::new(), + dummy_place: Place { local: RETURN_PLACE, projection: ty::List::empty() }, + next_local: 0, + bool_ty: TyCtxt::BOOL_TY_FOR_UNIT_TESTING, + } + } + + fn new_temp(&mut self) -> Local { + let index = self.next_local; + self.next_local += 1; + Local::new(index) + } + + fn push(&mut self, kind: TerminatorKind<'tcx>) -> BasicBlock { + let next_lo = if let Some(last) = self.blocks.last() { + self.blocks[last].terminator().source_info.span.hi() + } else { + BytePos(1) + }; + let next_hi = next_lo + BytePos(1); + self.blocks.push(BasicBlockData { + statements: vec![], + terminator: Some(Terminator { + source_info: SourceInfo::outermost(Span::with_root_ctxt(next_lo, next_hi)), + kind, + }), + is_cleanup: false, + }) + } + + fn link(&mut self, from_block: BasicBlock, to_block: BasicBlock) { + match self.blocks[from_block].terminator_mut().kind { + TerminatorKind::Assert { ref mut target, .. } + | TerminatorKind::Call { target: Some(ref mut target), .. } + | TerminatorKind::Drop { ref mut target, .. } + | TerminatorKind::DropAndReplace { ref mut target, .. } + | TerminatorKind::FalseEdge { real_target: ref mut target, .. } + | TerminatorKind::FalseUnwind { real_target: ref mut target, .. } + | TerminatorKind::Goto { ref mut target } + | TerminatorKind::InlineAsm { destination: Some(ref mut target), .. } + | TerminatorKind::Yield { resume: ref mut target, .. } => *target = to_block, + ref invalid => bug!("Invalid from_block: {:?}", invalid), + } + } + + fn add_block_from( + &mut self, + some_from_block: Option<BasicBlock>, + to_kind: TerminatorKind<'tcx>, + ) -> BasicBlock { + let new_block = self.push(to_kind); + if let Some(from_block) = some_from_block { + self.link(from_block, new_block); + } + new_block + } + + fn set_branch(&mut self, switchint: BasicBlock, branch_index: usize, to_block: BasicBlock) { + match self.blocks[switchint].terminator_mut().kind { + TerminatorKind::SwitchInt { ref mut targets, .. } => { + let mut branches = targets.iter().collect::<Vec<_>>(); + let otherwise = if branch_index == branches.len() { + to_block + } else { + let old_otherwise = targets.otherwise(); + if branch_index > branches.len() { + branches.push((branches.len() as u128, old_otherwise)); + while branches.len() < branch_index { + branches.push((branches.len() as u128, TEMP_BLOCK)); + } + to_block + } else { + branches[branch_index] = (branch_index as u128, to_block); + old_otherwise + } + }; + *targets = SwitchTargets::new(branches.into_iter(), otherwise); + } + ref invalid => bug!("Invalid BasicBlock kind or no to_block: {:?}", invalid), + } + } + + fn call(&mut self, some_from_block: Option<BasicBlock>) -> BasicBlock { + self.add_block_from( + some_from_block, + TerminatorKind::Call { + func: Operand::Copy(self.dummy_place.clone()), + args: vec![], + destination: self.dummy_place.clone(), + target: Some(TEMP_BLOCK), + cleanup: None, + from_hir_call: false, + fn_span: DUMMY_SP, + }, + ) + } + + fn goto(&mut self, some_from_block: Option<BasicBlock>) -> BasicBlock { + self.add_block_from(some_from_block, TerminatorKind::Goto { target: TEMP_BLOCK }) + } + + fn switchint(&mut self, some_from_block: Option<BasicBlock>) -> BasicBlock { + let switchint_kind = TerminatorKind::SwitchInt { + discr: Operand::Move(Place::from(self.new_temp())), + switch_ty: self.bool_ty, // just a dummy value + targets: SwitchTargets::static_if(0, TEMP_BLOCK, TEMP_BLOCK), + }; + self.add_block_from(some_from_block, switchint_kind) + } + + fn return_(&mut self, some_from_block: Option<BasicBlock>) -> BasicBlock { + self.add_block_from(some_from_block, TerminatorKind::Return) + } + + fn to_body(self) -> Body<'tcx> { + Body::new_cfg_only(self.blocks) + } +} + +fn debug_basic_blocks<'tcx>(mir_body: &Body<'tcx>) -> String { + format!( + "{:?}", + mir_body + .basic_blocks() + .iter_enumerated() + .map(|(bb, data)| { + let term = &data.terminator(); + let kind = &term.kind; + let span = term.source_info.span; + let sp = format!("(span:{},{})", span.lo().to_u32(), span.hi().to_u32()); + match kind { + TerminatorKind::Assert { target, .. } + | TerminatorKind::Call { target: Some(target), .. } + | TerminatorKind::Drop { target, .. } + | TerminatorKind::DropAndReplace { target, .. } + | TerminatorKind::FalseEdge { real_target: target, .. } + | TerminatorKind::FalseUnwind { real_target: target, .. } + | TerminatorKind::Goto { target } + | TerminatorKind::InlineAsm { destination: Some(target), .. } + | TerminatorKind::Yield { resume: target, .. } => { + format!("{}{:?}:{} -> {:?}", sp, bb, debug::term_type(kind), target) + } + TerminatorKind::SwitchInt { targets, .. } => { + format!("{}{:?}:{} -> {:?}", sp, bb, debug::term_type(kind), targets) + } + _ => format!("{}{:?}:{}", sp, bb, debug::term_type(kind)), + } + }) + .collect::<Vec<_>>() + ) +} + +static PRINT_GRAPHS: bool = false; + +fn print_mir_graphviz(name: &str, mir_body: &Body<'_>) { + if PRINT_GRAPHS { + println!( + "digraph {} {{\n{}\n}}", + name, + mir_body + .basic_blocks() + .iter_enumerated() + .map(|(bb, data)| { + format!( + " {:?} [label=\"{:?}: {}\"];\n{}", + bb, + bb, + debug::term_type(&data.terminator().kind), + mir_body + .basic_blocks + .successors(bb) + .map(|successor| { format!(" {:?} -> {:?};", bb, successor) }) + .join("\n") + ) + }) + .join("\n") + ); + } +} + +fn print_coverage_graphviz( + name: &str, + mir_body: &Body<'_>, + basic_coverage_blocks: &graph::CoverageGraph, +) { + if PRINT_GRAPHS { + println!( + "digraph {} {{\n{}\n}}", + name, + basic_coverage_blocks + .iter_enumerated() + .map(|(bcb, bcb_data)| { + format!( + " {:?} [label=\"{:?}: {}\"];\n{}", + bcb, + bcb, + debug::term_type(&bcb_data.terminator(mir_body).kind), + basic_coverage_blocks + .successors(bcb) + .map(|successor| { format!(" {:?} -> {:?};", bcb, successor) }) + .join("\n") + ) + }) + .join("\n") + ); + } +} + +/// Create a mock `Body` with a simple flow. +fn goto_switchint<'a>() -> Body<'a> { + let mut blocks = MockBlocks::new(); + let start = blocks.call(None); + let goto = blocks.goto(Some(start)); + let switchint = blocks.switchint(Some(goto)); + let then_call = blocks.call(None); + let else_call = blocks.call(None); + blocks.set_branch(switchint, 0, then_call); + blocks.set_branch(switchint, 1, else_call); + blocks.return_(Some(then_call)); + blocks.return_(Some(else_call)); + + let mir_body = blocks.to_body(); + print_mir_graphviz("mir_goto_switchint", &mir_body); + /* Graphviz character plots created using: `graph-easy --as=boxart`: + ┌────────────────┐ + │ bb0: Call │ + └────────────────┘ + │ + │ + ▼ + ┌────────────────┐ + │ bb1: Goto │ + └────────────────┘ + │ + │ + ▼ + ┌─────────────┐ ┌────────────────┐ + │ bb4: Call │ ◀── │ bb2: SwitchInt │ + └─────────────┘ └────────────────┘ + │ │ + │ │ + ▼ ▼ + ┌─────────────┐ ┌────────────────┐ + │ bb6: Return │ │ bb3: Call │ + └─────────────┘ └────────────────┘ + │ + │ + ▼ + ┌────────────────┐ + │ bb5: Return │ + └────────────────┘ + */ + mir_body +} + +macro_rules! assert_successors { + ($basic_coverage_blocks:ident, $i:ident, [$($successor:ident),*]) => { + let mut successors = $basic_coverage_blocks.successors[$i].clone(); + successors.sort_unstable(); + assert_eq!(successors, vec![$($successor),*]); + } +} + +#[test] +fn test_covgraph_goto_switchint() { + let mir_body = goto_switchint(); + if false { + eprintln!("basic_blocks = {}", debug_basic_blocks(&mir_body)); + } + let basic_coverage_blocks = graph::CoverageGraph::from_mir(&mir_body); + print_coverage_graphviz("covgraph_goto_switchint ", &mir_body, &basic_coverage_blocks); + /* + ┌──────────────┐ ┌─────────────────┐ + │ bcb2: Return │ ◀── │ bcb0: SwitchInt │ + └──────────────┘ └─────────────────┘ + │ + │ + ▼ + ┌─────────────────┐ + │ bcb1: Return │ + └─────────────────┘ + */ + assert_eq!( + basic_coverage_blocks.num_nodes(), + 3, + "basic_coverage_blocks: {:?}", + basic_coverage_blocks.iter_enumerated().collect::<Vec<_>>() + ); + + let_bcb!(0); + let_bcb!(1); + let_bcb!(2); + + assert_successors!(basic_coverage_blocks, bcb0, [bcb1, bcb2]); + assert_successors!(basic_coverage_blocks, bcb1, []); + assert_successors!(basic_coverage_blocks, bcb2, []); +} + +/// Create a mock `Body` with a loop. +fn switchint_then_loop_else_return<'a>() -> Body<'a> { + let mut blocks = MockBlocks::new(); + let start = blocks.call(None); + let switchint = blocks.switchint(Some(start)); + let then_call = blocks.call(None); + blocks.set_branch(switchint, 0, then_call); + let backedge_goto = blocks.goto(Some(then_call)); + blocks.link(backedge_goto, switchint); + let else_return = blocks.return_(None); + blocks.set_branch(switchint, 1, else_return); + + let mir_body = blocks.to_body(); + print_mir_graphviz("mir_switchint_then_loop_else_return", &mir_body); + /* + ┌────────────────┐ + │ bb0: Call │ + └────────────────┘ + │ + │ + ▼ + ┌─────────────┐ ┌────────────────┐ + │ bb4: Return │ ◀── │ bb1: SwitchInt │ ◀┐ + └─────────────┘ └────────────────┘ │ + │ │ + │ │ + ▼ │ + ┌────────────────┐ │ + │ bb2: Call │ │ + └────────────────┘ │ + │ │ + │ │ + ▼ │ + ┌────────────────┐ │ + │ bb3: Goto │ ─┘ + └────────────────┘ + */ + mir_body +} + +#[test] +fn test_covgraph_switchint_then_loop_else_return() { + let mir_body = switchint_then_loop_else_return(); + let basic_coverage_blocks = graph::CoverageGraph::from_mir(&mir_body); + print_coverage_graphviz( + "covgraph_switchint_then_loop_else_return", + &mir_body, + &basic_coverage_blocks, + ); + /* + ┌─────────────────┐ + │ bcb0: Call │ + └─────────────────┘ + │ + │ + ▼ + ┌────────────┐ ┌─────────────────┐ + │ bcb3: Goto │ ◀── │ bcb1: SwitchInt │ ◀┐ + └────────────┘ └─────────────────┘ │ + │ │ │ + │ │ │ + │ ▼ │ + │ ┌─────────────────┐ │ + │ │ bcb2: Return │ │ + │ └─────────────────┘ │ + │ │ + └─────────────────────────────────────┘ + */ + assert_eq!( + basic_coverage_blocks.num_nodes(), + 4, + "basic_coverage_blocks: {:?}", + basic_coverage_blocks.iter_enumerated().collect::<Vec<_>>() + ); + + let_bcb!(0); + let_bcb!(1); + let_bcb!(2); + let_bcb!(3); + + assert_successors!(basic_coverage_blocks, bcb0, [bcb1]); + assert_successors!(basic_coverage_blocks, bcb1, [bcb2, bcb3]); + assert_successors!(basic_coverage_blocks, bcb2, []); + assert_successors!(basic_coverage_blocks, bcb3, [bcb1]); +} + +/// Create a mock `Body` with nested loops. +fn switchint_loop_then_inner_loop_else_break<'a>() -> Body<'a> { + let mut blocks = MockBlocks::new(); + let start = blocks.call(None); + let switchint = blocks.switchint(Some(start)); + let then_call = blocks.call(None); + blocks.set_branch(switchint, 0, then_call); + let else_return = blocks.return_(None); + blocks.set_branch(switchint, 1, else_return); + + let inner_start = blocks.call(Some(then_call)); + let inner_switchint = blocks.switchint(Some(inner_start)); + let inner_then_call = blocks.call(None); + blocks.set_branch(inner_switchint, 0, inner_then_call); + let inner_backedge_goto = blocks.goto(Some(inner_then_call)); + blocks.link(inner_backedge_goto, inner_switchint); + let inner_else_break_goto = blocks.goto(None); + blocks.set_branch(inner_switchint, 1, inner_else_break_goto); + + let backedge_goto = blocks.goto(Some(inner_else_break_goto)); + blocks.link(backedge_goto, switchint); + + let mir_body = blocks.to_body(); + print_mir_graphviz("mir_switchint_loop_then_inner_loop_else_break", &mir_body); + /* + ┌────────────────┐ + │ bb0: Call │ + └────────────────┘ + │ + │ + ▼ + ┌─────────────┐ ┌────────────────┐ + │ bb3: Return │ ◀── │ bb1: SwitchInt │ ◀─────┐ + └─────────────┘ └────────────────┘ │ + │ │ + │ │ + ▼ │ + ┌────────────────┐ │ + │ bb2: Call │ │ + └────────────────┘ │ + │ │ + │ │ + ▼ │ + ┌────────────────┐ │ + │ bb4: Call │ │ + └────────────────┘ │ + │ │ + │ │ + ▼ │ + ┌─────────────┐ ┌────────────────┐ │ + │ bb8: Goto │ ◀── │ bb5: SwitchInt │ ◀┐ │ + └─────────────┘ └────────────────┘ │ │ + │ │ │ │ + │ │ │ │ + ▼ ▼ │ │ + ┌─────────────┐ ┌────────────────┐ │ │ + │ bb9: Goto │ ─┐ │ bb6: Call │ │ │ + └─────────────┘ │ └────────────────┘ │ │ + │ │ │ │ + │ │ │ │ + │ ▼ │ │ + │ ┌────────────────┐ │ │ + │ │ bb7: Goto │ ─┘ │ + │ └────────────────┘ │ + │ │ + └───────────────────────────┘ + */ + mir_body +} + +#[test] +fn test_covgraph_switchint_loop_then_inner_loop_else_break() { + let mir_body = switchint_loop_then_inner_loop_else_break(); + let basic_coverage_blocks = graph::CoverageGraph::from_mir(&mir_body); + print_coverage_graphviz( + "covgraph_switchint_loop_then_inner_loop_else_break", + &mir_body, + &basic_coverage_blocks, + ); + /* + ┌─────────────────┐ + │ bcb0: Call │ + └─────────────────┘ + │ + │ + ▼ + ┌──────────────┐ ┌─────────────────┐ + │ bcb2: Return │ ◀── │ bcb1: SwitchInt │ ◀┐ + └──────────────┘ └─────────────────┘ │ + │ │ + │ │ + ▼ │ + ┌─────────────────┐ │ + │ bcb3: Call │ │ + └─────────────────┘ │ + │ │ + │ │ + ▼ │ + ┌──────────────┐ ┌─────────────────┐ │ + │ bcb6: Goto │ ◀── │ bcb4: SwitchInt │ ◀┼────┐ + └──────────────┘ └─────────────────┘ │ │ + │ │ │ │ + │ │ │ │ + │ ▼ │ │ + │ ┌─────────────────┐ │ │ + │ │ bcb5: Goto │ ─┘ │ + │ └─────────────────┘ │ + │ │ + └────────────────────────────────────────────┘ + */ + assert_eq!( + basic_coverage_blocks.num_nodes(), + 7, + "basic_coverage_blocks: {:?}", + basic_coverage_blocks.iter_enumerated().collect::<Vec<_>>() + ); + + let_bcb!(0); + let_bcb!(1); + let_bcb!(2); + let_bcb!(3); + let_bcb!(4); + let_bcb!(5); + let_bcb!(6); + + assert_successors!(basic_coverage_blocks, bcb0, [bcb1]); + assert_successors!(basic_coverage_blocks, bcb1, [bcb2, bcb3]); + assert_successors!(basic_coverage_blocks, bcb2, []); + assert_successors!(basic_coverage_blocks, bcb3, [bcb4]); + assert_successors!(basic_coverage_blocks, bcb4, [bcb5, bcb6]); + assert_successors!(basic_coverage_blocks, bcb5, [bcb1]); + assert_successors!(basic_coverage_blocks, bcb6, [bcb4]); +} + +#[test] +fn test_find_loop_backedges_none() { + let mir_body = goto_switchint(); + let basic_coverage_blocks = graph::CoverageGraph::from_mir(&mir_body); + if false { + eprintln!( + "basic_coverage_blocks = {:?}", + basic_coverage_blocks.iter_enumerated().collect::<Vec<_>>() + ); + eprintln!("successors = {:?}", basic_coverage_blocks.successors); + } + let backedges = graph::find_loop_backedges(&basic_coverage_blocks); + assert_eq!( + backedges.iter_enumerated().map(|(_bcb, backedges)| backedges.len()).sum::<usize>(), + 0, + "backedges: {:?}", + backedges + ); +} + +#[test] +fn test_find_loop_backedges_one() { + let mir_body = switchint_then_loop_else_return(); + let basic_coverage_blocks = graph::CoverageGraph::from_mir(&mir_body); + let backedges = graph::find_loop_backedges(&basic_coverage_blocks); + assert_eq!( + backedges.iter_enumerated().map(|(_bcb, backedges)| backedges.len()).sum::<usize>(), + 1, + "backedges: {:?}", + backedges + ); + + let_bcb!(1); + let_bcb!(3); + + assert_eq!(backedges[bcb1], vec![bcb3]); +} + +#[test] +fn test_find_loop_backedges_two() { + let mir_body = switchint_loop_then_inner_loop_else_break(); + let basic_coverage_blocks = graph::CoverageGraph::from_mir(&mir_body); + let backedges = graph::find_loop_backedges(&basic_coverage_blocks); + assert_eq!( + backedges.iter_enumerated().map(|(_bcb, backedges)| backedges.len()).sum::<usize>(), + 2, + "backedges: {:?}", + backedges + ); + + let_bcb!(1); + let_bcb!(4); + let_bcb!(5); + let_bcb!(6); + + assert_eq!(backedges[bcb1], vec![bcb5]); + assert_eq!(backedges[bcb4], vec![bcb6]); +} + +#[test] +fn test_traverse_coverage_with_loops() { + let mir_body = switchint_loop_then_inner_loop_else_break(); + let basic_coverage_blocks = graph::CoverageGraph::from_mir(&mir_body); + let mut traversed_in_order = Vec::new(); + let mut traversal = graph::TraverseCoverageGraphWithLoops::new(&basic_coverage_blocks); + while let Some(bcb) = traversal.next(&basic_coverage_blocks) { + traversed_in_order.push(bcb); + } + + let_bcb!(6); + + // bcb0 is visited first. Then bcb1 starts the first loop, and all remaining nodes, *except* + // bcb6 are inside the first loop. + assert_eq!( + *traversed_in_order.last().expect("should have elements"), + bcb6, + "bcb6 should not be visited until all nodes inside the first loop have been visited" + ); +} + +fn synthesize_body_span_from_terminators(mir_body: &Body<'_>) -> Span { + let mut some_span: Option<Span> = None; + for (_, data) in mir_body.basic_blocks().iter_enumerated() { + let term_span = data.terminator().source_info.span; + if let Some(span) = some_span.as_mut() { + *span = span.to(term_span); + } else { + some_span = Some(term_span) + } + } + some_span.expect("body must have at least one BasicBlock") +} + +#[test] +fn test_make_bcb_counters() { + rustc_span::create_default_session_globals_then(|| { + let mir_body = goto_switchint(); + let body_span = synthesize_body_span_from_terminators(&mir_body); + let mut basic_coverage_blocks = graph::CoverageGraph::from_mir(&mir_body); + let mut coverage_spans = Vec::new(); + for (bcb, data) in basic_coverage_blocks.iter_enumerated() { + if let Some(span) = spans::filtered_terminator_span(data.terminator(&mir_body)) { + coverage_spans.push(spans::CoverageSpan::for_terminator( + spans::function_source_span(span, body_span), + span, + bcb, + data.last_bb(), + )); + } + } + let mut coverage_counters = counters::CoverageCounters::new(0); + let intermediate_expressions = coverage_counters + .make_bcb_counters(&mut basic_coverage_blocks, &coverage_spans) + .expect("should be Ok"); + assert_eq!(intermediate_expressions.len(), 0); + + let_bcb!(1); + assert_eq!( + 1, // coincidentally, bcb1 has a `Counter` with id = 1 + match basic_coverage_blocks[bcb1].counter().expect("should have a counter") { + CoverageKind::Counter { id, .. } => id, + _ => panic!("expected a Counter"), + } + .as_u32() + ); + + let_bcb!(2); + assert_eq!( + 2, // coincidentally, bcb2 has a `Counter` with id = 2 + match basic_coverage_blocks[bcb2].counter().expect("should have a counter") { + CoverageKind::Counter { id, .. } => id, + _ => panic!("expected a Counter"), + } + .as_u32() + ); + }); +} diff --git a/compiler/rustc_mir_transform/src/dead_store_elimination.rs b/compiler/rustc_mir_transform/src/dead_store_elimination.rs new file mode 100644 index 000000000..9163672f5 --- /dev/null +++ b/compiler/rustc_mir_transform/src/dead_store_elimination.rs @@ -0,0 +1,86 @@ +//! This module implements a dead store elimination (DSE) routine. +//! +//! This transformation was written specifically for the needs of dest prop. Although it is +//! perfectly sound to use it in any context that might need it, its behavior should not be changed +//! without analyzing the interaction this will have with dest prop. Specifically, in addition to +//! the soundness of this pass in general, dest prop needs it to satisfy two additional conditions: +//! +//! 1. It's idempotent, meaning that running this pass a second time immediately after running it a +//! first time will not cause any further changes. +//! 2. This idempotence persists across dest prop's main transform, in other words inserting any +//! number of iterations of dest prop between the first and second application of this transform +//! will still not cause any further changes. +//! + +use rustc_index::bit_set::BitSet; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use rustc_mir_dataflow::impls::{borrowed_locals, MaybeTransitiveLiveLocals}; +use rustc_mir_dataflow::Analysis; + +/// Performs the optimization on the body +/// +/// The `borrowed` set must be a `BitSet` of all the locals that are ever borrowed in this body. It +/// can be generated via the [`borrowed_locals`] function. +pub fn eliminate<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, borrowed: &BitSet<Local>) { + let mut live = MaybeTransitiveLiveLocals::new(borrowed) + .into_engine(tcx, body) + .iterate_to_fixpoint() + .into_results_cursor(body); + + let mut patch = Vec::new(); + for (bb, bb_data) in traversal::preorder(body) { + for (statement_index, statement) in bb_data.statements.iter().enumerate().rev() { + let loc = Location { block: bb, statement_index }; + if let StatementKind::Assign(assign) = &statement.kind { + if !assign.1.is_safe_to_remove() { + continue; + } + } + match &statement.kind { + StatementKind::Assign(box (place, _)) + | StatementKind::SetDiscriminant { place: box place, .. } + | StatementKind::Deinit(box place) => { + if !place.is_indirect() && !borrowed.contains(place.local) { + live.seek_before_primary_effect(loc); + if !live.get().contains(place.local) { + patch.push(loc); + } + } + } + StatementKind::Retag(_, _) + | StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::Coverage(_) + | StatementKind::CopyNonOverlapping(_) + | StatementKind::Nop => (), + + StatementKind::FakeRead(_) | StatementKind::AscribeUserType(_, _) => { + bug!("{:?} not found in this MIR phase!", &statement.kind) + } + } + } + } + + if patch.is_empty() { + return; + } + + let bbs = body.basic_blocks.as_mut_preserves_cfg(); + for Location { block, statement_index } in patch { + bbs[block].statements[statement_index].make_nop(); + } +} + +pub struct DeadStoreElimination; + +impl<'tcx> MirPass<'tcx> for DeadStoreElimination { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 2 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let borrowed = borrowed_locals(body); + eliminate(tcx, body, &borrowed); + } +} diff --git a/compiler/rustc_mir_transform/src/deaggregator.rs b/compiler/rustc_mir_transform/src/deaggregator.rs new file mode 100644 index 000000000..b93fe5879 --- /dev/null +++ b/compiler/rustc_mir_transform/src/deaggregator.rs @@ -0,0 +1,49 @@ +use crate::util::expand_aggregate; +use crate::MirPass; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; + +pub struct Deaggregator; + +impl<'tcx> MirPass<'tcx> for Deaggregator { + fn phase_change(&self) -> Option<MirPhase> { + Some(MirPhase::Deaggregated) + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let basic_blocks = body.basic_blocks.as_mut_preserves_cfg(); + for bb in basic_blocks { + bb.expand_statements(|stmt| { + // FIXME(eddyb) don't match twice on `stmt.kind` (post-NLL). + match stmt.kind { + // FIXME(#48193) Deaggregate arrays when it's cheaper to do so. + StatementKind::Assign(box ( + _, + Rvalue::Aggregate(box AggregateKind::Array(_), _), + )) => { + return None; + } + StatementKind::Assign(box (_, Rvalue::Aggregate(_, _))) => {} + _ => return None, + } + + let stmt = stmt.replace_nop(); + let source_info = stmt.source_info; + let StatementKind::Assign(box (lhs, Rvalue::Aggregate(kind, operands))) = stmt.kind else { + bug!(); + }; + + Some(expand_aggregate( + lhs, + operands.into_iter().map(|op| { + let ty = op.ty(&body.local_decls, tcx); + (op, ty) + }), + *kind, + source_info, + tcx, + )) + }); + } + } +} diff --git a/compiler/rustc_mir_transform/src/deduplicate_blocks.rs b/compiler/rustc_mir_transform/src/deduplicate_blocks.rs new file mode 100644 index 000000000..d1977ed49 --- /dev/null +++ b/compiler/rustc_mir_transform/src/deduplicate_blocks.rs @@ -0,0 +1,190 @@ +//! This pass finds basic blocks that are completely equal, +//! and replaces all uses with just one of them. + +use std::{collections::hash_map::Entry, hash::Hash, hash::Hasher, iter}; + +use crate::MirPass; + +use rustc_data_structures::fx::FxHashMap; +use rustc_middle::mir::visit::MutVisitor; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; + +use super::simplify::simplify_cfg; + +pub struct DeduplicateBlocks; + +impl<'tcx> MirPass<'tcx> for DeduplicateBlocks { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 4 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + debug!("Running DeduplicateBlocks on `{:?}`", body.source); + let duplicates = find_duplicates(body); + let has_opts_to_apply = !duplicates.is_empty(); + + if has_opts_to_apply { + let mut opt_applier = OptApplier { tcx, duplicates }; + opt_applier.visit_body(body); + simplify_cfg(tcx, body); + } + } +} + +struct OptApplier<'tcx> { + tcx: TyCtxt<'tcx>, + duplicates: FxHashMap<BasicBlock, BasicBlock>, +} + +impl<'tcx> MutVisitor<'tcx> for OptApplier<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) { + for target in terminator.successors_mut() { + if let Some(replacement) = self.duplicates.get(target) { + debug!("SUCCESS: Replacing: `{:?}` with `{:?}`", target, replacement); + *target = *replacement; + } + } + + self.super_terminator(terminator, location); + } +} + +fn find_duplicates(body: &Body<'_>) -> FxHashMap<BasicBlock, BasicBlock> { + let mut duplicates = FxHashMap::default(); + + let bbs_to_go_through = + body.basic_blocks().iter_enumerated().filter(|(_, bbd)| !bbd.is_cleanup).count(); + + let mut same_hashes = + FxHashMap::with_capacity_and_hasher(bbs_to_go_through, Default::default()); + + // Go through the basic blocks backwards. This means that in case of duplicates, + // we can use the basic block with the highest index as the replacement for all lower ones. + // For example, if bb1, bb2 and bb3 are duplicates, we will first insert bb3 in same_hashes. + // Then we will see that bb2 is a duplicate of bb3, + // and insert bb2 with the replacement bb3 in the duplicates list. + // When we see bb1, we see that it is a duplicate of bb3, and therefore insert it in the duplicates list + // with replacement bb3. + // When the duplicates are removed, we will end up with only bb3. + for (bb, bbd) in body.basic_blocks().iter_enumerated().rev().filter(|(_, bbd)| !bbd.is_cleanup) + { + // Basic blocks can get really big, so to avoid checking for duplicates in basic blocks + // that are unlikely to have duplicates, we stop early. The early bail number has been + // found experimentally by eprintln while compiling the crates in the rustc-perf suite. + if bbd.statements.len() > 10 { + continue; + } + + let to_hash = BasicBlockHashable { basic_block_data: bbd }; + let entry = same_hashes.entry(to_hash); + match entry { + Entry::Occupied(occupied) => { + // The basic block was already in the hashmap, which means we have a duplicate + let value = *occupied.get(); + debug!("Inserting {:?} -> {:?}", bb, value); + duplicates.try_insert(bb, value).expect("key was already inserted"); + } + Entry::Vacant(vacant) => { + vacant.insert(bb); + } + } + } + + duplicates +} + +struct BasicBlockHashable<'tcx, 'a> { + basic_block_data: &'a BasicBlockData<'tcx>, +} + +impl Hash for BasicBlockHashable<'_, '_> { + fn hash<H: Hasher>(&self, state: &mut H) { + hash_statements(state, self.basic_block_data.statements.iter()); + // Note that since we only hash the kind, we lose span information if we deduplicate the blocks + self.basic_block_data.terminator().kind.hash(state); + } +} + +impl Eq for BasicBlockHashable<'_, '_> {} + +impl PartialEq for BasicBlockHashable<'_, '_> { + fn eq(&self, other: &Self) -> bool { + self.basic_block_data.statements.len() == other.basic_block_data.statements.len() + && &self.basic_block_data.terminator().kind == &other.basic_block_data.terminator().kind + && iter::zip(&self.basic_block_data.statements, &other.basic_block_data.statements) + .all(|(x, y)| statement_eq(&x.kind, &y.kind)) + } +} + +fn hash_statements<'a, 'tcx, H: Hasher>( + hasher: &mut H, + iter: impl Iterator<Item = &'a Statement<'tcx>>, +) where + 'tcx: 'a, +{ + for stmt in iter { + statement_hash(hasher, &stmt.kind); + } +} + +fn statement_hash<H: Hasher>(hasher: &mut H, stmt: &StatementKind<'_>) { + match stmt { + StatementKind::Assign(box (place, rvalue)) => { + place.hash(hasher); + rvalue_hash(hasher, rvalue) + } + x => x.hash(hasher), + }; +} + +fn rvalue_hash<H: Hasher>(hasher: &mut H, rvalue: &Rvalue<'_>) { + match rvalue { + Rvalue::Use(op) => operand_hash(hasher, op), + x => x.hash(hasher), + }; +} + +fn operand_hash<H: Hasher>(hasher: &mut H, operand: &Operand<'_>) { + match operand { + Operand::Constant(box Constant { user_ty: _, literal, span: _ }) => literal.hash(hasher), + x => x.hash(hasher), + }; +} + +fn statement_eq<'tcx>(lhs: &StatementKind<'tcx>, rhs: &StatementKind<'tcx>) -> bool { + let res = match (lhs, rhs) { + ( + StatementKind::Assign(box (place, rvalue)), + StatementKind::Assign(box (place2, rvalue2)), + ) => place == place2 && rvalue_eq(rvalue, rvalue2), + (x, y) => x == y, + }; + debug!("statement_eq lhs: `{:?}` rhs: `{:?}` result: {:?}", lhs, rhs, res); + res +} + +fn rvalue_eq<'tcx>(lhs: &Rvalue<'tcx>, rhs: &Rvalue<'tcx>) -> bool { + let res = match (lhs, rhs) { + (Rvalue::Use(op1), Rvalue::Use(op2)) => operand_eq(op1, op2), + (x, y) => x == y, + }; + debug!("rvalue_eq lhs: `{:?}` rhs: `{:?}` result: {:?}", lhs, rhs, res); + res +} + +fn operand_eq<'tcx>(lhs: &Operand<'tcx>, rhs: &Operand<'tcx>) -> bool { + let res = match (lhs, rhs) { + ( + Operand::Constant(box Constant { user_ty: _, literal, span: _ }), + Operand::Constant(box Constant { user_ty: _, literal: literal2, span: _ }), + ) => literal == literal2, + (x, y) => x == y, + }; + debug!("operand_eq lhs: `{:?}` rhs: `{:?}` result: {:?}", lhs, rhs, res); + res +} diff --git a/compiler/rustc_mir_transform/src/deref_separator.rs b/compiler/rustc_mir_transform/src/deref_separator.rs new file mode 100644 index 000000000..a00bb16f7 --- /dev/null +++ b/compiler/rustc_mir_transform/src/deref_separator.rs @@ -0,0 +1,105 @@ +use crate::MirPass; +use rustc_index::vec::IndexVec; +use rustc_middle::mir::patch::MirPatch; +use rustc_middle::mir::visit::NonUseContext::VarDebugInfo; +use rustc_middle::mir::visit::{MutVisitor, PlaceContext}; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; + +pub struct Derefer; + +pub struct DerefChecker<'tcx> { + tcx: TyCtxt<'tcx>, + patcher: MirPatch<'tcx>, + local_decls: IndexVec<Local, LocalDecl<'tcx>>, +} + +impl<'tcx> MutVisitor<'tcx> for DerefChecker<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_place(&mut self, place: &mut Place<'tcx>, cntxt: PlaceContext, loc: Location) { + if !place.projection.is_empty() + && cntxt != PlaceContext::NonUse(VarDebugInfo) + && place.projection[1..].contains(&ProjectionElem::Deref) + { + let mut place_local = place.local; + let mut last_len = 0; + let mut last_deref_idx = 0; + + let mut prev_temp: Option<Local> = None; + + for (idx, elem) in place.projection[0..].iter().enumerate() { + if *elem == ProjectionElem::Deref { + last_deref_idx = idx; + } + } + + for (idx, (p_ref, p_elem)) in place.iter_projections().enumerate() { + if !p_ref.projection.is_empty() && p_elem == ProjectionElem::Deref { + let ty = p_ref.ty(&self.local_decls, self.tcx).ty; + let temp = self.patcher.new_local_with_info( + ty, + self.local_decls[p_ref.local].source_info.span, + Some(Box::new(LocalInfo::DerefTemp)), + ); + + self.patcher.add_statement(loc, StatementKind::StorageLive(temp)); + + // We are adding current p_ref's projections to our + // temp value, excluding projections we already covered. + let deref_place = Place::from(place_local) + .project_deeper(&p_ref.projection[last_len..], self.tcx); + + self.patcher.add_assign( + loc, + Place::from(temp), + Rvalue::CopyForDeref(deref_place), + ); + place_local = temp; + last_len = p_ref.projection.len(); + + // Change `Place` only if we are actually at the Place's last deref + if idx == last_deref_idx { + let temp_place = + Place::from(temp).project_deeper(&place.projection[idx..], self.tcx); + *place = temp_place; + } + + // We are destroying the previous temp since it's no longer used. + if let Some(prev_temp) = prev_temp { + self.patcher.add_statement(loc, StatementKind::StorageDead(prev_temp)); + } + + prev_temp = Some(temp); + } + } + + // Since we won't be able to reach final temp, we destroy it outside the loop. + if let Some(prev_temp) = prev_temp { + let last_loc = + Location { block: loc.block, statement_index: loc.statement_index + 1 }; + self.patcher.add_statement(last_loc, StatementKind::StorageDead(prev_temp)); + } + } + } +} + +pub fn deref_finder<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let patch = MirPatch::new(body); + let mut checker = DerefChecker { tcx, patcher: patch, local_decls: body.local_decls.clone() }; + + for (bb, data) in body.basic_blocks_mut().iter_enumerated_mut() { + checker.visit_basic_block_data(bb, data); + } + + checker.patcher.apply(body); +} + +impl<'tcx> MirPass<'tcx> for Derefer { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + deref_finder(tcx, body); + body.phase = MirPhase::Derefered; + } +} diff --git a/compiler/rustc_mir_transform/src/dest_prop.rs b/compiler/rustc_mir_transform/src/dest_prop.rs new file mode 100644 index 000000000..33572068f --- /dev/null +++ b/compiler/rustc_mir_transform/src/dest_prop.rs @@ -0,0 +1,917 @@ +//! Propagates assignment destinations backwards in the CFG to eliminate redundant assignments. +//! +//! # Motivation +//! +//! MIR building can insert a lot of redundant copies, and Rust code in general often tends to move +//! values around a lot. The result is a lot of assignments of the form `dest = {move} src;` in MIR. +//! MIR building for constants in particular tends to create additional locals that are only used +//! inside a single block to shuffle a value around unnecessarily. +//! +//! LLVM by itself is not good enough at eliminating these redundant copies (eg. see +//! <https://github.com/rust-lang/rust/issues/32966>), so this leaves some performance on the table +//! that we can regain by implementing an optimization for removing these assign statements in rustc +//! itself. When this optimization runs fast enough, it can also speed up the constant evaluation +//! and code generation phases of rustc due to the reduced number of statements and locals. +//! +//! # The Optimization +//! +//! Conceptually, this optimization is "destination propagation". It is similar to the Named Return +//! Value Optimization, or NRVO, known from the C++ world, except that it isn't limited to return +//! values or the return place `_0`. On a very high level, independent of the actual implementation +//! details, it does the following: +//! +//! 1) Identify `dest = src;` statements that can be soundly eliminated. +//! 2) Replace all mentions of `src` with `dest` ("unifying" them and propagating the destination +//! backwards). +//! 3) Delete the `dest = src;` statement (by making it a `nop`). +//! +//! Step 1) is by far the hardest, so it is explained in more detail below. +//! +//! ## Soundness +//! +//! Given an `Assign` statement `dest = src;`, where `dest` is a `Place` and `src` is an `Rvalue`, +//! there are a few requirements that must hold for the optimization to be sound: +//! +//! * `dest` must not contain any *indirection* through a pointer. It must access part of the base +//! local. Otherwise it might point to arbitrary memory that is hard to track. +//! +//! It must also not contain any indexing projections, since those take an arbitrary `Local` as +//! the index, and that local might only be initialized shortly before `dest` is used. +//! +//! * `src` must be a bare `Local` without any indirections or field projections (FIXME: Is this a +//! fundamental restriction or just current impl state?). It can be copied or moved by the +//! assignment. +//! +//! * The `dest` and `src` locals must never be [*live*][liveness] at the same time. If they are, it +//! means that they both hold a (potentially different) value that is needed by a future use of +//! the locals. Unifying them would overwrite one of the values. +//! +//! Note that computing liveness of locals that have had their address taken is more difficult: +//! Short of doing full escape analysis on the address/pointer/reference, the pass would need to +//! assume that any operation that can potentially involve opaque user code (such as function +//! calls, destructors, and inline assembly) may access any local that had its address taken +//! before that point. +//! +//! Here, the first two conditions are simple structural requirements on the `Assign` statements +//! that can be trivially checked. The liveness requirement however is more difficult and costly to +//! check. +//! +//! ## Previous Work +//! +//! A [previous attempt] at implementing an optimization like this turned out to be a significant +//! regression in compiler performance. Fixing the regressions introduced a lot of undesirable +//! complexity to the implementation. +//! +//! A [subsequent approach] tried to avoid the costly computation by limiting itself to acyclic +//! CFGs, but still turned out to be far too costly to run due to suboptimal performance within +//! individual basic blocks, requiring a walk across the entire block for every assignment found +//! within the block. For the `tuple-stress` benchmark, which has 458745 statements in a single +//! block, this proved to be far too costly. +//! +//! Since the first attempt at this, the compiler has improved dramatically, and new analysis +//! frameworks have been added that should make this approach viable without requiring a limited +//! approach that only works for some classes of CFGs: +//! - rustc now has a powerful dataflow analysis framework that can handle forwards and backwards +//! analyses efficiently. +//! - Layout optimizations for generators have been added to improve code generation for +//! async/await, which are very similar in spirit to what this optimization does. Both walk the +//! MIR and record conflicting uses of locals in a `BitMatrix`. +//! +//! Also, rustc now has a simple NRVO pass (see `nrvo.rs`), which handles a subset of the cases that +//! this destination propagation pass handles, proving that similar optimizations can be performed +//! on MIR. +//! +//! ## Pre/Post Optimization +//! +//! It is recommended to run `SimplifyCfg` and then `SimplifyLocals` some time after this pass, as +//! it replaces the eliminated assign statements with `nop`s and leaves unused locals behind. +//! +//! [liveness]: https://en.wikipedia.org/wiki/Live_variable_analysis +//! [previous attempt]: https://github.com/rust-lang/rust/pull/47954 +//! [subsequent approach]: https://github.com/rust-lang/rust/pull/71003 + +use crate::MirPass; +use itertools::Itertools; +use rustc_data_structures::unify::{InPlaceUnificationTable, UnifyKey}; +use rustc_index::{ + bit_set::{BitMatrix, BitSet}, + vec::IndexVec, +}; +use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; +use rustc_middle::mir::{dump_mir, PassWhere}; +use rustc_middle::mir::{ + traversal, Body, InlineAsmOperand, Local, LocalKind, Location, Operand, Place, PlaceElem, + Rvalue, Statement, StatementKind, Terminator, TerminatorKind, +}; +use rustc_middle::ty::TyCtxt; +use rustc_mir_dataflow::impls::{borrowed_locals, MaybeInitializedLocals, MaybeLiveLocals}; +use rustc_mir_dataflow::Analysis; + +// Empirical measurements have resulted in some observations: +// - Running on a body with a single block and 500 locals takes barely any time +// - Running on a body with ~400 blocks and ~300 relevant locals takes "too long" +// ...so we just limit both to somewhat reasonable-ish looking values. +const MAX_LOCALS: usize = 500; +const MAX_BLOCKS: usize = 250; + +pub struct DestinationPropagation; + +impl<'tcx> MirPass<'tcx> for DestinationPropagation { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + // FIXME(#79191, #82678): This is unsound. + // + // Only run at mir-opt-level=3 or higher for now (we don't fix up debuginfo and remove + // storage statements at the moment). + sess.opts.unstable_opts.unsound_mir_opts && sess.mir_opt_level() >= 3 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let def_id = body.source.def_id(); + + let candidates = find_candidates(body); + if candidates.is_empty() { + debug!("{:?}: no dest prop candidates, done", def_id); + return; + } + + // Collect all locals we care about. We only compute conflicts for these to save time. + let mut relevant_locals = BitSet::new_empty(body.local_decls.len()); + for CandidateAssignment { dest, src, loc: _ } in &candidates { + relevant_locals.insert(dest.local); + relevant_locals.insert(*src); + } + + // This pass unfortunately has `O(l² * s)` performance, where `l` is the number of locals + // and `s` is the number of statements and terminators in the function. + // To prevent blowing up compile times too much, we bail out when there are too many locals. + let relevant = relevant_locals.count(); + debug!( + "{:?}: {} locals ({} relevant), {} blocks", + def_id, + body.local_decls.len(), + relevant, + body.basic_blocks().len() + ); + if relevant > MAX_LOCALS { + warn!( + "too many candidate locals in {:?} ({}, max is {}), not optimizing", + def_id, relevant, MAX_LOCALS + ); + return; + } + if body.basic_blocks().len() > MAX_BLOCKS { + warn!( + "too many blocks in {:?} ({}, max is {}), not optimizing", + def_id, + body.basic_blocks().len(), + MAX_BLOCKS + ); + return; + } + + let mut conflicts = Conflicts::build(tcx, body, &relevant_locals); + + let mut replacements = Replacements::new(body.local_decls.len()); + for candidate @ CandidateAssignment { dest, src, loc } in candidates { + // Merge locals that don't conflict. + if !conflicts.can_unify(dest.local, src) { + debug!("at assignment {:?}, conflict {:?} vs. {:?}", loc, dest.local, src); + continue; + } + + if replacements.for_src(candidate.src).is_some() { + debug!("src {:?} already has replacement", candidate.src); + continue; + } + + if !tcx.consider_optimizing(|| { + format!("DestinationPropagation {:?} {:?}", def_id, candidate) + }) { + break; + } + + replacements.push(candidate); + conflicts.unify(candidate.src, candidate.dest.local); + } + + replacements.flatten(tcx); + + debug!("replacements {:?}", replacements.map); + + Replacer { tcx, replacements, place_elem_cache: Vec::new() }.visit_body(body); + + // FIXME fix debug info + } +} + +#[derive(Debug, Eq, PartialEq, Copy, Clone)] +struct UnifyLocal(Local); + +impl From<Local> for UnifyLocal { + fn from(l: Local) -> Self { + Self(l) + } +} + +impl UnifyKey for UnifyLocal { + type Value = (); + #[inline] + fn index(&self) -> u32 { + self.0.as_u32() + } + #[inline] + fn from_index(u: u32) -> Self { + Self(Local::from_u32(u)) + } + fn tag() -> &'static str { + "UnifyLocal" + } +} + +struct Replacements<'tcx> { + /// Maps locals to their replacement. + map: IndexVec<Local, Option<Place<'tcx>>>, + + /// Whose locals' live ranges to kill. + kill: BitSet<Local>, +} + +impl<'tcx> Replacements<'tcx> { + fn new(locals: usize) -> Self { + Self { map: IndexVec::from_elem_n(None, locals), kill: BitSet::new_empty(locals) } + } + + fn push(&mut self, candidate: CandidateAssignment<'tcx>) { + trace!("Replacements::push({:?})", candidate); + let entry = &mut self.map[candidate.src]; + assert!(entry.is_none()); + + *entry = Some(candidate.dest); + self.kill.insert(candidate.src); + self.kill.insert(candidate.dest.local); + } + + /// Applies the stored replacements to all replacements, until no replacements would result in + /// locals that need further replacements when applied. + fn flatten(&mut self, tcx: TyCtxt<'tcx>) { + // Note: This assumes that there are no cycles in the replacements, which is enforced via + // `self.unified_locals`. Otherwise this can cause an infinite loop. + + for local in self.map.indices() { + if let Some(replacement) = self.map[local] { + // Substitute the base local of `replacement` until fixpoint. + let mut base = replacement.local; + let mut reversed_projection_slices = Vec::with_capacity(1); + while let Some(replacement_for_replacement) = self.map[base] { + base = replacement_for_replacement.local; + reversed_projection_slices.push(replacement_for_replacement.projection); + } + + let projection: Vec<_> = reversed_projection_slices + .iter() + .rev() + .flat_map(|projs| projs.iter()) + .chain(replacement.projection.iter()) + .collect(); + let projection = tcx.intern_place_elems(&projection); + + // Replace with the final `Place`. + self.map[local] = Some(Place { local: base, projection }); + } + } + } + + fn for_src(&self, src: Local) -> Option<Place<'tcx>> { + self.map[src] + } +} + +struct Replacer<'tcx> { + tcx: TyCtxt<'tcx>, + replacements: Replacements<'tcx>, + place_elem_cache: Vec<PlaceElem<'tcx>>, +} + +impl<'tcx> MutVisitor<'tcx> for Replacer<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_local(&mut self, local: &mut Local, context: PlaceContext, location: Location) { + if context.is_use() && self.replacements.for_src(*local).is_some() { + bug!( + "use of local {:?} should have been replaced by visit_place; context={:?}, loc={:?}", + local, + context, + location, + ); + } + } + + fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) { + if let Some(replacement) = self.replacements.for_src(place.local) { + // Rebase `place`s projections onto `replacement`'s. + self.place_elem_cache.clear(); + self.place_elem_cache.extend(replacement.projection.iter().chain(place.projection)); + let projection = self.tcx.intern_place_elems(&self.place_elem_cache); + let new_place = Place { local: replacement.local, projection }; + + debug!("Replacer: {:?} -> {:?}", place, new_place); + *place = new_place; + } + + self.super_place(place, context, location); + } + + fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { + self.super_statement(statement, location); + + match &statement.kind { + // FIXME: Don't delete storage statements, merge the live ranges instead + StatementKind::StorageDead(local) | StatementKind::StorageLive(local) + if self.replacements.kill.contains(*local) => + { + statement.make_nop() + } + + StatementKind::Assign(box (dest, rvalue)) => { + match rvalue { + Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) => { + // These might've been turned into self-assignments by the replacement + // (this includes the original statement we wanted to eliminate). + if dest == place { + debug!("{:?} turned into self-assignment, deleting", location); + statement.make_nop(); + } + } + _ => {} + } + } + + _ => {} + } + } +} + +struct Conflicts<'a> { + relevant_locals: &'a BitSet<Local>, + + /// The conflict matrix. It is always symmetric and the adjacency matrix of the corresponding + /// conflict graph. + matrix: BitMatrix<Local, Local>, + + /// Preallocated `BitSet` used by `unify`. + unify_cache: BitSet<Local>, + + /// Tracks locals that have been merged together to prevent cycles and propagate conflicts. + unified_locals: InPlaceUnificationTable<UnifyLocal>, +} + +impl<'a> Conflicts<'a> { + fn build<'tcx>( + tcx: TyCtxt<'tcx>, + body: &'_ Body<'tcx>, + relevant_locals: &'a BitSet<Local>, + ) -> Self { + // We don't have to look out for locals that have their address taken, since + // `find_candidates` already takes care of that. + + let conflicts = BitMatrix::from_row_n( + &BitSet::new_empty(body.local_decls.len()), + body.local_decls.len(), + ); + + let mut init = MaybeInitializedLocals + .into_engine(tcx, body) + .iterate_to_fixpoint() + .into_results_cursor(body); + let mut live = + MaybeLiveLocals.into_engine(tcx, body).iterate_to_fixpoint().into_results_cursor(body); + + let mut reachable = None; + dump_mir(tcx, None, "DestinationPropagation-dataflow", &"", body, |pass_where, w| { + let reachable = reachable.get_or_insert_with(|| traversal::reachable_as_bitset(body)); + + match pass_where { + PassWhere::BeforeLocation(loc) if reachable.contains(loc.block) => { + init.seek_before_primary_effect(loc); + live.seek_after_primary_effect(loc); + + writeln!(w, " // init: {:?}", init.get())?; + writeln!(w, " // live: {:?}", live.get())?; + } + PassWhere::AfterTerminator(bb) if reachable.contains(bb) => { + let loc = body.terminator_loc(bb); + init.seek_after_primary_effect(loc); + live.seek_before_primary_effect(loc); + + writeln!(w, " // init: {:?}", init.get())?; + writeln!(w, " // live: {:?}", live.get())?; + } + + PassWhere::BeforeBlock(bb) if reachable.contains(bb) => { + init.seek_to_block_start(bb); + live.seek_to_block_start(bb); + + writeln!(w, " // init: {:?}", init.get())?; + writeln!(w, " // live: {:?}", live.get())?; + } + + PassWhere::BeforeCFG | PassWhere::AfterCFG | PassWhere::AfterLocation(_) => {} + + PassWhere::BeforeLocation(_) | PassWhere::AfterTerminator(_) => { + writeln!(w, " // init: <unreachable>")?; + writeln!(w, " // live: <unreachable>")?; + } + + PassWhere::BeforeBlock(_) => { + writeln!(w, " // init: <unreachable>")?; + writeln!(w, " // live: <unreachable>")?; + } + } + + Ok(()) + }); + + let mut this = Self { + relevant_locals, + matrix: conflicts, + unify_cache: BitSet::new_empty(body.local_decls.len()), + unified_locals: { + let mut table = InPlaceUnificationTable::new(); + // Pre-fill table with all locals (this creates N nodes / "connected" components, + // "graph"-ically speaking). + for local in 0..body.local_decls.len() { + assert_eq!(table.new_key(()), UnifyLocal(Local::from_usize(local))); + } + table + }, + }; + + let mut live_and_init_locals = Vec::new(); + + // Visit only reachable basic blocks. The exact order is not important. + for (block, data) in traversal::preorder(body) { + // We need to observe the dataflow state *before* all possible locations (statement or + // terminator) in each basic block, and then observe the state *after* the terminator + // effect is applied. As long as neither `init` nor `borrowed` has a "before" effect, + // we will observe all possible dataflow states. + + // Since liveness is a backwards analysis, we need to walk the results backwards. To do + // that, we first collect in the `MaybeInitializedLocals` results in a forwards + // traversal. + + live_and_init_locals.resize_with(data.statements.len() + 1, || { + BitSet::new_empty(body.local_decls.len()) + }); + + // First, go forwards for `MaybeInitializedLocals` and apply intra-statement/terminator + // conflicts. + for (i, statement) in data.statements.iter().enumerate() { + this.record_statement_conflicts(statement); + + let loc = Location { block, statement_index: i }; + init.seek_before_primary_effect(loc); + + live_and_init_locals[i].clone_from(init.get()); + } + + this.record_terminator_conflicts(data.terminator()); + let term_loc = Location { block, statement_index: data.statements.len() }; + init.seek_before_primary_effect(term_loc); + live_and_init_locals[term_loc.statement_index].clone_from(init.get()); + + // Now, go backwards and union with the liveness results. + for statement_index in (0..=data.statements.len()).rev() { + let loc = Location { block, statement_index }; + live.seek_after_primary_effect(loc); + + live_and_init_locals[statement_index].intersect(live.get()); + + trace!("record conflicts at {:?}", loc); + + this.record_dataflow_conflicts(&mut live_and_init_locals[statement_index]); + } + + init.seek_to_block_end(block); + live.seek_to_block_end(block); + let mut conflicts = init.get().clone(); + conflicts.intersect(live.get()); + trace!("record conflicts at end of {:?}", block); + + this.record_dataflow_conflicts(&mut conflicts); + } + + this + } + + fn record_dataflow_conflicts(&mut self, new_conflicts: &mut BitSet<Local>) { + // Remove all locals that are not candidates. + new_conflicts.intersect(self.relevant_locals); + + for local in new_conflicts.iter() { + self.matrix.union_row_with(&new_conflicts, local); + } + } + + fn record_local_conflict(&mut self, a: Local, b: Local, why: &str) { + trace!("conflict {:?} <-> {:?} due to {}", a, b, why); + self.matrix.insert(a, b); + self.matrix.insert(b, a); + } + + /// Records locals that must not overlap during the evaluation of `stmt`. These locals conflict + /// and must not be merged. + fn record_statement_conflicts(&mut self, stmt: &Statement<'_>) { + match &stmt.kind { + // While the left and right sides of an assignment must not overlap, we do not mark + // conflicts here as that would make this optimization useless. When we optimize, we + // eliminate the resulting self-assignments automatically. + StatementKind::Assign(_) => {} + + StatementKind::SetDiscriminant { .. } + | StatementKind::Deinit(..) + | StatementKind::StorageLive(..) + | StatementKind::StorageDead(..) + | StatementKind::Retag(..) + | StatementKind::FakeRead(..) + | StatementKind::AscribeUserType(..) + | StatementKind::Coverage(..) + | StatementKind::CopyNonOverlapping(..) + | StatementKind::Nop => {} + } + } + + fn record_terminator_conflicts(&mut self, term: &Terminator<'_>) { + match &term.kind { + TerminatorKind::DropAndReplace { + place: dropped_place, + value, + target: _, + unwind: _, + } => { + if let Some(place) = value.place() + && !place.is_indirect() + && !dropped_place.is_indirect() + { + self.record_local_conflict( + place.local, + dropped_place.local, + "DropAndReplace operand overlap", + ); + } + } + TerminatorKind::Yield { value, resume: _, resume_arg, drop: _ } => { + if let Some(place) = value.place() { + if !place.is_indirect() && !resume_arg.is_indirect() { + self.record_local_conflict( + place.local, + resume_arg.local, + "Yield operand overlap", + ); + } + } + } + TerminatorKind::Call { + func, + args, + destination, + target: _, + cleanup: _, + from_hir_call: _, + fn_span: _, + } => { + // No arguments may overlap with the destination. + for arg in args.iter().chain(Some(func)) { + if let Some(place) = arg.place() { + if !place.is_indirect() && !destination.is_indirect() { + self.record_local_conflict( + destination.local, + place.local, + "call dest/arg overlap", + ); + } + } + } + } + TerminatorKind::InlineAsm { + template: _, + operands, + options: _, + line_spans: _, + destination: _, + cleanup: _, + } => { + // The intended semantics here aren't documented, we just assume that nothing that + // could be written to by the assembly may overlap with any other operands. + for op in operands { + match op { + InlineAsmOperand::Out { reg: _, late: _, place: Some(dest_place) } + | InlineAsmOperand::InOut { + reg: _, + late: _, + in_value: _, + out_place: Some(dest_place), + } => { + // For output place `place`, add all places accessed by the inline asm. + for op in operands { + match op { + InlineAsmOperand::In { reg: _, value } => { + if let Some(p) = value.place() + && !p.is_indirect() + && !dest_place.is_indirect() + { + self.record_local_conflict( + p.local, + dest_place.local, + "asm! operand overlap", + ); + } + } + InlineAsmOperand::Out { + reg: _, + late: _, + place: Some(place), + } => { + if !place.is_indirect() && !dest_place.is_indirect() { + self.record_local_conflict( + place.local, + dest_place.local, + "asm! operand overlap", + ); + } + } + InlineAsmOperand::InOut { + reg: _, + late: _, + in_value, + out_place, + } => { + if let Some(place) = in_value.place() + && !place.is_indirect() + && !dest_place.is_indirect() + { + self.record_local_conflict( + place.local, + dest_place.local, + "asm! operand overlap", + ); + } + + if let Some(place) = out_place + && !place.is_indirect() + && !dest_place.is_indirect() + { + self.record_local_conflict( + place.local, + dest_place.local, + "asm! operand overlap", + ); + } + } + InlineAsmOperand::Out { reg: _, late: _, place: None } + | InlineAsmOperand::Const { value: _ } + | InlineAsmOperand::SymFn { value: _ } + | InlineAsmOperand::SymStatic { def_id: _ } => {} + } + } + } + InlineAsmOperand::InOut { + reg: _, + late: _, + in_value: _, + out_place: None, + } + | InlineAsmOperand::In { reg: _, value: _ } + | InlineAsmOperand::Out { reg: _, late: _, place: None } + | InlineAsmOperand::Const { value: _ } + | InlineAsmOperand::SymFn { value: _ } + | InlineAsmOperand::SymStatic { def_id: _ } => {} + } + } + } + + TerminatorKind::Goto { .. } + | TerminatorKind::SwitchInt { .. } + | TerminatorKind::Resume + | TerminatorKind::Abort + | TerminatorKind::Return + | TerminatorKind::Unreachable + | TerminatorKind::Drop { .. } + | TerminatorKind::Assert { .. } + | TerminatorKind::GeneratorDrop + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } => {} + } + } + + /// Checks whether `a` and `b` may be merged. Returns `false` if there's a conflict. + fn can_unify(&mut self, a: Local, b: Local) -> bool { + // After some locals have been unified, their conflicts are only tracked in the root key, + // so look that up. + let a = self.unified_locals.find(a).0; + let b = self.unified_locals.find(b).0; + + if a == b { + // Already merged (part of the same connected component). + return false; + } + + if self.matrix.contains(a, b) { + // Conflict (derived via dataflow, intra-statement conflicts, or inherited from another + // local during unification). + return false; + } + + true + } + + /// Merges the conflicts of `a` and `b`, so that each one inherits all conflicts of the other. + /// + /// `can_unify` must have returned `true` for the same locals, or this may panic or lead to + /// miscompiles. + /// + /// This is called when the pass makes the decision to unify `a` and `b` (or parts of `a` and + /// `b`) and is needed to ensure that future unification decisions take potentially newly + /// introduced conflicts into account. + /// + /// For an example, assume we have locals `_0`, `_1`, `_2`, and `_3`. There are these conflicts: + /// + /// * `_0` <-> `_1` + /// * `_1` <-> `_2` + /// * `_3` <-> `_0` + /// + /// We then decide to merge `_2` with `_3` since they don't conflict. Then we decide to merge + /// `_2` with `_0`, which also doesn't have a conflict in the above list. However `_2` is now + /// `_3`, which does conflict with `_0`. + fn unify(&mut self, a: Local, b: Local) { + trace!("unify({:?}, {:?})", a, b); + + // Get the root local of the connected components. The root local stores the conflicts of + // all locals in the connected component (and *is stored* as the conflicting local of other + // locals). + let a = self.unified_locals.find(a).0; + let b = self.unified_locals.find(b).0; + assert_ne!(a, b); + + trace!("roots: a={:?}, b={:?}", a, b); + trace!("{:?} conflicts: {:?}", a, self.matrix.iter(a).format(", ")); + trace!("{:?} conflicts: {:?}", b, self.matrix.iter(b).format(", ")); + + self.unified_locals.union(a, b); + + let root = self.unified_locals.find(a).0; + assert!(root == a || root == b); + + // Make all locals that conflict with `a` also conflict with `b`, and vice versa. + self.unify_cache.clear(); + for conflicts_with_a in self.matrix.iter(a) { + self.unify_cache.insert(conflicts_with_a); + } + for conflicts_with_b in self.matrix.iter(b) { + self.unify_cache.insert(conflicts_with_b); + } + for conflicts_with_a_or_b in self.unify_cache.iter() { + // Set both `a` and `b` for this local's row. + self.matrix.insert(conflicts_with_a_or_b, a); + self.matrix.insert(conflicts_with_a_or_b, b); + } + + // Write the locals `a` conflicts with to `b`'s row. + self.matrix.union_rows(a, b); + // Write the locals `b` conflicts with to `a`'s row. + self.matrix.union_rows(b, a); + } +} + +/// A `dest = {move} src;` statement at `loc`. +/// +/// We want to consider merging `dest` and `src` due to this assignment. +#[derive(Debug, Copy, Clone)] +struct CandidateAssignment<'tcx> { + /// Does not contain indirection or indexing (so the only local it contains is the place base). + dest: Place<'tcx>, + src: Local, + loc: Location, +} + +/// Scans the MIR for assignments between locals that we might want to consider merging. +/// +/// This will filter out assignments that do not match the right form (as described in the top-level +/// comment) and also throw out assignments that involve a local that has its address taken or is +/// otherwise ineligible (eg. locals used as array indices are ignored because we cannot propagate +/// arbitrary places into array indices). +fn find_candidates<'tcx>(body: &Body<'tcx>) -> Vec<CandidateAssignment<'tcx>> { + let mut visitor = FindAssignments { + body, + candidates: Vec::new(), + ever_borrowed_locals: borrowed_locals(body), + locals_used_as_array_index: locals_used_as_array_index(body), + }; + visitor.visit_body(body); + visitor.candidates +} + +struct FindAssignments<'a, 'tcx> { + body: &'a Body<'tcx>, + candidates: Vec<CandidateAssignment<'tcx>>, + ever_borrowed_locals: BitSet<Local>, + locals_used_as_array_index: BitSet<Local>, +} + +impl<'tcx> Visitor<'tcx> for FindAssignments<'_, 'tcx> { + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + if let StatementKind::Assign(box ( + dest, + Rvalue::Use(Operand::Copy(src) | Operand::Move(src)), + )) = &statement.kind + { + // `dest` must not have pointer indirection. + if dest.is_indirect() { + return; + } + + // `src` must be a plain local. + if !src.projection.is_empty() { + return; + } + + // Since we want to replace `src` with `dest`, `src` must not be required. + if is_local_required(src.local, self.body) { + return; + } + + // Can't optimize if either local ever has their address taken. This optimization does + // liveness analysis only based on assignments, and a local can be live even if its + // never assigned to again, because a reference to it might be live. + // FIXME: This can be smarter and take `StorageDead` into account (which invalidates + // borrows). + if self.ever_borrowed_locals.contains(dest.local) + || self.ever_borrowed_locals.contains(src.local) + { + return; + } + + assert_ne!(dest.local, src.local, "self-assignments are UB"); + + // We can't replace locals occurring in `PlaceElem::Index` for now. + if self.locals_used_as_array_index.contains(src.local) { + return; + } + + for elem in dest.projection { + if let PlaceElem::Index(_) = elem { + // `dest` contains an indexing projection. + return; + } + } + + self.candidates.push(CandidateAssignment { + dest: *dest, + src: src.local, + loc: location, + }); + } + } +} + +/// Some locals are part of the function's interface and can not be removed. +/// +/// Note that these locals *can* still be merged with non-required locals by removing that other +/// local. +fn is_local_required(local: Local, body: &Body<'_>) -> bool { + match body.local_kind(local) { + LocalKind::Arg | LocalKind::ReturnPointer => true, + LocalKind::Var | LocalKind::Temp => false, + } +} + +/// `PlaceElem::Index` only stores a `Local`, so we can't replace that with a full `Place`. +/// +/// Collect locals used as indices so we don't generate candidates that are impossible to apply +/// later. +fn locals_used_as_array_index(body: &Body<'_>) -> BitSet<Local> { + let mut visitor = IndexCollector { locals: BitSet::new_empty(body.local_decls.len()) }; + visitor.visit_body(body); + visitor.locals +} + +struct IndexCollector { + locals: BitSet<Local>, +} + +impl<'tcx> Visitor<'tcx> for IndexCollector { + fn visit_projection_elem( + &mut self, + local: Local, + proj_base: &[PlaceElem<'tcx>], + elem: PlaceElem<'tcx>, + context: PlaceContext, + location: Location, + ) { + if let PlaceElem::Index(i) = elem { + self.locals.insert(i); + } + self.super_projection_elem(local, proj_base, elem, context, location); + } +} diff --git a/compiler/rustc_mir_transform/src/dump_mir.rs b/compiler/rustc_mir_transform/src/dump_mir.rs new file mode 100644 index 000000000..6b995141a --- /dev/null +++ b/compiler/rustc_mir_transform/src/dump_mir.rs @@ -0,0 +1,28 @@ +//! This pass just dumps MIR at a specified point. + +use std::borrow::Cow; +use std::fs::File; +use std::io; + +use crate::MirPass; +use rustc_middle::mir::write_mir_pretty; +use rustc_middle::mir::Body; +use rustc_middle::ty::TyCtxt; +use rustc_session::config::{OutputFilenames, OutputType}; + +pub struct Marker(pub &'static str); + +impl<'tcx> MirPass<'tcx> for Marker { + fn name(&self) -> Cow<'_, str> { + Cow::Borrowed(self.0) + } + + fn run_pass(&self, _tcx: TyCtxt<'tcx>, _body: &mut Body<'tcx>) {} +} + +pub fn emit_mir(tcx: TyCtxt<'_>, outputs: &OutputFilenames) -> io::Result<()> { + let path = outputs.path(OutputType::Mir); + let mut f = io::BufWriter::new(File::create(&path)?); + write_mir_pretty(tcx, None, &mut f)?; + Ok(()) +} diff --git a/compiler/rustc_mir_transform/src/early_otherwise_branch.rs b/compiler/rustc_mir_transform/src/early_otherwise_branch.rs new file mode 100644 index 000000000..dba42f7af --- /dev/null +++ b/compiler/rustc_mir_transform/src/early_otherwise_branch.rs @@ -0,0 +1,429 @@ +use rustc_middle::mir::patch::MirPatch; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, Ty, TyCtxt}; +use std::fmt::Debug; + +use super::simplify::simplify_cfg; + +/// This pass optimizes something like +/// ```ignore (syntax-highlighting-only) +/// let x: Option<()>; +/// let y: Option<()>; +/// match (x,y) { +/// (Some(_), Some(_)) => {0}, +/// _ => {1} +/// } +/// ``` +/// into something like +/// ```ignore (syntax-highlighting-only) +/// let x: Option<()>; +/// let y: Option<()>; +/// let discriminant_x = std::mem::discriminant(x); +/// let discriminant_y = std::mem::discriminant(y); +/// if discriminant_x == discriminant_y { +/// match x { +/// Some(_) => 0, +/// _ => 1, // <---- +/// } // | Actually the same bb +/// } else { // | +/// 1 // <-------------- +/// } +/// ``` +/// +/// Specifically, it looks for instances of control flow like this: +/// ```text +/// +/// ================= +/// | BB1 | +/// |---------------| ============================ +/// | ... | /------> | BBC | +/// |---------------| | |--------------------------| +/// | switchInt(Q) | | | _cl = discriminant(P) | +/// | c | --------/ |--------------------------| +/// | d | -------\ | switchInt(_cl) | +/// | ... | | | c | ---> BBC.2 +/// | otherwise | --\ | /--- | otherwise | +/// ================= | | | ============================ +/// | | | +/// ================= | | | +/// | BBU | <-| | | ============================ +/// |---------------| | \-------> | BBD | +/// |---------------| | | |--------------------------| +/// | unreachable | | | | _dl = discriminant(P) | +/// ================= | | |--------------------------| +/// | | | switchInt(_dl) | +/// ================= | | | d | ---> BBD.2 +/// | BB9 | <--------------- | otherwise | +/// |---------------| ============================ +/// | ... | +/// ================= +/// ``` +/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU` or to `BB9`. In the +/// code: +/// - `BB1` is `parent` and `BBC, BBD` are children +/// - `P` is `child_place` +/// - `child_ty` is the type of `_cl`. +/// - `Q` is `parent_op`. +/// - `parent_ty` is the type of `Q`. +/// - `BB9` is `destination` +/// All this is then transformed into: +/// ```text +/// +/// ======================= +/// | BB1 | +/// |---------------------| ============================ +/// | ... | /------> | BBEq | +/// | _s = discriminant(P)| | |--------------------------| +/// | _t = Ne(Q, _s) | | |--------------------------| +/// |---------------------| | | switchInt(Q) | +/// | switchInt(_t) | | | c | ---> BBC.2 +/// | false | --------/ | d | ---> BBD.2 +/// | otherwise | ---------------- | otherwise | +/// ======================= | ============================ +/// | +/// ================= | +/// | BB9 | <-----------/ +/// |---------------| +/// | ... | +/// ================= +/// ``` +/// +/// This is only correct for some `P`, since `P` is now computed outside the original `switchInt`. +/// The filter on which `P` are allowed (together with discussion of its correctness) is found in +/// `may_hoist`. +pub struct EarlyOtherwiseBranch; + +impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 3 && sess.opts.unstable_opts.unsound_mir_opts + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + trace!("running EarlyOtherwiseBranch on {:?}", body.source); + + let mut should_cleanup = false; + + // Also consider newly generated bbs in the same pass + for i in 0..body.basic_blocks().len() { + let bbs = body.basic_blocks(); + let parent = BasicBlock::from_usize(i); + let Some(opt_data) = evaluate_candidate(tcx, body, parent) else { + continue + }; + + if !tcx.consider_optimizing(|| format!("EarlyOtherwiseBranch {:?}", &opt_data)) { + break; + } + + trace!("SUCCESS: found optimization possibility to apply: {:?}", &opt_data); + + should_cleanup = true; + + let TerminatorKind::SwitchInt { + discr: parent_op, + switch_ty: parent_ty, + targets: parent_targets + } = &bbs[parent].terminator().kind else { + unreachable!() + }; + // Always correct since we can only switch on `Copy` types + let parent_op = match parent_op { + Operand::Move(x) => Operand::Copy(*x), + Operand::Copy(x) => Operand::Copy(*x), + Operand::Constant(x) => Operand::Constant(x.clone()), + }; + let statements_before = bbs[parent].statements.len(); + let parent_end = Location { block: parent, statement_index: statements_before }; + + let mut patch = MirPatch::new(body); + + // create temp to store second discriminant in, `_s` in example above + let second_discriminant_temp = + patch.new_temp(opt_data.child_ty, opt_data.child_source.span); + + patch.add_statement(parent_end, StatementKind::StorageLive(second_discriminant_temp)); + + // create assignment of discriminant + patch.add_assign( + parent_end, + Place::from(second_discriminant_temp), + Rvalue::Discriminant(opt_data.child_place), + ); + + // create temp to store inequality comparison between the two discriminants, `_t` in + // example above + let nequal = BinOp::Ne; + let comp_res_type = nequal.ty(tcx, *parent_ty, opt_data.child_ty); + let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span); + patch.add_statement(parent_end, StatementKind::StorageLive(comp_temp)); + + // create inequality comparison between the two discriminants + let comp_rvalue = Rvalue::BinaryOp( + nequal, + Box::new((parent_op.clone(), Operand::Move(Place::from(second_discriminant_temp)))), + ); + patch.add_statement( + parent_end, + StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))), + ); + + let eq_new_targets = parent_targets.iter().map(|(value, child)| { + let TerminatorKind::SwitchInt{ targets, .. } = &bbs[child].terminator().kind else { + unreachable!() + }; + (value, targets.target_for_value(value)) + }); + let eq_targets = SwitchTargets::new(eq_new_targets, opt_data.destination); + + // Create `bbEq` in example above + let eq_switch = BasicBlockData::new(Some(Terminator { + source_info: bbs[parent].terminator().source_info, + kind: TerminatorKind::SwitchInt { + // switch on the first discriminant, so we can mark the second one as dead + discr: parent_op, + switch_ty: opt_data.child_ty, + targets: eq_targets, + }, + })); + + let eq_bb = patch.new_block(eq_switch); + + // Jump to it on the basis of the inequality comparison + let true_case = opt_data.destination; + let false_case = eq_bb; + patch.patch_terminator( + parent, + TerminatorKind::if_( + tcx, + Operand::Move(Place::from(comp_temp)), + true_case, + false_case, + ), + ); + + // generate StorageDead for the second_discriminant_temp not in use anymore + patch.add_statement(parent_end, StatementKind::StorageDead(second_discriminant_temp)); + + // Generate a StorageDead for comp_temp in each of the targets, since we moved it into + // the switch + for bb in [false_case, true_case].iter() { + patch.add_statement( + Location { block: *bb, statement_index: 0 }, + StatementKind::StorageDead(comp_temp), + ); + } + + patch.apply(body); + } + + // Since this optimization adds new basic blocks and invalidates others, + // clean up the cfg to make it nicer for other passes + if should_cleanup { + simplify_cfg(tcx, body); + } + } +} + +/// Returns true if computing the discriminant of `place` may be hoisted out of the branch +fn may_hoist<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, place: Place<'tcx>) -> bool { + // FIXME(JakobDegen): This is unsound. Someone could write code like this: + // ```rust + // let Q = val; + // if discriminant(P) == otherwise { + // let ptr = &mut Q as *mut _ as *mut u8; + // unsafe { *ptr = 10; } // Any invalid value for the type + // } + // + // match P { + // A => match Q { + // A => { + // // code + // } + // _ => { + // // don't use Q + // } + // } + // _ => { + // // don't use Q + // } + // }; + // ``` + // + // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an + // invalid value, which is UB. + // + // In order to fix this, we would either need to show that the discriminant computation of + // `place` is computed in all branches, including the `otherwise` branch, or we would need + // another analysis pass to determine that the place is fully initialized. It might even be best + // to have the hoisting be performed in a different pass and just do the CFG changing in this + // pass. + for (place, proj) in place.iter_projections() { + match proj { + // Dereferencing in the computation of `place` might cause issues from one of two + // categories. First, the referent might be invalid. We protect against this by + // dereferencing references only (not pointers). Second, the use of a reference may + // invalidate other references that are used later (for aliasing reasons). Consider + // where such an invalidated reference may appear: + // - In `Q`: Not possible since `Q` is used as the operand of a `SwitchInt` and so + // cannot contain referenced data. + // - In `BBU`: Not possible since that block contains only the `unreachable` terminator + // - In `BBC.2, BBD.2`: Not possible, since `discriminant(P)` was computed prior to + // reaching that block in the input to our transformation, and so any data + // invalidated by that computation could not have been used there. + // - In `BB9`: Not possible since control flow might have reached `BB9` via the + // `otherwise` branch in `BBC, BBD` in the input to our transformation, which would + // have invalidated the data when computing `discriminant(P)` + // So dereferencing here is correct. + ProjectionElem::Deref => match place.ty(body.local_decls(), tcx).ty.kind() { + ty::Ref(..) => {} + _ => return false, + }, + // Field projections are always valid + ProjectionElem::Field(..) => {} + // We cannot allow + // downcasts either, since the correctness of the downcast may depend on the parent + // branch being taken. An easy example of this is + // ``` + // Q = discriminant(_3) + // P = (_3 as Variant) + // ``` + // However, checking if the child and parent place are the same and only erroring then + // is not sufficient either, since the `discriminant(_3) == 1` (or whatever) check may + // be replaced by another optimization pass with any other condition that can be proven + // equivalent. + ProjectionElem::Downcast(..) => { + return false; + } + // We cannot allow indexing since the index may be out of bounds. + _ => { + return false; + } + } + } + true +} + +#[derive(Debug)] +struct OptimizationData<'tcx> { + destination: BasicBlock, + child_place: Place<'tcx>, + child_ty: Ty<'tcx>, + child_source: SourceInfo, +} + +fn evaluate_candidate<'tcx>( + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, + parent: BasicBlock, +) -> Option<OptimizationData<'tcx>> { + let bbs = body.basic_blocks(); + let TerminatorKind::SwitchInt { + targets, + switch_ty: parent_ty, + .. + } = &bbs[parent].terminator().kind else { + return None + }; + let parent_dest = { + let poss = targets.otherwise(); + // If the fallthrough on the parent is trivially unreachable, we can let the + // children choose the destination + if bbs[poss].statements.len() == 0 + && bbs[poss].terminator().kind == TerminatorKind::Unreachable + { + None + } else { + Some(poss) + } + }; + let (_, child) = targets.iter().next()?; + let child_terminator = &bbs[child].terminator(); + let TerminatorKind::SwitchInt { + switch_ty: child_ty, + targets: child_targets, + .. + } = &child_terminator.kind else { + return None + }; + if child_ty != parent_ty { + return None; + } + let Some(StatementKind::Assign(boxed)) + = &bbs[child].statements.first().map(|x| &x.kind) else { + return None; + }; + let (_, Rvalue::Discriminant(child_place)) = &**boxed else { + return None; + }; + let destination = parent_dest.unwrap_or(child_targets.otherwise()); + + // Verify that the optimization is legal in general + // We can hoist evaluating the child discriminant out of the branch + if !may_hoist(tcx, body, *child_place) { + return None; + } + + // Verify that the optimization is legal for each branch + for (value, child) in targets.iter() { + if !verify_candidate_branch(&bbs[child], value, *child_place, destination) { + return None; + } + } + Some(OptimizationData { + destination, + child_place: *child_place, + child_ty: *child_ty, + child_source: child_terminator.source_info, + }) +} + +fn verify_candidate_branch<'tcx>( + branch: &BasicBlockData<'tcx>, + value: u128, + place: Place<'tcx>, + destination: BasicBlock, +) -> bool { + // In order for the optimization to be correct, the branch must... + // ...have exactly one statement + if branch.statements.len() != 1 { + return false; + } + // ...assign the discriminant of `place` in that statement + let StatementKind::Assign(boxed) = &branch.statements[0].kind else { + return false + }; + let (discr_place, Rvalue::Discriminant(from_place)) = &**boxed else { + return false + }; + if *from_place != place { + return false; + } + // ...make that assignment to a local + if discr_place.projection.len() != 0 { + return false; + } + // ...terminate on a `SwitchInt` that invalidates that local + let TerminatorKind::SwitchInt{ discr: switch_op, targets, .. } = &branch.terminator().kind else { + return false + }; + if *switch_op != Operand::Move(*discr_place) { + return false; + } + // ...fall through to `destination` if the switch misses + if destination != targets.otherwise() { + return false; + } + // ...have a branch for value `value` + let mut iter = targets.iter(); + let Some((target_value, _)) = iter.next() else { + return false; + }; + if target_value != value { + return false; + } + // ...and have no more branches + if let Some(_) = iter.next() { + return false; + } + return true; +} diff --git a/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs b/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs new file mode 100644 index 000000000..44e3945d6 --- /dev/null +++ b/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs @@ -0,0 +1,184 @@ +//! This pass transforms derefs of Box into a deref of the pointer inside Box. +//! +//! Box is not actually a pointer so it is incorrect to dereference it directly. + +use crate::MirPass; +use rustc_hir::def_id::DefId; +use rustc_index::vec::Idx; +use rustc_middle::mir::patch::MirPatch; +use rustc_middle::mir::visit::MutVisitor; +use rustc_middle::mir::*; +use rustc_middle::ty::subst::Subst; +use rustc_middle::ty::{Ty, TyCtxt}; + +/// Constructs the types used when accessing a Box's pointer +pub fn build_ptr_tys<'tcx>( + tcx: TyCtxt<'tcx>, + pointee: Ty<'tcx>, + unique_did: DefId, + nonnull_did: DefId, +) -> (Ty<'tcx>, Ty<'tcx>, Ty<'tcx>) { + let substs = tcx.intern_substs(&[pointee.into()]); + let unique_ty = tcx.bound_type_of(unique_did).subst(tcx, substs); + let nonnull_ty = tcx.bound_type_of(nonnull_did).subst(tcx, substs); + let ptr_ty = tcx.mk_imm_ptr(pointee); + + (unique_ty, nonnull_ty, ptr_ty) +} + +// Constructs the projection needed to access a Box's pointer +pub fn build_projection<'tcx>( + unique_ty: Ty<'tcx>, + nonnull_ty: Ty<'tcx>, + ptr_ty: Ty<'tcx>, +) -> [PlaceElem<'tcx>; 3] { + [ + PlaceElem::Field(Field::new(0), unique_ty), + PlaceElem::Field(Field::new(0), nonnull_ty), + PlaceElem::Field(Field::new(0), ptr_ty), + ] +} + +struct ElaborateBoxDerefVisitor<'tcx, 'a> { + tcx: TyCtxt<'tcx>, + unique_did: DefId, + nonnull_did: DefId, + local_decls: &'a mut LocalDecls<'tcx>, + patch: MirPatch<'tcx>, +} + +impl<'tcx, 'a> MutVisitor<'tcx> for ElaborateBoxDerefVisitor<'tcx, 'a> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_place( + &mut self, + place: &mut Place<'tcx>, + context: visit::PlaceContext, + location: Location, + ) { + let tcx = self.tcx; + + let base_ty = self.local_decls[place.local].ty; + + // Derefer ensures that derefs are always the first projection + if place.projection.first() == Some(&PlaceElem::Deref) && base_ty.is_box() { + let source_info = self.local_decls[place.local].source_info; + + let (unique_ty, nonnull_ty, ptr_ty) = + build_ptr_tys(tcx, base_ty.boxed_ty(), self.unique_did, self.nonnull_did); + + let ptr_local = self.patch.new_temp(ptr_ty, source_info.span); + self.local_decls.push(LocalDecl::new(ptr_ty, source_info.span)); + + self.patch.add_statement(location, StatementKind::StorageLive(ptr_local)); + + self.patch.add_assign( + location, + Place::from(ptr_local), + Rvalue::Use(Operand::Copy( + Place::from(place.local) + .project_deeper(&build_projection(unique_ty, nonnull_ty, ptr_ty), tcx), + )), + ); + + place.local = ptr_local; + + self.patch.add_statement( + Location { block: location.block, statement_index: location.statement_index + 1 }, + StatementKind::StorageDead(ptr_local), + ); + } + + self.super_place(place, context, location); + } +} + +pub struct ElaborateBoxDerefs; + +impl<'tcx> MirPass<'tcx> for ElaborateBoxDerefs { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + if let Some(def_id) = tcx.lang_items().owned_box() { + let unique_did = tcx.adt_def(def_id).non_enum_variant().fields[0].did; + + let Some(nonnull_def) = tcx.type_of(unique_did).ty_adt_def() else { + span_bug!(tcx.def_span(unique_did), "expected Box to contain Unique") + }; + + let nonnull_did = nonnull_def.non_enum_variant().fields[0].did; + + let patch = MirPatch::new(body); + + let local_decls = &mut body.local_decls; + + let mut visitor = + ElaborateBoxDerefVisitor { tcx, unique_did, nonnull_did, local_decls, patch }; + + for (block, BasicBlockData { statements, terminator, .. }) in + body.basic_blocks.as_mut().iter_enumerated_mut() + { + let mut index = 0; + for statement in statements { + let location = Location { block, statement_index: index }; + visitor.visit_statement(statement, location); + index += 1; + } + + if let Some(terminator) = terminator + && !matches!(terminator.kind, TerminatorKind::Yield{..}) + { + let location = Location { block, statement_index: index }; + visitor.visit_terminator(terminator, location); + } + + let location = Location { block, statement_index: index }; + match terminator { + // yielding into a box is handled when lowering generators + Some(Terminator { kind: TerminatorKind::Yield { value, .. }, .. }) => { + visitor.visit_operand(value, location); + } + Some(terminator) => { + visitor.visit_terminator(terminator, location); + } + None => {} + } + } + + visitor.patch.apply(body); + + for debug_info in body.var_debug_info.iter_mut() { + if let VarDebugInfoContents::Place(place) = &mut debug_info.value { + let mut new_projections: Option<Vec<_>> = None; + let mut last_deref = 0; + + for (i, (base, elem)) in place.iter_projections().enumerate() { + let base_ty = base.ty(&body.local_decls, tcx).ty; + + if elem == PlaceElem::Deref && base_ty.is_box() { + let new_projections = new_projections.get_or_insert_default(); + + let (unique_ty, nonnull_ty, ptr_ty) = + build_ptr_tys(tcx, base_ty.boxed_ty(), unique_did, nonnull_did); + + new_projections.extend_from_slice(&base.projection[last_deref..]); + new_projections.extend_from_slice(&build_projection( + unique_ty, nonnull_ty, ptr_ty, + )); + new_projections.push(PlaceElem::Deref); + + last_deref = i; + } + } + + if let Some(mut new_projections) = new_projections { + new_projections.extend_from_slice(&place.projection[last_deref..]); + place.projection = tcx.intern_place_elems(&new_projections); + } + } + } + } else { + // box is not present, this pass doesn't need to do anything + } + } +} diff --git a/compiler/rustc_mir_transform/src/elaborate_drops.rs b/compiler/rustc_mir_transform/src/elaborate_drops.rs new file mode 100644 index 000000000..9c1fcbaa6 --- /dev/null +++ b/compiler/rustc_mir_transform/src/elaborate_drops.rs @@ -0,0 +1,613 @@ +use crate::deref_separator::deref_finder; +use crate::MirPass; +use rustc_data_structures::fx::FxHashMap; +use rustc_index::bit_set::BitSet; +use rustc_middle::mir::patch::MirPatch; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, TyCtxt}; +use rustc_mir_dataflow::elaborate_drops::{elaborate_drop, DropFlagState, Unwind}; +use rustc_mir_dataflow::elaborate_drops::{DropElaborator, DropFlagMode, DropStyle}; +use rustc_mir_dataflow::impls::{MaybeInitializedPlaces, MaybeUninitializedPlaces}; +use rustc_mir_dataflow::move_paths::{LookupResult, MoveData, MovePathIndex}; +use rustc_mir_dataflow::on_lookup_result_bits; +use rustc_mir_dataflow::un_derefer::UnDerefer; +use rustc_mir_dataflow::MoveDataParamEnv; +use rustc_mir_dataflow::{on_all_children_bits, on_all_drop_children_bits}; +use rustc_mir_dataflow::{Analysis, ResultsCursor}; +use rustc_span::Span; +use rustc_target::abi::VariantIdx; +use std::fmt; + +pub struct ElaborateDrops; + +impl<'tcx> MirPass<'tcx> for ElaborateDrops { + fn phase_change(&self) -> Option<MirPhase> { + Some(MirPhase::DropsLowered) + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + debug!("elaborate_drops({:?} @ {:?})", body.source, body.span); + + let def_id = body.source.def_id(); + let param_env = tcx.param_env_reveal_all_normalized(def_id); + let (side_table, move_data) = match MoveData::gather_moves(body, tcx, param_env) { + Ok(move_data) => move_data, + Err((move_data, _)) => { + tcx.sess.delay_span_bug( + body.span, + "No `move_errors` should be allowed in MIR borrowck", + ); + (Default::default(), move_data) + } + }; + let un_derefer = UnDerefer { tcx: tcx, derefer_sidetable: side_table }; + let elaborate_patch = { + let body = &*body; + let env = MoveDataParamEnv { move_data, param_env }; + let dead_unwinds = find_dead_unwinds(tcx, body, &env, &un_derefer); + + let inits = MaybeInitializedPlaces::new(tcx, body, &env) + .into_engine(tcx, body) + .dead_unwinds(&dead_unwinds) + .pass_name("elaborate_drops") + .iterate_to_fixpoint() + .into_results_cursor(body); + + let uninits = MaybeUninitializedPlaces::new(tcx, body, &env) + .mark_inactive_variants_as_uninit() + .into_engine(tcx, body) + .dead_unwinds(&dead_unwinds) + .pass_name("elaborate_drops") + .iterate_to_fixpoint() + .into_results_cursor(body); + + ElaborateDropsCtxt { + tcx, + body, + env: &env, + init_data: InitializationData { inits, uninits }, + drop_flags: Default::default(), + patch: MirPatch::new(body), + un_derefer: un_derefer, + } + .elaborate() + }; + elaborate_patch.apply(body); + deref_finder(tcx, body); + } +} + +/// Returns the set of basic blocks whose unwind edges are known +/// to not be reachable, because they are `drop` terminators +/// that can't drop anything. +fn find_dead_unwinds<'tcx>( + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, + env: &MoveDataParamEnv<'tcx>, + und: &UnDerefer<'tcx>, +) -> BitSet<BasicBlock> { + debug!("find_dead_unwinds({:?})", body.span); + // We only need to do this pass once, because unwind edges can only + // reach cleanup blocks, which can't have unwind edges themselves. + let mut dead_unwinds = BitSet::new_empty(body.basic_blocks().len()); + let mut flow_inits = MaybeInitializedPlaces::new(tcx, body, &env) + .into_engine(tcx, body) + .pass_name("find_dead_unwinds") + .iterate_to_fixpoint() + .into_results_cursor(body); + for (bb, bb_data) in body.basic_blocks().iter_enumerated() { + let place = match bb_data.terminator().kind { + TerminatorKind::Drop { ref place, unwind: Some(_), .. } + | TerminatorKind::DropAndReplace { ref place, unwind: Some(_), .. } => { + und.derefer(place.as_ref(), body).unwrap_or(*place) + } + _ => continue, + }; + + debug!("find_dead_unwinds @ {:?}: {:?}", bb, bb_data); + + let LookupResult::Exact(path) = env.move_data.rev_lookup.find(place.as_ref()) else { + debug!("find_dead_unwinds: has parent; skipping"); + continue; + }; + + flow_inits.seek_before_primary_effect(body.terminator_loc(bb)); + debug!( + "find_dead_unwinds @ {:?}: path({:?})={:?}; init_data={:?}", + bb, + place, + path, + flow_inits.get() + ); + + let mut maybe_live = false; + on_all_drop_children_bits(tcx, body, &env, path, |child| { + maybe_live |= flow_inits.contains(child); + }); + + debug!("find_dead_unwinds @ {:?}: maybe_live={}", bb, maybe_live); + if !maybe_live { + dead_unwinds.insert(bb); + } + } + + dead_unwinds +} + +struct InitializationData<'mir, 'tcx> { + inits: ResultsCursor<'mir, 'tcx, MaybeInitializedPlaces<'mir, 'tcx>>, + uninits: ResultsCursor<'mir, 'tcx, MaybeUninitializedPlaces<'mir, 'tcx>>, +} + +impl InitializationData<'_, '_> { + fn seek_before(&mut self, loc: Location) { + self.inits.seek_before_primary_effect(loc); + self.uninits.seek_before_primary_effect(loc); + } + + fn maybe_live_dead(&self, path: MovePathIndex) -> (bool, bool) { + (self.inits.contains(path), self.uninits.contains(path)) + } +} + +struct Elaborator<'a, 'b, 'tcx> { + ctxt: &'a mut ElaborateDropsCtxt<'b, 'tcx>, +} + +impl fmt::Debug for Elaborator<'_, '_, '_> { + fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result { + Ok(()) + } +} + +impl<'a, 'tcx> DropElaborator<'a, 'tcx> for Elaborator<'a, '_, 'tcx> { + type Path = MovePathIndex; + + fn patch(&mut self) -> &mut MirPatch<'tcx> { + &mut self.ctxt.patch + } + + fn body(&self) -> &'a Body<'tcx> { + self.ctxt.body + } + + fn tcx(&self) -> TyCtxt<'tcx> { + self.ctxt.tcx + } + + fn param_env(&self) -> ty::ParamEnv<'tcx> { + self.ctxt.param_env() + } + + fn drop_style(&self, path: Self::Path, mode: DropFlagMode) -> DropStyle { + let ((maybe_live, maybe_dead), multipart) = match mode { + DropFlagMode::Shallow => (self.ctxt.init_data.maybe_live_dead(path), false), + DropFlagMode::Deep => { + let mut some_live = false; + let mut some_dead = false; + let mut children_count = 0; + on_all_drop_children_bits(self.tcx(), self.body(), self.ctxt.env, path, |child| { + let (live, dead) = self.ctxt.init_data.maybe_live_dead(child); + debug!("elaborate_drop: state({:?}) = {:?}", child, (live, dead)); + some_live |= live; + some_dead |= dead; + children_count += 1; + }); + ((some_live, some_dead), children_count != 1) + } + }; + match (maybe_live, maybe_dead, multipart) { + (false, _, _) => DropStyle::Dead, + (true, false, _) => DropStyle::Static, + (true, true, false) => DropStyle::Conditional, + (true, true, true) => DropStyle::Open, + } + } + + fn clear_drop_flag(&mut self, loc: Location, path: Self::Path, mode: DropFlagMode) { + match mode { + DropFlagMode::Shallow => { + self.ctxt.set_drop_flag(loc, path, DropFlagState::Absent); + } + DropFlagMode::Deep => { + on_all_children_bits( + self.tcx(), + self.body(), + self.ctxt.move_data(), + path, + |child| self.ctxt.set_drop_flag(loc, child, DropFlagState::Absent), + ); + } + } + } + + fn field_subpath(&self, path: Self::Path, field: Field) -> Option<Self::Path> { + rustc_mir_dataflow::move_path_children_matching(self.ctxt.move_data(), path, |e| match e { + ProjectionElem::Field(idx, _) => idx == field, + _ => false, + }) + } + + fn array_subpath(&self, path: Self::Path, index: u64, size: u64) -> Option<Self::Path> { + rustc_mir_dataflow::move_path_children_matching(self.ctxt.move_data(), path, |e| match e { + ProjectionElem::ConstantIndex { offset, min_length, from_end } => { + debug_assert!(size == min_length, "min_length should be exact for arrays"); + assert!(!from_end, "from_end should not be used for array element ConstantIndex"); + offset == index + } + _ => false, + }) + } + + fn deref_subpath(&self, path: Self::Path) -> Option<Self::Path> { + rustc_mir_dataflow::move_path_children_matching(self.ctxt.move_data(), path, |e| { + e == ProjectionElem::Deref + }) + } + + fn downcast_subpath(&self, path: Self::Path, variant: VariantIdx) -> Option<Self::Path> { + rustc_mir_dataflow::move_path_children_matching(self.ctxt.move_data(), path, |e| match e { + ProjectionElem::Downcast(_, idx) => idx == variant, + _ => false, + }) + } + + fn get_drop_flag(&mut self, path: Self::Path) -> Option<Operand<'tcx>> { + self.ctxt.drop_flag(path).map(Operand::Copy) + } +} + +struct ElaborateDropsCtxt<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + body: &'a Body<'tcx>, + env: &'a MoveDataParamEnv<'tcx>, + init_data: InitializationData<'a, 'tcx>, + drop_flags: FxHashMap<MovePathIndex, Local>, + patch: MirPatch<'tcx>, + un_derefer: UnDerefer<'tcx>, +} + +impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> { + fn move_data(&self) -> &'b MoveData<'tcx> { + &self.env.move_data + } + + fn param_env(&self) -> ty::ParamEnv<'tcx> { + self.env.param_env + } + + fn create_drop_flag(&mut self, index: MovePathIndex, span: Span) { + let tcx = self.tcx; + let patch = &mut self.patch; + debug!("create_drop_flag({:?})", self.body.span); + self.drop_flags.entry(index).or_insert_with(|| patch.new_internal(tcx.types.bool, span)); + } + + fn drop_flag(&mut self, index: MovePathIndex) -> Option<Place<'tcx>> { + self.drop_flags.get(&index).map(|t| Place::from(*t)) + } + + /// create a patch that elaborates all drops in the input + /// MIR. + fn elaborate(mut self) -> MirPatch<'tcx> { + self.collect_drop_flags(); + + self.elaborate_drops(); + + self.drop_flags_on_init(); + self.drop_flags_for_fn_rets(); + self.drop_flags_for_args(); + self.drop_flags_for_locs(); + + self.patch + } + + fn collect_drop_flags(&mut self) { + for (bb, data) in self.body.basic_blocks().iter_enumerated() { + let terminator = data.terminator(); + let place = match terminator.kind { + TerminatorKind::Drop { ref place, .. } + | TerminatorKind::DropAndReplace { ref place, .. } => { + self.un_derefer.derefer(place.as_ref(), self.body).unwrap_or(*place) + } + _ => continue, + }; + + self.init_data.seek_before(self.body.terminator_loc(bb)); + + let path = self.move_data().rev_lookup.find(place.as_ref()); + debug!("collect_drop_flags: {:?}, place {:?} ({:?})", bb, place, path); + + let path = match path { + LookupResult::Exact(e) => e, + LookupResult::Parent(None) => continue, + LookupResult::Parent(Some(parent)) => { + let (_maybe_live, maybe_dead) = self.init_data.maybe_live_dead(parent); + + if self.body.local_decls[place.local].is_deref_temp() { + continue; + } + + if maybe_dead { + self.tcx.sess.delay_span_bug( + terminator.source_info.span, + &format!( + "drop of untracked, uninitialized value {:?}, place {:?} ({:?})", + bb, place, path + ), + ); + } + continue; + } + }; + + on_all_drop_children_bits(self.tcx, self.body, self.env, path, |child| { + let (maybe_live, maybe_dead) = self.init_data.maybe_live_dead(child); + debug!( + "collect_drop_flags: collecting {:?} from {:?}@{:?} - {:?}", + child, + place, + path, + (maybe_live, maybe_dead) + ); + if maybe_live && maybe_dead { + self.create_drop_flag(child, terminator.source_info.span) + } + }); + } + } + + fn elaborate_drops(&mut self) { + for (bb, data) in self.body.basic_blocks().iter_enumerated() { + let loc = Location { block: bb, statement_index: data.statements.len() }; + let terminator = data.terminator(); + + let resume_block = self.patch.resume_block(); + match terminator.kind { + TerminatorKind::Drop { mut place, target, unwind } => { + if let Some(new_place) = self.un_derefer.derefer(place.as_ref(), self.body) { + place = new_place; + } + + self.init_data.seek_before(loc); + match self.move_data().rev_lookup.find(place.as_ref()) { + LookupResult::Exact(path) => elaborate_drop( + &mut Elaborator { ctxt: self }, + terminator.source_info, + place, + path, + target, + if data.is_cleanup { + Unwind::InCleanup + } else { + Unwind::To(Option::unwrap_or(unwind, resume_block)) + }, + bb, + ), + LookupResult::Parent(..) => { + self.tcx.sess.delay_span_bug( + terminator.source_info.span, + &format!("drop of untracked value {:?}", bb), + ); + } + } + } + TerminatorKind::DropAndReplace { mut place, ref value, target, unwind } => { + assert!(!data.is_cleanup); + + if let Some(new_place) = self.un_derefer.derefer(place.as_ref(), self.body) { + place = new_place; + } + self.elaborate_replace(loc, place, value, target, unwind); + } + _ => continue, + } + } + } + + /// Elaborate a MIR `replace` terminator. This instruction + /// is not directly handled by codegen, and therefore + /// must be desugared. + /// + /// The desugaring drops the location if needed, and then writes + /// the value (including setting the drop flag) over it in *both* arms. + /// + /// The `replace` terminator can also be called on places that + /// are not tracked by elaboration (for example, + /// `replace x[i] <- tmp0`). The borrow checker requires that + /// these locations are initialized before the assignment, + /// so we just generate an unconditional drop. + fn elaborate_replace( + &mut self, + loc: Location, + place: Place<'tcx>, + value: &Operand<'tcx>, + target: BasicBlock, + unwind: Option<BasicBlock>, + ) { + let bb = loc.block; + let data = &self.body[bb]; + let terminator = data.terminator(); + assert!(!data.is_cleanup, "DropAndReplace in unwind path not supported"); + + let assign = Statement { + kind: StatementKind::Assign(Box::new((place, Rvalue::Use(value.clone())))), + source_info: terminator.source_info, + }; + + let unwind = unwind.unwrap_or_else(|| self.patch.resume_block()); + let unwind = self.patch.new_block(BasicBlockData { + statements: vec![assign.clone()], + terminator: Some(Terminator { + kind: TerminatorKind::Goto { target: unwind }, + ..*terminator + }), + is_cleanup: true, + }); + + let target = self.patch.new_block(BasicBlockData { + statements: vec![assign], + terminator: Some(Terminator { kind: TerminatorKind::Goto { target }, ..*terminator }), + is_cleanup: false, + }); + + match self.move_data().rev_lookup.find(place.as_ref()) { + LookupResult::Exact(path) => { + debug!("elaborate_drop_and_replace({:?}) - tracked {:?}", terminator, path); + self.init_data.seek_before(loc); + elaborate_drop( + &mut Elaborator { ctxt: self }, + terminator.source_info, + place, + path, + target, + Unwind::To(unwind), + bb, + ); + on_all_children_bits(self.tcx, self.body, self.move_data(), path, |child| { + self.set_drop_flag( + Location { block: target, statement_index: 0 }, + child, + DropFlagState::Present, + ); + self.set_drop_flag( + Location { block: unwind, statement_index: 0 }, + child, + DropFlagState::Present, + ); + }); + } + LookupResult::Parent(parent) => { + // drop and replace behind a pointer/array/whatever. The location + // must be initialized. + debug!("elaborate_drop_and_replace({:?}) - untracked {:?}", terminator, parent); + self.patch.patch_terminator( + bb, + TerminatorKind::Drop { place, target, unwind: Some(unwind) }, + ); + } + } + } + + fn constant_bool(&self, span: Span, val: bool) -> Rvalue<'tcx> { + Rvalue::Use(Operand::Constant(Box::new(Constant { + span, + user_ty: None, + literal: ConstantKind::from_bool(self.tcx, val), + }))) + } + + fn set_drop_flag(&mut self, loc: Location, path: MovePathIndex, val: DropFlagState) { + if let Some(&flag) = self.drop_flags.get(&path) { + let span = self.patch.source_info_for_location(self.body, loc).span; + let val = self.constant_bool(span, val.value()); + self.patch.add_assign(loc, Place::from(flag), val); + } + } + + fn drop_flags_on_init(&mut self) { + let loc = Location::START; + let span = self.patch.source_info_for_location(self.body, loc).span; + let false_ = self.constant_bool(span, false); + for flag in self.drop_flags.values() { + self.patch.add_assign(loc, Place::from(*flag), false_.clone()); + } + } + + fn drop_flags_for_fn_rets(&mut self) { + for (bb, data) in self.body.basic_blocks().iter_enumerated() { + if let TerminatorKind::Call { + destination, target: Some(tgt), cleanup: Some(_), .. + } = data.terminator().kind + { + assert!(!self.patch.is_patched(bb)); + + let loc = Location { block: tgt, statement_index: 0 }; + let path = self.move_data().rev_lookup.find(destination.as_ref()); + on_lookup_result_bits(self.tcx, self.body, self.move_data(), path, |child| { + self.set_drop_flag(loc, child, DropFlagState::Present) + }); + } + } + } + + fn drop_flags_for_args(&mut self) { + let loc = Location::START; + rustc_mir_dataflow::drop_flag_effects_for_function_entry( + self.tcx, + self.body, + self.env, + |path, ds| { + self.set_drop_flag(loc, path, ds); + }, + ) + } + + fn drop_flags_for_locs(&mut self) { + // We intentionally iterate only over the *old* basic blocks. + // + // Basic blocks created by drop elaboration update their + // drop flags by themselves, to avoid the drop flags being + // clobbered before they are read. + + for (bb, data) in self.body.basic_blocks().iter_enumerated() { + debug!("drop_flags_for_locs({:?})", data); + for i in 0..(data.statements.len() + 1) { + debug!("drop_flag_for_locs: stmt {}", i); + let mut allow_initializations = true; + if i == data.statements.len() { + match data.terminator().kind { + TerminatorKind::Drop { .. } => { + // drop elaboration should handle that by itself + continue; + } + TerminatorKind::DropAndReplace { .. } => { + // this contains the move of the source and + // the initialization of the destination. We + // only want the former - the latter is handled + // by the elaboration code and must be done + // *after* the destination is dropped. + assert!(self.patch.is_patched(bb)); + allow_initializations = false; + } + TerminatorKind::Resume => { + // It is possible for `Resume` to be patched + // (in particular it can be patched to be replaced with + // a Goto; see `MirPatch::new`). + } + _ => { + assert!(!self.patch.is_patched(bb)); + } + } + } + let loc = Location { block: bb, statement_index: i }; + rustc_mir_dataflow::drop_flag_effects_for_location( + self.tcx, + self.body, + self.env, + loc, + |path, ds| { + if ds == DropFlagState::Absent || allow_initializations { + self.set_drop_flag(loc, path, ds) + } + }, + ) + } + + // There may be a critical edge after this call, + // so mark the return as initialized *before* the + // call. + if let TerminatorKind::Call { destination, target: Some(_), cleanup: None, .. } = + data.terminator().kind + { + assert!(!self.patch.is_patched(bb)); + + let loc = Location { block: bb, statement_index: data.statements.len() }; + let path = self.move_data().rev_lookup.find(destination.as_ref()); + on_lookup_result_bits(self.tcx, self.body, self.move_data(), path, |child| { + self.set_drop_flag(loc, child, DropFlagState::Present) + }); + } + } + } +} diff --git a/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs b/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs new file mode 100644 index 000000000..7728fdaff --- /dev/null +++ b/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs @@ -0,0 +1,170 @@ +use rustc_hir::def_id::{CrateNum, LocalDefId, LOCAL_CRATE}; +use rustc_middle::mir::*; +use rustc_middle::ty::layout; +use rustc_middle::ty::query::Providers; +use rustc_middle::ty::{self, TyCtxt}; +use rustc_session::lint::builtin::FFI_UNWIND_CALLS; +use rustc_target::spec::abi::Abi; +use rustc_target::spec::PanicStrategy; + +fn abi_can_unwind(abi: Abi) -> bool { + use Abi::*; + match abi { + C { unwind } + | System { unwind } + | Cdecl { unwind } + | Stdcall { unwind } + | Fastcall { unwind } + | Vectorcall { unwind } + | Thiscall { unwind } + | Aapcs { unwind } + | Win64 { unwind } + | SysV64 { unwind } => unwind, + PtxKernel + | Msp430Interrupt + | X86Interrupt + | AmdGpuKernel + | EfiApi + | AvrInterrupt + | AvrNonBlockingInterrupt + | CCmseNonSecureCall + | Wasm + | RustIntrinsic + | PlatformIntrinsic + | Unadjusted => false, + Rust | RustCall | RustCold => true, + } +} + +// Check if the body of this def_id can possibly leak a foreign unwind into Rust code. +fn has_ffi_unwind_calls(tcx: TyCtxt<'_>, local_def_id: LocalDefId) -> bool { + debug!("has_ffi_unwind_calls({local_def_id:?})"); + + // Only perform check on functions because constants cannot call FFI functions. + let def_id = local_def_id.to_def_id(); + let kind = tcx.def_kind(def_id); + if !kind.is_fn_like() { + return false; + } + + let body = &*tcx.mir_built(ty::WithOptConstParam::unknown(local_def_id)).borrow(); + + let body_ty = tcx.type_of(def_id); + let body_abi = match body_ty.kind() { + ty::FnDef(..) => body_ty.fn_sig(tcx).abi(), + ty::Closure(..) => Abi::RustCall, + ty::Generator(..) => Abi::Rust, + _ => span_bug!(body.span, "unexpected body ty: {:?}", body_ty), + }; + let body_can_unwind = layout::fn_can_unwind(tcx, Some(def_id), body_abi); + + // Foreign unwinds cannot leak past functions that themselves cannot unwind. + if !body_can_unwind { + return false; + } + + let mut tainted = false; + + for block in body.basic_blocks() { + if block.is_cleanup { + continue; + } + let Some(terminator) = &block.terminator else { continue }; + let TerminatorKind::Call { func, .. } = &terminator.kind else { continue }; + + let ty = func.ty(body, tcx); + let sig = ty.fn_sig(tcx); + + // Rust calls cannot themselves create foreign unwinds. + if let Abi::Rust | Abi::RustCall | Abi::RustCold = sig.abi() { + continue; + }; + + let fn_def_id = match ty.kind() { + ty::FnPtr(_) => None, + &ty::FnDef(def_id, _) => { + // Rust calls cannot themselves create foreign unwinds. + if !tcx.is_foreign_item(def_id) { + continue; + } + Some(def_id) + } + _ => bug!("invalid callee of type {:?}", ty), + }; + + if layout::fn_can_unwind(tcx, fn_def_id, sig.abi()) && abi_can_unwind(sig.abi()) { + // We have detected a call that can possibly leak foreign unwind. + // + // Because the function body itself can unwind, we are not aborting this function call + // upon unwind, so this call can possibly leak foreign unwind into Rust code if the + // panic runtime linked is panic-abort. + + let lint_root = body.source_scopes[terminator.source_info.scope] + .local_data + .as_ref() + .assert_crate_local() + .lint_root; + let span = terminator.source_info.span; + + tcx.struct_span_lint_hir(FFI_UNWIND_CALLS, lint_root, span, |lint| { + let msg = match fn_def_id { + Some(_) => "call to foreign function with FFI-unwind ABI", + None => "call to function pointer with FFI-unwind ABI", + }; + let mut db = lint.build(msg); + db.span_label(span, msg); + db.emit(); + }); + + tainted = true; + } + } + + tainted +} + +fn required_panic_strategy(tcx: TyCtxt<'_>, cnum: CrateNum) -> Option<PanicStrategy> { + assert_eq!(cnum, LOCAL_CRATE); + + if tcx.is_panic_runtime(LOCAL_CRATE) { + return Some(tcx.sess.panic_strategy()); + } + + if tcx.sess.panic_strategy() == PanicStrategy::Abort { + return Some(PanicStrategy::Abort); + } + + for def_id in tcx.hir().body_owners() { + if tcx.has_ffi_unwind_calls(def_id) { + // Given that this crate is compiled in `-C panic=unwind`, the `AbortUnwindingCalls` + // MIR pass will not be run on FFI-unwind call sites, therefore a foreign exception + // can enter Rust through these sites. + // + // On the other hand, crates compiled with `-C panic=abort` expects that all Rust + // functions cannot unwind (whether it's caused by Rust panic or foreign exception), + // and this expectation mismatch can cause unsoundness (#96926). + // + // To address this issue, we enforce that if FFI-unwind calls are used in a crate + // compiled with `panic=unwind`, then the final panic strategy must be `panic=unwind`. + // This will ensure that no crates will have wrong unwindability assumption. + // + // It should be noted that it is okay to link `panic=unwind` into a `panic=abort` + // program if it contains no FFI-unwind calls. In such case foreign exception can only + // enter Rust in a `panic=abort` crate, which will lead to an abort. There will also + // be no exceptions generated from Rust, so the assumption which `panic=abort` crates + // make, that no Rust function can unwind, indeed holds for crates compiled with + // `panic=unwind` as well. In such case this function returns `None`, indicating that + // the crate does not require a particular final panic strategy, and can be freely + // linked to crates with either strategy (we need such ability for libstd and its + // dependencies). + return Some(PanicStrategy::Unwind); + } + } + + // This crate can be linked with either runtime. + None +} + +pub(crate) fn provide(providers: &mut Providers) { + *providers = Providers { has_ffi_unwind_calls, required_panic_strategy, ..*providers }; +} diff --git a/compiler/rustc_mir_transform/src/function_item_references.rs b/compiler/rustc_mir_transform/src/function_item_references.rs new file mode 100644 index 000000000..2e4fe1e3e --- /dev/null +++ b/compiler/rustc_mir_transform/src/function_item_references.rs @@ -0,0 +1,205 @@ +use itertools::Itertools; +use rustc_errors::Applicability; +use rustc_hir::def_id::DefId; +use rustc_middle::mir::visit::Visitor; +use rustc_middle::mir::*; +use rustc_middle::ty::{ + self, + subst::{GenericArgKind, Subst, SubstsRef}, + EarlyBinder, PredicateKind, Ty, TyCtxt, +}; +use rustc_session::lint::builtin::FUNCTION_ITEM_REFERENCES; +use rustc_span::{symbol::sym, Span}; +use rustc_target::spec::abi::Abi; + +use crate::MirLint; + +pub struct FunctionItemReferences; + +impl<'tcx> MirLint<'tcx> for FunctionItemReferences { + fn run_lint(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) { + let mut checker = FunctionItemRefChecker { tcx, body }; + checker.visit_body(&body); + } +} + +struct FunctionItemRefChecker<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + body: &'a Body<'tcx>, +} + +impl<'tcx> Visitor<'tcx> for FunctionItemRefChecker<'_, 'tcx> { + /// Emits a lint for function reference arguments bound by `fmt::Pointer` or passed to + /// `transmute`. This only handles arguments in calls outside macro expansions to avoid double + /// counting function references formatted as pointers by macros. + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { + if let TerminatorKind::Call { + func, + args, + destination: _, + target: _, + cleanup: _, + from_hir_call: _, + fn_span: _, + } = &terminator.kind + { + let source_info = *self.body.source_info(location); + let func_ty = func.ty(self.body, self.tcx); + if let ty::FnDef(def_id, substs_ref) = *func_ty.kind() { + // Handle calls to `transmute` + if self.tcx.is_diagnostic_item(sym::transmute, def_id) { + let arg_ty = args[0].ty(self.body, self.tcx); + for generic_inner_ty in arg_ty.walk() { + if let GenericArgKind::Type(inner_ty) = generic_inner_ty.unpack() { + if let Some((fn_id, fn_substs)) = + FunctionItemRefChecker::is_fn_ref(inner_ty) + { + let span = self.nth_arg_span(&args, 0); + self.emit_lint(fn_id, fn_substs, source_info, span); + } + } + } + } else { + self.check_bound_args(def_id, substs_ref, &args, source_info); + } + } + } + self.super_terminator(terminator, location); + } +} + +impl<'tcx> FunctionItemRefChecker<'_, 'tcx> { + /// Emits a lint for function reference arguments bound by `fmt::Pointer` in calls to the + /// function defined by `def_id` with the substitutions `substs_ref`. + fn check_bound_args( + &self, + def_id: DefId, + substs_ref: SubstsRef<'tcx>, + args: &[Operand<'tcx>], + source_info: SourceInfo, + ) { + let param_env = self.tcx.param_env(def_id); + let bounds = param_env.caller_bounds(); + for bound in bounds { + if let Some(bound_ty) = self.is_pointer_trait(&bound.kind().skip_binder()) { + // Get the argument types as they appear in the function signature. + let arg_defs = self.tcx.fn_sig(def_id).skip_binder().inputs(); + for (arg_num, arg_def) in arg_defs.iter().enumerate() { + // For all types reachable from the argument type in the fn sig + for generic_inner_ty in arg_def.walk() { + if let GenericArgKind::Type(inner_ty) = generic_inner_ty.unpack() { + // If the inner type matches the type bound by `Pointer` + if inner_ty == bound_ty { + // Do a substitution using the parameters from the callsite + let subst_ty = EarlyBinder(inner_ty).subst(self.tcx, substs_ref); + if let Some((fn_id, fn_substs)) = + FunctionItemRefChecker::is_fn_ref(subst_ty) + { + let mut span = self.nth_arg_span(args, arg_num); + if span.from_expansion() { + // The operand's ctxt wouldn't display the lint since it's inside a macro so + // we have to use the callsite's ctxt. + let callsite_ctxt = span.source_callsite().ctxt(); + span = span.with_ctxt(callsite_ctxt); + } + self.emit_lint(fn_id, fn_substs, source_info, span); + } + } + } + } + } + } + } + } + + /// If the given predicate is the trait `fmt::Pointer`, returns the bound parameter type. + fn is_pointer_trait(&self, bound: &PredicateKind<'tcx>) -> Option<Ty<'tcx>> { + if let ty::PredicateKind::Trait(predicate) = bound { + if self.tcx.is_diagnostic_item(sym::Pointer, predicate.def_id()) { + Some(predicate.trait_ref.self_ty()) + } else { + None + } + } else { + None + } + } + + /// If a type is a reference or raw pointer to the anonymous type of a function definition, + /// returns that function's `DefId` and `SubstsRef`. + fn is_fn_ref(ty: Ty<'tcx>) -> Option<(DefId, SubstsRef<'tcx>)> { + let referent_ty = match ty.kind() { + ty::Ref(_, referent_ty, _) => Some(referent_ty), + ty::RawPtr(ty_and_mut) => Some(&ty_and_mut.ty), + _ => None, + }; + referent_ty + .map(|ref_ty| { + if let ty::FnDef(def_id, substs_ref) = *ref_ty.kind() { + Some((def_id, substs_ref)) + } else { + None + } + }) + .unwrap_or(None) + } + + fn nth_arg_span(&self, args: &[Operand<'tcx>], n: usize) -> Span { + match &args[n] { + Operand::Copy(place) | Operand::Move(place) => { + self.body.local_decls[place.local].source_info.span + } + Operand::Constant(constant) => constant.span, + } + } + + fn emit_lint( + &self, + fn_id: DefId, + fn_substs: SubstsRef<'tcx>, + source_info: SourceInfo, + span: Span, + ) { + let lint_root = self.body.source_scopes[source_info.scope] + .local_data + .as_ref() + .assert_crate_local() + .lint_root; + let fn_sig = self.tcx.fn_sig(fn_id); + let unsafety = fn_sig.unsafety().prefix_str(); + let abi = match fn_sig.abi() { + Abi::Rust => String::from(""), + other_abi => { + let mut s = String::from("extern \""); + s.push_str(other_abi.name()); + s.push_str("\" "); + s + } + }; + let ident = self.tcx.item_name(fn_id).to_ident_string(); + let ty_params = fn_substs.types().map(|ty| format!("{}", ty)); + let const_params = fn_substs.consts().map(|c| format!("{}", c)); + let params = ty_params.chain(const_params).join(", "); + let num_args = fn_sig.inputs().map_bound(|inputs| inputs.len()).skip_binder(); + let variadic = if fn_sig.c_variadic() { ", ..." } else { "" }; + let ret = if fn_sig.output().skip_binder().is_unit() { "" } else { " -> _" }; + self.tcx.struct_span_lint_hir(FUNCTION_ITEM_REFERENCES, lint_root, span, |lint| { + lint.build("taking a reference to a function item does not give a function pointer") + .span_suggestion( + span, + &format!("cast `{}` to obtain a function pointer", ident), + format!( + "{} as {}{}fn({}{}){}", + if params.is_empty() { ident } else { format!("{}::<{}>", ident, params) }, + unsafety, + abi, + vec!["_"; num_args].join(", "), + variadic, + ret, + ), + Applicability::Unspecified, + ) + .emit(); + }); + } +} diff --git a/compiler/rustc_mir_transform/src/generator.rs b/compiler/rustc_mir_transform/src/generator.rs new file mode 100644 index 000000000..91ecf3879 --- /dev/null +++ b/compiler/rustc_mir_transform/src/generator.rs @@ -0,0 +1,1581 @@ +//! This is the implementation of the pass which transforms generators into state machines. +//! +//! MIR generation for generators creates a function which has a self argument which +//! passes by value. This argument is effectively a generator type which only contains upvars and +//! is only used for this argument inside the MIR for the generator. +//! It is passed by value to enable upvars to be moved out of it. Drop elaboration runs on that +//! MIR before this pass and creates drop flags for MIR locals. +//! It will also drop the generator argument (which only consists of upvars) if any of the upvars +//! are moved out of. This pass elaborates the drops of upvars / generator argument in the case +//! that none of the upvars were moved out of. This is because we cannot have any drops of this +//! generator in the MIR, since it is used to create the drop glue for the generator. We'd get +//! infinite recursion otherwise. +//! +//! This pass creates the implementation for the Generator::resume function and the drop shim +//! for the generator based on the MIR input. It converts the generator argument from Self to +//! &mut Self adding derefs in the MIR as needed. It computes the final layout of the generator +//! struct which looks like this: +//! First upvars are stored +//! It is followed by the generator state field. +//! Then finally the MIR locals which are live across a suspension point are stored. +//! ```ignore (illustrative) +//! struct Generator { +//! upvars..., +//! state: u32, +//! mir_locals..., +//! } +//! ``` +//! This pass computes the meaning of the state field and the MIR locals which are live +//! across a suspension point. There are however three hardcoded generator states: +//! 0 - Generator have not been resumed yet +//! 1 - Generator has returned / is completed +//! 2 - Generator has been poisoned +//! +//! It also rewrites `return x` and `yield y` as setting a new generator state and returning +//! GeneratorState::Complete(x) and GeneratorState::Yielded(y) respectively. +//! MIR locals which are live across a suspension point are moved to the generator struct +//! with references to them being updated with references to the generator struct. +//! +//! The pass creates two functions which have a switch on the generator state giving +//! the action to take. +//! +//! One of them is the implementation of Generator::resume. +//! For generators with state 0 (unresumed) it starts the execution of the generator. +//! For generators with state 1 (returned) and state 2 (poisoned) it panics. +//! Otherwise it continues the execution from the last suspension point. +//! +//! The other function is the drop glue for the generator. +//! For generators with state 0 (unresumed) it drops the upvars of the generator. +//! For generators with state 1 (returned) and state 2 (poisoned) it does nothing. +//! Otherwise it drops all the values in scope at the last suspension point. + +use crate::deref_separator::deref_finder; +use crate::simplify; +use crate::util::expand_aggregate; +use crate::MirPass; +use rustc_data_structures::fx::FxHashMap; +use rustc_hir as hir; +use rustc_hir::lang_items::LangItem; +use rustc_index::bit_set::{BitMatrix, BitSet, GrowableBitSet}; +use rustc_index::vec::{Idx, IndexVec}; +use rustc_middle::mir::dump_mir; +use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; +use rustc_middle::mir::*; +use rustc_middle::ty::subst::{Subst, SubstsRef}; +use rustc_middle::ty::GeneratorSubsts; +use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt}; +use rustc_mir_dataflow::impls::{ + MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage, MaybeStorageLive, +}; +use rustc_mir_dataflow::storage::always_storage_live_locals; +use rustc_mir_dataflow::{self, Analysis}; +use rustc_target::abi::VariantIdx; +use rustc_target::spec::PanicStrategy; +use std::{iter, ops}; + +pub struct StateTransform; + +struct RenameLocalVisitor<'tcx> { + from: Local, + to: Local, + tcx: TyCtxt<'tcx>, +} + +impl<'tcx> MutVisitor<'tcx> for RenameLocalVisitor<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) { + if *local == self.from { + *local = self.to; + } + } + + fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) { + match terminator.kind { + TerminatorKind::Return => { + // Do not replace the implicit `_0` access here, as that's not possible. The + // transform already handles `return` correctly. + } + _ => self.super_terminator(terminator, location), + } + } +} + +struct DerefArgVisitor<'tcx> { + tcx: TyCtxt<'tcx>, +} + +impl<'tcx> MutVisitor<'tcx> for DerefArgVisitor<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) { + assert_ne!(*local, SELF_ARG); + } + + fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) { + if place.local == SELF_ARG { + replace_base( + place, + Place { + local: SELF_ARG, + projection: self.tcx().intern_place_elems(&[ProjectionElem::Deref]), + }, + self.tcx, + ); + } else { + self.visit_local(&mut place.local, context, location); + + for elem in place.projection.iter() { + if let PlaceElem::Index(local) = elem { + assert_ne!(local, SELF_ARG); + } + } + } + } +} + +struct PinArgVisitor<'tcx> { + ref_gen_ty: Ty<'tcx>, + tcx: TyCtxt<'tcx>, +} + +impl<'tcx> MutVisitor<'tcx> for PinArgVisitor<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) { + assert_ne!(*local, SELF_ARG); + } + + fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) { + if place.local == SELF_ARG { + replace_base( + place, + Place { + local: SELF_ARG, + projection: self.tcx().intern_place_elems(&[ProjectionElem::Field( + Field::new(0), + self.ref_gen_ty, + )]), + }, + self.tcx, + ); + } else { + self.visit_local(&mut place.local, context, location); + + for elem in place.projection.iter() { + if let PlaceElem::Index(local) = elem { + assert_ne!(local, SELF_ARG); + } + } + } + } +} + +fn replace_base<'tcx>(place: &mut Place<'tcx>, new_base: Place<'tcx>, tcx: TyCtxt<'tcx>) { + place.local = new_base.local; + + let mut new_projection = new_base.projection.to_vec(); + new_projection.append(&mut place.projection.to_vec()); + + place.projection = tcx.intern_place_elems(&new_projection); +} + +const SELF_ARG: Local = Local::from_u32(1); + +/// Generator has not been resumed yet. +const UNRESUMED: usize = GeneratorSubsts::UNRESUMED; +/// Generator has returned / is completed. +const RETURNED: usize = GeneratorSubsts::RETURNED; +/// Generator has panicked and is poisoned. +const POISONED: usize = GeneratorSubsts::POISONED; + +/// Number of variants to reserve in generator state. Corresponds to +/// `UNRESUMED` (beginning of a generator) and `RETURNED`/`POISONED` +/// (end of a generator) states. +const RESERVED_VARIANTS: usize = 3; + +/// A `yield` point in the generator. +struct SuspensionPoint<'tcx> { + /// State discriminant used when suspending or resuming at this point. + state: usize, + /// The block to jump to after resumption. + resume: BasicBlock, + /// Where to move the resume argument after resumption. + resume_arg: Place<'tcx>, + /// Which block to jump to if the generator is dropped in this state. + drop: Option<BasicBlock>, + /// Set of locals that have live storage while at this suspension point. + storage_liveness: GrowableBitSet<Local>, +} + +struct TransformVisitor<'tcx> { + tcx: TyCtxt<'tcx>, + state_adt_ref: AdtDef<'tcx>, + state_substs: SubstsRef<'tcx>, + + // The type of the discriminant in the generator struct + discr_ty: Ty<'tcx>, + + // Mapping from Local to (type of local, generator struct index) + // FIXME(eddyb) This should use `IndexVec<Local, Option<_>>`. + remap: FxHashMap<Local, (Ty<'tcx>, VariantIdx, usize)>, + + // A map from a suspension point in a block to the locals which have live storage at that point + storage_liveness: IndexVec<BasicBlock, Option<BitSet<Local>>>, + + // A list of suspension points, generated during the transform + suspension_points: Vec<SuspensionPoint<'tcx>>, + + // The set of locals that have no `StorageLive`/`StorageDead` annotations. + always_live_locals: BitSet<Local>, + + // The original RETURN_PLACE local + new_ret_local: Local, +} + +impl<'tcx> TransformVisitor<'tcx> { + // Make a GeneratorState variant assignment. `core::ops::GeneratorState` only has single + // element tuple variants, so we can just write to the downcasted first field and then set the + // discriminant to the appropriate variant. + fn make_state( + &self, + idx: VariantIdx, + val: Operand<'tcx>, + source_info: SourceInfo, + ) -> impl Iterator<Item = Statement<'tcx>> { + let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_substs, None, None); + assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1); + let ty = self + .tcx + .bound_type_of(self.state_adt_ref.variant(idx).fields[0].did) + .subst(self.tcx, self.state_substs); + expand_aggregate( + Place::return_place(), + std::iter::once((val, ty)), + kind, + source_info, + self.tcx, + ) + } + + // Create a Place referencing a generator struct field + fn make_field(&self, variant_index: VariantIdx, idx: usize, ty: Ty<'tcx>) -> Place<'tcx> { + let self_place = Place::from(SELF_ARG); + let base = self.tcx.mk_place_downcast_unnamed(self_place, variant_index); + let mut projection = base.projection.to_vec(); + projection.push(ProjectionElem::Field(Field::new(idx), ty)); + + Place { local: base.local, projection: self.tcx.intern_place_elems(&projection) } + } + + // Create a statement which changes the discriminant + fn set_discr(&self, state_disc: VariantIdx, source_info: SourceInfo) -> Statement<'tcx> { + let self_place = Place::from(SELF_ARG); + Statement { + source_info, + kind: StatementKind::SetDiscriminant { + place: Box::new(self_place), + variant_index: state_disc, + }, + } + } + + // Create a statement which reads the discriminant into a temporary + fn get_discr(&self, body: &mut Body<'tcx>) -> (Statement<'tcx>, Place<'tcx>) { + let temp_decl = LocalDecl::new(self.discr_ty, body.span).internal(); + let local_decls_len = body.local_decls.push(temp_decl); + let temp = Place::from(local_decls_len); + + let self_place = Place::from(SELF_ARG); + let assign = Statement { + source_info: SourceInfo::outermost(body.span), + kind: StatementKind::Assign(Box::new((temp, Rvalue::Discriminant(self_place)))), + }; + (assign, temp) + } +} + +impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) { + assert_eq!(self.remap.get(local), None); + } + + fn visit_place( + &mut self, + place: &mut Place<'tcx>, + _context: PlaceContext, + _location: Location, + ) { + // Replace an Local in the remap with a generator struct access + if let Some(&(ty, variant_index, idx)) = self.remap.get(&place.local) { + replace_base(place, self.make_field(variant_index, idx, ty), self.tcx); + } + } + + fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) { + // Remove StorageLive and StorageDead statements for remapped locals + data.retain_statements(|s| match s.kind { + StatementKind::StorageLive(l) | StatementKind::StorageDead(l) => { + !self.remap.contains_key(&l) + } + _ => true, + }); + + let ret_val = match data.terminator().kind { + TerminatorKind::Return => Some(( + VariantIdx::new(1), + None, + Operand::Move(Place::from(self.new_ret_local)), + None, + )), + TerminatorKind::Yield { ref value, resume, resume_arg, drop } => { + Some((VariantIdx::new(0), Some((resume, resume_arg)), value.clone(), drop)) + } + _ => None, + }; + + if let Some((state_idx, resume, v, drop)) = ret_val { + let source_info = data.terminator().source_info; + // We must assign the value first in case it gets declared dead below + data.statements.extend(self.make_state(state_idx, v, source_info)); + let state = if let Some((resume, mut resume_arg)) = resume { + // Yield + let state = RESERVED_VARIANTS + self.suspension_points.len(); + + // The resume arg target location might itself be remapped if its base local is + // live across a yield. + let resume_arg = + if let Some(&(ty, variant, idx)) = self.remap.get(&resume_arg.local) { + replace_base(&mut resume_arg, self.make_field(variant, idx, ty), self.tcx); + resume_arg + } else { + resume_arg + }; + + self.suspension_points.push(SuspensionPoint { + state, + resume, + resume_arg, + drop, + storage_liveness: self.storage_liveness[block].clone().unwrap().into(), + }); + + VariantIdx::new(state) + } else { + // Return + VariantIdx::new(RETURNED) // state for returned + }; + data.statements.push(self.set_discr(state, source_info)); + data.terminator_mut().kind = TerminatorKind::Return; + } + + self.super_basic_block_data(block, data); + } +} + +fn make_generator_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let gen_ty = body.local_decls.raw[1].ty; + + let ref_gen_ty = + tcx.mk_ref(tcx.lifetimes.re_erased, ty::TypeAndMut { ty: gen_ty, mutbl: Mutability::Mut }); + + // Replace the by value generator argument + body.local_decls.raw[1].ty = ref_gen_ty; + + // Add a deref to accesses of the generator state + DerefArgVisitor { tcx }.visit_body(body); +} + +fn make_generator_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let ref_gen_ty = body.local_decls.raw[1].ty; + + let pin_did = tcx.require_lang_item(LangItem::Pin, Some(body.span)); + let pin_adt_ref = tcx.adt_def(pin_did); + let substs = tcx.intern_substs(&[ref_gen_ty.into()]); + let pin_ref_gen_ty = tcx.mk_adt(pin_adt_ref, substs); + + // Replace the by ref generator argument + body.local_decls.raw[1].ty = pin_ref_gen_ty; + + // Add the Pin field access to accesses of the generator state + PinArgVisitor { ref_gen_ty, tcx }.visit_body(body); +} + +/// Allocates a new local and replaces all references of `local` with it. Returns the new local. +/// +/// `local` will be changed to a new local decl with type `ty`. +/// +/// Note that the new local will be uninitialized. It is the caller's responsibility to assign some +/// valid value to it before its first use. +fn replace_local<'tcx>( + local: Local, + ty: Ty<'tcx>, + body: &mut Body<'tcx>, + tcx: TyCtxt<'tcx>, +) -> Local { + let new_decl = LocalDecl::new(ty, body.span); + let new_local = body.local_decls.push(new_decl); + body.local_decls.swap(local, new_local); + + RenameLocalVisitor { from: local, to: new_local, tcx }.visit_body(body); + + new_local +} + +struct LivenessInfo { + /// Which locals are live across any suspension point. + saved_locals: GeneratorSavedLocals, + + /// The set of saved locals live at each suspension point. + live_locals_at_suspension_points: Vec<BitSet<GeneratorSavedLocal>>, + + /// Parallel vec to the above with SourceInfo for each yield terminator. + source_info_at_suspension_points: Vec<SourceInfo>, + + /// For every saved local, the set of other saved locals that are + /// storage-live at the same time as this local. We cannot overlap locals in + /// the layout which have conflicting storage. + storage_conflicts: BitMatrix<GeneratorSavedLocal, GeneratorSavedLocal>, + + /// For every suspending block, the locals which are storage-live across + /// that suspension point. + storage_liveness: IndexVec<BasicBlock, Option<BitSet<Local>>>, +} + +fn locals_live_across_suspend_points<'tcx>( + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, + always_live_locals: &BitSet<Local>, + movable: bool, +) -> LivenessInfo { + let body_ref: &Body<'_> = &body; + + // Calculate when MIR locals have live storage. This gives us an upper bound of their + // lifetimes. + let mut storage_live = MaybeStorageLive::new(always_live_locals.clone()) + .into_engine(tcx, body_ref) + .iterate_to_fixpoint() + .into_results_cursor(body_ref); + + // Calculate the MIR locals which have been previously + // borrowed (even if they are still active). + let borrowed_locals_results = + MaybeBorrowedLocals.into_engine(tcx, body_ref).pass_name("generator").iterate_to_fixpoint(); + + let mut borrowed_locals_cursor = + rustc_mir_dataflow::ResultsCursor::new(body_ref, &borrowed_locals_results); + + // Calculate the MIR locals that we actually need to keep storage around + // for. + let requires_storage_results = MaybeRequiresStorage::new(body, &borrowed_locals_results) + .into_engine(tcx, body_ref) + .iterate_to_fixpoint(); + let mut requires_storage_cursor = + rustc_mir_dataflow::ResultsCursor::new(body_ref, &requires_storage_results); + + // Calculate the liveness of MIR locals ignoring borrows. + let mut liveness = MaybeLiveLocals + .into_engine(tcx, body_ref) + .pass_name("generator") + .iterate_to_fixpoint() + .into_results_cursor(body_ref); + + let mut storage_liveness_map = IndexVec::from_elem(None, body.basic_blocks()); + let mut live_locals_at_suspension_points = Vec::new(); + let mut source_info_at_suspension_points = Vec::new(); + let mut live_locals_at_any_suspension_point = BitSet::new_empty(body.local_decls.len()); + + for (block, data) in body.basic_blocks().iter_enumerated() { + if let TerminatorKind::Yield { .. } = data.terminator().kind { + let loc = Location { block, statement_index: data.statements.len() }; + + liveness.seek_to_block_end(block); + let mut live_locals: BitSet<_> = BitSet::new_empty(body.local_decls.len()); + live_locals.union(liveness.get()); + + if !movable { + // The `liveness` variable contains the liveness of MIR locals ignoring borrows. + // This is correct for movable generators since borrows cannot live across + // suspension points. However for immovable generators we need to account for + // borrows, so we conservatively assume that all borrowed locals are live until + // we find a StorageDead statement referencing the locals. + // To do this we just union our `liveness` result with `borrowed_locals`, which + // contains all the locals which has been borrowed before this suspension point. + // If a borrow is converted to a raw reference, we must also assume that it lives + // forever. Note that the final liveness is still bounded by the storage liveness + // of the local, which happens using the `intersect` operation below. + borrowed_locals_cursor.seek_before_primary_effect(loc); + live_locals.union(borrowed_locals_cursor.get()); + } + + // Store the storage liveness for later use so we can restore the state + // after a suspension point + storage_live.seek_before_primary_effect(loc); + storage_liveness_map[block] = Some(storage_live.get().clone()); + + // Locals live are live at this point only if they are used across + // suspension points (the `liveness` variable) + // and their storage is required (the `storage_required` variable) + requires_storage_cursor.seek_before_primary_effect(loc); + live_locals.intersect(requires_storage_cursor.get()); + + // The generator argument is ignored. + live_locals.remove(SELF_ARG); + + debug!("loc = {:?}, live_locals = {:?}", loc, live_locals); + + // Add the locals live at this suspension point to the set of locals which live across + // any suspension points + live_locals_at_any_suspension_point.union(&live_locals); + + live_locals_at_suspension_points.push(live_locals); + source_info_at_suspension_points.push(data.terminator().source_info); + } + } + + debug!("live_locals_anywhere = {:?}", live_locals_at_any_suspension_point); + let saved_locals = GeneratorSavedLocals(live_locals_at_any_suspension_point); + + // Renumber our liveness_map bitsets to include only the locals we are + // saving. + let live_locals_at_suspension_points = live_locals_at_suspension_points + .iter() + .map(|live_here| saved_locals.renumber_bitset(&live_here)) + .collect(); + + let storage_conflicts = compute_storage_conflicts( + body_ref, + &saved_locals, + always_live_locals.clone(), + requires_storage_results, + ); + + LivenessInfo { + saved_locals, + live_locals_at_suspension_points, + source_info_at_suspension_points, + storage_conflicts, + storage_liveness: storage_liveness_map, + } +} + +/// The set of `Local`s that must be saved across yield points. +/// +/// `GeneratorSavedLocal` is indexed in terms of the elements in this set; +/// i.e. `GeneratorSavedLocal::new(1)` corresponds to the second local +/// included in this set. +struct GeneratorSavedLocals(BitSet<Local>); + +impl GeneratorSavedLocals { + /// Returns an iterator over each `GeneratorSavedLocal` along with the `Local` it corresponds + /// to. + fn iter_enumerated(&self) -> impl '_ + Iterator<Item = (GeneratorSavedLocal, Local)> { + self.iter().enumerate().map(|(i, l)| (GeneratorSavedLocal::from(i), l)) + } + + /// Transforms a `BitSet<Local>` that contains only locals saved across yield points to the + /// equivalent `BitSet<GeneratorSavedLocal>`. + fn renumber_bitset(&self, input: &BitSet<Local>) -> BitSet<GeneratorSavedLocal> { + assert!(self.superset(&input), "{:?} not a superset of {:?}", self.0, input); + let mut out = BitSet::new_empty(self.count()); + for (saved_local, local) in self.iter_enumerated() { + if input.contains(local) { + out.insert(saved_local); + } + } + out + } + + fn get(&self, local: Local) -> Option<GeneratorSavedLocal> { + if !self.contains(local) { + return None; + } + + let idx = self.iter().take_while(|&l| l < local).count(); + Some(GeneratorSavedLocal::new(idx)) + } +} + +impl ops::Deref for GeneratorSavedLocals { + type Target = BitSet<Local>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +/// For every saved local, looks for which locals are StorageLive at the same +/// time. Generates a bitset for every local of all the other locals that may be +/// StorageLive simultaneously with that local. This is used in the layout +/// computation; see `GeneratorLayout` for more. +fn compute_storage_conflicts<'mir, 'tcx>( + body: &'mir Body<'tcx>, + saved_locals: &GeneratorSavedLocals, + always_live_locals: BitSet<Local>, + requires_storage: rustc_mir_dataflow::Results<'tcx, MaybeRequiresStorage<'mir, 'tcx>>, +) -> BitMatrix<GeneratorSavedLocal, GeneratorSavedLocal> { + assert_eq!(body.local_decls.len(), saved_locals.domain_size()); + + debug!("compute_storage_conflicts({:?})", body.span); + debug!("always_live = {:?}", always_live_locals); + + // Locals that are always live or ones that need to be stored across + // suspension points are not eligible for overlap. + let mut ineligible_locals = always_live_locals; + ineligible_locals.intersect(&**saved_locals); + + // Compute the storage conflicts for all eligible locals. + let mut visitor = StorageConflictVisitor { + body, + saved_locals: &saved_locals, + local_conflicts: BitMatrix::from_row_n(&ineligible_locals, body.local_decls.len()), + }; + + requires_storage.visit_reachable_with(body, &mut visitor); + + let local_conflicts = visitor.local_conflicts; + + // Compress the matrix using only stored locals (Local -> GeneratorSavedLocal). + // + // NOTE: Today we store a full conflict bitset for every local. Technically + // this is twice as many bits as we need, since the relation is symmetric. + // However, in practice these bitsets are not usually large. The layout code + // also needs to keep track of how many conflicts each local has, so it's + // simpler to keep it this way for now. + let mut storage_conflicts = BitMatrix::new(saved_locals.count(), saved_locals.count()); + for (saved_local_a, local_a) in saved_locals.iter_enumerated() { + if ineligible_locals.contains(local_a) { + // Conflicts with everything. + storage_conflicts.insert_all_into_row(saved_local_a); + } else { + // Keep overlap information only for stored locals. + for (saved_local_b, local_b) in saved_locals.iter_enumerated() { + if local_conflicts.contains(local_a, local_b) { + storage_conflicts.insert(saved_local_a, saved_local_b); + } + } + } + } + storage_conflicts +} + +struct StorageConflictVisitor<'mir, 'tcx, 's> { + body: &'mir Body<'tcx>, + saved_locals: &'s GeneratorSavedLocals, + // FIXME(tmandry): Consider using sparse bitsets here once we have good + // benchmarks for generators. + local_conflicts: BitMatrix<Local, Local>, +} + +impl<'mir, 'tcx> rustc_mir_dataflow::ResultsVisitor<'mir, 'tcx> + for StorageConflictVisitor<'mir, 'tcx, '_> +{ + type FlowState = BitSet<Local>; + + fn visit_statement_before_primary_effect( + &mut self, + state: &Self::FlowState, + _statement: &'mir Statement<'tcx>, + loc: Location, + ) { + self.apply_state(state, loc); + } + + fn visit_terminator_before_primary_effect( + &mut self, + state: &Self::FlowState, + _terminator: &'mir Terminator<'tcx>, + loc: Location, + ) { + self.apply_state(state, loc); + } +} + +impl StorageConflictVisitor<'_, '_, '_> { + fn apply_state(&mut self, flow_state: &BitSet<Local>, loc: Location) { + // Ignore unreachable blocks. + if self.body.basic_blocks()[loc.block].terminator().kind == TerminatorKind::Unreachable { + return; + } + + let mut eligible_storage_live = flow_state.clone(); + eligible_storage_live.intersect(&**self.saved_locals); + + for local in eligible_storage_live.iter() { + self.local_conflicts.union_row_with(&eligible_storage_live, local); + } + + if eligible_storage_live.count() > 1 { + trace!("at {:?}, eligible_storage_live={:?}", loc, eligible_storage_live); + } + } +} + +/// Validates the typeck view of the generator against the actual set of types saved between +/// yield points. +fn sanitize_witness<'tcx>( + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, + witness: Ty<'tcx>, + upvars: Vec<Ty<'tcx>>, + saved_locals: &GeneratorSavedLocals, +) { + let did = body.source.def_id(); + let param_env = tcx.param_env(did); + + let allowed_upvars = tcx.normalize_erasing_regions(param_env, upvars); + let allowed = match witness.kind() { + &ty::GeneratorWitness(interior_tys) => { + tcx.normalize_erasing_late_bound_regions(param_env, interior_tys) + } + _ => { + tcx.sess.delay_span_bug( + body.span, + &format!("unexpected generator witness type {:?}", witness.kind()), + ); + return; + } + }; + + for (local, decl) in body.local_decls.iter_enumerated() { + // Ignore locals which are internal or not saved between yields. + if !saved_locals.contains(local) || decl.internal { + continue; + } + let decl_ty = tcx.normalize_erasing_regions(param_env, decl.ty); + + // Sanity check that typeck knows about the type of locals which are + // live across a suspension point + if !allowed.contains(&decl_ty) && !allowed_upvars.contains(&decl_ty) { + span_bug!( + body.span, + "Broken MIR: generator contains type {} in MIR, \ + but typeck only knows about {} and {:?}", + decl_ty, + allowed, + allowed_upvars + ); + } + } +} + +fn compute_layout<'tcx>( + liveness: LivenessInfo, + body: &mut Body<'tcx>, +) -> ( + FxHashMap<Local, (Ty<'tcx>, VariantIdx, usize)>, + GeneratorLayout<'tcx>, + IndexVec<BasicBlock, Option<BitSet<Local>>>, +) { + let LivenessInfo { + saved_locals, + live_locals_at_suspension_points, + source_info_at_suspension_points, + storage_conflicts, + storage_liveness, + } = liveness; + + // Gather live local types and their indices. + let mut locals = IndexVec::<GeneratorSavedLocal, _>::new(); + let mut tys = IndexVec::<GeneratorSavedLocal, _>::new(); + for (saved_local, local) in saved_locals.iter_enumerated() { + locals.push(local); + tys.push(body.local_decls[local].ty); + debug!("generator saved local {:?} => {:?}", saved_local, local); + } + + // Leave empty variants for the UNRESUMED, RETURNED, and POISONED states. + // In debuginfo, these will correspond to the beginning (UNRESUMED) or end + // (RETURNED, POISONED) of the function. + let body_span = body.source_scopes[OUTERMOST_SOURCE_SCOPE].span; + let mut variant_source_info: IndexVec<VariantIdx, SourceInfo> = [ + SourceInfo::outermost(body_span.shrink_to_lo()), + SourceInfo::outermost(body_span.shrink_to_hi()), + SourceInfo::outermost(body_span.shrink_to_hi()), + ] + .iter() + .copied() + .collect(); + + // Build the generator variant field list. + // Create a map from local indices to generator struct indices. + let mut variant_fields: IndexVec<VariantIdx, IndexVec<Field, GeneratorSavedLocal>> = + iter::repeat(IndexVec::new()).take(RESERVED_VARIANTS).collect(); + let mut remap = FxHashMap::default(); + for (suspension_point_idx, live_locals) in live_locals_at_suspension_points.iter().enumerate() { + let variant_index = VariantIdx::from(RESERVED_VARIANTS + suspension_point_idx); + let mut fields = IndexVec::new(); + for (idx, saved_local) in live_locals.iter().enumerate() { + fields.push(saved_local); + // Note that if a field is included in multiple variants, we will + // just use the first one here. That's fine; fields do not move + // around inside generators, so it doesn't matter which variant + // index we access them by. + remap.entry(locals[saved_local]).or_insert((tys[saved_local], variant_index, idx)); + } + variant_fields.push(fields); + variant_source_info.push(source_info_at_suspension_points[suspension_point_idx]); + } + debug!("generator variant_fields = {:?}", variant_fields); + debug!("generator storage_conflicts = {:#?}", storage_conflicts); + + let layout = + GeneratorLayout { field_tys: tys, variant_fields, variant_source_info, storage_conflicts }; + + (remap, layout, storage_liveness) +} + +/// Replaces the entry point of `body` with a block that switches on the generator discriminant and +/// dispatches to blocks according to `cases`. +/// +/// After this function, the former entry point of the function will be bb1. +fn insert_switch<'tcx>( + body: &mut Body<'tcx>, + cases: Vec<(usize, BasicBlock)>, + transform: &TransformVisitor<'tcx>, + default: TerminatorKind<'tcx>, +) { + let default_block = insert_term_block(body, default); + let (assign, discr) = transform.get_discr(body); + let switch_targets = + SwitchTargets::new(cases.iter().map(|(i, bb)| ((*i) as u128, *bb)), default_block); + let switch = TerminatorKind::SwitchInt { + discr: Operand::Move(discr), + switch_ty: transform.discr_ty, + targets: switch_targets, + }; + + let source_info = SourceInfo::outermost(body.span); + body.basic_blocks_mut().raw.insert( + 0, + BasicBlockData { + statements: vec![assign], + terminator: Some(Terminator { source_info, kind: switch }), + is_cleanup: false, + }, + ); + + let blocks = body.basic_blocks_mut().iter_mut(); + + for target in blocks.flat_map(|b| b.terminator_mut().successors_mut()) { + *target = BasicBlock::new(target.index() + 1); + } +} + +fn elaborate_generator_drops<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + use crate::shim::DropShimElaborator; + use rustc_middle::mir::patch::MirPatch; + use rustc_mir_dataflow::elaborate_drops::{elaborate_drop, Unwind}; + + // Note that `elaborate_drops` only drops the upvars of a generator, and + // this is ok because `open_drop` can only be reached within that own + // generator's resume function. + + let def_id = body.source.def_id(); + let param_env = tcx.param_env(def_id); + + let mut elaborator = DropShimElaborator { body, patch: MirPatch::new(body), tcx, param_env }; + + for (block, block_data) in body.basic_blocks().iter_enumerated() { + let (target, unwind, source_info) = match block_data.terminator() { + Terminator { source_info, kind: TerminatorKind::Drop { place, target, unwind } } => { + if let Some(local) = place.as_local() { + if local == SELF_ARG { + (target, unwind, source_info) + } else { + continue; + } + } else { + continue; + } + } + _ => continue, + }; + let unwind = if block_data.is_cleanup { + Unwind::InCleanup + } else { + Unwind::To(unwind.unwrap_or_else(|| elaborator.patch.resume_block())) + }; + elaborate_drop( + &mut elaborator, + *source_info, + Place::from(SELF_ARG), + (), + *target, + unwind, + block, + ); + } + elaborator.patch.apply(body); +} + +fn create_generator_drop_shim<'tcx>( + tcx: TyCtxt<'tcx>, + transform: &TransformVisitor<'tcx>, + gen_ty: Ty<'tcx>, + body: &mut Body<'tcx>, + drop_clean: BasicBlock, +) -> Body<'tcx> { + let mut body = body.clone(); + body.arg_count = 1; // make sure the resume argument is not included here + + let source_info = SourceInfo::outermost(body.span); + + let mut cases = create_cases(&mut body, transform, Operation::Drop); + + cases.insert(0, (UNRESUMED, drop_clean)); + + // The returned state and the poisoned state fall through to the default + // case which is just to return + + insert_switch(&mut body, cases, &transform, TerminatorKind::Return); + + for block in body.basic_blocks_mut() { + let kind = &mut block.terminator_mut().kind; + if let TerminatorKind::GeneratorDrop = *kind { + *kind = TerminatorKind::Return; + } + } + + // Replace the return variable + body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(tcx.mk_unit(), source_info); + + make_generator_state_argument_indirect(tcx, &mut body); + + // Change the generator argument from &mut to *mut + body.local_decls[SELF_ARG] = LocalDecl::with_source_info( + tcx.mk_ptr(ty::TypeAndMut { ty: gen_ty, mutbl: hir::Mutability::Mut }), + source_info, + ); + if tcx.sess.opts.unstable_opts.mir_emit_retag { + // Alias tracking must know we changed the type + body.basic_blocks_mut()[START_BLOCK].statements.insert( + 0, + Statement { + source_info, + kind: StatementKind::Retag(RetagKind::Raw, Box::new(Place::from(SELF_ARG))), + }, + ) + } + + // Make sure we remove dead blocks to remove + // unrelated code from the resume part of the function + simplify::remove_dead_blocks(tcx, &mut body); + + dump_mir(tcx, None, "generator_drop", &0, &body, |_, _| Ok(())); + + body +} + +fn insert_term_block<'tcx>(body: &mut Body<'tcx>, kind: TerminatorKind<'tcx>) -> BasicBlock { + let source_info = SourceInfo::outermost(body.span); + body.basic_blocks_mut().push(BasicBlockData { + statements: Vec::new(), + terminator: Some(Terminator { source_info, kind }), + is_cleanup: false, + }) +} + +fn insert_panic_block<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + message: AssertMessage<'tcx>, +) -> BasicBlock { + let assert_block = BasicBlock::new(body.basic_blocks().len()); + let term = TerminatorKind::Assert { + cond: Operand::Constant(Box::new(Constant { + span: body.span, + user_ty: None, + literal: ConstantKind::from_bool(tcx, false), + })), + expected: true, + msg: message, + target: assert_block, + cleanup: None, + }; + + let source_info = SourceInfo::outermost(body.span); + body.basic_blocks_mut().push(BasicBlockData { + statements: Vec::new(), + terminator: Some(Terminator { source_info, kind: term }), + is_cleanup: false, + }); + + assert_block +} + +fn can_return<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, param_env: ty::ParamEnv<'tcx>) -> bool { + // Returning from a function with an uninhabited return type is undefined behavior. + if tcx.conservative_is_privately_uninhabited(param_env.and(body.return_ty())) { + return false; + } + + // If there's a return terminator the function may return. + for block in body.basic_blocks() { + if let TerminatorKind::Return = block.terminator().kind { + return true; + } + } + + // Otherwise the function can't return. + false +} + +fn can_unwind<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> bool { + // Nothing can unwind when landing pads are off. + if tcx.sess.panic_strategy() == PanicStrategy::Abort { + return false; + } + + // Unwinds can only start at certain terminators. + for block in body.basic_blocks() { + match block.terminator().kind { + // These never unwind. + TerminatorKind::Goto { .. } + | TerminatorKind::SwitchInt { .. } + | TerminatorKind::Abort + | TerminatorKind::Return + | TerminatorKind::Unreachable + | TerminatorKind::GeneratorDrop + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } => {} + + // Resume will *continue* unwinding, but if there's no other unwinding terminator it + // will never be reached. + TerminatorKind::Resume => {} + + TerminatorKind::Yield { .. } => { + unreachable!("`can_unwind` called before generator transform") + } + + // These may unwind. + TerminatorKind::Drop { .. } + | TerminatorKind::DropAndReplace { .. } + | TerminatorKind::Call { .. } + | TerminatorKind::InlineAsm { .. } + | TerminatorKind::Assert { .. } => return true, + } + } + + // If we didn't find an unwinding terminator, the function cannot unwind. + false +} + +fn create_generator_resume_function<'tcx>( + tcx: TyCtxt<'tcx>, + transform: TransformVisitor<'tcx>, + body: &mut Body<'tcx>, + can_return: bool, +) { + let can_unwind = can_unwind(tcx, body); + + // Poison the generator when it unwinds + if can_unwind { + let source_info = SourceInfo::outermost(body.span); + let poison_block = body.basic_blocks_mut().push(BasicBlockData { + statements: vec![transform.set_discr(VariantIdx::new(POISONED), source_info)], + terminator: Some(Terminator { source_info, kind: TerminatorKind::Resume }), + is_cleanup: true, + }); + + for (idx, block) in body.basic_blocks_mut().iter_enumerated_mut() { + let source_info = block.terminator().source_info; + + if let TerminatorKind::Resume = block.terminator().kind { + // An existing `Resume` terminator is redirected to jump to our dedicated + // "poisoning block" above. + if idx != poison_block { + *block.terminator_mut() = Terminator { + source_info, + kind: TerminatorKind::Goto { target: poison_block }, + }; + } + } else if !block.is_cleanup { + // Any terminators that *can* unwind but don't have an unwind target set are also + // pointed at our poisoning block (unless they're part of the cleanup path). + if let Some(unwind @ None) = block.terminator_mut().unwind_mut() { + *unwind = Some(poison_block); + } + } + } + } + + let mut cases = create_cases(body, &transform, Operation::Resume); + + use rustc_middle::mir::AssertKind::{ResumedAfterPanic, ResumedAfterReturn}; + + // Jump to the entry point on the unresumed + cases.insert(0, (UNRESUMED, BasicBlock::new(0))); + + // Panic when resumed on the returned or poisoned state + let generator_kind = body.generator_kind().unwrap(); + + if can_unwind { + cases.insert( + 1, + (POISONED, insert_panic_block(tcx, body, ResumedAfterPanic(generator_kind))), + ); + } + + if can_return { + cases.insert( + 1, + (RETURNED, insert_panic_block(tcx, body, ResumedAfterReturn(generator_kind))), + ); + } + + insert_switch(body, cases, &transform, TerminatorKind::Unreachable); + + make_generator_state_argument_indirect(tcx, body); + make_generator_state_argument_pinned(tcx, body); + + // Make sure we remove dead blocks to remove + // unrelated code from the drop part of the function + simplify::remove_dead_blocks(tcx, body); + + dump_mir(tcx, None, "generator_resume", &0, body, |_, _| Ok(())); +} + +fn insert_clean_drop(body: &mut Body<'_>) -> BasicBlock { + let return_block = insert_term_block(body, TerminatorKind::Return); + + let term = + TerminatorKind::Drop { place: Place::from(SELF_ARG), target: return_block, unwind: None }; + let source_info = SourceInfo::outermost(body.span); + + // Create a block to destroy an unresumed generators. This can only destroy upvars. + body.basic_blocks_mut().push(BasicBlockData { + statements: Vec::new(), + terminator: Some(Terminator { source_info, kind: term }), + is_cleanup: false, + }) +} + +/// An operation that can be performed on a generator. +#[derive(PartialEq, Copy, Clone)] +enum Operation { + Resume, + Drop, +} + +impl Operation { + fn target_block(self, point: &SuspensionPoint<'_>) -> Option<BasicBlock> { + match self { + Operation::Resume => Some(point.resume), + Operation::Drop => point.drop, + } + } +} + +fn create_cases<'tcx>( + body: &mut Body<'tcx>, + transform: &TransformVisitor<'tcx>, + operation: Operation, +) -> Vec<(usize, BasicBlock)> { + let tcx = transform.tcx; + + let source_info = SourceInfo::outermost(body.span); + + transform + .suspension_points + .iter() + .filter_map(|point| { + // Find the target for this suspension point, if applicable + operation.target_block(point).map(|target| { + let mut statements = Vec::new(); + + // Create StorageLive instructions for locals with live storage + for i in 0..(body.local_decls.len()) { + if i == 2 { + // The resume argument is live on function entry. Don't insert a + // `StorageLive`, or the following `Assign` will read from uninitialized + // memory. + continue; + } + + let l = Local::new(i); + let needs_storage_live = point.storage_liveness.contains(l) + && !transform.remap.contains_key(&l) + && !transform.always_live_locals.contains(l); + if needs_storage_live { + statements + .push(Statement { source_info, kind: StatementKind::StorageLive(l) }); + } + } + + if operation == Operation::Resume { + // Move the resume argument to the destination place of the `Yield` terminator + let resume_arg = Local::new(2); // 0 = return, 1 = self + + // handle `box yield` properly + let box_place = if let [projection @ .., ProjectionElem::Deref] = + &**point.resume_arg.projection + { + let box_place = + Place::from(point.resume_arg.local).project_deeper(projection, tcx); + + let box_ty = box_place.ty(&body.local_decls, tcx).ty; + + if box_ty.is_box() { Some((box_place, box_ty)) } else { None } + } else { + None + }; + + if let Some((box_place, box_ty)) = box_place { + let unique_did = box_ty + .ty_adt_def() + .expect("expected Box to be an Adt") + .non_enum_variant() + .fields[0] + .did; + + let Some(nonnull_def) = tcx.type_of(unique_did).ty_adt_def() else { + span_bug!(tcx.def_span(unique_did), "expected Box to contain Unique") + }; + + let nonnull_did = nonnull_def.non_enum_variant().fields[0].did; + + let (unique_ty, nonnull_ty, ptr_ty) = + crate::elaborate_box_derefs::build_ptr_tys( + tcx, + box_ty.boxed_ty(), + unique_did, + nonnull_did, + ); + + let ptr_local = body.local_decls.push(LocalDecl::new(ptr_ty, body.span)); + + statements.push(Statement { + source_info, + kind: StatementKind::StorageLive(ptr_local), + }); + + statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + Place::from(ptr_local), + Rvalue::Use(Operand::Copy(box_place.project_deeper( + &crate::elaborate_box_derefs::build_projection( + unique_ty, nonnull_ty, ptr_ty, + ), + tcx, + ))), + ))), + }); + + statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + Place::from(ptr_local) + .project_deeper(&[ProjectionElem::Deref], tcx), + Rvalue::Use(Operand::Move(resume_arg.into())), + ))), + }); + + statements.push(Statement { + source_info, + kind: StatementKind::StorageDead(ptr_local), + }); + } else { + statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + point.resume_arg, + Rvalue::Use(Operand::Move(resume_arg.into())), + ))), + }); + } + } + + // Then jump to the real target + let block = body.basic_blocks_mut().push(BasicBlockData { + statements, + terminator: Some(Terminator { + source_info, + kind: TerminatorKind::Goto { target }, + }), + is_cleanup: false, + }); + + (point.state, block) + }) + }) + .collect() +} + +impl<'tcx> MirPass<'tcx> for StateTransform { + fn phase_change(&self) -> Option<MirPhase> { + Some(MirPhase::GeneratorsLowered) + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let Some(yield_ty) = body.yield_ty() else { + // This only applies to generators + return; + }; + + assert!(body.generator_drop().is_none()); + + // The first argument is the generator type passed by value + let gen_ty = body.local_decls.raw[1].ty; + + // Get the interior types and substs which typeck computed + let (upvars, interior, discr_ty, movable) = match *gen_ty.kind() { + ty::Generator(_, substs, movability) => { + let substs = substs.as_generator(); + ( + substs.upvar_tys().collect(), + substs.witness(), + substs.discr_ty(tcx), + movability == hir::Movability::Movable, + ) + } + _ => { + tcx.sess + .delay_span_bug(body.span, &format!("unexpected generator type {}", gen_ty)); + return; + } + }; + + // Compute GeneratorState<yield_ty, return_ty> + let state_did = tcx.require_lang_item(LangItem::GeneratorState, None); + let state_adt_ref = tcx.adt_def(state_did); + let state_substs = tcx.intern_substs(&[yield_ty.into(), body.return_ty().into()]); + let ret_ty = tcx.mk_adt(state_adt_ref, state_substs); + + // We rename RETURN_PLACE which has type mir.return_ty to new_ret_local + // RETURN_PLACE then is a fresh unused local with type ret_ty. + let new_ret_local = replace_local(RETURN_PLACE, ret_ty, body, tcx); + + // We also replace the resume argument and insert an `Assign`. + // This is needed because the resume argument `_2` might be live across a `yield`, in which + // case there is no `Assign` to it that the transform can turn into a store to the generator + // state. After the yield the slot in the generator state would then be uninitialized. + let resume_local = Local::new(2); + let new_resume_local = + replace_local(resume_local, body.local_decls[resume_local].ty, body, tcx); + + // When first entering the generator, move the resume argument into its new local. + let source_info = SourceInfo::outermost(body.span); + let stmts = &mut body.basic_blocks_mut()[BasicBlock::new(0)].statements; + stmts.insert( + 0, + Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + new_resume_local.into(), + Rvalue::Use(Operand::Move(resume_local.into())), + ))), + }, + ); + + let always_live_locals = always_storage_live_locals(&body); + + let liveness_info = + locals_live_across_suspend_points(tcx, body, &always_live_locals, movable); + + sanitize_witness(tcx, body, interior, upvars, &liveness_info.saved_locals); + + if tcx.sess.opts.unstable_opts.validate_mir { + let mut vis = EnsureGeneratorFieldAssignmentsNeverAlias { + assigned_local: None, + saved_locals: &liveness_info.saved_locals, + storage_conflicts: &liveness_info.storage_conflicts, + }; + + vis.visit_body(body); + } + + // Extract locals which are live across suspension point into `layout` + // `remap` gives a mapping from local indices onto generator struct indices + // `storage_liveness` tells us which locals have live storage at suspension points + let (remap, layout, storage_liveness) = compute_layout(liveness_info, body); + + let can_return = can_return(tcx, body, tcx.param_env(body.source.def_id())); + + // Run the transformation which converts Places from Local to generator struct + // accesses for locals in `remap`. + // It also rewrites `return x` and `yield y` as writing a new generator state and returning + // GeneratorState::Complete(x) and GeneratorState::Yielded(y) respectively. + let mut transform = TransformVisitor { + tcx, + state_adt_ref, + state_substs, + remap, + storage_liveness, + always_live_locals, + suspension_points: Vec::new(), + new_ret_local, + discr_ty, + }; + transform.visit_body(body); + + // Update our MIR struct to reflect the changes we've made + body.arg_count = 2; // self, resume arg + body.spread_arg = None; + + body.generator.as_mut().unwrap().yield_ty = None; + body.generator.as_mut().unwrap().generator_layout = Some(layout); + + // Insert `drop(generator_struct)` which is used to drop upvars for generators in + // the unresumed state. + // This is expanded to a drop ladder in `elaborate_generator_drops`. + let drop_clean = insert_clean_drop(body); + + dump_mir(tcx, None, "generator_pre-elab", &0, body, |_, _| Ok(())); + + // Expand `drop(generator_struct)` to a drop ladder which destroys upvars. + // If any upvars are moved out of, drop elaboration will handle upvar destruction. + // However we need to also elaborate the code generated by `insert_clean_drop`. + elaborate_generator_drops(tcx, body); + + dump_mir(tcx, None, "generator_post-transform", &0, body, |_, _| Ok(())); + + // Create a copy of our MIR and use it to create the drop shim for the generator + let drop_shim = create_generator_drop_shim(tcx, &transform, gen_ty, body, drop_clean); + + body.generator.as_mut().unwrap().generator_drop = Some(drop_shim); + + // Create the Generator::resume function + create_generator_resume_function(tcx, transform, body, can_return); + + // Run derefer to fix Derefs that are not in the first place + deref_finder(tcx, body); + } +} + +/// Looks for any assignments between locals (e.g., `_4 = _5`) that will both be converted to fields +/// in the generator state machine but whose storage is not marked as conflicting +/// +/// Validation needs to happen immediately *before* `TransformVisitor` is invoked, not after. +/// +/// This condition would arise when the assignment is the last use of `_5` but the initial +/// definition of `_4` if we weren't extra careful to mark all locals used inside a statement as +/// conflicting. Non-conflicting generator saved locals may be stored at the same location within +/// the generator state machine, which would result in ill-formed MIR: the left-hand and right-hand +/// sides of an assignment may not alias. This caused a miscompilation in [#73137]. +/// +/// [#73137]: https://github.com/rust-lang/rust/issues/73137 +struct EnsureGeneratorFieldAssignmentsNeverAlias<'a> { + saved_locals: &'a GeneratorSavedLocals, + storage_conflicts: &'a BitMatrix<GeneratorSavedLocal, GeneratorSavedLocal>, + assigned_local: Option<GeneratorSavedLocal>, +} + +impl EnsureGeneratorFieldAssignmentsNeverAlias<'_> { + fn saved_local_for_direct_place(&self, place: Place<'_>) -> Option<GeneratorSavedLocal> { + if place.is_indirect() { + return None; + } + + self.saved_locals.get(place.local) + } + + fn check_assigned_place(&mut self, place: Place<'_>, f: impl FnOnce(&mut Self)) { + if let Some(assigned_local) = self.saved_local_for_direct_place(place) { + assert!(self.assigned_local.is_none(), "`check_assigned_place` must not recurse"); + + self.assigned_local = Some(assigned_local); + f(self); + self.assigned_local = None; + } + } +} + +impl<'tcx> Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> { + fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) { + let Some(lhs) = self.assigned_local else { + // This visitor only invokes `visit_place` for the right-hand side of an assignment + // and only after setting `self.assigned_local`. However, the default impl of + // `Visitor::super_body` may call `visit_place` with a `NonUseContext` for places + // with debuginfo. Ignore them here. + assert!(!context.is_use()); + return; + }; + + let Some(rhs) = self.saved_local_for_direct_place(*place) else { return }; + + if !self.storage_conflicts.contains(lhs, rhs) { + bug!( + "Assignment between generator saved locals whose storage is not \ + marked as conflicting: {:?}: {:?} = {:?}", + location, + lhs, + rhs, + ); + } + } + + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + match &statement.kind { + StatementKind::Assign(box (lhs, rhs)) => { + self.check_assigned_place(*lhs, |this| this.visit_rvalue(rhs, location)); + } + + StatementKind::FakeRead(..) + | StatementKind::SetDiscriminant { .. } + | StatementKind::Deinit(..) + | StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::Retag(..) + | StatementKind::AscribeUserType(..) + | StatementKind::Coverage(..) + | StatementKind::CopyNonOverlapping(..) + | StatementKind::Nop => {} + } + } + + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { + // Checking for aliasing in terminators is probably overkill, but until we have actual + // semantics, we should be conservative here. + match &terminator.kind { + TerminatorKind::Call { + func, + args, + destination, + target: Some(_), + cleanup: _, + from_hir_call: _, + fn_span: _, + } => { + self.check_assigned_place(*destination, |this| { + this.visit_operand(func, location); + for arg in args { + this.visit_operand(arg, location); + } + }); + } + + TerminatorKind::Yield { value, resume: _, resume_arg, drop: _ } => { + self.check_assigned_place(*resume_arg, |this| this.visit_operand(value, location)); + } + + // FIXME: Does `asm!` have any aliasing requirements? + TerminatorKind::InlineAsm { .. } => {} + + TerminatorKind::Call { .. } + | TerminatorKind::Goto { .. } + | TerminatorKind::SwitchInt { .. } + | TerminatorKind::Resume + | TerminatorKind::Abort + | TerminatorKind::Return + | TerminatorKind::Unreachable + | TerminatorKind::Drop { .. } + | TerminatorKind::DropAndReplace { .. } + | TerminatorKind::Assert { .. } + | TerminatorKind::GeneratorDrop + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } => {} + } + } +} diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs new file mode 100644 index 000000000..76b1522f3 --- /dev/null +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -0,0 +1,1006 @@ +//! Inlining pass for MIR functions +use crate::deref_separator::deref_finder; +use rustc_attr::InlineAttr; +use rustc_const_eval::transform::validate::equal_up_to_regions; +use rustc_index::bit_set::BitSet; +use rustc_index::vec::Idx; +use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs}; +use rustc_middle::mir::visit::*; +use rustc_middle::mir::*; +use rustc_middle::ty::subst::Subst; +use rustc_middle::ty::{self, ConstKind, Instance, InstanceDef, ParamEnv, Ty, TyCtxt}; +use rustc_span::{hygiene::ExpnKind, ExpnData, LocalExpnId, Span}; +use rustc_target::spec::abi::Abi; + +use super::simplify::{remove_dead_blocks, CfgSimplifier}; +use crate::MirPass; +use std::iter; +use std::ops::{Range, RangeFrom}; + +pub(crate) mod cycle; + +const INSTR_COST: usize = 5; +const CALL_PENALTY: usize = 25; +const LANDINGPAD_PENALTY: usize = 50; +const RESUME_PENALTY: usize = 45; + +const UNKNOWN_SIZE_COST: usize = 10; + +pub struct Inline; + +#[derive(Copy, Clone, Debug)] +struct CallSite<'tcx> { + callee: Instance<'tcx>, + fn_sig: ty::PolyFnSig<'tcx>, + block: BasicBlock, + target: Option<BasicBlock>, + source_info: SourceInfo, +} + +impl<'tcx> MirPass<'tcx> for Inline { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + if let Some(enabled) = sess.opts.unstable_opts.inline_mir { + return enabled; + } + + // rust-lang/rust#101004: reverted to old inlining decision logic + sess.mir_opt_level() >= 3 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let span = trace_span!("inline", body = %tcx.def_path_str(body.source.def_id())); + let _guard = span.enter(); + if inline(tcx, body) { + debug!("running simplify cfg on {:?}", body.source); + CfgSimplifier::new(body).simplify(); + remove_dead_blocks(tcx, body); + deref_finder(tcx, body); + } + } +} + +fn inline<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> bool { + let def_id = body.source.def_id().expect_local(); + + // Only do inlining into fn bodies. + if !tcx.hir().body_owner_kind(def_id).is_fn_or_closure() { + return false; + } + if body.source.promoted.is_some() { + return false; + } + // Avoid inlining into generators, since their `optimized_mir` is used for layout computation, + // which can create a cycle, even when no attempt is made to inline the function in the other + // direction. + if body.generator.is_some() { + return false; + } + + let param_env = tcx.param_env_reveal_all_normalized(def_id); + + let mut this = Inliner { + tcx, + param_env, + codegen_fn_attrs: tcx.codegen_fn_attrs(def_id), + history: Vec::new(), + changed: false, + }; + let blocks = BasicBlock::new(0)..body.basic_blocks().next_index(); + this.process_blocks(body, blocks); + this.changed +} + +struct Inliner<'tcx> { + tcx: TyCtxt<'tcx>, + param_env: ParamEnv<'tcx>, + /// Caller codegen attributes. + codegen_fn_attrs: &'tcx CodegenFnAttrs, + /// Stack of inlined Instances. + history: Vec<ty::Instance<'tcx>>, + /// Indicates that the caller body has been modified. + changed: bool, +} + +impl<'tcx> Inliner<'tcx> { + fn process_blocks(&mut self, caller_body: &mut Body<'tcx>, blocks: Range<BasicBlock>) { + for bb in blocks { + let bb_data = &caller_body[bb]; + if bb_data.is_cleanup { + continue; + } + + let Some(callsite) = self.resolve_callsite(caller_body, bb, bb_data) else { + continue; + }; + + let span = trace_span!("process_blocks", %callsite.callee, ?bb); + let _guard = span.enter(); + + match self.try_inlining(caller_body, &callsite) { + Err(reason) => { + debug!("not-inlined {} [{}]", callsite.callee, reason); + continue; + } + Ok(new_blocks) => { + debug!("inlined {}", callsite.callee); + self.changed = true; + self.history.push(callsite.callee); + self.process_blocks(caller_body, new_blocks); + self.history.pop(); + } + } + } + } + + /// Attempts to inline a callsite into the caller body. When successful returns basic blocks + /// containing the inlined body. Otherwise returns an error describing why inlining didn't take + /// place. + fn try_inlining( + &self, + caller_body: &mut Body<'tcx>, + callsite: &CallSite<'tcx>, + ) -> Result<std::ops::Range<BasicBlock>, &'static str> { + let callee_attrs = self.tcx.codegen_fn_attrs(callsite.callee.def_id()); + self.check_codegen_attributes(callsite, callee_attrs)?; + self.check_mir_is_available(caller_body, &callsite.callee)?; + let callee_body = self.tcx.instance_mir(callsite.callee.def); + self.check_mir_body(callsite, callee_body, callee_attrs)?; + + if !self.tcx.consider_optimizing(|| { + format!("Inline {:?} into {:?}", callsite.callee, caller_body.source) + }) { + return Err("optimization fuel exhausted"); + } + + let Ok(callee_body) = callsite.callee.try_subst_mir_and_normalize_erasing_regions( + self.tcx, + self.param_env, + callee_body.clone(), + ) else { + return Err("failed to normalize callee body"); + }; + + // Check call signature compatibility. + // Normally, this shouldn't be required, but trait normalization failure can create a + // validation ICE. + let terminator = caller_body[callsite.block].terminator.as_ref().unwrap(); + let TerminatorKind::Call { args, destination, .. } = &terminator.kind else { bug!() }; + let destination_ty = destination.ty(&caller_body.local_decls, self.tcx).ty; + let output_type = callee_body.return_ty(); + if !equal_up_to_regions(self.tcx, self.param_env, output_type, destination_ty) { + trace!(?output_type, ?destination_ty); + return Err("failed to normalize return type"); + } + if callsite.fn_sig.abi() == Abi::RustCall { + let (arg_tuple, skipped_args) = match &args[..] { + [arg_tuple] => (arg_tuple, 0), + [_, arg_tuple] => (arg_tuple, 1), + _ => bug!("Expected `rust-call` to have 1 or 2 args"), + }; + + let arg_tuple_ty = arg_tuple.ty(&caller_body.local_decls, self.tcx); + let ty::Tuple(arg_tuple_tys) = arg_tuple_ty.kind() else { + bug!("Closure arguments are not passed as a tuple"); + }; + + for (arg_ty, input) in + arg_tuple_tys.iter().zip(callee_body.args_iter().skip(skipped_args)) + { + let input_type = callee_body.local_decls[input].ty; + if !equal_up_to_regions(self.tcx, self.param_env, arg_ty, input_type) { + trace!(?arg_ty, ?input_type); + return Err("failed to normalize tuple argument type"); + } + } + } else { + for (arg, input) in args.iter().zip(callee_body.args_iter()) { + let input_type = callee_body.local_decls[input].ty; + let arg_ty = arg.ty(&caller_body.local_decls, self.tcx); + if !equal_up_to_regions(self.tcx, self.param_env, arg_ty, input_type) { + trace!(?arg_ty, ?input_type); + return Err("failed to normalize argument type"); + } + } + } + + let old_blocks = caller_body.basic_blocks().next_index(); + self.inline_call(caller_body, &callsite, callee_body); + let new_blocks = old_blocks..caller_body.basic_blocks().next_index(); + + Ok(new_blocks) + } + + fn check_mir_is_available( + &self, + caller_body: &Body<'tcx>, + callee: &Instance<'tcx>, + ) -> Result<(), &'static str> { + let caller_def_id = caller_body.source.def_id(); + let callee_def_id = callee.def_id(); + if callee_def_id == caller_def_id { + return Err("self-recursion"); + } + + match callee.def { + InstanceDef::Item(_) => { + // If there is no MIR available (either because it was not in metadata or + // because it has no MIR because it's an extern function), then the inliner + // won't cause cycles on this. + if !self.tcx.is_mir_available(callee_def_id) { + return Err("item MIR unavailable"); + } + } + // These have no own callable MIR. + InstanceDef::Intrinsic(_) | InstanceDef::Virtual(..) => { + return Err("instance without MIR (intrinsic / virtual)"); + } + // This cannot result in an immediate cycle since the callee MIR is a shim, which does + // not get any optimizations run on it. Any subsequent inlining may cause cycles, but we + // do not need to catch this here, we can wait until the inliner decides to continue + // inlining a second time. + InstanceDef::VTableShim(_) + | InstanceDef::ReifyShim(_) + | InstanceDef::FnPtrShim(..) + | InstanceDef::ClosureOnceShim { .. } + | InstanceDef::DropGlue(..) + | InstanceDef::CloneShim(..) => return Ok(()), + } + + if self.tcx.is_constructor(callee_def_id) { + trace!("constructors always have MIR"); + // Constructor functions cannot cause a query cycle. + return Ok(()); + } + + if callee_def_id.is_local() { + // Avoid a cycle here by only using `instance_mir` only if we have + // a lower `DefPathHash` than the callee. This ensures that the callee will + // not inline us. This trick even works with incremental compilation, + // since `DefPathHash` is stable. + if self.tcx.def_path_hash(caller_def_id).local_hash() + < self.tcx.def_path_hash(callee_def_id).local_hash() + { + return Ok(()); + } + + // If we know for sure that the function we're calling will itself try to + // call us, then we avoid inlining that function. + if self.tcx.mir_callgraph_reachable((*callee, caller_def_id.expect_local())) { + return Err("caller might be reachable from callee (query cycle avoidance)"); + } + + Ok(()) + } else { + // This cannot result in an immediate cycle since the callee MIR is from another crate + // and is already optimized. Any subsequent inlining may cause cycles, but we do + // not need to catch this here, we can wait until the inliner decides to continue + // inlining a second time. + trace!("functions from other crates always have MIR"); + Ok(()) + } + } + + fn resolve_callsite( + &self, + caller_body: &Body<'tcx>, + bb: BasicBlock, + bb_data: &BasicBlockData<'tcx>, + ) -> Option<CallSite<'tcx>> { + // Only consider direct calls to functions + let terminator = bb_data.terminator(); + if let TerminatorKind::Call { ref func, target, .. } = terminator.kind { + let func_ty = func.ty(caller_body, self.tcx); + if let ty::FnDef(def_id, substs) = *func_ty.kind() { + // To resolve an instance its substs have to be fully normalized. + let substs = self.tcx.try_normalize_erasing_regions(self.param_env, substs).ok()?; + let callee = + Instance::resolve(self.tcx, self.param_env, def_id, substs).ok().flatten()?; + + if let InstanceDef::Virtual(..) | InstanceDef::Intrinsic(_) = callee.def { + return None; + } + + if self.history.contains(&callee) { + return None; + } + + let fn_sig = self.tcx.bound_fn_sig(def_id).subst(self.tcx, substs); + + return Some(CallSite { + callee, + fn_sig, + block: bb, + target, + source_info: terminator.source_info, + }); + } + } + + None + } + + /// Returns an error if inlining is not possible based on codegen attributes alone. A success + /// indicates that inlining decision should be based on other criteria. + fn check_codegen_attributes( + &self, + callsite: &CallSite<'tcx>, + callee_attrs: &CodegenFnAttrs, + ) -> Result<(), &'static str> { + match callee_attrs.inline { + InlineAttr::Never => return Err("never inline hint"), + InlineAttr::Always | InlineAttr::Hint => {} + InlineAttr::None => { + if self.tcx.sess.mir_opt_level() <= 2 { + return Err("at mir-opt-level=2, only #[inline] is inlined"); + } + } + } + + // Only inline local functions if they would be eligible for cross-crate + // inlining. This is to ensure that the final crate doesn't have MIR that + // reference unexported symbols + if callsite.callee.def_id().is_local() { + let is_generic = callsite.callee.substs.non_erasable_generics().next().is_some(); + if !is_generic && !callee_attrs.requests_inline() { + return Err("not exported"); + } + } + + if callsite.fn_sig.c_variadic() { + return Err("C variadic"); + } + + if callee_attrs.flags.contains(CodegenFnAttrFlags::NAKED) { + return Err("naked"); + } + + if callee_attrs.flags.contains(CodegenFnAttrFlags::COLD) { + return Err("cold"); + } + + if callee_attrs.no_sanitize != self.codegen_fn_attrs.no_sanitize { + return Err("incompatible sanitizer set"); + } + + if callee_attrs.instruction_set != self.codegen_fn_attrs.instruction_set { + return Err("incompatible instruction set"); + } + + for feature in &callee_attrs.target_features { + if !self.codegen_fn_attrs.target_features.contains(feature) { + return Err("incompatible target feature"); + } + } + + Ok(()) + } + + /// Returns inlining decision that is based on the examination of callee MIR body. + /// Assumes that codegen attributes have been checked for compatibility already. + #[instrument(level = "debug", skip(self, callee_body))] + fn check_mir_body( + &self, + callsite: &CallSite<'tcx>, + callee_body: &Body<'tcx>, + callee_attrs: &CodegenFnAttrs, + ) -> Result<(), &'static str> { + let tcx = self.tcx; + + let mut threshold = if callee_attrs.requests_inline() { + self.tcx.sess.opts.unstable_opts.inline_mir_hint_threshold.unwrap_or(100) + } else { + self.tcx.sess.opts.unstable_opts.inline_mir_threshold.unwrap_or(50) + }; + + // Give a bonus functions with a small number of blocks, + // We normally have two or three blocks for even + // very small functions. + if callee_body.basic_blocks().len() <= 3 { + threshold += threshold / 4; + } + debug!(" final inline threshold = {}", threshold); + + // FIXME: Give a bonus to functions with only a single caller + let mut first_block = true; + let mut cost = 0; + + // Traverse the MIR manually so we can account for the effects of + // inlining on the CFG. + let mut work_list = vec![START_BLOCK]; + let mut visited = BitSet::new_empty(callee_body.basic_blocks().len()); + while let Some(bb) = work_list.pop() { + if !visited.insert(bb.index()) { + continue; + } + let blk = &callee_body.basic_blocks()[bb]; + + for stmt in &blk.statements { + // Don't count StorageLive/StorageDead in the inlining cost. + match stmt.kind { + StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::Deinit(_) + | StatementKind::Nop => {} + _ => cost += INSTR_COST, + } + } + let term = blk.terminator(); + let mut is_drop = false; + match term.kind { + TerminatorKind::Drop { ref place, target, unwind } + | TerminatorKind::DropAndReplace { ref place, target, unwind, .. } => { + is_drop = true; + work_list.push(target); + // If the place doesn't actually need dropping, treat it like + // a regular goto. + let ty = callsite.callee.subst_mir(self.tcx, &place.ty(callee_body, tcx).ty); + if ty.needs_drop(tcx, self.param_env) { + cost += CALL_PENALTY; + if let Some(unwind) = unwind { + cost += LANDINGPAD_PENALTY; + work_list.push(unwind); + } + } else { + cost += INSTR_COST; + } + } + + TerminatorKind::Unreachable | TerminatorKind::Call { target: None, .. } + if first_block => + { + // If the function always diverges, don't inline + // unless the cost is zero + threshold = 0; + } + + TerminatorKind::Call { func: Operand::Constant(ref f), cleanup, .. } => { + if let ty::FnDef(def_id, _) = + *callsite.callee.subst_mir(self.tcx, &f.literal.ty()).kind() + { + // Don't give intrinsics the extra penalty for calls + if tcx.is_intrinsic(def_id) { + cost += INSTR_COST; + } else { + cost += CALL_PENALTY; + } + } else { + cost += CALL_PENALTY; + } + if cleanup.is_some() { + cost += LANDINGPAD_PENALTY; + } + } + TerminatorKind::Assert { cleanup, .. } => { + cost += CALL_PENALTY; + + if cleanup.is_some() { + cost += LANDINGPAD_PENALTY; + } + } + TerminatorKind::Resume => cost += RESUME_PENALTY, + TerminatorKind::InlineAsm { cleanup, .. } => { + cost += INSTR_COST; + + if cleanup.is_some() { + cost += LANDINGPAD_PENALTY; + } + } + _ => cost += INSTR_COST, + } + + if !is_drop { + for succ in term.successors() { + work_list.push(succ); + } + } + + first_block = false; + } + + // Count up the cost of local variables and temps, if we know the size + // use that, otherwise we use a moderately-large dummy cost. + + let ptr_size = tcx.data_layout.pointer_size.bytes(); + + for v in callee_body.vars_and_temps_iter() { + let ty = callsite.callee.subst_mir(self.tcx, &callee_body.local_decls[v].ty); + // Cost of the var is the size in machine-words, if we know + // it. + if let Some(size) = type_size_of(tcx, self.param_env, ty) { + cost += ((size + ptr_size - 1) / ptr_size) as usize; + } else { + cost += UNKNOWN_SIZE_COST; + } + } + + if let InlineAttr::Always = callee_attrs.inline { + debug!("INLINING {:?} because inline(always) [cost={}]", callsite, cost); + Ok(()) + } else if cost <= threshold { + debug!("INLINING {:?} [cost={} <= threshold={}]", callsite, cost, threshold); + Ok(()) + } else { + debug!("NOT inlining {:?} [cost={} > threshold={}]", callsite, cost, threshold); + Err("cost above threshold") + } + } + + fn inline_call( + &self, + caller_body: &mut Body<'tcx>, + callsite: &CallSite<'tcx>, + mut callee_body: Body<'tcx>, + ) { + let terminator = caller_body[callsite.block].terminator.take().unwrap(); + match terminator.kind { + TerminatorKind::Call { args, destination, cleanup, .. } => { + // If the call is something like `a[*i] = f(i)`, where + // `i : &mut usize`, then just duplicating the `a[*i]` + // Place could result in two different locations if `f` + // writes to `i`. To prevent this we need to create a temporary + // borrow of the place and pass the destination as `*temp` instead. + fn dest_needs_borrow(place: Place<'_>) -> bool { + for elem in place.projection.iter() { + match elem { + ProjectionElem::Deref | ProjectionElem::Index(_) => return true, + _ => {} + } + } + + false + } + + let dest = if dest_needs_borrow(destination) { + trace!("creating temp for return destination"); + let dest = Rvalue::Ref( + self.tcx.lifetimes.re_erased, + BorrowKind::Mut { allow_two_phase_borrow: false }, + destination, + ); + let dest_ty = dest.ty(caller_body, self.tcx); + let temp = Place::from(self.new_call_temp(caller_body, &callsite, dest_ty)); + caller_body[callsite.block].statements.push(Statement { + source_info: callsite.source_info, + kind: StatementKind::Assign(Box::new((temp, dest))), + }); + self.tcx.mk_place_deref(temp) + } else { + destination + }; + + // Copy the arguments if needed. + let args: Vec<_> = self.make_call_args(args, &callsite, caller_body, &callee_body); + + let mut expn_data = ExpnData::default( + ExpnKind::Inlined, + callsite.source_info.span, + self.tcx.sess.edition(), + None, + None, + ); + expn_data.def_site = callee_body.span; + let expn_data = + self.tcx.with_stable_hashing_context(|hcx| LocalExpnId::fresh(expn_data, hcx)); + let mut integrator = Integrator { + args: &args, + new_locals: Local::new(caller_body.local_decls.len()).., + new_scopes: SourceScope::new(caller_body.source_scopes.len()).., + new_blocks: BasicBlock::new(caller_body.basic_blocks().len()).., + destination: dest, + callsite_scope: caller_body.source_scopes[callsite.source_info.scope].clone(), + callsite, + cleanup_block: cleanup, + in_cleanup_block: false, + tcx: self.tcx, + expn_data, + always_live_locals: BitSet::new_filled(callee_body.local_decls.len()), + }; + + // Map all `Local`s, `SourceScope`s and `BasicBlock`s to new ones + // (or existing ones, in a few special cases) in the caller. + integrator.visit_body(&mut callee_body); + + // If there are any locals without storage markers, give them storage only for the + // duration of the call. + for local in callee_body.vars_and_temps_iter() { + if integrator.always_live_locals.contains(local) { + let new_local = integrator.map_local(local); + caller_body[callsite.block].statements.push(Statement { + source_info: callsite.source_info, + kind: StatementKind::StorageLive(new_local), + }); + } + } + if let Some(block) = callsite.target { + // To avoid repeated O(n) insert, push any new statements to the end and rotate + // the slice once. + let mut n = 0; + for local in callee_body.vars_and_temps_iter().rev() { + if integrator.always_live_locals.contains(local) { + let new_local = integrator.map_local(local); + caller_body[block].statements.push(Statement { + source_info: callsite.source_info, + kind: StatementKind::StorageDead(new_local), + }); + n += 1; + } + } + caller_body[block].statements.rotate_right(n); + } + + // Insert all of the (mapped) parts of the callee body into the caller. + caller_body.local_decls.extend(callee_body.drain_vars_and_temps()); + caller_body.source_scopes.extend(&mut callee_body.source_scopes.drain(..)); + caller_body.var_debug_info.append(&mut callee_body.var_debug_info); + caller_body.basic_blocks_mut().extend(callee_body.basic_blocks_mut().drain(..)); + + caller_body[callsite.block].terminator = Some(Terminator { + source_info: callsite.source_info, + kind: TerminatorKind::Goto { target: integrator.map_block(START_BLOCK) }, + }); + + // Copy only unevaluated constants from the callee_body into the caller_body. + // Although we are only pushing `ConstKind::Unevaluated` consts to + // `required_consts`, here we may not only have `ConstKind::Unevaluated` + // because we are calling `subst_and_normalize_erasing_regions`. + caller_body.required_consts.extend( + callee_body.required_consts.iter().copied().filter(|&ct| { + match ct.literal.const_for_ty() { + Some(ct) => matches!(ct.kind(), ConstKind::Unevaluated(_)), + None => true, + } + }), + ); + } + kind => bug!("unexpected terminator kind {:?}", kind), + } + } + + fn make_call_args( + &self, + args: Vec<Operand<'tcx>>, + callsite: &CallSite<'tcx>, + caller_body: &mut Body<'tcx>, + callee_body: &Body<'tcx>, + ) -> Vec<Local> { + let tcx = self.tcx; + + // There is a bit of a mismatch between the *caller* of a closure and the *callee*. + // The caller provides the arguments wrapped up in a tuple: + // + // tuple_tmp = (a, b, c) + // Fn::call(closure_ref, tuple_tmp) + // + // meanwhile the closure body expects the arguments (here, `a`, `b`, and `c`) + // as distinct arguments. (This is the "rust-call" ABI hack.) Normally, codegen has + // the job of unpacking this tuple. But here, we are codegen. =) So we want to create + // a vector like + // + // [closure_ref, tuple_tmp.0, tuple_tmp.1, tuple_tmp.2] + // + // Except for one tiny wrinkle: we don't actually want `tuple_tmp.0`. It's more convenient + // if we "spill" that into *another* temporary, so that we can map the argument + // variable in the callee MIR directly to an argument variable on our side. + // So we introduce temporaries like: + // + // tmp0 = tuple_tmp.0 + // tmp1 = tuple_tmp.1 + // tmp2 = tuple_tmp.2 + // + // and the vector is `[closure_ref, tmp0, tmp1, tmp2]`. + if callsite.fn_sig.abi() == Abi::RustCall && callee_body.spread_arg.is_none() { + let mut args = args.into_iter(); + let self_ = self.create_temp_if_necessary(args.next().unwrap(), callsite, caller_body); + let tuple = self.create_temp_if_necessary(args.next().unwrap(), callsite, caller_body); + assert!(args.next().is_none()); + + let tuple = Place::from(tuple); + let ty::Tuple(tuple_tys) = tuple.ty(caller_body, tcx).ty.kind() else { + bug!("Closure arguments are not passed as a tuple"); + }; + + // The `closure_ref` in our example above. + let closure_ref_arg = iter::once(self_); + + // The `tmp0`, `tmp1`, and `tmp2` in our example above. + let tuple_tmp_args = tuple_tys.iter().enumerate().map(|(i, ty)| { + // This is e.g., `tuple_tmp.0` in our example above. + let tuple_field = Operand::Move(tcx.mk_place_field(tuple, Field::new(i), ty)); + + // Spill to a local to make e.g., `tmp0`. + self.create_temp_if_necessary(tuple_field, callsite, caller_body) + }); + + closure_ref_arg.chain(tuple_tmp_args).collect() + } else { + args.into_iter() + .map(|a| self.create_temp_if_necessary(a, callsite, caller_body)) + .collect() + } + } + + /// If `arg` is already a temporary, returns it. Otherwise, introduces a fresh + /// temporary `T` and an instruction `T = arg`, and returns `T`. + fn create_temp_if_necessary( + &self, + arg: Operand<'tcx>, + callsite: &CallSite<'tcx>, + caller_body: &mut Body<'tcx>, + ) -> Local { + // Reuse the operand if it is a moved temporary. + if let Operand::Move(place) = &arg + && let Some(local) = place.as_local() + && caller_body.local_kind(local) == LocalKind::Temp + { + return local; + } + + // Otherwise, create a temporary for the argument. + trace!("creating temp for argument {:?}", arg); + let arg_ty = arg.ty(caller_body, self.tcx); + let local = self.new_call_temp(caller_body, callsite, arg_ty); + caller_body[callsite.block].statements.push(Statement { + source_info: callsite.source_info, + kind: StatementKind::Assign(Box::new((Place::from(local), Rvalue::Use(arg)))), + }); + local + } + + /// Introduces a new temporary into the caller body that is live for the duration of the call. + fn new_call_temp( + &self, + caller_body: &mut Body<'tcx>, + callsite: &CallSite<'tcx>, + ty: Ty<'tcx>, + ) -> Local { + let local = caller_body.local_decls.push(LocalDecl::new(ty, callsite.source_info.span)); + + caller_body[callsite.block].statements.push(Statement { + source_info: callsite.source_info, + kind: StatementKind::StorageLive(local), + }); + + if let Some(block) = callsite.target { + caller_body[block].statements.insert( + 0, + Statement { + source_info: callsite.source_info, + kind: StatementKind::StorageDead(local), + }, + ); + } + + local + } +} + +fn type_size_of<'tcx>( + tcx: TyCtxt<'tcx>, + param_env: ty::ParamEnv<'tcx>, + ty: Ty<'tcx>, +) -> Option<u64> { + tcx.layout_of(param_env.and(ty)).ok().map(|layout| layout.size.bytes()) +} + +/** + * Integrator. + * + * Integrates blocks from the callee function into the calling function. + * Updates block indices, references to locals and other control flow + * stuff. +*/ +struct Integrator<'a, 'tcx> { + args: &'a [Local], + new_locals: RangeFrom<Local>, + new_scopes: RangeFrom<SourceScope>, + new_blocks: RangeFrom<BasicBlock>, + destination: Place<'tcx>, + callsite_scope: SourceScopeData<'tcx>, + callsite: &'a CallSite<'tcx>, + cleanup_block: Option<BasicBlock>, + in_cleanup_block: bool, + tcx: TyCtxt<'tcx>, + expn_data: LocalExpnId, + always_live_locals: BitSet<Local>, +} + +impl Integrator<'_, '_> { + fn map_local(&self, local: Local) -> Local { + let new = if local == RETURN_PLACE { + self.destination.local + } else { + let idx = local.index() - 1; + if idx < self.args.len() { + self.args[idx] + } else { + Local::new(self.new_locals.start.index() + (idx - self.args.len())) + } + }; + trace!("mapping local `{:?}` to `{:?}`", local, new); + new + } + + fn map_scope(&self, scope: SourceScope) -> SourceScope { + let new = SourceScope::new(self.new_scopes.start.index() + scope.index()); + trace!("mapping scope `{:?}` to `{:?}`", scope, new); + new + } + + fn map_block(&self, block: BasicBlock) -> BasicBlock { + let new = BasicBlock::new(self.new_blocks.start.index() + block.index()); + trace!("mapping block `{:?}` to `{:?}`", block, new); + new + } +} + +impl<'tcx> MutVisitor<'tcx> for Integrator<'_, 'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_local(&mut self, local: &mut Local, _ctxt: PlaceContext, _location: Location) { + *local = self.map_local(*local); + } + + fn visit_source_scope_data(&mut self, scope_data: &mut SourceScopeData<'tcx>) { + self.super_source_scope_data(scope_data); + if scope_data.parent_scope.is_none() { + // Attach the outermost callee scope as a child of the callsite + // scope, via the `parent_scope` and `inlined_parent_scope` chains. + scope_data.parent_scope = Some(self.callsite.source_info.scope); + assert_eq!(scope_data.inlined_parent_scope, None); + scope_data.inlined_parent_scope = if self.callsite_scope.inlined.is_some() { + Some(self.callsite.source_info.scope) + } else { + self.callsite_scope.inlined_parent_scope + }; + + // Mark the outermost callee scope as an inlined one. + assert_eq!(scope_data.inlined, None); + scope_data.inlined = Some((self.callsite.callee, self.callsite.source_info.span)); + } else if scope_data.inlined_parent_scope.is_none() { + // Make it easy to find the scope with `inlined` set above. + scope_data.inlined_parent_scope = Some(self.map_scope(OUTERMOST_SOURCE_SCOPE)); + } + } + + fn visit_source_scope(&mut self, scope: &mut SourceScope) { + *scope = self.map_scope(*scope); + } + + fn visit_span(&mut self, span: &mut Span) { + // Make sure that all spans track the fact that they were inlined. + *span = span.fresh_expansion(self.expn_data); + } + + fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) { + for elem in place.projection { + // FIXME: Make sure that return place is not used in an indexing projection, since it + // won't be rebased as it is supposed to be. + assert_ne!(ProjectionElem::Index(RETURN_PLACE), elem); + } + + // If this is the `RETURN_PLACE`, we need to rebase any projections onto it. + let dest_proj_len = self.destination.projection.len(); + if place.local == RETURN_PLACE && dest_proj_len > 0 { + let mut projs = Vec::with_capacity(dest_proj_len + place.projection.len()); + projs.extend(self.destination.projection); + projs.extend(place.projection); + + place.projection = self.tcx.intern_place_elems(&*projs); + } + // Handles integrating any locals that occur in the base + // or projections + self.super_place(place, context, location) + } + + fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) { + self.in_cleanup_block = data.is_cleanup; + self.super_basic_block_data(block, data); + self.in_cleanup_block = false; + } + + fn visit_retag(&mut self, kind: &mut RetagKind, place: &mut Place<'tcx>, loc: Location) { + self.super_retag(kind, place, loc); + + // We have to patch all inlined retags to be aware that they are no longer + // happening on function entry. + if *kind == RetagKind::FnEntry { + *kind = RetagKind::Default; + } + } + + fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { + if let StatementKind::StorageLive(local) | StatementKind::StorageDead(local) = + statement.kind + { + self.always_live_locals.remove(local); + } + self.super_statement(statement, location); + } + + fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, loc: Location) { + // Don't try to modify the implicit `_0` access on return (`return` terminators are + // replaced down below anyways). + if !matches!(terminator.kind, TerminatorKind::Return) { + self.super_terminator(terminator, loc); + } + + match terminator.kind { + TerminatorKind::GeneratorDrop | TerminatorKind::Yield { .. } => bug!(), + TerminatorKind::Goto { ref mut target } => { + *target = self.map_block(*target); + } + TerminatorKind::SwitchInt { ref mut targets, .. } => { + for tgt in targets.all_targets_mut() { + *tgt = self.map_block(*tgt); + } + } + TerminatorKind::Drop { ref mut target, ref mut unwind, .. } + | TerminatorKind::DropAndReplace { ref mut target, ref mut unwind, .. } => { + *target = self.map_block(*target); + if let Some(tgt) = *unwind { + *unwind = Some(self.map_block(tgt)); + } else if !self.in_cleanup_block { + // Unless this drop is in a cleanup block, add an unwind edge to + // the original call's cleanup block + *unwind = self.cleanup_block; + } + } + TerminatorKind::Call { ref mut target, ref mut cleanup, .. } => { + if let Some(ref mut tgt) = *target { + *tgt = self.map_block(*tgt); + } + if let Some(tgt) = *cleanup { + *cleanup = Some(self.map_block(tgt)); + } else if !self.in_cleanup_block { + // Unless this call is in a cleanup block, add an unwind edge to + // the original call's cleanup block + *cleanup = self.cleanup_block; + } + } + TerminatorKind::Assert { ref mut target, ref mut cleanup, .. } => { + *target = self.map_block(*target); + if let Some(tgt) = *cleanup { + *cleanup = Some(self.map_block(tgt)); + } else if !self.in_cleanup_block { + // Unless this assert is in a cleanup block, add an unwind edge to + // the original call's cleanup block + *cleanup = self.cleanup_block; + } + } + TerminatorKind::Return => { + terminator.kind = if let Some(tgt) = self.callsite.target { + TerminatorKind::Goto { target: tgt } + } else { + TerminatorKind::Unreachable + } + } + TerminatorKind::Resume => { + if let Some(tgt) = self.cleanup_block { + terminator.kind = TerminatorKind::Goto { target: tgt } + } + } + TerminatorKind::Abort => {} + TerminatorKind::Unreachable => {} + TerminatorKind::FalseEdge { ref mut real_target, ref mut imaginary_target } => { + *real_target = self.map_block(*real_target); + *imaginary_target = self.map_block(*imaginary_target); + } + TerminatorKind::FalseUnwind { real_target: _, unwind: _ } => + // see the ordering of passes in the optimized_mir query. + { + bug!("False unwinds should have been removed before inlining") + } + TerminatorKind::InlineAsm { ref mut destination, ref mut cleanup, .. } => { + if let Some(ref mut tgt) = *destination { + *tgt = self.map_block(*tgt); + } else if !self.in_cleanup_block { + // Unless this inline asm is in a cleanup block, add an unwind edge to + // the original call's cleanup block + *cleanup = self.cleanup_block; + } + } + } + } +} diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs new file mode 100644 index 000000000..7810218fd --- /dev/null +++ b/compiler/rustc_mir_transform/src/inline/cycle.rs @@ -0,0 +1,168 @@ +use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexSet}; +use rustc_data_structures::stack::ensure_sufficient_stack; +use rustc_hir::def_id::{DefId, LocalDefId}; +use rustc_middle::mir::TerminatorKind; +use rustc_middle::ty::TypeVisitable; +use rustc_middle::ty::{self, subst::SubstsRef, InstanceDef, TyCtxt}; +use rustc_session::Limit; + +// FIXME: check whether it is cheaper to precompute the entire call graph instead of invoking +// this query ridiculously often. +#[instrument(level = "debug", skip(tcx, root, target))] +pub(crate) fn mir_callgraph_reachable<'tcx>( + tcx: TyCtxt<'tcx>, + (root, target): (ty::Instance<'tcx>, LocalDefId), +) -> bool { + trace!(%root, target = %tcx.def_path_str(target.to_def_id())); + let param_env = tcx.param_env_reveal_all_normalized(target); + assert_ne!( + root.def_id().expect_local(), + target, + "you should not call `mir_callgraph_reachable` on immediate self recursion" + ); + assert!( + matches!(root.def, InstanceDef::Item(_)), + "you should not call `mir_callgraph_reachable` on shims" + ); + assert!( + !tcx.is_constructor(root.def_id()), + "you should not call `mir_callgraph_reachable` on enum/struct constructor functions" + ); + #[instrument( + level = "debug", + skip(tcx, param_env, target, stack, seen, recursion_limiter, caller, recursion_limit) + )] + fn process<'tcx>( + tcx: TyCtxt<'tcx>, + param_env: ty::ParamEnv<'tcx>, + caller: ty::Instance<'tcx>, + target: LocalDefId, + stack: &mut Vec<ty::Instance<'tcx>>, + seen: &mut FxHashSet<ty::Instance<'tcx>>, + recursion_limiter: &mut FxHashMap<DefId, usize>, + recursion_limit: Limit, + ) -> bool { + trace!(%caller); + for &(callee, substs) in tcx.mir_inliner_callees(caller.def) { + let Ok(substs) = caller.try_subst_mir_and_normalize_erasing_regions(tcx, param_env, substs) else { + trace!(?caller, ?param_env, ?substs, "cannot normalize, skipping"); + continue; + }; + let Ok(Some(callee)) = ty::Instance::resolve(tcx, param_env, callee, substs) else { + trace!(?callee, "cannot resolve, skipping"); + continue; + }; + + // Found a path. + if callee.def_id() == target.to_def_id() { + return true; + } + + if tcx.is_constructor(callee.def_id()) { + trace!("constructors always have MIR"); + // Constructor functions cannot cause a query cycle. + continue; + } + + match callee.def { + InstanceDef::Item(_) => { + // If there is no MIR available (either because it was not in metadata or + // because it has no MIR because it's an extern function), then the inliner + // won't cause cycles on this. + if !tcx.is_mir_available(callee.def_id()) { + trace!(?callee, "no mir available, skipping"); + continue; + } + } + // These have no own callable MIR. + InstanceDef::Intrinsic(_) | InstanceDef::Virtual(..) => continue, + // These have MIR and if that MIR is inlined, substituted and then inlining is run + // again, a function item can end up getting inlined. Thus we'll be able to cause + // a cycle that way + InstanceDef::VTableShim(_) + | InstanceDef::ReifyShim(_) + | InstanceDef::FnPtrShim(..) + | InstanceDef::ClosureOnceShim { .. } + | InstanceDef::CloneShim(..) => {} + InstanceDef::DropGlue(..) => { + // FIXME: A not fully substituted drop shim can cause ICEs if one attempts to + // have its MIR built. Likely oli-obk just screwed up the `ParamEnv`s, so this + // needs some more analysis. + if callee.needs_subst() { + continue; + } + } + } + + if seen.insert(callee) { + let recursion = recursion_limiter.entry(callee.def_id()).or_default(); + trace!(?callee, recursion = *recursion); + if recursion_limit.value_within_limit(*recursion) { + *recursion += 1; + stack.push(callee); + let found_recursion = ensure_sufficient_stack(|| { + process( + tcx, + param_env, + callee, + target, + stack, + seen, + recursion_limiter, + recursion_limit, + ) + }); + if found_recursion { + return true; + } + stack.pop(); + } else { + // Pessimistically assume that there could be recursion. + return true; + } + } + } + false + } + process( + tcx, + param_env, + root, + target, + &mut Vec::new(), + &mut FxHashSet::default(), + &mut FxHashMap::default(), + tcx.recursion_limit(), + ) +} + +pub(crate) fn mir_inliner_callees<'tcx>( + tcx: TyCtxt<'tcx>, + instance: ty::InstanceDef<'tcx>, +) -> &'tcx [(DefId, SubstsRef<'tcx>)] { + let steal; + let guard; + let body = match (instance, instance.def_id().as_local()) { + (InstanceDef::Item(_), Some(def_id)) => { + let def = ty::WithOptConstParam::unknown(def_id); + steal = tcx.mir_promoted(def).0; + guard = steal.borrow(); + &*guard + } + // Functions from other crates and MIR shims + _ => tcx.instance_mir(instance), + }; + let mut calls = FxIndexSet::default(); + for bb_data in body.basic_blocks() { + let terminator = bb_data.terminator(); + if let TerminatorKind::Call { func, .. } = &terminator.kind { + let ty = func.ty(&body.local_decls, tcx); + let call = match ty.kind() { + ty::FnDef(def_id, substs) => (*def_id, *substs), + _ => continue, + }; + calls.insert(call); + } + } + tcx.arena.alloc_from_iter(calls.iter().copied()) +} diff --git a/compiler/rustc_mir_transform/src/instcombine.rs b/compiler/rustc_mir_transform/src/instcombine.rs new file mode 100644 index 000000000..2f3c65869 --- /dev/null +++ b/compiler/rustc_mir_transform/src/instcombine.rs @@ -0,0 +1,203 @@ +//! Performs various peephole optimizations. + +use crate::MirPass; +use rustc_hir::Mutability; +use rustc_middle::mir::{ + BinOp, Body, Constant, ConstantKind, LocalDecls, Operand, Place, ProjectionElem, Rvalue, + SourceInfo, Statement, StatementKind, Terminator, TerminatorKind, UnOp, +}; +use rustc_middle::ty::{self, TyCtxt}; + +pub struct InstCombine; + +impl<'tcx> MirPass<'tcx> for InstCombine { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() > 0 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let ctx = InstCombineContext { tcx, local_decls: &body.local_decls }; + for block in body.basic_blocks.as_mut() { + for statement in block.statements.iter_mut() { + match statement.kind { + StatementKind::Assign(box (_place, ref mut rvalue)) => { + ctx.combine_bool_cmp(&statement.source_info, rvalue); + ctx.combine_ref_deref(&statement.source_info, rvalue); + ctx.combine_len(&statement.source_info, rvalue); + } + _ => {} + } + } + + ctx.combine_primitive_clone( + &mut block.terminator.as_mut().unwrap(), + &mut block.statements, + ); + } + } +} + +struct InstCombineContext<'tcx, 'a> { + tcx: TyCtxt<'tcx>, + local_decls: &'a LocalDecls<'tcx>, +} + +impl<'tcx> InstCombineContext<'tcx, '_> { + fn should_combine(&self, source_info: &SourceInfo, rvalue: &Rvalue<'tcx>) -> bool { + self.tcx.consider_optimizing(|| { + format!("InstCombine - Rvalue: {:?} SourceInfo: {:?}", rvalue, source_info) + }) + } + + /// Transform boolean comparisons into logical operations. + fn combine_bool_cmp(&self, source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) { + match rvalue { + Rvalue::BinaryOp(op @ (BinOp::Eq | BinOp::Ne), box (a, b)) => { + let new = match (op, self.try_eval_bool(a), self.try_eval_bool(b)) { + // Transform "Eq(a, true)" ==> "a" + (BinOp::Eq, _, Some(true)) => Some(Rvalue::Use(a.clone())), + + // Transform "Ne(a, false)" ==> "a" + (BinOp::Ne, _, Some(false)) => Some(Rvalue::Use(a.clone())), + + // Transform "Eq(true, b)" ==> "b" + (BinOp::Eq, Some(true), _) => Some(Rvalue::Use(b.clone())), + + // Transform "Ne(false, b)" ==> "b" + (BinOp::Ne, Some(false), _) => Some(Rvalue::Use(b.clone())), + + // Transform "Eq(false, b)" ==> "Not(b)" + (BinOp::Eq, Some(false), _) => Some(Rvalue::UnaryOp(UnOp::Not, b.clone())), + + // Transform "Ne(true, b)" ==> "Not(b)" + (BinOp::Ne, Some(true), _) => Some(Rvalue::UnaryOp(UnOp::Not, b.clone())), + + // Transform "Eq(a, false)" ==> "Not(a)" + (BinOp::Eq, _, Some(false)) => Some(Rvalue::UnaryOp(UnOp::Not, a.clone())), + + // Transform "Ne(a, true)" ==> "Not(a)" + (BinOp::Ne, _, Some(true)) => Some(Rvalue::UnaryOp(UnOp::Not, a.clone())), + + _ => None, + }; + + if let Some(new) = new && self.should_combine(source_info, rvalue) { + *rvalue = new; + } + } + + _ => {} + } + } + + fn try_eval_bool(&self, a: &Operand<'_>) -> Option<bool> { + let a = a.constant()?; + if a.literal.ty().is_bool() { a.literal.try_to_bool() } else { None } + } + + /// Transform "&(*a)" ==> "a". + fn combine_ref_deref(&self, source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) { + if let Rvalue::Ref(_, _, place) = rvalue { + if let Some((base, ProjectionElem::Deref)) = place.as_ref().last_projection() { + if let ty::Ref(_, _, Mutability::Not) = + base.ty(self.local_decls, self.tcx).ty.kind() + { + // The dereferenced place must have type `&_`, so that we don't copy `&mut _`. + } else { + return; + } + + if !self.should_combine(source_info, rvalue) { + return; + } + + *rvalue = Rvalue::Use(Operand::Copy(Place { + local: base.local, + projection: self.tcx.intern_place_elems(base.projection), + })); + } + } + } + + /// Transform "Len([_; N])" ==> "N". + fn combine_len(&self, source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) { + if let Rvalue::Len(ref place) = *rvalue { + let place_ty = place.ty(self.local_decls, self.tcx).ty; + if let ty::Array(_, len) = *place_ty.kind() { + if !self.should_combine(source_info, rvalue) { + return; + } + + let literal = ConstantKind::from_const(len, self.tcx); + let constant = Constant { span: source_info.span, literal, user_ty: None }; + *rvalue = Rvalue::Use(Operand::Constant(Box::new(constant))); + } + } + } + + fn combine_primitive_clone( + &self, + terminator: &mut Terminator<'tcx>, + statements: &mut Vec<Statement<'tcx>>, + ) { + let TerminatorKind::Call { func, args, destination, target, .. } = &mut terminator.kind + else { return }; + + // It's definitely not a clone if there are multiple arguments + if args.len() != 1 { + return; + } + + let Some(destination_block) = *target + else { return }; + + // Only bother looking more if it's easy to know what we're calling + let Some((fn_def_id, fn_substs)) = func.const_fn_def() + else { return }; + + // Clone needs one subst, so we can cheaply rule out other stuff + if fn_substs.len() != 1 { + return; + } + + // These types are easily available from locals, so check that before + // doing DefId lookups to figure out what we're actually calling. + let arg_ty = args[0].ty(self.local_decls, self.tcx); + + let ty::Ref(_region, inner_ty, Mutability::Not) = *arg_ty.kind() + else { return }; + + if !inner_ty.is_trivially_pure_clone_copy() { + return; + } + + let trait_def_id = self.tcx.trait_of_item(fn_def_id); + if trait_def_id.is_none() || trait_def_id != self.tcx.lang_items().clone_trait() { + return; + } + + if !self.tcx.consider_optimizing(|| { + format!( + "InstCombine - Call: {:?} SourceInfo: {:?}", + (fn_def_id, fn_substs), + terminator.source_info + ) + }) { + return; + } + + let Some(arg_place) = args.pop().unwrap().place() + else { return }; + + statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::Use(Operand::Copy( + arg_place.project_deeper(&[ProjectionElem::Deref], self.tcx), + )), + ))), + }); + terminator.kind = TerminatorKind::Goto { target: destination_block }; + } +} diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs new file mode 100644 index 000000000..d968a4885 --- /dev/null +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -0,0 +1,575 @@ +#![allow(rustc::potential_query_instability)] +#![feature(box_patterns)] +#![feature(let_chains)] +#![feature(let_else)] +#![feature(map_try_insert)] +#![feature(min_specialization)] +#![feature(never_type)] +#![feature(once_cell)] +#![feature(option_get_or_insert_default)] +#![feature(trusted_step)] +#![feature(try_blocks)] +#![feature(yeet_expr)] +#![recursion_limit = "256"] + +#[macro_use] +extern crate tracing; +#[macro_use] +extern crate rustc_middle; + +use required_consts::RequiredConstsVisitor; +use rustc_const_eval::util; +use rustc_data_structures::fx::FxIndexSet; +use rustc_data_structures::steal::Steal; +use rustc_hir as hir; +use rustc_hir::def_id::{DefId, LocalDefId}; +use rustc_hir::intravisit::{self, Visitor}; +use rustc_index::vec::IndexVec; +use rustc_middle::mir::visit::Visitor as _; +use rustc_middle::mir::{traversal, Body, ConstQualifs, MirPass, MirPhase, Promoted}; +use rustc_middle::ty::query::Providers; +use rustc_middle::ty::{self, TyCtxt, TypeVisitable}; +use rustc_span::{Span, Symbol}; + +#[macro_use] +mod pass_manager; + +use pass_manager::{self as pm, Lint, MirLint, WithMinOptLevel}; + +mod abort_unwinding_calls; +mod add_call_guards; +mod add_moves_for_packed_drops; +mod add_retag; +mod check_const_item_mutation; +mod check_packed_ref; +pub mod check_unsafety; +// This pass is public to allow external drivers to perform MIR cleanup +pub mod cleanup_post_borrowck; +mod const_debuginfo; +mod const_goto; +mod const_prop; +mod const_prop_lint; +mod coverage; +mod dead_store_elimination; +mod deaggregator; +mod deduplicate_blocks; +mod deref_separator; +mod dest_prop; +pub mod dump_mir; +mod early_otherwise_branch; +mod elaborate_box_derefs; +mod elaborate_drops; +mod ffi_unwind_calls; +mod function_item_references; +mod generator; +mod inline; +mod instcombine; +mod lower_intrinsics; +mod lower_slice_len; +mod marker; +mod match_branches; +mod multiple_return_terminators; +mod normalize_array_len; +mod nrvo; +// This pass is public to allow external drivers to perform MIR cleanup +pub mod remove_false_edges; +mod remove_noop_landing_pads; +mod remove_storage_markers; +mod remove_uninit_drops; +mod remove_unneeded_drops; +mod remove_zsts; +mod required_consts; +mod reveal_all; +mod separate_const_switch; +mod shim; +// This pass is public to allow external drivers to perform MIR cleanup +pub mod simplify; +mod simplify_branches; +mod simplify_comparison_integral; +mod simplify_try; +mod uninhabited_enum_branching; +mod unreachable_prop; + +use rustc_const_eval::transform::check_consts::{self, ConstCx}; +use rustc_const_eval::transform::promote_consts; +use rustc_const_eval::transform::validate; +use rustc_mir_dataflow::rustc_peek; + +pub fn provide(providers: &mut Providers) { + check_unsafety::provide(providers); + check_packed_ref::provide(providers); + coverage::query::provide(providers); + ffi_unwind_calls::provide(providers); + shim::provide(providers); + *providers = Providers { + mir_keys, + mir_const, + mir_const_qualif: |tcx, def_id| { + let def_id = def_id.expect_local(); + if let Some(def) = ty::WithOptConstParam::try_lookup(def_id, tcx) { + tcx.mir_const_qualif_const_arg(def) + } else { + mir_const_qualif(tcx, ty::WithOptConstParam::unknown(def_id)) + } + }, + mir_const_qualif_const_arg: |tcx, (did, param_did)| { + mir_const_qualif(tcx, ty::WithOptConstParam { did, const_param_did: Some(param_did) }) + }, + mir_promoted, + mir_drops_elaborated_and_const_checked, + mir_for_ctfe, + mir_for_ctfe_of_const_arg, + optimized_mir, + is_mir_available, + is_ctfe_mir_available: |tcx, did| is_mir_available(tcx, did), + mir_callgraph_reachable: inline::cycle::mir_callgraph_reachable, + mir_inliner_callees: inline::cycle::mir_inliner_callees, + promoted_mir: |tcx, def_id| { + let def_id = def_id.expect_local(); + if let Some(def) = ty::WithOptConstParam::try_lookup(def_id, tcx) { + tcx.promoted_mir_of_const_arg(def) + } else { + promoted_mir(tcx, ty::WithOptConstParam::unknown(def_id)) + } + }, + promoted_mir_of_const_arg: |tcx, (did, param_did)| { + promoted_mir(tcx, ty::WithOptConstParam { did, const_param_did: Some(param_did) }) + }, + ..*providers + }; +} + +fn is_mir_available(tcx: TyCtxt<'_>, def_id: DefId) -> bool { + let def_id = def_id.expect_local(); + tcx.mir_keys(()).contains(&def_id) +} + +/// Finds the full set of `DefId`s within the current crate that have +/// MIR associated with them. +fn mir_keys(tcx: TyCtxt<'_>, (): ()) -> FxIndexSet<LocalDefId> { + let mut set = FxIndexSet::default(); + + // All body-owners have MIR associated with them. + set.extend(tcx.hir().body_owners()); + + // Additionally, tuple struct/variant constructors have MIR, but + // they don't have a BodyId, so we need to build them separately. + struct GatherCtors<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + set: &'a mut FxIndexSet<LocalDefId>, + } + impl<'tcx> Visitor<'tcx> for GatherCtors<'_, 'tcx> { + fn visit_variant_data( + &mut self, + v: &'tcx hir::VariantData<'tcx>, + _: Symbol, + _: &'tcx hir::Generics<'tcx>, + _: hir::HirId, + _: Span, + ) { + if let hir::VariantData::Tuple(_, hir_id) = *v { + self.set.insert(self.tcx.hir().local_def_id(hir_id)); + } + intravisit::walk_struct_def(self, v) + } + } + tcx.hir().visit_all_item_likes_in_crate(&mut GatherCtors { tcx, set: &mut set }); + + set +} + +fn mir_const_qualif(tcx: TyCtxt<'_>, def: ty::WithOptConstParam<LocalDefId>) -> ConstQualifs { + let const_kind = tcx.hir().body_const_context(def.did); + + // No need to const-check a non-const `fn`. + if const_kind.is_none() { + return Default::default(); + } + + // N.B., this `borrow()` is guaranteed to be valid (i.e., the value + // cannot yet be stolen), because `mir_promoted()`, which steals + // from `mir_const(), forces this query to execute before + // performing the steal. + let body = &tcx.mir_const(def).borrow(); + + if body.return_ty().references_error() { + tcx.sess.delay_span_bug(body.span, "mir_const_qualif: MIR had errors"); + return Default::default(); + } + + let ccx = check_consts::ConstCx { body, tcx, const_kind, param_env: tcx.param_env(def.did) }; + + let mut validator = check_consts::check::Checker::new(&ccx); + validator.check_body(); + + // We return the qualifs in the return place for every MIR body, even though it is only used + // when deciding to promote a reference to a `const` for now. + validator.qualifs_in_return_place() +} + +/// Make MIR ready for const evaluation. This is run on all MIR, not just on consts! +fn mir_const<'tcx>( + tcx: TyCtxt<'tcx>, + def: ty::WithOptConstParam<LocalDefId>, +) -> &'tcx Steal<Body<'tcx>> { + if let Some(def) = def.try_upgrade(tcx) { + return tcx.mir_const(def); + } + + // Unsafety check uses the raw mir, so make sure it is run. + if !tcx.sess.opts.unstable_opts.thir_unsafeck { + if let Some(param_did) = def.const_param_did { + tcx.ensure().unsafety_check_result_for_const_arg((def.did, param_did)); + } else { + tcx.ensure().unsafety_check_result(def.did); + } + } + + // has_ffi_unwind_calls query uses the raw mir, so make sure it is run. + tcx.ensure().has_ffi_unwind_calls(def.did); + + let mut body = tcx.mir_built(def).steal(); + + rustc_middle::mir::dump_mir(tcx, None, "mir_map", &0, &body, |_, _| Ok(())); + + pm::run_passes( + tcx, + &mut body, + &[ + // MIR-level lints. + &Lint(check_packed_ref::CheckPackedRef), + &Lint(check_const_item_mutation::CheckConstItemMutation), + &Lint(function_item_references::FunctionItemReferences), + // What we need to do constant evaluation. + &simplify::SimplifyCfg::new("initial"), + &rustc_peek::SanityCheck, // Just a lint + &marker::PhaseChange(MirPhase::Const), + ], + ); + tcx.alloc_steal_mir(body) +} + +/// Compute the main MIR body and the list of MIR bodies of the promoteds. +fn mir_promoted<'tcx>( + tcx: TyCtxt<'tcx>, + def: ty::WithOptConstParam<LocalDefId>, +) -> (&'tcx Steal<Body<'tcx>>, &'tcx Steal<IndexVec<Promoted, Body<'tcx>>>) { + if let Some(def) = def.try_upgrade(tcx) { + return tcx.mir_promoted(def); + } + + // Ensure that we compute the `mir_const_qualif` for constants at + // this point, before we steal the mir-const result. + // Also this means promotion can rely on all const checks having been done. + let const_qualifs = tcx.mir_const_qualif_opt_const_arg(def); + let mut body = tcx.mir_const(def).steal(); + if let Some(error_reported) = const_qualifs.tainted_by_errors { + body.tainted_by_errors = Some(error_reported); + } + + let mut required_consts = Vec::new(); + let mut required_consts_visitor = RequiredConstsVisitor::new(&mut required_consts); + for (bb, bb_data) in traversal::reverse_postorder(&body) { + required_consts_visitor.visit_basic_block_data(bb, bb_data); + } + body.required_consts = required_consts; + + // What we need to run borrowck etc. + let promote_pass = promote_consts::PromoteTemps::default(); + pm::run_passes( + tcx, + &mut body, + &[ + &promote_pass, + &simplify::SimplifyCfg::new("promote-consts"), + &coverage::InstrumentCoverage, + ], + ); + + let promoted = promote_pass.promoted_fragments.into_inner(); + (tcx.alloc_steal_mir(body), tcx.alloc_steal_promoted(promoted)) +} + +/// Compute the MIR that is used during CTFE (and thus has no optimizations run on it) +fn mir_for_ctfe<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> &'tcx Body<'tcx> { + let did = def_id.expect_local(); + if let Some(def) = ty::WithOptConstParam::try_lookup(did, tcx) { + tcx.mir_for_ctfe_of_const_arg(def) + } else { + tcx.arena.alloc(inner_mir_for_ctfe(tcx, ty::WithOptConstParam::unknown(did))) + } +} + +/// Same as `mir_for_ctfe`, but used to get the MIR of a const generic parameter. +/// The docs on `WithOptConstParam` explain this a bit more, but the TLDR is that +/// we'd get cycle errors with `mir_for_ctfe`, because typeck would need to typeck +/// the const parameter while type checking the main body, which in turn would try +/// to type check the main body again. +fn mir_for_ctfe_of_const_arg<'tcx>( + tcx: TyCtxt<'tcx>, + (did, param_did): (LocalDefId, DefId), +) -> &'tcx Body<'tcx> { + tcx.arena.alloc(inner_mir_for_ctfe( + tcx, + ty::WithOptConstParam { did, const_param_did: Some(param_did) }, + )) +} + +fn inner_mir_for_ctfe(tcx: TyCtxt<'_>, def: ty::WithOptConstParam<LocalDefId>) -> Body<'_> { + // FIXME: don't duplicate this between the optimized_mir/mir_for_ctfe queries + if tcx.is_constructor(def.did.to_def_id()) { + // There's no reason to run all of the MIR passes on constructors when + // we can just output the MIR we want directly. This also saves const + // qualification and borrow checking the trouble of special casing + // constructors. + return shim::build_adt_ctor(tcx, def.did.to_def_id()); + } + + let context = tcx + .hir() + .body_const_context(def.did) + .expect("mir_for_ctfe should not be used for runtime functions"); + + let mut body = tcx.mir_drops_elaborated_and_const_checked(def).borrow().clone(); + + match context { + // Do not const prop functions, either they get executed at runtime or exported to metadata, + // so we run const prop on them, or they don't, in which case we const evaluate some control + // flow paths of the function and any errors in those paths will get emitted as const eval + // errors. + hir::ConstContext::ConstFn => {} + // Static items always get evaluated, so we can just let const eval see if any erroneous + // control flow paths get executed. + hir::ConstContext::Static(_) => {} + // Associated constants get const prop run so we detect common failure situations in the + // crate that defined the constant. + // Technically we want to not run on regular const items, but oli-obk doesn't know how to + // conveniently detect that at this point without looking at the HIR. + hir::ConstContext::Const => { + pm::run_passes( + tcx, + &mut body, + &[&const_prop::ConstProp, &marker::PhaseChange(MirPhase::Optimized)], + ); + } + } + + debug_assert!(!body.has_free_regions(), "Free regions in MIR for CTFE"); + + body +} + +/// Obtain just the main MIR (no promoteds) and run some cleanups on it. This also runs +/// mir borrowck *before* doing so in order to ensure that borrowck can be run and doesn't +/// end up missing the source MIR due to stealing happening. +fn mir_drops_elaborated_and_const_checked<'tcx>( + tcx: TyCtxt<'tcx>, + def: ty::WithOptConstParam<LocalDefId>, +) -> &'tcx Steal<Body<'tcx>> { + if let Some(def) = def.try_upgrade(tcx) { + return tcx.mir_drops_elaborated_and_const_checked(def); + } + + let mir_borrowck = tcx.mir_borrowck_opt_const_arg(def); + + let is_fn_like = tcx.def_kind(def.did).is_fn_like(); + if is_fn_like { + let did = def.did.to_def_id(); + let def = ty::WithOptConstParam::unknown(did); + + // Do not compute the mir call graph without said call graph actually being used. + if inline::Inline.is_enabled(&tcx.sess) { + let _ = tcx.mir_inliner_callees(ty::InstanceDef::Item(def)); + } + } + + let (body, _) = tcx.mir_promoted(def); + let mut body = body.steal(); + if let Some(error_reported) = mir_borrowck.tainted_by_errors { + body.tainted_by_errors = Some(error_reported); + } + + // IMPORTANT + pm::run_passes(tcx, &mut body, &[&remove_false_edges::RemoveFalseEdges]); + + // Do a little drop elaboration before const-checking if `const_precise_live_drops` is enabled. + if check_consts::post_drop_elaboration::checking_enabled(&ConstCx::new(tcx, &body)) { + pm::run_passes( + tcx, + &mut body, + &[ + &simplify::SimplifyCfg::new("remove-false-edges"), + &remove_uninit_drops::RemoveUninitDrops, + ], + ); + check_consts::post_drop_elaboration::check_live_drops(tcx, &body); // FIXME: make this a MIR lint + } + + run_post_borrowck_cleanup_passes(tcx, &mut body); + assert!(body.phase == MirPhase::Deaggregated); + tcx.alloc_steal_mir(body) +} + +/// After this series of passes, no lifetime analysis based on borrowing can be done. +fn run_post_borrowck_cleanup_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + debug!("post_borrowck_cleanup({:?})", body.source.def_id()); + + let post_borrowck_cleanup: &[&dyn MirPass<'tcx>] = &[ + // Remove all things only needed by analysis + &simplify_branches::SimplifyConstCondition::new("initial"), + &remove_noop_landing_pads::RemoveNoopLandingPads, + &cleanup_post_borrowck::CleanupNonCodegenStatements, + &simplify::SimplifyCfg::new("early-opt"), + &deref_separator::Derefer, + // These next passes must be executed together + &add_call_guards::CriticalCallEdges, + &elaborate_drops::ElaborateDrops, + // This will remove extraneous landing pads which are no longer + // necessary as well as well as forcing any call in a non-unwinding + // function calling a possibly-unwinding function to abort the process. + &abort_unwinding_calls::AbortUnwindingCalls, + // AddMovesForPackedDrops needs to run after drop + // elaboration. + &add_moves_for_packed_drops::AddMovesForPackedDrops, + // `AddRetag` needs to run after `ElaborateDrops`. Otherwise it should run fairly late, + // but before optimizations begin. + &elaborate_box_derefs::ElaborateBoxDerefs, + &add_retag::AddRetag, + &lower_intrinsics::LowerIntrinsics, + &simplify::SimplifyCfg::new("elaborate-drops"), + // `Deaggregator` is conceptually part of MIR building, some backends rely on it happening + // and it can help optimizations. + &deaggregator::Deaggregator, + &Lint(const_prop_lint::ConstProp), + ]; + + pm::run_passes(tcx, body, post_borrowck_cleanup); +} + +fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + fn o1<T>(x: T) -> WithMinOptLevel<T> { + WithMinOptLevel(1, x) + } + + // Lowering generator control-flow and variables has to happen before we do anything else + // to them. We run some optimizations before that, because they may be harder to do on the state + // machine than on MIR with async primitives. + pm::run_passes( + tcx, + body, + &[ + &reveal_all::RevealAll, // has to be done before inlining, since inlined code is in RevealAll mode. + &lower_slice_len::LowerSliceLenCalls, // has to be done before inlining, otherwise actual call will be almost always inlined. Also simple, so can just do first + &normalize_array_len::NormalizeArrayLen, // has to run after `slice::len` lowering + &unreachable_prop::UnreachablePropagation, + &uninhabited_enum_branching::UninhabitedEnumBranching, + &o1(simplify::SimplifyCfg::new("after-uninhabited-enum-branching")), + &inline::Inline, + &generator::StateTransform, + ], + ); + + assert!(body.phase == MirPhase::GeneratorsLowered); + + // The main optimizations that we do on MIR. + pm::run_passes( + tcx, + body, + &[ + &remove_storage_markers::RemoveStorageMarkers, + &remove_zsts::RemoveZsts, + &const_goto::ConstGoto, + &remove_unneeded_drops::RemoveUnneededDrops, + &match_branches::MatchBranchSimplification, + // inst combine is after MatchBranchSimplification to clean up Ne(_1, false) + &multiple_return_terminators::MultipleReturnTerminators, + &instcombine::InstCombine, + &separate_const_switch::SeparateConstSwitch, + // + // FIXME(#70073): This pass is responsible for both optimization as well as some lints. + &const_prop::ConstProp, + // + // Const-prop runs unconditionally, but doesn't mutate the MIR at mir-opt-level=0. + &const_debuginfo::ConstDebugInfo, + &o1(simplify_branches::SimplifyConstCondition::new("after-const-prop")), + &early_otherwise_branch::EarlyOtherwiseBranch, + &simplify_comparison_integral::SimplifyComparisonIntegral, + &simplify_try::SimplifyArmIdentity, + &simplify_try::SimplifyBranchSame, + &dead_store_elimination::DeadStoreElimination, + &dest_prop::DestinationPropagation, + &o1(simplify_branches::SimplifyConstCondition::new("final")), + &o1(remove_noop_landing_pads::RemoveNoopLandingPads), + &o1(simplify::SimplifyCfg::new("final")), + &nrvo::RenameReturnPlace, + &simplify::SimplifyLocals, + &multiple_return_terminators::MultipleReturnTerminators, + &deduplicate_blocks::DeduplicateBlocks, + // Some cleanup necessary at least for LLVM and potentially other codegen backends. + &add_call_guards::CriticalCallEdges, + &marker::PhaseChange(MirPhase::Optimized), + // Dump the end result for testing and debugging purposes. + &dump_mir::Marker("PreCodegen"), + ], + ); +} + +/// Optimize the MIR and prepare it for codegen. +fn optimized_mir<'tcx>(tcx: TyCtxt<'tcx>, did: DefId) -> &'tcx Body<'tcx> { + let did = did.expect_local(); + assert_eq!(ty::WithOptConstParam::try_lookup(did, tcx), None); + tcx.arena.alloc(inner_optimized_mir(tcx, did)) +} + +fn inner_optimized_mir(tcx: TyCtxt<'_>, did: LocalDefId) -> Body<'_> { + if tcx.is_constructor(did.to_def_id()) { + // There's no reason to run all of the MIR passes on constructors when + // we can just output the MIR we want directly. This also saves const + // qualification and borrow checking the trouble of special casing + // constructors. + return shim::build_adt_ctor(tcx, did.to_def_id()); + } + + match tcx.hir().body_const_context(did) { + // Run the `mir_for_ctfe` query, which depends on `mir_drops_elaborated_and_const_checked` + // which we are going to steal below. Thus we need to run `mir_for_ctfe` first, so it + // computes and caches its result. + Some(hir::ConstContext::ConstFn) => tcx.ensure().mir_for_ctfe(did), + None => {} + Some(other) => panic!("do not use `optimized_mir` for constants: {:?}", other), + } + debug!("about to call mir_drops_elaborated..."); + let mut body = + tcx.mir_drops_elaborated_and_const_checked(ty::WithOptConstParam::unknown(did)).steal(); + debug!("body: {:#?}", body); + run_optimization_passes(tcx, &mut body); + + debug_assert!(!body.has_free_regions(), "Free regions in optimized MIR"); + + body +} + +/// Fetch all the promoteds of an item and prepare their MIR bodies to be ready for +/// constant evaluation once all substitutions become known. +fn promoted_mir<'tcx>( + tcx: TyCtxt<'tcx>, + def: ty::WithOptConstParam<LocalDefId>, +) -> &'tcx IndexVec<Promoted, Body<'tcx>> { + if tcx.is_constructor(def.did.to_def_id()) { + return tcx.arena.alloc(IndexVec::new()); + } + + let tainted_by_errors = tcx.mir_borrowck_opt_const_arg(def).tainted_by_errors; + let mut promoted = tcx.mir_promoted(def).1.steal(); + + for body in &mut promoted { + if let Some(error_reported) = tainted_by_errors { + body.tainted_by_errors = Some(error_reported); + } + run_post_borrowck_cleanup_passes(tcx, body); + } + + debug_assert!(!promoted.has_free_regions(), "Free regions in promoted MIR"); + + tcx.arena.alloc(promoted) +} diff --git a/compiler/rustc_mir_transform/src/lower_intrinsics.rs b/compiler/rustc_mir_transform/src/lower_intrinsics.rs new file mode 100644 index 000000000..b7ba61651 --- /dev/null +++ b/compiler/rustc_mir_transform/src/lower_intrinsics.rs @@ -0,0 +1,156 @@ +//! Lowers intrinsic calls + +use crate::MirPass; +use rustc_middle::mir::*; +use rustc_middle::ty::subst::SubstsRef; +use rustc_middle::ty::{self, Ty, TyCtxt}; +use rustc_span::symbol::{sym, Symbol}; +use rustc_span::Span; + +pub struct LowerIntrinsics; + +impl<'tcx> MirPass<'tcx> for LowerIntrinsics { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let local_decls = &body.local_decls; + for block in body.basic_blocks.as_mut() { + let terminator = block.terminator.as_mut().unwrap(); + if let TerminatorKind::Call { func, args, destination, target, .. } = + &mut terminator.kind + { + let func_ty = func.ty(local_decls, tcx); + let Some((intrinsic_name, substs)) = resolve_rust_intrinsic(tcx, func_ty) else { + continue; + }; + match intrinsic_name { + sym::unreachable => { + terminator.kind = TerminatorKind::Unreachable; + } + sym::forget => { + if let Some(target) = *target { + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::Use(Operand::Constant(Box::new(Constant { + span: terminator.source_info.span, + user_ty: None, + literal: ConstantKind::zero_sized(tcx.types.unit), + }))), + ))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } + } + sym::copy_nonoverlapping => { + let target = target.unwrap(); + let mut args = args.drain(..); + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::CopyNonOverlapping(Box::new( + rustc_middle::mir::CopyNonOverlapping { + src: args.next().unwrap(), + dst: args.next().unwrap(), + count: args.next().unwrap(), + }, + )), + }); + assert_eq!( + args.next(), + None, + "Extra argument for copy_non_overlapping intrinsic" + ); + drop(args); + terminator.kind = TerminatorKind::Goto { target }; + } + sym::wrapping_add | sym::wrapping_sub | sym::wrapping_mul => { + if let Some(target) = *target { + let lhs; + let rhs; + { + let mut args = args.drain(..); + lhs = args.next().unwrap(); + rhs = args.next().unwrap(); + } + let bin_op = match intrinsic_name { + sym::wrapping_add => BinOp::Add, + sym::wrapping_sub => BinOp::Sub, + sym::wrapping_mul => BinOp::Mul, + _ => bug!("unexpected intrinsic"), + }; + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::BinaryOp(bin_op, Box::new((lhs, rhs))), + ))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } + } + sym::add_with_overflow | sym::sub_with_overflow | sym::mul_with_overflow => { + // The checked binary operations are not suitable target for lowering here, + // since their semantics depend on the value of overflow-checks flag used + // during codegen. Issue #35310. + } + sym::size_of | sym::min_align_of => { + if let Some(target) = *target { + let tp_ty = substs.type_at(0); + let null_op = match intrinsic_name { + sym::size_of => NullOp::SizeOf, + sym::min_align_of => NullOp::AlignOf, + _ => bug!("unexpected intrinsic"), + }; + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::NullaryOp(null_op, tp_ty), + ))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } + } + sym::discriminant_value => { + if let (Some(target), Some(arg)) = (*target, args[0].place()) { + let arg = tcx.mk_place_deref(arg); + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::Discriminant(arg), + ))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } + } + _ if intrinsic_name.as_str().starts_with("simd_shuffle") => { + validate_simd_shuffle(tcx, args, terminator.source_info.span); + } + _ => {} + } + } + } + } +} + +fn resolve_rust_intrinsic<'tcx>( + tcx: TyCtxt<'tcx>, + func_ty: Ty<'tcx>, +) -> Option<(Symbol, SubstsRef<'tcx>)> { + if let ty::FnDef(def_id, substs) = *func_ty.kind() { + if tcx.is_intrinsic(def_id) { + return Some((tcx.item_name(def_id), substs)); + } + } + None +} + +fn validate_simd_shuffle<'tcx>(tcx: TyCtxt<'tcx>, args: &[Operand<'tcx>], span: Span) { + match &args[2] { + Operand::Constant(_) => {} // all good + _ => { + let msg = "last argument of `simd_shuffle` is required to be a `const` item"; + tcx.sess.span_err(span, msg); + } + } +} diff --git a/compiler/rustc_mir_transform/src/lower_slice_len.rs b/compiler/rustc_mir_transform/src/lower_slice_len.rs new file mode 100644 index 000000000..2f02d00ec --- /dev/null +++ b/compiler/rustc_mir_transform/src/lower_slice_len.rs @@ -0,0 +1,99 @@ +//! This pass lowers calls to core::slice::len to just Len op. +//! It should run before inlining! + +use crate::MirPass; +use rustc_hir::def_id::DefId; +use rustc_index::vec::IndexVec; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, TyCtxt}; + +pub struct LowerSliceLenCalls; + +impl<'tcx> MirPass<'tcx> for LowerSliceLenCalls { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() > 0 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + lower_slice_len_calls(tcx, body) + } +} + +pub fn lower_slice_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let language_items = tcx.lang_items(); + let Some(slice_len_fn_item_def_id) = language_items.slice_len_fn() else { + // there is no language item to compare to :) + return; + }; + + // The one successor remains unchanged, so no need to invalidate + let basic_blocks = body.basic_blocks.as_mut_preserves_cfg(); + for block in basic_blocks { + // lower `<[_]>::len` calls + lower_slice_len_call(tcx, block, &body.local_decls, slice_len_fn_item_def_id); + } +} + +struct SliceLenPatchInformation<'tcx> { + add_statement: Statement<'tcx>, + new_terminator_kind: TerminatorKind<'tcx>, +} + +fn lower_slice_len_call<'tcx>( + tcx: TyCtxt<'tcx>, + block: &mut BasicBlockData<'tcx>, + local_decls: &IndexVec<Local, LocalDecl<'tcx>>, + slice_len_fn_item_def_id: DefId, +) { + let mut patch_found: Option<SliceLenPatchInformation<'_>> = None; + + let terminator = block.terminator(); + match &terminator.kind { + TerminatorKind::Call { + func, + args, + destination, + target: Some(bb), + cleanup: None, + from_hir_call: true, + .. + } => { + // some heuristics for fast rejection + if args.len() != 1 { + return; + } + let Some(arg) = args[0].place() else { return }; + let func_ty = func.ty(local_decls, tcx); + match func_ty.kind() { + ty::FnDef(fn_def_id, _) if fn_def_id == &slice_len_fn_item_def_id => { + // perform modifications + // from something like `_5 = core::slice::<impl [u8]>::len(move _6) -> bb1` + // into `_5 = Len(*_6) + // goto bb1 + + // make new RValue for Len + let deref_arg = tcx.mk_place_deref(arg); + let r_value = Rvalue::Len(deref_arg); + let len_statement_kind = + StatementKind::Assign(Box::new((*destination, r_value))); + let add_statement = + Statement { kind: len_statement_kind, source_info: terminator.source_info }; + + // modify terminator into simple Goto + let new_terminator_kind = TerminatorKind::Goto { target: *bb }; + + let patch = SliceLenPatchInformation { add_statement, new_terminator_kind }; + + patch_found = Some(patch); + } + _ => {} + } + } + _ => {} + } + + if let Some(SliceLenPatchInformation { add_statement, new_terminator_kind }) = patch_found { + block.statements.push(add_statement); + block.terminator_mut().kind = new_terminator_kind; + } +} diff --git a/compiler/rustc_mir_transform/src/marker.rs b/compiler/rustc_mir_transform/src/marker.rs new file mode 100644 index 000000000..06819fc1d --- /dev/null +++ b/compiler/rustc_mir_transform/src/marker.rs @@ -0,0 +1,20 @@ +use std::borrow::Cow; + +use crate::MirPass; +use rustc_middle::mir::{Body, MirPhase}; +use rustc_middle::ty::TyCtxt; + +/// Changes the MIR phase without changing the MIR itself. +pub struct PhaseChange(pub MirPhase); + +impl<'tcx> MirPass<'tcx> for PhaseChange { + fn phase_change(&self) -> Option<MirPhase> { + Some(self.0) + } + + fn name(&self) -> Cow<'_, str> { + Cow::from(format!("PhaseChange-{:?}", self.0)) + } + + fn run_pass(&self, _: TyCtxt<'tcx>, _body: &mut Body<'tcx>) {} +} diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs new file mode 100644 index 000000000..a0ba69c89 --- /dev/null +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -0,0 +1,176 @@ +use crate::MirPass; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use std::iter; + +use super::simplify::simplify_cfg; + +pub struct MatchBranchSimplification; + +/// If a source block is found that switches between two blocks that are exactly +/// the same modulo const bool assignments (e.g., one assigns true another false +/// to the same place), merge a target block statements into the source block, +/// using Eq / Ne comparison with switch value where const bools value differ. +/// +/// For example: +/// +/// ```ignore (MIR) +/// bb0: { +/// switchInt(move _3) -> [42_isize: bb1, otherwise: bb2]; +/// } +/// +/// bb1: { +/// _2 = const true; +/// goto -> bb3; +/// } +/// +/// bb2: { +/// _2 = const false; +/// goto -> bb3; +/// } +/// ``` +/// +/// into: +/// +/// ```ignore (MIR) +/// bb0: { +/// _2 = Eq(move _3, const 42_isize); +/// goto -> bb3; +/// } +/// ``` + +impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 3 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let def_id = body.source.def_id(); + let param_env = tcx.param_env(def_id); + + let bbs = body.basic_blocks.as_mut(); + let mut should_cleanup = false; + 'outer: for bb_idx in bbs.indices() { + if !tcx.consider_optimizing(|| format!("MatchBranchSimplification {:?} ", def_id)) { + continue; + } + + let (discr, val, switch_ty, first, second) = match bbs[bb_idx].terminator().kind { + TerminatorKind::SwitchInt { + discr: ref discr @ (Operand::Copy(_) | Operand::Move(_)), + switch_ty, + ref targets, + .. + } if targets.iter().len() == 1 => { + let (value, target) = targets.iter().next().unwrap(); + if target == targets.otherwise() { + continue; + } + (discr, value, switch_ty, target, targets.otherwise()) + } + // Only optimize switch int statements + _ => continue, + }; + + // Check that destinations are identical, and if not, then don't optimize this block + if bbs[first].terminator().kind != bbs[second].terminator().kind { + continue; + } + + // Check that blocks are assignments of consts to the same place or same statement, + // and match up 1-1, if not don't optimize this block. + let first_stmts = &bbs[first].statements; + let scnd_stmts = &bbs[second].statements; + if first_stmts.len() != scnd_stmts.len() { + continue; + } + for (f, s) in iter::zip(first_stmts, scnd_stmts) { + match (&f.kind, &s.kind) { + // If two statements are exactly the same, we can optimize. + (f_s, s_s) if f_s == s_s => {} + + // If two statements are const bool assignments to the same place, we can optimize. + ( + StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), + ) if lhs_f == lhs_s + && f_c.literal.ty().is_bool() + && s_c.literal.ty().is_bool() + && f_c.literal.try_eval_bool(tcx, param_env).is_some() + && s_c.literal.try_eval_bool(tcx, param_env).is_some() => {} + + // Otherwise we cannot optimize. Try another block. + _ => continue 'outer, + } + } + // Take ownership of items now that we know we can optimize. + let discr = discr.clone(); + + // Introduce a temporary for the discriminant value. + let source_info = bbs[bb_idx].terminator().source_info; + let discr_local = body.local_decls.push(LocalDecl::new(switch_ty, source_info.span)); + + // We already checked that first and second are different blocks, + // and bb_idx has a different terminator from both of them. + let (from, first, second) = bbs.pick3_mut(bb_idx, first, second); + + let new_stmts = iter::zip(&first.statements, &second.statements).map(|(f, s)| { + match (&f.kind, &s.kind) { + (f_s, s_s) if f_s == s_s => (*f).clone(), + + ( + StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), + StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(s_c)))), + ) => { + // From earlier loop we know that we are dealing with bool constants only: + let f_b = f_c.literal.try_eval_bool(tcx, param_env).unwrap(); + let s_b = s_c.literal.try_eval_bool(tcx, param_env).unwrap(); + if f_b == s_b { + // Same value in both blocks. Use statement as is. + (*f).clone() + } else { + // Different value between blocks. Make value conditional on switch condition. + let size = tcx.layout_of(param_env.and(switch_ty)).unwrap().size; + let const_cmp = Operand::const_from_scalar( + tcx, + switch_ty, + rustc_const_eval::interpret::Scalar::from_uint(val, size), + rustc_span::DUMMY_SP, + ); + let op = if f_b { BinOp::Eq } else { BinOp::Ne }; + let rhs = Rvalue::BinaryOp( + op, + Box::new((Operand::Copy(Place::from(discr_local)), const_cmp)), + ); + Statement { + source_info: f.source_info, + kind: StatementKind::Assign(Box::new((*lhs, rhs))), + } + } + } + + _ => unreachable!(), + } + }); + + from.statements + .push(Statement { source_info, kind: StatementKind::StorageLive(discr_local) }); + from.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + Place::from(discr_local), + Rvalue::Use(discr), + ))), + }); + from.statements.extend(new_stmts); + from.statements + .push(Statement { source_info, kind: StatementKind::StorageDead(discr_local) }); + from.terminator_mut().kind = first.terminator().kind.clone(); + should_cleanup = true; + } + + if should_cleanup { + simplify_cfg(tcx, body); + } + } +} diff --git a/compiler/rustc_mir_transform/src/multiple_return_terminators.rs b/compiler/rustc_mir_transform/src/multiple_return_terminators.rs new file mode 100644 index 000000000..22b6dead9 --- /dev/null +++ b/compiler/rustc_mir_transform/src/multiple_return_terminators.rs @@ -0,0 +1,43 @@ +//! This pass removes jumps to basic blocks containing only a return, and replaces them with a +//! return instead. + +use crate::{simplify, MirPass}; +use rustc_index::bit_set::BitSet; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; + +pub struct MultipleReturnTerminators; + +impl<'tcx> MirPass<'tcx> for MultipleReturnTerminators { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 4 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // find basic blocks with no statement and a return terminator + let mut bbs_simple_returns = BitSet::new_empty(body.basic_blocks().len()); + let def_id = body.source.def_id(); + let bbs = body.basic_blocks_mut(); + for idx in bbs.indices() { + if bbs[idx].statements.is_empty() + && bbs[idx].terminator().kind == TerminatorKind::Return + { + bbs_simple_returns.insert(idx); + } + } + + for bb in bbs { + if !tcx.consider_optimizing(|| format!("MultipleReturnTerminators {:?} ", def_id)) { + break; + } + + if let TerminatorKind::Goto { target } = bb.terminator().kind { + if bbs_simple_returns.contains(target) { + bb.terminator_mut().kind = TerminatorKind::Return; + } + } + } + + simplify::remove_dead_blocks(tcx, body) + } +} diff --git a/compiler/rustc_mir_transform/src/normalize_array_len.rs b/compiler/rustc_mir_transform/src/normalize_array_len.rs new file mode 100644 index 000000000..c0217a105 --- /dev/null +++ b/compiler/rustc_mir_transform/src/normalize_array_len.rs @@ -0,0 +1,287 @@ +//! This pass eliminates casting of arrays into slices when their length +//! is taken using `.len()` method. Handy to preserve information in MIR for const prop + +use crate::MirPass; +use rustc_data_structures::fx::FxIndexMap; +use rustc_data_structures::intern::Interned; +use rustc_index::bit_set::BitSet; +use rustc_index::vec::IndexVec; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, ReErased, Region, TyCtxt}; + +const MAX_NUM_BLOCKS: usize = 800; +const MAX_NUM_LOCALS: usize = 3000; + +pub struct NormalizeArrayLen; + +impl<'tcx> MirPass<'tcx> for NormalizeArrayLen { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 4 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // early returns for edge cases of highly unrolled functions + if body.basic_blocks().len() > MAX_NUM_BLOCKS { + return; + } + if body.local_decls().len() > MAX_NUM_LOCALS { + return; + } + normalize_array_len_calls(tcx, body) + } +} + +pub fn normalize_array_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // We don't ever touch terminators, so no need to invalidate the CFG cache + let basic_blocks = body.basic_blocks.as_mut_preserves_cfg(); + let local_decls = &mut body.local_decls; + + // do a preliminary analysis to see if we ever have locals of type `[T;N]` or `&[T;N]` + let mut interesting_locals = BitSet::new_empty(local_decls.len()); + for (local, decl) in local_decls.iter_enumerated() { + match decl.ty.kind() { + ty::Array(..) => { + interesting_locals.insert(local); + } + ty::Ref(.., ty, Mutability::Not) => match ty.kind() { + ty::Array(..) => { + interesting_locals.insert(local); + } + _ => {} + }, + _ => {} + } + } + if interesting_locals.is_empty() { + // we have found nothing to analyze + return; + } + let num_intesting_locals = interesting_locals.count(); + let mut state = FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default()); + let mut patches_scratchpad = + FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default()); + let mut replacements_scratchpad = + FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default()); + for block in basic_blocks { + // make length calls for arrays [T; N] not to decay into length calls for &[T] + // that forbids constant propagation + normalize_array_len_call( + tcx, + block, + local_decls, + &interesting_locals, + &mut state, + &mut patches_scratchpad, + &mut replacements_scratchpad, + ); + state.clear(); + patches_scratchpad.clear(); + replacements_scratchpad.clear(); + } +} + +struct Patcher<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + patches_scratchpad: &'a FxIndexMap<usize, usize>, + replacements_scratchpad: &'a mut FxIndexMap<usize, Local>, + local_decls: &'a mut IndexVec<Local, LocalDecl<'tcx>>, + statement_idx: usize, +} + +impl<'tcx> Patcher<'_, 'tcx> { + fn patch_expand_statement( + &mut self, + statement: &mut Statement<'tcx>, + ) -> Option<std::vec::IntoIter<Statement<'tcx>>> { + let idx = self.statement_idx; + if let Some(len_statemnt_idx) = self.patches_scratchpad.get(&idx).copied() { + let mut statements = Vec::with_capacity(2); + + // we are at statement that performs a cast. The only sound way is + // to create another local that performs a similar copy without a cast and then + // use this copy in the Len operation + + match &statement.kind { + StatementKind::Assign(box ( + .., + Rvalue::Cast( + CastKind::Pointer(ty::adjustment::PointerCast::Unsize), + operand, + _, + ), + )) => { + match operand { + Operand::Copy(place) | Operand::Move(place) => { + // create new local + let ty = operand.ty(self.local_decls, self.tcx); + let local_decl = LocalDecl::with_source_info(ty, statement.source_info); + let local = self.local_decls.push(local_decl); + // make it live + let mut make_live_statement = statement.clone(); + make_live_statement.kind = StatementKind::StorageLive(local); + statements.push(make_live_statement); + // copy into it + + let operand = Operand::Copy(*place); + let mut make_copy_statement = statement.clone(); + let assign_to = Place::from(local); + let rvalue = Rvalue::Use(operand); + make_copy_statement.kind = + StatementKind::Assign(Box::new((assign_to, rvalue))); + statements.push(make_copy_statement); + + // to reorder we have to copy and make NOP + statements.push(statement.clone()); + statement.make_nop(); + + self.replacements_scratchpad.insert(len_statemnt_idx, local); + } + _ => { + unreachable!("it's a bug in the implementation") + } + } + } + _ => { + unreachable!("it's a bug in the implementation") + } + } + + self.statement_idx += 1; + + Some(statements.into_iter()) + } else if let Some(local) = self.replacements_scratchpad.get(&idx).copied() { + let mut statements = Vec::with_capacity(2); + + match &statement.kind { + StatementKind::Assign(box (into, Rvalue::Len(place))) => { + let add_deref = if let Some(..) = place.as_local() { + false + } else if let Some(..) = place.local_or_deref_local() { + true + } else { + unreachable!("it's a bug in the implementation") + }; + // replace len statement + let mut len_statement = statement.clone(); + let mut place = Place::from(local); + if add_deref { + place = self.tcx.mk_place_deref(place); + } + len_statement.kind = + StatementKind::Assign(Box::new((*into, Rvalue::Len(place)))); + statements.push(len_statement); + + // make temporary dead + let mut make_dead_statement = statement.clone(); + make_dead_statement.kind = StatementKind::StorageDead(local); + statements.push(make_dead_statement); + + // make original statement NOP + statement.make_nop(); + } + _ => { + unreachable!("it's a bug in the implementation") + } + } + + self.statement_idx += 1; + + Some(statements.into_iter()) + } else { + self.statement_idx += 1; + None + } + } +} + +fn normalize_array_len_call<'tcx>( + tcx: TyCtxt<'tcx>, + block: &mut BasicBlockData<'tcx>, + local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>, + interesting_locals: &BitSet<Local>, + state: &mut FxIndexMap<Local, usize>, + patches_scratchpad: &mut FxIndexMap<usize, usize>, + replacements_scratchpad: &mut FxIndexMap<usize, Local>, +) { + for (statement_idx, statement) in block.statements.iter_mut().enumerate() { + match &mut statement.kind { + StatementKind::Assign(box (place, rvalue)) => { + match rvalue { + Rvalue::Cast( + CastKind::Pointer(ty::adjustment::PointerCast::Unsize), + operand, + cast_ty, + ) => { + let Some(local) = place.as_local() else { return }; + match operand { + Operand::Copy(place) | Operand::Move(place) => { + let Some(operand_local) = place.local_or_deref_local() else { return; }; + if !interesting_locals.contains(operand_local) { + return; + } + let operand_ty = local_decls[operand_local].ty; + match (operand_ty.kind(), cast_ty.kind()) { + (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => { + if of_ty_src == of_ty_dst { + // this is a cast from [T; N] into [T], so we are good + state.insert(local, statement_idx); + } + } + // current way of patching doesn't allow to work with `mut` + ( + ty::Ref( + Region(Interned(ReErased, _)), + operand_ty, + Mutability::Not, + ), + ty::Ref( + Region(Interned(ReErased, _)), + cast_ty, + Mutability::Not, + ), + ) => { + match (operand_ty.kind(), cast_ty.kind()) { + // current way of patching doesn't allow to work with `mut` + (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => { + if of_ty_src == of_ty_dst { + // this is a cast from [T; N] into [T], so we are good + state.insert(local, statement_idx); + } + } + _ => {} + } + } + _ => {} + } + } + _ => {} + } + } + Rvalue::Len(place) => { + let Some(local) = place.local_or_deref_local() else { + return; + }; + if let Some(cast_statement_idx) = state.get(&local).copied() { + patches_scratchpad.insert(cast_statement_idx, statement_idx); + } + } + _ => { + // invalidate + state.remove(&place.local); + } + } + } + _ => {} + } + } + + let mut patcher = Patcher { + tcx, + patches_scratchpad: &*patches_scratchpad, + replacements_scratchpad, + local_decls, + statement_idx: 0, + }; + + block.expand_statements(|st| patcher.patch_expand_statement(st)); +} diff --git a/compiler/rustc_mir_transform/src/nrvo.rs b/compiler/rustc_mir_transform/src/nrvo.rs new file mode 100644 index 000000000..bb063915f --- /dev/null +++ b/compiler/rustc_mir_transform/src/nrvo.rs @@ -0,0 +1,236 @@ +//! See the docs for [`RenameReturnPlace`]. + +use rustc_hir::Mutability; +use rustc_index::bit_set::HybridBitSet; +use rustc_middle::mir::visit::{MutVisitor, NonUseContext, PlaceContext, Visitor}; +use rustc_middle::mir::{self, BasicBlock, Local, Location}; +use rustc_middle::ty::TyCtxt; + +use crate::MirPass; + +/// This pass looks for MIR that always copies the same local into the return place and eliminates +/// the copy by renaming all uses of that local to `_0`. +/// +/// This allows LLVM to perform an optimization similar to the named return value optimization +/// (NRVO) that is guaranteed in C++. This avoids a stack allocation and `memcpy` for the +/// relatively common pattern of allocating a buffer on the stack, mutating it, and returning it by +/// value like so: +/// +/// ```rust +/// fn foo(init: fn(&mut [u8; 1024])) -> [u8; 1024] { +/// let mut buf = [0; 1024]; +/// init(&mut buf); +/// buf +/// } +/// ``` +/// +/// For now, this pass is very simple and only capable of eliminating a single copy. A more general +/// version of copy propagation, such as the one based on non-overlapping live ranges in [#47954] and +/// [#71003], could yield even more benefits. +/// +/// [#47954]: https://github.com/rust-lang/rust/pull/47954 +/// [#71003]: https://github.com/rust-lang/rust/pull/71003 +pub struct RenameReturnPlace; + +impl<'tcx> MirPass<'tcx> for RenameReturnPlace { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() > 0 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) { + let def_id = body.source.def_id(); + let Some(returned_local) = local_eligible_for_nrvo(body) else { + debug!("`{:?}` was ineligible for NRVO", def_id); + return; + }; + + if !tcx.consider_optimizing(|| format!("RenameReturnPlace {:?}", def_id)) { + return; + } + + debug!( + "`{:?}` was eligible for NRVO, making {:?} the return place", + def_id, returned_local + ); + + RenameToReturnPlace { tcx, to_rename: returned_local }.visit_body(body); + + // Clean up the `NOP`s we inserted for statements made useless by our renaming. + for block_data in body.basic_blocks_mut() { + block_data.statements.retain(|stmt| stmt.kind != mir::StatementKind::Nop); + } + + // Overwrite the debuginfo of `_0` with that of the renamed local. + let (renamed_decl, ret_decl) = + body.local_decls.pick2_mut(returned_local, mir::RETURN_PLACE); + + // Sometimes, the return place is assigned a local of a different but coercible type, for + // example `&mut T` instead of `&T`. Overwriting the `LocalInfo` for the return place means + // its type may no longer match the return type of its function. This doesn't cause a + // problem in codegen because these two types are layout-compatible, but may be unexpected. + debug!("_0: {:?} = {:?}: {:?}", ret_decl.ty, returned_local, renamed_decl.ty); + ret_decl.clone_from(renamed_decl); + + // The return place is always mutable. + ret_decl.mutability = Mutability::Mut; + } +} + +/// MIR that is eligible for the NRVO must fulfill two conditions: +/// 1. The return place must not be read prior to the `Return` terminator. +/// 2. A simple assignment of a whole local to the return place (e.g., `_0 = _1`) must be the +/// only definition of the return place reaching the `Return` terminator. +/// +/// If the MIR fulfills both these conditions, this function returns the `Local` that is assigned +/// to the return place along all possible paths through the control-flow graph. +fn local_eligible_for_nrvo(body: &mut mir::Body<'_>) -> Option<Local> { + if IsReturnPlaceRead::run(body) { + return None; + } + + let mut copied_to_return_place = None; + for block in body.basic_blocks().indices() { + // Look for blocks with a `Return` terminator. + if !matches!(body[block].terminator().kind, mir::TerminatorKind::Return) { + continue; + } + + // Look for an assignment of a single local to the return place prior to the `Return`. + let returned_local = find_local_assigned_to_return_place(block, body)?; + match body.local_kind(returned_local) { + // FIXME: Can we do this for arguments as well? + mir::LocalKind::Arg => return None, + + mir::LocalKind::ReturnPointer => bug!("Return place was assigned to itself?"), + mir::LocalKind::Var | mir::LocalKind::Temp => {} + } + + // If multiple different locals are copied to the return place. We can't pick a + // single one to rename. + if copied_to_return_place.map_or(false, |old| old != returned_local) { + return None; + } + + copied_to_return_place = Some(returned_local); + } + + copied_to_return_place +} + +fn find_local_assigned_to_return_place( + start: BasicBlock, + body: &mut mir::Body<'_>, +) -> Option<Local> { + let mut block = start; + let mut seen = HybridBitSet::new_empty(body.basic_blocks().len()); + + // Iterate as long as `block` has exactly one predecessor that we have not yet visited. + while seen.insert(block) { + trace!("Looking for assignments to `_0` in {:?}", block); + + let local = body[block].statements.iter().rev().find_map(as_local_assigned_to_return_place); + if local.is_some() { + return local; + } + + match body.basic_blocks.predecessors()[block].as_slice() { + &[pred] => block = pred, + _ => return None, + } + } + + None +} + +// If this statement is an assignment of an unprojected local to the return place, +// return that local. +fn as_local_assigned_to_return_place(stmt: &mir::Statement<'_>) -> Option<Local> { + if let mir::StatementKind::Assign(box (lhs, rhs)) = &stmt.kind { + if lhs.as_local() == Some(mir::RETURN_PLACE) { + if let mir::Rvalue::Use(mir::Operand::Copy(rhs) | mir::Operand::Move(rhs)) = rhs { + return rhs.as_local(); + } + } + } + + None +} + +struct RenameToReturnPlace<'tcx> { + to_rename: Local, + tcx: TyCtxt<'tcx>, +} + +/// Replaces all uses of `self.to_rename` with `_0`. +impl<'tcx> MutVisitor<'tcx> for RenameToReturnPlace<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_statement(&mut self, stmt: &mut mir::Statement<'tcx>, loc: Location) { + // Remove assignments of the local being replaced to the return place, since it is now the + // return place: + // _0 = _1 + if as_local_assigned_to_return_place(stmt) == Some(self.to_rename) { + stmt.kind = mir::StatementKind::Nop; + return; + } + + // Remove storage annotations for the local being replaced: + // StorageLive(_1) + if let mir::StatementKind::StorageLive(local) | mir::StatementKind::StorageDead(local) = + stmt.kind + { + if local == self.to_rename { + stmt.kind = mir::StatementKind::Nop; + return; + } + } + + self.super_statement(stmt, loc) + } + + fn visit_terminator(&mut self, terminator: &mut mir::Terminator<'tcx>, loc: Location) { + // Ignore the implicit "use" of the return place in a `Return` statement. + if let mir::TerminatorKind::Return = terminator.kind { + return; + } + + self.super_terminator(terminator, loc); + } + + fn visit_local(&mut self, l: &mut Local, ctxt: PlaceContext, _: Location) { + if *l == mir::RETURN_PLACE { + assert_eq!(ctxt, PlaceContext::NonUse(NonUseContext::VarDebugInfo)); + } else if *l == self.to_rename { + *l = mir::RETURN_PLACE; + } + } +} + +struct IsReturnPlaceRead(bool); + +impl IsReturnPlaceRead { + fn run(body: &mir::Body<'_>) -> bool { + let mut vis = IsReturnPlaceRead(false); + vis.visit_body(body); + vis.0 + } +} + +impl<'tcx> Visitor<'tcx> for IsReturnPlaceRead { + fn visit_local(&mut self, l: Local, ctxt: PlaceContext, _: Location) { + if l == mir::RETURN_PLACE && ctxt.is_use() && !ctxt.is_place_assignment() { + self.0 = true; + } + } + + fn visit_terminator(&mut self, terminator: &mir::Terminator<'tcx>, loc: Location) { + // Ignore the implicit "use" of the return place in a `Return` statement. + if let mir::TerminatorKind::Return = terminator.kind { + return; + } + + self.super_terminator(terminator, loc); + } +} diff --git a/compiler/rustc_mir_transform/src/pass_manager.rs b/compiler/rustc_mir_transform/src/pass_manager.rs new file mode 100644 index 000000000..e27d4ab16 --- /dev/null +++ b/compiler/rustc_mir_transform/src/pass_manager.rs @@ -0,0 +1,157 @@ +use std::borrow::Cow; + +use rustc_middle::mir::{self, Body, MirPhase}; +use rustc_middle::ty::TyCtxt; +use rustc_session::Session; + +use crate::{validate, MirPass}; + +/// Just like `MirPass`, except it cannot mutate `Body`. +pub trait MirLint<'tcx> { + fn name(&self) -> Cow<'_, str> { + let name = std::any::type_name::<Self>(); + if let Some(tail) = name.rfind(':') { + Cow::from(&name[tail + 1..]) + } else { + Cow::from(name) + } + } + + fn is_enabled(&self, _sess: &Session) -> bool { + true + } + + fn run_lint(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>); +} + +/// An adapter for `MirLint`s that implements `MirPass`. +#[derive(Debug, Clone)] +pub struct Lint<T>(pub T); + +impl<'tcx, T> MirPass<'tcx> for Lint<T> +where + T: MirLint<'tcx>, +{ + fn name(&self) -> Cow<'_, str> { + self.0.name() + } + + fn is_enabled(&self, sess: &Session) -> bool { + self.0.is_enabled(sess) + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + self.0.run_lint(tcx, body) + } + + fn is_mir_dump_enabled(&self) -> bool { + false + } +} + +pub struct WithMinOptLevel<T>(pub u32, pub T); + +impl<'tcx, T> MirPass<'tcx> for WithMinOptLevel<T> +where + T: MirPass<'tcx>, +{ + fn name(&self) -> Cow<'_, str> { + self.1.name() + } + + fn is_enabled(&self, sess: &Session) -> bool { + sess.mir_opt_level() >= self.0 as usize + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + self.1.run_pass(tcx, body) + } + + fn phase_change(&self) -> Option<MirPhase> { + self.1.phase_change() + } +} + +pub fn run_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, passes: &[&dyn MirPass<'tcx>]) { + let start_phase = body.phase; + let mut cnt = 0; + + let validate = tcx.sess.opts.unstable_opts.validate_mir; + let overridden_passes = &tcx.sess.opts.unstable_opts.mir_enable_passes; + trace!(?overridden_passes); + + if validate { + validate_body(tcx, body, format!("start of phase transition from {:?}", start_phase)); + } + + for pass in passes { + let name = pass.name(); + + if let Some((_, polarity)) = overridden_passes.iter().rev().find(|(s, _)| s == &*name) { + trace!( + pass = %name, + "{} as requested by flag", + if *polarity { "Running" } else { "Not running" }, + ); + if !polarity { + continue; + } + } else { + if !pass.is_enabled(&tcx.sess) { + continue; + } + } + let dump_enabled = pass.is_mir_dump_enabled(); + + if dump_enabled { + dump_mir(tcx, body, start_phase, &name, cnt, false); + } + + pass.run_pass(tcx, body); + + if dump_enabled { + dump_mir(tcx, body, start_phase, &name, cnt, true); + cnt += 1; + } + + if let Some(new_phase) = pass.phase_change() { + if body.phase >= new_phase { + panic!("Invalid MIR phase transition from {:?} to {:?}", body.phase, new_phase); + } + + body.phase = new_phase; + } + + if validate { + validate_body(tcx, body, format!("after pass {}", pass.name())); + } + } + + if validate || body.phase == MirPhase::Optimized { + validate_body(tcx, body, format!("end of phase transition to {:?}", body.phase)); + } +} + +pub fn validate_body<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, when: String) { + validate::Validator { when, mir_phase: body.phase }.run_pass(tcx, body); +} + +pub fn dump_mir<'tcx>( + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, + phase: MirPhase, + pass_name: &str, + cnt: usize, + is_after: bool, +) { + let phase_index = phase as u32; + + mir::dump_mir( + tcx, + Some(&format_args!("{:03}-{:03}", phase_index, cnt)), + pass_name, + if is_after { &"after" } else { &"before" }, + body, + |_, _| Ok(()), + ); +} diff --git a/compiler/rustc_mir_transform/src/remove_false_edges.rs b/compiler/rustc_mir_transform/src/remove_false_edges.rs new file mode 100644 index 000000000..71f5ccf7e --- /dev/null +++ b/compiler/rustc_mir_transform/src/remove_false_edges.rs @@ -0,0 +1,29 @@ +use rustc_middle::mir::{Body, TerminatorKind}; +use rustc_middle::ty::TyCtxt; + +use crate::MirPass; + +/// Removes `FalseEdge` and `FalseUnwind` terminators from the MIR. +/// +/// These are only needed for borrow checking, and can be removed afterwards. +/// +/// FIXME: This should probably have its own MIR phase. +pub struct RemoveFalseEdges; + +impl<'tcx> MirPass<'tcx> for RemoveFalseEdges { + fn run_pass(&self, _: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + for block in body.basic_blocks_mut() { + let terminator = block.terminator_mut(); + terminator.kind = match terminator.kind { + TerminatorKind::FalseEdge { real_target, .. } => { + TerminatorKind::Goto { target: real_target } + } + TerminatorKind::FalseUnwind { real_target, .. } => { + TerminatorKind::Goto { target: real_target } + } + + _ => continue, + } + } + } +} diff --git a/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs b/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs new file mode 100644 index 000000000..5c441c5b1 --- /dev/null +++ b/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs @@ -0,0 +1,131 @@ +use crate::MirPass; +use rustc_index::bit_set::BitSet; +use rustc_middle::mir::patch::MirPatch; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use rustc_target::spec::PanicStrategy; + +/// A pass that removes noop landing pads and replaces jumps to them with +/// `None`. This is important because otherwise LLVM generates terrible +/// code for these. +pub struct RemoveNoopLandingPads; + +impl<'tcx> MirPass<'tcx> for RemoveNoopLandingPads { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.panic_strategy() != PanicStrategy::Abort + } + + fn run_pass(&self, _tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + debug!("remove_noop_landing_pads({:?})", body); + self.remove_nop_landing_pads(body) + } +} + +impl RemoveNoopLandingPads { + fn is_nop_landing_pad( + &self, + bb: BasicBlock, + body: &Body<'_>, + nop_landing_pads: &BitSet<BasicBlock>, + ) -> bool { + for stmt in &body[bb].statements { + match &stmt.kind { + StatementKind::FakeRead(..) + | StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::AscribeUserType(..) + | StatementKind::Coverage(..) + | StatementKind::Nop => { + // These are all noops in a landing pad + } + + StatementKind::Assign(box (place, Rvalue::Use(_) | Rvalue::Discriminant(_))) => { + if place.as_local().is_some() { + // Writing to a local (e.g., a drop flag) does not + // turn a landing pad to a non-nop + } else { + return false; + } + } + + StatementKind::Assign { .. } + | StatementKind::SetDiscriminant { .. } + | StatementKind::Deinit(..) + | StatementKind::CopyNonOverlapping(..) + | StatementKind::Retag { .. } => { + return false; + } + } + } + + let terminator = body[bb].terminator(); + match terminator.kind { + TerminatorKind::Goto { .. } + | TerminatorKind::Resume + | TerminatorKind::SwitchInt { .. } + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } => { + terminator.successors().all(|succ| nop_landing_pads.contains(succ)) + } + TerminatorKind::GeneratorDrop + | TerminatorKind::Yield { .. } + | TerminatorKind::Return + | TerminatorKind::Abort + | TerminatorKind::Unreachable + | TerminatorKind::Call { .. } + | TerminatorKind::Assert { .. } + | TerminatorKind::DropAndReplace { .. } + | TerminatorKind::Drop { .. } + | TerminatorKind::InlineAsm { .. } => false, + } + } + + fn remove_nop_landing_pads(&self, body: &mut Body<'_>) { + debug!("body: {:#?}", body); + + // make sure there's a resume block + let resume_block = { + let mut patch = MirPatch::new(body); + let resume_block = patch.resume_block(); + patch.apply(body); + resume_block + }; + debug!("remove_noop_landing_pads: resume block is {:?}", resume_block); + + let mut jumps_folded = 0; + let mut landing_pads_removed = 0; + let mut nop_landing_pads = BitSet::new_empty(body.basic_blocks().len()); + + // This is a post-order traversal, so that if A post-dominates B + // then A will be visited before B. + let postorder: Vec<_> = traversal::postorder(body).map(|(bb, _)| bb).collect(); + for bb in postorder { + debug!(" processing {:?}", bb); + if let Some(unwind) = body[bb].terminator_mut().unwind_mut() { + if let Some(unwind_bb) = *unwind { + if nop_landing_pads.contains(unwind_bb) { + debug!(" removing noop landing pad"); + landing_pads_removed += 1; + *unwind = None; + } + } + } + + for target in body[bb].terminator_mut().successors_mut() { + if *target != resume_block && nop_landing_pads.contains(*target) { + debug!(" folding noop jump to {:?} to resume block", target); + *target = resume_block; + jumps_folded += 1; + } + } + + let is_nop_landing_pad = self.is_nop_landing_pad(bb, body, &nop_landing_pads); + if is_nop_landing_pad { + nop_landing_pads.insert(bb); + } + debug!(" is_nop_landing_pad({:?}) = {}", bb, is_nop_landing_pad); + } + + debug!("removed {:?} jumps and {:?} landing pads", jumps_folded, landing_pads_removed); + } +} diff --git a/compiler/rustc_mir_transform/src/remove_storage_markers.rs b/compiler/rustc_mir_transform/src/remove_storage_markers.rs new file mode 100644 index 000000000..dbe082e90 --- /dev/null +++ b/compiler/rustc_mir_transform/src/remove_storage_markers.rs @@ -0,0 +1,29 @@ +//! This pass removes storage markers if they won't be emitted during codegen. + +use crate::MirPass; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; + +pub struct RemoveStorageMarkers; + +impl<'tcx> MirPass<'tcx> for RemoveStorageMarkers { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() > 0 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + if tcx.sess.emit_lifetime_markers() { + return; + } + + trace!("Running RemoveStorageMarkers on {:?}", body.source); + for data in body.basic_blocks.as_mut_preserves_cfg() { + data.statements.retain(|statement| match statement.kind { + StatementKind::StorageLive(..) + | StatementKind::StorageDead(..) + | StatementKind::Nop => false, + _ => true, + }) + } + } +} diff --git a/compiler/rustc_mir_transform/src/remove_uninit_drops.rs b/compiler/rustc_mir_transform/src/remove_uninit_drops.rs new file mode 100644 index 000000000..96b715402 --- /dev/null +++ b/compiler/rustc_mir_transform/src/remove_uninit_drops.rs @@ -0,0 +1,171 @@ +use rustc_index::bit_set::ChunkedBitSet; +use rustc_middle::mir::{Body, Field, Rvalue, Statement, StatementKind, TerminatorKind}; +use rustc_middle::ty::subst::SubstsRef; +use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt, VariantDef}; +use rustc_mir_dataflow::impls::MaybeInitializedPlaces; +use rustc_mir_dataflow::move_paths::{LookupResult, MoveData, MovePathIndex}; +use rustc_mir_dataflow::{self, move_path_children_matching, Analysis, MoveDataParamEnv}; + +use crate::MirPass; + +/// Removes `Drop` and `DropAndReplace` terminators whose target is known to be uninitialized at +/// that point. +/// +/// This is redundant with drop elaboration, but we need to do it prior to const-checking, and +/// running const-checking after drop elaboration makes it optimization dependent, causing issues +/// like [#90770]. +/// +/// [#90770]: https://github.com/rust-lang/rust/issues/90770 +pub struct RemoveUninitDrops; + +impl<'tcx> MirPass<'tcx> for RemoveUninitDrops { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let param_env = tcx.param_env(body.source.def_id()); + let Ok((_,move_data)) = MoveData::gather_moves(body, tcx, param_env) else { + // We could continue if there are move errors, but there's not much point since our + // init data isn't complete. + return; + }; + + let mdpe = MoveDataParamEnv { move_data, param_env }; + let mut maybe_inits = MaybeInitializedPlaces::new(tcx, body, &mdpe) + .into_engine(tcx, body) + .pass_name("remove_uninit_drops") + .iterate_to_fixpoint() + .into_results_cursor(body); + + let mut to_remove = vec![]; + for (bb, block) in body.basic_blocks().iter_enumerated() { + let terminator = block.terminator(); + let (TerminatorKind::Drop { place, .. } | TerminatorKind::DropAndReplace { place, .. }) + = &terminator.kind + else { continue }; + + maybe_inits.seek_before_primary_effect(body.terminator_loc(bb)); + + // If there's no move path for the dropped place, it's probably a `Deref`. Let it alone. + let LookupResult::Exact(mpi) = mdpe.move_data.rev_lookup.find(place.as_ref()) else { + continue; + }; + + let should_keep = is_needs_drop_and_init( + tcx, + param_env, + maybe_inits.get(), + &mdpe.move_data, + place.ty(body, tcx).ty, + mpi, + ); + if !should_keep { + to_remove.push(bb) + } + } + + for bb in to_remove { + let block = &mut body.basic_blocks_mut()[bb]; + + let (TerminatorKind::Drop { target, .. } | TerminatorKind::DropAndReplace { target, .. }) + = &block.terminator().kind + else { unreachable!() }; + + // Replace block terminator with `Goto`. + let target = *target; + let old_terminator_kind = std::mem::replace( + &mut block.terminator_mut().kind, + TerminatorKind::Goto { target }, + ); + + // If this is a `DropAndReplace`, we need to emulate the assignment to the return place. + if let TerminatorKind::DropAndReplace { place, value, .. } = old_terminator_kind { + block.statements.push(Statement { + source_info: block.terminator().source_info, + kind: StatementKind::Assign(Box::new((place, Rvalue::Use(value)))), + }); + } + } + } +} + +fn is_needs_drop_and_init<'tcx>( + tcx: TyCtxt<'tcx>, + param_env: ParamEnv<'tcx>, + maybe_inits: &ChunkedBitSet<MovePathIndex>, + move_data: &MoveData<'tcx>, + ty: Ty<'tcx>, + mpi: MovePathIndex, +) -> bool { + // No need to look deeper if the root is definitely uninit or if it has no `Drop` impl. + if !maybe_inits.contains(mpi) || !ty.needs_drop(tcx, param_env) { + return false; + } + + let field_needs_drop_and_init = |(f, f_ty, mpi)| { + let child = move_path_children_matching(move_data, mpi, |x| x.is_field_to(f)); + let Some(mpi) = child else { + return Ty::needs_drop(f_ty, tcx, param_env); + }; + + is_needs_drop_and_init(tcx, param_env, maybe_inits, move_data, f_ty, mpi) + }; + + // This pass is only needed for const-checking, so it doesn't handle as many cases as + // `DropCtxt::open_drop`, since they aren't relevant in a const-context. + match ty.kind() { + ty::Adt(adt, substs) => { + let dont_elaborate = adt.is_union() || adt.is_manually_drop() || adt.has_dtor(tcx); + if dont_elaborate { + return true; + } + + // Look at all our fields, or if we are an enum all our variants and their fields. + // + // If a field's projection *is not* present in `MoveData`, it has the same + // initializedness as its parent (maybe init). + // + // If its projection *is* present in `MoveData`, then the field may have been moved + // from separate from its parent. Recurse. + adt.variants().iter_enumerated().any(|(vid, variant)| { + // Enums have multiple variants, which are discriminated with a `Downcast` projection. + // Structs have a single variant, and don't use a `Downcast` projection. + let mpi = if adt.is_enum() { + let downcast = + move_path_children_matching(move_data, mpi, |x| x.is_downcast_to(vid)); + let Some(dc_mpi) = downcast else { + return variant_needs_drop(tcx, param_env, substs, variant); + }; + + dc_mpi + } else { + mpi + }; + + variant + .fields + .iter() + .enumerate() + .map(|(f, field)| (Field::from_usize(f), field.ty(tcx, substs), mpi)) + .any(field_needs_drop_and_init) + }) + } + + ty::Tuple(fields) => fields + .iter() + .enumerate() + .map(|(f, f_ty)| (Field::from_usize(f), f_ty, mpi)) + .any(field_needs_drop_and_init), + + _ => true, + } +} + +fn variant_needs_drop<'tcx>( + tcx: TyCtxt<'tcx>, + param_env: ParamEnv<'tcx>, + substs: SubstsRef<'tcx>, + variant: &VariantDef, +) -> bool { + variant.fields.iter().any(|field| { + let f_ty = field.ty(tcx, substs); + f_ty.needs_drop(tcx, param_env) + }) +} diff --git a/compiler/rustc_mir_transform/src/remove_unneeded_drops.rs b/compiler/rustc_mir_transform/src/remove_unneeded_drops.rs new file mode 100644 index 000000000..84ccf6e1f --- /dev/null +++ b/compiler/rustc_mir_transform/src/remove_unneeded_drops.rs @@ -0,0 +1,45 @@ +//! This pass replaces a drop of a type that does not need dropping, with a goto. +//! +//! When the MIR is built, we check `needs_drop` before emitting a `Drop` for a place. This pass is +//! useful because (unlike MIR building) it runs after type checking, so it can make use of +//! `Reveal::All` to provide more precise type information. + +use crate::MirPass; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; + +use super::simplify::simplify_cfg; + +pub struct RemoveUnneededDrops; + +impl<'tcx> MirPass<'tcx> for RemoveUnneededDrops { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + trace!("Running RemoveUnneededDrops on {:?}", body.source); + + let did = body.source.def_id(); + let param_env = tcx.param_env_reveal_all_normalized(did); + let mut should_simplify = false; + + for block in body.basic_blocks.as_mut() { + let terminator = block.terminator_mut(); + if let TerminatorKind::Drop { place, target, .. } = terminator.kind { + let ty = place.ty(&body.local_decls, tcx); + if ty.ty.needs_drop(tcx, param_env) { + continue; + } + if !tcx.consider_optimizing(|| format!("RemoveUnneededDrops {:?} ", did)) { + continue; + } + debug!("SUCCESS: replacing `drop` with goto({:?})", target); + terminator.kind = TerminatorKind::Goto { target }; + should_simplify = true; + } + } + + // if we applied optimizations, we potentially have some cfg to cleanup to + // make it easier for further passes + if should_simplify { + simplify_cfg(tcx, body); + } + } +} diff --git a/compiler/rustc_mir_transform/src/remove_zsts.rs b/compiler/rustc_mir_transform/src/remove_zsts.rs new file mode 100644 index 000000000..40be4f146 --- /dev/null +++ b/compiler/rustc_mir_transform/src/remove_zsts.rs @@ -0,0 +1,86 @@ +//! Removes assignments to ZST places. + +use crate::MirPass; +use rustc_middle::mir::tcx::PlaceTy; +use rustc_middle::mir::{Body, LocalDecls, Place, StatementKind}; +use rustc_middle::ty::{self, Ty, TyCtxt}; + +pub struct RemoveZsts; + +impl<'tcx> MirPass<'tcx> for RemoveZsts { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() > 0 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // Avoid query cycles (generators require optimized MIR for layout). + if tcx.type_of(body.source.def_id()).is_generator() { + return; + } + let param_env = tcx.param_env(body.source.def_id()); + let basic_blocks = body.basic_blocks.as_mut_preserves_cfg(); + let local_decls = &body.local_decls; + for block in basic_blocks { + for statement in block.statements.iter_mut() { + if let StatementKind::Assign(box (place, _)) | StatementKind::Deinit(box place) = + statement.kind + { + let place_ty = place.ty(local_decls, tcx).ty; + if !maybe_zst(place_ty) { + continue; + } + let Ok(layout) = tcx.layout_of(param_env.and(place_ty)) else { + continue; + }; + if !layout.is_zst() { + continue; + } + if involves_a_union(place, local_decls, tcx) { + continue; + } + if tcx.consider_optimizing(|| { + format!( + "RemoveZsts - Place: {:?} SourceInfo: {:?}", + place, statement.source_info + ) + }) { + statement.make_nop(); + } + } + } + } + } +} + +/// A cheap, approximate check to avoid unnecessary `layout_of` calls. +fn maybe_zst(ty: Ty<'_>) -> bool { + match ty.kind() { + // maybe ZST (could be more precise) + ty::Adt(..) | ty::Array(..) | ty::Closure(..) | ty::Tuple(..) | ty::Opaque(..) => true, + // definitely ZST + ty::FnDef(..) | ty::Never => true, + // unreachable or can't be ZST + _ => false, + } +} + +/// Miri lazily allocates memory for locals on assignment, +/// so we must preserve writes to unions and union fields, +/// or it will ICE on reads of those fields. +fn involves_a_union<'tcx>( + place: Place<'tcx>, + local_decls: &LocalDecls<'tcx>, + tcx: TyCtxt<'tcx>, +) -> bool { + let mut place_ty = PlaceTy::from_ty(local_decls[place.local].ty); + if place_ty.ty.is_union() { + return true; + } + for elem in place.projection { + place_ty = place_ty.projection_ty(tcx, elem); + if place_ty.ty.is_union() { + return true; + } + } + return false; +} diff --git a/compiler/rustc_mir_transform/src/required_consts.rs b/compiler/rustc_mir_transform/src/required_consts.rs new file mode 100644 index 000000000..827ce0c02 --- /dev/null +++ b/compiler/rustc_mir_transform/src/required_consts.rs @@ -0,0 +1,22 @@ +use rustc_middle::mir::visit::Visitor; +use rustc_middle::mir::{Constant, Location}; +use rustc_middle::ty::ConstKind; + +pub struct RequiredConstsVisitor<'a, 'tcx> { + required_consts: &'a mut Vec<Constant<'tcx>>, +} + +impl<'a, 'tcx> RequiredConstsVisitor<'a, 'tcx> { + pub fn new(required_consts: &'a mut Vec<Constant<'tcx>>) -> Self { + RequiredConstsVisitor { required_consts } + } +} + +impl<'tcx> Visitor<'tcx> for RequiredConstsVisitor<'_, 'tcx> { + fn visit_constant(&mut self, constant: &Constant<'tcx>, _: Location) { + let literal = constant.literal; + if let Some(ct) = literal.const_for_ty() && let ConstKind::Unevaluated(_) = ct.kind() { + self.required_consts.push(*constant); + } + } +} diff --git a/compiler/rustc_mir_transform/src/reveal_all.rs b/compiler/rustc_mir_transform/src/reveal_all.rs new file mode 100644 index 000000000..4919ad400 --- /dev/null +++ b/compiler/rustc_mir_transform/src/reveal_all.rs @@ -0,0 +1,44 @@ +//! Normalizes MIR in RevealAll mode. + +use crate::MirPass; +use rustc_middle::mir::visit::*; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, Ty, TyCtxt}; + +pub struct RevealAll; + +impl<'tcx> MirPass<'tcx> for RevealAll { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 3 || super::inline::Inline.is_enabled(sess) + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // Do not apply this transformation to generators. + if body.generator.is_some() { + return; + } + + let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); + RevealAllVisitor { tcx, param_env }.visit_body(body); + } +} + +struct RevealAllVisitor<'tcx> { + tcx: TyCtxt<'tcx>, + param_env: ty::ParamEnv<'tcx>, +} + +impl<'tcx> MutVisitor<'tcx> for RevealAllVisitor<'tcx> { + #[inline] + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + #[inline] + fn visit_ty(&mut self, ty: &mut Ty<'tcx>, _: TyContext) { + // We have to use `try_normalize_erasing_regions` here, since it's + // possible that we visit impossible-to-satisfy where clauses here, + // see #91745 + *ty = self.tcx.try_normalize_erasing_regions(self.param_env, *ty).unwrap_or(*ty); + } +} diff --git a/compiler/rustc_mir_transform/src/separate_const_switch.rs b/compiler/rustc_mir_transform/src/separate_const_switch.rs new file mode 100644 index 000000000..925eb10a1 --- /dev/null +++ b/compiler/rustc_mir_transform/src/separate_const_switch.rs @@ -0,0 +1,341 @@ +//! A pass that duplicates switch-terminated blocks +//! into a new copy for each predecessor, provided +//! the predecessor sets the value being switched +//! over to a constant. +//! +//! The purpose of this pass is to help constant +//! propagation passes to simplify the switch terminator +//! of the copied blocks into gotos when some predecessors +//! statically determine the output of switches. +//! +//! ```text +//! x = 12 --- ---> something +//! \ / 12 +//! --> switch x +//! / \ otherwise +//! x = y --- ---> something else +//! ``` +//! becomes +//! ```text +//! x = 12 ---> switch x ------> something +//! \ / 12 +//! X +//! / \ otherwise +//! x = y ---> switch x ------> something else +//! ``` +//! so it can hopefully later be turned by another pass into +//! ```text +//! x = 12 --------------------> something +//! / 12 +//! / +//! / otherwise +//! x = y ---- switch x ------> something else +//! ``` +//! +//! This optimization is meant to cover simple cases +//! like `?` desugaring. For now, it thus focuses on +//! simplicity rather than completeness (it notably +//! sometimes duplicates abusively). + +use crate::MirPass; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use smallvec::SmallVec; + +pub struct SeparateConstSwitch; + +impl<'tcx> MirPass<'tcx> for SeparateConstSwitch { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 4 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // If execution did something, applying a simplification layer + // helps later passes optimize the copy away. + if separate_const_switch(body) > 0 { + super::simplify::simplify_cfg(tcx, body); + } + } +} + +/// Returns the amount of blocks that were duplicated +pub fn separate_const_switch(body: &mut Body<'_>) -> usize { + let mut new_blocks: SmallVec<[(BasicBlock, BasicBlock); 6]> = SmallVec::new(); + let predecessors = body.basic_blocks.predecessors(); + 'block_iter: for (block_id, block) in body.basic_blocks().iter_enumerated() { + if let TerminatorKind::SwitchInt { + discr: Operand::Copy(switch_place) | Operand::Move(switch_place), + .. + } = block.terminator().kind + { + // If the block is on an unwind path, do not + // apply the optimization as unwind paths + // rely on a unique parent invariant + if block.is_cleanup { + continue 'block_iter; + } + + // If the block has fewer than 2 predecessors, ignore it + // we could maybe chain blocks that have exactly one + // predecessor, but for now we ignore that + if predecessors[block_id].len() < 2 { + continue 'block_iter; + } + + // First, let's find a non-const place + // that determines the result of the switch + if let Some(switch_place) = find_determining_place(switch_place, block) { + // We now have an input place for which it would + // be interesting if predecessors assigned it from a const + + let mut predecessors_left = predecessors[block_id].len(); + 'predec_iter: for predecessor_id in predecessors[block_id].iter().copied() { + let predecessor = &body.basic_blocks()[predecessor_id]; + + // First we make sure the predecessor jumps + // in a reasonable way + match &predecessor.terminator().kind { + // The following terminators are + // unconditionally valid + TerminatorKind::Goto { .. } | TerminatorKind::SwitchInt { .. } => {} + + TerminatorKind::FalseEdge { real_target, .. } => { + if *real_target != block_id { + continue 'predec_iter; + } + } + + // The following terminators are not allowed + TerminatorKind::Resume + | TerminatorKind::Drop { .. } + | TerminatorKind::DropAndReplace { .. } + | TerminatorKind::Call { .. } + | TerminatorKind::Assert { .. } + | TerminatorKind::FalseUnwind { .. } + | TerminatorKind::Yield { .. } + | TerminatorKind::Abort + | TerminatorKind::Return + | TerminatorKind::Unreachable + | TerminatorKind::InlineAsm { .. } + | TerminatorKind::GeneratorDrop => { + continue 'predec_iter; + } + } + + if is_likely_const(switch_place, predecessor) { + new_blocks.push((predecessor_id, block_id)); + predecessors_left -= 1; + if predecessors_left < 2 { + // If the original block only has one predecessor left, + // we have nothing left to do + break 'predec_iter; + } + } + } + } + } + } + + // Once the analysis is done, perform the duplication + let body_span = body.span; + let copied_blocks = new_blocks.len(); + let blocks = body.basic_blocks_mut(); + for (pred_id, target_id) in new_blocks { + let new_block = blocks[target_id].clone(); + let new_block_id = blocks.push(new_block); + let terminator = blocks[pred_id].terminator_mut(); + + match terminator.kind { + TerminatorKind::Goto { ref mut target } => { + *target = new_block_id; + } + + TerminatorKind::FalseEdge { ref mut real_target, .. } => { + if *real_target == target_id { + *real_target = new_block_id; + } + } + + TerminatorKind::SwitchInt { ref mut targets, .. } => { + targets.all_targets_mut().iter_mut().for_each(|x| { + if *x == target_id { + *x = new_block_id; + } + }); + } + + TerminatorKind::Resume + | TerminatorKind::Abort + | TerminatorKind::Return + | TerminatorKind::Unreachable + | TerminatorKind::GeneratorDrop + | TerminatorKind::Assert { .. } + | TerminatorKind::DropAndReplace { .. } + | TerminatorKind::FalseUnwind { .. } + | TerminatorKind::Drop { .. } + | TerminatorKind::Call { .. } + | TerminatorKind::InlineAsm { .. } + | TerminatorKind::Yield { .. } => { + span_bug!( + body_span, + "basic block terminator had unexpected kind {:?}", + &terminator.kind + ) + } + } + } + + copied_blocks +} + +/// This function describes a rough heuristic guessing +/// whether a place is last set with a const within the block. +/// Notably, it will be overly pessimistic in cases that are already +/// not handled by `separate_const_switch`. +fn is_likely_const<'tcx>(mut tracked_place: Place<'tcx>, block: &BasicBlockData<'tcx>) -> bool { + for statement in block.statements.iter().rev() { + match &statement.kind { + StatementKind::Assign(assign) => { + if assign.0 == tracked_place { + match assign.1 { + // These rvalues are definitely constant + Rvalue::Use(Operand::Constant(_)) + | Rvalue::Ref(_, _, _) + | Rvalue::AddressOf(_, _) + | Rvalue::Cast(_, Operand::Constant(_), _) + | Rvalue::NullaryOp(_, _) + | Rvalue::ShallowInitBox(_, _) + | Rvalue::UnaryOp(_, Operand::Constant(_)) => return true, + + // These rvalues make things ambiguous + Rvalue::Repeat(_, _) + | Rvalue::ThreadLocalRef(_) + | Rvalue::Len(_) + | Rvalue::BinaryOp(_, _) + | Rvalue::CheckedBinaryOp(_, _) + | Rvalue::Aggregate(_, _) => return false, + + // These rvalues move the place to track + Rvalue::Cast(_, Operand::Copy(place) | Operand::Move(place), _) + | Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) + | Rvalue::CopyForDeref(place) + | Rvalue::UnaryOp(_, Operand::Copy(place) | Operand::Move(place)) + | Rvalue::Discriminant(place) => tracked_place = place, + } + } + } + + // If the discriminant is set, it is always set + // as a constant, so the job is done. + // As we are **ignoring projections**, if the place + // we are tracking sees its discriminant be set, + // that means we had to be tracking the discriminant + // specifically (as it is impossible to switch over + // an enum directly, and if we were switching over + // its content, we would have had to at least cast it to + // some variant first) + StatementKind::SetDiscriminant { place, .. } => { + if **place == tracked_place { + return true; + } + } + + // These statements have no influence on the place + // we are interested in + StatementKind::FakeRead(_) + | StatementKind::Deinit(_) + | StatementKind::StorageLive(_) + | StatementKind::Retag(_, _) + | StatementKind::AscribeUserType(_, _) + | StatementKind::Coverage(_) + | StatementKind::StorageDead(_) + | StatementKind::CopyNonOverlapping(_) + | StatementKind::Nop => {} + } + } + + // If no good reason for the place to be const is found, + // give up. We could maybe go up predecessors, but in + // most cases giving up now should be sufficient. + false +} + +/// Finds a unique place that entirely determines the value +/// of `switch_place`, if it exists. This is only a heuristic. +/// Ideally we would like to track multiple determining places +/// for some edge cases, but one is enough for a lot of situations. +fn find_determining_place<'tcx>( + mut switch_place: Place<'tcx>, + block: &BasicBlockData<'tcx>, +) -> Option<Place<'tcx>> { + for statement in block.statements.iter().rev() { + match &statement.kind { + StatementKind::Assign(op) => { + if op.0 != switch_place { + continue; + } + + match op.1 { + // The following rvalues move the place + // that may be const in the predecessor + Rvalue::Use(Operand::Move(new) | Operand::Copy(new)) + | Rvalue::UnaryOp(_, Operand::Copy(new) | Operand::Move(new)) + | Rvalue::CopyForDeref(new) + | Rvalue::Cast(_, Operand::Move(new) | Operand::Copy(new), _) + | Rvalue::Repeat(Operand::Move(new) | Operand::Copy(new), _) + | Rvalue::Discriminant(new) + => switch_place = new, + + // The following rvalues might still make the block + // be valid but for now we reject them + Rvalue::Len(_) + | Rvalue::Ref(_, _, _) + | Rvalue::BinaryOp(_, _) + | Rvalue::CheckedBinaryOp(_, _) + | Rvalue::Aggregate(_, _) + + // The following rvalues definitely mean we cannot + // or should not apply this optimization + | Rvalue::Use(Operand::Constant(_)) + | Rvalue::Repeat(Operand::Constant(_), _) + | Rvalue::ThreadLocalRef(_) + | Rvalue::AddressOf(_, _) + | Rvalue::NullaryOp(_, _) + | Rvalue::ShallowInitBox(_, _) + | Rvalue::UnaryOp(_, Operand::Constant(_)) + | Rvalue::Cast(_, Operand::Constant(_), _) + => return None, + } + } + + // These statements have no influence on the place + // we are interested in + StatementKind::FakeRead(_) + | StatementKind::Deinit(_) + | StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::Retag(_, _) + | StatementKind::AscribeUserType(_, _) + | StatementKind::Coverage(_) + | StatementKind::CopyNonOverlapping(_) + | StatementKind::Nop => {} + + // If the discriminant is set, it is always set + // as a constant, so the job is already done. + // As we are **ignoring projections**, if the place + // we are tracking sees its discriminant be set, + // that means we had to be tracking the discriminant + // specifically (as it is impossible to switch over + // an enum directly, and if we were switching over + // its content, we would have had to at least cast it to + // some variant first) + StatementKind::SetDiscriminant { place, .. } => { + if **place == switch_place { + return None; + } + } + } + } + + Some(switch_place) +} diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs new file mode 100644 index 000000000..3620e94be --- /dev/null +++ b/compiler/rustc_mir_transform/src/shim.rs @@ -0,0 +1,790 @@ +use rustc_hir as hir; +use rustc_hir::def_id::DefId; +use rustc_hir::lang_items::LangItem; +use rustc_middle::mir::*; +use rustc_middle::ty::query::Providers; +use rustc_middle::ty::subst::{InternalSubsts, Subst}; +use rustc_middle::ty::{self, EarlyBinder, Ty, TyCtxt}; +use rustc_target::abi::VariantIdx; + +use rustc_index::vec::{Idx, IndexVec}; + +use rustc_span::Span; +use rustc_target::spec::abi::Abi; + +use std::fmt; +use std::iter; + +use crate::util::expand_aggregate; +use crate::{ + abort_unwinding_calls, add_call_guards, add_moves_for_packed_drops, marker, pass_manager as pm, + remove_noop_landing_pads, simplify, +}; +use rustc_middle::mir::patch::MirPatch; +use rustc_mir_dataflow::elaborate_drops::{self, DropElaborator, DropFlagMode, DropStyle}; + +pub fn provide(providers: &mut Providers) { + providers.mir_shims = make_shim; +} + +fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<'tcx> { + debug!("make_shim({:?})", instance); + + let mut result = match instance { + ty::InstanceDef::Item(..) => bug!("item {:?} passed to make_shim", instance), + ty::InstanceDef::VTableShim(def_id) => { + build_call_shim(tcx, instance, Some(Adjustment::Deref), CallKind::Direct(def_id)) + } + ty::InstanceDef::FnPtrShim(def_id, ty) => { + let trait_ = tcx.trait_of_item(def_id).unwrap(); + let adjustment = match tcx.fn_trait_kind_from_lang_item(trait_) { + Some(ty::ClosureKind::FnOnce) => Adjustment::Identity, + Some(ty::ClosureKind::FnMut | ty::ClosureKind::Fn) => Adjustment::Deref, + None => bug!("fn pointer {:?} is not an fn", ty), + }; + + build_call_shim(tcx, instance, Some(adjustment), CallKind::Indirect(ty)) + } + // We are generating a call back to our def-id, which the + // codegen backend knows to turn to an actual call, be it + // a virtual call, or a direct call to a function for which + // indirect calls must be codegen'd differently than direct ones + // (such as `#[track_caller]`). + ty::InstanceDef::ReifyShim(def_id) => { + build_call_shim(tcx, instance, None, CallKind::Direct(def_id)) + } + ty::InstanceDef::ClosureOnceShim { call_once: _, track_caller: _ } => { + let fn_mut = tcx.require_lang_item(LangItem::FnMut, None); + let call_mut = tcx + .associated_items(fn_mut) + .in_definition_order() + .find(|it| it.kind == ty::AssocKind::Fn) + .unwrap() + .def_id; + + build_call_shim(tcx, instance, Some(Adjustment::RefMut), CallKind::Direct(call_mut)) + } + + ty::InstanceDef::DropGlue(def_id, ty) => { + // FIXME(#91576): Drop shims for generators aren't subject to the MIR passes at the end + // of this function. Is this intentional? + if let Some(ty::Generator(gen_def_id, substs, _)) = ty.map(Ty::kind) { + let body = tcx.optimized_mir(*gen_def_id).generator_drop().unwrap(); + let body = EarlyBinder(body.clone()).subst(tcx, substs); + debug!("make_shim({:?}) = {:?}", instance, body); + return body; + } + + build_drop_shim(tcx, def_id, ty) + } + ty::InstanceDef::CloneShim(def_id, ty) => build_clone_shim(tcx, def_id, ty), + ty::InstanceDef::Virtual(..) => { + bug!("InstanceDef::Virtual ({:?}) is for direct calls only", instance) + } + ty::InstanceDef::Intrinsic(_) => { + bug!("creating shims from intrinsics ({:?}) is unsupported", instance) + } + }; + debug!("make_shim({:?}) = untransformed {:?}", instance, result); + + pm::run_passes( + tcx, + &mut result, + &[ + &add_moves_for_packed_drops::AddMovesForPackedDrops, + &remove_noop_landing_pads::RemoveNoopLandingPads, + &simplify::SimplifyCfg::new("make_shim"), + &add_call_guards::CriticalCallEdges, + &abort_unwinding_calls::AbortUnwindingCalls, + &marker::PhaseChange(MirPhase::Const), + ], + ); + + debug!("make_shim({:?}) = {:?}", instance, result); + + result +} + +#[derive(Copy, Clone, Debug, PartialEq)] +enum Adjustment { + /// Pass the receiver as-is. + Identity, + + /// We get passed `&[mut] self` and call the target with `*self`. + /// + /// This either copies `self` (if `Self: Copy`, eg. for function items), or moves out of it + /// (for `VTableShim`, which effectively is passed `&own Self`). + Deref, + + /// We get passed `self: Self` and call the target with `&mut self`. + /// + /// In this case we need to ensure that the `Self` is dropped after the call, as the callee + /// won't do it for us. + RefMut, +} + +#[derive(Copy, Clone, Debug, PartialEq)] +enum CallKind<'tcx> { + /// Call the `FnPtr` that was passed as the receiver. + Indirect(Ty<'tcx>), + + /// Call a known `FnDef`. + Direct(DefId), +} + +fn local_decls_for_sig<'tcx>( + sig: &ty::FnSig<'tcx>, + span: Span, +) -> IndexVec<Local, LocalDecl<'tcx>> { + iter::once(LocalDecl::new(sig.output(), span)) + .chain(sig.inputs().iter().map(|ity| LocalDecl::new(*ity, span).immutable())) + .collect() +} + +fn build_drop_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, ty: Option<Ty<'tcx>>) -> Body<'tcx> { + debug!("build_drop_shim(def_id={:?}, ty={:?})", def_id, ty); + + assert!(!matches!(ty, Some(ty) if ty.is_generator())); + + let substs = if let Some(ty) = ty { + tcx.intern_substs(&[ty.into()]) + } else { + InternalSubsts::identity_for_item(tcx, def_id) + }; + let sig = tcx.bound_fn_sig(def_id).subst(tcx, substs); + let sig = tcx.erase_late_bound_regions(sig); + let span = tcx.def_span(def_id); + + let source_info = SourceInfo::outermost(span); + + let return_block = BasicBlock::new(1); + let mut blocks = IndexVec::with_capacity(2); + let block = |blocks: &mut IndexVec<_, _>, kind| { + blocks.push(BasicBlockData { + statements: vec![], + terminator: Some(Terminator { source_info, kind }), + is_cleanup: false, + }) + }; + block(&mut blocks, TerminatorKind::Goto { target: return_block }); + block(&mut blocks, TerminatorKind::Return); + + let source = MirSource::from_instance(ty::InstanceDef::DropGlue(def_id, ty)); + let mut body = + new_body(source, blocks, local_decls_for_sig(&sig, span), sig.inputs().len(), span); + + if ty.is_some() { + // The first argument (index 0), but add 1 for the return value. + let dropee_ptr = Place::from(Local::new(1 + 0)); + if tcx.sess.opts.unstable_opts.mir_emit_retag { + // Function arguments should be retagged, and we make this one raw. + body.basic_blocks_mut()[START_BLOCK].statements.insert( + 0, + Statement { + source_info, + kind: StatementKind::Retag(RetagKind::Raw, Box::new(dropee_ptr)), + }, + ); + } + let patch = { + let param_env = tcx.param_env_reveal_all_normalized(def_id); + let mut elaborator = + DropShimElaborator { body: &body, patch: MirPatch::new(&body), tcx, param_env }; + let dropee = tcx.mk_place_deref(dropee_ptr); + let resume_block = elaborator.patch.resume_block(); + elaborate_drops::elaborate_drop( + &mut elaborator, + source_info, + dropee, + (), + return_block, + elaborate_drops::Unwind::To(resume_block), + START_BLOCK, + ); + elaborator.patch + }; + patch.apply(&mut body); + } + + body +} + +fn new_body<'tcx>( + source: MirSource<'tcx>, + basic_blocks: IndexVec<BasicBlock, BasicBlockData<'tcx>>, + local_decls: IndexVec<Local, LocalDecl<'tcx>>, + arg_count: usize, + span: Span, +) -> Body<'tcx> { + Body::new( + source, + basic_blocks, + IndexVec::from_elem_n( + SourceScopeData { + span, + parent_scope: None, + inlined: None, + inlined_parent_scope: None, + local_data: ClearCrossCrate::Clear, + }, + 1, + ), + local_decls, + IndexVec::new(), + arg_count, + vec![], + span, + None, + // FIXME(compiler-errors): is this correct? + None, + ) +} + +pub struct DropShimElaborator<'a, 'tcx> { + pub body: &'a Body<'tcx>, + pub patch: MirPatch<'tcx>, + pub tcx: TyCtxt<'tcx>, + pub param_env: ty::ParamEnv<'tcx>, +} + +impl fmt::Debug for DropShimElaborator<'_, '_> { + fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + Ok(()) + } +} + +impl<'a, 'tcx> DropElaborator<'a, 'tcx> for DropShimElaborator<'a, 'tcx> { + type Path = (); + + fn patch(&mut self) -> &mut MirPatch<'tcx> { + &mut self.patch + } + fn body(&self) -> &'a Body<'tcx> { + self.body + } + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + fn param_env(&self) -> ty::ParamEnv<'tcx> { + self.param_env + } + + fn drop_style(&self, _path: Self::Path, mode: DropFlagMode) -> DropStyle { + match mode { + DropFlagMode::Shallow => { + // Drops for the contained fields are "shallow" and "static" - they will simply call + // the field's own drop glue. + DropStyle::Static + } + DropFlagMode::Deep => { + // The top-level drop is "deep" and "open" - it will be elaborated to a drop ladder + // dropping each field contained in the value. + DropStyle::Open + } + } + } + + fn get_drop_flag(&mut self, _path: Self::Path) -> Option<Operand<'tcx>> { + None + } + + fn clear_drop_flag(&mut self, _location: Location, _path: Self::Path, _mode: DropFlagMode) {} + + fn field_subpath(&self, _path: Self::Path, _field: Field) -> Option<Self::Path> { + None + } + fn deref_subpath(&self, _path: Self::Path) -> Option<Self::Path> { + None + } + fn downcast_subpath(&self, _path: Self::Path, _variant: VariantIdx) -> Option<Self::Path> { + Some(()) + } + fn array_subpath(&self, _path: Self::Path, _index: u64, _size: u64) -> Option<Self::Path> { + None + } +} + +/// Builds a `Clone::clone` shim for `self_ty`. Here, `def_id` is `Clone::clone`. +fn build_clone_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'tcx>) -> Body<'tcx> { + debug!("build_clone_shim(def_id={:?})", def_id); + + let param_env = tcx.param_env(def_id); + + let mut builder = CloneShimBuilder::new(tcx, def_id, self_ty); + let is_copy = self_ty.is_copy_modulo_regions(tcx.at(builder.span), param_env); + + let dest = Place::return_place(); + let src = tcx.mk_place_deref(Place::from(Local::new(1 + 0))); + + match self_ty.kind() { + _ if is_copy => builder.copy_shim(), + ty::Closure(_, substs) => { + builder.tuple_like_shim(dest, src, substs.as_closure().upvar_tys()) + } + ty::Tuple(..) => builder.tuple_like_shim(dest, src, self_ty.tuple_fields()), + _ => bug!("clone shim for `{:?}` which is not `Copy` and is not an aggregate", self_ty), + }; + + builder.into_mir() +} + +struct CloneShimBuilder<'tcx> { + tcx: TyCtxt<'tcx>, + def_id: DefId, + local_decls: IndexVec<Local, LocalDecl<'tcx>>, + blocks: IndexVec<BasicBlock, BasicBlockData<'tcx>>, + span: Span, + sig: ty::FnSig<'tcx>, +} + +impl<'tcx> CloneShimBuilder<'tcx> { + fn new(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'tcx>) -> Self { + // we must subst the self_ty because it's + // otherwise going to be TySelf and we can't index + // or access fields of a Place of type TySelf. + let substs = tcx.mk_substs_trait(self_ty, &[]); + let sig = tcx.bound_fn_sig(def_id).subst(tcx, substs); + let sig = tcx.erase_late_bound_regions(sig); + let span = tcx.def_span(def_id); + + CloneShimBuilder { + tcx, + def_id, + local_decls: local_decls_for_sig(&sig, span), + blocks: IndexVec::new(), + span, + sig, + } + } + + fn into_mir(self) -> Body<'tcx> { + let source = MirSource::from_instance(ty::InstanceDef::CloneShim( + self.def_id, + self.sig.inputs_and_output[0], + )); + new_body(source, self.blocks, self.local_decls, self.sig.inputs().len(), self.span) + } + + fn source_info(&self) -> SourceInfo { + SourceInfo::outermost(self.span) + } + + fn block( + &mut self, + statements: Vec<Statement<'tcx>>, + kind: TerminatorKind<'tcx>, + is_cleanup: bool, + ) -> BasicBlock { + let source_info = self.source_info(); + self.blocks.push(BasicBlockData { + statements, + terminator: Some(Terminator { source_info, kind }), + is_cleanup, + }) + } + + /// Gives the index of an upcoming BasicBlock, with an offset. + /// offset=0 will give you the index of the next BasicBlock, + /// offset=1 will give the index of the next-to-next block, + /// offset=-1 will give you the index of the last-created block + fn block_index_offset(&mut self, offset: usize) -> BasicBlock { + BasicBlock::new(self.blocks.len() + offset) + } + + fn make_statement(&self, kind: StatementKind<'tcx>) -> Statement<'tcx> { + Statement { source_info: self.source_info(), kind } + } + + fn copy_shim(&mut self) { + let rcvr = self.tcx.mk_place_deref(Place::from(Local::new(1 + 0))); + let ret_statement = self.make_statement(StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::Use(Operand::Copy(rcvr)), + )))); + self.block(vec![ret_statement], TerminatorKind::Return, false); + } + + fn make_place(&mut self, mutability: Mutability, ty: Ty<'tcx>) -> Place<'tcx> { + let span = self.span; + let mut local = LocalDecl::new(ty, span); + if mutability == Mutability::Not { + local = local.immutable(); + } + Place::from(self.local_decls.push(local)) + } + + fn make_clone_call( + &mut self, + dest: Place<'tcx>, + src: Place<'tcx>, + ty: Ty<'tcx>, + next: BasicBlock, + cleanup: BasicBlock, + ) { + let tcx = self.tcx; + + let substs = tcx.mk_substs_trait(ty, &[]); + + // `func == Clone::clone(&ty) -> ty` + let func_ty = tcx.mk_fn_def(self.def_id, substs); + let func = Operand::Constant(Box::new(Constant { + span: self.span, + user_ty: None, + literal: ConstantKind::zero_sized(func_ty), + })); + + let ref_loc = self.make_place( + Mutability::Not, + tcx.mk_ref(tcx.lifetimes.re_erased, ty::TypeAndMut { ty, mutbl: hir::Mutability::Not }), + ); + + // `let ref_loc: &ty = &src;` + let statement = self.make_statement(StatementKind::Assign(Box::new(( + ref_loc, + Rvalue::Ref(tcx.lifetimes.re_erased, BorrowKind::Shared, src), + )))); + + // `let loc = Clone::clone(ref_loc);` + self.block( + vec![statement], + TerminatorKind::Call { + func, + args: vec![Operand::Move(ref_loc)], + destination: dest, + target: Some(next), + cleanup: Some(cleanup), + from_hir_call: true, + fn_span: self.span, + }, + false, + ); + } + + fn tuple_like_shim<I>(&mut self, dest: Place<'tcx>, src: Place<'tcx>, tys: I) + where + I: IntoIterator<Item = Ty<'tcx>>, + { + let mut previous_field = None; + for (i, ity) in tys.into_iter().enumerate() { + let field = Field::new(i); + let src_field = self.tcx.mk_place_field(src, field, ity); + + let dest_field = self.tcx.mk_place_field(dest, field, ity); + + // #(2i + 1) is the cleanup block for the previous clone operation + let cleanup_block = self.block_index_offset(1); + // #(2i + 2) is the next cloning block + // (or the Return terminator if this is the last block) + let next_block = self.block_index_offset(2); + + // BB #(2i) + // `dest.i = Clone::clone(&src.i);` + // Goto #(2i + 2) if ok, #(2i + 1) if unwinding happens. + self.make_clone_call(dest_field, src_field, ity, next_block, cleanup_block); + + // BB #(2i + 1) (cleanup) + if let Some((previous_field, previous_cleanup)) = previous_field.take() { + // Drop previous field and goto previous cleanup block. + self.block( + vec![], + TerminatorKind::Drop { + place: previous_field, + target: previous_cleanup, + unwind: None, + }, + true, + ); + } else { + // Nothing to drop, just resume. + self.block(vec![], TerminatorKind::Resume, true); + } + + previous_field = Some((dest_field, cleanup_block)); + } + + self.block(vec![], TerminatorKind::Return, false); + } +} + +/// Builds a "call" shim for `instance`. The shim calls the function specified by `call_kind`, +/// first adjusting its first argument according to `rcvr_adjustment`. +fn build_call_shim<'tcx>( + tcx: TyCtxt<'tcx>, + instance: ty::InstanceDef<'tcx>, + rcvr_adjustment: Option<Adjustment>, + call_kind: CallKind<'tcx>, +) -> Body<'tcx> { + debug!( + "build_call_shim(instance={:?}, rcvr_adjustment={:?}, call_kind={:?})", + instance, rcvr_adjustment, call_kind + ); + + // `FnPtrShim` contains the fn pointer type that a call shim is being built for - this is used + // to substitute into the signature of the shim. It is not necessary for users of this + // MIR body to perform further substitutions (see `InstanceDef::has_polymorphic_mir_body`). + let (sig_substs, untuple_args) = if let ty::InstanceDef::FnPtrShim(_, ty) = instance { + let sig = tcx.erase_late_bound_regions(ty.fn_sig(tcx)); + + let untuple_args = sig.inputs(); + + // Create substitutions for the `Self` and `Args` generic parameters of the shim body. + let arg_tup = tcx.mk_tup(untuple_args.iter()); + let sig_substs = tcx.mk_substs_trait(ty, &[ty::subst::GenericArg::from(arg_tup)]); + + (Some(sig_substs), Some(untuple_args)) + } else { + (None, None) + }; + + let def_id = instance.def_id(); + let sig = tcx.bound_fn_sig(def_id); + let sig = sig.map_bound(|sig| tcx.erase_late_bound_regions(sig)); + + assert_eq!(sig_substs.is_some(), !instance.has_polymorphic_mir_body()); + let mut sig = + if let Some(sig_substs) = sig_substs { sig.subst(tcx, sig_substs) } else { sig.0 }; + + if let CallKind::Indirect(fnty) = call_kind { + // `sig` determines our local decls, and thus the callee type in the `Call` terminator. This + // can only be an `FnDef` or `FnPtr`, but currently will be `Self` since the types come from + // the implemented `FnX` trait. + + // Apply the opposite adjustment to the MIR input. + let mut inputs_and_output = sig.inputs_and_output.to_vec(); + + // Initial signature is `fn(&? Self, Args) -> Self::Output` where `Args` is a tuple of the + // fn arguments. `Self` may be passed via (im)mutable reference or by-value. + assert_eq!(inputs_and_output.len(), 3); + + // `Self` is always the original fn type `ty`. The MIR call terminator is only defined for + // `FnDef` and `FnPtr` callees, not the `Self` type param. + let self_arg = &mut inputs_and_output[0]; + *self_arg = match rcvr_adjustment.unwrap() { + Adjustment::Identity => fnty, + Adjustment::Deref => tcx.mk_imm_ptr(fnty), + Adjustment::RefMut => tcx.mk_mut_ptr(fnty), + }; + sig.inputs_and_output = tcx.intern_type_list(&inputs_and_output); + } + + // FIXME(eddyb) avoid having this snippet both here and in + // `Instance::fn_sig` (introduce `InstanceDef::fn_sig`?). + if let ty::InstanceDef::VTableShim(..) = instance { + // Modify fn(self, ...) to fn(self: *mut Self, ...) + let mut inputs_and_output = sig.inputs_and_output.to_vec(); + let self_arg = &mut inputs_and_output[0]; + debug_assert!(tcx.generics_of(def_id).has_self && *self_arg == tcx.types.self_param); + *self_arg = tcx.mk_mut_ptr(*self_arg); + sig.inputs_and_output = tcx.intern_type_list(&inputs_and_output); + } + + let span = tcx.def_span(def_id); + + debug!("build_call_shim: sig={:?}", sig); + + let mut local_decls = local_decls_for_sig(&sig, span); + let source_info = SourceInfo::outermost(span); + + let rcvr_place = || { + assert!(rcvr_adjustment.is_some()); + Place::from(Local::new(1 + 0)) + }; + let mut statements = vec![]; + + let rcvr = rcvr_adjustment.map(|rcvr_adjustment| match rcvr_adjustment { + Adjustment::Identity => Operand::Move(rcvr_place()), + Adjustment::Deref => Operand::Move(tcx.mk_place_deref(rcvr_place())), + Adjustment::RefMut => { + // let rcvr = &mut rcvr; + let ref_rcvr = local_decls.push( + LocalDecl::new( + tcx.mk_ref( + tcx.lifetimes.re_erased, + ty::TypeAndMut { ty: sig.inputs()[0], mutbl: hir::Mutability::Mut }, + ), + span, + ) + .immutable(), + ); + let borrow_kind = BorrowKind::Mut { allow_two_phase_borrow: false }; + statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + Place::from(ref_rcvr), + Rvalue::Ref(tcx.lifetimes.re_erased, borrow_kind, rcvr_place()), + ))), + }); + Operand::Move(Place::from(ref_rcvr)) + } + }); + + let (callee, mut args) = match call_kind { + // `FnPtr` call has no receiver. Args are untupled below. + CallKind::Indirect(_) => (rcvr.unwrap(), vec![]), + + // `FnDef` call with optional receiver. + CallKind::Direct(def_id) => { + let ty = tcx.type_of(def_id); + ( + Operand::Constant(Box::new(Constant { + span, + user_ty: None, + literal: ConstantKind::zero_sized(ty), + })), + rcvr.into_iter().collect::<Vec<_>>(), + ) + } + }; + + let mut arg_range = 0..sig.inputs().len(); + + // Take the `self` ("receiver") argument out of the range (it's adjusted above). + if rcvr_adjustment.is_some() { + arg_range.start += 1; + } + + // Take the last argument, if we need to untuple it (handled below). + if untuple_args.is_some() { + arg_range.end -= 1; + } + + // Pass all of the non-special arguments directly. + args.extend(arg_range.map(|i| Operand::Move(Place::from(Local::new(1 + i))))); + + // Untuple the last argument, if we have to. + if let Some(untuple_args) = untuple_args { + let tuple_arg = Local::new(1 + (sig.inputs().len() - 1)); + args.extend(untuple_args.iter().enumerate().map(|(i, ity)| { + Operand::Move(tcx.mk_place_field(Place::from(tuple_arg), Field::new(i), *ity)) + })); + } + + let n_blocks = if let Some(Adjustment::RefMut) = rcvr_adjustment { 5 } else { 2 }; + let mut blocks = IndexVec::with_capacity(n_blocks); + let block = |blocks: &mut IndexVec<_, _>, statements, kind, is_cleanup| { + blocks.push(BasicBlockData { + statements, + terminator: Some(Terminator { source_info, kind }), + is_cleanup, + }) + }; + + // BB #0 + block( + &mut blocks, + statements, + TerminatorKind::Call { + func: callee, + args, + destination: Place::return_place(), + target: Some(BasicBlock::new(1)), + cleanup: if let Some(Adjustment::RefMut) = rcvr_adjustment { + Some(BasicBlock::new(3)) + } else { + None + }, + from_hir_call: true, + fn_span: span, + }, + false, + ); + + if let Some(Adjustment::RefMut) = rcvr_adjustment { + // BB #1 - drop for Self + block( + &mut blocks, + vec![], + TerminatorKind::Drop { place: rcvr_place(), target: BasicBlock::new(2), unwind: None }, + false, + ); + } + // BB #1/#2 - return + block(&mut blocks, vec![], TerminatorKind::Return, false); + if let Some(Adjustment::RefMut) = rcvr_adjustment { + // BB #3 - drop if closure panics + block( + &mut blocks, + vec![], + TerminatorKind::Drop { place: rcvr_place(), target: BasicBlock::new(4), unwind: None }, + true, + ); + + // BB #4 - resume + block(&mut blocks, vec![], TerminatorKind::Resume, true); + } + + let mut body = + new_body(MirSource::from_instance(instance), blocks, local_decls, sig.inputs().len(), span); + + if let Abi::RustCall = sig.abi { + body.spread_arg = Some(Local::new(sig.inputs().len())); + } + + body +} + +pub fn build_adt_ctor(tcx: TyCtxt<'_>, ctor_id: DefId) -> Body<'_> { + debug_assert!(tcx.is_constructor(ctor_id)); + + let param_env = tcx.param_env(ctor_id); + + // Normalize the sig. + let sig = tcx.fn_sig(ctor_id).no_bound_vars().expect("LBR in ADT constructor signature"); + let sig = tcx.normalize_erasing_regions(param_env, sig); + + let ty::Adt(adt_def, substs) = sig.output().kind() else { + bug!("unexpected type for ADT ctor {:?}", sig.output()); + }; + + debug!("build_ctor: ctor_id={:?} sig={:?}", ctor_id, sig); + + let span = tcx.def_span(ctor_id); + + let local_decls = local_decls_for_sig(&sig, span); + + let source_info = SourceInfo::outermost(span); + + let variant_index = if adt_def.is_enum() { + adt_def.variant_index_with_ctor_id(ctor_id) + } else { + VariantIdx::new(0) + }; + + // Generate the following MIR: + // + // (return as Variant).field0 = arg0; + // (return as Variant).field1 = arg1; + // + // return; + debug!("build_ctor: variant_index={:?}", variant_index); + + let statements = expand_aggregate( + Place::return_place(), + adt_def.variant(variant_index).fields.iter().enumerate().map(|(idx, field_def)| { + (Operand::Move(Place::from(Local::new(idx + 1))), field_def.ty(tcx, substs)) + }), + AggregateKind::Adt(adt_def.did(), variant_index, substs, None, None), + source_info, + tcx, + ) + .collect(); + + let start_block = BasicBlockData { + statements, + terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }), + is_cleanup: false, + }; + + let source = MirSource::item(ctor_id); + let body = new_body( + source, + IndexVec::from_elem_n(start_block, 1), + local_decls, + sig.inputs().len(), + span, + ); + + rustc_middle::mir::dump_mir(tcx, None, "mir_map", &0, &body, |_, _| Ok(())); + + body +} diff --git a/compiler/rustc_mir_transform/src/simplify.rs b/compiler/rustc_mir_transform/src/simplify.rs new file mode 100644 index 000000000..180f4c7dc --- /dev/null +++ b/compiler/rustc_mir_transform/src/simplify.rs @@ -0,0 +1,590 @@ +//! A number of passes which remove various redundancies in the CFG. +//! +//! The `SimplifyCfg` pass gets rid of unnecessary blocks in the CFG, whereas the `SimplifyLocals` +//! gets rid of all the unnecessary local variable declarations. +//! +//! The `SimplifyLocals` pass is kinda expensive and therefore not very suitable to be run often. +//! Most of the passes should not care or be impacted in meaningful ways due to extra locals +//! either, so running the pass once, right before codegen, should suffice. +//! +//! On the other side of the spectrum, the `SimplifyCfg` pass is considerably cheap to run, thus +//! one should run it after every pass which may modify CFG in significant ways. This pass must +//! also be run before any analysis passes because it removes dead blocks, and some of these can be +//! ill-typed. +//! +//! The cause of this typing issue is typeck allowing most blocks whose end is not reachable have +//! an arbitrary return type, rather than having the usual () return type (as a note, typeck's +//! notion of reachability is in fact slightly weaker than MIR CFG reachability - see #31617). A +//! standard example of the situation is: +//! +//! ```rust +//! fn example() { +//! let _a: char = { return; }; +//! } +//! ``` +//! +//! Here the block (`{ return; }`) has the return type `char`, rather than `()`, but the MIR we +//! naively generate still contains the `_a = ()` write in the unreachable block "after" the +//! return. + +use crate::MirPass; +use rustc_data_structures::fx::FxHashSet; +use rustc_index::vec::{Idx, IndexVec}; +use rustc_middle::mir::coverage::*; +use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor}; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use smallvec::SmallVec; +use std::borrow::Cow; +use std::convert::TryInto; + +pub struct SimplifyCfg { + label: String, +} + +impl SimplifyCfg { + pub fn new(label: &str) -> Self { + SimplifyCfg { label: format!("SimplifyCfg-{}", label) } + } +} + +pub fn simplify_cfg<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + CfgSimplifier::new(body).simplify(); + remove_dead_blocks(tcx, body); + + // FIXME: Should probably be moved into some kind of pass manager + body.basic_blocks_mut().raw.shrink_to_fit(); +} + +impl<'tcx> MirPass<'tcx> for SimplifyCfg { + fn name(&self) -> Cow<'_, str> { + Cow::Borrowed(&self.label) + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + debug!("SimplifyCfg({:?}) - simplifying {:?}", self.label, body.source); + simplify_cfg(tcx, body); + } +} + +pub struct CfgSimplifier<'a, 'tcx> { + basic_blocks: &'a mut IndexVec<BasicBlock, BasicBlockData<'tcx>>, + pred_count: IndexVec<BasicBlock, u32>, +} + +impl<'a, 'tcx> CfgSimplifier<'a, 'tcx> { + pub fn new(body: &'a mut Body<'tcx>) -> Self { + let mut pred_count = IndexVec::from_elem(0u32, body.basic_blocks()); + + // we can't use mir.predecessors() here because that counts + // dead blocks, which we don't want to. + pred_count[START_BLOCK] = 1; + + for (_, data) in traversal::preorder(body) { + if let Some(ref term) = data.terminator { + for tgt in term.successors() { + pred_count[tgt] += 1; + } + } + } + + let basic_blocks = body.basic_blocks_mut(); + + CfgSimplifier { basic_blocks, pred_count } + } + + pub fn simplify(mut self) { + self.strip_nops(); + + // Vec of the blocks that should be merged. We store the indices here, instead of the + // statements itself to avoid moving the (relatively) large statements twice. + // We do not push the statements directly into the target block (`bb`) as that is slower + // due to additional reallocations + let mut merged_blocks = Vec::new(); + loop { + let mut changed = false; + + for bb in self.basic_blocks.indices() { + if self.pred_count[bb] == 0 { + continue; + } + + debug!("simplifying {:?}", bb); + + let mut terminator = + self.basic_blocks[bb].terminator.take().expect("invalid terminator state"); + + for successor in terminator.successors_mut() { + self.collapse_goto_chain(successor, &mut changed); + } + + let mut inner_changed = true; + merged_blocks.clear(); + while inner_changed { + inner_changed = false; + inner_changed |= self.simplify_branch(&mut terminator); + inner_changed |= self.merge_successor(&mut merged_blocks, &mut terminator); + changed |= inner_changed; + } + + let statements_to_merge = + merged_blocks.iter().map(|&i| self.basic_blocks[i].statements.len()).sum(); + + if statements_to_merge > 0 { + let mut statements = std::mem::take(&mut self.basic_blocks[bb].statements); + statements.reserve(statements_to_merge); + for &from in &merged_blocks { + statements.append(&mut self.basic_blocks[from].statements); + } + self.basic_blocks[bb].statements = statements; + } + + self.basic_blocks[bb].terminator = Some(terminator); + } + + if !changed { + break; + } + } + } + + /// This function will return `None` if + /// * the block has statements + /// * the block has a terminator other than `goto` + /// * the block has no terminator (meaning some other part of the current optimization stole it) + fn take_terminator_if_simple_goto(&mut self, bb: BasicBlock) -> Option<Terminator<'tcx>> { + match self.basic_blocks[bb] { + BasicBlockData { + ref statements, + terminator: + ref mut terminator @ Some(Terminator { kind: TerminatorKind::Goto { .. }, .. }), + .. + } if statements.is_empty() => terminator.take(), + // if `terminator` is None, this means we are in a loop. In that + // case, let all the loop collapse to its entry. + _ => None, + } + } + + /// Collapse a goto chain starting from `start` + fn collapse_goto_chain(&mut self, start: &mut BasicBlock, changed: &mut bool) { + // Using `SmallVec` here, because in some logs on libcore oli-obk saw many single-element + // goto chains. We should probably benchmark different sizes. + let mut terminators: SmallVec<[_; 1]> = Default::default(); + let mut current = *start; + while let Some(terminator) = self.take_terminator_if_simple_goto(current) { + let Terminator { kind: TerminatorKind::Goto { target }, .. } = terminator else { + unreachable!(); + }; + terminators.push((current, terminator)); + current = target; + } + let last = current; + *start = last; + while let Some((current, mut terminator)) = terminators.pop() { + let Terminator { kind: TerminatorKind::Goto { ref mut target }, .. } = terminator else { + unreachable!(); + }; + *changed |= *target != last; + *target = last; + debug!("collapsing goto chain from {:?} to {:?}", current, target); + + if self.pred_count[current] == 1 { + // This is the last reference to current, so the pred-count to + // to target is moved into the current block. + self.pred_count[current] = 0; + } else { + self.pred_count[*target] += 1; + self.pred_count[current] -= 1; + } + self.basic_blocks[current].terminator = Some(terminator); + } + } + + // merge a block with 1 `goto` predecessor to its parent + fn merge_successor( + &mut self, + merged_blocks: &mut Vec<BasicBlock>, + terminator: &mut Terminator<'tcx>, + ) -> bool { + let target = match terminator.kind { + TerminatorKind::Goto { target } if self.pred_count[target] == 1 => target, + _ => return false, + }; + + debug!("merging block {:?} into {:?}", target, terminator); + *terminator = match self.basic_blocks[target].terminator.take() { + Some(terminator) => terminator, + None => { + // unreachable loop - this should not be possible, as we + // don't strand blocks, but handle it correctly. + return false; + } + }; + + merged_blocks.push(target); + self.pred_count[target] = 0; + + true + } + + // turn a branch with all successors identical to a goto + fn simplify_branch(&mut self, terminator: &mut Terminator<'tcx>) -> bool { + match terminator.kind { + TerminatorKind::SwitchInt { .. } => {} + _ => return false, + }; + + let first_succ = { + if let Some(first_succ) = terminator.successors().next() { + if terminator.successors().all(|s| s == first_succ) { + let count = terminator.successors().count(); + self.pred_count[first_succ] -= (count - 1) as u32; + first_succ + } else { + return false; + } + } else { + return false; + } + }; + + debug!("simplifying branch {:?}", terminator); + terminator.kind = TerminatorKind::Goto { target: first_succ }; + true + } + + fn strip_nops(&mut self) { + for blk in self.basic_blocks.iter_mut() { + blk.statements.retain(|stmt| !matches!(stmt.kind, StatementKind::Nop)) + } + } +} + +pub fn remove_dead_blocks<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let reachable = traversal::reachable_as_bitset(body); + let num_blocks = body.basic_blocks().len(); + if num_blocks == reachable.count() { + return; + } + + let basic_blocks = body.basic_blocks.as_mut(); + let source_scopes = &body.source_scopes; + let mut replacements: Vec<_> = (0..num_blocks).map(BasicBlock::new).collect(); + let mut used_blocks = 0; + for alive_index in reachable.iter() { + let alive_index = alive_index.index(); + replacements[alive_index] = BasicBlock::new(used_blocks); + if alive_index != used_blocks { + // Swap the next alive block data with the current available slot. Since + // alive_index is non-decreasing this is a valid operation. + basic_blocks.raw.swap(alive_index, used_blocks); + } + used_blocks += 1; + } + + if tcx.sess.instrument_coverage() { + save_unreachable_coverage(basic_blocks, source_scopes, used_blocks); + } + + basic_blocks.raw.truncate(used_blocks); + + for block in basic_blocks { + for target in block.terminator_mut().successors_mut() { + *target = replacements[target.index()]; + } + } +} + +/// Some MIR transforms can determine at compile time that a sequences of +/// statements will never be executed, so they can be dropped from the MIR. +/// For example, an `if` or `else` block that is guaranteed to never be executed +/// because its condition can be evaluated at compile time, such as by const +/// evaluation: `if false { ... }`. +/// +/// Those statements are bypassed by redirecting paths in the CFG around the +/// `dead blocks`; but with `-C instrument-coverage`, the dead blocks usually +/// include `Coverage` statements representing the Rust source code regions to +/// be counted at runtime. Without these `Coverage` statements, the regions are +/// lost, and the Rust source code will show no coverage information. +/// +/// What we want to show in a coverage report is the dead code with coverage +/// counts of `0`. To do this, we need to save the code regions, by injecting +/// `Unreachable` coverage statements. These are non-executable statements whose +/// code regions are still recorded in the coverage map, representing regions +/// with `0` executions. +/// +/// If there are no live `Counter` `Coverage` statements remaining, we remove +/// `Coverage` statements along with the dead blocks. Since at least one +/// counter per function is required by LLVM (and necessary, to add the +/// `function_hash` to the counter's call to the LLVM intrinsic +/// `instrprof.increment()`). +/// +/// The `generator::StateTransform` MIR pass and MIR inlining can create +/// atypical conditions, where all live `Counter`s are dropped from the MIR. +/// +/// With MIR inlining we can have coverage counters belonging to different +/// instances in a single body, so the strategy described above is applied to +/// coverage counters from each instance individually. +fn save_unreachable_coverage( + basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>, + source_scopes: &IndexVec<SourceScope, SourceScopeData<'_>>, + first_dead_block: usize, +) { + // Identify instances that still have some live coverage counters left. + let mut live = FxHashSet::default(); + for basic_block in &basic_blocks.raw[0..first_dead_block] { + for statement in &basic_block.statements { + let StatementKind::Coverage(coverage) = &statement.kind else { continue }; + let CoverageKind::Counter { .. } = coverage.kind else { continue }; + let instance = statement.source_info.scope.inlined_instance(source_scopes); + live.insert(instance); + } + } + + for block in &mut basic_blocks.raw[..first_dead_block] { + for statement in &mut block.statements { + let StatementKind::Coverage(_) = &statement.kind else { continue }; + let instance = statement.source_info.scope.inlined_instance(source_scopes); + if !live.contains(&instance) { + statement.make_nop(); + } + } + } + + if live.is_empty() { + return; + } + + // Retain coverage for instances that still have some live counters left. + let mut retained_coverage = Vec::new(); + for dead_block in &basic_blocks.raw[first_dead_block..] { + for statement in &dead_block.statements { + let StatementKind::Coverage(coverage) = &statement.kind else { continue }; + let Some(code_region) = &coverage.code_region else { continue }; + let instance = statement.source_info.scope.inlined_instance(source_scopes); + if live.contains(&instance) { + retained_coverage.push((statement.source_info, code_region.clone())); + } + } + } + + let start_block = &mut basic_blocks[START_BLOCK]; + start_block.statements.extend(retained_coverage.into_iter().map( + |(source_info, code_region)| Statement { + source_info, + kind: StatementKind::Coverage(Box::new(Coverage { + kind: CoverageKind::Unreachable, + code_region: Some(code_region), + })), + }, + )); +} + +pub struct SimplifyLocals; + +impl<'tcx> MirPass<'tcx> for SimplifyLocals { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() > 0 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + trace!("running SimplifyLocals on {:?}", body.source); + simplify_locals(body, tcx); + } +} + +pub fn simplify_locals<'tcx>(body: &mut Body<'tcx>, tcx: TyCtxt<'tcx>) { + // First, we're going to get a count of *actual* uses for every `Local`. + let mut used_locals = UsedLocals::new(body); + + // Next, we're going to remove any `Local` with zero actual uses. When we remove those + // `Locals`, we're also going to subtract any uses of other `Locals` from the `used_locals` + // count. For example, if we removed `_2 = discriminant(_1)`, then we'll subtract one from + // `use_counts[_1]`. That in turn might make `_1` unused, so we loop until we hit a + // fixedpoint where there are no more unused locals. + remove_unused_definitions(&mut used_locals, body); + + // Finally, we'll actually do the work of shrinking `body.local_decls` and remapping the `Local`s. + let map = make_local_map(&mut body.local_decls, &used_locals); + + // Only bother running the `LocalUpdater` if we actually found locals to remove. + if map.iter().any(Option::is_none) { + // Update references to all vars and tmps now + let mut updater = LocalUpdater { map, tcx }; + updater.visit_body(body); + + body.local_decls.shrink_to_fit(); + } +} + +/// Construct the mapping while swapping out unused stuff out from the `vec`. +fn make_local_map<V>( + local_decls: &mut IndexVec<Local, V>, + used_locals: &UsedLocals, +) -> IndexVec<Local, Option<Local>> { + let mut map: IndexVec<Local, Option<Local>> = IndexVec::from_elem(None, &*local_decls); + let mut used = Local::new(0); + + for alive_index in local_decls.indices() { + // `is_used` treats the `RETURN_PLACE` and arguments as used. + if !used_locals.is_used(alive_index) { + continue; + } + + map[alive_index] = Some(used); + if alive_index != used { + local_decls.swap(alive_index, used); + } + used.increment_by(1); + } + local_decls.truncate(used.index()); + map +} + +/// Keeps track of used & unused locals. +struct UsedLocals { + increment: bool, + arg_count: u32, + use_count: IndexVec<Local, u32>, +} + +impl UsedLocals { + /// Determines which locals are used & unused in the given body. + fn new(body: &Body<'_>) -> Self { + let mut this = Self { + increment: true, + arg_count: body.arg_count.try_into().unwrap(), + use_count: IndexVec::from_elem(0, &body.local_decls), + }; + this.visit_body(body); + this + } + + /// Checks if local is used. + /// + /// Return place and arguments are always considered used. + fn is_used(&self, local: Local) -> bool { + trace!("is_used({:?}): use_count: {:?}", local, self.use_count[local]); + local.as_u32() <= self.arg_count || self.use_count[local] != 0 + } + + /// Updates the use counts to reflect the removal of given statement. + fn statement_removed(&mut self, statement: &Statement<'_>) { + self.increment = false; + + // The location of the statement is irrelevant. + let location = Location { block: START_BLOCK, statement_index: 0 }; + self.visit_statement(statement, location); + } + + /// Visits a left-hand side of an assignment. + fn visit_lhs(&mut self, place: &Place<'_>, location: Location) { + if place.is_indirect() { + // A use, not a definition. + self.visit_place(place, PlaceContext::MutatingUse(MutatingUseContext::Store), location); + } else { + // A definition. The base local itself is not visited, so this occurrence is not counted + // toward its use count. There might be other locals still, used in an indexing + // projection. + self.super_projection( + place.as_ref(), + PlaceContext::MutatingUse(MutatingUseContext::Projection), + location, + ); + } + } +} + +impl<'tcx> Visitor<'tcx> for UsedLocals { + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + match statement.kind { + StatementKind::CopyNonOverlapping(..) + | StatementKind::Retag(..) + | StatementKind::Coverage(..) + | StatementKind::FakeRead(..) + | StatementKind::AscribeUserType(..) => { + self.super_statement(statement, location); + } + + StatementKind::Nop => {} + + StatementKind::StorageLive(_local) | StatementKind::StorageDead(_local) => {} + + StatementKind::Assign(box (ref place, ref rvalue)) => { + if rvalue.is_safe_to_remove() { + self.visit_lhs(place, location); + self.visit_rvalue(rvalue, location); + } else { + self.super_statement(statement, location); + } + } + + StatementKind::SetDiscriminant { ref place, variant_index: _ } + | StatementKind::Deinit(ref place) => { + self.visit_lhs(place, location); + } + } + } + + fn visit_local(&mut self, local: Local, _ctx: PlaceContext, _location: Location) { + if self.increment { + self.use_count[local] += 1; + } else { + assert_ne!(self.use_count[local], 0); + self.use_count[local] -= 1; + } + } +} + +/// Removes unused definitions. Updates the used locals to reflect the changes made. +fn remove_unused_definitions(used_locals: &mut UsedLocals, body: &mut Body<'_>) { + // The use counts are updated as we remove the statements. A local might become unused + // during the retain operation, leading to a temporary inconsistency (storage statements or + // definitions referencing the local might remain). For correctness it is crucial that this + // computation reaches a fixed point. + + let mut modified = true; + while modified { + modified = false; + + for data in body.basic_blocks_mut() { + // Remove unnecessary StorageLive and StorageDead annotations. + data.statements.retain(|statement| { + let keep = match &statement.kind { + StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => { + used_locals.is_used(*local) + } + StatementKind::Assign(box (place, _)) => used_locals.is_used(place.local), + + StatementKind::SetDiscriminant { ref place, .. } + | StatementKind::Deinit(ref place) => used_locals.is_used(place.local), + _ => true, + }; + + if !keep { + trace!("removing statement {:?}", statement); + modified = true; + used_locals.statement_removed(statement); + } + + keep + }); + } + } +} + +struct LocalUpdater<'tcx> { + map: IndexVec<Local, Option<Local>>, + tcx: TyCtxt<'tcx>, +} + +impl<'tcx> MutVisitor<'tcx> for LocalUpdater<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_local(&mut self, l: &mut Local, _: PlaceContext, _: Location) { + *l = self.map[*l].unwrap(); + } +} diff --git a/compiler/rustc_mir_transform/src/simplify_branches.rs b/compiler/rustc_mir_transform/src/simplify_branches.rs new file mode 100644 index 000000000..3bbae5b89 --- /dev/null +++ b/compiler/rustc_mir_transform/src/simplify_branches.rs @@ -0,0 +1,52 @@ +use crate::MirPass; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; + +use std::borrow::Cow; + +/// A pass that replaces a branch with a goto when its condition is known. +pub struct SimplifyConstCondition { + label: String, +} + +impl SimplifyConstCondition { + pub fn new(label: &str) -> Self { + SimplifyConstCondition { label: format!("SimplifyConstCondition-{}", label) } + } +} + +impl<'tcx> MirPass<'tcx> for SimplifyConstCondition { + fn name(&self) -> Cow<'_, str> { + Cow::Borrowed(&self.label) + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let param_env = tcx.param_env(body.source.def_id()); + for block in body.basic_blocks_mut() { + let terminator = block.terminator_mut(); + terminator.kind = match terminator.kind { + TerminatorKind::SwitchInt { + discr: Operand::Constant(ref c), + switch_ty, + ref targets, + .. + } => { + let constant = c.literal.try_eval_bits(tcx, param_env, switch_ty); + if let Some(constant) = constant { + let target = targets.target_for_value(constant); + TerminatorKind::Goto { target } + } else { + continue; + } + } + TerminatorKind::Assert { + target, cond: Operand::Constant(ref c), expected, .. + } => match c.literal.try_eval_bool(tcx, param_env) { + Some(v) if v == expected => TerminatorKind::Goto { target }, + _ => continue, + }, + _ => continue, + }; + } + } +} diff --git a/compiler/rustc_mir_transform/src/simplify_comparison_integral.rs b/compiler/rustc_mir_transform/src/simplify_comparison_integral.rs new file mode 100644 index 000000000..bbfaace70 --- /dev/null +++ b/compiler/rustc_mir_transform/src/simplify_comparison_integral.rs @@ -0,0 +1,242 @@ +use std::iter; + +use super::MirPass; +use rustc_middle::{ + mir::{ + interpret::Scalar, BasicBlock, BinOp, Body, Operand, Place, Rvalue, Statement, + StatementKind, SwitchTargets, TerminatorKind, + }, + ty::{Ty, TyCtxt}, +}; + +/// Pass to convert `if` conditions on integrals into switches on the integral. +/// For an example, it turns something like +/// +/// ```ignore (MIR) +/// _3 = Eq(move _4, const 43i32); +/// StorageDead(_4); +/// switchInt(_3) -> [false: bb2, otherwise: bb3]; +/// ``` +/// +/// into: +/// +/// ```ignore (MIR) +/// switchInt(_4) -> [43i32: bb3, otherwise: bb2]; +/// ``` +pub struct SimplifyComparisonIntegral; + +impl<'tcx> MirPass<'tcx> for SimplifyComparisonIntegral { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() > 0 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + trace!("Running SimplifyComparisonIntegral on {:?}", body.source); + + let helper = OptimizationFinder { body }; + let opts = helper.find_optimizations(); + let mut storage_deads_to_insert = vec![]; + let mut storage_deads_to_remove: Vec<(usize, BasicBlock)> = vec![]; + let param_env = tcx.param_env(body.source.def_id()); + for opt in opts { + trace!("SUCCESS: Applying {:?}", opt); + // replace terminator with a switchInt that switches on the integer directly + let bbs = &mut body.basic_blocks_mut(); + let bb = &mut bbs[opt.bb_idx]; + let new_value = match opt.branch_value_scalar { + Scalar::Int(int) => { + let layout = tcx + .layout_of(param_env.and(opt.branch_value_ty)) + .expect("if we have an evaluated constant we must know the layout"); + int.assert_bits(layout.size) + } + Scalar::Ptr(..) => continue, + }; + const FALSE: u128 = 0; + + let mut new_targets = opt.targets; + let first_value = new_targets.iter().next().unwrap().0; + let first_is_false_target = first_value == FALSE; + match opt.op { + BinOp::Eq => { + // if the assignment was Eq we want the true case to be first + if first_is_false_target { + new_targets.all_targets_mut().swap(0, 1); + } + } + BinOp::Ne => { + // if the assignment was Ne we want the false case to be first + if !first_is_false_target { + new_targets.all_targets_mut().swap(0, 1); + } + } + _ => unreachable!(), + } + + // delete comparison statement if it the value being switched on was moved, which means it can not be user later on + if opt.can_remove_bin_op_stmt { + bb.statements[opt.bin_op_stmt_idx].make_nop(); + } else { + // if the integer being compared to a const integral is being moved into the comparison, + // e.g `_2 = Eq(move _3, const 'x');` + // we want to avoid making a double move later on in the switchInt on _3. + // So to avoid `switchInt(move _3) -> ['x': bb2, otherwise: bb1];`, + // we convert the move in the comparison statement to a copy. + + // unwrap is safe as we know this statement is an assign + let (_, rhs) = bb.statements[opt.bin_op_stmt_idx].kind.as_assign_mut().unwrap(); + + use Operand::*; + match rhs { + Rvalue::BinaryOp(_, box (ref mut left @ Move(_), Constant(_))) => { + *left = Copy(opt.to_switch_on); + } + Rvalue::BinaryOp(_, box (Constant(_), ref mut right @ Move(_))) => { + *right = Copy(opt.to_switch_on); + } + _ => (), + } + } + + let terminator = bb.terminator(); + + // remove StorageDead (if it exists) being used in the assign of the comparison + for (stmt_idx, stmt) in bb.statements.iter().enumerate() { + if !matches!(stmt.kind, StatementKind::StorageDead(local) if local == opt.to_switch_on.local) + { + continue; + } + storage_deads_to_remove.push((stmt_idx, opt.bb_idx)); + // if we have StorageDeads to remove then make sure to insert them at the top of each target + for bb_idx in new_targets.all_targets() { + storage_deads_to_insert.push(( + *bb_idx, + Statement { + source_info: terminator.source_info, + kind: StatementKind::StorageDead(opt.to_switch_on.local), + }, + )); + } + } + + let [bb_cond, bb_otherwise] = match new_targets.all_targets() { + [a, b] => [*a, *b], + e => bug!("expected 2 switch targets, got: {:?}", e), + }; + + let targets = SwitchTargets::new(iter::once((new_value, bb_cond)), bb_otherwise); + + let terminator = bb.terminator_mut(); + terminator.kind = TerminatorKind::SwitchInt { + discr: Operand::Move(opt.to_switch_on), + switch_ty: opt.branch_value_ty, + targets, + }; + } + + for (idx, bb_idx) in storage_deads_to_remove { + body.basic_blocks_mut()[bb_idx].statements[idx].make_nop(); + } + + for (idx, stmt) in storage_deads_to_insert { + body.basic_blocks_mut()[idx].statements.insert(0, stmt); + } + } +} + +struct OptimizationFinder<'a, 'tcx> { + body: &'a Body<'tcx>, +} + +impl<'tcx> OptimizationFinder<'_, 'tcx> { + fn find_optimizations(&self) -> Vec<OptimizationInfo<'tcx>> { + self.body + .basic_blocks() + .iter_enumerated() + .filter_map(|(bb_idx, bb)| { + // find switch + let (place_switched_on, targets, place_switched_on_moved) = + match &bb.terminator().kind { + rustc_middle::mir::TerminatorKind::SwitchInt { discr, targets, .. } => { + Some((discr.place()?, targets, discr.is_move())) + } + _ => None, + }?; + + // find the statement that assigns the place being switched on + bb.statements.iter().enumerate().rev().find_map(|(stmt_idx, stmt)| { + match &stmt.kind { + rustc_middle::mir::StatementKind::Assign(box (lhs, rhs)) + if *lhs == place_switched_on => + { + match rhs { + Rvalue::BinaryOp( + op @ (BinOp::Eq | BinOp::Ne), + box (left, right), + ) => { + let (branch_value_scalar, branch_value_ty, to_switch_on) = + find_branch_value_info(left, right)?; + + Some(OptimizationInfo { + bin_op_stmt_idx: stmt_idx, + bb_idx, + can_remove_bin_op_stmt: place_switched_on_moved, + to_switch_on, + branch_value_scalar, + branch_value_ty, + op: *op, + targets: targets.clone(), + }) + } + _ => None, + } + } + _ => None, + } + }) + }) + .collect() + } +} + +fn find_branch_value_info<'tcx>( + left: &Operand<'tcx>, + right: &Operand<'tcx>, +) -> Option<(Scalar, Ty<'tcx>, Place<'tcx>)> { + // check that either left or right is a constant. + // if any are, we can use the other to switch on, and the constant as a value in a switch + use Operand::*; + match (left, right) { + (Constant(branch_value), Copy(to_switch_on) | Move(to_switch_on)) + | (Copy(to_switch_on) | Move(to_switch_on), Constant(branch_value)) => { + let branch_value_ty = branch_value.literal.ty(); + // we only want to apply this optimization if we are matching on integrals (and chars), as it is not possible to switch on floats + if !branch_value_ty.is_integral() && !branch_value_ty.is_char() { + return None; + }; + let branch_value_scalar = branch_value.literal.try_to_scalar()?; + Some((branch_value_scalar, branch_value_ty, *to_switch_on)) + } + _ => None, + } +} + +#[derive(Debug)] +struct OptimizationInfo<'tcx> { + /// Basic block to apply the optimization + bb_idx: BasicBlock, + /// Statement index of Eq/Ne assignment that can be removed. None if the assignment can not be removed - i.e the statement is used later on + bin_op_stmt_idx: usize, + /// Can remove Eq/Ne assignment + can_remove_bin_op_stmt: bool, + /// Place that needs to be switched on. This place is of type integral + to_switch_on: Place<'tcx>, + /// Constant to use in switch target value + branch_value_scalar: Scalar, + /// Type of the constant value + branch_value_ty: Ty<'tcx>, + /// Either Eq or Ne + op: BinOp, + /// Current targets used in the switch + targets: SwitchTargets, +} diff --git a/compiler/rustc_mir_transform/src/simplify_try.rs b/compiler/rustc_mir_transform/src/simplify_try.rs new file mode 100644 index 000000000..d52f1261b --- /dev/null +++ b/compiler/rustc_mir_transform/src/simplify_try.rs @@ -0,0 +1,822 @@ +//! The general point of the optimizations provided here is to simplify something like: +//! +//! ```rust +//! # fn foo<T, E>(x: Result<T, E>) -> Result<T, E> { +//! match x { +//! Ok(x) => Ok(x), +//! Err(x) => Err(x) +//! } +//! # } +//! ``` +//! +//! into just `x`. + +use crate::{simplify, MirPass}; +use itertools::Itertools as _; +use rustc_index::{bit_set::BitSet, vec::IndexVec}; +use rustc_middle::mir::visit::{NonUseContext, PlaceContext, Visitor}; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, List, Ty, TyCtxt}; +use rustc_target::abi::VariantIdx; +use std::iter::{once, Enumerate, Peekable}; +use std::slice::Iter; + +/// Simplifies arms of form `Variant(x) => Variant(x)` to just a move. +/// +/// This is done by transforming basic blocks where the statements match: +/// +/// ```ignore (MIR) +/// _LOCAL_TMP = ((_LOCAL_1 as Variant ).FIELD: TY ); +/// _TMP_2 = _LOCAL_TMP; +/// ((_LOCAL_0 as Variant).FIELD: TY) = move _TMP_2; +/// discriminant(_LOCAL_0) = VAR_IDX; +/// ``` +/// +/// into: +/// +/// ```ignore (MIR) +/// _LOCAL_0 = move _LOCAL_1 +/// ``` +pub struct SimplifyArmIdentity; + +#[derive(Debug)] +struct ArmIdentityInfo<'tcx> { + /// Storage location for the variant's field + local_temp_0: Local, + /// Storage location holding the variant being read from + local_1: Local, + /// The variant field being read from + vf_s0: VarField<'tcx>, + /// Index of the statement which loads the variant being read + get_variant_field_stmt: usize, + + /// Tracks each assignment to a temporary of the variant's field + field_tmp_assignments: Vec<(Local, Local)>, + + /// Storage location holding the variant's field that was read from + local_tmp_s1: Local, + /// Storage location holding the enum that we are writing to + local_0: Local, + /// The variant field being written to + vf_s1: VarField<'tcx>, + + /// Storage location that the discriminant is being written to + set_discr_local: Local, + /// The variant being written + set_discr_var_idx: VariantIdx, + + /// Index of the statement that should be overwritten as a move + stmt_to_overwrite: usize, + /// SourceInfo for the new move + source_info: SourceInfo, + + /// Indices of matching Storage{Live,Dead} statements encountered. + /// (StorageLive index,, StorageDead index, Local) + storage_stmts: Vec<(usize, usize, Local)>, + + /// The statements that should be removed (turned into nops) + stmts_to_remove: Vec<usize>, + + /// Indices of debug variables that need to be adjusted to point to + // `{local_0}.{dbg_projection}`. + dbg_info_to_adjust: Vec<usize>, + + /// The projection used to rewrite debug info. + dbg_projection: &'tcx List<PlaceElem<'tcx>>, +} + +fn get_arm_identity_info<'a, 'tcx>( + stmts: &'a [Statement<'tcx>], + locals_count: usize, + debug_info: &'a [VarDebugInfo<'tcx>], +) -> Option<ArmIdentityInfo<'tcx>> { + // This can't possibly match unless there are at least 3 statements in the block + // so fail fast on tiny blocks. + if stmts.len() < 3 { + return None; + } + + let mut tmp_assigns = Vec::new(); + let mut nop_stmts = Vec::new(); + let mut storage_stmts = Vec::new(); + let mut storage_live_stmts = Vec::new(); + let mut storage_dead_stmts = Vec::new(); + + type StmtIter<'a, 'tcx> = Peekable<Enumerate<Iter<'a, Statement<'tcx>>>>; + + fn is_storage_stmt(stmt: &Statement<'_>) -> bool { + matches!(stmt.kind, StatementKind::StorageLive(_) | StatementKind::StorageDead(_)) + } + + /// Eats consecutive Statements which match `test`, performing the specified `action` for each. + /// The iterator `stmt_iter` is not advanced if none were matched. + fn try_eat<'a, 'tcx>( + stmt_iter: &mut StmtIter<'a, 'tcx>, + test: impl Fn(&'a Statement<'tcx>) -> bool, + mut action: impl FnMut(usize, &'a Statement<'tcx>), + ) { + while stmt_iter.peek().map_or(false, |(_, stmt)| test(stmt)) { + let (idx, stmt) = stmt_iter.next().unwrap(); + + action(idx, stmt); + } + } + + /// Eats consecutive `StorageLive` and `StorageDead` Statements. + /// The iterator `stmt_iter` is not advanced if none were found. + fn try_eat_storage_stmts( + stmt_iter: &mut StmtIter<'_, '_>, + storage_live_stmts: &mut Vec<(usize, Local)>, + storage_dead_stmts: &mut Vec<(usize, Local)>, + ) { + try_eat(stmt_iter, is_storage_stmt, |idx, stmt| { + if let StatementKind::StorageLive(l) = stmt.kind { + storage_live_stmts.push((idx, l)); + } else if let StatementKind::StorageDead(l) = stmt.kind { + storage_dead_stmts.push((idx, l)); + } + }) + } + + fn is_tmp_storage_stmt(stmt: &Statement<'_>) -> bool { + use rustc_middle::mir::StatementKind::Assign; + if let Assign(box (place, Rvalue::Use(Operand::Copy(p) | Operand::Move(p)))) = &stmt.kind { + place.as_local().is_some() && p.as_local().is_some() + } else { + false + } + } + + /// Eats consecutive `Assign` Statements. + // The iterator `stmt_iter` is not advanced if none were found. + fn try_eat_assign_tmp_stmts( + stmt_iter: &mut StmtIter<'_, '_>, + tmp_assigns: &mut Vec<(Local, Local)>, + nop_stmts: &mut Vec<usize>, + ) { + try_eat(stmt_iter, is_tmp_storage_stmt, |idx, stmt| { + use rustc_middle::mir::StatementKind::Assign; + if let Assign(box (place, Rvalue::Use(Operand::Copy(p) | Operand::Move(p)))) = + &stmt.kind + { + tmp_assigns.push((place.as_local().unwrap(), p.as_local().unwrap())); + nop_stmts.push(idx); + } + }) + } + + fn find_storage_live_dead_stmts_for_local( + local: Local, + stmts: &[Statement<'_>], + ) -> Option<(usize, usize)> { + trace!("looking for {:?}", local); + let mut storage_live_stmt = None; + let mut storage_dead_stmt = None; + for (idx, stmt) in stmts.iter().enumerate() { + if stmt.kind == StatementKind::StorageLive(local) { + storage_live_stmt = Some(idx); + } else if stmt.kind == StatementKind::StorageDead(local) { + storage_dead_stmt = Some(idx); + } + } + + Some((storage_live_stmt?, storage_dead_stmt.unwrap_or(usize::MAX))) + } + + // Try to match the expected MIR structure with the basic block we're processing. + // We want to see something that looks like: + // ``` + // (StorageLive(_) | StorageDead(_));* + // _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY); + // (StorageLive(_) | StorageDead(_));* + // (tmp_n+1 = tmp_n);* + // (StorageLive(_) | StorageDead(_));* + // (tmp_n+1 = tmp_n);* + // ((LOCAL_FROM as Variant).FIELD: TY) = move tmp; + // discriminant(LOCAL_FROM) = VariantIdx; + // (StorageLive(_) | StorageDead(_));* + // ``` + let mut stmt_iter = stmts.iter().enumerate().peekable(); + + try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); + + let (get_variant_field_stmt, stmt) = stmt_iter.next()?; + let (local_tmp_s0, local_1, vf_s0, dbg_projection) = match_get_variant_field(stmt)?; + + try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); + + try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts); + + try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); + + try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts); + + let (idx, stmt) = stmt_iter.next()?; + let (local_tmp_s1, local_0, vf_s1) = match_set_variant_field(stmt)?; + nop_stmts.push(idx); + + let (idx, stmt) = stmt_iter.next()?; + let (set_discr_local, set_discr_var_idx) = match_set_discr(stmt)?; + let discr_stmt_source_info = stmt.source_info; + nop_stmts.push(idx); + + try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); + + for (live_idx, live_local) in storage_live_stmts { + if let Some(i) = storage_dead_stmts.iter().rposition(|(_, l)| *l == live_local) { + let (dead_idx, _) = storage_dead_stmts.swap_remove(i); + storage_stmts.push((live_idx, dead_idx, live_local)); + + if live_local == local_tmp_s0 { + nop_stmts.push(get_variant_field_stmt); + } + } + } + // We sort primitive usize here so we can use unstable sort + nop_stmts.sort_unstable(); + + // Use one of the statements we're going to discard between the point + // where the storage location for the variant field becomes live and + // is killed. + let (live_idx, dead_idx) = find_storage_live_dead_stmts_for_local(local_tmp_s0, stmts)?; + let stmt_to_overwrite = + nop_stmts.iter().find(|stmt_idx| live_idx < **stmt_idx && **stmt_idx < dead_idx); + + let mut tmp_assigned_vars = BitSet::new_empty(locals_count); + for (l, r) in &tmp_assigns { + tmp_assigned_vars.insert(*l); + tmp_assigned_vars.insert(*r); + } + + let dbg_info_to_adjust: Vec<_> = debug_info + .iter() + .enumerate() + .filter_map(|(i, var_info)| { + if let VarDebugInfoContents::Place(p) = var_info.value { + if tmp_assigned_vars.contains(p.local) { + return Some(i); + } + } + + None + }) + .collect(); + + Some(ArmIdentityInfo { + local_temp_0: local_tmp_s0, + local_1, + vf_s0, + get_variant_field_stmt, + field_tmp_assignments: tmp_assigns, + local_tmp_s1, + local_0, + vf_s1, + set_discr_local, + set_discr_var_idx, + stmt_to_overwrite: *stmt_to_overwrite?, + source_info: discr_stmt_source_info, + storage_stmts, + stmts_to_remove: nop_stmts, + dbg_info_to_adjust, + dbg_projection, + }) +} + +fn optimization_applies<'tcx>( + opt_info: &ArmIdentityInfo<'tcx>, + local_decls: &IndexVec<Local, LocalDecl<'tcx>>, + local_uses: &IndexVec<Local, usize>, + var_debug_info: &[VarDebugInfo<'tcx>], +) -> bool { + trace!("testing if optimization applies..."); + + // FIXME(wesleywiser): possibly relax this restriction? + if opt_info.local_0 == opt_info.local_1 { + trace!("NO: moving into ourselves"); + return false; + } else if opt_info.vf_s0 != opt_info.vf_s1 { + trace!("NO: the field-and-variant information do not match"); + return false; + } else if local_decls[opt_info.local_0].ty != local_decls[opt_info.local_1].ty { + // FIXME(Centril,oli-obk): possibly relax to same layout? + trace!("NO: source and target locals have different types"); + return false; + } else if (opt_info.local_0, opt_info.vf_s0.var_idx) + != (opt_info.set_discr_local, opt_info.set_discr_var_idx) + { + trace!("NO: the discriminants do not match"); + return false; + } + + // Verify the assignment chain consists of the form b = a; c = b; d = c; etc... + if opt_info.field_tmp_assignments.is_empty() { + trace!("NO: no assignments found"); + return false; + } + let mut last_assigned_to = opt_info.field_tmp_assignments[0].1; + let source_local = last_assigned_to; + for (l, r) in &opt_info.field_tmp_assignments { + if *r != last_assigned_to { + trace!("NO: found unexpected assignment {:?} = {:?}", l, r); + return false; + } + + last_assigned_to = *l; + } + + // Check that the first and last used locals are only used twice + // since they are of the form: + // + // ``` + // _first = ((_x as Variant).n: ty); + // _n = _first; + // ... + // ((_y as Variant).n: ty) = _n; + // discriminant(_y) = z; + // ``` + for (l, r) in &opt_info.field_tmp_assignments { + if local_uses[*l] != 2 { + warn!("NO: FAILED assignment chain local {:?} was used more than twice", l); + return false; + } else if local_uses[*r] != 2 { + warn!("NO: FAILED assignment chain local {:?} was used more than twice", r); + return false; + } + } + + // Check that debug info only points to full Locals and not projections. + for dbg_idx in &opt_info.dbg_info_to_adjust { + let dbg_info = &var_debug_info[*dbg_idx]; + if let VarDebugInfoContents::Place(p) = dbg_info.value { + if !p.projection.is_empty() { + trace!("NO: debug info for {:?} had a projection {:?}", dbg_info.name, p); + return false; + } + } + } + + if source_local != opt_info.local_temp_0 { + trace!( + "NO: start of assignment chain does not match enum variant temp: {:?} != {:?}", + source_local, + opt_info.local_temp_0 + ); + return false; + } else if last_assigned_to != opt_info.local_tmp_s1 { + trace!( + "NO: end of assignment chain does not match written enum temp: {:?} != {:?}", + last_assigned_to, + opt_info.local_tmp_s1 + ); + return false; + } + + trace!("SUCCESS: optimization applies!"); + true +} + +impl<'tcx> MirPass<'tcx> for SimplifyArmIdentity { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // FIXME(77359): This optimization can result in unsoundness. + if !tcx.sess.opts.unstable_opts.unsound_mir_opts { + return; + } + + let source = body.source; + trace!("running SimplifyArmIdentity on {:?}", source); + + let local_uses = LocalUseCounter::get_local_uses(body); + for bb in body.basic_blocks.as_mut() { + if let Some(opt_info) = + get_arm_identity_info(&bb.statements, body.local_decls.len(), &body.var_debug_info) + { + trace!("got opt_info = {:#?}", opt_info); + if !optimization_applies( + &opt_info, + &body.local_decls, + &local_uses, + &body.var_debug_info, + ) { + debug!("optimization skipped for {:?}", source); + continue; + } + + // Also remove unused Storage{Live,Dead} statements which correspond + // to temps used previously. + for (live_idx, dead_idx, local) in &opt_info.storage_stmts { + // The temporary that we've read the variant field into is scoped to this block, + // so we can remove the assignment. + if *local == opt_info.local_temp_0 { + bb.statements[opt_info.get_variant_field_stmt].make_nop(); + } + + for (left, right) in &opt_info.field_tmp_assignments { + if local == left || local == right { + bb.statements[*live_idx].make_nop(); + bb.statements[*dead_idx].make_nop(); + } + } + } + + // Right shape; transform + for stmt_idx in opt_info.stmts_to_remove { + bb.statements[stmt_idx].make_nop(); + } + + let stmt = &mut bb.statements[opt_info.stmt_to_overwrite]; + stmt.source_info = opt_info.source_info; + stmt.kind = StatementKind::Assign(Box::new(( + opt_info.local_0.into(), + Rvalue::Use(Operand::Move(opt_info.local_1.into())), + ))); + + bb.statements.retain(|stmt| stmt.kind != StatementKind::Nop); + + // Fix the debug info to point to the right local + for dbg_index in opt_info.dbg_info_to_adjust { + let dbg_info = &mut body.var_debug_info[dbg_index]; + assert!( + matches!(dbg_info.value, VarDebugInfoContents::Place(_)), + "value was not a Place" + ); + if let VarDebugInfoContents::Place(p) = &mut dbg_info.value { + assert!(p.projection.is_empty()); + p.local = opt_info.local_0; + p.projection = opt_info.dbg_projection; + } + } + + trace!("block is now {:?}", bb.statements); + } + } + } +} + +struct LocalUseCounter { + local_uses: IndexVec<Local, usize>, +} + +impl LocalUseCounter { + fn get_local_uses(body: &Body<'_>) -> IndexVec<Local, usize> { + let mut counter = LocalUseCounter { local_uses: IndexVec::from_elem(0, &body.local_decls) }; + counter.visit_body(body); + counter.local_uses + } +} + +impl Visitor<'_> for LocalUseCounter { + fn visit_local(&mut self, local: Local, context: PlaceContext, _location: Location) { + if context.is_storage_marker() + || context == PlaceContext::NonUse(NonUseContext::VarDebugInfo) + { + return; + } + + self.local_uses[local] += 1; + } +} + +/// Match on: +/// ```ignore (MIR) +/// _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY); +/// ``` +fn match_get_variant_field<'tcx>( + stmt: &Statement<'tcx>, +) -> Option<(Local, Local, VarField<'tcx>, &'tcx List<PlaceElem<'tcx>>)> { + match &stmt.kind { + StatementKind::Assign(box ( + place_into, + Rvalue::Use(Operand::Copy(pf) | Operand::Move(pf)), + )) => { + let local_into = place_into.as_local()?; + let (local_from, vf) = match_variant_field_place(*pf)?; + Some((local_into, local_from, vf, pf.projection)) + } + _ => None, + } +} + +/// Match on: +/// ```ignore (MIR) +/// ((_LOCAL_FROM as Variant).FIELD: TY) = move _LOCAL_INTO; +/// ``` +fn match_set_variant_field<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, Local, VarField<'tcx>)> { + match &stmt.kind { + StatementKind::Assign(box (place_from, Rvalue::Use(Operand::Move(place_into)))) => { + let local_into = place_into.as_local()?; + let (local_from, vf) = match_variant_field_place(*place_from)?; + Some((local_into, local_from, vf)) + } + _ => None, + } +} + +/// Match on: +/// ```ignore (MIR) +/// discriminant(_LOCAL_TO_SET) = VAR_IDX; +/// ``` +fn match_set_discr(stmt: &Statement<'_>) -> Option<(Local, VariantIdx)> { + match &stmt.kind { + StatementKind::SetDiscriminant { place, variant_index } => { + Some((place.as_local()?, *variant_index)) + } + _ => None, + } +} + +#[derive(PartialEq, Debug)] +struct VarField<'tcx> { + field: Field, + field_ty: Ty<'tcx>, + var_idx: VariantIdx, +} + +/// Match on `((_LOCAL as Variant).FIELD: TY)`. +fn match_variant_field_place<'tcx>(place: Place<'tcx>) -> Option<(Local, VarField<'tcx>)> { + match place.as_ref() { + PlaceRef { + local, + projection: &[ProjectionElem::Downcast(_, var_idx), ProjectionElem::Field(field, ty)], + } => Some((local, VarField { field, field_ty: ty, var_idx })), + _ => None, + } +} + +/// Simplifies `SwitchInt(_) -> [targets]`, +/// where all the `targets` have the same form, +/// into `goto -> target_first`. +pub struct SimplifyBranchSame; + +impl<'tcx> MirPass<'tcx> for SimplifyBranchSame { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // This optimization is disabled by default for now due to + // soundness concerns; see issue #89485 and PR #89489. + if !tcx.sess.opts.unstable_opts.unsound_mir_opts { + return; + } + + trace!("Running SimplifyBranchSame on {:?}", body.source); + let finder = SimplifyBranchSameOptimizationFinder { body, tcx }; + let opts = finder.find(); + + let did_remove_blocks = opts.len() > 0; + for opt in opts.iter() { + trace!("SUCCESS: Applying optimization {:?}", opt); + // Replace `SwitchInt(..) -> [bb_first, ..];` with a `goto -> bb_first;`. + body.basic_blocks_mut()[opt.bb_to_opt_terminator].terminator_mut().kind = + TerminatorKind::Goto { target: opt.bb_to_goto }; + } + + if did_remove_blocks { + // We have dead blocks now, so remove those. + simplify::remove_dead_blocks(tcx, body); + } + } +} + +#[derive(Debug)] +struct SimplifyBranchSameOptimization { + /// All basic blocks are equal so go to this one + bb_to_goto: BasicBlock, + /// Basic block where the terminator can be simplified to a goto + bb_to_opt_terminator: BasicBlock, +} + +struct SwitchTargetAndValue { + target: BasicBlock, + // None in case of the `otherwise` case + value: Option<u128>, +} + +struct SimplifyBranchSameOptimizationFinder<'a, 'tcx> { + body: &'a Body<'tcx>, + tcx: TyCtxt<'tcx>, +} + +impl<'tcx> SimplifyBranchSameOptimizationFinder<'_, 'tcx> { + fn find(&self) -> Vec<SimplifyBranchSameOptimization> { + self.body + .basic_blocks() + .iter_enumerated() + .filter_map(|(bb_idx, bb)| { + let (discr_switched_on, targets_and_values) = match &bb.terminator().kind { + TerminatorKind::SwitchInt { targets, discr, .. } => { + let targets_and_values: Vec<_> = targets.iter() + .map(|(val, target)| SwitchTargetAndValue { target, value: Some(val) }) + .chain(once(SwitchTargetAndValue { target: targets.otherwise(), value: None })) + .collect(); + (discr, targets_and_values) + }, + _ => return None, + }; + + // find the adt that has its discriminant read + // assuming this must be the last statement of the block + let adt_matched_on = match &bb.statements.last()?.kind { + StatementKind::Assign(box (place, rhs)) + if Some(*place) == discr_switched_on.place() => + { + match rhs { + Rvalue::Discriminant(adt_place) if adt_place.ty(self.body, self.tcx).ty.is_enum() => adt_place, + _ => { + trace!("NO: expected a discriminant read of an enum instead of: {:?}", rhs); + return None; + } + } + } + other => { + trace!("NO: expected an assignment of a discriminant read to a place. Found: {:?}", other); + return None + }, + }; + + let mut iter_bbs_reachable = targets_and_values + .iter() + .map(|target_and_value| (target_and_value, &self.body.basic_blocks()[target_and_value.target])) + .filter(|(_, bb)| { + // Reaching `unreachable` is UB so assume it doesn't happen. + bb.terminator().kind != TerminatorKind::Unreachable + }) + .peekable(); + + let bb_first = iter_bbs_reachable.peek().map_or(&targets_and_values[0], |(idx, _)| *idx); + let mut all_successors_equivalent = StatementEquality::TrivialEqual; + + // All successor basic blocks must be equal or contain statements that are pairwise considered equal. + for ((target_and_value_l,bb_l), (target_and_value_r,bb_r)) in iter_bbs_reachable.tuple_windows() { + let trivial_checks = bb_l.is_cleanup == bb_r.is_cleanup + && bb_l.terminator().kind == bb_r.terminator().kind + && bb_l.statements.len() == bb_r.statements.len(); + let statement_check = || { + bb_l.statements.iter().zip(&bb_r.statements).try_fold(StatementEquality::TrivialEqual, |acc,(l,r)| { + let stmt_equality = self.statement_equality(*adt_matched_on, &l, target_and_value_l, &r, target_and_value_r); + if matches!(stmt_equality, StatementEquality::NotEqual) { + // short circuit + None + } else { + Some(acc.combine(&stmt_equality)) + } + }) + .unwrap_or(StatementEquality::NotEqual) + }; + if !trivial_checks { + all_successors_equivalent = StatementEquality::NotEqual; + break; + } + all_successors_equivalent = all_successors_equivalent.combine(&statement_check()); + }; + + match all_successors_equivalent{ + StatementEquality::TrivialEqual => { + // statements are trivially equal, so just take first + trace!("Statements are trivially equal"); + Some(SimplifyBranchSameOptimization { + bb_to_goto: bb_first.target, + bb_to_opt_terminator: bb_idx, + }) + } + StatementEquality::ConsideredEqual(bb_to_choose) => { + trace!("Statements are considered equal"); + Some(SimplifyBranchSameOptimization { + bb_to_goto: bb_to_choose, + bb_to_opt_terminator: bb_idx, + }) + } + StatementEquality::NotEqual => { + trace!("NO: not all successors of basic block {:?} were equivalent", bb_idx); + None + } + } + }) + .collect() + } + + /// Tests if two statements can be considered equal + /// + /// Statements can be trivially equal if the kinds match. + /// But they can also be considered equal in the following case A: + /// ```ignore (MIR) + /// discriminant(_0) = 0; // bb1 + /// _0 = move _1; // bb2 + /// ``` + /// In this case the two statements are equal iff + /// - `_0` is an enum where the variant index 0 is fieldless, and + /// - bb1 was targeted by a switch where the discriminant of `_1` was switched on + fn statement_equality( + &self, + adt_matched_on: Place<'tcx>, + x: &Statement<'tcx>, + x_target_and_value: &SwitchTargetAndValue, + y: &Statement<'tcx>, + y_target_and_value: &SwitchTargetAndValue, + ) -> StatementEquality { + let helper = |rhs: &Rvalue<'tcx>, + place: &Place<'tcx>, + variant_index: VariantIdx, + switch_value: u128, + side_to_choose| { + let place_type = place.ty(self.body, self.tcx).ty; + let adt = match *place_type.kind() { + ty::Adt(adt, _) if adt.is_enum() => adt, + _ => return StatementEquality::NotEqual, + }; + // We need to make sure that the switch value that targets the bb with + // SetDiscriminant is the same as the variant discriminant. + let variant_discr = adt.discriminant_for_variant(self.tcx, variant_index).val; + if variant_discr != switch_value { + trace!( + "NO: variant discriminant {} does not equal switch value {}", + variant_discr, + switch_value + ); + return StatementEquality::NotEqual; + } + let variant_is_fieldless = adt.variant(variant_index).fields.is_empty(); + if !variant_is_fieldless { + trace!("NO: variant {:?} was not fieldless", variant_index); + return StatementEquality::NotEqual; + } + + match rhs { + Rvalue::Use(operand) if operand.place() == Some(adt_matched_on) => { + StatementEquality::ConsideredEqual(side_to_choose) + } + _ => { + trace!( + "NO: RHS of assignment was {:?}, but expected it to match the adt being matched on in the switch, which is {:?}", + rhs, + adt_matched_on + ); + StatementEquality::NotEqual + } + } + }; + match (&x.kind, &y.kind) { + // trivial case + (x, y) if x == y => StatementEquality::TrivialEqual, + + // check for case A + ( + StatementKind::Assign(box (_, rhs)), + &StatementKind::SetDiscriminant { ref place, variant_index }, + ) if y_target_and_value.value.is_some() => { + // choose basic block of x, as that has the assign + helper( + rhs, + place, + variant_index, + y_target_and_value.value.unwrap(), + x_target_and_value.target, + ) + } + ( + &StatementKind::SetDiscriminant { ref place, variant_index }, + &StatementKind::Assign(box (_, ref rhs)), + ) if x_target_and_value.value.is_some() => { + // choose basic block of y, as that has the assign + helper( + rhs, + place, + variant_index, + x_target_and_value.value.unwrap(), + y_target_and_value.target, + ) + } + _ => { + trace!("NO: statements `{:?}` and `{:?}` not considered equal", x, y); + StatementEquality::NotEqual + } + } + } +} + +#[derive(Copy, Clone, Eq, PartialEq)] +enum StatementEquality { + /// The two statements are trivially equal; same kind + TrivialEqual, + /// The two statements are considered equal, but may be of different kinds. The BasicBlock field is the basic block to jump to when performing the branch-same optimization. + /// For example, `_0 = _1` and `discriminant(_0) = discriminant(0)` are considered equal if 0 is a fieldless variant of an enum. But we don't want to jump to the basic block with the SetDiscriminant, as that is not legal if _1 is not the 0 variant index + ConsideredEqual(BasicBlock), + /// The two statements are not equal + NotEqual, +} + +impl StatementEquality { + fn combine(&self, other: &StatementEquality) -> StatementEquality { + use StatementEquality::*; + match (self, other) { + (TrivialEqual, TrivialEqual) => TrivialEqual, + (TrivialEqual, ConsideredEqual(b)) | (ConsideredEqual(b), TrivialEqual) => { + ConsideredEqual(*b) + } + (ConsideredEqual(b1), ConsideredEqual(b2)) => { + if b1 == b2 { + ConsideredEqual(*b1) + } else { + NotEqual + } + } + (_, NotEqual) | (NotEqual, _) => NotEqual, + } + } +} diff --git a/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs b/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs new file mode 100644 index 000000000..30be64f5b --- /dev/null +++ b/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs @@ -0,0 +1,149 @@ +//! A pass that eliminates branches on uninhabited enum variants. + +use crate::MirPass; +use rustc_data_structures::fx::FxHashSet; +use rustc_middle::mir::{ + BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, SwitchTargets, Terminator, + TerminatorKind, +}; +use rustc_middle::ty::layout::TyAndLayout; +use rustc_middle::ty::{Ty, TyCtxt}; +use rustc_target::abi::{Abi, Variants}; + +pub struct UninhabitedEnumBranching; + +fn get_discriminant_local(terminator: &TerminatorKind<'_>) -> Option<Local> { + if let TerminatorKind::SwitchInt { discr: Operand::Move(p), .. } = terminator { + p.as_local() + } else { + None + } +} + +/// If the basic block terminates by switching on a discriminant, this returns the `Ty` the +/// discriminant is read from. Otherwise, returns None. +fn get_switched_on_type<'tcx>( + block_data: &BasicBlockData<'tcx>, + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, +) -> Option<Ty<'tcx>> { + let terminator = block_data.terminator(); + + // Only bother checking blocks which terminate by switching on a local. + if let Some(local) = get_discriminant_local(&terminator.kind) { + let stmt_before_term = (!block_data.statements.is_empty()) + .then(|| &block_data.statements[block_data.statements.len() - 1].kind); + + if let Some(StatementKind::Assign(box (l, Rvalue::Discriminant(place)))) = stmt_before_term + { + if l.as_local() == Some(local) { + let ty = place.ty(body, tcx).ty; + if ty.is_enum() { + return Some(ty); + } + } + } + } + + None +} + +fn variant_discriminants<'tcx>( + layout: &TyAndLayout<'tcx>, + ty: Ty<'tcx>, + tcx: TyCtxt<'tcx>, +) -> FxHashSet<u128> { + match &layout.variants { + Variants::Single { index } => { + let mut res = FxHashSet::default(); + res.insert( + ty.discriminant_for_variant(tcx, *index) + .map_or(index.as_u32() as u128, |discr| discr.val), + ); + res + } + Variants::Multiple { variants, .. } => variants + .iter_enumerated() + .filter_map(|(idx, layout)| { + (layout.abi() != Abi::Uninhabited) + .then(|| ty.discriminant_for_variant(tcx, idx).unwrap().val) + }) + .collect(), + } +} + +/// Ensures that the `otherwise` branch leads to an unreachable bb, returning `None` if so and a new +/// bb to use as the new target if not. +fn ensure_otherwise_unreachable<'tcx>( + body: &Body<'tcx>, + targets: &SwitchTargets, +) -> Option<BasicBlockData<'tcx>> { + let otherwise = targets.otherwise(); + let bb = &body.basic_blocks()[otherwise]; + if bb.terminator().kind == TerminatorKind::Unreachable + && bb.statements.iter().all(|s| matches!(&s.kind, StatementKind::StorageDead(_))) + { + return None; + } + + let mut new_block = BasicBlockData::new(Some(Terminator { + source_info: bb.terminator().source_info, + kind: TerminatorKind::Unreachable, + })); + new_block.is_cleanup = bb.is_cleanup; + Some(new_block) +} + +impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() > 0 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + trace!("UninhabitedEnumBranching starting for {:?}", body.source); + + for bb in body.basic_blocks().indices() { + trace!("processing block {:?}", bb); + + let Some(discriminant_ty) = get_switched_on_type(&body.basic_blocks()[bb], tcx, body) else { + continue; + }; + + let layout = tcx.layout_of(tcx.param_env(body.source.def_id()).and(discriminant_ty)); + + let allowed_variants = if let Ok(layout) = layout { + variant_discriminants(&layout, discriminant_ty, tcx) + } else { + continue; + }; + + trace!("allowed_variants = {:?}", allowed_variants); + + if let TerminatorKind::SwitchInt { targets, .. } = + &mut body.basic_blocks_mut()[bb].terminator_mut().kind + { + let mut new_targets = SwitchTargets::new( + targets.iter().filter(|(val, _)| allowed_variants.contains(val)), + targets.otherwise(), + ); + + if new_targets.iter().count() == allowed_variants.len() { + if let Some(updated) = ensure_otherwise_unreachable(body, &new_targets) { + let new_otherwise = body.basic_blocks_mut().push(updated); + *new_targets.all_targets_mut().last_mut().unwrap() = new_otherwise; + } + } + + if let TerminatorKind::SwitchInt { targets, .. } = + &mut body.basic_blocks_mut()[bb].terminator_mut().kind + { + *targets = new_targets; + } else { + unreachable!() + } + } else { + unreachable!() + } + } + } +} diff --git a/compiler/rustc_mir_transform/src/unreachable_prop.rs b/compiler/rustc_mir_transform/src/unreachable_prop.rs new file mode 100644 index 000000000..f916ca362 --- /dev/null +++ b/compiler/rustc_mir_transform/src/unreachable_prop.rs @@ -0,0 +1,102 @@ +//! A pass that propagates the unreachable terminator of a block to its predecessors +//! when all of their successors are unreachable. This is achieved through a +//! post-order traversal of the blocks. + +use crate::simplify; +use crate::MirPass; +use rustc_data_structures::fx::{FxHashMap, FxHashSet}; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; + +pub struct UnreachablePropagation; + +impl MirPass<'_> for UnreachablePropagation { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + // Enable only under -Zmir-opt-level=4 as in some cases (check the deeply-nested-opt + // perf benchmark) LLVM may spend quite a lot of time optimizing the generated code. + sess.mir_opt_level() >= 4 + } + + fn run_pass<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let mut unreachable_blocks = FxHashSet::default(); + let mut replacements = FxHashMap::default(); + + for (bb, bb_data) in traversal::postorder(body) { + let terminator = bb_data.terminator(); + if terminator.kind == TerminatorKind::Unreachable { + unreachable_blocks.insert(bb); + } else { + let is_unreachable = |succ: BasicBlock| unreachable_blocks.contains(&succ); + let terminator_kind_opt = remove_successors(&terminator.kind, is_unreachable); + + if let Some(terminator_kind) = terminator_kind_opt { + if terminator_kind == TerminatorKind::Unreachable { + unreachable_blocks.insert(bb); + } + replacements.insert(bb, terminator_kind); + } + } + } + + let replaced = !replacements.is_empty(); + for (bb, terminator_kind) in replacements { + if !tcx.consider_optimizing(|| { + format!("UnreachablePropagation {:?} ", body.source.def_id()) + }) { + break; + } + + body.basic_blocks_mut()[bb].terminator_mut().kind = terminator_kind; + } + + if replaced { + simplify::remove_dead_blocks(tcx, body); + } + } +} + +fn remove_successors<'tcx, F>( + terminator_kind: &TerminatorKind<'tcx>, + predicate: F, +) -> Option<TerminatorKind<'tcx>> +where + F: Fn(BasicBlock) -> bool, +{ + let terminator = match *terminator_kind { + TerminatorKind::Goto { target } if predicate(target) => TerminatorKind::Unreachable, + TerminatorKind::SwitchInt { ref discr, switch_ty, ref targets } => { + let otherwise = targets.otherwise(); + + let original_targets_len = targets.iter().len() + 1; + let (mut values, mut targets): (Vec<_>, Vec<_>) = + targets.iter().filter(|(_, bb)| !predicate(*bb)).unzip(); + + if !predicate(otherwise) { + targets.push(otherwise); + } else { + values.pop(); + } + + let retained_targets_len = targets.len(); + + if targets.is_empty() { + TerminatorKind::Unreachable + } else if targets.len() == 1 { + TerminatorKind::Goto { target: targets[0] } + } else if original_targets_len != retained_targets_len { + TerminatorKind::SwitchInt { + discr: discr.clone(), + switch_ty, + targets: SwitchTargets::new( + values.iter().copied().zip(targets.iter().copied()), + *targets.last().unwrap(), + ), + } + } else { + return None; + } + } + _ => return None, + }; + Some(terminator) +} |