diff options
Diffstat (limited to 'compiler/rustc_mir_transform')
53 files changed, 3664 insertions, 2455 deletions
diff --git a/compiler/rustc_mir_transform/Cargo.toml b/compiler/rustc_mir_transform/Cargo.toml index f1198d9bf..0448e9d27 100644 --- a/compiler/rustc_mir_transform/Cargo.toml +++ b/compiler/rustc_mir_transform/Cargo.toml @@ -3,30 +3,33 @@ name = "rustc_mir_transform" version = "0.0.0" edition = "2021" -[lib] - [dependencies] -itertools = "0.10.1" -smallvec = { version = "1.8.1", features = ["union", "may_dangle"] } -tracing = "0.1" +# tidy-alphabetical-start either = "1" +itertools = "0.10.1" +rustc_arena = { path = "../rustc_arena" } rustc_ast = { path = "../rustc_ast" } rustc_attr = { path = "../rustc_attr" } +rustc_const_eval = { path = "../rustc_const_eval" } rustc_data_structures = { path = "../rustc_data_structures" } rustc_errors = { path = "../rustc_errors" } +rustc_fluent_macro = { path = "../rustc_fluent_macro" } rustc_hir = { path = "../rustc_hir" } rustc_index = { path = "../rustc_index" } +rustc_macros = { path = "../rustc_macros" } rustc_middle = { path = "../rustc_middle" } -rustc_const_eval = { path = "../rustc_const_eval" } rustc_mir_build = { path = "../rustc_mir_build" } rustc_mir_dataflow = { path = "../rustc_mir_dataflow" } rustc_serialize = { path = "../rustc_serialize" } rustc_session = { path = "../rustc_session" } +rustc_span = { path = "../rustc_span" } rustc_target = { path = "../rustc_target" } rustc_trait_selection = { path = "../rustc_trait_selection" } -rustc_span = { path = "../rustc_span" } -rustc_fluent_macro = { path = "../rustc_fluent_macro" } -rustc_macros = { path = "../rustc_macros" } +smallvec = { version = "1.8.1", features = ["union", "may_dangle"] } +tracing = "0.1" +# tidy-alphabetical-end [dev-dependencies] +# tidy-alphabetical-start coverage_test_macros = { path = "src/coverage/test_macros" } +# tidy-alphabetical-end diff --git a/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs b/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs index 4500bb7ff..2b3d423ea 100644 --- a/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs +++ b/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs @@ -40,7 +40,7 @@ impl<'tcx> MirPass<'tcx> for AbortUnwindingCalls { let body_abi = match body_ty.kind() { ty::FnDef(..) => body_ty.fn_sig(tcx).abi(), ty::Closure(..) => Abi::RustCall, - ty::Generator(..) => Abi::Rust, + ty::Coroutine(..) => Abi::Rust, _ => span_bug!(body.span, "unexpected body ty: {:?}", body_ty), }; let body_can_unwind = layout::fn_can_unwind(tcx, Some(def_id), body_abi); @@ -113,6 +113,6 @@ impl<'tcx> MirPass<'tcx> for AbortUnwindingCalls { } // We may have invalidated some `cleanup` blocks so clean those up now. - super::simplify::remove_dead_blocks(tcx, body); + super::simplify::remove_dead_blocks(body); } } diff --git a/compiler/rustc_mir_transform/src/check_alignment.rs b/compiler/rustc_mir_transform/src/check_alignment.rs index 28765af20..42b2f1886 100644 --- a/compiler/rustc_mir_transform/src/check_alignment.rs +++ b/compiler/rustc_mir_transform/src/check_alignment.rs @@ -1,13 +1,12 @@ use crate::MirPass; -use rustc_hir::def_id::DefId; use rustc_hir::lang_items::LangItem; use rustc_index::IndexVec; use rustc_middle::mir::*; use rustc_middle::mir::{ interpret::Scalar, - visit::{PlaceContext, Visitor}, + visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor}, }; -use rustc_middle::ty::{Ty, TyCtxt, TypeAndMut}; +use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt, TypeAndMut}; use rustc_session::Session; pub struct CheckAlignment; @@ -30,7 +29,12 @@ impl<'tcx> MirPass<'tcx> for CheckAlignment { let basic_blocks = body.basic_blocks.as_mut(); let local_decls = &mut body.local_decls; + let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); + // This pass inserts new blocks. Each insertion changes the Location for all + // statements/blocks after. Iterating or visiting the MIR in order would require updating + // our current location after every insertion. By iterating backwards, we dodge this issue: + // The only Locations that an insertion changes have already been handled. for block in (0..basic_blocks.len()).rev() { let block = block.into(); for statement_index in (0..basic_blocks[block].statements.len()).rev() { @@ -38,22 +42,19 @@ impl<'tcx> MirPass<'tcx> for CheckAlignment { let statement = &basic_blocks[block].statements[statement_index]; let source_info = statement.source_info; - let mut finder = PointerFinder { - local_decls, - tcx, - pointers: Vec::new(), - def_id: body.source.def_id(), - }; - for (pointer, pointee_ty) in finder.find_pointers(statement) { - debug!("Inserting alignment check for {:?}", pointer.ty(&*local_decls, tcx).ty); + let mut finder = + PointerFinder { tcx, local_decls, param_env, pointers: Vec::new() }; + finder.visit_statement(statement, location); + for (local, ty) in finder.pointers { + debug!("Inserting alignment check for {:?}", ty); let new_block = split_block(basic_blocks, location); insert_alignment_check( tcx, local_decls, &mut basic_blocks[block], - pointer, - pointee_ty, + local, + ty, source_info, new_block, ); @@ -63,69 +64,71 @@ impl<'tcx> MirPass<'tcx> for CheckAlignment { } } -impl<'tcx, 'a> PointerFinder<'tcx, 'a> { - fn find_pointers(&mut self, statement: &Statement<'tcx>) -> Vec<(Place<'tcx>, Ty<'tcx>)> { - self.pointers.clear(); - self.visit_statement(statement, Location::START); - core::mem::take(&mut self.pointers) - } -} - struct PointerFinder<'tcx, 'a> { - local_decls: &'a mut LocalDecls<'tcx>, tcx: TyCtxt<'tcx>, - def_id: DefId, + local_decls: &'a mut LocalDecls<'tcx>, + param_env: ParamEnv<'tcx>, pointers: Vec<(Place<'tcx>, Ty<'tcx>)>, } impl<'tcx, 'a> Visitor<'tcx> for PointerFinder<'tcx, 'a> { - fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) { - if let Rvalue::AddressOf(..) = rvalue { - // Ignore dereferences inside of an AddressOf - return; + fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) { + // We want to only check reads and writes to Places, so we specifically exclude + // Borrows and AddressOf. + match context { + PlaceContext::MutatingUse( + MutatingUseContext::Store + | MutatingUseContext::AsmOutput + | MutatingUseContext::Call + | MutatingUseContext::Yield + | MutatingUseContext::Drop, + ) => {} + PlaceContext::NonMutatingUse( + NonMutatingUseContext::Copy | NonMutatingUseContext::Move, + ) => {} + _ => { + return; + } } - self.super_rvalue(rvalue, location); - } - fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) { - if let PlaceContext::NonUse(_) = context { - return; - } if !place.is_indirect() { return; } + // Since Deref projections must come first and only once, the pointer for an indirect place + // is the Local that the Place is based on. let pointer = Place::from(place.local); - let pointer_ty = pointer.ty(&*self.local_decls, self.tcx).ty; + let pointer_ty = self.local_decls[place.local].ty; - // We only want to check unsafe pointers + // We only want to check places based on unsafe pointers if !pointer_ty.is_unsafe_ptr() { - trace!("Indirect, but not an unsafe ptr, not checking {:?}", pointer_ty); + trace!("Indirect, but not based on an unsafe ptr, not checking {:?}", place); return; } - let Some(pointee) = pointer_ty.builtin_deref(true) else { - debug!("Indirect but no builtin deref: {:?}", pointer_ty); + let pointee_ty = + pointer_ty.builtin_deref(true).expect("no builtin_deref for an unsafe pointer").ty; + // Ideally we'd support this in the future, but for now we are limited to sized types. + if !pointee_ty.is_sized(self.tcx, self.param_env) { + debug!("Unsafe pointer, but pointee is not known to be sized: {:?}", pointer_ty); return; - }; - let mut pointee_ty = pointee.ty; - if pointee_ty.is_array() || pointee_ty.is_slice() || pointee_ty.is_str() { - pointee_ty = pointee_ty.sequence_element_type(self.tcx); } - if !pointee_ty.is_sized(self.tcx, self.tcx.param_env_reveal_all_normalized(self.def_id)) { - debug!("Unsafe pointer, but unsized: {:?}", pointer_ty); + // Try to detect types we are sure have an alignment of 1 and skip the check + // We don't need to look for str and slices, we already rejected unsized types above + let element_ty = match pointee_ty.kind() { + ty::Array(ty, _) => *ty, + _ => pointee_ty, + }; + if [self.tcx.types.bool, self.tcx.types.i8, self.tcx.types.u8].contains(&element_ty) { + debug!("Trivially aligned place type: {:?}", pointee_ty); return; } - if [self.tcx.types.bool, self.tcx.types.i8, self.tcx.types.u8, self.tcx.types.str_] - .contains(&pointee_ty) - { - debug!("Trivially aligned pointee type: {:?}", pointer_ty); - return; - } + // Ensure that this place is based on an aligned pointer. + self.pointers.push((pointer, pointee_ty)); - self.pointers.push((pointer, pointee_ty)) + self.super_place(place, context, location); } } diff --git a/compiler/rustc_mir_transform/src/check_const_item_mutation.rs b/compiler/rustc_mir_transform/src/check_const_item_mutation.rs index b79150737..61bf530f1 100644 --- a/compiler/rustc_mir_transform/src/check_const_item_mutation.rs +++ b/compiler/rustc_mir_transform/src/check_const_item_mutation.rs @@ -97,13 +97,15 @@ impl<'tcx> Visitor<'tcx> for ConstMutationChecker<'_, 'tcx> { // so emitting a lint would be redundant. if !lhs.projection.is_empty() { if let Some(def_id) = self.is_const_item_without_destructor(lhs.local) - && let Some((lint_root, span, item)) = self.should_lint_const_item_usage(&lhs, def_id, loc) { - self.tcx.emit_spanned_lint( - CONST_ITEM_MUTATION, - lint_root, - span, - errors::ConstMutate::Modify { konst: item } - ); + && let Some((lint_root, span, item)) = + self.should_lint_const_item_usage(&lhs, def_id, loc) + { + self.tcx.emit_spanned_lint( + CONST_ITEM_MUTATION, + lint_root, + span, + errors::ConstMutate::Modify { konst: item }, + ); } } // We are looking for MIR of the form: diff --git a/compiler/rustc_mir_transform/src/check_packed_ref.rs b/compiler/rustc_mir_transform/src/check_packed_ref.rs index 2e6cf603d..9ee0a7040 100644 --- a/compiler/rustc_mir_transform/src/check_packed_ref.rs +++ b/compiler/rustc_mir_transform/src/check_packed_ref.rs @@ -46,9 +46,14 @@ impl<'tcx> Visitor<'tcx> for PackedRefChecker<'_, 'tcx> { // If we ever reach here it means that the generated derive // code is somehow doing an unaligned reference, which it // shouldn't do. - span_bug!(self.source_info.span, "builtin derive created an unaligned reference"); + span_bug!( + self.source_info.span, + "builtin derive created an unaligned reference" + ); } else { - self.tcx.sess.emit_err(errors::UnalignedPackedRef { span: self.source_info.span }); + self.tcx + .sess + .emit_err(errors::UnalignedPackedRef { span: self.source_info.span }); } } } diff --git a/compiler/rustc_mir_transform/src/check_unsafety.rs b/compiler/rustc_mir_transform/src/check_unsafety.rs index bacabc62e..8872f9a97 100644 --- a/compiler/rustc_mir_transform/src/check_unsafety.rs +++ b/compiler/rustc_mir_transform/src/check_unsafety.rs @@ -56,7 +56,7 @@ impl<'tcx> Visitor<'tcx> for UnsafetyChecker<'_, 'tcx> { | TerminatorKind::Drop { .. } | TerminatorKind::Yield { .. } | TerminatorKind::Assert { .. } - | TerminatorKind::GeneratorDrop + | TerminatorKind::CoroutineDrop | TerminatorKind::UnwindResume | TerminatorKind::UnwindTerminate(_) | TerminatorKind::Return @@ -128,7 +128,7 @@ impl<'tcx> Visitor<'tcx> for UnsafetyChecker<'_, 'tcx> { ), } } - &AggregateKind::Closure(def_id, _) | &AggregateKind::Generator(def_id, _, _) => { + &AggregateKind::Closure(def_id, _) | &AggregateKind::Coroutine(def_id, _, _) => { let def_id = def_id.expect_local(); let UnsafetyCheckResult { violations, used_unsafe_blocks, .. } = self.tcx.unsafety_check_result(def_id); @@ -179,7 +179,7 @@ impl<'tcx> Visitor<'tcx> for UnsafetyChecker<'_, 'tcx> { // Check the base local: it might be an unsafe-to-access static. We only check derefs of the // temporary holding the static pointer to avoid duplicate errors // <https://github.com/rust-lang/rust/pull/78068#issuecomment-731753506>. - if decl.internal && place.projection.first() == Some(&ProjectionElem::Deref) { + if place.projection.first() == Some(&ProjectionElem::Deref) { // If the projection root is an artificial local that we introduced when // desugaring `static`, give a more specific error message // (avoid the general "raw pointer" clause below, that would only be confusing). @@ -540,8 +540,7 @@ pub fn check_unsafety(tcx: TyCtxt<'_>, def_id: LocalDefId) { && let BlockCheckMode::UnsafeBlock(_) = block.rules { true - } - else if let Some(sig) = tcx.hir().fn_sig_by_hir_id(*id) + } else if let Some(sig) = tcx.hir().fn_sig_by_hir_id(*id) && sig.header.is_unsafe() { true diff --git a/compiler/rustc_mir_transform/src/const_debuginfo.rs b/compiler/rustc_mir_transform/src/const_debuginfo.rs index 40cd28254..e4e4270c4 100644 --- a/compiler/rustc_mir_transform/src/const_debuginfo.rs +++ b/compiler/rustc_mir_transform/src/const_debuginfo.rs @@ -55,7 +55,9 @@ fn find_optimization_opportunities<'tcx>(body: &Body<'tcx>) -> Vec<(Local, Const let mut locals_to_debuginfo = BitSet::new_empty(body.local_decls.len()); for debuginfo in &body.var_debug_info { - if let VarDebugInfoContents::Place(p) = debuginfo.value && let Some(l) = p.as_local() { + if let VarDebugInfoContents::Place(p) = debuginfo.value + && let Some(l) = p.as_local() + { locals_to_debuginfo.insert(l); } } diff --git a/compiler/rustc_mir_transform/src/const_prop.rs b/compiler/rustc_mir_transform/src/const_prop.rs index 50443e739..f7f882310 100644 --- a/compiler/rustc_mir_transform/src/const_prop.rs +++ b/compiler/rustc_mir_transform/src/const_prop.rs @@ -2,8 +2,6 @@ //! assertion failures use either::Right; - -use rustc_const_eval::const_eval::CheckAlignment; use rustc_const_eval::ReportErrorExt; use rustc_data_structures::fx::FxHashSet; use rustc_hir::def::DefKind; @@ -16,7 +14,7 @@ use rustc_middle::mir::*; use rustc_middle::ty::layout::{LayoutError, LayoutOf, LayoutOfHelpers, TyAndLayout}; use rustc_middle::ty::{self, GenericArgs, Instance, ParamEnv, Ty, TyCtxt, TypeVisitableExt}; use rustc_span::{def_id::DefId, Span}; -use rustc_target::abi::{self, Align, HasDataLayout, Size, TargetDataLayout}; +use rustc_target::abi::{self, HasDataLayout, Size, TargetDataLayout}; use rustc_target::spec::abi::Abi as CallAbi; use crate::dataflow_const_prop::Patch; @@ -84,11 +82,11 @@ impl<'tcx> MirPass<'tcx> for ConstProp { return; } - // FIXME(welseywiser) const prop doesn't work on generators because of query cycles + // FIXME(welseywiser) const prop doesn't work on coroutines because of query cycles // computing their layout. - let is_generator = def_kind == DefKind::Generator; - if is_generator { - trace!("ConstProp skipped for generator {:?}", def_id); + let is_coroutine = def_kind == DefKind::Coroutine; + if is_coroutine { + trace!("ConstProp skipped for coroutine {:?}", def_id); return; } @@ -141,27 +139,14 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx> type MemoryKind = !; #[inline(always)] - fn enforce_alignment(_ecx: &InterpCx<'mir, 'tcx, Self>) -> CheckAlignment { - // We do not check for alignment to avoid having to carry an `Align` - // in `ConstValue::Indirect`. - CheckAlignment::No + fn enforce_alignment(_ecx: &InterpCx<'mir, 'tcx, Self>) -> bool { + false // no reason to enforce alignment } #[inline(always)] fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool { false // for now, we don't enforce validity } - fn alignment_check_failed( - ecx: &InterpCx<'mir, 'tcx, Self>, - _has: Align, - _required: Align, - _check: CheckAlignment, - ) -> InterpResult<'tcx, ()> { - span_bug!( - ecx.cur_span(), - "`alignment_check_failed` called when no alignment check requested" - ) - } fn load_mir( _ecx: &InterpCx<'mir, 'tcx, Self>, @@ -455,6 +440,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { // FIXME we need to revisit this for #67176 if rvalue.has_param() { + trace!("skipping, has param"); return None; } if !rvalue @@ -527,7 +513,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { fn replace_with_const(&mut self, place: Place<'tcx>) -> Option<Const<'tcx>> { // This will return None if the above `const_prop` invocation only "wrote" a - // type whose creation requires no write. E.g. a generator whose initial state + // type whose creation requires no write. E.g. a coroutine whose initial state // consists solely of uninitialized memory (so it doesn't capture any locals). let value = self.get_const(place)?; if !self.tcx.consider_optimizing(|| format!("ConstantPropagation - {value:?}")) { @@ -699,7 +685,9 @@ impl<'tcx> Visitor<'tcx> for CanConstProp { impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> { fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) { self.super_operand(operand, location); - if let Some(place) = operand.place() && let Some(value) = self.replace_with_const(place) { + if let Some(place) = operand.place() + && let Some(value) = self.replace_with_const(place) + { self.patch.before_effect.insert((location, place), value); } } @@ -721,7 +709,11 @@ impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> { fn visit_assign(&mut self, place: &Place<'tcx>, rvalue: &Rvalue<'tcx>, location: Location) { self.super_assign(place, rvalue, location); - let Some(()) = self.check_rvalue(rvalue) else { return }; + let Some(()) = self.check_rvalue(rvalue) else { + trace!("rvalue check failed, removing const"); + Self::remove_const(&mut self.ecx, place.local); + return; + }; match self.ecx.machine.can_const_prop[place.local] { // Do nothing if the place is indirect. @@ -733,7 +725,10 @@ impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> { if let Rvalue::Use(Operand::Constant(c)) = rvalue && let Const::Val(..) = c.const_ { - trace!("skipping replace of Rvalue::Use({:?} because it is already a const", c); + trace!( + "skipping replace of Rvalue::Use({:?} because it is already a const", + c + ); } else if let Some(operand) = self.replace_with_const(*place) { self.patch.assignments.insert(location, operand); } diff --git a/compiler/rustc_mir_transform/src/const_prop_lint.rs b/compiler/rustc_mir_transform/src/const_prop_lint.rs index 64e262c6c..a23ba9c4a 100644 --- a/compiler/rustc_mir_transform/src/const_prop_lint.rs +++ b/compiler/rustc_mir_transform/src/const_prop_lint.rs @@ -22,7 +22,6 @@ use rustc_middle::ty::{ }; use rustc_span::Span; use rustc_target::abi::{HasDataLayout, Size, TargetDataLayout}; -use rustc_trait_selection::traits; use crate::const_prop::CanConstProp; use crate::const_prop::ConstPropMachine; @@ -35,9 +34,9 @@ use crate::MirLint; /// Severely regress performance. const MAX_ALLOC_LIMIT: u64 = 1024; -pub struct ConstProp; +pub struct ConstPropLint; -impl<'tcx> MirLint<'tcx> for ConstProp { +impl<'tcx> MirLint<'tcx> for ConstPropLint { fn run_lint(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) { if body.tainted_by_errors.is_some() { return; @@ -49,61 +48,25 @@ impl<'tcx> MirLint<'tcx> for ConstProp { } let def_id = body.source.def_id().expect_local(); - let is_fn_like = tcx.def_kind(def_id).is_fn_like(); - let is_assoc_const = tcx.def_kind(def_id) == DefKind::AssocConst; + let def_kind = tcx.def_kind(def_id); + let is_fn_like = def_kind.is_fn_like(); + let is_assoc_const = def_kind == DefKind::AssocConst; // Only run const prop on functions, methods, closures and associated constants if !is_fn_like && !is_assoc_const { // skip anon_const/statics/consts because they'll be evaluated by miri anyway - trace!("ConstProp skipped for {:?}", def_id); + trace!("ConstPropLint skipped for {:?}", def_id); return; } - let is_generator = tcx.type_of(def_id.to_def_id()).instantiate_identity().is_generator(); - // FIXME(welseywiser) const prop doesn't work on generators because of query cycles + // FIXME(welseywiser) const prop doesn't work on coroutines because of query cycles // computing their layout. - if is_generator { - trace!("ConstProp skipped for generator {:?}", def_id); + if let DefKind::Coroutine = def_kind { + trace!("ConstPropLint skipped for coroutine {:?}", def_id); return; } - // Check if it's even possible to satisfy the 'where' clauses - // for this item. - // This branch will never be taken for any normal function. - // However, it's possible to `#!feature(trivial_bounds)]` to write - // a function with impossible to satisfy clauses, e.g.: - // `fn foo() where String: Copy {}` - // - // We don't usually need to worry about this kind of case, - // since we would get a compilation error if the user tried - // to call it. However, since we can do const propagation - // even without any calls to the function, we need to make - // sure that it even makes sense to try to evaluate the body. - // If there are unsatisfiable where clauses, then all bets are - // off, and we just give up. - // - // We manually filter the predicates, skipping anything that's not - // "global". We are in a potentially generic context - // (e.g. we are evaluating a function without substituting generic - // parameters, so this filtering serves two purposes: - // - // 1. We skip evaluating any predicates that we would - // never be able prove are unsatisfiable (e.g. `<T as Foo>` - // 2. We avoid trying to normalize predicates involving generic - // parameters (e.g. `<T as Foo>::MyItem`). This can confuse - // the normalization code (leading to cycle errors), since - // it's usually never invoked in this way. - let predicates = tcx - .predicates_of(def_id.to_def_id()) - .predicates - .iter() - .filter_map(|(p, _)| if p.is_global() { Some(*p) } else { None }); - if traits::impossible_predicates(tcx, traits::elaborate(tcx, predicates).collect()) { - trace!("ConstProp skipped for {:?}: found unsatisfiable predicates", def_id); - return; - } - - trace!("ConstProp starting for {:?}", def_id); + trace!("ConstPropLint starting for {:?}", def_id); // FIXME(oli-obk, eddyb) Optimize locals (or even local paths) to hold // constants, instead of just checking for const-folding succeeding. @@ -112,7 +75,7 @@ impl<'tcx> MirLint<'tcx> for ConstProp { let mut linter = ConstPropagator::new(body, tcx); linter.visit_body(body); - trace!("ConstProp done for {:?}", def_id); + trace!("ConstPropLint done for {:?}", def_id); } } @@ -664,9 +627,10 @@ impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> { } TerminatorKind::SwitchInt { ref discr, ref targets } => { if let Some(ref value) = self.eval_operand(&discr, location) - && let Some(value_const) = self.use_ecx(location, |this| this.ecx.read_scalar(value)) - && let Ok(constant) = value_const.try_to_int() - && let Ok(constant) = constant.to_bits(constant.size()) + && let Some(value_const) = + self.use_ecx(location, |this| this.ecx.read_scalar(value)) + && let Ok(constant) = value_const.try_to_int() + && let Ok(constant) = constant.to_bits(constant.size()) { // We managed to evaluate the discriminant, so we know we only need to visit // one target. @@ -684,7 +648,7 @@ impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> { | TerminatorKind::Unreachable | TerminatorKind::Drop { .. } | TerminatorKind::Yield { .. } - | TerminatorKind::GeneratorDrop + | TerminatorKind::CoroutineDrop | TerminatorKind::FalseEdge { .. } | TerminatorKind::FalseUnwind { .. } | TerminatorKind::Call { .. } diff --git a/compiler/rustc_mir_transform/src/copy_prop.rs b/compiler/rustc_mir_transform/src/copy_prop.rs index 9c38a6f81..f5db7ce97 100644 --- a/compiler/rustc_mir_transform/src/copy_prop.rs +++ b/compiler/rustc_mir_transform/src/copy_prop.rs @@ -168,14 +168,15 @@ impl<'tcx> MutVisitor<'tcx> for Replacer<'_, 'tcx> { && self.storage_to_remove.contains(l) { stmt.make_nop(); - return + return; } self.super_statement(stmt, loc); // Do not leave tautological assignments around. if let StatementKind::Assign(box (lhs, ref rhs)) = stmt.kind - && let Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)) | Rvalue::CopyForDeref(rhs) = *rhs + && let Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)) | Rvalue::CopyForDeref(rhs) = + *rhs && lhs == rhs { stmt.make_nop(); diff --git a/compiler/rustc_mir_transform/src/generator.rs b/compiler/rustc_mir_transform/src/coroutine.rs index e261b8ac2..abaed103f 100644 --- a/compiler/rustc_mir_transform/src/generator.rs +++ b/compiler/rustc_mir_transform/src/coroutine.rs @@ -1,53 +1,53 @@ -//! This is the implementation of the pass which transforms generators into state machines. +//! This is the implementation of the pass which transforms coroutines into state machines. //! -//! MIR generation for generators creates a function which has a self argument which -//! passes by value. This argument is effectively a generator type which only contains upvars and -//! is only used for this argument inside the MIR for the generator. +//! MIR generation for coroutines creates a function which has a self argument which +//! passes by value. This argument is effectively a coroutine type which only contains upvars and +//! is only used for this argument inside the MIR for the coroutine. //! It is passed by value to enable upvars to be moved out of it. Drop elaboration runs on that //! MIR before this pass and creates drop flags for MIR locals. -//! It will also drop the generator argument (which only consists of upvars) if any of the upvars -//! are moved out of. This pass elaborates the drops of upvars / generator argument in the case +//! It will also drop the coroutine argument (which only consists of upvars) if any of the upvars +//! are moved out of. This pass elaborates the drops of upvars / coroutine argument in the case //! that none of the upvars were moved out of. This is because we cannot have any drops of this -//! generator in the MIR, since it is used to create the drop glue for the generator. We'd get +//! coroutine in the MIR, since it is used to create the drop glue for the coroutine. We'd get //! infinite recursion otherwise. //! -//! This pass creates the implementation for either the `Generator::resume` or `Future::poll` -//! function and the drop shim for the generator based on the MIR input. -//! It converts the generator argument from Self to &mut Self adding derefs in the MIR as needed. -//! It computes the final layout of the generator struct which looks like this: +//! This pass creates the implementation for either the `Coroutine::resume` or `Future::poll` +//! function and the drop shim for the coroutine based on the MIR input. +//! It converts the coroutine argument from Self to &mut Self adding derefs in the MIR as needed. +//! It computes the final layout of the coroutine struct which looks like this: //! First upvars are stored -//! It is followed by the generator state field. +//! It is followed by the coroutine state field. //! Then finally the MIR locals which are live across a suspension point are stored. //! ```ignore (illustrative) -//! struct Generator { +//! struct Coroutine { //! upvars..., //! state: u32, //! mir_locals..., //! } //! ``` //! This pass computes the meaning of the state field and the MIR locals which are live -//! across a suspension point. There are however three hardcoded generator states: -//! 0 - Generator have not been resumed yet -//! 1 - Generator has returned / is completed -//! 2 - Generator has been poisoned +//! across a suspension point. There are however three hardcoded coroutine states: +//! 0 - Coroutine have not been resumed yet +//! 1 - Coroutine has returned / is completed +//! 2 - Coroutine has been poisoned //! -//! It also rewrites `return x` and `yield y` as setting a new generator state and returning -//! `GeneratorState::Complete(x)` and `GeneratorState::Yielded(y)`, +//! It also rewrites `return x` and `yield y` as setting a new coroutine state and returning +//! `CoroutineState::Complete(x)` and `CoroutineState::Yielded(y)`, //! or `Poll::Ready(x)` and `Poll::Pending` respectively. -//! MIR locals which are live across a suspension point are moved to the generator struct -//! with references to them being updated with references to the generator struct. +//! MIR locals which are live across a suspension point are moved to the coroutine struct +//! with references to them being updated with references to the coroutine struct. //! -//! The pass creates two functions which have a switch on the generator state giving +//! The pass creates two functions which have a switch on the coroutine state giving //! the action to take. //! -//! One of them is the implementation of `Generator::resume` / `Future::poll`. -//! For generators with state 0 (unresumed) it starts the execution of the generator. -//! For generators with state 1 (returned) and state 2 (poisoned) it panics. +//! One of them is the implementation of `Coroutine::resume` / `Future::poll`. +//! For coroutines with state 0 (unresumed) it starts the execution of the coroutine. +//! For coroutines with state 1 (returned) and state 2 (poisoned) it panics. //! Otherwise it continues the execution from the last suspension point. //! -//! The other function is the drop glue for the generator. -//! For generators with state 0 (unresumed) it drops the upvars of the generator. -//! For generators with state 1 (returned) and state 2 (poisoned) it does nothing. +//! The other function is the drop glue for the coroutine. +//! For coroutines with state 0 (unresumed) it drops the upvars of the coroutine. +//! For coroutines with state 1 (returned) and state 2 (poisoned) it does nothing. //! Otherwise it drops all the values in scope at the last suspension point. use crate::abort_unwinding_calls; @@ -60,7 +60,7 @@ use rustc_data_structures::fx::{FxHashMap, FxHashSet}; use rustc_errors::pluralize; use rustc_hir as hir; use rustc_hir::lang_items::LangItem; -use rustc_hir::GeneratorKind; +use rustc_hir::CoroutineKind; use rustc_index::bit_set::{BitMatrix, BitSet, GrowableBitSet}; use rustc_index::{Idx, IndexVec}; use rustc_middle::mir::dump_mir; @@ -68,7 +68,7 @@ use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; use rustc_middle::mir::*; use rustc_middle::ty::InstanceDef; use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt}; -use rustc_middle::ty::{GeneratorArgs, GenericArgsRef}; +use rustc_middle::ty::{CoroutineArgs, GenericArgsRef}; use rustc_mir_dataflow::impls::{ MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage, MaybeStorageLive, }; @@ -147,7 +147,7 @@ impl<'tcx> MutVisitor<'tcx> for DerefArgVisitor<'tcx> { } struct PinArgVisitor<'tcx> { - ref_gen_ty: Ty<'tcx>, + ref_coroutine_ty: Ty<'tcx>, tcx: TyCtxt<'tcx>, } @@ -168,7 +168,7 @@ impl<'tcx> MutVisitor<'tcx> for PinArgVisitor<'tcx> { local: SELF_ARG, projection: self.tcx().mk_place_elems(&[ProjectionElem::Field( FieldIdx::new(0), - self.ref_gen_ty, + self.ref_coroutine_ty, )]), }, self.tcx, @@ -196,19 +196,19 @@ fn replace_base<'tcx>(place: &mut Place<'tcx>, new_base: Place<'tcx>, tcx: TyCtx const SELF_ARG: Local = Local::from_u32(1); -/// Generator has not been resumed yet. -const UNRESUMED: usize = GeneratorArgs::UNRESUMED; -/// Generator has returned / is completed. -const RETURNED: usize = GeneratorArgs::RETURNED; -/// Generator has panicked and is poisoned. -const POISONED: usize = GeneratorArgs::POISONED; +/// Coroutine has not been resumed yet. +const UNRESUMED: usize = CoroutineArgs::UNRESUMED; +/// Coroutine has returned / is completed. +const RETURNED: usize = CoroutineArgs::RETURNED; +/// Coroutine has panicked and is poisoned. +const POISONED: usize = CoroutineArgs::POISONED; -/// Number of variants to reserve in generator state. Corresponds to -/// `UNRESUMED` (beginning of a generator) and `RETURNED`/`POISONED` -/// (end of a generator) states. +/// Number of variants to reserve in coroutine state. Corresponds to +/// `UNRESUMED` (beginning of a coroutine) and `RETURNED`/`POISONED` +/// (end of a coroutine) states. const RESERVED_VARIANTS: usize = 3; -/// A `yield` point in the generator. +/// A `yield` point in the coroutine. struct SuspensionPoint<'tcx> { /// State discriminant used when suspending or resuming at this point. state: usize, @@ -216,7 +216,7 @@ struct SuspensionPoint<'tcx> { resume: BasicBlock, /// Where to move the resume argument after resumption. resume_arg: Place<'tcx>, - /// Which block to jump to if the generator is dropped in this state. + /// Which block to jump to if the coroutine is dropped in this state. drop: Option<BasicBlock>, /// Set of locals that have live storage while at this suspension point. storage_liveness: GrowableBitSet<Local>, @@ -224,14 +224,14 @@ struct SuspensionPoint<'tcx> { struct TransformVisitor<'tcx> { tcx: TyCtxt<'tcx>, - is_async_kind: bool, + coroutine_kind: hir::CoroutineKind, state_adt_ref: AdtDef<'tcx>, state_args: GenericArgsRef<'tcx>, - // The type of the discriminant in the generator struct + // The type of the discriminant in the coroutine struct discr_ty: Ty<'tcx>, - // Mapping from Local to (type of local, generator struct index) + // Mapping from Local to (type of local, coroutine struct index) // FIXME(eddyb) This should use `IndexVec<Local, Option<_>>`. remap: FxHashMap<Local, (Ty<'tcx>, VariantIdx, FieldIdx)>, @@ -249,9 +249,50 @@ struct TransformVisitor<'tcx> { } impl<'tcx> TransformVisitor<'tcx> { - // Make a `GeneratorState` or `Poll` variant assignment. + fn insert_none_ret_block(&self, body: &mut Body<'tcx>) -> BasicBlock { + let block = BasicBlock::new(body.basic_blocks.len()); + + let source_info = SourceInfo::outermost(body.span); + + let (kind, idx) = self.coroutine_state_adt_and_variant_idx(true); + assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0); + let statements = vec![Statement { + kind: StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::Aggregate(Box::new(kind), IndexVec::new()), + ))), + source_info, + }]; + + body.basic_blocks_mut().push(BasicBlockData { + statements, + terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }), + is_cleanup: false, + }); + + block + } + + fn coroutine_state_adt_and_variant_idx( + &self, + is_return: bool, + ) -> (AggregateKind<'tcx>, VariantIdx) { + let idx = VariantIdx::new(match (is_return, self.coroutine_kind) { + (true, hir::CoroutineKind::Coroutine) => 1, // CoroutineState::Complete + (false, hir::CoroutineKind::Coroutine) => 0, // CoroutineState::Yielded + (true, hir::CoroutineKind::Async(_)) => 0, // Poll::Ready + (false, hir::CoroutineKind::Async(_)) => 1, // Poll::Pending + (true, hir::CoroutineKind::Gen(_)) => 0, // Option::None + (false, hir::CoroutineKind::Gen(_)) => 1, // Option::Some + }); + + let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_args, None, None); + (kind, idx) + } + + // Make a `CoroutineState` or `Poll` variant assignment. // - // `core::ops::GeneratorState` only has single element tuple variants, + // `core::ops::CoroutineState` only has single element tuple variants, // so we can just write to the downcasted first field and then set the // discriminant to the appropriate variant. fn make_state( @@ -261,31 +302,44 @@ impl<'tcx> TransformVisitor<'tcx> { is_return: bool, statements: &mut Vec<Statement<'tcx>>, ) { - let idx = VariantIdx::new(match (is_return, self.is_async_kind) { - (true, false) => 1, // GeneratorState::Complete - (false, false) => 0, // GeneratorState::Yielded - (true, true) => 0, // Poll::Ready - (false, true) => 1, // Poll::Pending - }); + let (kind, idx) = self.coroutine_state_adt_and_variant_idx(is_return); - let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_args, None, None); + match self.coroutine_kind { + // `Poll::Pending` + CoroutineKind::Async(_) => { + if !is_return { + assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0); - // `Poll::Pending` - if self.is_async_kind && idx == VariantIdx::new(1) { - assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0); + // FIXME(swatinem): assert that `val` is indeed unit? + statements.push(Statement { + kind: StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::Aggregate(Box::new(kind), IndexVec::new()), + ))), + source_info, + }); + return; + } + } + // `Option::None` + CoroutineKind::Gen(_) => { + if is_return { + assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0); - // FIXME(swatinem): assert that `val` is indeed unit? - statements.push(Statement { - kind: StatementKind::Assign(Box::new(( - Place::return_place(), - Rvalue::Aggregate(Box::new(kind), IndexVec::new()), - ))), - source_info, - }); - return; + statements.push(Statement { + kind: StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::Aggregate(Box::new(kind), IndexVec::new()), + ))), + source_info, + }); + return; + } + } + CoroutineKind::Coroutine => {} } - // else: `Poll::Ready(x)`, `GeneratorState::Yielded(x)` or `GeneratorState::Complete(x)` + // else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)`, `CoroutineState::Complete(x)`, or `Option::Some(x)` assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1); statements.push(Statement { @@ -297,7 +351,7 @@ impl<'tcx> TransformVisitor<'tcx> { }); } - // Create a Place referencing a generator struct field + // Create a Place referencing a coroutine struct field fn make_field(&self, variant_index: VariantIdx, idx: FieldIdx, ty: Ty<'tcx>) -> Place<'tcx> { let self_place = Place::from(SELF_ARG); let base = self.tcx.mk_place_downcast_unnamed(self_place, variant_index); @@ -321,7 +375,7 @@ impl<'tcx> TransformVisitor<'tcx> { // Create a statement which reads the discriminant into a temporary fn get_discr(&self, body: &mut Body<'tcx>) -> (Statement<'tcx>, Place<'tcx>) { - let temp_decl = LocalDecl::new(self.discr_ty, body.span).internal(); + let temp_decl = LocalDecl::new(self.discr_ty, body.span); let local_decls_len = body.local_decls.push(temp_decl); let temp = Place::from(local_decls_len); @@ -349,7 +403,7 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> { _context: PlaceContext, _location: Location, ) { - // Replace an Local in the remap with a generator struct access + // Replace an Local in the remap with a coroutine struct access if let Some(&(ty, variant_index, idx)) = self.remap.get(&place.local) { replace_base(place, self.make_field(variant_index, idx, ty), self.tcx); } @@ -413,35 +467,35 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> { } } -fn make_generator_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let gen_ty = body.local_decls.raw[1].ty; +fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let coroutine_ty = body.local_decls.raw[1].ty; - let ref_gen_ty = Ty::new_ref( + let ref_coroutine_ty = Ty::new_ref( tcx, tcx.lifetimes.re_erased, - ty::TypeAndMut { ty: gen_ty, mutbl: Mutability::Mut }, + ty::TypeAndMut { ty: coroutine_ty, mutbl: Mutability::Mut }, ); - // Replace the by value generator argument - body.local_decls.raw[1].ty = ref_gen_ty; + // Replace the by value coroutine argument + body.local_decls.raw[1].ty = ref_coroutine_ty; - // Add a deref to accesses of the generator state + // Add a deref to accesses of the coroutine state DerefArgVisitor { tcx }.visit_body(body); } -fn make_generator_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let ref_gen_ty = body.local_decls.raw[1].ty; +fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let ref_coroutine_ty = body.local_decls.raw[1].ty; let pin_did = tcx.require_lang_item(LangItem::Pin, Some(body.span)); let pin_adt_ref = tcx.adt_def(pin_did); - let args = tcx.mk_args(&[ref_gen_ty.into()]); - let pin_ref_gen_ty = Ty::new_adt(tcx, pin_adt_ref, args); + let args = tcx.mk_args(&[ref_coroutine_ty.into()]); + let pin_ref_coroutine_ty = Ty::new_adt(tcx, pin_adt_ref, args); - // Replace the by ref generator argument - body.local_decls.raw[1].ty = pin_ref_gen_ty; + // Replace the by ref coroutine argument + body.local_decls.raw[1].ty = pin_ref_coroutine_ty; - // Add the Pin field access to accesses of the generator state - PinArgVisitor { ref_gen_ty, tcx }.visit_body(body); + // Add the Pin field access to accesses of the coroutine state + PinArgVisitor { ref_coroutine_ty, tcx }.visit_body(body); } /// Allocates a new local and replaces all references of `local` with it. Returns the new local. @@ -465,7 +519,7 @@ fn replace_local<'tcx>( new_local } -/// Transforms the `body` of the generator applying the following transforms: +/// Transforms the `body` of the coroutine applying the following transforms: /// /// - Eliminates all the `get_context` calls that async lowering created. /// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`). @@ -485,7 +539,7 @@ fn replace_local<'tcx>( /// /// The async lowering step and the type / lifetime inference / checking are /// still using the `ResumeTy` indirection for the time being, and that indirection -/// is removed here. After this transform, the generator body only knows about `&mut Context<'_>`. +/// is removed here. After this transform, the coroutine body only knows about `&mut Context<'_>`. fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let context_mut_ref = Ty::new_task_context(tcx); @@ -565,10 +619,10 @@ fn replace_resume_ty_local<'tcx>( struct LivenessInfo { /// Which locals are live across any suspension point. - saved_locals: GeneratorSavedLocals, + saved_locals: CoroutineSavedLocals, /// The set of saved locals live at each suspension point. - live_locals_at_suspension_points: Vec<BitSet<GeneratorSavedLocal>>, + live_locals_at_suspension_points: Vec<BitSet<CoroutineSavedLocal>>, /// Parallel vec to the above with SourceInfo for each yield terminator. source_info_at_suspension_points: Vec<SourceInfo>, @@ -576,7 +630,7 @@ struct LivenessInfo { /// For every saved local, the set of other saved locals that are /// storage-live at the same time as this local. We cannot overlap locals in /// the layout which have conflicting storage. - storage_conflicts: BitMatrix<GeneratorSavedLocal, GeneratorSavedLocal>, + storage_conflicts: BitMatrix<CoroutineSavedLocal, CoroutineSavedLocal>, /// For every suspending block, the locals which are storage-live across /// that suspension point. @@ -609,7 +663,7 @@ fn locals_live_across_suspend_points<'tcx>( // Calculate the MIR locals which have been previously // borrowed (even if they are still active). let borrowed_locals_results = - MaybeBorrowedLocals.into_engine(tcx, body_ref).pass_name("generator").iterate_to_fixpoint(); + MaybeBorrowedLocals.into_engine(tcx, body_ref).pass_name("coroutine").iterate_to_fixpoint(); let mut borrowed_locals_cursor = borrowed_locals_results.cloned_results_cursor(body_ref); @@ -624,7 +678,7 @@ fn locals_live_across_suspend_points<'tcx>( // Calculate the liveness of MIR locals ignoring borrows. let mut liveness = MaybeLiveLocals .into_engine(tcx, body_ref) - .pass_name("generator") + .pass_name("coroutine") .iterate_to_fixpoint() .into_results_cursor(body_ref); @@ -643,8 +697,8 @@ fn locals_live_across_suspend_points<'tcx>( if !movable { // The `liveness` variable contains the liveness of MIR locals ignoring borrows. - // This is correct for movable generators since borrows cannot live across - // suspension points. However for immovable generators we need to account for + // This is correct for movable coroutines since borrows cannot live across + // suspension points. However for immovable coroutines we need to account for // borrows, so we conservatively assume that all borrowed locals are live until // we find a StorageDead statement referencing the locals. // To do this we just union our `liveness` result with `borrowed_locals`, which @@ -667,7 +721,7 @@ fn locals_live_across_suspend_points<'tcx>( requires_storage_cursor.seek_before_primary_effect(loc); live_locals.intersect(requires_storage_cursor.get()); - // The generator argument is ignored. + // The coroutine argument is ignored. live_locals.remove(SELF_ARG); debug!("loc = {:?}, live_locals = {:?}", loc, live_locals); @@ -682,7 +736,7 @@ fn locals_live_across_suspend_points<'tcx>( } debug!("live_locals_anywhere = {:?}", live_locals_at_any_suspension_point); - let saved_locals = GeneratorSavedLocals(live_locals_at_any_suspension_point); + let saved_locals = CoroutineSavedLocals(live_locals_at_any_suspension_point); // Renumber our liveness_map bitsets to include only the locals we are // saving. @@ -709,21 +763,21 @@ fn locals_live_across_suspend_points<'tcx>( /// The set of `Local`s that must be saved across yield points. /// -/// `GeneratorSavedLocal` is indexed in terms of the elements in this set; -/// i.e. `GeneratorSavedLocal::new(1)` corresponds to the second local +/// `CoroutineSavedLocal` is indexed in terms of the elements in this set; +/// i.e. `CoroutineSavedLocal::new(1)` corresponds to the second local /// included in this set. -struct GeneratorSavedLocals(BitSet<Local>); +struct CoroutineSavedLocals(BitSet<Local>); -impl GeneratorSavedLocals { - /// Returns an iterator over each `GeneratorSavedLocal` along with the `Local` it corresponds +impl CoroutineSavedLocals { + /// Returns an iterator over each `CoroutineSavedLocal` along with the `Local` it corresponds /// to. - fn iter_enumerated(&self) -> impl '_ + Iterator<Item = (GeneratorSavedLocal, Local)> { - self.iter().enumerate().map(|(i, l)| (GeneratorSavedLocal::from(i), l)) + fn iter_enumerated(&self) -> impl '_ + Iterator<Item = (CoroutineSavedLocal, Local)> { + self.iter().enumerate().map(|(i, l)| (CoroutineSavedLocal::from(i), l)) } /// Transforms a `BitSet<Local>` that contains only locals saved across yield points to the - /// equivalent `BitSet<GeneratorSavedLocal>`. - fn renumber_bitset(&self, input: &BitSet<Local>) -> BitSet<GeneratorSavedLocal> { + /// equivalent `BitSet<CoroutineSavedLocal>`. + fn renumber_bitset(&self, input: &BitSet<Local>) -> BitSet<CoroutineSavedLocal> { assert!(self.superset(&input), "{:?} not a superset of {:?}", self.0, input); let mut out = BitSet::new_empty(self.count()); for (saved_local, local) in self.iter_enumerated() { @@ -734,17 +788,17 @@ impl GeneratorSavedLocals { out } - fn get(&self, local: Local) -> Option<GeneratorSavedLocal> { + fn get(&self, local: Local) -> Option<CoroutineSavedLocal> { if !self.contains(local) { return None; } let idx = self.iter().take_while(|&l| l < local).count(); - Some(GeneratorSavedLocal::new(idx)) + Some(CoroutineSavedLocal::new(idx)) } } -impl ops::Deref for GeneratorSavedLocals { +impl ops::Deref for CoroutineSavedLocals { type Target = BitSet<Local>; fn deref(&self) -> &Self::Target { @@ -755,13 +809,13 @@ impl ops::Deref for GeneratorSavedLocals { /// For every saved local, looks for which locals are StorageLive at the same /// time. Generates a bitset for every local of all the other locals that may be /// StorageLive simultaneously with that local. This is used in the layout -/// computation; see `GeneratorLayout` for more. +/// computation; see `CoroutineLayout` for more. fn compute_storage_conflicts<'mir, 'tcx>( body: &'mir Body<'tcx>, - saved_locals: &GeneratorSavedLocals, + saved_locals: &CoroutineSavedLocals, always_live_locals: BitSet<Local>, mut requires_storage: rustc_mir_dataflow::Results<'tcx, MaybeRequiresStorage<'_, 'mir, 'tcx>>, -) -> BitMatrix<GeneratorSavedLocal, GeneratorSavedLocal> { +) -> BitMatrix<CoroutineSavedLocal, CoroutineSavedLocal> { assert_eq!(body.local_decls.len(), saved_locals.domain_size()); debug!("compute_storage_conflicts({:?})", body.span); @@ -783,7 +837,7 @@ fn compute_storage_conflicts<'mir, 'tcx>( let local_conflicts = visitor.local_conflicts; - // Compress the matrix using only stored locals (Local -> GeneratorSavedLocal). + // Compress the matrix using only stored locals (Local -> CoroutineSavedLocal). // // NOTE: Today we store a full conflict bitset for every local. Technically // this is twice as many bits as we need, since the relation is symmetric. @@ -809,9 +863,9 @@ fn compute_storage_conflicts<'mir, 'tcx>( struct StorageConflictVisitor<'mir, 'tcx, 's> { body: &'mir Body<'tcx>, - saved_locals: &'s GeneratorSavedLocals, + saved_locals: &'s CoroutineSavedLocals, // FIXME(tmandry): Consider using sparse bitsets here once we have good - // benchmarks for generators. + // benchmarks for coroutines. local_conflicts: BitMatrix<Local, Local>, } @@ -866,7 +920,7 @@ fn compute_layout<'tcx>( body: &Body<'tcx>, ) -> ( FxHashMap<Local, (Ty<'tcx>, VariantIdx, FieldIdx)>, - GeneratorLayout<'tcx>, + CoroutineLayout<'tcx>, IndexVec<BasicBlock, Option<BitSet<Local>>>, ) { let LivenessInfo { @@ -878,10 +932,10 @@ fn compute_layout<'tcx>( } = liveness; // Gather live local types and their indices. - let mut locals = IndexVec::<GeneratorSavedLocal, _>::new(); - let mut tys = IndexVec::<GeneratorSavedLocal, _>::new(); + let mut locals = IndexVec::<CoroutineSavedLocal, _>::new(); + let mut tys = IndexVec::<CoroutineSavedLocal, _>::new(); for (saved_local, local) in saved_locals.iter_enumerated() { - debug!("generator saved local {:?} => {:?}", saved_local, local); + debug!("coroutine saved local {:?} => {:?}", saved_local, local); locals.push(local); let decl = &body.local_decls[local]; @@ -903,7 +957,7 @@ fn compute_layout<'tcx>( _ => false, }; let decl = - GeneratorSavedTy { ty: decl.ty, source_info: decl.source_info, ignore_for_traits }; + CoroutineSavedTy { ty: decl.ty, source_info: decl.source_info, ignore_for_traits }; debug!(?decl); tys.push(decl); @@ -922,9 +976,9 @@ fn compute_layout<'tcx>( .copied() .collect(); - // Build the generator variant field list. - // Create a map from local indices to generator struct indices. - let mut variant_fields: IndexVec<VariantIdx, IndexVec<FieldIdx, GeneratorSavedLocal>> = + // Build the coroutine variant field list. + // Create a map from local indices to coroutine struct indices. + let mut variant_fields: IndexVec<VariantIdx, IndexVec<FieldIdx, CoroutineSavedLocal>> = iter::repeat(IndexVec::new()).take(RESERVED_VARIANTS).collect(); let mut remap = FxHashMap::default(); for (suspension_point_idx, live_locals) in live_locals_at_suspension_points.iter().enumerate() { @@ -934,7 +988,7 @@ fn compute_layout<'tcx>( fields.push(saved_local); // Note that if a field is included in multiple variants, we will // just use the first one here. That's fine; fields do not move - // around inside generators, so it doesn't matter which variant + // around inside coroutines, so it doesn't matter which variant // index we access them by. let idx = FieldIdx::from_usize(idx); remap.entry(locals[saved_local]).or_insert((tys[saved_local].ty, variant_index, idx)); @@ -942,8 +996,8 @@ fn compute_layout<'tcx>( variant_fields.push(fields); variant_source_info.push(source_info_at_suspension_points[suspension_point_idx]); } - debug!("generator variant_fields = {:?}", variant_fields); - debug!("generator storage_conflicts = {:#?}", storage_conflicts); + debug!("coroutine variant_fields = {:?}", variant_fields); + debug!("coroutine storage_conflicts = {:#?}", storage_conflicts); let mut field_names = IndexVec::from_elem(None, &tys); for var in &body.var_debug_info { @@ -955,7 +1009,7 @@ fn compute_layout<'tcx>( field_names.get_or_insert_with(saved_local, || var.name); } - let layout = GeneratorLayout { + let layout = CoroutineLayout { field_tys: tys, field_names, variant_fields, @@ -967,7 +1021,7 @@ fn compute_layout<'tcx>( (remap, layout, storage_liveness) } -/// Replaces the entry point of `body` with a block that switches on the generator discriminant and +/// Replaces the entry point of `body` with a block that switches on the coroutine discriminant and /// dispatches to blocks according to `cases`. /// /// After this function, the former entry point of the function will be bb1. @@ -1000,14 +1054,14 @@ fn insert_switch<'tcx>( } } -fn elaborate_generator_drops<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { +fn elaborate_coroutine_drops<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { use crate::shim::DropShimElaborator; use rustc_middle::mir::patch::MirPatch; use rustc_mir_dataflow::elaborate_drops::{elaborate_drop, Unwind}; - // Note that `elaborate_drops` only drops the upvars of a generator, and + // Note that `elaborate_drops` only drops the upvars of a coroutine, and // this is ok because `open_drop` can only be reached within that own - // generator's resume function. + // coroutine's resume function. let def_id = body.source.def_id(); let param_env = tcx.param_env(def_id); @@ -1055,10 +1109,10 @@ fn elaborate_generator_drops<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { elaborator.patch.apply(body); } -fn create_generator_drop_shim<'tcx>( +fn create_coroutine_drop_shim<'tcx>( tcx: TyCtxt<'tcx>, transform: &TransformVisitor<'tcx>, - gen_ty: Ty<'tcx>, + coroutine_ty: Ty<'tcx>, body: &mut Body<'tcx>, drop_clean: BasicBlock, ) -> Body<'tcx> { @@ -1078,7 +1132,7 @@ fn create_generator_drop_shim<'tcx>( for block in body.basic_blocks_mut() { let kind = &mut block.terminator_mut().kind; - if let TerminatorKind::GeneratorDrop = *kind { + if let TerminatorKind::CoroutineDrop = *kind { *kind = TerminatorKind::Return; } } @@ -1086,36 +1140,27 @@ fn create_generator_drop_shim<'tcx>( // Replace the return variable body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(Ty::new_unit(tcx), source_info); - make_generator_state_argument_indirect(tcx, &mut body); + make_coroutine_state_argument_indirect(tcx, &mut body); - // Change the generator argument from &mut to *mut + // Change the coroutine argument from &mut to *mut body.local_decls[SELF_ARG] = LocalDecl::with_source_info( - Ty::new_ptr(tcx, ty::TypeAndMut { ty: gen_ty, mutbl: hir::Mutability::Mut }), + Ty::new_ptr(tcx, ty::TypeAndMut { ty: coroutine_ty, mutbl: hir::Mutability::Mut }), source_info, ); // Make sure we remove dead blocks to remove // unrelated code from the resume part of the function - simplify::remove_dead_blocks(tcx, &mut body); + simplify::remove_dead_blocks(&mut body); // Update the body's def to become the drop glue. - // This needs to be updated before the AbortUnwindingCalls pass. - let gen_instance = body.source.instance; + let coroutine_instance = body.source.instance; let drop_in_place = tcx.require_lang_item(LangItem::DropInPlace, None); - let drop_instance = InstanceDef::DropGlue(drop_in_place, Some(gen_ty)); - body.source.instance = drop_instance; - - pm::run_passes_no_validate( - tcx, - &mut body, - &[&abort_unwinding_calls::AbortUnwindingCalls], - None, - ); + let drop_instance = InstanceDef::DropGlue(drop_in_place, Some(coroutine_ty)); - // Temporary change MirSource to generator's instance so that dump_mir produces more sensible + // Temporary change MirSource to coroutine's instance so that dump_mir produces more sensible // filename. - body.source.instance = gen_instance; - dump_mir(tcx, false, "generator_drop", &0, &body, |_, _| Ok(())); + body.source.instance = coroutine_instance; + dump_mir(tcx, false, "coroutine_drop", &0, &body, |_, _| Ok(())); body.source.instance = drop_instance; body @@ -1190,7 +1235,7 @@ fn can_unwind<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> bool { | TerminatorKind::UnwindTerminate(_) | TerminatorKind::Return | TerminatorKind::Unreachable - | TerminatorKind::GeneratorDrop + | TerminatorKind::CoroutineDrop | TerminatorKind::FalseEdge { .. } | TerminatorKind::FalseUnwind { .. } => {} @@ -1199,7 +1244,7 @@ fn can_unwind<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> bool { TerminatorKind::UnwindResume => {} TerminatorKind::Yield { .. } => { - unreachable!("`can_unwind` called before generator transform") + unreachable!("`can_unwind` called before coroutine transform") } // These may unwind. @@ -1214,7 +1259,7 @@ fn can_unwind<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> bool { false } -fn create_generator_resume_function<'tcx>( +fn create_coroutine_resume_function<'tcx>( tcx: TyCtxt<'tcx>, transform: TransformVisitor<'tcx>, body: &mut Body<'tcx>, @@ -1222,7 +1267,7 @@ fn create_generator_resume_function<'tcx>( ) { let can_unwind = can_unwind(tcx, body); - // Poison the generator when it unwinds + // Poison the coroutine when it unwinds if can_unwind { let source_info = SourceInfo::outermost(body.span); let poison_block = body.basic_blocks_mut().push(BasicBlockData { @@ -1261,34 +1306,37 @@ fn create_generator_resume_function<'tcx>( cases.insert(0, (UNRESUMED, START_BLOCK)); // Panic when resumed on the returned or poisoned state - let generator_kind = body.generator_kind().unwrap(); + let coroutine_kind = body.coroutine_kind().unwrap(); if can_unwind { cases.insert( 1, - (POISONED, insert_panic_block(tcx, body, ResumedAfterPanic(generator_kind))), + (POISONED, insert_panic_block(tcx, body, ResumedAfterPanic(coroutine_kind))), ); } if can_return { - cases.insert( - 1, - (RETURNED, insert_panic_block(tcx, body, ResumedAfterReturn(generator_kind))), - ); + let block = match coroutine_kind { + CoroutineKind::Async(_) | CoroutineKind::Coroutine => { + insert_panic_block(tcx, body, ResumedAfterReturn(coroutine_kind)) + } + CoroutineKind::Gen(_) => transform.insert_none_ret_block(body), + }; + cases.insert(1, (RETURNED, block)); } insert_switch(body, cases, &transform, TerminatorKind::Unreachable); - make_generator_state_argument_indirect(tcx, body); - make_generator_state_argument_pinned(tcx, body); + make_coroutine_state_argument_indirect(tcx, body); + make_coroutine_state_argument_pinned(tcx, body); // Make sure we remove dead blocks to remove // unrelated code from the drop part of the function - simplify::remove_dead_blocks(tcx, body); + simplify::remove_dead_blocks(body); pm::run_passes_no_validate(tcx, body, &[&abort_unwinding_calls::AbortUnwindingCalls], None); - dump_mir(tcx, false, "generator_resume", &0, body, |_, _| Ok(())); + dump_mir(tcx, false, "coroutine_resume", &0, body, |_, _| Ok(())); } fn insert_clean_drop(body: &mut Body<'_>) -> BasicBlock { @@ -1302,7 +1350,7 @@ fn insert_clean_drop(body: &mut Body<'_>) -> BasicBlock { }; let source_info = SourceInfo::outermost(body.span); - // Create a block to destroy an unresumed generators. This can only destroy upvars. + // Create a block to destroy an unresumed coroutines. This can only destroy upvars. body.basic_blocks_mut().push(BasicBlockData { statements: Vec::new(), terminator: Some(Terminator { source_info, kind: term }), @@ -1310,7 +1358,7 @@ fn insert_clean_drop(body: &mut Body<'_>) -> BasicBlock { }) } -/// An operation that can be performed on a generator. +/// An operation that can be performed on a coroutine. #[derive(PartialEq, Copy, Clone)] enum Operation { Resume, @@ -1389,21 +1437,21 @@ fn create_cases<'tcx>( } #[instrument(level = "debug", skip(tcx), ret)] -pub(crate) fn mir_generator_witnesses<'tcx>( +pub(crate) fn mir_coroutine_witnesses<'tcx>( tcx: TyCtxt<'tcx>, def_id: LocalDefId, -) -> Option<GeneratorLayout<'tcx>> { +) -> Option<CoroutineLayout<'tcx>> { let (body, _) = tcx.mir_promoted(def_id); let body = body.borrow(); let body = &*body; - // The first argument is the generator type passed by value - let gen_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty; + // The first argument is the coroutine type passed by value + let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty; - let movable = match *gen_ty.kind() { - ty::Generator(_, _, movability) => movability == hir::Movability::Movable, + let movable = match *coroutine_ty.kind() { + ty::Coroutine(_, _, movability) => movability == hir::Movability::Movable, ty::Error(_) => return None, - _ => span_bug!(body.span, "unexpected generator type {}", gen_ty), + _ => span_bug!(body.span, "unexpected coroutine type {}", coroutine_ty), }; // The witness simply contains all locals live across suspend points. @@ -1412,52 +1460,63 @@ pub(crate) fn mir_generator_witnesses<'tcx>( let liveness_info = locals_live_across_suspend_points(tcx, body, &always_live_locals, movable); // Extract locals which are live across suspension point into `layout` - // `remap` gives a mapping from local indices onto generator struct indices + // `remap` gives a mapping from local indices onto coroutine struct indices // `storage_liveness` tells us which locals have live storage at suspension points - let (_, generator_layout, _) = compute_layout(liveness_info, body); + let (_, coroutine_layout, _) = compute_layout(liveness_info, body); - check_suspend_tys(tcx, &generator_layout, &body); + check_suspend_tys(tcx, &coroutine_layout, &body); - Some(generator_layout) + Some(coroutine_layout) } impl<'tcx> MirPass<'tcx> for StateTransform { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let Some(yield_ty) = body.yield_ty() else { - // This only applies to generators + // This only applies to coroutines return; }; - assert!(body.generator_drop().is_none()); + assert!(body.coroutine_drop().is_none()); - // The first argument is the generator type passed by value - let gen_ty = body.local_decls.raw[1].ty; + // The first argument is the coroutine type passed by value + let coroutine_ty = body.local_decls.raw[1].ty; // Get the discriminant type and args which typeck computed - let (discr_ty, movable) = match *gen_ty.kind() { - ty::Generator(_, args, movability) => { - let args = args.as_generator(); + let (discr_ty, movable) = match *coroutine_ty.kind() { + ty::Coroutine(_, args, movability) => { + let args = args.as_coroutine(); (args.discr_ty(tcx), movability == hir::Movability::Movable) } _ => { - tcx.sess.delay_span_bug(body.span, format!("unexpected generator type {gen_ty}")); + tcx.sess + .delay_span_bug(body.span, format!("unexpected coroutine type {coroutine_ty}")); return; } }; - let is_async_kind = matches!(body.generator_kind(), Some(GeneratorKind::Async(_))); - let (state_adt_ref, state_args) = if is_async_kind { - // Compute Poll<return_ty> - let poll_did = tcx.require_lang_item(LangItem::Poll, None); - let poll_adt_ref = tcx.adt_def(poll_did); - let poll_args = tcx.mk_args(&[body.return_ty().into()]); - (poll_adt_ref, poll_args) - } else { - // Compute GeneratorState<yield_ty, return_ty> - let state_did = tcx.require_lang_item(LangItem::GeneratorState, None); - let state_adt_ref = tcx.adt_def(state_did); - let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]); - (state_adt_ref, state_args) + let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_))); + let (state_adt_ref, state_args) = match body.coroutine_kind().unwrap() { + CoroutineKind::Async(_) => { + // Compute Poll<return_ty> + let poll_did = tcx.require_lang_item(LangItem::Poll, None); + let poll_adt_ref = tcx.adt_def(poll_did); + let poll_args = tcx.mk_args(&[body.return_ty().into()]); + (poll_adt_ref, poll_args) + } + CoroutineKind::Gen(_) => { + // Compute Option<yield_ty> + let option_did = tcx.require_lang_item(LangItem::Option, None); + let option_adt_ref = tcx.adt_def(option_did); + let option_args = tcx.mk_args(&[body.yield_ty().unwrap().into()]); + (option_adt_ref, option_args) + } + CoroutineKind::Coroutine => { + // Compute CoroutineState<yield_ty, return_ty> + let state_did = tcx.require_lang_item(LangItem::CoroutineState, None); + let state_adt_ref = tcx.adt_def(state_did); + let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]); + (state_adt_ref, state_args) + } }; let ret_ty = Ty::new_adt(tcx, state_adt_ref, state_args); @@ -1472,8 +1531,8 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // We also replace the resume argument and insert an `Assign`. // This is needed because the resume argument `_2` might be live across a `yield`, in which - // case there is no `Assign` to it that the transform can turn into a store to the generator - // state. After the yield the slot in the generator state would then be uninitialized. + // case there is no `Assign` to it that the transform can turn into a store to the coroutine + // state. After the yield the slot in the coroutine state would then be uninitialized. let resume_local = Local::new(2); let resume_ty = if is_async_kind { Ty::new_task_context(tcx) @@ -1482,7 +1541,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform { }; let new_resume_local = replace_local(resume_local, resume_ty, body, tcx); - // When first entering the generator, move the resume argument into its new local. + // When first entering the coroutine, move the resume argument into its new local. let source_info = SourceInfo::outermost(body.span); let stmts = &mut body.basic_blocks_mut()[START_BLOCK].statements; stmts.insert( @@ -1502,7 +1561,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform { locals_live_across_suspend_points(tcx, body, &always_live_locals, movable); if tcx.sess.opts.unstable_opts.validate_mir { - let mut vis = EnsureGeneratorFieldAssignmentsNeverAlias { + let mut vis = EnsureCoroutineFieldAssignmentsNeverAlias { assigned_local: None, saved_locals: &liveness_info.saved_locals, storage_conflicts: &liveness_info.storage_conflicts, @@ -1512,20 +1571,20 @@ impl<'tcx> MirPass<'tcx> for StateTransform { } // Extract locals which are live across suspension point into `layout` - // `remap` gives a mapping from local indices onto generator struct indices + // `remap` gives a mapping from local indices onto coroutine struct indices // `storage_liveness` tells us which locals have live storage at suspension points let (remap, layout, storage_liveness) = compute_layout(liveness_info, body); let can_return = can_return(tcx, body, tcx.param_env(body.source.def_id())); - // Run the transformation which converts Places from Local to generator struct + // Run the transformation which converts Places from Local to coroutine struct // accesses for locals in `remap`. - // It also rewrites `return x` and `yield y` as writing a new generator state and returning - // either GeneratorState::Complete(x) and GeneratorState::Yielded(y), + // It also rewrites `return x` and `yield y` as writing a new coroutine state and returning + // either CoroutineState::Complete(x) and CoroutineState::Yielded(y), // or Poll::Ready(x) and Poll::Pending respectively depending on `is_async_kind`. let mut transform = TransformVisitor { tcx, - is_async_kind, + coroutine_kind: body.coroutine_kind().unwrap(), state_adt_ref, state_args, remap, @@ -1548,30 +1607,30 @@ impl<'tcx> MirPass<'tcx> for StateTransform { var.argument_index = None; } - body.generator.as_mut().unwrap().yield_ty = None; - body.generator.as_mut().unwrap().generator_layout = Some(layout); + body.coroutine.as_mut().unwrap().yield_ty = None; + body.coroutine.as_mut().unwrap().coroutine_layout = Some(layout); - // Insert `drop(generator_struct)` which is used to drop upvars for generators in + // Insert `drop(coroutine_struct)` which is used to drop upvars for coroutines in // the unresumed state. - // This is expanded to a drop ladder in `elaborate_generator_drops`. + // This is expanded to a drop ladder in `elaborate_coroutine_drops`. let drop_clean = insert_clean_drop(body); - dump_mir(tcx, false, "generator_pre-elab", &0, body, |_, _| Ok(())); + dump_mir(tcx, false, "coroutine_pre-elab", &0, body, |_, _| Ok(())); - // Expand `drop(generator_struct)` to a drop ladder which destroys upvars. + // Expand `drop(coroutine_struct)` to a drop ladder which destroys upvars. // If any upvars are moved out of, drop elaboration will handle upvar destruction. // However we need to also elaborate the code generated by `insert_clean_drop`. - elaborate_generator_drops(tcx, body); + elaborate_coroutine_drops(tcx, body); - dump_mir(tcx, false, "generator_post-transform", &0, body, |_, _| Ok(())); + dump_mir(tcx, false, "coroutine_post-transform", &0, body, |_, _| Ok(())); - // Create a copy of our MIR and use it to create the drop shim for the generator - let drop_shim = create_generator_drop_shim(tcx, &transform, gen_ty, body, drop_clean); + // Create a copy of our MIR and use it to create the drop shim for the coroutine + let drop_shim = create_coroutine_drop_shim(tcx, &transform, coroutine_ty, body, drop_clean); - body.generator.as_mut().unwrap().generator_drop = Some(drop_shim); + body.coroutine.as_mut().unwrap().coroutine_drop = Some(drop_shim); - // Create the Generator::resume / Future::poll function - create_generator_resume_function(tcx, transform, body, can_return); + // Create the Coroutine::resume / Future::poll function + create_coroutine_resume_function(tcx, transform, body, can_return); // Run derefer to fix Derefs that are not in the first place deref_finder(tcx, body); @@ -1579,25 +1638,25 @@ impl<'tcx> MirPass<'tcx> for StateTransform { } /// Looks for any assignments between locals (e.g., `_4 = _5`) that will both be converted to fields -/// in the generator state machine but whose storage is not marked as conflicting +/// in the coroutine state machine but whose storage is not marked as conflicting /// /// Validation needs to happen immediately *before* `TransformVisitor` is invoked, not after. /// /// This condition would arise when the assignment is the last use of `_5` but the initial /// definition of `_4` if we weren't extra careful to mark all locals used inside a statement as -/// conflicting. Non-conflicting generator saved locals may be stored at the same location within -/// the generator state machine, which would result in ill-formed MIR: the left-hand and right-hand +/// conflicting. Non-conflicting coroutine saved locals may be stored at the same location within +/// the coroutine state machine, which would result in ill-formed MIR: the left-hand and right-hand /// sides of an assignment may not alias. This caused a miscompilation in [#73137]. /// /// [#73137]: https://github.com/rust-lang/rust/issues/73137 -struct EnsureGeneratorFieldAssignmentsNeverAlias<'a> { - saved_locals: &'a GeneratorSavedLocals, - storage_conflicts: &'a BitMatrix<GeneratorSavedLocal, GeneratorSavedLocal>, - assigned_local: Option<GeneratorSavedLocal>, +struct EnsureCoroutineFieldAssignmentsNeverAlias<'a> { + saved_locals: &'a CoroutineSavedLocals, + storage_conflicts: &'a BitMatrix<CoroutineSavedLocal, CoroutineSavedLocal>, + assigned_local: Option<CoroutineSavedLocal>, } -impl EnsureGeneratorFieldAssignmentsNeverAlias<'_> { - fn saved_local_for_direct_place(&self, place: Place<'_>) -> Option<GeneratorSavedLocal> { +impl EnsureCoroutineFieldAssignmentsNeverAlias<'_> { + fn saved_local_for_direct_place(&self, place: Place<'_>) -> Option<CoroutineSavedLocal> { if place.is_indirect() { return None; } @@ -1616,7 +1675,7 @@ impl EnsureGeneratorFieldAssignmentsNeverAlias<'_> { } } -impl<'tcx> Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> { +impl<'tcx> Visitor<'tcx> for EnsureCoroutineFieldAssignmentsNeverAlias<'_> { fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) { let Some(lhs) = self.assigned_local else { // This visitor only invokes `visit_place` for the right-hand side of an assignment @@ -1631,7 +1690,7 @@ impl<'tcx> Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> { if !self.storage_conflicts.contains(lhs, rhs) { bug!( - "Assignment between generator saved locals whose storage is not \ + "Assignment between coroutine saved locals whose storage is not \ marked as conflicting: {:?}: {:?} = {:?}", location, lhs, @@ -1698,14 +1757,14 @@ impl<'tcx> Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> { | TerminatorKind::Unreachable | TerminatorKind::Drop { .. } | TerminatorKind::Assert { .. } - | TerminatorKind::GeneratorDrop + | TerminatorKind::CoroutineDrop | TerminatorKind::FalseEdge { .. } | TerminatorKind::FalseUnwind { .. } => {} } } } -fn check_suspend_tys<'tcx>(tcx: TyCtxt<'tcx>, layout: &GeneratorLayout<'tcx>, body: &Body<'tcx>) { +fn check_suspend_tys<'tcx>(tcx: TyCtxt<'tcx>, layout: &CoroutineLayout<'tcx>, body: &Body<'tcx>) { let mut linted_tys = FxHashSet::default(); // We want a user-facing param-env. diff --git a/compiler/rustc_mir_transform/src/cost_checker.rs b/compiler/rustc_mir_transform/src/cost_checker.rs new file mode 100644 index 000000000..9bb26693c --- /dev/null +++ b/compiler/rustc_mir_transform/src/cost_checker.rs @@ -0,0 +1,98 @@ +use rustc_middle::mir::visit::*; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt}; + +const INSTR_COST: usize = 5; +const CALL_PENALTY: usize = 25; +const LANDINGPAD_PENALTY: usize = 50; +const RESUME_PENALTY: usize = 45; + +/// Verify that the callee body is compatible with the caller. +#[derive(Clone)] +pub(crate) struct CostChecker<'b, 'tcx> { + tcx: TyCtxt<'tcx>, + param_env: ParamEnv<'tcx>, + cost: usize, + callee_body: &'b Body<'tcx>, + instance: Option<ty::Instance<'tcx>>, +} + +impl<'b, 'tcx> CostChecker<'b, 'tcx> { + pub fn new( + tcx: TyCtxt<'tcx>, + param_env: ParamEnv<'tcx>, + instance: Option<ty::Instance<'tcx>>, + callee_body: &'b Body<'tcx>, + ) -> CostChecker<'b, 'tcx> { + CostChecker { tcx, param_env, callee_body, instance, cost: 0 } + } + + pub fn cost(&self) -> usize { + self.cost + } + + fn instantiate_ty(&self, v: Ty<'tcx>) -> Ty<'tcx> { + if let Some(instance) = self.instance { + instance.instantiate_mir(self.tcx, ty::EarlyBinder::bind(&v)) + } else { + v + } + } +} + +impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> { + fn visit_statement(&mut self, statement: &Statement<'tcx>, _: Location) { + // Don't count StorageLive/StorageDead in the inlining cost. + match statement.kind { + StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::Deinit(_) + | StatementKind::Nop => {} + _ => self.cost += INSTR_COST, + } + } + + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, _: Location) { + let tcx = self.tcx; + match terminator.kind { + TerminatorKind::Drop { ref place, unwind, .. } => { + // If the place doesn't actually need dropping, treat it like a regular goto. + let ty = self.instantiate_ty(place.ty(self.callee_body, tcx).ty); + if ty.needs_drop(tcx, self.param_env) { + self.cost += CALL_PENALTY; + if let UnwindAction::Cleanup(_) = unwind { + self.cost += LANDINGPAD_PENALTY; + } + } else { + self.cost += INSTR_COST; + } + } + TerminatorKind::Call { func: Operand::Constant(ref f), unwind, .. } => { + let fn_ty = self.instantiate_ty(f.const_.ty()); + self.cost += if let ty::FnDef(def_id, _) = *fn_ty.kind() && tcx.is_intrinsic(def_id) { + // Don't give intrinsics the extra penalty for calls + INSTR_COST + } else { + CALL_PENALTY + }; + if let UnwindAction::Cleanup(_) = unwind { + self.cost += LANDINGPAD_PENALTY; + } + } + TerminatorKind::Assert { unwind, .. } => { + self.cost += CALL_PENALTY; + if let UnwindAction::Cleanup(_) = unwind { + self.cost += LANDINGPAD_PENALTY; + } + } + TerminatorKind::UnwindResume => self.cost += RESUME_PENALTY, + TerminatorKind::InlineAsm { unwind, .. } => { + self.cost += INSTR_COST; + if let UnwindAction::Cleanup(_) = unwind { + self.cost += LANDINGPAD_PENALTY; + } + } + _ => self.cost += INSTR_COST, + } + } +} diff --git a/compiler/rustc_mir_transform/src/coverage/counters.rs b/compiler/rustc_mir_transform/src/coverage/counters.rs index d56d4ad4f..b34ec95b4 100644 --- a/compiler/rustc_mir_transform/src/coverage/counters.rs +++ b/compiler/rustc_mir_transform/src/coverage/counters.rs @@ -1,10 +1,6 @@ -use super::Error; - use super::graph; -use super::spans; use graph::{BasicCoverageBlock, BcbBranch, CoverageGraph, TraverseCoverageGraphWithLoops}; -use spans::CoverageSpan; use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::graph::WithNumNodes; @@ -14,14 +10,12 @@ use rustc_middle::mir::coverage::*; use std::fmt::{self, Debug}; -const NESTED_INDENT: &str = " "; - /// The coverage counter or counter expression associated with a particular /// BCB node or BCB edge. #[derive(Clone)] pub(super) enum BcbCounter { Counter { id: CounterId }, - Expression { id: ExpressionId, lhs: Operand, op: Op, rhs: Operand }, + Expression { id: ExpressionId }, } impl BcbCounter { @@ -29,10 +23,10 @@ impl BcbCounter { matches!(self, Self::Expression { .. }) } - pub(super) fn as_operand(&self) -> Operand { + pub(super) fn as_term(&self) -> CovTerm { match *self { - BcbCounter::Counter { id, .. } => Operand::Counter(id), - BcbCounter::Expression { id, .. } => Operand::Expression(id), + BcbCounter::Counter { id, .. } => CovTerm::Counter(id), + BcbCounter::Expression { id, .. } => CovTerm::Expression(id), } } } @@ -41,17 +35,7 @@ impl Debug for BcbCounter { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Counter { id, .. } => write!(fmt, "Counter({:?})", id.index()), - Self::Expression { id, lhs, op, rhs } => write!( - fmt, - "Expression({:?}) = {:?} {} {:?}", - id.index(), - lhs, - match op { - Op::Add => "+", - Op::Subtract => "-", - }, - rhs, - ), + Self::Expression { id } => write!(fmt, "Expression({:?})", id.index()), } } } @@ -60,7 +44,6 @@ impl Debug for BcbCounter { /// associated with nodes/edges in the BCB graph. pub(super) struct CoverageCounters { next_counter_id: CounterId, - next_expression_id: ExpressionId, /// Coverage counters/expressions that are associated with individual BCBs. bcb_counters: IndexVec<BasicCoverageBlock, Option<BcbCounter>>, @@ -68,13 +51,12 @@ pub(super) struct CoverageCounters { /// edge between two BCBs. bcb_edge_counters: FxHashMap<(BasicCoverageBlock, BasicCoverageBlock), BcbCounter>, /// Tracks which BCBs have a counter associated with some incoming edge. - /// Only used by debug assertions, to verify that BCBs with incoming edge + /// Only used by assertions, to verify that BCBs with incoming edge /// counters do not have their own physical counters (expressions are allowed). bcb_has_incoming_edge_counters: BitSet<BasicCoverageBlock>, - /// Expression nodes that are not directly associated with any particular - /// BCB/edge, but are needed as operands to more complex expressions. - /// These are always [`BcbCounter::Expression`]. - pub(super) intermediate_expressions: Vec<BcbCounter>, + /// Table of expression data, associating each expression ID with its + /// corresponding operator (+ or -) and its LHS/RHS operands. + expressions: IndexVec<ExpressionId, Expression>, } impl CoverageCounters { @@ -83,24 +65,22 @@ impl CoverageCounters { Self { next_counter_id: CounterId::START, - next_expression_id: ExpressionId::START, - bcb_counters: IndexVec::from_elem_n(None, num_bcbs), bcb_edge_counters: FxHashMap::default(), bcb_has_incoming_edge_counters: BitSet::new_empty(num_bcbs), - intermediate_expressions: Vec::new(), + expressions: IndexVec::new(), } } /// Makes [`BcbCounter`] `Counter`s and `Expressions` for the `BasicCoverageBlock`s directly or - /// indirectly associated with `CoverageSpans`, and accumulates additional `Expression`s + /// indirectly associated with coverage spans, and accumulates additional `Expression`s /// representing intermediate values. pub fn make_bcb_counters( &mut self, basic_coverage_blocks: &CoverageGraph, - coverage_spans: &[CoverageSpan], - ) -> Result<(), Error> { - MakeBcbCounters::new(self, basic_coverage_blocks).make_bcb_counters(coverage_spans) + bcb_has_coverage_spans: impl Fn(BasicCoverageBlock) -> bool, + ) { + MakeBcbCounters::new(self, basic_coverage_blocks).make_bcb_counters(bcb_has_coverage_spans) } fn make_counter(&mut self) -> BcbCounter { @@ -108,50 +88,44 @@ impl CoverageCounters { BcbCounter::Counter { id } } - fn make_expression(&mut self, lhs: Operand, op: Op, rhs: Operand) -> BcbCounter { - let id = self.next_expression(); - BcbCounter::Expression { id, lhs, op, rhs } - } - - pub fn make_identity_counter(&mut self, counter_operand: Operand) -> BcbCounter { - self.make_expression(counter_operand, Op::Add, Operand::Zero) + fn make_expression(&mut self, lhs: CovTerm, op: Op, rhs: CovTerm) -> BcbCounter { + let id = self.expressions.push(Expression { lhs, op, rhs }); + BcbCounter::Expression { id } } /// Counter IDs start from one and go up. fn next_counter(&mut self) -> CounterId { let next = self.next_counter_id; - self.next_counter_id = next.next_id(); + self.next_counter_id = self.next_counter_id + 1; next } - /// Expression IDs start from 0 and go up. - /// (Counter IDs and Expression IDs are distinguished by the `Operand` enum.) - fn next_expression(&mut self) -> ExpressionId { - let next = self.next_expression_id; - self.next_expression_id = next.next_id(); - next + pub(super) fn num_counters(&self) -> usize { + self.next_counter_id.as_usize() } - fn set_bcb_counter( - &mut self, - bcb: BasicCoverageBlock, - counter_kind: BcbCounter, - ) -> Result<Operand, Error> { - debug_assert!( + #[cfg(test)] + pub(super) fn num_expressions(&self) -> usize { + self.expressions.len() + } + + fn set_bcb_counter(&mut self, bcb: BasicCoverageBlock, counter_kind: BcbCounter) -> CovTerm { + assert!( // If the BCB has an edge counter (to be injected into a new `BasicBlock`), it can also // have an expression (to be injected into an existing `BasicBlock` represented by this // `BasicCoverageBlock`). counter_kind.is_expression() || !self.bcb_has_incoming_edge_counters.contains(bcb), "attempt to add a `Counter` to a BCB target with existing incoming edge counters" ); - let operand = counter_kind.as_operand(); + + let term = counter_kind.as_term(); if let Some(replaced) = self.bcb_counters[bcb].replace(counter_kind) { - Error::from_string(format!( + bug!( "attempt to set a BasicCoverageBlock coverage counter more than once; \ {bcb:?} already had counter {replaced:?}", - )) + ); } else { - Ok(operand) + term } } @@ -160,27 +134,26 @@ impl CoverageCounters { from_bcb: BasicCoverageBlock, to_bcb: BasicCoverageBlock, counter_kind: BcbCounter, - ) -> Result<Operand, Error> { - if level_enabled!(tracing::Level::DEBUG) { - // If the BCB has an edge counter (to be injected into a new `BasicBlock`), it can also - // have an expression (to be injected into an existing `BasicBlock` represented by this - // `BasicCoverageBlock`). - if self.bcb_counter(to_bcb).is_some_and(|c| !c.is_expression()) { - return Error::from_string(format!( - "attempt to add an incoming edge counter from {from_bcb:?} when the target BCB already \ - has a `Counter`" - )); - } + ) -> CovTerm { + // If the BCB has an edge counter (to be injected into a new `BasicBlock`), it can also + // have an expression (to be injected into an existing `BasicBlock` represented by this + // `BasicCoverageBlock`). + if let Some(node_counter) = self.bcb_counter(to_bcb) && !node_counter.is_expression() { + bug!( + "attempt to add an incoming edge counter from {from_bcb:?} \ + when the target BCB already has {node_counter:?}" + ); } + self.bcb_has_incoming_edge_counters.insert(to_bcb); - let operand = counter_kind.as_operand(); + let term = counter_kind.as_term(); if let Some(replaced) = self.bcb_edge_counters.insert((from_bcb, to_bcb), counter_kind) { - Error::from_string(format!( + bug!( "attempt to set an edge counter more than once; from_bcb: \ {from_bcb:?} already had counter {replaced:?}", - )) + ); } else { - Ok(operand) + term } } @@ -188,27 +161,31 @@ impl CoverageCounters { self.bcb_counters[bcb].as_ref() } - pub(super) fn take_bcb_counter(&mut self, bcb: BasicCoverageBlock) -> Option<BcbCounter> { - self.bcb_counters[bcb].take() + pub(super) fn bcb_node_counters( + &self, + ) -> impl Iterator<Item = (BasicCoverageBlock, &BcbCounter)> { + self.bcb_counters + .iter_enumerated() + .filter_map(|(bcb, counter_kind)| Some((bcb, counter_kind.as_ref()?))) } - pub(super) fn drain_bcb_counters( - &mut self, - ) -> impl Iterator<Item = (BasicCoverageBlock, BcbCounter)> + '_ { - self.bcb_counters - .iter_enumerated_mut() - .filter_map(|(bcb, counter)| Some((bcb, counter.take()?))) + /// For each edge in the BCB graph that has an associated counter, yields + /// that edge's *from* and *to* nodes, and its counter. + pub(super) fn bcb_edge_counters( + &self, + ) -> impl Iterator<Item = (BasicCoverageBlock, BasicCoverageBlock, &BcbCounter)> { + self.bcb_edge_counters + .iter() + .map(|(&(from_bcb, to_bcb), counter_kind)| (from_bcb, to_bcb, counter_kind)) } - pub(super) fn drain_bcb_edge_counters( - &mut self, - ) -> impl Iterator<Item = ((BasicCoverageBlock, BasicCoverageBlock), BcbCounter)> + '_ { - self.bcb_edge_counters.drain() + pub(super) fn take_expressions(&mut self) -> IndexVec<ExpressionId, Expression> { + std::mem::take(&mut self.expressions) } } /// Traverse the `CoverageGraph` and add either a `Counter` or `Expression` to every BCB, to be -/// injected with `CoverageSpan`s. `Expressions` have no runtime overhead, so if a viable expression +/// injected with coverage spans. `Expressions` have no runtime overhead, so if a viable expression /// (adding or subtracting two other counters or expressions) can compute the same result as an /// embedded counter, an `Expression` should be used. struct MakeBcbCounters<'a> { @@ -230,21 +207,11 @@ impl<'a> MakeBcbCounters<'a> { /// One way to predict which branch executes the least is by considering loops. A loop is exited /// at a branch, so the branch that jumps to a `BasicCoverageBlock` outside the loop is almost /// always executed less than the branch that does not exit the loop. - /// - /// Returns any non-code-span expressions created to represent intermediate values (such as to - /// add two counters so the result can be subtracted from another counter), or an Error with - /// message for subsequent debugging. - fn make_bcb_counters(&mut self, coverage_spans: &[CoverageSpan]) -> Result<(), Error> { + fn make_bcb_counters(&mut self, bcb_has_coverage_spans: impl Fn(BasicCoverageBlock) -> bool) { debug!("make_bcb_counters(): adding a counter or expression to each BasicCoverageBlock"); - let num_bcbs = self.basic_coverage_blocks.num_nodes(); - - let mut bcbs_with_coverage = BitSet::new_empty(num_bcbs); - for covspan in coverage_spans { - bcbs_with_coverage.insert(covspan.bcb); - } // Walk the `CoverageGraph`. For each `BasicCoverageBlock` node with an associated - // `CoverageSpan`, add a counter. If the `BasicCoverageBlock` branches, add a counter or + // coverage span, add a counter. If the `BasicCoverageBlock` branches, add a counter or // expression to each branch `BasicCoverageBlock` (if the branch BCB has only one incoming // edge) or edge from the branching BCB to the branch BCB (if the branch BCB has multiple // incoming edges). @@ -254,39 +221,36 @@ impl<'a> MakeBcbCounters<'a> { // the loop. The `traversal` state includes a `context_stack`, providing a way to know if // the current BCB is in one or more nested loops or not. let mut traversal = TraverseCoverageGraphWithLoops::new(&self.basic_coverage_blocks); - while let Some(bcb) = traversal.next(self.basic_coverage_blocks) { - if bcbs_with_coverage.contains(bcb) { - debug!("{:?} has at least one `CoverageSpan`. Get or make its counter", bcb); - let branching_counter_operand = self.get_or_make_counter_operand(bcb)?; + while let Some(bcb) = traversal.next() { + if bcb_has_coverage_spans(bcb) { + debug!("{:?} has at least one coverage span. Get or make its counter", bcb); + let branching_counter_operand = self.get_or_make_counter_operand(bcb); if self.bcb_needs_branch_counters(bcb) { - self.make_branch_counters(&mut traversal, bcb, branching_counter_operand)?; + self.make_branch_counters(&traversal, bcb, branching_counter_operand); } } else { debug!( - "{:?} does not have any `CoverageSpan`s. A counter will only be added if \ + "{:?} does not have any coverage spans. A counter will only be added if \ and when a covered BCB has an expression dependency.", bcb, ); } } - if traversal.is_complete() { - Ok(()) - } else { - Error::from_string(format!( - "`TraverseCoverageGraphWithLoops` missed some `BasicCoverageBlock`s: {:?}", - traversal.unvisited(), - )) - } + assert!( + traversal.is_complete(), + "`TraverseCoverageGraphWithLoops` missed some `BasicCoverageBlock`s: {:?}", + traversal.unvisited(), + ); } fn make_branch_counters( &mut self, - traversal: &mut TraverseCoverageGraphWithLoops, + traversal: &TraverseCoverageGraphWithLoops<'_>, branching_bcb: BasicCoverageBlock, - branching_counter_operand: Operand, - ) -> Result<(), Error> { + branching_counter_operand: CovTerm, + ) { let branches = self.bcb_branches(branching_bcb); debug!( "{:?} has some branch(es) without counters:\n {}", @@ -319,10 +283,10 @@ impl<'a> MakeBcbCounters<'a> { counter", branch, branching_bcb ); - self.get_or_make_counter_operand(branch.target_bcb)? + self.get_or_make_counter_operand(branch.target_bcb) } else { debug!(" {:?} has multiple incoming edges, so adding an edge counter", branch); - self.get_or_make_edge_counter_operand(branching_bcb, branch.target_bcb)? + self.get_or_make_edge_counter_operand(branching_bcb, branch.target_bcb) }; if let Some(sumup_counter_operand) = some_sumup_counter_operand.replace(branch_counter_operand) @@ -333,8 +297,7 @@ impl<'a> MakeBcbCounters<'a> { sumup_counter_operand, ); debug!(" [new intermediate expression: {:?}]", intermediate_expression); - let intermediate_expression_operand = intermediate_expression.as_operand(); - self.coverage_counters.intermediate_expressions.push(intermediate_expression); + let intermediate_expression_operand = intermediate_expression.as_term(); some_sumup_counter_operand.replace(intermediate_expression_operand); } } @@ -358,31 +321,18 @@ impl<'a> MakeBcbCounters<'a> { debug!("{:?} gets an expression: {:?}", expression_branch, expression); let bcb = expression_branch.target_bcb; if expression_branch.is_only_path_to_target() { - self.coverage_counters.set_bcb_counter(bcb, expression)?; + self.coverage_counters.set_bcb_counter(bcb, expression); } else { - self.coverage_counters.set_bcb_edge_counter(branching_bcb, bcb, expression)?; + self.coverage_counters.set_bcb_edge_counter(branching_bcb, bcb, expression); } - Ok(()) - } - - fn get_or_make_counter_operand(&mut self, bcb: BasicCoverageBlock) -> Result<Operand, Error> { - self.recursive_get_or_make_counter_operand(bcb, 1) } - fn recursive_get_or_make_counter_operand( - &mut self, - bcb: BasicCoverageBlock, - debug_indent_level: usize, - ) -> Result<Operand, Error> { + #[instrument(level = "debug", skip(self))] + fn get_or_make_counter_operand(&mut self, bcb: BasicCoverageBlock) -> CovTerm { // If the BCB already has a counter, return it. if let Some(counter_kind) = &self.coverage_counters.bcb_counters[bcb] { - debug!( - "{}{:?} already has a counter: {:?}", - NESTED_INDENT.repeat(debug_indent_level), - bcb, - counter_kind, - ); - return Ok(counter_kind.as_operand()); + debug!("{bcb:?} already has a counter: {counter_kind:?}"); + return counter_kind.as_term(); } // A BCB with only one incoming edge gets a simple `Counter` (via `make_counter()`). @@ -392,20 +342,12 @@ impl<'a> MakeBcbCounters<'a> { if one_path_to_target || self.bcb_predecessors(bcb).contains(&bcb) { let counter_kind = self.coverage_counters.make_counter(); if one_path_to_target { - debug!( - "{}{:?} gets a new counter: {:?}", - NESTED_INDENT.repeat(debug_indent_level), - bcb, - counter_kind, - ); + debug!("{bcb:?} gets a new counter: {counter_kind:?}"); } else { debug!( - "{}{:?} has itself as its own predecessor. It can't be part of its own \ - Expression sum, so it will get its own new counter: {:?}. (Note, the compiled \ - code will generate an infinite loop.)", - NESTED_INDENT.repeat(debug_indent_level), - bcb, - counter_kind, + "{bcb:?} has itself as its own predecessor. It can't be part of its own \ + Expression sum, so it will get its own new counter: {counter_kind:?}. \ + (Note, the compiled code will generate an infinite loop.)", ); } return self.coverage_counters.set_bcb_counter(bcb, counter_kind); @@ -415,24 +357,14 @@ impl<'a> MakeBcbCounters<'a> { // counters and/or expressions of its incoming edges. This will recursively get or create // counters for those incoming edges first, then call `make_expression()` to sum them up, // with additional intermediate expressions as needed. + let _sumup_debug_span = debug_span!("(preparing sum-up expression)").entered(); + let mut predecessors = self.bcb_predecessors(bcb).to_owned().into_iter(); - debug!( - "{}{:?} has multiple incoming edges and will get an expression that sums them up...", - NESTED_INDENT.repeat(debug_indent_level), - bcb, - ); - let first_edge_counter_operand = self.recursive_get_or_make_edge_counter_operand( - predecessors.next().unwrap(), - bcb, - debug_indent_level + 1, - )?; + let first_edge_counter_operand = + self.get_or_make_edge_counter_operand(predecessors.next().unwrap(), bcb); let mut some_sumup_edge_counter_operand = None; for predecessor in predecessors { - let edge_counter_operand = self.recursive_get_or_make_edge_counter_operand( - predecessor, - bcb, - debug_indent_level + 1, - )?; + let edge_counter_operand = self.get_or_make_edge_counter_operand(predecessor, bcb); if let Some(sumup_edge_counter_operand) = some_sumup_edge_counter_operand.replace(edge_counter_operand) { @@ -441,13 +373,8 @@ impl<'a> MakeBcbCounters<'a> { Op::Add, edge_counter_operand, ); - debug!( - "{}new intermediate expression: {:?}", - NESTED_INDENT.repeat(debug_indent_level), - intermediate_expression - ); - let intermediate_expression_operand = intermediate_expression.as_operand(); - self.coverage_counters.intermediate_expressions.push(intermediate_expression); + debug!("new intermediate expression: {intermediate_expression:?}"); + let intermediate_expression_operand = intermediate_expression.as_term(); some_sumup_edge_counter_operand.replace(intermediate_expression_operand); } } @@ -456,59 +383,36 @@ impl<'a> MakeBcbCounters<'a> { Op::Add, some_sumup_edge_counter_operand.unwrap(), ); - debug!( - "{}{:?} gets a new counter (sum of predecessor counters): {:?}", - NESTED_INDENT.repeat(debug_indent_level), - bcb, - counter_kind - ); + drop(_sumup_debug_span); + + debug!("{bcb:?} gets a new counter (sum of predecessor counters): {counter_kind:?}"); self.coverage_counters.set_bcb_counter(bcb, counter_kind) } + #[instrument(level = "debug", skip(self))] fn get_or_make_edge_counter_operand( &mut self, from_bcb: BasicCoverageBlock, to_bcb: BasicCoverageBlock, - ) -> Result<Operand, Error> { - self.recursive_get_or_make_edge_counter_operand(from_bcb, to_bcb, 1) - } - - fn recursive_get_or_make_edge_counter_operand( - &mut self, - from_bcb: BasicCoverageBlock, - to_bcb: BasicCoverageBlock, - debug_indent_level: usize, - ) -> Result<Operand, Error> { + ) -> CovTerm { // If the source BCB has only one successor (assumed to be the given target), an edge // counter is unnecessary. Just get or make a counter for the source BCB. let successors = self.bcb_successors(from_bcb).iter(); if successors.len() == 1 { - return self.recursive_get_or_make_counter_operand(from_bcb, debug_indent_level + 1); + return self.get_or_make_counter_operand(from_bcb); } // If the edge already has a counter, return it. if let Some(counter_kind) = self.coverage_counters.bcb_edge_counters.get(&(from_bcb, to_bcb)) { - debug!( - "{}Edge {:?}->{:?} already has a counter: {:?}", - NESTED_INDENT.repeat(debug_indent_level), - from_bcb, - to_bcb, - counter_kind - ); - return Ok(counter_kind.as_operand()); + debug!("Edge {from_bcb:?}->{to_bcb:?} already has a counter: {counter_kind:?}"); + return counter_kind.as_term(); } // Make a new counter to count this edge. let counter_kind = self.coverage_counters.make_counter(); - debug!( - "{}Edge {:?}->{:?} gets a new counter: {:?}", - NESTED_INDENT.repeat(debug_indent_level), - from_bcb, - to_bcb, - counter_kind - ); + debug!("Edge {from_bcb:?}->{to_bcb:?} gets a new counter: {counter_kind:?}"); self.coverage_counters.set_bcb_edge_counter(from_bcb, to_bcb, counter_kind) } @@ -516,21 +420,14 @@ impl<'a> MakeBcbCounters<'a> { /// found, select any branch. fn choose_preferred_expression_branch( &self, - traversal: &TraverseCoverageGraphWithLoops, + traversal: &TraverseCoverageGraphWithLoops<'_>, branches: &[BcbBranch], ) -> BcbBranch { - let branch_needs_a_counter = |branch: &BcbBranch| self.branch_has_no_counter(branch); - - let some_reloop_branch = self.find_some_reloop_branch(traversal, &branches); - if let Some(reloop_branch_without_counter) = - some_reloop_branch.filter(branch_needs_a_counter) - { - debug!( - "Selecting reloop_branch={:?} that still needs a counter, to get the \ - `Expression`", - reloop_branch_without_counter - ); - reloop_branch_without_counter + let good_reloop_branch = self.find_good_reloop_branch(traversal, &branches); + if let Some(reloop_branch) = good_reloop_branch { + assert!(self.branch_has_no_counter(&reloop_branch)); + debug!("Selecting reloop branch {reloop_branch:?} to get an expression"); + reloop_branch } else { let &branch_without_counter = branches.iter().find(|&branch| self.branch_has_no_counter(branch)).expect( @@ -547,75 +444,52 @@ impl<'a> MakeBcbCounters<'a> { } } - /// At most, one of the branches (or its edge, from the branching_bcb, if the branch has - /// multiple incoming edges) can have a counter computed by expression. - /// - /// If at least one of the branches leads outside of a loop (`found_loop_exit` is - /// true), and at least one other branch does not exit the loop (the first of which - /// is captured in `some_reloop_branch`), it's likely any reloop branch will be - /// executed far more often than loop exit branch, making the reloop branch a better - /// candidate for an expression. - fn find_some_reloop_branch( + /// Tries to find a branch that leads back to the top of a loop, and that + /// doesn't already have a counter. Such branches are good candidates to + /// be given an expression (instead of a physical counter), because they + /// will tend to be executed more times than a loop-exit branch. + fn find_good_reloop_branch( &self, - traversal: &TraverseCoverageGraphWithLoops, + traversal: &TraverseCoverageGraphWithLoops<'_>, branches: &[BcbBranch], ) -> Option<BcbBranch> { - let branch_needs_a_counter = |branch: &BcbBranch| self.branch_has_no_counter(branch); - - let mut some_reloop_branch: Option<BcbBranch> = None; - for context in traversal.context_stack.iter().rev() { - if let Some((backedge_from_bcbs, _)) = &context.loop_backedges { - let mut found_loop_exit = false; - for &branch in branches.iter() { - if backedge_from_bcbs.iter().any(|&backedge_from_bcb| { - self.bcb_dominates(branch.target_bcb, backedge_from_bcb) - }) { - if let Some(reloop_branch) = some_reloop_branch { - if self.branch_has_no_counter(&reloop_branch) { - // we already found a candidate reloop_branch that still - // needs a counter - continue; - } - } - // The path from branch leads back to the top of the loop. Set this - // branch as the `reloop_branch`. If this branch already has a - // counter, and we find another reloop branch that doesn't have a - // counter yet, that branch will be selected as the `reloop_branch` - // instead. - some_reloop_branch = Some(branch); - } else { - // The path from branch leads outside this loop - found_loop_exit = true; - } - if found_loop_exit - && some_reloop_branch.filter(branch_needs_a_counter).is_some() - { - // Found both a branch that exits the loop and a branch that returns - // to the top of the loop (`reloop_branch`), and the `reloop_branch` - // doesn't already have a counter. - break; + // Consider each loop on the current traversal context stack, top-down. + for reloop_bcbs in traversal.reloop_bcbs_per_loop() { + let mut all_branches_exit_this_loop = true; + + // Try to find a branch that doesn't exit this loop and doesn't + // already have a counter. + for &branch in branches { + // A branch is a reloop branch if it dominates any BCB that has + // an edge back to the loop header. (Other branches are exits.) + let is_reloop_branch = reloop_bcbs.iter().any(|&reloop_bcb| { + self.basic_coverage_blocks.dominates(branch.target_bcb, reloop_bcb) + }); + + if is_reloop_branch { + all_branches_exit_this_loop = false; + if self.branch_has_no_counter(&branch) { + // We found a good branch to be given an expression. + return Some(branch); } + // Keep looking for another reloop branch without a counter. + } else { + // This branch exits the loop. } - if !found_loop_exit { - debug!( - "No branches exit the loop, so any branch without an existing \ - counter can have the `Expression`." - ); - break; - } - if some_reloop_branch.is_some() { - debug!( - "Found a branch that exits the loop and a branch the loops back to \ - the top of the loop (`reloop_branch`). The `reloop_branch` will \ - get the `Expression`, as long as it still needs a counter." - ); - break; - } - // else all branches exited this loop context, so run the same checks with - // the outer loop(s) } + + if !all_branches_exit_this_loop { + // We found one or more reloop branches, but all of them already + // have counters. Let the caller choose one of the exit branches. + debug!("All reloop branches had counters; skip checking the other loops"); + return None; + } + + // All of the branches exit this loop, so keep looking for a good + // reloop branch for one of the outer loops. } - some_reloop_branch + + None } #[inline] @@ -661,9 +535,4 @@ impl<'a> MakeBcbCounters<'a> { fn bcb_has_one_path_to_target(&self, bcb: BasicCoverageBlock) -> bool { self.bcb_predecessors(bcb).len() <= 1 } - - #[inline] - fn bcb_dominates(&self, dom: BasicCoverageBlock, node: BasicCoverageBlock) -> bool { - self.basic_coverage_blocks.dominates(dom, node) - } } diff --git a/compiler/rustc_mir_transform/src/coverage/graph.rs b/compiler/rustc_mir_transform/src/coverage/graph.rs index ff2254d69..6bab62aa8 100644 --- a/compiler/rustc_mir_transform/src/coverage/graph.rs +++ b/compiler/rustc_mir_transform/src/coverage/graph.rs @@ -1,10 +1,12 @@ +use rustc_data_structures::captures::Captures; use rustc_data_structures::graph::dominators::{self, Dominators}; use rustc_data_structures::graph::{self, GraphSuccessors, WithNumNodes, WithStartNode}; use rustc_index::bit_set::BitSet; use rustc_index::{IndexSlice, IndexVec}; -use rustc_middle::mir::{self, BasicBlock, BasicBlockData, Terminator, TerminatorKind}; +use rustc_middle::mir::{self, BasicBlock, TerminatorKind}; use std::cmp::Ordering; +use std::collections::VecDeque; use std::ops::{Index, IndexMut}; /// A coverage-specific simplification of the MIR control flow graph (CFG). The `CoverageGraph`s @@ -36,9 +38,8 @@ impl CoverageGraph { } let bcb_data = &bcbs[bcb]; let mut bcb_successors = Vec::new(); - for successor in - bcb_filtered_successors(&mir_body, &bcb_data.terminator(mir_body).kind) - .filter_map(|successor_bb| bb_to_bcb[successor_bb]) + for successor in bcb_filtered_successors(&mir_body, bcb_data.last_bb()) + .filter_map(|successor_bb| bb_to_bcb[successor_bb]) { if !seen[successor] { seen[successor] = true; @@ -80,10 +81,9 @@ impl CoverageGraph { // intentionally omits unwind paths. // FIXME(#78544): MIR InstrumentCoverage: Improve coverage of `#[should_panic]` tests and // `catch_unwind()` handlers. - let mir_cfg_without_unwind = ShortCircuitPreorder::new(&mir_body, bcb_filtered_successors); let mut basic_blocks = Vec::new(); - for (bb, data) in mir_cfg_without_unwind { + for bb in short_circuit_preorder(mir_body, bcb_filtered_successors) { if let Some(last) = basic_blocks.last() { let predecessors = &mir_body.basic_blocks.predecessors()[bb]; if predecessors.len() > 1 || !predecessors.contains(last) { @@ -109,7 +109,7 @@ impl CoverageGraph { } basic_blocks.push(bb); - let term = data.terminator(); + let term = mir_body[bb].terminator(); match term.kind { TerminatorKind::Return { .. } @@ -147,7 +147,7 @@ impl CoverageGraph { | TerminatorKind::Unreachable | TerminatorKind::Drop { .. } | TerminatorKind::Call { .. } - | TerminatorKind::GeneratorDrop + | TerminatorKind::CoroutineDrop | TerminatorKind::Assert { .. } | TerminatorKind::FalseEdge { .. } | TerminatorKind::FalseUnwind { .. } @@ -288,9 +288,9 @@ rustc_index::newtype_index! { /// not relevant to coverage analysis. `FalseUnwind`, for example, can be treated the same as /// a `Goto`, and merged with its successor into the same BCB. /// -/// Each BCB with at least one computed `CoverageSpan` will have no more than one `Counter`. +/// Each BCB with at least one computed coverage span will have no more than one `Counter`. /// In some cases, a BCB's execution count can be computed by `Expression`. Additional -/// disjoint `CoverageSpan`s in a BCB can also be counted by `Expression` (by adding `ZERO` +/// disjoint coverage spans in a BCB can also be counted by `Expression` (by adding `ZERO` /// to the BCB's primary counter or expression). /// /// The BCB CFG is critical to simplifying the coverage analysis by ensuring graph path-based @@ -316,11 +316,6 @@ impl BasicCoverageBlockData { pub fn last_bb(&self) -> BasicBlock { *self.basic_blocks.last().unwrap() } - - #[inline(always)] - pub fn terminator<'a, 'tcx>(&self, mir_body: &'a mir::Body<'tcx>) -> &'a Terminator<'tcx> { - &mir_body[self.last_bb()].terminator() - } } /// Represents a successor from a branching BasicCoverageBlock (such as the arms of a `SwitchInt`) @@ -362,26 +357,28 @@ impl std::fmt::Debug for BcbBranch { } } -// Returns the `Terminator`s non-unwind successors. +// Returns the subset of a block's successors that are relevant to the coverage +// graph, i.e. those that do not represent unwinds or unreachable branches. // FIXME(#78544): MIR InstrumentCoverage: Improve coverage of `#[should_panic]` tests and // `catch_unwind()` handlers. fn bcb_filtered_successors<'a, 'tcx>( body: &'a mir::Body<'tcx>, - term_kind: &'a TerminatorKind<'tcx>, -) -> Box<dyn Iterator<Item = BasicBlock> + 'a> { - Box::new( - match &term_kind { - // SwitchInt successors are never unwind, and all of them should be traversed. - TerminatorKind::SwitchInt { ref targets, .. } => { - None.into_iter().chain(targets.all_targets().into_iter().copied()) - } - // For all other kinds, return only the first successor, if any, and ignore unwinds. - // NOTE: `chain(&[])` is required to coerce the `option::iter` (from - // `next().into_iter()`) into the `mir::Successors` aliased type. - _ => term_kind.successors().next().into_iter().chain((&[]).into_iter().copied()), - } - .filter(move |&successor| body[successor].terminator().kind != TerminatorKind::Unreachable), - ) + bb: BasicBlock, +) -> impl Iterator<Item = BasicBlock> + Captures<'a> + Captures<'tcx> { + let terminator = body[bb].terminator(); + + let take_n_successors = match terminator.kind { + // SwitchInt successors are never unwinds, so all of them should be traversed. + TerminatorKind::SwitchInt { .. } => usize::MAX, + // For all other kinds, return only the first successor (if any), ignoring any + // unwind successors. + _ => 1, + }; + + terminator + .successors() + .take(take_n_successors) + .filter(move |&successor| body[successor].terminator().kind != TerminatorKind::Unreachable) } /// Maintains separate worklists for each loop in the BasicCoverageBlock CFG, plus one for the @@ -389,57 +386,72 @@ fn bcb_filtered_successors<'a, 'tcx>( /// ensures a loop is completely traversed before processing Blocks after the end of the loop. #[derive(Debug)] pub(super) struct TraversalContext { - /// From one or more backedges returning to a loop header. - pub loop_backedges: Option<(Vec<BasicCoverageBlock>, BasicCoverageBlock)>, - - /// worklist, to be traversed, of CoverageGraph in the loop with the given loop - /// backedges, such that the loop is the inner inner-most loop containing these - /// CoverageGraph - pub worklist: Vec<BasicCoverageBlock>, + /// BCB with one or more incoming loop backedges, indicating which loop + /// this context is for. + /// + /// If `None`, this is the non-loop context for the function as a whole. + loop_header: Option<BasicCoverageBlock>, + + /// Worklist of BCBs to be processed in this context. + worklist: VecDeque<BasicCoverageBlock>, } -pub(super) struct TraverseCoverageGraphWithLoops { - pub backedges: IndexVec<BasicCoverageBlock, Vec<BasicCoverageBlock>>, - pub context_stack: Vec<TraversalContext>, +pub(super) struct TraverseCoverageGraphWithLoops<'a> { + basic_coverage_blocks: &'a CoverageGraph, + + backedges: IndexVec<BasicCoverageBlock, Vec<BasicCoverageBlock>>, + context_stack: Vec<TraversalContext>, visited: BitSet<BasicCoverageBlock>, } -impl TraverseCoverageGraphWithLoops { - pub fn new(basic_coverage_blocks: &CoverageGraph) -> Self { - let start_bcb = basic_coverage_blocks.start_node(); +impl<'a> TraverseCoverageGraphWithLoops<'a> { + pub(super) fn new(basic_coverage_blocks: &'a CoverageGraph) -> Self { let backedges = find_loop_backedges(basic_coverage_blocks); - let context_stack = - vec![TraversalContext { loop_backedges: None, worklist: vec![start_bcb] }]; + + let worklist = VecDeque::from([basic_coverage_blocks.start_node()]); + let context_stack = vec![TraversalContext { loop_header: None, worklist }]; + // `context_stack` starts with a `TraversalContext` for the main function context (beginning // with the `start` BasicCoverageBlock of the function). New worklists are pushed to the top // of the stack as loops are entered, and popped off of the stack when a loop's worklist is // exhausted. let visited = BitSet::new_empty(basic_coverage_blocks.num_nodes()); - Self { backedges, context_stack, visited } + Self { basic_coverage_blocks, backedges, context_stack, visited } } - pub fn next(&mut self, basic_coverage_blocks: &CoverageGraph) -> Option<BasicCoverageBlock> { + /// For each loop on the loop context stack (top-down), yields a list of BCBs + /// within that loop that have an outgoing edge back to the loop header. + pub(super) fn reloop_bcbs_per_loop(&self) -> impl Iterator<Item = &[BasicCoverageBlock]> { + self.context_stack + .iter() + .rev() + .filter_map(|context| context.loop_header) + .map(|header_bcb| self.backedges[header_bcb].as_slice()) + } + + pub(super) fn next(&mut self) -> Option<BasicCoverageBlock> { debug!( "TraverseCoverageGraphWithLoops::next - context_stack: {:?}", self.context_stack.iter().rev().collect::<Vec<_>>() ); while let Some(context) = self.context_stack.last_mut() { - if let Some(next_bcb) = context.worklist.pop() { - if !self.visited.insert(next_bcb) { - debug!("Already visited: {:?}", next_bcb); + if let Some(bcb) = context.worklist.pop_front() { + if !self.visited.insert(bcb) { + debug!("Already visited: {bcb:?}"); continue; } - debug!("Visiting {:?}", next_bcb); - if self.backedges[next_bcb].len() > 0 { - debug!("{:?} is a loop header! Start a new TraversalContext...", next_bcb); + debug!("Visiting {bcb:?}"); + + if self.backedges[bcb].len() > 0 { + debug!("{bcb:?} is a loop header! Start a new TraversalContext..."); self.context_stack.push(TraversalContext { - loop_backedges: Some((self.backedges[next_bcb].clone(), next_bcb)), - worklist: Vec::new(), + loop_header: Some(bcb), + worklist: VecDeque::new(), }); } - self.extend_worklist(basic_coverage_blocks, next_bcb); - return Some(next_bcb); + self.add_successors_to_worklists(bcb); + return Some(bcb); } else { // Strip contexts with empty worklists from the top of the stack self.context_stack.pop(); @@ -449,13 +461,10 @@ impl TraverseCoverageGraphWithLoops { None } - pub fn extend_worklist( - &mut self, - basic_coverage_blocks: &CoverageGraph, - bcb: BasicCoverageBlock, - ) { - let successors = &basic_coverage_blocks.successors[bcb]; + pub fn add_successors_to_worklists(&mut self, bcb: BasicCoverageBlock) { + let successors = &self.basic_coverage_blocks.successors[bcb]; debug!("{:?} has {} successors:", bcb, successors.len()); + for &successor in successors { if successor == bcb { debug!( @@ -464,56 +473,44 @@ impl TraverseCoverageGraphWithLoops { bcb ); // Don't re-add this successor to the worklist. We are already processing it. + // FIXME: This claims to skip just the self-successor, but it actually skips + // all other successors as well. Does that matter? break; } - for context in self.context_stack.iter_mut().rev() { - // Add successors of the current BCB to the appropriate context. Successors that - // stay within a loop are added to the BCBs context worklist. Successors that - // exit the loop (they are not dominated by the loop header) must be reachable - // from other BCBs outside the loop, and they will be added to a different - // worklist. - // - // Branching blocks (with more than one successor) must be processed before - // blocks with only one successor, to prevent unnecessarily complicating - // `Expression`s by creating a Counter in a `BasicCoverageBlock` that the - // branching block would have given an `Expression` (or vice versa). - let (some_successor_to_add, some_loop_header) = - if let Some((_, loop_header)) = context.loop_backedges { - if basic_coverage_blocks.dominates(loop_header, successor) { - (Some(successor), Some(loop_header)) - } else { - (None, None) - } - } else { - (Some(successor), None) - }; - if let Some(successor_to_add) = some_successor_to_add { - if basic_coverage_blocks.successors[successor_to_add].len() > 1 { - debug!( - "{:?} successor is branching. Prioritize it at the beginning of \ - the {}", - successor_to_add, - if let Some(loop_header) = some_loop_header { - format!("worklist for the loop headed by {loop_header:?}") - } else { - String::from("non-loop worklist") - }, - ); - context.worklist.insert(0, successor_to_add); - } else { - debug!( - "{:?} successor is non-branching. Defer it to the end of the {}", - successor_to_add, - if let Some(loop_header) = some_loop_header { - format!("worklist for the loop headed by {loop_header:?}") - } else { - String::from("non-loop worklist") - }, - ); - context.worklist.push(successor_to_add); + + // Add successors of the current BCB to the appropriate context. Successors that + // stay within a loop are added to the BCBs context worklist. Successors that + // exit the loop (they are not dominated by the loop header) must be reachable + // from other BCBs outside the loop, and they will be added to a different + // worklist. + // + // Branching blocks (with more than one successor) must be processed before + // blocks with only one successor, to prevent unnecessarily complicating + // `Expression`s by creating a Counter in a `BasicCoverageBlock` that the + // branching block would have given an `Expression` (or vice versa). + + let context = self + .context_stack + .iter_mut() + .rev() + .find(|context| match context.loop_header { + Some(loop_header) => { + self.basic_coverage_blocks.dominates(loop_header, successor) } - break; - } + None => true, + }) + .unwrap_or_else(|| bug!("should always fall back to the root non-loop context")); + debug!("adding to worklist for {:?}", context.loop_header); + + // FIXME: The code below had debug messages claiming to add items to a + // particular end of the worklist, but was confused about which end was + // which. The existing behaviour has been preserved for now, but it's + // unclear what the intended behaviour was. + + if self.basic_coverage_blocks.successors[successor].len() > 1 { + context.worklist.push_back(successor); + } else { + context.worklist.push_front(successor); } } } @@ -553,66 +550,28 @@ pub(super) fn find_loop_backedges( backedges } -pub struct ShortCircuitPreorder< - 'a, - 'tcx, - F: Fn(&'a mir::Body<'tcx>, &'a TerminatorKind<'tcx>) -> Box<dyn Iterator<Item = BasicBlock> + 'a>, -> { +fn short_circuit_preorder<'a, 'tcx, F, Iter>( body: &'a mir::Body<'tcx>, - visited: BitSet<BasicBlock>, - worklist: Vec<BasicBlock>, filtered_successors: F, -} - -impl< - 'a, - 'tcx, - F: Fn(&'a mir::Body<'tcx>, &'a TerminatorKind<'tcx>) -> Box<dyn Iterator<Item = BasicBlock> + 'a>, -> ShortCircuitPreorder<'a, 'tcx, F> -{ - pub fn new( - body: &'a mir::Body<'tcx>, - filtered_successors: F, - ) -> ShortCircuitPreorder<'a, 'tcx, F> { - let worklist = vec![mir::START_BLOCK]; - - ShortCircuitPreorder { - body, - visited: BitSet::new_empty(body.basic_blocks.len()), - worklist, - filtered_successors, - } - } -} - -impl< - 'a, - 'tcx, - F: Fn(&'a mir::Body<'tcx>, &'a TerminatorKind<'tcx>) -> Box<dyn Iterator<Item = BasicBlock> + 'a>, -> Iterator for ShortCircuitPreorder<'a, 'tcx, F> +) -> impl Iterator<Item = BasicBlock> + Captures<'a> + Captures<'tcx> +where + F: Fn(&'a mir::Body<'tcx>, BasicBlock) -> Iter, + Iter: Iterator<Item = BasicBlock>, { - type Item = (BasicBlock, &'a BasicBlockData<'tcx>); + let mut visited = BitSet::new_empty(body.basic_blocks.len()); + let mut worklist = vec![mir::START_BLOCK]; - fn next(&mut self) -> Option<(BasicBlock, &'a BasicBlockData<'tcx>)> { - while let Some(idx) = self.worklist.pop() { - if !self.visited.insert(idx) { + std::iter::from_fn(move || { + while let Some(bb) = worklist.pop() { + if !visited.insert(bb) { continue; } - let data = &self.body[idx]; - - if let Some(ref term) = data.terminator { - self.worklist.extend((self.filtered_successors)(&self.body, &term.kind)); - } + worklist.extend(filtered_successors(body, bb)); - return Some((idx, data)); + return Some(bb); } None - } - - fn size_hint(&self) -> (usize, Option<usize>) { - let size = self.body.basic_blocks.len() - self.visited.count(); - (size, Some(size)) - } + }) } diff --git a/compiler/rustc_mir_transform/src/coverage/mod.rs b/compiler/rustc_mir_transform/src/coverage/mod.rs index c75d33eeb..97e4468a0 100644 --- a/compiler/rustc_mir_transform/src/coverage/mod.rs +++ b/compiler/rustc_mir_transform/src/coverage/mod.rs @@ -8,14 +8,12 @@ mod spans; mod tests; use self::counters::{BcbCounter, CoverageCounters}; -use self::graph::{BasicCoverageBlock, BasicCoverageBlockData, CoverageGraph}; -use self::spans::{CoverageSpan, CoverageSpans}; +use self::graph::CoverageGraph; +use self::spans::CoverageSpans; use crate::MirPass; -use rustc_data_structures::graph::WithNumNodes; use rustc_data_structures::sync::Lrc; -use rustc_index::IndexVec; use rustc_middle::hir; use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrFlags; use rustc_middle::mir::coverage::*; @@ -28,18 +26,6 @@ use rustc_span::def_id::DefId; use rustc_span::source_map::SourceMap; use rustc_span::{ExpnKind, SourceFile, Span, Symbol}; -/// A simple error message wrapper for `coverage::Error`s. -#[derive(Debug)] -struct Error { - message: String, -} - -impl Error { - pub fn from_string<T>(message: String) -> Result<T, Error> { - Err(Self { message }) - } -} - /// Inserts `StatementKind::Coverage` statements that either instrument the binary with injected /// counters, via intrinsic `llvm.instrprof.increment`, and/or inject metadata used during codegen /// to construct the coverage map. @@ -154,7 +140,7 @@ impl<'a, 'tcx> Instrumentor<'a, 'tcx> { let body_span = self.body_span; //////////////////////////////////////////////////// - // Compute `CoverageSpan`s from the `CoverageGraph`. + // Compute coverage spans from the `CoverageGraph`. let coverage_spans = CoverageSpans::generate_coverage_spans( &self.mir_body, fn_sig_span, @@ -164,179 +150,106 @@ impl<'a, 'tcx> Instrumentor<'a, 'tcx> { //////////////////////////////////////////////////// // Create an optimized mix of `Counter`s and `Expression`s for the `CoverageGraph`. Ensure - // every `CoverageSpan` has a `Counter` or `Expression` assigned to its `BasicCoverageBlock` + // every coverage span has a `Counter` or `Expression` assigned to its `BasicCoverageBlock` // and all `Expression` dependencies (operands) are also generated, for any other - // `BasicCoverageBlock`s not already associated with a `CoverageSpan`. - // - // Intermediate expressions (used to compute other `Expression` values), which have no - // direct association with any `BasicCoverageBlock`, are accumulated inside `coverage_counters`. - let result = self - .coverage_counters - .make_bcb_counters(&mut self.basic_coverage_blocks, &coverage_spans); - - if let Ok(()) = result { - //////////////////////////////////////////////////// - // Remove the counter or edge counter from of each `CoverageSpan`s associated - // `BasicCoverageBlock`, and inject a `Coverage` statement into the MIR. - // - // `Coverage` statements injected from `CoverageSpan`s will include the code regions - // (source code start and end positions) to be counted by the associated counter. - // - // These `CoverageSpan`-associated counters are removed from their associated - // `BasicCoverageBlock`s so that the only remaining counters in the `CoverageGraph` - // are indirect counters (to be injected next, without associated code regions). - self.inject_coverage_span_counters(coverage_spans); - - //////////////////////////////////////////////////// - // For any remaining `BasicCoverageBlock` counters (that were not associated with - // any `CoverageSpan`), inject `Coverage` statements (_without_ code region `Span`s) - // to ensure `BasicCoverageBlock` counters that other `Expression`s may depend on - // are in fact counted, even though they don't directly contribute to counting - // their own independent code region's coverage. - self.inject_indirect_counters(); - - // Intermediate expressions will be injected as the final step, after generating - // debug output, if any. - //////////////////////////////////////////////////// - }; - - if let Err(e) = result { - bug!("Error processing: {:?}: {:?}", self.mir_body.source.def_id(), e.message) - }; - - //////////////////////////////////////////////////// - // Finally, inject the intermediate expressions collected along the way. - for intermediate_expression in &self.coverage_counters.intermediate_expressions { - inject_intermediate_expression( - self.mir_body, - self.make_mir_coverage_kind(intermediate_expression), - ); - } + // `BasicCoverageBlock`s not already associated with a coverage span. + let bcb_has_coverage_spans = |bcb| coverage_spans.bcb_has_coverage_spans(bcb); + self.coverage_counters + .make_bcb_counters(&self.basic_coverage_blocks, bcb_has_coverage_spans); + + let mappings = self.create_mappings_and_inject_coverage_statements(&coverage_spans); + + self.mir_body.function_coverage_info = Some(Box::new(FunctionCoverageInfo { + function_source_hash: self.function_source_hash, + num_counters: self.coverage_counters.num_counters(), + expressions: self.coverage_counters.take_expressions(), + mappings, + })); } - /// Inject a counter for each `CoverageSpan`. There can be multiple `CoverageSpan`s for a given - /// BCB, but only one actual counter needs to be incremented per BCB. `bb_counters` maps each - /// `bcb` to its `Counter`, when injected. Subsequent `CoverageSpan`s for a BCB that already has - /// a `Counter` will inject an `Expression` instead, and compute its value by adding `ZERO` to - /// the BCB `Counter` value. - fn inject_coverage_span_counters(&mut self, coverage_spans: Vec<CoverageSpan>) { - let tcx = self.tcx; - let source_map = tcx.sess.source_map(); + /// For each [`BcbCounter`] associated with a BCB node or BCB edge, create + /// any corresponding mappings (for BCB nodes only), and inject any necessary + /// coverage statements into MIR. + fn create_mappings_and_inject_coverage_statements( + &mut self, + coverage_spans: &CoverageSpans, + ) -> Vec<Mapping> { + let source_map = self.tcx.sess.source_map(); let body_span = self.body_span; - let file_name = Symbol::intern(&self.source_file.name.prefer_remapped().to_string_lossy()); - - let mut bcb_counters = IndexVec::from_elem_n(None, self.basic_coverage_blocks.num_nodes()); - for covspan in coverage_spans { - let bcb = covspan.bcb; - let span = covspan.span; - let counter_kind = if let Some(&counter_operand) = bcb_counters[bcb].as_ref() { - self.coverage_counters.make_identity_counter(counter_operand) - } else if let Some(counter_kind) = self.coverage_counters.take_bcb_counter(bcb) { - bcb_counters[bcb] = Some(counter_kind.as_operand()); - counter_kind - } else { - bug!("Every BasicCoverageBlock should have a Counter or Expression"); - }; - - let code_region = make_code_region(source_map, file_name, span, body_span); - inject_statement( - self.mir_body, - self.make_mir_coverage_kind(&counter_kind), - self.bcb_leader_bb(bcb), - Some(code_region), - ); - } - } - - /// `inject_coverage_span_counters()` looped through the `CoverageSpan`s and injected the - /// counter from the `CoverageSpan`s `BasicCoverageBlock`, removing it from the BCB in the - /// process (via `take_counter()`). - /// - /// Any other counter associated with a `BasicCoverageBlock`, or its incoming edge, but not - /// associated with a `CoverageSpan`, should only exist if the counter is an `Expression` - /// dependency (one of the expression operands). Collect them, and inject the additional - /// counters into the MIR, without a reportable coverage span. - fn inject_indirect_counters(&mut self) { - let mut bcb_counters_without_direct_coverage_spans = Vec::new(); - for (target_bcb, counter_kind) in self.coverage_counters.drain_bcb_counters() { - bcb_counters_without_direct_coverage_spans.push((None, target_bcb, counter_kind)); - } - for ((from_bcb, target_bcb), counter_kind) in - self.coverage_counters.drain_bcb_edge_counters() - { - bcb_counters_without_direct_coverage_spans.push(( - Some(from_bcb), - target_bcb, - counter_kind, - )); - } + use rustc_session::RemapFileNameExt; + let file_name = + Symbol::intern(&self.source_file.name.for_codegen(self.tcx.sess).to_string_lossy()); + + let mut mappings = Vec::new(); + + // Process the counters and spans associated with BCB nodes. + for (bcb, counter_kind) in self.coverage_counters.bcb_node_counters() { + let spans = coverage_spans.spans_for_bcb(bcb); + let has_mappings = !spans.is_empty(); + + // If this BCB has any coverage spans, add corresponding mappings to + // the mappings table. + if has_mappings { + let term = counter_kind.as_term(); + mappings.extend(spans.iter().map(|&span| { + let code_region = make_code_region(source_map, file_name, span, body_span); + Mapping { code_region, term } + })); + } - for (edge_from_bcb, target_bcb, counter_kind) in bcb_counters_without_direct_coverage_spans - { - match counter_kind { - BcbCounter::Counter { .. } => { - let inject_to_bb = if let Some(from_bcb) = edge_from_bcb { - // The MIR edge starts `from_bb` (the outgoing / last BasicBlock in - // `from_bcb`) and ends at `to_bb` (the incoming / first BasicBlock in the - // `target_bcb`; also called the `leader_bb`). - let from_bb = self.bcb_last_bb(from_bcb); - let to_bb = self.bcb_leader_bb(target_bcb); - - let new_bb = inject_edge_counter_basic_block(self.mir_body, from_bb, to_bb); - debug!( - "Edge {:?} (last {:?}) -> {:?} (leader {:?}) requires a new MIR \ - BasicBlock {:?}, for unclaimed edge counter {:?}", - edge_from_bcb, from_bb, target_bcb, to_bb, new_bb, counter_kind, - ); - new_bb - } else { - let target_bb = self.bcb_last_bb(target_bcb); - debug!( - "{:?} ({:?}) gets a new Coverage statement for unclaimed counter {:?}", - target_bcb, target_bb, counter_kind, - ); - target_bb - }; - - inject_statement( - self.mir_body, - self.make_mir_coverage_kind(&counter_kind), - inject_to_bb, - None, - ); - } - BcbCounter::Expression { .. } => inject_intermediate_expression( + let do_inject = match counter_kind { + // Counter-increment statements always need to be injected. + BcbCounter::Counter { .. } => true, + // The only purpose of expression-used statements is to detect + // when a mapping is unreachable, so we only inject them for + // expressions with one or more mappings. + BcbCounter::Expression { .. } => has_mappings, + }; + if do_inject { + inject_statement( self.mir_body, - self.make_mir_coverage_kind(&counter_kind), - ), + self.make_mir_coverage_kind(counter_kind), + self.basic_coverage_blocks[bcb].leader_bb(), + ); } } - } - #[inline] - fn bcb_leader_bb(&self, bcb: BasicCoverageBlock) -> BasicBlock { - self.bcb_data(bcb).leader_bb() - } + // Process the counters associated with BCB edges. + for (from_bcb, to_bcb, counter_kind) in self.coverage_counters.bcb_edge_counters() { + let do_inject = match counter_kind { + // Counter-increment statements always need to be injected. + BcbCounter::Counter { .. } => true, + // BCB-edge expressions never have mappings, so they never need + // a corresponding statement. + BcbCounter::Expression { .. } => false, + }; + if !do_inject { + continue; + } - #[inline] - fn bcb_last_bb(&self, bcb: BasicCoverageBlock) -> BasicBlock { - self.bcb_data(bcb).last_bb() - } + // We need to inject a coverage statement into a new BB between the + // last BB of `from_bcb` and the first BB of `to_bcb`. + let from_bb = self.basic_coverage_blocks[from_bcb].last_bb(); + let to_bb = self.basic_coverage_blocks[to_bcb].leader_bb(); + + let new_bb = inject_edge_counter_basic_block(self.mir_body, from_bb, to_bb); + debug!( + "Edge {from_bcb:?} (last {from_bb:?}) -> {to_bcb:?} (leader {to_bb:?}) \ + requires a new MIR BasicBlock {new_bb:?} for edge counter {counter_kind:?}", + ); + + // Inject a counter into the newly-created BB. + inject_statement(self.mir_body, self.make_mir_coverage_kind(&counter_kind), new_bb); + } - #[inline] - fn bcb_data(&self, bcb: BasicCoverageBlock) -> &BasicCoverageBlockData { - &self.basic_coverage_blocks[bcb] + mappings } fn make_mir_coverage_kind(&self, counter_kind: &BcbCounter) -> CoverageKind { match *counter_kind { - BcbCounter::Counter { id } => { - CoverageKind::Counter { function_source_hash: self.function_source_hash, id } - } - BcbCounter::Expression { id, lhs, op, rhs } => { - CoverageKind::Expression { id, lhs, op, rhs } - } + BcbCounter::Counter { id } => CoverageKind::CounterIncrement { id }, + BcbCounter::Expression { id } => CoverageKind::ExpressionUsed { id }, } } } @@ -364,42 +277,17 @@ fn inject_edge_counter_basic_block( new_bb } -fn inject_statement( - mir_body: &mut mir::Body<'_>, - counter_kind: CoverageKind, - bb: BasicBlock, - some_code_region: Option<CodeRegion>, -) { - debug!( - " injecting statement {:?} for {:?} at code region: {:?}", - counter_kind, bb, some_code_region - ); +fn inject_statement(mir_body: &mut mir::Body<'_>, counter_kind: CoverageKind, bb: BasicBlock) { + debug!(" injecting statement {counter_kind:?} for {bb:?}"); let data = &mut mir_body[bb]; let source_info = data.terminator().source_info; let statement = Statement { source_info, - kind: StatementKind::Coverage(Box::new(Coverage { - kind: counter_kind, - code_region: some_code_region, - })), + kind: StatementKind::Coverage(Box::new(Coverage { kind: counter_kind })), }; data.statements.insert(0, statement); } -// Non-code expressions are injected into the coverage map, without generating executable code. -fn inject_intermediate_expression(mir_body: &mut mir::Body<'_>, expression: CoverageKind) { - debug_assert!(matches!(expression, CoverageKind::Expression { .. })); - debug!(" injecting non-code expression {:?}", expression); - let inject_in_bb = mir::START_BLOCK; - let data = &mut mir_body[inject_in_bb]; - let source_info = data.terminator().source_info; - let statement = Statement { - source_info, - kind: StatementKind::Coverage(Box::new(Coverage { kind: expression, code_region: None })), - }; - data.statements.push(statement); -} - /// Convert the Span into its file name, start line and column, and end line and column fn make_code_region( source_map: &SourceMap, diff --git a/compiler/rustc_mir_transform/src/coverage/query.rs b/compiler/rustc_mir_transform/src/coverage/query.rs index 56365c5d4..809407f89 100644 --- a/compiler/rustc_mir_transform/src/coverage/query.rs +++ b/compiler/rustc_mir_transform/src/coverage/query.rs @@ -2,100 +2,31 @@ use super::*; use rustc_data_structures::captures::Captures; use rustc_middle::mir::coverage::*; -use rustc_middle::mir::{self, Body, Coverage, CoverageInfo}; +use rustc_middle::mir::{Body, Coverage, CoverageIdsInfo}; use rustc_middle::query::Providers; use rustc_middle::ty::{self, TyCtxt}; -use rustc_span::def_id::DefId; /// A `query` provider for retrieving coverage information injected into MIR. pub(crate) fn provide(providers: &mut Providers) { - providers.coverageinfo = |tcx, def_id| coverageinfo(tcx, def_id); - providers.covered_code_regions = |tcx, def_id| covered_code_regions(tcx, def_id); + providers.coverage_ids_info = |tcx, def_id| coverage_ids_info(tcx, def_id); } -/// Coverage codegen needs to know the total number of counter IDs and expression IDs that have -/// been used by a function's coverage mappings. These totals are used to create vectors to hold -/// the relevant counter and expression data, and the maximum counter ID (+ 1) is also needed by -/// the `llvm.instrprof.increment` intrinsic. -/// -/// MIR optimization may split and duplicate some BasicBlock sequences, or optimize out some code -/// including injected counters. (It is OK if some counters are optimized out, but those counters -/// are still included in the total `num_counters` or `num_expressions`.) Simply counting the -/// calls may not work; but computing the number of counters or expressions by adding `1` to the -/// highest ID (for a given instrumented function) is valid. -/// -/// It's possible for a coverage expression to remain in MIR while one or both of its operands -/// have been optimized away. To avoid problems in codegen, we include those operands' IDs when -/// determining the maximum counter/expression ID, even if the underlying counter/expression is -/// no longer present. -struct CoverageVisitor { - max_counter_id: CounterId, - max_expression_id: ExpressionId, -} - -impl CoverageVisitor { - /// Updates `max_counter_id` to the maximum encountered counter ID. - #[inline(always)] - fn update_max_counter_id(&mut self, counter_id: CounterId) { - self.max_counter_id = self.max_counter_id.max(counter_id); - } - - /// Updates `max_expression_id` to the maximum encountered expression ID. - #[inline(always)] - fn update_max_expression_id(&mut self, expression_id: ExpressionId) { - self.max_expression_id = self.max_expression_id.max(expression_id); - } - - fn update_from_expression_operand(&mut self, operand: Operand) { - match operand { - Operand::Counter(id) => self.update_max_counter_id(id), - Operand::Expression(id) => self.update_max_expression_id(id), - Operand::Zero => {} - } - } - - fn visit_body(&mut self, body: &Body<'_>) { - for coverage in all_coverage_in_mir_body(body) { - self.visit_coverage(coverage); - } - } - - fn visit_coverage(&mut self, coverage: &Coverage) { - match coverage.kind { - CoverageKind::Counter { id, .. } => self.update_max_counter_id(id), - CoverageKind::Expression { id, lhs, rhs, .. } => { - self.update_max_expression_id(id); - self.update_from_expression_operand(lhs); - self.update_from_expression_operand(rhs); - } - CoverageKind::Unreachable => {} - } - } -} - -fn coverageinfo<'tcx>(tcx: TyCtxt<'tcx>, instance_def: ty::InstanceDef<'tcx>) -> CoverageInfo { +/// Query implementation for `coverage_ids_info`. +fn coverage_ids_info<'tcx>( + tcx: TyCtxt<'tcx>, + instance_def: ty::InstanceDef<'tcx>, +) -> CoverageIdsInfo { let mir_body = tcx.instance_mir(instance_def); - let mut coverage_visitor = CoverageVisitor { - max_counter_id: CounterId::START, - max_expression_id: ExpressionId::START, - }; - - coverage_visitor.visit_body(mir_body); - - // Add 1 to the highest IDs to get the total number of IDs. - CoverageInfo { - num_counters: (coverage_visitor.max_counter_id + 1).as_u32(), - num_expressions: (coverage_visitor.max_expression_id + 1).as_u32(), - } -} + let max_counter_id = all_coverage_in_mir_body(mir_body) + .filter_map(|coverage| match coverage.kind { + CoverageKind::CounterIncrement { id } => Some(id), + _ => None, + }) + .max() + .unwrap_or(CounterId::START); -fn covered_code_regions(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<&CodeRegion> { - let body = mir_body(tcx, def_id); - all_coverage_in_mir_body(body) - // Not all coverage statements have an attached code region. - .filter_map(|coverage| coverage.code_region.as_ref()) - .collect() + CoverageIdsInfo { max_counter_id } } fn all_coverage_in_mir_body<'a, 'tcx>( @@ -115,11 +46,3 @@ fn is_inlined(body: &Body<'_>, statement: &Statement<'_>) -> bool { let scope_data = &body.source_scopes[statement.source_info.scope]; scope_data.inlined.is_some() || scope_data.inlined_parent_scope.is_some() } - -/// This function ensures we obtain the correct MIR for the given item irrespective of -/// whether that means const mir or runtime mir. For `const fn` this opts for runtime -/// mir. -fn mir_body(tcx: TyCtxt<'_>, def_id: DefId) -> &mir::Body<'_> { - let def = ty::InstanceDef::Item(def_id); - tcx.instance_mir(def) -} diff --git a/compiler/rustc_mir_transform/src/coverage/spans.rs b/compiler/rustc_mir_transform/src/coverage/spans.rs index ed0e104d6..b318134ae 100644 --- a/compiler/rustc_mir_transform/src/coverage/spans.rs +++ b/compiler/rustc_mir_transform/src/coverage/spans.rs @@ -1,26 +1,48 @@ -use super::graph::{BasicCoverageBlock, BasicCoverageBlockData, CoverageGraph, START_BCB}; +use std::cell::OnceCell; use rustc_data_structures::graph::WithNumNodes; -use rustc_middle::mir::{ - self, AggregateKind, BasicBlock, FakeReadCause, Rvalue, Statement, StatementKind, Terminator, - TerminatorKind, -}; -use rustc_span::source_map::original_sp; -use rustc_span::{BytePos, ExpnKind, MacroKind, Span, Symbol}; +use rustc_index::IndexVec; +use rustc_middle::mir; +use rustc_span::{BytePos, ExpnKind, MacroKind, Span, Symbol, DUMMY_SP}; -use std::cell::OnceCell; +use super::graph::{BasicCoverageBlock, CoverageGraph, START_BCB}; + +mod from_mir; -#[derive(Debug, Copy, Clone)] -pub(super) enum CoverageStatement { - Statement(BasicBlock, Span, usize), - Terminator(BasicBlock, Span), +pub(super) struct CoverageSpans { + /// Map from BCBs to their list of coverage spans. + bcb_to_spans: IndexVec<BasicCoverageBlock, Vec<Span>>, } -impl CoverageStatement { - pub fn span(&self) -> Span { - match self { - Self::Statement(_, span, _) | Self::Terminator(_, span) => *span, +impl CoverageSpans { + pub(super) fn generate_coverage_spans( + mir_body: &mir::Body<'_>, + fn_sig_span: Span, + body_span: Span, + basic_coverage_blocks: &CoverageGraph, + ) -> Self { + let coverage_spans = CoverageSpansGenerator::generate_coverage_spans( + mir_body, + fn_sig_span, + body_span, + basic_coverage_blocks, + ); + + // Group the coverage spans by BCB, with the BCBs in sorted order. + let mut bcb_to_spans = IndexVec::from_elem_n(Vec::new(), basic_coverage_blocks.num_nodes()); + for CoverageSpan { bcb, span, .. } in coverage_spans { + bcb_to_spans[bcb].push(span); } + + Self { bcb_to_spans } + } + + pub(super) fn bcb_has_coverage_spans(&self, bcb: BasicCoverageBlock) -> bool { + !self.bcb_to_spans[bcb].is_empty() + } + + pub(super) fn spans_for_bcb(&self, bcb: BasicCoverageBlock) -> &[Span] { + &self.bcb_to_spans[bcb] } } @@ -28,87 +50,55 @@ impl CoverageStatement { /// references the originating BCB and one or more MIR `Statement`s and/or `Terminator`s. /// Initially, the `Span`s come from the `Statement`s and `Terminator`s, but subsequent /// transforms can combine adjacent `Span`s and `CoverageSpan` from the same BCB, merging the -/// `CoverageStatement` vectors, and the `Span`s to cover the extent of the combined `Span`s. +/// `merged_spans` vectors, and the `Span`s to cover the extent of the combined `Span`s. /// -/// Note: A `CoverageStatement` merged into another CoverageSpan may come from a `BasicBlock` that +/// Note: A span merged into another CoverageSpan may come from a `BasicBlock` that /// is not part of the `CoverageSpan` bcb if the statement was included because it's `Span` matches /// or is subsumed by the `Span` associated with this `CoverageSpan`, and it's `BasicBlock` /// `dominates()` the `BasicBlock`s in this `CoverageSpan`. #[derive(Debug, Clone)] -pub(super) struct CoverageSpan { +struct CoverageSpan { pub span: Span, pub expn_span: Span, pub current_macro_or_none: OnceCell<Option<Symbol>>, pub bcb: BasicCoverageBlock, - pub coverage_statements: Vec<CoverageStatement>, + /// List of all the original spans from MIR that have been merged into this + /// span. Mainly used to precisely skip over gaps when truncating a span. + pub merged_spans: Vec<Span>, pub is_closure: bool, } impl CoverageSpan { pub fn for_fn_sig(fn_sig_span: Span) -> Self { - Self { - span: fn_sig_span, - expn_span: fn_sig_span, - current_macro_or_none: Default::default(), - bcb: START_BCB, - coverage_statements: vec![], - is_closure: false, - } + Self::new(fn_sig_span, fn_sig_span, START_BCB, false) } - pub fn for_statement( - statement: &Statement<'_>, + pub(super) fn new( span: Span, expn_span: Span, bcb: BasicCoverageBlock, - bb: BasicBlock, - stmt_index: usize, + is_closure: bool, ) -> Self { - let is_closure = match statement.kind { - StatementKind::Assign(box (_, Rvalue::Aggregate(box ref kind, _))) => { - matches!(kind, AggregateKind::Closure(_, _) | AggregateKind::Generator(_, _, _)) - } - _ => false, - }; - Self { span, expn_span, current_macro_or_none: Default::default(), bcb, - coverage_statements: vec![CoverageStatement::Statement(bb, span, stmt_index)], + merged_spans: vec![span], is_closure, } } - pub fn for_terminator( - span: Span, - expn_span: Span, - bcb: BasicCoverageBlock, - bb: BasicBlock, - ) -> Self { - Self { - span, - expn_span, - current_macro_or_none: Default::default(), - bcb, - coverage_statements: vec![CoverageStatement::Terminator(bb, span)], - is_closure: false, - } - } - pub fn merge_from(&mut self, mut other: CoverageSpan) { debug_assert!(self.is_mergeable(&other)); self.span = self.span.to(other.span); - self.coverage_statements.append(&mut other.coverage_statements); + self.merged_spans.append(&mut other.merged_spans); } pub fn cutoff_statements_at(&mut self, cutoff_pos: BytePos) { - self.coverage_statements.retain(|covstmt| covstmt.span().hi() <= cutoff_pos); - if let Some(highest_covstmt) = - self.coverage_statements.iter().max_by_key(|covstmt| covstmt.span().hi()) - { - self.span = self.span.with_hi(highest_covstmt.span().hi()); + self.merged_spans.retain(|span| span.hi() <= cutoff_pos); + if let Some(max_hi) = self.merged_spans.iter().map(|span| span.hi()).max() { + self.span = self.span.with_hi(max_hi); } } @@ -139,11 +129,12 @@ impl CoverageSpan { /// If the span is part of a macro, and the macro is visible (expands directly to the given /// body_span), returns the macro name symbol. pub fn visible_macro(&self, body_span: Span) -> Option<Symbol> { - if let Some(current_macro) = self.current_macro() && self - .expn_span - .parent_callsite() - .unwrap_or_else(|| bug!("macro must have a parent")) - .eq_ctxt(body_span) + if let Some(current_macro) = self.current_macro() + && self + .expn_span + .parent_callsite() + .unwrap_or_else(|| bug!("macro must have a parent")) + .eq_ctxt(body_span) { return Some(current_macro); } @@ -162,13 +153,7 @@ impl CoverageSpan { /// * Merge spans that represent continuous (both in source code and control flow), non-branching /// execution /// * Carve out (leave uncovered) any span that will be counted by another MIR (notably, closures) -pub struct CoverageSpans<'a, 'tcx> { - /// The MIR, used to look up `BasicBlockData`. - mir_body: &'a mir::Body<'tcx>, - - /// A `Span` covering the signature of function for the MIR. - fn_sig_span: Span, - +struct CoverageSpansGenerator<'a> { /// A `Span` covering the function body of the MIR (typically from left curly brace to right /// curly brace). body_span: Span, @@ -178,7 +163,7 @@ pub struct CoverageSpans<'a, 'tcx> { /// The initial set of `CoverageSpan`s, sorted by `Span` (`lo` and `hi`) and by relative /// dominance between the `BasicCoverageBlock`s of equal `Span`s. - sorted_spans_iter: Option<std::vec::IntoIter<CoverageSpan>>, + sorted_spans_iter: std::vec::IntoIter<CoverageSpan>, /// The current `CoverageSpan` to compare to its `prev`, to possibly merge, discard, force the /// discard of the `prev` (and or `pending_dups`), or keep both (with `prev` moved to @@ -200,9 +185,6 @@ pub struct CoverageSpans<'a, 'tcx> { /// is mutated. prev_original_span: Span, - /// A copy of the expn_span from the prior iteration. - prev_expn_span: Option<Span>, - /// One or more `CoverageSpan`s with the same `Span` but different `BasicCoverageBlock`s, and /// no `BasicCoverageBlock` in this list dominates another `BasicCoverageBlock` in the list. /// If a new `curr` span also fits this criteria (compared to an existing list of @@ -218,7 +200,7 @@ pub struct CoverageSpans<'a, 'tcx> { refined_spans: Vec<CoverageSpan>, } -impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { +impl<'a> CoverageSpansGenerator<'a> { /// Generate a minimal set of `CoverageSpan`s, each representing a contiguous code region to be /// counted. /// @@ -241,109 +223,79 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { /// Note the resulting vector of `CoverageSpan`s may not be fully sorted (and does not need /// to be). pub(super) fn generate_coverage_spans( - mir_body: &'a mir::Body<'tcx>, + mir_body: &mir::Body<'_>, fn_sig_span: Span, // Ensured to be same SourceFile and SyntaxContext as `body_span` body_span: Span, basic_coverage_blocks: &'a CoverageGraph, ) -> Vec<CoverageSpan> { - let mut coverage_spans = CoverageSpans { + let sorted_spans = from_mir::mir_to_initial_sorted_coverage_spans( mir_body, fn_sig_span, body_span, basic_coverage_blocks, - sorted_spans_iter: None, - refined_spans: Vec::with_capacity(basic_coverage_blocks.num_nodes() * 2), + ); + + let coverage_spans = Self { + body_span, + basic_coverage_blocks, + sorted_spans_iter: sorted_spans.into_iter(), some_curr: None, - curr_original_span: Span::with_root_ctxt(BytePos(0), BytePos(0)), + curr_original_span: DUMMY_SP, some_prev: None, - prev_original_span: Span::with_root_ctxt(BytePos(0), BytePos(0)), - prev_expn_span: None, + prev_original_span: DUMMY_SP, pending_dups: Vec::new(), + refined_spans: Vec::with_capacity(basic_coverage_blocks.num_nodes() * 2), }; - let sorted_spans = coverage_spans.mir_to_initial_sorted_coverage_spans(); - - coverage_spans.sorted_spans_iter = Some(sorted_spans.into_iter()); - coverage_spans.to_refined_spans() } - fn mir_to_initial_sorted_coverage_spans(&self) -> Vec<CoverageSpan> { - let mut initial_spans = - Vec::<CoverageSpan>::with_capacity(self.mir_body.basic_blocks.len() * 2); - for (bcb, bcb_data) in self.basic_coverage_blocks.iter_enumerated() { - initial_spans.extend(self.bcb_to_initial_coverage_spans(bcb, bcb_data)); - } - - if initial_spans.is_empty() { - // This can happen if, for example, the function is unreachable (contains only a - // `BasicBlock`(s) with an `Unreachable` terminator). - return initial_spans; - } - - initial_spans.push(CoverageSpan::for_fn_sig(self.fn_sig_span)); - - initial_spans.sort_by(|a, b| { - // First sort by span start. - Ord::cmp(&a.span.lo(), &b.span.lo()) - // If span starts are the same, sort by span end in reverse order. - // This ensures that if spans A and B are adjacent in the list, - // and they overlap but are not equal, then either: - // - Span A extends further left, or - // - Both have the same start and span A extends further right - .then_with(|| Ord::cmp(&a.span.hi(), &b.span.hi()).reverse()) - // If both spans are equal, sort the BCBs in dominator order, - // so that dominating BCBs come before other BCBs they dominate. - .then_with(|| self.basic_coverage_blocks.cmp_in_dominator_order(a.bcb, b.bcb)) - // If two spans are otherwise identical, put closure spans first, - // as this seems to be what the refinement step expects. - .then_with(|| Ord::cmp(&a.is_closure, &b.is_closure).reverse()) - }); - - initial_spans - } - /// Iterate through the sorted `CoverageSpan`s, and return the refined list of merged and /// de-duplicated `CoverageSpan`s. fn to_refined_spans(mut self) -> Vec<CoverageSpan> { while self.next_coverage_span() { + // For the first span we don't have `prev` set, so most of the + // span-processing steps don't make sense yet. if self.some_prev.is_none() { debug!(" initial span"); - self.check_invoked_macro_name_span(); - } else if self.curr().is_mergeable(self.prev()) { - debug!(" same bcb (and neither is a closure), merge with prev={:?}", self.prev()); + self.maybe_push_macro_name_span(); + continue; + } + + // The remaining cases assume that `prev` and `curr` are set. + let prev = self.prev(); + let curr = self.curr(); + + if curr.is_mergeable(prev) { + debug!(" same bcb (and neither is a closure), merge with prev={prev:?}"); let prev = self.take_prev(); self.curr_mut().merge_from(prev); - self.check_invoked_macro_name_span(); + self.maybe_push_macro_name_span(); // Note that curr.span may now differ from curr_original_span - } else if self.prev_ends_before_curr() { + } else if prev.span.hi() <= curr.span.lo() { debug!( - " different bcbs and disjoint spans, so keep curr for next iter, and add \ - prev={:?}", - self.prev() + " different bcbs and disjoint spans, so keep curr for next iter, and add prev={prev:?}", ); let prev = self.take_prev(); self.push_refined_span(prev); - self.check_invoked_macro_name_span(); - } else if self.prev().is_closure { + self.maybe_push_macro_name_span(); + } else if prev.is_closure { // drop any equal or overlapping span (`curr`) and keep `prev` to test again in the // next iter debug!( - " curr overlaps a closure (prev). Drop curr and keep prev for next iter. \ - prev={:?}", - self.prev() + " curr overlaps a closure (prev). Drop curr and keep prev for next iter. prev={prev:?}", ); - self.take_curr(); - } else if self.curr().is_closure { + self.take_curr(); // Discards curr. + } else if curr.is_closure { self.carve_out_span_for_closure(); - } else if self.prev_original_span == self.curr().span { + } else if self.prev_original_span == curr.span { // Note that this compares the new (`curr`) span to `prev_original_span`. // In this branch, the actual span byte range of `prev_original_span` is not // important. What is important is knowing whether the new `curr` span was // **originally** the same as the original span of `prev()`. The original spans // reflect their original sort order, and for equal spans, conveys a partial // ordering based on CFG dominator priority. - if self.prev().is_macro_expansion() && self.curr().is_macro_expansion() { + if prev.is_macro_expansion() && curr.is_macro_expansion() { // Macros that expand to include branching (such as // `assert_eq!()`, `assert_ne!()`, `info!()`, `debug!()`, or // `trace!()`) typically generate callee spans with identical @@ -357,23 +309,24 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { debug!( " curr and prev are part of a macro expansion, and curr has the same span \ as prev, but is in a different bcb. Drop curr and keep prev for next iter. \ - prev={:?}", - self.prev() + prev={prev:?}", ); - self.take_curr(); + self.take_curr(); // Discards curr. } else { - self.hold_pending_dups_unless_dominated(); + self.update_pending_dups(); } } else { self.cutoff_prev_at_overlapping_curr(); - self.check_invoked_macro_name_span(); + self.maybe_push_macro_name_span(); } } - debug!(" AT END, adding last prev={:?}", self.prev()); let prev = self.take_prev(); - let pending_dups = self.pending_dups.split_off(0); - for dup in pending_dups { + debug!(" AT END, adding last prev={prev:?}"); + + // Take `pending_dups` so that we can drain it while calling self methods. + // It is never used as a field after this point. + for dup in std::mem::take(&mut self.pending_dups) { debug!(" ...adding at least one pending dup={:?}", dup); self.push_refined_span(dup); } @@ -403,91 +356,46 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { } fn push_refined_span(&mut self, covspan: CoverageSpan) { - let len = self.refined_spans.len(); - if len > 0 { - let last = &mut self.refined_spans[len - 1]; - if last.is_mergeable(&covspan) { - debug!( - "merging new refined span with last refined span, last={:?}, covspan={:?}", - last, covspan - ); - last.merge_from(covspan); - return; - } + if let Some(last) = self.refined_spans.last_mut() + && last.is_mergeable(&covspan) + { + // Instead of pushing the new span, merge it with the last refined span. + debug!(?last, ?covspan, "merging new refined span with last refined span"); + last.merge_from(covspan); + } else { + self.refined_spans.push(covspan); } - self.refined_spans.push(covspan) } - fn check_invoked_macro_name_span(&mut self) { - if let Some(visible_macro) = self.curr().visible_macro(self.body_span) { - if !self - .prev_expn_span - .is_some_and(|prev_expn_span| self.curr().expn_span.ctxt() == prev_expn_span.ctxt()) - { - let merged_prefix_len = self.curr_original_span.lo() - self.curr().span.lo(); - let after_macro_bang = - merged_prefix_len + BytePos(visible_macro.as_str().len() as u32 + 1); - if self.curr().span.lo() + after_macro_bang > self.curr().span.hi() { - // Something is wrong with the macro name span; - // return now to avoid emitting malformed mappings. - // FIXME(#117788): Track down why this happens. - return; - } - let mut macro_name_cov = self.curr().clone(); - self.curr_mut().span = - self.curr().span.with_lo(self.curr().span.lo() + after_macro_bang); - macro_name_cov.span = - macro_name_cov.span.with_hi(macro_name_cov.span.lo() + after_macro_bang); - debug!( - " and curr starts a new macro expansion, so add a new span just for \ - the macro `{}!`, new span={:?}", - visible_macro, macro_name_cov - ); - self.push_refined_span(macro_name_cov); - } + /// If `curr` is part of a new macro expansion, carve out and push a separate + /// span that ends just after the macro name and its subsequent `!`. + fn maybe_push_macro_name_span(&mut self) { + let curr = self.curr(); + + let Some(visible_macro) = curr.visible_macro(self.body_span) else { return }; + if let Some(prev) = &self.some_prev + && prev.expn_span.eq_ctxt(curr.expn_span) + { + return; } - } - // Generate a set of `CoverageSpan`s from the filtered set of `Statement`s and `Terminator`s of - // the `BasicBlock`(s) in the given `BasicCoverageBlockData`. One `CoverageSpan` is generated - // for each `Statement` and `Terminator`. (Note that subsequent stages of coverage analysis will - // merge some `CoverageSpan`s, at which point a `CoverageSpan` may represent multiple - // `Statement`s and/or `Terminator`s.) - fn bcb_to_initial_coverage_spans( - &self, - bcb: BasicCoverageBlock, - bcb_data: &'a BasicCoverageBlockData, - ) -> Vec<CoverageSpan> { - bcb_data - .basic_blocks - .iter() - .flat_map(|&bb| { - let data = &self.mir_body[bb]; - data.statements - .iter() - .enumerate() - .filter_map(move |(index, statement)| { - filtered_statement_span(statement).map(|span| { - CoverageSpan::for_statement( - statement, - function_source_span(span, self.body_span), - span, - bcb, - bb, - index, - ) - }) - }) - .chain(filtered_terminator_span(data.terminator()).map(|span| { - CoverageSpan::for_terminator( - function_source_span(span, self.body_span), - span, - bcb, - bb, - ) - })) - }) - .collect() + let merged_prefix_len = self.curr_original_span.lo() - curr.span.lo(); + let after_macro_bang = merged_prefix_len + BytePos(visible_macro.as_str().len() as u32 + 1); + if self.curr().span.lo() + after_macro_bang > self.curr().span.hi() { + // Something is wrong with the macro name span; + // return now to avoid emitting malformed mappings. + // FIXME(#117788): Track down why this happens. + return; + } + let mut macro_name_cov = curr.clone(); + self.curr_mut().span = curr.span.with_lo(curr.span.lo() + after_macro_bang); + macro_name_cov.span = + macro_name_cov.span.with_hi(macro_name_cov.span.lo() + after_macro_bang); + debug!( + " and curr starts a new macro expansion, so add a new span just for \ + the macro `{visible_macro}!`, new span={macro_name_cov:?}", + ); + self.push_refined_span(macro_name_cov); } fn curr(&self) -> &CoverageSpan { @@ -502,6 +410,12 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { .unwrap_or_else(|| bug!("invalid attempt to unwrap a None some_curr")) } + /// If called, then the next call to `next_coverage_span()` will *not* update `prev` with the + /// `curr` coverage span. + fn take_curr(&mut self) -> CoverageSpan { + self.some_curr.take().unwrap_or_else(|| bug!("invalid attempt to unwrap a None some_curr")) + } + fn prev(&self) -> &CoverageSpan { self.some_prev .as_ref() @@ -527,82 +441,78 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { /// `pending_dups` could have as few as one span) /// In either case, no more spans will match the span of `pending_dups`, so /// add the `pending_dups` if they don't overlap `curr`, and clear the list. - fn check_pending_dups(&mut self) { - if let Some(dup) = self.pending_dups.last() && dup.span != self.prev().span { - debug!( - " SAME spans, but pending_dups are NOT THE SAME, so BCBs matched on \ - previous iteration, or prev started a new disjoint span" - ); - if dup.span.hi() <= self.curr().span.lo() { - let pending_dups = self.pending_dups.split_off(0); - for dup in pending_dups.into_iter() { - debug!(" ...adding at least one pending={:?}", dup); - self.push_refined_span(dup); - } - } else { - self.pending_dups.clear(); + fn maybe_flush_pending_dups(&mut self) { + let Some(last_dup) = self.pending_dups.last() else { return }; + if last_dup.span == self.prev().span { + return; + } + + debug!( + " SAME spans, but pending_dups are NOT THE SAME, so BCBs matched on \ + previous iteration, or prev started a new disjoint span" + ); + if last_dup.span.hi() <= self.curr().span.lo() { + // Temporarily steal `pending_dups` into a local, so that we can + // drain it while calling other self methods. + let mut pending_dups = std::mem::take(&mut self.pending_dups); + for dup in pending_dups.drain(..) { + debug!(" ...adding at least one pending={:?}", dup); + self.push_refined_span(dup); } + // The list of dups is now empty, but we can recycle its capacity. + assert!(pending_dups.is_empty() && self.pending_dups.is_empty()); + self.pending_dups = pending_dups; + } else { + self.pending_dups.clear(); } } /// Advance `prev` to `curr` (if any), and `curr` to the next `CoverageSpan` in sorted order. fn next_coverage_span(&mut self) -> bool { if let Some(curr) = self.some_curr.take() { - self.prev_expn_span = Some(curr.expn_span); self.some_prev = Some(curr); self.prev_original_span = self.curr_original_span; } - while let Some(curr) = self.sorted_spans_iter.as_mut().unwrap().next() { + while let Some(curr) = self.sorted_spans_iter.next() { debug!("FOR curr={:?}", curr); - if self.some_prev.is_some() && self.prev_starts_after_next(&curr) { + if let Some(prev) = &self.some_prev && prev.span.lo() > curr.span.lo() { + // Skip curr because prev has already advanced beyond the end of curr. + // This can only happen if a prior iteration updated `prev` to skip past + // a region of code, such as skipping past a closure. debug!( " prev.span starts after curr.span, so curr will be dropped (skipping past \ - closure?); prev={:?}", - self.prev() + closure?); prev={prev:?}", ); } else { // Save a copy of the original span for `curr` in case the `CoverageSpan` is changed // by `self.curr_mut().merge_from(prev)`. self.curr_original_span = curr.span; self.some_curr.replace(curr); - self.check_pending_dups(); + self.maybe_flush_pending_dups(); return true; } } false } - /// If called, then the next call to `next_coverage_span()` will *not* update `prev` with the - /// `curr` coverage span. - fn take_curr(&mut self) -> CoverageSpan { - self.some_curr.take().unwrap_or_else(|| bug!("invalid attempt to unwrap a None some_curr")) - } - - /// Returns true if the curr span should be skipped because prev has already advanced beyond the - /// end of curr. This can only happen if a prior iteration updated `prev` to skip past a region - /// of code, such as skipping past a closure. - fn prev_starts_after_next(&self, next_curr: &CoverageSpan) -> bool { - self.prev().span.lo() > next_curr.span.lo() - } - - /// Returns true if the curr span starts past the end of the prev span, which means they don't - /// overlap, so we now know the prev can be added to the refined coverage spans. - fn prev_ends_before_curr(&self) -> bool { - self.prev().span.hi() <= self.curr().span.lo() - } - /// If `prev`s span extends left of the closure (`curr`), carve out the closure's span from /// `prev`'s span. (The closure's coverage counters will be injected when processing the /// closure's own MIR.) Add the portion of the span to the left of the closure; and if the span /// extends to the right of the closure, update `prev` to that portion of the span. For any /// `pending_dups`, repeat the same process. fn carve_out_span_for_closure(&mut self) { - let curr_span = self.curr().span; - let left_cutoff = curr_span.lo(); - let right_cutoff = curr_span.hi(); - let has_pre_closure_span = self.prev().span.lo() < right_cutoff; - let has_post_closure_span = self.prev().span.hi() > right_cutoff; - let mut pending_dups = self.pending_dups.split_off(0); + let prev = self.prev(); + let curr = self.curr(); + + let left_cutoff = curr.span.lo(); + let right_cutoff = curr.span.hi(); + let has_pre_closure_span = prev.span.lo() < right_cutoff; + let has_post_closure_span = prev.span.hi() > right_cutoff; + + // Temporarily steal `pending_dups` into a local, so that we can + // mutate and/or drain it while calling other self methods. + let mut pending_dups = std::mem::take(&mut self.pending_dups); + if has_pre_closure_span { let mut pre_closure = self.prev().clone(); pre_closure.span = pre_closure.span.with_hi(left_cutoff); @@ -616,6 +526,7 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { } self.push_refined_span(pre_closure); } + if has_post_closure_span { // Mutate `prev.span()` to start after the closure (and discard curr). // (**NEVER** update `prev_original_span` because it affects the assumptions @@ -626,12 +537,15 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { debug!(" ...and at least one overlapping dup={:?}", dup); dup.span = dup.span.with_lo(right_cutoff); } - self.pending_dups.append(&mut pending_dups); - let closure_covspan = self.take_curr(); + let closure_covspan = self.take_curr(); // Prevent this curr from becoming prev. self.push_refined_span(closure_covspan); // since self.prev() was already updated } else { pending_dups.clear(); } + + // Restore the modified post-closure spans, or the empty vector's capacity. + assert!(self.pending_dups.is_empty()); + self.pending_dups = pending_dups; } /// Called if `curr.span` equals `prev_original_span` (and potentially equal to all @@ -648,26 +562,28 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { /// neither `CoverageSpan` dominates the other, both (or possibly more than two) are held, /// until their disposition is determined. In this latter case, the `prev` dup is moved into /// `pending_dups` so the new `curr` dup can be moved to `prev` for the next iteration. - fn hold_pending_dups_unless_dominated(&mut self) { + fn update_pending_dups(&mut self) { + let prev_bcb = self.prev().bcb; + let curr_bcb = self.curr().bcb; + // Equal coverage spans are ordered by dominators before dominated (if any), so it should be // impossible for `curr` to dominate any previous `CoverageSpan`. - debug_assert!(!self.span_bcb_dominates(self.curr(), self.prev())); + debug_assert!(!self.basic_coverage_blocks.dominates(curr_bcb, prev_bcb)); let initial_pending_count = self.pending_dups.len(); if initial_pending_count > 0 { - let mut pending_dups = self.pending_dups.split_off(0); - pending_dups.retain(|dup| !self.span_bcb_dominates(dup, self.curr())); - self.pending_dups.append(&mut pending_dups); - if self.pending_dups.len() < initial_pending_count { + self.pending_dups + .retain(|dup| !self.basic_coverage_blocks.dominates(dup.bcb, curr_bcb)); + + let n_discarded = initial_pending_count - self.pending_dups.len(); + if n_discarded > 0 { debug!( - " discarded {} of {} pending_dups that dominated curr", - initial_pending_count - self.pending_dups.len(), - initial_pending_count + " discarded {n_discarded} of {initial_pending_count} pending_dups that dominated curr", ); } } - if self.span_bcb_dominates(self.prev(), self.curr()) { + if self.basic_coverage_blocks.dominates(prev_bcb, curr_bcb) { debug!( " different bcbs but SAME spans, and prev dominates curr. Discard prev={:?}", self.prev() @@ -720,7 +636,7 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { if self.pending_dups.is_empty() { let curr_span = self.curr().span; self.prev_mut().cutoff_statements_at(curr_span.lo()); - if self.prev().coverage_statements.is_empty() { + if self.prev().merged_spans.is_empty() { debug!(" ... no non-overlapping statements to add"); } else { debug!(" ... adding modified prev={:?}", self.prev()); @@ -732,109 +648,4 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { self.pending_dups.clear(); } } - - fn span_bcb_dominates(&self, dom_covspan: &CoverageSpan, covspan: &CoverageSpan) -> bool { - self.basic_coverage_blocks.dominates(dom_covspan.bcb, covspan.bcb) - } -} - -/// If the MIR `Statement` has a span contributive to computing coverage spans, -/// return it; otherwise return `None`. -pub(super) fn filtered_statement_span(statement: &Statement<'_>) -> Option<Span> { - match statement.kind { - // These statements have spans that are often outside the scope of the executed source code - // for their parent `BasicBlock`. - StatementKind::StorageLive(_) - | StatementKind::StorageDead(_) - // Coverage should not be encountered, but don't inject coverage coverage - | StatementKind::Coverage(_) - // Ignore `ConstEvalCounter`s - | StatementKind::ConstEvalCounter - // Ignore `Nop`s - | StatementKind::Nop => None, - - // FIXME(#78546): MIR InstrumentCoverage - Can the source_info.span for `FakeRead` - // statements be more consistent? - // - // FakeReadCause::ForGuardBinding, in this example: - // match somenum { - // x if x < 1 => { ... } - // }... - // The BasicBlock within the match arm code included one of these statements, but the span - // for it covered the `1` in this source. The actual statements have nothing to do with that - // source span: - // FakeRead(ForGuardBinding, _4); - // where `_4` is: - // _4 = &_1; (at the span for the first `x`) - // and `_1` is the `Place` for `somenum`. - // - // If and when the Issue is resolved, remove this special case match pattern: - StatementKind::FakeRead(box (FakeReadCause::ForGuardBinding, _)) => None, - - // Retain spans from all other statements - StatementKind::FakeRead(box (_, _)) // Not including `ForGuardBinding` - | StatementKind::Intrinsic(..) - | StatementKind::Assign(_) - | StatementKind::SetDiscriminant { .. } - | StatementKind::Deinit(..) - | StatementKind::Retag(_, _) - | StatementKind::PlaceMention(..) - | StatementKind::AscribeUserType(_, _) => { - Some(statement.source_info.span) - } - } -} - -/// If the MIR `Terminator` has a span contributive to computing coverage spans, -/// return it; otherwise return `None`. -pub(super) fn filtered_terminator_span(terminator: &Terminator<'_>) -> Option<Span> { - match terminator.kind { - // These terminators have spans that don't positively contribute to computing a reasonable - // span of actually executed source code. (For example, SwitchInt terminators extracted from - // an `if condition { block }` has a span that includes the executed block, if true, - // but for coverage, the code region executed, up to *and* through the SwitchInt, - // actually stops before the if's block.) - TerminatorKind::Unreachable // Unreachable blocks are not connected to the MIR CFG - | TerminatorKind::Assert { .. } - | TerminatorKind::Drop { .. } - | TerminatorKind::SwitchInt { .. } - // For `FalseEdge`, only the `real` branch is taken, so it is similar to a `Goto`. - | TerminatorKind::FalseEdge { .. } - | TerminatorKind::Goto { .. } => None, - - // Call `func` operand can have a more specific span when part of a chain of calls - | TerminatorKind::Call { ref func, .. } => { - let mut span = terminator.source_info.span; - if let mir::Operand::Constant(box constant) = func { - if constant.span.lo() > span.lo() { - span = span.with_lo(constant.span.lo()); - } - } - Some(span) - } - - // Retain spans from all other terminators - TerminatorKind::UnwindResume - | TerminatorKind::UnwindTerminate(_) - | TerminatorKind::Return - | TerminatorKind::Yield { .. } - | TerminatorKind::GeneratorDrop - | TerminatorKind::FalseUnwind { .. } - | TerminatorKind::InlineAsm { .. } => { - Some(terminator.source_info.span) - } - } -} - -/// Returns an extrapolated span (pre-expansion[^1]) corresponding to a range -/// within the function's body source. This span is guaranteed to be contained -/// within, or equal to, the `body_span`. If the extrapolated span is not -/// contained within the `body_span`, the `body_span` is returned. -/// -/// [^1]Expansions result from Rust syntax including macros, syntactic sugar, -/// etc.). -#[inline] -pub(super) fn function_source_span(span: Span, body_span: Span) -> Span { - let original_span = original_sp(span, body_span).with_ctxt(body_span.ctxt()); - if body_span.contains(original_span) { original_span } else { body_span } } diff --git a/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs b/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs new file mode 100644 index 000000000..6189e5379 --- /dev/null +++ b/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs @@ -0,0 +1,193 @@ +use rustc_data_structures::captures::Captures; +use rustc_middle::mir::{ + self, AggregateKind, FakeReadCause, Rvalue, Statement, StatementKind, Terminator, + TerminatorKind, +}; +use rustc_span::Span; + +use crate::coverage::graph::{BasicCoverageBlock, BasicCoverageBlockData, CoverageGraph}; +use crate::coverage::spans::CoverageSpan; + +pub(super) fn mir_to_initial_sorted_coverage_spans( + mir_body: &mir::Body<'_>, + fn_sig_span: Span, + body_span: Span, + basic_coverage_blocks: &CoverageGraph, +) -> Vec<CoverageSpan> { + let mut initial_spans = Vec::with_capacity(mir_body.basic_blocks.len() * 2); + for (bcb, bcb_data) in basic_coverage_blocks.iter_enumerated() { + initial_spans.extend(bcb_to_initial_coverage_spans(mir_body, body_span, bcb, bcb_data)); + } + + if initial_spans.is_empty() { + // This can happen if, for example, the function is unreachable (contains only a + // `BasicBlock`(s) with an `Unreachable` terminator). + return initial_spans; + } + + initial_spans.push(CoverageSpan::for_fn_sig(fn_sig_span)); + + initial_spans.sort_by(|a, b| { + // First sort by span start. + Ord::cmp(&a.span.lo(), &b.span.lo()) + // If span starts are the same, sort by span end in reverse order. + // This ensures that if spans A and B are adjacent in the list, + // and they overlap but are not equal, then either: + // - Span A extends further left, or + // - Both have the same start and span A extends further right + .then_with(|| Ord::cmp(&a.span.hi(), &b.span.hi()).reverse()) + // If both spans are equal, sort the BCBs in dominator order, + // so that dominating BCBs come before other BCBs they dominate. + .then_with(|| basic_coverage_blocks.cmp_in_dominator_order(a.bcb, b.bcb)) + // If two spans are otherwise identical, put closure spans first, + // as this seems to be what the refinement step expects. + .then_with(|| Ord::cmp(&a.is_closure, &b.is_closure).reverse()) + }); + + initial_spans +} + +// Generate a set of `CoverageSpan`s from the filtered set of `Statement`s and `Terminator`s of +// the `BasicBlock`(s) in the given `BasicCoverageBlockData`. One `CoverageSpan` is generated +// for each `Statement` and `Terminator`. (Note that subsequent stages of coverage analysis will +// merge some `CoverageSpan`s, at which point a `CoverageSpan` may represent multiple +// `Statement`s and/or `Terminator`s.) +fn bcb_to_initial_coverage_spans<'a, 'tcx>( + mir_body: &'a mir::Body<'tcx>, + body_span: Span, + bcb: BasicCoverageBlock, + bcb_data: &'a BasicCoverageBlockData, +) -> impl Iterator<Item = CoverageSpan> + Captures<'a> + Captures<'tcx> { + bcb_data.basic_blocks.iter().flat_map(move |&bb| { + let data = &mir_body[bb]; + + let statement_spans = data.statements.iter().filter_map(move |statement| { + let expn_span = filtered_statement_span(statement)?; + let span = function_source_span(expn_span, body_span); + + Some(CoverageSpan::new(span, expn_span, bcb, is_closure(statement))) + }); + + let terminator_span = Some(data.terminator()).into_iter().filter_map(move |terminator| { + let expn_span = filtered_terminator_span(terminator)?; + let span = function_source_span(expn_span, body_span); + + Some(CoverageSpan::new(span, expn_span, bcb, false)) + }); + + statement_spans.chain(terminator_span) + }) +} + +fn is_closure(statement: &Statement<'_>) -> bool { + match statement.kind { + StatementKind::Assign(box (_, Rvalue::Aggregate(box ref agg_kind, _))) => match agg_kind { + AggregateKind::Closure(_, _) | AggregateKind::Coroutine(_, _, _) => true, + _ => false, + }, + _ => false, + } +} + +/// If the MIR `Statement` has a span contributive to computing coverage spans, +/// return it; otherwise return `None`. +fn filtered_statement_span(statement: &Statement<'_>) -> Option<Span> { + match statement.kind { + // These statements have spans that are often outside the scope of the executed source code + // for their parent `BasicBlock`. + StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + // Coverage should not be encountered, but don't inject coverage coverage + | StatementKind::Coverage(_) + // Ignore `ConstEvalCounter`s + | StatementKind::ConstEvalCounter + // Ignore `Nop`s + | StatementKind::Nop => None, + + // FIXME(#78546): MIR InstrumentCoverage - Can the source_info.span for `FakeRead` + // statements be more consistent? + // + // FakeReadCause::ForGuardBinding, in this example: + // match somenum { + // x if x < 1 => { ... } + // }... + // The BasicBlock within the match arm code included one of these statements, but the span + // for it covered the `1` in this source. The actual statements have nothing to do with that + // source span: + // FakeRead(ForGuardBinding, _4); + // where `_4` is: + // _4 = &_1; (at the span for the first `x`) + // and `_1` is the `Place` for `somenum`. + // + // If and when the Issue is resolved, remove this special case match pattern: + StatementKind::FakeRead(box (FakeReadCause::ForGuardBinding, _)) => None, + + // Retain spans from all other statements + StatementKind::FakeRead(box (_, _)) // Not including `ForGuardBinding` + | StatementKind::Intrinsic(..) + | StatementKind::Assign(_) + | StatementKind::SetDiscriminant { .. } + | StatementKind::Deinit(..) + | StatementKind::Retag(_, _) + | StatementKind::PlaceMention(..) + | StatementKind::AscribeUserType(_, _) => { + Some(statement.source_info.span) + } + } +} + +/// If the MIR `Terminator` has a span contributive to computing coverage spans, +/// return it; otherwise return `None`. +fn filtered_terminator_span(terminator: &Terminator<'_>) -> Option<Span> { + match terminator.kind { + // These terminators have spans that don't positively contribute to computing a reasonable + // span of actually executed source code. (For example, SwitchInt terminators extracted from + // an `if condition { block }` has a span that includes the executed block, if true, + // but for coverage, the code region executed, up to *and* through the SwitchInt, + // actually stops before the if's block.) + TerminatorKind::Unreachable // Unreachable blocks are not connected to the MIR CFG + | TerminatorKind::Assert { .. } + | TerminatorKind::Drop { .. } + | TerminatorKind::SwitchInt { .. } + // For `FalseEdge`, only the `real` branch is taken, so it is similar to a `Goto`. + | TerminatorKind::FalseEdge { .. } + | TerminatorKind::Goto { .. } => None, + + // Call `func` operand can have a more specific span when part of a chain of calls + | TerminatorKind::Call { ref func, .. } => { + let mut span = terminator.source_info.span; + if let mir::Operand::Constant(box constant) = func { + if constant.span.lo() > span.lo() { + span = span.with_lo(constant.span.lo()); + } + } + Some(span) + } + + // Retain spans from all other terminators + TerminatorKind::UnwindResume + | TerminatorKind::UnwindTerminate(_) + | TerminatorKind::Return + | TerminatorKind::Yield { .. } + | TerminatorKind::CoroutineDrop + | TerminatorKind::FalseUnwind { .. } + | TerminatorKind::InlineAsm { .. } => { + Some(terminator.source_info.span) + } + } +} + +/// Returns an extrapolated span (pre-expansion[^1]) corresponding to a range +/// within the function's body source. This span is guaranteed to be contained +/// within, or equal to, the `body_span`. If the extrapolated span is not +/// contained within the `body_span`, the `body_span` is returned. +/// +/// [^1]Expansions result from Rust syntax including macros, syntactic sugar, +/// etc.). +#[inline] +fn function_source_span(span: Span, body_span: Span) -> Span { + use rustc_span::source_map::original_sp; + + let original_span = original_sp(span, body_span).with_ctxt(body_span.ctxt()); + if body_span.contains(original_span) { original_span } else { body_span } +} diff --git a/compiler/rustc_mir_transform/src/coverage/tests.rs b/compiler/rustc_mir_transform/src/coverage/tests.rs index 4a066ed3a..702fe5f56 100644 --- a/compiler/rustc_mir_transform/src/coverage/tests.rs +++ b/compiler/rustc_mir_transform/src/coverage/tests.rs @@ -25,8 +25,7 @@ //! to: `rustc_span::create_default_session_globals_then(|| { test_here(); })`. use super::counters; -use super::graph; -use super::spans; +use super::graph::{self, BasicCoverageBlock}; use coverage_test_macros::let_bcb; @@ -242,7 +241,7 @@ fn print_coverage_graphviz( " {:?} [label=\"{:?}: {}\"];\n{}", bcb, bcb, - bcb_data.terminator(mir_body).kind.name(), + mir_body[bcb_data.last_bb()].terminator().kind.name(), basic_coverage_blocks .successors(bcb) .map(|successor| { format!(" {:?} -> {:?};", bcb, successor) }) @@ -629,7 +628,7 @@ fn test_traverse_coverage_with_loops() { let basic_coverage_blocks = graph::CoverageGraph::from_mir(&mir_body); let mut traversed_in_order = Vec::new(); let mut traversal = graph::TraverseCoverageGraphWithLoops::new(&basic_coverage_blocks); - while let Some(bcb) = traversal.next(&basic_coverage_blocks) { + while let Some(bcb) = traversal.next() { traversed_in_order.push(bcb); } @@ -644,41 +643,18 @@ fn test_traverse_coverage_with_loops() { ); } -fn synthesize_body_span_from_terminators(mir_body: &Body<'_>) -> Span { - let mut some_span: Option<Span> = None; - for (_, data) in mir_body.basic_blocks.iter_enumerated() { - let term_span = data.terminator().source_info.span; - if let Some(span) = some_span.as_mut() { - *span = span.to(term_span); - } else { - some_span = Some(term_span) - } - } - some_span.expect("body must have at least one BasicBlock") -} - #[test] fn test_make_bcb_counters() { rustc_span::create_default_session_globals_then(|| { let mir_body = goto_switchint(); - let body_span = synthesize_body_span_from_terminators(&mir_body); - let mut basic_coverage_blocks = graph::CoverageGraph::from_mir(&mir_body); - let mut coverage_spans = Vec::new(); - for (bcb, data) in basic_coverage_blocks.iter_enumerated() { - if let Some(span) = spans::filtered_terminator_span(data.terminator(&mir_body)) { - coverage_spans.push(spans::CoverageSpan::for_terminator( - spans::function_source_span(span, body_span), - span, - bcb, - data.last_bb(), - )); - } - } + let basic_coverage_blocks = graph::CoverageGraph::from_mir(&mir_body); + // Historically this test would use `spans` internals to set up fake + // coverage spans for BCBs 1 and 2. Now we skip that step and just tell + // BCB counter construction that those BCBs have spans. + let bcb_has_coverage_spans = |bcb: BasicCoverageBlock| (1..=2).contains(&bcb.as_usize()); let mut coverage_counters = counters::CoverageCounters::new(&basic_coverage_blocks); - coverage_counters - .make_bcb_counters(&mut basic_coverage_blocks, &coverage_spans) - .expect("should be Ok"); - assert_eq!(coverage_counters.intermediate_expressions.len(), 0); + coverage_counters.make_bcb_counters(&basic_coverage_blocks, bcb_has_coverage_spans); + assert_eq!(coverage_counters.num_expressions(), 0); let_bcb!(1); assert_eq!( diff --git a/compiler/rustc_mir_transform/src/cross_crate_inline.rs b/compiler/rustc_mir_transform/src/cross_crate_inline.rs new file mode 100644 index 000000000..261d9dd44 --- /dev/null +++ b/compiler/rustc_mir_transform/src/cross_crate_inline.rs @@ -0,0 +1,130 @@ +use crate::inline; +use crate::pass_manager as pm; +use rustc_attr::InlineAttr; +use rustc_hir::def::DefKind; +use rustc_hir::def_id::LocalDefId; +use rustc_middle::mir::visit::Visitor; +use rustc_middle::mir::*; +use rustc_middle::query::Providers; +use rustc_middle::ty::TyCtxt; +use rustc_session::config::InliningThreshold; +use rustc_session::config::OptLevel; + +pub fn provide(providers: &mut Providers) { + providers.cross_crate_inlinable = cross_crate_inlinable; +} + +fn cross_crate_inlinable(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { + let codegen_fn_attrs = tcx.codegen_fn_attrs(def_id); + // If this has an extern indicator, then this function is globally shared and thus will not + // generate cgu-internal copies which would make it cross-crate inlinable. + if codegen_fn_attrs.contains_extern_indicator() { + return false; + } + + // Obey source annotations first; this is important because it means we can use + // #[inline(never)] to force code generation. + match codegen_fn_attrs.inline { + InlineAttr::Never => return false, + InlineAttr::Hint | InlineAttr::Always => return true, + _ => {} + } + + // This just reproduces the logic from Instance::requires_inline. + match tcx.def_kind(def_id) { + DefKind::Ctor(..) | DefKind::Closure => return true, + DefKind::Fn | DefKind::AssocFn => {} + _ => return false, + } + + // Don't do any inference when incremental compilation is enabled; the additional inlining that + // inference permits also creates more work for small edits. + if tcx.sess.opts.incremental.is_some() { + return false; + } + + // Don't do any inference if codegen optimizations are disabled and also MIR inlining is not + // enabled. This ensures that we do inference even if someone only passes -Zinline-mir, + // which is less confusing than having to also enable -Copt-level=1. + if matches!(tcx.sess.opts.optimize, OptLevel::No) && !pm::should_run_pass(tcx, &inline::Inline) + { + return false; + } + + if !tcx.is_mir_available(def_id) { + return false; + } + + let threshold = match tcx.sess.opts.unstable_opts.cross_crate_inline_threshold { + InliningThreshold::Always => return true, + InliningThreshold::Sometimes(threshold) => threshold, + InliningThreshold::Never => return false, + }; + + let mir = tcx.optimized_mir(def_id); + let mut checker = + CostChecker { tcx, callee_body: mir, calls: 0, statements: 0, landing_pads: 0, resumes: 0 }; + checker.visit_body(mir); + checker.calls == 0 + && checker.resumes == 0 + && checker.landing_pads == 0 + && checker.statements <= threshold +} + +struct CostChecker<'b, 'tcx> { + tcx: TyCtxt<'tcx>, + callee_body: &'b Body<'tcx>, + calls: usize, + statements: usize, + landing_pads: usize, + resumes: usize, +} + +impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> { + fn visit_statement(&mut self, statement: &Statement<'tcx>, _: Location) { + // Don't count StorageLive/StorageDead in the inlining cost. + match statement.kind { + StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::Deinit(_) + | StatementKind::Nop => {} + _ => self.statements += 1, + } + } + + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, _: Location) { + let tcx = self.tcx; + match terminator.kind { + TerminatorKind::Drop { ref place, unwind, .. } => { + let ty = place.ty(self.callee_body, tcx).ty; + if !ty.is_trivially_pure_clone_copy() { + self.calls += 1; + if let UnwindAction::Cleanup(_) = unwind { + self.landing_pads += 1; + } + } + } + TerminatorKind::Call { unwind, .. } => { + self.calls += 1; + if let UnwindAction::Cleanup(_) = unwind { + self.landing_pads += 1; + } + } + TerminatorKind::Assert { unwind, .. } => { + self.calls += 1; + if let UnwindAction::Cleanup(_) = unwind { + self.landing_pads += 1; + } + } + TerminatorKind::UnwindResume => self.resumes += 1, + TerminatorKind::InlineAsm { unwind, .. } => { + self.statements += 1; + if let UnwindAction::Cleanup(_) = unwind { + self.landing_pads += 1; + } + } + TerminatorKind::Return => {} + _ => self.statements += 1, + } + } +} diff --git a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs index 7b14fef61..81d2bba98 100644 --- a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs +++ b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs @@ -2,14 +2,13 @@ //! //! Currently, this pass only propagates scalar values. -use rustc_const_eval::const_eval::CheckAlignment; -use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable}; +use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, PlaceTy, Projectable}; use rustc_data_structures::fx::FxHashMap; use rustc_hir::def::DefKind; use rustc_middle::mir::interpret::{AllocId, ConstAllocation, InterpResult, Scalar}; use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; use rustc_middle::mir::*; -use rustc_middle::ty::layout::TyAndLayout; +use rustc_middle::ty::layout::{LayoutOf, TyAndLayout}; use rustc_middle::ty::{self, Ty, TyCtxt}; use rustc_mir_dataflow::value_analysis::{ Map, PlaceIndex, State, TrackElem, ValueAnalysis, ValueAnalysisWrapper, ValueOrPlace, @@ -17,8 +16,9 @@ use rustc_mir_dataflow::value_analysis::{ use rustc_mir_dataflow::{lattice::FlatSet, Analysis, Results, ResultsVisitor}; use rustc_span::def_id::DefId; use rustc_span::DUMMY_SP; -use rustc_target::abi::{Align, FieldIdx, VariantIdx}; +use rustc_target::abi::{Abi, FieldIdx, Size, VariantIdx, FIRST_VARIANT}; +use crate::const_prop::throw_machine_stop_str; use crate::MirPass; // These constants are somewhat random guesses and have not been optimized. @@ -286,9 +286,9 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { let val = match null_op { NullOp::SizeOf if layout.is_sized() => layout.size.bytes(), NullOp::AlignOf if layout.is_sized() => layout.align.abi.bytes(), - NullOp::OffsetOf(fields) => layout - .offset_of_subfield(&self.ecx, fields.iter().map(|f| f.index())) - .bytes(), + NullOp::OffsetOf(fields) => { + layout.offset_of_subfield(&self.ecx, fields.iter()).bytes() + } _ => return ValueOrPlace::Value(FlatSet::Top), }; FlatSet::Elem(Scalar::from_target_usize(val, &self.tcx)) @@ -406,7 +406,8 @@ impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> { TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).ok(), TrackElem::Discriminant => { let variant = self.ecx.read_discriminant(op).ok()?; - let discr_value = self.ecx.discriminant_for_variant(op.layout, variant).ok()?; + let discr_value = + self.ecx.discriminant_for_variant(op.layout.ty, variant).ok()?; Some(discr_value.into()) } TrackElem::DerefLen => { @@ -507,7 +508,8 @@ impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> { return None; } let enum_ty_layout = self.tcx.layout_of(self.param_env.and(enum_ty)).ok()?; - let discr_value = self.ecx.discriminant_for_variant(enum_ty_layout, variant_index).ok()?; + let discr_value = + self.ecx.discriminant_for_variant(enum_ty_layout.ty, variant_index).ok()?; Some(discr_value.to_scalar()) } @@ -554,16 +556,151 @@ impl<'tcx, 'locals> Collector<'tcx, 'locals> { fn try_make_constant( &self, + ecx: &mut InterpCx<'tcx, 'tcx, DummyMachine>, place: Place<'tcx>, state: &State<FlatSet<Scalar>>, map: &Map, ) -> Option<Const<'tcx>> { - let FlatSet::Elem(Scalar::Int(value)) = state.get(place.as_ref(), &map) else { - return None; - }; let ty = place.ty(self.local_decls, self.patch.tcx).ty; - Some(Const::Val(ConstValue::Scalar(value.into()), ty)) + let layout = ecx.layout_of(ty).ok()?; + + if layout.is_zst() { + return Some(Const::zero_sized(ty)); + } + + if layout.is_unsized() { + return None; + } + + let place = map.find(place.as_ref())?; + if layout.abi.is_scalar() + && let Some(value) = propagatable_scalar(place, state, map) + { + return Some(Const::Val(ConstValue::Scalar(value), ty)); + } + + if matches!(layout.abi, Abi::Scalar(..) | Abi::ScalarPair(..)) { + let alloc_id = ecx + .intern_with_temp_alloc(layout, |ecx, dest| { + try_write_constant(ecx, dest, place, ty, state, map) + }) + .ok()?; + return Some(Const::Val(ConstValue::Indirect { alloc_id, offset: Size::ZERO }, ty)); + } + + None + } +} + +fn propagatable_scalar( + place: PlaceIndex, + state: &State<FlatSet<Scalar>>, + map: &Map, +) -> Option<Scalar> { + if let FlatSet::Elem(value) = state.get_idx(place, map) && value.try_to_int().is_ok() { + // Do not attempt to propagate pointers, as we may fail to preserve their identity. + Some(value) + } else { + None + } +} + +#[instrument(level = "trace", skip(ecx, state, map))] +fn try_write_constant<'tcx>( + ecx: &mut InterpCx<'_, 'tcx, DummyMachine>, + dest: &PlaceTy<'tcx>, + place: PlaceIndex, + ty: Ty<'tcx>, + state: &State<FlatSet<Scalar>>, + map: &Map, +) -> InterpResult<'tcx> { + let layout = ecx.layout_of(ty)?; + + // Fast path for ZSTs. + if layout.is_zst() { + return Ok(()); + } + + // Fast path for scalars. + if layout.abi.is_scalar() + && let Some(value) = propagatable_scalar(place, state, map) + { + return ecx.write_immediate(Immediate::Scalar(value), dest); + } + + match ty.kind() { + // ZSTs. Nothing to do. + ty::FnDef(..) => {} + + // Those are scalars, must be handled above. + ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => throw_machine_stop_str!("primitive type with provenance"), + + ty::Tuple(elem_tys) => { + for (i, elem) in elem_tys.iter().enumerate() { + let Some(field) = map.apply(place, TrackElem::Field(FieldIdx::from_usize(i))) else { + throw_machine_stop_str!("missing field in tuple") + }; + let field_dest = ecx.project_field(dest, i)?; + try_write_constant(ecx, &field_dest, field, elem, state, map)?; + } + } + + ty::Adt(def, args) => { + if def.is_union() { + throw_machine_stop_str!("cannot propagate unions") + } + + let (variant_idx, variant_def, variant_place, variant_dest) = if def.is_enum() { + let Some(discr) = map.apply(place, TrackElem::Discriminant) else { + throw_machine_stop_str!("missing discriminant for enum") + }; + let FlatSet::Elem(Scalar::Int(discr)) = state.get_idx(discr, map) else { + throw_machine_stop_str!("discriminant with provenance") + }; + let discr_bits = discr.assert_bits(discr.size()); + let Some((variant, _)) = def.discriminants(*ecx.tcx).find(|(_, var)| discr_bits == var.val) else { + throw_machine_stop_str!("illegal discriminant for enum") + }; + let Some(variant_place) = map.apply(place, TrackElem::Variant(variant)) else { + throw_machine_stop_str!("missing variant for enum") + }; + let variant_dest = ecx.project_downcast(dest, variant)?; + (variant, def.variant(variant), variant_place, variant_dest) + } else { + (FIRST_VARIANT, def.non_enum_variant(), place, dest.clone()) + }; + + for (i, field) in variant_def.fields.iter_enumerated() { + let ty = field.ty(*ecx.tcx, args); + let Some(field) = map.apply(variant_place, TrackElem::Field(i)) else { + throw_machine_stop_str!("missing field in ADT") + }; + let field_dest = ecx.project_field(&variant_dest, i.as_usize())?; + try_write_constant(ecx, &field_dest, field, ty, state, map)?; + } + ecx.write_discriminant(variant_idx, dest)?; + } + + // Unsupported for now. + ty::Array(_, _) + + // Do not attempt to support indirection in constants. + | ty::Ref(..) | ty::RawPtr(..) | ty::FnPtr(..) | ty::Str | ty::Slice(_) + + | ty::Never + | ty::Foreign(..) + | ty::Alias(..) + | ty::Param(_) + | ty::Bound(..) + | ty::Placeholder(..) + | ty::Closure(..) + | ty::Coroutine(..) + | ty::Dynamic(..) => throw_machine_stop_str!("unsupported type"), + + ty::Error(_) | ty::Infer(..) | ty::CoroutineWitness(..) => bug!(), } + + Ok(()) } impl<'mir, 'tcx> @@ -581,8 +718,13 @@ impl<'mir, 'tcx> ) { match &statement.kind { StatementKind::Assign(box (_, rvalue)) => { - OperandCollector { state, visitor: self, map: &results.analysis.0.map } - .visit_rvalue(rvalue, location); + OperandCollector { + state, + visitor: self, + ecx: &mut results.analysis.0.ecx, + map: &results.analysis.0.map, + } + .visit_rvalue(rvalue, location); } _ => (), } @@ -600,7 +742,12 @@ impl<'mir, 'tcx> // Don't overwrite the assignment if it already uses a constant (to keep the span). } StatementKind::Assign(box (place, _)) => { - if let Some(value) = self.try_make_constant(place, state, &results.analysis.0.map) { + if let Some(value) = self.try_make_constant( + &mut results.analysis.0.ecx, + place, + state, + &results.analysis.0.map, + ) { self.patch.assignments.insert(location, value); } } @@ -615,8 +762,13 @@ impl<'mir, 'tcx> terminator: &'mir Terminator<'tcx>, location: Location, ) { - OperandCollector { state, visitor: self, map: &results.analysis.0.map } - .visit_terminator(terminator, location); + OperandCollector { + state, + visitor: self, + ecx: &mut results.analysis.0.ecx, + map: &results.analysis.0.map, + } + .visit_terminator(terminator, location); } } @@ -671,6 +823,7 @@ impl<'tcx> MutVisitor<'tcx> for Patch<'tcx> { struct OperandCollector<'tcx, 'map, 'locals, 'a> { state: &'a State<FlatSet<Scalar>>, visitor: &'a mut Collector<'tcx, 'locals>, + ecx: &'map mut InterpCx<'tcx, 'tcx, DummyMachine>, map: &'map Map, } @@ -683,7 +836,7 @@ impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> { location: Location, ) { if let PlaceElem::Index(local) = elem - && let Some(value) = self.visitor.try_make_constant(local.into(), self.state, self.map) + && let Some(value) = self.visitor.try_make_constant(self.ecx, local.into(), self.state, self.map) { self.visitor.patch.before_effect.insert((location, local.into()), value); } @@ -691,7 +844,9 @@ impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> { fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) { if let Some(place) = operand.place() { - if let Some(value) = self.visitor.try_make_constant(place, self.state, self.map) { + if let Some(value) = + self.visitor.try_make_constant(self.ecx, place, self.state, self.map) + { self.visitor.patch.before_effect.insert((location, place), value); } else if !place.projection.is_empty() { // Try to propagate into `Index` projections. @@ -701,7 +856,7 @@ impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> { } } -struct DummyMachine; +pub(crate) struct DummyMachine; impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for DummyMachine { rustc_const_eval::interpret::compile_time_machine!(<'mir, 'tcx>); @@ -709,22 +864,12 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm const PANIC_ON_ALLOC_FAIL: bool = true; #[inline(always)] - fn enforce_alignment(_ecx: &InterpCx<'mir, 'tcx, Self>) -> CheckAlignment { - // We do not check for alignment to avoid having to carry an `Align` - // in `ConstValue::ByRef`. - CheckAlignment::No + fn enforce_alignment(_ecx: &InterpCx<'mir, 'tcx, Self>) -> bool { + false // no reason to enforce alignment } fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool { - unimplemented!() - } - fn alignment_check_failed( - _ecx: &InterpCx<'mir, 'tcx, Self>, - _has: Align, - _required: Align, - _check: CheckAlignment, - ) -> interpret::InterpResult<'tcx, ()> { - unimplemented!() + false } fn before_access_global( @@ -736,13 +881,13 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm is_write: bool, ) -> InterpResult<'tcx> { if is_write { - crate::const_prop::throw_machine_stop_str!("can't write to global"); + throw_machine_stop_str!("can't write to global"); } // If the static allocation is mutable, then we can't const prop it as its content // might be different at runtime. if alloc.inner().mutability.is_mut() { - crate::const_prop::throw_machine_stop_str!("can't access mutable globals in ConstProp"); + throw_machine_stop_str!("can't access mutable globals in ConstProp"); } Ok(()) @@ -792,7 +937,7 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm _left: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>, _right: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>, ) -> interpret::InterpResult<'tcx, (ImmTy<'tcx, Self::Provenance>, bool)> { - crate::const_prop::throw_machine_stop_str!("can't do pointer arithmetic"); + throw_machine_stop_str!("can't do pointer arithmetic"); } fn expose_ptr( diff --git a/compiler/rustc_mir_transform/src/dead_store_elimination.rs b/compiler/rustc_mir_transform/src/dead_store_elimination.rs index ef1410504..3d74ef7e3 100644 --- a/compiler/rustc_mir_transform/src/dead_store_elimination.rs +++ b/compiler/rustc_mir_transform/src/dead_store_elimination.rs @@ -13,10 +13,10 @@ //! use crate::util::is_within_packed; -use rustc_index::bit_set::BitSet; use rustc_middle::mir::visit::Visitor; use rustc_middle::mir::*; use rustc_middle::ty::TyCtxt; +use rustc_mir_dataflow::debuginfo::debuginfo_locals; use rustc_mir_dataflow::impls::{ borrowed_locals, LivenessTransferFunction, MaybeTransitiveLiveLocals, }; @@ -26,8 +26,15 @@ use rustc_mir_dataflow::Analysis; /// /// The `borrowed` set must be a `BitSet` of all the locals that are ever borrowed in this body. It /// can be generated via the [`borrowed_locals`] function. -pub fn eliminate<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, borrowed: &BitSet<Local>) { - let mut live = MaybeTransitiveLiveLocals::new(borrowed) +pub fn eliminate<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let borrowed_locals = borrowed_locals(body); + + // If the user requests complete debuginfo, mark the locals that appear in it as live, so + // we don't remove assignements to them. + let mut always_live = debuginfo_locals(body); + always_live.union(&borrowed_locals); + + let mut live = MaybeTransitiveLiveLocals::new(&always_live) .into_engine(tcx, body) .iterate_to_fixpoint() .into_results_cursor(body); @@ -48,7 +55,9 @@ pub fn eliminate<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, borrowed: &BitS for (index, arg) in args.iter().enumerate().rev() { if let Operand::Copy(place) = *arg && !place.is_indirect() - && !borrowed.contains(place.local) + // Do not skip the transformation if the local is in debuginfo, as we do + // not really lose any information for this purpose. + && !borrowed_locals.contains(place.local) && !state.contains(place.local) // If `place` is a projection of a disaligned field in a packed ADT, // the move may be codegened as a pointer to that field. @@ -75,7 +84,7 @@ pub fn eliminate<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, borrowed: &BitS StatementKind::Assign(box (place, _)) | StatementKind::SetDiscriminant { place: box place, .. } | StatementKind::Deinit(box place) => { - if !place.is_indirect() && !borrowed.contains(place.local) { + if !place.is_indirect() && !always_live.contains(place.local) { live.seek_before_primary_effect(loc); if !live.get().contains(place.local) { patch.push(loc); @@ -126,7 +135,6 @@ impl<'tcx> MirPass<'tcx> for DeadStoreElimination { } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let borrowed = borrowed_locals(body); - eliminate(tcx, body, &borrowed); + eliminate(tcx, body); } } diff --git a/compiler/rustc_mir_transform/src/deduce_param_attrs.rs b/compiler/rustc_mir_transform/src/deduce_param_attrs.rs index 79645310a..990cfb05e 100644 --- a/compiler/rustc_mir_transform/src/deduce_param_attrs.rs +++ b/compiler/rustc_mir_transform/src/deduce_param_attrs.rs @@ -44,6 +44,7 @@ impl<'tcx> Visitor<'tcx> for DeduceReadOnly { // Whether mutating though a `&raw const` is allowed is still undecided, so we // disable any sketchy `readonly` optimizations for now. // But we only need to do this if the pointer would point into the argument. + // IOW: for indirect places, like `&raw (*local).field`, this surely cannot mutate `local`. !place.is_indirect() } PlaceContext::NonMutatingUse(..) | PlaceContext::NonUse(..) => { diff --git a/compiler/rustc_mir_transform/src/deref_separator.rs b/compiler/rustc_mir_transform/src/deref_separator.rs index 95898b5b7..42be74570 100644 --- a/compiler/rustc_mir_transform/src/deref_separator.rs +++ b/compiler/rustc_mir_transform/src/deref_separator.rs @@ -37,7 +37,7 @@ impl<'a, 'tcx> MutVisitor<'tcx> for DerefChecker<'a, 'tcx> { for (idx, (p_ref, p_elem)) in place.iter_projections().enumerate() { if !p_ref.projection.is_empty() && p_elem == ProjectionElem::Deref { let ty = p_ref.ty(self.local_decls, self.tcx).ty; - let temp = self.patcher.new_internal_with_info( + let temp = self.patcher.new_local_with_info( ty, self.local_decls[p_ref.local].source_info.span, LocalInfo::DerefTemp, diff --git a/compiler/rustc_mir_transform/src/dest_prop.rs b/compiler/rustc_mir_transform/src/dest_prop.rs index d9a132e5c..15502adfb 100644 --- a/compiler/rustc_mir_transform/src/dest_prop.rs +++ b/compiler/rustc_mir_transform/src/dest_prop.rs @@ -114,7 +114,7 @@ //! approach that only works for some classes of CFGs: //! - rustc now has a powerful dataflow analysis framework that can handle forwards and backwards //! analyses efficiently. -//! - Layout optimizations for generators have been added to improve code generation for +//! - Layout optimizations for coroutines have been added to improve code generation for //! async/await, which are very similar in spirit to what this optimization does. //! //! Also, rustc now has a simple NRVO pass (see `nrvo.rs`), which handles a subset of the cases that @@ -244,7 +244,7 @@ impl<'tcx> MirPass<'tcx> for DestinationPropagation { if round_count != 0 { // Merging can introduce overlap between moved arguments and/or call destination in an // unreachable code, which validator considers to be ill-formed. - remove_dead_blocks(tcx, body); + remove_dead_blocks(body); } trace!(round_count); @@ -655,7 +655,7 @@ impl WriteInfo { // `Drop`s create a `&mut` and so are not considered } TerminatorKind::Yield { .. } - | TerminatorKind::GeneratorDrop + | TerminatorKind::CoroutineDrop | TerminatorKind::FalseEdge { .. } | TerminatorKind::FalseUnwind { .. } => { bug!("{:?} not found in this MIR phase", terminator) diff --git a/compiler/rustc_mir_transform/src/early_otherwise_branch.rs b/compiler/rustc_mir_transform/src/early_otherwise_branch.rs index 319fb4eaf..6eb6cb069 100644 --- a/compiler/rustc_mir_transform/src/early_otherwise_branch.rs +++ b/compiler/rustc_mir_transform/src/early_otherwise_branch.rs @@ -95,6 +95,7 @@ pub struct EarlyOtherwiseBranch; impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + // unsound: https://github.com/rust-lang/rust/issues/95162 sess.mir_opt_level() >= 3 && sess.opts.unstable_opts.unsound_mir_opts } diff --git a/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs b/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs index e51f771e0..1c917a85c 100644 --- a/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs +++ b/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs @@ -69,7 +69,7 @@ impl<'tcx, 'a> MutVisitor<'tcx> for ElaborateBoxDerefVisitor<'tcx, 'a> { let (unique_ty, nonnull_ty, ptr_ty) = build_ptr_tys(tcx, base_ty.boxed_ty(), self.unique_did, self.nonnull_did); - let ptr_local = self.patch.new_internal(ptr_ty, source_info.span); + let ptr_local = self.patch.new_temp(ptr_ty, source_info.span); self.patch.add_assign( location, diff --git a/compiler/rustc_mir_transform/src/elaborate_drops.rs b/compiler/rustc_mir_transform/src/elaborate_drops.rs index b62d7da2a..59156b242 100644 --- a/compiler/rustc_mir_transform/src/elaborate_drops.rs +++ b/compiler/rustc_mir_transform/src/elaborate_drops.rs @@ -9,9 +9,9 @@ use rustc_mir_dataflow::elaborate_drops::{elaborate_drop, DropFlagState, Unwind} use rustc_mir_dataflow::elaborate_drops::{DropElaborator, DropFlagMode, DropStyle}; use rustc_mir_dataflow::impls::{MaybeInitializedPlaces, MaybeUninitializedPlaces}; use rustc_mir_dataflow::move_paths::{LookupResult, MoveData, MovePathIndex}; +use rustc_mir_dataflow::on_all_children_bits; use rustc_mir_dataflow::on_lookup_result_bits; use rustc_mir_dataflow::MoveDataParamEnv; -use rustc_mir_dataflow::{on_all_children_bits, on_all_drop_children_bits}; use rustc_mir_dataflow::{Analysis, ResultsCursor}; use rustc_span::Span; use rustc_target::abi::{FieldIdx, VariantIdx}; @@ -54,16 +54,10 @@ impl<'tcx> MirPass<'tcx> for ElaborateDrops { let def_id = body.source.def_id(); let param_env = tcx.param_env_reveal_all_normalized(def_id); - let move_data = match MoveData::gather_moves(body, tcx, param_env) { - Ok(move_data) => move_data, - Err((move_data, _)) => { - tcx.sess.delay_span_bug( - body.span, - "No `move_errors` should be allowed in MIR borrowck", - ); - move_data - } - }; + // For types that do not need dropping, the behaviour is trivial. So we only need to track + // init/uninit for types that do need dropping. + let move_data = + MoveData::gather_moves(&body, tcx, param_env, |ty| ty.needs_drop(tcx, param_env)); let elaborate_patch = { let env = MoveDataParamEnv { move_data, param_env }; @@ -178,13 +172,19 @@ impl<'a, 'tcx> DropElaborator<'a, 'tcx> for Elaborator<'a, '_, 'tcx> { let mut some_live = false; let mut some_dead = false; let mut children_count = 0; - on_all_drop_children_bits(self.tcx(), self.body(), self.ctxt.env, path, |child| { - let (live, dead) = self.ctxt.init_data.maybe_live_dead(child); - debug!("elaborate_drop: state({:?}) = {:?}", child, (live, dead)); - some_live |= live; - some_dead |= dead; - children_count += 1; - }); + on_all_children_bits( + self.tcx(), + self.body(), + self.ctxt.move_data(), + path, + |child| { + let (live, dead) = self.ctxt.init_data.maybe_live_dead(child); + debug!("elaborate_drop: state({:?}) = {:?}", child, (live, dead)); + some_live |= live; + some_dead |= dead; + children_count += 1; + }, + ); ((some_live, some_dead), children_count != 1) } }; @@ -271,7 +271,7 @@ impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> { let tcx = self.tcx; let patch = &mut self.patch; debug!("create_drop_flag({:?})", self.body.span); - self.drop_flags[index].get_or_insert_with(|| patch.new_internal(tcx.types.bool, span)); + self.drop_flags[index].get_or_insert_with(|| patch.new_temp(tcx.types.bool, span)); } fn drop_flag(&mut self, index: MovePathIndex) -> Option<Place<'tcx>> { @@ -296,26 +296,36 @@ impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> { fn collect_drop_flags(&mut self) { for (bb, data) in self.body.basic_blocks.iter_enumerated() { let terminator = data.terminator(); - let place = match terminator.kind { - TerminatorKind::Drop { ref place, .. } => place, - _ => continue, - }; - - self.init_data.seek_before(self.body.terminator_loc(bb)); + let TerminatorKind::Drop { ref place, .. } = terminator.kind else { continue }; let path = self.move_data().rev_lookup.find(place.as_ref()); debug!("collect_drop_flags: {:?}, place {:?} ({:?})", bb, place, path); - let path = match path { - LookupResult::Exact(e) => e, - LookupResult::Parent(None) => continue, + match path { + LookupResult::Exact(path) => { + self.init_data.seek_before(self.body.terminator_loc(bb)); + on_all_children_bits(self.tcx, self.body, self.move_data(), path, |child| { + let (maybe_live, maybe_dead) = self.init_data.maybe_live_dead(child); + debug!( + "collect_drop_flags: collecting {:?} from {:?}@{:?} - {:?}", + child, + place, + path, + (maybe_live, maybe_dead) + ); + if maybe_live && maybe_dead { + self.create_drop_flag(child, terminator.source_info.span) + } + }); + } + LookupResult::Parent(None) => {} LookupResult::Parent(Some(parent)) => { - let (_maybe_live, maybe_dead) = self.init_data.maybe_live_dead(parent); - if self.body.local_decls[place.local].is_deref_temp() { continue; } + self.init_data.seek_before(self.body.terminator_loc(bb)); + let (_maybe_live, maybe_dead) = self.init_data.maybe_live_dead(parent); if maybe_dead { self.tcx.sess.delay_span_bug( terminator.source_info.span, @@ -324,80 +334,74 @@ impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> { ), ); } - continue; } }; - - on_all_drop_children_bits(self.tcx, self.body, self.env, path, |child| { - let (maybe_live, maybe_dead) = self.init_data.maybe_live_dead(child); - debug!( - "collect_drop_flags: collecting {:?} from {:?}@{:?} - {:?}", - child, - place, - path, - (maybe_live, maybe_dead) - ); - if maybe_live && maybe_dead { - self.create_drop_flag(child, terminator.source_info.span) - } - }); } } fn elaborate_drops(&mut self) { + // This function should mirror what `collect_drop_flags` does. for (bb, data) in self.body.basic_blocks.iter_enumerated() { - let loc = Location { block: bb, statement_index: data.statements.len() }; let terminator = data.terminator(); + let TerminatorKind::Drop { place, target, unwind, replace } = terminator.kind else { + continue; + }; - match terminator.kind { - TerminatorKind::Drop { place, target, unwind, replace } => { - self.init_data.seek_before(loc); - match self.move_data().rev_lookup.find(place.as_ref()) { - LookupResult::Exact(path) => { - let unwind = if data.is_cleanup { - Unwind::InCleanup - } else { - match unwind { - UnwindAction::Cleanup(cleanup) => Unwind::To(cleanup), - UnwindAction::Continue => Unwind::To(self.patch.resume_block()), - UnwindAction::Unreachable => { - Unwind::To(self.patch.unreachable_cleanup_block()) - } - UnwindAction::Terminate(reason) => { - debug_assert_ne!( - reason, - UnwindTerminateReason::InCleanup, - "we are not in a cleanup block, InCleanup reason should be impossible" - ); - Unwind::To(self.patch.terminate_block(reason)) - } - } - }; - elaborate_drop( - &mut Elaborator { ctxt: self }, - terminator.source_info, - place, - path, - target, - unwind, - bb, - ) + // This place does not need dropping. It does not have an associated move-path, so the + // match below will conservatively keep an unconditional drop. As that drop is useless, + // just remove it here and now. + if !place + .ty(&self.body.local_decls, self.tcx) + .ty + .needs_drop(self.tcx, self.env.param_env) + { + self.patch.patch_terminator(bb, TerminatorKind::Goto { target }); + continue; + } + + let path = self.move_data().rev_lookup.find(place.as_ref()); + match path { + LookupResult::Exact(path) => { + let unwind = match unwind { + _ if data.is_cleanup => Unwind::InCleanup, + UnwindAction::Cleanup(cleanup) => Unwind::To(cleanup), + UnwindAction::Continue => Unwind::To(self.patch.resume_block()), + UnwindAction::Unreachable => { + Unwind::To(self.patch.unreachable_cleanup_block()) } - LookupResult::Parent(..) => { - if !replace { - self.tcx.sess.delay_span_bug( - terminator.source_info.span, - format!("drop of untracked value {bb:?}"), - ); - } - // A drop and replace behind a pointer/array/whatever. - // The borrow checker requires that these locations are initialized before the assignment, - // so we just leave an unconditional drop. - assert!(!data.is_cleanup); + UnwindAction::Terminate(reason) => { + debug_assert_ne!( + reason, + UnwindTerminateReason::InCleanup, + "we are not in a cleanup block, InCleanup reason should be impossible" + ); + Unwind::To(self.patch.terminate_block(reason)) } + }; + self.init_data.seek_before(self.body.terminator_loc(bb)); + elaborate_drop( + &mut Elaborator { ctxt: self }, + terminator.source_info, + place, + path, + target, + unwind, + bb, + ) + } + LookupResult::Parent(None) => {} + LookupResult::Parent(Some(_)) => { + if !replace { + self.tcx.sess.delay_span_bug( + terminator.source_info.span, + format!("drop of untracked value {bb:?}"), + ); } + // A drop and replace behind a pointer/array/whatever. + // The borrow checker requires that these locations are initialized before the assignment, + // so we just leave an unconditional drop. + assert!(!data.is_cleanup); } - _ => continue, } } } diff --git a/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs b/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs index d20286084..26fcfad82 100644 --- a/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs +++ b/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs @@ -58,7 +58,7 @@ fn has_ffi_unwind_calls(tcx: TyCtxt<'_>, local_def_id: LocalDefId) -> bool { let body_abi = match body_ty.kind() { ty::FnDef(..) => body_ty.fn_sig(tcx).abi(), ty::Closure(..) => Abi::RustCall, - ty::Generator(..) => Abi::Rust, + ty::Coroutine(..) => Abi::Rust, _ => span_bug!(body.span, "unexpected body ty: {:?}", body_ty), }; let body_can_unwind = layout::fn_can_unwind(tcx, Some(def_id), body_abi); diff --git a/compiler/rustc_mir_transform/src/gvn.rs b/compiler/rustc_mir_transform/src/gvn.rs index 56bdc5a17..dce298e92 100644 --- a/compiler/rustc_mir_transform/src/gvn.rs +++ b/compiler/rustc_mir_transform/src/gvn.rs @@ -52,19 +52,59 @@ //! _a = *_b // _b is &Freeze //! _c = *_b // replaced by _c = _a //! ``` +//! +//! # Determinism of constant propagation +//! +//! When registering a new `Value`, we attempt to opportunistically evaluate it as a constant. +//! The evaluated form is inserted in `evaluated` as an `OpTy` or `None` if evaluation failed. +//! +//! The difficulty is non-deterministic evaluation of MIR constants. Some `Const` can have +//! different runtime values each time they are evaluated. This is the case with +//! `Const::Slice` which have a new pointer each time they are evaluated, and constants that +//! contain a fn pointer (`AllocId` pointing to a `GlobalAlloc::Function`) pointing to a different +//! symbol in each codegen unit. +//! +//! Meanwhile, we want to be able to read indirect constants. For instance: +//! ``` +//! static A: &'static &'static u8 = &&63; +//! fn foo() -> u8 { +//! **A // We want to replace by 63. +//! } +//! fn bar() -> u8 { +//! b"abc"[1] // We want to replace by 'b'. +//! } +//! ``` +//! +//! The `Value::Constant` variant stores a possibly unevaluated constant. Evaluating that constant +//! may be non-deterministic. When that happens, we assign a disambiguator to ensure that we do not +//! merge the constants. See `duplicate_slice` test in `gvn.rs`. +//! +//! Second, when writing constants in MIR, we do not write `Const::Slice` or `Const` +//! that contain `AllocId`s. +use rustc_const_eval::interpret::{intern_const_alloc_for_constprop, MemoryKind}; +use rustc_const_eval::interpret::{ImmTy, InterpCx, OpTy, Projectable, Scalar}; use rustc_data_structures::fx::{FxHashMap, FxIndexSet}; use rustc_data_structures::graph::dominators::Dominators; +use rustc_hir::def::DefKind; use rustc_index::bit_set::BitSet; use rustc_index::IndexVec; use rustc_macros::newtype_index; +use rustc_middle::mir::interpret::GlobalAlloc; use rustc_middle::mir::visit::*; use rustc_middle::mir::*; -use rustc_middle::ty::{self, Ty, TyCtxt}; -use rustc_target::abi::{VariantIdx, FIRST_VARIANT}; +use rustc_middle::ty::adjustment::PointerCoercion; +use rustc_middle::ty::layout::LayoutOf; +use rustc_middle::ty::{self, Ty, TyCtxt, TypeAndMut}; +use rustc_span::def_id::DefId; +use rustc_span::DUMMY_SP; +use rustc_target::abi::{self, Abi, Size, VariantIdx, FIRST_VARIANT}; +use std::borrow::Cow; -use crate::ssa::SsaLocals; +use crate::dataflow_const_prop::DummyMachine; +use crate::ssa::{AssignedValue, SsaLocals}; use crate::MirPass; +use either::Either; pub struct GVN; @@ -87,21 +127,28 @@ fn propagate_ssa<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let dominators = body.basic_blocks.dominators().clone(); let mut state = VnState::new(tcx, param_env, &ssa, &dominators, &body.local_decls); - for arg in body.args_iter() { - if ssa.is_ssa(arg) { - let value = state.new_opaque().unwrap(); - state.assign(arg, value); - } - } - - ssa.for_each_assignment_mut(&mut body.basic_blocks, |local, rvalue, location| { - let value = state.simplify_rvalue(rvalue, location).or_else(|| state.new_opaque()).unwrap(); - // FIXME(#112651) `rvalue` may have a subtype to `local`. We can only mark `local` as - // reusable if we have an exact type match. - if state.local_decls[local].ty == rvalue.ty(state.local_decls, tcx) { + ssa.for_each_assignment_mut( + body.basic_blocks.as_mut_preserves_cfg(), + |local, value, location| { + let value = match value { + // We do not know anything of this assigned value. + AssignedValue::Arg | AssignedValue::Terminator(_) => None, + // Try to get some insight. + AssignedValue::Rvalue(rvalue) => { + let value = state.simplify_rvalue(rvalue, location); + // FIXME(#112651) `rvalue` may have a subtype to `local`. We can only mark `local` as + // reusable if we have an exact type match. + if state.local_decls[local].ty != rvalue.ty(state.local_decls, tcx) { + return; + } + value + } + }; + // `next_opaque` is `Some`, so `new_opaque` must return `Some`. + let value = value.or_else(|| state.new_opaque()).unwrap(); state.assign(local, value); - } - }); + }, + ); // Stop creating opaques during replacement as it is useless. state.next_opaque = None; @@ -111,22 +158,33 @@ fn propagate_ssa<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let data = &mut body.basic_blocks.as_mut_preserves_cfg()[bb]; state.visit_basic_block_data(bb, data); } - let any_replacement = state.any_replacement; // For each local that is reused (`y` above), we remove its storage statements do avoid any // difficulty. Those locals are SSA, so should be easy to optimize by LLVM without storage // statements. StorageRemover { tcx, reused_locals: state.reused_locals }.visit_body_preserves_cfg(body); - - if any_replacement { - crate::simplify::remove_unused_definitions(body); - } } newtype_index! { struct VnIndex {} } +/// Computing the aggregate's type can be quite slow, so we only keep the minimal amount of +/// information to reconstruct it when needed. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +enum AggregateTy<'tcx> { + /// Invariant: this must not be used for an empty array. + Array, + Tuple, + Def(DefId, ty::GenericArgsRef<'tcx>), +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +enum AddressKind { + Ref(BorrowKind), + Address(Mutability), +} + #[derive(Debug, PartialEq, Eq, Hash)] enum Value<'tcx> { // Root values. @@ -134,15 +192,21 @@ enum Value<'tcx> { /// The `usize` is a counter incremented by `new_opaque`. Opaque(usize), /// Evaluated or unevaluated constant value. - Constant(Const<'tcx>), + Constant { + value: Const<'tcx>, + /// Some constants do not have a deterministic value. To avoid merging two instances of the + /// same `Const`, we assign them an additional integer index. + disambiguator: usize, + }, /// An aggregate value, either tuple/closure/struct/enum. /// This does not contain unions, as we cannot reason with the value. - Aggregate(Ty<'tcx>, VariantIdx, Vec<VnIndex>), + Aggregate(AggregateTy<'tcx>, VariantIdx, Vec<VnIndex>), /// This corresponds to a `[value; count]` expression. Repeat(VnIndex, ty::Const<'tcx>), /// The address of a place. Address { place: Place<'tcx>, + kind: AddressKind, /// Give each borrow and pointer a different provenance, so we don't merge them. provenance: usize, }, @@ -170,6 +234,7 @@ enum Value<'tcx> { struct VnState<'body, 'tcx> { tcx: TyCtxt<'tcx>, + ecx: InterpCx<'tcx, 'tcx, DummyMachine>, param_env: ty::ParamEnv<'tcx>, local_decls: &'body LocalDecls<'tcx>, /// Value stored in each local. @@ -177,13 +242,14 @@ struct VnState<'body, 'tcx> { /// First local to be assigned that value. rev_locals: FxHashMap<VnIndex, Vec<Local>>, values: FxIndexSet<Value<'tcx>>, + /// Values evaluated as constants if possible. + evaluated: IndexVec<VnIndex, Option<OpTy<'tcx>>>, /// Counter to generate different values. /// This is an option to stop creating opaques during replacement. next_opaque: Option<usize>, ssa: &'body SsaLocals, dominators: &'body Dominators<BasicBlock>, reused_locals: BitSet<Local>, - any_replacement: bool, } impl<'body, 'tcx> VnState<'body, 'tcx> { @@ -196,23 +262,30 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { ) -> Self { VnState { tcx, + ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine), param_env, local_decls, locals: IndexVec::from_elem(None, local_decls), rev_locals: FxHashMap::default(), values: FxIndexSet::default(), + evaluated: IndexVec::new(), next_opaque: Some(0), ssa, dominators, reused_locals: BitSet::new_empty(local_decls.len()), - any_replacement: false, } } #[instrument(level = "trace", skip(self), ret)] fn insert(&mut self, value: Value<'tcx>) -> VnIndex { - let (index, _) = self.values.insert_full(value); - VnIndex::from_usize(index) + let (index, new) = self.values.insert_full(value); + let index = VnIndex::from_usize(index); + if new { + let evaluated = self.eval_to_const(index); + let _index = self.evaluated.push(evaluated); + debug_assert_eq!(index, _index); + } + index } /// Create a new `Value` for which we have no information at all, except that it is distinct @@ -227,9 +300,9 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { /// Create a new `Value::Address` distinct from all the others. #[instrument(level = "trace", skip(self), ret)] - fn new_pointer(&mut self, place: Place<'tcx>) -> Option<VnIndex> { + fn new_pointer(&mut self, place: Place<'tcx>, kind: AddressKind) -> Option<VnIndex> { let next_opaque = self.next_opaque.as_mut()?; - let value = Value::Address { place, provenance: *next_opaque }; + let value = Value::Address { place, kind, provenance: *next_opaque }; *next_opaque += 1; Some(self.insert(value)) } @@ -251,6 +324,343 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { } } + fn insert_constant(&mut self, value: Const<'tcx>) -> Option<VnIndex> { + let disambiguator = if value.is_deterministic() { + // The constant is deterministic, no need to disambiguate. + 0 + } else { + // Multiple mentions of this constant will yield different values, + // so assign a different `disambiguator` to ensure they do not get the same `VnIndex`. + let next_opaque = self.next_opaque.as_mut()?; + let disambiguator = *next_opaque; + *next_opaque += 1; + disambiguator + }; + Some(self.insert(Value::Constant { value, disambiguator })) + } + + fn insert_scalar(&mut self, scalar: Scalar, ty: Ty<'tcx>) -> VnIndex { + self.insert_constant(Const::from_scalar(self.tcx, scalar, ty)) + .expect("scalars are deterministic") + } + + #[instrument(level = "trace", skip(self), ret)] + fn eval_to_const(&mut self, value: VnIndex) -> Option<OpTy<'tcx>> { + use Value::*; + let op = match *self.get(value) { + Opaque(_) => return None, + // Do not bother evaluating repeat expressions. This would uselessly consume memory. + Repeat(..) => return None, + + Constant { ref value, disambiguator: _ } => { + self.ecx.eval_mir_constant(value, None, None).ok()? + } + Aggregate(kind, variant, ref fields) => { + let fields = fields + .iter() + .map(|&f| self.evaluated[f].as_ref()) + .collect::<Option<Vec<_>>>()?; + let ty = match kind { + AggregateTy::Array => { + assert!(fields.len() > 0); + Ty::new_array(self.tcx, fields[0].layout.ty, fields.len() as u64) + } + AggregateTy::Tuple => { + Ty::new_tup_from_iter(self.tcx, fields.iter().map(|f| f.layout.ty)) + } + AggregateTy::Def(def_id, args) => { + self.tcx.type_of(def_id).instantiate(self.tcx, args) + } + }; + let variant = if ty.is_enum() { Some(variant) } else { None }; + let ty = self.ecx.layout_of(ty).ok()?; + if ty.is_zst() { + ImmTy::uninit(ty).into() + } else if matches!(ty.abi, Abi::Scalar(..) | Abi::ScalarPair(..)) { + let dest = self.ecx.allocate(ty, MemoryKind::Stack).ok()?; + let variant_dest = if let Some(variant) = variant { + self.ecx.project_downcast(&dest, variant).ok()? + } else { + dest.clone() + }; + for (field_index, op) in fields.into_iter().enumerate() { + let field_dest = self.ecx.project_field(&variant_dest, field_index).ok()?; + self.ecx.copy_op(op, &field_dest, /*allow_transmute*/ false).ok()?; + } + self.ecx.write_discriminant(variant.unwrap_or(FIRST_VARIANT), &dest).ok()?; + self.ecx.alloc_mark_immutable(dest.ptr().provenance.unwrap()).ok()?; + dest.into() + } else { + return None; + } + } + + Projection(base, elem) => { + let value = self.evaluated[base].as_ref()?; + let elem = match elem { + ProjectionElem::Deref => ProjectionElem::Deref, + ProjectionElem::Downcast(name, read_variant) => { + ProjectionElem::Downcast(name, read_variant) + } + ProjectionElem::Field(f, ty) => ProjectionElem::Field(f, ty), + ProjectionElem::ConstantIndex { offset, min_length, from_end } => { + ProjectionElem::ConstantIndex { offset, min_length, from_end } + } + ProjectionElem::Subslice { from, to, from_end } => { + ProjectionElem::Subslice { from, to, from_end } + } + ProjectionElem::OpaqueCast(ty) => ProjectionElem::OpaqueCast(ty), + ProjectionElem::Subtype(ty) => ProjectionElem::Subtype(ty), + // This should have been replaced by a `ConstantIndex` earlier. + ProjectionElem::Index(_) => return None, + }; + self.ecx.project(value, elem).ok()? + } + Address { place, kind, provenance: _ } => { + if !place.is_indirect_first_projection() { + return None; + } + let local = self.locals[place.local]?; + let pointer = self.evaluated[local].as_ref()?; + let mut mplace = self.ecx.deref_pointer(pointer).ok()?; + for proj in place.projection.iter().skip(1) { + // We have no call stack to associate a local with a value, so we cannot interpret indexing. + if matches!(proj, ProjectionElem::Index(_)) { + return None; + } + mplace = self.ecx.project(&mplace, proj).ok()?; + } + let pointer = mplace.to_ref(&self.ecx); + let ty = match kind { + AddressKind::Ref(bk) => Ty::new_ref( + self.tcx, + self.tcx.lifetimes.re_erased, + ty::TypeAndMut { ty: mplace.layout.ty, mutbl: bk.to_mutbl_lossy() }, + ), + AddressKind::Address(mutbl) => { + Ty::new_ptr(self.tcx, TypeAndMut { ty: mplace.layout.ty, mutbl }) + } + }; + let layout = self.ecx.layout_of(ty).ok()?; + ImmTy::from_immediate(pointer, layout).into() + } + + Discriminant(base) => { + let base = self.evaluated[base].as_ref()?; + let variant = self.ecx.read_discriminant(base).ok()?; + let discr_value = + self.ecx.discriminant_for_variant(base.layout.ty, variant).ok()?; + discr_value.into() + } + Len(slice) => { + let slice = self.evaluated[slice].as_ref()?; + let usize_layout = self.ecx.layout_of(self.tcx.types.usize).unwrap(); + let len = slice.len(&self.ecx).ok()?; + let imm = ImmTy::try_from_uint(len, usize_layout)?; + imm.into() + } + NullaryOp(null_op, ty) => { + let layout = self.ecx.layout_of(ty).ok()?; + if let NullOp::SizeOf | NullOp::AlignOf = null_op && layout.is_unsized() { + return None; + } + let val = match null_op { + NullOp::SizeOf => layout.size.bytes(), + NullOp::AlignOf => layout.align.abi.bytes(), + NullOp::OffsetOf(fields) => { + layout.offset_of_subfield(&self.ecx, fields.iter()).bytes() + } + }; + let usize_layout = self.ecx.layout_of(self.tcx.types.usize).unwrap(); + let imm = ImmTy::try_from_uint(val, usize_layout)?; + imm.into() + } + UnaryOp(un_op, operand) => { + let operand = self.evaluated[operand].as_ref()?; + let operand = self.ecx.read_immediate(operand).ok()?; + let (val, _) = self.ecx.overflowing_unary_op(un_op, &operand).ok()?; + val.into() + } + BinaryOp(bin_op, lhs, rhs) => { + let lhs = self.evaluated[lhs].as_ref()?; + let lhs = self.ecx.read_immediate(lhs).ok()?; + let rhs = self.evaluated[rhs].as_ref()?; + let rhs = self.ecx.read_immediate(rhs).ok()?; + let (val, _) = self.ecx.overflowing_binary_op(bin_op, &lhs, &rhs).ok()?; + val.into() + } + CheckedBinaryOp(bin_op, lhs, rhs) => { + let lhs = self.evaluated[lhs].as_ref()?; + let lhs = self.ecx.read_immediate(lhs).ok()?; + let rhs = self.evaluated[rhs].as_ref()?; + let rhs = self.ecx.read_immediate(rhs).ok()?; + let (val, overflowed) = self.ecx.overflowing_binary_op(bin_op, &lhs, &rhs).ok()?; + let tuple = Ty::new_tup_from_iter( + self.tcx, + [val.layout.ty, self.tcx.types.bool].into_iter(), + ); + let tuple = self.ecx.layout_of(tuple).ok()?; + ImmTy::from_scalar_pair(val.to_scalar(), Scalar::from_bool(overflowed), tuple) + .into() + } + Cast { kind, value, from: _, to } => match kind { + CastKind::IntToInt | CastKind::IntToFloat => { + let value = self.evaluated[value].as_ref()?; + let value = self.ecx.read_immediate(value).ok()?; + let to = self.ecx.layout_of(to).ok()?; + let res = self.ecx.int_to_int_or_float(&value, to).ok()?; + res.into() + } + CastKind::FloatToFloat | CastKind::FloatToInt => { + let value = self.evaluated[value].as_ref()?; + let value = self.ecx.read_immediate(value).ok()?; + let to = self.ecx.layout_of(to).ok()?; + let res = self.ecx.float_to_float_or_int(&value, to).ok()?; + res.into() + } + CastKind::Transmute => { + let value = self.evaluated[value].as_ref()?; + let to = self.ecx.layout_of(to).ok()?; + // `offset` for immediates only supports scalar/scalar-pair ABIs, + // so bail out if the target is not one. + if value.as_mplace_or_imm().is_right() { + match (value.layout.abi, to.abi) { + (Abi::Scalar(..), Abi::Scalar(..)) => {} + (Abi::ScalarPair(..), Abi::ScalarPair(..)) => {} + _ => return None, + } + } + value.offset(Size::ZERO, to, &self.ecx).ok()? + } + _ => return None, + }, + }; + Some(op) + } + + fn project( + &mut self, + place: PlaceRef<'tcx>, + value: VnIndex, + proj: PlaceElem<'tcx>, + ) -> Option<VnIndex> { + let proj = match proj { + ProjectionElem::Deref => { + let ty = place.ty(self.local_decls, self.tcx).ty; + if let Some(Mutability::Not) = ty.ref_mutability() + && let Some(pointee_ty) = ty.builtin_deref(true) + && pointee_ty.ty.is_freeze(self.tcx, self.param_env) + { + // An immutable borrow `_x` always points to the same value for the + // lifetime of the borrow, so we can merge all instances of `*_x`. + ProjectionElem::Deref + } else { + return None; + } + } + ProjectionElem::Downcast(name, index) => ProjectionElem::Downcast(name, index), + ProjectionElem::Field(f, ty) => { + if let Value::Aggregate(_, _, fields) = self.get(value) { + return Some(fields[f.as_usize()]); + } else if let Value::Projection(outer_value, ProjectionElem::Downcast(_, read_variant)) = self.get(value) + && let Value::Aggregate(_, written_variant, fields) = self.get(*outer_value) + // This pass is not aware of control-flow, so we do not know whether the + // replacement we are doing is actually reachable. We could be in any arm of + // ``` + // match Some(x) { + // Some(y) => /* stuff */, + // None => /* other */, + // } + // ``` + // + // In surface rust, the current statement would be unreachable. + // + // However, from the reference chapter on enums and RFC 2195, + // accessing the wrong variant is not UB if the enum has repr. + // So it's not impossible for a series of MIR opts to generate + // a downcast to an inactive variant. + && written_variant == read_variant + { + return Some(fields[f.as_usize()]); + } + ProjectionElem::Field(f, ty) + } + ProjectionElem::Index(idx) => { + if let Value::Repeat(inner, _) = self.get(value) { + return Some(*inner); + } + let idx = self.locals[idx]?; + ProjectionElem::Index(idx) + } + ProjectionElem::ConstantIndex { offset, min_length, from_end } => { + match self.get(value) { + Value::Repeat(inner, _) => { + return Some(*inner); + } + Value::Aggregate(AggregateTy::Array, _, operands) => { + let offset = if from_end { + operands.len() - offset as usize + } else { + offset as usize + }; + return operands.get(offset).copied(); + } + _ => {} + }; + ProjectionElem::ConstantIndex { offset, min_length, from_end } + } + ProjectionElem::Subslice { from, to, from_end } => { + ProjectionElem::Subslice { from, to, from_end } + } + ProjectionElem::OpaqueCast(ty) => ProjectionElem::OpaqueCast(ty), + ProjectionElem::Subtype(ty) => ProjectionElem::Subtype(ty), + }; + + Some(self.insert(Value::Projection(value, proj))) + } + + /// Simplify the projection chain if we know better. + #[instrument(level = "trace", skip(self))] + fn simplify_place_projection(&mut self, place: &mut Place<'tcx>, location: Location) { + // If the projection is indirect, we treat the local as a value, so can replace it with + // another local. + if place.is_indirect() + && let Some(base) = self.locals[place.local] + && let Some(new_local) = self.try_as_local(base, location) + { + place.local = new_local; + self.reused_locals.insert(new_local); + } + + let mut projection = Cow::Borrowed(&place.projection[..]); + + for i in 0..projection.len() { + let elem = projection[i]; + if let ProjectionElem::Index(idx) = elem + && let Some(idx) = self.locals[idx] + { + if let Some(offset) = self.evaluated[idx].as_ref() + && let Ok(offset) = self.ecx.read_target_usize(offset) + { + projection.to_mut()[i] = ProjectionElem::ConstantIndex { + offset, + min_length: offset + 1, + from_end: false, + }; + } else if let Some(new_idx) = self.try_as_local(idx, location) { + projection.to_mut()[i] = ProjectionElem::Index(new_idx); + self.reused_locals.insert(new_idx); + } + } + } + + if projection.is_owned() { + place.projection = self.tcx.mk_place_elems(&projection); + } + + trace!(?place); + } + /// Represent the *value* which would be read from `place`, and point `place` to a preexisting /// place with the same value (if that already exists). #[instrument(level = "trace", skip(self), ret)] @@ -259,6 +669,8 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { place: &mut Place<'tcx>, location: Location, ) -> Option<VnIndex> { + self.simplify_place_projection(place, location); + // Invariant: `place` and `place_ref` point to the same value, even if they point to // different memory locations. let mut place_ref = place.as_ref(); @@ -273,57 +685,18 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { place_ref = PlaceRef { local, projection: &place.projection[index..] }; } - let proj = match proj { - ProjectionElem::Deref => { - let ty = Place::ty_from( - place.local, - &place.projection[..index], - self.local_decls, - self.tcx, - ) - .ty; - if let Some(Mutability::Not) = ty.ref_mutability() - && let Some(pointee_ty) = ty.builtin_deref(true) - && pointee_ty.ty.is_freeze(self.tcx, self.param_env) - { - // An immutable borrow `_x` always points to the same value for the - // lifetime of the borrow, so we can merge all instances of `*_x`. - ProjectionElem::Deref - } else { - return None; - } - } - ProjectionElem::Field(f, ty) => ProjectionElem::Field(f, ty), - ProjectionElem::Index(idx) => { - let idx = self.locals[idx]?; - ProjectionElem::Index(idx) - } - ProjectionElem::ConstantIndex { offset, min_length, from_end } => { - ProjectionElem::ConstantIndex { offset, min_length, from_end } - } - ProjectionElem::Subslice { from, to, from_end } => { - ProjectionElem::Subslice { from, to, from_end } - } - ProjectionElem::Downcast(name, index) => ProjectionElem::Downcast(name, index), - ProjectionElem::OpaqueCast(ty) => ProjectionElem::OpaqueCast(ty), - ProjectionElem::Subtype(ty) => ProjectionElem::Subtype(ty), - }; - value = self.insert(Value::Projection(value, proj)); + let base = PlaceRef { local: place.local, projection: &place.projection[..index] }; + value = self.project(base, value, proj)?; } - if let Some(local) = self.try_as_local(value, location) - && local != place.local // in case we had no projection to begin with. - { - *place = local.into(); - self.reused_locals.insert(local); - self.any_replacement = true; - } else if place_ref.local != place.local - || place_ref.projection.len() < place.projection.len() - { + if let Some(new_local) = self.try_as_local(value, location) { + place_ref = PlaceRef { local: new_local, projection: &[] }; + } + + if place_ref.local != place.local || place_ref.projection.len() < place.projection.len() { // By the invariant on `place_ref`. *place = place_ref.project_deeper(&[], self.tcx); self.reused_locals.insert(place_ref.local); - self.any_replacement = true; } Some(value) @@ -336,12 +709,14 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { location: Location, ) -> Option<VnIndex> { match *operand { - Operand::Constant(ref constant) => Some(self.insert(Value::Constant(constant.const_))), + Operand::Constant(ref mut constant) => { + let const_ = constant.const_.normalize(self.tcx, self.param_env); + self.insert_constant(const_) + } Operand::Copy(ref mut place) | Operand::Move(ref mut place) => { let value = self.simplify_place_value(place, location)?; if let Some(const_) = self.try_as_constant(value) { *operand = Operand::Constant(Box::new(const_)); - self.any_replacement = true; } Some(value) } @@ -370,24 +745,15 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { Value::Repeat(op, amount) } Rvalue::NullaryOp(op, ty) => Value::NullaryOp(op, ty), - Rvalue::Aggregate(box ref kind, ref mut fields) => { - let variant_index = match *kind { - AggregateKind::Array(..) - | AggregateKind::Tuple - | AggregateKind::Closure(..) - | AggregateKind::Generator(..) => FIRST_VARIANT, - AggregateKind::Adt(_, variant_index, _, _, None) => variant_index, - // Do not track unions. - AggregateKind::Adt(_, _, _, _, Some(_)) => return None, - }; - let fields: Option<Vec<_>> = fields - .iter_mut() - .map(|op| self.simplify_operand(op, location).or_else(|| self.new_opaque())) - .collect(); - let ty = rvalue.ty(self.local_decls, self.tcx); - Value::Aggregate(ty, variant_index, fields?) + Rvalue::Aggregate(..) => return self.simplify_aggregate(rvalue, location), + Rvalue::Ref(_, borrow_kind, ref mut place) => { + self.simplify_place_projection(place, location); + return self.new_pointer(*place, AddressKind::Ref(borrow_kind)); + } + Rvalue::AddressOf(mutbl, ref mut place) => { + self.simplify_place_projection(place, location); + return self.new_pointer(*place, AddressKind::Address(mutbl)); } - Rvalue::Ref(.., place) | Rvalue::AddressOf(_, place) => return self.new_pointer(place), // Operations. Rvalue::Len(ref mut place) => { @@ -397,6 +763,14 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { Rvalue::Cast(kind, ref mut value, to) => { let from = value.ty(self.local_decls, self.tcx); let value = self.simplify_operand(value, location)?; + if let CastKind::PointerCoercion( + PointerCoercion::ReifyFnPointer | PointerCoercion::ClosureFnPointer(_), + ) = kind + { + // Each reification of a generic fn may get a different pointer. + // Do not try to merge them. + return self.new_opaque(); + } Value::Cast { kind, value, from, to } } Rvalue::BinaryOp(op, box (ref mut lhs, ref mut rhs)) => { @@ -415,6 +789,9 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { } Rvalue::Discriminant(ref mut place) => { let place = self.simplify_place_value(place, location)?; + if let Some(discr) = self.simplify_discriminant(place) { + return Some(discr); + } Value::Discriminant(place) } @@ -424,45 +801,182 @@ impl<'body, 'tcx> VnState<'body, 'tcx> { debug!(?value); Some(self.insert(value)) } + + fn simplify_discriminant(&mut self, place: VnIndex) -> Option<VnIndex> { + if let Value::Aggregate(enum_ty, variant, _) = *self.get(place) + && let AggregateTy::Def(enum_did, enum_substs) = enum_ty + && let DefKind::Enum = self.tcx.def_kind(enum_did) + { + let enum_ty = self.tcx.type_of(enum_did).instantiate(self.tcx, enum_substs); + let discr = self.ecx.discriminant_for_variant(enum_ty, variant).ok()?; + return Some(self.insert_scalar(discr.to_scalar(), discr.layout.ty)); + } + + None + } + + fn simplify_aggregate( + &mut self, + rvalue: &mut Rvalue<'tcx>, + location: Location, + ) -> Option<VnIndex> { + let Rvalue::Aggregate(box ref kind, ref mut fields) = *rvalue else { bug!() }; + + let tcx = self.tcx; + if fields.is_empty() { + let is_zst = match *kind { + AggregateKind::Array(..) | AggregateKind::Tuple | AggregateKind::Closure(..) => { + true + } + // Only enums can be non-ZST. + AggregateKind::Adt(did, ..) => tcx.def_kind(did) != DefKind::Enum, + // Coroutines are never ZST, as they at least contain the implicit states. + AggregateKind::Coroutine(..) => false, + }; + + if is_zst { + let ty = rvalue.ty(self.local_decls, tcx); + return self.insert_constant(Const::zero_sized(ty)); + } + } + + let (ty, variant_index) = match *kind { + AggregateKind::Array(..) => { + assert!(!fields.is_empty()); + (AggregateTy::Array, FIRST_VARIANT) + } + AggregateKind::Tuple => { + assert!(!fields.is_empty()); + (AggregateTy::Tuple, FIRST_VARIANT) + } + AggregateKind::Closure(did, substs) | AggregateKind::Coroutine(did, substs, _) => { + (AggregateTy::Def(did, substs), FIRST_VARIANT) + } + AggregateKind::Adt(did, variant_index, substs, _, None) => { + (AggregateTy::Def(did, substs), variant_index) + } + // Do not track unions. + AggregateKind::Adt(_, _, _, _, Some(_)) => return None, + }; + + let fields: Option<Vec<_>> = fields + .iter_mut() + .map(|op| self.simplify_operand(op, location).or_else(|| self.new_opaque())) + .collect(); + let fields = fields?; + + if let AggregateTy::Array = ty && fields.len() > 4 { + let first = fields[0]; + if fields.iter().all(|&v| v == first) { + let len = ty::Const::from_target_usize(self.tcx, fields.len().try_into().unwrap()); + if let Some(const_) = self.try_as_constant(first) { + *rvalue = Rvalue::Repeat(Operand::Constant(Box::new(const_)), len); + } else if let Some(local) = self.try_as_local(first, location) { + *rvalue = Rvalue::Repeat(Operand::Copy(local.into()), len); + self.reused_locals.insert(local); + } + return Some(self.insert(Value::Repeat(first, len))); + } + } + + Some(self.insert(Value::Aggregate(ty, variant_index, fields))) + } +} + +fn op_to_prop_const<'tcx>( + ecx: &mut InterpCx<'_, 'tcx, DummyMachine>, + op: &OpTy<'tcx>, +) -> Option<ConstValue<'tcx>> { + // Do not attempt to propagate unsized locals. + if op.layout.is_unsized() { + return None; + } + + // This constant is a ZST, just return an empty value. + if op.layout.is_zst() { + return Some(ConstValue::ZeroSized); + } + + // Do not synthetize too large constants. Codegen will just memcpy them, which we'd like to avoid. + if !matches!(op.layout.abi, Abi::Scalar(..) | Abi::ScalarPair(..)) { + return None; + } + + // If this constant has scalar ABI, return it as a `ConstValue::Scalar`. + if let Abi::Scalar(abi::Scalar::Initialized { .. }) = op.layout.abi + && let Ok(scalar) = ecx.read_scalar(op) + && scalar.try_to_int().is_ok() + { + return Some(ConstValue::Scalar(scalar)); + } + + // If this constant is already represented as an `Allocation`, + // try putting it into global memory to return it. + if let Either::Left(mplace) = op.as_mplace_or_imm() { + let (size, _align) = ecx.size_and_align_of_mplace(&mplace).ok()??; + + // Do not try interning a value that contains provenance. + // Due to https://github.com/rust-lang/rust/issues/79738, doing so could lead to bugs. + // FIXME: remove this hack once that issue is fixed. + let alloc_ref = ecx.get_ptr_alloc(mplace.ptr(), size).ok()??; + if alloc_ref.has_provenance() { + return None; + } + + let pointer = mplace.ptr().into_pointer_or_addr().ok()?; + let (alloc_id, offset) = pointer.into_parts(); + intern_const_alloc_for_constprop(ecx, alloc_id).ok()?; + if matches!(ecx.tcx.global_alloc(alloc_id), GlobalAlloc::Memory(_)) { + // `alloc_id` may point to a static. Codegen will choke on an `Indirect` with anything + // by `GlobalAlloc::Memory`, so do fall through to copying if needed. + // FIXME: find a way to treat this more uniformly + // (probably by fixing codegen) + return Some(ConstValue::Indirect { alloc_id, offset }); + } + } + + // Everything failed: create a new allocation to hold the data. + let alloc_id = + ecx.intern_with_temp_alloc(op.layout, |ecx, dest| ecx.copy_op(op, dest, false)).ok()?; + let value = ConstValue::Indirect { alloc_id, offset: Size::ZERO }; + + // Check that we do not leak a pointer. + // Those pointers may lose part of their identity in codegen. + // FIXME: remove this hack once https://github.com/rust-lang/rust/issues/79738 is fixed. + if ecx.tcx.global_alloc(alloc_id).unwrap_memory().inner().provenance().ptrs().is_empty() { + return Some(value); + } + + None } impl<'tcx> VnState<'_, 'tcx> { /// If `index` is a `Value::Constant`, return the `Constant` to be put in the MIR. fn try_as_constant(&mut self, index: VnIndex) -> Option<ConstOperand<'tcx>> { - if let Value::Constant(const_) = *self.get(index) { - // Some constants may contain pointers. We need to preserve the provenance of these - // pointers, but not all constants guarantee this: - // - valtrees purposefully do not; - // - ConstValue::Slice does not either. - match const_ { - Const::Ty(c) => match c.kind() { - ty::ConstKind::Value(valtree) => match valtree { - // This is just an integer, keep it. - ty::ValTree::Leaf(_) => {} - ty::ValTree::Branch(_) => return None, - }, - ty::ConstKind::Param(..) - | ty::ConstKind::Unevaluated(..) - | ty::ConstKind::Expr(..) => {} - // Should not appear in runtime MIR. - ty::ConstKind::Infer(..) - | ty::ConstKind::Bound(..) - | ty::ConstKind::Placeholder(..) - | ty::ConstKind::Error(..) => bug!(), - }, - Const::Unevaluated(..) => {} - // If the same slice appears twice in the MIR, we cannot guarantee that we will - // give the same `AllocId` to the data. - Const::Val(ConstValue::Slice { .. }, _) => return None, - Const::Val( - ConstValue::ZeroSized | ConstValue::Scalar(_) | ConstValue::Indirect { .. }, - _, - ) => {} - } - Some(ConstOperand { span: rustc_span::DUMMY_SP, user_ty: None, const_ }) - } else { - None + // This was already constant in MIR, do not change it. + if let Value::Constant { value, disambiguator: _ } = *self.get(index) + // If the constant is not deterministic, adding an additional mention of it in MIR will + // not give the same value as the former mention. + && value.is_deterministic() + { + return Some(ConstOperand { span: rustc_span::DUMMY_SP, user_ty: None, const_: value }); + } + + let op = self.evaluated[index].as_ref()?; + if op.layout.is_unsized() { + // Do not attempt to propagate unsized locals. + return None; } + + let value = op_to_prop_const(&mut self.ecx, op)?; + + // Check that we do not leak a pointer. + // Those pointers may lose part of their identity in codegen. + // FIXME: remove this hack once https://github.com/rust-lang/rust/issues/79738 is fixed. + assert!(!value.may_have_provenance(self.tcx, op.layout.size)); + + let const_ = Const::Val(value, op.layout.ty); + Some(ConstOperand { span: rustc_span::DUMMY_SP, user_ty: None, const_ }) } /// If there is a local which is assigned `index`, and its assignment strictly dominates `loc`, @@ -481,27 +995,32 @@ impl<'tcx> MutVisitor<'tcx> for VnState<'_, 'tcx> { self.tcx } + fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, location: Location) { + self.simplify_place_projection(place, location); + } + fn visit_operand(&mut self, operand: &mut Operand<'tcx>, location: Location) { self.simplify_operand(operand, location); } fn visit_statement(&mut self, stmt: &mut Statement<'tcx>, location: Location) { - self.super_statement(stmt, location); if let StatementKind::Assign(box (_, ref mut rvalue)) = stmt.kind // Do not try to simplify a constant, it's already in canonical shape. && !matches!(rvalue, Rvalue::Use(Operand::Constant(_))) - && let Some(value) = self.simplify_rvalue(rvalue, location) { - if let Some(const_) = self.try_as_constant(value) { - *rvalue = Rvalue::Use(Operand::Constant(Box::new(const_))); - self.any_replacement = true; - } else if let Some(local) = self.try_as_local(value, location) - && *rvalue != Rvalue::Use(Operand::Move(local.into())) + if let Some(value) = self.simplify_rvalue(rvalue, location) { - *rvalue = Rvalue::Use(Operand::Copy(local.into())); - self.reused_locals.insert(local); - self.any_replacement = true; + if let Some(const_) = self.try_as_constant(value) { + *rvalue = Rvalue::Use(Operand::Constant(Box::new(const_))); + } else if let Some(local) = self.try_as_local(value, location) + && *rvalue != Rvalue::Use(Operand::Move(local.into())) + { + *rvalue = Rvalue::Use(Operand::Copy(local.into())); + self.reused_locals.insert(local); + } } + } else { + self.super_statement(stmt, location); } } } diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index b53e0852c..793dcf0d9 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -14,6 +14,7 @@ use rustc_session::config::OptLevel; use rustc_target::abi::FieldIdx; use rustc_target::spec::abi::Abi; +use crate::cost_checker::CostChecker; use crate::simplify::{remove_dead_blocks, CfgSimplifier}; use crate::util; use crate::MirPass; @@ -22,11 +23,6 @@ use std::ops::{Range, RangeFrom}; pub(crate) mod cycle; -const INSTR_COST: usize = 5; -const CALL_PENALTY: usize = 25; -const LANDINGPAD_PENALTY: usize = 50; -const RESUME_PENALTY: usize = 45; - const TOP_DOWN_DEPTH_LIMIT: usize = 5; pub struct Inline; @@ -63,7 +59,7 @@ impl<'tcx> MirPass<'tcx> for Inline { if inline(tcx, body) { debug!("running simplify cfg on {:?}", body.source); CfgSimplifier::new(body).simplify(); - remove_dead_blocks(tcx, body); + remove_dead_blocks(body); deref_finder(tcx, body); } } @@ -79,10 +75,10 @@ fn inline<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> bool { if body.source.promoted.is_some() { return false; } - // Avoid inlining into generators, since their `optimized_mir` is used for layout computation, + // Avoid inlining into coroutines, since their `optimized_mir` is used for layout computation, // which can create a cycle, even when no attempt is made to inline the function in the other // direction. - if body.generator.is_some() { + if body.coroutine.is_some() { return false; } @@ -169,8 +165,11 @@ impl<'tcx> Inliner<'tcx> { caller_body: &mut Body<'tcx>, callsite: &CallSite<'tcx>, ) -> Result<std::ops::Range<BasicBlock>, &'static str> { + self.check_mir_is_available(caller_body, &callsite.callee)?; + let callee_attrs = self.tcx.codegen_fn_attrs(callsite.callee.def_id()); - self.check_codegen_attributes(callsite, callee_attrs)?; + let cross_crate_inlinable = self.tcx.cross_crate_inlinable(callsite.callee.def_id()); + self.check_codegen_attributes(callsite, callee_attrs, cross_crate_inlinable)?; let terminator = caller_body[callsite.block].terminator.as_ref().unwrap(); let TerminatorKind::Call { args, destination, .. } = &terminator.kind else { bug!() }; @@ -183,9 +182,8 @@ impl<'tcx> Inliner<'tcx> { } } - self.check_mir_is_available(caller_body, &callsite.callee)?; let callee_body = try_instance_mir(self.tcx, callsite.callee.def)?; - self.check_mir_body(callsite, callee_body, callee_attrs)?; + self.check_mir_body(callsite, callee_body, callee_attrs, cross_crate_inlinable)?; if !self.tcx.consider_optimizing(|| { format!("Inline {:?} into {:?}", callsite.callee, caller_body.source) @@ -401,6 +399,7 @@ impl<'tcx> Inliner<'tcx> { &self, callsite: &CallSite<'tcx>, callee_attrs: &CodegenFnAttrs, + cross_crate_inlinable: bool, ) -> Result<(), &'static str> { if let InlineAttr::Never = callee_attrs.inline { return Err("never inline hint"); @@ -414,7 +413,7 @@ impl<'tcx> Inliner<'tcx> { .non_erasable_generics(self.tcx, callsite.callee.def_id()) .next() .is_some(); - if !is_generic && !callee_attrs.requests_inline() { + if !is_generic && !cross_crate_inlinable { return Err("not exported"); } @@ -439,10 +438,13 @@ impl<'tcx> Inliner<'tcx> { return Err("incompatible instruction set"); } - for feature in &callee_attrs.target_features { - if !self.codegen_fn_attrs.target_features.contains(feature) { - return Err("incompatible target feature"); - } + if callee_attrs.target_features != self.codegen_fn_attrs.target_features { + // In general it is not correct to inline a callee with target features that are a + // subset of the caller. This is because the callee might contain calls, and the ABI of + // those calls depends on the target features of the surrounding function. By moving a + // `Call` terminator from one MIR body to another with more target features, we might + // change the ABI of that call! + return Err("incompatible target features"); } Ok(()) @@ -456,10 +458,11 @@ impl<'tcx> Inliner<'tcx> { callsite: &CallSite<'tcx>, callee_body: &Body<'tcx>, callee_attrs: &CodegenFnAttrs, + cross_crate_inlinable: bool, ) -> Result<(), &'static str> { let tcx = self.tcx; - let mut threshold = if callee_attrs.requests_inline() { + let mut threshold = if cross_crate_inlinable { self.tcx.sess.opts.unstable_opts.inline_mir_hint_threshold.unwrap_or(100) } else { self.tcx.sess.opts.unstable_opts.inline_mir_threshold.unwrap_or(50) @@ -475,13 +478,8 @@ impl<'tcx> Inliner<'tcx> { // FIXME: Give a bonus to functions with only a single caller - let mut checker = CostChecker { - tcx: self.tcx, - param_env: self.param_env, - instance: callsite.callee, - callee_body, - cost: 0, - }; + let mut checker = + CostChecker::new(self.tcx, self.param_env, Some(callsite.callee), callee_body); // Traverse the MIR manually so we can account for the effects of inlining on the CFG. let mut work_list = vec![START_BLOCK]; @@ -503,7 +501,9 @@ impl<'tcx> Inliner<'tcx> { self.tcx, ty::EarlyBinder::bind(&place.ty(callee_body, tcx).ty), ); - if ty.needs_drop(tcx, self.param_env) && let UnwindAction::Cleanup(unwind) = unwind { + if ty.needs_drop(tcx, self.param_env) + && let UnwindAction::Cleanup(unwind) = unwind + { work_list.push(unwind); } } else if callee_attrs.instruction_set != self.codegen_fn_attrs.instruction_set @@ -524,7 +524,7 @@ impl<'tcx> Inliner<'tcx> { // That attribute is often applied to very large functions that exceed LLVM's (very // generous) inlining threshold. Such functions are very poor MIR inlining candidates. // Always inlining #[inline(always)] functions in MIR, on net, slows down the compiler. - let cost = checker.cost; + let cost = checker.cost(); if cost <= threshold { debug!("INLINING {:?} [cost={} <= threshold={}]", callsite, cost, threshold); Ok(()) @@ -616,9 +616,7 @@ impl<'tcx> Inliner<'tcx> { // If there are any locals without storage markers, give them storage only for the // duration of the call. for local in callee_body.vars_and_temps_iter() { - if !callee_body.local_decls[local].internal - && integrator.always_live_locals.contains(local) - { + if integrator.always_live_locals.contains(local) { let new_local = integrator.map_local(local); caller_body[callsite.block].statements.push(Statement { source_info: callsite.source_info, @@ -641,9 +639,7 @@ impl<'tcx> Inliner<'tcx> { n += 1; } for local in callee_body.vars_and_temps_iter().rev() { - if !callee_body.local_decls[local].internal - && integrator.always_live_locals.contains(local) - { + if integrator.always_live_locals.contains(local) { let new_local = integrator.map_local(local); caller_body[block].statements.push(Statement { source_info: callsite.source_info, @@ -801,79 +797,6 @@ impl<'tcx> Inliner<'tcx> { } } -/// Verify that the callee body is compatible with the caller. -/// -/// This visitor mostly computes the inlining cost, -/// but also needs to verify that types match because of normalization failure. -struct CostChecker<'b, 'tcx> { - tcx: TyCtxt<'tcx>, - param_env: ParamEnv<'tcx>, - cost: usize, - callee_body: &'b Body<'tcx>, - instance: ty::Instance<'tcx>, -} - -impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> { - fn visit_statement(&mut self, statement: &Statement<'tcx>, _: Location) { - // Don't count StorageLive/StorageDead in the inlining cost. - match statement.kind { - StatementKind::StorageLive(_) - | StatementKind::StorageDead(_) - | StatementKind::Deinit(_) - | StatementKind::Nop => {} - _ => self.cost += INSTR_COST, - } - } - - fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, _: Location) { - let tcx = self.tcx; - match terminator.kind { - TerminatorKind::Drop { ref place, unwind, .. } => { - // If the place doesn't actually need dropping, treat it like a regular goto. - let ty = self.instance.instantiate_mir( - tcx, - ty::EarlyBinder::bind(&place.ty(self.callee_body, tcx).ty), - ); - if ty.needs_drop(tcx, self.param_env) { - self.cost += CALL_PENALTY; - if let UnwindAction::Cleanup(_) = unwind { - self.cost += LANDINGPAD_PENALTY; - } - } else { - self.cost += INSTR_COST; - } - } - TerminatorKind::Call { func: Operand::Constant(ref f), unwind, .. } => { - let fn_ty = - self.instance.instantiate_mir(tcx, ty::EarlyBinder::bind(&f.const_.ty())); - self.cost += if let ty::FnDef(def_id, _) = *fn_ty.kind() && tcx.is_intrinsic(def_id) { - // Don't give intrinsics the extra penalty for calls - INSTR_COST - } else { - CALL_PENALTY - }; - if let UnwindAction::Cleanup(_) = unwind { - self.cost += LANDINGPAD_PENALTY; - } - } - TerminatorKind::Assert { unwind, .. } => { - self.cost += CALL_PENALTY; - if let UnwindAction::Cleanup(_) = unwind { - self.cost += LANDINGPAD_PENALTY; - } - } - TerminatorKind::UnwindResume => self.cost += RESUME_PENALTY, - TerminatorKind::InlineAsm { unwind, .. } => { - self.cost += INSTR_COST; - if let UnwindAction::Cleanup(_) = unwind { - self.cost += LANDINGPAD_PENALTY; - } - } - _ => self.cost += INSTR_COST, - } - } -} - /** * Integrator. * @@ -1010,7 +933,7 @@ impl<'tcx> MutVisitor<'tcx> for Integrator<'_, 'tcx> { } match terminator.kind { - TerminatorKind::GeneratorDrop | TerminatorKind::Yield { .. } => bug!(), + TerminatorKind::CoroutineDrop | TerminatorKind::Yield { .. } => bug!(), TerminatorKind::Goto { ref mut target } => { *target = self.map_block(*target); } diff --git a/compiler/rustc_mir_transform/src/instsimplify.rs b/compiler/rustc_mir_transform/src/instsimplify.rs index a6ef2e11a..fbcd6e75a 100644 --- a/compiler/rustc_mir_transform/src/instsimplify.rs +++ b/compiler/rustc_mir_transform/src/instsimplify.rs @@ -93,7 +93,9 @@ impl<'tcx> InstSimplifyContext<'tcx, '_> { _ => None, }; - if let Some(new) = new && self.should_simplify(source_info, rvalue) { + if let Some(new) = new + && self.should_simplify(source_info, rvalue) + { *rvalue = new; } } @@ -150,7 +152,8 @@ impl<'tcx> InstSimplifyContext<'tcx, '_> { *rvalue = Rvalue::Use(operand.clone()); } else if *kind == CastKind::Transmute { // Transmuting an integer to another integer is just a signedness cast - if let (ty::Int(int), ty::Uint(uint)) | (ty::Uint(uint), ty::Int(int)) = (operand_ty.kind(), cast_ty.kind()) + if let (ty::Int(int), ty::Uint(uint)) | (ty::Uint(uint), ty::Int(int)) = + (operand_ty.kind(), cast_ty.kind()) && int.bit_width() == uint.bit_width() { // The width check isn't strictly necessary, as different widths @@ -172,8 +175,15 @@ impl<'tcx> InstSimplifyContext<'tcx, '_> { for (i, field) in variant.fields.iter().enumerate() { let field_ty = field.ty(self.tcx, args); if field_ty == *cast_ty { - let place = place.project_deeper(&[ProjectionElem::Field(FieldIdx::from_usize(i), *cast_ty)], self.tcx); - let operand = if operand.is_move() { Operand::Move(place) } else { Operand::Copy(place) }; + let place = place.project_deeper( + &[ProjectionElem::Field(FieldIdx::from_usize(i), *cast_ty)], + self.tcx, + ); + let operand = if operand.is_move() { + Operand::Move(place) + } else { + Operand::Copy(place) + }; *rvalue = Rvalue::Use(operand); return; } diff --git a/compiler/rustc_mir_transform/src/jump_threading.rs b/compiler/rustc_mir_transform/src/jump_threading.rs new file mode 100644 index 000000000..7b918be44 --- /dev/null +++ b/compiler/rustc_mir_transform/src/jump_threading.rs @@ -0,0 +1,759 @@ +//! A jump threading optimization. +//! +//! This optimization seeks to replace join-then-switch control flow patterns by straight jumps +//! X = 0 X = 0 +//! ------------\ /-------- ------------ +//! X = 1 X----X SwitchInt(X) => X = 1 +//! ------------/ \-------- ------------ +//! +//! +//! We proceed by walking the cfg backwards starting from each `SwitchInt` terminator, +//! looking for assignments that will turn the `SwitchInt` into a simple `Goto`. +//! +//! The algorithm maintains a set of replacement conditions: +//! - `conditions[place]` contains `Condition { value, polarity: Eq, target }` +//! if assigning `value` to `place` turns the `SwitchInt` into `Goto { target }`. +//! - `conditions[place]` contains `Condition { value, polarity: Ne, target }` +//! if assigning anything different from `value` to `place` turns the `SwitchInt` +//! into `Goto { target }`. +//! +//! In this file, we denote as `place ?= value` the existence of a replacement condition +//! on `place` with given `value`, irrespective of the polarity and target of that +//! replacement condition. +//! +//! We then walk the CFG backwards transforming the set of conditions. +//! When we find a fulfilling assignment, we record a `ThreadingOpportunity`. +//! All `ThreadingOpportunity`s are applied to the body, by duplicating blocks if required. +//! +//! The optimization search can be very heavy, as it performs a DFS on MIR starting from +//! each `SwitchInt` terminator. To manage the complexity, we: +//! - bound the maximum depth by a constant `MAX_BACKTRACK`; +//! - we only traverse `Goto` terminators. +//! +//! We try to avoid creating irreducible control-flow by not threading through a loop header. +//! +//! Likewise, applying the optimisation can create a lot of new MIR, so we bound the instruction +//! cost by `MAX_COST`. + +use rustc_arena::DroplessArena; +use rustc_data_structures::fx::FxHashSet; +use rustc_index::bit_set::BitSet; +use rustc_index::IndexVec; +use rustc_middle::mir::visit::Visitor; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt}; +use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem}; + +use crate::cost_checker::CostChecker; +use crate::MirPass; + +pub struct JumpThreading; + +const MAX_BACKTRACK: usize = 5; +const MAX_COST: usize = 100; +const MAX_PLACES: usize = 100; + +impl<'tcx> MirPass<'tcx> for JumpThreading { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 4 + } + + #[instrument(skip_all level = "debug")] + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let def_id = body.source.def_id(); + debug!(?def_id); + + let param_env = tcx.param_env_reveal_all_normalized(def_id); + let map = Map::new(tcx, body, Some(MAX_PLACES)); + let loop_headers = loop_headers(body); + + let arena = DroplessArena::default(); + let mut finder = TOFinder { + tcx, + param_env, + body, + arena: &arena, + map: &map, + loop_headers: &loop_headers, + opportunities: Vec::new(), + }; + + for (bb, bbdata) in body.basic_blocks.iter_enumerated() { + debug!(?bb, term = ?bbdata.terminator()); + if bbdata.is_cleanup || loop_headers.contains(bb) { + continue; + } + let Some((discr, targets)) = bbdata.terminator().kind.as_switch() else { continue }; + let Some(discr) = discr.place() else { continue }; + debug!(?discr, ?bb); + + let discr_ty = discr.ty(body, tcx).ty; + let Ok(discr_layout) = tcx.layout_of(param_env.and(discr_ty)) else { continue }; + + let Some(discr) = finder.map.find(discr.as_ref()) else { continue }; + debug!(?discr); + + let cost = CostChecker::new(tcx, param_env, None, body); + + let mut state = State::new(ConditionSet::default(), &finder.map); + + let conds = if let Some((value, then, else_)) = targets.as_static_if() { + let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { + continue; + }; + arena.alloc_from_iter([ + Condition { value, polarity: Polarity::Eq, target: then }, + Condition { value, polarity: Polarity::Ne, target: else_ }, + ]) + } else { + arena.alloc_from_iter(targets.iter().filter_map(|(value, target)| { + let value = ScalarInt::try_from_uint(value, discr_layout.size)?; + Some(Condition { value, polarity: Polarity::Eq, target }) + })) + }; + let conds = ConditionSet(conds); + state.insert_value_idx(discr, conds, &finder.map); + + finder.find_opportunity(bb, state, cost, 0); + } + + let opportunities = finder.opportunities; + debug!(?opportunities); + if opportunities.is_empty() { + return; + } + + // Verify that we do not thread through a loop header. + for to in opportunities.iter() { + assert!(to.chain.iter().all(|&block| !loop_headers.contains(block))); + } + OpportunitySet::new(body, opportunities).apply(body); + } +} + +#[derive(Debug)] +struct ThreadingOpportunity { + /// The list of `BasicBlock`s from the one that found the opportunity to the `SwitchInt`. + chain: Vec<BasicBlock>, + /// The `SwitchInt` will be replaced by `Goto { target }`. + target: BasicBlock, +} + +struct TOFinder<'tcx, 'a> { + tcx: TyCtxt<'tcx>, + param_env: ty::ParamEnv<'tcx>, + body: &'a Body<'tcx>, + map: &'a Map, + loop_headers: &'a BitSet<BasicBlock>, + /// We use an arena to avoid cloning the slices when cloning `state`. + arena: &'a DroplessArena, + opportunities: Vec<ThreadingOpportunity>, +} + +/// Represent the following statement. If we can prove that the current local is equal/not-equal +/// to `value`, jump to `target`. +#[derive(Copy, Clone, Debug)] +struct Condition { + value: ScalarInt, + polarity: Polarity, + target: BasicBlock, +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +enum Polarity { + Ne, + Eq, +} + +impl Condition { + fn matches(&self, value: ScalarInt) -> bool { + (self.value == value) == (self.polarity == Polarity::Eq) + } + + fn inv(mut self) -> Self { + self.polarity = match self.polarity { + Polarity::Eq => Polarity::Ne, + Polarity::Ne => Polarity::Eq, + }; + self + } +} + +#[derive(Copy, Clone, Debug, Default)] +struct ConditionSet<'a>(&'a [Condition]); + +impl<'a> ConditionSet<'a> { + fn iter(self) -> impl Iterator<Item = Condition> + 'a { + self.0.iter().copied() + } + + fn iter_matches(self, value: ScalarInt) -> impl Iterator<Item = Condition> + 'a { + self.iter().filter(move |c| c.matches(value)) + } + + fn map(self, arena: &'a DroplessArena, f: impl Fn(Condition) -> Condition) -> ConditionSet<'a> { + ConditionSet(arena.alloc_from_iter(self.iter().map(f))) + } +} + +impl<'tcx, 'a> TOFinder<'tcx, 'a> { + fn is_empty(&self, state: &State<ConditionSet<'a>>) -> bool { + state.all(|cs| cs.0.is_empty()) + } + + /// Recursion entry point to find threading opportunities. + #[instrument(level = "trace", skip(self, cost), ret)] + fn find_opportunity( + &mut self, + bb: BasicBlock, + mut state: State<ConditionSet<'a>>, + mut cost: CostChecker<'_, 'tcx>, + depth: usize, + ) { + // Do not thread through loop headers. + if self.loop_headers.contains(bb) { + return; + } + + debug!(cost = ?cost.cost()); + for (statement_index, stmt) in + self.body.basic_blocks[bb].statements.iter().enumerate().rev() + { + if self.is_empty(&state) { + return; + } + + cost.visit_statement(stmt, Location { block: bb, statement_index }); + if cost.cost() > MAX_COST { + return; + } + + // Attempt to turn the `current_condition` on `lhs` into a condition on another place. + self.process_statement(bb, stmt, &mut state); + + // When a statement mutates a place, assignments to that place that happen + // above the mutation cannot fulfill a condition. + // _1 = 5 // Whatever happens here, it won't change the result of a `SwitchInt`. + // _1 = 6 + if let Some((lhs, tail)) = self.mutated_statement(stmt) { + state.flood_with_tail_elem(lhs.as_ref(), tail, self.map, ConditionSet::default()); + } + } + + if self.is_empty(&state) || depth >= MAX_BACKTRACK { + return; + } + + let last_non_rec = self.opportunities.len(); + + let predecessors = &self.body.basic_blocks.predecessors()[bb]; + if let &[pred] = &predecessors[..] && bb != START_BLOCK { + let term = self.body.basic_blocks[pred].terminator(); + match term.kind { + TerminatorKind::SwitchInt { ref discr, ref targets } => { + self.process_switch_int(discr, targets, bb, &mut state); + self.find_opportunity(pred, state, cost, depth + 1); + } + _ => self.recurse_through_terminator(pred, &state, &cost, depth), + } + } else { + for &pred in predecessors { + self.recurse_through_terminator(pred, &state, &cost, depth); + } + } + + let new_tos = &mut self.opportunities[last_non_rec..]; + debug!(?new_tos); + + // Try to deduplicate threading opportunities. + if new_tos.len() > 1 + && new_tos.len() == predecessors.len() + && predecessors + .iter() + .zip(new_tos.iter()) + .all(|(&pred, to)| to.chain == &[pred] && to.target == new_tos[0].target) + { + // All predecessors have a threading opportunity, and they all point to the same block. + debug!(?new_tos, "dedup"); + let first = &mut new_tos[0]; + *first = ThreadingOpportunity { chain: vec![bb], target: first.target }; + self.opportunities.truncate(last_non_rec + 1); + return; + } + + for op in self.opportunities[last_non_rec..].iter_mut() { + op.chain.push(bb); + } + } + + /// Extract the mutated place from a statement. + /// + /// This method returns the `Place` so we can flood the state in case of a partial assignment. + /// (_1 as Ok).0 = _5; + /// (_1 as Err).0 = _6; + /// We want to ensure that a `SwitchInt((_1 as Ok).0)` does not see the first assignment, as + /// the value may have been mangled by the second assignment. + /// + /// In case we assign to a discriminant, we return `Some(TrackElem::Discriminant)`, so we can + /// stop at flooding the discriminant, and preserve the variant fields. + /// (_1 as Some).0 = _6; + /// SetDiscriminant(_1, 1); + /// switchInt((_1 as Some).0) + #[instrument(level = "trace", skip(self), ret)] + fn mutated_statement( + &self, + stmt: &Statement<'tcx>, + ) -> Option<(Place<'tcx>, Option<TrackElem>)> { + match stmt.kind { + StatementKind::Assign(box (place, _)) + | StatementKind::Deinit(box place) => Some((place, None)), + StatementKind::SetDiscriminant { box place, variant_index: _ } => { + Some((place, Some(TrackElem::Discriminant))) + } + StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => { + Some((Place::from(local), None)) + } + StatementKind::Retag(..) + | StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(..)) + // copy_nonoverlapping takes pointers and mutated the pointed-to value. + | StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping(..)) + | StatementKind::AscribeUserType(..) + | StatementKind::Coverage(..) + | StatementKind::FakeRead(..) + | StatementKind::ConstEvalCounter + | StatementKind::PlaceMention(..) + | StatementKind::Nop => None, + } + } + + #[instrument(level = "trace", skip(self))] + fn process_operand( + &mut self, + bb: BasicBlock, + lhs: PlaceIndex, + rhs: &Operand<'tcx>, + state: &mut State<ConditionSet<'a>>, + ) -> Option<!> { + let register_opportunity = |c: Condition| { + debug!(?bb, ?c.target, "register"); + self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target }) + }; + + match rhs { + // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`. + Operand::Constant(constant) => { + let conditions = state.try_get_idx(lhs, self.map)?; + let constant = + constant.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?; + conditions.iter_matches(constant).for_each(register_opportunity); + } + // Transfer the conditions on the copied rhs. + Operand::Move(rhs) | Operand::Copy(rhs) => { + let rhs = self.map.find(rhs.as_ref())?; + state.insert_place_idx(rhs, lhs, self.map); + } + } + + None + } + + #[instrument(level = "trace", skip(self))] + fn process_statement( + &mut self, + bb: BasicBlock, + stmt: &Statement<'tcx>, + state: &mut State<ConditionSet<'a>>, + ) -> Option<!> { + let register_opportunity = |c: Condition| { + debug!(?bb, ?c.target, "register"); + self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target }) + }; + + // Below, `lhs` is the return value of `mutated_statement`, + // the place to which `conditions` apply. + + let discriminant_for_variant = |enum_ty: Ty<'tcx>, variant_index| { + let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?; + let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?; + let scalar = ScalarInt::try_from_uint(discr.val, discr_layout.size)?; + Some(Operand::const_from_scalar( + self.tcx, + discr.ty, + scalar.into(), + rustc_span::DUMMY_SP, + )) + }; + + match &stmt.kind { + // If we expect `discriminant(place) ?= A`, + // we have an opportunity if `variant_index ?= A`. + StatementKind::SetDiscriminant { box place, variant_index } => { + let discr_target = self.map.find_discr(place.as_ref())?; + let enum_ty = place.ty(self.body, self.tcx).ty; + let discr = discriminant_for_variant(enum_ty, *variant_index)?; + self.process_operand(bb, discr_target, &discr, state)?; + } + // If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`. + StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume( + Operand::Copy(place) | Operand::Move(place), + )) => { + let conditions = state.try_get(place.as_ref(), self.map)?; + conditions.iter_matches(ScalarInt::TRUE).for_each(register_opportunity); + } + StatementKind::Assign(box (lhs_place, rhs)) => { + if let Some(lhs) = self.map.find(lhs_place.as_ref()) { + match rhs { + Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state)?, + // Transfer the conditions on the copy rhs. + Rvalue::CopyForDeref(rhs) => { + self.process_operand(bb, lhs, &Operand::Copy(*rhs), state)? + } + Rvalue::Discriminant(rhs) => { + let rhs = self.map.find_discr(rhs.as_ref())?; + state.insert_place_idx(rhs, lhs, self.map); + } + // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`. + Rvalue::Aggregate(box ref kind, ref operands) => { + let agg_ty = lhs_place.ty(self.body, self.tcx).ty; + let lhs = match kind { + // Do not support unions. + AggregateKind::Adt(.., Some(_)) => return None, + AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => { + if let Some(discr_target) = self.map.apply(lhs, TrackElem::Discriminant) + && let Some(discr_value) = discriminant_for_variant(agg_ty, *variant_index) + { + self.process_operand(bb, discr_target, &discr_value, state); + } + self.map.apply(lhs, TrackElem::Variant(*variant_index))? + } + _ => lhs, + }; + for (field_index, operand) in operands.iter_enumerated() { + if let Some(field) = + self.map.apply(lhs, TrackElem::Field(field_index)) + { + self.process_operand(bb, field, operand, state); + } + } + } + // Transfer the conditions on the copy rhs, after inversing polarity. + Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => { + let conditions = state.try_get_idx(lhs, self.map)?; + let place = self.map.find(place.as_ref())?; + let conds = conditions.map(self.arena, Condition::inv); + state.insert_value_idx(place, conds, self.map); + } + // We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`. + // Create a condition on `rhs ?= B`. + Rvalue::BinaryOp( + op, + box ( + Operand::Move(place) | Operand::Copy(place), + Operand::Constant(value), + ) + | box ( + Operand::Constant(value), + Operand::Move(place) | Operand::Copy(place), + ), + ) => { + let conditions = state.try_get_idx(lhs, self.map)?; + let place = self.map.find(place.as_ref())?; + let equals = match op { + BinOp::Eq => ScalarInt::TRUE, + BinOp::Ne => ScalarInt::FALSE, + _ => return None, + }; + let value = value + .const_ + .normalize(self.tcx, self.param_env) + .try_to_scalar_int()?; + let conds = conditions.map(self.arena, |c| Condition { + value, + polarity: if c.matches(equals) { + Polarity::Eq + } else { + Polarity::Ne + }, + ..c + }); + state.insert_value_idx(place, conds, self.map); + } + + _ => {} + } + } + } + _ => {} + } + + None + } + + #[instrument(level = "trace", skip(self, cost))] + fn recurse_through_terminator( + &mut self, + bb: BasicBlock, + state: &State<ConditionSet<'a>>, + cost: &CostChecker<'_, 'tcx>, + depth: usize, + ) { + let register_opportunity = |c: Condition| { + debug!(?bb, ?c.target, "register"); + self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target }) + }; + + let term = self.body.basic_blocks[bb].terminator(); + let place_to_flood = match term.kind { + // We come from a target, so those are not possible. + TerminatorKind::UnwindResume + | TerminatorKind::UnwindTerminate(_) + | TerminatorKind::Return + | TerminatorKind::Unreachable + | TerminatorKind::CoroutineDrop => bug!("{term:?} has no terminators"), + // Disallowed during optimizations. + TerminatorKind::FalseEdge { .. } + | TerminatorKind::FalseUnwind { .. } + | TerminatorKind::Yield { .. } => bug!("{term:?} invalid"), + // Cannot reason about inline asm. + TerminatorKind::InlineAsm { .. } => return, + // `SwitchInt` is handled specially. + TerminatorKind::SwitchInt { .. } => return, + // We can recurse, no thing particular to do. + TerminatorKind::Goto { .. } => None, + // Flood the overwritten place, and progress through. + TerminatorKind::Drop { place: destination, .. } + | TerminatorKind::Call { destination, .. } => Some(destination), + // Treat as an `assume(cond == expected)`. + TerminatorKind::Assert { ref cond, expected, .. } => { + if let Some(place) = cond.place() + && let Some(conditions) = state.try_get(place.as_ref(), self.map) + { + let expected = if expected { ScalarInt::TRUE } else { ScalarInt::FALSE }; + conditions.iter_matches(expected).for_each(register_opportunity); + } + None + } + }; + + // We can recurse through this terminator. + let mut state = state.clone(); + if let Some(place_to_flood) = place_to_flood { + state.flood_with(place_to_flood.as_ref(), self.map, ConditionSet::default()); + } + self.find_opportunity(bb, state, cost.clone(), depth + 1); + } + + #[instrument(level = "trace", skip(self))] + fn process_switch_int( + &mut self, + discr: &Operand<'tcx>, + targets: &SwitchTargets, + target_bb: BasicBlock, + state: &mut State<ConditionSet<'a>>, + ) -> Option<!> { + debug_assert_ne!(target_bb, START_BLOCK); + debug_assert_eq!(self.body.basic_blocks.predecessors()[target_bb].len(), 1); + + let discr = discr.place()?; + let discr_ty = discr.ty(self.body, self.tcx).ty; + let discr_layout = self.tcx.layout_of(self.param_env.and(discr_ty)).ok()?; + let conditions = state.try_get(discr.as_ref(), self.map)?; + + if let Some((value, _)) = targets.iter().find(|&(_, target)| target == target_bb) { + let value = ScalarInt::try_from_uint(value, discr_layout.size)?; + debug_assert_eq!(targets.iter().filter(|&(_, target)| target == target_bb).count(), 1); + + // We are inside `target_bb`. Since we have a single predecessor, we know we passed + // through the `SwitchInt` before arriving here. Therefore, we know that + // `discr == value`. If one condition can be fulfilled by `discr == value`, + // that's an opportunity. + for c in conditions.iter_matches(value) { + debug!(?target_bb, ?c.target, "register"); + self.opportunities.push(ThreadingOpportunity { chain: vec![], target: c.target }); + } + } else if let Some((value, _, else_bb)) = targets.as_static_if() + && target_bb == else_bb + { + let value = ScalarInt::try_from_uint(value, discr_layout.size)?; + + // We only know that `discr != value`. That's much weaker information than + // the equality we had in the previous arm. All we can conclude is that + // the replacement condition `discr != value` can be threaded, and nothing else. + for c in conditions.iter() { + if c.value == value && c.polarity == Polarity::Ne { + debug!(?target_bb, ?c.target, "register"); + self.opportunities + .push(ThreadingOpportunity { chain: vec![], target: c.target }); + } + } + } + + None + } +} + +struct OpportunitySet { + opportunities: Vec<ThreadingOpportunity>, + /// For each bb, give the TOs in which it appears. The pair corresponds to the index + /// in `opportunities` and the index in `ThreadingOpportunity::chain`. + involving_tos: IndexVec<BasicBlock, Vec<(usize, usize)>>, + /// Cache the number of predecessors for each block, as we clear the basic block cache.. + predecessors: IndexVec<BasicBlock, usize>, +} + +impl OpportunitySet { + fn new(body: &Body<'_>, opportunities: Vec<ThreadingOpportunity>) -> OpportunitySet { + let mut involving_tos = IndexVec::from_elem(Vec::new(), &body.basic_blocks); + for (index, to) in opportunities.iter().enumerate() { + for (ibb, &bb) in to.chain.iter().enumerate() { + involving_tos[bb].push((index, ibb)); + } + involving_tos[to.target].push((index, to.chain.len())); + } + let predecessors = predecessor_count(body); + OpportunitySet { opportunities, involving_tos, predecessors } + } + + /// Apply the opportunities on the graph. + fn apply(&mut self, body: &mut Body<'_>) { + for i in 0..self.opportunities.len() { + self.apply_once(i, body); + } + } + + #[instrument(level = "trace", skip(self, body))] + fn apply_once(&mut self, index: usize, body: &mut Body<'_>) { + debug!(?self.predecessors); + debug!(?self.involving_tos); + + // Check that `predecessors` satisfies its invariant. + debug_assert_eq!(self.predecessors, predecessor_count(body)); + + // Remove the TO from the vector to allow modifying the other ones later. + let op = &mut self.opportunities[index]; + debug!(?op); + let op_chain = std::mem::take(&mut op.chain); + let op_target = op.target; + debug_assert_eq!(op_chain.len(), op_chain.iter().collect::<FxHashSet<_>>().len()); + + let Some((current, chain)) = op_chain.split_first() else { return }; + let basic_blocks = body.basic_blocks.as_mut(); + + // Invariant: the control-flow is well-formed at the end of each iteration. + let mut current = *current; + for &succ in chain { + debug!(?current, ?succ); + + // `succ` must be a successor of `current`. If it is not, this means this TO is not + // satisfiable and a previous TO erased this edge, so we bail out. + if basic_blocks[current].terminator().successors().find(|s| *s == succ).is_none() { + debug!("impossible"); + return; + } + + // Fast path: `succ` is only used once, so we can reuse it directly. + if self.predecessors[succ] == 1 { + debug!("single"); + current = succ; + continue; + } + + let new_succ = basic_blocks.push(basic_blocks[succ].clone()); + debug!(?new_succ); + + // Replace `succ` by `new_succ` where it appears. + let mut num_edges = 0; + for s in basic_blocks[current].terminator_mut().successors_mut() { + if *s == succ { + *s = new_succ; + num_edges += 1; + } + } + + // Update predecessors with the new block. + let _new_succ = self.predecessors.push(num_edges); + debug_assert_eq!(new_succ, _new_succ); + self.predecessors[succ] -= num_edges; + self.update_predecessor_count(basic_blocks[new_succ].terminator(), Update::Incr); + + // Replace the `current -> succ` edge by `current -> new_succ` in all the following + // TOs. This is necessary to avoid trying to thread through a non-existing edge. We + // use `involving_tos` here to avoid traversing the full set of TOs on each iteration. + let mut new_involved = Vec::new(); + for &(to_index, in_to_index) in &self.involving_tos[current] { + // That TO has already been applied, do nothing. + if to_index <= index { + continue; + } + + let other_to = &mut self.opportunities[to_index]; + if other_to.chain.get(in_to_index) != Some(¤t) { + continue; + } + let s = other_to.chain.get_mut(in_to_index + 1).unwrap_or(&mut other_to.target); + if *s == succ { + // `other_to` references the `current -> succ` edge, so replace `succ`. + *s = new_succ; + new_involved.push((to_index, in_to_index + 1)); + } + } + + // The TOs that we just updated now reference `new_succ`. Update `involving_tos` + // in case we need to duplicate an edge starting at `new_succ` later. + let _new_succ = self.involving_tos.push(new_involved); + debug_assert_eq!(new_succ, _new_succ); + + current = new_succ; + } + + let current = &mut basic_blocks[current]; + self.update_predecessor_count(current.terminator(), Update::Decr); + current.terminator_mut().kind = TerminatorKind::Goto { target: op_target }; + self.predecessors[op_target] += 1; + } + + fn update_predecessor_count(&mut self, terminator: &Terminator<'_>, incr: Update) { + match incr { + Update::Incr => { + for s in terminator.successors() { + self.predecessors[s] += 1; + } + } + Update::Decr => { + for s in terminator.successors() { + self.predecessors[s] -= 1; + } + } + } + } +} + +fn predecessor_count(body: &Body<'_>) -> IndexVec<BasicBlock, usize> { + let mut predecessors: IndexVec<_, _> = + body.basic_blocks.predecessors().iter().map(|ps| ps.len()).collect(); + predecessors[START_BLOCK] += 1; // Account for the implicit entry edge. + predecessors +} + +enum Update { + Incr, + Decr, +} + +/// Compute the set of loop headers in the given body. We define a loop header as a block which has +/// at least a predecessor which it dominates. This definition is only correct for reducible CFGs. +/// But if the CFG is already irreducible, there is no point in trying much harder. +/// is already irreducible. +fn loop_headers(body: &Body<'_>) -> BitSet<BasicBlock> { + let mut loop_headers = BitSet::new_empty(body.basic_blocks.len()); + let dominators = body.basic_blocks.dominators(); + // Only visit reachable blocks. + for (bb, bbdata) in traversal::preorder(body) { + for succ in bbdata.terminator().successors() { + if dominators.dominates(succ, bb) { + loop_headers.insert(succ); + } + } + } + loop_headers +} diff --git a/compiler/rustc_mir_transform/src/large_enums.rs b/compiler/rustc_mir_transform/src/large_enums.rs index 886ff7604..0a8b13d66 100644 --- a/compiler/rustc_mir_transform/src/large_enums.rs +++ b/compiler/rustc_mir_transform/src/large_enums.rs @@ -30,6 +30,9 @@ pub struct EnumSizeOpt { impl<'tcx> MirPass<'tcx> for EnumSizeOpt { fn is_enabled(&self, sess: &Session) -> bool { + // There are some differences in behavior on wasm and ARM that are not properly + // understood, so we conservatively treat this optimization as unsound: + // https://github.com/rust-lang/rust/pull/85158#issuecomment-1101836457 sess.opts.unstable_opts.unsound_mir_opts || sess.mir_opt_level() >= 3 } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index c0a09b7a7..bf5f0ca7c 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -2,6 +2,7 @@ #![deny(rustc::untranslatable_diagnostic)] #![deny(rustc::diagnostic_outside_of_impl)] #![feature(box_patterns)] +#![feature(cow_is_borrowed)] #![feature(decl_macro)] #![feature(is_sorted)] #![feature(let_chains)] @@ -20,6 +21,7 @@ extern crate tracing; #[macro_use] extern crate rustc_middle; +use hir::ConstContext; use required_consts::RequiredConstsVisitor; use rustc_const_eval::util; use rustc_data_structures::fx::FxIndexSet; @@ -61,7 +63,10 @@ mod const_goto; mod const_prop; mod const_prop_lint; mod copy_prop; +mod coroutine; +mod cost_checker; mod coverage; +mod cross_crate_inline; mod ctfe_limit; mod dataflow_const_prop; mod dead_store_elimination; @@ -76,10 +81,10 @@ mod elaborate_drops; mod errors; mod ffi_unwind_calls; mod function_item_references; -mod generator; mod gvn; pub mod inline; mod instsimplify; +mod jump_threading; mod large_enums; mod lower_intrinsics; mod lower_slice_len; @@ -123,6 +128,7 @@ pub fn provide(providers: &mut Providers) { coverage::query::provide(providers); ffi_unwind_calls::provide(providers); shim::provide(providers); + cross_crate_inline::provide(providers); *providers = Providers { mir_keys, mir_const, @@ -130,7 +136,7 @@ pub fn provide(providers: &mut Providers) { mir_promoted, mir_drops_elaborated_and_const_checked, mir_for_ctfe, - mir_generator_witnesses: generator::mir_generator_witnesses, + mir_coroutine_witnesses: coroutine::mir_coroutine_witnesses, optimized_mir, is_mir_available, is_ctfe_mir_available: |tcx, did| is_mir_available(tcx, did), @@ -162,37 +168,50 @@ fn remap_mir_for_const_eval_select<'tcx>( && tcx.item_name(def_id) == sym::const_eval_select && tcx.is_intrinsic(def_id) => { - let [tupled_args, called_in_const, called_at_rt]: [_; 3] = std::mem::take(args).try_into().unwrap(); + let [tupled_args, called_in_const, called_at_rt]: [_; 3] = + std::mem::take(args).try_into().unwrap(); let ty = tupled_args.ty(&body.local_decls, tcx); let fields = ty.tuple_fields(); let num_args = fields.len(); - let func = if context == hir::Constness::Const { called_in_const } else { called_at_rt }; - let (method, place): (fn(Place<'tcx>) -> Operand<'tcx>, Place<'tcx>) = match tupled_args { - Operand::Constant(_) => { - // there is no good way of extracting a tuple arg from a constant (const generic stuff) - // so we just create a temporary and deconstruct that. - let local = body.local_decls.push(LocalDecl::new(ty, fn_span)); - bb.statements.push(Statement { - source_info: SourceInfo::outermost(fn_span), - kind: StatementKind::Assign(Box::new((local.into(), Rvalue::Use(tupled_args.clone())))), - }); - (Operand::Move, local.into()) - } - Operand::Move(place) => (Operand::Move, place), - Operand::Copy(place) => (Operand::Copy, place), - }; - let place_elems = place.projection; - let arguments = (0..num_args).map(|x| { - let mut place_elems = place_elems.to_vec(); - place_elems.push(ProjectionElem::Field(x.into(), fields[x])); - let projection = tcx.mk_place_elems(&place_elems); - let place = Place { - local: place.local, - projection, + let func = + if context == hir::Constness::Const { called_in_const } else { called_at_rt }; + let (method, place): (fn(Place<'tcx>) -> Operand<'tcx>, Place<'tcx>) = + match tupled_args { + Operand::Constant(_) => { + // there is no good way of extracting a tuple arg from a constant (const generic stuff) + // so we just create a temporary and deconstruct that. + let local = body.local_decls.push(LocalDecl::new(ty, fn_span)); + bb.statements.push(Statement { + source_info: SourceInfo::outermost(fn_span), + kind: StatementKind::Assign(Box::new(( + local.into(), + Rvalue::Use(tupled_args.clone()), + ))), + }); + (Operand::Move, local.into()) + } + Operand::Move(place) => (Operand::Move, place), + Operand::Copy(place) => (Operand::Copy, place), }; - method(place) - }).collect(); - terminator.kind = TerminatorKind::Call { func, args: arguments, destination, target, unwind, call_source: CallSource::Misc, fn_span }; + let place_elems = place.projection; + let arguments = (0..num_args) + .map(|x| { + let mut place_elems = place_elems.to_vec(); + place_elems.push(ProjectionElem::Field(x.into(), fields[x])); + let projection = tcx.mk_place_elems(&place_elems); + let place = Place { local: place.local, projection }; + method(place) + }) + .collect(); + terminator.kind = TerminatorKind::Call { + func, + args: arguments, + destination, + target, + unwind, + call_source: CallSource::Misc, + fn_span, + }; } _ => {} } @@ -234,8 +253,13 @@ fn mir_const_qualif(tcx: TyCtxt<'_>, def: LocalDefId) -> ConstQualifs { let const_kind = tcx.hir().body_const_context(def); // No need to const-check a non-const `fn`. - if const_kind.is_none() { - return Default::default(); + match const_kind { + Some(ConstContext::Const { .. } | ConstContext::Static(_)) + | Some(ConstContext::ConstFn) => {} + None => span_bug!( + tcx.def_span(def), + "`mir_const_qualif` should only be called on const fns and const items" + ), } // N.B., this `borrow()` is guaranteed to be valid (i.e., the value @@ -300,7 +324,21 @@ fn mir_promoted( // Ensure that we compute the `mir_const_qualif` for constants at // this point, before we steal the mir-const result. // Also this means promotion can rely on all const checks having been done. - let const_qualifs = tcx.mir_const_qualif(def); + + let const_qualifs = match tcx.def_kind(def) { + DefKind::Fn | DefKind::AssocFn | DefKind::Closure + if tcx.constness(def) == hir::Constness::Const + || tcx.is_const_default_method(def.to_def_id()) => + { + tcx.mir_const_qualif(def) + } + DefKind::AssocConst + | DefKind::Const + | DefKind::Static(_) + | DefKind::InlineConst + | DefKind::AnonConst => tcx.mir_const_qualif(def), + _ => ConstQualifs::default(), + }; let mut body = tcx.mir_const(def).steal(); if let Some(error_reported) = const_qualifs.tainted_by_errors { body.tainted_by_errors = Some(error_reported); @@ -360,15 +398,15 @@ fn inner_mir_for_ctfe(tcx: TyCtxt<'_>, def: LocalDefId) -> Body<'_> { /// mir borrowck *before* doing so in order to ensure that borrowck can be run and doesn't /// end up missing the source MIR due to stealing happening. fn mir_drops_elaborated_and_const_checked(tcx: TyCtxt<'_>, def: LocalDefId) -> &Steal<Body<'_>> { - if let DefKind::Generator = tcx.def_kind(def) { - tcx.ensure_with_value().mir_generator_witnesses(def); + if let DefKind::Coroutine = tcx.def_kind(def) { + tcx.ensure_with_value().mir_coroutine_witnesses(def); } let mir_borrowck = tcx.mir_borrowck(def); let is_fn_like = tcx.def_kind(def).is_fn_like(); if is_fn_like { // Do not compute the mir call graph without said call graph actually being used. - if inline::Inline.is_enabled(&tcx.sess) { + if pm::should_run_pass(tcx, &inline::Inline) { tcx.ensure_with_value().mir_inliner_callees(ty::InstanceDef::Item(def.to_def_id())); } } @@ -494,9 +532,9 @@ fn run_runtime_lowering_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { // `AddRetag` needs to run after `ElaborateDrops`. Otherwise it should run fairly late, // but before optimizations begin. &elaborate_box_derefs::ElaborateBoxDerefs, - &generator::StateTransform, + &coroutine::StateTransform, &add_retag::AddRetag, - &Lint(const_prop_lint::ConstProp), + &Lint(const_prop_lint::ConstPropLint), ]; pm::run_passes_no_validate(tcx, body, passes, Some(MirPhase::Runtime(RuntimePhase::Initial))); } @@ -530,10 +568,11 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &[ &check_alignment::CheckAlignment, &lower_slice_len::LowerSliceLenCalls, // has to be done before inlining, otherwise actual call will be almost always inlined. Also simple, so can just do first - &unreachable_prop::UnreachablePropagation, + &inline::Inline, + // Substitutions during inlining may introduce switch on enums with uninhabited branches. &uninhabited_enum_branching::UninhabitedEnumBranching, + &unreachable_prop::UnreachablePropagation, &o1(simplify::SimplifyCfg::AfterUninhabitedEnumBranching), - &inline::Inline, &remove_storage_markers::RemoveStorageMarkers, &remove_zsts::RemoveZsts, &normalize_array_len::NormalizeArrayLen, // has to run after `slice::len` lowering @@ -553,11 +592,11 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &separate_const_switch::SeparateConstSwitch, &const_prop::ConstProp, &gvn::GVN, + &simplify::SimplifyLocals::AfterGVN, &dataflow_const_prop::DataflowConstProp, - // - // Const-prop runs unconditionally, but doesn't mutate the MIR at mir-opt-level=0. &const_debuginfo::ConstDebugInfo, &o1(simplify_branches::SimplifyConstCondition::AfterConstProp), + &jump_threading::JumpThreading, &early_otherwise_branch::EarlyOtherwiseBranch, &simplify_comparison_integral::SimplifyComparisonIntegral, &dead_store_elimination::DeadStoreElimination, @@ -613,6 +652,15 @@ fn inner_optimized_mir(tcx: TyCtxt<'_>, did: LocalDefId) -> Body<'_> { return body; } + // If `mir_drops_elaborated_and_const_checked` found that the current body has unsatisfiable + // predicates, it will shrink the MIR to a single `unreachable` terminator. + // More generally, if MIR is a lone `unreachable`, there is nothing to optimize. + if let TerminatorKind::Unreachable = body.basic_blocks[START_BLOCK].terminator().kind + && body.basic_blocks[START_BLOCK].statements.is_empty() + { + return body; + } + run_optimization_passes(tcx, &mut body); body diff --git a/compiler/rustc_mir_transform/src/lower_intrinsics.rs b/compiler/rustc_mir_transform/src/lower_intrinsics.rs index 0d2d764c4..5f3d8dfc6 100644 --- a/compiler/rustc_mir_transform/src/lower_intrinsics.rs +++ b/compiler/rustc_mir_transform/src/lower_intrinsics.rs @@ -2,9 +2,8 @@ use crate::MirPass; use rustc_middle::mir::*; -use rustc_middle::ty::GenericArgsRef; -use rustc_middle::ty::{self, Ty, TyCtxt}; -use rustc_span::symbol::{sym, Symbol}; +use rustc_middle::ty::{self, TyCtxt}; +use rustc_span::symbol::sym; use rustc_target::abi::{FieldIdx, VariantIdx}; pub struct LowerIntrinsics; @@ -16,12 +15,10 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics { let terminator = block.terminator.as_mut().unwrap(); if let TerminatorKind::Call { func, args, destination, target, .. } = &mut terminator.kind + && let ty::FnDef(def_id, generic_args) = *func.ty(local_decls, tcx).kind() + && tcx.is_intrinsic(def_id) { - let func_ty = func.ty(local_decls, tcx); - let Some((intrinsic_name, generic_args)) = resolve_rust_intrinsic(tcx, func_ty) - else { - continue; - }; + let intrinsic_name = tcx.item_name(def_id); match intrinsic_name { sym::unreachable => { terminator.kind = TerminatorKind::Unreachable; @@ -169,12 +166,16 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics { let [arg] = args.as_slice() else { span_bug!(terminator.source_info.span, "Wrong number of arguments"); }; - let derefed_place = - if let Some(place) = arg.place() && let Some(local) = place.as_local() { - tcx.mk_place_deref(local.into()) - } else { - span_bug!(terminator.source_info.span, "Only passing a local is supported"); - }; + let derefed_place = if let Some(place) = arg.place() + && let Some(local) = place.as_local() + { + tcx.mk_place_deref(local.into()) + } else { + span_bug!( + terminator.source_info.span, + "Only passing a local is supported" + ); + }; // Add new statement at the end of the block that does the read, and patch // up the terminator. block.statements.push(Statement { @@ -201,12 +202,16 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics { "Wrong number of arguments for write_via_move intrinsic", ); }; - let derefed_place = - if let Some(place) = ptr.place() && let Some(local) = place.as_local() { - tcx.mk_place_deref(local.into()) - } else { - span_bug!(terminator.source_info.span, "Only passing a local is supported"); - }; + let derefed_place = if let Some(place) = ptr.place() + && let Some(local) = place.as_local() + { + tcx.mk_place_deref(local.into()) + } else { + span_bug!( + terminator.source_info.span, + "Only passing a local is supported" + ); + }; block.statements.push(Statement { source_info: terminator.source_info, kind: StatementKind::Assign(Box::new(( @@ -309,15 +314,3 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics { } } } - -fn resolve_rust_intrinsic<'tcx>( - tcx: TyCtxt<'tcx>, - func_ty: Ty<'tcx>, -) -> Option<(Symbol, GenericArgsRef<'tcx>)> { - if let ty::FnDef(def_id, args) = *func_ty.kind() { - if tcx.is_intrinsic(def_id) { - return Some((tcx.item_name(def_id), args)); - } - } - None -} diff --git a/compiler/rustc_mir_transform/src/lower_slice_len.rs b/compiler/rustc_mir_transform/src/lower_slice_len.rs index b7cc0db95..ae4878411 100644 --- a/compiler/rustc_mir_transform/src/lower_slice_len.rs +++ b/compiler/rustc_mir_transform/src/lower_slice_len.rs @@ -34,67 +34,43 @@ pub fn lower_slice_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { } } -struct SliceLenPatchInformation<'tcx> { - add_statement: Statement<'tcx>, - new_terminator_kind: TerminatorKind<'tcx>, -} - fn lower_slice_len_call<'tcx>( tcx: TyCtxt<'tcx>, block: &mut BasicBlockData<'tcx>, local_decls: &IndexSlice<Local, LocalDecl<'tcx>>, slice_len_fn_item_def_id: DefId, ) { - let mut patch_found: Option<SliceLenPatchInformation<'_>> = None; - let terminator = block.terminator(); - match &terminator.kind { - TerminatorKind::Call { - func, - args, - destination, - target: Some(bb), - call_source: CallSource::Normal, - .. - } => { - // some heuristics for fast rejection - if args.len() != 1 { - return; - } - let Some(arg) = args[0].place() else { return }; - let func_ty = func.ty(local_decls, tcx); - match func_ty.kind() { - ty::FnDef(fn_def_id, _) if fn_def_id == &slice_len_fn_item_def_id => { - // perform modifications - // from something like `_5 = core::slice::<impl [u8]>::len(move _6) -> bb1` - // into: - // ``` - // _5 = Len(*_6) - // goto bb1 - // ``` + if let TerminatorKind::Call { + func, + args, + destination, + target: Some(bb), + call_source: CallSource::Normal, + .. + } = &terminator.kind + // some heuristics for fast rejection + && let [arg] = &args[..] + && let Some(arg) = arg.place() + && let ty::FnDef(fn_def_id, _) = func.ty(local_decls, tcx).kind() + && *fn_def_id == slice_len_fn_item_def_id + { + // perform modifications from something like: + // _5 = core::slice::<impl [u8]>::len(move _6) -> bb1 + // into: + // _5 = Len(*_6) + // goto bb1 - // make new RValue for Len - let deref_arg = tcx.mk_place_deref(arg); - let r_value = Rvalue::Len(deref_arg); - let len_statement_kind = - StatementKind::Assign(Box::new((*destination, r_value))); - let add_statement = - Statement { kind: len_statement_kind, source_info: terminator.source_info }; + // make new RValue for Len + let deref_arg = tcx.mk_place_deref(arg); + let r_value = Rvalue::Len(deref_arg); + let len_statement_kind = StatementKind::Assign(Box::new((*destination, r_value))); + let add_statement = + Statement { kind: len_statement_kind, source_info: terminator.source_info }; - // modify terminator into simple Goto - let new_terminator_kind = TerminatorKind::Goto { target: *bb }; - - let patch = SliceLenPatchInformation { add_statement, new_terminator_kind }; - - patch_found = Some(patch); - } - _ => {} - } - } - _ => {} - } + // modify terminator into simple Goto + let new_terminator_kind = TerminatorKind::Goto { target: *bb }; - if let Some(SliceLenPatchInformation { add_statement, new_terminator_kind }) = patch_found { block.statements.push(add_statement); block.terminator_mut().kind = new_terminator_kind; } diff --git a/compiler/rustc_mir_transform/src/multiple_return_terminators.rs b/compiler/rustc_mir_transform/src/multiple_return_terminators.rs index c97d03454..c9b42e75c 100644 --- a/compiler/rustc_mir_transform/src/multiple_return_terminators.rs +++ b/compiler/rustc_mir_transform/src/multiple_return_terminators.rs @@ -38,6 +38,6 @@ impl<'tcx> MirPass<'tcx> for MultipleReturnTerminators { } } - simplify::remove_dead_blocks(tcx, body) + simplify::remove_dead_blocks(body) } } diff --git a/compiler/rustc_mir_transform/src/normalize_array_len.rs b/compiler/rustc_mir_transform/src/normalize_array_len.rs index d1a4b26a0..206cdf9fe 100644 --- a/compiler/rustc_mir_transform/src/normalize_array_len.rs +++ b/compiler/rustc_mir_transform/src/normalize_array_len.rs @@ -57,7 +57,9 @@ fn compute_slice_length<'tcx>( } // The length information is stored in the fat pointer, so we treat `operand` as a value. Rvalue::Use(operand) => { - if let Some(rhs) = operand.place() && let Some(rhs) = rhs.as_local() { + if let Some(rhs) = operand.place() + && let Some(rhs) = rhs.as_local() + { slice_lengths[local] = slice_lengths[rhs]; } } diff --git a/compiler/rustc_mir_transform/src/nrvo.rs b/compiler/rustc_mir_transform/src/nrvo.rs index e1298b065..ff309bd10 100644 --- a/compiler/rustc_mir_transform/src/nrvo.rs +++ b/compiler/rustc_mir_transform/src/nrvo.rs @@ -34,7 +34,7 @@ pub struct RenameReturnPlace; impl<'tcx> MirPass<'tcx> for RenameReturnPlace { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - // #111005 + // unsound: #111005 sess.mir_opt_level() > 0 && sess.opts.unstable_opts.unsound_mir_opts } diff --git a/compiler/rustc_mir_transform/src/pass_manager.rs b/compiler/rustc_mir_transform/src/pass_manager.rs index 5abb2f3d0..a8aba29ad 100644 --- a/compiler/rustc_mir_transform/src/pass_manager.rs +++ b/compiler/rustc_mir_transform/src/pass_manager.rs @@ -83,6 +83,25 @@ pub fn run_passes<'tcx>( run_passes_inner(tcx, body, passes, phase_change, true); } +pub fn should_run_pass<'tcx, P>(tcx: TyCtxt<'tcx>, pass: &P) -> bool +where + P: MirPass<'tcx> + ?Sized, +{ + let name = pass.name(); + + let overridden_passes = &tcx.sess.opts.unstable_opts.mir_enable_passes; + let overridden = + overridden_passes.iter().rev().find(|(s, _)| s == &*name).map(|(_name, polarity)| { + trace!( + pass = %name, + "{} as requested by flag", + if *polarity { "Running" } else { "Not running" }, + ); + *polarity + }); + overridden.unwrap_or_else(|| pass.is_enabled(&tcx.sess)) +} + fn run_passes_inner<'tcx>( tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, @@ -100,19 +119,9 @@ fn run_passes_inner<'tcx>( for pass in passes { let name = pass.name(); - let overridden = overridden_passes.iter().rev().find(|(s, _)| s == &*name).map( - |(_name, polarity)| { - trace!( - pass = %name, - "{} as requested by flag", - if *polarity { "Running" } else { "Not running" }, - ); - *polarity - }, - ); - if !overridden.unwrap_or_else(|| pass.is_enabled(&tcx.sess)) { + if !should_run_pass(tcx, *pass) { continue; - } + }; let dump_enabled = pass.is_mir_dump_enabled(); diff --git a/compiler/rustc_mir_transform/src/ref_prop.rs b/compiler/rustc_mir_transform/src/ref_prop.rs index 67941cf43..df39c819b 100644 --- a/compiler/rustc_mir_transform/src/ref_prop.rs +++ b/compiler/rustc_mir_transform/src/ref_prop.rs @@ -210,14 +210,17 @@ fn compute_replacement<'tcx>( // have been visited before. Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) | Rvalue::CopyForDeref(place) => { - if let Some(rhs) = place.as_local() && ssa.is_ssa(rhs) { + if let Some(rhs) = place.as_local() + && ssa.is_ssa(rhs) + { let target = targets[rhs]; // Only see through immutable reference and pointers, as we do not know yet if // mutable references are fully replaced. if !needs_unique && matches!(target, Value::Pointer(..)) { targets[local] = target; } else { - targets[local] = Value::Pointer(tcx.mk_place_deref(rhs.into()), needs_unique); + targets[local] = + Value::Pointer(tcx.mk_place_deref(rhs.into()), needs_unique); } } } @@ -365,7 +368,7 @@ impl<'tcx> MutVisitor<'tcx> for Replacer<'tcx> { *place = Place::from(target.local).project_deeper(rest, self.tcx); self.any_replacement = true; } else { - break + break; } } diff --git a/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs b/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs index 8c48a6677..54892442c 100644 --- a/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs +++ b/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs @@ -69,7 +69,7 @@ impl RemoveNoopLandingPads { | TerminatorKind::FalseUnwind { .. } => { terminator.successors().all(|succ| nop_landing_pads.contains(succ)) } - TerminatorKind::GeneratorDrop + TerminatorKind::CoroutineDrop | TerminatorKind::Yield { .. } | TerminatorKind::Return | TerminatorKind::UnwindTerminate(_) diff --git a/compiler/rustc_mir_transform/src/remove_uninit_drops.rs b/compiler/rustc_mir_transform/src/remove_uninit_drops.rs index 263849747..87fee2410 100644 --- a/compiler/rustc_mir_transform/src/remove_uninit_drops.rs +++ b/compiler/rustc_mir_transform/src/remove_uninit_drops.rs @@ -24,11 +24,8 @@ pub struct RemoveUninitDrops; impl<'tcx> MirPass<'tcx> for RemoveUninitDrops { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let param_env = tcx.param_env(body.source.def_id()); - let Ok(move_data) = MoveData::gather_moves(body, tcx, param_env) else { - // We could continue if there are move errors, but there's not much point since our - // init data isn't complete. - return; - }; + let move_data = + MoveData::gather_moves(&body, tcx, param_env, |ty| ty.needs_drop(tcx, param_env)); let mdpe = MoveDataParamEnv { move_data, param_env }; let mut maybe_inits = MaybeInitializedPlaces::new(tcx, body, &mdpe) diff --git a/compiler/rustc_mir_transform/src/remove_zsts.rs b/compiler/rustc_mir_transform/src/remove_zsts.rs index a34d4b027..5aa3c3cfe 100644 --- a/compiler/rustc_mir_transform/src/remove_zsts.rs +++ b/compiler/rustc_mir_transform/src/remove_zsts.rs @@ -13,8 +13,8 @@ impl<'tcx> MirPass<'tcx> for RemoveZsts { } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - // Avoid query cycles (generators require optimized MIR for layout). - if tcx.type_of(body.source.def_id()).instantiate_identity().is_generator() { + // Avoid query cycles (coroutines require optimized MIR for layout). + if tcx.type_of(body.source.def_id()).instantiate_identity().is_coroutine() { return; } let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); @@ -126,7 +126,10 @@ impl<'tcx> MutVisitor<'tcx> for Replacer<'_, 'tcx> { && let ty = place_for_ty.ty(self.local_decls, self.tcx).ty && self.known_to_be_zst(ty) && self.tcx.consider_optimizing(|| { - format!("RemoveZsts - Place: {:?} SourceInfo: {:?}", place_for_ty, statement.source_info) + format!( + "RemoveZsts - Place: {:?} SourceInfo: {:?}", + place_for_ty, statement.source_info + ) }) { statement.make_nop(); diff --git a/compiler/rustc_mir_transform/src/separate_const_switch.rs b/compiler/rustc_mir_transform/src/separate_const_switch.rs index e1e4acccc..907cfe758 100644 --- a/compiler/rustc_mir_transform/src/separate_const_switch.rs +++ b/compiler/rustc_mir_transform/src/separate_const_switch.rs @@ -118,7 +118,7 @@ pub fn separate_const_switch(body: &mut Body<'_>) -> usize { | TerminatorKind::Return | TerminatorKind::Unreachable | TerminatorKind::InlineAsm { .. } - | TerminatorKind::GeneratorDrop => { + | TerminatorKind::CoroutineDrop => { continue 'predec_iter; } } @@ -169,7 +169,7 @@ pub fn separate_const_switch(body: &mut Body<'_>) -> usize { | TerminatorKind::UnwindTerminate(_) | TerminatorKind::Return | TerminatorKind::Unreachable - | TerminatorKind::GeneratorDrop + | TerminatorKind::CoroutineDrop | TerminatorKind::Assert { .. } | TerminatorKind::FalseUnwind { .. } | TerminatorKind::Drop { .. } diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs index e9895d97d..ab7961321 100644 --- a/compiler/rustc_mir_transform/src/shim.rs +++ b/compiler/rustc_mir_transform/src/shim.rs @@ -4,7 +4,7 @@ use rustc_hir::lang_items::LangItem; use rustc_middle::mir::*; use rustc_middle::query::Providers; use rustc_middle::ty::GenericArgs; -use rustc_middle::ty::{self, EarlyBinder, GeneratorArgs, Ty, TyCtxt}; +use rustc_middle::ty::{self, CoroutineArgs, EarlyBinder, Ty, TyCtxt}; use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT}; use rustc_index::{Idx, IndexVec}; @@ -67,18 +67,20 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<' } ty::InstanceDef::DropGlue(def_id, ty) => { - // FIXME(#91576): Drop shims for generators aren't subject to the MIR passes at the end + // FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end // of this function. Is this intentional? - if let Some(ty::Generator(gen_def_id, args, _)) = ty.map(Ty::kind) { - let body = tcx.optimized_mir(*gen_def_id).generator_drop().unwrap(); + if let Some(ty::Coroutine(coroutine_def_id, args, _)) = ty.map(Ty::kind) { + let body = tcx.optimized_mir(*coroutine_def_id).coroutine_drop().unwrap(); let mut body = EarlyBinder::bind(body.clone()).instantiate(tcx, args); debug!("make_shim({:?}) = {:?}", instance, body); - // Run empty passes to mark phase change and perform validation. pm::run_passes( tcx, &mut body, - &[], + &[ + &abort_unwinding_calls::AbortUnwindingCalls, + &add_call_guards::CriticalCallEdges, + ], Some(MirPhase::Runtime(RuntimePhase::Optimized)), ); @@ -171,7 +173,7 @@ fn local_decls_for_sig<'tcx>( fn build_drop_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, ty: Option<Ty<'tcx>>) -> Body<'tcx> { debug!("build_drop_shim(def_id={:?}, ty={:?})", def_id, ty); - assert!(!matches!(ty, Some(ty) if ty.is_generator())); + assert!(!matches!(ty, Some(ty) if ty.is_coroutine())); let args = if let Some(ty) = ty { tcx.mk_args(&[ty.into()]) @@ -392,8 +394,8 @@ fn build_clone_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'tcx>) - _ if is_copy => builder.copy_shim(), ty::Closure(_, args) => builder.tuple_like_shim(dest, src, args.as_closure().upvar_tys()), ty::Tuple(..) => builder.tuple_like_shim(dest, src, self_ty.tuple_fields()), - ty::Generator(gen_def_id, args, hir::Movability::Movable) => { - builder.generator_shim(dest, src, *gen_def_id, args.as_generator()) + ty::Coroutine(coroutine_def_id, args, hir::Movability::Movable) => { + builder.coroutine_shim(dest, src, *coroutine_def_id, args.as_coroutine()) } _ => bug!("clone shim for `{:?}` which is not `Copy` and is not an aggregate", self_ty), }; @@ -593,12 +595,12 @@ impl<'tcx> CloneShimBuilder<'tcx> { let _final_cleanup_block = self.clone_fields(dest, src, target, unwind, tys); } - fn generator_shim( + fn coroutine_shim( &mut self, dest: Place<'tcx>, src: Place<'tcx>, - gen_def_id: DefId, - args: GeneratorArgs<'tcx>, + coroutine_def_id: DefId, + args: CoroutineArgs<'tcx>, ) { self.block(vec![], TerminatorKind::Goto { target: self.block_index_offset(3) }, false); let unwind = self.block(vec![], TerminatorKind::UnwindResume, true); @@ -607,8 +609,8 @@ impl<'tcx> CloneShimBuilder<'tcx> { let unwind = self.clone_fields(dest, src, switch, unwind, args.upvar_tys()); let target = self.block(vec![], TerminatorKind::Return, false); let unreachable = self.block(vec![], TerminatorKind::Unreachable, false); - let mut cases = Vec::with_capacity(args.state_tys(gen_def_id, self.tcx).count()); - for (index, state_tys) in args.state_tys(gen_def_id, self.tcx).enumerate() { + let mut cases = Vec::with_capacity(args.state_tys(coroutine_def_id, self.tcx).count()); + for (index, state_tys) in args.state_tys(coroutine_def_id, self.tcx).enumerate() { let variant_index = VariantIdx::new(index); let dest = self.tcx.mk_place_downcast_unnamed(dest, variant_index); let src = self.tcx.mk_place_downcast_unnamed(src, variant_index); diff --git a/compiler/rustc_mir_transform/src/simplify.rs b/compiler/rustc_mir_transform/src/simplify.rs index 2795cf157..0a1c01114 100644 --- a/compiler/rustc_mir_transform/src/simplify.rs +++ b/compiler/rustc_mir_transform/src/simplify.rs @@ -28,10 +28,8 @@ //! return. use crate::MirPass; -use rustc_data_structures::fx::{FxHashSet, FxIndexSet}; -use rustc_index::bit_set::BitSet; +use rustc_data_structures::fx::FxIndexSet; use rustc_index::{Idx, IndexSlice, IndexVec}; -use rustc_middle::mir::coverage::*; use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor}; use rustc_middle::mir::*; use rustc_middle::ty::TyCtxt; @@ -68,7 +66,7 @@ impl SimplifyCfg { pub fn simplify_cfg<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { CfgSimplifier::new(body).simplify(); remove_duplicate_unreachable_blocks(tcx, body); - remove_dead_blocks(tcx, body); + remove_dead_blocks(body); // FIXME: Should probably be moved into some kind of pass manager body.basic_blocks_mut().raw.shrink_to_fit(); @@ -337,7 +335,7 @@ pub fn remove_duplicate_unreachable_blocks<'tcx>(tcx: TyCtxt<'tcx>, body: &mut B } } -pub fn remove_dead_blocks<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { +pub fn remove_dead_blocks(body: &mut Body<'_>) { let reachable = traversal::reachable_as_bitset(body); let num_blocks = body.basic_blocks.len(); if num_blocks == reachable.count() { @@ -345,10 +343,6 @@ pub fn remove_dead_blocks<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { } let basic_blocks = body.basic_blocks.as_mut(); - let source_scopes = &body.source_scopes; - if tcx.sess.instrument_coverage() { - save_unreachable_coverage(basic_blocks, source_scopes, &reachable); - } let mut replacements: Vec<_> = (0..num_blocks).map(BasicBlock::new).collect(); let mut orig_index = 0; @@ -370,99 +364,9 @@ pub fn remove_dead_blocks<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { } } -/// Some MIR transforms can determine at compile time that a sequences of -/// statements will never be executed, so they can be dropped from the MIR. -/// For example, an `if` or `else` block that is guaranteed to never be executed -/// because its condition can be evaluated at compile time, such as by const -/// evaluation: `if false { ... }`. -/// -/// Those statements are bypassed by redirecting paths in the CFG around the -/// `dead blocks`; but with `-C instrument-coverage`, the dead blocks usually -/// include `Coverage` statements representing the Rust source code regions to -/// be counted at runtime. Without these `Coverage` statements, the regions are -/// lost, and the Rust source code will show no coverage information. -/// -/// What we want to show in a coverage report is the dead code with coverage -/// counts of `0`. To do this, we need to save the code regions, by injecting -/// `Unreachable` coverage statements. These are non-executable statements whose -/// code regions are still recorded in the coverage map, representing regions -/// with `0` executions. -/// -/// If there are no live `Counter` `Coverage` statements remaining, we remove -/// `Coverage` statements along with the dead blocks. Since at least one -/// counter per function is required by LLVM (and necessary, to add the -/// `function_hash` to the counter's call to the LLVM intrinsic -/// `instrprof.increment()`). -/// -/// The `generator::StateTransform` MIR pass and MIR inlining can create -/// atypical conditions, where all live `Counter`s are dropped from the MIR. -/// -/// With MIR inlining we can have coverage counters belonging to different -/// instances in a single body, so the strategy described above is applied to -/// coverage counters from each instance individually. -fn save_unreachable_coverage( - basic_blocks: &mut IndexSlice<BasicBlock, BasicBlockData<'_>>, - source_scopes: &IndexSlice<SourceScope, SourceScopeData<'_>>, - reachable: &BitSet<BasicBlock>, -) { - // Identify instances that still have some live coverage counters left. - let mut live = FxHashSet::default(); - for bb in reachable.iter() { - let basic_block = &basic_blocks[bb]; - for statement in &basic_block.statements { - let StatementKind::Coverage(coverage) = &statement.kind else { continue }; - let CoverageKind::Counter { .. } = coverage.kind else { continue }; - let instance = statement.source_info.scope.inlined_instance(source_scopes); - live.insert(instance); - } - } - - for bb in reachable.iter() { - let block = &mut basic_blocks[bb]; - for statement in &mut block.statements { - let StatementKind::Coverage(_) = &statement.kind else { continue }; - let instance = statement.source_info.scope.inlined_instance(source_scopes); - if !live.contains(&instance) { - statement.make_nop(); - } - } - } - - if live.is_empty() { - return; - } - - // Retain coverage for instances that still have some live counters left. - let mut retained_coverage = Vec::new(); - for dead_block in basic_blocks.indices() { - if reachable.contains(dead_block) { - continue; - } - let dead_block = &basic_blocks[dead_block]; - for statement in &dead_block.statements { - let StatementKind::Coverage(coverage) = &statement.kind else { continue }; - let Some(code_region) = &coverage.code_region else { continue }; - let instance = statement.source_info.scope.inlined_instance(source_scopes); - if live.contains(&instance) { - retained_coverage.push((statement.source_info, code_region.clone())); - } - } - } - - let start_block = &mut basic_blocks[START_BLOCK]; - start_block.statements.extend(retained_coverage.into_iter().map( - |(source_info, code_region)| Statement { - source_info, - kind: StatementKind::Coverage(Box::new(Coverage { - kind: CoverageKind::Unreachable, - code_region: Some(code_region), - })), - }, - )); -} - pub enum SimplifyLocals { BeforeConstProp, + AfterGVN, Final, } @@ -470,6 +374,7 @@ impl<'tcx> MirPass<'tcx> for SimplifyLocals { fn name(&self) -> &'static str { match &self { SimplifyLocals::BeforeConstProp => "SimplifyLocals-before-const-prop", + SimplifyLocals::AfterGVN => "SimplifyLocals-after-value-numbering", SimplifyLocals::Final => "SimplifyLocals-final", } } diff --git a/compiler/rustc_mir_transform/src/simplify_branches.rs b/compiler/rustc_mir_transform/src/simplify_branches.rs index b508cd1c9..1f0e605c3 100644 --- a/compiler/rustc_mir_transform/src/simplify_branches.rs +++ b/compiler/rustc_mir_transform/src/simplify_branches.rs @@ -16,8 +16,25 @@ impl<'tcx> MirPass<'tcx> for SimplifyConstCondition { } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + trace!("Running SimplifyConstCondition on {:?}", body.source); let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); - for block in body.basic_blocks_mut() { + 'blocks: for block in body.basic_blocks_mut() { + for stmt in block.statements.iter_mut() { + if let StatementKind::Intrinsic(box ref intrinsic) = stmt.kind + && let NonDivergingIntrinsic::Assume(discr) = intrinsic + && let Operand::Constant(ref c) = discr + && let Some(constant) = c.const_.try_eval_bool(tcx, param_env) + { + if constant { + stmt.make_nop(); + } else { + block.statements.clear(); + block.terminator_mut().kind = TerminatorKind::Unreachable; + continue 'blocks; + } + } + } + let terminator = block.terminator_mut(); terminator.kind = match terminator.kind { TerminatorKind::SwitchInt { diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs index c21b1724c..7de4ca667 100644 --- a/compiler/rustc_mir_transform/src/sroa.rs +++ b/compiler/rustc_mir_transform/src/sroa.rs @@ -7,7 +7,7 @@ use rustc_middle::mir::visit::*; use rustc_middle::mir::*; use rustc_middle::ty::{self, Ty, TyCtxt}; use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields}; -use rustc_target::abi::{FieldIdx, ReprFlags, FIRST_VARIANT}; +use rustc_target::abi::{FieldIdx, FIRST_VARIANT}; pub struct ScalarReplacementOfAggregates; @@ -20,8 +20,8 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { debug!(def_id = ?body.source.def_id()); - // Avoid query cycles (generators require optimized MIR for layout). - if tcx.type_of(body.source.def_id()).instantiate_identity().is_generator() { + // Avoid query cycles (coroutines require optimized MIR for layout). + if tcx.type_of(body.source.def_id()).instantiate_identity().is_coroutine() { return; } @@ -66,7 +66,7 @@ fn escaping_locals<'tcx>( return true; } if let ty::Adt(def, _args) = ty.kind() { - if def.repr().flags.contains(ReprFlags::IS_SIMD) { + if def.repr().simd() { // Exclude #[repr(simd)] types so that they are not de-optimized into an array return true; } diff --git a/compiler/rustc_mir_transform/src/ssa.rs b/compiler/rustc_mir_transform/src/ssa.rs index 43fc1b7b9..1f59c790b 100644 --- a/compiler/rustc_mir_transform/src/ssa.rs +++ b/compiler/rustc_mir_transform/src/ssa.rs @@ -5,7 +5,6 @@ //! As a consequence of rule 2, we consider that borrowed locals are not SSA, even if they are //! `Freeze`, as we do not track that the assignment dominates all uses of the borrow. -use either::Either; use rustc_data_structures::graph::dominators::Dominators; use rustc_index::bit_set::BitSet; use rustc_index::{IndexSlice, IndexVec}; @@ -15,7 +14,7 @@ use rustc_middle::mir::*; pub struct SsaLocals { /// Assignments to each local. This defines whether the local is SSA. - assignments: IndexVec<Local, Set1<LocationExtended>>, + assignments: IndexVec<Local, Set1<DefLocation>>, /// We visit the body in reverse postorder, to ensure each local is assigned before it is used. /// We remember the order in which we saw the assignments to compute the SSA values in a single /// pass. @@ -27,39 +26,10 @@ pub struct SsaLocals { direct_uses: IndexVec<Local, u32>, } -/// We often encounter MIR bodies with 1 or 2 basic blocks. In those cases, it's unnecessary to -/// actually compute dominators, we can just compare block indices because bb0 is always the first -/// block, and in any body all other blocks are always dominated by bb0. -struct SmallDominators<'a> { - inner: Option<&'a Dominators<BasicBlock>>, -} - -impl SmallDominators<'_> { - fn dominates(&self, first: Location, second: Location) -> bool { - if first.block == second.block { - first.statement_index <= second.statement_index - } else if let Some(inner) = &self.inner { - inner.dominates(first.block, second.block) - } else { - first.block < second.block - } - } - - fn check_dominates(&mut self, set: &mut Set1<LocationExtended>, loc: Location) { - let assign_dominates = match *set { - Set1::Empty | Set1::Many => false, - Set1::One(LocationExtended::Arg) => true, - Set1::One(LocationExtended::Plain(assign)) => { - self.dominates(assign.successor_within_block(), loc) - } - }; - // We are visiting a use that is not dominated by an assignment. - // Either there is a cycle involved, or we are reading for uninitialized local. - // Bail out. - if !assign_dominates { - *set = Set1::Many; - } - } +pub enum AssignedValue<'a, 'tcx> { + Arg, + Rvalue(&'a mut Rvalue<'tcx>), + Terminator(&'a mut TerminatorKind<'tcx>), } impl SsaLocals { @@ -67,15 +37,14 @@ impl SsaLocals { let assignment_order = Vec::with_capacity(body.local_decls.len()); let assignments = IndexVec::from_elem(Set1::Empty, &body.local_decls); - let dominators = - if body.basic_blocks.len() > 2 { Some(body.basic_blocks.dominators()) } else { None }; - let dominators = SmallDominators { inner: dominators }; + let dominators = body.basic_blocks.dominators(); let direct_uses = IndexVec::from_elem(0, &body.local_decls); let mut visitor = SsaVisitor { assignments, assignment_order, dominators, direct_uses }; for local in body.args_iter() { - visitor.assignments[local] = Set1::One(LocationExtended::Arg); + visitor.assignments[local] = Set1::One(DefLocation::Argument); + visitor.assignment_order.push(local); } // For SSA assignments, a RPO visit will see the assignment before it sees any use. @@ -131,14 +100,7 @@ impl SsaLocals { location: Location, ) -> bool { match self.assignments[local] { - Set1::One(LocationExtended::Arg) => true, - Set1::One(LocationExtended::Plain(ass)) => { - if ass.block == location.block { - ass.statement_index < location.statement_index - } else { - dominators.dominates(ass.block, location.block) - } - } + Set1::One(def) => def.dominates(location, dominators), _ => false, } } @@ -148,9 +110,9 @@ impl SsaLocals { body: &'a Body<'tcx>, ) -> impl Iterator<Item = (Local, &'a Rvalue<'tcx>, Location)> + 'a { self.assignment_order.iter().filter_map(|&local| { - if let Set1::One(LocationExtended::Plain(loc)) = self.assignments[local] { + if let Set1::One(DefLocation::Body(loc)) = self.assignments[local] { + let stmt = body.stmt_at(loc).left()?; // `loc` must point to a direct assignment to `local`. - let Either::Left(stmt) = body.stmt_at(loc) else { bug!() }; let Some((target, rvalue)) = stmt.kind.as_assign() else { bug!() }; assert_eq!(target.as_local(), Some(local)); Some((local, rvalue, loc)) @@ -162,18 +124,33 @@ impl SsaLocals { pub fn for_each_assignment_mut<'tcx>( &self, - basic_blocks: &mut BasicBlocks<'tcx>, - mut f: impl FnMut(Local, &mut Rvalue<'tcx>, Location), + basic_blocks: &mut IndexSlice<BasicBlock, BasicBlockData<'tcx>>, + mut f: impl FnMut(Local, AssignedValue<'_, 'tcx>, Location), ) { for &local in &self.assignment_order { - if let Set1::One(LocationExtended::Plain(loc)) = self.assignments[local] { - // `loc` must point to a direct assignment to `local`. - let bbs = basic_blocks.as_mut_preserves_cfg(); - let bb = &mut bbs[loc.block]; - let stmt = &mut bb.statements[loc.statement_index]; - let StatementKind::Assign(box (target, ref mut rvalue)) = stmt.kind else { bug!() }; - assert_eq!(target.as_local(), Some(local)); - f(local, rvalue, loc) + match self.assignments[local] { + Set1::One(DefLocation::Argument) => f( + local, + AssignedValue::Arg, + Location { block: START_BLOCK, statement_index: 0 }, + ), + Set1::One(DefLocation::Body(loc)) => { + let bb = &mut basic_blocks[loc.block]; + let value = if loc.statement_index < bb.statements.len() { + // `loc` must point to a direct assignment to `local`. + let stmt = &mut bb.statements[loc.statement_index]; + let StatementKind::Assign(box (target, ref mut rvalue)) = stmt.kind else { + bug!() + }; + assert_eq!(target.as_local(), Some(local)); + AssignedValue::Rvalue(rvalue) + } else { + let term = bb.terminator_mut(); + AssignedValue::Terminator(&mut term.kind) + }; + f(local, value, loc) + } + _ => {} } } } @@ -224,19 +201,29 @@ impl SsaLocals { } } -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -enum LocationExtended { - Plain(Location), - Arg, -} - struct SsaVisitor<'a> { - dominators: SmallDominators<'a>, - assignments: IndexVec<Local, Set1<LocationExtended>>, + dominators: &'a Dominators<BasicBlock>, + assignments: IndexVec<Local, Set1<DefLocation>>, assignment_order: Vec<Local>, direct_uses: IndexVec<Local, u32>, } +impl SsaVisitor<'_> { + fn check_dominates(&mut self, local: Local, loc: Location) { + let set = &mut self.assignments[local]; + let assign_dominates = match *set { + Set1::Empty | Set1::Many => false, + Set1::One(def) => def.dominates(loc, self.dominators), + }; + // We are visiting a use that is not dominated by an assignment. + // Either there is a cycle involved, or we are reading for uninitialized local. + // Bail out. + if !assign_dominates { + *set = Set1::Many; + } + } +} + impl<'tcx> Visitor<'tcx> for SsaVisitor<'_> { fn visit_local(&mut self, local: Local, ctxt: PlaceContext, loc: Location) { match ctxt { @@ -254,7 +241,7 @@ impl<'tcx> Visitor<'tcx> for SsaVisitor<'_> { self.assignments[local] = Set1::Many; } PlaceContext::NonMutatingUse(_) => { - self.dominators.check_dominates(&mut self.assignments[local], loc); + self.check_dominates(local, loc); self.direct_uses[local] += 1; } PlaceContext::NonUse(_) => {} @@ -262,34 +249,34 @@ impl<'tcx> Visitor<'tcx> for SsaVisitor<'_> { } fn visit_place(&mut self, place: &Place<'tcx>, ctxt: PlaceContext, loc: Location) { - if place.projection.first() == Some(&PlaceElem::Deref) { - // Do not do anything for storage statements and debuginfo. + let location = match ctxt { + PlaceContext::MutatingUse( + MutatingUseContext::Store | MutatingUseContext::Call | MutatingUseContext::Yield, + ) => Some(DefLocation::Body(loc)), + _ => None, + }; + if let Some(location) = location + && let Some(local) = place.as_local() + { + self.assignments[local].insert(location); + if let Set1::One(_) = self.assignments[local] { + // Only record if SSA-like, to avoid growing the vector needlessly. + self.assignment_order.push(local); + } + } else if place.projection.first() == Some(&PlaceElem::Deref) { + // Do not do anything for debuginfo. if ctxt.is_use() { // Only change the context if it is a real use, not a "use" in debuginfo. let new_ctxt = PlaceContext::NonMutatingUse(NonMutatingUseContext::Copy); self.visit_projection(place.as_ref(), new_ctxt, loc); - self.dominators.check_dominates(&mut self.assignments[place.local], loc); + self.check_dominates(place.local, loc); } - return; } else { self.visit_projection(place.as_ref(), ctxt, loc); self.visit_local(place.local, ctxt, loc); } } - - fn visit_assign(&mut self, place: &Place<'tcx>, rvalue: &Rvalue<'tcx>, loc: Location) { - if let Some(local) = place.as_local() { - self.assignments[local].insert(LocationExtended::Plain(loc)); - if let Set1::One(_) = self.assignments[local] { - // Only record if SSA-like, to avoid growing the vector needlessly. - self.assignment_order.push(local); - } - } else { - self.visit_place(place, PlaceContext::MutatingUse(MutatingUseContext::Store), loc); - } - self.visit_rvalue(rvalue, loc); - } } #[instrument(level = "trace", skip(ssa, body))] @@ -356,7 +343,7 @@ fn compute_copy_classes(ssa: &mut SsaLocals, body: &Body<'_>) { #[derive(Debug)] pub(crate) struct StorageLiveLocals { /// Set of "StorageLive" statements for each local. - storage_live: IndexVec<Local, Set1<LocationExtended>>, + storage_live: IndexVec<Local, Set1<DefLocation>>, } impl StorageLiveLocals { @@ -366,13 +353,13 @@ impl StorageLiveLocals { ) -> StorageLiveLocals { let mut storage_live = IndexVec::from_elem(Set1::Empty, &body.local_decls); for local in always_storage_live_locals.iter() { - storage_live[local] = Set1::One(LocationExtended::Arg); + storage_live[local] = Set1::One(DefLocation::Argument); } for (block, bbdata) in body.basic_blocks.iter_enumerated() { for (statement_index, statement) in bbdata.statements.iter().enumerate() { if let StatementKind::StorageLive(local) = statement.kind { storage_live[local] - .insert(LocationExtended::Plain(Location { block, statement_index })); + .insert(DefLocation::Body(Location { block, statement_index })); } } } diff --git a/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs b/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs index 092bcb5c9..98f67e18a 100644 --- a/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs +++ b/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs @@ -3,8 +3,7 @@ use crate::MirPass; use rustc_data_structures::fx::FxHashSet; use rustc_middle::mir::{ - BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, SwitchTargets, Terminator, - TerminatorKind, + BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, Terminator, TerminatorKind, }; use rustc_middle::ty::layout::TyAndLayout; use rustc_middle::ty::{Ty, TyCtxt}; @@ -30,18 +29,16 @@ fn get_switched_on_type<'tcx>( let terminator = block_data.terminator(); // Only bother checking blocks which terminate by switching on a local. - if let Some(local) = get_discriminant_local(&terminator.kind) { - let stmt_before_term = (!block_data.statements.is_empty()) - .then(|| &block_data.statements[block_data.statements.len() - 1].kind); - - if let Some(StatementKind::Assign(box (l, Rvalue::Discriminant(place)))) = stmt_before_term - { - if l.as_local() == Some(local) { - let ty = place.ty(body, tcx).ty; - if ty.is_enum() { - return Some(ty); - } - } + let local = get_discriminant_local(&terminator.kind)?; + + let stmt_before_term = block_data.statements.last()?; + + if let StatementKind::Assign(box (l, Rvalue::Discriminant(place))) = stmt_before_term.kind + && l.as_local() == Some(local) + { + let ty = place.ty(body, tcx).ty; + if ty.is_enum() { + return Some(ty); } } @@ -72,28 +69,6 @@ fn variant_discriminants<'tcx>( } } -/// Ensures that the `otherwise` branch leads to an unreachable bb, returning `None` if so and a new -/// bb to use as the new target if not. -fn ensure_otherwise_unreachable<'tcx>( - body: &Body<'tcx>, - targets: &SwitchTargets, -) -> Option<BasicBlockData<'tcx>> { - let otherwise = targets.otherwise(); - let bb = &body.basic_blocks[otherwise]; - if bb.terminator().kind == TerminatorKind::Unreachable - && bb.statements.iter().all(|s| matches!(&s.kind, StatementKind::StorageDead(_))) - { - return None; - } - - let mut new_block = BasicBlockData::new(Some(Terminator { - source_info: bb.terminator().source_info, - kind: TerminatorKind::Unreachable, - })); - new_block.is_cleanup = bb.is_cleanup; - Some(new_block) -} - impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { sess.mir_opt_level() > 0 @@ -102,13 +77,16 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { trace!("UninhabitedEnumBranching starting for {:?}", body.source); - for bb in body.basic_blocks.indices() { + let mut removable_switchs = Vec::new(); + + for (bb, bb_data) in body.basic_blocks.iter_enumerated() { trace!("processing block {:?}", bb); - let Some(discriminant_ty) = get_switched_on_type(&body.basic_blocks[bb], tcx, body) - else { + if bb_data.is_cleanup { continue; - }; + } + + let Some(discriminant_ty) = get_switched_on_type(&bb_data, tcx, body) else { continue }; let layout = tcx.layout_of( tcx.param_env_reveal_all_normalized(body.source.def_id()).and(discriminant_ty), @@ -122,31 +100,38 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching { trace!("allowed_variants = {:?}", allowed_variants); - if let TerminatorKind::SwitchInt { targets, .. } = - &mut body.basic_blocks_mut()[bb].terminator_mut().kind - { - let mut new_targets = SwitchTargets::new( - targets.iter().filter(|(val, _)| allowed_variants.contains(val)), - targets.otherwise(), - ); - - if new_targets.iter().count() == allowed_variants.len() { - if let Some(updated) = ensure_otherwise_unreachable(body, &new_targets) { - let new_otherwise = body.basic_blocks_mut().push(updated); - *new_targets.all_targets_mut().last_mut().unwrap() = new_otherwise; - } - } + let terminator = bb_data.terminator(); + let TerminatorKind::SwitchInt { targets, .. } = &terminator.kind else { bug!() }; - if let TerminatorKind::SwitchInt { targets, .. } = - &mut body.basic_blocks_mut()[bb].terminator_mut().kind - { - *targets = new_targets; + let mut reachable_count = 0; + for (index, (val, _)) in targets.iter().enumerate() { + if allowed_variants.contains(&val) { + reachable_count += 1; } else { - unreachable!() + removable_switchs.push((bb, index)); } - } else { - unreachable!() } + + if reachable_count == allowed_variants.len() { + removable_switchs.push((bb, targets.iter().count())); + } + } + + if removable_switchs.is_empty() { + return; + } + + let new_block = BasicBlockData::new(Some(Terminator { + source_info: body.basic_blocks[removable_switchs[0].0].terminator().source_info, + kind: TerminatorKind::Unreachable, + })); + let unreachable_block = body.basic_blocks.as_mut().push(new_block); + + for (bb, index) in removable_switchs { + let bb = &mut body.basic_blocks.as_mut()[bb]; + let terminator = bb.terminator_mut(); + let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind else { bug!() }; + targets.all_targets_mut()[index] = unreachable_block; } } } diff --git a/compiler/rustc_mir_transform/src/unreachable_prop.rs b/compiler/rustc_mir_transform/src/unreachable_prop.rs index 0b9311a20..919e8d6a2 100644 --- a/compiler/rustc_mir_transform/src/unreachable_prop.rs +++ b/compiler/rustc_mir_transform/src/unreachable_prop.rs @@ -2,11 +2,13 @@ //! when all of their successors are unreachable. This is achieved through a //! post-order traversal of the blocks. -use crate::simplify; use crate::MirPass; -use rustc_data_structures::fx::{FxHashMap, FxHashSet}; +use rustc_data_structures::fx::FxHashSet; +use rustc_middle::mir::interpret::Scalar; +use rustc_middle::mir::patch::MirPatch; use rustc_middle::mir::*; -use rustc_middle::ty::TyCtxt; +use rustc_middle::ty::{self, TyCtxt}; +use rustc_target::abi::Size; pub struct UnreachablePropagation; @@ -21,106 +23,133 @@ impl MirPass<'_> for UnreachablePropagation { } fn run_pass<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let mut patch = MirPatch::new(body); let mut unreachable_blocks = FxHashSet::default(); - let mut replacements = FxHashMap::default(); for (bb, bb_data) in traversal::postorder(body) { let terminator = bb_data.terminator(); - if terminator.kind == TerminatorKind::Unreachable { - unreachable_blocks.insert(bb); - } else { - let is_unreachable = |succ: BasicBlock| unreachable_blocks.contains(&succ); - let terminator_kind_opt = remove_successors(&terminator.kind, is_unreachable); - - if let Some(terminator_kind) = terminator_kind_opt { - if terminator_kind == TerminatorKind::Unreachable { - unreachable_blocks.insert(bb); - } - replacements.insert(bb, terminator_kind); + let is_unreachable = match &terminator.kind { + TerminatorKind::Unreachable => true, + // This will unconditionally run into an unreachable and is therefore unreachable as well. + TerminatorKind::Goto { target } if unreachable_blocks.contains(target) => { + patch.patch_terminator(bb, TerminatorKind::Unreachable); + true + } + // Try to remove unreachable targets from the switch. + TerminatorKind::SwitchInt { .. } => { + remove_successors_from_switch(tcx, bb, &unreachable_blocks, body, &mut patch) } + _ => false, + }; + if is_unreachable { + unreachable_blocks.insert(bb); } } + if !tcx + .consider_optimizing(|| format!("UnreachablePropagation {:?} ", body.source.def_id())) + { + return; + } + + patch.apply(body); + // We do want do keep some unreachable blocks, but make them empty. for bb in unreachable_blocks { - if !tcx.consider_optimizing(|| { - format!("UnreachablePropagation {:?} ", body.source.def_id()) - }) { - break; - } - body.basic_blocks_mut()[bb].statements.clear(); } + } +} - let replaced = !replacements.is_empty(); +/// Return whether the current terminator is fully unreachable. +fn remove_successors_from_switch<'tcx>( + tcx: TyCtxt<'tcx>, + bb: BasicBlock, + unreachable_blocks: &FxHashSet<BasicBlock>, + body: &Body<'tcx>, + patch: &mut MirPatch<'tcx>, +) -> bool { + let terminator = body.basic_blocks[bb].terminator(); + let TerminatorKind::SwitchInt { discr, targets } = &terminator.kind else { bug!() }; + let source_info = terminator.source_info; + let location = body.terminator_loc(bb); + + let is_unreachable = |bb| unreachable_blocks.contains(&bb); + + // If there are multiple targets, we want to keep information about reachability for codegen. + // For example (see tests/codegen/match-optimizes-away.rs) + // + // pub enum Two { A, B } + // pub fn identity(x: Two) -> Two { + // match x { + // Two::A => Two::A, + // Two::B => Two::B, + // } + // } + // + // This generates a `switchInt() -> [0: 0, 1: 1, otherwise: unreachable]`, which allows us or LLVM to + // turn it into just `x` later. Without the unreachable, such a transformation would be illegal. + // + // In order to preserve this information, we record reachable and unreachable targets as + // `Assume` statements in MIR. + + let discr_ty = discr.ty(body, tcx); + let discr_size = Size::from_bits(match discr_ty.kind() { + ty::Uint(uint) => uint.normalize(tcx.sess.target.pointer_width).bit_width().unwrap(), + ty::Int(int) => int.normalize(tcx.sess.target.pointer_width).bit_width().unwrap(), + ty::Char => 32, + ty::Bool => 1, + other => bug!("unhandled type: {:?}", other), + }); + + let mut add_assumption = |binop, value| { + let local = patch.new_temp(tcx.types.bool, source_info.span); + let value = Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::from_scalar(tcx, Scalar::from_uint(value, discr_size), discr_ty), + })); + let cmp = Rvalue::BinaryOp(binop, Box::new((discr.to_copy(), value))); + patch.add_assign(location, local.into(), cmp); + + let assume = NonDivergingIntrinsic::Assume(Operand::Move(local.into())); + patch.add_statement(location, StatementKind::Intrinsic(Box::new(assume))); + }; - for (bb, terminator_kind) in replacements { - if !tcx.consider_optimizing(|| { - format!("UnreachablePropagation {:?} ", body.source.def_id()) - }) { - break; - } + let otherwise = targets.otherwise(); + let otherwise_unreachable = is_unreachable(otherwise); - body.basic_blocks_mut()[bb].terminator_mut().kind = terminator_kind; + let reachable_iter = targets.iter().filter(|&(value, bb)| { + let is_unreachable = is_unreachable(bb); + // We remove this target from the switch, so record the inequality using `Assume`. + if is_unreachable && !otherwise_unreachable { + add_assumption(BinOp::Ne, value); } - - if replaced { - simplify::remove_dead_blocks(tcx, body); + !is_unreachable + }); + + let new_targets = SwitchTargets::new(reachable_iter, otherwise); + + let num_targets = new_targets.all_targets().len(); + let fully_unreachable = num_targets == 1 && otherwise_unreachable; + + let terminator = match (num_targets, otherwise_unreachable) { + // If all targets are unreachable, we can be unreachable as well. + (1, true) => TerminatorKind::Unreachable, + (1, false) => TerminatorKind::Goto { target: otherwise }, + (2, true) => { + // All targets are unreachable except one. Record the equality, and make it a goto. + let (value, target) = new_targets.iter().next().unwrap(); + add_assumption(BinOp::Eq, value); + TerminatorKind::Goto { target } } - } -} - -fn remove_successors<'tcx, F>( - terminator_kind: &TerminatorKind<'tcx>, - is_unreachable: F, -) -> Option<TerminatorKind<'tcx>> -where - F: Fn(BasicBlock) -> bool, -{ - let terminator = match terminator_kind { - // This will unconditionally run into an unreachable and is therefore unreachable as well. - TerminatorKind::Goto { target } if is_unreachable(*target) => TerminatorKind::Unreachable, - TerminatorKind::SwitchInt { targets, discr } => { - let otherwise = targets.otherwise(); - - // If all targets are unreachable, we can be unreachable as well. - if targets.all_targets().iter().all(|bb| is_unreachable(*bb)) { - TerminatorKind::Unreachable - } else if is_unreachable(otherwise) { - // If there are multiple targets, don't delete unreachable branches (like an unreachable otherwise) - // unless otherwise is unreachable, in which case deleting a normal branch causes it to be merged with - // the otherwise, keeping its unreachable. - // This looses information about reachability causing worse codegen. - // For example (see tests/codegen/match-optimizes-away.rs) - // - // pub enum Two { A, B } - // pub fn identity(x: Two) -> Two { - // match x { - // Two::A => Two::A, - // Two::B => Two::B, - // } - // } - // - // This generates a `switchInt() -> [0: 0, 1: 1, otherwise: unreachable]`, which allows us or LLVM to - // turn it into just `x` later. Without the unreachable, such a transformation would be illegal. - // If the otherwise branch is unreachable, we can delete all other unreachable targets, as they will - // still point to the unreachable and therefore not lose reachability information. - let reachable_iter = targets.iter().filter(|(_, bb)| !is_unreachable(*bb)); - - let new_targets = SwitchTargets::new(reachable_iter, otherwise); - - // No unreachable branches were removed. - if new_targets.all_targets().len() == targets.all_targets().len() { - return None; - } - - TerminatorKind::SwitchInt { discr: discr.clone(), targets: new_targets } - } else { - // If the otherwise branch is reachable, we don't want to delete any unreachable branches. - return None; - } + _ if num_targets == targets.all_targets().len() => { + // Nothing has changed. + return false; } - _ => return None, + _ => TerminatorKind::SwitchInt { discr: discr.clone(), targets: new_targets }, }; - Some(terminator) + + patch.patch_terminator(bb, terminator); + fully_unreachable } |