diff options
Diffstat (limited to 'compiler/rustc_mir_transform/src')
24 files changed, 1782 insertions, 1053 deletions
diff --git a/compiler/rustc_mir_transform/src/add_moves_for_packed_drops.rs b/compiler/rustc_mir_transform/src/add_moves_for_packed_drops.rs index ffb5d8c6d..9b2260f68 100644 --- a/compiler/rustc_mir_transform/src/add_moves_for_packed_drops.rs +++ b/compiler/rustc_mir_transform/src/add_moves_for_packed_drops.rs @@ -5,37 +5,36 @@ use crate::util; use crate::MirPass; use rustc_middle::mir::patch::MirPatch; -// This pass moves values being dropped that are within a packed -// struct to a separate local before dropping them, to ensure that -// they are dropped from an aligned address. -// -// For example, if we have something like -// ```Rust -// #[repr(packed)] -// struct Foo { -// dealign: u8, -// data: Vec<u8> -// } -// -// let foo = ...; -// ``` -// -// We want to call `drop_in_place::<Vec<u8>>` on `data` from an aligned -// address. This means we can't simply drop `foo.data` directly, because -// its address is not aligned. -// -// Instead, we move `foo.data` to a local and drop that: -// ``` -// storage.live(drop_temp) -// drop_temp = foo.data; -// drop(drop_temp) -> next -// next: -// storage.dead(drop_temp) -// ``` -// -// The storage instructions are required to avoid stack space -// blowup. - +/// This pass moves values being dropped that are within a packed +/// struct to a separate local before dropping them, to ensure that +/// they are dropped from an aligned address. +/// +/// For example, if we have something like +/// ```ignore (ilustrative) +/// #[repr(packed)] +/// struct Foo { +/// dealign: u8, +/// data: Vec<u8> +/// } +/// +/// let foo = ...; +/// ``` +/// +/// We want to call `drop_in_place::<Vec<u8>>` on `data` from an aligned +/// address. This means we can't simply drop `foo.data` directly, because +/// its address is not aligned. +/// +/// Instead, we move `foo.data` to a local and drop that: +/// ```ignore (ilustrative) +/// storage.live(drop_temp) +/// drop_temp = foo.data; +/// drop(drop_temp) -> next +/// next: +/// storage.dead(drop_temp) +/// ``` +/// +/// The storage instructions are required to avoid stack space +/// blowup. pub struct AddMovesForPackedDrops; impl<'tcx> MirPass<'tcx> for AddMovesForPackedDrops { diff --git a/compiler/rustc_mir_transform/src/add_retag.rs b/compiler/rustc_mir_transform/src/add_retag.rs index 036b55898..3d22035f0 100644 --- a/compiler/rustc_mir_transform/src/add_retag.rs +++ b/compiler/rustc_mir_transform/src/add_retag.rs @@ -10,16 +10,6 @@ use rustc_middle::ty::{self, Ty, TyCtxt}; pub struct AddRetag; -/// Determines whether this place is "stable": Whether, if we evaluate it again -/// after the assignment, we can be sure to obtain the same place value. -/// (Concurrent accesses by other threads are no problem as these are anyway non-atomic -/// copies. Data races are UB.) -fn is_stable(place: PlaceRef<'_>) -> bool { - // Which place this evaluates to can change with any memory write, - // so cannot assume deref to be stable. - !place.has_deref() -} - /// Determine whether this type may contain a reference (or box), and thus needs retagging. /// We will only recurse `depth` times into Tuples/ADTs to bound the cost of this. fn may_contain_reference<'tcx>(ty: Ty<'tcx>, depth: u32, tcx: TyCtxt<'tcx>) -> bool { @@ -69,22 +59,10 @@ impl<'tcx> MirPass<'tcx> for AddRetag { let basic_blocks = body.basic_blocks.as_mut(); let local_decls = &body.local_decls; let needs_retag = |place: &Place<'tcx>| { - // FIXME: Instead of giving up for unstable places, we should introduce - // a temporary and retag on that. - is_stable(place.as_ref()) + !place.has_deref() // we're not eally interested in stores to "outside" locations, they are hard to keep track of anyway && may_contain_reference(place.ty(&*local_decls, tcx).ty, /*depth*/ 3, tcx) && !local_decls[place.local].is_deref_temp() }; - let place_base_raw = |place: &Place<'tcx>| { - // If this is a `Deref`, get the type of what we are deref'ing. - if place.has_deref() { - let ty = &local_decls[place.local].ty; - ty.is_unsafe_ptr() - } else { - // Not a deref, and thus not raw. - false - } - }; // PART 1 // Retag arguments at the beginning of the start block. @@ -108,7 +86,7 @@ impl<'tcx> MirPass<'tcx> for AddRetag { } // PART 2 - // Retag return values of functions. Also escape-to-raw the argument of `drop`. + // Retag return values of functions. // We collect the return destinations because we cannot mutate while iterating. let returns = basic_blocks .iter_mut() @@ -140,30 +118,25 @@ impl<'tcx> MirPass<'tcx> for AddRetag { } // PART 3 - // Add retag after assignment. + // Add retag after assignments where data "enters" this function: the RHS is behind a deref and the LHS is not. for block_data in basic_blocks { // We want to insert statements as we iterate. To this end, we // iterate backwards using indices. for i in (0..block_data.statements.len()).rev() { let (retag_kind, place) = match block_data.statements[i].kind { - // Retag-as-raw after escaping to a raw pointer, if the referent - // is not already a raw pointer. - StatementKind::Assign(box (lplace, Rvalue::AddressOf(_, ref rplace))) - if !place_base_raw(rplace) => - { - (RetagKind::Raw, lplace) - } // Retag after assignments of reference type. StatementKind::Assign(box (ref place, ref rvalue)) if needs_retag(place) => { - let kind = match rvalue { - Rvalue::Ref(_, borrow_kind, _) - if borrow_kind.allows_two_phase_borrow() => - { - RetagKind::TwoPhase - } - _ => RetagKind::Default, + let add_retag = match rvalue { + // Ptr-creating operations already do their own internal retagging, no + // need to also add a retag statement. + Rvalue::Ref(..) | Rvalue::AddressOf(..) => false, + _ => true, }; - (kind, *place) + if add_retag { + (RetagKind::Default, *place) + } else { + continue; + } } // Do nothing for the rest _ => continue, diff --git a/compiler/rustc_mir_transform/src/check_unsafety.rs b/compiler/rustc_mir_transform/src/check_unsafety.rs index f8f04214a..e783d1891 100644 --- a/compiler/rustc_mir_transform/src/check_unsafety.rs +++ b/compiler/rustc_mir_transform/src/check_unsafety.rs @@ -4,6 +4,7 @@ use rustc_hir as hir; use rustc_hir::def_id::{DefId, LocalDefId}; use rustc_hir::hir_id::HirId; use rustc_hir::intravisit; +use rustc_hir::{BlockCheckMode, ExprKind, Node}; use rustc_middle::mir::visit::{MutatingUseContext, PlaceContext, Visitor}; use rustc_middle::mir::*; use rustc_middle::ty::query::Providers; @@ -101,12 +102,10 @@ impl<'tcx> Visitor<'tcx> for UnsafetyChecker<'_, 'tcx> { | StatementKind::Retag { .. } | StatementKind::AscribeUserType(..) | StatementKind::Coverage(..) + | StatementKind::Intrinsic(..) | StatementKind::Nop => { // safe (at least as emitted during MIR construction) } - - // Move to above list once mir construction uses it. - StatementKind::Intrinsic(..) => unreachable!(), } self.super_statement(statement, location); } @@ -472,6 +471,14 @@ fn unsafety_check_result<'tcx>( // `mir_built` force this. let body = &tcx.mir_built(def).borrow(); + if body.should_skip() { + return tcx.arena.alloc(UnsafetyCheckResult { + violations: Vec::new(), + used_unsafe_blocks: FxHashSet::default(), + unused_unsafes: Some(Vec::new()), + }); + } + let param_env = tcx.param_env(def.did); let mut checker = UnsafetyChecker::new(body, def.did, tcx, param_env); @@ -519,24 +526,48 @@ pub fn check_unsafety(tcx: TyCtxt<'_>, def_id: LocalDefId) { for &UnsafetyViolation { source_info, lint_root, kind, details } in violations.iter() { let (description, note) = details.description_and_note(); - // Report an error. - let unsafe_fn_msg = - if unsafe_op_in_unsafe_fn_allowed(tcx, lint_root) { " function or" } else { "" }; - match kind { UnsafetyViolationKind::General => { // once - struct_span_err!( + let unsafe_fn_msg = if unsafe_op_in_unsafe_fn_allowed(tcx, lint_root) { + " function or" + } else { + "" + }; + + let mut err = struct_span_err!( tcx.sess, source_info.span, E0133, "{} is unsafe and requires unsafe{} block", description, unsafe_fn_msg, - ) - .span_label(source_info.span, description) - .note(note) - .emit(); + ); + err.span_label(source_info.span, description).note(note); + let note_non_inherited = tcx.hir().parent_iter(lint_root).find(|(id, node)| { + if let Node::Expr(block) = node + && let ExprKind::Block(block, _) = block.kind + && let BlockCheckMode::UnsafeBlock(_) = block.rules + { + true + } + else if let Some(sig) = tcx.hir().fn_sig_by_hir_id(*id) + && sig.header.is_unsafe() + { + true + } else { + false + } + }); + if let Some((id, _)) = note_non_inherited { + let span = tcx.hir().span(id); + err.span_label( + tcx.sess.source_map().guess_head_span(span), + "items do not inherit unsafety from separate enclosing items", + ); + } + + err.emit(); } UnsafetyViolationKind::UnsafeFn => tcx.struct_span_lint_hir( UNSAFE_OP_IN_UNSAFE_FN, diff --git a/compiler/rustc_mir_transform/src/const_prop.rs b/compiler/rustc_mir_transform/src/const_prop.rs index 4e4515888..b0514e033 100644 --- a/compiler/rustc_mir_transform/src/const_prop.rs +++ b/compiler/rustc_mir_transform/src/const_prop.rs @@ -3,6 +3,8 @@ use std::cell::Cell; +use either::Right; + use rustc_ast::Mutability; use rustc_data_structures::fx::FxHashSet; use rustc_hir::def::DefKind; @@ -266,7 +268,7 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx> _tcx: TyCtxt<'tcx>, _machine: &Self, _alloc_id: AllocId, - alloc: ConstAllocation<'tcx, Self::Provenance, Self::AllocExtra>, + alloc: ConstAllocation<'tcx>, _static_def_id: Option<DefId>, is_write: bool, ) -> InterpResult<'tcx> { @@ -385,7 +387,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { // I don't know how return types can seem to be unsized but this happens in the // `type/type-unsatisfiable.rs` test. .filter(|ret_layout| { - !ret_layout.is_unsized() && ret_layout.size < Size::from_bytes(MAX_ALLOC_LIMIT) + ret_layout.is_sized() && ret_layout.size < Size::from_bytes(MAX_ALLOC_LIMIT) }) .unwrap_or_else(|| ecx.layout_of(tcx.types.unit).unwrap()); @@ -429,7 +431,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { // Try to read the local as an immediate so that if it is representable as a scalar, we can // handle it as such, but otherwise, just return the value as is. Some(match self.ecx.read_immediate_raw(&op) { - Ok(Ok(imm)) => imm.into(), + Ok(Right(imm)) => imm.into(), _ => op, }) } @@ -471,7 +473,8 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { return None; } - self.ecx.const_to_op(&c.literal, None).ok() + // No span, we don't want errors to be shown. + self.ecx.eval_mir_constant(&c.literal, None, None).ok() } /// Returns the value, if any, of evaluating `place`. @@ -742,7 +745,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { // FIXME> figure out what to do when read_immediate_raw fails let imm = self.use_ecx(|this| this.ecx.read_immediate_raw(value)); - if let Some(Ok(imm)) = imm { + if let Some(Right(imm)) = imm { match *imm { interpret::Immediate::Scalar(scalar) => { *rval = Rvalue::Use(self.operand_from_scalar( diff --git a/compiler/rustc_mir_transform/src/const_prop_lint.rs b/compiler/rustc_mir_transform/src/const_prop_lint.rs index 479c4e577..0ab67228f 100644 --- a/compiler/rustc_mir_transform/src/const_prop_lint.rs +++ b/compiler/rustc_mir_transform/src/const_prop_lint.rs @@ -1,11 +1,10 @@ //! Propagates constants for early reporting of statically known //! assertion failures -use crate::const_prop::CanConstProp; -use crate::const_prop::ConstPropMachine; -use crate::const_prop::ConstPropMode; -use crate::MirLint; -use rustc_const_eval::const_eval::ConstEvalErr; +use std::cell::Cell; + +use either::{Left, Right}; + use rustc_const_eval::interpret::Immediate; use rustc_const_eval::interpret::{ self, InterpCx, InterpResult, LocalState, LocalValue, MemoryKind, OpTy, Scalar, StackPopCleanup, @@ -27,12 +26,17 @@ use rustc_session::lint; use rustc_span::Span; use rustc_target::abi::{HasDataLayout, Size, TargetDataLayout}; use rustc_trait_selection::traits; -use std::cell::Cell; + +use crate::const_prop::CanConstProp; +use crate::const_prop::ConstPropMachine; +use crate::const_prop::ConstPropMode; +use crate::MirLint; /// The maximum number of bytes that we'll allocate space for a local or the return value. /// Needed for #66397, because otherwise we eval into large places and that can cause OOM or just /// Severely regress performance. const MAX_ALLOC_LIMIT: u64 = 1024; + pub struct ConstProp; impl<'tcx> MirLint<'tcx> for ConstProp { @@ -199,7 +203,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { // I don't know how return types can seem to be unsized but this happens in the // `type/type-unsatisfiable.rs` test. .filter(|ret_layout| { - !ret_layout.is_unsized() && ret_layout.size < Size::from_bytes(MAX_ALLOC_LIMIT) + ret_layout.is_sized() && ret_layout.size < Size::from_bytes(MAX_ALLOC_LIMIT) }) .unwrap_or_else(|| ecx.layout_of(tcx.types.unit).unwrap()); @@ -244,7 +248,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { // Try to read the local as an immediate so that if it is representable as a scalar, we can // handle it as such, but otherwise, just return the value as is. Some(match self.ecx.read_immediate_raw(&op) { - Ok(Ok(imm)) => imm.into(), + Ok(Left(imm)) => imm.into(), _ => op, }) } @@ -267,7 +271,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { F: FnOnce(&mut Self) -> InterpResult<'tcx, T>, { // Overwrite the PC -- whatever the interpreter does to it does not make any sense anyway. - self.ecx.frame_mut().loc = Err(source_info.span); + self.ecx.frame_mut().loc = Right(source_info.span); match f(self) { Ok(val) => Some(val), Err(error) => { @@ -286,25 +290,13 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { } /// Returns the value, if any, of evaluating `c`. - fn eval_constant( - &mut self, - c: &Constant<'tcx>, - _source_info: SourceInfo, - ) -> Option<OpTy<'tcx>> { + fn eval_constant(&mut self, c: &Constant<'tcx>, source_info: SourceInfo) -> Option<OpTy<'tcx>> { // FIXME we need to revisit this for #67176 if c.needs_subst() { return None; } - match self.ecx.const_to_op(&c.literal, None) { - Ok(op) => Some(op), - Err(error) => { - let tcx = self.ecx.tcx.at(c.span); - let err = ConstEvalErr::new(&self.ecx, error, Some(c.span)); - err.report_as_error(tcx, "erroneous constant used"); - None - } - } + self.use_ecx(source_info, |this| this.ecx.eval_mir_constant(&c.literal, Some(c.span), None)) } /// Returns the value, if any, of evaluating `place`. diff --git a/compiler/rustc_mir_transform/src/coverage/debug.rs b/compiler/rustc_mir_transform/src/coverage/debug.rs index 0f8679b0b..d6a298fad 100644 --- a/compiler/rustc_mir_transform/src/coverage/debug.rs +++ b/compiler/rustc_mir_transform/src/coverage/debug.rs @@ -638,7 +638,7 @@ pub(super) fn dump_coverage_spanview<'tcx>( let def_id = mir_source.def_id(); let span_viewables = span_viewables(tcx, mir_body, basic_coverage_blocks, &coverage_spans); - let mut file = create_dump_file(tcx, "html", None, pass_name, &0, mir_source) + let mut file = create_dump_file(tcx, "html", false, pass_name, &0, mir_body) .expect("Unexpected error creating MIR spanview HTML file"); let crate_name = tcx.crate_name(def_id.krate); let item_name = tcx.def_path(def_id).to_filename_friendly_no_crate(); @@ -739,7 +739,7 @@ pub(super) fn dump_coverage_graphviz<'tcx>( .join("\n ") )); } - let mut file = create_dump_file(tcx, "dot", None, pass_name, &0, mir_source) + let mut file = create_dump_file(tcx, "dot", false, pass_name, &0, mir_body) .expect("Unexpected error creating BasicCoverageBlock graphviz DOT file"); graphviz_writer .write_graphviz(tcx, &mut file) diff --git a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs new file mode 100644 index 000000000..e90273874 --- /dev/null +++ b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs @@ -0,0 +1,530 @@ +//! A constant propagation optimization pass based on dataflow analysis. +//! +//! Currently, this pass only propagates scalar values. + +use rustc_const_eval::interpret::{ConstValue, ImmTy, Immediate, InterpCx, Scalar}; +use rustc_data_structures::fx::FxHashMap; +use rustc_middle::mir::visit::{MutVisitor, Visitor}; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, Ty, TyCtxt}; +use rustc_mir_dataflow::value_analysis::{Map, State, TrackElem, ValueAnalysis, ValueOrPlace}; +use rustc_mir_dataflow::{lattice::FlatSet, Analysis, ResultsVisitor, SwitchIntEdgeEffects}; +use rustc_span::DUMMY_SP; + +use crate::MirPass; + +// These constants are somewhat random guesses and have not been optimized. +// If `tcx.sess.mir_opt_level() >= 4`, we ignore the limits (this can become very expensive). +const BLOCK_LIMIT: usize = 100; +const PLACE_LIMIT: usize = 100; + +pub struct DataflowConstProp; + +impl<'tcx> MirPass<'tcx> for DataflowConstProp { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 3 + } + + #[instrument(skip_all level = "debug")] + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + if tcx.sess.mir_opt_level() < 4 && body.basic_blocks.len() > BLOCK_LIMIT { + debug!("aborted dataflow const prop due too many basic blocks"); + return; + } + + // Decide which places to track during the analysis. + let map = Map::from_filter(tcx, body, Ty::is_scalar); + + // We want to have a somewhat linear runtime w.r.t. the number of statements/terminators. + // Let's call this number `n`. Dataflow analysis has `O(h*n)` transfer function + // applications, where `h` is the height of the lattice. Because the height of our lattice + // is linear w.r.t. the number of tracked places, this is `O(tracked_places * n)`. However, + // because every transfer function application could traverse the whole map, this becomes + // `O(num_nodes * tracked_places * n)` in terms of time complexity. Since the number of + // map nodes is strongly correlated to the number of tracked places, this becomes more or + // less `O(n)` if we place a constant limit on the number of tracked places. + if tcx.sess.mir_opt_level() < 4 && map.tracked_places() > PLACE_LIMIT { + debug!("aborted dataflow const prop due to too many tracked places"); + return; + } + + // Perform the actual dataflow analysis. + let analysis = ConstAnalysis::new(tcx, body, map); + let results = debug_span!("analyze") + .in_scope(|| analysis.wrap().into_engine(tcx, body).iterate_to_fixpoint()); + + // Collect results and patch the body afterwards. + let mut visitor = CollectAndPatch::new(tcx, &results.analysis.0.map); + debug_span!("collect").in_scope(|| results.visit_reachable_with(body, &mut visitor)); + debug_span!("patch").in_scope(|| visitor.visit_body(body)); + } +} + +struct ConstAnalysis<'tcx> { + map: Map, + tcx: TyCtxt<'tcx>, + ecx: InterpCx<'tcx, 'tcx, DummyMachine>, + param_env: ty::ParamEnv<'tcx>, +} + +impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> { + type Value = FlatSet<ScalarTy<'tcx>>; + + const NAME: &'static str = "ConstAnalysis"; + + fn map(&self) -> &Map { + &self.map + } + + fn handle_assign( + &self, + target: Place<'tcx>, + rvalue: &Rvalue<'tcx>, + state: &mut State<Self::Value>, + ) { + match rvalue { + Rvalue::CheckedBinaryOp(op, box (left, right)) => { + let target = self.map().find(target.as_ref()); + if let Some(target) = target { + // We should not track any projections other than + // what is overwritten below, but just in case... + state.flood_idx(target, self.map()); + } + + let value_target = target + .and_then(|target| self.map().apply(target, TrackElem::Field(0_u32.into()))); + let overflow_target = target + .and_then(|target| self.map().apply(target, TrackElem::Field(1_u32.into()))); + + if value_target.is_some() || overflow_target.is_some() { + let (val, overflow) = self.binary_op(state, *op, left, right); + + if let Some(value_target) = value_target { + state.assign_idx(value_target, ValueOrPlace::Value(val), self.map()); + } + if let Some(overflow_target) = overflow_target { + let overflow = match overflow { + FlatSet::Top => FlatSet::Top, + FlatSet::Elem(overflow) => { + if overflow { + // Overflow cannot be reliably propagated. See: https://github.com/rust-lang/rust/pull/101168#issuecomment-1288091446 + FlatSet::Top + } else { + self.wrap_scalar(Scalar::from_bool(false), self.tcx.types.bool) + } + } + FlatSet::Bottom => FlatSet::Bottom, + }; + state.assign_idx( + overflow_target, + ValueOrPlace::Value(overflow), + self.map(), + ); + } + } + } + _ => self.super_assign(target, rvalue, state), + } + } + + fn handle_rvalue( + &self, + rvalue: &Rvalue<'tcx>, + state: &mut State<Self::Value>, + ) -> ValueOrPlace<Self::Value> { + match rvalue { + Rvalue::Cast( + kind @ (CastKind::IntToInt + | CastKind::FloatToInt + | CastKind::FloatToFloat + | CastKind::IntToFloat), + operand, + ty, + ) => match self.eval_operand(operand, state) { + FlatSet::Elem(op) => match kind { + CastKind::IntToInt | CastKind::IntToFloat => { + self.ecx.int_to_int_or_float(&op, *ty) + } + CastKind::FloatToInt | CastKind::FloatToFloat => { + self.ecx.float_to_float_or_int(&op, *ty) + } + _ => unreachable!(), + } + .map(|result| ValueOrPlace::Value(self.wrap_immediate(result, *ty))) + .unwrap_or(ValueOrPlace::top()), + _ => ValueOrPlace::top(), + }, + Rvalue::BinaryOp(op, box (left, right)) => { + // Overflows must be ignored here. + let (val, _overflow) = self.binary_op(state, *op, left, right); + ValueOrPlace::Value(val) + } + Rvalue::UnaryOp(op, operand) => match self.eval_operand(operand, state) { + FlatSet::Elem(value) => self + .ecx + .unary_op(*op, &value) + .map(|val| ValueOrPlace::Value(self.wrap_immty(val))) + .unwrap_or(ValueOrPlace::Value(FlatSet::Top)), + FlatSet::Bottom => ValueOrPlace::Value(FlatSet::Bottom), + FlatSet::Top => ValueOrPlace::Value(FlatSet::Top), + }, + _ => self.super_rvalue(rvalue, state), + } + } + + fn handle_constant( + &self, + constant: &Constant<'tcx>, + _state: &mut State<Self::Value>, + ) -> Self::Value { + constant + .literal + .eval(self.tcx, self.param_env) + .try_to_scalar() + .map(|value| FlatSet::Elem(ScalarTy(value, constant.ty()))) + .unwrap_or(FlatSet::Top) + } + + fn handle_switch_int( + &self, + discr: &Operand<'tcx>, + apply_edge_effects: &mut impl SwitchIntEdgeEffects<State<Self::Value>>, + ) { + // FIXME: The dataflow framework only provides the state if we call `apply()`, which makes + // this more inefficient than it has to be. + let mut discr_value = None; + let mut handled = false; + apply_edge_effects.apply(|state, target| { + let discr_value = match discr_value { + Some(value) => value, + None => { + let value = match self.handle_operand(discr, state) { + ValueOrPlace::Value(value) => value, + ValueOrPlace::Place(place) => state.get_idx(place, self.map()), + }; + let result = match value { + FlatSet::Top => FlatSet::Top, + FlatSet::Elem(ScalarTy(scalar, _)) => { + let int = scalar.assert_int(); + FlatSet::Elem(int.assert_bits(int.size())) + } + FlatSet::Bottom => FlatSet::Bottom, + }; + discr_value = Some(result); + result + } + }; + + let FlatSet::Elem(choice) = discr_value else { + // Do nothing if we don't know which branch will be taken. + return + }; + + if target.value.map(|n| n == choice).unwrap_or(!handled) { + // Branch is taken. Has no effect on state. + handled = true; + } else { + // Branch is not taken. + state.mark_unreachable(); + } + }) + } +} + +#[derive(Clone, PartialEq, Eq)] +struct ScalarTy<'tcx>(Scalar, Ty<'tcx>); + +impl<'tcx> std::fmt::Debug for ScalarTy<'tcx> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // This is used for dataflow visualization, so we return something more concise. + std::fmt::Display::fmt(&ConstantKind::Val(ConstValue::Scalar(self.0), self.1), f) + } +} + +impl<'tcx> ConstAnalysis<'tcx> { + pub fn new(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, map: Map) -> Self { + let param_env = tcx.param_env(body.source.def_id()); + Self { + map, + tcx, + ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine), + param_env: param_env, + } + } + + fn binary_op( + &self, + state: &mut State<FlatSet<ScalarTy<'tcx>>>, + op: BinOp, + left: &Operand<'tcx>, + right: &Operand<'tcx>, + ) -> (FlatSet<ScalarTy<'tcx>>, FlatSet<bool>) { + let left = self.eval_operand(left, state); + let right = self.eval_operand(right, state); + match (left, right) { + (FlatSet::Elem(left), FlatSet::Elem(right)) => { + match self.ecx.overflowing_binary_op(op, &left, &right) { + Ok((val, overflow, ty)) => (self.wrap_scalar(val, ty), FlatSet::Elem(overflow)), + _ => (FlatSet::Top, FlatSet::Top), + } + } + (FlatSet::Bottom, _) | (_, FlatSet::Bottom) => (FlatSet::Bottom, FlatSet::Bottom), + (_, _) => { + // Could attempt some algebraic simplifcations here. + (FlatSet::Top, FlatSet::Top) + } + } + } + + fn eval_operand( + &self, + op: &Operand<'tcx>, + state: &mut State<FlatSet<ScalarTy<'tcx>>>, + ) -> FlatSet<ImmTy<'tcx>> { + let value = match self.handle_operand(op, state) { + ValueOrPlace::Value(value) => value, + ValueOrPlace::Place(place) => state.get_idx(place, &self.map), + }; + match value { + FlatSet::Top => FlatSet::Top, + FlatSet::Elem(ScalarTy(scalar, ty)) => self + .tcx + .layout_of(self.param_env.and(ty)) + .map(|layout| FlatSet::Elem(ImmTy::from_scalar(scalar, layout))) + .unwrap_or(FlatSet::Top), + FlatSet::Bottom => FlatSet::Bottom, + } + } + + fn wrap_scalar(&self, scalar: Scalar, ty: Ty<'tcx>) -> FlatSet<ScalarTy<'tcx>> { + FlatSet::Elem(ScalarTy(scalar, ty)) + } + + fn wrap_immediate(&self, imm: Immediate, ty: Ty<'tcx>) -> FlatSet<ScalarTy<'tcx>> { + match imm { + Immediate::Scalar(scalar) => self.wrap_scalar(scalar, ty), + _ => FlatSet::Top, + } + } + + fn wrap_immty(&self, val: ImmTy<'tcx>) -> FlatSet<ScalarTy<'tcx>> { + self.wrap_immediate(*val, val.layout.ty) + } +} + +struct CollectAndPatch<'tcx, 'map> { + tcx: TyCtxt<'tcx>, + map: &'map Map, + + /// For a given MIR location, this stores the values of the operands used by that location. In + /// particular, this is before the effect, such that the operands of `_1 = _1 + _2` are + /// properly captured. (This may become UB soon, but it is currently emitted even by safe code.) + before_effect: FxHashMap<(Location, Place<'tcx>), ScalarTy<'tcx>>, + + /// Stores the assigned values for assignments where the Rvalue is constant. + assignments: FxHashMap<Location, ScalarTy<'tcx>>, +} + +impl<'tcx, 'map> CollectAndPatch<'tcx, 'map> { + fn new(tcx: TyCtxt<'tcx>, map: &'map Map) -> Self { + Self { tcx, map, before_effect: FxHashMap::default(), assignments: FxHashMap::default() } + } + + fn make_operand(&self, scalar: ScalarTy<'tcx>) -> Operand<'tcx> { + Operand::Constant(Box::new(Constant { + span: DUMMY_SP, + user_ty: None, + literal: ConstantKind::Val(ConstValue::Scalar(scalar.0), scalar.1), + })) + } +} + +impl<'mir, 'tcx, 'map> ResultsVisitor<'mir, 'tcx> for CollectAndPatch<'tcx, 'map> { + type FlowState = State<FlatSet<ScalarTy<'tcx>>>; + + fn visit_statement_before_primary_effect( + &mut self, + state: &Self::FlowState, + statement: &'mir Statement<'tcx>, + location: Location, + ) { + match &statement.kind { + StatementKind::Assign(box (_, rvalue)) => { + OperandCollector { state, visitor: self }.visit_rvalue(rvalue, location); + } + _ => (), + } + } + + fn visit_statement_after_primary_effect( + &mut self, + state: &Self::FlowState, + statement: &'mir Statement<'tcx>, + location: Location, + ) { + match statement.kind { + StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(_)))) => { + // Don't overwrite the assignment if it already uses a constant (to keep the span). + } + StatementKind::Assign(box (place, _)) => match state.get(place.as_ref(), self.map) { + FlatSet::Top => (), + FlatSet::Elem(value) => { + self.assignments.insert(location, value); + } + FlatSet::Bottom => { + // This assignment is either unreachable, or an uninitialized value is assigned. + } + }, + _ => (), + } + } + + fn visit_terminator_before_primary_effect( + &mut self, + state: &Self::FlowState, + terminator: &'mir Terminator<'tcx>, + location: Location, + ) { + OperandCollector { state, visitor: self }.visit_terminator(terminator, location); + } +} + +impl<'tcx, 'map> MutVisitor<'tcx> for CollectAndPatch<'tcx, 'map> { + fn tcx<'a>(&'a self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { + if let Some(value) = self.assignments.get(&location) { + match &mut statement.kind { + StatementKind::Assign(box (_, rvalue)) => { + *rvalue = Rvalue::Use(self.make_operand(value.clone())); + } + _ => bug!("found assignment info for non-assign statement"), + } + } else { + self.super_statement(statement, location); + } + } + + fn visit_operand(&mut self, operand: &mut Operand<'tcx>, location: Location) { + match operand { + Operand::Copy(place) | Operand::Move(place) => { + if let Some(value) = self.before_effect.get(&(location, *place)) { + *operand = self.make_operand(value.clone()); + } + } + _ => (), + } + } +} + +struct OperandCollector<'tcx, 'map, 'a> { + state: &'a State<FlatSet<ScalarTy<'tcx>>>, + visitor: &'a mut CollectAndPatch<'tcx, 'map>, +} + +impl<'tcx, 'map, 'a> Visitor<'tcx> for OperandCollector<'tcx, 'map, 'a> { + fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) { + match operand { + Operand::Copy(place) | Operand::Move(place) => { + match self.state.get(place.as_ref(), self.visitor.map) { + FlatSet::Top => (), + FlatSet::Elem(value) => { + self.visitor.before_effect.insert((location, *place), value); + } + FlatSet::Bottom => (), + } + } + _ => (), + } + } +} + +struct DummyMachine; + +impl<'mir, 'tcx> rustc_const_eval::interpret::Machine<'mir, 'tcx> for DummyMachine { + rustc_const_eval::interpret::compile_time_machine!(<'mir, 'tcx>); + type MemoryKind = !; + const PANIC_ON_ALLOC_FAIL: bool = true; + + fn enforce_alignment(_ecx: &InterpCx<'mir, 'tcx, Self>) -> bool { + unimplemented!() + } + + fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>) -> bool { + unimplemented!() + } + + fn find_mir_or_eval_fn( + _ecx: &mut InterpCx<'mir, 'tcx, Self>, + _instance: ty::Instance<'tcx>, + _abi: rustc_target::spec::abi::Abi, + _args: &[rustc_const_eval::interpret::OpTy<'tcx, Self::Provenance>], + _destination: &rustc_const_eval::interpret::PlaceTy<'tcx, Self::Provenance>, + _target: Option<BasicBlock>, + _unwind: rustc_const_eval::interpret::StackPopUnwind, + ) -> interpret::InterpResult<'tcx, Option<(&'mir Body<'tcx>, ty::Instance<'tcx>)>> { + unimplemented!() + } + + fn call_intrinsic( + _ecx: &mut InterpCx<'mir, 'tcx, Self>, + _instance: ty::Instance<'tcx>, + _args: &[rustc_const_eval::interpret::OpTy<'tcx, Self::Provenance>], + _destination: &rustc_const_eval::interpret::PlaceTy<'tcx, Self::Provenance>, + _target: Option<BasicBlock>, + _unwind: rustc_const_eval::interpret::StackPopUnwind, + ) -> interpret::InterpResult<'tcx> { + unimplemented!() + } + + fn assert_panic( + _ecx: &mut InterpCx<'mir, 'tcx, Self>, + _msg: &rustc_middle::mir::AssertMessage<'tcx>, + _unwind: Option<BasicBlock>, + ) -> interpret::InterpResult<'tcx> { + unimplemented!() + } + + fn binary_ptr_op( + _ecx: &InterpCx<'mir, 'tcx, Self>, + _bin_op: BinOp, + _left: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>, + _right: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>, + ) -> interpret::InterpResult<'tcx, (interpret::Scalar<Self::Provenance>, bool, Ty<'tcx>)> { + throw_unsup!(Unsupported("".into())) + } + + fn expose_ptr( + _ecx: &mut InterpCx<'mir, 'tcx, Self>, + _ptr: interpret::Pointer<Self::Provenance>, + ) -> interpret::InterpResult<'tcx> { + unimplemented!() + } + + fn init_frame_extra( + _ecx: &mut InterpCx<'mir, 'tcx, Self>, + _frame: rustc_const_eval::interpret::Frame<'mir, 'tcx, Self::Provenance>, + ) -> interpret::InterpResult< + 'tcx, + rustc_const_eval::interpret::Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>, + > { + unimplemented!() + } + + fn stack<'a>( + _ecx: &'a InterpCx<'mir, 'tcx, Self>, + ) -> &'a [rustc_const_eval::interpret::Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>] + { + unimplemented!() + } + + fn stack_mut<'a>( + _ecx: &'a mut InterpCx<'mir, 'tcx, Self>, + ) -> &'a mut Vec< + rustc_const_eval::interpret::Frame<'mir, 'tcx, Self::Provenance, Self::FrameExtra>, + > { + unimplemented!() + } +} diff --git a/compiler/rustc_mir_transform/src/dead_store_elimination.rs b/compiler/rustc_mir_transform/src/dead_store_elimination.rs index 3f3870cc7..42c580c63 100644 --- a/compiler/rustc_mir_transform/src/dead_store_elimination.rs +++ b/compiler/rustc_mir_transform/src/dead_store_elimination.rs @@ -70,6 +70,8 @@ pub fn eliminate<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, borrowed: &BitS for Location { block, statement_index } in patch { bbs[block].statements[statement_index].make_nop(); } + + crate::simplify::SimplifyLocals.run_pass(tcx, body) } pub struct DeadStoreElimination; diff --git a/compiler/rustc_mir_transform/src/deduce_param_attrs.rs b/compiler/rustc_mir_transform/src/deduce_param_attrs.rs index 28b1c5a48..92f1fff6b 100644 --- a/compiler/rustc_mir_transform/src/deduce_param_attrs.rs +++ b/compiler/rustc_mir_transform/src/deduce_param_attrs.rs @@ -110,15 +110,16 @@ impl<'tcx> Visitor<'tcx> for DeduceReadOnly { if let TerminatorKind::Call { ref args, .. } = terminator.kind { for arg in args { - if let Operand::Move(_) = *arg { - // ArgumentChecker panics if a direct move of an argument from a caller to a - // callee was detected. - // - // If, in the future, MIR optimizations cause arguments to be moved directly - // from callers to callees, change the panic to instead add the argument in - // question to `mutating_uses`. - ArgumentChecker::new(self.mutable_args.domain_size()) - .visit_operand(arg, location) + if let Operand::Move(place) = *arg { + let local = place.local; + if place.is_indirect() + || local == RETURN_PLACE + || local.index() > self.mutable_args.domain_size() + { + continue; + } + + self.mutable_args.insert(local.index() - 1); } } }; @@ -127,35 +128,6 @@ impl<'tcx> Visitor<'tcx> for DeduceReadOnly { } } -/// A visitor that simply panics if a direct move of an argument from a caller to a callee was -/// detected. -struct ArgumentChecker { - /// The number of arguments to the calling function. - arg_count: usize, -} - -impl ArgumentChecker { - /// Creates a new ArgumentChecker. - fn new(arg_count: usize) -> Self { - Self { arg_count } - } -} - -impl<'tcx> Visitor<'tcx> for ArgumentChecker { - fn visit_local(&mut self, local: Local, context: PlaceContext, _: Location) { - // Check to make sure that, if this local is an argument, we didn't move directly from it. - if matches!(context, PlaceContext::NonMutatingUse(NonMutatingUseContext::Move)) - && local != RETURN_PLACE - && local.index() <= self.arg_count - { - // If, in the future, MIR optimizations cause arguments to be moved directly from - // callers to callees, change this panic to instead add the argument in question to - // `mutating_uses`. - panic!("Detected a direct move from a caller's argument to a callee's argument!") - } - } -} - /// Returns true if values of a given type will never be passed indirectly, regardless of ABI. fn type_will_always_be_passed_directly<'tcx>(ty: Ty<'tcx>) -> bool { matches!( diff --git a/compiler/rustc_mir_transform/src/dest_prop.rs b/compiler/rustc_mir_transform/src/dest_prop.rs index 9bc47613e..97485c4f5 100644 --- a/compiler/rustc_mir_transform/src/dest_prop.rs +++ b/compiler/rustc_mir_transform/src/dest_prop.rs @@ -20,7 +20,8 @@ //! values or the return place `_0`. On a very high level, independent of the actual implementation //! details, it does the following: //! -//! 1) Identify `dest = src;` statements that can be soundly eliminated. +//! 1) Identify `dest = src;` statements with values for `dest` and `src` whose storage can soundly +//! be merged. //! 2) Replace all mentions of `src` with `dest` ("unifying" them and propagating the destination //! backwards). //! 3) Delete the `dest = src;` statement (by making it a `nop`). @@ -29,44 +30,80 @@ //! //! ## Soundness //! -//! Given an `Assign` statement `dest = src;`, where `dest` is a `Place` and `src` is an `Rvalue`, -//! there are a few requirements that must hold for the optimization to be sound: +//! We have a pair of places `p` and `q`, whose memory we would like to merge. In order for this to +//! be sound, we need to check a number of conditions: //! -//! * `dest` must not contain any *indirection* through a pointer. It must access part of the base -//! local. Otherwise it might point to arbitrary memory that is hard to track. +//! * `p` and `q` must both be *constant* - it does not make much sense to talk about merging them +//! if they do not consistently refer to the same place in memory. This is satisfied if they do +//! not contain any indirection through a pointer or any indexing projections. //! -//! It must also not contain any indexing projections, since those take an arbitrary `Local` as -//! the index, and that local might only be initialized shortly before `dest` is used. +//! * We need to make sure that the goal of "merging the memory" is actually structurally possible +//! in MIR. For example, even if all the other conditions are satisfied, there is no way to +//! "merge" `_5.foo` and `_6.bar`. For now, we ensure this by requiring that both `p` and `q` are +//! locals with no further projections. Future iterations of this pass should improve on this. //! -//! * `src` must be a bare `Local` without any indirections or field projections (FIXME: Is this a -//! fundamental restriction or just current impl state?). It can be copied or moved by the -//! assignment. +//! * Finally, we want `p` and `q` to use the same memory - however, we still need to make sure that +//! each of them has enough "ownership" of that memory to continue "doing its job." More +//! precisely, what we will check is that whenever the program performs a write to `p`, then it +//! does not currently care about what the value in `q` is (and vice versa). We formalize the +//! notion of "does not care what the value in `q` is" by checking the *liveness* of `q`. //! -//! * The `dest` and `src` locals must never be [*live*][liveness] at the same time. If they are, it -//! means that they both hold a (potentially different) value that is needed by a future use of -//! the locals. Unifying them would overwrite one of the values. +//! Because of the difficulty of computing liveness of places that have their address taken, we do +//! not even attempt to do it. Any places that are in a local that has its address taken is +//! excluded from the optimization. //! -//! Note that computing liveness of locals that have had their address taken is more difficult: -//! Short of doing full escape analysis on the address/pointer/reference, the pass would need to -//! assume that any operation that can potentially involve opaque user code (such as function -//! calls, destructors, and inline assembly) may access any local that had its address taken -//! before that point. +//! The first two conditions are simple structural requirements on the `Assign` statements that can +//! be trivially checked. The third requirement however is more difficult and costly to check. //! -//! Here, the first two conditions are simple structural requirements on the `Assign` statements -//! that can be trivially checked. The liveness requirement however is more difficult and costly to -//! check. +//! ## Future Improvements +//! +//! There are a number of ways in which this pass could be improved in the future: +//! +//! * Merging storage liveness ranges instead of removing storage statements completely. This may +//! improve stack usage. +//! +//! * Allow merging locals into places with projections, eg `_5` into `_6.foo`. +//! +//! * Liveness analysis with more precision than whole locals at a time. The smaller benefit of this +//! is that it would allow us to dest prop at "sub-local" levels in some cases. The bigger benefit +//! of this is that such liveness analysis can report more accurate results about whole locals at +//! a time. For example, consider: +//! +//! ```ignore (syntax-highliting-only) +//! _1 = u; +//! // unrelated code +//! _1.f1 = v; +//! _2 = _1.f1; +//! ``` +//! +//! Because the current analysis only thinks in terms of locals, it does not have enough +//! information to report that `_1` is dead in the "unrelated code" section. +//! +//! * Liveness analysis enabled by alias analysis. This would allow us to not just bail on locals +//! that ever have their address taken. Of course that requires actually having alias analysis +//! (and a model to build it on), so this might be a bit of a ways off. +//! +//! * Various perf improvents. There are a bunch of comments in here marked `PERF` with ideas for +//! how to do things more efficiently. However, the complexity of the pass as a whole should be +//! kept in mind. //! //! ## Previous Work //! -//! A [previous attempt] at implementing an optimization like this turned out to be a significant -//! regression in compiler performance. Fixing the regressions introduced a lot of undesirable -//! complexity to the implementation. +//! A [previous attempt][attempt 1] at implementing an optimization like this turned out to be a +//! significant regression in compiler performance. Fixing the regressions introduced a lot of +//! undesirable complexity to the implementation. +//! +//! A [subsequent approach][attempt 2] tried to avoid the costly computation by limiting itself to +//! acyclic CFGs, but still turned out to be far too costly to run due to suboptimal performance +//! within individual basic blocks, requiring a walk across the entire block for every assignment +//! found within the block. For the `tuple-stress` benchmark, which has 458745 statements in a +//! single block, this proved to be far too costly. //! -//! A [subsequent approach] tried to avoid the costly computation by limiting itself to acyclic -//! CFGs, but still turned out to be far too costly to run due to suboptimal performance within -//! individual basic blocks, requiring a walk across the entire block for every assignment found -//! within the block. For the `tuple-stress` benchmark, which has 458745 statements in a single -//! block, this proved to be far too costly. +//! [Another approach after that][attempt 3] was much closer to correct, but had some soundness +//! issues - it was failing to consider stores outside live ranges, and failed to uphold some of the +//! requirements that MIR has for non-overlapping places within statements. However, it also had +//! performance issues caused by `O(l² * s)` runtime, where `l` is the number of locals and `s` is +//! the number of statements and terminators. //! //! Since the first attempt at this, the compiler has improved dramatically, and new analysis //! frameworks have been added that should make this approach viable without requiring a limited @@ -74,8 +111,7 @@ //! - rustc now has a powerful dataflow analysis framework that can handle forwards and backwards //! analyses efficiently. //! - Layout optimizations for generators have been added to improve code generation for -//! async/await, which are very similar in spirit to what this optimization does. Both walk the -//! MIR and record conflicting uses of locals in a `BitMatrix`. +//! async/await, which are very similar in spirit to what this optimization does. //! //! Also, rustc now has a simple NRVO pass (see `nrvo.rs`), which handles a subset of the cases that //! this destination propagation pass handles, proving that similar optimizations can be performed @@ -87,253 +123,205 @@ //! it replaces the eliminated assign statements with `nop`s and leaves unused locals behind. //! //! [liveness]: https://en.wikipedia.org/wiki/Live_variable_analysis -//! [previous attempt]: https://github.com/rust-lang/rust/pull/47954 -//! [subsequent approach]: https://github.com/rust-lang/rust/pull/71003 +//! [attempt 1]: https://github.com/rust-lang/rust/pull/47954 +//! [attempt 2]: https://github.com/rust-lang/rust/pull/71003 +//! [attempt 3]: https://github.com/rust-lang/rust/pull/72632 + +use std::collections::hash_map::{Entry, OccupiedEntry}; use crate::MirPass; -use itertools::Itertools; -use rustc_data_structures::unify::{InPlaceUnificationTable, UnifyKey}; -use rustc_index::{ - bit_set::{BitMatrix, BitSet}, - vec::IndexVec, -}; -use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; +use rustc_data_structures::fx::FxHashMap; +use rustc_index::bit_set::BitSet; use rustc_middle::mir::{dump_mir, PassWhere}; use rustc_middle::mir::{ - traversal, Body, InlineAsmOperand, Local, LocalKind, Location, Operand, Place, PlaceElem, - Rvalue, Statement, StatementKind, Terminator, TerminatorKind, + traversal, BasicBlock, Body, InlineAsmOperand, Local, LocalKind, Location, Operand, Place, + Rvalue, Statement, StatementKind, TerminatorKind, +}; +use rustc_middle::mir::{ + visit::{MutVisitor, PlaceContext, Visitor}, + ProjectionElem, }; use rustc_middle::ty::TyCtxt; -use rustc_mir_dataflow::impls::{borrowed_locals, MaybeInitializedLocals, MaybeLiveLocals}; -use rustc_mir_dataflow::Analysis; - -// Empirical measurements have resulted in some observations: -// - Running on a body with a single block and 500 locals takes barely any time -// - Running on a body with ~400 blocks and ~300 relevant locals takes "too long" -// ...so we just limit both to somewhat reasonable-ish looking values. -const MAX_LOCALS: usize = 500; -const MAX_BLOCKS: usize = 250; +use rustc_mir_dataflow::impls::MaybeLiveLocals; +use rustc_mir_dataflow::{Analysis, ResultsCursor}; pub struct DestinationPropagation; impl<'tcx> MirPass<'tcx> for DestinationPropagation { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - // FIXME(#79191, #82678): This is unsound. - // - // Only run at mir-opt-level=3 or higher for now (we don't fix up debuginfo and remove - // storage statements at the moment). - sess.opts.unstable_opts.unsound_mir_opts && sess.mir_opt_level() >= 3 + // For now, only run at MIR opt level 3. Two things need to be changed before this can be + // turned on by default: + // 1. Because of the overeager removal of storage statements, this can cause stack space + // regressions. This opt is not the place to fix this though, it's a more general + // problem in MIR. + // 2. Despite being an overall perf improvement, this still causes a 30% regression in + // keccak. We can temporarily fix this by bounding function size, but in the long term + // we should fix this by being smarter about invalidating analysis results. + sess.mir_opt_level() >= 3 } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let def_id = body.source.def_id(); + let mut allocations = Allocations::default(); + trace!(func = ?tcx.def_path_str(def_id)); - let candidates = find_candidates(body); - if candidates.is_empty() { - debug!("{:?}: no dest prop candidates, done", def_id); - return; - } + let borrowed = rustc_mir_dataflow::impls::borrowed_locals(body); - // Collect all locals we care about. We only compute conflicts for these to save time. - let mut relevant_locals = BitSet::new_empty(body.local_decls.len()); - for CandidateAssignment { dest, src, loc: _ } in &candidates { - relevant_locals.insert(dest.local); - relevant_locals.insert(*src); - } - - // This pass unfortunately has `O(l² * s)` performance, where `l` is the number of locals - // and `s` is the number of statements and terminators in the function. - // To prevent blowing up compile times too much, we bail out when there are too many locals. - let relevant = relevant_locals.count(); - debug!( - "{:?}: {} locals ({} relevant), {} blocks", - def_id, - body.local_decls.len(), - relevant, - body.basic_blocks.len() - ); - if relevant > MAX_LOCALS { - warn!( - "too many candidate locals in {:?} ({}, max is {}), not optimizing", - def_id, relevant, MAX_LOCALS + // In order to avoid having to collect data for every single pair of locals in the body, we + // do not allow doing more than one merge for places that are derived from the same local at + // once. To avoid missed opportunities, we instead iterate to a fixed point - we'll refer to + // each of these iterations as a "round." + // + // Reaching a fixed point could in theory take up to `min(l, s)` rounds - however, we do not + // expect to see MIR like that. To verify this, a test was run against `[rust-lang/regex]` - + // the average MIR body saw 1.32 full iterations of this loop. The most that was hit were 30 + // for a single function. Only 80/2801 (2.9%) of functions saw at least 5. + // + // [rust-lang/regex]: + // https://github.com/rust-lang/regex/tree/b5372864e2df6a2f5e543a556a62197f50ca3650 + let mut round_count = 0; + loop { + // PERF: Can we do something smarter than recalculating the candidates and liveness + // results? + let mut candidates = find_candidates( + body, + &borrowed, + &mut allocations.candidates, + &mut allocations.candidates_reverse, ); - return; - } - if body.basic_blocks.len() > MAX_BLOCKS { - warn!( - "too many blocks in {:?} ({}, max is {}), not optimizing", - def_id, - body.basic_blocks.len(), - MAX_BLOCKS + trace!(?candidates); + let mut live = MaybeLiveLocals + .into_engine(tcx, body) + .iterate_to_fixpoint() + .into_results_cursor(body); + dest_prop_mir_dump(tcx, body, &mut live, round_count); + + FilterInformation::filter_liveness( + &mut candidates, + &mut live, + &mut allocations.write_info, + body, ); - return; - } - let mut conflicts = Conflicts::build(tcx, body, &relevant_locals); + // Because we do not update liveness information, it is unsound to use a local for more + // than one merge operation within a single round of optimizations. We store here which + // ones we have already used. + let mut merged_locals: BitSet<Local> = BitSet::new_empty(body.local_decls.len()); - let mut replacements = Replacements::new(body.local_decls.len()); - for candidate @ CandidateAssignment { dest, src, loc } in candidates { - // Merge locals that don't conflict. - if !conflicts.can_unify(dest.local, src) { - debug!("at assignment {:?}, conflict {:?} vs. {:?}", loc, dest.local, src); - continue; - } + // This is the set of merges we will apply this round. It is a subset of the candidates. + let mut merges = FxHashMap::default(); - if replacements.for_src(candidate.src).is_some() { - debug!("src {:?} already has replacement", candidate.src); - continue; + for (src, candidates) in candidates.c.iter() { + if merged_locals.contains(*src) { + continue; + } + let Some(dest) = + candidates.iter().find(|dest| !merged_locals.contains(**dest)) else { + continue; + }; + if !tcx.consider_optimizing(|| { + format!("{} round {}", tcx.def_path_str(def_id), round_count) + }) { + break; + } + merges.insert(*src, *dest); + merged_locals.insert(*src); + merged_locals.insert(*dest); } + trace!(merging = ?merges); - if !tcx.consider_optimizing(|| { - format!("DestinationPropagation {:?} {:?}", def_id, candidate) - }) { + if merges.is_empty() { break; } + round_count += 1; - replacements.push(candidate); - conflicts.unify(candidate.src, candidate.dest.local); + apply_merges(body, tcx, &merges, &merged_locals); } - replacements.flatten(tcx); - - debug!("replacements {:?}", replacements.map); - - Replacer { tcx, replacements, place_elem_cache: Vec::new() }.visit_body(body); - - // FIXME fix debug info + trace!(round_count); } } -#[derive(Debug, Eq, PartialEq, Copy, Clone)] -struct UnifyLocal(Local); - -impl From<Local> for UnifyLocal { - fn from(l: Local) -> Self { - Self(l) - } -} - -impl UnifyKey for UnifyLocal { - type Value = (); - #[inline] - fn index(&self) -> u32 { - self.0.as_u32() - } - #[inline] - fn from_index(u: u32) -> Self { - Self(Local::from_u32(u)) - } - fn tag() -> &'static str { - "UnifyLocal" - } +/// Container for the various allocations that we need. +/// +/// We store these here and hand out `&mut` access to them, instead of dropping and recreating them +/// frequently. Everything with a `&'alloc` lifetime points into here. +#[derive(Default)] +struct Allocations { + candidates: FxHashMap<Local, Vec<Local>>, + candidates_reverse: FxHashMap<Local, Vec<Local>>, + write_info: WriteInfo, + // PERF: Do this for `MaybeLiveLocals` allocations too. } -struct Replacements<'tcx> { - /// Maps locals to their replacement. - map: IndexVec<Local, Option<Place<'tcx>>>, - - /// Whose locals' live ranges to kill. - kill: BitSet<Local>, +#[derive(Debug)] +struct Candidates<'alloc> { + /// The set of candidates we are considering in this optimization. + /// + /// We will always merge the key into at most one of its values. + /// + /// Whether a place ends up in the key or the value does not correspond to whether it appears as + /// the lhs or rhs of any assignment. As a matter of fact, the places in here might never appear + /// in an assignment at all. This happens because if we see an assignment like this: + /// + /// ```ignore (syntax-highlighting-only) + /// _1.0 = _2.0 + /// ``` + /// + /// We will still report that we would like to merge `_1` and `_2` in an attempt to allow us to + /// remove that assignment. + c: &'alloc mut FxHashMap<Local, Vec<Local>>, + /// A reverse index of the `c` set; if the `c` set contains `a => Place { local: b, proj }`, + /// then this contains `b => a`. + // PERF: Possibly these should be `SmallVec`s? + reverse: &'alloc mut FxHashMap<Local, Vec<Local>>, } -impl<'tcx> Replacements<'tcx> { - fn new(locals: usize) -> Self { - Self { map: IndexVec::from_elem_n(None, locals), kill: BitSet::new_empty(locals) } - } - - fn push(&mut self, candidate: CandidateAssignment<'tcx>) { - trace!("Replacements::push({:?})", candidate); - let entry = &mut self.map[candidate.src]; - assert!(entry.is_none()); - - *entry = Some(candidate.dest); - self.kill.insert(candidate.src); - self.kill.insert(candidate.dest.local); - } - - /// Applies the stored replacements to all replacements, until no replacements would result in - /// locals that need further replacements when applied. - fn flatten(&mut self, tcx: TyCtxt<'tcx>) { - // Note: This assumes that there are no cycles in the replacements, which is enforced via - // `self.unified_locals`. Otherwise this can cause an infinite loop. - - for local in self.map.indices() { - if let Some(replacement) = self.map[local] { - // Substitute the base local of `replacement` until fixpoint. - let mut base = replacement.local; - let mut reversed_projection_slices = Vec::with_capacity(1); - while let Some(replacement_for_replacement) = self.map[base] { - base = replacement_for_replacement.local; - reversed_projection_slices.push(replacement_for_replacement.projection); - } - - let projection: Vec<_> = reversed_projection_slices - .iter() - .rev() - .flat_map(|projs| projs.iter()) - .chain(replacement.projection.iter()) - .collect(); - let projection = tcx.intern_place_elems(&projection); +////////////////////////////////////////////////////////// +// Merging +// +// Applies the actual optimization - // Replace with the final `Place`. - self.map[local] = Some(Place { local: base, projection }); - } - } - } - - fn for_src(&self, src: Local) -> Option<Place<'tcx>> { - self.map[src] - } +fn apply_merges<'tcx>( + body: &mut Body<'tcx>, + tcx: TyCtxt<'tcx>, + merges: &FxHashMap<Local, Local>, + merged_locals: &BitSet<Local>, +) { + let mut merger = Merger { tcx, merges, merged_locals }; + merger.visit_body_preserves_cfg(body); } -struct Replacer<'tcx> { +struct Merger<'a, 'tcx> { tcx: TyCtxt<'tcx>, - replacements: Replacements<'tcx>, - place_elem_cache: Vec<PlaceElem<'tcx>>, + merges: &'a FxHashMap<Local, Local>, + merged_locals: &'a BitSet<Local>, } -impl<'tcx> MutVisitor<'tcx> for Replacer<'tcx> { +impl<'a, 'tcx> MutVisitor<'tcx> for Merger<'a, 'tcx> { fn tcx(&self) -> TyCtxt<'tcx> { self.tcx } - fn visit_local(&mut self, local: &mut Local, context: PlaceContext, location: Location) { - if context.is_use() && self.replacements.for_src(*local).is_some() { - bug!( - "use of local {:?} should have been replaced by visit_place; context={:?}, loc={:?}", - local, - context, - location, - ); + fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _location: Location) { + if let Some(dest) = self.merges.get(local) { + *local = *dest; } } - fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) { - if let Some(replacement) = self.replacements.for_src(place.local) { - // Rebase `place`s projections onto `replacement`'s. - self.place_elem_cache.clear(); - self.place_elem_cache.extend(replacement.projection.iter().chain(place.projection)); - let projection = self.tcx.intern_place_elems(&self.place_elem_cache); - let new_place = Place { local: replacement.local, projection }; - - debug!("Replacer: {:?} -> {:?}", place, new_place); - *place = new_place; - } - - self.super_place(place, context, location); - } - fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { - self.super_statement(statement, location); - match &statement.kind { - // FIXME: Don't delete storage statements, merge the live ranges instead + // FIXME: Don't delete storage statements, but "merge" the storage ranges instead. StatementKind::StorageDead(local) | StatementKind::StorageLive(local) - if self.replacements.kill.contains(*local) => + if self.merged_locals.contains(*local) => { - statement.make_nop() + statement.make_nop(); + return; } - + _ => (), + }; + self.super_statement(statement, location); + match &statement.kind { StatementKind::Assign(box (dest, rvalue)) => { match rvalue { Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) => { @@ -353,524 +341,427 @@ impl<'tcx> MutVisitor<'tcx> for Replacer<'tcx> { } } -struct Conflicts<'a> { - relevant_locals: &'a BitSet<Local>, - - /// The conflict matrix. It is always symmetric and the adjacency matrix of the corresponding - /// conflict graph. - matrix: BitMatrix<Local, Local>, - - /// Preallocated `BitSet` used by `unify`. - unify_cache: BitSet<Local>, - - /// Tracks locals that have been merged together to prevent cycles and propagate conflicts. - unified_locals: InPlaceUnificationTable<UnifyLocal>, +////////////////////////////////////////////////////////// +// Liveness filtering +// +// This section enforces bullet point 2 + +struct FilterInformation<'a, 'body, 'alloc, 'tcx> { + body: &'body Body<'tcx>, + live: &'a mut ResultsCursor<'body, 'tcx, MaybeLiveLocals>, + candidates: &'a mut Candidates<'alloc>, + write_info: &'alloc mut WriteInfo, + at: Location, } -impl<'a> Conflicts<'a> { - fn build<'tcx>( - tcx: TyCtxt<'tcx>, - body: &'_ Body<'tcx>, - relevant_locals: &'a BitSet<Local>, - ) -> Self { - // We don't have to look out for locals that have their address taken, since - // `find_candidates` already takes care of that. - - let conflicts = BitMatrix::from_row_n( - &BitSet::new_empty(body.local_decls.len()), - body.local_decls.len(), - ); - - let mut init = MaybeInitializedLocals - .into_engine(tcx, body) - .iterate_to_fixpoint() - .into_results_cursor(body); - let mut live = - MaybeLiveLocals.into_engine(tcx, body).iterate_to_fixpoint().into_results_cursor(body); - - let mut reachable = None; - dump_mir(tcx, None, "DestinationPropagation-dataflow", &"", body, |pass_where, w| { - let reachable = reachable.get_or_insert_with(|| traversal::reachable_as_bitset(body)); - - match pass_where { - PassWhere::BeforeLocation(loc) if reachable.contains(loc.block) => { - init.seek_before_primary_effect(loc); - live.seek_after_primary_effect(loc); - - writeln!(w, " // init: {:?}", init.get())?; - writeln!(w, " // live: {:?}", live.get())?; - } - PassWhere::AfterTerminator(bb) if reachable.contains(bb) => { - let loc = body.terminator_loc(bb); - init.seek_after_primary_effect(loc); - live.seek_before_primary_effect(loc); - - writeln!(w, " // init: {:?}", init.get())?; - writeln!(w, " // live: {:?}", live.get())?; - } - - PassWhere::BeforeBlock(bb) if reachable.contains(bb) => { - init.seek_to_block_start(bb); - live.seek_to_block_start(bb); - - writeln!(w, " // init: {:?}", init.get())?; - writeln!(w, " // live: {:?}", live.get())?; - } - - PassWhere::BeforeCFG | PassWhere::AfterCFG | PassWhere::AfterLocation(_) => {} - - PassWhere::BeforeLocation(_) | PassWhere::AfterTerminator(_) => { - writeln!(w, " // init: <unreachable>")?; - writeln!(w, " // live: <unreachable>")?; - } - - PassWhere::BeforeBlock(_) => { - writeln!(w, " // init: <unreachable>")?; - writeln!(w, " // live: <unreachable>")?; - } +// We first implement some utility functions which we will expose removing candidates according to +// different needs. Throughout the livenss filtering, the `candidates` are only ever accessed +// through these methods, and not directly. +impl<'alloc> Candidates<'alloc> { + /// Just `Vec::retain`, but the condition is inverted and we add debugging output + fn vec_remove_debug( + src: Local, + v: &mut Vec<Local>, + mut f: impl FnMut(Local) -> bool, + at: Location, + ) { + v.retain(|dest| { + let remove = f(*dest); + if remove { + trace!("eliminating {:?} => {:?} due to conflict at {:?}", src, dest, at); } - - Ok(()) + !remove }); + } - let mut this = Self { - relevant_locals, - matrix: conflicts, - unify_cache: BitSet::new_empty(body.local_decls.len()), - unified_locals: { - let mut table = InPlaceUnificationTable::new(); - // Pre-fill table with all locals (this creates N nodes / "connected" components, - // "graph"-ically speaking). - for local in 0..body.local_decls.len() { - assert_eq!(table.new_key(()), UnifyLocal(Local::from_usize(local))); - } - table - }, - }; - - let mut live_and_init_locals = Vec::new(); - - // Visit only reachable basic blocks. The exact order is not important. - for (block, data) in traversal::preorder(body) { - // We need to observe the dataflow state *before* all possible locations (statement or - // terminator) in each basic block, and then observe the state *after* the terminator - // effect is applied. As long as neither `init` nor `borrowed` has a "before" effect, - // we will observe all possible dataflow states. - - // Since liveness is a backwards analysis, we need to walk the results backwards. To do - // that, we first collect in the `MaybeInitializedLocals` results in a forwards - // traversal. - - live_and_init_locals.resize_with(data.statements.len() + 1, || { - BitSet::new_empty(body.local_decls.len()) - }); - - // First, go forwards for `MaybeInitializedLocals` and apply intra-statement/terminator - // conflicts. - for (i, statement) in data.statements.iter().enumerate() { - this.record_statement_conflicts(statement); - - let loc = Location { block, statement_index: i }; - init.seek_before_primary_effect(loc); + /// `vec_remove_debug` but for an `Entry` + fn entry_remove( + mut entry: OccupiedEntry<'_, Local, Vec<Local>>, + p: Local, + f: impl FnMut(Local) -> bool, + at: Location, + ) { + let candidates = entry.get_mut(); + Self::vec_remove_debug(p, candidates, f, at); + if candidates.len() == 0 { + entry.remove(); + } + } - live_and_init_locals[i].clone_from(init.get()); + /// Removes all candidates `(p, q)` or `(q, p)` where `p` is the indicated local and `f(q)` is true. + fn remove_candidates_if(&mut self, p: Local, mut f: impl FnMut(Local) -> bool, at: Location) { + // Cover the cases where `p` appears as a `src` + if let Entry::Occupied(entry) = self.c.entry(p) { + Self::entry_remove(entry, p, &mut f, at); + } + // And the cases where `p` appears as a `dest` + let Some(srcs) = self.reverse.get_mut(&p) else { + return; + }; + // We use `retain` here to remove the elements from the reverse set if we've removed the + // matching candidate in the forward set. + srcs.retain(|src| { + if !f(*src) { + return true; } + let Entry::Occupied(entry) = self.c.entry(*src) else { + return false; + }; + Self::entry_remove(entry, *src, |dest| dest == p, at); + false + }); + } +} - this.record_terminator_conflicts(data.terminator()); - let term_loc = Location { block, statement_index: data.statements.len() }; - init.seek_before_primary_effect(term_loc); - live_and_init_locals[term_loc.statement_index].clone_from(init.get()); - - // Now, go backwards and union with the liveness results. - for statement_index in (0..=data.statements.len()).rev() { - let loc = Location { block, statement_index }; - live.seek_after_primary_effect(loc); - - live_and_init_locals[statement_index].intersect(live.get()); - - trace!("record conflicts at {:?}", loc); +impl<'a, 'body, 'alloc, 'tcx> FilterInformation<'a, 'body, 'alloc, 'tcx> { + /// Filters the set of candidates to remove those that conflict. + /// + /// The steps we take are exactly those that are outlined at the top of the file. For each + /// statement/terminator, we collect the set of locals that are written to in that + /// statement/terminator, and then we remove all pairs of candidates that contain one such local + /// and another one that is live. + /// + /// We need to be careful about the ordering of operations within each statement/terminator + /// here. Many statements might write and read from more than one place, and we need to consider + /// them all. The strategy for doing this is as follows: We first gather all the places that are + /// written to within the statement/terminator via `WriteInfo`. Then, we use the liveness + /// analysis from *before* the statement/terminator (in the control flow sense) to eliminate + /// candidates - this is because we want to conservatively treat a pair of locals that is both + /// read and written in the statement/terminator to be conflicting, and the liveness analysis + /// before the statement/terminator will correctly report locals that are read in the + /// statement/terminator to be live. We are additionally conservative by treating all written to + /// locals as also being read from. + fn filter_liveness<'b>( + candidates: &mut Candidates<'alloc>, + live: &mut ResultsCursor<'b, 'tcx, MaybeLiveLocals>, + write_info_alloc: &'alloc mut WriteInfo, + body: &'b Body<'tcx>, + ) { + let mut this = FilterInformation { + body, + live, + candidates, + // We don't actually store anything at this scope, we just keep things here to be able + // to reuse the allocation. + write_info: write_info_alloc, + // Doesn't matter what we put here, will be overwritten before being used + at: Location { block: BasicBlock::from_u32(0), statement_index: 0 }, + }; + this.internal_filter_liveness(); + } - this.record_dataflow_conflicts(&mut live_and_init_locals[statement_index]); + fn internal_filter_liveness(&mut self) { + for (block, data) in traversal::preorder(self.body) { + self.at = Location { block, statement_index: data.statements.len() }; + self.live.seek_after_primary_effect(self.at); + self.write_info.for_terminator(&data.terminator().kind); + self.apply_conflicts(); + + for (i, statement) in data.statements.iter().enumerate().rev() { + self.at = Location { block, statement_index: i }; + self.live.seek_after_primary_effect(self.at); + self.get_statement_write_info(&statement.kind); + self.apply_conflicts(); } - - init.seek_to_block_end(block); - live.seek_to_block_end(block); - let mut conflicts = init.get().clone(); - conflicts.intersect(live.get()); - trace!("record conflicts at end of {:?}", block); - - this.record_dataflow_conflicts(&mut conflicts); } - - this } - fn record_dataflow_conflicts(&mut self, new_conflicts: &mut BitSet<Local>) { - // Remove all locals that are not candidates. - new_conflicts.intersect(self.relevant_locals); + fn apply_conflicts(&mut self) { + let writes = &self.write_info.writes; + for p in writes { + self.candidates.remove_candidates_if( + *p, + // It is possible that a local may be live for less than the + // duration of a statement This happens in the case of function + // calls or inline asm. Because of this, we also mark locals as + // conflicting when both of them are written to in the same + // statement. + |q| self.live.contains(q) || writes.contains(&q), + self.at, + ); + } + } - for local in new_conflicts.iter() { - self.matrix.union_row_with(&new_conflicts, local); + /// Gets the write info for the `statement`. + fn get_statement_write_info(&mut self, statement: &StatementKind<'tcx>) { + self.write_info.writes.clear(); + match statement { + StatementKind::Assign(box (lhs, rhs)) => match rhs { + Rvalue::Use(op) => { + if !lhs.is_indirect() { + self.get_assign_use_write_info(*lhs, op); + return; + } + } + _ => (), + }, + _ => (), } + + self.write_info.for_statement(statement); } - fn record_local_conflict(&mut self, a: Local, b: Local, why: &str) { - trace!("conflict {:?} <-> {:?} due to {}", a, b, why); - self.matrix.insert(a, b); - self.matrix.insert(b, a); + fn get_assign_use_write_info(&mut self, lhs: Place<'tcx>, rhs: &Operand<'tcx>) { + // We register the writes for the operand unconditionally + self.write_info.add_operand(rhs); + // However, we cannot do the same thing for the `lhs` as that would always block the + // optimization. Instead, we consider removing candidates manually. + let Some(rhs) = rhs.place() else { + self.write_info.add_place(lhs); + return; + }; + // Find out which candidate pair we should skip, if any + let Some((src, dest)) = places_to_candidate_pair(lhs, rhs, self.body) else { + self.write_info.add_place(lhs); + return; + }; + self.candidates.remove_candidates_if( + lhs.local, + |other| { + // Check if this is the candidate pair that should not be removed + if (lhs.local == src && other == dest) || (lhs.local == dest && other == src) { + return false; + } + // Otherwise, do the "standard" thing + self.live.contains(other) + }, + self.at, + ) } +} - /// Records locals that must not overlap during the evaluation of `stmt`. These locals conflict - /// and must not be merged. - fn record_statement_conflicts(&mut self, stmt: &Statement<'_>) { - match &stmt.kind { - // While the left and right sides of an assignment must not overlap, we do not mark - // conflicts here as that would make this optimization useless. When we optimize, we - // eliminate the resulting self-assignments automatically. - StatementKind::Assign(_) => {} - - StatementKind::SetDiscriminant { .. } - | StatementKind::Deinit(..) - | StatementKind::StorageLive(..) - | StatementKind::StorageDead(..) - | StatementKind::Retag(..) - | StatementKind::FakeRead(..) - | StatementKind::AscribeUserType(..) - | StatementKind::Coverage(..) - | StatementKind::Intrinsic(..) - | StatementKind::Nop => {} +/// Describes where a statement/terminator writes to +#[derive(Default, Debug)] +struct WriteInfo { + writes: Vec<Local>, +} + +impl WriteInfo { + fn for_statement<'tcx>(&mut self, statement: &StatementKind<'tcx>) { + match statement { + StatementKind::Assign(box (lhs, rhs)) => { + self.add_place(*lhs); + match rhs { + Rvalue::Use(op) | Rvalue::Repeat(op, _) => { + self.add_operand(op); + } + Rvalue::Cast(_, op, _) + | Rvalue::UnaryOp(_, op) + | Rvalue::ShallowInitBox(op, _) => { + self.add_operand(op); + } + Rvalue::BinaryOp(_, ops) | Rvalue::CheckedBinaryOp(_, ops) => { + for op in [&ops.0, &ops.1] { + self.add_operand(op); + } + } + Rvalue::Aggregate(_, ops) => { + for op in ops { + self.add_operand(op); + } + } + Rvalue::ThreadLocalRef(_) + | Rvalue::NullaryOp(_, _) + | Rvalue::Ref(_, _, _) + | Rvalue::AddressOf(_, _) + | Rvalue::Len(_) + | Rvalue::Discriminant(_) + | Rvalue::CopyForDeref(_) => (), + } + } + // Retags are technically also reads, but reporting them as a write suffices + StatementKind::SetDiscriminant { place, .. } + | StatementKind::Deinit(place) + | StatementKind::Retag(_, place) => { + self.add_place(**place); + } + StatementKind::Intrinsic(_) + | StatementKind::Nop + | StatementKind::Coverage(_) + | StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) => (), + StatementKind::FakeRead(_) | StatementKind::AscribeUserType(_, _) => { + bug!("{:?} not found in this MIR phase", statement) + } } } - fn record_terminator_conflicts(&mut self, term: &Terminator<'_>) { - match &term.kind { - TerminatorKind::DropAndReplace { - place: dropped_place, - value, - target: _, - unwind: _, - } => { - if let Some(place) = value.place() - && !place.is_indirect() - && !dropped_place.is_indirect() - { - self.record_local_conflict( - place.local, - dropped_place.local, - "DropAndReplace operand overlap", - ); - } + fn for_terminator<'tcx>(&mut self, terminator: &TerminatorKind<'tcx>) { + self.writes.clear(); + match terminator { + TerminatorKind::SwitchInt { discr: op, .. } + | TerminatorKind::Assert { cond: op, .. } => { + self.add_operand(op); } - TerminatorKind::Yield { value, resume: _, resume_arg, drop: _ } => { - if let Some(place) = value.place() { - if !place.is_indirect() && !resume_arg.is_indirect() { - self.record_local_conflict( - place.local, - resume_arg.local, - "Yield operand overlap", - ); - } + TerminatorKind::Call { destination, func, args, .. } => { + self.add_place(*destination); + self.add_operand(func); + for arg in args { + self.add_operand(arg); } } - TerminatorKind::Call { - func, - args, - destination, - target: _, - cleanup: _, - from_hir_call: _, - fn_span: _, - } => { - // No arguments may overlap with the destination. - for arg in args.iter().chain(Some(func)) { - if let Some(place) = arg.place() { - if !place.is_indirect() && !destination.is_indirect() { - self.record_local_conflict( - destination.local, - place.local, - "call dest/arg overlap", - ); + TerminatorKind::InlineAsm { operands, .. } => { + for asm_operand in operands { + match asm_operand { + InlineAsmOperand::In { value, .. } => { + self.add_operand(value); } - } - } - } - TerminatorKind::InlineAsm { - template: _, - operands, - options: _, - line_spans: _, - destination: _, - cleanup: _, - } => { - // The intended semantics here aren't documented, we just assume that nothing that - // could be written to by the assembly may overlap with any other operands. - for op in operands { - match op { - InlineAsmOperand::Out { reg: _, late: _, place: Some(dest_place) } - | InlineAsmOperand::InOut { - reg: _, - late: _, - in_value: _, - out_place: Some(dest_place), - } => { - // For output place `place`, add all places accessed by the inline asm. - for op in operands { - match op { - InlineAsmOperand::In { reg: _, value } => { - if let Some(p) = value.place() - && !p.is_indirect() - && !dest_place.is_indirect() - { - self.record_local_conflict( - p.local, - dest_place.local, - "asm! operand overlap", - ); - } - } - InlineAsmOperand::Out { - reg: _, - late: _, - place: Some(place), - } => { - if !place.is_indirect() && !dest_place.is_indirect() { - self.record_local_conflict( - place.local, - dest_place.local, - "asm! operand overlap", - ); - } - } - InlineAsmOperand::InOut { - reg: _, - late: _, - in_value, - out_place, - } => { - if let Some(place) = in_value.place() - && !place.is_indirect() - && !dest_place.is_indirect() - { - self.record_local_conflict( - place.local, - dest_place.local, - "asm! operand overlap", - ); - } - - if let Some(place) = out_place - && !place.is_indirect() - && !dest_place.is_indirect() - { - self.record_local_conflict( - place.local, - dest_place.local, - "asm! operand overlap", - ); - } - } - InlineAsmOperand::Out { reg: _, late: _, place: None } - | InlineAsmOperand::Const { value: _ } - | InlineAsmOperand::SymFn { value: _ } - | InlineAsmOperand::SymStatic { def_id: _ } => {} - } + InlineAsmOperand::Out { place, .. } => { + if let Some(place) = place { + self.add_place(*place); } } - InlineAsmOperand::InOut { - reg: _, - late: _, - in_value: _, - out_place: None, + // Note that the `late` field in `InOut` is about whether the registers used + // for these things overlap, and is of absolutely no interest to us. + InlineAsmOperand::InOut { in_value, out_place, .. } => { + if let Some(place) = out_place { + self.add_place(*place); + } + self.add_operand(in_value); } - | InlineAsmOperand::In { reg: _, value: _ } - | InlineAsmOperand::Out { reg: _, late: _, place: None } - | InlineAsmOperand::Const { value: _ } - | InlineAsmOperand::SymFn { value: _ } - | InlineAsmOperand::SymStatic { def_id: _ } => {} + InlineAsmOperand::Const { .. } + | InlineAsmOperand::SymFn { .. } + | InlineAsmOperand::SymStatic { .. } => (), } } } - TerminatorKind::Goto { .. } - | TerminatorKind::SwitchInt { .. } - | TerminatorKind::Resume - | TerminatorKind::Abort + | TerminatorKind::Resume { .. } + | TerminatorKind::Abort { .. } | TerminatorKind::Return - | TerminatorKind::Unreachable - | TerminatorKind::Drop { .. } - | TerminatorKind::Assert { .. } + | TerminatorKind::Unreachable { .. } => (), + TerminatorKind::Drop { .. } => { + // `Drop`s create a `&mut` and so are not considered + } + TerminatorKind::DropAndReplace { .. } + | TerminatorKind::Yield { .. } | TerminatorKind::GeneratorDrop | TerminatorKind::FalseEdge { .. } - | TerminatorKind::FalseUnwind { .. } => {} + | TerminatorKind::FalseUnwind { .. } => { + bug!("{:?} not found in this MIR phase", terminator) + } } } - /// Checks whether `a` and `b` may be merged. Returns `false` if there's a conflict. - fn can_unify(&mut self, a: Local, b: Local) -> bool { - // After some locals have been unified, their conflicts are only tracked in the root key, - // so look that up. - let a = self.unified_locals.find(a).0; - let b = self.unified_locals.find(b).0; - - if a == b { - // Already merged (part of the same connected component). - return false; - } + fn add_place<'tcx>(&mut self, place: Place<'tcx>) { + self.writes.push(place.local); + } - if self.matrix.contains(a, b) { - // Conflict (derived via dataflow, intra-statement conflicts, or inherited from another - // local during unification). - return false; + fn add_operand<'tcx>(&mut self, op: &Operand<'tcx>) { + match op { + // FIXME(JakobDegen): In a previous version, the `Move` case was incorrectly treated as + // being a read only. This was unsound, however we cannot add a regression test because + // it is not possible to set this off with current MIR. Once we have that ability, a + // regression test should be added. + Operand::Move(p) => self.add_place(*p), + Operand::Copy(_) | Operand::Constant(_) => (), } - - true } +} - /// Merges the conflicts of `a` and `b`, so that each one inherits all conflicts of the other. - /// - /// `can_unify` must have returned `true` for the same locals, or this may panic or lead to - /// miscompiles. - /// - /// This is called when the pass makes the decision to unify `a` and `b` (or parts of `a` and - /// `b`) and is needed to ensure that future unification decisions take potentially newly - /// introduced conflicts into account. - /// - /// For an example, assume we have locals `_0`, `_1`, `_2`, and `_3`. There are these conflicts: - /// - /// * `_0` <-> `_1` - /// * `_1` <-> `_2` - /// * `_3` <-> `_0` - /// - /// We then decide to merge `_2` with `_3` since they don't conflict. Then we decide to merge - /// `_2` with `_0`, which also doesn't have a conflict in the above list. However `_2` is now - /// `_3`, which does conflict with `_0`. - fn unify(&mut self, a: Local, b: Local) { - trace!("unify({:?}, {:?})", a, b); - - // Get the root local of the connected components. The root local stores the conflicts of - // all locals in the connected component (and *is stored* as the conflicting local of other - // locals). - let a = self.unified_locals.find(a).0; - let b = self.unified_locals.find(b).0; - assert_ne!(a, b); - - trace!("roots: a={:?}, b={:?}", a, b); - trace!("{:?} conflicts: {:?}", a, self.matrix.iter(a).format(", ")); - trace!("{:?} conflicts: {:?}", b, self.matrix.iter(b).format(", ")); - - self.unified_locals.union(a, b); - - let root = self.unified_locals.find(a).0; - assert!(root == a || root == b); - - // Make all locals that conflict with `a` also conflict with `b`, and vice versa. - self.unify_cache.clear(); - for conflicts_with_a in self.matrix.iter(a) { - self.unify_cache.insert(conflicts_with_a); - } - for conflicts_with_b in self.matrix.iter(b) { - self.unify_cache.insert(conflicts_with_b); - } - for conflicts_with_a_or_b in self.unify_cache.iter() { - // Set both `a` and `b` for this local's row. - self.matrix.insert(conflicts_with_a_or_b, a); - self.matrix.insert(conflicts_with_a_or_b, b); - } +///////////////////////////////////////////////////// +// Candidate accumulation - // Write the locals `a` conflicts with to `b`'s row. - self.matrix.union_rows(a, b); - // Write the locals `b` conflicts with to `a`'s row. - self.matrix.union_rows(b, a); - } +fn is_constant<'tcx>(place: Place<'tcx>) -> bool { + place.projection.iter().all(|p| !matches!(p, ProjectionElem::Deref | ProjectionElem::Index(_))) } -/// A `dest = {move} src;` statement at `loc`. +/// If the pair of places is being considered for merging, returns the candidate which would be +/// merged in order to accomplish this. +/// +/// The contract here is in one direction - there is a guarantee that merging the locals that are +/// outputted by this function would result in an assignment between the inputs becoming a +/// self-assignment. However, there is no guarantee that the returned pair is actually suitable for +/// merging - candidate collection must still check this independently. /// -/// We want to consider merging `dest` and `src` due to this assignment. -#[derive(Debug, Copy, Clone)] -struct CandidateAssignment<'tcx> { - /// Does not contain indirection or indexing (so the only local it contains is the place base). - dest: Place<'tcx>, - src: Local, - loc: Location, +/// This output is unique for each unordered pair of input places. +fn places_to_candidate_pair<'tcx>( + a: Place<'tcx>, + b: Place<'tcx>, + body: &Body<'tcx>, +) -> Option<(Local, Local)> { + let (mut a, mut b) = if a.projection.len() == 0 && b.projection.len() == 0 { + (a.local, b.local) + } else { + return None; + }; + + // By sorting, we make sure we're input order independent + if a > b { + std::mem::swap(&mut a, &mut b); + } + + // We could now return `(a, b)`, but then we miss some candidates in the case where `a` can't be + // used as a `src`. + if is_local_required(a, body) { + std::mem::swap(&mut a, &mut b); + } + // We could check `is_local_required` again here, but there's no need - after all, we make no + // promise that the candidate pair is actually valid + Some((a, b)) } -/// Scans the MIR for assignments between locals that we might want to consider merging. +/// Collects the candidates for merging /// -/// This will filter out assignments that do not match the right form (as described in the top-level -/// comment) and also throw out assignments that involve a local that has its address taken or is -/// otherwise ineligible (eg. locals used as array indices are ignored because we cannot propagate -/// arbitrary places into array indices). -fn find_candidates<'tcx>(body: &Body<'tcx>) -> Vec<CandidateAssignment<'tcx>> { - let mut visitor = FindAssignments { - body, - candidates: Vec::new(), - ever_borrowed_locals: borrowed_locals(body), - locals_used_as_array_index: locals_used_as_array_index(body), - }; +/// This is responsible for enforcing the first and third bullet point. +fn find_candidates<'alloc, 'tcx>( + body: &Body<'tcx>, + borrowed: &BitSet<Local>, + candidates: &'alloc mut FxHashMap<Local, Vec<Local>>, + candidates_reverse: &'alloc mut FxHashMap<Local, Vec<Local>>, +) -> Candidates<'alloc> { + candidates.clear(); + candidates_reverse.clear(); + let mut visitor = FindAssignments { body, candidates, borrowed }; visitor.visit_body(body); - visitor.candidates + // Deduplicate candidates + for (_, cands) in candidates.iter_mut() { + cands.sort(); + cands.dedup(); + } + // Generate the reverse map + for (src, cands) in candidates.iter() { + for dest in cands.iter().copied() { + candidates_reverse.entry(dest).or_default().push(*src); + } + } + Candidates { c: candidates, reverse: candidates_reverse } } -struct FindAssignments<'a, 'tcx> { +struct FindAssignments<'a, 'alloc, 'tcx> { body: &'a Body<'tcx>, - candidates: Vec<CandidateAssignment<'tcx>>, - ever_borrowed_locals: BitSet<Local>, - locals_used_as_array_index: BitSet<Local>, + candidates: &'alloc mut FxHashMap<Local, Vec<Local>>, + borrowed: &'a BitSet<Local>, } -impl<'tcx> Visitor<'tcx> for FindAssignments<'_, 'tcx> { - fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { +impl<'tcx> Visitor<'tcx> for FindAssignments<'_, '_, 'tcx> { + fn visit_statement(&mut self, statement: &Statement<'tcx>, _: Location) { if let StatementKind::Assign(box ( - dest, - Rvalue::Use(Operand::Copy(src) | Operand::Move(src)), + lhs, + Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)), )) = &statement.kind { - // `dest` must not have pointer indirection. - if dest.is_indirect() { - return; - } - - // `src` must be a plain local. - if !src.projection.is_empty() { + if !is_constant(*lhs) || !is_constant(*rhs) { return; } - // Since we want to replace `src` with `dest`, `src` must not be required. - if is_local_required(src.local, self.body) { + let Some((src, dest)) = places_to_candidate_pair(*lhs, *rhs, self.body) else { return; - } + }; - // Can't optimize if either local ever has their address taken. This optimization does - // liveness analysis only based on assignments, and a local can be live even if its - // never assigned to again, because a reference to it might be live. - // FIXME: This can be smarter and take `StorageDead` into account (which invalidates - // borrows). - if self.ever_borrowed_locals.contains(dest.local) - || self.ever_borrowed_locals.contains(src.local) - { + // As described at the top of the file, we do not go near things that have their address + // taken. + if self.borrowed.contains(src) || self.borrowed.contains(dest) { return; } - assert_ne!(dest.local, src.local, "self-assignments are UB"); - - // We can't replace locals occurring in `PlaceElem::Index` for now. - if self.locals_used_as_array_index.contains(src.local) { + // Also, we need to make sure that MIR actually allows the `src` to be removed + if is_local_required(src, self.body) { return; } - for elem in dest.projection { - if let PlaceElem::Index(_) = elem { - // `dest` contains an indexing projection. - return; - } - } - - self.candidates.push(CandidateAssignment { - dest: *dest, - src: src.local, - loc: location, - }); + // We may insert duplicates here, but that's fine + self.candidates.entry(src).or_default().push(dest); } } } @@ -886,32 +777,46 @@ fn is_local_required(local: Local, body: &Body<'_>) -> bool { } } -/// `PlaceElem::Index` only stores a `Local`, so we can't replace that with a full `Place`. -/// -/// Collect locals used as indices so we don't generate candidates that are impossible to apply -/// later. -fn locals_used_as_array_index(body: &Body<'_>) -> BitSet<Local> { - let mut visitor = IndexCollector { locals: BitSet::new_empty(body.local_decls.len()) }; - visitor.visit_body(body); - visitor.locals -} +///////////////////////////////////////////////////////// +// MIR Dump -struct IndexCollector { - locals: BitSet<Local>, -} +fn dest_prop_mir_dump<'body, 'tcx>( + tcx: TyCtxt<'tcx>, + body: &'body Body<'tcx>, + live: &mut ResultsCursor<'body, 'tcx, MaybeLiveLocals>, + round: usize, +) { + let mut reachable = None; + dump_mir(tcx, false, "DestinationPropagation-dataflow", &round, body, |pass_where, w| { + let reachable = reachable.get_or_insert_with(|| traversal::reachable_as_bitset(body)); + + match pass_where { + PassWhere::BeforeLocation(loc) if reachable.contains(loc.block) => { + live.seek_after_primary_effect(loc); + writeln!(w, " // live: {:?}", live.get())?; + } + PassWhere::AfterTerminator(bb) if reachable.contains(bb) => { + let loc = body.terminator_loc(bb); + live.seek_before_primary_effect(loc); + writeln!(w, " // live: {:?}", live.get())?; + } -impl<'tcx> Visitor<'tcx> for IndexCollector { - fn visit_projection_elem( - &mut self, - local: Local, - proj_base: &[PlaceElem<'tcx>], - elem: PlaceElem<'tcx>, - context: PlaceContext, - location: Location, - ) { - if let PlaceElem::Index(i) = elem { - self.locals.insert(i); + PassWhere::BeforeBlock(bb) if reachable.contains(bb) => { + live.seek_to_block_start(bb); + writeln!(w, " // live: {:?}", live.get())?; + } + + PassWhere::BeforeCFG | PassWhere::AfterCFG | PassWhere::AfterLocation(_) => {} + + PassWhere::BeforeLocation(_) | PassWhere::AfterTerminator(_) => { + writeln!(w, " // live: <unreachable>")?; + } + + PassWhere::BeforeBlock(_) => { + writeln!(w, " // live: <unreachable>")?; + } } - self.super_projection_elem(local, proj_base, elem, context, location); - } + + Ok(()) + }); } diff --git a/compiler/rustc_mir_transform/src/dump_mir.rs b/compiler/rustc_mir_transform/src/dump_mir.rs index 6b995141a..594cbd897 100644 --- a/compiler/rustc_mir_transform/src/dump_mir.rs +++ b/compiler/rustc_mir_transform/src/dump_mir.rs @@ -1,6 +1,5 @@ //! This pass just dumps MIR at a specified point. -use std::borrow::Cow; use std::fs::File; use std::io; @@ -8,20 +7,20 @@ use crate::MirPass; use rustc_middle::mir::write_mir_pretty; use rustc_middle::mir::Body; use rustc_middle::ty::TyCtxt; -use rustc_session::config::{OutputFilenames, OutputType}; +use rustc_session::config::OutputType; pub struct Marker(pub &'static str); impl<'tcx> MirPass<'tcx> for Marker { - fn name(&self) -> Cow<'_, str> { - Cow::Borrowed(self.0) + fn name(&self) -> &str { + self.0 } fn run_pass(&self, _tcx: TyCtxt<'tcx>, _body: &mut Body<'tcx>) {} } -pub fn emit_mir(tcx: TyCtxt<'_>, outputs: &OutputFilenames) -> io::Result<()> { - let path = outputs.path(OutputType::Mir); +pub fn emit_mir(tcx: TyCtxt<'_>) -> io::Result<()> { + let path = tcx.output_filenames(()).path(OutputType::Mir); let mut f = io::BufWriter::new(File::create(&path)?); write_mir_pretty(tcx, None, &mut f)?; Ok(()) diff --git a/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs b/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs index ef8d6bb65..932134bd6 100644 --- a/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs +++ b/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs @@ -25,7 +25,7 @@ pub fn build_ptr_tys<'tcx>( (unique_ty, nonnull_ty, ptr_ty) } -// Constructs the projection needed to access a Box's pointer +/// Constructs the projection needed to access a Box's pointer pub fn build_projection<'tcx>( unique_ty: Ty<'tcx>, nonnull_ty: Ty<'tcx>, diff --git a/compiler/rustc_mir_transform/src/function_item_references.rs b/compiler/rustc_mir_transform/src/function_item_references.rs index 469566694..b708d780b 100644 --- a/compiler/rustc_mir_transform/src/function_item_references.rs +++ b/compiler/rustc_mir_transform/src/function_item_references.rs @@ -110,7 +110,7 @@ impl<'tcx> FunctionItemRefChecker<'_, 'tcx> { /// If the given predicate is the trait `fmt::Pointer`, returns the bound parameter type. fn is_pointer_trait(&self, bound: &PredicateKind<'tcx>) -> Option<Ty<'tcx>> { - if let ty::PredicateKind::Trait(predicate) = bound { + if let ty::PredicateKind::Clause(ty::Clause::Trait(predicate)) = bound { if self.tcx.is_diagnostic_item(sym::Pointer, predicate.def_id()) { Some(predicate.trait_ref.self_ty()) } else { diff --git a/compiler/rustc_mir_transform/src/generator.rs b/compiler/rustc_mir_transform/src/generator.rs index c833de3a8..69f96fe48 100644 --- a/compiler/rustc_mir_transform/src/generator.rs +++ b/compiler/rustc_mir_transform/src/generator.rs @@ -11,10 +11,10 @@ //! generator in the MIR, since it is used to create the drop glue for the generator. We'd get //! infinite recursion otherwise. //! -//! This pass creates the implementation for the Generator::resume function and the drop shim -//! for the generator based on the MIR input. It converts the generator argument from Self to -//! &mut Self adding derefs in the MIR as needed. It computes the final layout of the generator -//! struct which looks like this: +//! This pass creates the implementation for either the `Generator::resume` or `Future::poll` +//! function and the drop shim for the generator based on the MIR input. +//! It converts the generator argument from Self to &mut Self adding derefs in the MIR as needed. +//! It computes the final layout of the generator struct which looks like this: //! First upvars are stored //! It is followed by the generator state field. //! Then finally the MIR locals which are live across a suspension point are stored. @@ -32,14 +32,15 @@ //! 2 - Generator has been poisoned //! //! It also rewrites `return x` and `yield y` as setting a new generator state and returning -//! GeneratorState::Complete(x) and GeneratorState::Yielded(y) respectively. +//! `GeneratorState::Complete(x)` and `GeneratorState::Yielded(y)`, +//! or `Poll::Ready(x)` and `Poll::Pending` respectively. //! MIR locals which are live across a suspension point are moved to the generator struct //! with references to them being updated with references to the generator struct. //! //! The pass creates two functions which have a switch on the generator state giving //! the action to take. //! -//! One of them is the implementation of Generator::resume. +//! One of them is the implementation of `Generator::resume` / `Future::poll`. //! For generators with state 0 (unresumed) it starts the execution of the generator. //! For generators with state 1 (returned) and state 2 (poisoned) it panics. //! Otherwise it continues the execution from the last suspension point. @@ -56,6 +57,7 @@ use crate::MirPass; use rustc_data_structures::fx::FxHashMap; use rustc_hir as hir; use rustc_hir::lang_items::LangItem; +use rustc_hir::GeneratorKind; use rustc_index::bit_set::{BitMatrix, BitSet, GrowableBitSet}; use rustc_index::vec::{Idx, IndexVec}; use rustc_middle::mir::dump_mir; @@ -215,6 +217,7 @@ struct SuspensionPoint<'tcx> { struct TransformVisitor<'tcx> { tcx: TyCtxt<'tcx>, + is_async_kind: bool, state_adt_ref: AdtDef<'tcx>, state_substs: SubstsRef<'tcx>, @@ -239,28 +242,57 @@ struct TransformVisitor<'tcx> { } impl<'tcx> TransformVisitor<'tcx> { - // Make a GeneratorState variant assignment. `core::ops::GeneratorState` only has single - // element tuple variants, so we can just write to the downcasted first field and then set the + // Make a `GeneratorState` or `Poll` variant assignment. + // + // `core::ops::GeneratorState` only has single element tuple variants, + // so we can just write to the downcasted first field and then set the // discriminant to the appropriate variant. fn make_state( &self, - idx: VariantIdx, val: Operand<'tcx>, source_info: SourceInfo, - ) -> impl Iterator<Item = Statement<'tcx>> { + is_return: bool, + statements: &mut Vec<Statement<'tcx>>, + ) { + let idx = VariantIdx::new(match (is_return, self.is_async_kind) { + (true, false) => 1, // GeneratorState::Complete + (false, false) => 0, // GeneratorState::Yielded + (true, true) => 0, // Poll::Ready + (false, true) => 1, // Poll::Pending + }); + let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_substs, None, None); + + // `Poll::Pending` + if self.is_async_kind && idx == VariantIdx::new(1) { + assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0); + + // FIXME(swatinem): assert that `val` is indeed unit? + statements.extend(expand_aggregate( + Place::return_place(), + std::iter::empty(), + kind, + source_info, + self.tcx, + )); + return; + } + + // else: `Poll::Ready(x)`, `GeneratorState::Yielded(x)` or `GeneratorState::Complete(x)` assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1); + let ty = self .tcx .bound_type_of(self.state_adt_ref.variant(idx).fields[0].did) .subst(self.tcx, self.state_substs); - expand_aggregate( + + statements.extend(expand_aggregate( Place::return_place(), std::iter::once((val, ty)), kind, source_info, self.tcx, - ) + )); } // Create a Place referencing a generator struct field @@ -331,22 +363,19 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> { }); let ret_val = match data.terminator().kind { - TerminatorKind::Return => Some(( - VariantIdx::new(1), - None, - Operand::Move(Place::from(self.new_ret_local)), - None, - )), + TerminatorKind::Return => { + Some((true, None, Operand::Move(Place::from(self.new_ret_local)), None)) + } TerminatorKind::Yield { ref value, resume, resume_arg, drop } => { - Some((VariantIdx::new(0), Some((resume, resume_arg)), value.clone(), drop)) + Some((false, Some((resume, resume_arg)), value.clone(), drop)) } _ => None, }; - if let Some((state_idx, resume, v, drop)) = ret_val { + if let Some((is_return, resume, v, drop)) = ret_val { let source_info = data.terminator().source_info; // We must assign the value first in case it gets declared dead below - data.statements.extend(self.make_state(state_idx, v, source_info)); + self.make_state(v, source_info, is_return, &mut data.statements); let state = if let Some((resume, mut resume_arg)) = resume { // Yield let state = RESERVED_VARIANTS + self.suspension_points.len(); @@ -956,22 +985,12 @@ fn create_generator_drop_shim<'tcx>( tcx.mk_ptr(ty::TypeAndMut { ty: gen_ty, mutbl: hir::Mutability::Mut }), source_info, ); - if tcx.sess.opts.unstable_opts.mir_emit_retag { - // Alias tracking must know we changed the type - body.basic_blocks_mut()[START_BLOCK].statements.insert( - 0, - Statement { - source_info, - kind: StatementKind::Retag(RetagKind::Raw, Box::new(Place::from(SELF_ARG))), - }, - ) - } // Make sure we remove dead blocks to remove // unrelated code from the resume part of the function simplify::remove_dead_blocks(tcx, &mut body); - dump_mir(tcx, None, "generator_drop", &0, &body, |_, _| Ok(())); + dump_mir(tcx, false, "generator_drop", &0, &body, |_, _| Ok(())); body } @@ -1015,7 +1034,7 @@ fn insert_panic_block<'tcx>( fn can_return<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, param_env: ty::ParamEnv<'tcx>) -> bool { // Returning from a function with an uninhabited return type is undefined behavior. - if tcx.conservative_is_privately_uninhabited(param_env.and(body.return_ty())) { + if body.return_ty().is_privately_uninhabited(tcx, param_env) { return false; } @@ -1142,7 +1161,7 @@ fn create_generator_resume_function<'tcx>( // unrelated code from the drop part of the function simplify::remove_dead_blocks(tcx, body); - dump_mir(tcx, None, "generator_resume", &0, body, |_, _| Ok(())); + dump_mir(tcx, false, "generator_resume", &0, body, |_, _| Ok(())); } fn insert_clean_drop(body: &mut Body<'_>) -> BasicBlock { @@ -1268,10 +1287,20 @@ impl<'tcx> MirPass<'tcx> for StateTransform { } }; - // Compute GeneratorState<yield_ty, return_ty> - let state_did = tcx.require_lang_item(LangItem::GeneratorState, None); - let state_adt_ref = tcx.adt_def(state_did); - let state_substs = tcx.intern_substs(&[yield_ty.into(), body.return_ty().into()]); + let is_async_kind = body.generator_kind().unwrap() != GeneratorKind::Gen; + let (state_adt_ref, state_substs) = if is_async_kind { + // Compute Poll<return_ty> + let state_did = tcx.require_lang_item(LangItem::Poll, None); + let state_adt_ref = tcx.adt_def(state_did); + let state_substs = tcx.intern_substs(&[body.return_ty().into()]); + (state_adt_ref, state_substs) + } else { + // Compute GeneratorState<yield_ty, return_ty> + let state_did = tcx.require_lang_item(LangItem::GeneratorState, None); + let state_adt_ref = tcx.adt_def(state_did); + let state_substs = tcx.intern_substs(&[yield_ty.into(), body.return_ty().into()]); + (state_adt_ref, state_substs) + }; let ret_ty = tcx.mk_adt(state_adt_ref, state_substs); // We rename RETURN_PLACE which has type mir.return_ty to new_ret_local @@ -1327,9 +1356,11 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // Run the transformation which converts Places from Local to generator struct // accesses for locals in `remap`. // It also rewrites `return x` and `yield y` as writing a new generator state and returning - // GeneratorState::Complete(x) and GeneratorState::Yielded(y) respectively. + // either GeneratorState::Complete(x) and GeneratorState::Yielded(y), + // or Poll::Ready(x) and Poll::Pending respectively depending on `is_async_kind`. let mut transform = TransformVisitor { tcx, + is_async_kind, state_adt_ref, state_substs, remap, @@ -1353,21 +1384,21 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // This is expanded to a drop ladder in `elaborate_generator_drops`. let drop_clean = insert_clean_drop(body); - dump_mir(tcx, None, "generator_pre-elab", &0, body, |_, _| Ok(())); + dump_mir(tcx, false, "generator_pre-elab", &0, body, |_, _| Ok(())); // Expand `drop(generator_struct)` to a drop ladder which destroys upvars. // If any upvars are moved out of, drop elaboration will handle upvar destruction. // However we need to also elaborate the code generated by `insert_clean_drop`. elaborate_generator_drops(tcx, body); - dump_mir(tcx, None, "generator_post-transform", &0, body, |_, _| Ok(())); + dump_mir(tcx, false, "generator_post-transform", &0, body, |_, _| Ok(())); // Create a copy of our MIR and use it to create the drop shim for the generator let drop_shim = create_generator_drop_shim(tcx, &transform, gen_ty, body, drop_clean); body.generator.as_mut().unwrap().generator_drop = Some(drop_shim); - // Create the Generator::resume function + // Create the Generator::resume / Future::poll function create_generator_resume_function(tcx, transform, body, can_return); // Run derefer to fix Derefs that are not in the first place diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index 780b91d92..220cf7df9 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -1,7 +1,6 @@ //! Inlining pass for MIR functions use crate::deref_separator::deref_finder; use rustc_attr::InlineAttr; -use rustc_const_eval::transform::validate::equal_up_to_regions; use rustc_index::bit_set::BitSet; use rustc_index::vec::Idx; use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs}; @@ -9,12 +8,12 @@ use rustc_middle::mir::visit::*; use rustc_middle::mir::*; use rustc_middle::ty::{self, Instance, InstanceDef, ParamEnv, Ty, TyCtxt}; use rustc_session::config::OptLevel; -use rustc_span::def_id::DefId; use rustc_span::{hygiene::ExpnKind, ExpnData, LocalExpnId, Span}; use rustc_target::abi::VariantIdx; use rustc_target::spec::abi::Abi; -use super::simplify::{remove_dead_blocks, CfgSimplifier}; +use crate::simplify::{remove_dead_blocks, CfgSimplifier}; +use crate::util; use crate::MirPass; use std::iter; use std::ops::{Range, RangeFrom}; @@ -87,13 +86,8 @@ fn inline<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> bool { let param_env = tcx.param_env_reveal_all_normalized(def_id); - let mut this = Inliner { - tcx, - param_env, - codegen_fn_attrs: tcx.codegen_fn_attrs(def_id), - history: Vec::new(), - changed: false, - }; + let mut this = + Inliner { tcx, param_env, codegen_fn_attrs: tcx.codegen_fn_attrs(def_id), changed: false }; let blocks = BasicBlock::new(0)..body.basic_blocks.next_index(); this.process_blocks(body, blocks); this.changed @@ -104,12 +98,6 @@ struct Inliner<'tcx> { param_env: ParamEnv<'tcx>, /// Caller codegen attributes. codegen_fn_attrs: &'tcx CodegenFnAttrs, - /// Stack of inlined instances. - /// We only check the `DefId` and not the substs because we want to - /// avoid inlining cases of polymorphic recursion. - /// The number of `DefId`s is finite, so checking history is enough - /// to ensure that we do not loop endlessly while inlining. - history: Vec<DefId>, /// Indicates that the caller body has been modified. changed: bool, } @@ -134,12 +122,12 @@ impl<'tcx> Inliner<'tcx> { debug!("not-inlined {} [{}]", callsite.callee, reason); continue; } - Ok(new_blocks) => { + Ok(_) => { debug!("inlined {}", callsite.callee); self.changed = true; - self.history.push(callsite.callee.def_id()); - self.process_blocks(caller_body, new_blocks); - self.history.pop(); + // We could process the blocks returned by `try_inlining` here. However, that + // leads to exponential compile times due to the top-down nature of this kind + // of inlining. } } } @@ -180,7 +168,7 @@ impl<'tcx> Inliner<'tcx> { let TerminatorKind::Call { args, destination, .. } = &terminator.kind else { bug!() }; let destination_ty = destination.ty(&caller_body.local_decls, self.tcx).ty; let output_type = callee_body.return_ty(); - if !equal_up_to_regions(self.tcx, self.param_env, output_type, destination_ty) { + if !util::is_subtype(self.tcx, self.param_env, output_type, destination_ty) { trace!(?output_type, ?destination_ty); return Err("failed to normalize return type"); } @@ -200,7 +188,7 @@ impl<'tcx> Inliner<'tcx> { arg_tuple_tys.iter().zip(callee_body.args_iter().skip(skipped_args)) { let input_type = callee_body.local_decls[input].ty; - if !equal_up_to_regions(self.tcx, self.param_env, arg_ty, input_type) { + if !util::is_subtype(self.tcx, self.param_env, input_type, arg_ty) { trace!(?arg_ty, ?input_type); return Err("failed to normalize tuple argument type"); } @@ -209,7 +197,7 @@ impl<'tcx> Inliner<'tcx> { for (arg, input) in args.iter().zip(callee_body.args_iter()) { let input_type = callee_body.local_decls[input].ty; let arg_ty = arg.ty(&caller_body.local_decls, self.tcx); - if !equal_up_to_regions(self.tcx, self.param_env, arg_ty, input_type) { + if !util::is_subtype(self.tcx, self.param_env, input_type, arg_ty) { trace!(?arg_ty, ?input_type); return Err("failed to normalize argument type"); } @@ -313,10 +301,6 @@ impl<'tcx> Inliner<'tcx> { return None; } - if self.history.contains(&callee.def_id()) { - return None; - } - let fn_sig = self.tcx.bound_fn_sig(def_id).subst(self.tcx, substs); return Some(CallSite { @@ -363,10 +347,6 @@ impl<'tcx> Inliner<'tcx> { return Err("C variadic"); } - if callee_attrs.flags.contains(CodegenFnAttrFlags::NAKED) { - return Err("naked"); - } - if callee_attrs.flags.contains(CodegenFnAttrFlags::COLD) { return Err("cold"); } @@ -375,7 +355,12 @@ impl<'tcx> Inliner<'tcx> { return Err("incompatible sanitizer set"); } - if callee_attrs.instruction_set != self.codegen_fn_attrs.instruction_set { + // Two functions are compatible if the callee has no attribute (meaning + // that it's codegen agnostic), or sets an attribute that is identical + // to this function's attribute. + if callee_attrs.instruction_set.is_some() + && callee_attrs.instruction_set != self.codegen_fn_attrs.instruction_set + { return Err("incompatible instruction set"); } @@ -453,6 +438,15 @@ impl<'tcx> Inliner<'tcx> { if ty.needs_drop(tcx, self.param_env) && let Some(unwind) = unwind { work_list.push(unwind); } + } else if callee_attrs.instruction_set != self.codegen_fn_attrs.instruction_set + && matches!(term.kind, TerminatorKind::InlineAsm { .. }) + { + // During the attribute checking stage we allow a callee with no + // instruction_set assigned to count as compatible with a function that does + // assign one. However, during this stage we require an exact match when any + // inline-asm is detected. LLVM will still possibly do an inline later on + // if the no-attribute function ends up with the same instruction set anyway. + return Err("Cannot move inline-asm across instruction sets"); } else { work_list.extend(term.successors()) } @@ -847,7 +841,7 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> { let parent = Place { local, projection: self.tcx.intern_place_elems(proj_base) }; let parent_ty = parent.ty(&self.callee_body.local_decls, self.tcx); let check_equal = |this: &mut Self, f_ty| { - if !equal_up_to_regions(this.tcx, this.param_env, ty, f_ty) { + if !util::is_equal_up_to_subtyping(this.tcx, this.param_env, ty, f_ty) { trace!(?ty, ?f_ty); this.validation = Err("failed to normalize projection type"); return; diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index 5be223254..93200b288 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -1,5 +1,6 @@ #![allow(rustc::potential_query_instability)] #![feature(box_patterns)] +#![feature(drain_filter)] #![feature(let_chains)] #![feature(map_try_insert)] #![feature(min_specialization)] @@ -54,6 +55,7 @@ mod const_goto; mod const_prop; mod const_prop_lint; mod coverage; +mod dataflow_const_prop; mod dead_store_elimination; mod deaggregator; mod deduce_param_attrs; @@ -91,6 +93,7 @@ pub mod simplify; mod simplify_branches; mod simplify_comparison_integral; mod simplify_try; +mod sroa; mod uninhabited_enum_branching; mod unreachable_prop; @@ -217,19 +220,18 @@ fn mir_keys(tcx: TyCtxt<'_>, (): ()) -> FxIndexSet<LocalDefId> { // Additionally, tuple struct/variant constructors have MIR, but // they don't have a BodyId, so we need to build them separately. - struct GatherCtors<'a, 'tcx> { - tcx: TyCtxt<'tcx>, + struct GatherCtors<'a> { set: &'a mut FxIndexSet<LocalDefId>, } - impl<'tcx> Visitor<'tcx> for GatherCtors<'_, 'tcx> { + impl<'tcx> Visitor<'tcx> for GatherCtors<'_> { fn visit_variant_data(&mut self, v: &'tcx hir::VariantData<'tcx>) { - if let hir::VariantData::Tuple(_, hir_id) = *v { - self.set.insert(self.tcx.hir().local_def_id(hir_id)); + if let hir::VariantData::Tuple(_, _, def_id) = *v { + self.set.insert(def_id); } intravisit::walk_struct_def(self, v) } } - tcx.hir().visit_all_item_likes_in_crate(&mut GatherCtors { tcx, set: &mut set }); + tcx.hir().visit_all_item_likes_in_crate(&mut GatherCtors { set: &mut set }); set } @@ -288,7 +290,7 @@ fn mir_const<'tcx>( let mut body = tcx.mir_built(def).steal(); - rustc_middle::mir::dump_mir(tcx, None, "mir_map", &0, &body, |_, _| Ok(())); + pass_manager::dump_mir_for_phase_change(tcx, &body); pm::run_passes( tcx, @@ -561,6 +563,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &remove_zsts::RemoveZsts, &const_goto::ConstGoto, &remove_unneeded_drops::RemoveUnneededDrops, + &sroa::ScalarReplacementOfAggregates, &match_branches::MatchBranchSimplification, // inst combine is after MatchBranchSimplification to clean up Ne(_1, false) &multiple_return_terminators::MultipleReturnTerminators, @@ -569,6 +572,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { // // FIXME(#70073): This pass is responsible for both optimization as well as some lints. &const_prop::ConstProp, + &dataflow_const_prop::DataflowConstProp, // // Const-prop runs unconditionally, but doesn't mutate the MIR at mir-opt-level=0. &const_debuginfo::ConstDebugInfo, diff --git a/compiler/rustc_mir_transform/src/pass_manager.rs b/compiler/rustc_mir_transform/src/pass_manager.rs index 230c6a7cb..e1b65823a 100644 --- a/compiler/rustc_mir_transform/src/pass_manager.rs +++ b/compiler/rustc_mir_transform/src/pass_manager.rs @@ -1,5 +1,3 @@ -use std::borrow::Cow; - use rustc_middle::mir::{self, Body, MirPhase, RuntimePhase}; use rustc_middle::ty::TyCtxt; use rustc_session::Session; @@ -8,13 +6,9 @@ use crate::{validate, MirPass}; /// Just like `MirPass`, except it cannot mutate `Body`. pub trait MirLint<'tcx> { - fn name(&self) -> Cow<'_, str> { + fn name(&self) -> &str { let name = std::any::type_name::<Self>(); - if let Some(tail) = name.rfind(':') { - Cow::from(&name[tail + 1..]) - } else { - Cow::from(name) - } + if let Some((_, tail)) = name.rsplit_once(':') { tail } else { name } } fn is_enabled(&self, _sess: &Session) -> bool { @@ -32,7 +26,7 @@ impl<'tcx, T> MirPass<'tcx> for Lint<T> where T: MirLint<'tcx>, { - fn name(&self) -> Cow<'_, str> { + fn name(&self) -> &str { self.0.name() } @@ -55,7 +49,7 @@ impl<'tcx, T> MirPass<'tcx> for WithMinOptLevel<T> where T: MirPass<'tcx>, { - fn name(&self) -> Cow<'_, str> { + fn name(&self) -> &str { self.1.name() } @@ -96,45 +90,48 @@ fn run_passes_inner<'tcx>( phase_change: Option<MirPhase>, validate_each: bool, ) { - let validate = validate_each & tcx.sess.opts.unstable_opts.validate_mir; + let validate = validate_each & tcx.sess.opts.unstable_opts.validate_mir & !body.should_skip(); let overridden_passes = &tcx.sess.opts.unstable_opts.mir_enable_passes; trace!(?overridden_passes); - for pass in passes { - let name = pass.name(); - - let overridden = - overridden_passes.iter().rev().find(|(s, _)| s == &*name).map(|(_name, polarity)| { - trace!( - pass = %name, - "{} as requested by flag", - if *polarity { "Running" } else { "Not running" }, - ); - *polarity - }); - if !overridden.unwrap_or_else(|| pass.is_enabled(&tcx.sess)) { - continue; - } - - let dump_enabled = pass.is_mir_dump_enabled(); - - if dump_enabled { - dump_mir_for_pass(tcx, body, &name, false); - } - if validate { - validate_body(tcx, body, format!("before pass {}", name)); - } - - pass.run_pass(tcx, body); - - if dump_enabled { - dump_mir_for_pass(tcx, body, &name, true); + if !body.should_skip() { + for pass in passes { + let name = pass.name(); + + let overridden = overridden_passes.iter().rev().find(|(s, _)| s == &*name).map( + |(_name, polarity)| { + trace!( + pass = %name, + "{} as requested by flag", + if *polarity { "Running" } else { "Not running" }, + ); + *polarity + }, + ); + if !overridden.unwrap_or_else(|| pass.is_enabled(&tcx.sess)) { + continue; + } + + let dump_enabled = pass.is_mir_dump_enabled(); + + if dump_enabled { + dump_mir_for_pass(tcx, body, &name, false); + } + if validate { + validate_body(tcx, body, format!("before pass {}", name)); + } + + pass.run_pass(tcx, body); + + if dump_enabled { + dump_mir_for_pass(tcx, body, &name, true); + } + if validate { + validate_body(tcx, body, format!("after pass {}", name)); + } + + body.pass_count += 1; } - if validate { - validate_body(tcx, body, format!("after pass {}", name)); - } - - body.pass_count += 1; } if let Some(new_phase) = phase_change { @@ -143,10 +140,11 @@ fn run_passes_inner<'tcx>( } body.phase = new_phase; + body.pass_count = 0; dump_mir_for_phase_change(tcx, body); if validate || new_phase == MirPhase::Runtime(RuntimePhase::Optimized) { - validate_body(tcx, body, format!("after phase change to {}", new_phase)); + validate_body(tcx, body, format!("after phase change to {}", new_phase.name())); } body.pass_count = 1; @@ -163,11 +161,9 @@ pub fn dump_mir_for_pass<'tcx>( pass_name: &str, is_after: bool, ) { - let phase_index = body.phase.phase_index(); - mir::dump_mir( tcx, - Some(&format_args!("{:03}-{:03}", phase_index, body.pass_count)), + true, pass_name, if is_after { &"after" } else { &"before" }, body, @@ -176,14 +172,6 @@ pub fn dump_mir_for_pass<'tcx>( } pub fn dump_mir_for_phase_change<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) { - let phase_index = body.phase.phase_index(); - - mir::dump_mir( - tcx, - Some(&format_args!("{:03}-000", phase_index)), - &format!("{}", body.phase), - &"after", - body, - |_, _| Ok(()), - ) + assert_eq!(body.pass_count, 0); + mir::dump_mir(tcx, true, body.phase.name(), &"after", body, |_, _| Ok(())) } diff --git a/compiler/rustc_mir_transform/src/remove_zsts.rs b/compiler/rustc_mir_transform/src/remove_zsts.rs index 40be4f146..569e783fe 100644 --- a/compiler/rustc_mir_transform/src/remove_zsts.rs +++ b/compiler/rustc_mir_transform/src/remove_zsts.rs @@ -1,8 +1,7 @@ //! Removes assignments to ZST places. use crate::MirPass; -use rustc_middle::mir::tcx::PlaceTy; -use rustc_middle::mir::{Body, LocalDecls, Place, StatementKind}; +use rustc_middle::mir::{Body, StatementKind}; use rustc_middle::ty::{self, Ty, TyCtxt}; pub struct RemoveZsts; @@ -35,9 +34,6 @@ impl<'tcx> MirPass<'tcx> for RemoveZsts { if !layout.is_zst() { continue; } - if involves_a_union(place, local_decls, tcx) { - continue; - } if tcx.consider_optimizing(|| { format!( "RemoveZsts - Place: {:?} SourceInfo: {:?}", @@ -63,24 +59,3 @@ fn maybe_zst(ty: Ty<'_>) -> bool { _ => false, } } - -/// Miri lazily allocates memory for locals on assignment, -/// so we must preserve writes to unions and union fields, -/// or it will ICE on reads of those fields. -fn involves_a_union<'tcx>( - place: Place<'tcx>, - local_decls: &LocalDecls<'tcx>, - tcx: TyCtxt<'tcx>, -) -> bool { - let mut place_ty = PlaceTy::from_ty(local_decls[place.local].ty); - if place_ty.ty.is_union() { - return true; - } - for elem in place.projection { - place_ty = place_ty.projection_ty(tcx, elem); - if place_ty.ty.is_union() { - return true; - } - } - return false; -} diff --git a/compiler/rustc_mir_transform/src/required_consts.rs b/compiler/rustc_mir_transform/src/required_consts.rs index cc75947d9..0ea8f2ba9 100644 --- a/compiler/rustc_mir_transform/src/required_consts.rs +++ b/compiler/rustc_mir_transform/src/required_consts.rs @@ -17,7 +17,7 @@ impl<'tcx> Visitor<'tcx> for RequiredConstsVisitor<'_, 'tcx> { let literal = constant.literal; match literal { ConstantKind::Ty(c) => match c.kind() { - ConstKind::Param(_) => {} + ConstKind::Param(_) | ConstKind::Error(_) => {} _ => bug!("only ConstKind::Param should be encountered here, got {:#?}", c), }, ConstantKind::Unevaluated(..) => self.required_consts.push(*constant), diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs index 4e8798b7a..16b7dcad1 100644 --- a/compiler/rustc_mir_transform/src/shim.rs +++ b/compiler/rustc_mir_transform/src/shim.rs @@ -37,7 +37,7 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<' } ty::InstanceDef::FnPtrShim(def_id, ty) => { let trait_ = tcx.trait_of_item(def_id).unwrap(); - let adjustment = match tcx.fn_trait_kind_from_lang_item(trait_) { + let adjustment = match tcx.fn_trait_kind_from_def_id(trait_) { Some(ty::ClosureKind::FnOnce) => Adjustment::Identity, Some(ty::ClosureKind::FnMut | ty::ClosureKind::Fn) => Adjustment::Deref, None => bug!("fn pointer {:?} is not an fn", ty), @@ -177,16 +177,6 @@ fn build_drop_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, ty: Option<Ty<'tcx>>) if ty.is_some() { // The first argument (index 0), but add 1 for the return value. let dropee_ptr = Place::from(Local::new(1 + 0)); - if tcx.sess.opts.unstable_opts.mir_emit_retag { - // Function arguments should be retagged, and we make this one raw. - body.basic_blocks_mut()[START_BLOCK].statements.insert( - 0, - Statement { - source_info, - kind: StatementKind::Retag(RetagKind::Raw, Box::new(dropee_ptr)), - }, - ); - } let patch = { let param_env = tcx.param_env_reveal_all_normalized(def_id); let mut elaborator = @@ -346,7 +336,7 @@ impl<'tcx> CloneShimBuilder<'tcx> { // we must subst the self_ty because it's // otherwise going to be TySelf and we can't index // or access fields of a Place of type TySelf. - let substs = tcx.mk_substs_trait(self_ty, &[]); + let substs = tcx.mk_substs_trait(self_ty, []); let sig = tcx.bound_fn_sig(def_id).subst(tcx, substs); let sig = tcx.erase_late_bound_regions(sig); let span = tcx.def_span(def_id); @@ -427,7 +417,7 @@ impl<'tcx> CloneShimBuilder<'tcx> { ) { let tcx = self.tcx; - let substs = tcx.mk_substs_trait(ty, &[]); + let substs = tcx.mk_substs_trait(ty, []); // `func == Clone::clone(&ty) -> ty` let func_ty = tcx.mk_fn_def(self.def_id, substs); @@ -569,17 +559,13 @@ impl<'tcx> CloneShimBuilder<'tcx> { /// Builds a "call" shim for `instance`. The shim calls the function specified by `call_kind`, /// first adjusting its first argument according to `rcvr_adjustment`. +#[instrument(level = "debug", skip(tcx), ret)] fn build_call_shim<'tcx>( tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>, rcvr_adjustment: Option<Adjustment>, call_kind: CallKind<'tcx>, ) -> Body<'tcx> { - debug!( - "build_call_shim(instance={:?}, rcvr_adjustment={:?}, call_kind={:?})", - instance, rcvr_adjustment, call_kind - ); - // `FnPtrShim` contains the fn pointer type that a call shim is being built for - this is used // to substitute into the signature of the shim. It is not necessary for users of this // MIR body to perform further substitutions (see `InstanceDef::has_polymorphic_mir_body`). @@ -590,7 +576,7 @@ fn build_call_shim<'tcx>( // Create substitutions for the `Self` and `Args` generic parameters of the shim body. let arg_tup = tcx.mk_tup(untuple_args.iter()); - let sig_substs = tcx.mk_substs_trait(ty, &[ty::subst::GenericArg::from(arg_tup)]); + let sig_substs = tcx.mk_substs_trait(ty, [ty::subst::GenericArg::from(arg_tup)]); (Some(sig_substs), Some(untuple_args)) } else { @@ -641,7 +627,7 @@ fn build_call_shim<'tcx>( let span = tcx.def_span(def_id); - debug!("build_call_shim: sig={:?}", sig); + debug!(?sig); let mut local_decls = local_decls_for_sig(&sig, span); let source_info = SourceInfo::outermost(span); @@ -845,7 +831,7 @@ pub fn build_adt_ctor(tcx: TyCtxt<'_>, ctor_id: DefId) -> Body<'_> { span, ); - rustc_middle::mir::dump_mir(tcx, None, "mir_map", &0, &body, |_, _| Ok(())); + crate::pass_manager::dump_mir_for_phase_change(tcx, &body); body } diff --git a/compiler/rustc_mir_transform/src/simplify.rs b/compiler/rustc_mir_transform/src/simplify.rs index 57d372fda..475e2ec9a 100644 --- a/compiler/rustc_mir_transform/src/simplify.rs +++ b/compiler/rustc_mir_transform/src/simplify.rs @@ -35,7 +35,6 @@ use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Vis use rustc_middle::mir::*; use rustc_middle::ty::TyCtxt; use smallvec::SmallVec; -use std::borrow::Cow; use std::convert::TryInto; pub struct SimplifyCfg { @@ -57,8 +56,8 @@ pub fn simplify_cfg<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { } impl<'tcx> MirPass<'tcx> for SimplifyCfg { - fn name(&self) -> Cow<'_, str> { - Cow::Borrowed(&self.label) + fn name(&self) -> &str { + &self.label } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { diff --git a/compiler/rustc_mir_transform/src/simplify_branches.rs b/compiler/rustc_mir_transform/src/simplify_branches.rs index 3bbae5b89..405ebce4d 100644 --- a/compiler/rustc_mir_transform/src/simplify_branches.rs +++ b/compiler/rustc_mir_transform/src/simplify_branches.rs @@ -2,8 +2,6 @@ use crate::MirPass; use rustc_middle::mir::*; use rustc_middle::ty::TyCtxt; -use std::borrow::Cow; - /// A pass that replaces a branch with a goto when its condition is known. pub struct SimplifyConstCondition { label: String, @@ -16,8 +14,8 @@ impl SimplifyConstCondition { } impl<'tcx> MirPass<'tcx> for SimplifyConstCondition { - fn name(&self) -> Cow<'_, str> { - Cow::Borrowed(&self.label) + fn name(&self) -> &str { + &self.label } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs new file mode 100644 index 000000000..558a372fb --- /dev/null +++ b/compiler/rustc_mir_transform/src/sroa.rs @@ -0,0 +1,348 @@ +use crate::MirPass; +use rustc_data_structures::fx::{FxIndexMap, IndexEntry}; +use rustc_index::bit_set::BitSet; +use rustc_index::vec::IndexVec; +use rustc_middle::mir::visit::*; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; + +pub struct ScalarReplacementOfAggregates; + +impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 3 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let escaping = escaping_locals(&*body); + debug!(?escaping); + let replacements = compute_flattening(tcx, body, escaping); + debug!(?replacements); + replace_flattened_locals(tcx, body, replacements); + } +} + +/// Identify all locals that are not eligible for SROA. +/// +/// There are 3 cases: +/// - the aggegated local is used or passed to other code (function parameters and arguments); +/// - the locals is a union or an enum; +/// - the local's address is taken, and thus the relative addresses of the fields are observable to +/// client code. +fn escaping_locals(body: &Body<'_>) -> BitSet<Local> { + let mut set = BitSet::new_empty(body.local_decls.len()); + set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count)); + for (local, decl) in body.local_decls().iter_enumerated() { + if decl.ty.is_union() || decl.ty.is_enum() { + set.insert(local); + } + } + let mut visitor = EscapeVisitor { set }; + visitor.visit_body(body); + return visitor.set; + + struct EscapeVisitor { + set: BitSet<Local>, + } + + impl<'tcx> Visitor<'tcx> for EscapeVisitor { + fn visit_local(&mut self, local: Local, _: PlaceContext, _: Location) { + self.set.insert(local); + } + + fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) { + // Mirror the implementation in PreFlattenVisitor. + if let &[PlaceElem::Field(..), ..] = &place.projection[..] { + return; + } + self.super_place(place, context, location); + } + + fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) { + if let Rvalue::AddressOf(.., place) | Rvalue::Ref(.., place) = rvalue { + if !place.is_indirect() { + // Raw pointers may be used to access anything inside the enclosing place. + self.set.insert(place.local); + return; + } + } + self.super_rvalue(rvalue, location) + } + + fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { + if let StatementKind::StorageLive(..) + | StatementKind::StorageDead(..) + | StatementKind::Deinit(..) = statement.kind + { + // Storage statements are expanded in run_pass. + return; + } + self.super_statement(statement, location) + } + + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { + // Drop implicitly calls `drop_in_place`, which takes a `&mut`. + // This implies that `Drop` implicitly takes the address of the place. + if let TerminatorKind::Drop { place, .. } + | TerminatorKind::DropAndReplace { place, .. } = terminator.kind + { + if !place.is_indirect() { + // Raw pointers may be used to access anything inside the enclosing place. + self.set.insert(place.local); + return; + } + } + self.super_terminator(terminator, location); + } + + // We ignore anything that happens in debuginfo, since we expand it using + // `VarDebugInfoContents::Composite`. + fn visit_var_debug_info(&mut self, _: &VarDebugInfo<'tcx>) {} + } +} + +#[derive(Default, Debug)] +struct ReplacementMap<'tcx> { + fields: FxIndexMap<PlaceRef<'tcx>, Local>, +} + +/// Compute the replacement of flattened places into locals. +/// +/// For each eligible place, we assign a new local to each accessed field. +/// The replacement will be done later in `ReplacementVisitor`. +fn compute_flattening<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + escaping: BitSet<Local>, +) -> ReplacementMap<'tcx> { + let mut visitor = PreFlattenVisitor { + tcx, + escaping, + local_decls: &mut body.local_decls, + map: Default::default(), + }; + for (block, bbdata) in body.basic_blocks.iter_enumerated() { + visitor.visit_basic_block_data(block, bbdata); + } + return visitor.map; + + struct PreFlattenVisitor<'tcx, 'll> { + tcx: TyCtxt<'tcx>, + local_decls: &'ll mut LocalDecls<'tcx>, + escaping: BitSet<Local>, + map: ReplacementMap<'tcx>, + } + + impl<'tcx, 'll> PreFlattenVisitor<'tcx, 'll> { + fn create_place(&mut self, place: PlaceRef<'tcx>) { + if self.escaping.contains(place.local) { + return; + } + + match self.map.fields.entry(place) { + IndexEntry::Occupied(_) => {} + IndexEntry::Vacant(v) => { + let ty = place.ty(&*self.local_decls, self.tcx).ty; + let local = self.local_decls.push(LocalDecl { + ty, + user_ty: None, + ..self.local_decls[place.local].clone() + }); + v.insert(local); + } + } + } + } + + impl<'tcx, 'll> Visitor<'tcx> for PreFlattenVisitor<'tcx, 'll> { + fn visit_place(&mut self, place: &Place<'tcx>, _: PlaceContext, _: Location) { + if let &[PlaceElem::Field(..), ..] = &place.projection[..] { + let pr = PlaceRef { local: place.local, projection: &place.projection[..1] }; + self.create_place(pr) + } + } + } +} + +/// Perform the replacement computed by `compute_flattening`. +fn replace_flattened_locals<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + replacements: ReplacementMap<'tcx>, +) { + let mut all_dead_locals = BitSet::new_empty(body.local_decls.len()); + for p in replacements.fields.keys() { + all_dead_locals.insert(p.local); + } + debug!(?all_dead_locals); + if all_dead_locals.is_empty() { + return; + } + + let mut fragments = IndexVec::new(); + for (k, v) in &replacements.fields { + fragments.ensure_contains_elem(k.local, || Vec::new()); + fragments[k.local].push((&k.projection[..], *v)); + } + debug!(?fragments); + + let mut visitor = ReplacementVisitor { + tcx, + local_decls: &body.local_decls, + replacements, + all_dead_locals, + fragments, + }; + for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() { + visitor.visit_basic_block_data(bb, data); + } + for scope in &mut body.source_scopes { + visitor.visit_source_scope_data(scope); + } + for (index, annotation) in body.user_type_annotations.iter_enumerated_mut() { + visitor.visit_user_type_annotation(index, annotation); + } + for var_debug_info in &mut body.var_debug_info { + visitor.visit_var_debug_info(var_debug_info); + } +} + +struct ReplacementVisitor<'tcx, 'll> { + tcx: TyCtxt<'tcx>, + /// This is only used to compute the type for `VarDebugInfoContents::Composite`. + local_decls: &'ll LocalDecls<'tcx>, + /// Work to do. + replacements: ReplacementMap<'tcx>, + /// This is used to check that we are not leaving references to replaced locals behind. + all_dead_locals: BitSet<Local>, + /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage + /// and deinit statement and debuginfo. + fragments: IndexVec<Local, Vec<(&'tcx [PlaceElem<'tcx>], Local)>>, +} + +impl<'tcx, 'll> ReplacementVisitor<'tcx, 'll> { + fn gather_debug_info_fragments( + &self, + place: PlaceRef<'tcx>, + ) -> Vec<VarDebugInfoFragment<'tcx>> { + let mut fragments = Vec::new(); + let parts = &self.fragments[place.local]; + for (proj, replacement_local) in parts { + if proj.starts_with(place.projection) { + fragments.push(VarDebugInfoFragment { + projection: proj[place.projection.len()..].to_vec(), + contents: Place::from(*replacement_local), + }); + } + } + fragments + } + + fn replace_place(&self, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> { + if let &[PlaceElem::Field(..), ref rest @ ..] = place.projection { + let pr = PlaceRef { local: place.local, projection: &place.projection[..1] }; + let local = self.replacements.fields.get(&pr)?; + Some(Place { local: *local, projection: self.tcx.intern_place_elems(&rest) }) + } else { + None + } + } +} + +impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { + if let StatementKind::StorageLive(..) + | StatementKind::StorageDead(..) + | StatementKind::Deinit(..) = statement.kind + { + // Storage statements are expanded in run_pass. + return; + } + self.super_statement(statement, location) + } + + fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) { + if let Some(repl) = self.replace_place(place.as_ref()) { + *place = repl + } else { + self.super_place(place, context, location) + } + } + + fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) { + match &mut var_debug_info.value { + VarDebugInfoContents::Place(ref mut place) => { + if let Some(repl) = self.replace_place(place.as_ref()) { + *place = repl; + } else if self.all_dead_locals.contains(place.local) { + let ty = place.ty(self.local_decls, self.tcx).ty; + let fragments = self.gather_debug_info_fragments(place.as_ref()); + var_debug_info.value = VarDebugInfoContents::Composite { ty, fragments }; + } + } + VarDebugInfoContents::Composite { ty: _, ref mut fragments } => { + let mut new_fragments = Vec::new(); + fragments + .drain_filter(|fragment| { + if let Some(repl) = self.replace_place(fragment.contents.as_ref()) { + fragment.contents = repl; + true + } else if self.all_dead_locals.contains(fragment.contents.local) { + let frg = self.gather_debug_info_fragments(fragment.contents.as_ref()); + new_fragments.extend(frg.into_iter().map(|mut f| { + f.projection.splice(0..0, fragment.projection.iter().copied()); + f + })); + false + } else { + true + } + }) + .for_each(drop); + fragments.extend(new_fragments); + } + VarDebugInfoContents::Const(_) => {} + } + } + + fn visit_basic_block_data(&mut self, bb: BasicBlock, bbdata: &mut BasicBlockData<'tcx>) { + self.super_basic_block_data(bb, bbdata); + + #[derive(Debug)] + enum Stmt { + StorageLive, + StorageDead, + Deinit, + } + + bbdata.expand_statements(|stmt| { + let source_info = stmt.source_info; + let (stmt, origin_local) = match &stmt.kind { + StatementKind::StorageLive(l) => (Stmt::StorageLive, *l), + StatementKind::StorageDead(l) => (Stmt::StorageDead, *l), + StatementKind::Deinit(p) if let Some(l) = p.as_local() => (Stmt::Deinit, l), + _ => return None, + }; + if !self.all_dead_locals.contains(origin_local) { + return None; + } + let final_locals = self.fragments.get(origin_local)?; + Some(final_locals.iter().map(move |&(_, l)| { + let kind = match stmt { + Stmt::StorageLive => StatementKind::StorageLive(l), + Stmt::StorageDead => StatementKind::StorageDead(l), + Stmt::Deinit => StatementKind::Deinit(Box::new(l.into())), + }; + Statement { source_info, kind } + })) + }); + } + + fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) { + assert!(!self.all_dead_locals.contains(*local)); + } +} diff --git a/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs b/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs index 96ea15f1b..be0aa0fc4 100644 --- a/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs +++ b/compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs @@ -65,7 +65,7 @@ fn variant_discriminants<'tcx>( Variants::Multiple { variants, .. } => variants .iter_enumerated() .filter_map(|(idx, layout)| { - (layout.abi() != Abi::Uninhabited) + (layout.abi != Abi::Uninhabited) .then(|| ty.discriminant_for_variant(tcx, idx).unwrap().val) }) .collect(), |