summaryrefslogtreecommitdiffstats
path: root/compiler/rustc_mir_transform/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src')
-rw-r--r--compiler/rustc_mir_transform/src/abort_unwinding_calls.rs4
-rw-r--r--compiler/rustc_mir_transform/src/check_alignment.rs105
-rw-r--r--compiler/rustc_mir_transform/src/check_const_item_mutation.rs16
-rw-r--r--compiler/rustc_mir_transform/src/check_packed_ref.rs9
-rw-r--r--compiler/rustc_mir_transform/src/check_unsafety.rs9
-rw-r--r--compiler/rustc_mir_transform/src/const_debuginfo.rs4
-rw-r--r--compiler/rustc_mir_transform/src/const_prop.rs47
-rw-r--r--compiler/rustc_mir_transform/src/const_prop_lint.rs68
-rw-r--r--compiler/rustc_mir_transform/src/copy_prop.rs5
-rw-r--r--compiler/rustc_mir_transform/src/coroutine.rs (renamed from compiler/rustc_mir_transform/src/generator.rs)529
-rw-r--r--compiler/rustc_mir_transform/src/cost_checker.rs98
-rw-r--r--compiler/rustc_mir_transform/src/coverage/counters.rs437
-rw-r--r--compiler/rustc_mir_transform/src/coverage/graph.rs283
-rw-r--r--compiler/rustc_mir_transform/src/coverage/mod.rs292
-rw-r--r--compiler/rustc_mir_transform/src/coverage/query.rs107
-rw-r--r--compiler/rustc_mir_transform/src/coverage/spans.rs597
-rw-r--r--compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs193
-rw-r--r--compiler/rustc_mir_transform/src/coverage/tests.rs44
-rw-r--r--compiler/rustc_mir_transform/src/cross_crate_inline.rs130
-rw-r--r--compiler/rustc_mir_transform/src/dataflow_const_prop.rs219
-rw-r--r--compiler/rustc_mir_transform/src/dead_store_elimination.rs22
-rw-r--r--compiler/rustc_mir_transform/src/deduce_param_attrs.rs1
-rw-r--r--compiler/rustc_mir_transform/src/deref_separator.rs2
-rw-r--r--compiler/rustc_mir_transform/src/dest_prop.rs6
-rw-r--r--compiler/rustc_mir_transform/src/early_otherwise_branch.rs1
-rw-r--r--compiler/rustc_mir_transform/src/elaborate_box_derefs.rs2
-rw-r--r--compiler/rustc_mir_transform/src/elaborate_drops.rs186
-rw-r--r--compiler/rustc_mir_transform/src/ffi_unwind_calls.rs2
-rw-r--r--compiler/rustc_mir_transform/src/gvn.rs795
-rw-r--r--compiler/rustc_mir_transform/src/inline.rs135
-rw-r--r--compiler/rustc_mir_transform/src/instsimplify.rs18
-rw-r--r--compiler/rustc_mir_transform/src/jump_threading.rs759
-rw-r--r--compiler/rustc_mir_transform/src/large_enums.rs3
-rw-r--r--compiler/rustc_mir_transform/src/lib.rs130
-rw-r--r--compiler/rustc_mir_transform/src/lower_intrinsics.rs57
-rw-r--r--compiler/rustc_mir_transform/src/lower_slice_len.rs78
-rw-r--r--compiler/rustc_mir_transform/src/multiple_return_terminators.rs2
-rw-r--r--compiler/rustc_mir_transform/src/normalize_array_len.rs4
-rw-r--r--compiler/rustc_mir_transform/src/nrvo.rs2
-rw-r--r--compiler/rustc_mir_transform/src/pass_manager.rs33
-rw-r--r--compiler/rustc_mir_transform/src/ref_prop.rs9
-rw-r--r--compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs2
-rw-r--r--compiler/rustc_mir_transform/src/remove_uninit_drops.rs7
-rw-r--r--compiler/rustc_mir_transform/src/remove_zsts.rs9
-rw-r--r--compiler/rustc_mir_transform/src/separate_const_switch.rs4
-rw-r--r--compiler/rustc_mir_transform/src/shim.rs30
-rw-r--r--compiler/rustc_mir_transform/src/simplify.rs105
-rw-r--r--compiler/rustc_mir_transform/src/simplify_branches.rs19
-rw-r--r--compiler/rustc_mir_transform/src/sroa.rs8
-rw-r--r--compiler/rustc_mir_transform/src/ssa.rs163
-rw-r--r--compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs107
-rw-r--r--compiler/rustc_mir_transform/src/unreachable_prop.rs201
52 files changed, 3652 insertions, 2446 deletions
diff --git a/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs b/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs
index 4500bb7ff..2b3d423ea 100644
--- a/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs
+++ b/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs
@@ -40,7 +40,7 @@ impl<'tcx> MirPass<'tcx> for AbortUnwindingCalls {
let body_abi = match body_ty.kind() {
ty::FnDef(..) => body_ty.fn_sig(tcx).abi(),
ty::Closure(..) => Abi::RustCall,
- ty::Generator(..) => Abi::Rust,
+ ty::Coroutine(..) => Abi::Rust,
_ => span_bug!(body.span, "unexpected body ty: {:?}", body_ty),
};
let body_can_unwind = layout::fn_can_unwind(tcx, Some(def_id), body_abi);
@@ -113,6 +113,6 @@ impl<'tcx> MirPass<'tcx> for AbortUnwindingCalls {
}
// We may have invalidated some `cleanup` blocks so clean those up now.
- super::simplify::remove_dead_blocks(tcx, body);
+ super::simplify::remove_dead_blocks(body);
}
}
diff --git a/compiler/rustc_mir_transform/src/check_alignment.rs b/compiler/rustc_mir_transform/src/check_alignment.rs
index 28765af20..42b2f1886 100644
--- a/compiler/rustc_mir_transform/src/check_alignment.rs
+++ b/compiler/rustc_mir_transform/src/check_alignment.rs
@@ -1,13 +1,12 @@
use crate::MirPass;
-use rustc_hir::def_id::DefId;
use rustc_hir::lang_items::LangItem;
use rustc_index::IndexVec;
use rustc_middle::mir::*;
use rustc_middle::mir::{
interpret::Scalar,
- visit::{PlaceContext, Visitor},
+ visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor},
};
-use rustc_middle::ty::{Ty, TyCtxt, TypeAndMut};
+use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt, TypeAndMut};
use rustc_session::Session;
pub struct CheckAlignment;
@@ -30,7 +29,12 @@ impl<'tcx> MirPass<'tcx> for CheckAlignment {
let basic_blocks = body.basic_blocks.as_mut();
let local_decls = &mut body.local_decls;
+ let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id());
+ // This pass inserts new blocks. Each insertion changes the Location for all
+ // statements/blocks after. Iterating or visiting the MIR in order would require updating
+ // our current location after every insertion. By iterating backwards, we dodge this issue:
+ // The only Locations that an insertion changes have already been handled.
for block in (0..basic_blocks.len()).rev() {
let block = block.into();
for statement_index in (0..basic_blocks[block].statements.len()).rev() {
@@ -38,22 +42,19 @@ impl<'tcx> MirPass<'tcx> for CheckAlignment {
let statement = &basic_blocks[block].statements[statement_index];
let source_info = statement.source_info;
- let mut finder = PointerFinder {
- local_decls,
- tcx,
- pointers: Vec::new(),
- def_id: body.source.def_id(),
- };
- for (pointer, pointee_ty) in finder.find_pointers(statement) {
- debug!("Inserting alignment check for {:?}", pointer.ty(&*local_decls, tcx).ty);
+ let mut finder =
+ PointerFinder { tcx, local_decls, param_env, pointers: Vec::new() };
+ finder.visit_statement(statement, location);
+ for (local, ty) in finder.pointers {
+ debug!("Inserting alignment check for {:?}", ty);
let new_block = split_block(basic_blocks, location);
insert_alignment_check(
tcx,
local_decls,
&mut basic_blocks[block],
- pointer,
- pointee_ty,
+ local,
+ ty,
source_info,
new_block,
);
@@ -63,69 +64,71 @@ impl<'tcx> MirPass<'tcx> for CheckAlignment {
}
}
-impl<'tcx, 'a> PointerFinder<'tcx, 'a> {
- fn find_pointers(&mut self, statement: &Statement<'tcx>) -> Vec<(Place<'tcx>, Ty<'tcx>)> {
- self.pointers.clear();
- self.visit_statement(statement, Location::START);
- core::mem::take(&mut self.pointers)
- }
-}
-
struct PointerFinder<'tcx, 'a> {
- local_decls: &'a mut LocalDecls<'tcx>,
tcx: TyCtxt<'tcx>,
- def_id: DefId,
+ local_decls: &'a mut LocalDecls<'tcx>,
+ param_env: ParamEnv<'tcx>,
pointers: Vec<(Place<'tcx>, Ty<'tcx>)>,
}
impl<'tcx, 'a> Visitor<'tcx> for PointerFinder<'tcx, 'a> {
- fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
- if let Rvalue::AddressOf(..) = rvalue {
- // Ignore dereferences inside of an AddressOf
- return;
+ fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
+ // We want to only check reads and writes to Places, so we specifically exclude
+ // Borrows and AddressOf.
+ match context {
+ PlaceContext::MutatingUse(
+ MutatingUseContext::Store
+ | MutatingUseContext::AsmOutput
+ | MutatingUseContext::Call
+ | MutatingUseContext::Yield
+ | MutatingUseContext::Drop,
+ ) => {}
+ PlaceContext::NonMutatingUse(
+ NonMutatingUseContext::Copy | NonMutatingUseContext::Move,
+ ) => {}
+ _ => {
+ return;
+ }
}
- self.super_rvalue(rvalue, location);
- }
- fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) {
- if let PlaceContext::NonUse(_) = context {
- return;
- }
if !place.is_indirect() {
return;
}
+ // Since Deref projections must come first and only once, the pointer for an indirect place
+ // is the Local that the Place is based on.
let pointer = Place::from(place.local);
- let pointer_ty = pointer.ty(&*self.local_decls, self.tcx).ty;
+ let pointer_ty = self.local_decls[place.local].ty;
- // We only want to check unsafe pointers
+ // We only want to check places based on unsafe pointers
if !pointer_ty.is_unsafe_ptr() {
- trace!("Indirect, but not an unsafe ptr, not checking {:?}", pointer_ty);
+ trace!("Indirect, but not based on an unsafe ptr, not checking {:?}", place);
return;
}
- let Some(pointee) = pointer_ty.builtin_deref(true) else {
- debug!("Indirect but no builtin deref: {:?}", pointer_ty);
+ let pointee_ty =
+ pointer_ty.builtin_deref(true).expect("no builtin_deref for an unsafe pointer").ty;
+ // Ideally we'd support this in the future, but for now we are limited to sized types.
+ if !pointee_ty.is_sized(self.tcx, self.param_env) {
+ debug!("Unsafe pointer, but pointee is not known to be sized: {:?}", pointer_ty);
return;
- };
- let mut pointee_ty = pointee.ty;
- if pointee_ty.is_array() || pointee_ty.is_slice() || pointee_ty.is_str() {
- pointee_ty = pointee_ty.sequence_element_type(self.tcx);
}
- if !pointee_ty.is_sized(self.tcx, self.tcx.param_env_reveal_all_normalized(self.def_id)) {
- debug!("Unsafe pointer, but unsized: {:?}", pointer_ty);
+ // Try to detect types we are sure have an alignment of 1 and skip the check
+ // We don't need to look for str and slices, we already rejected unsized types above
+ let element_ty = match pointee_ty.kind() {
+ ty::Array(ty, _) => *ty,
+ _ => pointee_ty,
+ };
+ if [self.tcx.types.bool, self.tcx.types.i8, self.tcx.types.u8].contains(&element_ty) {
+ debug!("Trivially aligned place type: {:?}", pointee_ty);
return;
}
- if [self.tcx.types.bool, self.tcx.types.i8, self.tcx.types.u8, self.tcx.types.str_]
- .contains(&pointee_ty)
- {
- debug!("Trivially aligned pointee type: {:?}", pointer_ty);
- return;
- }
+ // Ensure that this place is based on an aligned pointer.
+ self.pointers.push((pointer, pointee_ty));
- self.pointers.push((pointer, pointee_ty))
+ self.super_place(place, context, location);
}
}
diff --git a/compiler/rustc_mir_transform/src/check_const_item_mutation.rs b/compiler/rustc_mir_transform/src/check_const_item_mutation.rs
index b79150737..61bf530f1 100644
--- a/compiler/rustc_mir_transform/src/check_const_item_mutation.rs
+++ b/compiler/rustc_mir_transform/src/check_const_item_mutation.rs
@@ -97,13 +97,15 @@ impl<'tcx> Visitor<'tcx> for ConstMutationChecker<'_, 'tcx> {
// so emitting a lint would be redundant.
if !lhs.projection.is_empty() {
if let Some(def_id) = self.is_const_item_without_destructor(lhs.local)
- && let Some((lint_root, span, item)) = self.should_lint_const_item_usage(&lhs, def_id, loc) {
- self.tcx.emit_spanned_lint(
- CONST_ITEM_MUTATION,
- lint_root,
- span,
- errors::ConstMutate::Modify { konst: item }
- );
+ && let Some((lint_root, span, item)) =
+ self.should_lint_const_item_usage(&lhs, def_id, loc)
+ {
+ self.tcx.emit_spanned_lint(
+ CONST_ITEM_MUTATION,
+ lint_root,
+ span,
+ errors::ConstMutate::Modify { konst: item },
+ );
}
}
// We are looking for MIR of the form:
diff --git a/compiler/rustc_mir_transform/src/check_packed_ref.rs b/compiler/rustc_mir_transform/src/check_packed_ref.rs
index 2e6cf603d..9ee0a7040 100644
--- a/compiler/rustc_mir_transform/src/check_packed_ref.rs
+++ b/compiler/rustc_mir_transform/src/check_packed_ref.rs
@@ -46,9 +46,14 @@ impl<'tcx> Visitor<'tcx> for PackedRefChecker<'_, 'tcx> {
// If we ever reach here it means that the generated derive
// code is somehow doing an unaligned reference, which it
// shouldn't do.
- span_bug!(self.source_info.span, "builtin derive created an unaligned reference");
+ span_bug!(
+ self.source_info.span,
+ "builtin derive created an unaligned reference"
+ );
} else {
- self.tcx.sess.emit_err(errors::UnalignedPackedRef { span: self.source_info.span });
+ self.tcx
+ .sess
+ .emit_err(errors::UnalignedPackedRef { span: self.source_info.span });
}
}
}
diff --git a/compiler/rustc_mir_transform/src/check_unsafety.rs b/compiler/rustc_mir_transform/src/check_unsafety.rs
index bacabc62e..8872f9a97 100644
--- a/compiler/rustc_mir_transform/src/check_unsafety.rs
+++ b/compiler/rustc_mir_transform/src/check_unsafety.rs
@@ -56,7 +56,7 @@ impl<'tcx> Visitor<'tcx> for UnsafetyChecker<'_, 'tcx> {
| TerminatorKind::Drop { .. }
| TerminatorKind::Yield { .. }
| TerminatorKind::Assert { .. }
- | TerminatorKind::GeneratorDrop
+ | TerminatorKind::CoroutineDrop
| TerminatorKind::UnwindResume
| TerminatorKind::UnwindTerminate(_)
| TerminatorKind::Return
@@ -128,7 +128,7 @@ impl<'tcx> Visitor<'tcx> for UnsafetyChecker<'_, 'tcx> {
),
}
}
- &AggregateKind::Closure(def_id, _) | &AggregateKind::Generator(def_id, _, _) => {
+ &AggregateKind::Closure(def_id, _) | &AggregateKind::Coroutine(def_id, _, _) => {
let def_id = def_id.expect_local();
let UnsafetyCheckResult { violations, used_unsafe_blocks, .. } =
self.tcx.unsafety_check_result(def_id);
@@ -179,7 +179,7 @@ impl<'tcx> Visitor<'tcx> for UnsafetyChecker<'_, 'tcx> {
// Check the base local: it might be an unsafe-to-access static. We only check derefs of the
// temporary holding the static pointer to avoid duplicate errors
// <https://github.com/rust-lang/rust/pull/78068#issuecomment-731753506>.
- if decl.internal && place.projection.first() == Some(&ProjectionElem::Deref) {
+ if place.projection.first() == Some(&ProjectionElem::Deref) {
// If the projection root is an artificial local that we introduced when
// desugaring `static`, give a more specific error message
// (avoid the general "raw pointer" clause below, that would only be confusing).
@@ -540,8 +540,7 @@ pub fn check_unsafety(tcx: TyCtxt<'_>, def_id: LocalDefId) {
&& let BlockCheckMode::UnsafeBlock(_) = block.rules
{
true
- }
- else if let Some(sig) = tcx.hir().fn_sig_by_hir_id(*id)
+ } else if let Some(sig) = tcx.hir().fn_sig_by_hir_id(*id)
&& sig.header.is_unsafe()
{
true
diff --git a/compiler/rustc_mir_transform/src/const_debuginfo.rs b/compiler/rustc_mir_transform/src/const_debuginfo.rs
index 40cd28254..e4e4270c4 100644
--- a/compiler/rustc_mir_transform/src/const_debuginfo.rs
+++ b/compiler/rustc_mir_transform/src/const_debuginfo.rs
@@ -55,7 +55,9 @@ fn find_optimization_opportunities<'tcx>(body: &Body<'tcx>) -> Vec<(Local, Const
let mut locals_to_debuginfo = BitSet::new_empty(body.local_decls.len());
for debuginfo in &body.var_debug_info {
- if let VarDebugInfoContents::Place(p) = debuginfo.value && let Some(l) = p.as_local() {
+ if let VarDebugInfoContents::Place(p) = debuginfo.value
+ && let Some(l) = p.as_local()
+ {
locals_to_debuginfo.insert(l);
}
}
diff --git a/compiler/rustc_mir_transform/src/const_prop.rs b/compiler/rustc_mir_transform/src/const_prop.rs
index 50443e739..f7f882310 100644
--- a/compiler/rustc_mir_transform/src/const_prop.rs
+++ b/compiler/rustc_mir_transform/src/const_prop.rs
@@ -2,8 +2,6 @@
//! assertion failures
use either::Right;
-
-use rustc_const_eval::const_eval::CheckAlignment;
use rustc_const_eval::ReportErrorExt;
use rustc_data_structures::fx::FxHashSet;
use rustc_hir::def::DefKind;
@@ -16,7 +14,7 @@ use rustc_middle::mir::*;
use rustc_middle::ty::layout::{LayoutError, LayoutOf, LayoutOfHelpers, TyAndLayout};
use rustc_middle::ty::{self, GenericArgs, Instance, ParamEnv, Ty, TyCtxt, TypeVisitableExt};
use rustc_span::{def_id::DefId, Span};
-use rustc_target::abi::{self, Align, HasDataLayout, Size, TargetDataLayout};
+use rustc_target::abi::{self, HasDataLayout, Size, TargetDataLayout};
use rustc_target::spec::abi::Abi as CallAbi;
use crate::dataflow_const_prop::Patch;
@@ -84,11 +82,11 @@ impl<'tcx> MirPass<'tcx> for ConstProp {
return;
}
- // FIXME(welseywiser) const prop doesn't work on generators because of query cycles
+ // FIXME(welseywiser) const prop doesn't work on coroutines because of query cycles
// computing their layout.
- let is_generator = def_kind == DefKind::Generator;
- if is_generator {
- trace!("ConstProp skipped for generator {:?}", def_id);
+ let is_coroutine = def_kind == DefKind::Coroutine;
+ if is_coroutine {
+ trace!("ConstProp skipped for coroutine {:?}", def_id);
return;
}
@@ -141,27 +139,14 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx>
type MemoryKind = !;
#[inline(always)]
- fn enforce_alignment(_ecx: &InterpCx<'mir, 'tcx, Self>) -> CheckAlignment {
- // We do not check for alignment to avoid having to carry an `Align`
- // in `ConstValue::Indirect`.
- CheckAlignment::No
+ fn enforce_alignment(_ecx: &InterpCx<'mir, 'tcx, Self>) -> bool {
+ false // no reason to enforce alignment
}
#[inline(always)]
fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool {
false // for now, we don't enforce validity
}
- fn alignment_check_failed(
- ecx: &InterpCx<'mir, 'tcx, Self>,
- _has: Align,
- _required: Align,
- _check: CheckAlignment,
- ) -> InterpResult<'tcx, ()> {
- span_bug!(
- ecx.cur_span(),
- "`alignment_check_failed` called when no alignment check requested"
- )
- }
fn load_mir(
_ecx: &InterpCx<'mir, 'tcx, Self>,
@@ -455,6 +440,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
// FIXME we need to revisit this for #67176
if rvalue.has_param() {
+ trace!("skipping, has param");
return None;
}
if !rvalue
@@ -527,7 +513,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
fn replace_with_const(&mut self, place: Place<'tcx>) -> Option<Const<'tcx>> {
// This will return None if the above `const_prop` invocation only "wrote" a
- // type whose creation requires no write. E.g. a generator whose initial state
+ // type whose creation requires no write. E.g. a coroutine whose initial state
// consists solely of uninitialized memory (so it doesn't capture any locals).
let value = self.get_const(place)?;
if !self.tcx.consider_optimizing(|| format!("ConstantPropagation - {value:?}")) {
@@ -699,7 +685,9 @@ impl<'tcx> Visitor<'tcx> for CanConstProp {
impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> {
fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) {
self.super_operand(operand, location);
- if let Some(place) = operand.place() && let Some(value) = self.replace_with_const(place) {
+ if let Some(place) = operand.place()
+ && let Some(value) = self.replace_with_const(place)
+ {
self.patch.before_effect.insert((location, place), value);
}
}
@@ -721,7 +709,11 @@ impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> {
fn visit_assign(&mut self, place: &Place<'tcx>, rvalue: &Rvalue<'tcx>, location: Location) {
self.super_assign(place, rvalue, location);
- let Some(()) = self.check_rvalue(rvalue) else { return };
+ let Some(()) = self.check_rvalue(rvalue) else {
+ trace!("rvalue check failed, removing const");
+ Self::remove_const(&mut self.ecx, place.local);
+ return;
+ };
match self.ecx.machine.can_const_prop[place.local] {
// Do nothing if the place is indirect.
@@ -733,7 +725,10 @@ impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> {
if let Rvalue::Use(Operand::Constant(c)) = rvalue
&& let Const::Val(..) = c.const_
{
- trace!("skipping replace of Rvalue::Use({:?} because it is already a const", c);
+ trace!(
+ "skipping replace of Rvalue::Use({:?} because it is already a const",
+ c
+ );
} else if let Some(operand) = self.replace_with_const(*place) {
self.patch.assignments.insert(location, operand);
}
diff --git a/compiler/rustc_mir_transform/src/const_prop_lint.rs b/compiler/rustc_mir_transform/src/const_prop_lint.rs
index 64e262c6c..a23ba9c4a 100644
--- a/compiler/rustc_mir_transform/src/const_prop_lint.rs
+++ b/compiler/rustc_mir_transform/src/const_prop_lint.rs
@@ -22,7 +22,6 @@ use rustc_middle::ty::{
};
use rustc_span::Span;
use rustc_target::abi::{HasDataLayout, Size, TargetDataLayout};
-use rustc_trait_selection::traits;
use crate::const_prop::CanConstProp;
use crate::const_prop::ConstPropMachine;
@@ -35,9 +34,9 @@ use crate::MirLint;
/// Severely regress performance.
const MAX_ALLOC_LIMIT: u64 = 1024;
-pub struct ConstProp;
+pub struct ConstPropLint;
-impl<'tcx> MirLint<'tcx> for ConstProp {
+impl<'tcx> MirLint<'tcx> for ConstPropLint {
fn run_lint(&self, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) {
if body.tainted_by_errors.is_some() {
return;
@@ -49,61 +48,25 @@ impl<'tcx> MirLint<'tcx> for ConstProp {
}
let def_id = body.source.def_id().expect_local();
- let is_fn_like = tcx.def_kind(def_id).is_fn_like();
- let is_assoc_const = tcx.def_kind(def_id) == DefKind::AssocConst;
+ let def_kind = tcx.def_kind(def_id);
+ let is_fn_like = def_kind.is_fn_like();
+ let is_assoc_const = def_kind == DefKind::AssocConst;
// Only run const prop on functions, methods, closures and associated constants
if !is_fn_like && !is_assoc_const {
// skip anon_const/statics/consts because they'll be evaluated by miri anyway
- trace!("ConstProp skipped for {:?}", def_id);
+ trace!("ConstPropLint skipped for {:?}", def_id);
return;
}
- let is_generator = tcx.type_of(def_id.to_def_id()).instantiate_identity().is_generator();
- // FIXME(welseywiser) const prop doesn't work on generators because of query cycles
+ // FIXME(welseywiser) const prop doesn't work on coroutines because of query cycles
// computing their layout.
- if is_generator {
- trace!("ConstProp skipped for generator {:?}", def_id);
+ if let DefKind::Coroutine = def_kind {
+ trace!("ConstPropLint skipped for coroutine {:?}", def_id);
return;
}
- // Check if it's even possible to satisfy the 'where' clauses
- // for this item.
- // This branch will never be taken for any normal function.
- // However, it's possible to `#!feature(trivial_bounds)]` to write
- // a function with impossible to satisfy clauses, e.g.:
- // `fn foo() where String: Copy {}`
- //
- // We don't usually need to worry about this kind of case,
- // since we would get a compilation error if the user tried
- // to call it. However, since we can do const propagation
- // even without any calls to the function, we need to make
- // sure that it even makes sense to try to evaluate the body.
- // If there are unsatisfiable where clauses, then all bets are
- // off, and we just give up.
- //
- // We manually filter the predicates, skipping anything that's not
- // "global". We are in a potentially generic context
- // (e.g. we are evaluating a function without substituting generic
- // parameters, so this filtering serves two purposes:
- //
- // 1. We skip evaluating any predicates that we would
- // never be able prove are unsatisfiable (e.g. `<T as Foo>`
- // 2. We avoid trying to normalize predicates involving generic
- // parameters (e.g. `<T as Foo>::MyItem`). This can confuse
- // the normalization code (leading to cycle errors), since
- // it's usually never invoked in this way.
- let predicates = tcx
- .predicates_of(def_id.to_def_id())
- .predicates
- .iter()
- .filter_map(|(p, _)| if p.is_global() { Some(*p) } else { None });
- if traits::impossible_predicates(tcx, traits::elaborate(tcx, predicates).collect()) {
- trace!("ConstProp skipped for {:?}: found unsatisfiable predicates", def_id);
- return;
- }
-
- trace!("ConstProp starting for {:?}", def_id);
+ trace!("ConstPropLint starting for {:?}", def_id);
// FIXME(oli-obk, eddyb) Optimize locals (or even local paths) to hold
// constants, instead of just checking for const-folding succeeding.
@@ -112,7 +75,7 @@ impl<'tcx> MirLint<'tcx> for ConstProp {
let mut linter = ConstPropagator::new(body, tcx);
linter.visit_body(body);
- trace!("ConstProp done for {:?}", def_id);
+ trace!("ConstPropLint done for {:?}", def_id);
}
}
@@ -664,9 +627,10 @@ impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> {
}
TerminatorKind::SwitchInt { ref discr, ref targets } => {
if let Some(ref value) = self.eval_operand(&discr, location)
- && let Some(value_const) = self.use_ecx(location, |this| this.ecx.read_scalar(value))
- && let Ok(constant) = value_const.try_to_int()
- && let Ok(constant) = constant.to_bits(constant.size())
+ && let Some(value_const) =
+ self.use_ecx(location, |this| this.ecx.read_scalar(value))
+ && let Ok(constant) = value_const.try_to_int()
+ && let Ok(constant) = constant.to_bits(constant.size())
{
// We managed to evaluate the discriminant, so we know we only need to visit
// one target.
@@ -684,7 +648,7 @@ impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> {
| TerminatorKind::Unreachable
| TerminatorKind::Drop { .. }
| TerminatorKind::Yield { .. }
- | TerminatorKind::GeneratorDrop
+ | TerminatorKind::CoroutineDrop
| TerminatorKind::FalseEdge { .. }
| TerminatorKind::FalseUnwind { .. }
| TerminatorKind::Call { .. }
diff --git a/compiler/rustc_mir_transform/src/copy_prop.rs b/compiler/rustc_mir_transform/src/copy_prop.rs
index 9c38a6f81..f5db7ce97 100644
--- a/compiler/rustc_mir_transform/src/copy_prop.rs
+++ b/compiler/rustc_mir_transform/src/copy_prop.rs
@@ -168,14 +168,15 @@ impl<'tcx> MutVisitor<'tcx> for Replacer<'_, 'tcx> {
&& self.storage_to_remove.contains(l)
{
stmt.make_nop();
- return
+ return;
}
self.super_statement(stmt, loc);
// Do not leave tautological assignments around.
if let StatementKind::Assign(box (lhs, ref rhs)) = stmt.kind
- && let Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)) | Rvalue::CopyForDeref(rhs) = *rhs
+ && let Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)) | Rvalue::CopyForDeref(rhs) =
+ *rhs
&& lhs == rhs
{
stmt.make_nop();
diff --git a/compiler/rustc_mir_transform/src/generator.rs b/compiler/rustc_mir_transform/src/coroutine.rs
index e261b8ac2..abaed103f 100644
--- a/compiler/rustc_mir_transform/src/generator.rs
+++ b/compiler/rustc_mir_transform/src/coroutine.rs
@@ -1,53 +1,53 @@
-//! This is the implementation of the pass which transforms generators into state machines.
+//! This is the implementation of the pass which transforms coroutines into state machines.
//!
-//! MIR generation for generators creates a function which has a self argument which
-//! passes by value. This argument is effectively a generator type which only contains upvars and
-//! is only used for this argument inside the MIR for the generator.
+//! MIR generation for coroutines creates a function which has a self argument which
+//! passes by value. This argument is effectively a coroutine type which only contains upvars and
+//! is only used for this argument inside the MIR for the coroutine.
//! It is passed by value to enable upvars to be moved out of it. Drop elaboration runs on that
//! MIR before this pass and creates drop flags for MIR locals.
-//! It will also drop the generator argument (which only consists of upvars) if any of the upvars
-//! are moved out of. This pass elaborates the drops of upvars / generator argument in the case
+//! It will also drop the coroutine argument (which only consists of upvars) if any of the upvars
+//! are moved out of. This pass elaborates the drops of upvars / coroutine argument in the case
//! that none of the upvars were moved out of. This is because we cannot have any drops of this
-//! generator in the MIR, since it is used to create the drop glue for the generator. We'd get
+//! coroutine in the MIR, since it is used to create the drop glue for the coroutine. We'd get
//! infinite recursion otherwise.
//!
-//! This pass creates the implementation for either the `Generator::resume` or `Future::poll`
-//! function and the drop shim for the generator based on the MIR input.
-//! It converts the generator argument from Self to &mut Self adding derefs in the MIR as needed.
-//! It computes the final layout of the generator struct which looks like this:
+//! This pass creates the implementation for either the `Coroutine::resume` or `Future::poll`
+//! function and the drop shim for the coroutine based on the MIR input.
+//! It converts the coroutine argument from Self to &mut Self adding derefs in the MIR as needed.
+//! It computes the final layout of the coroutine struct which looks like this:
//! First upvars are stored
-//! It is followed by the generator state field.
+//! It is followed by the coroutine state field.
//! Then finally the MIR locals which are live across a suspension point are stored.
//! ```ignore (illustrative)
-//! struct Generator {
+//! struct Coroutine {
//! upvars...,
//! state: u32,
//! mir_locals...,
//! }
//! ```
//! This pass computes the meaning of the state field and the MIR locals which are live
-//! across a suspension point. There are however three hardcoded generator states:
-//! 0 - Generator have not been resumed yet
-//! 1 - Generator has returned / is completed
-//! 2 - Generator has been poisoned
+//! across a suspension point. There are however three hardcoded coroutine states:
+//! 0 - Coroutine have not been resumed yet
+//! 1 - Coroutine has returned / is completed
+//! 2 - Coroutine has been poisoned
//!
-//! It also rewrites `return x` and `yield y` as setting a new generator state and returning
-//! `GeneratorState::Complete(x)` and `GeneratorState::Yielded(y)`,
+//! It also rewrites `return x` and `yield y` as setting a new coroutine state and returning
+//! `CoroutineState::Complete(x)` and `CoroutineState::Yielded(y)`,
//! or `Poll::Ready(x)` and `Poll::Pending` respectively.
-//! MIR locals which are live across a suspension point are moved to the generator struct
-//! with references to them being updated with references to the generator struct.
+//! MIR locals which are live across a suspension point are moved to the coroutine struct
+//! with references to them being updated with references to the coroutine struct.
//!
-//! The pass creates two functions which have a switch on the generator state giving
+//! The pass creates two functions which have a switch on the coroutine state giving
//! the action to take.
//!
-//! One of them is the implementation of `Generator::resume` / `Future::poll`.
-//! For generators with state 0 (unresumed) it starts the execution of the generator.
-//! For generators with state 1 (returned) and state 2 (poisoned) it panics.
+//! One of them is the implementation of `Coroutine::resume` / `Future::poll`.
+//! For coroutines with state 0 (unresumed) it starts the execution of the coroutine.
+//! For coroutines with state 1 (returned) and state 2 (poisoned) it panics.
//! Otherwise it continues the execution from the last suspension point.
//!
-//! The other function is the drop glue for the generator.
-//! For generators with state 0 (unresumed) it drops the upvars of the generator.
-//! For generators with state 1 (returned) and state 2 (poisoned) it does nothing.
+//! The other function is the drop glue for the coroutine.
+//! For coroutines with state 0 (unresumed) it drops the upvars of the coroutine.
+//! For coroutines with state 1 (returned) and state 2 (poisoned) it does nothing.
//! Otherwise it drops all the values in scope at the last suspension point.
use crate::abort_unwinding_calls;
@@ -60,7 +60,7 @@ use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_errors::pluralize;
use rustc_hir as hir;
use rustc_hir::lang_items::LangItem;
-use rustc_hir::GeneratorKind;
+use rustc_hir::CoroutineKind;
use rustc_index::bit_set::{BitMatrix, BitSet, GrowableBitSet};
use rustc_index::{Idx, IndexVec};
use rustc_middle::mir::dump_mir;
@@ -68,7 +68,7 @@ use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
use rustc_middle::mir::*;
use rustc_middle::ty::InstanceDef;
use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt};
-use rustc_middle::ty::{GeneratorArgs, GenericArgsRef};
+use rustc_middle::ty::{CoroutineArgs, GenericArgsRef};
use rustc_mir_dataflow::impls::{
MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage, MaybeStorageLive,
};
@@ -147,7 +147,7 @@ impl<'tcx> MutVisitor<'tcx> for DerefArgVisitor<'tcx> {
}
struct PinArgVisitor<'tcx> {
- ref_gen_ty: Ty<'tcx>,
+ ref_coroutine_ty: Ty<'tcx>,
tcx: TyCtxt<'tcx>,
}
@@ -168,7 +168,7 @@ impl<'tcx> MutVisitor<'tcx> for PinArgVisitor<'tcx> {
local: SELF_ARG,
projection: self.tcx().mk_place_elems(&[ProjectionElem::Field(
FieldIdx::new(0),
- self.ref_gen_ty,
+ self.ref_coroutine_ty,
)]),
},
self.tcx,
@@ -196,19 +196,19 @@ fn replace_base<'tcx>(place: &mut Place<'tcx>, new_base: Place<'tcx>, tcx: TyCtx
const SELF_ARG: Local = Local::from_u32(1);
-/// Generator has not been resumed yet.
-const UNRESUMED: usize = GeneratorArgs::UNRESUMED;
-/// Generator has returned / is completed.
-const RETURNED: usize = GeneratorArgs::RETURNED;
-/// Generator has panicked and is poisoned.
-const POISONED: usize = GeneratorArgs::POISONED;
+/// Coroutine has not been resumed yet.
+const UNRESUMED: usize = CoroutineArgs::UNRESUMED;
+/// Coroutine has returned / is completed.
+const RETURNED: usize = CoroutineArgs::RETURNED;
+/// Coroutine has panicked and is poisoned.
+const POISONED: usize = CoroutineArgs::POISONED;
-/// Number of variants to reserve in generator state. Corresponds to
-/// `UNRESUMED` (beginning of a generator) and `RETURNED`/`POISONED`
-/// (end of a generator) states.
+/// Number of variants to reserve in coroutine state. Corresponds to
+/// `UNRESUMED` (beginning of a coroutine) and `RETURNED`/`POISONED`
+/// (end of a coroutine) states.
const RESERVED_VARIANTS: usize = 3;
-/// A `yield` point in the generator.
+/// A `yield` point in the coroutine.
struct SuspensionPoint<'tcx> {
/// State discriminant used when suspending or resuming at this point.
state: usize,
@@ -216,7 +216,7 @@ struct SuspensionPoint<'tcx> {
resume: BasicBlock,
/// Where to move the resume argument after resumption.
resume_arg: Place<'tcx>,
- /// Which block to jump to if the generator is dropped in this state.
+ /// Which block to jump to if the coroutine is dropped in this state.
drop: Option<BasicBlock>,
/// Set of locals that have live storage while at this suspension point.
storage_liveness: GrowableBitSet<Local>,
@@ -224,14 +224,14 @@ struct SuspensionPoint<'tcx> {
struct TransformVisitor<'tcx> {
tcx: TyCtxt<'tcx>,
- is_async_kind: bool,
+ coroutine_kind: hir::CoroutineKind,
state_adt_ref: AdtDef<'tcx>,
state_args: GenericArgsRef<'tcx>,
- // The type of the discriminant in the generator struct
+ // The type of the discriminant in the coroutine struct
discr_ty: Ty<'tcx>,
- // Mapping from Local to (type of local, generator struct index)
+ // Mapping from Local to (type of local, coroutine struct index)
// FIXME(eddyb) This should use `IndexVec<Local, Option<_>>`.
remap: FxHashMap<Local, (Ty<'tcx>, VariantIdx, FieldIdx)>,
@@ -249,9 +249,50 @@ struct TransformVisitor<'tcx> {
}
impl<'tcx> TransformVisitor<'tcx> {
- // Make a `GeneratorState` or `Poll` variant assignment.
+ fn insert_none_ret_block(&self, body: &mut Body<'tcx>) -> BasicBlock {
+ let block = BasicBlock::new(body.basic_blocks.len());
+
+ let source_info = SourceInfo::outermost(body.span);
+
+ let (kind, idx) = self.coroutine_state_adt_and_variant_idx(true);
+ assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
+ let statements = vec![Statement {
+ kind: StatementKind::Assign(Box::new((
+ Place::return_place(),
+ Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
+ ))),
+ source_info,
+ }];
+
+ body.basic_blocks_mut().push(BasicBlockData {
+ statements,
+ terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }),
+ is_cleanup: false,
+ });
+
+ block
+ }
+
+ fn coroutine_state_adt_and_variant_idx(
+ &self,
+ is_return: bool,
+ ) -> (AggregateKind<'tcx>, VariantIdx) {
+ let idx = VariantIdx::new(match (is_return, self.coroutine_kind) {
+ (true, hir::CoroutineKind::Coroutine) => 1, // CoroutineState::Complete
+ (false, hir::CoroutineKind::Coroutine) => 0, // CoroutineState::Yielded
+ (true, hir::CoroutineKind::Async(_)) => 0, // Poll::Ready
+ (false, hir::CoroutineKind::Async(_)) => 1, // Poll::Pending
+ (true, hir::CoroutineKind::Gen(_)) => 0, // Option::None
+ (false, hir::CoroutineKind::Gen(_)) => 1, // Option::Some
+ });
+
+ let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_args, None, None);
+ (kind, idx)
+ }
+
+ // Make a `CoroutineState` or `Poll` variant assignment.
//
- // `core::ops::GeneratorState` only has single element tuple variants,
+ // `core::ops::CoroutineState` only has single element tuple variants,
// so we can just write to the downcasted first field and then set the
// discriminant to the appropriate variant.
fn make_state(
@@ -261,31 +302,44 @@ impl<'tcx> TransformVisitor<'tcx> {
is_return: bool,
statements: &mut Vec<Statement<'tcx>>,
) {
- let idx = VariantIdx::new(match (is_return, self.is_async_kind) {
- (true, false) => 1, // GeneratorState::Complete
- (false, false) => 0, // GeneratorState::Yielded
- (true, true) => 0, // Poll::Ready
- (false, true) => 1, // Poll::Pending
- });
+ let (kind, idx) = self.coroutine_state_adt_and_variant_idx(is_return);
- let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_args, None, None);
+ match self.coroutine_kind {
+ // `Poll::Pending`
+ CoroutineKind::Async(_) => {
+ if !is_return {
+ assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
- // `Poll::Pending`
- if self.is_async_kind && idx == VariantIdx::new(1) {
- assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
+ // FIXME(swatinem): assert that `val` is indeed unit?
+ statements.push(Statement {
+ kind: StatementKind::Assign(Box::new((
+ Place::return_place(),
+ Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
+ ))),
+ source_info,
+ });
+ return;
+ }
+ }
+ // `Option::None`
+ CoroutineKind::Gen(_) => {
+ if is_return {
+ assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
- // FIXME(swatinem): assert that `val` is indeed unit?
- statements.push(Statement {
- kind: StatementKind::Assign(Box::new((
- Place::return_place(),
- Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
- ))),
- source_info,
- });
- return;
+ statements.push(Statement {
+ kind: StatementKind::Assign(Box::new((
+ Place::return_place(),
+ Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
+ ))),
+ source_info,
+ });
+ return;
+ }
+ }
+ CoroutineKind::Coroutine => {}
}
- // else: `Poll::Ready(x)`, `GeneratorState::Yielded(x)` or `GeneratorState::Complete(x)`
+ // else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)`, `CoroutineState::Complete(x)`, or `Option::Some(x)`
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1);
statements.push(Statement {
@@ -297,7 +351,7 @@ impl<'tcx> TransformVisitor<'tcx> {
});
}
- // Create a Place referencing a generator struct field
+ // Create a Place referencing a coroutine struct field
fn make_field(&self, variant_index: VariantIdx, idx: FieldIdx, ty: Ty<'tcx>) -> Place<'tcx> {
let self_place = Place::from(SELF_ARG);
let base = self.tcx.mk_place_downcast_unnamed(self_place, variant_index);
@@ -321,7 +375,7 @@ impl<'tcx> TransformVisitor<'tcx> {
// Create a statement which reads the discriminant into a temporary
fn get_discr(&self, body: &mut Body<'tcx>) -> (Statement<'tcx>, Place<'tcx>) {
- let temp_decl = LocalDecl::new(self.discr_ty, body.span).internal();
+ let temp_decl = LocalDecl::new(self.discr_ty, body.span);
let local_decls_len = body.local_decls.push(temp_decl);
let temp = Place::from(local_decls_len);
@@ -349,7 +403,7 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
_context: PlaceContext,
_location: Location,
) {
- // Replace an Local in the remap with a generator struct access
+ // Replace an Local in the remap with a coroutine struct access
if let Some(&(ty, variant_index, idx)) = self.remap.get(&place.local) {
replace_base(place, self.make_field(variant_index, idx, ty), self.tcx);
}
@@ -413,35 +467,35 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
}
}
-fn make_generator_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
- let gen_ty = body.local_decls.raw[1].ty;
+fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+ let coroutine_ty = body.local_decls.raw[1].ty;
- let ref_gen_ty = Ty::new_ref(
+ let ref_coroutine_ty = Ty::new_ref(
tcx,
tcx.lifetimes.re_erased,
- ty::TypeAndMut { ty: gen_ty, mutbl: Mutability::Mut },
+ ty::TypeAndMut { ty: coroutine_ty, mutbl: Mutability::Mut },
);
- // Replace the by value generator argument
- body.local_decls.raw[1].ty = ref_gen_ty;
+ // Replace the by value coroutine argument
+ body.local_decls.raw[1].ty = ref_coroutine_ty;
- // Add a deref to accesses of the generator state
+ // Add a deref to accesses of the coroutine state
DerefArgVisitor { tcx }.visit_body(body);
}
-fn make_generator_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
- let ref_gen_ty = body.local_decls.raw[1].ty;
+fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+ let ref_coroutine_ty = body.local_decls.raw[1].ty;
let pin_did = tcx.require_lang_item(LangItem::Pin, Some(body.span));
let pin_adt_ref = tcx.adt_def(pin_did);
- let args = tcx.mk_args(&[ref_gen_ty.into()]);
- let pin_ref_gen_ty = Ty::new_adt(tcx, pin_adt_ref, args);
+ let args = tcx.mk_args(&[ref_coroutine_ty.into()]);
+ let pin_ref_coroutine_ty = Ty::new_adt(tcx, pin_adt_ref, args);
- // Replace the by ref generator argument
- body.local_decls.raw[1].ty = pin_ref_gen_ty;
+ // Replace the by ref coroutine argument
+ body.local_decls.raw[1].ty = pin_ref_coroutine_ty;
- // Add the Pin field access to accesses of the generator state
- PinArgVisitor { ref_gen_ty, tcx }.visit_body(body);
+ // Add the Pin field access to accesses of the coroutine state
+ PinArgVisitor { ref_coroutine_ty, tcx }.visit_body(body);
}
/// Allocates a new local and replaces all references of `local` with it. Returns the new local.
@@ -465,7 +519,7 @@ fn replace_local<'tcx>(
new_local
}
-/// Transforms the `body` of the generator applying the following transforms:
+/// Transforms the `body` of the coroutine applying the following transforms:
///
/// - Eliminates all the `get_context` calls that async lowering created.
/// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`).
@@ -485,7 +539,7 @@ fn replace_local<'tcx>(
///
/// The async lowering step and the type / lifetime inference / checking are
/// still using the `ResumeTy` indirection for the time being, and that indirection
-/// is removed here. After this transform, the generator body only knows about `&mut Context<'_>`.
+/// is removed here. After this transform, the coroutine body only knows about `&mut Context<'_>`.
fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let context_mut_ref = Ty::new_task_context(tcx);
@@ -565,10 +619,10 @@ fn replace_resume_ty_local<'tcx>(
struct LivenessInfo {
/// Which locals are live across any suspension point.
- saved_locals: GeneratorSavedLocals,
+ saved_locals: CoroutineSavedLocals,
/// The set of saved locals live at each suspension point.
- live_locals_at_suspension_points: Vec<BitSet<GeneratorSavedLocal>>,
+ live_locals_at_suspension_points: Vec<BitSet<CoroutineSavedLocal>>,
/// Parallel vec to the above with SourceInfo for each yield terminator.
source_info_at_suspension_points: Vec<SourceInfo>,
@@ -576,7 +630,7 @@ struct LivenessInfo {
/// For every saved local, the set of other saved locals that are
/// storage-live at the same time as this local. We cannot overlap locals in
/// the layout which have conflicting storage.
- storage_conflicts: BitMatrix<GeneratorSavedLocal, GeneratorSavedLocal>,
+ storage_conflicts: BitMatrix<CoroutineSavedLocal, CoroutineSavedLocal>,
/// For every suspending block, the locals which are storage-live across
/// that suspension point.
@@ -609,7 +663,7 @@ fn locals_live_across_suspend_points<'tcx>(
// Calculate the MIR locals which have been previously
// borrowed (even if they are still active).
let borrowed_locals_results =
- MaybeBorrowedLocals.into_engine(tcx, body_ref).pass_name("generator").iterate_to_fixpoint();
+ MaybeBorrowedLocals.into_engine(tcx, body_ref).pass_name("coroutine").iterate_to_fixpoint();
let mut borrowed_locals_cursor = borrowed_locals_results.cloned_results_cursor(body_ref);
@@ -624,7 +678,7 @@ fn locals_live_across_suspend_points<'tcx>(
// Calculate the liveness of MIR locals ignoring borrows.
let mut liveness = MaybeLiveLocals
.into_engine(tcx, body_ref)
- .pass_name("generator")
+ .pass_name("coroutine")
.iterate_to_fixpoint()
.into_results_cursor(body_ref);
@@ -643,8 +697,8 @@ fn locals_live_across_suspend_points<'tcx>(
if !movable {
// The `liveness` variable contains the liveness of MIR locals ignoring borrows.
- // This is correct for movable generators since borrows cannot live across
- // suspension points. However for immovable generators we need to account for
+ // This is correct for movable coroutines since borrows cannot live across
+ // suspension points. However for immovable coroutines we need to account for
// borrows, so we conservatively assume that all borrowed locals are live until
// we find a StorageDead statement referencing the locals.
// To do this we just union our `liveness` result with `borrowed_locals`, which
@@ -667,7 +721,7 @@ fn locals_live_across_suspend_points<'tcx>(
requires_storage_cursor.seek_before_primary_effect(loc);
live_locals.intersect(requires_storage_cursor.get());
- // The generator argument is ignored.
+ // The coroutine argument is ignored.
live_locals.remove(SELF_ARG);
debug!("loc = {:?}, live_locals = {:?}", loc, live_locals);
@@ -682,7 +736,7 @@ fn locals_live_across_suspend_points<'tcx>(
}
debug!("live_locals_anywhere = {:?}", live_locals_at_any_suspension_point);
- let saved_locals = GeneratorSavedLocals(live_locals_at_any_suspension_point);
+ let saved_locals = CoroutineSavedLocals(live_locals_at_any_suspension_point);
// Renumber our liveness_map bitsets to include only the locals we are
// saving.
@@ -709,21 +763,21 @@ fn locals_live_across_suspend_points<'tcx>(
/// The set of `Local`s that must be saved across yield points.
///
-/// `GeneratorSavedLocal` is indexed in terms of the elements in this set;
-/// i.e. `GeneratorSavedLocal::new(1)` corresponds to the second local
+/// `CoroutineSavedLocal` is indexed in terms of the elements in this set;
+/// i.e. `CoroutineSavedLocal::new(1)` corresponds to the second local
/// included in this set.
-struct GeneratorSavedLocals(BitSet<Local>);
+struct CoroutineSavedLocals(BitSet<Local>);
-impl GeneratorSavedLocals {
- /// Returns an iterator over each `GeneratorSavedLocal` along with the `Local` it corresponds
+impl CoroutineSavedLocals {
+ /// Returns an iterator over each `CoroutineSavedLocal` along with the `Local` it corresponds
/// to.
- fn iter_enumerated(&self) -> impl '_ + Iterator<Item = (GeneratorSavedLocal, Local)> {
- self.iter().enumerate().map(|(i, l)| (GeneratorSavedLocal::from(i), l))
+ fn iter_enumerated(&self) -> impl '_ + Iterator<Item = (CoroutineSavedLocal, Local)> {
+ self.iter().enumerate().map(|(i, l)| (CoroutineSavedLocal::from(i), l))
}
/// Transforms a `BitSet<Local>` that contains only locals saved across yield points to the
- /// equivalent `BitSet<GeneratorSavedLocal>`.
- fn renumber_bitset(&self, input: &BitSet<Local>) -> BitSet<GeneratorSavedLocal> {
+ /// equivalent `BitSet<CoroutineSavedLocal>`.
+ fn renumber_bitset(&self, input: &BitSet<Local>) -> BitSet<CoroutineSavedLocal> {
assert!(self.superset(&input), "{:?} not a superset of {:?}", self.0, input);
let mut out = BitSet::new_empty(self.count());
for (saved_local, local) in self.iter_enumerated() {
@@ -734,17 +788,17 @@ impl GeneratorSavedLocals {
out
}
- fn get(&self, local: Local) -> Option<GeneratorSavedLocal> {
+ fn get(&self, local: Local) -> Option<CoroutineSavedLocal> {
if !self.contains(local) {
return None;
}
let idx = self.iter().take_while(|&l| l < local).count();
- Some(GeneratorSavedLocal::new(idx))
+ Some(CoroutineSavedLocal::new(idx))
}
}
-impl ops::Deref for GeneratorSavedLocals {
+impl ops::Deref for CoroutineSavedLocals {
type Target = BitSet<Local>;
fn deref(&self) -> &Self::Target {
@@ -755,13 +809,13 @@ impl ops::Deref for GeneratorSavedLocals {
/// For every saved local, looks for which locals are StorageLive at the same
/// time. Generates a bitset for every local of all the other locals that may be
/// StorageLive simultaneously with that local. This is used in the layout
-/// computation; see `GeneratorLayout` for more.
+/// computation; see `CoroutineLayout` for more.
fn compute_storage_conflicts<'mir, 'tcx>(
body: &'mir Body<'tcx>,
- saved_locals: &GeneratorSavedLocals,
+ saved_locals: &CoroutineSavedLocals,
always_live_locals: BitSet<Local>,
mut requires_storage: rustc_mir_dataflow::Results<'tcx, MaybeRequiresStorage<'_, 'mir, 'tcx>>,
-) -> BitMatrix<GeneratorSavedLocal, GeneratorSavedLocal> {
+) -> BitMatrix<CoroutineSavedLocal, CoroutineSavedLocal> {
assert_eq!(body.local_decls.len(), saved_locals.domain_size());
debug!("compute_storage_conflicts({:?})", body.span);
@@ -783,7 +837,7 @@ fn compute_storage_conflicts<'mir, 'tcx>(
let local_conflicts = visitor.local_conflicts;
- // Compress the matrix using only stored locals (Local -> GeneratorSavedLocal).
+ // Compress the matrix using only stored locals (Local -> CoroutineSavedLocal).
//
// NOTE: Today we store a full conflict bitset for every local. Technically
// this is twice as many bits as we need, since the relation is symmetric.
@@ -809,9 +863,9 @@ fn compute_storage_conflicts<'mir, 'tcx>(
struct StorageConflictVisitor<'mir, 'tcx, 's> {
body: &'mir Body<'tcx>,
- saved_locals: &'s GeneratorSavedLocals,
+ saved_locals: &'s CoroutineSavedLocals,
// FIXME(tmandry): Consider using sparse bitsets here once we have good
- // benchmarks for generators.
+ // benchmarks for coroutines.
local_conflicts: BitMatrix<Local, Local>,
}
@@ -866,7 +920,7 @@ fn compute_layout<'tcx>(
body: &Body<'tcx>,
) -> (
FxHashMap<Local, (Ty<'tcx>, VariantIdx, FieldIdx)>,
- GeneratorLayout<'tcx>,
+ CoroutineLayout<'tcx>,
IndexVec<BasicBlock, Option<BitSet<Local>>>,
) {
let LivenessInfo {
@@ -878,10 +932,10 @@ fn compute_layout<'tcx>(
} = liveness;
// Gather live local types and their indices.
- let mut locals = IndexVec::<GeneratorSavedLocal, _>::new();
- let mut tys = IndexVec::<GeneratorSavedLocal, _>::new();
+ let mut locals = IndexVec::<CoroutineSavedLocal, _>::new();
+ let mut tys = IndexVec::<CoroutineSavedLocal, _>::new();
for (saved_local, local) in saved_locals.iter_enumerated() {
- debug!("generator saved local {:?} => {:?}", saved_local, local);
+ debug!("coroutine saved local {:?} => {:?}", saved_local, local);
locals.push(local);
let decl = &body.local_decls[local];
@@ -903,7 +957,7 @@ fn compute_layout<'tcx>(
_ => false,
};
let decl =
- GeneratorSavedTy { ty: decl.ty, source_info: decl.source_info, ignore_for_traits };
+ CoroutineSavedTy { ty: decl.ty, source_info: decl.source_info, ignore_for_traits };
debug!(?decl);
tys.push(decl);
@@ -922,9 +976,9 @@ fn compute_layout<'tcx>(
.copied()
.collect();
- // Build the generator variant field list.
- // Create a map from local indices to generator struct indices.
- let mut variant_fields: IndexVec<VariantIdx, IndexVec<FieldIdx, GeneratorSavedLocal>> =
+ // Build the coroutine variant field list.
+ // Create a map from local indices to coroutine struct indices.
+ let mut variant_fields: IndexVec<VariantIdx, IndexVec<FieldIdx, CoroutineSavedLocal>> =
iter::repeat(IndexVec::new()).take(RESERVED_VARIANTS).collect();
let mut remap = FxHashMap::default();
for (suspension_point_idx, live_locals) in live_locals_at_suspension_points.iter().enumerate() {
@@ -934,7 +988,7 @@ fn compute_layout<'tcx>(
fields.push(saved_local);
// Note that if a field is included in multiple variants, we will
// just use the first one here. That's fine; fields do not move
- // around inside generators, so it doesn't matter which variant
+ // around inside coroutines, so it doesn't matter which variant
// index we access them by.
let idx = FieldIdx::from_usize(idx);
remap.entry(locals[saved_local]).or_insert((tys[saved_local].ty, variant_index, idx));
@@ -942,8 +996,8 @@ fn compute_layout<'tcx>(
variant_fields.push(fields);
variant_source_info.push(source_info_at_suspension_points[suspension_point_idx]);
}
- debug!("generator variant_fields = {:?}", variant_fields);
- debug!("generator storage_conflicts = {:#?}", storage_conflicts);
+ debug!("coroutine variant_fields = {:?}", variant_fields);
+ debug!("coroutine storage_conflicts = {:#?}", storage_conflicts);
let mut field_names = IndexVec::from_elem(None, &tys);
for var in &body.var_debug_info {
@@ -955,7 +1009,7 @@ fn compute_layout<'tcx>(
field_names.get_or_insert_with(saved_local, || var.name);
}
- let layout = GeneratorLayout {
+ let layout = CoroutineLayout {
field_tys: tys,
field_names,
variant_fields,
@@ -967,7 +1021,7 @@ fn compute_layout<'tcx>(
(remap, layout, storage_liveness)
}
-/// Replaces the entry point of `body` with a block that switches on the generator discriminant and
+/// Replaces the entry point of `body` with a block that switches on the coroutine discriminant and
/// dispatches to blocks according to `cases`.
///
/// After this function, the former entry point of the function will be bb1.
@@ -1000,14 +1054,14 @@ fn insert_switch<'tcx>(
}
}
-fn elaborate_generator_drops<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+fn elaborate_coroutine_drops<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
use crate::shim::DropShimElaborator;
use rustc_middle::mir::patch::MirPatch;
use rustc_mir_dataflow::elaborate_drops::{elaborate_drop, Unwind};
- // Note that `elaborate_drops` only drops the upvars of a generator, and
+ // Note that `elaborate_drops` only drops the upvars of a coroutine, and
// this is ok because `open_drop` can only be reached within that own
- // generator's resume function.
+ // coroutine's resume function.
let def_id = body.source.def_id();
let param_env = tcx.param_env(def_id);
@@ -1055,10 +1109,10 @@ fn elaborate_generator_drops<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
elaborator.patch.apply(body);
}
-fn create_generator_drop_shim<'tcx>(
+fn create_coroutine_drop_shim<'tcx>(
tcx: TyCtxt<'tcx>,
transform: &TransformVisitor<'tcx>,
- gen_ty: Ty<'tcx>,
+ coroutine_ty: Ty<'tcx>,
body: &mut Body<'tcx>,
drop_clean: BasicBlock,
) -> Body<'tcx> {
@@ -1078,7 +1132,7 @@ fn create_generator_drop_shim<'tcx>(
for block in body.basic_blocks_mut() {
let kind = &mut block.terminator_mut().kind;
- if let TerminatorKind::GeneratorDrop = *kind {
+ if let TerminatorKind::CoroutineDrop = *kind {
*kind = TerminatorKind::Return;
}
}
@@ -1086,36 +1140,27 @@ fn create_generator_drop_shim<'tcx>(
// Replace the return variable
body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(Ty::new_unit(tcx), source_info);
- make_generator_state_argument_indirect(tcx, &mut body);
+ make_coroutine_state_argument_indirect(tcx, &mut body);
- // Change the generator argument from &mut to *mut
+ // Change the coroutine argument from &mut to *mut
body.local_decls[SELF_ARG] = LocalDecl::with_source_info(
- Ty::new_ptr(tcx, ty::TypeAndMut { ty: gen_ty, mutbl: hir::Mutability::Mut }),
+ Ty::new_ptr(tcx, ty::TypeAndMut { ty: coroutine_ty, mutbl: hir::Mutability::Mut }),
source_info,
);
// Make sure we remove dead blocks to remove
// unrelated code from the resume part of the function
- simplify::remove_dead_blocks(tcx, &mut body);
+ simplify::remove_dead_blocks(&mut body);
// Update the body's def to become the drop glue.
- // This needs to be updated before the AbortUnwindingCalls pass.
- let gen_instance = body.source.instance;
+ let coroutine_instance = body.source.instance;
let drop_in_place = tcx.require_lang_item(LangItem::DropInPlace, None);
- let drop_instance = InstanceDef::DropGlue(drop_in_place, Some(gen_ty));
- body.source.instance = drop_instance;
-
- pm::run_passes_no_validate(
- tcx,
- &mut body,
- &[&abort_unwinding_calls::AbortUnwindingCalls],
- None,
- );
+ let drop_instance = InstanceDef::DropGlue(drop_in_place, Some(coroutine_ty));
- // Temporary change MirSource to generator's instance so that dump_mir produces more sensible
+ // Temporary change MirSource to coroutine's instance so that dump_mir produces more sensible
// filename.
- body.source.instance = gen_instance;
- dump_mir(tcx, false, "generator_drop", &0, &body, |_, _| Ok(()));
+ body.source.instance = coroutine_instance;
+ dump_mir(tcx, false, "coroutine_drop", &0, &body, |_, _| Ok(()));
body.source.instance = drop_instance;
body
@@ -1190,7 +1235,7 @@ fn can_unwind<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> bool {
| TerminatorKind::UnwindTerminate(_)
| TerminatorKind::Return
| TerminatorKind::Unreachable
- | TerminatorKind::GeneratorDrop
+ | TerminatorKind::CoroutineDrop
| TerminatorKind::FalseEdge { .. }
| TerminatorKind::FalseUnwind { .. } => {}
@@ -1199,7 +1244,7 @@ fn can_unwind<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> bool {
TerminatorKind::UnwindResume => {}
TerminatorKind::Yield { .. } => {
- unreachable!("`can_unwind` called before generator transform")
+ unreachable!("`can_unwind` called before coroutine transform")
}
// These may unwind.
@@ -1214,7 +1259,7 @@ fn can_unwind<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> bool {
false
}
-fn create_generator_resume_function<'tcx>(
+fn create_coroutine_resume_function<'tcx>(
tcx: TyCtxt<'tcx>,
transform: TransformVisitor<'tcx>,
body: &mut Body<'tcx>,
@@ -1222,7 +1267,7 @@ fn create_generator_resume_function<'tcx>(
) {
let can_unwind = can_unwind(tcx, body);
- // Poison the generator when it unwinds
+ // Poison the coroutine when it unwinds
if can_unwind {
let source_info = SourceInfo::outermost(body.span);
let poison_block = body.basic_blocks_mut().push(BasicBlockData {
@@ -1261,34 +1306,37 @@ fn create_generator_resume_function<'tcx>(
cases.insert(0, (UNRESUMED, START_BLOCK));
// Panic when resumed on the returned or poisoned state
- let generator_kind = body.generator_kind().unwrap();
+ let coroutine_kind = body.coroutine_kind().unwrap();
if can_unwind {
cases.insert(
1,
- (POISONED, insert_panic_block(tcx, body, ResumedAfterPanic(generator_kind))),
+ (POISONED, insert_panic_block(tcx, body, ResumedAfterPanic(coroutine_kind))),
);
}
if can_return {
- cases.insert(
- 1,
- (RETURNED, insert_panic_block(tcx, body, ResumedAfterReturn(generator_kind))),
- );
+ let block = match coroutine_kind {
+ CoroutineKind::Async(_) | CoroutineKind::Coroutine => {
+ insert_panic_block(tcx, body, ResumedAfterReturn(coroutine_kind))
+ }
+ CoroutineKind::Gen(_) => transform.insert_none_ret_block(body),
+ };
+ cases.insert(1, (RETURNED, block));
}
insert_switch(body, cases, &transform, TerminatorKind::Unreachable);
- make_generator_state_argument_indirect(tcx, body);
- make_generator_state_argument_pinned(tcx, body);
+ make_coroutine_state_argument_indirect(tcx, body);
+ make_coroutine_state_argument_pinned(tcx, body);
// Make sure we remove dead blocks to remove
// unrelated code from the drop part of the function
- simplify::remove_dead_blocks(tcx, body);
+ simplify::remove_dead_blocks(body);
pm::run_passes_no_validate(tcx, body, &[&abort_unwinding_calls::AbortUnwindingCalls], None);
- dump_mir(tcx, false, "generator_resume", &0, body, |_, _| Ok(()));
+ dump_mir(tcx, false, "coroutine_resume", &0, body, |_, _| Ok(()));
}
fn insert_clean_drop(body: &mut Body<'_>) -> BasicBlock {
@@ -1302,7 +1350,7 @@ fn insert_clean_drop(body: &mut Body<'_>) -> BasicBlock {
};
let source_info = SourceInfo::outermost(body.span);
- // Create a block to destroy an unresumed generators. This can only destroy upvars.
+ // Create a block to destroy an unresumed coroutines. This can only destroy upvars.
body.basic_blocks_mut().push(BasicBlockData {
statements: Vec::new(),
terminator: Some(Terminator { source_info, kind: term }),
@@ -1310,7 +1358,7 @@ fn insert_clean_drop(body: &mut Body<'_>) -> BasicBlock {
})
}
-/// An operation that can be performed on a generator.
+/// An operation that can be performed on a coroutine.
#[derive(PartialEq, Copy, Clone)]
enum Operation {
Resume,
@@ -1389,21 +1437,21 @@ fn create_cases<'tcx>(
}
#[instrument(level = "debug", skip(tcx), ret)]
-pub(crate) fn mir_generator_witnesses<'tcx>(
+pub(crate) fn mir_coroutine_witnesses<'tcx>(
tcx: TyCtxt<'tcx>,
def_id: LocalDefId,
-) -> Option<GeneratorLayout<'tcx>> {
+) -> Option<CoroutineLayout<'tcx>> {
let (body, _) = tcx.mir_promoted(def_id);
let body = body.borrow();
let body = &*body;
- // The first argument is the generator type passed by value
- let gen_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
+ // The first argument is the coroutine type passed by value
+ let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
- let movable = match *gen_ty.kind() {
- ty::Generator(_, _, movability) => movability == hir::Movability::Movable,
+ let movable = match *coroutine_ty.kind() {
+ ty::Coroutine(_, _, movability) => movability == hir::Movability::Movable,
ty::Error(_) => return None,
- _ => span_bug!(body.span, "unexpected generator type {}", gen_ty),
+ _ => span_bug!(body.span, "unexpected coroutine type {}", coroutine_ty),
};
// The witness simply contains all locals live across suspend points.
@@ -1412,52 +1460,63 @@ pub(crate) fn mir_generator_witnesses<'tcx>(
let liveness_info = locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);
// Extract locals which are live across suspension point into `layout`
- // `remap` gives a mapping from local indices onto generator struct indices
+ // `remap` gives a mapping from local indices onto coroutine struct indices
// `storage_liveness` tells us which locals have live storage at suspension points
- let (_, generator_layout, _) = compute_layout(liveness_info, body);
+ let (_, coroutine_layout, _) = compute_layout(liveness_info, body);
- check_suspend_tys(tcx, &generator_layout, &body);
+ check_suspend_tys(tcx, &coroutine_layout, &body);
- Some(generator_layout)
+ Some(coroutine_layout)
}
impl<'tcx> MirPass<'tcx> for StateTransform {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let Some(yield_ty) = body.yield_ty() else {
- // This only applies to generators
+ // This only applies to coroutines
return;
};
- assert!(body.generator_drop().is_none());
+ assert!(body.coroutine_drop().is_none());
- // The first argument is the generator type passed by value
- let gen_ty = body.local_decls.raw[1].ty;
+ // The first argument is the coroutine type passed by value
+ let coroutine_ty = body.local_decls.raw[1].ty;
// Get the discriminant type and args which typeck computed
- let (discr_ty, movable) = match *gen_ty.kind() {
- ty::Generator(_, args, movability) => {
- let args = args.as_generator();
+ let (discr_ty, movable) = match *coroutine_ty.kind() {
+ ty::Coroutine(_, args, movability) => {
+ let args = args.as_coroutine();
(args.discr_ty(tcx), movability == hir::Movability::Movable)
}
_ => {
- tcx.sess.delay_span_bug(body.span, format!("unexpected generator type {gen_ty}"));
+ tcx.sess
+ .delay_span_bug(body.span, format!("unexpected coroutine type {coroutine_ty}"));
return;
}
};
- let is_async_kind = matches!(body.generator_kind(), Some(GeneratorKind::Async(_)));
- let (state_adt_ref, state_args) = if is_async_kind {
- // Compute Poll<return_ty>
- let poll_did = tcx.require_lang_item(LangItem::Poll, None);
- let poll_adt_ref = tcx.adt_def(poll_did);
- let poll_args = tcx.mk_args(&[body.return_ty().into()]);
- (poll_adt_ref, poll_args)
- } else {
- // Compute GeneratorState<yield_ty, return_ty>
- let state_did = tcx.require_lang_item(LangItem::GeneratorState, None);
- let state_adt_ref = tcx.adt_def(state_did);
- let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]);
- (state_adt_ref, state_args)
+ let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_)));
+ let (state_adt_ref, state_args) = match body.coroutine_kind().unwrap() {
+ CoroutineKind::Async(_) => {
+ // Compute Poll<return_ty>
+ let poll_did = tcx.require_lang_item(LangItem::Poll, None);
+ let poll_adt_ref = tcx.adt_def(poll_did);
+ let poll_args = tcx.mk_args(&[body.return_ty().into()]);
+ (poll_adt_ref, poll_args)
+ }
+ CoroutineKind::Gen(_) => {
+ // Compute Option<yield_ty>
+ let option_did = tcx.require_lang_item(LangItem::Option, None);
+ let option_adt_ref = tcx.adt_def(option_did);
+ let option_args = tcx.mk_args(&[body.yield_ty().unwrap().into()]);
+ (option_adt_ref, option_args)
+ }
+ CoroutineKind::Coroutine => {
+ // Compute CoroutineState<yield_ty, return_ty>
+ let state_did = tcx.require_lang_item(LangItem::CoroutineState, None);
+ let state_adt_ref = tcx.adt_def(state_did);
+ let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]);
+ (state_adt_ref, state_args)
+ }
};
let ret_ty = Ty::new_adt(tcx, state_adt_ref, state_args);
@@ -1472,8 +1531,8 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
// We also replace the resume argument and insert an `Assign`.
// This is needed because the resume argument `_2` might be live across a `yield`, in which
- // case there is no `Assign` to it that the transform can turn into a store to the generator
- // state. After the yield the slot in the generator state would then be uninitialized.
+ // case there is no `Assign` to it that the transform can turn into a store to the coroutine
+ // state. After the yield the slot in the coroutine state would then be uninitialized.
let resume_local = Local::new(2);
let resume_ty = if is_async_kind {
Ty::new_task_context(tcx)
@@ -1482,7 +1541,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
};
let new_resume_local = replace_local(resume_local, resume_ty, body, tcx);
- // When first entering the generator, move the resume argument into its new local.
+ // When first entering the coroutine, move the resume argument into its new local.
let source_info = SourceInfo::outermost(body.span);
let stmts = &mut body.basic_blocks_mut()[START_BLOCK].statements;
stmts.insert(
@@ -1502,7 +1561,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);
if tcx.sess.opts.unstable_opts.validate_mir {
- let mut vis = EnsureGeneratorFieldAssignmentsNeverAlias {
+ let mut vis = EnsureCoroutineFieldAssignmentsNeverAlias {
assigned_local: None,
saved_locals: &liveness_info.saved_locals,
storage_conflicts: &liveness_info.storage_conflicts,
@@ -1512,20 +1571,20 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
}
// Extract locals which are live across suspension point into `layout`
- // `remap` gives a mapping from local indices onto generator struct indices
+ // `remap` gives a mapping from local indices onto coroutine struct indices
// `storage_liveness` tells us which locals have live storage at suspension points
let (remap, layout, storage_liveness) = compute_layout(liveness_info, body);
let can_return = can_return(tcx, body, tcx.param_env(body.source.def_id()));
- // Run the transformation which converts Places from Local to generator struct
+ // Run the transformation which converts Places from Local to coroutine struct
// accesses for locals in `remap`.
- // It also rewrites `return x` and `yield y` as writing a new generator state and returning
- // either GeneratorState::Complete(x) and GeneratorState::Yielded(y),
+ // It also rewrites `return x` and `yield y` as writing a new coroutine state and returning
+ // either CoroutineState::Complete(x) and CoroutineState::Yielded(y),
// or Poll::Ready(x) and Poll::Pending respectively depending on `is_async_kind`.
let mut transform = TransformVisitor {
tcx,
- is_async_kind,
+ coroutine_kind: body.coroutine_kind().unwrap(),
state_adt_ref,
state_args,
remap,
@@ -1548,30 +1607,30 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
var.argument_index = None;
}
- body.generator.as_mut().unwrap().yield_ty = None;
- body.generator.as_mut().unwrap().generator_layout = Some(layout);
+ body.coroutine.as_mut().unwrap().yield_ty = None;
+ body.coroutine.as_mut().unwrap().coroutine_layout = Some(layout);
- // Insert `drop(generator_struct)` which is used to drop upvars for generators in
+ // Insert `drop(coroutine_struct)` which is used to drop upvars for coroutines in
// the unresumed state.
- // This is expanded to a drop ladder in `elaborate_generator_drops`.
+ // This is expanded to a drop ladder in `elaborate_coroutine_drops`.
let drop_clean = insert_clean_drop(body);
- dump_mir(tcx, false, "generator_pre-elab", &0, body, |_, _| Ok(()));
+ dump_mir(tcx, false, "coroutine_pre-elab", &0, body, |_, _| Ok(()));
- // Expand `drop(generator_struct)` to a drop ladder which destroys upvars.
+ // Expand `drop(coroutine_struct)` to a drop ladder which destroys upvars.
// If any upvars are moved out of, drop elaboration will handle upvar destruction.
// However we need to also elaborate the code generated by `insert_clean_drop`.
- elaborate_generator_drops(tcx, body);
+ elaborate_coroutine_drops(tcx, body);
- dump_mir(tcx, false, "generator_post-transform", &0, body, |_, _| Ok(()));
+ dump_mir(tcx, false, "coroutine_post-transform", &0, body, |_, _| Ok(()));
- // Create a copy of our MIR and use it to create the drop shim for the generator
- let drop_shim = create_generator_drop_shim(tcx, &transform, gen_ty, body, drop_clean);
+ // Create a copy of our MIR and use it to create the drop shim for the coroutine
+ let drop_shim = create_coroutine_drop_shim(tcx, &transform, coroutine_ty, body, drop_clean);
- body.generator.as_mut().unwrap().generator_drop = Some(drop_shim);
+ body.coroutine.as_mut().unwrap().coroutine_drop = Some(drop_shim);
- // Create the Generator::resume / Future::poll function
- create_generator_resume_function(tcx, transform, body, can_return);
+ // Create the Coroutine::resume / Future::poll function
+ create_coroutine_resume_function(tcx, transform, body, can_return);
// Run derefer to fix Derefs that are not in the first place
deref_finder(tcx, body);
@@ -1579,25 +1638,25 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
}
/// Looks for any assignments between locals (e.g., `_4 = _5`) that will both be converted to fields
-/// in the generator state machine but whose storage is not marked as conflicting
+/// in the coroutine state machine but whose storage is not marked as conflicting
///
/// Validation needs to happen immediately *before* `TransformVisitor` is invoked, not after.
///
/// This condition would arise when the assignment is the last use of `_5` but the initial
/// definition of `_4` if we weren't extra careful to mark all locals used inside a statement as
-/// conflicting. Non-conflicting generator saved locals may be stored at the same location within
-/// the generator state machine, which would result in ill-formed MIR: the left-hand and right-hand
+/// conflicting. Non-conflicting coroutine saved locals may be stored at the same location within
+/// the coroutine state machine, which would result in ill-formed MIR: the left-hand and right-hand
/// sides of an assignment may not alias. This caused a miscompilation in [#73137].
///
/// [#73137]: https://github.com/rust-lang/rust/issues/73137
-struct EnsureGeneratorFieldAssignmentsNeverAlias<'a> {
- saved_locals: &'a GeneratorSavedLocals,
- storage_conflicts: &'a BitMatrix<GeneratorSavedLocal, GeneratorSavedLocal>,
- assigned_local: Option<GeneratorSavedLocal>,
+struct EnsureCoroutineFieldAssignmentsNeverAlias<'a> {
+ saved_locals: &'a CoroutineSavedLocals,
+ storage_conflicts: &'a BitMatrix<CoroutineSavedLocal, CoroutineSavedLocal>,
+ assigned_local: Option<CoroutineSavedLocal>,
}
-impl EnsureGeneratorFieldAssignmentsNeverAlias<'_> {
- fn saved_local_for_direct_place(&self, place: Place<'_>) -> Option<GeneratorSavedLocal> {
+impl EnsureCoroutineFieldAssignmentsNeverAlias<'_> {
+ fn saved_local_for_direct_place(&self, place: Place<'_>) -> Option<CoroutineSavedLocal> {
if place.is_indirect() {
return None;
}
@@ -1616,7 +1675,7 @@ impl EnsureGeneratorFieldAssignmentsNeverAlias<'_> {
}
}
-impl<'tcx> Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> {
+impl<'tcx> Visitor<'tcx> for EnsureCoroutineFieldAssignmentsNeverAlias<'_> {
fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
let Some(lhs) = self.assigned_local else {
// This visitor only invokes `visit_place` for the right-hand side of an assignment
@@ -1631,7 +1690,7 @@ impl<'tcx> Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> {
if !self.storage_conflicts.contains(lhs, rhs) {
bug!(
- "Assignment between generator saved locals whose storage is not \
+ "Assignment between coroutine saved locals whose storage is not \
marked as conflicting: {:?}: {:?} = {:?}",
location,
lhs,
@@ -1698,14 +1757,14 @@ impl<'tcx> Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> {
| TerminatorKind::Unreachable
| TerminatorKind::Drop { .. }
| TerminatorKind::Assert { .. }
- | TerminatorKind::GeneratorDrop
+ | TerminatorKind::CoroutineDrop
| TerminatorKind::FalseEdge { .. }
| TerminatorKind::FalseUnwind { .. } => {}
}
}
}
-fn check_suspend_tys<'tcx>(tcx: TyCtxt<'tcx>, layout: &GeneratorLayout<'tcx>, body: &Body<'tcx>) {
+fn check_suspend_tys<'tcx>(tcx: TyCtxt<'tcx>, layout: &CoroutineLayout<'tcx>, body: &Body<'tcx>) {
let mut linted_tys = FxHashSet::default();
// We want a user-facing param-env.
diff --git a/compiler/rustc_mir_transform/src/cost_checker.rs b/compiler/rustc_mir_transform/src/cost_checker.rs
new file mode 100644
index 000000000..9bb26693c
--- /dev/null
+++ b/compiler/rustc_mir_transform/src/cost_checker.rs
@@ -0,0 +1,98 @@
+use rustc_middle::mir::visit::*;
+use rustc_middle::mir::*;
+use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt};
+
+const INSTR_COST: usize = 5;
+const CALL_PENALTY: usize = 25;
+const LANDINGPAD_PENALTY: usize = 50;
+const RESUME_PENALTY: usize = 45;
+
+/// Verify that the callee body is compatible with the caller.
+#[derive(Clone)]
+pub(crate) struct CostChecker<'b, 'tcx> {
+ tcx: TyCtxt<'tcx>,
+ param_env: ParamEnv<'tcx>,
+ cost: usize,
+ callee_body: &'b Body<'tcx>,
+ instance: Option<ty::Instance<'tcx>>,
+}
+
+impl<'b, 'tcx> CostChecker<'b, 'tcx> {
+ pub fn new(
+ tcx: TyCtxt<'tcx>,
+ param_env: ParamEnv<'tcx>,
+ instance: Option<ty::Instance<'tcx>>,
+ callee_body: &'b Body<'tcx>,
+ ) -> CostChecker<'b, 'tcx> {
+ CostChecker { tcx, param_env, callee_body, instance, cost: 0 }
+ }
+
+ pub fn cost(&self) -> usize {
+ self.cost
+ }
+
+ fn instantiate_ty(&self, v: Ty<'tcx>) -> Ty<'tcx> {
+ if let Some(instance) = self.instance {
+ instance.instantiate_mir(self.tcx, ty::EarlyBinder::bind(&v))
+ } else {
+ v
+ }
+ }
+}
+
+impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
+ fn visit_statement(&mut self, statement: &Statement<'tcx>, _: Location) {
+ // Don't count StorageLive/StorageDead in the inlining cost.
+ match statement.kind {
+ StatementKind::StorageLive(_)
+ | StatementKind::StorageDead(_)
+ | StatementKind::Deinit(_)
+ | StatementKind::Nop => {}
+ _ => self.cost += INSTR_COST,
+ }
+ }
+
+ fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, _: Location) {
+ let tcx = self.tcx;
+ match terminator.kind {
+ TerminatorKind::Drop { ref place, unwind, .. } => {
+ // If the place doesn't actually need dropping, treat it like a regular goto.
+ let ty = self.instantiate_ty(place.ty(self.callee_body, tcx).ty);
+ if ty.needs_drop(tcx, self.param_env) {
+ self.cost += CALL_PENALTY;
+ if let UnwindAction::Cleanup(_) = unwind {
+ self.cost += LANDINGPAD_PENALTY;
+ }
+ } else {
+ self.cost += INSTR_COST;
+ }
+ }
+ TerminatorKind::Call { func: Operand::Constant(ref f), unwind, .. } => {
+ let fn_ty = self.instantiate_ty(f.const_.ty());
+ self.cost += if let ty::FnDef(def_id, _) = *fn_ty.kind() && tcx.is_intrinsic(def_id) {
+ // Don't give intrinsics the extra penalty for calls
+ INSTR_COST
+ } else {
+ CALL_PENALTY
+ };
+ if let UnwindAction::Cleanup(_) = unwind {
+ self.cost += LANDINGPAD_PENALTY;
+ }
+ }
+ TerminatorKind::Assert { unwind, .. } => {
+ self.cost += CALL_PENALTY;
+ if let UnwindAction::Cleanup(_) = unwind {
+ self.cost += LANDINGPAD_PENALTY;
+ }
+ }
+ TerminatorKind::UnwindResume => self.cost += RESUME_PENALTY,
+ TerminatorKind::InlineAsm { unwind, .. } => {
+ self.cost += INSTR_COST;
+ if let UnwindAction::Cleanup(_) = unwind {
+ self.cost += LANDINGPAD_PENALTY;
+ }
+ }
+ _ => self.cost += INSTR_COST,
+ }
+ }
+}
diff --git a/compiler/rustc_mir_transform/src/coverage/counters.rs b/compiler/rustc_mir_transform/src/coverage/counters.rs
index d56d4ad4f..b34ec95b4 100644
--- a/compiler/rustc_mir_transform/src/coverage/counters.rs
+++ b/compiler/rustc_mir_transform/src/coverage/counters.rs
@@ -1,10 +1,6 @@
-use super::Error;
-
use super::graph;
-use super::spans;
use graph::{BasicCoverageBlock, BcbBranch, CoverageGraph, TraverseCoverageGraphWithLoops};
-use spans::CoverageSpan;
use rustc_data_structures::fx::FxHashMap;
use rustc_data_structures::graph::WithNumNodes;
@@ -14,14 +10,12 @@ use rustc_middle::mir::coverage::*;
use std::fmt::{self, Debug};
-const NESTED_INDENT: &str = " ";
-
/// The coverage counter or counter expression associated with a particular
/// BCB node or BCB edge.
#[derive(Clone)]
pub(super) enum BcbCounter {
Counter { id: CounterId },
- Expression { id: ExpressionId, lhs: Operand, op: Op, rhs: Operand },
+ Expression { id: ExpressionId },
}
impl BcbCounter {
@@ -29,10 +23,10 @@ impl BcbCounter {
matches!(self, Self::Expression { .. })
}
- pub(super) fn as_operand(&self) -> Operand {
+ pub(super) fn as_term(&self) -> CovTerm {
match *self {
- BcbCounter::Counter { id, .. } => Operand::Counter(id),
- BcbCounter::Expression { id, .. } => Operand::Expression(id),
+ BcbCounter::Counter { id, .. } => CovTerm::Counter(id),
+ BcbCounter::Expression { id, .. } => CovTerm::Expression(id),
}
}
}
@@ -41,17 +35,7 @@ impl Debug for BcbCounter {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Counter { id, .. } => write!(fmt, "Counter({:?})", id.index()),
- Self::Expression { id, lhs, op, rhs } => write!(
- fmt,
- "Expression({:?}) = {:?} {} {:?}",
- id.index(),
- lhs,
- match op {
- Op::Add => "+",
- Op::Subtract => "-",
- },
- rhs,
- ),
+ Self::Expression { id } => write!(fmt, "Expression({:?})", id.index()),
}
}
}
@@ -60,7 +44,6 @@ impl Debug for BcbCounter {
/// associated with nodes/edges in the BCB graph.
pub(super) struct CoverageCounters {
next_counter_id: CounterId,
- next_expression_id: ExpressionId,
/// Coverage counters/expressions that are associated with individual BCBs.
bcb_counters: IndexVec<BasicCoverageBlock, Option<BcbCounter>>,
@@ -68,13 +51,12 @@ pub(super) struct CoverageCounters {
/// edge between two BCBs.
bcb_edge_counters: FxHashMap<(BasicCoverageBlock, BasicCoverageBlock), BcbCounter>,
/// Tracks which BCBs have a counter associated with some incoming edge.
- /// Only used by debug assertions, to verify that BCBs with incoming edge
+ /// Only used by assertions, to verify that BCBs with incoming edge
/// counters do not have their own physical counters (expressions are allowed).
bcb_has_incoming_edge_counters: BitSet<BasicCoverageBlock>,
- /// Expression nodes that are not directly associated with any particular
- /// BCB/edge, but are needed as operands to more complex expressions.
- /// These are always [`BcbCounter::Expression`].
- pub(super) intermediate_expressions: Vec<BcbCounter>,
+ /// Table of expression data, associating each expression ID with its
+ /// corresponding operator (+ or -) and its LHS/RHS operands.
+ expressions: IndexVec<ExpressionId, Expression>,
}
impl CoverageCounters {
@@ -83,24 +65,22 @@ impl CoverageCounters {
Self {
next_counter_id: CounterId::START,
- next_expression_id: ExpressionId::START,
-
bcb_counters: IndexVec::from_elem_n(None, num_bcbs),
bcb_edge_counters: FxHashMap::default(),
bcb_has_incoming_edge_counters: BitSet::new_empty(num_bcbs),
- intermediate_expressions: Vec::new(),
+ expressions: IndexVec::new(),
}
}
/// Makes [`BcbCounter`] `Counter`s and `Expressions` for the `BasicCoverageBlock`s directly or
- /// indirectly associated with `CoverageSpans`, and accumulates additional `Expression`s
+ /// indirectly associated with coverage spans, and accumulates additional `Expression`s
/// representing intermediate values.
pub fn make_bcb_counters(
&mut self,
basic_coverage_blocks: &CoverageGraph,
- coverage_spans: &[CoverageSpan],
- ) -> Result<(), Error> {
- MakeBcbCounters::new(self, basic_coverage_blocks).make_bcb_counters(coverage_spans)
+ bcb_has_coverage_spans: impl Fn(BasicCoverageBlock) -> bool,
+ ) {
+ MakeBcbCounters::new(self, basic_coverage_blocks).make_bcb_counters(bcb_has_coverage_spans)
}
fn make_counter(&mut self) -> BcbCounter {
@@ -108,50 +88,44 @@ impl CoverageCounters {
BcbCounter::Counter { id }
}
- fn make_expression(&mut self, lhs: Operand, op: Op, rhs: Operand) -> BcbCounter {
- let id = self.next_expression();
- BcbCounter::Expression { id, lhs, op, rhs }
- }
-
- pub fn make_identity_counter(&mut self, counter_operand: Operand) -> BcbCounter {
- self.make_expression(counter_operand, Op::Add, Operand::Zero)
+ fn make_expression(&mut self, lhs: CovTerm, op: Op, rhs: CovTerm) -> BcbCounter {
+ let id = self.expressions.push(Expression { lhs, op, rhs });
+ BcbCounter::Expression { id }
}
/// Counter IDs start from one and go up.
fn next_counter(&mut self) -> CounterId {
let next = self.next_counter_id;
- self.next_counter_id = next.next_id();
+ self.next_counter_id = self.next_counter_id + 1;
next
}
- /// Expression IDs start from 0 and go up.
- /// (Counter IDs and Expression IDs are distinguished by the `Operand` enum.)
- fn next_expression(&mut self) -> ExpressionId {
- let next = self.next_expression_id;
- self.next_expression_id = next.next_id();
- next
+ pub(super) fn num_counters(&self) -> usize {
+ self.next_counter_id.as_usize()
}
- fn set_bcb_counter(
- &mut self,
- bcb: BasicCoverageBlock,
- counter_kind: BcbCounter,
- ) -> Result<Operand, Error> {
- debug_assert!(
+ #[cfg(test)]
+ pub(super) fn num_expressions(&self) -> usize {
+ self.expressions.len()
+ }
+
+ fn set_bcb_counter(&mut self, bcb: BasicCoverageBlock, counter_kind: BcbCounter) -> CovTerm {
+ assert!(
// If the BCB has an edge counter (to be injected into a new `BasicBlock`), it can also
// have an expression (to be injected into an existing `BasicBlock` represented by this
// `BasicCoverageBlock`).
counter_kind.is_expression() || !self.bcb_has_incoming_edge_counters.contains(bcb),
"attempt to add a `Counter` to a BCB target with existing incoming edge counters"
);
- let operand = counter_kind.as_operand();
+
+ let term = counter_kind.as_term();
if let Some(replaced) = self.bcb_counters[bcb].replace(counter_kind) {
- Error::from_string(format!(
+ bug!(
"attempt to set a BasicCoverageBlock coverage counter more than once; \
{bcb:?} already had counter {replaced:?}",
- ))
+ );
} else {
- Ok(operand)
+ term
}
}
@@ -160,27 +134,26 @@ impl CoverageCounters {
from_bcb: BasicCoverageBlock,
to_bcb: BasicCoverageBlock,
counter_kind: BcbCounter,
- ) -> Result<Operand, Error> {
- if level_enabled!(tracing::Level::DEBUG) {
- // If the BCB has an edge counter (to be injected into a new `BasicBlock`), it can also
- // have an expression (to be injected into an existing `BasicBlock` represented by this
- // `BasicCoverageBlock`).
- if self.bcb_counter(to_bcb).is_some_and(|c| !c.is_expression()) {
- return Error::from_string(format!(
- "attempt to add an incoming edge counter from {from_bcb:?} when the target BCB already \
- has a `Counter`"
- ));
- }
+ ) -> CovTerm {
+ // If the BCB has an edge counter (to be injected into a new `BasicBlock`), it can also
+ // have an expression (to be injected into an existing `BasicBlock` represented by this
+ // `BasicCoverageBlock`).
+ if let Some(node_counter) = self.bcb_counter(to_bcb) && !node_counter.is_expression() {
+ bug!(
+ "attempt to add an incoming edge counter from {from_bcb:?} \
+ when the target BCB already has {node_counter:?}"
+ );
}
+
self.bcb_has_incoming_edge_counters.insert(to_bcb);
- let operand = counter_kind.as_operand();
+ let term = counter_kind.as_term();
if let Some(replaced) = self.bcb_edge_counters.insert((from_bcb, to_bcb), counter_kind) {
- Error::from_string(format!(
+ bug!(
"attempt to set an edge counter more than once; from_bcb: \
{from_bcb:?} already had counter {replaced:?}",
- ))
+ );
} else {
- Ok(operand)
+ term
}
}
@@ -188,27 +161,31 @@ impl CoverageCounters {
self.bcb_counters[bcb].as_ref()
}
- pub(super) fn take_bcb_counter(&mut self, bcb: BasicCoverageBlock) -> Option<BcbCounter> {
- self.bcb_counters[bcb].take()
+ pub(super) fn bcb_node_counters(
+ &self,
+ ) -> impl Iterator<Item = (BasicCoverageBlock, &BcbCounter)> {
+ self.bcb_counters
+ .iter_enumerated()
+ .filter_map(|(bcb, counter_kind)| Some((bcb, counter_kind.as_ref()?)))
}
- pub(super) fn drain_bcb_counters(
- &mut self,
- ) -> impl Iterator<Item = (BasicCoverageBlock, BcbCounter)> + '_ {
- self.bcb_counters
- .iter_enumerated_mut()
- .filter_map(|(bcb, counter)| Some((bcb, counter.take()?)))
+ /// For each edge in the BCB graph that has an associated counter, yields
+ /// that edge's *from* and *to* nodes, and its counter.
+ pub(super) fn bcb_edge_counters(
+ &self,
+ ) -> impl Iterator<Item = (BasicCoverageBlock, BasicCoverageBlock, &BcbCounter)> {
+ self.bcb_edge_counters
+ .iter()
+ .map(|(&(from_bcb, to_bcb), counter_kind)| (from_bcb, to_bcb, counter_kind))
}
- pub(super) fn drain_bcb_edge_counters(
- &mut self,
- ) -> impl Iterator<Item = ((BasicCoverageBlock, BasicCoverageBlock), BcbCounter)> + '_ {
- self.bcb_edge_counters.drain()
+ pub(super) fn take_expressions(&mut self) -> IndexVec<ExpressionId, Expression> {
+ std::mem::take(&mut self.expressions)
}
}
/// Traverse the `CoverageGraph` and add either a `Counter` or `Expression` to every BCB, to be
-/// injected with `CoverageSpan`s. `Expressions` have no runtime overhead, so if a viable expression
+/// injected with coverage spans. `Expressions` have no runtime overhead, so if a viable expression
/// (adding or subtracting two other counters or expressions) can compute the same result as an
/// embedded counter, an `Expression` should be used.
struct MakeBcbCounters<'a> {
@@ -230,21 +207,11 @@ impl<'a> MakeBcbCounters<'a> {
/// One way to predict which branch executes the least is by considering loops. A loop is exited
/// at a branch, so the branch that jumps to a `BasicCoverageBlock` outside the loop is almost
/// always executed less than the branch that does not exit the loop.
- ///
- /// Returns any non-code-span expressions created to represent intermediate values (such as to
- /// add two counters so the result can be subtracted from another counter), or an Error with
- /// message for subsequent debugging.
- fn make_bcb_counters(&mut self, coverage_spans: &[CoverageSpan]) -> Result<(), Error> {
+ fn make_bcb_counters(&mut self, bcb_has_coverage_spans: impl Fn(BasicCoverageBlock) -> bool) {
debug!("make_bcb_counters(): adding a counter or expression to each BasicCoverageBlock");
- let num_bcbs = self.basic_coverage_blocks.num_nodes();
-
- let mut bcbs_with_coverage = BitSet::new_empty(num_bcbs);
- for covspan in coverage_spans {
- bcbs_with_coverage.insert(covspan.bcb);
- }
// Walk the `CoverageGraph`. For each `BasicCoverageBlock` node with an associated
- // `CoverageSpan`, add a counter. If the `BasicCoverageBlock` branches, add a counter or
+ // coverage span, add a counter. If the `BasicCoverageBlock` branches, add a counter or
// expression to each branch `BasicCoverageBlock` (if the branch BCB has only one incoming
// edge) or edge from the branching BCB to the branch BCB (if the branch BCB has multiple
// incoming edges).
@@ -254,39 +221,36 @@ impl<'a> MakeBcbCounters<'a> {
// the loop. The `traversal` state includes a `context_stack`, providing a way to know if
// the current BCB is in one or more nested loops or not.
let mut traversal = TraverseCoverageGraphWithLoops::new(&self.basic_coverage_blocks);
- while let Some(bcb) = traversal.next(self.basic_coverage_blocks) {
- if bcbs_with_coverage.contains(bcb) {
- debug!("{:?} has at least one `CoverageSpan`. Get or make its counter", bcb);
- let branching_counter_operand = self.get_or_make_counter_operand(bcb)?;
+ while let Some(bcb) = traversal.next() {
+ if bcb_has_coverage_spans(bcb) {
+ debug!("{:?} has at least one coverage span. Get or make its counter", bcb);
+ let branching_counter_operand = self.get_or_make_counter_operand(bcb);
if self.bcb_needs_branch_counters(bcb) {
- self.make_branch_counters(&mut traversal, bcb, branching_counter_operand)?;
+ self.make_branch_counters(&traversal, bcb, branching_counter_operand);
}
} else {
debug!(
- "{:?} does not have any `CoverageSpan`s. A counter will only be added if \
+ "{:?} does not have any coverage spans. A counter will only be added if \
and when a covered BCB has an expression dependency.",
bcb,
);
}
}
- if traversal.is_complete() {
- Ok(())
- } else {
- Error::from_string(format!(
- "`TraverseCoverageGraphWithLoops` missed some `BasicCoverageBlock`s: {:?}",
- traversal.unvisited(),
- ))
- }
+ assert!(
+ traversal.is_complete(),
+ "`TraverseCoverageGraphWithLoops` missed some `BasicCoverageBlock`s: {:?}",
+ traversal.unvisited(),
+ );
}
fn make_branch_counters(
&mut self,
- traversal: &mut TraverseCoverageGraphWithLoops,
+ traversal: &TraverseCoverageGraphWithLoops<'_>,
branching_bcb: BasicCoverageBlock,
- branching_counter_operand: Operand,
- ) -> Result<(), Error> {
+ branching_counter_operand: CovTerm,
+ ) {
let branches = self.bcb_branches(branching_bcb);
debug!(
"{:?} has some branch(es) without counters:\n {}",
@@ -319,10 +283,10 @@ impl<'a> MakeBcbCounters<'a> {
counter",
branch, branching_bcb
);
- self.get_or_make_counter_operand(branch.target_bcb)?
+ self.get_or_make_counter_operand(branch.target_bcb)
} else {
debug!(" {:?} has multiple incoming edges, so adding an edge counter", branch);
- self.get_or_make_edge_counter_operand(branching_bcb, branch.target_bcb)?
+ self.get_or_make_edge_counter_operand(branching_bcb, branch.target_bcb)
};
if let Some(sumup_counter_operand) =
some_sumup_counter_operand.replace(branch_counter_operand)
@@ -333,8 +297,7 @@ impl<'a> MakeBcbCounters<'a> {
sumup_counter_operand,
);
debug!(" [new intermediate expression: {:?}]", intermediate_expression);
- let intermediate_expression_operand = intermediate_expression.as_operand();
- self.coverage_counters.intermediate_expressions.push(intermediate_expression);
+ let intermediate_expression_operand = intermediate_expression.as_term();
some_sumup_counter_operand.replace(intermediate_expression_operand);
}
}
@@ -358,31 +321,18 @@ impl<'a> MakeBcbCounters<'a> {
debug!("{:?} gets an expression: {:?}", expression_branch, expression);
let bcb = expression_branch.target_bcb;
if expression_branch.is_only_path_to_target() {
- self.coverage_counters.set_bcb_counter(bcb, expression)?;
+ self.coverage_counters.set_bcb_counter(bcb, expression);
} else {
- self.coverage_counters.set_bcb_edge_counter(branching_bcb, bcb, expression)?;
+ self.coverage_counters.set_bcb_edge_counter(branching_bcb, bcb, expression);
}
- Ok(())
- }
-
- fn get_or_make_counter_operand(&mut self, bcb: BasicCoverageBlock) -> Result<Operand, Error> {
- self.recursive_get_or_make_counter_operand(bcb, 1)
}
- fn recursive_get_or_make_counter_operand(
- &mut self,
- bcb: BasicCoverageBlock,
- debug_indent_level: usize,
- ) -> Result<Operand, Error> {
+ #[instrument(level = "debug", skip(self))]
+ fn get_or_make_counter_operand(&mut self, bcb: BasicCoverageBlock) -> CovTerm {
// If the BCB already has a counter, return it.
if let Some(counter_kind) = &self.coverage_counters.bcb_counters[bcb] {
- debug!(
- "{}{:?} already has a counter: {:?}",
- NESTED_INDENT.repeat(debug_indent_level),
- bcb,
- counter_kind,
- );
- return Ok(counter_kind.as_operand());
+ debug!("{bcb:?} already has a counter: {counter_kind:?}");
+ return counter_kind.as_term();
}
// A BCB with only one incoming edge gets a simple `Counter` (via `make_counter()`).
@@ -392,20 +342,12 @@ impl<'a> MakeBcbCounters<'a> {
if one_path_to_target || self.bcb_predecessors(bcb).contains(&bcb) {
let counter_kind = self.coverage_counters.make_counter();
if one_path_to_target {
- debug!(
- "{}{:?} gets a new counter: {:?}",
- NESTED_INDENT.repeat(debug_indent_level),
- bcb,
- counter_kind,
- );
+ debug!("{bcb:?} gets a new counter: {counter_kind:?}");
} else {
debug!(
- "{}{:?} has itself as its own predecessor. It can't be part of its own \
- Expression sum, so it will get its own new counter: {:?}. (Note, the compiled \
- code will generate an infinite loop.)",
- NESTED_INDENT.repeat(debug_indent_level),
- bcb,
- counter_kind,
+ "{bcb:?} has itself as its own predecessor. It can't be part of its own \
+ Expression sum, so it will get its own new counter: {counter_kind:?}. \
+ (Note, the compiled code will generate an infinite loop.)",
);
}
return self.coverage_counters.set_bcb_counter(bcb, counter_kind);
@@ -415,24 +357,14 @@ impl<'a> MakeBcbCounters<'a> {
// counters and/or expressions of its incoming edges. This will recursively get or create
// counters for those incoming edges first, then call `make_expression()` to sum them up,
// with additional intermediate expressions as needed.
+ let _sumup_debug_span = debug_span!("(preparing sum-up expression)").entered();
+
let mut predecessors = self.bcb_predecessors(bcb).to_owned().into_iter();
- debug!(
- "{}{:?} has multiple incoming edges and will get an expression that sums them up...",
- NESTED_INDENT.repeat(debug_indent_level),
- bcb,
- );
- let first_edge_counter_operand = self.recursive_get_or_make_edge_counter_operand(
- predecessors.next().unwrap(),
- bcb,
- debug_indent_level + 1,
- )?;
+ let first_edge_counter_operand =
+ self.get_or_make_edge_counter_operand(predecessors.next().unwrap(), bcb);
let mut some_sumup_edge_counter_operand = None;
for predecessor in predecessors {
- let edge_counter_operand = self.recursive_get_or_make_edge_counter_operand(
- predecessor,
- bcb,
- debug_indent_level + 1,
- )?;
+ let edge_counter_operand = self.get_or_make_edge_counter_operand(predecessor, bcb);
if let Some(sumup_edge_counter_operand) =
some_sumup_edge_counter_operand.replace(edge_counter_operand)
{
@@ -441,13 +373,8 @@ impl<'a> MakeBcbCounters<'a> {
Op::Add,
edge_counter_operand,
);
- debug!(
- "{}new intermediate expression: {:?}",
- NESTED_INDENT.repeat(debug_indent_level),
- intermediate_expression
- );
- let intermediate_expression_operand = intermediate_expression.as_operand();
- self.coverage_counters.intermediate_expressions.push(intermediate_expression);
+ debug!("new intermediate expression: {intermediate_expression:?}");
+ let intermediate_expression_operand = intermediate_expression.as_term();
some_sumup_edge_counter_operand.replace(intermediate_expression_operand);
}
}
@@ -456,59 +383,36 @@ impl<'a> MakeBcbCounters<'a> {
Op::Add,
some_sumup_edge_counter_operand.unwrap(),
);
- debug!(
- "{}{:?} gets a new counter (sum of predecessor counters): {:?}",
- NESTED_INDENT.repeat(debug_indent_level),
- bcb,
- counter_kind
- );
+ drop(_sumup_debug_span);
+
+ debug!("{bcb:?} gets a new counter (sum of predecessor counters): {counter_kind:?}");
self.coverage_counters.set_bcb_counter(bcb, counter_kind)
}
+ #[instrument(level = "debug", skip(self))]
fn get_or_make_edge_counter_operand(
&mut self,
from_bcb: BasicCoverageBlock,
to_bcb: BasicCoverageBlock,
- ) -> Result<Operand, Error> {
- self.recursive_get_or_make_edge_counter_operand(from_bcb, to_bcb, 1)
- }
-
- fn recursive_get_or_make_edge_counter_operand(
- &mut self,
- from_bcb: BasicCoverageBlock,
- to_bcb: BasicCoverageBlock,
- debug_indent_level: usize,
- ) -> Result<Operand, Error> {
+ ) -> CovTerm {
// If the source BCB has only one successor (assumed to be the given target), an edge
// counter is unnecessary. Just get or make a counter for the source BCB.
let successors = self.bcb_successors(from_bcb).iter();
if successors.len() == 1 {
- return self.recursive_get_or_make_counter_operand(from_bcb, debug_indent_level + 1);
+ return self.get_or_make_counter_operand(from_bcb);
}
// If the edge already has a counter, return it.
if let Some(counter_kind) =
self.coverage_counters.bcb_edge_counters.get(&(from_bcb, to_bcb))
{
- debug!(
- "{}Edge {:?}->{:?} already has a counter: {:?}",
- NESTED_INDENT.repeat(debug_indent_level),
- from_bcb,
- to_bcb,
- counter_kind
- );
- return Ok(counter_kind.as_operand());
+ debug!("Edge {from_bcb:?}->{to_bcb:?} already has a counter: {counter_kind:?}");
+ return counter_kind.as_term();
}
// Make a new counter to count this edge.
let counter_kind = self.coverage_counters.make_counter();
- debug!(
- "{}Edge {:?}->{:?} gets a new counter: {:?}",
- NESTED_INDENT.repeat(debug_indent_level),
- from_bcb,
- to_bcb,
- counter_kind
- );
+ debug!("Edge {from_bcb:?}->{to_bcb:?} gets a new counter: {counter_kind:?}");
self.coverage_counters.set_bcb_edge_counter(from_bcb, to_bcb, counter_kind)
}
@@ -516,21 +420,14 @@ impl<'a> MakeBcbCounters<'a> {
/// found, select any branch.
fn choose_preferred_expression_branch(
&self,
- traversal: &TraverseCoverageGraphWithLoops,
+ traversal: &TraverseCoverageGraphWithLoops<'_>,
branches: &[BcbBranch],
) -> BcbBranch {
- let branch_needs_a_counter = |branch: &BcbBranch| self.branch_has_no_counter(branch);
-
- let some_reloop_branch = self.find_some_reloop_branch(traversal, &branches);
- if let Some(reloop_branch_without_counter) =
- some_reloop_branch.filter(branch_needs_a_counter)
- {
- debug!(
- "Selecting reloop_branch={:?} that still needs a counter, to get the \
- `Expression`",
- reloop_branch_without_counter
- );
- reloop_branch_without_counter
+ let good_reloop_branch = self.find_good_reloop_branch(traversal, &branches);
+ if let Some(reloop_branch) = good_reloop_branch {
+ assert!(self.branch_has_no_counter(&reloop_branch));
+ debug!("Selecting reloop branch {reloop_branch:?} to get an expression");
+ reloop_branch
} else {
let &branch_without_counter =
branches.iter().find(|&branch| self.branch_has_no_counter(branch)).expect(
@@ -547,75 +444,52 @@ impl<'a> MakeBcbCounters<'a> {
}
}
- /// At most, one of the branches (or its edge, from the branching_bcb, if the branch has
- /// multiple incoming edges) can have a counter computed by expression.
- ///
- /// If at least one of the branches leads outside of a loop (`found_loop_exit` is
- /// true), and at least one other branch does not exit the loop (the first of which
- /// is captured in `some_reloop_branch`), it's likely any reloop branch will be
- /// executed far more often than loop exit branch, making the reloop branch a better
- /// candidate for an expression.
- fn find_some_reloop_branch(
+ /// Tries to find a branch that leads back to the top of a loop, and that
+ /// doesn't already have a counter. Such branches are good candidates to
+ /// be given an expression (instead of a physical counter), because they
+ /// will tend to be executed more times than a loop-exit branch.
+ fn find_good_reloop_branch(
&self,
- traversal: &TraverseCoverageGraphWithLoops,
+ traversal: &TraverseCoverageGraphWithLoops<'_>,
branches: &[BcbBranch],
) -> Option<BcbBranch> {
- let branch_needs_a_counter = |branch: &BcbBranch| self.branch_has_no_counter(branch);
-
- let mut some_reloop_branch: Option<BcbBranch> = None;
- for context in traversal.context_stack.iter().rev() {
- if let Some((backedge_from_bcbs, _)) = &context.loop_backedges {
- let mut found_loop_exit = false;
- for &branch in branches.iter() {
- if backedge_from_bcbs.iter().any(|&backedge_from_bcb| {
- self.bcb_dominates(branch.target_bcb, backedge_from_bcb)
- }) {
- if let Some(reloop_branch) = some_reloop_branch {
- if self.branch_has_no_counter(&reloop_branch) {
- // we already found a candidate reloop_branch that still
- // needs a counter
- continue;
- }
- }
- // The path from branch leads back to the top of the loop. Set this
- // branch as the `reloop_branch`. If this branch already has a
- // counter, and we find another reloop branch that doesn't have a
- // counter yet, that branch will be selected as the `reloop_branch`
- // instead.
- some_reloop_branch = Some(branch);
- } else {
- // The path from branch leads outside this loop
- found_loop_exit = true;
- }
- if found_loop_exit
- && some_reloop_branch.filter(branch_needs_a_counter).is_some()
- {
- // Found both a branch that exits the loop and a branch that returns
- // to the top of the loop (`reloop_branch`), and the `reloop_branch`
- // doesn't already have a counter.
- break;
+ // Consider each loop on the current traversal context stack, top-down.
+ for reloop_bcbs in traversal.reloop_bcbs_per_loop() {
+ let mut all_branches_exit_this_loop = true;
+
+ // Try to find a branch that doesn't exit this loop and doesn't
+ // already have a counter.
+ for &branch in branches {
+ // A branch is a reloop branch if it dominates any BCB that has
+ // an edge back to the loop header. (Other branches are exits.)
+ let is_reloop_branch = reloop_bcbs.iter().any(|&reloop_bcb| {
+ self.basic_coverage_blocks.dominates(branch.target_bcb, reloop_bcb)
+ });
+
+ if is_reloop_branch {
+ all_branches_exit_this_loop = false;
+ if self.branch_has_no_counter(&branch) {
+ // We found a good branch to be given an expression.
+ return Some(branch);
}
+ // Keep looking for another reloop branch without a counter.
+ } else {
+ // This branch exits the loop.
}
- if !found_loop_exit {
- debug!(
- "No branches exit the loop, so any branch without an existing \
- counter can have the `Expression`."
- );
- break;
- }
- if some_reloop_branch.is_some() {
- debug!(
- "Found a branch that exits the loop and a branch the loops back to \
- the top of the loop (`reloop_branch`). The `reloop_branch` will \
- get the `Expression`, as long as it still needs a counter."
- );
- break;
- }
- // else all branches exited this loop context, so run the same checks with
- // the outer loop(s)
}
+
+ if !all_branches_exit_this_loop {
+ // We found one or more reloop branches, but all of them already
+ // have counters. Let the caller choose one of the exit branches.
+ debug!("All reloop branches had counters; skip checking the other loops");
+ return None;
+ }
+
+ // All of the branches exit this loop, so keep looking for a good
+ // reloop branch for one of the outer loops.
}
- some_reloop_branch
+
+ None
}
#[inline]
@@ -661,9 +535,4 @@ impl<'a> MakeBcbCounters<'a> {
fn bcb_has_one_path_to_target(&self, bcb: BasicCoverageBlock) -> bool {
self.bcb_predecessors(bcb).len() <= 1
}
-
- #[inline]
- fn bcb_dominates(&self, dom: BasicCoverageBlock, node: BasicCoverageBlock) -> bool {
- self.basic_coverage_blocks.dominates(dom, node)
- }
}
diff --git a/compiler/rustc_mir_transform/src/coverage/graph.rs b/compiler/rustc_mir_transform/src/coverage/graph.rs
index ff2254d69..6bab62aa8 100644
--- a/compiler/rustc_mir_transform/src/coverage/graph.rs
+++ b/compiler/rustc_mir_transform/src/coverage/graph.rs
@@ -1,10 +1,12 @@
+use rustc_data_structures::captures::Captures;
use rustc_data_structures::graph::dominators::{self, Dominators};
use rustc_data_structures::graph::{self, GraphSuccessors, WithNumNodes, WithStartNode};
use rustc_index::bit_set::BitSet;
use rustc_index::{IndexSlice, IndexVec};
-use rustc_middle::mir::{self, BasicBlock, BasicBlockData, Terminator, TerminatorKind};
+use rustc_middle::mir::{self, BasicBlock, TerminatorKind};
use std::cmp::Ordering;
+use std::collections::VecDeque;
use std::ops::{Index, IndexMut};
/// A coverage-specific simplification of the MIR control flow graph (CFG). The `CoverageGraph`s
@@ -36,9 +38,8 @@ impl CoverageGraph {
}
let bcb_data = &bcbs[bcb];
let mut bcb_successors = Vec::new();
- for successor in
- bcb_filtered_successors(&mir_body, &bcb_data.terminator(mir_body).kind)
- .filter_map(|successor_bb| bb_to_bcb[successor_bb])
+ for successor in bcb_filtered_successors(&mir_body, bcb_data.last_bb())
+ .filter_map(|successor_bb| bb_to_bcb[successor_bb])
{
if !seen[successor] {
seen[successor] = true;
@@ -80,10 +81,9 @@ impl CoverageGraph {
// intentionally omits unwind paths.
// FIXME(#78544): MIR InstrumentCoverage: Improve coverage of `#[should_panic]` tests and
// `catch_unwind()` handlers.
- let mir_cfg_without_unwind = ShortCircuitPreorder::new(&mir_body, bcb_filtered_successors);
let mut basic_blocks = Vec::new();
- for (bb, data) in mir_cfg_without_unwind {
+ for bb in short_circuit_preorder(mir_body, bcb_filtered_successors) {
if let Some(last) = basic_blocks.last() {
let predecessors = &mir_body.basic_blocks.predecessors()[bb];
if predecessors.len() > 1 || !predecessors.contains(last) {
@@ -109,7 +109,7 @@ impl CoverageGraph {
}
basic_blocks.push(bb);
- let term = data.terminator();
+ let term = mir_body[bb].terminator();
match term.kind {
TerminatorKind::Return { .. }
@@ -147,7 +147,7 @@ impl CoverageGraph {
| TerminatorKind::Unreachable
| TerminatorKind::Drop { .. }
| TerminatorKind::Call { .. }
- | TerminatorKind::GeneratorDrop
+ | TerminatorKind::CoroutineDrop
| TerminatorKind::Assert { .. }
| TerminatorKind::FalseEdge { .. }
| TerminatorKind::FalseUnwind { .. }
@@ -288,9 +288,9 @@ rustc_index::newtype_index! {
/// not relevant to coverage analysis. `FalseUnwind`, for example, can be treated the same as
/// a `Goto`, and merged with its successor into the same BCB.
///
-/// Each BCB with at least one computed `CoverageSpan` will have no more than one `Counter`.
+/// Each BCB with at least one computed coverage span will have no more than one `Counter`.
/// In some cases, a BCB's execution count can be computed by `Expression`. Additional
-/// disjoint `CoverageSpan`s in a BCB can also be counted by `Expression` (by adding `ZERO`
+/// disjoint coverage spans in a BCB can also be counted by `Expression` (by adding `ZERO`
/// to the BCB's primary counter or expression).
///
/// The BCB CFG is critical to simplifying the coverage analysis by ensuring graph path-based
@@ -316,11 +316,6 @@ impl BasicCoverageBlockData {
pub fn last_bb(&self) -> BasicBlock {
*self.basic_blocks.last().unwrap()
}
-
- #[inline(always)]
- pub fn terminator<'a, 'tcx>(&self, mir_body: &'a mir::Body<'tcx>) -> &'a Terminator<'tcx> {
- &mir_body[self.last_bb()].terminator()
- }
}
/// Represents a successor from a branching BasicCoverageBlock (such as the arms of a `SwitchInt`)
@@ -362,26 +357,28 @@ impl std::fmt::Debug for BcbBranch {
}
}
-// Returns the `Terminator`s non-unwind successors.
+// Returns the subset of a block's successors that are relevant to the coverage
+// graph, i.e. those that do not represent unwinds or unreachable branches.
// FIXME(#78544): MIR InstrumentCoverage: Improve coverage of `#[should_panic]` tests and
// `catch_unwind()` handlers.
fn bcb_filtered_successors<'a, 'tcx>(
body: &'a mir::Body<'tcx>,
- term_kind: &'a TerminatorKind<'tcx>,
-) -> Box<dyn Iterator<Item = BasicBlock> + 'a> {
- Box::new(
- match &term_kind {
- // SwitchInt successors are never unwind, and all of them should be traversed.
- TerminatorKind::SwitchInt { ref targets, .. } => {
- None.into_iter().chain(targets.all_targets().into_iter().copied())
- }
- // For all other kinds, return only the first successor, if any, and ignore unwinds.
- // NOTE: `chain(&[])` is required to coerce the `option::iter` (from
- // `next().into_iter()`) into the `mir::Successors` aliased type.
- _ => term_kind.successors().next().into_iter().chain((&[]).into_iter().copied()),
- }
- .filter(move |&successor| body[successor].terminator().kind != TerminatorKind::Unreachable),
- )
+ bb: BasicBlock,
+) -> impl Iterator<Item = BasicBlock> + Captures<'a> + Captures<'tcx> {
+ let terminator = body[bb].terminator();
+
+ let take_n_successors = match terminator.kind {
+ // SwitchInt successors are never unwinds, so all of them should be traversed.
+ TerminatorKind::SwitchInt { .. } => usize::MAX,
+ // For all other kinds, return only the first successor (if any), ignoring any
+ // unwind successors.
+ _ => 1,
+ };
+
+ terminator
+ .successors()
+ .take(take_n_successors)
+ .filter(move |&successor| body[successor].terminator().kind != TerminatorKind::Unreachable)
}
/// Maintains separate worklists for each loop in the BasicCoverageBlock CFG, plus one for the
@@ -389,57 +386,72 @@ fn bcb_filtered_successors<'a, 'tcx>(
/// ensures a loop is completely traversed before processing Blocks after the end of the loop.
#[derive(Debug)]
pub(super) struct TraversalContext {
- /// From one or more backedges returning to a loop header.
- pub loop_backedges: Option<(Vec<BasicCoverageBlock>, BasicCoverageBlock)>,
-
- /// worklist, to be traversed, of CoverageGraph in the loop with the given loop
- /// backedges, such that the loop is the inner inner-most loop containing these
- /// CoverageGraph
- pub worklist: Vec<BasicCoverageBlock>,
+ /// BCB with one or more incoming loop backedges, indicating which loop
+ /// this context is for.
+ ///
+ /// If `None`, this is the non-loop context for the function as a whole.
+ loop_header: Option<BasicCoverageBlock>,
+
+ /// Worklist of BCBs to be processed in this context.
+ worklist: VecDeque<BasicCoverageBlock>,
}
-pub(super) struct TraverseCoverageGraphWithLoops {
- pub backedges: IndexVec<BasicCoverageBlock, Vec<BasicCoverageBlock>>,
- pub context_stack: Vec<TraversalContext>,
+pub(super) struct TraverseCoverageGraphWithLoops<'a> {
+ basic_coverage_blocks: &'a CoverageGraph,
+
+ backedges: IndexVec<BasicCoverageBlock, Vec<BasicCoverageBlock>>,
+ context_stack: Vec<TraversalContext>,
visited: BitSet<BasicCoverageBlock>,
}
-impl TraverseCoverageGraphWithLoops {
- pub fn new(basic_coverage_blocks: &CoverageGraph) -> Self {
- let start_bcb = basic_coverage_blocks.start_node();
+impl<'a> TraverseCoverageGraphWithLoops<'a> {
+ pub(super) fn new(basic_coverage_blocks: &'a CoverageGraph) -> Self {
let backedges = find_loop_backedges(basic_coverage_blocks);
- let context_stack =
- vec![TraversalContext { loop_backedges: None, worklist: vec![start_bcb] }];
+
+ let worklist = VecDeque::from([basic_coverage_blocks.start_node()]);
+ let context_stack = vec![TraversalContext { loop_header: None, worklist }];
+
// `context_stack` starts with a `TraversalContext` for the main function context (beginning
// with the `start` BasicCoverageBlock of the function). New worklists are pushed to the top
// of the stack as loops are entered, and popped off of the stack when a loop's worklist is
// exhausted.
let visited = BitSet::new_empty(basic_coverage_blocks.num_nodes());
- Self { backedges, context_stack, visited }
+ Self { basic_coverage_blocks, backedges, context_stack, visited }
}
- pub fn next(&mut self, basic_coverage_blocks: &CoverageGraph) -> Option<BasicCoverageBlock> {
+ /// For each loop on the loop context stack (top-down), yields a list of BCBs
+ /// within that loop that have an outgoing edge back to the loop header.
+ pub(super) fn reloop_bcbs_per_loop(&self) -> impl Iterator<Item = &[BasicCoverageBlock]> {
+ self.context_stack
+ .iter()
+ .rev()
+ .filter_map(|context| context.loop_header)
+ .map(|header_bcb| self.backedges[header_bcb].as_slice())
+ }
+
+ pub(super) fn next(&mut self) -> Option<BasicCoverageBlock> {
debug!(
"TraverseCoverageGraphWithLoops::next - context_stack: {:?}",
self.context_stack.iter().rev().collect::<Vec<_>>()
);
while let Some(context) = self.context_stack.last_mut() {
- if let Some(next_bcb) = context.worklist.pop() {
- if !self.visited.insert(next_bcb) {
- debug!("Already visited: {:?}", next_bcb);
+ if let Some(bcb) = context.worklist.pop_front() {
+ if !self.visited.insert(bcb) {
+ debug!("Already visited: {bcb:?}");
continue;
}
- debug!("Visiting {:?}", next_bcb);
- if self.backedges[next_bcb].len() > 0 {
- debug!("{:?} is a loop header! Start a new TraversalContext...", next_bcb);
+ debug!("Visiting {bcb:?}");
+
+ if self.backedges[bcb].len() > 0 {
+ debug!("{bcb:?} is a loop header! Start a new TraversalContext...");
self.context_stack.push(TraversalContext {
- loop_backedges: Some((self.backedges[next_bcb].clone(), next_bcb)),
- worklist: Vec::new(),
+ loop_header: Some(bcb),
+ worklist: VecDeque::new(),
});
}
- self.extend_worklist(basic_coverage_blocks, next_bcb);
- return Some(next_bcb);
+ self.add_successors_to_worklists(bcb);
+ return Some(bcb);
} else {
// Strip contexts with empty worklists from the top of the stack
self.context_stack.pop();
@@ -449,13 +461,10 @@ impl TraverseCoverageGraphWithLoops {
None
}
- pub fn extend_worklist(
- &mut self,
- basic_coverage_blocks: &CoverageGraph,
- bcb: BasicCoverageBlock,
- ) {
- let successors = &basic_coverage_blocks.successors[bcb];
+ pub fn add_successors_to_worklists(&mut self, bcb: BasicCoverageBlock) {
+ let successors = &self.basic_coverage_blocks.successors[bcb];
debug!("{:?} has {} successors:", bcb, successors.len());
+
for &successor in successors {
if successor == bcb {
debug!(
@@ -464,56 +473,44 @@ impl TraverseCoverageGraphWithLoops {
bcb
);
// Don't re-add this successor to the worklist. We are already processing it.
+ // FIXME: This claims to skip just the self-successor, but it actually skips
+ // all other successors as well. Does that matter?
break;
}
- for context in self.context_stack.iter_mut().rev() {
- // Add successors of the current BCB to the appropriate context. Successors that
- // stay within a loop are added to the BCBs context worklist. Successors that
- // exit the loop (they are not dominated by the loop header) must be reachable
- // from other BCBs outside the loop, and they will be added to a different
- // worklist.
- //
- // Branching blocks (with more than one successor) must be processed before
- // blocks with only one successor, to prevent unnecessarily complicating
- // `Expression`s by creating a Counter in a `BasicCoverageBlock` that the
- // branching block would have given an `Expression` (or vice versa).
- let (some_successor_to_add, some_loop_header) =
- if let Some((_, loop_header)) = context.loop_backedges {
- if basic_coverage_blocks.dominates(loop_header, successor) {
- (Some(successor), Some(loop_header))
- } else {
- (None, None)
- }
- } else {
- (Some(successor), None)
- };
- if let Some(successor_to_add) = some_successor_to_add {
- if basic_coverage_blocks.successors[successor_to_add].len() > 1 {
- debug!(
- "{:?} successor is branching. Prioritize it at the beginning of \
- the {}",
- successor_to_add,
- if let Some(loop_header) = some_loop_header {
- format!("worklist for the loop headed by {loop_header:?}")
- } else {
- String::from("non-loop worklist")
- },
- );
- context.worklist.insert(0, successor_to_add);
- } else {
- debug!(
- "{:?} successor is non-branching. Defer it to the end of the {}",
- successor_to_add,
- if let Some(loop_header) = some_loop_header {
- format!("worklist for the loop headed by {loop_header:?}")
- } else {
- String::from("non-loop worklist")
- },
- );
- context.worklist.push(successor_to_add);
+
+ // Add successors of the current BCB to the appropriate context. Successors that
+ // stay within a loop are added to the BCBs context worklist. Successors that
+ // exit the loop (they are not dominated by the loop header) must be reachable
+ // from other BCBs outside the loop, and they will be added to a different
+ // worklist.
+ //
+ // Branching blocks (with more than one successor) must be processed before
+ // blocks with only one successor, to prevent unnecessarily complicating
+ // `Expression`s by creating a Counter in a `BasicCoverageBlock` that the
+ // branching block would have given an `Expression` (or vice versa).
+
+ let context = self
+ .context_stack
+ .iter_mut()
+ .rev()
+ .find(|context| match context.loop_header {
+ Some(loop_header) => {
+ self.basic_coverage_blocks.dominates(loop_header, successor)
}
- break;
- }
+ None => true,
+ })
+ .unwrap_or_else(|| bug!("should always fall back to the root non-loop context"));
+ debug!("adding to worklist for {:?}", context.loop_header);
+
+ // FIXME: The code below had debug messages claiming to add items to a
+ // particular end of the worklist, but was confused about which end was
+ // which. The existing behaviour has been preserved for now, but it's
+ // unclear what the intended behaviour was.
+
+ if self.basic_coverage_blocks.successors[successor].len() > 1 {
+ context.worklist.push_back(successor);
+ } else {
+ context.worklist.push_front(successor);
}
}
}
@@ -553,66 +550,28 @@ pub(super) fn find_loop_backedges(
backedges
}
-pub struct ShortCircuitPreorder<
- 'a,
- 'tcx,
- F: Fn(&'a mir::Body<'tcx>, &'a TerminatorKind<'tcx>) -> Box<dyn Iterator<Item = BasicBlock> + 'a>,
-> {
+fn short_circuit_preorder<'a, 'tcx, F, Iter>(
body: &'a mir::Body<'tcx>,
- visited: BitSet<BasicBlock>,
- worklist: Vec<BasicBlock>,
filtered_successors: F,
-}
-
-impl<
- 'a,
- 'tcx,
- F: Fn(&'a mir::Body<'tcx>, &'a TerminatorKind<'tcx>) -> Box<dyn Iterator<Item = BasicBlock> + 'a>,
-> ShortCircuitPreorder<'a, 'tcx, F>
-{
- pub fn new(
- body: &'a mir::Body<'tcx>,
- filtered_successors: F,
- ) -> ShortCircuitPreorder<'a, 'tcx, F> {
- let worklist = vec![mir::START_BLOCK];
-
- ShortCircuitPreorder {
- body,
- visited: BitSet::new_empty(body.basic_blocks.len()),
- worklist,
- filtered_successors,
- }
- }
-}
-
-impl<
- 'a,
- 'tcx,
- F: Fn(&'a mir::Body<'tcx>, &'a TerminatorKind<'tcx>) -> Box<dyn Iterator<Item = BasicBlock> + 'a>,
-> Iterator for ShortCircuitPreorder<'a, 'tcx, F>
+) -> impl Iterator<Item = BasicBlock> + Captures<'a> + Captures<'tcx>
+where
+ F: Fn(&'a mir::Body<'tcx>, BasicBlock) -> Iter,
+ Iter: Iterator<Item = BasicBlock>,
{
- type Item = (BasicBlock, &'a BasicBlockData<'tcx>);
+ let mut visited = BitSet::new_empty(body.basic_blocks.len());
+ let mut worklist = vec![mir::START_BLOCK];
- fn next(&mut self) -> Option<(BasicBlock, &'a BasicBlockData<'tcx>)> {
- while let Some(idx) = self.worklist.pop() {
- if !self.visited.insert(idx) {
+ std::iter::from_fn(move || {
+ while let Some(bb) = worklist.pop() {
+ if !visited.insert(bb) {
continue;
}
- let data = &self.body[idx];
-
- if let Some(ref term) = data.terminator {
- self.worklist.extend((self.filtered_successors)(&self.body, &term.kind));
- }
+ worklist.extend(filtered_successors(body, bb));
- return Some((idx, data));
+ return Some(bb);
}
None
- }
-
- fn size_hint(&self) -> (usize, Option<usize>) {
- let size = self.body.basic_blocks.len() - self.visited.count();
- (size, Some(size))
- }
+ })
}
diff --git a/compiler/rustc_mir_transform/src/coverage/mod.rs b/compiler/rustc_mir_transform/src/coverage/mod.rs
index c75d33eeb..97e4468a0 100644
--- a/compiler/rustc_mir_transform/src/coverage/mod.rs
+++ b/compiler/rustc_mir_transform/src/coverage/mod.rs
@@ -8,14 +8,12 @@ mod spans;
mod tests;
use self::counters::{BcbCounter, CoverageCounters};
-use self::graph::{BasicCoverageBlock, BasicCoverageBlockData, CoverageGraph};
-use self::spans::{CoverageSpan, CoverageSpans};
+use self::graph::CoverageGraph;
+use self::spans::CoverageSpans;
use crate::MirPass;
-use rustc_data_structures::graph::WithNumNodes;
use rustc_data_structures::sync::Lrc;
-use rustc_index::IndexVec;
use rustc_middle::hir;
use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrFlags;
use rustc_middle::mir::coverage::*;
@@ -28,18 +26,6 @@ use rustc_span::def_id::DefId;
use rustc_span::source_map::SourceMap;
use rustc_span::{ExpnKind, SourceFile, Span, Symbol};
-/// A simple error message wrapper for `coverage::Error`s.
-#[derive(Debug)]
-struct Error {
- message: String,
-}
-
-impl Error {
- pub fn from_string<T>(message: String) -> Result<T, Error> {
- Err(Self { message })
- }
-}
-
/// Inserts `StatementKind::Coverage` statements that either instrument the binary with injected
/// counters, via intrinsic `llvm.instrprof.increment`, and/or inject metadata used during codegen
/// to construct the coverage map.
@@ -154,7 +140,7 @@ impl<'a, 'tcx> Instrumentor<'a, 'tcx> {
let body_span = self.body_span;
////////////////////////////////////////////////////
- // Compute `CoverageSpan`s from the `CoverageGraph`.
+ // Compute coverage spans from the `CoverageGraph`.
let coverage_spans = CoverageSpans::generate_coverage_spans(
&self.mir_body,
fn_sig_span,
@@ -164,179 +150,106 @@ impl<'a, 'tcx> Instrumentor<'a, 'tcx> {
////////////////////////////////////////////////////
// Create an optimized mix of `Counter`s and `Expression`s for the `CoverageGraph`. Ensure
- // every `CoverageSpan` has a `Counter` or `Expression` assigned to its `BasicCoverageBlock`
+ // every coverage span has a `Counter` or `Expression` assigned to its `BasicCoverageBlock`
// and all `Expression` dependencies (operands) are also generated, for any other
- // `BasicCoverageBlock`s not already associated with a `CoverageSpan`.
- //
- // Intermediate expressions (used to compute other `Expression` values), which have no
- // direct association with any `BasicCoverageBlock`, are accumulated inside `coverage_counters`.
- let result = self
- .coverage_counters
- .make_bcb_counters(&mut self.basic_coverage_blocks, &coverage_spans);
-
- if let Ok(()) = result {
- ////////////////////////////////////////////////////
- // Remove the counter or edge counter from of each `CoverageSpan`s associated
- // `BasicCoverageBlock`, and inject a `Coverage` statement into the MIR.
- //
- // `Coverage` statements injected from `CoverageSpan`s will include the code regions
- // (source code start and end positions) to be counted by the associated counter.
- //
- // These `CoverageSpan`-associated counters are removed from their associated
- // `BasicCoverageBlock`s so that the only remaining counters in the `CoverageGraph`
- // are indirect counters (to be injected next, without associated code regions).
- self.inject_coverage_span_counters(coverage_spans);
-
- ////////////////////////////////////////////////////
- // For any remaining `BasicCoverageBlock` counters (that were not associated with
- // any `CoverageSpan`), inject `Coverage` statements (_without_ code region `Span`s)
- // to ensure `BasicCoverageBlock` counters that other `Expression`s may depend on
- // are in fact counted, even though they don't directly contribute to counting
- // their own independent code region's coverage.
- self.inject_indirect_counters();
-
- // Intermediate expressions will be injected as the final step, after generating
- // debug output, if any.
- ////////////////////////////////////////////////////
- };
-
- if let Err(e) = result {
- bug!("Error processing: {:?}: {:?}", self.mir_body.source.def_id(), e.message)
- };
-
- ////////////////////////////////////////////////////
- // Finally, inject the intermediate expressions collected along the way.
- for intermediate_expression in &self.coverage_counters.intermediate_expressions {
- inject_intermediate_expression(
- self.mir_body,
- self.make_mir_coverage_kind(intermediate_expression),
- );
- }
+ // `BasicCoverageBlock`s not already associated with a coverage span.
+ let bcb_has_coverage_spans = |bcb| coverage_spans.bcb_has_coverage_spans(bcb);
+ self.coverage_counters
+ .make_bcb_counters(&self.basic_coverage_blocks, bcb_has_coverage_spans);
+
+ let mappings = self.create_mappings_and_inject_coverage_statements(&coverage_spans);
+
+ self.mir_body.function_coverage_info = Some(Box::new(FunctionCoverageInfo {
+ function_source_hash: self.function_source_hash,
+ num_counters: self.coverage_counters.num_counters(),
+ expressions: self.coverage_counters.take_expressions(),
+ mappings,
+ }));
}
- /// Inject a counter for each `CoverageSpan`. There can be multiple `CoverageSpan`s for a given
- /// BCB, but only one actual counter needs to be incremented per BCB. `bb_counters` maps each
- /// `bcb` to its `Counter`, when injected. Subsequent `CoverageSpan`s for a BCB that already has
- /// a `Counter` will inject an `Expression` instead, and compute its value by adding `ZERO` to
- /// the BCB `Counter` value.
- fn inject_coverage_span_counters(&mut self, coverage_spans: Vec<CoverageSpan>) {
- let tcx = self.tcx;
- let source_map = tcx.sess.source_map();
+ /// For each [`BcbCounter`] associated with a BCB node or BCB edge, create
+ /// any corresponding mappings (for BCB nodes only), and inject any necessary
+ /// coverage statements into MIR.
+ fn create_mappings_and_inject_coverage_statements(
+ &mut self,
+ coverage_spans: &CoverageSpans,
+ ) -> Vec<Mapping> {
+ let source_map = self.tcx.sess.source_map();
let body_span = self.body_span;
- let file_name = Symbol::intern(&self.source_file.name.prefer_remapped().to_string_lossy());
-
- let mut bcb_counters = IndexVec::from_elem_n(None, self.basic_coverage_blocks.num_nodes());
- for covspan in coverage_spans {
- let bcb = covspan.bcb;
- let span = covspan.span;
- let counter_kind = if let Some(&counter_operand) = bcb_counters[bcb].as_ref() {
- self.coverage_counters.make_identity_counter(counter_operand)
- } else if let Some(counter_kind) = self.coverage_counters.take_bcb_counter(bcb) {
- bcb_counters[bcb] = Some(counter_kind.as_operand());
- counter_kind
- } else {
- bug!("Every BasicCoverageBlock should have a Counter or Expression");
- };
-
- let code_region = make_code_region(source_map, file_name, span, body_span);
- inject_statement(
- self.mir_body,
- self.make_mir_coverage_kind(&counter_kind),
- self.bcb_leader_bb(bcb),
- Some(code_region),
- );
- }
- }
-
- /// `inject_coverage_span_counters()` looped through the `CoverageSpan`s and injected the
- /// counter from the `CoverageSpan`s `BasicCoverageBlock`, removing it from the BCB in the
- /// process (via `take_counter()`).
- ///
- /// Any other counter associated with a `BasicCoverageBlock`, or its incoming edge, but not
- /// associated with a `CoverageSpan`, should only exist if the counter is an `Expression`
- /// dependency (one of the expression operands). Collect them, and inject the additional
- /// counters into the MIR, without a reportable coverage span.
- fn inject_indirect_counters(&mut self) {
- let mut bcb_counters_without_direct_coverage_spans = Vec::new();
- for (target_bcb, counter_kind) in self.coverage_counters.drain_bcb_counters() {
- bcb_counters_without_direct_coverage_spans.push((None, target_bcb, counter_kind));
- }
- for ((from_bcb, target_bcb), counter_kind) in
- self.coverage_counters.drain_bcb_edge_counters()
- {
- bcb_counters_without_direct_coverage_spans.push((
- Some(from_bcb),
- target_bcb,
- counter_kind,
- ));
- }
+ use rustc_session::RemapFileNameExt;
+ let file_name =
+ Symbol::intern(&self.source_file.name.for_codegen(self.tcx.sess).to_string_lossy());
+
+ let mut mappings = Vec::new();
+
+ // Process the counters and spans associated with BCB nodes.
+ for (bcb, counter_kind) in self.coverage_counters.bcb_node_counters() {
+ let spans = coverage_spans.spans_for_bcb(bcb);
+ let has_mappings = !spans.is_empty();
+
+ // If this BCB has any coverage spans, add corresponding mappings to
+ // the mappings table.
+ if has_mappings {
+ let term = counter_kind.as_term();
+ mappings.extend(spans.iter().map(|&span| {
+ let code_region = make_code_region(source_map, file_name, span, body_span);
+ Mapping { code_region, term }
+ }));
+ }
- for (edge_from_bcb, target_bcb, counter_kind) in bcb_counters_without_direct_coverage_spans
- {
- match counter_kind {
- BcbCounter::Counter { .. } => {
- let inject_to_bb = if let Some(from_bcb) = edge_from_bcb {
- // The MIR edge starts `from_bb` (the outgoing / last BasicBlock in
- // `from_bcb`) and ends at `to_bb` (the incoming / first BasicBlock in the
- // `target_bcb`; also called the `leader_bb`).
- let from_bb = self.bcb_last_bb(from_bcb);
- let to_bb = self.bcb_leader_bb(target_bcb);
-
- let new_bb = inject_edge_counter_basic_block(self.mir_body, from_bb, to_bb);
- debug!(
- "Edge {:?} (last {:?}) -> {:?} (leader {:?}) requires a new MIR \
- BasicBlock {:?}, for unclaimed edge counter {:?}",
- edge_from_bcb, from_bb, target_bcb, to_bb, new_bb, counter_kind,
- );
- new_bb
- } else {
- let target_bb = self.bcb_last_bb(target_bcb);
- debug!(
- "{:?} ({:?}) gets a new Coverage statement for unclaimed counter {:?}",
- target_bcb, target_bb, counter_kind,
- );
- target_bb
- };
-
- inject_statement(
- self.mir_body,
- self.make_mir_coverage_kind(&counter_kind),
- inject_to_bb,
- None,
- );
- }
- BcbCounter::Expression { .. } => inject_intermediate_expression(
+ let do_inject = match counter_kind {
+ // Counter-increment statements always need to be injected.
+ BcbCounter::Counter { .. } => true,
+ // The only purpose of expression-used statements is to detect
+ // when a mapping is unreachable, so we only inject them for
+ // expressions with one or more mappings.
+ BcbCounter::Expression { .. } => has_mappings,
+ };
+ if do_inject {
+ inject_statement(
self.mir_body,
- self.make_mir_coverage_kind(&counter_kind),
- ),
+ self.make_mir_coverage_kind(counter_kind),
+ self.basic_coverage_blocks[bcb].leader_bb(),
+ );
}
}
- }
- #[inline]
- fn bcb_leader_bb(&self, bcb: BasicCoverageBlock) -> BasicBlock {
- self.bcb_data(bcb).leader_bb()
- }
+ // Process the counters associated with BCB edges.
+ for (from_bcb, to_bcb, counter_kind) in self.coverage_counters.bcb_edge_counters() {
+ let do_inject = match counter_kind {
+ // Counter-increment statements always need to be injected.
+ BcbCounter::Counter { .. } => true,
+ // BCB-edge expressions never have mappings, so they never need
+ // a corresponding statement.
+ BcbCounter::Expression { .. } => false,
+ };
+ if !do_inject {
+ continue;
+ }
- #[inline]
- fn bcb_last_bb(&self, bcb: BasicCoverageBlock) -> BasicBlock {
- self.bcb_data(bcb).last_bb()
- }
+ // We need to inject a coverage statement into a new BB between the
+ // last BB of `from_bcb` and the first BB of `to_bcb`.
+ let from_bb = self.basic_coverage_blocks[from_bcb].last_bb();
+ let to_bb = self.basic_coverage_blocks[to_bcb].leader_bb();
+
+ let new_bb = inject_edge_counter_basic_block(self.mir_body, from_bb, to_bb);
+ debug!(
+ "Edge {from_bcb:?} (last {from_bb:?}) -> {to_bcb:?} (leader {to_bb:?}) \
+ requires a new MIR BasicBlock {new_bb:?} for edge counter {counter_kind:?}",
+ );
+
+ // Inject a counter into the newly-created BB.
+ inject_statement(self.mir_body, self.make_mir_coverage_kind(&counter_kind), new_bb);
+ }
- #[inline]
- fn bcb_data(&self, bcb: BasicCoverageBlock) -> &BasicCoverageBlockData {
- &self.basic_coverage_blocks[bcb]
+ mappings
}
fn make_mir_coverage_kind(&self, counter_kind: &BcbCounter) -> CoverageKind {
match *counter_kind {
- BcbCounter::Counter { id } => {
- CoverageKind::Counter { function_source_hash: self.function_source_hash, id }
- }
- BcbCounter::Expression { id, lhs, op, rhs } => {
- CoverageKind::Expression { id, lhs, op, rhs }
- }
+ BcbCounter::Counter { id } => CoverageKind::CounterIncrement { id },
+ BcbCounter::Expression { id } => CoverageKind::ExpressionUsed { id },
}
}
}
@@ -364,42 +277,17 @@ fn inject_edge_counter_basic_block(
new_bb
}
-fn inject_statement(
- mir_body: &mut mir::Body<'_>,
- counter_kind: CoverageKind,
- bb: BasicBlock,
- some_code_region: Option<CodeRegion>,
-) {
- debug!(
- " injecting statement {:?} for {:?} at code region: {:?}",
- counter_kind, bb, some_code_region
- );
+fn inject_statement(mir_body: &mut mir::Body<'_>, counter_kind: CoverageKind, bb: BasicBlock) {
+ debug!(" injecting statement {counter_kind:?} for {bb:?}");
let data = &mut mir_body[bb];
let source_info = data.terminator().source_info;
let statement = Statement {
source_info,
- kind: StatementKind::Coverage(Box::new(Coverage {
- kind: counter_kind,
- code_region: some_code_region,
- })),
+ kind: StatementKind::Coverage(Box::new(Coverage { kind: counter_kind })),
};
data.statements.insert(0, statement);
}
-// Non-code expressions are injected into the coverage map, without generating executable code.
-fn inject_intermediate_expression(mir_body: &mut mir::Body<'_>, expression: CoverageKind) {
- debug_assert!(matches!(expression, CoverageKind::Expression { .. }));
- debug!(" injecting non-code expression {:?}", expression);
- let inject_in_bb = mir::START_BLOCK;
- let data = &mut mir_body[inject_in_bb];
- let source_info = data.terminator().source_info;
- let statement = Statement {
- source_info,
- kind: StatementKind::Coverage(Box::new(Coverage { kind: expression, code_region: None })),
- };
- data.statements.push(statement);
-}
-
/// Convert the Span into its file name, start line and column, and end line and column
fn make_code_region(
source_map: &SourceMap,
diff --git a/compiler/rustc_mir_transform/src/coverage/query.rs b/compiler/rustc_mir_transform/src/coverage/query.rs
index 56365c5d4..809407f89 100644
--- a/compiler/rustc_mir_transform/src/coverage/query.rs
+++ b/compiler/rustc_mir_transform/src/coverage/query.rs
@@ -2,100 +2,31 @@ use super::*;
use rustc_data_structures::captures::Captures;
use rustc_middle::mir::coverage::*;
-use rustc_middle::mir::{self, Body, Coverage, CoverageInfo};
+use rustc_middle::mir::{Body, Coverage, CoverageIdsInfo};
use rustc_middle::query::Providers;
use rustc_middle::ty::{self, TyCtxt};
-use rustc_span::def_id::DefId;
/// A `query` provider for retrieving coverage information injected into MIR.
pub(crate) fn provide(providers: &mut Providers) {
- providers.coverageinfo = |tcx, def_id| coverageinfo(tcx, def_id);
- providers.covered_code_regions = |tcx, def_id| covered_code_regions(tcx, def_id);
+ providers.coverage_ids_info = |tcx, def_id| coverage_ids_info(tcx, def_id);
}
-/// Coverage codegen needs to know the total number of counter IDs and expression IDs that have
-/// been used by a function's coverage mappings. These totals are used to create vectors to hold
-/// the relevant counter and expression data, and the maximum counter ID (+ 1) is also needed by
-/// the `llvm.instrprof.increment` intrinsic.
-///
-/// MIR optimization may split and duplicate some BasicBlock sequences, or optimize out some code
-/// including injected counters. (It is OK if some counters are optimized out, but those counters
-/// are still included in the total `num_counters` or `num_expressions`.) Simply counting the
-/// calls may not work; but computing the number of counters or expressions by adding `1` to the
-/// highest ID (for a given instrumented function) is valid.
-///
-/// It's possible for a coverage expression to remain in MIR while one or both of its operands
-/// have been optimized away. To avoid problems in codegen, we include those operands' IDs when
-/// determining the maximum counter/expression ID, even if the underlying counter/expression is
-/// no longer present.
-struct CoverageVisitor {
- max_counter_id: CounterId,
- max_expression_id: ExpressionId,
-}
-
-impl CoverageVisitor {
- /// Updates `max_counter_id` to the maximum encountered counter ID.
- #[inline(always)]
- fn update_max_counter_id(&mut self, counter_id: CounterId) {
- self.max_counter_id = self.max_counter_id.max(counter_id);
- }
-
- /// Updates `max_expression_id` to the maximum encountered expression ID.
- #[inline(always)]
- fn update_max_expression_id(&mut self, expression_id: ExpressionId) {
- self.max_expression_id = self.max_expression_id.max(expression_id);
- }
-
- fn update_from_expression_operand(&mut self, operand: Operand) {
- match operand {
- Operand::Counter(id) => self.update_max_counter_id(id),
- Operand::Expression(id) => self.update_max_expression_id(id),
- Operand::Zero => {}
- }
- }
-
- fn visit_body(&mut self, body: &Body<'_>) {
- for coverage in all_coverage_in_mir_body(body) {
- self.visit_coverage(coverage);
- }
- }
-
- fn visit_coverage(&mut self, coverage: &Coverage) {
- match coverage.kind {
- CoverageKind::Counter { id, .. } => self.update_max_counter_id(id),
- CoverageKind::Expression { id, lhs, rhs, .. } => {
- self.update_max_expression_id(id);
- self.update_from_expression_operand(lhs);
- self.update_from_expression_operand(rhs);
- }
- CoverageKind::Unreachable => {}
- }
- }
-}
-
-fn coverageinfo<'tcx>(tcx: TyCtxt<'tcx>, instance_def: ty::InstanceDef<'tcx>) -> CoverageInfo {
+/// Query implementation for `coverage_ids_info`.
+fn coverage_ids_info<'tcx>(
+ tcx: TyCtxt<'tcx>,
+ instance_def: ty::InstanceDef<'tcx>,
+) -> CoverageIdsInfo {
let mir_body = tcx.instance_mir(instance_def);
- let mut coverage_visitor = CoverageVisitor {
- max_counter_id: CounterId::START,
- max_expression_id: ExpressionId::START,
- };
-
- coverage_visitor.visit_body(mir_body);
-
- // Add 1 to the highest IDs to get the total number of IDs.
- CoverageInfo {
- num_counters: (coverage_visitor.max_counter_id + 1).as_u32(),
- num_expressions: (coverage_visitor.max_expression_id + 1).as_u32(),
- }
-}
+ let max_counter_id = all_coverage_in_mir_body(mir_body)
+ .filter_map(|coverage| match coverage.kind {
+ CoverageKind::CounterIncrement { id } => Some(id),
+ _ => None,
+ })
+ .max()
+ .unwrap_or(CounterId::START);
-fn covered_code_regions(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<&CodeRegion> {
- let body = mir_body(tcx, def_id);
- all_coverage_in_mir_body(body)
- // Not all coverage statements have an attached code region.
- .filter_map(|coverage| coverage.code_region.as_ref())
- .collect()
+ CoverageIdsInfo { max_counter_id }
}
fn all_coverage_in_mir_body<'a, 'tcx>(
@@ -115,11 +46,3 @@ fn is_inlined(body: &Body<'_>, statement: &Statement<'_>) -> bool {
let scope_data = &body.source_scopes[statement.source_info.scope];
scope_data.inlined.is_some() || scope_data.inlined_parent_scope.is_some()
}
-
-/// This function ensures we obtain the correct MIR for the given item irrespective of
-/// whether that means const mir or runtime mir. For `const fn` this opts for runtime
-/// mir.
-fn mir_body(tcx: TyCtxt<'_>, def_id: DefId) -> &mir::Body<'_> {
- let def = ty::InstanceDef::Item(def_id);
- tcx.instance_mir(def)
-}
diff --git a/compiler/rustc_mir_transform/src/coverage/spans.rs b/compiler/rustc_mir_transform/src/coverage/spans.rs
index ed0e104d6..b318134ae 100644
--- a/compiler/rustc_mir_transform/src/coverage/spans.rs
+++ b/compiler/rustc_mir_transform/src/coverage/spans.rs
@@ -1,26 +1,48 @@
-use super::graph::{BasicCoverageBlock, BasicCoverageBlockData, CoverageGraph, START_BCB};
+use std::cell::OnceCell;
use rustc_data_structures::graph::WithNumNodes;
-use rustc_middle::mir::{
- self, AggregateKind, BasicBlock, FakeReadCause, Rvalue, Statement, StatementKind, Terminator,
- TerminatorKind,
-};
-use rustc_span::source_map::original_sp;
-use rustc_span::{BytePos, ExpnKind, MacroKind, Span, Symbol};
+use rustc_index::IndexVec;
+use rustc_middle::mir;
+use rustc_span::{BytePos, ExpnKind, MacroKind, Span, Symbol, DUMMY_SP};
-use std::cell::OnceCell;
+use super::graph::{BasicCoverageBlock, CoverageGraph, START_BCB};
+
+mod from_mir;
-#[derive(Debug, Copy, Clone)]
-pub(super) enum CoverageStatement {
- Statement(BasicBlock, Span, usize),
- Terminator(BasicBlock, Span),
+pub(super) struct CoverageSpans {
+ /// Map from BCBs to their list of coverage spans.
+ bcb_to_spans: IndexVec<BasicCoverageBlock, Vec<Span>>,
}
-impl CoverageStatement {
- pub fn span(&self) -> Span {
- match self {
- Self::Statement(_, span, _) | Self::Terminator(_, span) => *span,
+impl CoverageSpans {
+ pub(super) fn generate_coverage_spans(
+ mir_body: &mir::Body<'_>,
+ fn_sig_span: Span,
+ body_span: Span,
+ basic_coverage_blocks: &CoverageGraph,
+ ) -> Self {
+ let coverage_spans = CoverageSpansGenerator::generate_coverage_spans(
+ mir_body,
+ fn_sig_span,
+ body_span,
+ basic_coverage_blocks,
+ );
+
+ // Group the coverage spans by BCB, with the BCBs in sorted order.
+ let mut bcb_to_spans = IndexVec::from_elem_n(Vec::new(), basic_coverage_blocks.num_nodes());
+ for CoverageSpan { bcb, span, .. } in coverage_spans {
+ bcb_to_spans[bcb].push(span);
}
+
+ Self { bcb_to_spans }
+ }
+
+ pub(super) fn bcb_has_coverage_spans(&self, bcb: BasicCoverageBlock) -> bool {
+ !self.bcb_to_spans[bcb].is_empty()
+ }
+
+ pub(super) fn spans_for_bcb(&self, bcb: BasicCoverageBlock) -> &[Span] {
+ &self.bcb_to_spans[bcb]
}
}
@@ -28,87 +50,55 @@ impl CoverageStatement {
/// references the originating BCB and one or more MIR `Statement`s and/or `Terminator`s.
/// Initially, the `Span`s come from the `Statement`s and `Terminator`s, but subsequent
/// transforms can combine adjacent `Span`s and `CoverageSpan` from the same BCB, merging the
-/// `CoverageStatement` vectors, and the `Span`s to cover the extent of the combined `Span`s.
+/// `merged_spans` vectors, and the `Span`s to cover the extent of the combined `Span`s.
///
-/// Note: A `CoverageStatement` merged into another CoverageSpan may come from a `BasicBlock` that
+/// Note: A span merged into another CoverageSpan may come from a `BasicBlock` that
/// is not part of the `CoverageSpan` bcb if the statement was included because it's `Span` matches
/// or is subsumed by the `Span` associated with this `CoverageSpan`, and it's `BasicBlock`
/// `dominates()` the `BasicBlock`s in this `CoverageSpan`.
#[derive(Debug, Clone)]
-pub(super) struct CoverageSpan {
+struct CoverageSpan {
pub span: Span,
pub expn_span: Span,
pub current_macro_or_none: OnceCell<Option<Symbol>>,
pub bcb: BasicCoverageBlock,
- pub coverage_statements: Vec<CoverageStatement>,
+ /// List of all the original spans from MIR that have been merged into this
+ /// span. Mainly used to precisely skip over gaps when truncating a span.
+ pub merged_spans: Vec<Span>,
pub is_closure: bool,
}
impl CoverageSpan {
pub fn for_fn_sig(fn_sig_span: Span) -> Self {
- Self {
- span: fn_sig_span,
- expn_span: fn_sig_span,
- current_macro_or_none: Default::default(),
- bcb: START_BCB,
- coverage_statements: vec![],
- is_closure: false,
- }
+ Self::new(fn_sig_span, fn_sig_span, START_BCB, false)
}
- pub fn for_statement(
- statement: &Statement<'_>,
+ pub(super) fn new(
span: Span,
expn_span: Span,
bcb: BasicCoverageBlock,
- bb: BasicBlock,
- stmt_index: usize,
+ is_closure: bool,
) -> Self {
- let is_closure = match statement.kind {
- StatementKind::Assign(box (_, Rvalue::Aggregate(box ref kind, _))) => {
- matches!(kind, AggregateKind::Closure(_, _) | AggregateKind::Generator(_, _, _))
- }
- _ => false,
- };
-
Self {
span,
expn_span,
current_macro_or_none: Default::default(),
bcb,
- coverage_statements: vec![CoverageStatement::Statement(bb, span, stmt_index)],
+ merged_spans: vec![span],
is_closure,
}
}
- pub fn for_terminator(
- span: Span,
- expn_span: Span,
- bcb: BasicCoverageBlock,
- bb: BasicBlock,
- ) -> Self {
- Self {
- span,
- expn_span,
- current_macro_or_none: Default::default(),
- bcb,
- coverage_statements: vec![CoverageStatement::Terminator(bb, span)],
- is_closure: false,
- }
- }
-
pub fn merge_from(&mut self, mut other: CoverageSpan) {
debug_assert!(self.is_mergeable(&other));
self.span = self.span.to(other.span);
- self.coverage_statements.append(&mut other.coverage_statements);
+ self.merged_spans.append(&mut other.merged_spans);
}
pub fn cutoff_statements_at(&mut self, cutoff_pos: BytePos) {
- self.coverage_statements.retain(|covstmt| covstmt.span().hi() <= cutoff_pos);
- if let Some(highest_covstmt) =
- self.coverage_statements.iter().max_by_key(|covstmt| covstmt.span().hi())
- {
- self.span = self.span.with_hi(highest_covstmt.span().hi());
+ self.merged_spans.retain(|span| span.hi() <= cutoff_pos);
+ if let Some(max_hi) = self.merged_spans.iter().map(|span| span.hi()).max() {
+ self.span = self.span.with_hi(max_hi);
}
}
@@ -139,11 +129,12 @@ impl CoverageSpan {
/// If the span is part of a macro, and the macro is visible (expands directly to the given
/// body_span), returns the macro name symbol.
pub fn visible_macro(&self, body_span: Span) -> Option<Symbol> {
- if let Some(current_macro) = self.current_macro() && self
- .expn_span
- .parent_callsite()
- .unwrap_or_else(|| bug!("macro must have a parent"))
- .eq_ctxt(body_span)
+ if let Some(current_macro) = self.current_macro()
+ && self
+ .expn_span
+ .parent_callsite()
+ .unwrap_or_else(|| bug!("macro must have a parent"))
+ .eq_ctxt(body_span)
{
return Some(current_macro);
}
@@ -162,13 +153,7 @@ impl CoverageSpan {
/// * Merge spans that represent continuous (both in source code and control flow), non-branching
/// execution
/// * Carve out (leave uncovered) any span that will be counted by another MIR (notably, closures)
-pub struct CoverageSpans<'a, 'tcx> {
- /// The MIR, used to look up `BasicBlockData`.
- mir_body: &'a mir::Body<'tcx>,
-
- /// A `Span` covering the signature of function for the MIR.
- fn_sig_span: Span,
-
+struct CoverageSpansGenerator<'a> {
/// A `Span` covering the function body of the MIR (typically from left curly brace to right
/// curly brace).
body_span: Span,
@@ -178,7 +163,7 @@ pub struct CoverageSpans<'a, 'tcx> {
/// The initial set of `CoverageSpan`s, sorted by `Span` (`lo` and `hi`) and by relative
/// dominance between the `BasicCoverageBlock`s of equal `Span`s.
- sorted_spans_iter: Option<std::vec::IntoIter<CoverageSpan>>,
+ sorted_spans_iter: std::vec::IntoIter<CoverageSpan>,
/// The current `CoverageSpan` to compare to its `prev`, to possibly merge, discard, force the
/// discard of the `prev` (and or `pending_dups`), or keep both (with `prev` moved to
@@ -200,9 +185,6 @@ pub struct CoverageSpans<'a, 'tcx> {
/// is mutated.
prev_original_span: Span,
- /// A copy of the expn_span from the prior iteration.
- prev_expn_span: Option<Span>,
-
/// One or more `CoverageSpan`s with the same `Span` but different `BasicCoverageBlock`s, and
/// no `BasicCoverageBlock` in this list dominates another `BasicCoverageBlock` in the list.
/// If a new `curr` span also fits this criteria (compared to an existing list of
@@ -218,7 +200,7 @@ pub struct CoverageSpans<'a, 'tcx> {
refined_spans: Vec<CoverageSpan>,
}
-impl<'a, 'tcx> CoverageSpans<'a, 'tcx> {
+impl<'a> CoverageSpansGenerator<'a> {
/// Generate a minimal set of `CoverageSpan`s, each representing a contiguous code region to be
/// counted.
///
@@ -241,109 +223,79 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> {
/// Note the resulting vector of `CoverageSpan`s may not be fully sorted (and does not need
/// to be).
pub(super) fn generate_coverage_spans(
- mir_body: &'a mir::Body<'tcx>,
+ mir_body: &mir::Body<'_>,
fn_sig_span: Span, // Ensured to be same SourceFile and SyntaxContext as `body_span`
body_span: Span,
basic_coverage_blocks: &'a CoverageGraph,
) -> Vec<CoverageSpan> {
- let mut coverage_spans = CoverageSpans {
+ let sorted_spans = from_mir::mir_to_initial_sorted_coverage_spans(
mir_body,
fn_sig_span,
body_span,
basic_coverage_blocks,
- sorted_spans_iter: None,
- refined_spans: Vec::with_capacity(basic_coverage_blocks.num_nodes() * 2),
+ );
+
+ let coverage_spans = Self {
+ body_span,
+ basic_coverage_blocks,
+ sorted_spans_iter: sorted_spans.into_iter(),
some_curr: None,
- curr_original_span: Span::with_root_ctxt(BytePos(0), BytePos(0)),
+ curr_original_span: DUMMY_SP,
some_prev: None,
- prev_original_span: Span::with_root_ctxt(BytePos(0), BytePos(0)),
- prev_expn_span: None,
+ prev_original_span: DUMMY_SP,
pending_dups: Vec::new(),
+ refined_spans: Vec::with_capacity(basic_coverage_blocks.num_nodes() * 2),
};
- let sorted_spans = coverage_spans.mir_to_initial_sorted_coverage_spans();
-
- coverage_spans.sorted_spans_iter = Some(sorted_spans.into_iter());
-
coverage_spans.to_refined_spans()
}
- fn mir_to_initial_sorted_coverage_spans(&self) -> Vec<CoverageSpan> {
- let mut initial_spans =
- Vec::<CoverageSpan>::with_capacity(self.mir_body.basic_blocks.len() * 2);
- for (bcb, bcb_data) in self.basic_coverage_blocks.iter_enumerated() {
- initial_spans.extend(self.bcb_to_initial_coverage_spans(bcb, bcb_data));
- }
-
- if initial_spans.is_empty() {
- // This can happen if, for example, the function is unreachable (contains only a
- // `BasicBlock`(s) with an `Unreachable` terminator).
- return initial_spans;
- }
-
- initial_spans.push(CoverageSpan::for_fn_sig(self.fn_sig_span));
-
- initial_spans.sort_by(|a, b| {
- // First sort by span start.
- Ord::cmp(&a.span.lo(), &b.span.lo())
- // If span starts are the same, sort by span end in reverse order.
- // This ensures that if spans A and B are adjacent in the list,
- // and they overlap but are not equal, then either:
- // - Span A extends further left, or
- // - Both have the same start and span A extends further right
- .then_with(|| Ord::cmp(&a.span.hi(), &b.span.hi()).reverse())
- // If both spans are equal, sort the BCBs in dominator order,
- // so that dominating BCBs come before other BCBs they dominate.
- .then_with(|| self.basic_coverage_blocks.cmp_in_dominator_order(a.bcb, b.bcb))
- // If two spans are otherwise identical, put closure spans first,
- // as this seems to be what the refinement step expects.
- .then_with(|| Ord::cmp(&a.is_closure, &b.is_closure).reverse())
- });
-
- initial_spans
- }
-
/// Iterate through the sorted `CoverageSpan`s, and return the refined list of merged and
/// de-duplicated `CoverageSpan`s.
fn to_refined_spans(mut self) -> Vec<CoverageSpan> {
while self.next_coverage_span() {
+ // For the first span we don't have `prev` set, so most of the
+ // span-processing steps don't make sense yet.
if self.some_prev.is_none() {
debug!(" initial span");
- self.check_invoked_macro_name_span();
- } else if self.curr().is_mergeable(self.prev()) {
- debug!(" same bcb (and neither is a closure), merge with prev={:?}", self.prev());
+ self.maybe_push_macro_name_span();
+ continue;
+ }
+
+ // The remaining cases assume that `prev` and `curr` are set.
+ let prev = self.prev();
+ let curr = self.curr();
+
+ if curr.is_mergeable(prev) {
+ debug!(" same bcb (and neither is a closure), merge with prev={prev:?}");
let prev = self.take_prev();
self.curr_mut().merge_from(prev);
- self.check_invoked_macro_name_span();
+ self.maybe_push_macro_name_span();
// Note that curr.span may now differ from curr_original_span
- } else if self.prev_ends_before_curr() {
+ } else if prev.span.hi() <= curr.span.lo() {
debug!(
- " different bcbs and disjoint spans, so keep curr for next iter, and add \
- prev={:?}",
- self.prev()
+ " different bcbs and disjoint spans, so keep curr for next iter, and add prev={prev:?}",
);
let prev = self.take_prev();
self.push_refined_span(prev);
- self.check_invoked_macro_name_span();
- } else if self.prev().is_closure {
+ self.maybe_push_macro_name_span();
+ } else if prev.is_closure {
// drop any equal or overlapping span (`curr`) and keep `prev` to test again in the
// next iter
debug!(
- " curr overlaps a closure (prev). Drop curr and keep prev for next iter. \
- prev={:?}",
- self.prev()
+ " curr overlaps a closure (prev). Drop curr and keep prev for next iter. prev={prev:?}",
);
- self.take_curr();
- } else if self.curr().is_closure {
+ self.take_curr(); // Discards curr.
+ } else if curr.is_closure {
self.carve_out_span_for_closure();
- } else if self.prev_original_span == self.curr().span {
+ } else if self.prev_original_span == curr.span {
// Note that this compares the new (`curr`) span to `prev_original_span`.
// In this branch, the actual span byte range of `prev_original_span` is not
// important. What is important is knowing whether the new `curr` span was
// **originally** the same as the original span of `prev()`. The original spans
// reflect their original sort order, and for equal spans, conveys a partial
// ordering based on CFG dominator priority.
- if self.prev().is_macro_expansion() && self.curr().is_macro_expansion() {
+ if prev.is_macro_expansion() && curr.is_macro_expansion() {
// Macros that expand to include branching (such as
// `assert_eq!()`, `assert_ne!()`, `info!()`, `debug!()`, or
// `trace!()`) typically generate callee spans with identical
@@ -357,23 +309,24 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> {
debug!(
" curr and prev are part of a macro expansion, and curr has the same span \
as prev, but is in a different bcb. Drop curr and keep prev for next iter. \
- prev={:?}",
- self.prev()
+ prev={prev:?}",
);
- self.take_curr();
+ self.take_curr(); // Discards curr.
} else {
- self.hold_pending_dups_unless_dominated();
+ self.update_pending_dups();
}
} else {
self.cutoff_prev_at_overlapping_curr();
- self.check_invoked_macro_name_span();
+ self.maybe_push_macro_name_span();
}
}
- debug!(" AT END, adding last prev={:?}", self.prev());
let prev = self.take_prev();
- let pending_dups = self.pending_dups.split_off(0);
- for dup in pending_dups {
+ debug!(" AT END, adding last prev={prev:?}");
+
+ // Take `pending_dups` so that we can drain it while calling self methods.
+ // It is never used as a field after this point.
+ for dup in std::mem::take(&mut self.pending_dups) {
debug!(" ...adding at least one pending dup={:?}", dup);
self.push_refined_span(dup);
}
@@ -403,91 +356,46 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> {
}
fn push_refined_span(&mut self, covspan: CoverageSpan) {
- let len = self.refined_spans.len();
- if len > 0 {
- let last = &mut self.refined_spans[len - 1];
- if last.is_mergeable(&covspan) {
- debug!(
- "merging new refined span with last refined span, last={:?}, covspan={:?}",
- last, covspan
- );
- last.merge_from(covspan);
- return;
- }
+ if let Some(last) = self.refined_spans.last_mut()
+ && last.is_mergeable(&covspan)
+ {
+ // Instead of pushing the new span, merge it with the last refined span.
+ debug!(?last, ?covspan, "merging new refined span with last refined span");
+ last.merge_from(covspan);
+ } else {
+ self.refined_spans.push(covspan);
}
- self.refined_spans.push(covspan)
}
- fn check_invoked_macro_name_span(&mut self) {
- if let Some(visible_macro) = self.curr().visible_macro(self.body_span) {
- if !self
- .prev_expn_span
- .is_some_and(|prev_expn_span| self.curr().expn_span.ctxt() == prev_expn_span.ctxt())
- {
- let merged_prefix_len = self.curr_original_span.lo() - self.curr().span.lo();
- let after_macro_bang =
- merged_prefix_len + BytePos(visible_macro.as_str().len() as u32 + 1);
- if self.curr().span.lo() + after_macro_bang > self.curr().span.hi() {
- // Something is wrong with the macro name span;
- // return now to avoid emitting malformed mappings.
- // FIXME(#117788): Track down why this happens.
- return;
- }
- let mut macro_name_cov = self.curr().clone();
- self.curr_mut().span =
- self.curr().span.with_lo(self.curr().span.lo() + after_macro_bang);
- macro_name_cov.span =
- macro_name_cov.span.with_hi(macro_name_cov.span.lo() + after_macro_bang);
- debug!(
- " and curr starts a new macro expansion, so add a new span just for \
- the macro `{}!`, new span={:?}",
- visible_macro, macro_name_cov
- );
- self.push_refined_span(macro_name_cov);
- }
+ /// If `curr` is part of a new macro expansion, carve out and push a separate
+ /// span that ends just after the macro name and its subsequent `!`.
+ fn maybe_push_macro_name_span(&mut self) {
+ let curr = self.curr();
+
+ let Some(visible_macro) = curr.visible_macro(self.body_span) else { return };
+ if let Some(prev) = &self.some_prev
+ && prev.expn_span.eq_ctxt(curr.expn_span)
+ {
+ return;
}
- }
- // Generate a set of `CoverageSpan`s from the filtered set of `Statement`s and `Terminator`s of
- // the `BasicBlock`(s) in the given `BasicCoverageBlockData`. One `CoverageSpan` is generated
- // for each `Statement` and `Terminator`. (Note that subsequent stages of coverage analysis will
- // merge some `CoverageSpan`s, at which point a `CoverageSpan` may represent multiple
- // `Statement`s and/or `Terminator`s.)
- fn bcb_to_initial_coverage_spans(
- &self,
- bcb: BasicCoverageBlock,
- bcb_data: &'a BasicCoverageBlockData,
- ) -> Vec<CoverageSpan> {
- bcb_data
- .basic_blocks
- .iter()
- .flat_map(|&bb| {
- let data = &self.mir_body[bb];
- data.statements
- .iter()
- .enumerate()
- .filter_map(move |(index, statement)| {
- filtered_statement_span(statement).map(|span| {
- CoverageSpan::for_statement(
- statement,
- function_source_span(span, self.body_span),
- span,
- bcb,
- bb,
- index,
- )
- })
- })
- .chain(filtered_terminator_span(data.terminator()).map(|span| {
- CoverageSpan::for_terminator(
- function_source_span(span, self.body_span),
- span,
- bcb,
- bb,
- )
- }))
- })
- .collect()
+ let merged_prefix_len = self.curr_original_span.lo() - curr.span.lo();
+ let after_macro_bang = merged_prefix_len + BytePos(visible_macro.as_str().len() as u32 + 1);
+ if self.curr().span.lo() + after_macro_bang > self.curr().span.hi() {
+ // Something is wrong with the macro name span;
+ // return now to avoid emitting malformed mappings.
+ // FIXME(#117788): Track down why this happens.
+ return;
+ }
+ let mut macro_name_cov = curr.clone();
+ self.curr_mut().span = curr.span.with_lo(curr.span.lo() + after_macro_bang);
+ macro_name_cov.span =
+ macro_name_cov.span.with_hi(macro_name_cov.span.lo() + after_macro_bang);
+ debug!(
+ " and curr starts a new macro expansion, so add a new span just for \
+ the macro `{visible_macro}!`, new span={macro_name_cov:?}",
+ );
+ self.push_refined_span(macro_name_cov);
}
fn curr(&self) -> &CoverageSpan {
@@ -502,6 +410,12 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> {
.unwrap_or_else(|| bug!("invalid attempt to unwrap a None some_curr"))
}
+ /// If called, then the next call to `next_coverage_span()` will *not* update `prev` with the
+ /// `curr` coverage span.
+ fn take_curr(&mut self) -> CoverageSpan {
+ self.some_curr.take().unwrap_or_else(|| bug!("invalid attempt to unwrap a None some_curr"))
+ }
+
fn prev(&self) -> &CoverageSpan {
self.some_prev
.as_ref()
@@ -527,82 +441,78 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> {
/// `pending_dups` could have as few as one span)
/// In either case, no more spans will match the span of `pending_dups`, so
/// add the `pending_dups` if they don't overlap `curr`, and clear the list.
- fn check_pending_dups(&mut self) {
- if let Some(dup) = self.pending_dups.last() && dup.span != self.prev().span {
- debug!(
- " SAME spans, but pending_dups are NOT THE SAME, so BCBs matched on \
- previous iteration, or prev started a new disjoint span"
- );
- if dup.span.hi() <= self.curr().span.lo() {
- let pending_dups = self.pending_dups.split_off(0);
- for dup in pending_dups.into_iter() {
- debug!(" ...adding at least one pending={:?}", dup);
- self.push_refined_span(dup);
- }
- } else {
- self.pending_dups.clear();
+ fn maybe_flush_pending_dups(&mut self) {
+ let Some(last_dup) = self.pending_dups.last() else { return };
+ if last_dup.span == self.prev().span {
+ return;
+ }
+
+ debug!(
+ " SAME spans, but pending_dups are NOT THE SAME, so BCBs matched on \
+ previous iteration, or prev started a new disjoint span"
+ );
+ if last_dup.span.hi() <= self.curr().span.lo() {
+ // Temporarily steal `pending_dups` into a local, so that we can
+ // drain it while calling other self methods.
+ let mut pending_dups = std::mem::take(&mut self.pending_dups);
+ for dup in pending_dups.drain(..) {
+ debug!(" ...adding at least one pending={:?}", dup);
+ self.push_refined_span(dup);
}
+ // The list of dups is now empty, but we can recycle its capacity.
+ assert!(pending_dups.is_empty() && self.pending_dups.is_empty());
+ self.pending_dups = pending_dups;
+ } else {
+ self.pending_dups.clear();
}
}
/// Advance `prev` to `curr` (if any), and `curr` to the next `CoverageSpan` in sorted order.
fn next_coverage_span(&mut self) -> bool {
if let Some(curr) = self.some_curr.take() {
- self.prev_expn_span = Some(curr.expn_span);
self.some_prev = Some(curr);
self.prev_original_span = self.curr_original_span;
}
- while let Some(curr) = self.sorted_spans_iter.as_mut().unwrap().next() {
+ while let Some(curr) = self.sorted_spans_iter.next() {
debug!("FOR curr={:?}", curr);
- if self.some_prev.is_some() && self.prev_starts_after_next(&curr) {
+ if let Some(prev) = &self.some_prev && prev.span.lo() > curr.span.lo() {
+ // Skip curr because prev has already advanced beyond the end of curr.
+ // This can only happen if a prior iteration updated `prev` to skip past
+ // a region of code, such as skipping past a closure.
debug!(
" prev.span starts after curr.span, so curr will be dropped (skipping past \
- closure?); prev={:?}",
- self.prev()
+ closure?); prev={prev:?}",
);
} else {
// Save a copy of the original span for `curr` in case the `CoverageSpan` is changed
// by `self.curr_mut().merge_from(prev)`.
self.curr_original_span = curr.span;
self.some_curr.replace(curr);
- self.check_pending_dups();
+ self.maybe_flush_pending_dups();
return true;
}
}
false
}
- /// If called, then the next call to `next_coverage_span()` will *not* update `prev` with the
- /// `curr` coverage span.
- fn take_curr(&mut self) -> CoverageSpan {
- self.some_curr.take().unwrap_or_else(|| bug!("invalid attempt to unwrap a None some_curr"))
- }
-
- /// Returns true if the curr span should be skipped because prev has already advanced beyond the
- /// end of curr. This can only happen if a prior iteration updated `prev` to skip past a region
- /// of code, such as skipping past a closure.
- fn prev_starts_after_next(&self, next_curr: &CoverageSpan) -> bool {
- self.prev().span.lo() > next_curr.span.lo()
- }
-
- /// Returns true if the curr span starts past the end of the prev span, which means they don't
- /// overlap, so we now know the prev can be added to the refined coverage spans.
- fn prev_ends_before_curr(&self) -> bool {
- self.prev().span.hi() <= self.curr().span.lo()
- }
-
/// If `prev`s span extends left of the closure (`curr`), carve out the closure's span from
/// `prev`'s span. (The closure's coverage counters will be injected when processing the
/// closure's own MIR.) Add the portion of the span to the left of the closure; and if the span
/// extends to the right of the closure, update `prev` to that portion of the span. For any
/// `pending_dups`, repeat the same process.
fn carve_out_span_for_closure(&mut self) {
- let curr_span = self.curr().span;
- let left_cutoff = curr_span.lo();
- let right_cutoff = curr_span.hi();
- let has_pre_closure_span = self.prev().span.lo() < right_cutoff;
- let has_post_closure_span = self.prev().span.hi() > right_cutoff;
- let mut pending_dups = self.pending_dups.split_off(0);
+ let prev = self.prev();
+ let curr = self.curr();
+
+ let left_cutoff = curr.span.lo();
+ let right_cutoff = curr.span.hi();
+ let has_pre_closure_span = prev.span.lo() < right_cutoff;
+ let has_post_closure_span = prev.span.hi() > right_cutoff;
+
+ // Temporarily steal `pending_dups` into a local, so that we can
+ // mutate and/or drain it while calling other self methods.
+ let mut pending_dups = std::mem::take(&mut self.pending_dups);
+
if has_pre_closure_span {
let mut pre_closure = self.prev().clone();
pre_closure.span = pre_closure.span.with_hi(left_cutoff);
@@ -616,6 +526,7 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> {
}
self.push_refined_span(pre_closure);
}
+
if has_post_closure_span {
// Mutate `prev.span()` to start after the closure (and discard curr).
// (**NEVER** update `prev_original_span` because it affects the assumptions
@@ -626,12 +537,15 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> {
debug!(" ...and at least one overlapping dup={:?}", dup);
dup.span = dup.span.with_lo(right_cutoff);
}
- self.pending_dups.append(&mut pending_dups);
- let closure_covspan = self.take_curr();
+ let closure_covspan = self.take_curr(); // Prevent this curr from becoming prev.
self.push_refined_span(closure_covspan); // since self.prev() was already updated
} else {
pending_dups.clear();
}
+
+ // Restore the modified post-closure spans, or the empty vector's capacity.
+ assert!(self.pending_dups.is_empty());
+ self.pending_dups = pending_dups;
}
/// Called if `curr.span` equals `prev_original_span` (and potentially equal to all
@@ -648,26 +562,28 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> {
/// neither `CoverageSpan` dominates the other, both (or possibly more than two) are held,
/// until their disposition is determined. In this latter case, the `prev` dup is moved into
/// `pending_dups` so the new `curr` dup can be moved to `prev` for the next iteration.
- fn hold_pending_dups_unless_dominated(&mut self) {
+ fn update_pending_dups(&mut self) {
+ let prev_bcb = self.prev().bcb;
+ let curr_bcb = self.curr().bcb;
+
// Equal coverage spans are ordered by dominators before dominated (if any), so it should be
// impossible for `curr` to dominate any previous `CoverageSpan`.
- debug_assert!(!self.span_bcb_dominates(self.curr(), self.prev()));
+ debug_assert!(!self.basic_coverage_blocks.dominates(curr_bcb, prev_bcb));
let initial_pending_count = self.pending_dups.len();
if initial_pending_count > 0 {
- let mut pending_dups = self.pending_dups.split_off(0);
- pending_dups.retain(|dup| !self.span_bcb_dominates(dup, self.curr()));
- self.pending_dups.append(&mut pending_dups);
- if self.pending_dups.len() < initial_pending_count {
+ self.pending_dups
+ .retain(|dup| !self.basic_coverage_blocks.dominates(dup.bcb, curr_bcb));
+
+ let n_discarded = initial_pending_count - self.pending_dups.len();
+ if n_discarded > 0 {
debug!(
- " discarded {} of {} pending_dups that dominated curr",
- initial_pending_count - self.pending_dups.len(),
- initial_pending_count
+ " discarded {n_discarded} of {initial_pending_count} pending_dups that dominated curr",
);
}
}
- if self.span_bcb_dominates(self.prev(), self.curr()) {
+ if self.basic_coverage_blocks.dominates(prev_bcb, curr_bcb) {
debug!(
" different bcbs but SAME spans, and prev dominates curr. Discard prev={:?}",
self.prev()
@@ -720,7 +636,7 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> {
if self.pending_dups.is_empty() {
let curr_span = self.curr().span;
self.prev_mut().cutoff_statements_at(curr_span.lo());
- if self.prev().coverage_statements.is_empty() {
+ if self.prev().merged_spans.is_empty() {
debug!(" ... no non-overlapping statements to add");
} else {
debug!(" ... adding modified prev={:?}", self.prev());
@@ -732,109 +648,4 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> {
self.pending_dups.clear();
}
}
-
- fn span_bcb_dominates(&self, dom_covspan: &CoverageSpan, covspan: &CoverageSpan) -> bool {
- self.basic_coverage_blocks.dominates(dom_covspan.bcb, covspan.bcb)
- }
-}
-
-/// If the MIR `Statement` has a span contributive to computing coverage spans,
-/// return it; otherwise return `None`.
-pub(super) fn filtered_statement_span(statement: &Statement<'_>) -> Option<Span> {
- match statement.kind {
- // These statements have spans that are often outside the scope of the executed source code
- // for their parent `BasicBlock`.
- StatementKind::StorageLive(_)
- | StatementKind::StorageDead(_)
- // Coverage should not be encountered, but don't inject coverage coverage
- | StatementKind::Coverage(_)
- // Ignore `ConstEvalCounter`s
- | StatementKind::ConstEvalCounter
- // Ignore `Nop`s
- | StatementKind::Nop => None,
-
- // FIXME(#78546): MIR InstrumentCoverage - Can the source_info.span for `FakeRead`
- // statements be more consistent?
- //
- // FakeReadCause::ForGuardBinding, in this example:
- // match somenum {
- // x if x < 1 => { ... }
- // }...
- // The BasicBlock within the match arm code included one of these statements, but the span
- // for it covered the `1` in this source. The actual statements have nothing to do with that
- // source span:
- // FakeRead(ForGuardBinding, _4);
- // where `_4` is:
- // _4 = &_1; (at the span for the first `x`)
- // and `_1` is the `Place` for `somenum`.
- //
- // If and when the Issue is resolved, remove this special case match pattern:
- StatementKind::FakeRead(box (FakeReadCause::ForGuardBinding, _)) => None,
-
- // Retain spans from all other statements
- StatementKind::FakeRead(box (_, _)) // Not including `ForGuardBinding`
- | StatementKind::Intrinsic(..)
- | StatementKind::Assign(_)
- | StatementKind::SetDiscriminant { .. }
- | StatementKind::Deinit(..)
- | StatementKind::Retag(_, _)
- | StatementKind::PlaceMention(..)
- | StatementKind::AscribeUserType(_, _) => {
- Some(statement.source_info.span)
- }
- }
-}
-
-/// If the MIR `Terminator` has a span contributive to computing coverage spans,
-/// return it; otherwise return `None`.
-pub(super) fn filtered_terminator_span(terminator: &Terminator<'_>) -> Option<Span> {
- match terminator.kind {
- // These terminators have spans that don't positively contribute to computing a reasonable
- // span of actually executed source code. (For example, SwitchInt terminators extracted from
- // an `if condition { block }` has a span that includes the executed block, if true,
- // but for coverage, the code region executed, up to *and* through the SwitchInt,
- // actually stops before the if's block.)
- TerminatorKind::Unreachable // Unreachable blocks are not connected to the MIR CFG
- | TerminatorKind::Assert { .. }
- | TerminatorKind::Drop { .. }
- | TerminatorKind::SwitchInt { .. }
- // For `FalseEdge`, only the `real` branch is taken, so it is similar to a `Goto`.
- | TerminatorKind::FalseEdge { .. }
- | TerminatorKind::Goto { .. } => None,
-
- // Call `func` operand can have a more specific span when part of a chain of calls
- | TerminatorKind::Call { ref func, .. } => {
- let mut span = terminator.source_info.span;
- if let mir::Operand::Constant(box constant) = func {
- if constant.span.lo() > span.lo() {
- span = span.with_lo(constant.span.lo());
- }
- }
- Some(span)
- }
-
- // Retain spans from all other terminators
- TerminatorKind::UnwindResume
- | TerminatorKind::UnwindTerminate(_)
- | TerminatorKind::Return
- | TerminatorKind::Yield { .. }
- | TerminatorKind::GeneratorDrop
- | TerminatorKind::FalseUnwind { .. }
- | TerminatorKind::InlineAsm { .. } => {
- Some(terminator.source_info.span)
- }
- }
-}
-
-/// Returns an extrapolated span (pre-expansion[^1]) corresponding to a range
-/// within the function's body source. This span is guaranteed to be contained
-/// within, or equal to, the `body_span`. If the extrapolated span is not
-/// contained within the `body_span`, the `body_span` is returned.
-///
-/// [^1]Expansions result from Rust syntax including macros, syntactic sugar,
-/// etc.).
-#[inline]
-pub(super) fn function_source_span(span: Span, body_span: Span) -> Span {
- let original_span = original_sp(span, body_span).with_ctxt(body_span.ctxt());
- if body_span.contains(original_span) { original_span } else { body_span }
}
diff --git a/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs b/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs
new file mode 100644
index 000000000..6189e5379
--- /dev/null
+++ b/compiler/rustc_mir_transform/src/coverage/spans/from_mir.rs
@@ -0,0 +1,193 @@
+use rustc_data_structures::captures::Captures;
+use rustc_middle::mir::{
+ self, AggregateKind, FakeReadCause, Rvalue, Statement, StatementKind, Terminator,
+ TerminatorKind,
+};
+use rustc_span::Span;
+
+use crate::coverage::graph::{BasicCoverageBlock, BasicCoverageBlockData, CoverageGraph};
+use crate::coverage::spans::CoverageSpan;
+
+pub(super) fn mir_to_initial_sorted_coverage_spans(
+ mir_body: &mir::Body<'_>,
+ fn_sig_span: Span,
+ body_span: Span,
+ basic_coverage_blocks: &CoverageGraph,
+) -> Vec<CoverageSpan> {
+ let mut initial_spans = Vec::with_capacity(mir_body.basic_blocks.len() * 2);
+ for (bcb, bcb_data) in basic_coverage_blocks.iter_enumerated() {
+ initial_spans.extend(bcb_to_initial_coverage_spans(mir_body, body_span, bcb, bcb_data));
+ }
+
+ if initial_spans.is_empty() {
+ // This can happen if, for example, the function is unreachable (contains only a
+ // `BasicBlock`(s) with an `Unreachable` terminator).
+ return initial_spans;
+ }
+
+ initial_spans.push(CoverageSpan::for_fn_sig(fn_sig_span));
+
+ initial_spans.sort_by(|a, b| {
+ // First sort by span start.
+ Ord::cmp(&a.span.lo(), &b.span.lo())
+ // If span starts are the same, sort by span end in reverse order.
+ // This ensures that if spans A and B are adjacent in the list,
+ // and they overlap but are not equal, then either:
+ // - Span A extends further left, or
+ // - Both have the same start and span A extends further right
+ .then_with(|| Ord::cmp(&a.span.hi(), &b.span.hi()).reverse())
+ // If both spans are equal, sort the BCBs in dominator order,
+ // so that dominating BCBs come before other BCBs they dominate.
+ .then_with(|| basic_coverage_blocks.cmp_in_dominator_order(a.bcb, b.bcb))
+ // If two spans are otherwise identical, put closure spans first,
+ // as this seems to be what the refinement step expects.
+ .then_with(|| Ord::cmp(&a.is_closure, &b.is_closure).reverse())
+ });
+
+ initial_spans
+}
+
+// Generate a set of `CoverageSpan`s from the filtered set of `Statement`s and `Terminator`s of
+// the `BasicBlock`(s) in the given `BasicCoverageBlockData`. One `CoverageSpan` is generated
+// for each `Statement` and `Terminator`. (Note that subsequent stages of coverage analysis will
+// merge some `CoverageSpan`s, at which point a `CoverageSpan` may represent multiple
+// `Statement`s and/or `Terminator`s.)
+fn bcb_to_initial_coverage_spans<'a, 'tcx>(
+ mir_body: &'a mir::Body<'tcx>,
+ body_span: Span,
+ bcb: BasicCoverageBlock,
+ bcb_data: &'a BasicCoverageBlockData,
+) -> impl Iterator<Item = CoverageSpan> + Captures<'a> + Captures<'tcx> {
+ bcb_data.basic_blocks.iter().flat_map(move |&bb| {
+ let data = &mir_body[bb];
+
+ let statement_spans = data.statements.iter().filter_map(move |statement| {
+ let expn_span = filtered_statement_span(statement)?;
+ let span = function_source_span(expn_span, body_span);
+
+ Some(CoverageSpan::new(span, expn_span, bcb, is_closure(statement)))
+ });
+
+ let terminator_span = Some(data.terminator()).into_iter().filter_map(move |terminator| {
+ let expn_span = filtered_terminator_span(terminator)?;
+ let span = function_source_span(expn_span, body_span);
+
+ Some(CoverageSpan::new(span, expn_span, bcb, false))
+ });
+
+ statement_spans.chain(terminator_span)
+ })
+}
+
+fn is_closure(statement: &Statement<'_>) -> bool {
+ match statement.kind {
+ StatementKind::Assign(box (_, Rvalue::Aggregate(box ref agg_kind, _))) => match agg_kind {
+ AggregateKind::Closure(_, _) | AggregateKind::Coroutine(_, _, _) => true,
+ _ => false,
+ },
+ _ => false,
+ }
+}
+
+/// If the MIR `Statement` has a span contributive to computing coverage spans,
+/// return it; otherwise return `None`.
+fn filtered_statement_span(statement: &Statement<'_>) -> Option<Span> {
+ match statement.kind {
+ // These statements have spans that are often outside the scope of the executed source code
+ // for their parent `BasicBlock`.
+ StatementKind::StorageLive(_)
+ | StatementKind::StorageDead(_)
+ // Coverage should not be encountered, but don't inject coverage coverage
+ | StatementKind::Coverage(_)
+ // Ignore `ConstEvalCounter`s
+ | StatementKind::ConstEvalCounter
+ // Ignore `Nop`s
+ | StatementKind::Nop => None,
+
+ // FIXME(#78546): MIR InstrumentCoverage - Can the source_info.span for `FakeRead`
+ // statements be more consistent?
+ //
+ // FakeReadCause::ForGuardBinding, in this example:
+ // match somenum {
+ // x if x < 1 => { ... }
+ // }...
+ // The BasicBlock within the match arm code included one of these statements, but the span
+ // for it covered the `1` in this source. The actual statements have nothing to do with that
+ // source span:
+ // FakeRead(ForGuardBinding, _4);
+ // where `_4` is:
+ // _4 = &_1; (at the span for the first `x`)
+ // and `_1` is the `Place` for `somenum`.
+ //
+ // If and when the Issue is resolved, remove this special case match pattern:
+ StatementKind::FakeRead(box (FakeReadCause::ForGuardBinding, _)) => None,
+
+ // Retain spans from all other statements
+ StatementKind::FakeRead(box (_, _)) // Not including `ForGuardBinding`
+ | StatementKind::Intrinsic(..)
+ | StatementKind::Assign(_)
+ | StatementKind::SetDiscriminant { .. }
+ | StatementKind::Deinit(..)
+ | StatementKind::Retag(_, _)
+ | StatementKind::PlaceMention(..)
+ | StatementKind::AscribeUserType(_, _) => {
+ Some(statement.source_info.span)
+ }
+ }
+}
+
+/// If the MIR `Terminator` has a span contributive to computing coverage spans,
+/// return it; otherwise return `None`.
+fn filtered_terminator_span(terminator: &Terminator<'_>) -> Option<Span> {
+ match terminator.kind {
+ // These terminators have spans that don't positively contribute to computing a reasonable
+ // span of actually executed source code. (For example, SwitchInt terminators extracted from
+ // an `if condition { block }` has a span that includes the executed block, if true,
+ // but for coverage, the code region executed, up to *and* through the SwitchInt,
+ // actually stops before the if's block.)
+ TerminatorKind::Unreachable // Unreachable blocks are not connected to the MIR CFG
+ | TerminatorKind::Assert { .. }
+ | TerminatorKind::Drop { .. }
+ | TerminatorKind::SwitchInt { .. }
+ // For `FalseEdge`, only the `real` branch is taken, so it is similar to a `Goto`.
+ | TerminatorKind::FalseEdge { .. }
+ | TerminatorKind::Goto { .. } => None,
+
+ // Call `func` operand can have a more specific span when part of a chain of calls
+ | TerminatorKind::Call { ref func, .. } => {
+ let mut span = terminator.source_info.span;
+ if let mir::Operand::Constant(box constant) = func {
+ if constant.span.lo() > span.lo() {
+ span = span.with_lo(constant.span.lo());
+ }
+ }
+ Some(span)
+ }
+
+ // Retain spans from all other terminators
+ TerminatorKind::UnwindResume
+ | TerminatorKind::UnwindTerminate(_)
+ | TerminatorKind::Return
+ | TerminatorKind::Yield { .. }
+ | TerminatorKind::CoroutineDrop
+ | TerminatorKind::FalseUnwind { .. }
+ | TerminatorKind::InlineAsm { .. } => {
+ Some(terminator.source_info.span)
+ }
+ }
+}
+
+/// Returns an extrapolated span (pre-expansion[^1]) corresponding to a range
+/// within the function's body source. This span is guaranteed to be contained
+/// within, or equal to, the `body_span`. If the extrapolated span is not
+/// contained within the `body_span`, the `body_span` is returned.
+///
+/// [^1]Expansions result from Rust syntax including macros, syntactic sugar,
+/// etc.).
+#[inline]
+fn function_source_span(span: Span, body_span: Span) -> Span {
+ use rustc_span::source_map::original_sp;
+
+ let original_span = original_sp(span, body_span).with_ctxt(body_span.ctxt());
+ if body_span.contains(original_span) { original_span } else { body_span }
+}
diff --git a/compiler/rustc_mir_transform/src/coverage/tests.rs b/compiler/rustc_mir_transform/src/coverage/tests.rs
index 4a066ed3a..702fe5f56 100644
--- a/compiler/rustc_mir_transform/src/coverage/tests.rs
+++ b/compiler/rustc_mir_transform/src/coverage/tests.rs
@@ -25,8 +25,7 @@
//! to: `rustc_span::create_default_session_globals_then(|| { test_here(); })`.
use super::counters;
-use super::graph;
-use super::spans;
+use super::graph::{self, BasicCoverageBlock};
use coverage_test_macros::let_bcb;
@@ -242,7 +241,7 @@ fn print_coverage_graphviz(
" {:?} [label=\"{:?}: {}\"];\n{}",
bcb,
bcb,
- bcb_data.terminator(mir_body).kind.name(),
+ mir_body[bcb_data.last_bb()].terminator().kind.name(),
basic_coverage_blocks
.successors(bcb)
.map(|successor| { format!(" {:?} -> {:?};", bcb, successor) })
@@ -629,7 +628,7 @@ fn test_traverse_coverage_with_loops() {
let basic_coverage_blocks = graph::CoverageGraph::from_mir(&mir_body);
let mut traversed_in_order = Vec::new();
let mut traversal = graph::TraverseCoverageGraphWithLoops::new(&basic_coverage_blocks);
- while let Some(bcb) = traversal.next(&basic_coverage_blocks) {
+ while let Some(bcb) = traversal.next() {
traversed_in_order.push(bcb);
}
@@ -644,41 +643,18 @@ fn test_traverse_coverage_with_loops() {
);
}
-fn synthesize_body_span_from_terminators(mir_body: &Body<'_>) -> Span {
- let mut some_span: Option<Span> = None;
- for (_, data) in mir_body.basic_blocks.iter_enumerated() {
- let term_span = data.terminator().source_info.span;
- if let Some(span) = some_span.as_mut() {
- *span = span.to(term_span);
- } else {
- some_span = Some(term_span)
- }
- }
- some_span.expect("body must have at least one BasicBlock")
-}
-
#[test]
fn test_make_bcb_counters() {
rustc_span::create_default_session_globals_then(|| {
let mir_body = goto_switchint();
- let body_span = synthesize_body_span_from_terminators(&mir_body);
- let mut basic_coverage_blocks = graph::CoverageGraph::from_mir(&mir_body);
- let mut coverage_spans = Vec::new();
- for (bcb, data) in basic_coverage_blocks.iter_enumerated() {
- if let Some(span) = spans::filtered_terminator_span(data.terminator(&mir_body)) {
- coverage_spans.push(spans::CoverageSpan::for_terminator(
- spans::function_source_span(span, body_span),
- span,
- bcb,
- data.last_bb(),
- ));
- }
- }
+ let basic_coverage_blocks = graph::CoverageGraph::from_mir(&mir_body);
+ // Historically this test would use `spans` internals to set up fake
+ // coverage spans for BCBs 1 and 2. Now we skip that step and just tell
+ // BCB counter construction that those BCBs have spans.
+ let bcb_has_coverage_spans = |bcb: BasicCoverageBlock| (1..=2).contains(&bcb.as_usize());
let mut coverage_counters = counters::CoverageCounters::new(&basic_coverage_blocks);
- coverage_counters
- .make_bcb_counters(&mut basic_coverage_blocks, &coverage_spans)
- .expect("should be Ok");
- assert_eq!(coverage_counters.intermediate_expressions.len(), 0);
+ coverage_counters.make_bcb_counters(&basic_coverage_blocks, bcb_has_coverage_spans);
+ assert_eq!(coverage_counters.num_expressions(), 0);
let_bcb!(1);
assert_eq!(
diff --git a/compiler/rustc_mir_transform/src/cross_crate_inline.rs b/compiler/rustc_mir_transform/src/cross_crate_inline.rs
new file mode 100644
index 000000000..261d9dd44
--- /dev/null
+++ b/compiler/rustc_mir_transform/src/cross_crate_inline.rs
@@ -0,0 +1,130 @@
+use crate::inline;
+use crate::pass_manager as pm;
+use rustc_attr::InlineAttr;
+use rustc_hir::def::DefKind;
+use rustc_hir::def_id::LocalDefId;
+use rustc_middle::mir::visit::Visitor;
+use rustc_middle::mir::*;
+use rustc_middle::query::Providers;
+use rustc_middle::ty::TyCtxt;
+use rustc_session::config::InliningThreshold;
+use rustc_session::config::OptLevel;
+
+pub fn provide(providers: &mut Providers) {
+ providers.cross_crate_inlinable = cross_crate_inlinable;
+}
+
+fn cross_crate_inlinable(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool {
+ let codegen_fn_attrs = tcx.codegen_fn_attrs(def_id);
+ // If this has an extern indicator, then this function is globally shared and thus will not
+ // generate cgu-internal copies which would make it cross-crate inlinable.
+ if codegen_fn_attrs.contains_extern_indicator() {
+ return false;
+ }
+
+ // Obey source annotations first; this is important because it means we can use
+ // #[inline(never)] to force code generation.
+ match codegen_fn_attrs.inline {
+ InlineAttr::Never => return false,
+ InlineAttr::Hint | InlineAttr::Always => return true,
+ _ => {}
+ }
+
+ // This just reproduces the logic from Instance::requires_inline.
+ match tcx.def_kind(def_id) {
+ DefKind::Ctor(..) | DefKind::Closure => return true,
+ DefKind::Fn | DefKind::AssocFn => {}
+ _ => return false,
+ }
+
+ // Don't do any inference when incremental compilation is enabled; the additional inlining that
+ // inference permits also creates more work for small edits.
+ if tcx.sess.opts.incremental.is_some() {
+ return false;
+ }
+
+ // Don't do any inference if codegen optimizations are disabled and also MIR inlining is not
+ // enabled. This ensures that we do inference even if someone only passes -Zinline-mir,
+ // which is less confusing than having to also enable -Copt-level=1.
+ if matches!(tcx.sess.opts.optimize, OptLevel::No) && !pm::should_run_pass(tcx, &inline::Inline)
+ {
+ return false;
+ }
+
+ if !tcx.is_mir_available(def_id) {
+ return false;
+ }
+
+ let threshold = match tcx.sess.opts.unstable_opts.cross_crate_inline_threshold {
+ InliningThreshold::Always => return true,
+ InliningThreshold::Sometimes(threshold) => threshold,
+ InliningThreshold::Never => return false,
+ };
+
+ let mir = tcx.optimized_mir(def_id);
+ let mut checker =
+ CostChecker { tcx, callee_body: mir, calls: 0, statements: 0, landing_pads: 0, resumes: 0 };
+ checker.visit_body(mir);
+ checker.calls == 0
+ && checker.resumes == 0
+ && checker.landing_pads == 0
+ && checker.statements <= threshold
+}
+
+struct CostChecker<'b, 'tcx> {
+ tcx: TyCtxt<'tcx>,
+ callee_body: &'b Body<'tcx>,
+ calls: usize,
+ statements: usize,
+ landing_pads: usize,
+ resumes: usize,
+}
+
+impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
+ fn visit_statement(&mut self, statement: &Statement<'tcx>, _: Location) {
+ // Don't count StorageLive/StorageDead in the inlining cost.
+ match statement.kind {
+ StatementKind::StorageLive(_)
+ | StatementKind::StorageDead(_)
+ | StatementKind::Deinit(_)
+ | StatementKind::Nop => {}
+ _ => self.statements += 1,
+ }
+ }
+
+ fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, _: Location) {
+ let tcx = self.tcx;
+ match terminator.kind {
+ TerminatorKind::Drop { ref place, unwind, .. } => {
+ let ty = place.ty(self.callee_body, tcx).ty;
+ if !ty.is_trivially_pure_clone_copy() {
+ self.calls += 1;
+ if let UnwindAction::Cleanup(_) = unwind {
+ self.landing_pads += 1;
+ }
+ }
+ }
+ TerminatorKind::Call { unwind, .. } => {
+ self.calls += 1;
+ if let UnwindAction::Cleanup(_) = unwind {
+ self.landing_pads += 1;
+ }
+ }
+ TerminatorKind::Assert { unwind, .. } => {
+ self.calls += 1;
+ if let UnwindAction::Cleanup(_) = unwind {
+ self.landing_pads += 1;
+ }
+ }
+ TerminatorKind::UnwindResume => self.resumes += 1,
+ TerminatorKind::InlineAsm { unwind, .. } => {
+ self.statements += 1;
+ if let UnwindAction::Cleanup(_) = unwind {
+ self.landing_pads += 1;
+ }
+ }
+ TerminatorKind::Return => {}
+ _ => self.statements += 1,
+ }
+ }
+}
diff --git a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs
index 7b14fef61..81d2bba98 100644
--- a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs
+++ b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs
@@ -2,14 +2,13 @@
//!
//! Currently, this pass only propagates scalar values.
-use rustc_const_eval::const_eval::CheckAlignment;
-use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
+use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, PlaceTy, Projectable};
use rustc_data_structures::fx::FxHashMap;
use rustc_hir::def::DefKind;
use rustc_middle::mir::interpret::{AllocId, ConstAllocation, InterpResult, Scalar};
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
use rustc_middle::mir::*;
-use rustc_middle::ty::layout::TyAndLayout;
+use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_mir_dataflow::value_analysis::{
Map, PlaceIndex, State, TrackElem, ValueAnalysis, ValueAnalysisWrapper, ValueOrPlace,
@@ -17,8 +16,9 @@ use rustc_mir_dataflow::value_analysis::{
use rustc_mir_dataflow::{lattice::FlatSet, Analysis, Results, ResultsVisitor};
use rustc_span::def_id::DefId;
use rustc_span::DUMMY_SP;
-use rustc_target::abi::{Align, FieldIdx, VariantIdx};
+use rustc_target::abi::{Abi, FieldIdx, Size, VariantIdx, FIRST_VARIANT};
+use crate::const_prop::throw_machine_stop_str;
use crate::MirPass;
// These constants are somewhat random guesses and have not been optimized.
@@ -286,9 +286,9 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> {
let val = match null_op {
NullOp::SizeOf if layout.is_sized() => layout.size.bytes(),
NullOp::AlignOf if layout.is_sized() => layout.align.abi.bytes(),
- NullOp::OffsetOf(fields) => layout
- .offset_of_subfield(&self.ecx, fields.iter().map(|f| f.index()))
- .bytes(),
+ NullOp::OffsetOf(fields) => {
+ layout.offset_of_subfield(&self.ecx, fields.iter()).bytes()
+ }
_ => return ValueOrPlace::Value(FlatSet::Top),
};
FlatSet::Elem(Scalar::from_target_usize(val, &self.tcx))
@@ -406,7 +406,8 @@ impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> {
TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).ok(),
TrackElem::Discriminant => {
let variant = self.ecx.read_discriminant(op).ok()?;
- let discr_value = self.ecx.discriminant_for_variant(op.layout, variant).ok()?;
+ let discr_value =
+ self.ecx.discriminant_for_variant(op.layout.ty, variant).ok()?;
Some(discr_value.into())
}
TrackElem::DerefLen => {
@@ -507,7 +508,8 @@ impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> {
return None;
}
let enum_ty_layout = self.tcx.layout_of(self.param_env.and(enum_ty)).ok()?;
- let discr_value = self.ecx.discriminant_for_variant(enum_ty_layout, variant_index).ok()?;
+ let discr_value =
+ self.ecx.discriminant_for_variant(enum_ty_layout.ty, variant_index).ok()?;
Some(discr_value.to_scalar())
}
@@ -554,16 +556,151 @@ impl<'tcx, 'locals> Collector<'tcx, 'locals> {
fn try_make_constant(
&self,
+ ecx: &mut InterpCx<'tcx, 'tcx, DummyMachine>,
place: Place<'tcx>,
state: &State<FlatSet<Scalar>>,
map: &Map,
) -> Option<Const<'tcx>> {
- let FlatSet::Elem(Scalar::Int(value)) = state.get(place.as_ref(), &map) else {
- return None;
- };
let ty = place.ty(self.local_decls, self.patch.tcx).ty;
- Some(Const::Val(ConstValue::Scalar(value.into()), ty))
+ let layout = ecx.layout_of(ty).ok()?;
+
+ if layout.is_zst() {
+ return Some(Const::zero_sized(ty));
+ }
+
+ if layout.is_unsized() {
+ return None;
+ }
+
+ let place = map.find(place.as_ref())?;
+ if layout.abi.is_scalar()
+ && let Some(value) = propagatable_scalar(place, state, map)
+ {
+ return Some(Const::Val(ConstValue::Scalar(value), ty));
+ }
+
+ if matches!(layout.abi, Abi::Scalar(..) | Abi::ScalarPair(..)) {
+ let alloc_id = ecx
+ .intern_with_temp_alloc(layout, |ecx, dest| {
+ try_write_constant(ecx, dest, place, ty, state, map)
+ })
+ .ok()?;
+ return Some(Const::Val(ConstValue::Indirect { alloc_id, offset: Size::ZERO }, ty));
+ }
+
+ None
+ }
+}
+
+fn propagatable_scalar(
+ place: PlaceIndex,
+ state: &State<FlatSet<Scalar>>,
+ map: &Map,
+) -> Option<Scalar> {
+ if let FlatSet::Elem(value) = state.get_idx(place, map) && value.try_to_int().is_ok() {
+ // Do not attempt to propagate pointers, as we may fail to preserve their identity.
+ Some(value)
+ } else {
+ None
+ }
+}
+
+#[instrument(level = "trace", skip(ecx, state, map))]
+fn try_write_constant<'tcx>(
+ ecx: &mut InterpCx<'_, 'tcx, DummyMachine>,
+ dest: &PlaceTy<'tcx>,
+ place: PlaceIndex,
+ ty: Ty<'tcx>,
+ state: &State<FlatSet<Scalar>>,
+ map: &Map,
+) -> InterpResult<'tcx> {
+ let layout = ecx.layout_of(ty)?;
+
+ // Fast path for ZSTs.
+ if layout.is_zst() {
+ return Ok(());
+ }
+
+ // Fast path for scalars.
+ if layout.abi.is_scalar()
+ && let Some(value) = propagatable_scalar(place, state, map)
+ {
+ return ecx.write_immediate(Immediate::Scalar(value), dest);
+ }
+
+ match ty.kind() {
+ // ZSTs. Nothing to do.
+ ty::FnDef(..) => {}
+
+ // Those are scalars, must be handled above.
+ ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => throw_machine_stop_str!("primitive type with provenance"),
+
+ ty::Tuple(elem_tys) => {
+ for (i, elem) in elem_tys.iter().enumerate() {
+ let Some(field) = map.apply(place, TrackElem::Field(FieldIdx::from_usize(i))) else {
+ throw_machine_stop_str!("missing field in tuple")
+ };
+ let field_dest = ecx.project_field(dest, i)?;
+ try_write_constant(ecx, &field_dest, field, elem, state, map)?;
+ }
+ }
+
+ ty::Adt(def, args) => {
+ if def.is_union() {
+ throw_machine_stop_str!("cannot propagate unions")
+ }
+
+ let (variant_idx, variant_def, variant_place, variant_dest) = if def.is_enum() {
+ let Some(discr) = map.apply(place, TrackElem::Discriminant) else {
+ throw_machine_stop_str!("missing discriminant for enum")
+ };
+ let FlatSet::Elem(Scalar::Int(discr)) = state.get_idx(discr, map) else {
+ throw_machine_stop_str!("discriminant with provenance")
+ };
+ let discr_bits = discr.assert_bits(discr.size());
+ let Some((variant, _)) = def.discriminants(*ecx.tcx).find(|(_, var)| discr_bits == var.val) else {
+ throw_machine_stop_str!("illegal discriminant for enum")
+ };
+ let Some(variant_place) = map.apply(place, TrackElem::Variant(variant)) else {
+ throw_machine_stop_str!("missing variant for enum")
+ };
+ let variant_dest = ecx.project_downcast(dest, variant)?;
+ (variant, def.variant(variant), variant_place, variant_dest)
+ } else {
+ (FIRST_VARIANT, def.non_enum_variant(), place, dest.clone())
+ };
+
+ for (i, field) in variant_def.fields.iter_enumerated() {
+ let ty = field.ty(*ecx.tcx, args);
+ let Some(field) = map.apply(variant_place, TrackElem::Field(i)) else {
+ throw_machine_stop_str!("missing field in ADT")
+ };
+ let field_dest = ecx.project_field(&variant_dest, i.as_usize())?;
+ try_write_constant(ecx, &field_dest, field, ty, state, map)?;
+ }
+ ecx.write_discriminant(variant_idx, dest)?;
+ }
+
+ // Unsupported for now.
+ ty::Array(_, _)
+
+ // Do not attempt to support indirection in constants.
+ | ty::Ref(..) | ty::RawPtr(..) | ty::FnPtr(..) | ty::Str | ty::Slice(_)
+
+ | ty::Never
+ | ty::Foreign(..)
+ | ty::Alias(..)
+ | ty::Param(_)
+ | ty::Bound(..)
+ | ty::Placeholder(..)
+ | ty::Closure(..)
+ | ty::Coroutine(..)
+ | ty::Dynamic(..) => throw_machine_stop_str!("unsupported type"),
+
+ ty::Error(_) | ty::Infer(..) | ty::CoroutineWitness(..) => bug!(),
}
+
+ Ok(())
}
impl<'mir, 'tcx>
@@ -581,8 +718,13 @@ impl<'mir, 'tcx>
) {
match &statement.kind {
StatementKind::Assign(box (_, rvalue)) => {
- OperandCollector { state, visitor: self, map: &results.analysis.0.map }
- .visit_rvalue(rvalue, location);
+ OperandCollector {
+ state,
+ visitor: self,
+ ecx: &mut results.analysis.0.ecx,
+ map: &results.analysis.0.map,
+ }
+ .visit_rvalue(rvalue, location);
}
_ => (),
}
@@ -600,7 +742,12 @@ impl<'mir, 'tcx>
// Don't overwrite the assignment if it already uses a constant (to keep the span).
}
StatementKind::Assign(box (place, _)) => {
- if let Some(value) = self.try_make_constant(place, state, &results.analysis.0.map) {
+ if let Some(value) = self.try_make_constant(
+ &mut results.analysis.0.ecx,
+ place,
+ state,
+ &results.analysis.0.map,
+ ) {
self.patch.assignments.insert(location, value);
}
}
@@ -615,8 +762,13 @@ impl<'mir, 'tcx>
terminator: &'mir Terminator<'tcx>,
location: Location,
) {
- OperandCollector { state, visitor: self, map: &results.analysis.0.map }
- .visit_terminator(terminator, location);
+ OperandCollector {
+ state,
+ visitor: self,
+ ecx: &mut results.analysis.0.ecx,
+ map: &results.analysis.0.map,
+ }
+ .visit_terminator(terminator, location);
}
}
@@ -671,6 +823,7 @@ impl<'tcx> MutVisitor<'tcx> for Patch<'tcx> {
struct OperandCollector<'tcx, 'map, 'locals, 'a> {
state: &'a State<FlatSet<Scalar>>,
visitor: &'a mut Collector<'tcx, 'locals>,
+ ecx: &'map mut InterpCx<'tcx, 'tcx, DummyMachine>,
map: &'map Map,
}
@@ -683,7 +836,7 @@ impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> {
location: Location,
) {
if let PlaceElem::Index(local) = elem
- && let Some(value) = self.visitor.try_make_constant(local.into(), self.state, self.map)
+ && let Some(value) = self.visitor.try_make_constant(self.ecx, local.into(), self.state, self.map)
{
self.visitor.patch.before_effect.insert((location, local.into()), value);
}
@@ -691,7 +844,9 @@ impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> {
fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) {
if let Some(place) = operand.place() {
- if let Some(value) = self.visitor.try_make_constant(place, self.state, self.map) {
+ if let Some(value) =
+ self.visitor.try_make_constant(self.ecx, place, self.state, self.map)
+ {
self.visitor.patch.before_effect.insert((location, place), value);
} else if !place.projection.is_empty() {
// Try to propagate into `Index` projections.
@@ -701,7 +856,7 @@ impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> {
}
}
-struct DummyMachine;
+pub(crate) struct DummyMachine;
impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for DummyMachine {
rustc_const_eval::interpret::compile_time_machine!(<'mir, 'tcx>);
@@ -709,22 +864,12 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
const PANIC_ON_ALLOC_FAIL: bool = true;
#[inline(always)]
- fn enforce_alignment(_ecx: &InterpCx<'mir, 'tcx, Self>) -> CheckAlignment {
- // We do not check for alignment to avoid having to carry an `Align`
- // in `ConstValue::ByRef`.
- CheckAlignment::No
+ fn enforce_alignment(_ecx: &InterpCx<'mir, 'tcx, Self>) -> bool {
+ false // no reason to enforce alignment
}
fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool {
- unimplemented!()
- }
- fn alignment_check_failed(
- _ecx: &InterpCx<'mir, 'tcx, Self>,
- _has: Align,
- _required: Align,
- _check: CheckAlignment,
- ) -> interpret::InterpResult<'tcx, ()> {
- unimplemented!()
+ false
}
fn before_access_global(
@@ -736,13 +881,13 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
is_write: bool,
) -> InterpResult<'tcx> {
if is_write {
- crate::const_prop::throw_machine_stop_str!("can't write to global");
+ throw_machine_stop_str!("can't write to global");
}
// If the static allocation is mutable, then we can't const prop it as its content
// might be different at runtime.
if alloc.inner().mutability.is_mut() {
- crate::const_prop::throw_machine_stop_str!("can't access mutable globals in ConstProp");
+ throw_machine_stop_str!("can't access mutable globals in ConstProp");
}
Ok(())
@@ -792,7 +937,7 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
_left: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>,
_right: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>,
) -> interpret::InterpResult<'tcx, (ImmTy<'tcx, Self::Provenance>, bool)> {
- crate::const_prop::throw_machine_stop_str!("can't do pointer arithmetic");
+ throw_machine_stop_str!("can't do pointer arithmetic");
}
fn expose_ptr(
diff --git a/compiler/rustc_mir_transform/src/dead_store_elimination.rs b/compiler/rustc_mir_transform/src/dead_store_elimination.rs
index ef1410504..3d74ef7e3 100644
--- a/compiler/rustc_mir_transform/src/dead_store_elimination.rs
+++ b/compiler/rustc_mir_transform/src/dead_store_elimination.rs
@@ -13,10 +13,10 @@
//!
use crate::util::is_within_packed;
-use rustc_index::bit_set::BitSet;
use rustc_middle::mir::visit::Visitor;
use rustc_middle::mir::*;
use rustc_middle::ty::TyCtxt;
+use rustc_mir_dataflow::debuginfo::debuginfo_locals;
use rustc_mir_dataflow::impls::{
borrowed_locals, LivenessTransferFunction, MaybeTransitiveLiveLocals,
};
@@ -26,8 +26,15 @@ use rustc_mir_dataflow::Analysis;
///
/// The `borrowed` set must be a `BitSet` of all the locals that are ever borrowed in this body. It
/// can be generated via the [`borrowed_locals`] function.
-pub fn eliminate<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, borrowed: &BitSet<Local>) {
- let mut live = MaybeTransitiveLiveLocals::new(borrowed)
+pub fn eliminate<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+ let borrowed_locals = borrowed_locals(body);
+
+ // If the user requests complete debuginfo, mark the locals that appear in it as live, so
+ // we don't remove assignements to them.
+ let mut always_live = debuginfo_locals(body);
+ always_live.union(&borrowed_locals);
+
+ let mut live = MaybeTransitiveLiveLocals::new(&always_live)
.into_engine(tcx, body)
.iterate_to_fixpoint()
.into_results_cursor(body);
@@ -48,7 +55,9 @@ pub fn eliminate<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, borrowed: &BitS
for (index, arg) in args.iter().enumerate().rev() {
if let Operand::Copy(place) = *arg
&& !place.is_indirect()
- && !borrowed.contains(place.local)
+ // Do not skip the transformation if the local is in debuginfo, as we do
+ // not really lose any information for this purpose.
+ && !borrowed_locals.contains(place.local)
&& !state.contains(place.local)
// If `place` is a projection of a disaligned field in a packed ADT,
// the move may be codegened as a pointer to that field.
@@ -75,7 +84,7 @@ pub fn eliminate<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, borrowed: &BitS
StatementKind::Assign(box (place, _))
| StatementKind::SetDiscriminant { place: box place, .. }
| StatementKind::Deinit(box place) => {
- if !place.is_indirect() && !borrowed.contains(place.local) {
+ if !place.is_indirect() && !always_live.contains(place.local) {
live.seek_before_primary_effect(loc);
if !live.get().contains(place.local) {
patch.push(loc);
@@ -126,7 +135,6 @@ impl<'tcx> MirPass<'tcx> for DeadStoreElimination {
}
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
- let borrowed = borrowed_locals(body);
- eliminate(tcx, body, &borrowed);
+ eliminate(tcx, body);
}
}
diff --git a/compiler/rustc_mir_transform/src/deduce_param_attrs.rs b/compiler/rustc_mir_transform/src/deduce_param_attrs.rs
index 79645310a..990cfb05e 100644
--- a/compiler/rustc_mir_transform/src/deduce_param_attrs.rs
+++ b/compiler/rustc_mir_transform/src/deduce_param_attrs.rs
@@ -44,6 +44,7 @@ impl<'tcx> Visitor<'tcx> for DeduceReadOnly {
// Whether mutating though a `&raw const` is allowed is still undecided, so we
// disable any sketchy `readonly` optimizations for now.
// But we only need to do this if the pointer would point into the argument.
+ // IOW: for indirect places, like `&raw (*local).field`, this surely cannot mutate `local`.
!place.is_indirect()
}
PlaceContext::NonMutatingUse(..) | PlaceContext::NonUse(..) => {
diff --git a/compiler/rustc_mir_transform/src/deref_separator.rs b/compiler/rustc_mir_transform/src/deref_separator.rs
index 95898b5b7..42be74570 100644
--- a/compiler/rustc_mir_transform/src/deref_separator.rs
+++ b/compiler/rustc_mir_transform/src/deref_separator.rs
@@ -37,7 +37,7 @@ impl<'a, 'tcx> MutVisitor<'tcx> for DerefChecker<'a, 'tcx> {
for (idx, (p_ref, p_elem)) in place.iter_projections().enumerate() {
if !p_ref.projection.is_empty() && p_elem == ProjectionElem::Deref {
let ty = p_ref.ty(self.local_decls, self.tcx).ty;
- let temp = self.patcher.new_internal_with_info(
+ let temp = self.patcher.new_local_with_info(
ty,
self.local_decls[p_ref.local].source_info.span,
LocalInfo::DerefTemp,
diff --git a/compiler/rustc_mir_transform/src/dest_prop.rs b/compiler/rustc_mir_transform/src/dest_prop.rs
index d9a132e5c..15502adfb 100644
--- a/compiler/rustc_mir_transform/src/dest_prop.rs
+++ b/compiler/rustc_mir_transform/src/dest_prop.rs
@@ -114,7 +114,7 @@
//! approach that only works for some classes of CFGs:
//! - rustc now has a powerful dataflow analysis framework that can handle forwards and backwards
//! analyses efficiently.
-//! - Layout optimizations for generators have been added to improve code generation for
+//! - Layout optimizations for coroutines have been added to improve code generation for
//! async/await, which are very similar in spirit to what this optimization does.
//!
//! Also, rustc now has a simple NRVO pass (see `nrvo.rs`), which handles a subset of the cases that
@@ -244,7 +244,7 @@ impl<'tcx> MirPass<'tcx> for DestinationPropagation {
if round_count != 0 {
// Merging can introduce overlap between moved arguments and/or call destination in an
// unreachable code, which validator considers to be ill-formed.
- remove_dead_blocks(tcx, body);
+ remove_dead_blocks(body);
}
trace!(round_count);
@@ -655,7 +655,7 @@ impl WriteInfo {
// `Drop`s create a `&mut` and so are not considered
}
TerminatorKind::Yield { .. }
- | TerminatorKind::GeneratorDrop
+ | TerminatorKind::CoroutineDrop
| TerminatorKind::FalseEdge { .. }
| TerminatorKind::FalseUnwind { .. } => {
bug!("{:?} not found in this MIR phase", terminator)
diff --git a/compiler/rustc_mir_transform/src/early_otherwise_branch.rs b/compiler/rustc_mir_transform/src/early_otherwise_branch.rs
index 319fb4eaf..6eb6cb069 100644
--- a/compiler/rustc_mir_transform/src/early_otherwise_branch.rs
+++ b/compiler/rustc_mir_transform/src/early_otherwise_branch.rs
@@ -95,6 +95,7 @@ pub struct EarlyOtherwiseBranch;
impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
+ // unsound: https://github.com/rust-lang/rust/issues/95162
sess.mir_opt_level() >= 3 && sess.opts.unstable_opts.unsound_mir_opts
}
diff --git a/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs b/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs
index e51f771e0..1c917a85c 100644
--- a/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs
+++ b/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs
@@ -69,7 +69,7 @@ impl<'tcx, 'a> MutVisitor<'tcx> for ElaborateBoxDerefVisitor<'tcx, 'a> {
let (unique_ty, nonnull_ty, ptr_ty) =
build_ptr_tys(tcx, base_ty.boxed_ty(), self.unique_did, self.nonnull_did);
- let ptr_local = self.patch.new_internal(ptr_ty, source_info.span);
+ let ptr_local = self.patch.new_temp(ptr_ty, source_info.span);
self.patch.add_assign(
location,
diff --git a/compiler/rustc_mir_transform/src/elaborate_drops.rs b/compiler/rustc_mir_transform/src/elaborate_drops.rs
index b62d7da2a..59156b242 100644
--- a/compiler/rustc_mir_transform/src/elaborate_drops.rs
+++ b/compiler/rustc_mir_transform/src/elaborate_drops.rs
@@ -9,9 +9,9 @@ use rustc_mir_dataflow::elaborate_drops::{elaborate_drop, DropFlagState, Unwind}
use rustc_mir_dataflow::elaborate_drops::{DropElaborator, DropFlagMode, DropStyle};
use rustc_mir_dataflow::impls::{MaybeInitializedPlaces, MaybeUninitializedPlaces};
use rustc_mir_dataflow::move_paths::{LookupResult, MoveData, MovePathIndex};
+use rustc_mir_dataflow::on_all_children_bits;
use rustc_mir_dataflow::on_lookup_result_bits;
use rustc_mir_dataflow::MoveDataParamEnv;
-use rustc_mir_dataflow::{on_all_children_bits, on_all_drop_children_bits};
use rustc_mir_dataflow::{Analysis, ResultsCursor};
use rustc_span::Span;
use rustc_target::abi::{FieldIdx, VariantIdx};
@@ -54,16 +54,10 @@ impl<'tcx> MirPass<'tcx> for ElaborateDrops {
let def_id = body.source.def_id();
let param_env = tcx.param_env_reveal_all_normalized(def_id);
- let move_data = match MoveData::gather_moves(body, tcx, param_env) {
- Ok(move_data) => move_data,
- Err((move_data, _)) => {
- tcx.sess.delay_span_bug(
- body.span,
- "No `move_errors` should be allowed in MIR borrowck",
- );
- move_data
- }
- };
+ // For types that do not need dropping, the behaviour is trivial. So we only need to track
+ // init/uninit for types that do need dropping.
+ let move_data =
+ MoveData::gather_moves(&body, tcx, param_env, |ty| ty.needs_drop(tcx, param_env));
let elaborate_patch = {
let env = MoveDataParamEnv { move_data, param_env };
@@ -178,13 +172,19 @@ impl<'a, 'tcx> DropElaborator<'a, 'tcx> for Elaborator<'a, '_, 'tcx> {
let mut some_live = false;
let mut some_dead = false;
let mut children_count = 0;
- on_all_drop_children_bits(self.tcx(), self.body(), self.ctxt.env, path, |child| {
- let (live, dead) = self.ctxt.init_data.maybe_live_dead(child);
- debug!("elaborate_drop: state({:?}) = {:?}", child, (live, dead));
- some_live |= live;
- some_dead |= dead;
- children_count += 1;
- });
+ on_all_children_bits(
+ self.tcx(),
+ self.body(),
+ self.ctxt.move_data(),
+ path,
+ |child| {
+ let (live, dead) = self.ctxt.init_data.maybe_live_dead(child);
+ debug!("elaborate_drop: state({:?}) = {:?}", child, (live, dead));
+ some_live |= live;
+ some_dead |= dead;
+ children_count += 1;
+ },
+ );
((some_live, some_dead), children_count != 1)
}
};
@@ -271,7 +271,7 @@ impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> {
let tcx = self.tcx;
let patch = &mut self.patch;
debug!("create_drop_flag({:?})", self.body.span);
- self.drop_flags[index].get_or_insert_with(|| patch.new_internal(tcx.types.bool, span));
+ self.drop_flags[index].get_or_insert_with(|| patch.new_temp(tcx.types.bool, span));
}
fn drop_flag(&mut self, index: MovePathIndex) -> Option<Place<'tcx>> {
@@ -296,26 +296,36 @@ impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> {
fn collect_drop_flags(&mut self) {
for (bb, data) in self.body.basic_blocks.iter_enumerated() {
let terminator = data.terminator();
- let place = match terminator.kind {
- TerminatorKind::Drop { ref place, .. } => place,
- _ => continue,
- };
-
- self.init_data.seek_before(self.body.terminator_loc(bb));
+ let TerminatorKind::Drop { ref place, .. } = terminator.kind else { continue };
let path = self.move_data().rev_lookup.find(place.as_ref());
debug!("collect_drop_flags: {:?}, place {:?} ({:?})", bb, place, path);
- let path = match path {
- LookupResult::Exact(e) => e,
- LookupResult::Parent(None) => continue,
+ match path {
+ LookupResult::Exact(path) => {
+ self.init_data.seek_before(self.body.terminator_loc(bb));
+ on_all_children_bits(self.tcx, self.body, self.move_data(), path, |child| {
+ let (maybe_live, maybe_dead) = self.init_data.maybe_live_dead(child);
+ debug!(
+ "collect_drop_flags: collecting {:?} from {:?}@{:?} - {:?}",
+ child,
+ place,
+ path,
+ (maybe_live, maybe_dead)
+ );
+ if maybe_live && maybe_dead {
+ self.create_drop_flag(child, terminator.source_info.span)
+ }
+ });
+ }
+ LookupResult::Parent(None) => {}
LookupResult::Parent(Some(parent)) => {
- let (_maybe_live, maybe_dead) = self.init_data.maybe_live_dead(parent);
-
if self.body.local_decls[place.local].is_deref_temp() {
continue;
}
+ self.init_data.seek_before(self.body.terminator_loc(bb));
+ let (_maybe_live, maybe_dead) = self.init_data.maybe_live_dead(parent);
if maybe_dead {
self.tcx.sess.delay_span_bug(
terminator.source_info.span,
@@ -324,80 +334,74 @@ impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> {
),
);
}
- continue;
}
};
-
- on_all_drop_children_bits(self.tcx, self.body, self.env, path, |child| {
- let (maybe_live, maybe_dead) = self.init_data.maybe_live_dead(child);
- debug!(
- "collect_drop_flags: collecting {:?} from {:?}@{:?} - {:?}",
- child,
- place,
- path,
- (maybe_live, maybe_dead)
- );
- if maybe_live && maybe_dead {
- self.create_drop_flag(child, terminator.source_info.span)
- }
- });
}
}
fn elaborate_drops(&mut self) {
+ // This function should mirror what `collect_drop_flags` does.
for (bb, data) in self.body.basic_blocks.iter_enumerated() {
- let loc = Location { block: bb, statement_index: data.statements.len() };
let terminator = data.terminator();
+ let TerminatorKind::Drop { place, target, unwind, replace } = terminator.kind else {
+ continue;
+ };
- match terminator.kind {
- TerminatorKind::Drop { place, target, unwind, replace } => {
- self.init_data.seek_before(loc);
- match self.move_data().rev_lookup.find(place.as_ref()) {
- LookupResult::Exact(path) => {
- let unwind = if data.is_cleanup {
- Unwind::InCleanup
- } else {
- match unwind {
- UnwindAction::Cleanup(cleanup) => Unwind::To(cleanup),
- UnwindAction::Continue => Unwind::To(self.patch.resume_block()),
- UnwindAction::Unreachable => {
- Unwind::To(self.patch.unreachable_cleanup_block())
- }
- UnwindAction::Terminate(reason) => {
- debug_assert_ne!(
- reason,
- UnwindTerminateReason::InCleanup,
- "we are not in a cleanup block, InCleanup reason should be impossible"
- );
- Unwind::To(self.patch.terminate_block(reason))
- }
- }
- };
- elaborate_drop(
- &mut Elaborator { ctxt: self },
- terminator.source_info,
- place,
- path,
- target,
- unwind,
- bb,
- )
+ // This place does not need dropping. It does not have an associated move-path, so the
+ // match below will conservatively keep an unconditional drop. As that drop is useless,
+ // just remove it here and now.
+ if !place
+ .ty(&self.body.local_decls, self.tcx)
+ .ty
+ .needs_drop(self.tcx, self.env.param_env)
+ {
+ self.patch.patch_terminator(bb, TerminatorKind::Goto { target });
+ continue;
+ }
+
+ let path = self.move_data().rev_lookup.find(place.as_ref());
+ match path {
+ LookupResult::Exact(path) => {
+ let unwind = match unwind {
+ _ if data.is_cleanup => Unwind::InCleanup,
+ UnwindAction::Cleanup(cleanup) => Unwind::To(cleanup),
+ UnwindAction::Continue => Unwind::To(self.patch.resume_block()),
+ UnwindAction::Unreachable => {
+ Unwind::To(self.patch.unreachable_cleanup_block())
}
- LookupResult::Parent(..) => {
- if !replace {
- self.tcx.sess.delay_span_bug(
- terminator.source_info.span,
- format!("drop of untracked value {bb:?}"),
- );
- }
- // A drop and replace behind a pointer/array/whatever.
- // The borrow checker requires that these locations are initialized before the assignment,
- // so we just leave an unconditional drop.
- assert!(!data.is_cleanup);
+ UnwindAction::Terminate(reason) => {
+ debug_assert_ne!(
+ reason,
+ UnwindTerminateReason::InCleanup,
+ "we are not in a cleanup block, InCleanup reason should be impossible"
+ );
+ Unwind::To(self.patch.terminate_block(reason))
}
+ };
+ self.init_data.seek_before(self.body.terminator_loc(bb));
+ elaborate_drop(
+ &mut Elaborator { ctxt: self },
+ terminator.source_info,
+ place,
+ path,
+ target,
+ unwind,
+ bb,
+ )
+ }
+ LookupResult::Parent(None) => {}
+ LookupResult::Parent(Some(_)) => {
+ if !replace {
+ self.tcx.sess.delay_span_bug(
+ terminator.source_info.span,
+ format!("drop of untracked value {bb:?}"),
+ );
}
+ // A drop and replace behind a pointer/array/whatever.
+ // The borrow checker requires that these locations are initialized before the assignment,
+ // so we just leave an unconditional drop.
+ assert!(!data.is_cleanup);
}
- _ => continue,
}
}
}
diff --git a/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs b/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs
index d20286084..26fcfad82 100644
--- a/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs
+++ b/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs
@@ -58,7 +58,7 @@ fn has_ffi_unwind_calls(tcx: TyCtxt<'_>, local_def_id: LocalDefId) -> bool {
let body_abi = match body_ty.kind() {
ty::FnDef(..) => body_ty.fn_sig(tcx).abi(),
ty::Closure(..) => Abi::RustCall,
- ty::Generator(..) => Abi::Rust,
+ ty::Coroutine(..) => Abi::Rust,
_ => span_bug!(body.span, "unexpected body ty: {:?}", body_ty),
};
let body_can_unwind = layout::fn_can_unwind(tcx, Some(def_id), body_abi);
diff --git a/compiler/rustc_mir_transform/src/gvn.rs b/compiler/rustc_mir_transform/src/gvn.rs
index 56bdc5a17..dce298e92 100644
--- a/compiler/rustc_mir_transform/src/gvn.rs
+++ b/compiler/rustc_mir_transform/src/gvn.rs
@@ -52,19 +52,59 @@
//! _a = *_b // _b is &Freeze
//! _c = *_b // replaced by _c = _a
//! ```
+//!
+//! # Determinism of constant propagation
+//!
+//! When registering a new `Value`, we attempt to opportunistically evaluate it as a constant.
+//! The evaluated form is inserted in `evaluated` as an `OpTy` or `None` if evaluation failed.
+//!
+//! The difficulty is non-deterministic evaluation of MIR constants. Some `Const` can have
+//! different runtime values each time they are evaluated. This is the case with
+//! `Const::Slice` which have a new pointer each time they are evaluated, and constants that
+//! contain a fn pointer (`AllocId` pointing to a `GlobalAlloc::Function`) pointing to a different
+//! symbol in each codegen unit.
+//!
+//! Meanwhile, we want to be able to read indirect constants. For instance:
+//! ```
+//! static A: &'static &'static u8 = &&63;
+//! fn foo() -> u8 {
+//! **A // We want to replace by 63.
+//! }
+//! fn bar() -> u8 {
+//! b"abc"[1] // We want to replace by 'b'.
+//! }
+//! ```
+//!
+//! The `Value::Constant` variant stores a possibly unevaluated constant. Evaluating that constant
+//! may be non-deterministic. When that happens, we assign a disambiguator to ensure that we do not
+//! merge the constants. See `duplicate_slice` test in `gvn.rs`.
+//!
+//! Second, when writing constants in MIR, we do not write `Const::Slice` or `Const`
+//! that contain `AllocId`s.
+use rustc_const_eval::interpret::{intern_const_alloc_for_constprop, MemoryKind};
+use rustc_const_eval::interpret::{ImmTy, InterpCx, OpTy, Projectable, Scalar};
use rustc_data_structures::fx::{FxHashMap, FxIndexSet};
use rustc_data_structures::graph::dominators::Dominators;
+use rustc_hir::def::DefKind;
use rustc_index::bit_set::BitSet;
use rustc_index::IndexVec;
use rustc_macros::newtype_index;
+use rustc_middle::mir::interpret::GlobalAlloc;
use rustc_middle::mir::visit::*;
use rustc_middle::mir::*;
-use rustc_middle::ty::{self, Ty, TyCtxt};
-use rustc_target::abi::{VariantIdx, FIRST_VARIANT};
+use rustc_middle::ty::adjustment::PointerCoercion;
+use rustc_middle::ty::layout::LayoutOf;
+use rustc_middle::ty::{self, Ty, TyCtxt, TypeAndMut};
+use rustc_span::def_id::DefId;
+use rustc_span::DUMMY_SP;
+use rustc_target::abi::{self, Abi, Size, VariantIdx, FIRST_VARIANT};
+use std::borrow::Cow;
-use crate::ssa::SsaLocals;
+use crate::dataflow_const_prop::DummyMachine;
+use crate::ssa::{AssignedValue, SsaLocals};
use crate::MirPass;
+use either::Either;
pub struct GVN;
@@ -87,21 +127,28 @@ fn propagate_ssa<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let dominators = body.basic_blocks.dominators().clone();
let mut state = VnState::new(tcx, param_env, &ssa, &dominators, &body.local_decls);
- for arg in body.args_iter() {
- if ssa.is_ssa(arg) {
- let value = state.new_opaque().unwrap();
- state.assign(arg, value);
- }
- }
-
- ssa.for_each_assignment_mut(&mut body.basic_blocks, |local, rvalue, location| {
- let value = state.simplify_rvalue(rvalue, location).or_else(|| state.new_opaque()).unwrap();
- // FIXME(#112651) `rvalue` may have a subtype to `local`. We can only mark `local` as
- // reusable if we have an exact type match.
- if state.local_decls[local].ty == rvalue.ty(state.local_decls, tcx) {
+ ssa.for_each_assignment_mut(
+ body.basic_blocks.as_mut_preserves_cfg(),
+ |local, value, location| {
+ let value = match value {
+ // We do not know anything of this assigned value.
+ AssignedValue::Arg | AssignedValue::Terminator(_) => None,
+ // Try to get some insight.
+ AssignedValue::Rvalue(rvalue) => {
+ let value = state.simplify_rvalue(rvalue, location);
+ // FIXME(#112651) `rvalue` may have a subtype to `local`. We can only mark `local` as
+ // reusable if we have an exact type match.
+ if state.local_decls[local].ty != rvalue.ty(state.local_decls, tcx) {
+ return;
+ }
+ value
+ }
+ };
+ // `next_opaque` is `Some`, so `new_opaque` must return `Some`.
+ let value = value.or_else(|| state.new_opaque()).unwrap();
state.assign(local, value);
- }
- });
+ },
+ );
// Stop creating opaques during replacement as it is useless.
state.next_opaque = None;
@@ -111,22 +158,33 @@ fn propagate_ssa<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let data = &mut body.basic_blocks.as_mut_preserves_cfg()[bb];
state.visit_basic_block_data(bb, data);
}
- let any_replacement = state.any_replacement;
// For each local that is reused (`y` above), we remove its storage statements do avoid any
// difficulty. Those locals are SSA, so should be easy to optimize by LLVM without storage
// statements.
StorageRemover { tcx, reused_locals: state.reused_locals }.visit_body_preserves_cfg(body);
-
- if any_replacement {
- crate::simplify::remove_unused_definitions(body);
- }
}
newtype_index! {
struct VnIndex {}
}
+/// Computing the aggregate's type can be quite slow, so we only keep the minimal amount of
+/// information to reconstruct it when needed.
+#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
+enum AggregateTy<'tcx> {
+ /// Invariant: this must not be used for an empty array.
+ Array,
+ Tuple,
+ Def(DefId, ty::GenericArgsRef<'tcx>),
+}
+
+#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
+enum AddressKind {
+ Ref(BorrowKind),
+ Address(Mutability),
+}
+
#[derive(Debug, PartialEq, Eq, Hash)]
enum Value<'tcx> {
// Root values.
@@ -134,15 +192,21 @@ enum Value<'tcx> {
/// The `usize` is a counter incremented by `new_opaque`.
Opaque(usize),
/// Evaluated or unevaluated constant value.
- Constant(Const<'tcx>),
+ Constant {
+ value: Const<'tcx>,
+ /// Some constants do not have a deterministic value. To avoid merging two instances of the
+ /// same `Const`, we assign them an additional integer index.
+ disambiguator: usize,
+ },
/// An aggregate value, either tuple/closure/struct/enum.
/// This does not contain unions, as we cannot reason with the value.
- Aggregate(Ty<'tcx>, VariantIdx, Vec<VnIndex>),
+ Aggregate(AggregateTy<'tcx>, VariantIdx, Vec<VnIndex>),
/// This corresponds to a `[value; count]` expression.
Repeat(VnIndex, ty::Const<'tcx>),
/// The address of a place.
Address {
place: Place<'tcx>,
+ kind: AddressKind,
/// Give each borrow and pointer a different provenance, so we don't merge them.
provenance: usize,
},
@@ -170,6 +234,7 @@ enum Value<'tcx> {
struct VnState<'body, 'tcx> {
tcx: TyCtxt<'tcx>,
+ ecx: InterpCx<'tcx, 'tcx, DummyMachine>,
param_env: ty::ParamEnv<'tcx>,
local_decls: &'body LocalDecls<'tcx>,
/// Value stored in each local.
@@ -177,13 +242,14 @@ struct VnState<'body, 'tcx> {
/// First local to be assigned that value.
rev_locals: FxHashMap<VnIndex, Vec<Local>>,
values: FxIndexSet<Value<'tcx>>,
+ /// Values evaluated as constants if possible.
+ evaluated: IndexVec<VnIndex, Option<OpTy<'tcx>>>,
/// Counter to generate different values.
/// This is an option to stop creating opaques during replacement.
next_opaque: Option<usize>,
ssa: &'body SsaLocals,
dominators: &'body Dominators<BasicBlock>,
reused_locals: BitSet<Local>,
- any_replacement: bool,
}
impl<'body, 'tcx> VnState<'body, 'tcx> {
@@ -196,23 +262,30 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
) -> Self {
VnState {
tcx,
+ ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine),
param_env,
local_decls,
locals: IndexVec::from_elem(None, local_decls),
rev_locals: FxHashMap::default(),
values: FxIndexSet::default(),
+ evaluated: IndexVec::new(),
next_opaque: Some(0),
ssa,
dominators,
reused_locals: BitSet::new_empty(local_decls.len()),
- any_replacement: false,
}
}
#[instrument(level = "trace", skip(self), ret)]
fn insert(&mut self, value: Value<'tcx>) -> VnIndex {
- let (index, _) = self.values.insert_full(value);
- VnIndex::from_usize(index)
+ let (index, new) = self.values.insert_full(value);
+ let index = VnIndex::from_usize(index);
+ if new {
+ let evaluated = self.eval_to_const(index);
+ let _index = self.evaluated.push(evaluated);
+ debug_assert_eq!(index, _index);
+ }
+ index
}
/// Create a new `Value` for which we have no information at all, except that it is distinct
@@ -227,9 +300,9 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
/// Create a new `Value::Address` distinct from all the others.
#[instrument(level = "trace", skip(self), ret)]
- fn new_pointer(&mut self, place: Place<'tcx>) -> Option<VnIndex> {
+ fn new_pointer(&mut self, place: Place<'tcx>, kind: AddressKind) -> Option<VnIndex> {
let next_opaque = self.next_opaque.as_mut()?;
- let value = Value::Address { place, provenance: *next_opaque };
+ let value = Value::Address { place, kind, provenance: *next_opaque };
*next_opaque += 1;
Some(self.insert(value))
}
@@ -251,6 +324,343 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
}
}
+ fn insert_constant(&mut self, value: Const<'tcx>) -> Option<VnIndex> {
+ let disambiguator = if value.is_deterministic() {
+ // The constant is deterministic, no need to disambiguate.
+ 0
+ } else {
+ // Multiple mentions of this constant will yield different values,
+ // so assign a different `disambiguator` to ensure they do not get the same `VnIndex`.
+ let next_opaque = self.next_opaque.as_mut()?;
+ let disambiguator = *next_opaque;
+ *next_opaque += 1;
+ disambiguator
+ };
+ Some(self.insert(Value::Constant { value, disambiguator }))
+ }
+
+ fn insert_scalar(&mut self, scalar: Scalar, ty: Ty<'tcx>) -> VnIndex {
+ self.insert_constant(Const::from_scalar(self.tcx, scalar, ty))
+ .expect("scalars are deterministic")
+ }
+
+ #[instrument(level = "trace", skip(self), ret)]
+ fn eval_to_const(&mut self, value: VnIndex) -> Option<OpTy<'tcx>> {
+ use Value::*;
+ let op = match *self.get(value) {
+ Opaque(_) => return None,
+ // Do not bother evaluating repeat expressions. This would uselessly consume memory.
+ Repeat(..) => return None,
+
+ Constant { ref value, disambiguator: _ } => {
+ self.ecx.eval_mir_constant(value, None, None).ok()?
+ }
+ Aggregate(kind, variant, ref fields) => {
+ let fields = fields
+ .iter()
+ .map(|&f| self.evaluated[f].as_ref())
+ .collect::<Option<Vec<_>>>()?;
+ let ty = match kind {
+ AggregateTy::Array => {
+ assert!(fields.len() > 0);
+ Ty::new_array(self.tcx, fields[0].layout.ty, fields.len() as u64)
+ }
+ AggregateTy::Tuple => {
+ Ty::new_tup_from_iter(self.tcx, fields.iter().map(|f| f.layout.ty))
+ }
+ AggregateTy::Def(def_id, args) => {
+ self.tcx.type_of(def_id).instantiate(self.tcx, args)
+ }
+ };
+ let variant = if ty.is_enum() { Some(variant) } else { None };
+ let ty = self.ecx.layout_of(ty).ok()?;
+ if ty.is_zst() {
+ ImmTy::uninit(ty).into()
+ } else if matches!(ty.abi, Abi::Scalar(..) | Abi::ScalarPair(..)) {
+ let dest = self.ecx.allocate(ty, MemoryKind::Stack).ok()?;
+ let variant_dest = if let Some(variant) = variant {
+ self.ecx.project_downcast(&dest, variant).ok()?
+ } else {
+ dest.clone()
+ };
+ for (field_index, op) in fields.into_iter().enumerate() {
+ let field_dest = self.ecx.project_field(&variant_dest, field_index).ok()?;
+ self.ecx.copy_op(op, &field_dest, /*allow_transmute*/ false).ok()?;
+ }
+ self.ecx.write_discriminant(variant.unwrap_or(FIRST_VARIANT), &dest).ok()?;
+ self.ecx.alloc_mark_immutable(dest.ptr().provenance.unwrap()).ok()?;
+ dest.into()
+ } else {
+ return None;
+ }
+ }
+
+ Projection(base, elem) => {
+ let value = self.evaluated[base].as_ref()?;
+ let elem = match elem {
+ ProjectionElem::Deref => ProjectionElem::Deref,
+ ProjectionElem::Downcast(name, read_variant) => {
+ ProjectionElem::Downcast(name, read_variant)
+ }
+ ProjectionElem::Field(f, ty) => ProjectionElem::Field(f, ty),
+ ProjectionElem::ConstantIndex { offset, min_length, from_end } => {
+ ProjectionElem::ConstantIndex { offset, min_length, from_end }
+ }
+ ProjectionElem::Subslice { from, to, from_end } => {
+ ProjectionElem::Subslice { from, to, from_end }
+ }
+ ProjectionElem::OpaqueCast(ty) => ProjectionElem::OpaqueCast(ty),
+ ProjectionElem::Subtype(ty) => ProjectionElem::Subtype(ty),
+ // This should have been replaced by a `ConstantIndex` earlier.
+ ProjectionElem::Index(_) => return None,
+ };
+ self.ecx.project(value, elem).ok()?
+ }
+ Address { place, kind, provenance: _ } => {
+ if !place.is_indirect_first_projection() {
+ return None;
+ }
+ let local = self.locals[place.local]?;
+ let pointer = self.evaluated[local].as_ref()?;
+ let mut mplace = self.ecx.deref_pointer(pointer).ok()?;
+ for proj in place.projection.iter().skip(1) {
+ // We have no call stack to associate a local with a value, so we cannot interpret indexing.
+ if matches!(proj, ProjectionElem::Index(_)) {
+ return None;
+ }
+ mplace = self.ecx.project(&mplace, proj).ok()?;
+ }
+ let pointer = mplace.to_ref(&self.ecx);
+ let ty = match kind {
+ AddressKind::Ref(bk) => Ty::new_ref(
+ self.tcx,
+ self.tcx.lifetimes.re_erased,
+ ty::TypeAndMut { ty: mplace.layout.ty, mutbl: bk.to_mutbl_lossy() },
+ ),
+ AddressKind::Address(mutbl) => {
+ Ty::new_ptr(self.tcx, TypeAndMut { ty: mplace.layout.ty, mutbl })
+ }
+ };
+ let layout = self.ecx.layout_of(ty).ok()?;
+ ImmTy::from_immediate(pointer, layout).into()
+ }
+
+ Discriminant(base) => {
+ let base = self.evaluated[base].as_ref()?;
+ let variant = self.ecx.read_discriminant(base).ok()?;
+ let discr_value =
+ self.ecx.discriminant_for_variant(base.layout.ty, variant).ok()?;
+ discr_value.into()
+ }
+ Len(slice) => {
+ let slice = self.evaluated[slice].as_ref()?;
+ let usize_layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
+ let len = slice.len(&self.ecx).ok()?;
+ let imm = ImmTy::try_from_uint(len, usize_layout)?;
+ imm.into()
+ }
+ NullaryOp(null_op, ty) => {
+ let layout = self.ecx.layout_of(ty).ok()?;
+ if let NullOp::SizeOf | NullOp::AlignOf = null_op && layout.is_unsized() {
+ return None;
+ }
+ let val = match null_op {
+ NullOp::SizeOf => layout.size.bytes(),
+ NullOp::AlignOf => layout.align.abi.bytes(),
+ NullOp::OffsetOf(fields) => {
+ layout.offset_of_subfield(&self.ecx, fields.iter()).bytes()
+ }
+ };
+ let usize_layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
+ let imm = ImmTy::try_from_uint(val, usize_layout)?;
+ imm.into()
+ }
+ UnaryOp(un_op, operand) => {
+ let operand = self.evaluated[operand].as_ref()?;
+ let operand = self.ecx.read_immediate(operand).ok()?;
+ let (val, _) = self.ecx.overflowing_unary_op(un_op, &operand).ok()?;
+ val.into()
+ }
+ BinaryOp(bin_op, lhs, rhs) => {
+ let lhs = self.evaluated[lhs].as_ref()?;
+ let lhs = self.ecx.read_immediate(lhs).ok()?;
+ let rhs = self.evaluated[rhs].as_ref()?;
+ let rhs = self.ecx.read_immediate(rhs).ok()?;
+ let (val, _) = self.ecx.overflowing_binary_op(bin_op, &lhs, &rhs).ok()?;
+ val.into()
+ }
+ CheckedBinaryOp(bin_op, lhs, rhs) => {
+ let lhs = self.evaluated[lhs].as_ref()?;
+ let lhs = self.ecx.read_immediate(lhs).ok()?;
+ let rhs = self.evaluated[rhs].as_ref()?;
+ let rhs = self.ecx.read_immediate(rhs).ok()?;
+ let (val, overflowed) = self.ecx.overflowing_binary_op(bin_op, &lhs, &rhs).ok()?;
+ let tuple = Ty::new_tup_from_iter(
+ self.tcx,
+ [val.layout.ty, self.tcx.types.bool].into_iter(),
+ );
+ let tuple = self.ecx.layout_of(tuple).ok()?;
+ ImmTy::from_scalar_pair(val.to_scalar(), Scalar::from_bool(overflowed), tuple)
+ .into()
+ }
+ Cast { kind, value, from: _, to } => match kind {
+ CastKind::IntToInt | CastKind::IntToFloat => {
+ let value = self.evaluated[value].as_ref()?;
+ let value = self.ecx.read_immediate(value).ok()?;
+ let to = self.ecx.layout_of(to).ok()?;
+ let res = self.ecx.int_to_int_or_float(&value, to).ok()?;
+ res.into()
+ }
+ CastKind::FloatToFloat | CastKind::FloatToInt => {
+ let value = self.evaluated[value].as_ref()?;
+ let value = self.ecx.read_immediate(value).ok()?;
+ let to = self.ecx.layout_of(to).ok()?;
+ let res = self.ecx.float_to_float_or_int(&value, to).ok()?;
+ res.into()
+ }
+ CastKind::Transmute => {
+ let value = self.evaluated[value].as_ref()?;
+ let to = self.ecx.layout_of(to).ok()?;
+ // `offset` for immediates only supports scalar/scalar-pair ABIs,
+ // so bail out if the target is not one.
+ if value.as_mplace_or_imm().is_right() {
+ match (value.layout.abi, to.abi) {
+ (Abi::Scalar(..), Abi::Scalar(..)) => {}
+ (Abi::ScalarPair(..), Abi::ScalarPair(..)) => {}
+ _ => return None,
+ }
+ }
+ value.offset(Size::ZERO, to, &self.ecx).ok()?
+ }
+ _ => return None,
+ },
+ };
+ Some(op)
+ }
+
+ fn project(
+ &mut self,
+ place: PlaceRef<'tcx>,
+ value: VnIndex,
+ proj: PlaceElem<'tcx>,
+ ) -> Option<VnIndex> {
+ let proj = match proj {
+ ProjectionElem::Deref => {
+ let ty = place.ty(self.local_decls, self.tcx).ty;
+ if let Some(Mutability::Not) = ty.ref_mutability()
+ && let Some(pointee_ty) = ty.builtin_deref(true)
+ && pointee_ty.ty.is_freeze(self.tcx, self.param_env)
+ {
+ // An immutable borrow `_x` always points to the same value for the
+ // lifetime of the borrow, so we can merge all instances of `*_x`.
+ ProjectionElem::Deref
+ } else {
+ return None;
+ }
+ }
+ ProjectionElem::Downcast(name, index) => ProjectionElem::Downcast(name, index),
+ ProjectionElem::Field(f, ty) => {
+ if let Value::Aggregate(_, _, fields) = self.get(value) {
+ return Some(fields[f.as_usize()]);
+ } else if let Value::Projection(outer_value, ProjectionElem::Downcast(_, read_variant)) = self.get(value)
+ && let Value::Aggregate(_, written_variant, fields) = self.get(*outer_value)
+ // This pass is not aware of control-flow, so we do not know whether the
+ // replacement we are doing is actually reachable. We could be in any arm of
+ // ```
+ // match Some(x) {
+ // Some(y) => /* stuff */,
+ // None => /* other */,
+ // }
+ // ```
+ //
+ // In surface rust, the current statement would be unreachable.
+ //
+ // However, from the reference chapter on enums and RFC 2195,
+ // accessing the wrong variant is not UB if the enum has repr.
+ // So it's not impossible for a series of MIR opts to generate
+ // a downcast to an inactive variant.
+ && written_variant == read_variant
+ {
+ return Some(fields[f.as_usize()]);
+ }
+ ProjectionElem::Field(f, ty)
+ }
+ ProjectionElem::Index(idx) => {
+ if let Value::Repeat(inner, _) = self.get(value) {
+ return Some(*inner);
+ }
+ let idx = self.locals[idx]?;
+ ProjectionElem::Index(idx)
+ }
+ ProjectionElem::ConstantIndex { offset, min_length, from_end } => {
+ match self.get(value) {
+ Value::Repeat(inner, _) => {
+ return Some(*inner);
+ }
+ Value::Aggregate(AggregateTy::Array, _, operands) => {
+ let offset = if from_end {
+ operands.len() - offset as usize
+ } else {
+ offset as usize
+ };
+ return operands.get(offset).copied();
+ }
+ _ => {}
+ };
+ ProjectionElem::ConstantIndex { offset, min_length, from_end }
+ }
+ ProjectionElem::Subslice { from, to, from_end } => {
+ ProjectionElem::Subslice { from, to, from_end }
+ }
+ ProjectionElem::OpaqueCast(ty) => ProjectionElem::OpaqueCast(ty),
+ ProjectionElem::Subtype(ty) => ProjectionElem::Subtype(ty),
+ };
+
+ Some(self.insert(Value::Projection(value, proj)))
+ }
+
+ /// Simplify the projection chain if we know better.
+ #[instrument(level = "trace", skip(self))]
+ fn simplify_place_projection(&mut self, place: &mut Place<'tcx>, location: Location) {
+ // If the projection is indirect, we treat the local as a value, so can replace it with
+ // another local.
+ if place.is_indirect()
+ && let Some(base) = self.locals[place.local]
+ && let Some(new_local) = self.try_as_local(base, location)
+ {
+ place.local = new_local;
+ self.reused_locals.insert(new_local);
+ }
+
+ let mut projection = Cow::Borrowed(&place.projection[..]);
+
+ for i in 0..projection.len() {
+ let elem = projection[i];
+ if let ProjectionElem::Index(idx) = elem
+ && let Some(idx) = self.locals[idx]
+ {
+ if let Some(offset) = self.evaluated[idx].as_ref()
+ && let Ok(offset) = self.ecx.read_target_usize(offset)
+ {
+ projection.to_mut()[i] = ProjectionElem::ConstantIndex {
+ offset,
+ min_length: offset + 1,
+ from_end: false,
+ };
+ } else if let Some(new_idx) = self.try_as_local(idx, location) {
+ projection.to_mut()[i] = ProjectionElem::Index(new_idx);
+ self.reused_locals.insert(new_idx);
+ }
+ }
+ }
+
+ if projection.is_owned() {
+ place.projection = self.tcx.mk_place_elems(&projection);
+ }
+
+ trace!(?place);
+ }
+
/// Represent the *value* which would be read from `place`, and point `place` to a preexisting
/// place with the same value (if that already exists).
#[instrument(level = "trace", skip(self), ret)]
@@ -259,6 +669,8 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
place: &mut Place<'tcx>,
location: Location,
) -> Option<VnIndex> {
+ self.simplify_place_projection(place, location);
+
// Invariant: `place` and `place_ref` point to the same value, even if they point to
// different memory locations.
let mut place_ref = place.as_ref();
@@ -273,57 +685,18 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
place_ref = PlaceRef { local, projection: &place.projection[index..] };
}
- let proj = match proj {
- ProjectionElem::Deref => {
- let ty = Place::ty_from(
- place.local,
- &place.projection[..index],
- self.local_decls,
- self.tcx,
- )
- .ty;
- if let Some(Mutability::Not) = ty.ref_mutability()
- && let Some(pointee_ty) = ty.builtin_deref(true)
- && pointee_ty.ty.is_freeze(self.tcx, self.param_env)
- {
- // An immutable borrow `_x` always points to the same value for the
- // lifetime of the borrow, so we can merge all instances of `*_x`.
- ProjectionElem::Deref
- } else {
- return None;
- }
- }
- ProjectionElem::Field(f, ty) => ProjectionElem::Field(f, ty),
- ProjectionElem::Index(idx) => {
- let idx = self.locals[idx]?;
- ProjectionElem::Index(idx)
- }
- ProjectionElem::ConstantIndex { offset, min_length, from_end } => {
- ProjectionElem::ConstantIndex { offset, min_length, from_end }
- }
- ProjectionElem::Subslice { from, to, from_end } => {
- ProjectionElem::Subslice { from, to, from_end }
- }
- ProjectionElem::Downcast(name, index) => ProjectionElem::Downcast(name, index),
- ProjectionElem::OpaqueCast(ty) => ProjectionElem::OpaqueCast(ty),
- ProjectionElem::Subtype(ty) => ProjectionElem::Subtype(ty),
- };
- value = self.insert(Value::Projection(value, proj));
+ let base = PlaceRef { local: place.local, projection: &place.projection[..index] };
+ value = self.project(base, value, proj)?;
}
- if let Some(local) = self.try_as_local(value, location)
- && local != place.local // in case we had no projection to begin with.
- {
- *place = local.into();
- self.reused_locals.insert(local);
- self.any_replacement = true;
- } else if place_ref.local != place.local
- || place_ref.projection.len() < place.projection.len()
- {
+ if let Some(new_local) = self.try_as_local(value, location) {
+ place_ref = PlaceRef { local: new_local, projection: &[] };
+ }
+
+ if place_ref.local != place.local || place_ref.projection.len() < place.projection.len() {
// By the invariant on `place_ref`.
*place = place_ref.project_deeper(&[], self.tcx);
self.reused_locals.insert(place_ref.local);
- self.any_replacement = true;
}
Some(value)
@@ -336,12 +709,14 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
location: Location,
) -> Option<VnIndex> {
match *operand {
- Operand::Constant(ref constant) => Some(self.insert(Value::Constant(constant.const_))),
+ Operand::Constant(ref mut constant) => {
+ let const_ = constant.const_.normalize(self.tcx, self.param_env);
+ self.insert_constant(const_)
+ }
Operand::Copy(ref mut place) | Operand::Move(ref mut place) => {
let value = self.simplify_place_value(place, location)?;
if let Some(const_) = self.try_as_constant(value) {
*operand = Operand::Constant(Box::new(const_));
- self.any_replacement = true;
}
Some(value)
}
@@ -370,24 +745,15 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
Value::Repeat(op, amount)
}
Rvalue::NullaryOp(op, ty) => Value::NullaryOp(op, ty),
- Rvalue::Aggregate(box ref kind, ref mut fields) => {
- let variant_index = match *kind {
- AggregateKind::Array(..)
- | AggregateKind::Tuple
- | AggregateKind::Closure(..)
- | AggregateKind::Generator(..) => FIRST_VARIANT,
- AggregateKind::Adt(_, variant_index, _, _, None) => variant_index,
- // Do not track unions.
- AggregateKind::Adt(_, _, _, _, Some(_)) => return None,
- };
- let fields: Option<Vec<_>> = fields
- .iter_mut()
- .map(|op| self.simplify_operand(op, location).or_else(|| self.new_opaque()))
- .collect();
- let ty = rvalue.ty(self.local_decls, self.tcx);
- Value::Aggregate(ty, variant_index, fields?)
+ Rvalue::Aggregate(..) => return self.simplify_aggregate(rvalue, location),
+ Rvalue::Ref(_, borrow_kind, ref mut place) => {
+ self.simplify_place_projection(place, location);
+ return self.new_pointer(*place, AddressKind::Ref(borrow_kind));
+ }
+ Rvalue::AddressOf(mutbl, ref mut place) => {
+ self.simplify_place_projection(place, location);
+ return self.new_pointer(*place, AddressKind::Address(mutbl));
}
- Rvalue::Ref(.., place) | Rvalue::AddressOf(_, place) => return self.new_pointer(place),
// Operations.
Rvalue::Len(ref mut place) => {
@@ -397,6 +763,14 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
Rvalue::Cast(kind, ref mut value, to) => {
let from = value.ty(self.local_decls, self.tcx);
let value = self.simplify_operand(value, location)?;
+ if let CastKind::PointerCoercion(
+ PointerCoercion::ReifyFnPointer | PointerCoercion::ClosureFnPointer(_),
+ ) = kind
+ {
+ // Each reification of a generic fn may get a different pointer.
+ // Do not try to merge them.
+ return self.new_opaque();
+ }
Value::Cast { kind, value, from, to }
}
Rvalue::BinaryOp(op, box (ref mut lhs, ref mut rhs)) => {
@@ -415,6 +789,9 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
}
Rvalue::Discriminant(ref mut place) => {
let place = self.simplify_place_value(place, location)?;
+ if let Some(discr) = self.simplify_discriminant(place) {
+ return Some(discr);
+ }
Value::Discriminant(place)
}
@@ -424,45 +801,182 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
debug!(?value);
Some(self.insert(value))
}
+
+ fn simplify_discriminant(&mut self, place: VnIndex) -> Option<VnIndex> {
+ if let Value::Aggregate(enum_ty, variant, _) = *self.get(place)
+ && let AggregateTy::Def(enum_did, enum_substs) = enum_ty
+ && let DefKind::Enum = self.tcx.def_kind(enum_did)
+ {
+ let enum_ty = self.tcx.type_of(enum_did).instantiate(self.tcx, enum_substs);
+ let discr = self.ecx.discriminant_for_variant(enum_ty, variant).ok()?;
+ return Some(self.insert_scalar(discr.to_scalar(), discr.layout.ty));
+ }
+
+ None
+ }
+
+ fn simplify_aggregate(
+ &mut self,
+ rvalue: &mut Rvalue<'tcx>,
+ location: Location,
+ ) -> Option<VnIndex> {
+ let Rvalue::Aggregate(box ref kind, ref mut fields) = *rvalue else { bug!() };
+
+ let tcx = self.tcx;
+ if fields.is_empty() {
+ let is_zst = match *kind {
+ AggregateKind::Array(..) | AggregateKind::Tuple | AggregateKind::Closure(..) => {
+ true
+ }
+ // Only enums can be non-ZST.
+ AggregateKind::Adt(did, ..) => tcx.def_kind(did) != DefKind::Enum,
+ // Coroutines are never ZST, as they at least contain the implicit states.
+ AggregateKind::Coroutine(..) => false,
+ };
+
+ if is_zst {
+ let ty = rvalue.ty(self.local_decls, tcx);
+ return self.insert_constant(Const::zero_sized(ty));
+ }
+ }
+
+ let (ty, variant_index) = match *kind {
+ AggregateKind::Array(..) => {
+ assert!(!fields.is_empty());
+ (AggregateTy::Array, FIRST_VARIANT)
+ }
+ AggregateKind::Tuple => {
+ assert!(!fields.is_empty());
+ (AggregateTy::Tuple, FIRST_VARIANT)
+ }
+ AggregateKind::Closure(did, substs) | AggregateKind::Coroutine(did, substs, _) => {
+ (AggregateTy::Def(did, substs), FIRST_VARIANT)
+ }
+ AggregateKind::Adt(did, variant_index, substs, _, None) => {
+ (AggregateTy::Def(did, substs), variant_index)
+ }
+ // Do not track unions.
+ AggregateKind::Adt(_, _, _, _, Some(_)) => return None,
+ };
+
+ let fields: Option<Vec<_>> = fields
+ .iter_mut()
+ .map(|op| self.simplify_operand(op, location).or_else(|| self.new_opaque()))
+ .collect();
+ let fields = fields?;
+
+ if let AggregateTy::Array = ty && fields.len() > 4 {
+ let first = fields[0];
+ if fields.iter().all(|&v| v == first) {
+ let len = ty::Const::from_target_usize(self.tcx, fields.len().try_into().unwrap());
+ if let Some(const_) = self.try_as_constant(first) {
+ *rvalue = Rvalue::Repeat(Operand::Constant(Box::new(const_)), len);
+ } else if let Some(local) = self.try_as_local(first, location) {
+ *rvalue = Rvalue::Repeat(Operand::Copy(local.into()), len);
+ self.reused_locals.insert(local);
+ }
+ return Some(self.insert(Value::Repeat(first, len)));
+ }
+ }
+
+ Some(self.insert(Value::Aggregate(ty, variant_index, fields)))
+ }
+}
+
+fn op_to_prop_const<'tcx>(
+ ecx: &mut InterpCx<'_, 'tcx, DummyMachine>,
+ op: &OpTy<'tcx>,
+) -> Option<ConstValue<'tcx>> {
+ // Do not attempt to propagate unsized locals.
+ if op.layout.is_unsized() {
+ return None;
+ }
+
+ // This constant is a ZST, just return an empty value.
+ if op.layout.is_zst() {
+ return Some(ConstValue::ZeroSized);
+ }
+
+ // Do not synthetize too large constants. Codegen will just memcpy them, which we'd like to avoid.
+ if !matches!(op.layout.abi, Abi::Scalar(..) | Abi::ScalarPair(..)) {
+ return None;
+ }
+
+ // If this constant has scalar ABI, return it as a `ConstValue::Scalar`.
+ if let Abi::Scalar(abi::Scalar::Initialized { .. }) = op.layout.abi
+ && let Ok(scalar) = ecx.read_scalar(op)
+ && scalar.try_to_int().is_ok()
+ {
+ return Some(ConstValue::Scalar(scalar));
+ }
+
+ // If this constant is already represented as an `Allocation`,
+ // try putting it into global memory to return it.
+ if let Either::Left(mplace) = op.as_mplace_or_imm() {
+ let (size, _align) = ecx.size_and_align_of_mplace(&mplace).ok()??;
+
+ // Do not try interning a value that contains provenance.
+ // Due to https://github.com/rust-lang/rust/issues/79738, doing so could lead to bugs.
+ // FIXME: remove this hack once that issue is fixed.
+ let alloc_ref = ecx.get_ptr_alloc(mplace.ptr(), size).ok()??;
+ if alloc_ref.has_provenance() {
+ return None;
+ }
+
+ let pointer = mplace.ptr().into_pointer_or_addr().ok()?;
+ let (alloc_id, offset) = pointer.into_parts();
+ intern_const_alloc_for_constprop(ecx, alloc_id).ok()?;
+ if matches!(ecx.tcx.global_alloc(alloc_id), GlobalAlloc::Memory(_)) {
+ // `alloc_id` may point to a static. Codegen will choke on an `Indirect` with anything
+ // by `GlobalAlloc::Memory`, so do fall through to copying if needed.
+ // FIXME: find a way to treat this more uniformly
+ // (probably by fixing codegen)
+ return Some(ConstValue::Indirect { alloc_id, offset });
+ }
+ }
+
+ // Everything failed: create a new allocation to hold the data.
+ let alloc_id =
+ ecx.intern_with_temp_alloc(op.layout, |ecx, dest| ecx.copy_op(op, dest, false)).ok()?;
+ let value = ConstValue::Indirect { alloc_id, offset: Size::ZERO };
+
+ // Check that we do not leak a pointer.
+ // Those pointers may lose part of their identity in codegen.
+ // FIXME: remove this hack once https://github.com/rust-lang/rust/issues/79738 is fixed.
+ if ecx.tcx.global_alloc(alloc_id).unwrap_memory().inner().provenance().ptrs().is_empty() {
+ return Some(value);
+ }
+
+ None
}
impl<'tcx> VnState<'_, 'tcx> {
/// If `index` is a `Value::Constant`, return the `Constant` to be put in the MIR.
fn try_as_constant(&mut self, index: VnIndex) -> Option<ConstOperand<'tcx>> {
- if let Value::Constant(const_) = *self.get(index) {
- // Some constants may contain pointers. We need to preserve the provenance of these
- // pointers, but not all constants guarantee this:
- // - valtrees purposefully do not;
- // - ConstValue::Slice does not either.
- match const_ {
- Const::Ty(c) => match c.kind() {
- ty::ConstKind::Value(valtree) => match valtree {
- // This is just an integer, keep it.
- ty::ValTree::Leaf(_) => {}
- ty::ValTree::Branch(_) => return None,
- },
- ty::ConstKind::Param(..)
- | ty::ConstKind::Unevaluated(..)
- | ty::ConstKind::Expr(..) => {}
- // Should not appear in runtime MIR.
- ty::ConstKind::Infer(..)
- | ty::ConstKind::Bound(..)
- | ty::ConstKind::Placeholder(..)
- | ty::ConstKind::Error(..) => bug!(),
- },
- Const::Unevaluated(..) => {}
- // If the same slice appears twice in the MIR, we cannot guarantee that we will
- // give the same `AllocId` to the data.
- Const::Val(ConstValue::Slice { .. }, _) => return None,
- Const::Val(
- ConstValue::ZeroSized | ConstValue::Scalar(_) | ConstValue::Indirect { .. },
- _,
- ) => {}
- }
- Some(ConstOperand { span: rustc_span::DUMMY_SP, user_ty: None, const_ })
- } else {
- None
+ // This was already constant in MIR, do not change it.
+ if let Value::Constant { value, disambiguator: _ } = *self.get(index)
+ // If the constant is not deterministic, adding an additional mention of it in MIR will
+ // not give the same value as the former mention.
+ && value.is_deterministic()
+ {
+ return Some(ConstOperand { span: rustc_span::DUMMY_SP, user_ty: None, const_: value });
+ }
+
+ let op = self.evaluated[index].as_ref()?;
+ if op.layout.is_unsized() {
+ // Do not attempt to propagate unsized locals.
+ return None;
}
+
+ let value = op_to_prop_const(&mut self.ecx, op)?;
+
+ // Check that we do not leak a pointer.
+ // Those pointers may lose part of their identity in codegen.
+ // FIXME: remove this hack once https://github.com/rust-lang/rust/issues/79738 is fixed.
+ assert!(!value.may_have_provenance(self.tcx, op.layout.size));
+
+ let const_ = Const::Val(value, op.layout.ty);
+ Some(ConstOperand { span: rustc_span::DUMMY_SP, user_ty: None, const_ })
}
/// If there is a local which is assigned `index`, and its assignment strictly dominates `loc`,
@@ -481,27 +995,32 @@ impl<'tcx> MutVisitor<'tcx> for VnState<'_, 'tcx> {
self.tcx
}
+ fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, location: Location) {
+ self.simplify_place_projection(place, location);
+ }
+
fn visit_operand(&mut self, operand: &mut Operand<'tcx>, location: Location) {
self.simplify_operand(operand, location);
}
fn visit_statement(&mut self, stmt: &mut Statement<'tcx>, location: Location) {
- self.super_statement(stmt, location);
if let StatementKind::Assign(box (_, ref mut rvalue)) = stmt.kind
// Do not try to simplify a constant, it's already in canonical shape.
&& !matches!(rvalue, Rvalue::Use(Operand::Constant(_)))
- && let Some(value) = self.simplify_rvalue(rvalue, location)
{
- if let Some(const_) = self.try_as_constant(value) {
- *rvalue = Rvalue::Use(Operand::Constant(Box::new(const_)));
- self.any_replacement = true;
- } else if let Some(local) = self.try_as_local(value, location)
- && *rvalue != Rvalue::Use(Operand::Move(local.into()))
+ if let Some(value) = self.simplify_rvalue(rvalue, location)
{
- *rvalue = Rvalue::Use(Operand::Copy(local.into()));
- self.reused_locals.insert(local);
- self.any_replacement = true;
+ if let Some(const_) = self.try_as_constant(value) {
+ *rvalue = Rvalue::Use(Operand::Constant(Box::new(const_)));
+ } else if let Some(local) = self.try_as_local(value, location)
+ && *rvalue != Rvalue::Use(Operand::Move(local.into()))
+ {
+ *rvalue = Rvalue::Use(Operand::Copy(local.into()));
+ self.reused_locals.insert(local);
+ }
}
+ } else {
+ self.super_statement(stmt, location);
}
}
}
diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs
index b53e0852c..793dcf0d9 100644
--- a/compiler/rustc_mir_transform/src/inline.rs
+++ b/compiler/rustc_mir_transform/src/inline.rs
@@ -14,6 +14,7 @@ use rustc_session::config::OptLevel;
use rustc_target::abi::FieldIdx;
use rustc_target::spec::abi::Abi;
+use crate::cost_checker::CostChecker;
use crate::simplify::{remove_dead_blocks, CfgSimplifier};
use crate::util;
use crate::MirPass;
@@ -22,11 +23,6 @@ use std::ops::{Range, RangeFrom};
pub(crate) mod cycle;
-const INSTR_COST: usize = 5;
-const CALL_PENALTY: usize = 25;
-const LANDINGPAD_PENALTY: usize = 50;
-const RESUME_PENALTY: usize = 45;
-
const TOP_DOWN_DEPTH_LIMIT: usize = 5;
pub struct Inline;
@@ -63,7 +59,7 @@ impl<'tcx> MirPass<'tcx> for Inline {
if inline(tcx, body) {
debug!("running simplify cfg on {:?}", body.source);
CfgSimplifier::new(body).simplify();
- remove_dead_blocks(tcx, body);
+ remove_dead_blocks(body);
deref_finder(tcx, body);
}
}
@@ -79,10 +75,10 @@ fn inline<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> bool {
if body.source.promoted.is_some() {
return false;
}
- // Avoid inlining into generators, since their `optimized_mir` is used for layout computation,
+ // Avoid inlining into coroutines, since their `optimized_mir` is used for layout computation,
// which can create a cycle, even when no attempt is made to inline the function in the other
// direction.
- if body.generator.is_some() {
+ if body.coroutine.is_some() {
return false;
}
@@ -169,8 +165,11 @@ impl<'tcx> Inliner<'tcx> {
caller_body: &mut Body<'tcx>,
callsite: &CallSite<'tcx>,
) -> Result<std::ops::Range<BasicBlock>, &'static str> {
+ self.check_mir_is_available(caller_body, &callsite.callee)?;
+
let callee_attrs = self.tcx.codegen_fn_attrs(callsite.callee.def_id());
- self.check_codegen_attributes(callsite, callee_attrs)?;
+ let cross_crate_inlinable = self.tcx.cross_crate_inlinable(callsite.callee.def_id());
+ self.check_codegen_attributes(callsite, callee_attrs, cross_crate_inlinable)?;
let terminator = caller_body[callsite.block].terminator.as_ref().unwrap();
let TerminatorKind::Call { args, destination, .. } = &terminator.kind else { bug!() };
@@ -183,9 +182,8 @@ impl<'tcx> Inliner<'tcx> {
}
}
- self.check_mir_is_available(caller_body, &callsite.callee)?;
let callee_body = try_instance_mir(self.tcx, callsite.callee.def)?;
- self.check_mir_body(callsite, callee_body, callee_attrs)?;
+ self.check_mir_body(callsite, callee_body, callee_attrs, cross_crate_inlinable)?;
if !self.tcx.consider_optimizing(|| {
format!("Inline {:?} into {:?}", callsite.callee, caller_body.source)
@@ -401,6 +399,7 @@ impl<'tcx> Inliner<'tcx> {
&self,
callsite: &CallSite<'tcx>,
callee_attrs: &CodegenFnAttrs,
+ cross_crate_inlinable: bool,
) -> Result<(), &'static str> {
if let InlineAttr::Never = callee_attrs.inline {
return Err("never inline hint");
@@ -414,7 +413,7 @@ impl<'tcx> Inliner<'tcx> {
.non_erasable_generics(self.tcx, callsite.callee.def_id())
.next()
.is_some();
- if !is_generic && !callee_attrs.requests_inline() {
+ if !is_generic && !cross_crate_inlinable {
return Err("not exported");
}
@@ -439,10 +438,13 @@ impl<'tcx> Inliner<'tcx> {
return Err("incompatible instruction set");
}
- for feature in &callee_attrs.target_features {
- if !self.codegen_fn_attrs.target_features.contains(feature) {
- return Err("incompatible target feature");
- }
+ if callee_attrs.target_features != self.codegen_fn_attrs.target_features {
+ // In general it is not correct to inline a callee with target features that are a
+ // subset of the caller. This is because the callee might contain calls, and the ABI of
+ // those calls depends on the target features of the surrounding function. By moving a
+ // `Call` terminator from one MIR body to another with more target features, we might
+ // change the ABI of that call!
+ return Err("incompatible target features");
}
Ok(())
@@ -456,10 +458,11 @@ impl<'tcx> Inliner<'tcx> {
callsite: &CallSite<'tcx>,
callee_body: &Body<'tcx>,
callee_attrs: &CodegenFnAttrs,
+ cross_crate_inlinable: bool,
) -> Result<(), &'static str> {
let tcx = self.tcx;
- let mut threshold = if callee_attrs.requests_inline() {
+ let mut threshold = if cross_crate_inlinable {
self.tcx.sess.opts.unstable_opts.inline_mir_hint_threshold.unwrap_or(100)
} else {
self.tcx.sess.opts.unstable_opts.inline_mir_threshold.unwrap_or(50)
@@ -475,13 +478,8 @@ impl<'tcx> Inliner<'tcx> {
// FIXME: Give a bonus to functions with only a single caller
- let mut checker = CostChecker {
- tcx: self.tcx,
- param_env: self.param_env,
- instance: callsite.callee,
- callee_body,
- cost: 0,
- };
+ let mut checker =
+ CostChecker::new(self.tcx, self.param_env, Some(callsite.callee), callee_body);
// Traverse the MIR manually so we can account for the effects of inlining on the CFG.
let mut work_list = vec![START_BLOCK];
@@ -503,7 +501,9 @@ impl<'tcx> Inliner<'tcx> {
self.tcx,
ty::EarlyBinder::bind(&place.ty(callee_body, tcx).ty),
);
- if ty.needs_drop(tcx, self.param_env) && let UnwindAction::Cleanup(unwind) = unwind {
+ if ty.needs_drop(tcx, self.param_env)
+ && let UnwindAction::Cleanup(unwind) = unwind
+ {
work_list.push(unwind);
}
} else if callee_attrs.instruction_set != self.codegen_fn_attrs.instruction_set
@@ -524,7 +524,7 @@ impl<'tcx> Inliner<'tcx> {
// That attribute is often applied to very large functions that exceed LLVM's (very
// generous) inlining threshold. Such functions are very poor MIR inlining candidates.
// Always inlining #[inline(always)] functions in MIR, on net, slows down the compiler.
- let cost = checker.cost;
+ let cost = checker.cost();
if cost <= threshold {
debug!("INLINING {:?} [cost={} <= threshold={}]", callsite, cost, threshold);
Ok(())
@@ -616,9 +616,7 @@ impl<'tcx> Inliner<'tcx> {
// If there are any locals without storage markers, give them storage only for the
// duration of the call.
for local in callee_body.vars_and_temps_iter() {
- if !callee_body.local_decls[local].internal
- && integrator.always_live_locals.contains(local)
- {
+ if integrator.always_live_locals.contains(local) {
let new_local = integrator.map_local(local);
caller_body[callsite.block].statements.push(Statement {
source_info: callsite.source_info,
@@ -641,9 +639,7 @@ impl<'tcx> Inliner<'tcx> {
n += 1;
}
for local in callee_body.vars_and_temps_iter().rev() {
- if !callee_body.local_decls[local].internal
- && integrator.always_live_locals.contains(local)
- {
+ if integrator.always_live_locals.contains(local) {
let new_local = integrator.map_local(local);
caller_body[block].statements.push(Statement {
source_info: callsite.source_info,
@@ -801,79 +797,6 @@ impl<'tcx> Inliner<'tcx> {
}
}
-/// Verify that the callee body is compatible with the caller.
-///
-/// This visitor mostly computes the inlining cost,
-/// but also needs to verify that types match because of normalization failure.
-struct CostChecker<'b, 'tcx> {
- tcx: TyCtxt<'tcx>,
- param_env: ParamEnv<'tcx>,
- cost: usize,
- callee_body: &'b Body<'tcx>,
- instance: ty::Instance<'tcx>,
-}
-
-impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
- fn visit_statement(&mut self, statement: &Statement<'tcx>, _: Location) {
- // Don't count StorageLive/StorageDead in the inlining cost.
- match statement.kind {
- StatementKind::StorageLive(_)
- | StatementKind::StorageDead(_)
- | StatementKind::Deinit(_)
- | StatementKind::Nop => {}
- _ => self.cost += INSTR_COST,
- }
- }
-
- fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, _: Location) {
- let tcx = self.tcx;
- match terminator.kind {
- TerminatorKind::Drop { ref place, unwind, .. } => {
- // If the place doesn't actually need dropping, treat it like a regular goto.
- let ty = self.instance.instantiate_mir(
- tcx,
- ty::EarlyBinder::bind(&place.ty(self.callee_body, tcx).ty),
- );
- if ty.needs_drop(tcx, self.param_env) {
- self.cost += CALL_PENALTY;
- if let UnwindAction::Cleanup(_) = unwind {
- self.cost += LANDINGPAD_PENALTY;
- }
- } else {
- self.cost += INSTR_COST;
- }
- }
- TerminatorKind::Call { func: Operand::Constant(ref f), unwind, .. } => {
- let fn_ty =
- self.instance.instantiate_mir(tcx, ty::EarlyBinder::bind(&f.const_.ty()));
- self.cost += if let ty::FnDef(def_id, _) = *fn_ty.kind() && tcx.is_intrinsic(def_id) {
- // Don't give intrinsics the extra penalty for calls
- INSTR_COST
- } else {
- CALL_PENALTY
- };
- if let UnwindAction::Cleanup(_) = unwind {
- self.cost += LANDINGPAD_PENALTY;
- }
- }
- TerminatorKind::Assert { unwind, .. } => {
- self.cost += CALL_PENALTY;
- if let UnwindAction::Cleanup(_) = unwind {
- self.cost += LANDINGPAD_PENALTY;
- }
- }
- TerminatorKind::UnwindResume => self.cost += RESUME_PENALTY,
- TerminatorKind::InlineAsm { unwind, .. } => {
- self.cost += INSTR_COST;
- if let UnwindAction::Cleanup(_) = unwind {
- self.cost += LANDINGPAD_PENALTY;
- }
- }
- _ => self.cost += INSTR_COST,
- }
- }
-}
-
/**
* Integrator.
*
@@ -1010,7 +933,7 @@ impl<'tcx> MutVisitor<'tcx> for Integrator<'_, 'tcx> {
}
match terminator.kind {
- TerminatorKind::GeneratorDrop | TerminatorKind::Yield { .. } => bug!(),
+ TerminatorKind::CoroutineDrop | TerminatorKind::Yield { .. } => bug!(),
TerminatorKind::Goto { ref mut target } => {
*target = self.map_block(*target);
}
diff --git a/compiler/rustc_mir_transform/src/instsimplify.rs b/compiler/rustc_mir_transform/src/instsimplify.rs
index a6ef2e11a..fbcd6e75a 100644
--- a/compiler/rustc_mir_transform/src/instsimplify.rs
+++ b/compiler/rustc_mir_transform/src/instsimplify.rs
@@ -93,7 +93,9 @@ impl<'tcx> InstSimplifyContext<'tcx, '_> {
_ => None,
};
- if let Some(new) = new && self.should_simplify(source_info, rvalue) {
+ if let Some(new) = new
+ && self.should_simplify(source_info, rvalue)
+ {
*rvalue = new;
}
}
@@ -150,7 +152,8 @@ impl<'tcx> InstSimplifyContext<'tcx, '_> {
*rvalue = Rvalue::Use(operand.clone());
} else if *kind == CastKind::Transmute {
// Transmuting an integer to another integer is just a signedness cast
- if let (ty::Int(int), ty::Uint(uint)) | (ty::Uint(uint), ty::Int(int)) = (operand_ty.kind(), cast_ty.kind())
+ if let (ty::Int(int), ty::Uint(uint)) | (ty::Uint(uint), ty::Int(int)) =
+ (operand_ty.kind(), cast_ty.kind())
&& int.bit_width() == uint.bit_width()
{
// The width check isn't strictly necessary, as different widths
@@ -172,8 +175,15 @@ impl<'tcx> InstSimplifyContext<'tcx, '_> {
for (i, field) in variant.fields.iter().enumerate() {
let field_ty = field.ty(self.tcx, args);
if field_ty == *cast_ty {
- let place = place.project_deeper(&[ProjectionElem::Field(FieldIdx::from_usize(i), *cast_ty)], self.tcx);
- let operand = if operand.is_move() { Operand::Move(place) } else { Operand::Copy(place) };
+ let place = place.project_deeper(
+ &[ProjectionElem::Field(FieldIdx::from_usize(i), *cast_ty)],
+ self.tcx,
+ );
+ let operand = if operand.is_move() {
+ Operand::Move(place)
+ } else {
+ Operand::Copy(place)
+ };
*rvalue = Rvalue::Use(operand);
return;
}
diff --git a/compiler/rustc_mir_transform/src/jump_threading.rs b/compiler/rustc_mir_transform/src/jump_threading.rs
new file mode 100644
index 000000000..7b918be44
--- /dev/null
+++ b/compiler/rustc_mir_transform/src/jump_threading.rs
@@ -0,0 +1,759 @@
+//! A jump threading optimization.
+//!
+//! This optimization seeks to replace join-then-switch control flow patterns by straight jumps
+//! X = 0 X = 0
+//! ------------\ /-------- ------------
+//! X = 1 X----X SwitchInt(X) => X = 1
+//! ------------/ \-------- ------------
+//!
+//!
+//! We proceed by walking the cfg backwards starting from each `SwitchInt` terminator,
+//! looking for assignments that will turn the `SwitchInt` into a simple `Goto`.
+//!
+//! The algorithm maintains a set of replacement conditions:
+//! - `conditions[place]` contains `Condition { value, polarity: Eq, target }`
+//! if assigning `value` to `place` turns the `SwitchInt` into `Goto { target }`.
+//! - `conditions[place]` contains `Condition { value, polarity: Ne, target }`
+//! if assigning anything different from `value` to `place` turns the `SwitchInt`
+//! into `Goto { target }`.
+//!
+//! In this file, we denote as `place ?= value` the existence of a replacement condition
+//! on `place` with given `value`, irrespective of the polarity and target of that
+//! replacement condition.
+//!
+//! We then walk the CFG backwards transforming the set of conditions.
+//! When we find a fulfilling assignment, we record a `ThreadingOpportunity`.
+//! All `ThreadingOpportunity`s are applied to the body, by duplicating blocks if required.
+//!
+//! The optimization search can be very heavy, as it performs a DFS on MIR starting from
+//! each `SwitchInt` terminator. To manage the complexity, we:
+//! - bound the maximum depth by a constant `MAX_BACKTRACK`;
+//! - we only traverse `Goto` terminators.
+//!
+//! We try to avoid creating irreducible control-flow by not threading through a loop header.
+//!
+//! Likewise, applying the optimisation can create a lot of new MIR, so we bound the instruction
+//! cost by `MAX_COST`.
+
+use rustc_arena::DroplessArena;
+use rustc_data_structures::fx::FxHashSet;
+use rustc_index::bit_set::BitSet;
+use rustc_index::IndexVec;
+use rustc_middle::mir::visit::Visitor;
+use rustc_middle::mir::*;
+use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
+use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
+
+use crate::cost_checker::CostChecker;
+use crate::MirPass;
+
+pub struct JumpThreading;
+
+const MAX_BACKTRACK: usize = 5;
+const MAX_COST: usize = 100;
+const MAX_PLACES: usize = 100;
+
+impl<'tcx> MirPass<'tcx> for JumpThreading {
+ fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
+ sess.mir_opt_level() >= 4
+ }
+
+ #[instrument(skip_all level = "debug")]
+ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+ let def_id = body.source.def_id();
+ debug!(?def_id);
+
+ let param_env = tcx.param_env_reveal_all_normalized(def_id);
+ let map = Map::new(tcx, body, Some(MAX_PLACES));
+ let loop_headers = loop_headers(body);
+
+ let arena = DroplessArena::default();
+ let mut finder = TOFinder {
+ tcx,
+ param_env,
+ body,
+ arena: &arena,
+ map: &map,
+ loop_headers: &loop_headers,
+ opportunities: Vec::new(),
+ };
+
+ for (bb, bbdata) in body.basic_blocks.iter_enumerated() {
+ debug!(?bb, term = ?bbdata.terminator());
+ if bbdata.is_cleanup || loop_headers.contains(bb) {
+ continue;
+ }
+ let Some((discr, targets)) = bbdata.terminator().kind.as_switch() else { continue };
+ let Some(discr) = discr.place() else { continue };
+ debug!(?discr, ?bb);
+
+ let discr_ty = discr.ty(body, tcx).ty;
+ let Ok(discr_layout) = tcx.layout_of(param_env.and(discr_ty)) else { continue };
+
+ let Some(discr) = finder.map.find(discr.as_ref()) else { continue };
+ debug!(?discr);
+
+ let cost = CostChecker::new(tcx, param_env, None, body);
+
+ let mut state = State::new(ConditionSet::default(), &finder.map);
+
+ let conds = if let Some((value, then, else_)) = targets.as_static_if() {
+ let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else {
+ continue;
+ };
+ arena.alloc_from_iter([
+ Condition { value, polarity: Polarity::Eq, target: then },
+ Condition { value, polarity: Polarity::Ne, target: else_ },
+ ])
+ } else {
+ arena.alloc_from_iter(targets.iter().filter_map(|(value, target)| {
+ let value = ScalarInt::try_from_uint(value, discr_layout.size)?;
+ Some(Condition { value, polarity: Polarity::Eq, target })
+ }))
+ };
+ let conds = ConditionSet(conds);
+ state.insert_value_idx(discr, conds, &finder.map);
+
+ finder.find_opportunity(bb, state, cost, 0);
+ }
+
+ let opportunities = finder.opportunities;
+ debug!(?opportunities);
+ if opportunities.is_empty() {
+ return;
+ }
+
+ // Verify that we do not thread through a loop header.
+ for to in opportunities.iter() {
+ assert!(to.chain.iter().all(|&block| !loop_headers.contains(block)));
+ }
+ OpportunitySet::new(body, opportunities).apply(body);
+ }
+}
+
+#[derive(Debug)]
+struct ThreadingOpportunity {
+ /// The list of `BasicBlock`s from the one that found the opportunity to the `SwitchInt`.
+ chain: Vec<BasicBlock>,
+ /// The `SwitchInt` will be replaced by `Goto { target }`.
+ target: BasicBlock,
+}
+
+struct TOFinder<'tcx, 'a> {
+ tcx: TyCtxt<'tcx>,
+ param_env: ty::ParamEnv<'tcx>,
+ body: &'a Body<'tcx>,
+ map: &'a Map,
+ loop_headers: &'a BitSet<BasicBlock>,
+ /// We use an arena to avoid cloning the slices when cloning `state`.
+ arena: &'a DroplessArena,
+ opportunities: Vec<ThreadingOpportunity>,
+}
+
+/// Represent the following statement. If we can prove that the current local is equal/not-equal
+/// to `value`, jump to `target`.
+#[derive(Copy, Clone, Debug)]
+struct Condition {
+ value: ScalarInt,
+ polarity: Polarity,
+ target: BasicBlock,
+}
+
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+enum Polarity {
+ Ne,
+ Eq,
+}
+
+impl Condition {
+ fn matches(&self, value: ScalarInt) -> bool {
+ (self.value == value) == (self.polarity == Polarity::Eq)
+ }
+
+ fn inv(mut self) -> Self {
+ self.polarity = match self.polarity {
+ Polarity::Eq => Polarity::Ne,
+ Polarity::Ne => Polarity::Eq,
+ };
+ self
+ }
+}
+
+#[derive(Copy, Clone, Debug, Default)]
+struct ConditionSet<'a>(&'a [Condition]);
+
+impl<'a> ConditionSet<'a> {
+ fn iter(self) -> impl Iterator<Item = Condition> + 'a {
+ self.0.iter().copied()
+ }
+
+ fn iter_matches(self, value: ScalarInt) -> impl Iterator<Item = Condition> + 'a {
+ self.iter().filter(move |c| c.matches(value))
+ }
+
+ fn map(self, arena: &'a DroplessArena, f: impl Fn(Condition) -> Condition) -> ConditionSet<'a> {
+ ConditionSet(arena.alloc_from_iter(self.iter().map(f)))
+ }
+}
+
+impl<'tcx, 'a> TOFinder<'tcx, 'a> {
+ fn is_empty(&self, state: &State<ConditionSet<'a>>) -> bool {
+ state.all(|cs| cs.0.is_empty())
+ }
+
+ /// Recursion entry point to find threading opportunities.
+ #[instrument(level = "trace", skip(self, cost), ret)]
+ fn find_opportunity(
+ &mut self,
+ bb: BasicBlock,
+ mut state: State<ConditionSet<'a>>,
+ mut cost: CostChecker<'_, 'tcx>,
+ depth: usize,
+ ) {
+ // Do not thread through loop headers.
+ if self.loop_headers.contains(bb) {
+ return;
+ }
+
+ debug!(cost = ?cost.cost());
+ for (statement_index, stmt) in
+ self.body.basic_blocks[bb].statements.iter().enumerate().rev()
+ {
+ if self.is_empty(&state) {
+ return;
+ }
+
+ cost.visit_statement(stmt, Location { block: bb, statement_index });
+ if cost.cost() > MAX_COST {
+ return;
+ }
+
+ // Attempt to turn the `current_condition` on `lhs` into a condition on another place.
+ self.process_statement(bb, stmt, &mut state);
+
+ // When a statement mutates a place, assignments to that place that happen
+ // above the mutation cannot fulfill a condition.
+ // _1 = 5 // Whatever happens here, it won't change the result of a `SwitchInt`.
+ // _1 = 6
+ if let Some((lhs, tail)) = self.mutated_statement(stmt) {
+ state.flood_with_tail_elem(lhs.as_ref(), tail, self.map, ConditionSet::default());
+ }
+ }
+
+ if self.is_empty(&state) || depth >= MAX_BACKTRACK {
+ return;
+ }
+
+ let last_non_rec = self.opportunities.len();
+
+ let predecessors = &self.body.basic_blocks.predecessors()[bb];
+ if let &[pred] = &predecessors[..] && bb != START_BLOCK {
+ let term = self.body.basic_blocks[pred].terminator();
+ match term.kind {
+ TerminatorKind::SwitchInt { ref discr, ref targets } => {
+ self.process_switch_int(discr, targets, bb, &mut state);
+ self.find_opportunity(pred, state, cost, depth + 1);
+ }
+ _ => self.recurse_through_terminator(pred, &state, &cost, depth),
+ }
+ } else {
+ for &pred in predecessors {
+ self.recurse_through_terminator(pred, &state, &cost, depth);
+ }
+ }
+
+ let new_tos = &mut self.opportunities[last_non_rec..];
+ debug!(?new_tos);
+
+ // Try to deduplicate threading opportunities.
+ if new_tos.len() > 1
+ && new_tos.len() == predecessors.len()
+ && predecessors
+ .iter()
+ .zip(new_tos.iter())
+ .all(|(&pred, to)| to.chain == &[pred] && to.target == new_tos[0].target)
+ {
+ // All predecessors have a threading opportunity, and they all point to the same block.
+ debug!(?new_tos, "dedup");
+ let first = &mut new_tos[0];
+ *first = ThreadingOpportunity { chain: vec![bb], target: first.target };
+ self.opportunities.truncate(last_non_rec + 1);
+ return;
+ }
+
+ for op in self.opportunities[last_non_rec..].iter_mut() {
+ op.chain.push(bb);
+ }
+ }
+
+ /// Extract the mutated place from a statement.
+ ///
+ /// This method returns the `Place` so we can flood the state in case of a partial assignment.
+ /// (_1 as Ok).0 = _5;
+ /// (_1 as Err).0 = _6;
+ /// We want to ensure that a `SwitchInt((_1 as Ok).0)` does not see the first assignment, as
+ /// the value may have been mangled by the second assignment.
+ ///
+ /// In case we assign to a discriminant, we return `Some(TrackElem::Discriminant)`, so we can
+ /// stop at flooding the discriminant, and preserve the variant fields.
+ /// (_1 as Some).0 = _6;
+ /// SetDiscriminant(_1, 1);
+ /// switchInt((_1 as Some).0)
+ #[instrument(level = "trace", skip(self), ret)]
+ fn mutated_statement(
+ &self,
+ stmt: &Statement<'tcx>,
+ ) -> Option<(Place<'tcx>, Option<TrackElem>)> {
+ match stmt.kind {
+ StatementKind::Assign(box (place, _))
+ | StatementKind::Deinit(box place) => Some((place, None)),
+ StatementKind::SetDiscriminant { box place, variant_index: _ } => {
+ Some((place, Some(TrackElem::Discriminant)))
+ }
+ StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => {
+ Some((Place::from(local), None))
+ }
+ StatementKind::Retag(..)
+ | StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(..))
+ // copy_nonoverlapping takes pointers and mutated the pointed-to value.
+ | StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping(..))
+ | StatementKind::AscribeUserType(..)
+ | StatementKind::Coverage(..)
+ | StatementKind::FakeRead(..)
+ | StatementKind::ConstEvalCounter
+ | StatementKind::PlaceMention(..)
+ | StatementKind::Nop => None,
+ }
+ }
+
+ #[instrument(level = "trace", skip(self))]
+ fn process_operand(
+ &mut self,
+ bb: BasicBlock,
+ lhs: PlaceIndex,
+ rhs: &Operand<'tcx>,
+ state: &mut State<ConditionSet<'a>>,
+ ) -> Option<!> {
+ let register_opportunity = |c: Condition| {
+ debug!(?bb, ?c.target, "register");
+ self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
+ };
+
+ match rhs {
+ // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
+ Operand::Constant(constant) => {
+ let conditions = state.try_get_idx(lhs, self.map)?;
+ let constant =
+ constant.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?;
+ conditions.iter_matches(constant).for_each(register_opportunity);
+ }
+ // Transfer the conditions on the copied rhs.
+ Operand::Move(rhs) | Operand::Copy(rhs) => {
+ let rhs = self.map.find(rhs.as_ref())?;
+ state.insert_place_idx(rhs, lhs, self.map);
+ }
+ }
+
+ None
+ }
+
+ #[instrument(level = "trace", skip(self))]
+ fn process_statement(
+ &mut self,
+ bb: BasicBlock,
+ stmt: &Statement<'tcx>,
+ state: &mut State<ConditionSet<'a>>,
+ ) -> Option<!> {
+ let register_opportunity = |c: Condition| {
+ debug!(?bb, ?c.target, "register");
+ self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
+ };
+
+ // Below, `lhs` is the return value of `mutated_statement`,
+ // the place to which `conditions` apply.
+
+ let discriminant_for_variant = |enum_ty: Ty<'tcx>, variant_index| {
+ let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?;
+ let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?;
+ let scalar = ScalarInt::try_from_uint(discr.val, discr_layout.size)?;
+ Some(Operand::const_from_scalar(
+ self.tcx,
+ discr.ty,
+ scalar.into(),
+ rustc_span::DUMMY_SP,
+ ))
+ };
+
+ match &stmt.kind {
+ // If we expect `discriminant(place) ?= A`,
+ // we have an opportunity if `variant_index ?= A`.
+ StatementKind::SetDiscriminant { box place, variant_index } => {
+ let discr_target = self.map.find_discr(place.as_ref())?;
+ let enum_ty = place.ty(self.body, self.tcx).ty;
+ let discr = discriminant_for_variant(enum_ty, *variant_index)?;
+ self.process_operand(bb, discr_target, &discr, state)?;
+ }
+ // If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
+ StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(
+ Operand::Copy(place) | Operand::Move(place),
+ )) => {
+ let conditions = state.try_get(place.as_ref(), self.map)?;
+ conditions.iter_matches(ScalarInt::TRUE).for_each(register_opportunity);
+ }
+ StatementKind::Assign(box (lhs_place, rhs)) => {
+ if let Some(lhs) = self.map.find(lhs_place.as_ref()) {
+ match rhs {
+ Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state)?,
+ // Transfer the conditions on the copy rhs.
+ Rvalue::CopyForDeref(rhs) => {
+ self.process_operand(bb, lhs, &Operand::Copy(*rhs), state)?
+ }
+ Rvalue::Discriminant(rhs) => {
+ let rhs = self.map.find_discr(rhs.as_ref())?;
+ state.insert_place_idx(rhs, lhs, self.map);
+ }
+ // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
+ Rvalue::Aggregate(box ref kind, ref operands) => {
+ let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
+ let lhs = match kind {
+ // Do not support unions.
+ AggregateKind::Adt(.., Some(_)) => return None,
+ AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
+ if let Some(discr_target) = self.map.apply(lhs, TrackElem::Discriminant)
+ && let Some(discr_value) = discriminant_for_variant(agg_ty, *variant_index)
+ {
+ self.process_operand(bb, discr_target, &discr_value, state);
+ }
+ self.map.apply(lhs, TrackElem::Variant(*variant_index))?
+ }
+ _ => lhs,
+ };
+ for (field_index, operand) in operands.iter_enumerated() {
+ if let Some(field) =
+ self.map.apply(lhs, TrackElem::Field(field_index))
+ {
+ self.process_operand(bb, field, operand, state);
+ }
+ }
+ }
+ // Transfer the conditions on the copy rhs, after inversing polarity.
+ Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => {
+ let conditions = state.try_get_idx(lhs, self.map)?;
+ let place = self.map.find(place.as_ref())?;
+ let conds = conditions.map(self.arena, Condition::inv);
+ state.insert_value_idx(place, conds, self.map);
+ }
+ // We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
+ // Create a condition on `rhs ?= B`.
+ Rvalue::BinaryOp(
+ op,
+ box (
+ Operand::Move(place) | Operand::Copy(place),
+ Operand::Constant(value),
+ )
+ | box (
+ Operand::Constant(value),
+ Operand::Move(place) | Operand::Copy(place),
+ ),
+ ) => {
+ let conditions = state.try_get_idx(lhs, self.map)?;
+ let place = self.map.find(place.as_ref())?;
+ let equals = match op {
+ BinOp::Eq => ScalarInt::TRUE,
+ BinOp::Ne => ScalarInt::FALSE,
+ _ => return None,
+ };
+ let value = value
+ .const_
+ .normalize(self.tcx, self.param_env)
+ .try_to_scalar_int()?;
+ let conds = conditions.map(self.arena, |c| Condition {
+ value,
+ polarity: if c.matches(equals) {
+ Polarity::Eq
+ } else {
+ Polarity::Ne
+ },
+ ..c
+ });
+ state.insert_value_idx(place, conds, self.map);
+ }
+
+ _ => {}
+ }
+ }
+ }
+ _ => {}
+ }
+
+ None
+ }
+
+ #[instrument(level = "trace", skip(self, cost))]
+ fn recurse_through_terminator(
+ &mut self,
+ bb: BasicBlock,
+ state: &State<ConditionSet<'a>>,
+ cost: &CostChecker<'_, 'tcx>,
+ depth: usize,
+ ) {
+ let register_opportunity = |c: Condition| {
+ debug!(?bb, ?c.target, "register");
+ self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
+ };
+
+ let term = self.body.basic_blocks[bb].terminator();
+ let place_to_flood = match term.kind {
+ // We come from a target, so those are not possible.
+ TerminatorKind::UnwindResume
+ | TerminatorKind::UnwindTerminate(_)
+ | TerminatorKind::Return
+ | TerminatorKind::Unreachable
+ | TerminatorKind::CoroutineDrop => bug!("{term:?} has no terminators"),
+ // Disallowed during optimizations.
+ TerminatorKind::FalseEdge { .. }
+ | TerminatorKind::FalseUnwind { .. }
+ | TerminatorKind::Yield { .. } => bug!("{term:?} invalid"),
+ // Cannot reason about inline asm.
+ TerminatorKind::InlineAsm { .. } => return,
+ // `SwitchInt` is handled specially.
+ TerminatorKind::SwitchInt { .. } => return,
+ // We can recurse, no thing particular to do.
+ TerminatorKind::Goto { .. } => None,
+ // Flood the overwritten place, and progress through.
+ TerminatorKind::Drop { place: destination, .. }
+ | TerminatorKind::Call { destination, .. } => Some(destination),
+ // Treat as an `assume(cond == expected)`.
+ TerminatorKind::Assert { ref cond, expected, .. } => {
+ if let Some(place) = cond.place()
+ && let Some(conditions) = state.try_get(place.as_ref(), self.map)
+ {
+ let expected = if expected { ScalarInt::TRUE } else { ScalarInt::FALSE };
+ conditions.iter_matches(expected).for_each(register_opportunity);
+ }
+ None
+ }
+ };
+
+ // We can recurse through this terminator.
+ let mut state = state.clone();
+ if let Some(place_to_flood) = place_to_flood {
+ state.flood_with(place_to_flood.as_ref(), self.map, ConditionSet::default());
+ }
+ self.find_opportunity(bb, state, cost.clone(), depth + 1);
+ }
+
+ #[instrument(level = "trace", skip(self))]
+ fn process_switch_int(
+ &mut self,
+ discr: &Operand<'tcx>,
+ targets: &SwitchTargets,
+ target_bb: BasicBlock,
+ state: &mut State<ConditionSet<'a>>,
+ ) -> Option<!> {
+ debug_assert_ne!(target_bb, START_BLOCK);
+ debug_assert_eq!(self.body.basic_blocks.predecessors()[target_bb].len(), 1);
+
+ let discr = discr.place()?;
+ let discr_ty = discr.ty(self.body, self.tcx).ty;
+ let discr_layout = self.tcx.layout_of(self.param_env.and(discr_ty)).ok()?;
+ let conditions = state.try_get(discr.as_ref(), self.map)?;
+
+ if let Some((value, _)) = targets.iter().find(|&(_, target)| target == target_bb) {
+ let value = ScalarInt::try_from_uint(value, discr_layout.size)?;
+ debug_assert_eq!(targets.iter().filter(|&(_, target)| target == target_bb).count(), 1);
+
+ // We are inside `target_bb`. Since we have a single predecessor, we know we passed
+ // through the `SwitchInt` before arriving here. Therefore, we know that
+ // `discr == value`. If one condition can be fulfilled by `discr == value`,
+ // that's an opportunity.
+ for c in conditions.iter_matches(value) {
+ debug!(?target_bb, ?c.target, "register");
+ self.opportunities.push(ThreadingOpportunity { chain: vec![], target: c.target });
+ }
+ } else if let Some((value, _, else_bb)) = targets.as_static_if()
+ && target_bb == else_bb
+ {
+ let value = ScalarInt::try_from_uint(value, discr_layout.size)?;
+
+ // We only know that `discr != value`. That's much weaker information than
+ // the equality we had in the previous arm. All we can conclude is that
+ // the replacement condition `discr != value` can be threaded, and nothing else.
+ for c in conditions.iter() {
+ if c.value == value && c.polarity == Polarity::Ne {
+ debug!(?target_bb, ?c.target, "register");
+ self.opportunities
+ .push(ThreadingOpportunity { chain: vec![], target: c.target });
+ }
+ }
+ }
+
+ None
+ }
+}
+
+struct OpportunitySet {
+ opportunities: Vec<ThreadingOpportunity>,
+ /// For each bb, give the TOs in which it appears. The pair corresponds to the index
+ /// in `opportunities` and the index in `ThreadingOpportunity::chain`.
+ involving_tos: IndexVec<BasicBlock, Vec<(usize, usize)>>,
+ /// Cache the number of predecessors for each block, as we clear the basic block cache..
+ predecessors: IndexVec<BasicBlock, usize>,
+}
+
+impl OpportunitySet {
+ fn new(body: &Body<'_>, opportunities: Vec<ThreadingOpportunity>) -> OpportunitySet {
+ let mut involving_tos = IndexVec::from_elem(Vec::new(), &body.basic_blocks);
+ for (index, to) in opportunities.iter().enumerate() {
+ for (ibb, &bb) in to.chain.iter().enumerate() {
+ involving_tos[bb].push((index, ibb));
+ }
+ involving_tos[to.target].push((index, to.chain.len()));
+ }
+ let predecessors = predecessor_count(body);
+ OpportunitySet { opportunities, involving_tos, predecessors }
+ }
+
+ /// Apply the opportunities on the graph.
+ fn apply(&mut self, body: &mut Body<'_>) {
+ for i in 0..self.opportunities.len() {
+ self.apply_once(i, body);
+ }
+ }
+
+ #[instrument(level = "trace", skip(self, body))]
+ fn apply_once(&mut self, index: usize, body: &mut Body<'_>) {
+ debug!(?self.predecessors);
+ debug!(?self.involving_tos);
+
+ // Check that `predecessors` satisfies its invariant.
+ debug_assert_eq!(self.predecessors, predecessor_count(body));
+
+ // Remove the TO from the vector to allow modifying the other ones later.
+ let op = &mut self.opportunities[index];
+ debug!(?op);
+ let op_chain = std::mem::take(&mut op.chain);
+ let op_target = op.target;
+ debug_assert_eq!(op_chain.len(), op_chain.iter().collect::<FxHashSet<_>>().len());
+
+ let Some((current, chain)) = op_chain.split_first() else { return };
+ let basic_blocks = body.basic_blocks.as_mut();
+
+ // Invariant: the control-flow is well-formed at the end of each iteration.
+ let mut current = *current;
+ for &succ in chain {
+ debug!(?current, ?succ);
+
+ // `succ` must be a successor of `current`. If it is not, this means this TO is not
+ // satisfiable and a previous TO erased this edge, so we bail out.
+ if basic_blocks[current].terminator().successors().find(|s| *s == succ).is_none() {
+ debug!("impossible");
+ return;
+ }
+
+ // Fast path: `succ` is only used once, so we can reuse it directly.
+ if self.predecessors[succ] == 1 {
+ debug!("single");
+ current = succ;
+ continue;
+ }
+
+ let new_succ = basic_blocks.push(basic_blocks[succ].clone());
+ debug!(?new_succ);
+
+ // Replace `succ` by `new_succ` where it appears.
+ let mut num_edges = 0;
+ for s in basic_blocks[current].terminator_mut().successors_mut() {
+ if *s == succ {
+ *s = new_succ;
+ num_edges += 1;
+ }
+ }
+
+ // Update predecessors with the new block.
+ let _new_succ = self.predecessors.push(num_edges);
+ debug_assert_eq!(new_succ, _new_succ);
+ self.predecessors[succ] -= num_edges;
+ self.update_predecessor_count(basic_blocks[new_succ].terminator(), Update::Incr);
+
+ // Replace the `current -> succ` edge by `current -> new_succ` in all the following
+ // TOs. This is necessary to avoid trying to thread through a non-existing edge. We
+ // use `involving_tos` here to avoid traversing the full set of TOs on each iteration.
+ let mut new_involved = Vec::new();
+ for &(to_index, in_to_index) in &self.involving_tos[current] {
+ // That TO has already been applied, do nothing.
+ if to_index <= index {
+ continue;
+ }
+
+ let other_to = &mut self.opportunities[to_index];
+ if other_to.chain.get(in_to_index) != Some(&current) {
+ continue;
+ }
+ let s = other_to.chain.get_mut(in_to_index + 1).unwrap_or(&mut other_to.target);
+ if *s == succ {
+ // `other_to` references the `current -> succ` edge, so replace `succ`.
+ *s = new_succ;
+ new_involved.push((to_index, in_to_index + 1));
+ }
+ }
+
+ // The TOs that we just updated now reference `new_succ`. Update `involving_tos`
+ // in case we need to duplicate an edge starting at `new_succ` later.
+ let _new_succ = self.involving_tos.push(new_involved);
+ debug_assert_eq!(new_succ, _new_succ);
+
+ current = new_succ;
+ }
+
+ let current = &mut basic_blocks[current];
+ self.update_predecessor_count(current.terminator(), Update::Decr);
+ current.terminator_mut().kind = TerminatorKind::Goto { target: op_target };
+ self.predecessors[op_target] += 1;
+ }
+
+ fn update_predecessor_count(&mut self, terminator: &Terminator<'_>, incr: Update) {
+ match incr {
+ Update::Incr => {
+ for s in terminator.successors() {
+ self.predecessors[s] += 1;
+ }
+ }
+ Update::Decr => {
+ for s in terminator.successors() {
+ self.predecessors[s] -= 1;
+ }
+ }
+ }
+ }
+}
+
+fn predecessor_count(body: &Body<'_>) -> IndexVec<BasicBlock, usize> {
+ let mut predecessors: IndexVec<_, _> =
+ body.basic_blocks.predecessors().iter().map(|ps| ps.len()).collect();
+ predecessors[START_BLOCK] += 1; // Account for the implicit entry edge.
+ predecessors
+}
+
+enum Update {
+ Incr,
+ Decr,
+}
+
+/// Compute the set of loop headers in the given body. We define a loop header as a block which has
+/// at least a predecessor which it dominates. This definition is only correct for reducible CFGs.
+/// But if the CFG is already irreducible, there is no point in trying much harder.
+/// is already irreducible.
+fn loop_headers(body: &Body<'_>) -> BitSet<BasicBlock> {
+ let mut loop_headers = BitSet::new_empty(body.basic_blocks.len());
+ let dominators = body.basic_blocks.dominators();
+ // Only visit reachable blocks.
+ for (bb, bbdata) in traversal::preorder(body) {
+ for succ in bbdata.terminator().successors() {
+ if dominators.dominates(succ, bb) {
+ loop_headers.insert(succ);
+ }
+ }
+ }
+ loop_headers
+}
diff --git a/compiler/rustc_mir_transform/src/large_enums.rs b/compiler/rustc_mir_transform/src/large_enums.rs
index 886ff7604..0a8b13d66 100644
--- a/compiler/rustc_mir_transform/src/large_enums.rs
+++ b/compiler/rustc_mir_transform/src/large_enums.rs
@@ -30,6 +30,9 @@ pub struct EnumSizeOpt {
impl<'tcx> MirPass<'tcx> for EnumSizeOpt {
fn is_enabled(&self, sess: &Session) -> bool {
+ // There are some differences in behavior on wasm and ARM that are not properly
+ // understood, so we conservatively treat this optimization as unsound:
+ // https://github.com/rust-lang/rust/pull/85158#issuecomment-1101836457
sess.opts.unstable_opts.unsound_mir_opts || sess.mir_opt_level() >= 3
}
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs
index c0a09b7a7..bf5f0ca7c 100644
--- a/compiler/rustc_mir_transform/src/lib.rs
+++ b/compiler/rustc_mir_transform/src/lib.rs
@@ -2,6 +2,7 @@
#![deny(rustc::untranslatable_diagnostic)]
#![deny(rustc::diagnostic_outside_of_impl)]
#![feature(box_patterns)]
+#![feature(cow_is_borrowed)]
#![feature(decl_macro)]
#![feature(is_sorted)]
#![feature(let_chains)]
@@ -20,6 +21,7 @@ extern crate tracing;
#[macro_use]
extern crate rustc_middle;
+use hir::ConstContext;
use required_consts::RequiredConstsVisitor;
use rustc_const_eval::util;
use rustc_data_structures::fx::FxIndexSet;
@@ -61,7 +63,10 @@ mod const_goto;
mod const_prop;
mod const_prop_lint;
mod copy_prop;
+mod coroutine;
+mod cost_checker;
mod coverage;
+mod cross_crate_inline;
mod ctfe_limit;
mod dataflow_const_prop;
mod dead_store_elimination;
@@ -76,10 +81,10 @@ mod elaborate_drops;
mod errors;
mod ffi_unwind_calls;
mod function_item_references;
-mod generator;
mod gvn;
pub mod inline;
mod instsimplify;
+mod jump_threading;
mod large_enums;
mod lower_intrinsics;
mod lower_slice_len;
@@ -123,6 +128,7 @@ pub fn provide(providers: &mut Providers) {
coverage::query::provide(providers);
ffi_unwind_calls::provide(providers);
shim::provide(providers);
+ cross_crate_inline::provide(providers);
*providers = Providers {
mir_keys,
mir_const,
@@ -130,7 +136,7 @@ pub fn provide(providers: &mut Providers) {
mir_promoted,
mir_drops_elaborated_and_const_checked,
mir_for_ctfe,
- mir_generator_witnesses: generator::mir_generator_witnesses,
+ mir_coroutine_witnesses: coroutine::mir_coroutine_witnesses,
optimized_mir,
is_mir_available,
is_ctfe_mir_available: |tcx, did| is_mir_available(tcx, did),
@@ -162,37 +168,50 @@ fn remap_mir_for_const_eval_select<'tcx>(
&& tcx.item_name(def_id) == sym::const_eval_select
&& tcx.is_intrinsic(def_id) =>
{
- let [tupled_args, called_in_const, called_at_rt]: [_; 3] = std::mem::take(args).try_into().unwrap();
+ let [tupled_args, called_in_const, called_at_rt]: [_; 3] =
+ std::mem::take(args).try_into().unwrap();
let ty = tupled_args.ty(&body.local_decls, tcx);
let fields = ty.tuple_fields();
let num_args = fields.len();
- let func = if context == hir::Constness::Const { called_in_const } else { called_at_rt };
- let (method, place): (fn(Place<'tcx>) -> Operand<'tcx>, Place<'tcx>) = match tupled_args {
- Operand::Constant(_) => {
- // there is no good way of extracting a tuple arg from a constant (const generic stuff)
- // so we just create a temporary and deconstruct that.
- let local = body.local_decls.push(LocalDecl::new(ty, fn_span));
- bb.statements.push(Statement {
- source_info: SourceInfo::outermost(fn_span),
- kind: StatementKind::Assign(Box::new((local.into(), Rvalue::Use(tupled_args.clone())))),
- });
- (Operand::Move, local.into())
- }
- Operand::Move(place) => (Operand::Move, place),
- Operand::Copy(place) => (Operand::Copy, place),
- };
- let place_elems = place.projection;
- let arguments = (0..num_args).map(|x| {
- let mut place_elems = place_elems.to_vec();
- place_elems.push(ProjectionElem::Field(x.into(), fields[x]));
- let projection = tcx.mk_place_elems(&place_elems);
- let place = Place {
- local: place.local,
- projection,
+ let func =
+ if context == hir::Constness::Const { called_in_const } else { called_at_rt };
+ let (method, place): (fn(Place<'tcx>) -> Operand<'tcx>, Place<'tcx>) =
+ match tupled_args {
+ Operand::Constant(_) => {
+ // there is no good way of extracting a tuple arg from a constant (const generic stuff)
+ // so we just create a temporary and deconstruct that.
+ let local = body.local_decls.push(LocalDecl::new(ty, fn_span));
+ bb.statements.push(Statement {
+ source_info: SourceInfo::outermost(fn_span),
+ kind: StatementKind::Assign(Box::new((
+ local.into(),
+ Rvalue::Use(tupled_args.clone()),
+ ))),
+ });
+ (Operand::Move, local.into())
+ }
+ Operand::Move(place) => (Operand::Move, place),
+ Operand::Copy(place) => (Operand::Copy, place),
};
- method(place)
- }).collect();
- terminator.kind = TerminatorKind::Call { func, args: arguments, destination, target, unwind, call_source: CallSource::Misc, fn_span };
+ let place_elems = place.projection;
+ let arguments = (0..num_args)
+ .map(|x| {
+ let mut place_elems = place_elems.to_vec();
+ place_elems.push(ProjectionElem::Field(x.into(), fields[x]));
+ let projection = tcx.mk_place_elems(&place_elems);
+ let place = Place { local: place.local, projection };
+ method(place)
+ })
+ .collect();
+ terminator.kind = TerminatorKind::Call {
+ func,
+ args: arguments,
+ destination,
+ target,
+ unwind,
+ call_source: CallSource::Misc,
+ fn_span,
+ };
}
_ => {}
}
@@ -234,8 +253,13 @@ fn mir_const_qualif(tcx: TyCtxt<'_>, def: LocalDefId) -> ConstQualifs {
let const_kind = tcx.hir().body_const_context(def);
// No need to const-check a non-const `fn`.
- if const_kind.is_none() {
- return Default::default();
+ match const_kind {
+ Some(ConstContext::Const { .. } | ConstContext::Static(_))
+ | Some(ConstContext::ConstFn) => {}
+ None => span_bug!(
+ tcx.def_span(def),
+ "`mir_const_qualif` should only be called on const fns and const items"
+ ),
}
// N.B., this `borrow()` is guaranteed to be valid (i.e., the value
@@ -300,7 +324,21 @@ fn mir_promoted(
// Ensure that we compute the `mir_const_qualif` for constants at
// this point, before we steal the mir-const result.
// Also this means promotion can rely on all const checks having been done.
- let const_qualifs = tcx.mir_const_qualif(def);
+
+ let const_qualifs = match tcx.def_kind(def) {
+ DefKind::Fn | DefKind::AssocFn | DefKind::Closure
+ if tcx.constness(def) == hir::Constness::Const
+ || tcx.is_const_default_method(def.to_def_id()) =>
+ {
+ tcx.mir_const_qualif(def)
+ }
+ DefKind::AssocConst
+ | DefKind::Const
+ | DefKind::Static(_)
+ | DefKind::InlineConst
+ | DefKind::AnonConst => tcx.mir_const_qualif(def),
+ _ => ConstQualifs::default(),
+ };
let mut body = tcx.mir_const(def).steal();
if let Some(error_reported) = const_qualifs.tainted_by_errors {
body.tainted_by_errors = Some(error_reported);
@@ -360,15 +398,15 @@ fn inner_mir_for_ctfe(tcx: TyCtxt<'_>, def: LocalDefId) -> Body<'_> {
/// mir borrowck *before* doing so in order to ensure that borrowck can be run and doesn't
/// end up missing the source MIR due to stealing happening.
fn mir_drops_elaborated_and_const_checked(tcx: TyCtxt<'_>, def: LocalDefId) -> &Steal<Body<'_>> {
- if let DefKind::Generator = tcx.def_kind(def) {
- tcx.ensure_with_value().mir_generator_witnesses(def);
+ if let DefKind::Coroutine = tcx.def_kind(def) {
+ tcx.ensure_with_value().mir_coroutine_witnesses(def);
}
let mir_borrowck = tcx.mir_borrowck(def);
let is_fn_like = tcx.def_kind(def).is_fn_like();
if is_fn_like {
// Do not compute the mir call graph without said call graph actually being used.
- if inline::Inline.is_enabled(&tcx.sess) {
+ if pm::should_run_pass(tcx, &inline::Inline) {
tcx.ensure_with_value().mir_inliner_callees(ty::InstanceDef::Item(def.to_def_id()));
}
}
@@ -494,9 +532,9 @@ fn run_runtime_lowering_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
// `AddRetag` needs to run after `ElaborateDrops`. Otherwise it should run fairly late,
// but before optimizations begin.
&elaborate_box_derefs::ElaborateBoxDerefs,
- &generator::StateTransform,
+ &coroutine::StateTransform,
&add_retag::AddRetag,
- &Lint(const_prop_lint::ConstProp),
+ &Lint(const_prop_lint::ConstPropLint),
];
pm::run_passes_no_validate(tcx, body, passes, Some(MirPhase::Runtime(RuntimePhase::Initial)));
}
@@ -530,10 +568,11 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&[
&check_alignment::CheckAlignment,
&lower_slice_len::LowerSliceLenCalls, // has to be done before inlining, otherwise actual call will be almost always inlined. Also simple, so can just do first
- &unreachable_prop::UnreachablePropagation,
+ &inline::Inline,
+ // Substitutions during inlining may introduce switch on enums with uninhabited branches.
&uninhabited_enum_branching::UninhabitedEnumBranching,
+ &unreachable_prop::UnreachablePropagation,
&o1(simplify::SimplifyCfg::AfterUninhabitedEnumBranching),
- &inline::Inline,
&remove_storage_markers::RemoveStorageMarkers,
&remove_zsts::RemoveZsts,
&normalize_array_len::NormalizeArrayLen, // has to run after `slice::len` lowering
@@ -553,11 +592,11 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&separate_const_switch::SeparateConstSwitch,
&const_prop::ConstProp,
&gvn::GVN,
+ &simplify::SimplifyLocals::AfterGVN,
&dataflow_const_prop::DataflowConstProp,
- //
- // Const-prop runs unconditionally, but doesn't mutate the MIR at mir-opt-level=0.
&const_debuginfo::ConstDebugInfo,
&o1(simplify_branches::SimplifyConstCondition::AfterConstProp),
+ &jump_threading::JumpThreading,
&early_otherwise_branch::EarlyOtherwiseBranch,
&simplify_comparison_integral::SimplifyComparisonIntegral,
&dead_store_elimination::DeadStoreElimination,
@@ -613,6 +652,15 @@ fn inner_optimized_mir(tcx: TyCtxt<'_>, did: LocalDefId) -> Body<'_> {
return body;
}
+ // If `mir_drops_elaborated_and_const_checked` found that the current body has unsatisfiable
+ // predicates, it will shrink the MIR to a single `unreachable` terminator.
+ // More generally, if MIR is a lone `unreachable`, there is nothing to optimize.
+ if let TerminatorKind::Unreachable = body.basic_blocks[START_BLOCK].terminator().kind
+ && body.basic_blocks[START_BLOCK].statements.is_empty()
+ {
+ return body;
+ }
+
run_optimization_passes(tcx, &mut body);
body
diff --git a/compiler/rustc_mir_transform/src/lower_intrinsics.rs b/compiler/rustc_mir_transform/src/lower_intrinsics.rs
index 0d2d764c4..5f3d8dfc6 100644
--- a/compiler/rustc_mir_transform/src/lower_intrinsics.rs
+++ b/compiler/rustc_mir_transform/src/lower_intrinsics.rs
@@ -2,9 +2,8 @@
use crate::MirPass;
use rustc_middle::mir::*;
-use rustc_middle::ty::GenericArgsRef;
-use rustc_middle::ty::{self, Ty, TyCtxt};
-use rustc_span::symbol::{sym, Symbol};
+use rustc_middle::ty::{self, TyCtxt};
+use rustc_span::symbol::sym;
use rustc_target::abi::{FieldIdx, VariantIdx};
pub struct LowerIntrinsics;
@@ -16,12 +15,10 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics {
let terminator = block.terminator.as_mut().unwrap();
if let TerminatorKind::Call { func, args, destination, target, .. } =
&mut terminator.kind
+ && let ty::FnDef(def_id, generic_args) = *func.ty(local_decls, tcx).kind()
+ && tcx.is_intrinsic(def_id)
{
- let func_ty = func.ty(local_decls, tcx);
- let Some((intrinsic_name, generic_args)) = resolve_rust_intrinsic(tcx, func_ty)
- else {
- continue;
- };
+ let intrinsic_name = tcx.item_name(def_id);
match intrinsic_name {
sym::unreachable => {
terminator.kind = TerminatorKind::Unreachable;
@@ -169,12 +166,16 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics {
let [arg] = args.as_slice() else {
span_bug!(terminator.source_info.span, "Wrong number of arguments");
};
- let derefed_place =
- if let Some(place) = arg.place() && let Some(local) = place.as_local() {
- tcx.mk_place_deref(local.into())
- } else {
- span_bug!(terminator.source_info.span, "Only passing a local is supported");
- };
+ let derefed_place = if let Some(place) = arg.place()
+ && let Some(local) = place.as_local()
+ {
+ tcx.mk_place_deref(local.into())
+ } else {
+ span_bug!(
+ terminator.source_info.span,
+ "Only passing a local is supported"
+ );
+ };
// Add new statement at the end of the block that does the read, and patch
// up the terminator.
block.statements.push(Statement {
@@ -201,12 +202,16 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics {
"Wrong number of arguments for write_via_move intrinsic",
);
};
- let derefed_place =
- if let Some(place) = ptr.place() && let Some(local) = place.as_local() {
- tcx.mk_place_deref(local.into())
- } else {
- span_bug!(terminator.source_info.span, "Only passing a local is supported");
- };
+ let derefed_place = if let Some(place) = ptr.place()
+ && let Some(local) = place.as_local()
+ {
+ tcx.mk_place_deref(local.into())
+ } else {
+ span_bug!(
+ terminator.source_info.span,
+ "Only passing a local is supported"
+ );
+ };
block.statements.push(Statement {
source_info: terminator.source_info,
kind: StatementKind::Assign(Box::new((
@@ -309,15 +314,3 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics {
}
}
}
-
-fn resolve_rust_intrinsic<'tcx>(
- tcx: TyCtxt<'tcx>,
- func_ty: Ty<'tcx>,
-) -> Option<(Symbol, GenericArgsRef<'tcx>)> {
- if let ty::FnDef(def_id, args) = *func_ty.kind() {
- if tcx.is_intrinsic(def_id) {
- return Some((tcx.item_name(def_id), args));
- }
- }
- None
-}
diff --git a/compiler/rustc_mir_transform/src/lower_slice_len.rs b/compiler/rustc_mir_transform/src/lower_slice_len.rs
index b7cc0db95..ae4878411 100644
--- a/compiler/rustc_mir_transform/src/lower_slice_len.rs
+++ b/compiler/rustc_mir_transform/src/lower_slice_len.rs
@@ -34,67 +34,43 @@ pub fn lower_slice_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
}
}
-struct SliceLenPatchInformation<'tcx> {
- add_statement: Statement<'tcx>,
- new_terminator_kind: TerminatorKind<'tcx>,
-}
-
fn lower_slice_len_call<'tcx>(
tcx: TyCtxt<'tcx>,
block: &mut BasicBlockData<'tcx>,
local_decls: &IndexSlice<Local, LocalDecl<'tcx>>,
slice_len_fn_item_def_id: DefId,
) {
- let mut patch_found: Option<SliceLenPatchInformation<'_>> = None;
-
let terminator = block.terminator();
- match &terminator.kind {
- TerminatorKind::Call {
- func,
- args,
- destination,
- target: Some(bb),
- call_source: CallSource::Normal,
- ..
- } => {
- // some heuristics for fast rejection
- if args.len() != 1 {
- return;
- }
- let Some(arg) = args[0].place() else { return };
- let func_ty = func.ty(local_decls, tcx);
- match func_ty.kind() {
- ty::FnDef(fn_def_id, _) if fn_def_id == &slice_len_fn_item_def_id => {
- // perform modifications
- // from something like `_5 = core::slice::<impl [u8]>::len(move _6) -> bb1`
- // into:
- // ```
- // _5 = Len(*_6)
- // goto bb1
- // ```
+ if let TerminatorKind::Call {
+ func,
+ args,
+ destination,
+ target: Some(bb),
+ call_source: CallSource::Normal,
+ ..
+ } = &terminator.kind
+ // some heuristics for fast rejection
+ && let [arg] = &args[..]
+ && let Some(arg) = arg.place()
+ && let ty::FnDef(fn_def_id, _) = func.ty(local_decls, tcx).kind()
+ && *fn_def_id == slice_len_fn_item_def_id
+ {
+ // perform modifications from something like:
+ // _5 = core::slice::<impl [u8]>::len(move _6) -> bb1
+ // into:
+ // _5 = Len(*_6)
+ // goto bb1
- // make new RValue for Len
- let deref_arg = tcx.mk_place_deref(arg);
- let r_value = Rvalue::Len(deref_arg);
- let len_statement_kind =
- StatementKind::Assign(Box::new((*destination, r_value)));
- let add_statement =
- Statement { kind: len_statement_kind, source_info: terminator.source_info };
+ // make new RValue for Len
+ let deref_arg = tcx.mk_place_deref(arg);
+ let r_value = Rvalue::Len(deref_arg);
+ let len_statement_kind = StatementKind::Assign(Box::new((*destination, r_value)));
+ let add_statement =
+ Statement { kind: len_statement_kind, source_info: terminator.source_info };
- // modify terminator into simple Goto
- let new_terminator_kind = TerminatorKind::Goto { target: *bb };
-
- let patch = SliceLenPatchInformation { add_statement, new_terminator_kind };
-
- patch_found = Some(patch);
- }
- _ => {}
- }
- }
- _ => {}
- }
+ // modify terminator into simple Goto
+ let new_terminator_kind = TerminatorKind::Goto { target: *bb };
- if let Some(SliceLenPatchInformation { add_statement, new_terminator_kind }) = patch_found {
block.statements.push(add_statement);
block.terminator_mut().kind = new_terminator_kind;
}
diff --git a/compiler/rustc_mir_transform/src/multiple_return_terminators.rs b/compiler/rustc_mir_transform/src/multiple_return_terminators.rs
index c97d03454..c9b42e75c 100644
--- a/compiler/rustc_mir_transform/src/multiple_return_terminators.rs
+++ b/compiler/rustc_mir_transform/src/multiple_return_terminators.rs
@@ -38,6 +38,6 @@ impl<'tcx> MirPass<'tcx> for MultipleReturnTerminators {
}
}
- simplify::remove_dead_blocks(tcx, body)
+ simplify::remove_dead_blocks(body)
}
}
diff --git a/compiler/rustc_mir_transform/src/normalize_array_len.rs b/compiler/rustc_mir_transform/src/normalize_array_len.rs
index d1a4b26a0..206cdf9fe 100644
--- a/compiler/rustc_mir_transform/src/normalize_array_len.rs
+++ b/compiler/rustc_mir_transform/src/normalize_array_len.rs
@@ -57,7 +57,9 @@ fn compute_slice_length<'tcx>(
}
// The length information is stored in the fat pointer, so we treat `operand` as a value.
Rvalue::Use(operand) => {
- if let Some(rhs) = operand.place() && let Some(rhs) = rhs.as_local() {
+ if let Some(rhs) = operand.place()
+ && let Some(rhs) = rhs.as_local()
+ {
slice_lengths[local] = slice_lengths[rhs];
}
}
diff --git a/compiler/rustc_mir_transform/src/nrvo.rs b/compiler/rustc_mir_transform/src/nrvo.rs
index e1298b065..ff309bd10 100644
--- a/compiler/rustc_mir_transform/src/nrvo.rs
+++ b/compiler/rustc_mir_transform/src/nrvo.rs
@@ -34,7 +34,7 @@ pub struct RenameReturnPlace;
impl<'tcx> MirPass<'tcx> for RenameReturnPlace {
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
- // #111005
+ // unsound: #111005
sess.mir_opt_level() > 0 && sess.opts.unstable_opts.unsound_mir_opts
}
diff --git a/compiler/rustc_mir_transform/src/pass_manager.rs b/compiler/rustc_mir_transform/src/pass_manager.rs
index 5abb2f3d0..a8aba29ad 100644
--- a/compiler/rustc_mir_transform/src/pass_manager.rs
+++ b/compiler/rustc_mir_transform/src/pass_manager.rs
@@ -83,6 +83,25 @@ pub fn run_passes<'tcx>(
run_passes_inner(tcx, body, passes, phase_change, true);
}
+pub fn should_run_pass<'tcx, P>(tcx: TyCtxt<'tcx>, pass: &P) -> bool
+where
+ P: MirPass<'tcx> + ?Sized,
+{
+ let name = pass.name();
+
+ let overridden_passes = &tcx.sess.opts.unstable_opts.mir_enable_passes;
+ let overridden =
+ overridden_passes.iter().rev().find(|(s, _)| s == &*name).map(|(_name, polarity)| {
+ trace!(
+ pass = %name,
+ "{} as requested by flag",
+ if *polarity { "Running" } else { "Not running" },
+ );
+ *polarity
+ });
+ overridden.unwrap_or_else(|| pass.is_enabled(&tcx.sess))
+}
+
fn run_passes_inner<'tcx>(
tcx: TyCtxt<'tcx>,
body: &mut Body<'tcx>,
@@ -100,19 +119,9 @@ fn run_passes_inner<'tcx>(
for pass in passes {
let name = pass.name();
- let overridden = overridden_passes.iter().rev().find(|(s, _)| s == &*name).map(
- |(_name, polarity)| {
- trace!(
- pass = %name,
- "{} as requested by flag",
- if *polarity { "Running" } else { "Not running" },
- );
- *polarity
- },
- );
- if !overridden.unwrap_or_else(|| pass.is_enabled(&tcx.sess)) {
+ if !should_run_pass(tcx, *pass) {
continue;
- }
+ };
let dump_enabled = pass.is_mir_dump_enabled();
diff --git a/compiler/rustc_mir_transform/src/ref_prop.rs b/compiler/rustc_mir_transform/src/ref_prop.rs
index 67941cf43..df39c819b 100644
--- a/compiler/rustc_mir_transform/src/ref_prop.rs
+++ b/compiler/rustc_mir_transform/src/ref_prop.rs
@@ -210,14 +210,17 @@ fn compute_replacement<'tcx>(
// have been visited before.
Rvalue::Use(Operand::Copy(place) | Operand::Move(place))
| Rvalue::CopyForDeref(place) => {
- if let Some(rhs) = place.as_local() && ssa.is_ssa(rhs) {
+ if let Some(rhs) = place.as_local()
+ && ssa.is_ssa(rhs)
+ {
let target = targets[rhs];
// Only see through immutable reference and pointers, as we do not know yet if
// mutable references are fully replaced.
if !needs_unique && matches!(target, Value::Pointer(..)) {
targets[local] = target;
} else {
- targets[local] = Value::Pointer(tcx.mk_place_deref(rhs.into()), needs_unique);
+ targets[local] =
+ Value::Pointer(tcx.mk_place_deref(rhs.into()), needs_unique);
}
}
}
@@ -365,7 +368,7 @@ impl<'tcx> MutVisitor<'tcx> for Replacer<'tcx> {
*place = Place::from(target.local).project_deeper(rest, self.tcx);
self.any_replacement = true;
} else {
- break
+ break;
}
}
diff --git a/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs b/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs
index 8c48a6677..54892442c 100644
--- a/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs
+++ b/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs
@@ -69,7 +69,7 @@ impl RemoveNoopLandingPads {
| TerminatorKind::FalseUnwind { .. } => {
terminator.successors().all(|succ| nop_landing_pads.contains(succ))
}
- TerminatorKind::GeneratorDrop
+ TerminatorKind::CoroutineDrop
| TerminatorKind::Yield { .. }
| TerminatorKind::Return
| TerminatorKind::UnwindTerminate(_)
diff --git a/compiler/rustc_mir_transform/src/remove_uninit_drops.rs b/compiler/rustc_mir_transform/src/remove_uninit_drops.rs
index 263849747..87fee2410 100644
--- a/compiler/rustc_mir_transform/src/remove_uninit_drops.rs
+++ b/compiler/rustc_mir_transform/src/remove_uninit_drops.rs
@@ -24,11 +24,8 @@ pub struct RemoveUninitDrops;
impl<'tcx> MirPass<'tcx> for RemoveUninitDrops {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let param_env = tcx.param_env(body.source.def_id());
- let Ok(move_data) = MoveData::gather_moves(body, tcx, param_env) else {
- // We could continue if there are move errors, but there's not much point since our
- // init data isn't complete.
- return;
- };
+ let move_data =
+ MoveData::gather_moves(&body, tcx, param_env, |ty| ty.needs_drop(tcx, param_env));
let mdpe = MoveDataParamEnv { move_data, param_env };
let mut maybe_inits = MaybeInitializedPlaces::new(tcx, body, &mdpe)
diff --git a/compiler/rustc_mir_transform/src/remove_zsts.rs b/compiler/rustc_mir_transform/src/remove_zsts.rs
index a34d4b027..5aa3c3cfe 100644
--- a/compiler/rustc_mir_transform/src/remove_zsts.rs
+++ b/compiler/rustc_mir_transform/src/remove_zsts.rs
@@ -13,8 +13,8 @@ impl<'tcx> MirPass<'tcx> for RemoveZsts {
}
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
- // Avoid query cycles (generators require optimized MIR for layout).
- if tcx.type_of(body.source.def_id()).instantiate_identity().is_generator() {
+ // Avoid query cycles (coroutines require optimized MIR for layout).
+ if tcx.type_of(body.source.def_id()).instantiate_identity().is_coroutine() {
return;
}
let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id());
@@ -126,7 +126,10 @@ impl<'tcx> MutVisitor<'tcx> for Replacer<'_, 'tcx> {
&& let ty = place_for_ty.ty(self.local_decls, self.tcx).ty
&& self.known_to_be_zst(ty)
&& self.tcx.consider_optimizing(|| {
- format!("RemoveZsts - Place: {:?} SourceInfo: {:?}", place_for_ty, statement.source_info)
+ format!(
+ "RemoveZsts - Place: {:?} SourceInfo: {:?}",
+ place_for_ty, statement.source_info
+ )
})
{
statement.make_nop();
diff --git a/compiler/rustc_mir_transform/src/separate_const_switch.rs b/compiler/rustc_mir_transform/src/separate_const_switch.rs
index e1e4acccc..907cfe758 100644
--- a/compiler/rustc_mir_transform/src/separate_const_switch.rs
+++ b/compiler/rustc_mir_transform/src/separate_const_switch.rs
@@ -118,7 +118,7 @@ pub fn separate_const_switch(body: &mut Body<'_>) -> usize {
| TerminatorKind::Return
| TerminatorKind::Unreachable
| TerminatorKind::InlineAsm { .. }
- | TerminatorKind::GeneratorDrop => {
+ | TerminatorKind::CoroutineDrop => {
continue 'predec_iter;
}
}
@@ -169,7 +169,7 @@ pub fn separate_const_switch(body: &mut Body<'_>) -> usize {
| TerminatorKind::UnwindTerminate(_)
| TerminatorKind::Return
| TerminatorKind::Unreachable
- | TerminatorKind::GeneratorDrop
+ | TerminatorKind::CoroutineDrop
| TerminatorKind::Assert { .. }
| TerminatorKind::FalseUnwind { .. }
| TerminatorKind::Drop { .. }
diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs
index e9895d97d..ab7961321 100644
--- a/compiler/rustc_mir_transform/src/shim.rs
+++ b/compiler/rustc_mir_transform/src/shim.rs
@@ -4,7 +4,7 @@ use rustc_hir::lang_items::LangItem;
use rustc_middle::mir::*;
use rustc_middle::query::Providers;
use rustc_middle::ty::GenericArgs;
-use rustc_middle::ty::{self, EarlyBinder, GeneratorArgs, Ty, TyCtxt};
+use rustc_middle::ty::{self, CoroutineArgs, EarlyBinder, Ty, TyCtxt};
use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT};
use rustc_index::{Idx, IndexVec};
@@ -67,18 +67,20 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<'
}
ty::InstanceDef::DropGlue(def_id, ty) => {
- // FIXME(#91576): Drop shims for generators aren't subject to the MIR passes at the end
+ // FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end
// of this function. Is this intentional?
- if let Some(ty::Generator(gen_def_id, args, _)) = ty.map(Ty::kind) {
- let body = tcx.optimized_mir(*gen_def_id).generator_drop().unwrap();
+ if let Some(ty::Coroutine(coroutine_def_id, args, _)) = ty.map(Ty::kind) {
+ let body = tcx.optimized_mir(*coroutine_def_id).coroutine_drop().unwrap();
let mut body = EarlyBinder::bind(body.clone()).instantiate(tcx, args);
debug!("make_shim({:?}) = {:?}", instance, body);
- // Run empty passes to mark phase change and perform validation.
pm::run_passes(
tcx,
&mut body,
- &[],
+ &[
+ &abort_unwinding_calls::AbortUnwindingCalls,
+ &add_call_guards::CriticalCallEdges,
+ ],
Some(MirPhase::Runtime(RuntimePhase::Optimized)),
);
@@ -171,7 +173,7 @@ fn local_decls_for_sig<'tcx>(
fn build_drop_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, ty: Option<Ty<'tcx>>) -> Body<'tcx> {
debug!("build_drop_shim(def_id={:?}, ty={:?})", def_id, ty);
- assert!(!matches!(ty, Some(ty) if ty.is_generator()));
+ assert!(!matches!(ty, Some(ty) if ty.is_coroutine()));
let args = if let Some(ty) = ty {
tcx.mk_args(&[ty.into()])
@@ -392,8 +394,8 @@ fn build_clone_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'tcx>) -
_ if is_copy => builder.copy_shim(),
ty::Closure(_, args) => builder.tuple_like_shim(dest, src, args.as_closure().upvar_tys()),
ty::Tuple(..) => builder.tuple_like_shim(dest, src, self_ty.tuple_fields()),
- ty::Generator(gen_def_id, args, hir::Movability::Movable) => {
- builder.generator_shim(dest, src, *gen_def_id, args.as_generator())
+ ty::Coroutine(coroutine_def_id, args, hir::Movability::Movable) => {
+ builder.coroutine_shim(dest, src, *coroutine_def_id, args.as_coroutine())
}
_ => bug!("clone shim for `{:?}` which is not `Copy` and is not an aggregate", self_ty),
};
@@ -593,12 +595,12 @@ impl<'tcx> CloneShimBuilder<'tcx> {
let _final_cleanup_block = self.clone_fields(dest, src, target, unwind, tys);
}
- fn generator_shim(
+ fn coroutine_shim(
&mut self,
dest: Place<'tcx>,
src: Place<'tcx>,
- gen_def_id: DefId,
- args: GeneratorArgs<'tcx>,
+ coroutine_def_id: DefId,
+ args: CoroutineArgs<'tcx>,
) {
self.block(vec![], TerminatorKind::Goto { target: self.block_index_offset(3) }, false);
let unwind = self.block(vec![], TerminatorKind::UnwindResume, true);
@@ -607,8 +609,8 @@ impl<'tcx> CloneShimBuilder<'tcx> {
let unwind = self.clone_fields(dest, src, switch, unwind, args.upvar_tys());
let target = self.block(vec![], TerminatorKind::Return, false);
let unreachable = self.block(vec![], TerminatorKind::Unreachable, false);
- let mut cases = Vec::with_capacity(args.state_tys(gen_def_id, self.tcx).count());
- for (index, state_tys) in args.state_tys(gen_def_id, self.tcx).enumerate() {
+ let mut cases = Vec::with_capacity(args.state_tys(coroutine_def_id, self.tcx).count());
+ for (index, state_tys) in args.state_tys(coroutine_def_id, self.tcx).enumerate() {
let variant_index = VariantIdx::new(index);
let dest = self.tcx.mk_place_downcast_unnamed(dest, variant_index);
let src = self.tcx.mk_place_downcast_unnamed(src, variant_index);
diff --git a/compiler/rustc_mir_transform/src/simplify.rs b/compiler/rustc_mir_transform/src/simplify.rs
index 2795cf157..0a1c01114 100644
--- a/compiler/rustc_mir_transform/src/simplify.rs
+++ b/compiler/rustc_mir_transform/src/simplify.rs
@@ -28,10 +28,8 @@
//! return.
use crate::MirPass;
-use rustc_data_structures::fx::{FxHashSet, FxIndexSet};
-use rustc_index::bit_set::BitSet;
+use rustc_data_structures::fx::FxIndexSet;
use rustc_index::{Idx, IndexSlice, IndexVec};
-use rustc_middle::mir::coverage::*;
use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
use rustc_middle::mir::*;
use rustc_middle::ty::TyCtxt;
@@ -68,7 +66,7 @@ impl SimplifyCfg {
pub fn simplify_cfg<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
CfgSimplifier::new(body).simplify();
remove_duplicate_unreachable_blocks(tcx, body);
- remove_dead_blocks(tcx, body);
+ remove_dead_blocks(body);
// FIXME: Should probably be moved into some kind of pass manager
body.basic_blocks_mut().raw.shrink_to_fit();
@@ -337,7 +335,7 @@ pub fn remove_duplicate_unreachable_blocks<'tcx>(tcx: TyCtxt<'tcx>, body: &mut B
}
}
-pub fn remove_dead_blocks<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+pub fn remove_dead_blocks(body: &mut Body<'_>) {
let reachable = traversal::reachable_as_bitset(body);
let num_blocks = body.basic_blocks.len();
if num_blocks == reachable.count() {
@@ -345,10 +343,6 @@ pub fn remove_dead_blocks<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
}
let basic_blocks = body.basic_blocks.as_mut();
- let source_scopes = &body.source_scopes;
- if tcx.sess.instrument_coverage() {
- save_unreachable_coverage(basic_blocks, source_scopes, &reachable);
- }
let mut replacements: Vec<_> = (0..num_blocks).map(BasicBlock::new).collect();
let mut orig_index = 0;
@@ -370,99 +364,9 @@ pub fn remove_dead_blocks<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
}
}
-/// Some MIR transforms can determine at compile time that a sequences of
-/// statements will never be executed, so they can be dropped from the MIR.
-/// For example, an `if` or `else` block that is guaranteed to never be executed
-/// because its condition can be evaluated at compile time, such as by const
-/// evaluation: `if false { ... }`.
-///
-/// Those statements are bypassed by redirecting paths in the CFG around the
-/// `dead blocks`; but with `-C instrument-coverage`, the dead blocks usually
-/// include `Coverage` statements representing the Rust source code regions to
-/// be counted at runtime. Without these `Coverage` statements, the regions are
-/// lost, and the Rust source code will show no coverage information.
-///
-/// What we want to show in a coverage report is the dead code with coverage
-/// counts of `0`. To do this, we need to save the code regions, by injecting
-/// `Unreachable` coverage statements. These are non-executable statements whose
-/// code regions are still recorded in the coverage map, representing regions
-/// with `0` executions.
-///
-/// If there are no live `Counter` `Coverage` statements remaining, we remove
-/// `Coverage` statements along with the dead blocks. Since at least one
-/// counter per function is required by LLVM (and necessary, to add the
-/// `function_hash` to the counter's call to the LLVM intrinsic
-/// `instrprof.increment()`).
-///
-/// The `generator::StateTransform` MIR pass and MIR inlining can create
-/// atypical conditions, where all live `Counter`s are dropped from the MIR.
-///
-/// With MIR inlining we can have coverage counters belonging to different
-/// instances in a single body, so the strategy described above is applied to
-/// coverage counters from each instance individually.
-fn save_unreachable_coverage(
- basic_blocks: &mut IndexSlice<BasicBlock, BasicBlockData<'_>>,
- source_scopes: &IndexSlice<SourceScope, SourceScopeData<'_>>,
- reachable: &BitSet<BasicBlock>,
-) {
- // Identify instances that still have some live coverage counters left.
- let mut live = FxHashSet::default();
- for bb in reachable.iter() {
- let basic_block = &basic_blocks[bb];
- for statement in &basic_block.statements {
- let StatementKind::Coverage(coverage) = &statement.kind else { continue };
- let CoverageKind::Counter { .. } = coverage.kind else { continue };
- let instance = statement.source_info.scope.inlined_instance(source_scopes);
- live.insert(instance);
- }
- }
-
- for bb in reachable.iter() {
- let block = &mut basic_blocks[bb];
- for statement in &mut block.statements {
- let StatementKind::Coverage(_) = &statement.kind else { continue };
- let instance = statement.source_info.scope.inlined_instance(source_scopes);
- if !live.contains(&instance) {
- statement.make_nop();
- }
- }
- }
-
- if live.is_empty() {
- return;
- }
-
- // Retain coverage for instances that still have some live counters left.
- let mut retained_coverage = Vec::new();
- for dead_block in basic_blocks.indices() {
- if reachable.contains(dead_block) {
- continue;
- }
- let dead_block = &basic_blocks[dead_block];
- for statement in &dead_block.statements {
- let StatementKind::Coverage(coverage) = &statement.kind else { continue };
- let Some(code_region) = &coverage.code_region else { continue };
- let instance = statement.source_info.scope.inlined_instance(source_scopes);
- if live.contains(&instance) {
- retained_coverage.push((statement.source_info, code_region.clone()));
- }
- }
- }
-
- let start_block = &mut basic_blocks[START_BLOCK];
- start_block.statements.extend(retained_coverage.into_iter().map(
- |(source_info, code_region)| Statement {
- source_info,
- kind: StatementKind::Coverage(Box::new(Coverage {
- kind: CoverageKind::Unreachable,
- code_region: Some(code_region),
- })),
- },
- ));
-}
-
pub enum SimplifyLocals {
BeforeConstProp,
+ AfterGVN,
Final,
}
@@ -470,6 +374,7 @@ impl<'tcx> MirPass<'tcx> for SimplifyLocals {
fn name(&self) -> &'static str {
match &self {
SimplifyLocals::BeforeConstProp => "SimplifyLocals-before-const-prop",
+ SimplifyLocals::AfterGVN => "SimplifyLocals-after-value-numbering",
SimplifyLocals::Final => "SimplifyLocals-final",
}
}
diff --git a/compiler/rustc_mir_transform/src/simplify_branches.rs b/compiler/rustc_mir_transform/src/simplify_branches.rs
index b508cd1c9..1f0e605c3 100644
--- a/compiler/rustc_mir_transform/src/simplify_branches.rs
+++ b/compiler/rustc_mir_transform/src/simplify_branches.rs
@@ -16,8 +16,25 @@ impl<'tcx> MirPass<'tcx> for SimplifyConstCondition {
}
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+ trace!("Running SimplifyConstCondition on {:?}", body.source);
let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id());
- for block in body.basic_blocks_mut() {
+ 'blocks: for block in body.basic_blocks_mut() {
+ for stmt in block.statements.iter_mut() {
+ if let StatementKind::Intrinsic(box ref intrinsic) = stmt.kind
+ && let NonDivergingIntrinsic::Assume(discr) = intrinsic
+ && let Operand::Constant(ref c) = discr
+ && let Some(constant) = c.const_.try_eval_bool(tcx, param_env)
+ {
+ if constant {
+ stmt.make_nop();
+ } else {
+ block.statements.clear();
+ block.terminator_mut().kind = TerminatorKind::Unreachable;
+ continue 'blocks;
+ }
+ }
+ }
+
let terminator = block.terminator_mut();
terminator.kind = match terminator.kind {
TerminatorKind::SwitchInt {
diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs
index c21b1724c..7de4ca667 100644
--- a/compiler/rustc_mir_transform/src/sroa.rs
+++ b/compiler/rustc_mir_transform/src/sroa.rs
@@ -7,7 +7,7 @@ use rustc_middle::mir::visit::*;
use rustc_middle::mir::*;
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields};
-use rustc_target::abi::{FieldIdx, ReprFlags, FIRST_VARIANT};
+use rustc_target::abi::{FieldIdx, FIRST_VARIANT};
pub struct ScalarReplacementOfAggregates;
@@ -20,8 +20,8 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
debug!(def_id = ?body.source.def_id());
- // Avoid query cycles (generators require optimized MIR for layout).
- if tcx.type_of(body.source.def_id()).instantiate_identity().is_generator() {
+ // Avoid query cycles (coroutines require optimized MIR for layout).
+ if tcx.type_of(body.source.def_id()).instantiate_identity().is_coroutine() {
return;
}
@@ -66,7 +66,7 @@ fn escaping_locals<'tcx>(
return true;
}
if let ty::Adt(def, _args) = ty.kind() {
- if def.repr().flags.contains(ReprFlags::IS_SIMD) {
+ if def.repr().simd() {
// Exclude #[repr(simd)] types so that they are not de-optimized into an array
return true;
}
diff --git a/compiler/rustc_mir_transform/src/ssa.rs b/compiler/rustc_mir_transform/src/ssa.rs
index 43fc1b7b9..1f59c790b 100644
--- a/compiler/rustc_mir_transform/src/ssa.rs
+++ b/compiler/rustc_mir_transform/src/ssa.rs
@@ -5,7 +5,6 @@
//! As a consequence of rule 2, we consider that borrowed locals are not SSA, even if they are
//! `Freeze`, as we do not track that the assignment dominates all uses of the borrow.
-use either::Either;
use rustc_data_structures::graph::dominators::Dominators;
use rustc_index::bit_set::BitSet;
use rustc_index::{IndexSlice, IndexVec};
@@ -15,7 +14,7 @@ use rustc_middle::mir::*;
pub struct SsaLocals {
/// Assignments to each local. This defines whether the local is SSA.
- assignments: IndexVec<Local, Set1<LocationExtended>>,
+ assignments: IndexVec<Local, Set1<DefLocation>>,
/// We visit the body in reverse postorder, to ensure each local is assigned before it is used.
/// We remember the order in which we saw the assignments to compute the SSA values in a single
/// pass.
@@ -27,39 +26,10 @@ pub struct SsaLocals {
direct_uses: IndexVec<Local, u32>,
}
-/// We often encounter MIR bodies with 1 or 2 basic blocks. In those cases, it's unnecessary to
-/// actually compute dominators, we can just compare block indices because bb0 is always the first
-/// block, and in any body all other blocks are always dominated by bb0.
-struct SmallDominators<'a> {
- inner: Option<&'a Dominators<BasicBlock>>,
-}
-
-impl SmallDominators<'_> {
- fn dominates(&self, first: Location, second: Location) -> bool {
- if first.block == second.block {
- first.statement_index <= second.statement_index
- } else if let Some(inner) = &self.inner {
- inner.dominates(first.block, second.block)
- } else {
- first.block < second.block
- }
- }
-
- fn check_dominates(&mut self, set: &mut Set1<LocationExtended>, loc: Location) {
- let assign_dominates = match *set {
- Set1::Empty | Set1::Many => false,
- Set1::One(LocationExtended::Arg) => true,
- Set1::One(LocationExtended::Plain(assign)) => {
- self.dominates(assign.successor_within_block(), loc)
- }
- };
- // We are visiting a use that is not dominated by an assignment.
- // Either there is a cycle involved, or we are reading for uninitialized local.
- // Bail out.
- if !assign_dominates {
- *set = Set1::Many;
- }
- }
+pub enum AssignedValue<'a, 'tcx> {
+ Arg,
+ Rvalue(&'a mut Rvalue<'tcx>),
+ Terminator(&'a mut TerminatorKind<'tcx>),
}
impl SsaLocals {
@@ -67,15 +37,14 @@ impl SsaLocals {
let assignment_order = Vec::with_capacity(body.local_decls.len());
let assignments = IndexVec::from_elem(Set1::Empty, &body.local_decls);
- let dominators =
- if body.basic_blocks.len() > 2 { Some(body.basic_blocks.dominators()) } else { None };
- let dominators = SmallDominators { inner: dominators };
+ let dominators = body.basic_blocks.dominators();
let direct_uses = IndexVec::from_elem(0, &body.local_decls);
let mut visitor = SsaVisitor { assignments, assignment_order, dominators, direct_uses };
for local in body.args_iter() {
- visitor.assignments[local] = Set1::One(LocationExtended::Arg);
+ visitor.assignments[local] = Set1::One(DefLocation::Argument);
+ visitor.assignment_order.push(local);
}
// For SSA assignments, a RPO visit will see the assignment before it sees any use.
@@ -131,14 +100,7 @@ impl SsaLocals {
location: Location,
) -> bool {
match self.assignments[local] {
- Set1::One(LocationExtended::Arg) => true,
- Set1::One(LocationExtended::Plain(ass)) => {
- if ass.block == location.block {
- ass.statement_index < location.statement_index
- } else {
- dominators.dominates(ass.block, location.block)
- }
- }
+ Set1::One(def) => def.dominates(location, dominators),
_ => false,
}
}
@@ -148,9 +110,9 @@ impl SsaLocals {
body: &'a Body<'tcx>,
) -> impl Iterator<Item = (Local, &'a Rvalue<'tcx>, Location)> + 'a {
self.assignment_order.iter().filter_map(|&local| {
- if let Set1::One(LocationExtended::Plain(loc)) = self.assignments[local] {
+ if let Set1::One(DefLocation::Body(loc)) = self.assignments[local] {
+ let stmt = body.stmt_at(loc).left()?;
// `loc` must point to a direct assignment to `local`.
- let Either::Left(stmt) = body.stmt_at(loc) else { bug!() };
let Some((target, rvalue)) = stmt.kind.as_assign() else { bug!() };
assert_eq!(target.as_local(), Some(local));
Some((local, rvalue, loc))
@@ -162,18 +124,33 @@ impl SsaLocals {
pub fn for_each_assignment_mut<'tcx>(
&self,
- basic_blocks: &mut BasicBlocks<'tcx>,
- mut f: impl FnMut(Local, &mut Rvalue<'tcx>, Location),
+ basic_blocks: &mut IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
+ mut f: impl FnMut(Local, AssignedValue<'_, 'tcx>, Location),
) {
for &local in &self.assignment_order {
- if let Set1::One(LocationExtended::Plain(loc)) = self.assignments[local] {
- // `loc` must point to a direct assignment to `local`.
- let bbs = basic_blocks.as_mut_preserves_cfg();
- let bb = &mut bbs[loc.block];
- let stmt = &mut bb.statements[loc.statement_index];
- let StatementKind::Assign(box (target, ref mut rvalue)) = stmt.kind else { bug!() };
- assert_eq!(target.as_local(), Some(local));
- f(local, rvalue, loc)
+ match self.assignments[local] {
+ Set1::One(DefLocation::Argument) => f(
+ local,
+ AssignedValue::Arg,
+ Location { block: START_BLOCK, statement_index: 0 },
+ ),
+ Set1::One(DefLocation::Body(loc)) => {
+ let bb = &mut basic_blocks[loc.block];
+ let value = if loc.statement_index < bb.statements.len() {
+ // `loc` must point to a direct assignment to `local`.
+ let stmt = &mut bb.statements[loc.statement_index];
+ let StatementKind::Assign(box (target, ref mut rvalue)) = stmt.kind else {
+ bug!()
+ };
+ assert_eq!(target.as_local(), Some(local));
+ AssignedValue::Rvalue(rvalue)
+ } else {
+ let term = bb.terminator_mut();
+ AssignedValue::Terminator(&mut term.kind)
+ };
+ f(local, value, loc)
+ }
+ _ => {}
}
}
}
@@ -224,19 +201,29 @@ impl SsaLocals {
}
}
-#[derive(Copy, Clone, Debug, PartialEq, Eq)]
-enum LocationExtended {
- Plain(Location),
- Arg,
-}
-
struct SsaVisitor<'a> {
- dominators: SmallDominators<'a>,
- assignments: IndexVec<Local, Set1<LocationExtended>>,
+ dominators: &'a Dominators<BasicBlock>,
+ assignments: IndexVec<Local, Set1<DefLocation>>,
assignment_order: Vec<Local>,
direct_uses: IndexVec<Local, u32>,
}
+impl SsaVisitor<'_> {
+ fn check_dominates(&mut self, local: Local, loc: Location) {
+ let set = &mut self.assignments[local];
+ let assign_dominates = match *set {
+ Set1::Empty | Set1::Many => false,
+ Set1::One(def) => def.dominates(loc, self.dominators),
+ };
+ // We are visiting a use that is not dominated by an assignment.
+ // Either there is a cycle involved, or we are reading for uninitialized local.
+ // Bail out.
+ if !assign_dominates {
+ *set = Set1::Many;
+ }
+ }
+}
+
impl<'tcx> Visitor<'tcx> for SsaVisitor<'_> {
fn visit_local(&mut self, local: Local, ctxt: PlaceContext, loc: Location) {
match ctxt {
@@ -254,7 +241,7 @@ impl<'tcx> Visitor<'tcx> for SsaVisitor<'_> {
self.assignments[local] = Set1::Many;
}
PlaceContext::NonMutatingUse(_) => {
- self.dominators.check_dominates(&mut self.assignments[local], loc);
+ self.check_dominates(local, loc);
self.direct_uses[local] += 1;
}
PlaceContext::NonUse(_) => {}
@@ -262,34 +249,34 @@ impl<'tcx> Visitor<'tcx> for SsaVisitor<'_> {
}
fn visit_place(&mut self, place: &Place<'tcx>, ctxt: PlaceContext, loc: Location) {
- if place.projection.first() == Some(&PlaceElem::Deref) {
- // Do not do anything for storage statements and debuginfo.
+ let location = match ctxt {
+ PlaceContext::MutatingUse(
+ MutatingUseContext::Store | MutatingUseContext::Call | MutatingUseContext::Yield,
+ ) => Some(DefLocation::Body(loc)),
+ _ => None,
+ };
+ if let Some(location) = location
+ && let Some(local) = place.as_local()
+ {
+ self.assignments[local].insert(location);
+ if let Set1::One(_) = self.assignments[local] {
+ // Only record if SSA-like, to avoid growing the vector needlessly.
+ self.assignment_order.push(local);
+ }
+ } else if place.projection.first() == Some(&PlaceElem::Deref) {
+ // Do not do anything for debuginfo.
if ctxt.is_use() {
// Only change the context if it is a real use, not a "use" in debuginfo.
let new_ctxt = PlaceContext::NonMutatingUse(NonMutatingUseContext::Copy);
self.visit_projection(place.as_ref(), new_ctxt, loc);
- self.dominators.check_dominates(&mut self.assignments[place.local], loc);
+ self.check_dominates(place.local, loc);
}
- return;
} else {
self.visit_projection(place.as_ref(), ctxt, loc);
self.visit_local(place.local, ctxt, loc);
}
}
-
- fn visit_assign(&mut self, place: &Place<'tcx>, rvalue: &Rvalue<'tcx>, loc: Location) {
- if let Some(local) = place.as_local() {
- self.assignments[local].insert(LocationExtended::Plain(loc));
- if let Set1::One(_) = self.assignments[local] {
- // Only record if SSA-like, to avoid growing the vector needlessly.
- self.assignment_order.push(local);
- }
- } else {
- self.visit_place(place, PlaceContext::MutatingUse(MutatingUseContext::Store), loc);
- }
- self.visit_rvalue(rvalue, loc);
- }
}
#[instrument(level = "trace", skip(ssa, body))]
@@ -356,7 +343,7 @@ fn compute_copy_classes(ssa: &mut SsaLocals, body: &Body<'_>) {
#[derive(Debug)]
pub(crate) struct StorageLiveLocals {
/// Set of "StorageLive" statements for each local.
- storage_live: IndexVec<Local, Set1<LocationExtended>>,
+ storage_live: IndexVec<Local, Set1<DefLocation>>,
}
impl StorageLiveLocals {
@@ -366,13 +353,13 @@ impl StorageLiveLocals {
) -> StorageLiveLocals {
let mut storage_live = IndexVec::from_elem(Set1::Empty, &body.local_decls);
for local in always_storage_live_locals.iter() {
- storage_live[local] = Set1::One(LocationExtended::Arg);
+ storage_live[local] = Set1::One(DefLocation::Argument);
}
for (block, bbdata) in body.basic_blocks.iter_enumerated() {
for (statement_index, statement) in bbdata.statements.iter().enumerate() {
if let StatementKind::StorageLive(local) = statement.kind {
storage_live[local]
- .insert(LocationExtended::Plain(Location { block, statement_index }));
+ .insert(DefLocation::Body(Location { block, statement_index }));
}
}
}
diff --git a/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs b/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs
index 092bcb5c9..98f67e18a 100644
--- a/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs
+++ b/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs
@@ -3,8 +3,7 @@
use crate::MirPass;
use rustc_data_structures::fx::FxHashSet;
use rustc_middle::mir::{
- BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, SwitchTargets, Terminator,
- TerminatorKind,
+ BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, Terminator, TerminatorKind,
};
use rustc_middle::ty::layout::TyAndLayout;
use rustc_middle::ty::{Ty, TyCtxt};
@@ -30,18 +29,16 @@ fn get_switched_on_type<'tcx>(
let terminator = block_data.terminator();
// Only bother checking blocks which terminate by switching on a local.
- if let Some(local) = get_discriminant_local(&terminator.kind) {
- let stmt_before_term = (!block_data.statements.is_empty())
- .then(|| &block_data.statements[block_data.statements.len() - 1].kind);
-
- if let Some(StatementKind::Assign(box (l, Rvalue::Discriminant(place)))) = stmt_before_term
- {
- if l.as_local() == Some(local) {
- let ty = place.ty(body, tcx).ty;
- if ty.is_enum() {
- return Some(ty);
- }
- }
+ let local = get_discriminant_local(&terminator.kind)?;
+
+ let stmt_before_term = block_data.statements.last()?;
+
+ if let StatementKind::Assign(box (l, Rvalue::Discriminant(place))) = stmt_before_term.kind
+ && l.as_local() == Some(local)
+ {
+ let ty = place.ty(body, tcx).ty;
+ if ty.is_enum() {
+ return Some(ty);
}
}
@@ -72,28 +69,6 @@ fn variant_discriminants<'tcx>(
}
}
-/// Ensures that the `otherwise` branch leads to an unreachable bb, returning `None` if so and a new
-/// bb to use as the new target if not.
-fn ensure_otherwise_unreachable<'tcx>(
- body: &Body<'tcx>,
- targets: &SwitchTargets,
-) -> Option<BasicBlockData<'tcx>> {
- let otherwise = targets.otherwise();
- let bb = &body.basic_blocks[otherwise];
- if bb.terminator().kind == TerminatorKind::Unreachable
- && bb.statements.iter().all(|s| matches!(&s.kind, StatementKind::StorageDead(_)))
- {
- return None;
- }
-
- let mut new_block = BasicBlockData::new(Some(Terminator {
- source_info: bb.terminator().source_info,
- kind: TerminatorKind::Unreachable,
- }));
- new_block.is_cleanup = bb.is_cleanup;
- Some(new_block)
-}
-
impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
sess.mir_opt_level() > 0
@@ -102,13 +77,16 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
trace!("UninhabitedEnumBranching starting for {:?}", body.source);
- for bb in body.basic_blocks.indices() {
+ let mut removable_switchs = Vec::new();
+
+ for (bb, bb_data) in body.basic_blocks.iter_enumerated() {
trace!("processing block {:?}", bb);
- let Some(discriminant_ty) = get_switched_on_type(&body.basic_blocks[bb], tcx, body)
- else {
+ if bb_data.is_cleanup {
continue;
- };
+ }
+
+ let Some(discriminant_ty) = get_switched_on_type(&bb_data, tcx, body) else { continue };
let layout = tcx.layout_of(
tcx.param_env_reveal_all_normalized(body.source.def_id()).and(discriminant_ty),
@@ -122,31 +100,38 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
trace!("allowed_variants = {:?}", allowed_variants);
- if let TerminatorKind::SwitchInt { targets, .. } =
- &mut body.basic_blocks_mut()[bb].terminator_mut().kind
- {
- let mut new_targets = SwitchTargets::new(
- targets.iter().filter(|(val, _)| allowed_variants.contains(val)),
- targets.otherwise(),
- );
-
- if new_targets.iter().count() == allowed_variants.len() {
- if let Some(updated) = ensure_otherwise_unreachable(body, &new_targets) {
- let new_otherwise = body.basic_blocks_mut().push(updated);
- *new_targets.all_targets_mut().last_mut().unwrap() = new_otherwise;
- }
- }
+ let terminator = bb_data.terminator();
+ let TerminatorKind::SwitchInt { targets, .. } = &terminator.kind else { bug!() };
- if let TerminatorKind::SwitchInt { targets, .. } =
- &mut body.basic_blocks_mut()[bb].terminator_mut().kind
- {
- *targets = new_targets;
+ let mut reachable_count = 0;
+ for (index, (val, _)) in targets.iter().enumerate() {
+ if allowed_variants.contains(&val) {
+ reachable_count += 1;
} else {
- unreachable!()
+ removable_switchs.push((bb, index));
}
- } else {
- unreachable!()
}
+
+ if reachable_count == allowed_variants.len() {
+ removable_switchs.push((bb, targets.iter().count()));
+ }
+ }
+
+ if removable_switchs.is_empty() {
+ return;
+ }
+
+ let new_block = BasicBlockData::new(Some(Terminator {
+ source_info: body.basic_blocks[removable_switchs[0].0].terminator().source_info,
+ kind: TerminatorKind::Unreachable,
+ }));
+ let unreachable_block = body.basic_blocks.as_mut().push(new_block);
+
+ for (bb, index) in removable_switchs {
+ let bb = &mut body.basic_blocks.as_mut()[bb];
+ let terminator = bb.terminator_mut();
+ let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind else { bug!() };
+ targets.all_targets_mut()[index] = unreachable_block;
}
}
}
diff --git a/compiler/rustc_mir_transform/src/unreachable_prop.rs b/compiler/rustc_mir_transform/src/unreachable_prop.rs
index 0b9311a20..919e8d6a2 100644
--- a/compiler/rustc_mir_transform/src/unreachable_prop.rs
+++ b/compiler/rustc_mir_transform/src/unreachable_prop.rs
@@ -2,11 +2,13 @@
//! when all of their successors are unreachable. This is achieved through a
//! post-order traversal of the blocks.
-use crate::simplify;
use crate::MirPass;
-use rustc_data_structures::fx::{FxHashMap, FxHashSet};
+use rustc_data_structures::fx::FxHashSet;
+use rustc_middle::mir::interpret::Scalar;
+use rustc_middle::mir::patch::MirPatch;
use rustc_middle::mir::*;
-use rustc_middle::ty::TyCtxt;
+use rustc_middle::ty::{self, TyCtxt};
+use rustc_target::abi::Size;
pub struct UnreachablePropagation;
@@ -21,106 +23,133 @@ impl MirPass<'_> for UnreachablePropagation {
}
fn run_pass<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+ let mut patch = MirPatch::new(body);
let mut unreachable_blocks = FxHashSet::default();
- let mut replacements = FxHashMap::default();
for (bb, bb_data) in traversal::postorder(body) {
let terminator = bb_data.terminator();
- if terminator.kind == TerminatorKind::Unreachable {
- unreachable_blocks.insert(bb);
- } else {
- let is_unreachable = |succ: BasicBlock| unreachable_blocks.contains(&succ);
- let terminator_kind_opt = remove_successors(&terminator.kind, is_unreachable);
-
- if let Some(terminator_kind) = terminator_kind_opt {
- if terminator_kind == TerminatorKind::Unreachable {
- unreachable_blocks.insert(bb);
- }
- replacements.insert(bb, terminator_kind);
+ let is_unreachable = match &terminator.kind {
+ TerminatorKind::Unreachable => true,
+ // This will unconditionally run into an unreachable and is therefore unreachable as well.
+ TerminatorKind::Goto { target } if unreachable_blocks.contains(target) => {
+ patch.patch_terminator(bb, TerminatorKind::Unreachable);
+ true
+ }
+ // Try to remove unreachable targets from the switch.
+ TerminatorKind::SwitchInt { .. } => {
+ remove_successors_from_switch(tcx, bb, &unreachable_blocks, body, &mut patch)
}
+ _ => false,
+ };
+ if is_unreachable {
+ unreachable_blocks.insert(bb);
}
}
+ if !tcx
+ .consider_optimizing(|| format!("UnreachablePropagation {:?} ", body.source.def_id()))
+ {
+ return;
+ }
+
+ patch.apply(body);
+
// We do want do keep some unreachable blocks, but make them empty.
for bb in unreachable_blocks {
- if !tcx.consider_optimizing(|| {
- format!("UnreachablePropagation {:?} ", body.source.def_id())
- }) {
- break;
- }
-
body.basic_blocks_mut()[bb].statements.clear();
}
+ }
+}
- let replaced = !replacements.is_empty();
+/// Return whether the current terminator is fully unreachable.
+fn remove_successors_from_switch<'tcx>(
+ tcx: TyCtxt<'tcx>,
+ bb: BasicBlock,
+ unreachable_blocks: &FxHashSet<BasicBlock>,
+ body: &Body<'tcx>,
+ patch: &mut MirPatch<'tcx>,
+) -> bool {
+ let terminator = body.basic_blocks[bb].terminator();
+ let TerminatorKind::SwitchInt { discr, targets } = &terminator.kind else { bug!() };
+ let source_info = terminator.source_info;
+ let location = body.terminator_loc(bb);
+
+ let is_unreachable = |bb| unreachable_blocks.contains(&bb);
+
+ // If there are multiple targets, we want to keep information about reachability for codegen.
+ // For example (see tests/codegen/match-optimizes-away.rs)
+ //
+ // pub enum Two { A, B }
+ // pub fn identity(x: Two) -> Two {
+ // match x {
+ // Two::A => Two::A,
+ // Two::B => Two::B,
+ // }
+ // }
+ //
+ // This generates a `switchInt() -> [0: 0, 1: 1, otherwise: unreachable]`, which allows us or LLVM to
+ // turn it into just `x` later. Without the unreachable, such a transformation would be illegal.
+ //
+ // In order to preserve this information, we record reachable and unreachable targets as
+ // `Assume` statements in MIR.
+
+ let discr_ty = discr.ty(body, tcx);
+ let discr_size = Size::from_bits(match discr_ty.kind() {
+ ty::Uint(uint) => uint.normalize(tcx.sess.target.pointer_width).bit_width().unwrap(),
+ ty::Int(int) => int.normalize(tcx.sess.target.pointer_width).bit_width().unwrap(),
+ ty::Char => 32,
+ ty::Bool => 1,
+ other => bug!("unhandled type: {:?}", other),
+ });
+
+ let mut add_assumption = |binop, value| {
+ let local = patch.new_temp(tcx.types.bool, source_info.span);
+ let value = Operand::Constant(Box::new(ConstOperand {
+ span: source_info.span,
+ user_ty: None,
+ const_: Const::from_scalar(tcx, Scalar::from_uint(value, discr_size), discr_ty),
+ }));
+ let cmp = Rvalue::BinaryOp(binop, Box::new((discr.to_copy(), value)));
+ patch.add_assign(location, local.into(), cmp);
+
+ let assume = NonDivergingIntrinsic::Assume(Operand::Move(local.into()));
+ patch.add_statement(location, StatementKind::Intrinsic(Box::new(assume)));
+ };
- for (bb, terminator_kind) in replacements {
- if !tcx.consider_optimizing(|| {
- format!("UnreachablePropagation {:?} ", body.source.def_id())
- }) {
- break;
- }
+ let otherwise = targets.otherwise();
+ let otherwise_unreachable = is_unreachable(otherwise);
- body.basic_blocks_mut()[bb].terminator_mut().kind = terminator_kind;
+ let reachable_iter = targets.iter().filter(|&(value, bb)| {
+ let is_unreachable = is_unreachable(bb);
+ // We remove this target from the switch, so record the inequality using `Assume`.
+ if is_unreachable && !otherwise_unreachable {
+ add_assumption(BinOp::Ne, value);
}
-
- if replaced {
- simplify::remove_dead_blocks(tcx, body);
+ !is_unreachable
+ });
+
+ let new_targets = SwitchTargets::new(reachable_iter, otherwise);
+
+ let num_targets = new_targets.all_targets().len();
+ let fully_unreachable = num_targets == 1 && otherwise_unreachable;
+
+ let terminator = match (num_targets, otherwise_unreachable) {
+ // If all targets are unreachable, we can be unreachable as well.
+ (1, true) => TerminatorKind::Unreachable,
+ (1, false) => TerminatorKind::Goto { target: otherwise },
+ (2, true) => {
+ // All targets are unreachable except one. Record the equality, and make it a goto.
+ let (value, target) = new_targets.iter().next().unwrap();
+ add_assumption(BinOp::Eq, value);
+ TerminatorKind::Goto { target }
}
- }
-}
-
-fn remove_successors<'tcx, F>(
- terminator_kind: &TerminatorKind<'tcx>,
- is_unreachable: F,
-) -> Option<TerminatorKind<'tcx>>
-where
- F: Fn(BasicBlock) -> bool,
-{
- let terminator = match terminator_kind {
- // This will unconditionally run into an unreachable and is therefore unreachable as well.
- TerminatorKind::Goto { target } if is_unreachable(*target) => TerminatorKind::Unreachable,
- TerminatorKind::SwitchInt { targets, discr } => {
- let otherwise = targets.otherwise();
-
- // If all targets are unreachable, we can be unreachable as well.
- if targets.all_targets().iter().all(|bb| is_unreachable(*bb)) {
- TerminatorKind::Unreachable
- } else if is_unreachable(otherwise) {
- // If there are multiple targets, don't delete unreachable branches (like an unreachable otherwise)
- // unless otherwise is unreachable, in which case deleting a normal branch causes it to be merged with
- // the otherwise, keeping its unreachable.
- // This looses information about reachability causing worse codegen.
- // For example (see tests/codegen/match-optimizes-away.rs)
- //
- // pub enum Two { A, B }
- // pub fn identity(x: Two) -> Two {
- // match x {
- // Two::A => Two::A,
- // Two::B => Two::B,
- // }
- // }
- //
- // This generates a `switchInt() -> [0: 0, 1: 1, otherwise: unreachable]`, which allows us or LLVM to
- // turn it into just `x` later. Without the unreachable, such a transformation would be illegal.
- // If the otherwise branch is unreachable, we can delete all other unreachable targets, as they will
- // still point to the unreachable and therefore not lose reachability information.
- let reachable_iter = targets.iter().filter(|(_, bb)| !is_unreachable(*bb));
-
- let new_targets = SwitchTargets::new(reachable_iter, otherwise);
-
- // No unreachable branches were removed.
- if new_targets.all_targets().len() == targets.all_targets().len() {
- return None;
- }
-
- TerminatorKind::SwitchInt { discr: discr.clone(), targets: new_targets }
- } else {
- // If the otherwise branch is reachable, we don't want to delete any unreachable branches.
- return None;
- }
+ _ if num_targets == targets.all_targets().len() => {
+ // Nothing has changed.
+ return false;
}
- _ => return None,
+ _ => TerminatorKind::SwitchInt { discr: discr.clone(), targets: new_targets },
};
- Some(terminator)
+
+ patch.patch_terminator(bb, terminator);
+ fully_unreachable
}