From 2e00214b3efbdfeefaa0fe9e8b8fd519de7adc35 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 17 Apr 2024 14:19:50 +0200 Subject: Merging upstream version 1.69.0+dfsg1. Signed-off-by: Daniel Baumann --- .../src/abort_unwinding_calls.rs | 2 +- .../src/check_const_item_mutation.rs | 2 +- .../rustc_mir_transform/src/check_packed_ref.rs | 88 +-- compiler/rustc_mir_transform/src/check_unsafety.rs | 4 +- compiler/rustc_mir_transform/src/const_goto.rs | 9 + compiler/rustc_mir_transform/src/const_prop.rs | 351 +++------ .../rustc_mir_transform/src/const_prop_lint.rs | 18 +- compiler/rustc_mir_transform/src/copy_prop.rs | 182 +++++ .../rustc_mir_transform/src/coverage/counters.rs | 6 +- compiler/rustc_mir_transform/src/coverage/debug.rs | 5 +- compiler/rustc_mir_transform/src/coverage/graph.rs | 14 +- compiler/rustc_mir_transform/src/coverage/mod.rs | 3 +- compiler/rustc_mir_transform/src/coverage/spans.rs | 18 +- compiler/rustc_mir_transform/src/ctfe_limit.rs | 58 ++ .../rustc_mir_transform/src/dataflow_const_prop.rs | 150 +++- .../src/dead_store_elimination.rs | 1 + compiler/rustc_mir_transform/src/deaggregator.rs | 45 -- .../rustc_mir_transform/src/deduce_param_attrs.rs | 2 +- compiler/rustc_mir_transform/src/dest_prop.rs | 12 +- .../src/elaborate_box_derefs.rs | 10 +- .../rustc_mir_transform/src/elaborate_drops.rs | 86 ++- .../rustc_mir_transform/src/ffi_unwind_calls.rs | 2 +- .../src/function_item_references.rs | 13 +- compiler/rustc_mir_transform/src/generator.rs | 389 ++++++++-- compiler/rustc_mir_transform/src/inline.rs | 12 +- compiler/rustc_mir_transform/src/inline/cycle.rs | 2 +- compiler/rustc_mir_transform/src/instcombine.rs | 86 ++- compiler/rustc_mir_transform/src/large_enums.rs | 294 ++++++++ compiler/rustc_mir_transform/src/lib.rs | 32 +- .../rustc_mir_transform/src/lower_intrinsics.rs | 26 +- .../rustc_mir_transform/src/lower_slice_len.rs | 5 +- .../rustc_mir_transform/src/normalize_array_len.rs | 322 ++------ .../src/remove_noop_landing_pads.rs | 1 + compiler/rustc_mir_transform/src/remove_zsts.rs | 2 +- .../src/separate_const_switch.rs | 2 + compiler/rustc_mir_transform/src/shim.rs | 49 +- compiler/rustc_mir_transform/src/simplify.rs | 20 +- compiler/rustc_mir_transform/src/simplify_try.rs | 822 --------------------- compiler/rustc_mir_transform/src/sroa.rs | 439 ++++++----- compiler/rustc_mir_transform/src/ssa.rs | 257 +++++++ 40 files changed, 2041 insertions(+), 1800 deletions(-) create mode 100644 compiler/rustc_mir_transform/src/copy_prop.rs create mode 100644 compiler/rustc_mir_transform/src/ctfe_limit.rs delete mode 100644 compiler/rustc_mir_transform/src/deaggregator.rs create mode 100644 compiler/rustc_mir_transform/src/large_enums.rs delete mode 100644 compiler/rustc_mir_transform/src/simplify_try.rs create mode 100644 compiler/rustc_mir_transform/src/ssa.rs (limited to 'compiler/rustc_mir_transform') diff --git a/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs b/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs index d8f85d2e3..9b4b72070 100644 --- a/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs +++ b/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs @@ -41,7 +41,7 @@ impl<'tcx> MirPass<'tcx> for AbortUnwindingCalls { // // Here we test for this function itself whether its ABI allows // unwinding or not. - let body_ty = tcx.type_of(def_id); + let body_ty = tcx.type_of(def_id).skip_binder(); let body_abi = match body_ty.kind() { ty::FnDef(..) => body_ty.fn_sig(tcx).abi(), ty::Closure(..) => Abi::RustCall, diff --git a/compiler/rustc_mir_transform/src/check_const_item_mutation.rs b/compiler/rustc_mir_transform/src/check_const_item_mutation.rs index fa5f392fa..536745d2c 100644 --- a/compiler/rustc_mir_transform/src/check_const_item_mutation.rs +++ b/compiler/rustc_mir_transform/src/check_const_item_mutation.rs @@ -74,7 +74,7 @@ impl<'tcx> ConstMutationChecker<'_, 'tcx> { // // `unsafe { *FOO = 0; *BAR.field = 1; }` // `unsafe { &mut *FOO }` - // `unsafe { (*ARRAY)[0] = val; } + // `unsafe { (*ARRAY)[0] = val; }` if !place.projection.iter().any(|p| matches!(p, PlaceElem::Deref)) { let source_info = self.body.source_info(location); let lint_root = self.body.source_scopes[source_info.scope] diff --git a/compiler/rustc_mir_transform/src/check_packed_ref.rs b/compiler/rustc_mir_transform/src/check_packed_ref.rs index 51abcf511..f5f1c1010 100644 --- a/compiler/rustc_mir_transform/src/check_packed_ref.rs +++ b/compiler/rustc_mir_transform/src/check_packed_ref.rs @@ -1,17 +1,11 @@ -use rustc_hir::def_id::LocalDefId; +use rustc_errors::struct_span_err; use rustc_middle::mir::visit::{PlaceContext, Visitor}; use rustc_middle::mir::*; -use rustc_middle::ty::query::Providers; use rustc_middle::ty::{self, TyCtxt}; -use rustc_session::lint::builtin::UNALIGNED_REFERENCES; use crate::util; use crate::MirLint; -pub(crate) fn provide(providers: &mut Providers) { - *providers = Providers { unsafe_derive_on_repr_packed, ..*providers }; -} - pub struct CheckPackedRef; impl<'tcx> MirLint<'tcx> for CheckPackedRef { @@ -30,32 +24,6 @@ struct PackedRefChecker<'a, 'tcx> { source_info: SourceInfo, } -fn unsafe_derive_on_repr_packed(tcx: TyCtxt<'_>, def_id: LocalDefId) { - let lint_hir_id = tcx.hir().local_def_id_to_hir_id(def_id); - - // FIXME: when we make this a hard error, this should have its - // own error code. - - let extra = if tcx.generics_of(def_id).own_requires_monomorphization() { - "with type or const parameters" - } else { - "that does not derive `Copy`" - }; - let message = format!( - "`{}` can't be derived on this `#[repr(packed)]` struct {}", - tcx.item_name(tcx.trait_id_of_impl(def_id.to_def_id()).expect("derived trait name")), - extra - ); - - tcx.struct_span_lint_hir( - UNALIGNED_REFERENCES, - lint_hir_id, - tcx.def_span(def_id), - message, - |lint| lint, - ); -} - impl<'tcx> Visitor<'tcx> for PackedRefChecker<'_, 'tcx> { fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { // Make sure we know where in the MIR we are. @@ -73,40 +41,30 @@ impl<'tcx> Visitor<'tcx> for PackedRefChecker<'_, 'tcx> { if context.is_borrow() { if util::is_disaligned(self.tcx, self.body, self.param_env, *place) { let def_id = self.body.source.instance.def_id(); - if let Some(impl_def_id) = self - .tcx - .impl_of_method(def_id) - .filter(|&def_id| self.tcx.is_builtin_derive(def_id)) + if let Some(impl_def_id) = self.tcx.impl_of_method(def_id) + && self.tcx.is_builtin_derived(impl_def_id) { - // If a method is defined in the local crate, - // the impl containing that method should also be. - self.tcx.ensure().unsafe_derive_on_repr_packed(impl_def_id.expect_local()); + // If we ever reach here it means that the generated derive + // code is somehow doing an unaligned reference, which it + // shouldn't do. + span_bug!(self.source_info.span, "builtin derive created an unaligned reference"); } else { - let source_info = self.source_info; - let lint_root = self.body.source_scopes[source_info.scope] - .local_data - .as_ref() - .assert_crate_local() - .lint_root; - self.tcx.struct_span_lint_hir( - UNALIGNED_REFERENCES, - lint_root, - source_info.span, - "reference to packed field is unaligned", - |lint| { - lint - .note( - "fields of packed structs are not properly aligned, and creating \ - a misaligned reference is undefined behavior (even if that \ - reference is never dereferenced)", - ) - .help( - "copy the field contents to a local variable, or replace the \ - reference with a raw pointer and use `read_unaligned`/`write_unaligned` \ - (loads and stores via `*p` must be properly aligned even when using raw pointers)" - ) - }, - ); + struct_span_err!( + self.tcx.sess, + self.source_info.span, + E0793, + "reference to packed field is unaligned" + ) + .note( + "fields of packed structs are not properly aligned, and creating \ + a misaligned reference is undefined behavior (even if that \ + reference is never dereferenced)", + ).help( + "copy the field contents to a local variable, or replace the \ + reference with a raw pointer and use `read_unaligned`/`write_unaligned` \ + (loads and stores via `*p` must be properly aligned even when using raw pointers)" + ) + .emit(); } } } diff --git a/compiler/rustc_mir_transform/src/check_unsafety.rs b/compiler/rustc_mir_transform/src/check_unsafety.rs index adf6ae4c7..d00ee1f4b 100644 --- a/compiler/rustc_mir_transform/src/check_unsafety.rs +++ b/compiler/rustc_mir_transform/src/check_unsafety.rs @@ -104,6 +104,7 @@ impl<'tcx> Visitor<'tcx> for UnsafetyChecker<'_, 'tcx> { | StatementKind::AscribeUserType(..) | StatementKind::Coverage(..) | StatementKind::Intrinsic(..) + | StatementKind::ConstEvalCounter | StatementKind::Nop => { // safe (at least as emitted during MIR construction) } @@ -125,6 +126,7 @@ impl<'tcx> Visitor<'tcx> for UnsafetyChecker<'_, 'tcx> { } } &AggregateKind::Closure(def_id, _) | &AggregateKind::Generator(def_id, _, _) => { + let def_id = def_id.expect_local(); let UnsafetyCheckResult { violations, used_unsafe_blocks, .. } = self.tcx.unsafety_check_result(def_id); self.register_violations(violations, used_unsafe_blocks.iter().copied()); @@ -445,7 +447,7 @@ impl<'tcx> intravisit::Visitor<'tcx> for UnusedUnsafeVisitor<'_, 'tcx> { _fd: &'tcx hir::FnDecl<'tcx>, b: hir::BodyId, _s: rustc_span::Span, - _id: HirId, + _id: LocalDefId, ) { if matches!(fk, intravisit::FnKind::Closure) { self.visit_body(self.tcx.hir().body(b)) diff --git a/compiler/rustc_mir_transform/src/const_goto.rs b/compiler/rustc_mir_transform/src/const_goto.rs index 40eefda4f..da101ca7a 100644 --- a/compiler/rustc_mir_transform/src/const_goto.rs +++ b/compiler/rustc_mir_transform/src/const_goto.rs @@ -57,6 +57,15 @@ impl<'tcx> MirPass<'tcx> for ConstGoto { } impl<'tcx> Visitor<'tcx> for ConstGotoOptimizationFinder<'_, 'tcx> { + fn visit_basic_block_data(&mut self, block: BasicBlock, data: &BasicBlockData<'tcx>) { + if data.is_cleanup { + // Because of the restrictions around control flow in cleanup blocks, we don't perform + // this optimization at all in such blocks. + return; + } + self.super_basic_block_data(block, data); + } + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { let _: Option<_> = try { let target = terminator.kind.as_goto()?; diff --git a/compiler/rustc_mir_transform/src/const_prop.rs b/compiler/rustc_mir_transform/src/const_prop.rs index 5c45abc5a..6b2eefce2 100644 --- a/compiler/rustc_mir_transform/src/const_prop.rs +++ b/compiler/rustc_mir_transform/src/const_prop.rs @@ -5,7 +5,6 @@ use std::cell::Cell; use either::Right; -use rustc_ast::Mutability; use rustc_const_eval::const_eval::CheckAlignment; use rustc_data_structures::fx::FxHashSet; use rustc_hir::def::DefKind; @@ -14,14 +13,10 @@ use rustc_index::vec::IndexVec; use rustc_middle::mir::visit::{ MutVisitor, MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor, }; -use rustc_middle::mir::{ - BasicBlock, BinOp, Body, Constant, ConstantKind, Local, LocalDecl, LocalKind, Location, - Operand, Place, Rvalue, SourceInfo, Statement, StatementKind, Terminator, TerminatorKind, UnOp, - RETURN_PLACE, -}; +use rustc_middle::mir::*; use rustc_middle::ty::layout::{LayoutError, LayoutOf, LayoutOfHelpers, TyAndLayout}; use rustc_middle::ty::InternalSubsts; -use rustc_middle::ty::{self, ConstKind, Instance, ParamEnv, Ty, TyCtxt, TypeVisitable}; +use rustc_middle::ty::{self, ConstKind, Instance, ParamEnv, Ty, TyCtxt, TypeVisitableExt}; use rustc_span::{def_id::DefId, Span}; use rustc_target::abi::{self, Align, HasDataLayout, Size, TargetDataLayout}; use rustc_target::spec::abi::Abi as CallAbi; @@ -83,7 +78,7 @@ impl<'tcx> MirPass<'tcx> for ConstProp { return; } - let is_generator = tcx.type_of(def_id.to_def_id()).is_generator(); + let is_generator = tcx.type_of(def_id.to_def_id()).subst_identity().is_generator(); // FIXME(welseywiser) const prop doesn't work on generators because of query cycles // computing their layout. if is_generator { @@ -289,7 +284,7 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx> } // If the static allocation is mutable, then we can't const prop it as its content // might be different at runtime. - if alloc.inner().mutability == Mutability::Mut { + if alloc.inner().mutability.is_mut() { throw_machine_stop_str!("can't access mutable globals in ConstProp"); } @@ -457,27 +452,6 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { }; } - fn use_ecx(&mut self, f: F) -> Option - where - F: FnOnce(&mut Self) -> InterpResult<'tcx, T>, - { - match f(self) { - Ok(val) => Some(val), - Err(error) => { - trace!("InterpCx operation failed: {:?}", error); - // Some errors shouldn't come up because creating them causes - // an allocation, which we should avoid. When that happens, - // dedicated error variants should be introduced instead. - assert!( - !error.kind().formatted_string(), - "const-prop encountered formatting error: {}", - error - ); - None - } - } - } - /// Returns the value, if any, of evaluating `c`. fn eval_constant(&mut self, c: &Constant<'tcx>) -> Option> { // FIXME we need to revisit this for #67176 @@ -492,7 +466,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { /// Returns the value, if any, of evaluating `place`. fn eval_place(&mut self, place: Place<'tcx>) -> Option> { trace!("eval_place(place={:?})", place); - self.use_ecx(|this| this.ecx.eval_place_to_op(place, None)) + self.ecx.eval_place_to_op(place, None).ok() } /// Returns the value, if any, of evaluating `op`. Calls upon `eval_constant` @@ -504,55 +478,6 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { } } - fn check_unary_op(&mut self, op: UnOp, arg: &Operand<'tcx>) -> Option<()> { - if self.use_ecx(|this| { - let val = this.ecx.read_immediate(&this.ecx.eval_operand(arg, None)?)?; - let (_res, overflow, _ty) = this.ecx.overflowing_unary_op(op, &val)?; - Ok(overflow) - })? { - // `AssertKind` only has an `OverflowNeg` variant, so make sure that is - // appropriate to use. - assert_eq!(op, UnOp::Neg, "Neg is the only UnOp that can overflow"); - return None; - } - - Some(()) - } - - fn check_binary_op( - &mut self, - op: BinOp, - left: &Operand<'tcx>, - right: &Operand<'tcx>, - ) -> Option<()> { - let r = self.use_ecx(|this| this.ecx.read_immediate(&this.ecx.eval_operand(right, None)?)); - let l = self.use_ecx(|this| this.ecx.read_immediate(&this.ecx.eval_operand(left, None)?)); - // Check for exceeding shifts *even if* we cannot evaluate the LHS. - if op == BinOp::Shr || op == BinOp::Shl { - let r = r.clone()?; - // We need the type of the LHS. We cannot use `place_layout` as that is the type - // of the result, which for checked binops is not the same! - let left_ty = left.ty(self.local_decls, self.tcx); - let left_size = self.ecx.layout_of(left_ty).ok()?.size; - let right_size = r.layout.size; - let r_bits = r.to_scalar().to_bits(right_size).ok(); - if r_bits.map_or(false, |b| b >= left_size.bits() as u128) { - return None; - } - } - - if let (Some(l), Some(r)) = (&l, &r) { - // The remaining operators are handled through `overflowing_binary_op`. - if self.use_ecx(|this| { - let (_res, overflow, _ty) = this.ecx.overflowing_binary_op(op, l, r)?; - Ok(overflow) - })? { - return None; - } - } - Some(()) - } - fn propagate_operand(&mut self, operand: &mut Operand<'tcx>) { match *operand { Operand::Copy(l) | Operand::Move(l) => { @@ -588,28 +513,6 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { // 2. Working around bugs in other parts of the compiler // - In this case, we'll return `None` from this function to stop evaluation. match rvalue { - // Additional checking: give lints to the user if an overflow would occur. - // We do this here and not in the `Assert` terminator as that terminator is - // only sometimes emitted (overflow checks can be disabled), but we want to always - // lint. - Rvalue::UnaryOp(op, arg) => { - trace!("checking UnaryOp(op = {:?}, arg = {:?})", op, arg); - self.check_unary_op(*op, arg)?; - } - Rvalue::BinaryOp(op, box (left, right)) => { - trace!("checking BinaryOp(op = {:?}, left = {:?}, right = {:?})", op, left, right); - self.check_binary_op(*op, left, right)?; - } - Rvalue::CheckedBinaryOp(op, box (left, right)) => { - trace!( - "checking CheckedBinaryOp(op = {:?}, left = {:?}, right = {:?})", - op, - left, - right - ); - self.check_binary_op(*op, left, right)?; - } - // Do not try creating references (#67862) Rvalue::AddressOf(_, place) | Rvalue::Ref(_, _, place) => { trace!("skipping AddressOf | Ref for {:?}", place); @@ -639,7 +542,10 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { | Rvalue::Cast(..) | Rvalue::ShallowInitBox(..) | Rvalue::Discriminant(..) - | Rvalue::NullaryOp(..) => {} + | Rvalue::NullaryOp(..) + | Rvalue::UnaryOp(..) + | Rvalue::BinaryOp(..) + | Rvalue::CheckedBinaryOp(..) => {} } // FIXME we need to revisit this for #67176 @@ -664,35 +570,37 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { rvalue: &Rvalue<'tcx>, place: Place<'tcx>, ) -> Option<()> { - self.use_ecx(|this| match rvalue { + match rvalue { Rvalue::BinaryOp(op, box (left, right)) | Rvalue::CheckedBinaryOp(op, box (left, right)) => { - let l = this.ecx.eval_operand(left, None).and_then(|x| this.ecx.read_immediate(&x)); + let l = self.ecx.eval_operand(left, None).and_then(|x| self.ecx.read_immediate(&x)); let r = - this.ecx.eval_operand(right, None).and_then(|x| this.ecx.read_immediate(&x)); + self.ecx.eval_operand(right, None).and_then(|x| self.ecx.read_immediate(&x)); let const_arg = match (l, r) { (Ok(x), Err(_)) | (Err(_), Ok(x)) => x, // exactly one side is known - (Err(e), Err(_)) => return Err(e), // neither side is known - (Ok(_), Ok(_)) => return this.ecx.eval_rvalue_into_place(rvalue, place), // both sides are known + (Err(_), Err(_)) => return None, // neither side is known + (Ok(_), Ok(_)) => return self.ecx.eval_rvalue_into_place(rvalue, place).ok(), // both sides are known }; if !matches!(const_arg.layout.abi, abi::Abi::Scalar(..)) { // We cannot handle Scalar Pair stuff. // No point in calling `eval_rvalue_into_place`, since only one side is known - throw_machine_stop_str!("cannot optimize this") + return None; } - let arg_value = const_arg.to_scalar().to_bits(const_arg.layout.size)?; - let dest = this.ecx.eval_place(place)?; + let arg_value = const_arg.to_scalar().to_bits(const_arg.layout.size).ok()?; + let dest = self.ecx.eval_place(place).ok()?; match op { - BinOp::BitAnd if arg_value == 0 => this.ecx.write_immediate(*const_arg, &dest), + BinOp::BitAnd if arg_value == 0 => { + self.ecx.write_immediate(*const_arg, &dest).ok() + } BinOp::BitOr if arg_value == const_arg.layout.size.truncate(u128::MAX) || (const_arg.layout.ty.is_bool() && arg_value == 1) => { - this.ecx.write_immediate(*const_arg, &dest) + self.ecx.write_immediate(*const_arg, &dest).ok() } BinOp::Mul if const_arg.layout.ty.is_integral() && arg_value == 0 => { if let Rvalue::CheckedBinaryOp(_, _) = rvalue { @@ -700,16 +608,16 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { const_arg.to_scalar(), Scalar::from_bool(false), ); - this.ecx.write_immediate(val, &dest) + self.ecx.write_immediate(val, &dest).ok() } else { - this.ecx.write_immediate(*const_arg, &dest) + self.ecx.write_immediate(*const_arg, &dest).ok() } } - _ => throw_machine_stop_str!("cannot optimize this"), + _ => None, } } - _ => this.ecx.eval_rvalue_into_place(rvalue, place), - }) + _ => self.ecx.eval_rvalue_into_place(rvalue, place).ok(), + } } /// Creates a new `Operand::Constant` from a `Scalar` value @@ -751,7 +659,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { } // FIXME> figure out what to do when read_immediate_raw fails - let imm = self.use_ecx(|this| this.ecx.read_immediate_raw(value)); + let imm = self.ecx.read_immediate_raw(value).ok(); if let Some(Right(imm)) = imm { match *imm { @@ -771,25 +679,23 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { if let ty::Tuple(types) = ty.kind() { // Only do it if tuple is also a pair with two scalars if let [ty1, ty2] = types[..] { - let alloc = self.use_ecx(|this| { - let ty_is_scalar = |ty| { - this.ecx.layout_of(ty).ok().map(|layout| layout.abi.is_scalar()) - == Some(true) - }; - if ty_is_scalar(ty1) && ty_is_scalar(ty2) { - let alloc = this - .ecx - .intern_with_temp_alloc(value.layout, |ecx, dest| { - ecx.write_immediate(*imm, dest) - }) - .unwrap(); - Ok(Some(alloc)) - } else { - Ok(None) - } - }); - - if let Some(Some(alloc)) = alloc { + let ty_is_scalar = |ty| { + self.ecx.layout_of(ty).ok().map(|layout| layout.abi.is_scalar()) + == Some(true) + }; + let alloc = if ty_is_scalar(ty1) && ty_is_scalar(ty2) { + let alloc = self + .ecx + .intern_with_temp_alloc(value.layout, |ecx, dest| { + ecx.write_immediate(*imm, dest) + }) + .unwrap(); + Some(alloc) + } else { + None + }; + + if let Some(alloc) = alloc { // Assign entire constant in a single statement. // We can't use aggregates, as we run after the aggregate-lowering `MirPhase`. let const_val = ConstValue::ByRef { alloc, offset: Size::ZERO }; @@ -990,84 +896,80 @@ impl<'tcx> MutVisitor<'tcx> for ConstPropagator<'_, 'tcx> { trace!("visit_statement: {:?}", statement); let source_info = statement.source_info; self.source_info = Some(source_info); - if let StatementKind::Assign(box (place, ref mut rval)) = statement.kind { - let can_const_prop = self.ecx.machine.can_const_prop[place.local]; - if let Some(()) = self.const_prop(rval, place) { - // This will return None if the above `const_prop` invocation only "wrote" a - // type whose creation requires no write. E.g. a generator whose initial state - // consists solely of uninitialized memory (so it doesn't capture any locals). - if let Some(ref value) = self.get_const(place) && self.should_const_prop(value) { - trace!("replacing {:?} with {:?}", rval, value); - self.replace_with_const(rval, value, source_info); - if can_const_prop == ConstPropMode::FullConstProp - || can_const_prop == ConstPropMode::OnlyInsideOwnBlock - { - trace!("propagated into {:?}", place); + match statement.kind { + StatementKind::Assign(box (place, ref mut rval)) => { + let can_const_prop = self.ecx.machine.can_const_prop[place.local]; + if let Some(()) = self.const_prop(rval, place) { + // This will return None if the above `const_prop` invocation only "wrote" a + // type whose creation requires no write. E.g. a generator whose initial state + // consists solely of uninitialized memory (so it doesn't capture any locals). + if let Some(ref value) = self.get_const(place) && self.should_const_prop(value) { + trace!("replacing {:?} with {:?}", rval, value); + self.replace_with_const(rval, value, source_info); + if can_const_prop == ConstPropMode::FullConstProp + || can_const_prop == ConstPropMode::OnlyInsideOwnBlock + { + trace!("propagated into {:?}", place); + } } - } - match can_const_prop { - ConstPropMode::OnlyInsideOwnBlock => { - trace!( - "found local restricted to its block. \ + match can_const_prop { + ConstPropMode::OnlyInsideOwnBlock => { + trace!( + "found local restricted to its block. \ Will remove it from const-prop after block is finished. Local: {:?}", - place.local - ); - } - ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => { - trace!("can't propagate into {:?}", place); - if place.local != RETURN_PLACE { - Self::remove_const(&mut self.ecx, place.local); + place.local + ); } - } - ConstPropMode::FullConstProp => {} - } - } else { - // Const prop failed, so erase the destination, ensuring that whatever happens - // from here on, does not know about the previous value. - // This is important in case we have - // ```rust - // let mut x = 42; - // x = SOME_MUTABLE_STATIC; - // // x must now be uninit - // ``` - // FIXME: we overzealously erase the entire local, because that's easier to - // implement. - trace!( - "propagation into {:?} failed. - Nuking the entire site from orbit, it's the only way to be sure", - place, - ); - Self::remove_const(&mut self.ecx, place.local); - } - } else { - match statement.kind { - StatementKind::SetDiscriminant { ref place, .. } => { - match self.ecx.machine.can_const_prop[place.local] { - ConstPropMode::FullConstProp | ConstPropMode::OnlyInsideOwnBlock => { - if self.use_ecx(|this| this.ecx.statement(statement)).is_some() { - trace!("propped discriminant into {:?}", place); - } else { + ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => { + trace!("can't propagate into {:?}", place); + if place.local != RETURN_PLACE { Self::remove_const(&mut self.ecx, place.local); } } - ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => { - Self::remove_const(&mut self.ecx, place.local); - } + ConstPropMode::FullConstProp => {} } + } else { + // Const prop failed, so erase the destination, ensuring that whatever happens + // from here on, does not know about the previous value. + // This is important in case we have + // ```rust + // let mut x = 42; + // x = SOME_MUTABLE_STATIC; + // // x must now be uninit + // ``` + // FIXME: we overzealously erase the entire local, because that's easier to + // implement. + trace!( + "propagation into {:?} failed. + Nuking the entire site from orbit, it's the only way to be sure", + place, + ); + Self::remove_const(&mut self.ecx, place.local); } - StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => { - let frame = self.ecx.frame_mut(); - frame.locals[local].value = - if let StatementKind::StorageLive(_) = statement.kind { - LocalValue::Live(interpret::Operand::Immediate( - interpret::Immediate::Uninit, - )) + } + StatementKind::SetDiscriminant { ref place, .. } => { + match self.ecx.machine.can_const_prop[place.local] { + ConstPropMode::FullConstProp | ConstPropMode::OnlyInsideOwnBlock => { + if self.ecx.statement(statement).is_ok() { + trace!("propped discriminant into {:?}", place); } else { - LocalValue::Dead - }; + Self::remove_const(&mut self.ecx, place.local); + } + } + ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => { + Self::remove_const(&mut self.ecx, place.local); + } } - _ => {} } + StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => { + let frame = self.ecx.frame_mut(); + frame.locals[local].value = if let StatementKind::StorageLive(_) = statement.kind { + LocalValue::Live(interpret::Operand::Immediate(interpret::Immediate::Uninit)) + } else { + LocalValue::Dead + }; + } + _ => {} } self.super_statement(statement, location); @@ -1077,34 +979,19 @@ impl<'tcx> MutVisitor<'tcx> for ConstPropagator<'_, 'tcx> { let source_info = terminator.source_info; self.source_info = Some(source_info); self.super_terminator(terminator, location); - // Do NOT early return in this function, it does some crucial fixup of the state at the end! + match &mut terminator.kind { TerminatorKind::Assert { expected, ref mut cond, .. } => { - if let Some(ref value) = self.eval_operand(&cond) { + if let Some(ref value) = self.eval_operand(&cond) + && let Ok(value_const) = self.ecx.read_scalar(&value) + && self.should_const_prop(value) + { trace!("assertion on {:?} should be {:?}", value, expected); - let expected = Scalar::from_bool(*expected); - // FIXME should be used use_ecx rather than a local match... but we have - // quite a few of these read_scalar/read_immediate that need fixing. - if let Ok(value_const) = self.ecx.read_scalar(&value) { - if expected != value_const { - // Poison all places this operand references so that further code - // doesn't use the invalid value - match cond { - Operand::Move(ref place) | Operand::Copy(ref place) => { - Self::remove_const(&mut self.ecx, place.local); - } - Operand::Constant(_) => {} - } - } else { - if self.should_const_prop(value) { - *cond = self.operand_from_scalar( - value_const, - self.tcx.types.bool, - source_info.span, - ); - } - } - } + *cond = self.operand_from_scalar( + value_const, + self.tcx.types.bool, + source_info.span, + ); } } TerminatorKind::SwitchInt { ref mut discr, .. } => { @@ -1132,6 +1019,10 @@ impl<'tcx> MutVisitor<'tcx> for ConstPropagator<'_, 'tcx> { // gated on `mir_opt_level=3`. TerminatorKind::Call { .. } => {} } + } + + fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) { + self.super_basic_block_data(block, data); // We remove all Locals which are restricted in propagation to their containing blocks and // which were modified in the current block. diff --git a/compiler/rustc_mir_transform/src/const_prop_lint.rs b/compiler/rustc_mir_transform/src/const_prop_lint.rs index 0ab67228f..fd9475748 100644 --- a/compiler/rustc_mir_transform/src/const_prop_lint.rs +++ b/compiler/rustc_mir_transform/src/const_prop_lint.rs @@ -21,7 +21,9 @@ use rustc_middle::mir::{ }; use rustc_middle::ty::layout::{LayoutError, LayoutOf, LayoutOfHelpers, TyAndLayout}; use rustc_middle::ty::InternalSubsts; -use rustc_middle::ty::{self, ConstInt, Instance, ParamEnv, ScalarInt, Ty, TyCtxt, TypeVisitable}; +use rustc_middle::ty::{ + self, ConstInt, Instance, ParamEnv, ScalarInt, Ty, TyCtxt, TypeVisitableExt, +}; use rustc_session::lint; use rustc_span::Span; use rustc_target::abi::{HasDataLayout, Size, TargetDataLayout}; @@ -57,7 +59,7 @@ impl<'tcx> MirLint<'tcx> for ConstProp { return; } - let is_generator = tcx.type_of(def_id.to_def_id()).is_generator(); + let is_generator = tcx.type_of(def_id.to_def_id()).subst_identity().is_generator(); // FIXME(welseywiser) const prop doesn't work on generators because of query cycles // computing their layout. if is_generator { @@ -296,7 +298,15 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { return None; } - self.use_ecx(source_info, |this| this.ecx.eval_mir_constant(&c.literal, Some(c.span), None)) + // Normalization needed b/c const prop lint runs in + // `mir_drops_elaborated_and_const_checked`, which happens before + // optimized MIR. Only after optimizing the MIR can we guarantee + // that the `RevealAll` pass has happened and that the body's consts + // are normalized, so any call to resolve before that needs to be + // manually normalized. + let val = self.tcx.try_normalize_erasing_regions(self.param_env, c.literal).ok()?; + + self.use_ecx(source_info, |this| this.ecx.eval_mir_constant(&val, Some(c.span), None)) } /// Returns the value, if any, of evaluating `place`. @@ -368,7 +378,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { this.ecx.read_immediate(&this.ecx.eval_operand(left, None)?) }); // Check for exceeding shifts *even if* we cannot evaluate the LHS. - if op == BinOp::Shr || op == BinOp::Shl { + if matches!(op, BinOp::Shr | BinOp::Shl) { let r = r.clone()?; // We need the type of the LHS. We cannot use `place_layout` as that is the type // of the result, which for checked binops is not the same! diff --git a/compiler/rustc_mir_transform/src/copy_prop.rs b/compiler/rustc_mir_transform/src/copy_prop.rs new file mode 100644 index 000000000..f27beb64a --- /dev/null +++ b/compiler/rustc_mir_transform/src/copy_prop.rs @@ -0,0 +1,182 @@ +use rustc_index::bit_set::BitSet; +use rustc_index::vec::IndexVec; +use rustc_middle::mir::visit::*; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; +use rustc_mir_dataflow::impls::borrowed_locals; + +use crate::ssa::SsaLocals; +use crate::MirPass; + +/// Unify locals that copy each other. +/// +/// We consider patterns of the form +/// _a = rvalue +/// _b = move? _a +/// _c = move? _a +/// _d = move? _c +/// where each of the locals is only assigned once. +/// +/// We want to replace all those locals by `_a`, either copied or moved. +pub struct CopyProp; + +impl<'tcx> MirPass<'tcx> for CopyProp { + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { + sess.mir_opt_level() >= 1 + } + + #[instrument(level = "trace", skip(self, tcx, body))] + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + debug!(def_id = ?body.source.def_id()); + propagate_ssa(tcx, body); + } +} + +fn propagate_ssa<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); + let borrowed_locals = borrowed_locals(body); + let ssa = SsaLocals::new(tcx, param_env, body, &borrowed_locals); + + let fully_moved = fully_moved_locals(&ssa, body); + debug!(?fully_moved); + + let mut storage_to_remove = BitSet::new_empty(fully_moved.domain_size()); + for (local, &head) in ssa.copy_classes().iter_enumerated() { + if local != head { + storage_to_remove.insert(head); + } + } + + let any_replacement = ssa.copy_classes().iter_enumerated().any(|(l, &h)| l != h); + + Replacer { + tcx, + copy_classes: &ssa.copy_classes(), + fully_moved, + borrowed_locals, + storage_to_remove, + } + .visit_body_preserves_cfg(body); + + if any_replacement { + crate::simplify::remove_unused_definitions(body); + } +} + +/// `SsaLocals` computed equivalence classes between locals considering copy/move assignments. +/// +/// This function also returns whether all the `move?` in the pattern are `move` and not copies. +/// A local which is in the bitset can be replaced by `move _a`. Otherwise, it must be +/// replaced by `copy _a`, as we cannot move multiple times from `_a`. +/// +/// If an operand copies `_c`, it must happen before the assignment `_d = _c`, otherwise it is UB. +/// This means that replacing it by a copy of `_a` if ok, since this copy happens before `_c` is +/// moved, and therefore that `_d` is moved. +#[instrument(level = "trace", skip(ssa, body))] +fn fully_moved_locals(ssa: &SsaLocals, body: &Body<'_>) -> BitSet { + let mut fully_moved = BitSet::new_filled(body.local_decls.len()); + + for (_, rvalue) in ssa.assignments(body) { + let (Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) | Rvalue::CopyForDeref(place)) + = rvalue + else { continue }; + + let Some(rhs) = place.as_local() else { continue }; + if !ssa.is_ssa(rhs) { + continue; + } + + if let Rvalue::Use(Operand::Copy(_)) | Rvalue::CopyForDeref(_) = rvalue { + fully_moved.remove(rhs); + } + } + + ssa.meet_copy_equivalence(&mut fully_moved); + + fully_moved +} + +/// Utility to help performing substitution of `*pattern` by `target`. +struct Replacer<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + fully_moved: BitSet, + storage_to_remove: BitSet, + borrowed_locals: BitSet, + copy_classes: &'a IndexVec, +} + +impl<'tcx> MutVisitor<'tcx> for Replacer<'_, 'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn visit_local(&mut self, local: &mut Local, ctxt: PlaceContext, _: Location) { + let new_local = self.copy_classes[*local]; + match ctxt { + // Do not modify the local in storage statements. + PlaceContext::NonUse(NonUseContext::StorageLive | NonUseContext::StorageDead) => {} + // The local should have been marked as non-SSA. + PlaceContext::MutatingUse(_) => assert_eq!(*local, new_local), + // We access the value. + _ => *local = new_local, + } + } + + fn visit_place(&mut self, place: &mut Place<'tcx>, ctxt: PlaceContext, loc: Location) { + if let Some(new_projection) = self.process_projection(&place.projection, loc) { + place.projection = self.tcx().mk_place_elems(&new_projection); + } + + let observes_address = match ctxt { + PlaceContext::NonMutatingUse( + NonMutatingUseContext::SharedBorrow + | NonMutatingUseContext::ShallowBorrow + | NonMutatingUseContext::UniqueBorrow + | NonMutatingUseContext::AddressOf, + ) => true, + // For debuginfo, merging locals is ok. + PlaceContext::NonUse(NonUseContext::VarDebugInfo) => { + self.borrowed_locals.contains(place.local) + } + _ => false, + }; + if observes_address && !place.is_indirect() { + // We observe the address of `place.local`. Do not replace it. + } else { + self.visit_local( + &mut place.local, + PlaceContext::NonMutatingUse(NonMutatingUseContext::Copy), + loc, + ) + } + } + + fn visit_operand(&mut self, operand: &mut Operand<'tcx>, loc: Location) { + if let Operand::Move(place) = *operand + // A move out of a projection of a copy is equivalent to a copy of the original projection. + && !place.has_deref() + && !self.fully_moved.contains(place.local) + { + *operand = Operand::Copy(place); + } + self.super_operand(operand, loc); + } + + fn visit_statement(&mut self, stmt: &mut Statement<'tcx>, loc: Location) { + match stmt.kind { + // When removing storage statements, we need to remove both (#107511). + StatementKind::StorageLive(l) | StatementKind::StorageDead(l) + if self.storage_to_remove.contains(l) => + { + stmt.make_nop() + } + StatementKind::Assign(box (ref place, ref mut rvalue)) + if place.as_local().is_some() => + { + // Do not replace assignments. + self.visit_rvalue(rvalue, loc) + } + _ => self.super_statement(stmt, loc), + } + } +} diff --git a/compiler/rustc_mir_transform/src/coverage/counters.rs b/compiler/rustc_mir_transform/src/coverage/counters.rs index 45de0c280..658e01d93 100644 --- a/compiler/rustc_mir_transform/src/coverage/counters.rs +++ b/compiler/rustc_mir_transform/src/coverage/counters.rs @@ -520,7 +520,7 @@ impl<'a> BcbCounters<'a> { let mut found_loop_exit = false; for &branch in branches.iter() { if backedge_from_bcbs.iter().any(|&backedge_from_bcb| { - self.bcb_is_dominated_by(backedge_from_bcb, branch.target_bcb) + self.bcb_dominates(branch.target_bcb, backedge_from_bcb) }) { if let Some(reloop_branch) = some_reloop_branch { if reloop_branch.counter(&self.basic_coverage_blocks).is_none() { @@ -603,8 +603,8 @@ impl<'a> BcbCounters<'a> { } #[inline] - fn bcb_is_dominated_by(&self, node: BasicCoverageBlock, dom: BasicCoverageBlock) -> bool { - self.basic_coverage_blocks.is_dominated_by(node, dom) + fn bcb_dominates(&self, dom: BasicCoverageBlock, node: BasicCoverageBlock) -> bool { + self.basic_coverage_blocks.dominates(dom, node) } #[inline] diff --git a/compiler/rustc_mir_transform/src/coverage/debug.rs b/compiler/rustc_mir_transform/src/coverage/debug.rs index d6a298fad..22ea8710e 100644 --- a/compiler/rustc_mir_transform/src/coverage/debug.rs +++ b/compiler/rustc_mir_transform/src/coverage/debug.rs @@ -323,7 +323,10 @@ impl DebugCounters { String::new() }, self.format_operand(lhs), - if op == Op::Add { "+" } else { "-" }, + match op { + Op::Add => "+", + Op::Subtract => "-", + }, self.format_operand(rhs), ); } diff --git a/compiler/rustc_mir_transform/src/coverage/graph.rs b/compiler/rustc_mir_transform/src/coverage/graph.rs index 78d28f1eb..a2671eef2 100644 --- a/compiler/rustc_mir_transform/src/coverage/graph.rs +++ b/compiler/rustc_mir_transform/src/coverage/graph.rs @@ -209,8 +209,8 @@ impl CoverageGraph { } #[inline(always)] - pub fn is_dominated_by(&self, node: BasicCoverageBlock, dom: BasicCoverageBlock) -> bool { - self.dominators.as_ref().unwrap().is_dominated_by(node, dom) + pub fn dominates(&self, dom: BasicCoverageBlock, node: BasicCoverageBlock) -> bool { + self.dominators.as_ref().unwrap().dominates(dom, node) } #[inline(always)] @@ -312,7 +312,7 @@ rustc_index::newtype_index! { /// to the BCB's primary counter or expression). /// /// The BCB CFG is critical to simplifying the coverage analysis by ensuring graph path-based -/// queries (`is_dominated_by()`, `predecessors`, `successors`, etc.) have branch (control flow) +/// queries (`dominates()`, `predecessors`, `successors`, etc.) have branch (control flow) /// significance. #[derive(Debug, Clone)] pub(super) struct BasicCoverageBlockData { @@ -594,7 +594,7 @@ impl TraverseCoverageGraphWithLoops { // branching block would have given an `Expression` (or vice versa). let (some_successor_to_add, some_loop_header) = if let Some((_, loop_header)) = context.loop_backedges { - if basic_coverage_blocks.is_dominated_by(successor, loop_header) { + if basic_coverage_blocks.dominates(loop_header, successor) { (Some(successor), Some(loop_header)) } else { (None, None) @@ -666,15 +666,15 @@ pub(super) fn find_loop_backedges( // // The overall complexity appears to be comparable to many other MIR transform algorithms, and I // don't expect that this function is creating a performance hot spot, but if this becomes an - // issue, there may be ways to optimize the `is_dominated_by` algorithm (as indicated by an + // issue, there may be ways to optimize the `dominates` algorithm (as indicated by an // existing `FIXME` comment in that code), or possibly ways to optimize it's usage here, perhaps // by keeping track of results for visited `BasicCoverageBlock`s if they can be used to short - // circuit downstream `is_dominated_by` checks. + // circuit downstream `dominates` checks. // // For now, that kind of optimization seems unnecessarily complicated. for (bcb, _) in basic_coverage_blocks.iter_enumerated() { for &successor in &basic_coverage_blocks.successors[bcb] { - if basic_coverage_blocks.is_dominated_by(bcb, successor) { + if basic_coverage_blocks.dominates(successor, bcb) { let loop_header = successor; let backedge_from_bcb = bcb; debug!( diff --git a/compiler/rustc_mir_transform/src/coverage/mod.rs b/compiler/rustc_mir_transform/src/coverage/mod.rs index 1468afc64..9a6171598 100644 --- a/compiler/rustc_mir_transform/src/coverage/mod.rs +++ b/compiler/rustc_mir_transform/src/coverage/mod.rs @@ -540,7 +540,8 @@ fn fn_sig_and_body( // FIXME(#79625): Consider improving MIR to provide the information needed, to avoid going back // to HIR for it. let hir_node = tcx.hir().get_if_local(def_id).expect("expected DefId is local"); - let fn_body_id = hir::map::associated_body(hir_node).expect("HIR node is a function with body"); + let (_, fn_body_id) = + hir::map::associated_body(hir_node).expect("HIR node is a function with body"); (hir_node.fn_sig(), tcx.hir().body(fn_body_id)) } diff --git a/compiler/rustc_mir_transform/src/coverage/spans.rs b/compiler/rustc_mir_transform/src/coverage/spans.rs index c54348404..8ee316773 100644 --- a/compiler/rustc_mir_transform/src/coverage/spans.rs +++ b/compiler/rustc_mir_transform/src/coverage/spans.rs @@ -63,7 +63,7 @@ impl CoverageStatement { /// Note: A `CoverageStatement` merged into another CoverageSpan may come from a `BasicBlock` that /// is not part of the `CoverageSpan` bcb if the statement was included because it's `Span` matches /// or is subsumed by the `Span` associated with this `CoverageSpan`, and it's `BasicBlock` -/// `is_dominated_by()` the `BasicBlock`s in this `CoverageSpan`. +/// `dominates()` the `BasicBlock`s in this `CoverageSpan`. #[derive(Debug, Clone)] pub(super) struct CoverageSpan { pub span: Span, @@ -407,7 +407,7 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { if self.prev().is_macro_expansion() && self.curr().is_macro_expansion() { // Macros that expand to include branching (such as // `assert_eq!()`, `assert_ne!()`, `info!()`, `debug!()`, or - // `trace!()) typically generate callee spans with identical + // `trace!()`) typically generate callee spans with identical // ranges (typically the full span of the macro) for all // `BasicBlocks`. This makes it impossible to distinguish // the condition (`if val1 != val2`) from the optional @@ -694,7 +694,7 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { /// `prev.span.hi()` will be greater than (further right of) `prev_original_span.hi()`. /// If prev.span() was split off to the right of a closure, prev.span().lo() will be /// greater than prev_original_span.lo(). The actual span of `prev_original_span` is - /// not as important as knowing that `prev()` **used to have the same span** as `curr(), + /// not as important as knowing that `prev()` **used to have the same span** as `curr()`, /// which means their sort order is still meaningful for determining the dominator /// relationship. /// @@ -705,12 +705,12 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { fn hold_pending_dups_unless_dominated(&mut self) { // Equal coverage spans are ordered by dominators before dominated (if any), so it should be // impossible for `curr` to dominate any previous `CoverageSpan`. - debug_assert!(!self.span_bcb_is_dominated_by(self.prev(), self.curr())); + debug_assert!(!self.span_bcb_dominates(self.curr(), self.prev())); let initial_pending_count = self.pending_dups.len(); if initial_pending_count > 0 { let mut pending_dups = self.pending_dups.split_off(0); - pending_dups.retain(|dup| !self.span_bcb_is_dominated_by(self.curr(), dup)); + pending_dups.retain(|dup| !self.span_bcb_dominates(dup, self.curr())); self.pending_dups.append(&mut pending_dups); if self.pending_dups.len() < initial_pending_count { debug!( @@ -721,7 +721,7 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { } } - if self.span_bcb_is_dominated_by(self.curr(), self.prev()) { + if self.span_bcb_dominates(self.prev(), self.curr()) { debug!( " different bcbs but SAME spans, and prev dominates curr. Discard prev={:?}", self.prev() @@ -787,8 +787,8 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> { } } - fn span_bcb_is_dominated_by(&self, covspan: &CoverageSpan, dom_covspan: &CoverageSpan) -> bool { - self.basic_coverage_blocks.is_dominated_by(covspan.bcb, dom_covspan.bcb) + fn span_bcb_dominates(&self, dom_covspan: &CoverageSpan, covspan: &CoverageSpan) -> bool { + self.basic_coverage_blocks.dominates(dom_covspan.bcb, covspan.bcb) } } @@ -802,6 +802,8 @@ pub(super) fn filtered_statement_span(statement: &Statement<'_>) -> Option | StatementKind::StorageDead(_) // Coverage should not be encountered, but don't inject coverage coverage | StatementKind::Coverage(_) + // Ignore `ConstEvalCounter`s + | StatementKind::ConstEvalCounter // Ignore `Nop`s | StatementKind::Nop => None, diff --git a/compiler/rustc_mir_transform/src/ctfe_limit.rs b/compiler/rustc_mir_transform/src/ctfe_limit.rs new file mode 100644 index 000000000..1b3ac78fb --- /dev/null +++ b/compiler/rustc_mir_transform/src/ctfe_limit.rs @@ -0,0 +1,58 @@ +//! A pass that inserts the `ConstEvalCounter` instruction into any blocks that have a back edge +//! (thus indicating there is a loop in the CFG), or whose terminator is a function call. +use crate::MirPass; + +use rustc_data_structures::graph::dominators::Dominators; +use rustc_middle::mir::{ + BasicBlock, BasicBlockData, Body, Statement, StatementKind, TerminatorKind, +}; +use rustc_middle::ty::TyCtxt; + +pub struct CtfeLimit; + +impl<'tcx> MirPass<'tcx> for CtfeLimit { + #[instrument(skip(self, _tcx, body))] + fn run_pass(&self, _tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let doms = body.basic_blocks.dominators(); + let indices: Vec = body + .basic_blocks + .iter_enumerated() + .filter_map(|(node, node_data)| { + if matches!(node_data.terminator().kind, TerminatorKind::Call { .. }) + // Back edges in a CFG indicate loops + || has_back_edge(&doms, node, &node_data) + { + Some(node) + } else { + None + } + }) + .collect(); + for index in indices { + insert_counter( + body.basic_blocks_mut() + .get_mut(index) + .expect("basic_blocks index {index} should exist"), + ); + } + } +} + +fn has_back_edge( + doms: &Dominators, + node: BasicBlock, + node_data: &BasicBlockData<'_>, +) -> bool { + if !doms.is_reachable(node) { + return false; + } + // Check if any of the dominators of the node are also the node's successor. + doms.dominators(node).any(|dom| node_data.terminator().successors().any(|succ| succ == dom)) +} + +fn insert_counter(basic_block_data: &mut BasicBlockData<'_>) { + basic_block_data.statements.push(Statement { + source_info: basic_block_data.terminator().source_info, + kind: StatementKind::ConstEvalCounter, + }); +} diff --git a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs index c75fe2327..49ded10ba 100644 --- a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs +++ b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs @@ -5,6 +5,7 @@ use rustc_const_eval::const_eval::CheckAlignment; use rustc_const_eval::interpret::{ConstValue, ImmTy, Immediate, InterpCx, Scalar}; use rustc_data_structures::fx::FxHashMap; +use rustc_hir::def::DefKind; use rustc_middle::mir::visit::{MutVisitor, Visitor}; use rustc_middle::mir::*; use rustc_middle::ty::{self, Ty, TyCtxt}; @@ -12,6 +13,7 @@ use rustc_mir_dataflow::value_analysis::{Map, State, TrackElem, ValueAnalysis, V use rustc_mir_dataflow::{lattice::FlatSet, Analysis, ResultsVisitor, SwitchIntEdgeEffects}; use rustc_span::DUMMY_SP; use rustc_target::abi::Align; +use rustc_target::abi::VariantIdx; use crate::MirPass; @@ -29,14 +31,12 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp { #[instrument(skip_all level = "debug")] fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + debug!(def_id = ?body.source.def_id()); if tcx.sess.mir_opt_level() < 4 && body.basic_blocks.len() > BLOCK_LIMIT { debug!("aborted dataflow const prop due too many basic blocks"); return; } - // Decide which places to track during the analysis. - let map = Map::from_filter(tcx, body, Ty::is_scalar); - // We want to have a somewhat linear runtime w.r.t. the number of statements/terminators. // Let's call this number `n`. Dataflow analysis has `O(h*n)` transfer function // applications, where `h` is the height of the lattice. Because the height of our lattice @@ -45,10 +45,10 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp { // `O(num_nodes * tracked_places * n)` in terms of time complexity. Since the number of // map nodes is strongly correlated to the number of tracked places, this becomes more or // less `O(n)` if we place a constant limit on the number of tracked places. - if tcx.sess.mir_opt_level() < 4 && map.tracked_places() > PLACE_LIMIT { - debug!("aborted dataflow const prop due to too many tracked places"); - return; - } + let place_limit = if tcx.sess.mir_opt_level() < 4 { Some(PLACE_LIMIT) } else { None }; + + // Decide which places to track during the analysis. + let map = Map::from_filter(tcx, body, Ty::is_scalar, place_limit); // Perform the actual dataflow analysis. let analysis = ConstAnalysis::new(tcx, body, map); @@ -62,14 +62,31 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp { } } -struct ConstAnalysis<'tcx> { +struct ConstAnalysis<'a, 'tcx> { map: Map, tcx: TyCtxt<'tcx>, + local_decls: &'a LocalDecls<'tcx>, ecx: InterpCx<'tcx, 'tcx, DummyMachine>, param_env: ty::ParamEnv<'tcx>, } -impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> { +impl<'tcx> ConstAnalysis<'_, 'tcx> { + fn eval_discriminant( + &self, + enum_ty: Ty<'tcx>, + variant_index: VariantIdx, + ) -> Option> { + if !enum_ty.is_enum() { + return None; + } + let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?; + let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?; + let discr_value = Scalar::try_from_uint(discr.val, discr_layout.size)?; + Some(ScalarTy(discr_value, discr.ty)) + } +} + +impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { type Value = FlatSet>; const NAME: &'static str = "ConstAnalysis"; @@ -78,6 +95,25 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> { &self.map } + fn handle_statement(&self, statement: &Statement<'tcx>, state: &mut State) { + match statement.kind { + StatementKind::SetDiscriminant { box ref place, variant_index } => { + state.flood_discr(place.as_ref(), &self.map); + if self.map.find_discr(place.as_ref()).is_some() { + let enum_ty = place.ty(self.local_decls, self.tcx).ty; + if let Some(discr) = self.eval_discriminant(enum_ty, variant_index) { + state.assign_discr( + place.as_ref(), + ValueOrPlace::Value(FlatSet::Elem(discr)), + &self.map, + ); + } + } + } + _ => self.super_statement(statement, state), + } + } + fn handle_assign( &self, target: Place<'tcx>, @@ -85,13 +121,59 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> { state: &mut State, ) { match rvalue { + Rvalue::Aggregate(kind, operands) => { + // If we assign `target = Enum::Variant#0(operand)`, + // we must make sure that all `target as Variant#i` are `Top`. + state.flood(target.as_ref(), self.map()); + + if let Some(target_idx) = self.map().find(target.as_ref()) { + let (variant_target, variant_index) = match **kind { + AggregateKind::Tuple | AggregateKind::Closure(..) => { + (Some(target_idx), None) + } + AggregateKind::Adt(def_id, variant_index, ..) => { + match self.tcx.def_kind(def_id) { + DefKind::Struct => (Some(target_idx), None), + DefKind::Enum => ( + self.map.apply(target_idx, TrackElem::Variant(variant_index)), + Some(variant_index), + ), + _ => (None, None), + } + } + _ => (None, None), + }; + if let Some(variant_target_idx) = variant_target { + for (field_index, operand) in operands.iter().enumerate() { + if let Some(field) = self.map().apply( + variant_target_idx, + TrackElem::Field(Field::from_usize(field_index)), + ) { + let result = self.handle_operand(operand, state); + state.insert_idx(field, result, self.map()); + } + } + } + if let Some(variant_index) = variant_index + && let Some(discr_idx) = self.map().apply(target_idx, TrackElem::Discriminant) + { + // We are assigning the discriminant as part of an aggregate. + // This discriminant can only alias a variant field's value if the operand + // had an invalid value for that type. + // Using invalid values is UB, so we are allowed to perform the assignment + // without extra flooding. + let enum_ty = target.ty(self.local_decls, self.tcx).ty; + if let Some(discr_val) = self.eval_discriminant(enum_ty, variant_index) { + state.insert_value_idx(discr_idx, FlatSet::Elem(discr_val), &self.map); + } + } + } + } Rvalue::CheckedBinaryOp(op, box (left, right)) => { + // Flood everything now, so we can use `insert_value_idx` directly later. + state.flood(target.as_ref(), self.map()); + let target = self.map().find(target.as_ref()); - if let Some(target) = target { - // We should not track any projections other than - // what is overwritten below, but just in case... - state.flood_idx(target, self.map()); - } let value_target = target .and_then(|target| self.map().apply(target, TrackElem::Field(0_u32.into()))); @@ -102,26 +184,19 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> { let (val, overflow) = self.binary_op(state, *op, left, right); if let Some(value_target) = value_target { - state.assign_idx(value_target, ValueOrPlace::Value(val), self.map()); + // We have flooded `target` earlier. + state.insert_value_idx(value_target, val, self.map()); } if let Some(overflow_target) = overflow_target { let overflow = match overflow { FlatSet::Top => FlatSet::Top, FlatSet::Elem(overflow) => { - if overflow { - // Overflow cannot be reliably propagated. See: https://github.com/rust-lang/rust/pull/101168#issuecomment-1288091446 - FlatSet::Top - } else { - self.wrap_scalar(Scalar::from_bool(false), self.tcx.types.bool) - } + self.wrap_scalar(Scalar::from_bool(overflow), self.tcx.types.bool) } FlatSet::Bottom => FlatSet::Bottom, }; - state.assign_idx( - overflow_target, - ValueOrPlace::Value(overflow), - self.map(), - ); + // We have flooded `target` earlier. + state.insert_value_idx(overflow_target, overflow, self.map()); } } } @@ -170,6 +245,9 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> { FlatSet::Bottom => ValueOrPlace::Value(FlatSet::Bottom), FlatSet::Top => ValueOrPlace::Value(FlatSet::Top), }, + Rvalue::Discriminant(place) => { + ValueOrPlace::Value(state.get_discr(place.as_ref(), self.map())) + } _ => self.super_rvalue(rvalue, state), } } @@ -243,12 +321,13 @@ impl<'tcx> std::fmt::Debug for ScalarTy<'tcx> { } } -impl<'tcx> ConstAnalysis<'tcx> { - pub fn new(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, map: Map) -> Self { +impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> { + pub fn new(tcx: TyCtxt<'tcx>, body: &'a Body<'tcx>, map: Map) -> Self { let param_env = tcx.param_env(body.source.def_id()); Self { map, tcx, + local_decls: &body.local_decls, ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine), param_env: param_env, } @@ -441,6 +520,21 @@ impl<'tcx, 'map, 'a> Visitor<'tcx> for OperandCollector<'tcx, 'map, 'a> { _ => (), } } + + fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) { + match rvalue { + Rvalue::Discriminant(place) => { + match self.state.get_discr(place.as_ref(), self.visitor.map) { + FlatSet::Top => (), + FlatSet::Elem(value) => { + self.visitor.before_effect.insert((location, *place), value); + } + FlatSet::Bottom => (), + } + } + _ => self.super_rvalue(rvalue, location), + } + } } struct DummyMachine; diff --git a/compiler/rustc_mir_transform/src/dead_store_elimination.rs b/compiler/rustc_mir_transform/src/dead_store_elimination.rs index 09546330c..9dbfb089d 100644 --- a/compiler/rustc_mir_transform/src/dead_store_elimination.rs +++ b/compiler/rustc_mir_transform/src/dead_store_elimination.rs @@ -53,6 +53,7 @@ pub fn eliminate<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, borrowed: &BitS | StatementKind::StorageDead(_) | StatementKind::Coverage(_) | StatementKind::Intrinsic(_) + | StatementKind::ConstEvalCounter | StatementKind::Nop => (), StatementKind::FakeRead(_) | StatementKind::AscribeUserType(_, _) => { diff --git a/compiler/rustc_mir_transform/src/deaggregator.rs b/compiler/rustc_mir_transform/src/deaggregator.rs deleted file mode 100644 index fe272de20..000000000 --- a/compiler/rustc_mir_transform/src/deaggregator.rs +++ /dev/null @@ -1,45 +0,0 @@ -use crate::util::expand_aggregate; -use crate::MirPass; -use rustc_middle::mir::*; -use rustc_middle::ty::TyCtxt; - -pub struct Deaggregator; - -impl<'tcx> MirPass<'tcx> for Deaggregator { - fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let basic_blocks = body.basic_blocks.as_mut_preserves_cfg(); - for bb in basic_blocks { - bb.expand_statements(|stmt| { - // FIXME(eddyb) don't match twice on `stmt.kind` (post-NLL). - match stmt.kind { - // FIXME(#48193) Deaggregate arrays when it's cheaper to do so. - StatementKind::Assign(box ( - _, - Rvalue::Aggregate(box AggregateKind::Array(_), _), - )) => { - return None; - } - StatementKind::Assign(box (_, Rvalue::Aggregate(_, _))) => {} - _ => return None, - } - - let stmt = stmt.replace_nop(); - let source_info = stmt.source_info; - let StatementKind::Assign(box (lhs, Rvalue::Aggregate(kind, operands))) = stmt.kind else { - bug!(); - }; - - Some(expand_aggregate( - lhs, - operands.into_iter().map(|op| { - let ty = op.ty(&body.local_decls, tcx); - (op, ty) - }), - *kind, - source_info, - tcx, - )) - }); - } - } -} diff --git a/compiler/rustc_mir_transform/src/deduce_param_attrs.rs b/compiler/rustc_mir_transform/src/deduce_param_attrs.rs index ddab7bbb2..89ca04a15 100644 --- a/compiler/rustc_mir_transform/src/deduce_param_attrs.rs +++ b/compiler/rustc_mir_transform/src/deduce_param_attrs.rs @@ -163,7 +163,7 @@ pub fn deduced_param_attrs<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> &'tcx [Ded // Codegen won't use this information for anything if all the function parameters are passed // directly. Detect that and bail, for compilation speed. - let fn_ty = tcx.type_of(def_id); + let fn_ty = tcx.type_of(def_id).subst_identity(); if matches!(fn_ty.kind(), ty::FnDef(..)) { if fn_ty .fn_sig(tcx) diff --git a/compiler/rustc_mir_transform/src/dest_prop.rs b/compiler/rustc_mir_transform/src/dest_prop.rs index 08e296a83..2e481b972 100644 --- a/compiler/rustc_mir_transform/src/dest_prop.rs +++ b/compiler/rustc_mir_transform/src/dest_prop.rs @@ -136,8 +136,8 @@ use rustc_index::bit_set::BitSet; use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; use rustc_middle::mir::{dump_mir, PassWhere}; use rustc_middle::mir::{ - traversal, BasicBlock, Body, InlineAsmOperand, Local, LocalKind, Location, Operand, Place, - Rvalue, Statement, StatementKind, TerminatorKind, + traversal, Body, InlineAsmOperand, Local, LocalKind, Location, Operand, Place, Rvalue, + Statement, StatementKind, TerminatorKind, }; use rustc_middle::ty::TyCtxt; use rustc_mir_dataflow::impls::MaybeLiveLocals; @@ -328,7 +328,8 @@ impl<'a, 'tcx> MutVisitor<'tcx> for Merger<'a, 'tcx> { match &statement.kind { StatementKind::Assign(box (dest, rvalue)) => { match rvalue { - Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) => { + Rvalue::CopyForDeref(place) + | Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) => { // These might've been turned into self-assignments by the replacement // (this includes the original statement we wanted to eliminate). if dest == place { @@ -467,7 +468,7 @@ impl<'a, 'body, 'alloc, 'tcx> FilterInformation<'a, 'body, 'alloc, 'tcx> { // to reuse the allocation. write_info: write_info_alloc, // Doesn't matter what we put here, will be overwritten before being used - at: Location { block: BasicBlock::from_u32(0), statement_index: 0 }, + at: Location::START, }; this.internal_filter_liveness(); } @@ -577,6 +578,7 @@ impl WriteInfo { self.add_place(**place); } StatementKind::Intrinsic(_) + | StatementKind::ConstEvalCounter | StatementKind::Nop | StatementKind::Coverage(_) | StatementKind::StorageLive(_) @@ -754,7 +756,7 @@ impl<'tcx> Visitor<'tcx> for FindAssignments<'_, '_, 'tcx> { fn visit_statement(&mut self, statement: &Statement<'tcx>, _: Location) { if let StatementKind::Assign(box ( lhs, - Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)), + Rvalue::CopyForDeref(rhs) | Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)), )) = &statement.kind { let Some((src, dest)) = places_to_candidate_pair(*lhs, *rhs, self.body) else { diff --git a/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs b/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs index 932134bd6..954bb5aff 100644 --- a/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs +++ b/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs @@ -17,9 +17,9 @@ pub fn build_ptr_tys<'tcx>( unique_did: DefId, nonnull_did: DefId, ) -> (Ty<'tcx>, Ty<'tcx>, Ty<'tcx>) { - let substs = tcx.intern_substs(&[pointee.into()]); - let unique_ty = tcx.bound_type_of(unique_did).subst(tcx, substs); - let nonnull_ty = tcx.bound_type_of(nonnull_did).subst(tcx, substs); + let substs = tcx.mk_substs(&[pointee.into()]); + let unique_ty = tcx.type_of(unique_did).subst(tcx, substs); + let nonnull_ty = tcx.type_of(nonnull_did).subst(tcx, substs); let ptr_ty = tcx.mk_imm_ptr(pointee); (unique_ty, nonnull_ty, ptr_ty) @@ -93,7 +93,7 @@ impl<'tcx> MirPass<'tcx> for ElaborateBoxDerefs { if let Some(def_id) = tcx.lang_items().owned_box() { let unique_did = tcx.adt_def(def_id).non_enum_variant().fields[0].did; - let Some(nonnull_def) = tcx.type_of(unique_did).ty_adt_def() else { + let Some(nonnull_def) = tcx.type_of(unique_did).subst_identity().ty_adt_def() else { span_bug!(tcx.def_span(unique_did), "expected Box to contain Unique") }; @@ -138,7 +138,7 @@ impl<'tcx> MirPass<'tcx> for ElaborateBoxDerefs { if let Some(mut new_projections) = new_projections { new_projections.extend_from_slice(&place.projection[last_deref..]); - place.projection = tcx.intern_place_elems(&new_projections); + place.projection = tcx.mk_place_elems(&new_projections); } } } diff --git a/compiler/rustc_mir_transform/src/elaborate_drops.rs b/compiler/rustc_mir_transform/src/elaborate_drops.rs index 65f4956d2..bdfd8dc6e 100644 --- a/compiler/rustc_mir_transform/src/elaborate_drops.rs +++ b/compiler/rustc_mir_transform/src/elaborate_drops.rs @@ -18,6 +18,35 @@ use rustc_span::Span; use rustc_target::abi::VariantIdx; use std::fmt; +/// During MIR building, Drop and DropAndReplace terminators are inserted in every place where a drop may occur. +/// However, in this phase, the presence of these terminators does not guarantee that a destructor will run, +/// as the target of the drop may be uninitialized. +/// In general, the compiler cannot determine at compile time whether a destructor will run or not. +/// +/// At a high level, this pass refines Drop and DropAndReplace to only run the destructor if the +/// target is initialized. The way this is achievied is by inserting drop flags for every variable +/// that may be dropped, and then using those flags to determine whether a destructor should run. +/// This pass also removes DropAndReplace, replacing it with a Drop paired with an assign statement. +/// Once this is complete, Drop terminators in the MIR correspond to a call to the "drop glue" or +/// "drop shim" for the type of the dropped place. +/// +/// This pass relies on dropped places having an associated move path, which is then used to determine +/// the initialization status of the place and its descendants. +/// It's worth noting that a MIR containing a Drop without an associated move path is probably ill formed, +/// as it would allow running a destructor on a place behind a reference: +/// +/// ```text +// fn drop_term(t: &mut T) { +// mir!( +// { +// Drop(*t, exit) +// } +// exit = { +// Return() +// } +// ) +// } +/// ``` pub struct ElaborateDrops; impl<'tcx> MirPass<'tcx> for ElaborateDrops { @@ -38,13 +67,11 @@ impl<'tcx> MirPass<'tcx> for ElaborateDrops { }; let un_derefer = UnDerefer { tcx: tcx, derefer_sidetable: side_table }; let elaborate_patch = { - let body = &*body; let env = MoveDataParamEnv { move_data, param_env }; - let dead_unwinds = find_dead_unwinds(tcx, body, &env, &un_derefer); + remove_dead_unwinds(tcx, body, &env, &un_derefer); let inits = MaybeInitializedPlaces::new(tcx, body, &env) .into_engine(tcx, body) - .dead_unwinds(&dead_unwinds) .pass_name("elaborate_drops") .iterate_to_fixpoint() .into_results_cursor(body); @@ -52,11 +79,12 @@ impl<'tcx> MirPass<'tcx> for ElaborateDrops { let uninits = MaybeUninitializedPlaces::new(tcx, body, &env) .mark_inactive_variants_as_uninit() .into_engine(tcx, body) - .dead_unwinds(&dead_unwinds) .pass_name("elaborate_drops") .iterate_to_fixpoint() .into_results_cursor(body); + let reachable = traversal::reachable_as_bitset(body); + ElaborateDropsCtxt { tcx, body, @@ -65,6 +93,7 @@ impl<'tcx> MirPass<'tcx> for ElaborateDrops { drop_flags: Default::default(), patch: MirPatch::new(body), un_derefer: un_derefer, + reachable, } .elaborate() }; @@ -73,22 +102,21 @@ impl<'tcx> MirPass<'tcx> for ElaborateDrops { } } -/// Returns the set of basic blocks whose unwind edges are known -/// to not be reachable, because they are `drop` terminators +/// Removes unwind edges which are known to be unreachable, because they are in `drop` terminators /// that can't drop anything. -fn find_dead_unwinds<'tcx>( +fn remove_dead_unwinds<'tcx>( tcx: TyCtxt<'tcx>, - body: &Body<'tcx>, + body: &mut Body<'tcx>, env: &MoveDataParamEnv<'tcx>, und: &UnDerefer<'tcx>, -) -> BitSet { - debug!("find_dead_unwinds({:?})", body.span); +) { + debug!("remove_dead_unwinds({:?})", body.span); // We only need to do this pass once, because unwind edges can only // reach cleanup blocks, which can't have unwind edges themselves. - let mut dead_unwinds = BitSet::new_empty(body.basic_blocks.len()); + let mut dead_unwinds = Vec::new(); let mut flow_inits = MaybeInitializedPlaces::new(tcx, body, &env) .into_engine(tcx, body) - .pass_name("find_dead_unwinds") + .pass_name("remove_dead_unwinds") .iterate_to_fixpoint() .into_results_cursor(body); for (bb, bb_data) in body.basic_blocks.iter_enumerated() { @@ -100,16 +128,16 @@ fn find_dead_unwinds<'tcx>( _ => continue, }; - debug!("find_dead_unwinds @ {:?}: {:?}", bb, bb_data); + debug!("remove_dead_unwinds @ {:?}: {:?}", bb, bb_data); let LookupResult::Exact(path) = env.move_data.rev_lookup.find(place.as_ref()) else { - debug!("find_dead_unwinds: has parent; skipping"); + debug!("remove_dead_unwinds: has parent; skipping"); continue; }; flow_inits.seek_before_primary_effect(body.terminator_loc(bb)); debug!( - "find_dead_unwinds @ {:?}: path({:?})={:?}; init_data={:?}", + "remove_dead_unwinds @ {:?}: path({:?})={:?}; init_data={:?}", bb, place, path, @@ -121,13 +149,22 @@ fn find_dead_unwinds<'tcx>( maybe_live |= flow_inits.contains(child); }); - debug!("find_dead_unwinds @ {:?}: maybe_live={}", bb, maybe_live); + debug!("remove_dead_unwinds @ {:?}: maybe_live={}", bb, maybe_live); if !maybe_live { - dead_unwinds.insert(bb); + dead_unwinds.push(bb); } } - dead_unwinds + if dead_unwinds.is_empty() { + return; + } + + let basic_blocks = body.basic_blocks.as_mut(); + for &bb in dead_unwinds.iter() { + if let Some(unwind) = basic_blocks[bb].terminator_mut().unwind_mut() { + *unwind = None; + } + } } struct InitializationData<'mir, 'tcx> { @@ -261,6 +298,7 @@ struct ElaborateDropsCtxt<'a, 'tcx> { drop_flags: FxHashMap, patch: MirPatch<'tcx>, un_derefer: UnDerefer<'tcx>, + reachable: BitSet, } impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> { @@ -300,6 +338,9 @@ impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> { fn collect_drop_flags(&mut self) { for (bb, data) in self.body.basic_blocks.iter_enumerated() { + if !self.reachable.contains(bb) { + continue; + } let terminator = data.terminator(); let place = match terminator.kind { TerminatorKind::Drop { ref place, .. } @@ -355,6 +396,9 @@ impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> { fn elaborate_drops(&mut self) { for (bb, data) in self.body.basic_blocks.iter_enumerated() { + if !self.reachable.contains(bb) { + continue; + } let loc = Location { block: bb, statement_index: data.statements.len() }; let terminator = data.terminator(); @@ -512,6 +556,9 @@ impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> { fn drop_flags_for_fn_rets(&mut self) { for (bb, data) in self.body.basic_blocks.iter_enumerated() { + if !self.reachable.contains(bb) { + continue; + } if let TerminatorKind::Call { destination, target: Some(tgt), cleanup: Some(_), .. } = data.terminator().kind @@ -547,6 +594,9 @@ impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> { // clobbered before they are read. for (bb, data) in self.body.basic_blocks.iter_enumerated() { + if !self.reachable.contains(bb) { + continue; + } debug!("drop_flags_for_locs({:?})", data); for i in 0..(data.statements.len() + 1) { debug!("drop_flag_for_locs: stmt {}", i); diff --git a/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs b/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs index 1244c1802..e6546911a 100644 --- a/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs +++ b/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs @@ -49,7 +49,7 @@ fn has_ffi_unwind_calls(tcx: TyCtxt<'_>, local_def_id: LocalDefId) -> bool { let body = &*tcx.mir_built(ty::WithOptConstParam::unknown(local_def_id)).borrow(); - let body_ty = tcx.type_of(def_id); + let body_ty = tcx.type_of(def_id).skip_binder(); let body_abi = match body_ty.kind() { ty::FnDef(..) => body_ty.fn_sig(tcx).abi(), ty::Closure(..) => Abi::RustCall, diff --git a/compiler/rustc_mir_transform/src/function_item_references.rs b/compiler/rustc_mir_transform/src/function_item_references.rs index b708d780b..66d32b954 100644 --- a/compiler/rustc_mir_transform/src/function_item_references.rs +++ b/compiler/rustc_mir_transform/src/function_item_references.rs @@ -79,7 +79,7 @@ impl<'tcx> FunctionItemRefChecker<'_, 'tcx> { for bound in bounds { if let Some(bound_ty) = self.is_pointer_trait(&bound.kind().skip_binder()) { // Get the argument types as they appear in the function signature. - let arg_defs = self.tcx.fn_sig(def_id).skip_binder().inputs(); + let arg_defs = self.tcx.fn_sig(def_id).subst_identity().skip_binder().inputs(); for (arg_num, arg_def) in arg_defs.iter().enumerate() { // For all types reachable from the argument type in the fn sig for generic_inner_ty in arg_def.walk() { @@ -111,11 +111,9 @@ impl<'tcx> FunctionItemRefChecker<'_, 'tcx> { /// If the given predicate is the trait `fmt::Pointer`, returns the bound parameter type. fn is_pointer_trait(&self, bound: &PredicateKind<'tcx>) -> Option> { if let ty::PredicateKind::Clause(ty::Clause::Trait(predicate)) = bound { - if self.tcx.is_diagnostic_item(sym::Pointer, predicate.def_id()) { - Some(predicate.trait_ref.self_ty()) - } else { - None - } + self.tcx + .is_diagnostic_item(sym::Pointer, predicate.def_id()) + .then(|| predicate.trait_ref.self_ty()) } else { None } @@ -161,7 +159,8 @@ impl<'tcx> FunctionItemRefChecker<'_, 'tcx> { .as_ref() .assert_crate_local() .lint_root; - let fn_sig = self.tcx.fn_sig(fn_id); + // FIXME: use existing printing routines to print the function signature + let fn_sig = self.tcx.fn_sig(fn_id).subst(self.tcx, fn_substs); let unsafety = fn_sig.unsafety().prefix_str(); let abi = match fn_sig.abi() { Abi::Rust => String::from(""), diff --git a/compiler/rustc_mir_transform/src/generator.rs b/compiler/rustc_mir_transform/src/generator.rs index 39c61a34a..2e97312ee 100644 --- a/compiler/rustc_mir_transform/src/generator.rs +++ b/compiler/rustc_mir_transform/src/generator.rs @@ -52,9 +52,9 @@ use crate::deref_separator::deref_finder; use crate::simplify; -use crate::util::expand_aggregate; use crate::MirPass; -use rustc_data_structures::fx::FxHashMap; +use rustc_data_structures::fx::{FxHashMap, FxHashSet}; +use rustc_errors::pluralize; use rustc_hir as hir; use rustc_hir::lang_items::LangItem; use rustc_hir::GeneratorKind; @@ -70,6 +70,9 @@ use rustc_mir_dataflow::impls::{ }; use rustc_mir_dataflow::storage::always_storage_live_locals; use rustc_mir_dataflow::{self, Analysis}; +use rustc_span::def_id::DefId; +use rustc_span::symbol::sym; +use rustc_span::Span; use rustc_target::abi::VariantIdx; use rustc_target::spec::PanicStrategy; use std::{iter, ops}; @@ -123,7 +126,7 @@ impl<'tcx> MutVisitor<'tcx> for DerefArgVisitor<'tcx> { place, Place { local: SELF_ARG, - projection: self.tcx().intern_place_elems(&[ProjectionElem::Deref]), + projection: self.tcx().mk_place_elems(&[ProjectionElem::Deref]), }, self.tcx, ); @@ -159,10 +162,9 @@ impl<'tcx> MutVisitor<'tcx> for PinArgVisitor<'tcx> { place, Place { local: SELF_ARG, - projection: self.tcx().intern_place_elems(&[ProjectionElem::Field( - Field::new(0), - self.ref_gen_ty, - )]), + projection: self + .tcx() + .mk_place_elems(&[ProjectionElem::Field(Field::new(0), self.ref_gen_ty)]), }, self.tcx, ); @@ -184,7 +186,7 @@ fn replace_base<'tcx>(place: &mut Place<'tcx>, new_base: Place<'tcx>, tcx: TyCtx let mut new_projection = new_base.projection.to_vec(); new_projection.append(&mut place.projection.to_vec()); - place.projection = tcx.intern_place_elems(&new_projection); + place.projection = tcx.mk_place_elems(&new_projection); } const SELF_ARG: Local = Local::from_u32(1); @@ -268,31 +270,26 @@ impl<'tcx> TransformVisitor<'tcx> { assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0); // FIXME(swatinem): assert that `val` is indeed unit? - statements.extend(expand_aggregate( - Place::return_place(), - std::iter::empty(), - kind, + statements.push(Statement { + kind: StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::Aggregate(Box::new(kind), vec![]), + ))), source_info, - self.tcx, - )); + }); return; } // else: `Poll::Ready(x)`, `GeneratorState::Yielded(x)` or `GeneratorState::Complete(x)` assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1); - let ty = self - .tcx - .bound_type_of(self.state_adt_ref.variant(idx).fields[0].did) - .subst(self.tcx, self.state_substs); - - statements.extend(expand_aggregate( - Place::return_place(), - std::iter::once((val, ty)), - kind, + statements.push(Statement { + kind: StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::Aggregate(Box::new(kind), vec![val]), + ))), source_info, - self.tcx, - )); + }); } // Create a Place referencing a generator struct field @@ -302,7 +299,7 @@ impl<'tcx> TransformVisitor<'tcx> { let mut projection = base.projection.to_vec(); projection.push(ProjectionElem::Field(Field::new(idx), ty)); - Place { local: base.local, projection: self.tcx.intern_place_elems(&projection) } + Place { local: base.local, projection: self.tcx.mk_place_elems(&projection) } } // Create a statement which changes the discriminant @@ -429,7 +426,7 @@ fn make_generator_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body let pin_did = tcx.require_lang_item(LangItem::Pin, Some(body.span)); let pin_adt_ref = tcx.adt_def(pin_did); - let substs = tcx.intern_substs(&[ref_gen_ty.into()]); + let substs = tcx.mk_substs(&[ref_gen_ty.into()]); let pin_ref_gen_ty = tcx.mk_adt(pin_adt_ref, substs); // Replace the by ref generator argument @@ -489,7 +486,7 @@ fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, None); - for bb in BasicBlock::new(0)..body.basic_blocks.next_index() { + for bb in START_BLOCK..body.basic_blocks.next_index() { let bb_data = &body[bb]; if bb_data.is_cleanup { continue; @@ -854,7 +851,7 @@ fn sanitize_witness<'tcx>( body: &Body<'tcx>, witness: Ty<'tcx>, upvars: Vec>, - saved_locals: &GeneratorSavedLocals, + layout: &GeneratorLayout<'tcx>, ) { let did = body.source.def_id(); let param_env = tcx.param_env(did); @@ -873,31 +870,36 @@ fn sanitize_witness<'tcx>( } }; - for (local, decl) in body.local_decls.iter_enumerated() { - // Ignore locals which are internal or not saved between yields. - if !saved_locals.contains(local) || decl.internal { + let mut mismatches = Vec::new(); + for fty in &layout.field_tys { + if fty.ignore_for_traits { continue; } - let decl_ty = tcx.normalize_erasing_regions(param_env, decl.ty); + let decl_ty = tcx.normalize_erasing_regions(param_env, fty.ty); // Sanity check that typeck knows about the type of locals which are // live across a suspension point if !allowed.contains(&decl_ty) && !allowed_upvars.contains(&decl_ty) { - span_bug!( - body.span, - "Broken MIR: generator contains type {} in MIR, \ - but typeck only knows about {} and {:?}", - decl_ty, - allowed, - allowed_upvars - ); + mismatches.push(decl_ty); } } + + if !mismatches.is_empty() { + span_bug!( + body.span, + "Broken MIR: generator contains type {:?} in MIR, \ + but typeck only knows about {} and {:?}", + mismatches, + allowed, + allowed_upvars + ); + } } fn compute_layout<'tcx>( + tcx: TyCtxt<'tcx>, liveness: LivenessInfo, - body: &mut Body<'tcx>, + body: &Body<'tcx>, ) -> ( FxHashMap, VariantIdx, usize)>, GeneratorLayout<'tcx>, @@ -915,9 +917,33 @@ fn compute_layout<'tcx>( let mut locals = IndexVec::::new(); let mut tys = IndexVec::::new(); for (saved_local, local) in saved_locals.iter_enumerated() { - locals.push(local); - tys.push(body.local_decls[local].ty); debug!("generator saved local {:?} => {:?}", saved_local, local); + + locals.push(local); + let decl = &body.local_decls[local]; + debug!(?decl); + + let ignore_for_traits = if tcx.sess.opts.unstable_opts.drop_tracking_mir { + match decl.local_info { + // Do not include raw pointers created from accessing `static` items, as those could + // well be re-created by another access to the same static. + Some(box LocalInfo::StaticRef { is_thread_local, .. }) => !is_thread_local, + // Fake borrows are only read by fake reads, so do not have any reality in + // post-analysis MIR. + Some(box LocalInfo::FakeBorrow) => true, + _ => false, + } + } else { + // FIXME(#105084) HIR-based drop tracking does not account for all the temporaries that + // MIR building may introduce. This leads to wrongly ignored types, but this is + // necessary for internal consistency and to avoid ICEs. + decl.internal + }; + let decl = + GeneratorSavedTy { ty: decl.ty, source_info: decl.source_info, ignore_for_traits }; + debug!(?decl); + + tys.push(decl); } // Leave empty variants for the UNRESUMED, RETURNED, and POISONED states. @@ -947,7 +973,7 @@ fn compute_layout<'tcx>( // just use the first one here. That's fine; fields do not move // around inside generators, so it doesn't matter which variant // index we access them by. - remap.entry(locals[saved_local]).or_insert((tys[saved_local], variant_index, idx)); + remap.entry(locals[saved_local]).or_insert((tys[saved_local].ty, variant_index, idx)); } variant_fields.push(fields); variant_source_info.push(source_info_at_suspension_points[suspension_point_idx]); @@ -957,6 +983,7 @@ fn compute_layout<'tcx>( let layout = GeneratorLayout { field_tys: tys, variant_fields, variant_source_info, storage_conflicts }; + debug!(?layout); (remap, layout, storage_liveness) } @@ -1227,7 +1254,7 @@ fn create_generator_resume_function<'tcx>( use rustc_middle::mir::AssertKind::{ResumedAfterPanic, ResumedAfterReturn}; // Jump to the entry point on the unresumed - cases.insert(0, (UNRESUMED, BasicBlock::new(0))); + cases.insert(0, (UNRESUMED, START_BLOCK)); // Panic when resumed on the returned or poisoned state let generator_kind = body.generator_kind().unwrap(); @@ -1351,6 +1378,42 @@ fn create_cases<'tcx>( .collect() } +#[instrument(level = "debug", skip(tcx), ret)] +pub(crate) fn mir_generator_witnesses<'tcx>( + tcx: TyCtxt<'tcx>, + def_id: DefId, +) -> GeneratorLayout<'tcx> { + assert!(tcx.sess.opts.unstable_opts.drop_tracking_mir); + let def_id = def_id.expect_local(); + + let (body, _) = tcx.mir_promoted(ty::WithOptConstParam::unknown(def_id)); + let body = body.borrow(); + let body = &*body; + + // The first argument is the generator type passed by value + let gen_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty; + + // Get the interior types and substs which typeck computed + let movable = match *gen_ty.kind() { + ty::Generator(_, _, movability) => movability == hir::Movability::Movable, + _ => span_bug!(body.span, "unexpected generator type {}", gen_ty), + }; + + // When first entering the generator, move the resume argument into its new local. + let always_live_locals = always_storage_live_locals(&body); + + let liveness_info = locals_live_across_suspend_points(tcx, body, &always_live_locals, movable); + + // Extract locals which are live across suspension point into `layout` + // `remap` gives a mapping from local indices onto generator struct indices + // `storage_liveness` tells us which locals have live storage at suspension points + let (_, generator_layout, _) = compute_layout(tcx, liveness_info, body); + + check_suspend_tys(tcx, &generator_layout, &body); + + generator_layout +} + impl<'tcx> MirPass<'tcx> for StateTransform { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let Some(yield_ty) = body.yield_ty() else { @@ -1363,14 +1426,14 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // The first argument is the generator type passed by value let gen_ty = body.local_decls.raw[1].ty; - // Get the interior types and substs which typeck computed - let (upvars, interior, discr_ty, movable) = match *gen_ty.kind() { + // Get the discriminant type and substs which typeck computed + let (discr_ty, upvars, interior, movable) = match *gen_ty.kind() { ty::Generator(_, substs, movability) => { let substs = substs.as_generator(); ( - substs.upvar_tys().collect(), - substs.witness(), substs.discr_ty(tcx), + substs.upvar_tys().collect::>(), + substs.witness(), movability == hir::Movability::Movable, ) } @@ -1386,13 +1449,13 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // Compute Poll let poll_did = tcx.require_lang_item(LangItem::Poll, None); let poll_adt_ref = tcx.adt_def(poll_did); - let poll_substs = tcx.intern_substs(&[body.return_ty().into()]); + let poll_substs = tcx.mk_substs(&[body.return_ty().into()]); (poll_adt_ref, poll_substs) } else { // Compute GeneratorState let state_did = tcx.require_lang_item(LangItem::GeneratorState, None); let state_adt_ref = tcx.adt_def(state_did); - let state_substs = tcx.intern_substs(&[yield_ty.into(), body.return_ty().into()]); + let state_substs = tcx.mk_substs(&[yield_ty.into(), body.return_ty().into()]); (state_adt_ref, state_substs) }; let ret_ty = tcx.mk_adt(state_adt_ref, state_substs); @@ -1417,7 +1480,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // When first entering the generator, move the resume argument into its new local. let source_info = SourceInfo::outermost(body.span); - let stmts = &mut body.basic_blocks_mut()[BasicBlock::new(0)].statements; + let stmts = &mut body.basic_blocks_mut()[START_BLOCK].statements; stmts.insert( 0, Statement { @@ -1434,8 +1497,6 @@ impl<'tcx> MirPass<'tcx> for StateTransform { let liveness_info = locals_live_across_suspend_points(tcx, body, &always_live_locals, movable); - sanitize_witness(tcx, body, interior, upvars, &liveness_info.saved_locals); - if tcx.sess.opts.unstable_opts.validate_mir { let mut vis = EnsureGeneratorFieldAssignmentsNeverAlias { assigned_local: None, @@ -1449,7 +1510,13 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // Extract locals which are live across suspension point into `layout` // `remap` gives a mapping from local indices onto generator struct indices // `storage_liveness` tells us which locals have live storage at suspension points - let (remap, layout, storage_liveness) = compute_layout(liveness_info, body); + let (remap, layout, storage_liveness) = compute_layout(tcx, liveness_info, body); + + if tcx.sess.opts.unstable_opts.validate_mir + && !tcx.sess.opts.unstable_opts.drop_tracking_mir + { + sanitize_witness(tcx, body, interior, upvars, &layout); + } let can_return = can_return(tcx, body, tcx.param_env(body.source.def_id())); @@ -1583,6 +1650,7 @@ impl<'tcx> Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> { | StatementKind::AscribeUserType(..) | StatementKind::Coverage(..) | StatementKind::Intrinsic(..) + | StatementKind::ConstEvalCounter | StatementKind::Nop => {} } } @@ -1631,3 +1699,212 @@ impl<'tcx> Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> { } } } + +fn check_suspend_tys<'tcx>(tcx: TyCtxt<'tcx>, layout: &GeneratorLayout<'tcx>, body: &Body<'tcx>) { + let mut linted_tys = FxHashSet::default(); + + // We want a user-facing param-env. + let param_env = tcx.param_env(body.source.def_id()); + + for (variant, yield_source_info) in + layout.variant_fields.iter().zip(&layout.variant_source_info) + { + debug!(?variant); + for &local in variant { + let decl = &layout.field_tys[local]; + debug!(?decl); + + if !decl.ignore_for_traits && linted_tys.insert(decl.ty) { + let Some(hir_id) = decl.source_info.scope.lint_root(&body.source_scopes) else { continue }; + + check_must_not_suspend_ty( + tcx, + decl.ty, + hir_id, + param_env, + SuspendCheckData { + source_span: decl.source_info.span, + yield_span: yield_source_info.span, + plural_len: 1, + ..Default::default() + }, + ); + } + } + } +} + +#[derive(Default)] +struct SuspendCheckData<'a> { + source_span: Span, + yield_span: Span, + descr_pre: &'a str, + descr_post: &'a str, + plural_len: usize, +} + +// Returns whether it emitted a diagnostic or not +// Note that this fn and the proceeding one are based on the code +// for creating must_use diagnostics +// +// Note that this technique was chosen over things like a `Suspend` marker trait +// as it is simpler and has precedent in the compiler +fn check_must_not_suspend_ty<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, + hir_id: hir::HirId, + param_env: ty::ParamEnv<'tcx>, + data: SuspendCheckData<'_>, +) -> bool { + if ty.is_unit() { + return false; + } + + let plural_suffix = pluralize!(data.plural_len); + + debug!("Checking must_not_suspend for {}", ty); + + match *ty.kind() { + ty::Adt(..) if ty.is_box() => { + let boxed_ty = ty.boxed_ty(); + let descr_pre = &format!("{}boxed ", data.descr_pre); + check_must_not_suspend_ty( + tcx, + boxed_ty, + hir_id, + param_env, + SuspendCheckData { descr_pre, ..data }, + ) + } + ty::Adt(def, _) => check_must_not_suspend_def(tcx, def.did(), hir_id, data), + // FIXME: support adding the attribute to TAITs + ty::Alias(ty::Opaque, ty::AliasTy { def_id: def, .. }) => { + let mut has_emitted = false; + for &(predicate, _) in tcx.explicit_item_bounds(def) { + // We only look at the `DefId`, so it is safe to skip the binder here. + if let ty::PredicateKind::Clause(ty::Clause::Trait(ref poly_trait_predicate)) = + predicate.kind().skip_binder() + { + let def_id = poly_trait_predicate.trait_ref.def_id; + let descr_pre = &format!("{}implementer{} of ", data.descr_pre, plural_suffix); + if check_must_not_suspend_def( + tcx, + def_id, + hir_id, + SuspendCheckData { descr_pre, ..data }, + ) { + has_emitted = true; + break; + } + } + } + has_emitted + } + ty::Dynamic(binder, _, _) => { + let mut has_emitted = false; + for predicate in binder.iter() { + if let ty::ExistentialPredicate::Trait(ref trait_ref) = predicate.skip_binder() { + let def_id = trait_ref.def_id; + let descr_post = &format!(" trait object{}{}", plural_suffix, data.descr_post); + if check_must_not_suspend_def( + tcx, + def_id, + hir_id, + SuspendCheckData { descr_post, ..data }, + ) { + has_emitted = true; + break; + } + } + } + has_emitted + } + ty::Tuple(fields) => { + let mut has_emitted = false; + for (i, ty) in fields.iter().enumerate() { + let descr_post = &format!(" in tuple element {i}"); + if check_must_not_suspend_ty( + tcx, + ty, + hir_id, + param_env, + SuspendCheckData { descr_post, ..data }, + ) { + has_emitted = true; + } + } + has_emitted + } + ty::Array(ty, len) => { + let descr_pre = &format!("{}array{} of ", data.descr_pre, plural_suffix); + check_must_not_suspend_ty( + tcx, + ty, + hir_id, + param_env, + SuspendCheckData { + descr_pre, + plural_len: len.try_eval_target_usize(tcx, param_env).unwrap_or(0) as usize + 1, + ..data + }, + ) + } + // If drop tracking is enabled, we want to look through references, since the referrent + // may not be considered live across the await point. + ty::Ref(_region, ty, _mutability) => { + let descr_pre = &format!("{}reference{} to ", data.descr_pre, plural_suffix); + check_must_not_suspend_ty( + tcx, + ty, + hir_id, + param_env, + SuspendCheckData { descr_pre, ..data }, + ) + } + _ => false, + } +} + +fn check_must_not_suspend_def( + tcx: TyCtxt<'_>, + def_id: DefId, + hir_id: hir::HirId, + data: SuspendCheckData<'_>, +) -> bool { + if let Some(attr) = tcx.get_attr(def_id, sym::must_not_suspend) { + let msg = format!( + "{}`{}`{} held across a suspend point, but should not be", + data.descr_pre, + tcx.def_path_str(def_id), + data.descr_post, + ); + tcx.struct_span_lint_hir( + rustc_session::lint::builtin::MUST_NOT_SUSPEND, + hir_id, + data.source_span, + msg, + |lint| { + // add span pointing to the offending yield/await + lint.span_label(data.yield_span, "the value is held across this suspend point"); + + // Add optional reason note + if let Some(note) = attr.value_str() { + // FIXME(guswynn): consider formatting this better + lint.span_note(data.source_span, note.as_str()); + } + + // Add some quick suggestions on what to do + // FIXME: can `drop` work as a suggestion here as well? + lint.span_help( + data.source_span, + "consider using a block (`{ ... }`) \ + to shrink the value's scope, ending before the suspend point", + ) + }, + ); + + true + } else { + false + } +} diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs index 28c9080d3..6e6d6566f 100644 --- a/compiler/rustc_mir_transform/src/inline.rs +++ b/compiler/rustc_mir_transform/src/inline.rs @@ -96,7 +96,7 @@ fn inline<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> bool { history: Vec::new(), changed: false, }; - let blocks = BasicBlock::new(0)..body.basic_blocks.next_index(); + let blocks = START_BLOCK..body.basic_blocks.next_index(); this.process_blocks(body, blocks); this.changed } @@ -331,7 +331,7 @@ impl<'tcx> Inliner<'tcx> { return None; } - let fn_sig = self.tcx.bound_fn_sig(def_id).subst(self.tcx, substs); + let fn_sig = self.tcx.fn_sig(def_id).subst(self.tcx, substs); let source_info = SourceInfo { span: fn_span, ..terminator.source_info }; return Some(CallSite { callee, fn_sig, block: bb, target, source_info }); @@ -888,7 +888,7 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> { location: Location, ) { if let ProjectionElem::Field(f, ty) = elem { - let parent = Place { local, projection: self.tcx.intern_place_elems(proj_base) }; + let parent = Place { local, projection: self.tcx.mk_place_elems(proj_base) }; let parent_ty = parent.ty(&self.callee_body.local_decls, self.tcx); let check_equal = |this: &mut Self, f_ty| { if !util::is_equal_up_to_subtyping(this.tcx, this.param_env, ty, f_ty) { @@ -900,7 +900,7 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> { let kind = match parent_ty.ty.kind() { &ty::Alias(ty::Opaque, ty::AliasTy { def_id, substs, .. }) => { - self.tcx.bound_type_of(def_id).subst(self.tcx, substs).kind() + self.tcx.type_of(def_id).subst(self.tcx, substs).kind() } kind => kind, }; @@ -947,12 +947,12 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> { return; }; - let Some(&f_ty) = layout.field_tys.get(local) else { + let Some(f_ty) = layout.field_tys.get(local) else { self.validation = Err("malformed MIR"); return; }; - f_ty + f_ty.ty } else { let Some(f_ty) = substs.as_generator().prefix_tys().nth(f.index()) else { self.validation = Err("malformed MIR"); diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs index b027f9492..792457c80 100644 --- a/compiler/rustc_mir_transform/src/inline/cycle.rs +++ b/compiler/rustc_mir_transform/src/inline/cycle.rs @@ -2,7 +2,7 @@ use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexSet}; use rustc_data_structures::stack::ensure_sufficient_stack; use rustc_hir::def_id::{DefId, LocalDefId}; use rustc_middle::mir::TerminatorKind; -use rustc_middle::ty::TypeVisitable; +use rustc_middle::ty::TypeVisitableExt; use rustc_middle::ty::{self, subst::SubstsRef, InstanceDef, TyCtxt}; use rustc_session::Limit; diff --git a/compiler/rustc_mir_transform/src/instcombine.rs b/compiler/rustc_mir_transform/src/instcombine.rs index 2f3c65869..4182da195 100644 --- a/compiler/rustc_mir_transform/src/instcombine.rs +++ b/compiler/rustc_mir_transform/src/instcombine.rs @@ -6,7 +6,9 @@ use rustc_middle::mir::{ BinOp, Body, Constant, ConstantKind, LocalDecls, Operand, Place, ProjectionElem, Rvalue, SourceInfo, Statement, StatementKind, Terminator, TerminatorKind, UnOp, }; -use rustc_middle::ty::{self, TyCtxt}; +use rustc_middle::ty::layout::ValidityRequirement; +use rustc_middle::ty::{self, ParamEnv, SubstsRef, Ty, TyCtxt}; +use rustc_span::symbol::Symbol; pub struct InstCombine; @@ -16,7 +18,11 @@ impl<'tcx> MirPass<'tcx> for InstCombine { } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let ctx = InstCombineContext { tcx, local_decls: &body.local_decls }; + let ctx = InstCombineContext { + tcx, + local_decls: &body.local_decls, + param_env: tcx.param_env_reveal_all_normalized(body.source.def_id()), + }; for block in body.basic_blocks.as_mut() { for statement in block.statements.iter_mut() { match statement.kind { @@ -24,6 +30,7 @@ impl<'tcx> MirPass<'tcx> for InstCombine { ctx.combine_bool_cmp(&statement.source_info, rvalue); ctx.combine_ref_deref(&statement.source_info, rvalue); ctx.combine_len(&statement.source_info, rvalue); + ctx.combine_cast(&statement.source_info, rvalue); } _ => {} } @@ -33,6 +40,10 @@ impl<'tcx> MirPass<'tcx> for InstCombine { &mut block.terminator.as_mut().unwrap(), &mut block.statements, ); + ctx.combine_intrinsic_assert( + &mut block.terminator.as_mut().unwrap(), + &mut block.statements, + ); } } } @@ -40,6 +51,7 @@ impl<'tcx> MirPass<'tcx> for InstCombine { struct InstCombineContext<'tcx, 'a> { tcx: TyCtxt<'tcx>, local_decls: &'a LocalDecls<'tcx>, + param_env: ParamEnv<'tcx>, } impl<'tcx> InstCombineContext<'tcx, '_> { @@ -99,11 +111,7 @@ impl<'tcx> InstCombineContext<'tcx, '_> { fn combine_ref_deref(&self, source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) { if let Rvalue::Ref(_, _, place) = rvalue { if let Some((base, ProjectionElem::Deref)) = place.as_ref().last_projection() { - if let ty::Ref(_, _, Mutability::Not) = - base.ty(self.local_decls, self.tcx).ty.kind() - { - // The dereferenced place must have type `&_`, so that we don't copy `&mut _`. - } else { + if rvalue.ty(self.local_decls, self.tcx) != base.ty(self.local_decls, self.tcx).ty { return; } @@ -113,7 +121,7 @@ impl<'tcx> InstCombineContext<'tcx, '_> { *rvalue = Rvalue::Use(Operand::Copy(Place { local: base.local, - projection: self.tcx.intern_place_elems(base.projection), + projection: self.tcx.mk_place_elems(base.projection), })); } } @@ -135,6 +143,14 @@ impl<'tcx> InstCombineContext<'tcx, '_> { } } + fn combine_cast(&self, _source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) { + if let Rvalue::Cast(_kind, operand, ty) = rvalue { + if operand.ty(self.local_decls, self.tcx) == *ty { + *rvalue = Rvalue::Use(operand.clone()); + } + } + } + fn combine_primitive_clone( &self, terminator: &mut Terminator<'tcx>, @@ -200,4 +216,58 @@ impl<'tcx> InstCombineContext<'tcx, '_> { }); terminator.kind = TerminatorKind::Goto { target: destination_block }; } + + fn combine_intrinsic_assert( + &self, + terminator: &mut Terminator<'tcx>, + _statements: &mut Vec>, + ) { + let TerminatorKind::Call { func, target, .. } = &mut terminator.kind else { return; }; + let Some(target_block) = target else { return; }; + let func_ty = func.ty(self.local_decls, self.tcx); + let Some((intrinsic_name, substs)) = resolve_rust_intrinsic(self.tcx, func_ty) else { + return; + }; + // The intrinsics we are interested in have one generic parameter + if substs.is_empty() { + return; + } + let ty = substs.type_at(0); + + let known_is_valid = intrinsic_assert_panics(self.tcx, self.param_env, ty, intrinsic_name); + match known_is_valid { + // We don't know the layout or it's not validity assertion at all, don't touch it + None => {} + Some(true) => { + // If we know the assert panics, indicate to later opts that the call diverges + *target = None; + } + Some(false) => { + // If we know the assert does not panic, turn the call into a Goto + terminator.kind = TerminatorKind::Goto { target: *target_block }; + } + } + } +} + +fn intrinsic_assert_panics<'tcx>( + tcx: TyCtxt<'tcx>, + param_env: ty::ParamEnv<'tcx>, + ty: Ty<'tcx>, + intrinsic_name: Symbol, +) -> Option { + let requirement = ValidityRequirement::from_intrinsic(intrinsic_name)?; + Some(!tcx.check_validity_requirement((requirement, param_env.and(ty))).ok()?) +} + +fn resolve_rust_intrinsic<'tcx>( + tcx: TyCtxt<'tcx>, + func_ty: Ty<'tcx>, +) -> Option<(Symbol, SubstsRef<'tcx>)> { + if let ty::FnDef(def_id, substs) = *func_ty.kind() { + if tcx.is_intrinsic(def_id) { + return Some((tcx.item_name(def_id), substs)); + } + } + None } diff --git a/compiler/rustc_mir_transform/src/large_enums.rs b/compiler/rustc_mir_transform/src/large_enums.rs new file mode 100644 index 000000000..89e0a007d --- /dev/null +++ b/compiler/rustc_mir_transform/src/large_enums.rs @@ -0,0 +1,294 @@ +use crate::rustc_middle::ty::util::IntTypeExt; +use crate::MirPass; +use rustc_data_structures::fx::FxHashMap; +use rustc_middle::mir::interpret::AllocId; +use rustc_middle::mir::*; +use rustc_middle::ty::{self, AdtDef, ParamEnv, Ty, TyCtxt}; +use rustc_session::Session; +use rustc_target::abi::{HasDataLayout, Size, TagEncoding, Variants}; + +/// A pass that seeks to optimize unnecessary moves of large enum types, if there is a large +/// enough discrepancy between them. +/// +/// i.e. If there is are two variants: +/// ``` +/// enum Example { +/// Small, +/// Large([u32; 1024]), +/// } +/// ``` +/// Instead of emitting moves of the large variant, +/// Perform a memcpy instead. +/// Based off of [this HackMD](https://hackmd.io/@ft4bxUsFT5CEUBmRKYHr7w/rJM8BBPzD). +/// +/// In summary, what this does is at runtime determine which enum variant is active, +/// and instead of copying all the bytes of the largest possible variant, +/// copy only the bytes for the currently active variant. +pub struct EnumSizeOpt { + pub(crate) discrepancy: u64, +} + +impl<'tcx> MirPass<'tcx> for EnumSizeOpt { + fn is_enabled(&self, sess: &Session) -> bool { + sess.opts.unstable_opts.unsound_mir_opts || sess.mir_opt_level() >= 3 + } + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // NOTE: This pass may produce different MIR based on the alignment of the target + // platform, but it will still be valid. + self.optim(tcx, body); + } +} + +impl EnumSizeOpt { + fn candidate<'tcx>( + &self, + tcx: TyCtxt<'tcx>, + param_env: ParamEnv<'tcx>, + ty: Ty<'tcx>, + alloc_cache: &mut FxHashMap, AllocId>, + ) -> Option<(AdtDef<'tcx>, usize, AllocId)> { + let adt_def = match ty.kind() { + ty::Adt(adt_def, _substs) if adt_def.is_enum() => adt_def, + _ => return None, + }; + let layout = tcx.layout_of(param_env.and(ty)).ok()?; + let variants = match &layout.variants { + Variants::Single { .. } => return None, + Variants::Multiple { tag_encoding, .. } + if matches!(tag_encoding, TagEncoding::Niche { .. }) => + { + return None; + } + Variants::Multiple { variants, .. } if variants.len() <= 1 => return None, + Variants::Multiple { variants, .. } => variants, + }; + let min = variants.iter().map(|v| v.size).min().unwrap(); + let max = variants.iter().map(|v| v.size).max().unwrap(); + if max.bytes() - min.bytes() < self.discrepancy { + return None; + } + + let num_discrs = adt_def.discriminants(tcx).count(); + if variants.iter_enumerated().any(|(var_idx, _)| { + let discr_for_var = adt_def.discriminant_for_variant(tcx, var_idx).val; + (discr_for_var > usize::MAX as u128) || (discr_for_var as usize >= num_discrs) + }) { + return None; + } + if let Some(alloc_id) = alloc_cache.get(&ty) { + return Some((*adt_def, num_discrs, *alloc_id)); + } + + let data_layout = tcx.data_layout(); + let ptr_sized_int = data_layout.ptr_sized_integer(); + let target_bytes = ptr_sized_int.size().bytes() as usize; + let mut data = vec![0; target_bytes * num_discrs]; + macro_rules! encode_store { + ($curr_idx: expr, $endian: expr, $bytes: expr) => { + let bytes = match $endian { + rustc_target::abi::Endian::Little => $bytes.to_le_bytes(), + rustc_target::abi::Endian::Big => $bytes.to_be_bytes(), + }; + for (i, b) in bytes.into_iter().enumerate() { + data[$curr_idx + i] = b; + } + }; + } + + for (var_idx, layout) in variants.iter_enumerated() { + let curr_idx = + target_bytes * adt_def.discriminant_for_variant(tcx, var_idx).val as usize; + let sz = layout.size; + match ptr_sized_int { + rustc_target::abi::Integer::I32 => { + encode_store!(curr_idx, data_layout.endian, sz.bytes() as u32); + } + rustc_target::abi::Integer::I64 => { + encode_store!(curr_idx, data_layout.endian, sz.bytes()); + } + _ => unreachable!(), + }; + } + let alloc = interpret::Allocation::from_bytes( + data, + tcx.data_layout.ptr_sized_integer().align(&tcx.data_layout).abi, + Mutability::Not, + ); + let alloc = tcx.create_memory_alloc(tcx.mk_const_alloc(alloc)); + Some((*adt_def, num_discrs, *alloc_cache.entry(ty).or_insert(alloc))) + } + fn optim<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let mut alloc_cache = FxHashMap::default(); + let body_did = body.source.def_id(); + let param_env = tcx.param_env(body_did); + + let blocks = body.basic_blocks.as_mut(); + let local_decls = &mut body.local_decls; + + for bb in blocks { + bb.expand_statements(|st| { + if let StatementKind::Assign(box ( + lhs, + Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)), + )) = &st.kind + { + let ty = lhs.ty(local_decls, tcx).ty; + + let source_info = st.source_info; + let span = source_info.span; + + let (adt_def, num_variants, alloc_id) = + self.candidate(tcx, param_env, ty, &mut alloc_cache)?; + let alloc = tcx.global_alloc(alloc_id).unwrap_memory(); + + let tmp_ty = tcx.mk_array(tcx.types.usize, num_variants as u64); + + let size_array_local = local_decls.push(LocalDecl::new(tmp_ty, span)); + let store_live = Statement { + source_info, + kind: StatementKind::StorageLive(size_array_local), + }; + + let place = Place::from(size_array_local); + let constant_vals = Constant { + span, + user_ty: None, + literal: ConstantKind::Val( + interpret::ConstValue::ByRef { alloc, offset: Size::ZERO }, + tmp_ty, + ), + }; + let rval = Rvalue::Use(Operand::Constant(box (constant_vals))); + + let const_assign = + Statement { source_info, kind: StatementKind::Assign(box (place, rval)) }; + + let discr_place = Place::from( + local_decls + .push(LocalDecl::new(adt_def.repr().discr_type().to_ty(tcx), span)), + ); + + let store_discr = Statement { + source_info, + kind: StatementKind::Assign(box (discr_place, Rvalue::Discriminant(*rhs))), + }; + + let discr_cast_place = + Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span))); + + let cast_discr = Statement { + source_info, + kind: StatementKind::Assign(box ( + discr_cast_place, + Rvalue::Cast( + CastKind::IntToInt, + Operand::Copy(discr_place), + tcx.types.usize, + ), + )), + }; + + let size_place = + Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span))); + + let store_size = Statement { + source_info, + kind: StatementKind::Assign(box ( + size_place, + Rvalue::Use(Operand::Copy(Place { + local: size_array_local, + projection: tcx + .mk_place_elems(&[PlaceElem::Index(discr_cast_place.local)]), + })), + )), + }; + + let dst = + Place::from(local_decls.push(LocalDecl::new(tcx.mk_mut_ptr(ty), span))); + + let dst_ptr = Statement { + source_info, + kind: StatementKind::Assign(box ( + dst, + Rvalue::AddressOf(Mutability::Mut, *lhs), + )), + }; + + let dst_cast_ty = tcx.mk_mut_ptr(tcx.types.u8); + let dst_cast_place = + Place::from(local_decls.push(LocalDecl::new(dst_cast_ty, span))); + + let dst_cast = Statement { + source_info, + kind: StatementKind::Assign(box ( + dst_cast_place, + Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(dst), dst_cast_ty), + )), + }; + + let src = + Place::from(local_decls.push(LocalDecl::new(tcx.mk_imm_ptr(ty), span))); + + let src_ptr = Statement { + source_info, + kind: StatementKind::Assign(box ( + src, + Rvalue::AddressOf(Mutability::Not, *rhs), + )), + }; + + let src_cast_ty = tcx.mk_imm_ptr(tcx.types.u8); + let src_cast_place = + Place::from(local_decls.push(LocalDecl::new(src_cast_ty, span))); + + let src_cast = Statement { + source_info, + kind: StatementKind::Assign(box ( + src_cast_place, + Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(src), src_cast_ty), + )), + }; + + let deinit_old = + Statement { source_info, kind: StatementKind::Deinit(box dst) }; + + let copy_bytes = Statement { + source_info, + kind: StatementKind::Intrinsic( + box NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping { + src: Operand::Copy(src_cast_place), + dst: Operand::Copy(dst_cast_place), + count: Operand::Copy(size_place), + }), + ), + }; + + let store_dead = Statement { + source_info, + kind: StatementKind::StorageDead(size_array_local), + }; + let iter = [ + store_live, + const_assign, + store_discr, + cast_discr, + store_size, + dst_ptr, + dst_cast, + src_ptr, + src_cast, + deinit_old, + copy_bytes, + store_dead, + ] + .into_iter(); + + st.make_nop(); + Some(iter) + } else { + None + } + }); + } + } +} diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index 20b7fdcfe..cdd28ae0c 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -1,6 +1,7 @@ #![allow(rustc::potential_query_instability)] #![feature(box_patterns)] #![feature(drain_filter)] +#![feature(box_syntax)] #![feature(let_chains)] #![feature(map_try_insert)] #![feature(min_specialization)] @@ -23,6 +24,7 @@ use rustc_const_eval::util; use rustc_data_structures::fx::FxIndexSet; use rustc_data_structures::steal::Steal; use rustc_hir as hir; +use rustc_hir::def::DefKind; use rustc_hir::def_id::{DefId, LocalDefId}; use rustc_hir::intravisit::{self, Visitor}; use rustc_index::vec::IndexVec; @@ -33,7 +35,7 @@ use rustc_middle::mir::{ TerminatorKind, }; use rustc_middle::ty::query::Providers; -use rustc_middle::ty::{self, TyCtxt, TypeVisitable}; +use rustc_middle::ty::{self, TyCtxt, TypeVisitableExt}; use rustc_span::sym; #[macro_use] @@ -54,10 +56,11 @@ mod const_debuginfo; mod const_goto; mod const_prop; mod const_prop_lint; +mod copy_prop; mod coverage; +mod ctfe_limit; mod dataflow_const_prop; mod dead_store_elimination; -mod deaggregator; mod deduce_param_attrs; mod deduplicate_blocks; mod deref_separator; @@ -71,6 +74,7 @@ mod function_item_references; mod generator; mod inline; mod instcombine; +mod large_enums; mod lower_intrinsics; mod lower_slice_len; mod match_branches; @@ -86,11 +90,11 @@ mod required_consts; mod reveal_all; mod separate_const_switch; mod shim; +mod ssa; // This pass is public to allow external drivers to perform MIR cleanup pub mod simplify; mod simplify_branches; mod simplify_comparison_integral; -mod simplify_try; mod sroa; mod uninhabited_enum_branching; mod unreachable_prop; @@ -102,7 +106,6 @@ use rustc_mir_dataflow::rustc_peek; pub fn provide(providers: &mut Providers) { check_unsafety::provide(providers); - check_packed_ref::provide(providers); coverage::query::provide(providers); ffi_unwind_calls::provide(providers); shim::provide(providers); @@ -124,6 +127,7 @@ pub fn provide(providers: &mut Providers) { mir_drops_elaborated_and_const_checked, mir_for_ctfe, mir_for_ctfe_of_const_arg, + mir_generator_witnesses: generator::mir_generator_witnesses, optimized_mir, is_mir_available, is_ctfe_mir_available: |tcx, did| is_mir_available(tcx, did), @@ -188,7 +192,7 @@ fn remap_mir_for_const_eval_select<'tcx>( let arguments = (0..num_args).map(|x| { let mut place_elems = place_elems.to_vec(); place_elems.push(ProjectionElem::Field(x.into(), fields[x])); - let projection = tcx.intern_place_elems(&place_elems); + let projection = tcx.mk_place_elems(&place_elems); let place = Place { local: place.local, projection, @@ -244,7 +248,7 @@ fn mir_const_qualif(tcx: TyCtxt<'_>, def: ty::WithOptConstParam) -> // N.B., this `borrow()` is guaranteed to be valid (i.e., the value // cannot yet be stolen), because `mir_promoted()`, which steals - // from `mir_const(), forces this query to execute before + // from `mir_const()`, forces this query to execute before // performing the steal. let body = &tcx.mir_const(def).borrow(); @@ -410,6 +414,8 @@ fn inner_mir_for_ctfe(tcx: TyCtxt<'_>, def: ty::WithOptConstParam) - } } + pm::run_passes(tcx, &mut body, &[&ctfe_limit::CtfeLimit], None); + debug_assert!(!body.has_free_regions(), "Free regions in MIR for CTFE"); body @@ -426,6 +432,11 @@ fn mir_drops_elaborated_and_const_checked( return tcx.mir_drops_elaborated_and_const_checked(def); } + if tcx.sess.opts.unstable_opts.drop_tracking_mir + && let DefKind::Generator = tcx.def_kind(def.did) + { + tcx.ensure().mir_generator_witnesses(def.did); + } let mir_borrowck = tcx.mir_borrowck_opt_const_arg(def); let is_fn_like = tcx.def_kind(def.did).is_fn_like(); @@ -513,9 +524,6 @@ fn run_runtime_lowering_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &elaborate_box_derefs::ElaborateBoxDerefs, &generator::StateTransform, &add_retag::AddRetag, - // Deaggregator is necessary for const prop. We may want to consider implementing - // CTFE support for aggregates. - &deaggregator::Deaggregator, &Lint(const_prop_lint::ConstProp), ]; pm::run_passes_no_validate(tcx, body, passes, Some(MirPhase::Runtime(RuntimePhase::Initial))); @@ -541,13 +549,13 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &[ &reveal_all::RevealAll, // has to be done before inlining, since inlined code is in RevealAll mode. &lower_slice_len::LowerSliceLenCalls, // has to be done before inlining, otherwise actual call will be almost always inlined. Also simple, so can just do first - &normalize_array_len::NormalizeArrayLen, // has to run after `slice::len` lowering &unreachable_prop::UnreachablePropagation, &uninhabited_enum_branching::UninhabitedEnumBranching, &o1(simplify::SimplifyCfg::new("after-uninhabited-enum-branching")), &inline::Inline, &remove_storage_markers::RemoveStorageMarkers, &remove_zsts::RemoveZsts, + &normalize_array_len::NormalizeArrayLen, // has to run after `slice::len` lowering &const_goto::ConstGoto, &remove_unneeded_drops::RemoveUnneededDrops, &sroa::ScalarReplacementOfAggregates, @@ -557,6 +565,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &instcombine::InstCombine, &separate_const_switch::SeparateConstSwitch, &simplify::SimplifyLocals::new("before-const-prop"), + ©_prop::CopyProp, // // FIXME(#70073): This pass is responsible for both optimization as well as some lints. &const_prop::ConstProp, @@ -567,8 +576,6 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &o1(simplify_branches::SimplifyConstCondition::new("after-const-prop")), &early_otherwise_branch::EarlyOtherwiseBranch, &simplify_comparison_integral::SimplifyComparisonIntegral, - &simplify_try::SimplifyArmIdentity, - &simplify_try::SimplifyBranchSame, &dead_store_elimination::DeadStoreElimination, &dest_prop::DestinationPropagation, &o1(simplify_branches::SimplifyConstCondition::new("final")), @@ -578,6 +585,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &simplify::SimplifyLocals::new("final"), &multiple_return_terminators::MultipleReturnTerminators, &deduplicate_blocks::DeduplicateBlocks, + &large_enums::EnumSizeOpt { discrepancy: 128 }, // Some cleanup necessary at least for LLVM and potentially other codegen backends. &add_call_guards::CriticalCallEdges, // Dump the end result for testing and debugging purposes. diff --git a/compiler/rustc_mir_transform/src/lower_intrinsics.rs b/compiler/rustc_mir_transform/src/lower_intrinsics.rs index 9892580e6..f596cc180 100644 --- a/compiler/rustc_mir_transform/src/lower_intrinsics.rs +++ b/compiler/rustc_mir_transform/src/lower_intrinsics.rs @@ -107,9 +107,29 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics { } } sym::add_with_overflow | sym::sub_with_overflow | sym::mul_with_overflow => { - // The checked binary operations are not suitable target for lowering here, - // since their semantics depend on the value of overflow-checks flag used - // during codegen. Issue #35310. + if let Some(target) = *target { + let lhs; + let rhs; + { + let mut args = args.drain(..); + lhs = args.next().unwrap(); + rhs = args.next().unwrap(); + } + let bin_op = match intrinsic_name { + sym::add_with_overflow => BinOp::Add, + sym::sub_with_overflow => BinOp::Sub, + sym::mul_with_overflow => BinOp::Mul, + _ => bug!("unexpected intrinsic"), + }; + block.statements.push(Statement { + source_info: terminator.source_info, + kind: StatementKind::Assign(Box::new(( + *destination, + Rvalue::CheckedBinaryOp(bin_op, Box::new((lhs, rhs))), + ))), + }); + terminator.kind = TerminatorKind::Goto { target }; + } } sym::size_of | sym::min_align_of => { if let Some(target) = *target { diff --git a/compiler/rustc_mir_transform/src/lower_slice_len.rs b/compiler/rustc_mir_transform/src/lower_slice_len.rs index 2f02d00ec..c6e7468aa 100644 --- a/compiler/rustc_mir_transform/src/lower_slice_len.rs +++ b/compiler/rustc_mir_transform/src/lower_slice_len.rs @@ -68,8 +68,11 @@ fn lower_slice_len_call<'tcx>( ty::FnDef(fn_def_id, _) if fn_def_id == &slice_len_fn_item_def_id => { // perform modifications // from something like `_5 = core::slice::::len(move _6) -> bb1` - // into `_5 = Len(*_6) + // into: + // ``` + // _5 = Len(*_6) // goto bb1 + // ``` // make new RValue for Len let deref_arg = tcx.mk_place_deref(arg); diff --git a/compiler/rustc_mir_transform/src/normalize_array_len.rs b/compiler/rustc_mir_transform/src/normalize_array_len.rs index 1708b287e..b36c8a0bd 100644 --- a/compiler/rustc_mir_transform/src/normalize_array_len.rs +++ b/compiler/rustc_mir_transform/src/normalize_array_len.rs @@ -1,288 +1,104 @@ //! This pass eliminates casting of arrays into slices when their length //! is taken using `.len()` method. Handy to preserve information in MIR for const prop +use crate::ssa::SsaLocals; use crate::MirPass; -use rustc_data_structures::fx::FxIndexMap; -use rustc_data_structures::intern::Interned; -use rustc_index::bit_set::BitSet; use rustc_index::vec::IndexVec; +use rustc_middle::mir::visit::*; use rustc_middle::mir::*; -use rustc_middle::ty::{self, ReErased, Region, TyCtxt}; - -const MAX_NUM_BLOCKS: usize = 800; -const MAX_NUM_LOCALS: usize = 3000; +use rustc_middle::ty::{self, TyCtxt}; +use rustc_mir_dataflow::impls::borrowed_locals; pub struct NormalizeArrayLen; impl<'tcx> MirPass<'tcx> for NormalizeArrayLen { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - // See #105929 - sess.mir_opt_level() >= 4 && sess.opts.unstable_opts.unsound_mir_opts + sess.mir_opt_level() >= 3 } + #[instrument(level = "trace", skip(self, tcx, body))] fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - // early returns for edge cases of highly unrolled functions - if body.basic_blocks.len() > MAX_NUM_BLOCKS { - return; - } - if body.local_decls.len() > MAX_NUM_LOCALS { - return; - } + debug!(def_id = ?body.source.def_id()); normalize_array_len_calls(tcx, body) } } -pub fn normalize_array_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - // We don't ever touch terminators, so no need to invalidate the CFG cache - let basic_blocks = body.basic_blocks.as_mut_preserves_cfg(); - let local_decls = &mut body.local_decls; +fn normalize_array_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); + let borrowed_locals = borrowed_locals(body); + let ssa = SsaLocals::new(tcx, param_env, body, &borrowed_locals); - // do a preliminary analysis to see if we ever have locals of type `[T;N]` or `&[T;N]` - let mut interesting_locals = BitSet::new_empty(local_decls.len()); - for (local, decl) in local_decls.iter_enumerated() { - match decl.ty.kind() { - ty::Array(..) => { - interesting_locals.insert(local); - } - ty::Ref(.., ty, Mutability::Not) => match ty.kind() { - ty::Array(..) => { - interesting_locals.insert(local); - } - _ => {} - }, - _ => {} - } - } - if interesting_locals.is_empty() { - // we have found nothing to analyze - return; - } - let num_intesting_locals = interesting_locals.count(); - let mut state = FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default()); - let mut patches_scratchpad = - FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default()); - let mut replacements_scratchpad = - FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default()); - for block in basic_blocks { - // make length calls for arrays [T; N] not to decay into length calls for &[T] - // that forbids constant propagation - normalize_array_len_call( - tcx, - block, - local_decls, - &interesting_locals, - &mut state, - &mut patches_scratchpad, - &mut replacements_scratchpad, - ); - state.clear(); - patches_scratchpad.clear(); - replacements_scratchpad.clear(); - } -} + let slice_lengths = compute_slice_length(tcx, &ssa, body); + debug!(?slice_lengths); -struct Patcher<'a, 'tcx> { - tcx: TyCtxt<'tcx>, - patches_scratchpad: &'a FxIndexMap, - replacements_scratchpad: &'a mut FxIndexMap, - local_decls: &'a mut IndexVec>, - statement_idx: usize, + Replacer { tcx, slice_lengths }.visit_body_preserves_cfg(body); } -impl<'tcx> Patcher<'_, 'tcx> { - fn patch_expand_statement( - &mut self, - statement: &mut Statement<'tcx>, - ) -> Option>> { - let idx = self.statement_idx; - if let Some(len_statemnt_idx) = self.patches_scratchpad.get(&idx).copied() { - let mut statements = Vec::with_capacity(2); - - // we are at statement that performs a cast. The only sound way is - // to create another local that performs a similar copy without a cast and then - // use this copy in the Len operation - - match &statement.kind { - StatementKind::Assign(box ( - .., - Rvalue::Cast( - CastKind::Pointer(ty::adjustment::PointerCast::Unsize), - operand, - _, - ), - )) => { - match operand { - Operand::Copy(place) | Operand::Move(place) => { - // create new local - let ty = operand.ty(self.local_decls, self.tcx); - let local_decl = LocalDecl::with_source_info(ty, statement.source_info); - let local = self.local_decls.push(local_decl); - // make it live - let mut make_live_statement = statement.clone(); - make_live_statement.kind = StatementKind::StorageLive(local); - statements.push(make_live_statement); - // copy into it - - let operand = Operand::Copy(*place); - let mut make_copy_statement = statement.clone(); - let assign_to = Place::from(local); - let rvalue = Rvalue::Use(operand); - make_copy_statement.kind = - StatementKind::Assign(Box::new((assign_to, rvalue))); - statements.push(make_copy_statement); - - // to reorder we have to copy and make NOP - statements.push(statement.clone()); - statement.make_nop(); - - self.replacements_scratchpad.insert(len_statemnt_idx, local); - } - _ => { - unreachable!("it's a bug in the implementation") - } - } - } - _ => { - unreachable!("it's a bug in the implementation") +fn compute_slice_length<'tcx>( + tcx: TyCtxt<'tcx>, + ssa: &SsaLocals, + body: &Body<'tcx>, +) -> IndexVec>> { + let mut slice_lengths = IndexVec::from_elem(None, &body.local_decls); + + for (local, rvalue) in ssa.assignments(body) { + match rvalue { + Rvalue::Cast( + CastKind::Pointer(ty::adjustment::PointerCast::Unsize), + operand, + cast_ty, + ) => { + let operand_ty = operand.ty(body, tcx); + debug!(?operand_ty); + if let Some(operand_ty) = operand_ty.builtin_deref(true) + && let ty::Array(_, len) = operand_ty.ty.kind() + && let Some(cast_ty) = cast_ty.builtin_deref(true) + && let ty::Slice(..) = cast_ty.ty.kind() + { + slice_lengths[local] = Some(*len); } } - - self.statement_idx += 1; - - Some(statements.into_iter()) - } else if let Some(local) = self.replacements_scratchpad.get(&idx).copied() { - let mut statements = Vec::with_capacity(2); - - match &statement.kind { - StatementKind::Assign(box (into, Rvalue::Len(place))) => { - let add_deref = if let Some(..) = place.as_local() { - false - } else if let Some(..) = place.local_or_deref_local() { - true - } else { - unreachable!("it's a bug in the implementation") - }; - // replace len statement - let mut len_statement = statement.clone(); - let mut place = Place::from(local); - if add_deref { - place = self.tcx.mk_place_deref(place); - } - len_statement.kind = - StatementKind::Assign(Box::new((*into, Rvalue::Len(place)))); - statements.push(len_statement); - - // make temporary dead - let mut make_dead_statement = statement.clone(); - make_dead_statement.kind = StatementKind::StorageDead(local); - statements.push(make_dead_statement); - - // make original statement NOP - statement.make_nop(); + // The length information is stored in the fat pointer, so we treat `operand` as a value. + Rvalue::Use(operand) => { + if let Some(rhs) = operand.place() && let Some(rhs) = rhs.as_local() { + slice_lengths[local] = slice_lengths[rhs]; } - _ => { - unreachable!("it's a bug in the implementation") + } + // The length information is stored in the fat pointer. + // Reborrowing copies length information from one pointer to the other. + Rvalue::Ref(_, _, rhs) | Rvalue::AddressOf(_, rhs) => { + if let [PlaceElem::Deref] = rhs.projection[..] { + slice_lengths[local] = slice_lengths[rhs.local]; } } - - self.statement_idx += 1; - - Some(statements.into_iter()) - } else { - self.statement_idx += 1; - None + _ => {} } } + + slice_lengths } -fn normalize_array_len_call<'tcx>( +struct Replacer<'tcx> { tcx: TyCtxt<'tcx>, - block: &mut BasicBlockData<'tcx>, - local_decls: &mut IndexVec>, - interesting_locals: &BitSet, - state: &mut FxIndexMap, - patches_scratchpad: &mut FxIndexMap, - replacements_scratchpad: &mut FxIndexMap, -) { - for (statement_idx, statement) in block.statements.iter_mut().enumerate() { - match &mut statement.kind { - StatementKind::Assign(box (place, rvalue)) => { - match rvalue { - Rvalue::Cast( - CastKind::Pointer(ty::adjustment::PointerCast::Unsize), - operand, - cast_ty, - ) => { - let Some(local) = place.as_local() else { return }; - match operand { - Operand::Copy(place) | Operand::Move(place) => { - let Some(operand_local) = place.local_or_deref_local() else { return; }; - if !interesting_locals.contains(operand_local) { - return; - } - let operand_ty = local_decls[operand_local].ty; - match (operand_ty.kind(), cast_ty.kind()) { - (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => { - if of_ty_src == of_ty_dst { - // this is a cast from [T; N] into [T], so we are good - state.insert(local, statement_idx); - } - } - // current way of patching doesn't allow to work with `mut` - ( - ty::Ref( - Region(Interned(ReErased, _)), - operand_ty, - Mutability::Not, - ), - ty::Ref( - Region(Interned(ReErased, _)), - cast_ty, - Mutability::Not, - ), - ) => { - match (operand_ty.kind(), cast_ty.kind()) { - // current way of patching doesn't allow to work with `mut` - (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => { - if of_ty_src == of_ty_dst { - // this is a cast from [T; N] into [T], so we are good - state.insert(local, statement_idx); - } - } - _ => {} - } - } - _ => {} - } - } - _ => {} - } - } - Rvalue::Len(place) => { - let Some(local) = place.local_or_deref_local() else { - return; - }; - if let Some(cast_statement_idx) = state.get(&local).copied() { - patches_scratchpad.insert(cast_statement_idx, statement_idx); - } - } - _ => { - // invalidate - state.remove(&place.local); - } - } - } - _ => {} - } - } + slice_lengths: IndexVec>>, +} - let mut patcher = Patcher { - tcx, - patches_scratchpad: &*patches_scratchpad, - replacements_scratchpad, - local_decls, - statement_idx: 0, - }; +impl<'tcx> MutVisitor<'tcx> for Replacer<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } - block.expand_statements(|st| patcher.patch_expand_statement(st)); + fn visit_rvalue(&mut self, rvalue: &mut Rvalue<'tcx>, loc: Location) { + if let Rvalue::Len(place) = rvalue + && let [PlaceElem::Deref] = &place.projection[..] + && let Some(len) = self.slice_lengths[place.local] + { + *rvalue = Rvalue::Use(Operand::Constant(Box::new(Constant { + span: rustc_span::DUMMY_SP, + user_ty: None, + literal: ConstantKind::from_const(len, self.tcx), + }))); + } + self.super_rvalue(rvalue, loc); + } } diff --git a/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs b/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs index f1bbf2ea7..e3a03aa08 100644 --- a/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs +++ b/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs @@ -35,6 +35,7 @@ impl RemoveNoopLandingPads { | StatementKind::StorageDead(_) | StatementKind::AscribeUserType(..) | StatementKind::Coverage(..) + | StatementKind::ConstEvalCounter | StatementKind::Nop => { // These are all noops in a landing pad } diff --git a/compiler/rustc_mir_transform/src/remove_zsts.rs b/compiler/rustc_mir_transform/src/remove_zsts.rs index 6cabef92d..1becfddb2 100644 --- a/compiler/rustc_mir_transform/src/remove_zsts.rs +++ b/compiler/rustc_mir_transform/src/remove_zsts.rs @@ -13,7 +13,7 @@ impl<'tcx> MirPass<'tcx> for RemoveZsts { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { // Avoid query cycles (generators require optimized MIR for layout). - if tcx.type_of(body.source.def_id()).is_generator() { + if tcx.type_of(body.source.def_id()).subst_identity().is_generator() { return; } let param_env = tcx.param_env(body.source.def_id()); diff --git a/compiler/rustc_mir_transform/src/separate_const_switch.rs b/compiler/rustc_mir_transform/src/separate_const_switch.rs index 2f116aaa9..a24d2d34d 100644 --- a/compiler/rustc_mir_transform/src/separate_const_switch.rs +++ b/compiler/rustc_mir_transform/src/separate_const_switch.rs @@ -250,6 +250,7 @@ fn is_likely_const<'tcx>(mut tracked_place: Place<'tcx>, block: &BasicBlockData< | StatementKind::Coverage(_) | StatementKind::StorageDead(_) | StatementKind::Intrinsic(_) + | StatementKind::ConstEvalCounter | StatementKind::Nop => {} } } @@ -318,6 +319,7 @@ fn find_determining_place<'tcx>( | StatementKind::AscribeUserType(_, _) | StatementKind::Coverage(_) | StatementKind::Intrinsic(_) + | StatementKind::ConstEvalCounter | StatementKind::Nop => {} // If the discriminant is set, it is always set diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs index dace540fa..ebe63d6cb 100644 --- a/compiler/rustc_mir_transform/src/shim.rs +++ b/compiler/rustc_mir_transform/src/shim.rs @@ -15,7 +15,6 @@ use rustc_target::spec::abi::Abi; use std::fmt; use std::iter; -use crate::util::expand_aggregate; use crate::{ abort_unwinding_calls, add_call_guards, add_moves_for_packed_drops, deref_separator, pass_manager as pm, remove_noop_landing_pads, simplify, @@ -148,11 +147,11 @@ fn build_drop_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, ty: Option>) assert!(!matches!(ty, Some(ty) if ty.is_generator())); let substs = if let Some(ty) = ty { - tcx.intern_substs(&[ty.into()]) + tcx.mk_substs(&[ty.into()]) } else { InternalSubsts::identity_for_item(tcx, def_id) }; - let sig = tcx.bound_fn_sig(def_id).subst(tcx, substs); + let sig = tcx.fn_sig(def_id).subst(tcx, substs); let sig = tcx.erase_late_bound_regions(sig); let span = tcx.def_span(def_id); @@ -363,7 +362,7 @@ impl<'tcx> CloneShimBuilder<'tcx> { // we must subst the self_ty because it's // otherwise going to be TySelf and we can't index // or access fields of a Place of type TySelf. - let sig = tcx.bound_fn_sig(def_id).subst(tcx, &[self_ty.into()]); + let sig = tcx.fn_sig(def_id).subst(tcx, &[self_ty.into()]); let sig = tcx.erase_late_bound_regions(sig); let span = tcx.def_span(def_id); @@ -427,7 +426,7 @@ impl<'tcx> CloneShimBuilder<'tcx> { fn make_place(&mut self, mutability: Mutability, ty: Ty<'tcx>) -> Place<'tcx> { let span = self.span; let mut local = LocalDecl::new(ty, span); - if mutability == Mutability::Not { + if mutability.is_not() { local = local.immutable(); } Place::from(self.local_decls.push(local)) @@ -598,7 +597,7 @@ fn build_call_shim<'tcx>( let untuple_args = sig.inputs(); // Create substitutions for the `Self` and `Args` generic parameters of the shim body. - let arg_tup = tcx.mk_tup(untuple_args.iter()); + let arg_tup = tcx.mk_tup(untuple_args); (Some([ty.into(), arg_tup.into()]), Some(untuple_args)) } else { @@ -606,7 +605,7 @@ fn build_call_shim<'tcx>( }; let def_id = instance.def_id(); - let sig = tcx.bound_fn_sig(def_id); + let sig = tcx.fn_sig(def_id); let sig = sig.map_bound(|sig| tcx.erase_late_bound_regions(sig)); assert_eq!(sig_substs.is_some(), !instance.has_polymorphic_mir_body()); @@ -633,7 +632,7 @@ fn build_call_shim<'tcx>( Adjustment::Deref => tcx.mk_imm_ptr(fnty), Adjustment::RefMut => tcx.mk_mut_ptr(fnty), }; - sig.inputs_and_output = tcx.intern_type_list(&inputs_and_output); + sig.inputs_and_output = tcx.mk_type_list(&inputs_and_output); } // FIXME(eddyb) avoid having this snippet both here and in @@ -644,7 +643,7 @@ fn build_call_shim<'tcx>( let self_arg = &mut inputs_and_output[0]; debug_assert!(tcx.generics_of(def_id).has_self && *self_arg == tcx.types.self_param); *self_arg = tcx.mk_mut_ptr(*self_arg); - sig.inputs_and_output = tcx.intern_type_list(&inputs_and_output); + sig.inputs_and_output = tcx.mk_type_list(&inputs_and_output); } let span = tcx.def_span(def_id); @@ -693,7 +692,7 @@ fn build_call_shim<'tcx>( // `FnDef` call with optional receiver. CallKind::Direct(def_id) => { - let ty = tcx.type_of(def_id); + let ty = tcx.type_of(def_id).subst_identity(); ( Operand::Constant(Box::new(Constant { span, @@ -798,7 +797,11 @@ pub fn build_adt_ctor(tcx: TyCtxt<'_>, ctor_id: DefId) -> Body<'_> { let param_env = tcx.param_env(ctor_id); // Normalize the sig. - let sig = tcx.fn_sig(ctor_id).no_bound_vars().expect("LBR in ADT constructor signature"); + let sig = tcx + .fn_sig(ctor_id) + .subst_identity() + .no_bound_vars() + .expect("LBR in ADT constructor signature"); let sig = tcx.normalize_erasing_regions(param_env, sig); let ty::Adt(adt_def, substs) = sig.output().kind() else { @@ -827,19 +830,23 @@ pub fn build_adt_ctor(tcx: TyCtxt<'_>, ctor_id: DefId) -> Body<'_> { // return; debug!("build_ctor: variant_index={:?}", variant_index); - let statements = expand_aggregate( - Place::return_place(), - adt_def.variant(variant_index).fields.iter().enumerate().map(|(idx, field_def)| { - (Operand::Move(Place::from(Local::new(idx + 1))), field_def.ty(tcx, substs)) - }), - AggregateKind::Adt(adt_def.did(), variant_index, substs, None, None), + let kind = AggregateKind::Adt(adt_def.did(), variant_index, substs, None, None); + let variant = adt_def.variant(variant_index); + let statement = Statement { + kind: StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::Aggregate( + Box::new(kind), + (0..variant.fields.len()) + .map(|idx| Operand::Move(Place::from(Local::new(idx + 1)))) + .collect(), + ), + ))), source_info, - tcx, - ) - .collect(); + }; let start_block = BasicBlockData { - statements, + statements: vec![statement], terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }), is_cleanup: false, }; diff --git a/compiler/rustc_mir_transform/src/simplify.rs b/compiler/rustc_mir_transform/src/simplify.rs index 8f6abe7a9..9ef55c558 100644 --- a/compiler/rustc_mir_transform/src/simplify.rs +++ b/compiler/rustc_mir_transform/src/simplify.rs @@ -404,6 +404,18 @@ impl<'tcx> MirPass<'tcx> for SimplifyLocals { } } +pub fn remove_unused_definitions<'tcx>(body: &mut Body<'tcx>) { + // First, we're going to get a count of *actual* uses for every `Local`. + let mut used_locals = UsedLocals::new(body); + + // Next, we're going to remove any `Local` with zero actual uses. When we remove those + // `Locals`, we're also going to subtract any uses of other `Locals` from the `used_locals` + // count. For example, if we removed `_2 = discriminant(_1)`, then we'll subtract one from + // `use_counts[_1]`. That in turn might make `_1` unused, so we loop until we hit a + // fixedpoint where there are no more unused locals. + remove_unused_definitions_helper(&mut used_locals, body); +} + pub fn simplify_locals<'tcx>(body: &mut Body<'tcx>, tcx: TyCtxt<'tcx>) { // First, we're going to get a count of *actual* uses for every `Local`. let mut used_locals = UsedLocals::new(body); @@ -413,7 +425,7 @@ pub fn simplify_locals<'tcx>(body: &mut Body<'tcx>, tcx: TyCtxt<'tcx>) { // count. For example, if we removed `_2 = discriminant(_1)`, then we'll subtract one from // `use_counts[_1]`. That in turn might make `_1` unused, so we loop until we hit a // fixedpoint where there are no more unused locals. - remove_unused_definitions(&mut used_locals, body); + remove_unused_definitions_helper(&mut used_locals, body); // Finally, we'll actually do the work of shrinking `body.local_decls` and remapping the `Local`s. let map = make_local_map(&mut body.local_decls, &used_locals); @@ -484,7 +496,7 @@ impl UsedLocals { self.increment = false; // The location of the statement is irrelevant. - let location = Location { block: START_BLOCK, statement_index: 0 }; + let location = Location::START; self.visit_statement(statement, location); } @@ -517,7 +529,7 @@ impl<'tcx> Visitor<'tcx> for UsedLocals { self.super_statement(statement, location); } - StatementKind::Nop => {} + StatementKind::ConstEvalCounter | StatementKind::Nop => {} StatementKind::StorageLive(_local) | StatementKind::StorageDead(_local) => {} @@ -548,7 +560,7 @@ impl<'tcx> Visitor<'tcx> for UsedLocals { } /// Removes unused definitions. Updates the used locals to reflect the changes made. -fn remove_unused_definitions(used_locals: &mut UsedLocals, body: &mut Body<'_>) { +fn remove_unused_definitions_helper(used_locals: &mut UsedLocals, body: &mut Body<'_>) { // The use counts are updated as we remove the statements. A local might become unused // during the retain operation, leading to a temporary inconsistency (storage statements or // definitions referencing the local might remain). For correctness it is crucial that this diff --git a/compiler/rustc_mir_transform/src/simplify_try.rs b/compiler/rustc_mir_transform/src/simplify_try.rs deleted file mode 100644 index e4f3ace9a..000000000 --- a/compiler/rustc_mir_transform/src/simplify_try.rs +++ /dev/null @@ -1,822 +0,0 @@ -//! The general point of the optimizations provided here is to simplify something like: -//! -//! ```rust -//! # fn foo(x: Result) -> Result { -//! match x { -//! Ok(x) => Ok(x), -//! Err(x) => Err(x) -//! } -//! # } -//! ``` -//! -//! into just `x`. - -use crate::{simplify, MirPass}; -use itertools::Itertools as _; -use rustc_index::{bit_set::BitSet, vec::IndexVec}; -use rustc_middle::mir::visit::{NonUseContext, PlaceContext, Visitor}; -use rustc_middle::mir::*; -use rustc_middle::ty::{self, List, Ty, TyCtxt}; -use rustc_target::abi::VariantIdx; -use std::iter::{once, Enumerate, Peekable}; -use std::slice::Iter; - -/// Simplifies arms of form `Variant(x) => Variant(x)` to just a move. -/// -/// This is done by transforming basic blocks where the statements match: -/// -/// ```ignore (MIR) -/// _LOCAL_TMP = ((_LOCAL_1 as Variant ).FIELD: TY ); -/// _TMP_2 = _LOCAL_TMP; -/// ((_LOCAL_0 as Variant).FIELD: TY) = move _TMP_2; -/// discriminant(_LOCAL_0) = VAR_IDX; -/// ``` -/// -/// into: -/// -/// ```ignore (MIR) -/// _LOCAL_0 = move _LOCAL_1 -/// ``` -pub struct SimplifyArmIdentity; - -#[derive(Debug)] -struct ArmIdentityInfo<'tcx> { - /// Storage location for the variant's field - local_temp_0: Local, - /// Storage location holding the variant being read from - local_1: Local, - /// The variant field being read from - vf_s0: VarField<'tcx>, - /// Index of the statement which loads the variant being read - get_variant_field_stmt: usize, - - /// Tracks each assignment to a temporary of the variant's field - field_tmp_assignments: Vec<(Local, Local)>, - - /// Storage location holding the variant's field that was read from - local_tmp_s1: Local, - /// Storage location holding the enum that we are writing to - local_0: Local, - /// The variant field being written to - vf_s1: VarField<'tcx>, - - /// Storage location that the discriminant is being written to - set_discr_local: Local, - /// The variant being written - set_discr_var_idx: VariantIdx, - - /// Index of the statement that should be overwritten as a move - stmt_to_overwrite: usize, - /// SourceInfo for the new move - source_info: SourceInfo, - - /// Indices of matching Storage{Live,Dead} statements encountered. - /// (StorageLive index,, StorageDead index, Local) - storage_stmts: Vec<(usize, usize, Local)>, - - /// The statements that should be removed (turned into nops) - stmts_to_remove: Vec, - - /// Indices of debug variables that need to be adjusted to point to - // `{local_0}.{dbg_projection}`. - dbg_info_to_adjust: Vec, - - /// The projection used to rewrite debug info. - dbg_projection: &'tcx List>, -} - -fn get_arm_identity_info<'a, 'tcx>( - stmts: &'a [Statement<'tcx>], - locals_count: usize, - debug_info: &'a [VarDebugInfo<'tcx>], -) -> Option> { - // This can't possibly match unless there are at least 3 statements in the block - // so fail fast on tiny blocks. - if stmts.len() < 3 { - return None; - } - - let mut tmp_assigns = Vec::new(); - let mut nop_stmts = Vec::new(); - let mut storage_stmts = Vec::new(); - let mut storage_live_stmts = Vec::new(); - let mut storage_dead_stmts = Vec::new(); - - type StmtIter<'a, 'tcx> = Peekable>>>; - - fn is_storage_stmt(stmt: &Statement<'_>) -> bool { - matches!(stmt.kind, StatementKind::StorageLive(_) | StatementKind::StorageDead(_)) - } - - /// Eats consecutive Statements which match `test`, performing the specified `action` for each. - /// The iterator `stmt_iter` is not advanced if none were matched. - fn try_eat<'a, 'tcx>( - stmt_iter: &mut StmtIter<'a, 'tcx>, - test: impl Fn(&'a Statement<'tcx>) -> bool, - mut action: impl FnMut(usize, &'a Statement<'tcx>), - ) { - while stmt_iter.peek().map_or(false, |(_, stmt)| test(stmt)) { - let (idx, stmt) = stmt_iter.next().unwrap(); - - action(idx, stmt); - } - } - - /// Eats consecutive `StorageLive` and `StorageDead` Statements. - /// The iterator `stmt_iter` is not advanced if none were found. - fn try_eat_storage_stmts( - stmt_iter: &mut StmtIter<'_, '_>, - storage_live_stmts: &mut Vec<(usize, Local)>, - storage_dead_stmts: &mut Vec<(usize, Local)>, - ) { - try_eat(stmt_iter, is_storage_stmt, |idx, stmt| { - if let StatementKind::StorageLive(l) = stmt.kind { - storage_live_stmts.push((idx, l)); - } else if let StatementKind::StorageDead(l) = stmt.kind { - storage_dead_stmts.push((idx, l)); - } - }) - } - - fn is_tmp_storage_stmt(stmt: &Statement<'_>) -> bool { - use rustc_middle::mir::StatementKind::Assign; - if let Assign(box (place, Rvalue::Use(Operand::Copy(p) | Operand::Move(p)))) = &stmt.kind { - place.as_local().is_some() && p.as_local().is_some() - } else { - false - } - } - - /// Eats consecutive `Assign` Statements. - // The iterator `stmt_iter` is not advanced if none were found. - fn try_eat_assign_tmp_stmts( - stmt_iter: &mut StmtIter<'_, '_>, - tmp_assigns: &mut Vec<(Local, Local)>, - nop_stmts: &mut Vec, - ) { - try_eat(stmt_iter, is_tmp_storage_stmt, |idx, stmt| { - use rustc_middle::mir::StatementKind::Assign; - if let Assign(box (place, Rvalue::Use(Operand::Copy(p) | Operand::Move(p)))) = - &stmt.kind - { - tmp_assigns.push((place.as_local().unwrap(), p.as_local().unwrap())); - nop_stmts.push(idx); - } - }) - } - - fn find_storage_live_dead_stmts_for_local( - local: Local, - stmts: &[Statement<'_>], - ) -> Option<(usize, usize)> { - trace!("looking for {:?}", local); - let mut storage_live_stmt = None; - let mut storage_dead_stmt = None; - for (idx, stmt) in stmts.iter().enumerate() { - if stmt.kind == StatementKind::StorageLive(local) { - storage_live_stmt = Some(idx); - } else if stmt.kind == StatementKind::StorageDead(local) { - storage_dead_stmt = Some(idx); - } - } - - Some((storage_live_stmt?, storage_dead_stmt.unwrap_or(usize::MAX))) - } - - // Try to match the expected MIR structure with the basic block we're processing. - // We want to see something that looks like: - // ``` - // (StorageLive(_) | StorageDead(_));* - // _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY); - // (StorageLive(_) | StorageDead(_));* - // (tmp_n+1 = tmp_n);* - // (StorageLive(_) | StorageDead(_));* - // (tmp_n+1 = tmp_n);* - // ((LOCAL_FROM as Variant).FIELD: TY) = move tmp; - // discriminant(LOCAL_FROM) = VariantIdx; - // (StorageLive(_) | StorageDead(_));* - // ``` - let mut stmt_iter = stmts.iter().enumerate().peekable(); - - try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); - - let (get_variant_field_stmt, stmt) = stmt_iter.next()?; - let (local_tmp_s0, local_1, vf_s0, dbg_projection) = match_get_variant_field(stmt)?; - - try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); - - try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts); - - try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); - - try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts); - - let (idx, stmt) = stmt_iter.next()?; - let (local_tmp_s1, local_0, vf_s1) = match_set_variant_field(stmt)?; - nop_stmts.push(idx); - - let (idx, stmt) = stmt_iter.next()?; - let (set_discr_local, set_discr_var_idx) = match_set_discr(stmt)?; - let discr_stmt_source_info = stmt.source_info; - nop_stmts.push(idx); - - try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); - - for (live_idx, live_local) in storage_live_stmts { - if let Some(i) = storage_dead_stmts.iter().rposition(|(_, l)| *l == live_local) { - let (dead_idx, _) = storage_dead_stmts.swap_remove(i); - storage_stmts.push((live_idx, dead_idx, live_local)); - - if live_local == local_tmp_s0 { - nop_stmts.push(get_variant_field_stmt); - } - } - } - // We sort primitive usize here so we can use unstable sort - nop_stmts.sort_unstable(); - - // Use one of the statements we're going to discard between the point - // where the storage location for the variant field becomes live and - // is killed. - let (live_idx, dead_idx) = find_storage_live_dead_stmts_for_local(local_tmp_s0, stmts)?; - let stmt_to_overwrite = - nop_stmts.iter().find(|stmt_idx| live_idx < **stmt_idx && **stmt_idx < dead_idx); - - let mut tmp_assigned_vars = BitSet::new_empty(locals_count); - for (l, r) in &tmp_assigns { - tmp_assigned_vars.insert(*l); - tmp_assigned_vars.insert(*r); - } - - let dbg_info_to_adjust: Vec<_> = debug_info - .iter() - .enumerate() - .filter_map(|(i, var_info)| { - if let VarDebugInfoContents::Place(p) = var_info.value { - if tmp_assigned_vars.contains(p.local) { - return Some(i); - } - } - - None - }) - .collect(); - - Some(ArmIdentityInfo { - local_temp_0: local_tmp_s0, - local_1, - vf_s0, - get_variant_field_stmt, - field_tmp_assignments: tmp_assigns, - local_tmp_s1, - local_0, - vf_s1, - set_discr_local, - set_discr_var_idx, - stmt_to_overwrite: *stmt_to_overwrite?, - source_info: discr_stmt_source_info, - storage_stmts, - stmts_to_remove: nop_stmts, - dbg_info_to_adjust, - dbg_projection, - }) -} - -fn optimization_applies<'tcx>( - opt_info: &ArmIdentityInfo<'tcx>, - local_decls: &IndexVec>, - local_uses: &IndexVec, - var_debug_info: &[VarDebugInfo<'tcx>], -) -> bool { - trace!("testing if optimization applies..."); - - // FIXME(wesleywiser): possibly relax this restriction? - if opt_info.local_0 == opt_info.local_1 { - trace!("NO: moving into ourselves"); - return false; - } else if opt_info.vf_s0 != opt_info.vf_s1 { - trace!("NO: the field-and-variant information do not match"); - return false; - } else if local_decls[opt_info.local_0].ty != local_decls[opt_info.local_1].ty { - // FIXME(Centril,oli-obk): possibly relax to same layout? - trace!("NO: source and target locals have different types"); - return false; - } else if (opt_info.local_0, opt_info.vf_s0.var_idx) - != (opt_info.set_discr_local, opt_info.set_discr_var_idx) - { - trace!("NO: the discriminants do not match"); - return false; - } - - // Verify the assignment chain consists of the form b = a; c = b; d = c; etc... - if opt_info.field_tmp_assignments.is_empty() { - trace!("NO: no assignments found"); - return false; - } - let mut last_assigned_to = opt_info.field_tmp_assignments[0].1; - let source_local = last_assigned_to; - for (l, r) in &opt_info.field_tmp_assignments { - if *r != last_assigned_to { - trace!("NO: found unexpected assignment {:?} = {:?}", l, r); - return false; - } - - last_assigned_to = *l; - } - - // Check that the first and last used locals are only used twice - // since they are of the form: - // - // ``` - // _first = ((_x as Variant).n: ty); - // _n = _first; - // ... - // ((_y as Variant).n: ty) = _n; - // discriminant(_y) = z; - // ``` - for (l, r) in &opt_info.field_tmp_assignments { - if local_uses[*l] != 2 { - warn!("NO: FAILED assignment chain local {:?} was used more than twice", l); - return false; - } else if local_uses[*r] != 2 { - warn!("NO: FAILED assignment chain local {:?} was used more than twice", r); - return false; - } - } - - // Check that debug info only points to full Locals and not projections. - for dbg_idx in &opt_info.dbg_info_to_adjust { - let dbg_info = &var_debug_info[*dbg_idx]; - if let VarDebugInfoContents::Place(p) = dbg_info.value { - if !p.projection.is_empty() { - trace!("NO: debug info for {:?} had a projection {:?}", dbg_info.name, p); - return false; - } - } - } - - if source_local != opt_info.local_temp_0 { - trace!( - "NO: start of assignment chain does not match enum variant temp: {:?} != {:?}", - source_local, - opt_info.local_temp_0 - ); - return false; - } else if last_assigned_to != opt_info.local_tmp_s1 { - trace!( - "NO: end of assignment chain does not match written enum temp: {:?} != {:?}", - last_assigned_to, - opt_info.local_tmp_s1 - ); - return false; - } - - trace!("SUCCESS: optimization applies!"); - true -} - -impl<'tcx> MirPass<'tcx> for SimplifyArmIdentity { - fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - // FIXME(77359): This optimization can result in unsoundness. - if !tcx.sess.opts.unstable_opts.unsound_mir_opts { - return; - } - - let source = body.source; - trace!("running SimplifyArmIdentity on {:?}", source); - - let local_uses = LocalUseCounter::get_local_uses(body); - for bb in body.basic_blocks.as_mut() { - if let Some(opt_info) = - get_arm_identity_info(&bb.statements, body.local_decls.len(), &body.var_debug_info) - { - trace!("got opt_info = {:#?}", opt_info); - if !optimization_applies( - &opt_info, - &body.local_decls, - &local_uses, - &body.var_debug_info, - ) { - debug!("optimization skipped for {:?}", source); - continue; - } - - // Also remove unused Storage{Live,Dead} statements which correspond - // to temps used previously. - for (live_idx, dead_idx, local) in &opt_info.storage_stmts { - // The temporary that we've read the variant field into is scoped to this block, - // so we can remove the assignment. - if *local == opt_info.local_temp_0 { - bb.statements[opt_info.get_variant_field_stmt].make_nop(); - } - - for (left, right) in &opt_info.field_tmp_assignments { - if local == left || local == right { - bb.statements[*live_idx].make_nop(); - bb.statements[*dead_idx].make_nop(); - } - } - } - - // Right shape; transform - for stmt_idx in opt_info.stmts_to_remove { - bb.statements[stmt_idx].make_nop(); - } - - let stmt = &mut bb.statements[opt_info.stmt_to_overwrite]; - stmt.source_info = opt_info.source_info; - stmt.kind = StatementKind::Assign(Box::new(( - opt_info.local_0.into(), - Rvalue::Use(Operand::Move(opt_info.local_1.into())), - ))); - - bb.statements.retain(|stmt| stmt.kind != StatementKind::Nop); - - // Fix the debug info to point to the right local - for dbg_index in opt_info.dbg_info_to_adjust { - let dbg_info = &mut body.var_debug_info[dbg_index]; - assert!( - matches!(dbg_info.value, VarDebugInfoContents::Place(_)), - "value was not a Place" - ); - if let VarDebugInfoContents::Place(p) = &mut dbg_info.value { - assert!(p.projection.is_empty()); - p.local = opt_info.local_0; - p.projection = opt_info.dbg_projection; - } - } - - trace!("block is now {:?}", bb.statements); - } - } - } -} - -struct LocalUseCounter { - local_uses: IndexVec, -} - -impl LocalUseCounter { - fn get_local_uses(body: &Body<'_>) -> IndexVec { - let mut counter = LocalUseCounter { local_uses: IndexVec::from_elem(0, &body.local_decls) }; - counter.visit_body(body); - counter.local_uses - } -} - -impl Visitor<'_> for LocalUseCounter { - fn visit_local(&mut self, local: Local, context: PlaceContext, _location: Location) { - if context.is_storage_marker() - || context == PlaceContext::NonUse(NonUseContext::VarDebugInfo) - { - return; - } - - self.local_uses[local] += 1; - } -} - -/// Match on: -/// ```ignore (MIR) -/// _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY); -/// ``` -fn match_get_variant_field<'tcx>( - stmt: &Statement<'tcx>, -) -> Option<(Local, Local, VarField<'tcx>, &'tcx List>)> { - match &stmt.kind { - StatementKind::Assign(box ( - place_into, - Rvalue::Use(Operand::Copy(pf) | Operand::Move(pf)), - )) => { - let local_into = place_into.as_local()?; - let (local_from, vf) = match_variant_field_place(*pf)?; - Some((local_into, local_from, vf, pf.projection)) - } - _ => None, - } -} - -/// Match on: -/// ```ignore (MIR) -/// ((_LOCAL_FROM as Variant).FIELD: TY) = move _LOCAL_INTO; -/// ``` -fn match_set_variant_field<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, Local, VarField<'tcx>)> { - match &stmt.kind { - StatementKind::Assign(box (place_from, Rvalue::Use(Operand::Move(place_into)))) => { - let local_into = place_into.as_local()?; - let (local_from, vf) = match_variant_field_place(*place_from)?; - Some((local_into, local_from, vf)) - } - _ => None, - } -} - -/// Match on: -/// ```ignore (MIR) -/// discriminant(_LOCAL_TO_SET) = VAR_IDX; -/// ``` -fn match_set_discr(stmt: &Statement<'_>) -> Option<(Local, VariantIdx)> { - match &stmt.kind { - StatementKind::SetDiscriminant { place, variant_index } => { - Some((place.as_local()?, *variant_index)) - } - _ => None, - } -} - -#[derive(PartialEq, Debug)] -struct VarField<'tcx> { - field: Field, - field_ty: Ty<'tcx>, - var_idx: VariantIdx, -} - -/// Match on `((_LOCAL as Variant).FIELD: TY)`. -fn match_variant_field_place(place: Place<'_>) -> Option<(Local, VarField<'_>)> { - match place.as_ref() { - PlaceRef { - local, - projection: &[ProjectionElem::Downcast(_, var_idx), ProjectionElem::Field(field, ty)], - } => Some((local, VarField { field, field_ty: ty, var_idx })), - _ => None, - } -} - -/// Simplifies `SwitchInt(_) -> [targets]`, -/// where all the `targets` have the same form, -/// into `goto -> target_first`. -pub struct SimplifyBranchSame; - -impl<'tcx> MirPass<'tcx> for SimplifyBranchSame { - fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - // This optimization is disabled by default for now due to - // soundness concerns; see issue #89485 and PR #89489. - if !tcx.sess.opts.unstable_opts.unsound_mir_opts { - return; - } - - trace!("Running SimplifyBranchSame on {:?}", body.source); - let finder = SimplifyBranchSameOptimizationFinder { body, tcx }; - let opts = finder.find(); - - let did_remove_blocks = opts.len() > 0; - for opt in opts.iter() { - trace!("SUCCESS: Applying optimization {:?}", opt); - // Replace `SwitchInt(..) -> [bb_first, ..];` with a `goto -> bb_first;`. - body.basic_blocks_mut()[opt.bb_to_opt_terminator].terminator_mut().kind = - TerminatorKind::Goto { target: opt.bb_to_goto }; - } - - if did_remove_blocks { - // We have dead blocks now, so remove those. - simplify::remove_dead_blocks(tcx, body); - } - } -} - -#[derive(Debug)] -struct SimplifyBranchSameOptimization { - /// All basic blocks are equal so go to this one - bb_to_goto: BasicBlock, - /// Basic block where the terminator can be simplified to a goto - bb_to_opt_terminator: BasicBlock, -} - -struct SwitchTargetAndValue { - target: BasicBlock, - // None in case of the `otherwise` case - value: Option, -} - -struct SimplifyBranchSameOptimizationFinder<'a, 'tcx> { - body: &'a Body<'tcx>, - tcx: TyCtxt<'tcx>, -} - -impl<'tcx> SimplifyBranchSameOptimizationFinder<'_, 'tcx> { - fn find(&self) -> Vec { - self.body - .basic_blocks - .iter_enumerated() - .filter_map(|(bb_idx, bb)| { - let (discr_switched_on, targets_and_values) = match &bb.terminator().kind { - TerminatorKind::SwitchInt { targets, discr, .. } => { - let targets_and_values: Vec<_> = targets.iter() - .map(|(val, target)| SwitchTargetAndValue { target, value: Some(val) }) - .chain(once(SwitchTargetAndValue { target: targets.otherwise(), value: None })) - .collect(); - (discr, targets_and_values) - }, - _ => return None, - }; - - // find the adt that has its discriminant read - // assuming this must be the last statement of the block - let adt_matched_on = match &bb.statements.last()?.kind { - StatementKind::Assign(box (place, rhs)) - if Some(*place) == discr_switched_on.place() => - { - match rhs { - Rvalue::Discriminant(adt_place) if adt_place.ty(self.body, self.tcx).ty.is_enum() => adt_place, - _ => { - trace!("NO: expected a discriminant read of an enum instead of: {:?}", rhs); - return None; - } - } - } - other => { - trace!("NO: expected an assignment of a discriminant read to a place. Found: {:?}", other); - return None - }, - }; - - let mut iter_bbs_reachable = targets_and_values - .iter() - .map(|target_and_value| (target_and_value, &self.body.basic_blocks[target_and_value.target])) - .filter(|(_, bb)| { - // Reaching `unreachable` is UB so assume it doesn't happen. - bb.terminator().kind != TerminatorKind::Unreachable - }) - .peekable(); - - let bb_first = iter_bbs_reachable.peek().map_or(&targets_and_values[0], |(idx, _)| *idx); - let mut all_successors_equivalent = StatementEquality::TrivialEqual; - - // All successor basic blocks must be equal or contain statements that are pairwise considered equal. - for ((target_and_value_l,bb_l), (target_and_value_r,bb_r)) in iter_bbs_reachable.tuple_windows() { - let trivial_checks = bb_l.is_cleanup == bb_r.is_cleanup - && bb_l.terminator().kind == bb_r.terminator().kind - && bb_l.statements.len() == bb_r.statements.len(); - let statement_check = || { - bb_l.statements.iter().zip(&bb_r.statements).try_fold(StatementEquality::TrivialEqual, |acc,(l,r)| { - let stmt_equality = self.statement_equality(*adt_matched_on, &l, target_and_value_l, &r, target_and_value_r); - if matches!(stmt_equality, StatementEquality::NotEqual) { - // short circuit - None - } else { - Some(acc.combine(&stmt_equality)) - } - }) - .unwrap_or(StatementEquality::NotEqual) - }; - if !trivial_checks { - all_successors_equivalent = StatementEquality::NotEqual; - break; - } - all_successors_equivalent = all_successors_equivalent.combine(&statement_check()); - }; - - match all_successors_equivalent{ - StatementEquality::TrivialEqual => { - // statements are trivially equal, so just take first - trace!("Statements are trivially equal"); - Some(SimplifyBranchSameOptimization { - bb_to_goto: bb_first.target, - bb_to_opt_terminator: bb_idx, - }) - } - StatementEquality::ConsideredEqual(bb_to_choose) => { - trace!("Statements are considered equal"); - Some(SimplifyBranchSameOptimization { - bb_to_goto: bb_to_choose, - bb_to_opt_terminator: bb_idx, - }) - } - StatementEquality::NotEqual => { - trace!("NO: not all successors of basic block {:?} were equivalent", bb_idx); - None - } - } - }) - .collect() - } - - /// Tests if two statements can be considered equal - /// - /// Statements can be trivially equal if the kinds match. - /// But they can also be considered equal in the following case A: - /// ```ignore (MIR) - /// discriminant(_0) = 0; // bb1 - /// _0 = move _1; // bb2 - /// ``` - /// In this case the two statements are equal iff - /// - `_0` is an enum where the variant index 0 is fieldless, and - /// - bb1 was targeted by a switch where the discriminant of `_1` was switched on - fn statement_equality( - &self, - adt_matched_on: Place<'tcx>, - x: &Statement<'tcx>, - x_target_and_value: &SwitchTargetAndValue, - y: &Statement<'tcx>, - y_target_and_value: &SwitchTargetAndValue, - ) -> StatementEquality { - let helper = |rhs: &Rvalue<'tcx>, - place: &Place<'tcx>, - variant_index: VariantIdx, - switch_value: u128, - side_to_choose| { - let place_type = place.ty(self.body, self.tcx).ty; - let adt = match *place_type.kind() { - ty::Adt(adt, _) if adt.is_enum() => adt, - _ => return StatementEquality::NotEqual, - }; - // We need to make sure that the switch value that targets the bb with - // SetDiscriminant is the same as the variant discriminant. - let variant_discr = adt.discriminant_for_variant(self.tcx, variant_index).val; - if variant_discr != switch_value { - trace!( - "NO: variant discriminant {} does not equal switch value {}", - variant_discr, - switch_value - ); - return StatementEquality::NotEqual; - } - let variant_is_fieldless = adt.variant(variant_index).fields.is_empty(); - if !variant_is_fieldless { - trace!("NO: variant {:?} was not fieldless", variant_index); - return StatementEquality::NotEqual; - } - - match rhs { - Rvalue::Use(operand) if operand.place() == Some(adt_matched_on) => { - StatementEquality::ConsideredEqual(side_to_choose) - } - _ => { - trace!( - "NO: RHS of assignment was {:?}, but expected it to match the adt being matched on in the switch, which is {:?}", - rhs, - adt_matched_on - ); - StatementEquality::NotEqual - } - } - }; - match (&x.kind, &y.kind) { - // trivial case - (x, y) if x == y => StatementEquality::TrivialEqual, - - // check for case A - ( - StatementKind::Assign(box (_, rhs)), - &StatementKind::SetDiscriminant { ref place, variant_index }, - ) if y_target_and_value.value.is_some() => { - // choose basic block of x, as that has the assign - helper( - rhs, - place, - variant_index, - y_target_and_value.value.unwrap(), - x_target_and_value.target, - ) - } - ( - &StatementKind::SetDiscriminant { ref place, variant_index }, - &StatementKind::Assign(box (_, ref rhs)), - ) if x_target_and_value.value.is_some() => { - // choose basic block of y, as that has the assign - helper( - rhs, - place, - variant_index, - x_target_and_value.value.unwrap(), - y_target_and_value.target, - ) - } - _ => { - trace!("NO: statements `{:?}` and `{:?}` not considered equal", x, y); - StatementEquality::NotEqual - } - } - } -} - -#[derive(Copy, Clone, Eq, PartialEq)] -enum StatementEquality { - /// The two statements are trivially equal; same kind - TrivialEqual, - /// The two statements are considered equal, but may be of different kinds. The BasicBlock field is the basic block to jump to when performing the branch-same optimization. - /// For example, `_0 = _1` and `discriminant(_0) = discriminant(0)` are considered equal if 0 is a fieldless variant of an enum. But we don't want to jump to the basic block with the SetDiscriminant, as that is not legal if _1 is not the 0 variant index - ConsideredEqual(BasicBlock), - /// The two statements are not equal - NotEqual, -} - -impl StatementEquality { - fn combine(&self, other: &StatementEquality) -> StatementEquality { - use StatementEquality::*; - match (self, other) { - (TrivialEqual, TrivialEqual) => TrivialEqual, - (TrivialEqual, ConsideredEqual(b)) | (ConsideredEqual(b), TrivialEqual) => { - ConsideredEqual(*b) - } - (ConsideredEqual(b1), ConsideredEqual(b2)) => { - if b1 == b2 { - ConsideredEqual(*b1) - } else { - NotEqual - } - } - (_, NotEqual) | (NotEqual, _) => NotEqual, - } - } -} diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs index 42124f5a4..13168e9a2 100644 --- a/compiler/rustc_mir_transform/src/sroa.rs +++ b/compiler/rustc_mir_transform/src/sroa.rs @@ -1,10 +1,11 @@ use crate::MirPass; -use rustc_data_structures::fx::{FxIndexMap, IndexEntry}; -use rustc_index::bit_set::BitSet; +use rustc_index::bit_set::{BitSet, GrowableBitSet}; use rustc_index::vec::IndexVec; +use rustc_middle::mir::patch::MirPatch; use rustc_middle::mir::visit::*; use rustc_middle::mir::*; -use rustc_middle::ty::TyCtxt; +use rustc_middle::ty::{Ty, TyCtxt}; +use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields}; pub struct ScalarReplacementOfAggregates; @@ -13,27 +14,43 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates { sess.mir_opt_level() >= 3 } + #[instrument(level = "debug", skip(self, tcx, body))] fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let escaping = escaping_locals(&*body); - debug!(?escaping); - let replacements = compute_flattening(tcx, body, escaping); - debug!(?replacements); - replace_flattened_locals(tcx, body, replacements); + debug!(def_id = ?body.source.def_id()); + let mut excluded = excluded_locals(body); + loop { + debug!(?excluded); + let escaping = escaping_locals(&excluded, body); + debug!(?escaping); + let replacements = compute_flattening(tcx, body, escaping); + debug!(?replacements); + let all_dead_locals = replace_flattened_locals(tcx, body, replacements); + if !all_dead_locals.is_empty() { + excluded.union(&all_dead_locals); + excluded = { + let mut growable = GrowableBitSet::from(excluded); + growable.ensure(body.local_decls.len()); + growable.into() + }; + } else { + break; + } + } } } /// Identify all locals that are not eligible for SROA. /// /// There are 3 cases: -/// - the aggegated local is used or passed to other code (function parameters and arguments); +/// - the aggregated local is used or passed to other code (function parameters and arguments); /// - the locals is a union or an enum; /// - the local's address is taken, and thus the relative addresses of the fields are observable to /// client code. -fn escaping_locals(body: &Body<'_>) -> BitSet { +fn escaping_locals(excluded: &BitSet, body: &Body<'_>) -> BitSet { let mut set = BitSet::new_empty(body.local_decls.len()); set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count)); for (local, decl) in body.local_decls().iter_enumerated() { - if decl.ty.is_union() || decl.ty.is_enum() { + if decl.ty.is_union() || decl.ty.is_enum() || excluded.contains(local) { set.insert(local); } } @@ -58,41 +75,33 @@ fn escaping_locals(body: &Body<'_>) -> BitSet { self.super_place(place, context, location); } - fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) { - if let Rvalue::AddressOf(.., place) | Rvalue::Ref(.., place) = rvalue { - if !place.is_indirect() { - // Raw pointers may be used to access anything inside the enclosing place. - self.set.insert(place.local); - return; + fn visit_assign( + &mut self, + lvalue: &Place<'tcx>, + rvalue: &Rvalue<'tcx>, + location: Location, + ) { + if lvalue.as_local().is_some() { + match rvalue { + // Aggregate assignments are expanded in run_pass. + Rvalue::Aggregate(..) | Rvalue::Use(..) => { + self.visit_rvalue(rvalue, location); + return; + } + _ => {} } } - self.super_rvalue(rvalue, location) + self.super_assign(lvalue, rvalue, location) } fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { - if let StatementKind::StorageLive(..) - | StatementKind::StorageDead(..) - | StatementKind::Deinit(..) = statement.kind - { + match statement.kind { // Storage statements are expanded in run_pass. - return; + StatementKind::StorageLive(..) + | StatementKind::StorageDead(..) + | StatementKind::Deinit(..) => return, + _ => self.super_statement(statement, location), } - self.super_statement(statement, location) - } - - fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { - // Drop implicitly calls `drop_in_place`, which takes a `&mut`. - // This implies that `Drop` implicitly takes the address of the place. - if let TerminatorKind::Drop { place, .. } - | TerminatorKind::DropAndReplace { place, .. } = terminator.kind - { - if !place.is_indirect() { - // Raw pointers may be used to access anything inside the enclosing place. - self.set.insert(place.local); - return; - } - } - self.super_terminator(terminator, location); } // We ignore anything that happens in debuginfo, since we expand it using @@ -103,7 +112,30 @@ fn escaping_locals(body: &Body<'_>) -> BitSet { #[derive(Default, Debug)] struct ReplacementMap<'tcx> { - fields: FxIndexMap, Local>, + /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage + /// and deinit statement and debuginfo. + fragments: IndexVec, Local)>>>>, +} + +impl<'tcx> ReplacementMap<'tcx> { + fn replace_place(&self, tcx: TyCtxt<'tcx>, place: PlaceRef<'tcx>) -> Option> { + let &[PlaceElem::Field(f, _), ref rest @ ..] = place.projection else { return None; }; + let fields = self.fragments[place.local].as_ref()?; + let (_, new_local) = fields[f]?; + Some(Place { local: new_local, projection: tcx.mk_place_elems(&rest) }) + } + + fn place_fragments( + &self, + place: Place<'tcx>, + ) -> Option, Local)> + '_> { + let local = place.as_local()?; + let fields = self.fragments[local].as_ref()?; + Some(fields.iter_enumerated().filter_map(|(field, &opt_ty_local)| { + let (ty, local) = opt_ty_local?; + Some((field, ty, local)) + })) + } } /// Compute the replacement of flattened places into locals. @@ -115,53 +147,25 @@ fn compute_flattening<'tcx>( body: &mut Body<'tcx>, escaping: BitSet, ) -> ReplacementMap<'tcx> { - let mut visitor = PreFlattenVisitor { - tcx, - escaping, - local_decls: &mut body.local_decls, - map: Default::default(), - }; - for (block, bbdata) in body.basic_blocks.iter_enumerated() { - visitor.visit_basic_block_data(block, bbdata); - } - return visitor.map; - - struct PreFlattenVisitor<'tcx, 'll> { - tcx: TyCtxt<'tcx>, - local_decls: &'ll mut LocalDecls<'tcx>, - escaping: BitSet, - map: ReplacementMap<'tcx>, - } - - impl<'tcx, 'll> PreFlattenVisitor<'tcx, 'll> { - fn create_place(&mut self, place: PlaceRef<'tcx>) { - if self.escaping.contains(place.local) { - return; - } + let mut fragments = IndexVec::from_elem(None, &body.local_decls); - match self.map.fields.entry(place) { - IndexEntry::Occupied(_) => {} - IndexEntry::Vacant(v) => { - let ty = place.ty(&*self.local_decls, self.tcx).ty; - let local = self.local_decls.push(LocalDecl { - ty, - user_ty: None, - ..self.local_decls[place.local].clone() - }); - v.insert(local); - } - } - } - } - - impl<'tcx, 'll> Visitor<'tcx> for PreFlattenVisitor<'tcx, 'll> { - fn visit_place(&mut self, place: &Place<'tcx>, _: PlaceContext, _: Location) { - if let &[PlaceElem::Field(..), ..] = &place.projection[..] { - let pr = PlaceRef { local: place.local, projection: &place.projection[..1] }; - self.create_place(pr) - } + for local in body.local_decls.indices() { + if escaping.contains(local) { + continue; } + let decl = body.local_decls[local].clone(); + let ty = decl.ty; + iter_fields(ty, tcx, |variant, field, field_ty| { + if variant.is_some() { + // Downcasts are currently not supported. + return; + }; + let new_local = + body.local_decls.push(LocalDecl { ty: field_ty, user_ty: None, ..decl.clone() }); + fragments.get_or_insert_with(local, IndexVec::new).insert(field, (field_ty, new_local)); + }); } + ReplacementMap { fragments } } /// Perform the replacement computed by `compute_flattening`. @@ -169,29 +173,24 @@ fn replace_flattened_locals<'tcx>( tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, replacements: ReplacementMap<'tcx>, -) { - let mut all_dead_locals = BitSet::new_empty(body.local_decls.len()); - for p in replacements.fields.keys() { - all_dead_locals.insert(p.local); +) -> BitSet { + let mut all_dead_locals = BitSet::new_empty(replacements.fragments.len()); + for (local, replacements) in replacements.fragments.iter_enumerated() { + if replacements.is_some() { + all_dead_locals.insert(local); + } } debug!(?all_dead_locals); if all_dead_locals.is_empty() { - return; + return all_dead_locals; } - let mut fragments = IndexVec::new(); - for (k, v) in &replacements.fields { - fragments.ensure_contains_elem(k.local, || Vec::new()); - fragments[k.local].push((k.projection, *v)); - } - debug!(?fragments); - let mut visitor = ReplacementVisitor { tcx, local_decls: &body.local_decls, - replacements, + replacements: &replacements, all_dead_locals, - fragments, + patch: MirPatch::new(body), }; for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() { visitor.visit_basic_block_data(bb, data); @@ -205,6 +204,9 @@ fn replace_flattened_locals<'tcx>( for var_debug_info in &mut body.var_debug_info { visitor.visit_var_debug_info(var_debug_info); } + let ReplacementVisitor { patch, all_dead_locals, .. } = visitor; + patch.apply(body); + all_dead_locals } struct ReplacementVisitor<'tcx, 'll> { @@ -212,40 +214,23 @@ struct ReplacementVisitor<'tcx, 'll> { /// This is only used to compute the type for `VarDebugInfoContents::Composite`. local_decls: &'ll LocalDecls<'tcx>, /// Work to do. - replacements: ReplacementMap<'tcx>, + replacements: &'ll ReplacementMap<'tcx>, /// This is used to check that we are not leaving references to replaced locals behind. all_dead_locals: BitSet, - /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage - /// and deinit statement and debuginfo. - fragments: IndexVec], Local)>>, + patch: MirPatch<'tcx>, } -impl<'tcx, 'll> ReplacementVisitor<'tcx, 'll> { - fn gather_debug_info_fragments( - &self, - place: PlaceRef<'tcx>, - ) -> Vec> { +impl<'tcx> ReplacementVisitor<'tcx, '_> { + fn gather_debug_info_fragments(&self, local: Local) -> Option>> { let mut fragments = Vec::new(); - let parts = &self.fragments[place.local]; - for (proj, replacement_local) in parts { - if proj.starts_with(place.projection) { - fragments.push(VarDebugInfoFragment { - projection: proj[place.projection.len()..].to_vec(), - contents: Place::from(*replacement_local), - }); - } - } - fragments - } - - fn replace_place(&self, place: PlaceRef<'tcx>) -> Option> { - if let &[PlaceElem::Field(..), ref rest @ ..] = place.projection { - let pr = PlaceRef { local: place.local, projection: &place.projection[..1] }; - let local = self.replacements.fields.get(&pr)?; - Some(Place { local: *local, projection: self.tcx.intern_place_elems(&rest) }) - } else { - None + let parts = self.replacements.place_fragments(local.into())?; + for (field, ty, replacement_local) in parts { + fragments.push(VarDebugInfoFragment { + projection: vec![PlaceElem::Field(field, ty)], + contents: Place::from(replacement_local), + }); } + Some(fragments) } } @@ -254,94 +239,188 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> { self.tcx } - fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { - if let StatementKind::StorageLive(..) - | StatementKind::StorageDead(..) - | StatementKind::Deinit(..) = statement.kind - { - // Storage statements are expanded in run_pass. - return; - } - self.super_statement(statement, location) - } - fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) { - if let Some(repl) = self.replace_place(place.as_ref()) { + if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) { *place = repl } else { self.super_place(place, context, location) } } + #[instrument(level = "trace", skip(self))] + fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { + match statement.kind { + // Duplicate storage and deinit statements, as they pretty much apply to all fields. + StatementKind::StorageLive(l) => { + if let Some(final_locals) = self.replacements.place_fragments(l.into()) { + for (_, _, fl) in final_locals { + self.patch.add_statement(location, StatementKind::StorageLive(fl)); + } + statement.make_nop(); + } + return; + } + StatementKind::StorageDead(l) => { + if let Some(final_locals) = self.replacements.place_fragments(l.into()) { + for (_, _, fl) in final_locals { + self.patch.add_statement(location, StatementKind::StorageDead(fl)); + } + statement.make_nop(); + } + return; + } + StatementKind::Deinit(box place) => { + if let Some(final_locals) = self.replacements.place_fragments(place) { + for (_, _, fl) in final_locals { + self.patch + .add_statement(location, StatementKind::Deinit(Box::new(fl.into()))); + } + statement.make_nop(); + return; + } + } + + // We have `a = Struct { 0: x, 1: y, .. }`. + // We replace it by + // ``` + // a_0 = x + // a_1 = y + // ... + // ``` + StatementKind::Assign(box (place, Rvalue::Aggregate(_, ref mut operands))) => { + if let Some(local) = place.as_local() + && let Some(final_locals) = &self.replacements.fragments[local] + { + // This is ok as we delete the statement later. + let operands = std::mem::take(operands); + for (&opt_ty_local, mut operand) in final_locals.iter().zip(operands) { + if let Some((_, new_local)) = opt_ty_local { + // Replace mentions of SROA'd locals that appear in the operand. + self.visit_operand(&mut operand, location); + + let rvalue = Rvalue::Use(operand); + self.patch.add_statement( + location, + StatementKind::Assign(Box::new((new_local.into(), rvalue))), + ); + } + } + statement.make_nop(); + return; + } + } + + // We have `a = some constant` + // We add the projections. + // ``` + // a_0 = a.0 + // a_1 = a.1 + // ... + // ``` + // ConstProp will pick up the pieces and replace them by actual constants. + StatementKind::Assign(box (place, Rvalue::Use(Operand::Constant(_)))) => { + if let Some(final_locals) = self.replacements.place_fragments(place) { + // Put the deaggregated statements *after* the original one. + let location = location.successor_within_block(); + for (field, ty, new_local) in final_locals { + let rplace = self.tcx.mk_place_field(place, field, ty); + let rvalue = Rvalue::Use(Operand::Move(rplace)); + self.patch.add_statement( + location, + StatementKind::Assign(Box::new((new_local.into(), rvalue))), + ); + } + // We still need `place.local` to exist, so don't make it nop. + return; + } + } + + // We have `a = move? place` + // We replace it by + // ``` + // a_0 = move? place.0 + // a_1 = move? place.1 + // ... + // ``` + StatementKind::Assign(box (lhs, Rvalue::Use(ref op))) => { + let (rplace, copy) = match *op { + Operand::Copy(rplace) => (rplace, true), + Operand::Move(rplace) => (rplace, false), + Operand::Constant(_) => bug!(), + }; + if let Some(final_locals) = self.replacements.place_fragments(lhs) { + for (field, ty, new_local) in final_locals { + let rplace = self.tcx.mk_place_field(rplace, field, ty); + debug!(?rplace); + let rplace = self + .replacements + .replace_place(self.tcx, rplace.as_ref()) + .unwrap_or(rplace); + debug!(?rplace); + let rvalue = if copy { + Rvalue::Use(Operand::Copy(rplace)) + } else { + Rvalue::Use(Operand::Move(rplace)) + }; + self.patch.add_statement( + location, + StatementKind::Assign(Box::new((new_local.into(), rvalue))), + ); + } + statement.make_nop(); + return; + } + } + + _ => {} + } + self.super_statement(statement, location) + } + + #[instrument(level = "trace", skip(self))] fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) { match &mut var_debug_info.value { VarDebugInfoContents::Place(ref mut place) => { - if let Some(repl) = self.replace_place(place.as_ref()) { + if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) { *place = repl; - } else if self.all_dead_locals.contains(place.local) { + } else if let Some(local) = place.as_local() + && let Some(fragments) = self.gather_debug_info_fragments(local) + { let ty = place.ty(self.local_decls, self.tcx).ty; - let fragments = self.gather_debug_info_fragments(place.as_ref()); var_debug_info.value = VarDebugInfoContents::Composite { ty, fragments }; } } VarDebugInfoContents::Composite { ty: _, ref mut fragments } => { let mut new_fragments = Vec::new(); + debug!(?fragments); fragments .drain_filter(|fragment| { - if let Some(repl) = self.replace_place(fragment.contents.as_ref()) { + if let Some(repl) = + self.replacements.replace_place(self.tcx, fragment.contents.as_ref()) + { fragment.contents = repl; - true - } else if self.all_dead_locals.contains(fragment.contents.local) { - let frg = self.gather_debug_info_fragments(fragment.contents.as_ref()); + false + } else if let Some(local) = fragment.contents.as_local() + && let Some(frg) = self.gather_debug_info_fragments(local) + { new_fragments.extend(frg.into_iter().map(|mut f| { f.projection.splice(0..0, fragment.projection.iter().copied()); f })); - false - } else { true + } else { + false } }) .for_each(drop); + debug!(?fragments); + debug!(?new_fragments); fragments.extend(new_fragments); } VarDebugInfoContents::Const(_) => {} } } - fn visit_basic_block_data(&mut self, bb: BasicBlock, bbdata: &mut BasicBlockData<'tcx>) { - self.super_basic_block_data(bb, bbdata); - - #[derive(Debug)] - enum Stmt { - StorageLive, - StorageDead, - Deinit, - } - - bbdata.expand_statements(|stmt| { - let source_info = stmt.source_info; - let (stmt, origin_local) = match &stmt.kind { - StatementKind::StorageLive(l) => (Stmt::StorageLive, *l), - StatementKind::StorageDead(l) => (Stmt::StorageDead, *l), - StatementKind::Deinit(p) if let Some(l) = p.as_local() => (Stmt::Deinit, l), - _ => return None, - }; - if !self.all_dead_locals.contains(origin_local) { - return None; - } - let final_locals = self.fragments.get(origin_local)?; - Some(final_locals.iter().map(move |&(_, l)| { - let kind = match stmt { - Stmt::StorageLive => StatementKind::StorageLive(l), - Stmt::StorageDead => StatementKind::StorageDead(l), - Stmt::Deinit => StatementKind::Deinit(Box::new(l.into())), - }; - Statement { source_info, kind } - })) - }); - } - fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) { assert!(!self.all_dead_locals.contains(*local)); } diff --git a/compiler/rustc_mir_transform/src/ssa.rs b/compiler/rustc_mir_transform/src/ssa.rs new file mode 100644 index 000000000..c1e7f62de --- /dev/null +++ b/compiler/rustc_mir_transform/src/ssa.rs @@ -0,0 +1,257 @@ +use either::Either; +use rustc_data_structures::graph::dominators::Dominators; +use rustc_index::bit_set::BitSet; +use rustc_index::vec::IndexVec; +use rustc_middle::middle::resolve_bound_vars::Set1; +use rustc_middle::mir::visit::*; +use rustc_middle::mir::*; +use rustc_middle::ty::{ParamEnv, TyCtxt}; + +#[derive(Debug)] +pub struct SsaLocals { + /// Assignments to each local. This defines whether the local is SSA. + assignments: IndexVec>, + /// We visit the body in reverse postorder, to ensure each local is assigned before it is used. + /// We remember the order in which we saw the assignments to compute the SSA values in a single + /// pass. + assignment_order: Vec, + /// Copy equivalence classes between locals. See `copy_classes` for documentation. + copy_classes: IndexVec, +} + +/// We often encounter MIR bodies with 1 or 2 basic blocks. In those cases, it's unnecessary to +/// actually compute dominators, we can just compare block indices because bb0 is always the first +/// block, and in any body all other blocks are always always dominated by bb0. +struct SmallDominators { + inner: Option>, +} + +trait DomExt { + fn dominates(self, _other: Self, dominators: &SmallDominators) -> bool; +} + +impl DomExt for Location { + fn dominates(self, other: Location, dominators: &SmallDominators) -> bool { + if self.block == other.block { + self.statement_index <= other.statement_index + } else { + dominators.dominates(self.block, other.block) + } + } +} + +impl SmallDominators { + fn dominates(&self, dom: BasicBlock, node: BasicBlock) -> bool { + if let Some(inner) = &self.inner { inner.dominates(dom, node) } else { dom < node } + } +} + +impl SsaLocals { + pub fn new<'tcx>( + tcx: TyCtxt<'tcx>, + param_env: ParamEnv<'tcx>, + body: &Body<'tcx>, + borrowed_locals: &BitSet, + ) -> SsaLocals { + let assignment_order = Vec::new(); + + let assignments = IndexVec::from_elem(Set1::Empty, &body.local_decls); + let dominators = + if body.basic_blocks.len() > 2 { Some(body.basic_blocks.dominators()) } else { None }; + let dominators = SmallDominators { inner: dominators }; + let mut visitor = SsaVisitor { assignments, assignment_order, dominators }; + + for (local, decl) in body.local_decls.iter_enumerated() { + if matches!(body.local_kind(local), LocalKind::Arg) { + visitor.assignments[local] = Set1::One(LocationExtended::Arg); + } + if borrowed_locals.contains(local) && !decl.ty.is_freeze(tcx, param_env) { + visitor.assignments[local] = Set1::Many; + } + } + + if body.basic_blocks.len() > 2 { + for (bb, data) in traversal::reverse_postorder(body) { + visitor.visit_basic_block_data(bb, data); + } + } else { + for (bb, data) in body.basic_blocks.iter_enumerated() { + visitor.visit_basic_block_data(bb, data); + } + } + + for var_debug_info in &body.var_debug_info { + visitor.visit_var_debug_info(var_debug_info); + } + + debug!(?visitor.assignments); + + visitor + .assignment_order + .retain(|&local| matches!(visitor.assignments[local], Set1::One(_))); + debug!(?visitor.assignment_order); + + let copy_classes = compute_copy_classes(&visitor, body); + + SsaLocals { + assignments: visitor.assignments, + assignment_order: visitor.assignment_order, + copy_classes, + } + } + + pub fn is_ssa(&self, local: Local) -> bool { + matches!(self.assignments[local], Set1::One(_)) + } + + pub fn assignments<'a, 'tcx>( + &'a self, + body: &'a Body<'tcx>, + ) -> impl Iterator)> + 'a { + self.assignment_order.iter().filter_map(|&local| { + if let Set1::One(LocationExtended::Plain(loc)) = self.assignments[local] { + // `loc` must point to a direct assignment to `local`. + let Either::Left(stmt) = body.stmt_at(loc) else { bug!() }; + let Some((target, rvalue)) = stmt.kind.as_assign() else { bug!() }; + assert_eq!(target.as_local(), Some(local)); + Some((local, rvalue)) + } else { + None + } + }) + } + + /// Compute the equivalence classes for locals, based on copy statements. + /// + /// The returned vector maps each local to the one it copies. In the following case: + /// _a = &mut _0 + /// _b = move? _a + /// _c = move? _a + /// _d = move? _c + /// We return the mapping + /// _a => _a // not a copy so, represented by itself + /// _b => _a + /// _c => _a + /// _d => _a // transitively through _c + /// + /// Exception: we do not see through the return place, as it cannot be substituted. + pub fn copy_classes(&self) -> &IndexVec { + &self.copy_classes + } + + /// Make a property uniform on a copy equivalence class by removing elements. + pub fn meet_copy_equivalence(&self, property: &mut BitSet) { + // Consolidate to have a local iff all its copies are. + // + // `copy_classes` defines equivalence classes between locals. The `local`s that recursively + // move/copy the same local all have the same `head`. + for (local, &head) in self.copy_classes.iter_enumerated() { + // If any copy does not have `property`, then the head is not. + if !property.contains(local) { + property.remove(head); + } + } + for (local, &head) in self.copy_classes.iter_enumerated() { + // If any copy does not have `property`, then the head doesn't either, + // then no copy has `property`. + if !property.contains(head) { + property.remove(local); + } + } + + // Verify that we correctly computed equivalence classes. + #[cfg(debug_assertions)] + for (local, &head) in self.copy_classes.iter_enumerated() { + assert_eq!(property.contains(local), property.contains(head)); + } + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +enum LocationExtended { + Plain(Location), + Arg, +} + +struct SsaVisitor { + dominators: SmallDominators, + assignments: IndexVec>, + assignment_order: Vec, +} + +impl<'tcx> Visitor<'tcx> for SsaVisitor { + fn visit_local(&mut self, local: Local, ctxt: PlaceContext, loc: Location) { + match ctxt { + PlaceContext::MutatingUse(MutatingUseContext::Store) => { + self.assignments[local].insert(LocationExtended::Plain(loc)); + self.assignment_order.push(local); + } + // Anything can happen with raw pointers, so remove them. + PlaceContext::NonMutatingUse(NonMutatingUseContext::AddressOf) + | PlaceContext::MutatingUse(_) => self.assignments[local] = Set1::Many, + // Immutable borrows are taken into account in `SsaLocals::new` by + // removing non-freeze locals. + PlaceContext::NonMutatingUse(_) => { + let set = &mut self.assignments[local]; + let assign_dominates = match *set { + Set1::Empty | Set1::Many => false, + Set1::One(LocationExtended::Arg) => true, + Set1::One(LocationExtended::Plain(assign)) => { + assign.successor_within_block().dominates(loc, &self.dominators) + } + }; + // We are visiting a use that is not dominated by an assignment. + // Either there is a cycle involved, or we are reading for uninitialized local. + // Bail out. + if !assign_dominates { + *set = Set1::Many; + } + } + PlaceContext::NonUse(_) => {} + } + } +} + +#[instrument(level = "trace", skip(ssa, body))] +fn compute_copy_classes(ssa: &SsaVisitor, body: &Body<'_>) -> IndexVec { + let mut copies = IndexVec::from_fn_n(|l| l, body.local_decls.len()); + + for &local in &ssa.assignment_order { + debug!(?local); + + if local == RETURN_PLACE { + // `_0` is special, we cannot rename it. + continue; + } + + // This is not SSA: mark that we don't know the value. + debug!(assignments = ?ssa.assignments[local]); + let Set1::One(LocationExtended::Plain(loc)) = ssa.assignments[local] else { continue }; + + // `loc` must point to a direct assignment to `local`. + let Either::Left(stmt) = body.stmt_at(loc) else { bug!() }; + let Some((_target, rvalue)) = stmt.kind.as_assign() else { bug!() }; + assert_eq!(_target.as_local(), Some(local)); + + let (Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) | Rvalue::CopyForDeref(place)) + = rvalue + else { continue }; + + let Some(rhs) = place.as_local() else { continue }; + let Set1::One(_) = ssa.assignments[rhs] else { continue }; + + // We visit in `assignment_order`, ie. reverse post-order, so `rhs` has been + // visited before `local`, and we just have to copy the representing local. + copies[local] = copies[rhs]; + } + + debug!(?copies); + + // Invariant: `copies` must point to the head of an equivalence class. + #[cfg(debug_assertions)] + for &head in copies.iter() { + assert_eq!(copies[head], head); + } + + copies +} -- cgit v1.2.3