summaryrefslogtreecommitdiffstats
path: root/compiler/rustc_mir_transform
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform')
-rw-r--r--compiler/rustc_mir_transform/src/abort_unwinding_calls.rs2
-rw-r--r--compiler/rustc_mir_transform/src/check_const_item_mutation.rs2
-rw-r--r--compiler/rustc_mir_transform/src/check_packed_ref.rs88
-rw-r--r--compiler/rustc_mir_transform/src/check_unsafety.rs4
-rw-r--r--compiler/rustc_mir_transform/src/const_goto.rs9
-rw-r--r--compiler/rustc_mir_transform/src/const_prop.rs351
-rw-r--r--compiler/rustc_mir_transform/src/const_prop_lint.rs18
-rw-r--r--compiler/rustc_mir_transform/src/copy_prop.rs182
-rw-r--r--compiler/rustc_mir_transform/src/coverage/counters.rs6
-rw-r--r--compiler/rustc_mir_transform/src/coverage/debug.rs5
-rw-r--r--compiler/rustc_mir_transform/src/coverage/graph.rs14
-rw-r--r--compiler/rustc_mir_transform/src/coverage/mod.rs3
-rw-r--r--compiler/rustc_mir_transform/src/coverage/spans.rs18
-rw-r--r--compiler/rustc_mir_transform/src/ctfe_limit.rs58
-rw-r--r--compiler/rustc_mir_transform/src/dataflow_const_prop.rs150
-rw-r--r--compiler/rustc_mir_transform/src/dead_store_elimination.rs1
-rw-r--r--compiler/rustc_mir_transform/src/deaggregator.rs45
-rw-r--r--compiler/rustc_mir_transform/src/deduce_param_attrs.rs2
-rw-r--r--compiler/rustc_mir_transform/src/dest_prop.rs12
-rw-r--r--compiler/rustc_mir_transform/src/elaborate_box_derefs.rs10
-rw-r--r--compiler/rustc_mir_transform/src/elaborate_drops.rs86
-rw-r--r--compiler/rustc_mir_transform/src/ffi_unwind_calls.rs2
-rw-r--r--compiler/rustc_mir_transform/src/function_item_references.rs13
-rw-r--r--compiler/rustc_mir_transform/src/generator.rs389
-rw-r--r--compiler/rustc_mir_transform/src/inline.rs12
-rw-r--r--compiler/rustc_mir_transform/src/inline/cycle.rs2
-rw-r--r--compiler/rustc_mir_transform/src/instcombine.rs86
-rw-r--r--compiler/rustc_mir_transform/src/large_enums.rs294
-rw-r--r--compiler/rustc_mir_transform/src/lib.rs32
-rw-r--r--compiler/rustc_mir_transform/src/lower_intrinsics.rs26
-rw-r--r--compiler/rustc_mir_transform/src/lower_slice_len.rs5
-rw-r--r--compiler/rustc_mir_transform/src/normalize_array_len.rs322
-rw-r--r--compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs1
-rw-r--r--compiler/rustc_mir_transform/src/remove_zsts.rs2
-rw-r--r--compiler/rustc_mir_transform/src/separate_const_switch.rs2
-rw-r--r--compiler/rustc_mir_transform/src/shim.rs49
-rw-r--r--compiler/rustc_mir_transform/src/simplify.rs20
-rw-r--r--compiler/rustc_mir_transform/src/simplify_try.rs822
-rw-r--r--compiler/rustc_mir_transform/src/sroa.rs439
-rw-r--r--compiler/rustc_mir_transform/src/ssa.rs257
40 files changed, 2041 insertions, 1800 deletions
diff --git a/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs b/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs
index d8f85d2e3..9b4b72070 100644
--- a/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs
+++ b/compiler/rustc_mir_transform/src/abort_unwinding_calls.rs
@@ -41,7 +41,7 @@ impl<'tcx> MirPass<'tcx> for AbortUnwindingCalls {
//
// Here we test for this function itself whether its ABI allows
// unwinding or not.
- let body_ty = tcx.type_of(def_id);
+ let body_ty = tcx.type_of(def_id).skip_binder();
let body_abi = match body_ty.kind() {
ty::FnDef(..) => body_ty.fn_sig(tcx).abi(),
ty::Closure(..) => Abi::RustCall,
diff --git a/compiler/rustc_mir_transform/src/check_const_item_mutation.rs b/compiler/rustc_mir_transform/src/check_const_item_mutation.rs
index fa5f392fa..536745d2c 100644
--- a/compiler/rustc_mir_transform/src/check_const_item_mutation.rs
+++ b/compiler/rustc_mir_transform/src/check_const_item_mutation.rs
@@ -74,7 +74,7 @@ impl<'tcx> ConstMutationChecker<'_, 'tcx> {
//
// `unsafe { *FOO = 0; *BAR.field = 1; }`
// `unsafe { &mut *FOO }`
- // `unsafe { (*ARRAY)[0] = val; }
+ // `unsafe { (*ARRAY)[0] = val; }`
if !place.projection.iter().any(|p| matches!(p, PlaceElem::Deref)) {
let source_info = self.body.source_info(location);
let lint_root = self.body.source_scopes[source_info.scope]
diff --git a/compiler/rustc_mir_transform/src/check_packed_ref.rs b/compiler/rustc_mir_transform/src/check_packed_ref.rs
index 51abcf511..f5f1c1010 100644
--- a/compiler/rustc_mir_transform/src/check_packed_ref.rs
+++ b/compiler/rustc_mir_transform/src/check_packed_ref.rs
@@ -1,17 +1,11 @@
-use rustc_hir::def_id::LocalDefId;
+use rustc_errors::struct_span_err;
use rustc_middle::mir::visit::{PlaceContext, Visitor};
use rustc_middle::mir::*;
-use rustc_middle::ty::query::Providers;
use rustc_middle::ty::{self, TyCtxt};
-use rustc_session::lint::builtin::UNALIGNED_REFERENCES;
use crate::util;
use crate::MirLint;
-pub(crate) fn provide(providers: &mut Providers) {
- *providers = Providers { unsafe_derive_on_repr_packed, ..*providers };
-}
-
pub struct CheckPackedRef;
impl<'tcx> MirLint<'tcx> for CheckPackedRef {
@@ -30,32 +24,6 @@ struct PackedRefChecker<'a, 'tcx> {
source_info: SourceInfo,
}
-fn unsafe_derive_on_repr_packed(tcx: TyCtxt<'_>, def_id: LocalDefId) {
- let lint_hir_id = tcx.hir().local_def_id_to_hir_id(def_id);
-
- // FIXME: when we make this a hard error, this should have its
- // own error code.
-
- let extra = if tcx.generics_of(def_id).own_requires_monomorphization() {
- "with type or const parameters"
- } else {
- "that does not derive `Copy`"
- };
- let message = format!(
- "`{}` can't be derived on this `#[repr(packed)]` struct {}",
- tcx.item_name(tcx.trait_id_of_impl(def_id.to_def_id()).expect("derived trait name")),
- extra
- );
-
- tcx.struct_span_lint_hir(
- UNALIGNED_REFERENCES,
- lint_hir_id,
- tcx.def_span(def_id),
- message,
- |lint| lint,
- );
-}
-
impl<'tcx> Visitor<'tcx> for PackedRefChecker<'_, 'tcx> {
fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
// Make sure we know where in the MIR we are.
@@ -73,40 +41,30 @@ impl<'tcx> Visitor<'tcx> for PackedRefChecker<'_, 'tcx> {
if context.is_borrow() {
if util::is_disaligned(self.tcx, self.body, self.param_env, *place) {
let def_id = self.body.source.instance.def_id();
- if let Some(impl_def_id) = self
- .tcx
- .impl_of_method(def_id)
- .filter(|&def_id| self.tcx.is_builtin_derive(def_id))
+ if let Some(impl_def_id) = self.tcx.impl_of_method(def_id)
+ && self.tcx.is_builtin_derived(impl_def_id)
{
- // If a method is defined in the local crate,
- // the impl containing that method should also be.
- self.tcx.ensure().unsafe_derive_on_repr_packed(impl_def_id.expect_local());
+ // If we ever reach here it means that the generated derive
+ // code is somehow doing an unaligned reference, which it
+ // shouldn't do.
+ span_bug!(self.source_info.span, "builtin derive created an unaligned reference");
} else {
- let source_info = self.source_info;
- let lint_root = self.body.source_scopes[source_info.scope]
- .local_data
- .as_ref()
- .assert_crate_local()
- .lint_root;
- self.tcx.struct_span_lint_hir(
- UNALIGNED_REFERENCES,
- lint_root,
- source_info.span,
- "reference to packed field is unaligned",
- |lint| {
- lint
- .note(
- "fields of packed structs are not properly aligned, and creating \
- a misaligned reference is undefined behavior (even if that \
- reference is never dereferenced)",
- )
- .help(
- "copy the field contents to a local variable, or replace the \
- reference with a raw pointer and use `read_unaligned`/`write_unaligned` \
- (loads and stores via `*p` must be properly aligned even when using raw pointers)"
- )
- },
- );
+ struct_span_err!(
+ self.tcx.sess,
+ self.source_info.span,
+ E0793,
+ "reference to packed field is unaligned"
+ )
+ .note(
+ "fields of packed structs are not properly aligned, and creating \
+ a misaligned reference is undefined behavior (even if that \
+ reference is never dereferenced)",
+ ).help(
+ "copy the field contents to a local variable, or replace the \
+ reference with a raw pointer and use `read_unaligned`/`write_unaligned` \
+ (loads and stores via `*p` must be properly aligned even when using raw pointers)"
+ )
+ .emit();
}
}
}
diff --git a/compiler/rustc_mir_transform/src/check_unsafety.rs b/compiler/rustc_mir_transform/src/check_unsafety.rs
index adf6ae4c7..d00ee1f4b 100644
--- a/compiler/rustc_mir_transform/src/check_unsafety.rs
+++ b/compiler/rustc_mir_transform/src/check_unsafety.rs
@@ -104,6 +104,7 @@ impl<'tcx> Visitor<'tcx> for UnsafetyChecker<'_, 'tcx> {
| StatementKind::AscribeUserType(..)
| StatementKind::Coverage(..)
| StatementKind::Intrinsic(..)
+ | StatementKind::ConstEvalCounter
| StatementKind::Nop => {
// safe (at least as emitted during MIR construction)
}
@@ -125,6 +126,7 @@ impl<'tcx> Visitor<'tcx> for UnsafetyChecker<'_, 'tcx> {
}
}
&AggregateKind::Closure(def_id, _) | &AggregateKind::Generator(def_id, _, _) => {
+ let def_id = def_id.expect_local();
let UnsafetyCheckResult { violations, used_unsafe_blocks, .. } =
self.tcx.unsafety_check_result(def_id);
self.register_violations(violations, used_unsafe_blocks.iter().copied());
@@ -445,7 +447,7 @@ impl<'tcx> intravisit::Visitor<'tcx> for UnusedUnsafeVisitor<'_, 'tcx> {
_fd: &'tcx hir::FnDecl<'tcx>,
b: hir::BodyId,
_s: rustc_span::Span,
- _id: HirId,
+ _id: LocalDefId,
) {
if matches!(fk, intravisit::FnKind::Closure) {
self.visit_body(self.tcx.hir().body(b))
diff --git a/compiler/rustc_mir_transform/src/const_goto.rs b/compiler/rustc_mir_transform/src/const_goto.rs
index 40eefda4f..da101ca7a 100644
--- a/compiler/rustc_mir_transform/src/const_goto.rs
+++ b/compiler/rustc_mir_transform/src/const_goto.rs
@@ -57,6 +57,15 @@ impl<'tcx> MirPass<'tcx> for ConstGoto {
}
impl<'tcx> Visitor<'tcx> for ConstGotoOptimizationFinder<'_, 'tcx> {
+ fn visit_basic_block_data(&mut self, block: BasicBlock, data: &BasicBlockData<'tcx>) {
+ if data.is_cleanup {
+ // Because of the restrictions around control flow in cleanup blocks, we don't perform
+ // this optimization at all in such blocks.
+ return;
+ }
+ self.super_basic_block_data(block, data);
+ }
+
fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
let _: Option<_> = try {
let target = terminator.kind.as_goto()?;
diff --git a/compiler/rustc_mir_transform/src/const_prop.rs b/compiler/rustc_mir_transform/src/const_prop.rs
index 5c45abc5a..6b2eefce2 100644
--- a/compiler/rustc_mir_transform/src/const_prop.rs
+++ b/compiler/rustc_mir_transform/src/const_prop.rs
@@ -5,7 +5,6 @@ use std::cell::Cell;
use either::Right;
-use rustc_ast::Mutability;
use rustc_const_eval::const_eval::CheckAlignment;
use rustc_data_structures::fx::FxHashSet;
use rustc_hir::def::DefKind;
@@ -14,14 +13,10 @@ use rustc_index::vec::IndexVec;
use rustc_middle::mir::visit::{
MutVisitor, MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor,
};
-use rustc_middle::mir::{
- BasicBlock, BinOp, Body, Constant, ConstantKind, Local, LocalDecl, LocalKind, Location,
- Operand, Place, Rvalue, SourceInfo, Statement, StatementKind, Terminator, TerminatorKind, UnOp,
- RETURN_PLACE,
-};
+use rustc_middle::mir::*;
use rustc_middle::ty::layout::{LayoutError, LayoutOf, LayoutOfHelpers, TyAndLayout};
use rustc_middle::ty::InternalSubsts;
-use rustc_middle::ty::{self, ConstKind, Instance, ParamEnv, Ty, TyCtxt, TypeVisitable};
+use rustc_middle::ty::{self, ConstKind, Instance, ParamEnv, Ty, TyCtxt, TypeVisitableExt};
use rustc_span::{def_id::DefId, Span};
use rustc_target::abi::{self, Align, HasDataLayout, Size, TargetDataLayout};
use rustc_target::spec::abi::Abi as CallAbi;
@@ -83,7 +78,7 @@ impl<'tcx> MirPass<'tcx> for ConstProp {
return;
}
- let is_generator = tcx.type_of(def_id.to_def_id()).is_generator();
+ let is_generator = tcx.type_of(def_id.to_def_id()).subst_identity().is_generator();
// FIXME(welseywiser) const prop doesn't work on generators because of query cycles
// computing their layout.
if is_generator {
@@ -289,7 +284,7 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx>
}
// If the static allocation is mutable, then we can't const prop it as its content
// might be different at runtime.
- if alloc.inner().mutability == Mutability::Mut {
+ if alloc.inner().mutability.is_mut() {
throw_machine_stop_str!("can't access mutable globals in ConstProp");
}
@@ -457,27 +452,6 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
};
}
- fn use_ecx<F, T>(&mut self, f: F) -> Option<T>
- where
- F: FnOnce(&mut Self) -> InterpResult<'tcx, T>,
- {
- match f(self) {
- Ok(val) => Some(val),
- Err(error) => {
- trace!("InterpCx operation failed: {:?}", error);
- // Some errors shouldn't come up because creating them causes
- // an allocation, which we should avoid. When that happens,
- // dedicated error variants should be introduced instead.
- assert!(
- !error.kind().formatted_string(),
- "const-prop encountered formatting error: {}",
- error
- );
- None
- }
- }
- }
-
/// Returns the value, if any, of evaluating `c`.
fn eval_constant(&mut self, c: &Constant<'tcx>) -> Option<OpTy<'tcx>> {
// FIXME we need to revisit this for #67176
@@ -492,7 +466,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
/// Returns the value, if any, of evaluating `place`.
fn eval_place(&mut self, place: Place<'tcx>) -> Option<OpTy<'tcx>> {
trace!("eval_place(place={:?})", place);
- self.use_ecx(|this| this.ecx.eval_place_to_op(place, None))
+ self.ecx.eval_place_to_op(place, None).ok()
}
/// Returns the value, if any, of evaluating `op`. Calls upon `eval_constant`
@@ -504,55 +478,6 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
}
}
- fn check_unary_op(&mut self, op: UnOp, arg: &Operand<'tcx>) -> Option<()> {
- if self.use_ecx(|this| {
- let val = this.ecx.read_immediate(&this.ecx.eval_operand(arg, None)?)?;
- let (_res, overflow, _ty) = this.ecx.overflowing_unary_op(op, &val)?;
- Ok(overflow)
- })? {
- // `AssertKind` only has an `OverflowNeg` variant, so make sure that is
- // appropriate to use.
- assert_eq!(op, UnOp::Neg, "Neg is the only UnOp that can overflow");
- return None;
- }
-
- Some(())
- }
-
- fn check_binary_op(
- &mut self,
- op: BinOp,
- left: &Operand<'tcx>,
- right: &Operand<'tcx>,
- ) -> Option<()> {
- let r = self.use_ecx(|this| this.ecx.read_immediate(&this.ecx.eval_operand(right, None)?));
- let l = self.use_ecx(|this| this.ecx.read_immediate(&this.ecx.eval_operand(left, None)?));
- // Check for exceeding shifts *even if* we cannot evaluate the LHS.
- if op == BinOp::Shr || op == BinOp::Shl {
- let r = r.clone()?;
- // We need the type of the LHS. We cannot use `place_layout` as that is the type
- // of the result, which for checked binops is not the same!
- let left_ty = left.ty(self.local_decls, self.tcx);
- let left_size = self.ecx.layout_of(left_ty).ok()?.size;
- let right_size = r.layout.size;
- let r_bits = r.to_scalar().to_bits(right_size).ok();
- if r_bits.map_or(false, |b| b >= left_size.bits() as u128) {
- return None;
- }
- }
-
- if let (Some(l), Some(r)) = (&l, &r) {
- // The remaining operators are handled through `overflowing_binary_op`.
- if self.use_ecx(|this| {
- let (_res, overflow, _ty) = this.ecx.overflowing_binary_op(op, l, r)?;
- Ok(overflow)
- })? {
- return None;
- }
- }
- Some(())
- }
-
fn propagate_operand(&mut self, operand: &mut Operand<'tcx>) {
match *operand {
Operand::Copy(l) | Operand::Move(l) => {
@@ -588,28 +513,6 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
// 2. Working around bugs in other parts of the compiler
// - In this case, we'll return `None` from this function to stop evaluation.
match rvalue {
- // Additional checking: give lints to the user if an overflow would occur.
- // We do this here and not in the `Assert` terminator as that terminator is
- // only sometimes emitted (overflow checks can be disabled), but we want to always
- // lint.
- Rvalue::UnaryOp(op, arg) => {
- trace!("checking UnaryOp(op = {:?}, arg = {:?})", op, arg);
- self.check_unary_op(*op, arg)?;
- }
- Rvalue::BinaryOp(op, box (left, right)) => {
- trace!("checking BinaryOp(op = {:?}, left = {:?}, right = {:?})", op, left, right);
- self.check_binary_op(*op, left, right)?;
- }
- Rvalue::CheckedBinaryOp(op, box (left, right)) => {
- trace!(
- "checking CheckedBinaryOp(op = {:?}, left = {:?}, right = {:?})",
- op,
- left,
- right
- );
- self.check_binary_op(*op, left, right)?;
- }
-
// Do not try creating references (#67862)
Rvalue::AddressOf(_, place) | Rvalue::Ref(_, _, place) => {
trace!("skipping AddressOf | Ref for {:?}", place);
@@ -639,7 +542,10 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
| Rvalue::Cast(..)
| Rvalue::ShallowInitBox(..)
| Rvalue::Discriminant(..)
- | Rvalue::NullaryOp(..) => {}
+ | Rvalue::NullaryOp(..)
+ | Rvalue::UnaryOp(..)
+ | Rvalue::BinaryOp(..)
+ | Rvalue::CheckedBinaryOp(..) => {}
}
// FIXME we need to revisit this for #67176
@@ -664,35 +570,37 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
rvalue: &Rvalue<'tcx>,
place: Place<'tcx>,
) -> Option<()> {
- self.use_ecx(|this| match rvalue {
+ match rvalue {
Rvalue::BinaryOp(op, box (left, right))
| Rvalue::CheckedBinaryOp(op, box (left, right)) => {
- let l = this.ecx.eval_operand(left, None).and_then(|x| this.ecx.read_immediate(&x));
+ let l = self.ecx.eval_operand(left, None).and_then(|x| self.ecx.read_immediate(&x));
let r =
- this.ecx.eval_operand(right, None).and_then(|x| this.ecx.read_immediate(&x));
+ self.ecx.eval_operand(right, None).and_then(|x| self.ecx.read_immediate(&x));
let const_arg = match (l, r) {
(Ok(x), Err(_)) | (Err(_), Ok(x)) => x, // exactly one side is known
- (Err(e), Err(_)) => return Err(e), // neither side is known
- (Ok(_), Ok(_)) => return this.ecx.eval_rvalue_into_place(rvalue, place), // both sides are known
+ (Err(_), Err(_)) => return None, // neither side is known
+ (Ok(_), Ok(_)) => return self.ecx.eval_rvalue_into_place(rvalue, place).ok(), // both sides are known
};
if !matches!(const_arg.layout.abi, abi::Abi::Scalar(..)) {
// We cannot handle Scalar Pair stuff.
// No point in calling `eval_rvalue_into_place`, since only one side is known
- throw_machine_stop_str!("cannot optimize this")
+ return None;
}
- let arg_value = const_arg.to_scalar().to_bits(const_arg.layout.size)?;
- let dest = this.ecx.eval_place(place)?;
+ let arg_value = const_arg.to_scalar().to_bits(const_arg.layout.size).ok()?;
+ let dest = self.ecx.eval_place(place).ok()?;
match op {
- BinOp::BitAnd if arg_value == 0 => this.ecx.write_immediate(*const_arg, &dest),
+ BinOp::BitAnd if arg_value == 0 => {
+ self.ecx.write_immediate(*const_arg, &dest).ok()
+ }
BinOp::BitOr
if arg_value == const_arg.layout.size.truncate(u128::MAX)
|| (const_arg.layout.ty.is_bool() && arg_value == 1) =>
{
- this.ecx.write_immediate(*const_arg, &dest)
+ self.ecx.write_immediate(*const_arg, &dest).ok()
}
BinOp::Mul if const_arg.layout.ty.is_integral() && arg_value == 0 => {
if let Rvalue::CheckedBinaryOp(_, _) = rvalue {
@@ -700,16 +608,16 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
const_arg.to_scalar(),
Scalar::from_bool(false),
);
- this.ecx.write_immediate(val, &dest)
+ self.ecx.write_immediate(val, &dest).ok()
} else {
- this.ecx.write_immediate(*const_arg, &dest)
+ self.ecx.write_immediate(*const_arg, &dest).ok()
}
}
- _ => throw_machine_stop_str!("cannot optimize this"),
+ _ => None,
}
}
- _ => this.ecx.eval_rvalue_into_place(rvalue, place),
- })
+ _ => self.ecx.eval_rvalue_into_place(rvalue, place).ok(),
+ }
}
/// Creates a new `Operand::Constant` from a `Scalar` value
@@ -751,7 +659,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
}
// FIXME> figure out what to do when read_immediate_raw fails
- let imm = self.use_ecx(|this| this.ecx.read_immediate_raw(value));
+ let imm = self.ecx.read_immediate_raw(value).ok();
if let Some(Right(imm)) = imm {
match *imm {
@@ -771,25 +679,23 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
if let ty::Tuple(types) = ty.kind() {
// Only do it if tuple is also a pair with two scalars
if let [ty1, ty2] = types[..] {
- let alloc = self.use_ecx(|this| {
- let ty_is_scalar = |ty| {
- this.ecx.layout_of(ty).ok().map(|layout| layout.abi.is_scalar())
- == Some(true)
- };
- if ty_is_scalar(ty1) && ty_is_scalar(ty2) {
- let alloc = this
- .ecx
- .intern_with_temp_alloc(value.layout, |ecx, dest| {
- ecx.write_immediate(*imm, dest)
- })
- .unwrap();
- Ok(Some(alloc))
- } else {
- Ok(None)
- }
- });
-
- if let Some(Some(alloc)) = alloc {
+ let ty_is_scalar = |ty| {
+ self.ecx.layout_of(ty).ok().map(|layout| layout.abi.is_scalar())
+ == Some(true)
+ };
+ let alloc = if ty_is_scalar(ty1) && ty_is_scalar(ty2) {
+ let alloc = self
+ .ecx
+ .intern_with_temp_alloc(value.layout, |ecx, dest| {
+ ecx.write_immediate(*imm, dest)
+ })
+ .unwrap();
+ Some(alloc)
+ } else {
+ None
+ };
+
+ if let Some(alloc) = alloc {
// Assign entire constant in a single statement.
// We can't use aggregates, as we run after the aggregate-lowering `MirPhase`.
let const_val = ConstValue::ByRef { alloc, offset: Size::ZERO };
@@ -990,84 +896,80 @@ impl<'tcx> MutVisitor<'tcx> for ConstPropagator<'_, 'tcx> {
trace!("visit_statement: {:?}", statement);
let source_info = statement.source_info;
self.source_info = Some(source_info);
- if let StatementKind::Assign(box (place, ref mut rval)) = statement.kind {
- let can_const_prop = self.ecx.machine.can_const_prop[place.local];
- if let Some(()) = self.const_prop(rval, place) {
- // This will return None if the above `const_prop` invocation only "wrote" a
- // type whose creation requires no write. E.g. a generator whose initial state
- // consists solely of uninitialized memory (so it doesn't capture any locals).
- if let Some(ref value) = self.get_const(place) && self.should_const_prop(value) {
- trace!("replacing {:?} with {:?}", rval, value);
- self.replace_with_const(rval, value, source_info);
- if can_const_prop == ConstPropMode::FullConstProp
- || can_const_prop == ConstPropMode::OnlyInsideOwnBlock
- {
- trace!("propagated into {:?}", place);
+ match statement.kind {
+ StatementKind::Assign(box (place, ref mut rval)) => {
+ let can_const_prop = self.ecx.machine.can_const_prop[place.local];
+ if let Some(()) = self.const_prop(rval, place) {
+ // This will return None if the above `const_prop` invocation only "wrote" a
+ // type whose creation requires no write. E.g. a generator whose initial state
+ // consists solely of uninitialized memory (so it doesn't capture any locals).
+ if let Some(ref value) = self.get_const(place) && self.should_const_prop(value) {
+ trace!("replacing {:?} with {:?}", rval, value);
+ self.replace_with_const(rval, value, source_info);
+ if can_const_prop == ConstPropMode::FullConstProp
+ || can_const_prop == ConstPropMode::OnlyInsideOwnBlock
+ {
+ trace!("propagated into {:?}", place);
+ }
}
- }
- match can_const_prop {
- ConstPropMode::OnlyInsideOwnBlock => {
- trace!(
- "found local restricted to its block. \
+ match can_const_prop {
+ ConstPropMode::OnlyInsideOwnBlock => {
+ trace!(
+ "found local restricted to its block. \
Will remove it from const-prop after block is finished. Local: {:?}",
- place.local
- );
- }
- ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => {
- trace!("can't propagate into {:?}", place);
- if place.local != RETURN_PLACE {
- Self::remove_const(&mut self.ecx, place.local);
+ place.local
+ );
}
- }
- ConstPropMode::FullConstProp => {}
- }
- } else {
- // Const prop failed, so erase the destination, ensuring that whatever happens
- // from here on, does not know about the previous value.
- // This is important in case we have
- // ```rust
- // let mut x = 42;
- // x = SOME_MUTABLE_STATIC;
- // // x must now be uninit
- // ```
- // FIXME: we overzealously erase the entire local, because that's easier to
- // implement.
- trace!(
- "propagation into {:?} failed.
- Nuking the entire site from orbit, it's the only way to be sure",
- place,
- );
- Self::remove_const(&mut self.ecx, place.local);
- }
- } else {
- match statement.kind {
- StatementKind::SetDiscriminant { ref place, .. } => {
- match self.ecx.machine.can_const_prop[place.local] {
- ConstPropMode::FullConstProp | ConstPropMode::OnlyInsideOwnBlock => {
- if self.use_ecx(|this| this.ecx.statement(statement)).is_some() {
- trace!("propped discriminant into {:?}", place);
- } else {
+ ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => {
+ trace!("can't propagate into {:?}", place);
+ if place.local != RETURN_PLACE {
Self::remove_const(&mut self.ecx, place.local);
}
}
- ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => {
- Self::remove_const(&mut self.ecx, place.local);
- }
+ ConstPropMode::FullConstProp => {}
}
+ } else {
+ // Const prop failed, so erase the destination, ensuring that whatever happens
+ // from here on, does not know about the previous value.
+ // This is important in case we have
+ // ```rust
+ // let mut x = 42;
+ // x = SOME_MUTABLE_STATIC;
+ // // x must now be uninit
+ // ```
+ // FIXME: we overzealously erase the entire local, because that's easier to
+ // implement.
+ trace!(
+ "propagation into {:?} failed.
+ Nuking the entire site from orbit, it's the only way to be sure",
+ place,
+ );
+ Self::remove_const(&mut self.ecx, place.local);
}
- StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => {
- let frame = self.ecx.frame_mut();
- frame.locals[local].value =
- if let StatementKind::StorageLive(_) = statement.kind {
- LocalValue::Live(interpret::Operand::Immediate(
- interpret::Immediate::Uninit,
- ))
+ }
+ StatementKind::SetDiscriminant { ref place, .. } => {
+ match self.ecx.machine.can_const_prop[place.local] {
+ ConstPropMode::FullConstProp | ConstPropMode::OnlyInsideOwnBlock => {
+ if self.ecx.statement(statement).is_ok() {
+ trace!("propped discriminant into {:?}", place);
} else {
- LocalValue::Dead
- };
+ Self::remove_const(&mut self.ecx, place.local);
+ }
+ }
+ ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => {
+ Self::remove_const(&mut self.ecx, place.local);
+ }
}
- _ => {}
}
+ StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => {
+ let frame = self.ecx.frame_mut();
+ frame.locals[local].value = if let StatementKind::StorageLive(_) = statement.kind {
+ LocalValue::Live(interpret::Operand::Immediate(interpret::Immediate::Uninit))
+ } else {
+ LocalValue::Dead
+ };
+ }
+ _ => {}
}
self.super_statement(statement, location);
@@ -1077,34 +979,19 @@ impl<'tcx> MutVisitor<'tcx> for ConstPropagator<'_, 'tcx> {
let source_info = terminator.source_info;
self.source_info = Some(source_info);
self.super_terminator(terminator, location);
- // Do NOT early return in this function, it does some crucial fixup of the state at the end!
+
match &mut terminator.kind {
TerminatorKind::Assert { expected, ref mut cond, .. } => {
- if let Some(ref value) = self.eval_operand(&cond) {
+ if let Some(ref value) = self.eval_operand(&cond)
+ && let Ok(value_const) = self.ecx.read_scalar(&value)
+ && self.should_const_prop(value)
+ {
trace!("assertion on {:?} should be {:?}", value, expected);
- let expected = Scalar::from_bool(*expected);
- // FIXME should be used use_ecx rather than a local match... but we have
- // quite a few of these read_scalar/read_immediate that need fixing.
- if let Ok(value_const) = self.ecx.read_scalar(&value) {
- if expected != value_const {
- // Poison all places this operand references so that further code
- // doesn't use the invalid value
- match cond {
- Operand::Move(ref place) | Operand::Copy(ref place) => {
- Self::remove_const(&mut self.ecx, place.local);
- }
- Operand::Constant(_) => {}
- }
- } else {
- if self.should_const_prop(value) {
- *cond = self.operand_from_scalar(
- value_const,
- self.tcx.types.bool,
- source_info.span,
- );
- }
- }
- }
+ *cond = self.operand_from_scalar(
+ value_const,
+ self.tcx.types.bool,
+ source_info.span,
+ );
}
}
TerminatorKind::SwitchInt { ref mut discr, .. } => {
@@ -1132,6 +1019,10 @@ impl<'tcx> MutVisitor<'tcx> for ConstPropagator<'_, 'tcx> {
// gated on `mir_opt_level=3`.
TerminatorKind::Call { .. } => {}
}
+ }
+
+ fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) {
+ self.super_basic_block_data(block, data);
// We remove all Locals which are restricted in propagation to their containing blocks and
// which were modified in the current block.
diff --git a/compiler/rustc_mir_transform/src/const_prop_lint.rs b/compiler/rustc_mir_transform/src/const_prop_lint.rs
index 0ab67228f..fd9475748 100644
--- a/compiler/rustc_mir_transform/src/const_prop_lint.rs
+++ b/compiler/rustc_mir_transform/src/const_prop_lint.rs
@@ -21,7 +21,9 @@ use rustc_middle::mir::{
};
use rustc_middle::ty::layout::{LayoutError, LayoutOf, LayoutOfHelpers, TyAndLayout};
use rustc_middle::ty::InternalSubsts;
-use rustc_middle::ty::{self, ConstInt, Instance, ParamEnv, ScalarInt, Ty, TyCtxt, TypeVisitable};
+use rustc_middle::ty::{
+ self, ConstInt, Instance, ParamEnv, ScalarInt, Ty, TyCtxt, TypeVisitableExt,
+};
use rustc_session::lint;
use rustc_span::Span;
use rustc_target::abi::{HasDataLayout, Size, TargetDataLayout};
@@ -57,7 +59,7 @@ impl<'tcx> MirLint<'tcx> for ConstProp {
return;
}
- let is_generator = tcx.type_of(def_id.to_def_id()).is_generator();
+ let is_generator = tcx.type_of(def_id.to_def_id()).subst_identity().is_generator();
// FIXME(welseywiser) const prop doesn't work on generators because of query cycles
// computing their layout.
if is_generator {
@@ -296,7 +298,15 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
return None;
}
- self.use_ecx(source_info, |this| this.ecx.eval_mir_constant(&c.literal, Some(c.span), None))
+ // Normalization needed b/c const prop lint runs in
+ // `mir_drops_elaborated_and_const_checked`, which happens before
+ // optimized MIR. Only after optimizing the MIR can we guarantee
+ // that the `RevealAll` pass has happened and that the body's consts
+ // are normalized, so any call to resolve before that needs to be
+ // manually normalized.
+ let val = self.tcx.try_normalize_erasing_regions(self.param_env, c.literal).ok()?;
+
+ self.use_ecx(source_info, |this| this.ecx.eval_mir_constant(&val, Some(c.span), None))
}
/// Returns the value, if any, of evaluating `place`.
@@ -368,7 +378,7 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
this.ecx.read_immediate(&this.ecx.eval_operand(left, None)?)
});
// Check for exceeding shifts *even if* we cannot evaluate the LHS.
- if op == BinOp::Shr || op == BinOp::Shl {
+ if matches!(op, BinOp::Shr | BinOp::Shl) {
let r = r.clone()?;
// We need the type of the LHS. We cannot use `place_layout` as that is the type
// of the result, which for checked binops is not the same!
diff --git a/compiler/rustc_mir_transform/src/copy_prop.rs b/compiler/rustc_mir_transform/src/copy_prop.rs
new file mode 100644
index 000000000..f27beb64a
--- /dev/null
+++ b/compiler/rustc_mir_transform/src/copy_prop.rs
@@ -0,0 +1,182 @@
+use rustc_index::bit_set::BitSet;
+use rustc_index::vec::IndexVec;
+use rustc_middle::mir::visit::*;
+use rustc_middle::mir::*;
+use rustc_middle::ty::TyCtxt;
+use rustc_mir_dataflow::impls::borrowed_locals;
+
+use crate::ssa::SsaLocals;
+use crate::MirPass;
+
+/// Unify locals that copy each other.
+///
+/// We consider patterns of the form
+/// _a = rvalue
+/// _b = move? _a
+/// _c = move? _a
+/// _d = move? _c
+/// where each of the locals is only assigned once.
+///
+/// We want to replace all those locals by `_a`, either copied or moved.
+pub struct CopyProp;
+
+impl<'tcx> MirPass<'tcx> for CopyProp {
+ fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
+ sess.mir_opt_level() >= 1
+ }
+
+ #[instrument(level = "trace", skip(self, tcx, body))]
+ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+ debug!(def_id = ?body.source.def_id());
+ propagate_ssa(tcx, body);
+ }
+}
+
+fn propagate_ssa<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+ let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id());
+ let borrowed_locals = borrowed_locals(body);
+ let ssa = SsaLocals::new(tcx, param_env, body, &borrowed_locals);
+
+ let fully_moved = fully_moved_locals(&ssa, body);
+ debug!(?fully_moved);
+
+ let mut storage_to_remove = BitSet::new_empty(fully_moved.domain_size());
+ for (local, &head) in ssa.copy_classes().iter_enumerated() {
+ if local != head {
+ storage_to_remove.insert(head);
+ }
+ }
+
+ let any_replacement = ssa.copy_classes().iter_enumerated().any(|(l, &h)| l != h);
+
+ Replacer {
+ tcx,
+ copy_classes: &ssa.copy_classes(),
+ fully_moved,
+ borrowed_locals,
+ storage_to_remove,
+ }
+ .visit_body_preserves_cfg(body);
+
+ if any_replacement {
+ crate::simplify::remove_unused_definitions(body);
+ }
+}
+
+/// `SsaLocals` computed equivalence classes between locals considering copy/move assignments.
+///
+/// This function also returns whether all the `move?` in the pattern are `move` and not copies.
+/// A local which is in the bitset can be replaced by `move _a`. Otherwise, it must be
+/// replaced by `copy _a`, as we cannot move multiple times from `_a`.
+///
+/// If an operand copies `_c`, it must happen before the assignment `_d = _c`, otherwise it is UB.
+/// This means that replacing it by a copy of `_a` if ok, since this copy happens before `_c` is
+/// moved, and therefore that `_d` is moved.
+#[instrument(level = "trace", skip(ssa, body))]
+fn fully_moved_locals(ssa: &SsaLocals, body: &Body<'_>) -> BitSet<Local> {
+ let mut fully_moved = BitSet::new_filled(body.local_decls.len());
+
+ for (_, rvalue) in ssa.assignments(body) {
+ let (Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) | Rvalue::CopyForDeref(place))
+ = rvalue
+ else { continue };
+
+ let Some(rhs) = place.as_local() else { continue };
+ if !ssa.is_ssa(rhs) {
+ continue;
+ }
+
+ if let Rvalue::Use(Operand::Copy(_)) | Rvalue::CopyForDeref(_) = rvalue {
+ fully_moved.remove(rhs);
+ }
+ }
+
+ ssa.meet_copy_equivalence(&mut fully_moved);
+
+ fully_moved
+}
+
+/// Utility to help performing substitution of `*pattern` by `target`.
+struct Replacer<'a, 'tcx> {
+ tcx: TyCtxt<'tcx>,
+ fully_moved: BitSet<Local>,
+ storage_to_remove: BitSet<Local>,
+ borrowed_locals: BitSet<Local>,
+ copy_classes: &'a IndexVec<Local, Local>,
+}
+
+impl<'tcx> MutVisitor<'tcx> for Replacer<'_, 'tcx> {
+ fn tcx(&self) -> TyCtxt<'tcx> {
+ self.tcx
+ }
+
+ fn visit_local(&mut self, local: &mut Local, ctxt: PlaceContext, _: Location) {
+ let new_local = self.copy_classes[*local];
+ match ctxt {
+ // Do not modify the local in storage statements.
+ PlaceContext::NonUse(NonUseContext::StorageLive | NonUseContext::StorageDead) => {}
+ // The local should have been marked as non-SSA.
+ PlaceContext::MutatingUse(_) => assert_eq!(*local, new_local),
+ // We access the value.
+ _ => *local = new_local,
+ }
+ }
+
+ fn visit_place(&mut self, place: &mut Place<'tcx>, ctxt: PlaceContext, loc: Location) {
+ if let Some(new_projection) = self.process_projection(&place.projection, loc) {
+ place.projection = self.tcx().mk_place_elems(&new_projection);
+ }
+
+ let observes_address = match ctxt {
+ PlaceContext::NonMutatingUse(
+ NonMutatingUseContext::SharedBorrow
+ | NonMutatingUseContext::ShallowBorrow
+ | NonMutatingUseContext::UniqueBorrow
+ | NonMutatingUseContext::AddressOf,
+ ) => true,
+ // For debuginfo, merging locals is ok.
+ PlaceContext::NonUse(NonUseContext::VarDebugInfo) => {
+ self.borrowed_locals.contains(place.local)
+ }
+ _ => false,
+ };
+ if observes_address && !place.is_indirect() {
+ // We observe the address of `place.local`. Do not replace it.
+ } else {
+ self.visit_local(
+ &mut place.local,
+ PlaceContext::NonMutatingUse(NonMutatingUseContext::Copy),
+ loc,
+ )
+ }
+ }
+
+ fn visit_operand(&mut self, operand: &mut Operand<'tcx>, loc: Location) {
+ if let Operand::Move(place) = *operand
+ // A move out of a projection of a copy is equivalent to a copy of the original projection.
+ && !place.has_deref()
+ && !self.fully_moved.contains(place.local)
+ {
+ *operand = Operand::Copy(place);
+ }
+ self.super_operand(operand, loc);
+ }
+
+ fn visit_statement(&mut self, stmt: &mut Statement<'tcx>, loc: Location) {
+ match stmt.kind {
+ // When removing storage statements, we need to remove both (#107511).
+ StatementKind::StorageLive(l) | StatementKind::StorageDead(l)
+ if self.storage_to_remove.contains(l) =>
+ {
+ stmt.make_nop()
+ }
+ StatementKind::Assign(box (ref place, ref mut rvalue))
+ if place.as_local().is_some() =>
+ {
+ // Do not replace assignments.
+ self.visit_rvalue(rvalue, loc)
+ }
+ _ => self.super_statement(stmt, loc),
+ }
+ }
+}
diff --git a/compiler/rustc_mir_transform/src/coverage/counters.rs b/compiler/rustc_mir_transform/src/coverage/counters.rs
index 45de0c280..658e01d93 100644
--- a/compiler/rustc_mir_transform/src/coverage/counters.rs
+++ b/compiler/rustc_mir_transform/src/coverage/counters.rs
@@ -520,7 +520,7 @@ impl<'a> BcbCounters<'a> {
let mut found_loop_exit = false;
for &branch in branches.iter() {
if backedge_from_bcbs.iter().any(|&backedge_from_bcb| {
- self.bcb_is_dominated_by(backedge_from_bcb, branch.target_bcb)
+ self.bcb_dominates(branch.target_bcb, backedge_from_bcb)
}) {
if let Some(reloop_branch) = some_reloop_branch {
if reloop_branch.counter(&self.basic_coverage_blocks).is_none() {
@@ -603,8 +603,8 @@ impl<'a> BcbCounters<'a> {
}
#[inline]
- fn bcb_is_dominated_by(&self, node: BasicCoverageBlock, dom: BasicCoverageBlock) -> bool {
- self.basic_coverage_blocks.is_dominated_by(node, dom)
+ fn bcb_dominates(&self, dom: BasicCoverageBlock, node: BasicCoverageBlock) -> bool {
+ self.basic_coverage_blocks.dominates(dom, node)
}
#[inline]
diff --git a/compiler/rustc_mir_transform/src/coverage/debug.rs b/compiler/rustc_mir_transform/src/coverage/debug.rs
index d6a298fad..22ea8710e 100644
--- a/compiler/rustc_mir_transform/src/coverage/debug.rs
+++ b/compiler/rustc_mir_transform/src/coverage/debug.rs
@@ -323,7 +323,10 @@ impl DebugCounters {
String::new()
},
self.format_operand(lhs),
- if op == Op::Add { "+" } else { "-" },
+ match op {
+ Op::Add => "+",
+ Op::Subtract => "-",
+ },
self.format_operand(rhs),
);
}
diff --git a/compiler/rustc_mir_transform/src/coverage/graph.rs b/compiler/rustc_mir_transform/src/coverage/graph.rs
index 78d28f1eb..a2671eef2 100644
--- a/compiler/rustc_mir_transform/src/coverage/graph.rs
+++ b/compiler/rustc_mir_transform/src/coverage/graph.rs
@@ -209,8 +209,8 @@ impl CoverageGraph {
}
#[inline(always)]
- pub fn is_dominated_by(&self, node: BasicCoverageBlock, dom: BasicCoverageBlock) -> bool {
- self.dominators.as_ref().unwrap().is_dominated_by(node, dom)
+ pub fn dominates(&self, dom: BasicCoverageBlock, node: BasicCoverageBlock) -> bool {
+ self.dominators.as_ref().unwrap().dominates(dom, node)
}
#[inline(always)]
@@ -312,7 +312,7 @@ rustc_index::newtype_index! {
/// to the BCB's primary counter or expression).
///
/// The BCB CFG is critical to simplifying the coverage analysis by ensuring graph path-based
-/// queries (`is_dominated_by()`, `predecessors`, `successors`, etc.) have branch (control flow)
+/// queries (`dominates()`, `predecessors`, `successors`, etc.) have branch (control flow)
/// significance.
#[derive(Debug, Clone)]
pub(super) struct BasicCoverageBlockData {
@@ -594,7 +594,7 @@ impl TraverseCoverageGraphWithLoops {
// branching block would have given an `Expression` (or vice versa).
let (some_successor_to_add, some_loop_header) =
if let Some((_, loop_header)) = context.loop_backedges {
- if basic_coverage_blocks.is_dominated_by(successor, loop_header) {
+ if basic_coverage_blocks.dominates(loop_header, successor) {
(Some(successor), Some(loop_header))
} else {
(None, None)
@@ -666,15 +666,15 @@ pub(super) fn find_loop_backedges(
//
// The overall complexity appears to be comparable to many other MIR transform algorithms, and I
// don't expect that this function is creating a performance hot spot, but if this becomes an
- // issue, there may be ways to optimize the `is_dominated_by` algorithm (as indicated by an
+ // issue, there may be ways to optimize the `dominates` algorithm (as indicated by an
// existing `FIXME` comment in that code), or possibly ways to optimize it's usage here, perhaps
// by keeping track of results for visited `BasicCoverageBlock`s if they can be used to short
- // circuit downstream `is_dominated_by` checks.
+ // circuit downstream `dominates` checks.
//
// For now, that kind of optimization seems unnecessarily complicated.
for (bcb, _) in basic_coverage_blocks.iter_enumerated() {
for &successor in &basic_coverage_blocks.successors[bcb] {
- if basic_coverage_blocks.is_dominated_by(bcb, successor) {
+ if basic_coverage_blocks.dominates(successor, bcb) {
let loop_header = successor;
let backedge_from_bcb = bcb;
debug!(
diff --git a/compiler/rustc_mir_transform/src/coverage/mod.rs b/compiler/rustc_mir_transform/src/coverage/mod.rs
index 1468afc64..9a6171598 100644
--- a/compiler/rustc_mir_transform/src/coverage/mod.rs
+++ b/compiler/rustc_mir_transform/src/coverage/mod.rs
@@ -540,7 +540,8 @@ fn fn_sig_and_body(
// FIXME(#79625): Consider improving MIR to provide the information needed, to avoid going back
// to HIR for it.
let hir_node = tcx.hir().get_if_local(def_id).expect("expected DefId is local");
- let fn_body_id = hir::map::associated_body(hir_node).expect("HIR node is a function with body");
+ let (_, fn_body_id) =
+ hir::map::associated_body(hir_node).expect("HIR node is a function with body");
(hir_node.fn_sig(), tcx.hir().body(fn_body_id))
}
diff --git a/compiler/rustc_mir_transform/src/coverage/spans.rs b/compiler/rustc_mir_transform/src/coverage/spans.rs
index c54348404..8ee316773 100644
--- a/compiler/rustc_mir_transform/src/coverage/spans.rs
+++ b/compiler/rustc_mir_transform/src/coverage/spans.rs
@@ -63,7 +63,7 @@ impl CoverageStatement {
/// Note: A `CoverageStatement` merged into another CoverageSpan may come from a `BasicBlock` that
/// is not part of the `CoverageSpan` bcb if the statement was included because it's `Span` matches
/// or is subsumed by the `Span` associated with this `CoverageSpan`, and it's `BasicBlock`
-/// `is_dominated_by()` the `BasicBlock`s in this `CoverageSpan`.
+/// `dominates()` the `BasicBlock`s in this `CoverageSpan`.
#[derive(Debug, Clone)]
pub(super) struct CoverageSpan {
pub span: Span,
@@ -407,7 +407,7 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> {
if self.prev().is_macro_expansion() && self.curr().is_macro_expansion() {
// Macros that expand to include branching (such as
// `assert_eq!()`, `assert_ne!()`, `info!()`, `debug!()`, or
- // `trace!()) typically generate callee spans with identical
+ // `trace!()`) typically generate callee spans with identical
// ranges (typically the full span of the macro) for all
// `BasicBlocks`. This makes it impossible to distinguish
// the condition (`if val1 != val2`) from the optional
@@ -694,7 +694,7 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> {
/// `prev.span.hi()` will be greater than (further right of) `prev_original_span.hi()`.
/// If prev.span() was split off to the right of a closure, prev.span().lo() will be
/// greater than prev_original_span.lo(). The actual span of `prev_original_span` is
- /// not as important as knowing that `prev()` **used to have the same span** as `curr(),
+ /// not as important as knowing that `prev()` **used to have the same span** as `curr()`,
/// which means their sort order is still meaningful for determining the dominator
/// relationship.
///
@@ -705,12 +705,12 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> {
fn hold_pending_dups_unless_dominated(&mut self) {
// Equal coverage spans are ordered by dominators before dominated (if any), so it should be
// impossible for `curr` to dominate any previous `CoverageSpan`.
- debug_assert!(!self.span_bcb_is_dominated_by(self.prev(), self.curr()));
+ debug_assert!(!self.span_bcb_dominates(self.curr(), self.prev()));
let initial_pending_count = self.pending_dups.len();
if initial_pending_count > 0 {
let mut pending_dups = self.pending_dups.split_off(0);
- pending_dups.retain(|dup| !self.span_bcb_is_dominated_by(self.curr(), dup));
+ pending_dups.retain(|dup| !self.span_bcb_dominates(dup, self.curr()));
self.pending_dups.append(&mut pending_dups);
if self.pending_dups.len() < initial_pending_count {
debug!(
@@ -721,7 +721,7 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> {
}
}
- if self.span_bcb_is_dominated_by(self.curr(), self.prev()) {
+ if self.span_bcb_dominates(self.prev(), self.curr()) {
debug!(
" different bcbs but SAME spans, and prev dominates curr. Discard prev={:?}",
self.prev()
@@ -787,8 +787,8 @@ impl<'a, 'tcx> CoverageSpans<'a, 'tcx> {
}
}
- fn span_bcb_is_dominated_by(&self, covspan: &CoverageSpan, dom_covspan: &CoverageSpan) -> bool {
- self.basic_coverage_blocks.is_dominated_by(covspan.bcb, dom_covspan.bcb)
+ fn span_bcb_dominates(&self, dom_covspan: &CoverageSpan, covspan: &CoverageSpan) -> bool {
+ self.basic_coverage_blocks.dominates(dom_covspan.bcb, covspan.bcb)
}
}
@@ -802,6 +802,8 @@ pub(super) fn filtered_statement_span(statement: &Statement<'_>) -> Option<Span>
| StatementKind::StorageDead(_)
// Coverage should not be encountered, but don't inject coverage coverage
| StatementKind::Coverage(_)
+ // Ignore `ConstEvalCounter`s
+ | StatementKind::ConstEvalCounter
// Ignore `Nop`s
| StatementKind::Nop => None,
diff --git a/compiler/rustc_mir_transform/src/ctfe_limit.rs b/compiler/rustc_mir_transform/src/ctfe_limit.rs
new file mode 100644
index 000000000..1b3ac78fb
--- /dev/null
+++ b/compiler/rustc_mir_transform/src/ctfe_limit.rs
@@ -0,0 +1,58 @@
+//! A pass that inserts the `ConstEvalCounter` instruction into any blocks that have a back edge
+//! (thus indicating there is a loop in the CFG), or whose terminator is a function call.
+use crate::MirPass;
+
+use rustc_data_structures::graph::dominators::Dominators;
+use rustc_middle::mir::{
+ BasicBlock, BasicBlockData, Body, Statement, StatementKind, TerminatorKind,
+};
+use rustc_middle::ty::TyCtxt;
+
+pub struct CtfeLimit;
+
+impl<'tcx> MirPass<'tcx> for CtfeLimit {
+ #[instrument(skip(self, _tcx, body))]
+ fn run_pass(&self, _tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+ let doms = body.basic_blocks.dominators();
+ let indices: Vec<BasicBlock> = body
+ .basic_blocks
+ .iter_enumerated()
+ .filter_map(|(node, node_data)| {
+ if matches!(node_data.terminator().kind, TerminatorKind::Call { .. })
+ // Back edges in a CFG indicate loops
+ || has_back_edge(&doms, node, &node_data)
+ {
+ Some(node)
+ } else {
+ None
+ }
+ })
+ .collect();
+ for index in indices {
+ insert_counter(
+ body.basic_blocks_mut()
+ .get_mut(index)
+ .expect("basic_blocks index {index} should exist"),
+ );
+ }
+ }
+}
+
+fn has_back_edge(
+ doms: &Dominators<BasicBlock>,
+ node: BasicBlock,
+ node_data: &BasicBlockData<'_>,
+) -> bool {
+ if !doms.is_reachable(node) {
+ return false;
+ }
+ // Check if any of the dominators of the node are also the node's successor.
+ doms.dominators(node).any(|dom| node_data.terminator().successors().any(|succ| succ == dom))
+}
+
+fn insert_counter(basic_block_data: &mut BasicBlockData<'_>) {
+ basic_block_data.statements.push(Statement {
+ source_info: basic_block_data.terminator().source_info,
+ kind: StatementKind::ConstEvalCounter,
+ });
+}
diff --git a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs
index c75fe2327..49ded10ba 100644
--- a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs
+++ b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs
@@ -5,6 +5,7 @@
use rustc_const_eval::const_eval::CheckAlignment;
use rustc_const_eval::interpret::{ConstValue, ImmTy, Immediate, InterpCx, Scalar};
use rustc_data_structures::fx::FxHashMap;
+use rustc_hir::def::DefKind;
use rustc_middle::mir::visit::{MutVisitor, Visitor};
use rustc_middle::mir::*;
use rustc_middle::ty::{self, Ty, TyCtxt};
@@ -12,6 +13,7 @@ use rustc_mir_dataflow::value_analysis::{Map, State, TrackElem, ValueAnalysis, V
use rustc_mir_dataflow::{lattice::FlatSet, Analysis, ResultsVisitor, SwitchIntEdgeEffects};
use rustc_span::DUMMY_SP;
use rustc_target::abi::Align;
+use rustc_target::abi::VariantIdx;
use crate::MirPass;
@@ -29,14 +31,12 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
#[instrument(skip_all level = "debug")]
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+ debug!(def_id = ?body.source.def_id());
if tcx.sess.mir_opt_level() < 4 && body.basic_blocks.len() > BLOCK_LIMIT {
debug!("aborted dataflow const prop due too many basic blocks");
return;
}
- // Decide which places to track during the analysis.
- let map = Map::from_filter(tcx, body, Ty::is_scalar);
-
// We want to have a somewhat linear runtime w.r.t. the number of statements/terminators.
// Let's call this number `n`. Dataflow analysis has `O(h*n)` transfer function
// applications, where `h` is the height of the lattice. Because the height of our lattice
@@ -45,10 +45,10 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
// `O(num_nodes * tracked_places * n)` in terms of time complexity. Since the number of
// map nodes is strongly correlated to the number of tracked places, this becomes more or
// less `O(n)` if we place a constant limit on the number of tracked places.
- if tcx.sess.mir_opt_level() < 4 && map.tracked_places() > PLACE_LIMIT {
- debug!("aborted dataflow const prop due to too many tracked places");
- return;
- }
+ let place_limit = if tcx.sess.mir_opt_level() < 4 { Some(PLACE_LIMIT) } else { None };
+
+ // Decide which places to track during the analysis.
+ let map = Map::from_filter(tcx, body, Ty::is_scalar, place_limit);
// Perform the actual dataflow analysis.
let analysis = ConstAnalysis::new(tcx, body, map);
@@ -62,14 +62,31 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp {
}
}
-struct ConstAnalysis<'tcx> {
+struct ConstAnalysis<'a, 'tcx> {
map: Map,
tcx: TyCtxt<'tcx>,
+ local_decls: &'a LocalDecls<'tcx>,
ecx: InterpCx<'tcx, 'tcx, DummyMachine>,
param_env: ty::ParamEnv<'tcx>,
}
-impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
+impl<'tcx> ConstAnalysis<'_, 'tcx> {
+ fn eval_discriminant(
+ &self,
+ enum_ty: Ty<'tcx>,
+ variant_index: VariantIdx,
+ ) -> Option<ScalarTy<'tcx>> {
+ if !enum_ty.is_enum() {
+ return None;
+ }
+ let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?;
+ let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?;
+ let discr_value = Scalar::try_from_uint(discr.val, discr_layout.size)?;
+ Some(ScalarTy(discr_value, discr.ty))
+ }
+}
+
+impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> {
type Value = FlatSet<ScalarTy<'tcx>>;
const NAME: &'static str = "ConstAnalysis";
@@ -78,6 +95,25 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
&self.map
}
+ fn handle_statement(&self, statement: &Statement<'tcx>, state: &mut State<Self::Value>) {
+ match statement.kind {
+ StatementKind::SetDiscriminant { box ref place, variant_index } => {
+ state.flood_discr(place.as_ref(), &self.map);
+ if self.map.find_discr(place.as_ref()).is_some() {
+ let enum_ty = place.ty(self.local_decls, self.tcx).ty;
+ if let Some(discr) = self.eval_discriminant(enum_ty, variant_index) {
+ state.assign_discr(
+ place.as_ref(),
+ ValueOrPlace::Value(FlatSet::Elem(discr)),
+ &self.map,
+ );
+ }
+ }
+ }
+ _ => self.super_statement(statement, state),
+ }
+ }
+
fn handle_assign(
&self,
target: Place<'tcx>,
@@ -85,13 +121,59 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
state: &mut State<Self::Value>,
) {
match rvalue {
+ Rvalue::Aggregate(kind, operands) => {
+ // If we assign `target = Enum::Variant#0(operand)`,
+ // we must make sure that all `target as Variant#i` are `Top`.
+ state.flood(target.as_ref(), self.map());
+
+ if let Some(target_idx) = self.map().find(target.as_ref()) {
+ let (variant_target, variant_index) = match **kind {
+ AggregateKind::Tuple | AggregateKind::Closure(..) => {
+ (Some(target_idx), None)
+ }
+ AggregateKind::Adt(def_id, variant_index, ..) => {
+ match self.tcx.def_kind(def_id) {
+ DefKind::Struct => (Some(target_idx), None),
+ DefKind::Enum => (
+ self.map.apply(target_idx, TrackElem::Variant(variant_index)),
+ Some(variant_index),
+ ),
+ _ => (None, None),
+ }
+ }
+ _ => (None, None),
+ };
+ if let Some(variant_target_idx) = variant_target {
+ for (field_index, operand) in operands.iter().enumerate() {
+ if let Some(field) = self.map().apply(
+ variant_target_idx,
+ TrackElem::Field(Field::from_usize(field_index)),
+ ) {
+ let result = self.handle_operand(operand, state);
+ state.insert_idx(field, result, self.map());
+ }
+ }
+ }
+ if let Some(variant_index) = variant_index
+ && let Some(discr_idx) = self.map().apply(target_idx, TrackElem::Discriminant)
+ {
+ // We are assigning the discriminant as part of an aggregate.
+ // This discriminant can only alias a variant field's value if the operand
+ // had an invalid value for that type.
+ // Using invalid values is UB, so we are allowed to perform the assignment
+ // without extra flooding.
+ let enum_ty = target.ty(self.local_decls, self.tcx).ty;
+ if let Some(discr_val) = self.eval_discriminant(enum_ty, variant_index) {
+ state.insert_value_idx(discr_idx, FlatSet::Elem(discr_val), &self.map);
+ }
+ }
+ }
+ }
Rvalue::CheckedBinaryOp(op, box (left, right)) => {
+ // Flood everything now, so we can use `insert_value_idx` directly later.
+ state.flood(target.as_ref(), self.map());
+
let target = self.map().find(target.as_ref());
- if let Some(target) = target {
- // We should not track any projections other than
- // what is overwritten below, but just in case...
- state.flood_idx(target, self.map());
- }
let value_target = target
.and_then(|target| self.map().apply(target, TrackElem::Field(0_u32.into())));
@@ -102,26 +184,19 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
let (val, overflow) = self.binary_op(state, *op, left, right);
if let Some(value_target) = value_target {
- state.assign_idx(value_target, ValueOrPlace::Value(val), self.map());
+ // We have flooded `target` earlier.
+ state.insert_value_idx(value_target, val, self.map());
}
if let Some(overflow_target) = overflow_target {
let overflow = match overflow {
FlatSet::Top => FlatSet::Top,
FlatSet::Elem(overflow) => {
- if overflow {
- // Overflow cannot be reliably propagated. See: https://github.com/rust-lang/rust/pull/101168#issuecomment-1288091446
- FlatSet::Top
- } else {
- self.wrap_scalar(Scalar::from_bool(false), self.tcx.types.bool)
- }
+ self.wrap_scalar(Scalar::from_bool(overflow), self.tcx.types.bool)
}
FlatSet::Bottom => FlatSet::Bottom,
};
- state.assign_idx(
- overflow_target,
- ValueOrPlace::Value(overflow),
- self.map(),
- );
+ // We have flooded `target` earlier.
+ state.insert_value_idx(overflow_target, overflow, self.map());
}
}
}
@@ -170,6 +245,9 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> {
FlatSet::Bottom => ValueOrPlace::Value(FlatSet::Bottom),
FlatSet::Top => ValueOrPlace::Value(FlatSet::Top),
},
+ Rvalue::Discriminant(place) => {
+ ValueOrPlace::Value(state.get_discr(place.as_ref(), self.map()))
+ }
_ => self.super_rvalue(rvalue, state),
}
}
@@ -243,12 +321,13 @@ impl<'tcx> std::fmt::Debug for ScalarTy<'tcx> {
}
}
-impl<'tcx> ConstAnalysis<'tcx> {
- pub fn new(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, map: Map) -> Self {
+impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> {
+ pub fn new(tcx: TyCtxt<'tcx>, body: &'a Body<'tcx>, map: Map) -> Self {
let param_env = tcx.param_env(body.source.def_id());
Self {
map,
tcx,
+ local_decls: &body.local_decls,
ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine),
param_env: param_env,
}
@@ -441,6 +520,21 @@ impl<'tcx, 'map, 'a> Visitor<'tcx> for OperandCollector<'tcx, 'map, 'a> {
_ => (),
}
}
+
+ fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
+ match rvalue {
+ Rvalue::Discriminant(place) => {
+ match self.state.get_discr(place.as_ref(), self.visitor.map) {
+ FlatSet::Top => (),
+ FlatSet::Elem(value) => {
+ self.visitor.before_effect.insert((location, *place), value);
+ }
+ FlatSet::Bottom => (),
+ }
+ }
+ _ => self.super_rvalue(rvalue, location),
+ }
+ }
}
struct DummyMachine;
diff --git a/compiler/rustc_mir_transform/src/dead_store_elimination.rs b/compiler/rustc_mir_transform/src/dead_store_elimination.rs
index 09546330c..9dbfb089d 100644
--- a/compiler/rustc_mir_transform/src/dead_store_elimination.rs
+++ b/compiler/rustc_mir_transform/src/dead_store_elimination.rs
@@ -53,6 +53,7 @@ pub fn eliminate<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, borrowed: &BitS
| StatementKind::StorageDead(_)
| StatementKind::Coverage(_)
| StatementKind::Intrinsic(_)
+ | StatementKind::ConstEvalCounter
| StatementKind::Nop => (),
StatementKind::FakeRead(_) | StatementKind::AscribeUserType(_, _) => {
diff --git a/compiler/rustc_mir_transform/src/deaggregator.rs b/compiler/rustc_mir_transform/src/deaggregator.rs
deleted file mode 100644
index fe272de20..000000000
--- a/compiler/rustc_mir_transform/src/deaggregator.rs
+++ /dev/null
@@ -1,45 +0,0 @@
-use crate::util::expand_aggregate;
-use crate::MirPass;
-use rustc_middle::mir::*;
-use rustc_middle::ty::TyCtxt;
-
-pub struct Deaggregator;
-
-impl<'tcx> MirPass<'tcx> for Deaggregator {
- fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
- let basic_blocks = body.basic_blocks.as_mut_preserves_cfg();
- for bb in basic_blocks {
- bb.expand_statements(|stmt| {
- // FIXME(eddyb) don't match twice on `stmt.kind` (post-NLL).
- match stmt.kind {
- // FIXME(#48193) Deaggregate arrays when it's cheaper to do so.
- StatementKind::Assign(box (
- _,
- Rvalue::Aggregate(box AggregateKind::Array(_), _),
- )) => {
- return None;
- }
- StatementKind::Assign(box (_, Rvalue::Aggregate(_, _))) => {}
- _ => return None,
- }
-
- let stmt = stmt.replace_nop();
- let source_info = stmt.source_info;
- let StatementKind::Assign(box (lhs, Rvalue::Aggregate(kind, operands))) = stmt.kind else {
- bug!();
- };
-
- Some(expand_aggregate(
- lhs,
- operands.into_iter().map(|op| {
- let ty = op.ty(&body.local_decls, tcx);
- (op, ty)
- }),
- *kind,
- source_info,
- tcx,
- ))
- });
- }
- }
-}
diff --git a/compiler/rustc_mir_transform/src/deduce_param_attrs.rs b/compiler/rustc_mir_transform/src/deduce_param_attrs.rs
index ddab7bbb2..89ca04a15 100644
--- a/compiler/rustc_mir_transform/src/deduce_param_attrs.rs
+++ b/compiler/rustc_mir_transform/src/deduce_param_attrs.rs
@@ -163,7 +163,7 @@ pub fn deduced_param_attrs<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> &'tcx [Ded
// Codegen won't use this information for anything if all the function parameters are passed
// directly. Detect that and bail, for compilation speed.
- let fn_ty = tcx.type_of(def_id);
+ let fn_ty = tcx.type_of(def_id).subst_identity();
if matches!(fn_ty.kind(), ty::FnDef(..)) {
if fn_ty
.fn_sig(tcx)
diff --git a/compiler/rustc_mir_transform/src/dest_prop.rs b/compiler/rustc_mir_transform/src/dest_prop.rs
index 08e296a83..2e481b972 100644
--- a/compiler/rustc_mir_transform/src/dest_prop.rs
+++ b/compiler/rustc_mir_transform/src/dest_prop.rs
@@ -136,8 +136,8 @@ use rustc_index::bit_set::BitSet;
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
use rustc_middle::mir::{dump_mir, PassWhere};
use rustc_middle::mir::{
- traversal, BasicBlock, Body, InlineAsmOperand, Local, LocalKind, Location, Operand, Place,
- Rvalue, Statement, StatementKind, TerminatorKind,
+ traversal, Body, InlineAsmOperand, Local, LocalKind, Location, Operand, Place, Rvalue,
+ Statement, StatementKind, TerminatorKind,
};
use rustc_middle::ty::TyCtxt;
use rustc_mir_dataflow::impls::MaybeLiveLocals;
@@ -328,7 +328,8 @@ impl<'a, 'tcx> MutVisitor<'tcx> for Merger<'a, 'tcx> {
match &statement.kind {
StatementKind::Assign(box (dest, rvalue)) => {
match rvalue {
- Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) => {
+ Rvalue::CopyForDeref(place)
+ | Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) => {
// These might've been turned into self-assignments by the replacement
// (this includes the original statement we wanted to eliminate).
if dest == place {
@@ -467,7 +468,7 @@ impl<'a, 'body, 'alloc, 'tcx> FilterInformation<'a, 'body, 'alloc, 'tcx> {
// to reuse the allocation.
write_info: write_info_alloc,
// Doesn't matter what we put here, will be overwritten before being used
- at: Location { block: BasicBlock::from_u32(0), statement_index: 0 },
+ at: Location::START,
};
this.internal_filter_liveness();
}
@@ -577,6 +578,7 @@ impl WriteInfo {
self.add_place(**place);
}
StatementKind::Intrinsic(_)
+ | StatementKind::ConstEvalCounter
| StatementKind::Nop
| StatementKind::Coverage(_)
| StatementKind::StorageLive(_)
@@ -754,7 +756,7 @@ impl<'tcx> Visitor<'tcx> for FindAssignments<'_, '_, 'tcx> {
fn visit_statement(&mut self, statement: &Statement<'tcx>, _: Location) {
if let StatementKind::Assign(box (
lhs,
- Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)),
+ Rvalue::CopyForDeref(rhs) | Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)),
)) = &statement.kind
{
let Some((src, dest)) = places_to_candidate_pair(*lhs, *rhs, self.body) else {
diff --git a/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs b/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs
index 932134bd6..954bb5aff 100644
--- a/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs
+++ b/compiler/rustc_mir_transform/src/elaborate_box_derefs.rs
@@ -17,9 +17,9 @@ pub fn build_ptr_tys<'tcx>(
unique_did: DefId,
nonnull_did: DefId,
) -> (Ty<'tcx>, Ty<'tcx>, Ty<'tcx>) {
- let substs = tcx.intern_substs(&[pointee.into()]);
- let unique_ty = tcx.bound_type_of(unique_did).subst(tcx, substs);
- let nonnull_ty = tcx.bound_type_of(nonnull_did).subst(tcx, substs);
+ let substs = tcx.mk_substs(&[pointee.into()]);
+ let unique_ty = tcx.type_of(unique_did).subst(tcx, substs);
+ let nonnull_ty = tcx.type_of(nonnull_did).subst(tcx, substs);
let ptr_ty = tcx.mk_imm_ptr(pointee);
(unique_ty, nonnull_ty, ptr_ty)
@@ -93,7 +93,7 @@ impl<'tcx> MirPass<'tcx> for ElaborateBoxDerefs {
if let Some(def_id) = tcx.lang_items().owned_box() {
let unique_did = tcx.adt_def(def_id).non_enum_variant().fields[0].did;
- let Some(nonnull_def) = tcx.type_of(unique_did).ty_adt_def() else {
+ let Some(nonnull_def) = tcx.type_of(unique_did).subst_identity().ty_adt_def() else {
span_bug!(tcx.def_span(unique_did), "expected Box to contain Unique")
};
@@ -138,7 +138,7 @@ impl<'tcx> MirPass<'tcx> for ElaborateBoxDerefs {
if let Some(mut new_projections) = new_projections {
new_projections.extend_from_slice(&place.projection[last_deref..]);
- place.projection = tcx.intern_place_elems(&new_projections);
+ place.projection = tcx.mk_place_elems(&new_projections);
}
}
}
diff --git a/compiler/rustc_mir_transform/src/elaborate_drops.rs b/compiler/rustc_mir_transform/src/elaborate_drops.rs
index 65f4956d2..bdfd8dc6e 100644
--- a/compiler/rustc_mir_transform/src/elaborate_drops.rs
+++ b/compiler/rustc_mir_transform/src/elaborate_drops.rs
@@ -18,6 +18,35 @@ use rustc_span::Span;
use rustc_target::abi::VariantIdx;
use std::fmt;
+/// During MIR building, Drop and DropAndReplace terminators are inserted in every place where a drop may occur.
+/// However, in this phase, the presence of these terminators does not guarantee that a destructor will run,
+/// as the target of the drop may be uninitialized.
+/// In general, the compiler cannot determine at compile time whether a destructor will run or not.
+///
+/// At a high level, this pass refines Drop and DropAndReplace to only run the destructor if the
+/// target is initialized. The way this is achievied is by inserting drop flags for every variable
+/// that may be dropped, and then using those flags to determine whether a destructor should run.
+/// This pass also removes DropAndReplace, replacing it with a Drop paired with an assign statement.
+/// Once this is complete, Drop terminators in the MIR correspond to a call to the "drop glue" or
+/// "drop shim" for the type of the dropped place.
+///
+/// This pass relies on dropped places having an associated move path, which is then used to determine
+/// the initialization status of the place and its descendants.
+/// It's worth noting that a MIR containing a Drop without an associated move path is probably ill formed,
+/// as it would allow running a destructor on a place behind a reference:
+///
+/// ```text
+// fn drop_term<T>(t: &mut T) {
+// mir!(
+// {
+// Drop(*t, exit)
+// }
+// exit = {
+// Return()
+// }
+// )
+// }
+/// ```
pub struct ElaborateDrops;
impl<'tcx> MirPass<'tcx> for ElaborateDrops {
@@ -38,13 +67,11 @@ impl<'tcx> MirPass<'tcx> for ElaborateDrops {
};
let un_derefer = UnDerefer { tcx: tcx, derefer_sidetable: side_table };
let elaborate_patch = {
- let body = &*body;
let env = MoveDataParamEnv { move_data, param_env };
- let dead_unwinds = find_dead_unwinds(tcx, body, &env, &un_derefer);
+ remove_dead_unwinds(tcx, body, &env, &un_derefer);
let inits = MaybeInitializedPlaces::new(tcx, body, &env)
.into_engine(tcx, body)
- .dead_unwinds(&dead_unwinds)
.pass_name("elaborate_drops")
.iterate_to_fixpoint()
.into_results_cursor(body);
@@ -52,11 +79,12 @@ impl<'tcx> MirPass<'tcx> for ElaborateDrops {
let uninits = MaybeUninitializedPlaces::new(tcx, body, &env)
.mark_inactive_variants_as_uninit()
.into_engine(tcx, body)
- .dead_unwinds(&dead_unwinds)
.pass_name("elaborate_drops")
.iterate_to_fixpoint()
.into_results_cursor(body);
+ let reachable = traversal::reachable_as_bitset(body);
+
ElaborateDropsCtxt {
tcx,
body,
@@ -65,6 +93,7 @@ impl<'tcx> MirPass<'tcx> for ElaborateDrops {
drop_flags: Default::default(),
patch: MirPatch::new(body),
un_derefer: un_derefer,
+ reachable,
}
.elaborate()
};
@@ -73,22 +102,21 @@ impl<'tcx> MirPass<'tcx> for ElaborateDrops {
}
}
-/// Returns the set of basic blocks whose unwind edges are known
-/// to not be reachable, because they are `drop` terminators
+/// Removes unwind edges which are known to be unreachable, because they are in `drop` terminators
/// that can't drop anything.
-fn find_dead_unwinds<'tcx>(
+fn remove_dead_unwinds<'tcx>(
tcx: TyCtxt<'tcx>,
- body: &Body<'tcx>,
+ body: &mut Body<'tcx>,
env: &MoveDataParamEnv<'tcx>,
und: &UnDerefer<'tcx>,
-) -> BitSet<BasicBlock> {
- debug!("find_dead_unwinds({:?})", body.span);
+) {
+ debug!("remove_dead_unwinds({:?})", body.span);
// We only need to do this pass once, because unwind edges can only
// reach cleanup blocks, which can't have unwind edges themselves.
- let mut dead_unwinds = BitSet::new_empty(body.basic_blocks.len());
+ let mut dead_unwinds = Vec::new();
let mut flow_inits = MaybeInitializedPlaces::new(tcx, body, &env)
.into_engine(tcx, body)
- .pass_name("find_dead_unwinds")
+ .pass_name("remove_dead_unwinds")
.iterate_to_fixpoint()
.into_results_cursor(body);
for (bb, bb_data) in body.basic_blocks.iter_enumerated() {
@@ -100,16 +128,16 @@ fn find_dead_unwinds<'tcx>(
_ => continue,
};
- debug!("find_dead_unwinds @ {:?}: {:?}", bb, bb_data);
+ debug!("remove_dead_unwinds @ {:?}: {:?}", bb, bb_data);
let LookupResult::Exact(path) = env.move_data.rev_lookup.find(place.as_ref()) else {
- debug!("find_dead_unwinds: has parent; skipping");
+ debug!("remove_dead_unwinds: has parent; skipping");
continue;
};
flow_inits.seek_before_primary_effect(body.terminator_loc(bb));
debug!(
- "find_dead_unwinds @ {:?}: path({:?})={:?}; init_data={:?}",
+ "remove_dead_unwinds @ {:?}: path({:?})={:?}; init_data={:?}",
bb,
place,
path,
@@ -121,13 +149,22 @@ fn find_dead_unwinds<'tcx>(
maybe_live |= flow_inits.contains(child);
});
- debug!("find_dead_unwinds @ {:?}: maybe_live={}", bb, maybe_live);
+ debug!("remove_dead_unwinds @ {:?}: maybe_live={}", bb, maybe_live);
if !maybe_live {
- dead_unwinds.insert(bb);
+ dead_unwinds.push(bb);
}
}
- dead_unwinds
+ if dead_unwinds.is_empty() {
+ return;
+ }
+
+ let basic_blocks = body.basic_blocks.as_mut();
+ for &bb in dead_unwinds.iter() {
+ if let Some(unwind) = basic_blocks[bb].terminator_mut().unwind_mut() {
+ *unwind = None;
+ }
+ }
}
struct InitializationData<'mir, 'tcx> {
@@ -261,6 +298,7 @@ struct ElaborateDropsCtxt<'a, 'tcx> {
drop_flags: FxHashMap<MovePathIndex, Local>,
patch: MirPatch<'tcx>,
un_derefer: UnDerefer<'tcx>,
+ reachable: BitSet<BasicBlock>,
}
impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> {
@@ -300,6 +338,9 @@ impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> {
fn collect_drop_flags(&mut self) {
for (bb, data) in self.body.basic_blocks.iter_enumerated() {
+ if !self.reachable.contains(bb) {
+ continue;
+ }
let terminator = data.terminator();
let place = match terminator.kind {
TerminatorKind::Drop { ref place, .. }
@@ -355,6 +396,9 @@ impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> {
fn elaborate_drops(&mut self) {
for (bb, data) in self.body.basic_blocks.iter_enumerated() {
+ if !self.reachable.contains(bb) {
+ continue;
+ }
let loc = Location { block: bb, statement_index: data.statements.len() };
let terminator = data.terminator();
@@ -512,6 +556,9 @@ impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> {
fn drop_flags_for_fn_rets(&mut self) {
for (bb, data) in self.body.basic_blocks.iter_enumerated() {
+ if !self.reachable.contains(bb) {
+ continue;
+ }
if let TerminatorKind::Call {
destination, target: Some(tgt), cleanup: Some(_), ..
} = data.terminator().kind
@@ -547,6 +594,9 @@ impl<'b, 'tcx> ElaborateDropsCtxt<'b, 'tcx> {
// clobbered before they are read.
for (bb, data) in self.body.basic_blocks.iter_enumerated() {
+ if !self.reachable.contains(bb) {
+ continue;
+ }
debug!("drop_flags_for_locs({:?})", data);
for i in 0..(data.statements.len() + 1) {
debug!("drop_flag_for_locs: stmt {}", i);
diff --git a/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs b/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs
index 1244c1802..e6546911a 100644
--- a/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs
+++ b/compiler/rustc_mir_transform/src/ffi_unwind_calls.rs
@@ -49,7 +49,7 @@ fn has_ffi_unwind_calls(tcx: TyCtxt<'_>, local_def_id: LocalDefId) -> bool {
let body = &*tcx.mir_built(ty::WithOptConstParam::unknown(local_def_id)).borrow();
- let body_ty = tcx.type_of(def_id);
+ let body_ty = tcx.type_of(def_id).skip_binder();
let body_abi = match body_ty.kind() {
ty::FnDef(..) => body_ty.fn_sig(tcx).abi(),
ty::Closure(..) => Abi::RustCall,
diff --git a/compiler/rustc_mir_transform/src/function_item_references.rs b/compiler/rustc_mir_transform/src/function_item_references.rs
index b708d780b..66d32b954 100644
--- a/compiler/rustc_mir_transform/src/function_item_references.rs
+++ b/compiler/rustc_mir_transform/src/function_item_references.rs
@@ -79,7 +79,7 @@ impl<'tcx> FunctionItemRefChecker<'_, 'tcx> {
for bound in bounds {
if let Some(bound_ty) = self.is_pointer_trait(&bound.kind().skip_binder()) {
// Get the argument types as they appear in the function signature.
- let arg_defs = self.tcx.fn_sig(def_id).skip_binder().inputs();
+ let arg_defs = self.tcx.fn_sig(def_id).subst_identity().skip_binder().inputs();
for (arg_num, arg_def) in arg_defs.iter().enumerate() {
// For all types reachable from the argument type in the fn sig
for generic_inner_ty in arg_def.walk() {
@@ -111,11 +111,9 @@ impl<'tcx> FunctionItemRefChecker<'_, 'tcx> {
/// If the given predicate is the trait `fmt::Pointer`, returns the bound parameter type.
fn is_pointer_trait(&self, bound: &PredicateKind<'tcx>) -> Option<Ty<'tcx>> {
if let ty::PredicateKind::Clause(ty::Clause::Trait(predicate)) = bound {
- if self.tcx.is_diagnostic_item(sym::Pointer, predicate.def_id()) {
- Some(predicate.trait_ref.self_ty())
- } else {
- None
- }
+ self.tcx
+ .is_diagnostic_item(sym::Pointer, predicate.def_id())
+ .then(|| predicate.trait_ref.self_ty())
} else {
None
}
@@ -161,7 +159,8 @@ impl<'tcx> FunctionItemRefChecker<'_, 'tcx> {
.as_ref()
.assert_crate_local()
.lint_root;
- let fn_sig = self.tcx.fn_sig(fn_id);
+ // FIXME: use existing printing routines to print the function signature
+ let fn_sig = self.tcx.fn_sig(fn_id).subst(self.tcx, fn_substs);
let unsafety = fn_sig.unsafety().prefix_str();
let abi = match fn_sig.abi() {
Abi::Rust => String::from(""),
diff --git a/compiler/rustc_mir_transform/src/generator.rs b/compiler/rustc_mir_transform/src/generator.rs
index 39c61a34a..2e97312ee 100644
--- a/compiler/rustc_mir_transform/src/generator.rs
+++ b/compiler/rustc_mir_transform/src/generator.rs
@@ -52,9 +52,9 @@
use crate::deref_separator::deref_finder;
use crate::simplify;
-use crate::util::expand_aggregate;
use crate::MirPass;
-use rustc_data_structures::fx::FxHashMap;
+use rustc_data_structures::fx::{FxHashMap, FxHashSet};
+use rustc_errors::pluralize;
use rustc_hir as hir;
use rustc_hir::lang_items::LangItem;
use rustc_hir::GeneratorKind;
@@ -70,6 +70,9 @@ use rustc_mir_dataflow::impls::{
};
use rustc_mir_dataflow::storage::always_storage_live_locals;
use rustc_mir_dataflow::{self, Analysis};
+use rustc_span::def_id::DefId;
+use rustc_span::symbol::sym;
+use rustc_span::Span;
use rustc_target::abi::VariantIdx;
use rustc_target::spec::PanicStrategy;
use std::{iter, ops};
@@ -123,7 +126,7 @@ impl<'tcx> MutVisitor<'tcx> for DerefArgVisitor<'tcx> {
place,
Place {
local: SELF_ARG,
- projection: self.tcx().intern_place_elems(&[ProjectionElem::Deref]),
+ projection: self.tcx().mk_place_elems(&[ProjectionElem::Deref]),
},
self.tcx,
);
@@ -159,10 +162,9 @@ impl<'tcx> MutVisitor<'tcx> for PinArgVisitor<'tcx> {
place,
Place {
local: SELF_ARG,
- projection: self.tcx().intern_place_elems(&[ProjectionElem::Field(
- Field::new(0),
- self.ref_gen_ty,
- )]),
+ projection: self
+ .tcx()
+ .mk_place_elems(&[ProjectionElem::Field(Field::new(0), self.ref_gen_ty)]),
},
self.tcx,
);
@@ -184,7 +186,7 @@ fn replace_base<'tcx>(place: &mut Place<'tcx>, new_base: Place<'tcx>, tcx: TyCtx
let mut new_projection = new_base.projection.to_vec();
new_projection.append(&mut place.projection.to_vec());
- place.projection = tcx.intern_place_elems(&new_projection);
+ place.projection = tcx.mk_place_elems(&new_projection);
}
const SELF_ARG: Local = Local::from_u32(1);
@@ -268,31 +270,26 @@ impl<'tcx> TransformVisitor<'tcx> {
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
// FIXME(swatinem): assert that `val` is indeed unit?
- statements.extend(expand_aggregate(
- Place::return_place(),
- std::iter::empty(),
- kind,
+ statements.push(Statement {
+ kind: StatementKind::Assign(Box::new((
+ Place::return_place(),
+ Rvalue::Aggregate(Box::new(kind), vec![]),
+ ))),
source_info,
- self.tcx,
- ));
+ });
return;
}
// else: `Poll::Ready(x)`, `GeneratorState::Yielded(x)` or `GeneratorState::Complete(x)`
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1);
- let ty = self
- .tcx
- .bound_type_of(self.state_adt_ref.variant(idx).fields[0].did)
- .subst(self.tcx, self.state_substs);
-
- statements.extend(expand_aggregate(
- Place::return_place(),
- std::iter::once((val, ty)),
- kind,
+ statements.push(Statement {
+ kind: StatementKind::Assign(Box::new((
+ Place::return_place(),
+ Rvalue::Aggregate(Box::new(kind), vec![val]),
+ ))),
source_info,
- self.tcx,
- ));
+ });
}
// Create a Place referencing a generator struct field
@@ -302,7 +299,7 @@ impl<'tcx> TransformVisitor<'tcx> {
let mut projection = base.projection.to_vec();
projection.push(ProjectionElem::Field(Field::new(idx), ty));
- Place { local: base.local, projection: self.tcx.intern_place_elems(&projection) }
+ Place { local: base.local, projection: self.tcx.mk_place_elems(&projection) }
}
// Create a statement which changes the discriminant
@@ -429,7 +426,7 @@ fn make_generator_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
let pin_did = tcx.require_lang_item(LangItem::Pin, Some(body.span));
let pin_adt_ref = tcx.adt_def(pin_did);
- let substs = tcx.intern_substs(&[ref_gen_ty.into()]);
+ let substs = tcx.mk_substs(&[ref_gen_ty.into()]);
let pin_ref_gen_ty = tcx.mk_adt(pin_adt_ref, substs);
// Replace the by ref generator argument
@@ -489,7 +486,7 @@ fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, None);
- for bb in BasicBlock::new(0)..body.basic_blocks.next_index() {
+ for bb in START_BLOCK..body.basic_blocks.next_index() {
let bb_data = &body[bb];
if bb_data.is_cleanup {
continue;
@@ -854,7 +851,7 @@ fn sanitize_witness<'tcx>(
body: &Body<'tcx>,
witness: Ty<'tcx>,
upvars: Vec<Ty<'tcx>>,
- saved_locals: &GeneratorSavedLocals,
+ layout: &GeneratorLayout<'tcx>,
) {
let did = body.source.def_id();
let param_env = tcx.param_env(did);
@@ -873,31 +870,36 @@ fn sanitize_witness<'tcx>(
}
};
- for (local, decl) in body.local_decls.iter_enumerated() {
- // Ignore locals which are internal or not saved between yields.
- if !saved_locals.contains(local) || decl.internal {
+ let mut mismatches = Vec::new();
+ for fty in &layout.field_tys {
+ if fty.ignore_for_traits {
continue;
}
- let decl_ty = tcx.normalize_erasing_regions(param_env, decl.ty);
+ let decl_ty = tcx.normalize_erasing_regions(param_env, fty.ty);
// Sanity check that typeck knows about the type of locals which are
// live across a suspension point
if !allowed.contains(&decl_ty) && !allowed_upvars.contains(&decl_ty) {
- span_bug!(
- body.span,
- "Broken MIR: generator contains type {} in MIR, \
- but typeck only knows about {} and {:?}",
- decl_ty,
- allowed,
- allowed_upvars
- );
+ mismatches.push(decl_ty);
}
}
+
+ if !mismatches.is_empty() {
+ span_bug!(
+ body.span,
+ "Broken MIR: generator contains type {:?} in MIR, \
+ but typeck only knows about {} and {:?}",
+ mismatches,
+ allowed,
+ allowed_upvars
+ );
+ }
}
fn compute_layout<'tcx>(
+ tcx: TyCtxt<'tcx>,
liveness: LivenessInfo,
- body: &mut Body<'tcx>,
+ body: &Body<'tcx>,
) -> (
FxHashMap<Local, (Ty<'tcx>, VariantIdx, usize)>,
GeneratorLayout<'tcx>,
@@ -915,9 +917,33 @@ fn compute_layout<'tcx>(
let mut locals = IndexVec::<GeneratorSavedLocal, _>::new();
let mut tys = IndexVec::<GeneratorSavedLocal, _>::new();
for (saved_local, local) in saved_locals.iter_enumerated() {
- locals.push(local);
- tys.push(body.local_decls[local].ty);
debug!("generator saved local {:?} => {:?}", saved_local, local);
+
+ locals.push(local);
+ let decl = &body.local_decls[local];
+ debug!(?decl);
+
+ let ignore_for_traits = if tcx.sess.opts.unstable_opts.drop_tracking_mir {
+ match decl.local_info {
+ // Do not include raw pointers created from accessing `static` items, as those could
+ // well be re-created by another access to the same static.
+ Some(box LocalInfo::StaticRef { is_thread_local, .. }) => !is_thread_local,
+ // Fake borrows are only read by fake reads, so do not have any reality in
+ // post-analysis MIR.
+ Some(box LocalInfo::FakeBorrow) => true,
+ _ => false,
+ }
+ } else {
+ // FIXME(#105084) HIR-based drop tracking does not account for all the temporaries that
+ // MIR building may introduce. This leads to wrongly ignored types, but this is
+ // necessary for internal consistency and to avoid ICEs.
+ decl.internal
+ };
+ let decl =
+ GeneratorSavedTy { ty: decl.ty, source_info: decl.source_info, ignore_for_traits };
+ debug!(?decl);
+
+ tys.push(decl);
}
// Leave empty variants for the UNRESUMED, RETURNED, and POISONED states.
@@ -947,7 +973,7 @@ fn compute_layout<'tcx>(
// just use the first one here. That's fine; fields do not move
// around inside generators, so it doesn't matter which variant
// index we access them by.
- remap.entry(locals[saved_local]).or_insert((tys[saved_local], variant_index, idx));
+ remap.entry(locals[saved_local]).or_insert((tys[saved_local].ty, variant_index, idx));
}
variant_fields.push(fields);
variant_source_info.push(source_info_at_suspension_points[suspension_point_idx]);
@@ -957,6 +983,7 @@ fn compute_layout<'tcx>(
let layout =
GeneratorLayout { field_tys: tys, variant_fields, variant_source_info, storage_conflicts };
+ debug!(?layout);
(remap, layout, storage_liveness)
}
@@ -1227,7 +1254,7 @@ fn create_generator_resume_function<'tcx>(
use rustc_middle::mir::AssertKind::{ResumedAfterPanic, ResumedAfterReturn};
// Jump to the entry point on the unresumed
- cases.insert(0, (UNRESUMED, BasicBlock::new(0)));
+ cases.insert(0, (UNRESUMED, START_BLOCK));
// Panic when resumed on the returned or poisoned state
let generator_kind = body.generator_kind().unwrap();
@@ -1351,6 +1378,42 @@ fn create_cases<'tcx>(
.collect()
}
+#[instrument(level = "debug", skip(tcx), ret)]
+pub(crate) fn mir_generator_witnesses<'tcx>(
+ tcx: TyCtxt<'tcx>,
+ def_id: DefId,
+) -> GeneratorLayout<'tcx> {
+ assert!(tcx.sess.opts.unstable_opts.drop_tracking_mir);
+ let def_id = def_id.expect_local();
+
+ let (body, _) = tcx.mir_promoted(ty::WithOptConstParam::unknown(def_id));
+ let body = body.borrow();
+ let body = &*body;
+
+ // The first argument is the generator type passed by value
+ let gen_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
+
+ // Get the interior types and substs which typeck computed
+ let movable = match *gen_ty.kind() {
+ ty::Generator(_, _, movability) => movability == hir::Movability::Movable,
+ _ => span_bug!(body.span, "unexpected generator type {}", gen_ty),
+ };
+
+ // When first entering the generator, move the resume argument into its new local.
+ let always_live_locals = always_storage_live_locals(&body);
+
+ let liveness_info = locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);
+
+ // Extract locals which are live across suspension point into `layout`
+ // `remap` gives a mapping from local indices onto generator struct indices
+ // `storage_liveness` tells us which locals have live storage at suspension points
+ let (_, generator_layout, _) = compute_layout(tcx, liveness_info, body);
+
+ check_suspend_tys(tcx, &generator_layout, &body);
+
+ generator_layout
+}
+
impl<'tcx> MirPass<'tcx> for StateTransform {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let Some(yield_ty) = body.yield_ty() else {
@@ -1363,14 +1426,14 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
// The first argument is the generator type passed by value
let gen_ty = body.local_decls.raw[1].ty;
- // Get the interior types and substs which typeck computed
- let (upvars, interior, discr_ty, movable) = match *gen_ty.kind() {
+ // Get the discriminant type and substs which typeck computed
+ let (discr_ty, upvars, interior, movable) = match *gen_ty.kind() {
ty::Generator(_, substs, movability) => {
let substs = substs.as_generator();
(
- substs.upvar_tys().collect(),
- substs.witness(),
substs.discr_ty(tcx),
+ substs.upvar_tys().collect::<Vec<_>>(),
+ substs.witness(),
movability == hir::Movability::Movable,
)
}
@@ -1386,13 +1449,13 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
// Compute Poll<return_ty>
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
let poll_adt_ref = tcx.adt_def(poll_did);
- let poll_substs = tcx.intern_substs(&[body.return_ty().into()]);
+ let poll_substs = tcx.mk_substs(&[body.return_ty().into()]);
(poll_adt_ref, poll_substs)
} else {
// Compute GeneratorState<yield_ty, return_ty>
let state_did = tcx.require_lang_item(LangItem::GeneratorState, None);
let state_adt_ref = tcx.adt_def(state_did);
- let state_substs = tcx.intern_substs(&[yield_ty.into(), body.return_ty().into()]);
+ let state_substs = tcx.mk_substs(&[yield_ty.into(), body.return_ty().into()]);
(state_adt_ref, state_substs)
};
let ret_ty = tcx.mk_adt(state_adt_ref, state_substs);
@@ -1417,7 +1480,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
// When first entering the generator, move the resume argument into its new local.
let source_info = SourceInfo::outermost(body.span);
- let stmts = &mut body.basic_blocks_mut()[BasicBlock::new(0)].statements;
+ let stmts = &mut body.basic_blocks_mut()[START_BLOCK].statements;
stmts.insert(
0,
Statement {
@@ -1434,8 +1497,6 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
let liveness_info =
locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);
- sanitize_witness(tcx, body, interior, upvars, &liveness_info.saved_locals);
-
if tcx.sess.opts.unstable_opts.validate_mir {
let mut vis = EnsureGeneratorFieldAssignmentsNeverAlias {
assigned_local: None,
@@ -1449,7 +1510,13 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
// Extract locals which are live across suspension point into `layout`
// `remap` gives a mapping from local indices onto generator struct indices
// `storage_liveness` tells us which locals have live storage at suspension points
- let (remap, layout, storage_liveness) = compute_layout(liveness_info, body);
+ let (remap, layout, storage_liveness) = compute_layout(tcx, liveness_info, body);
+
+ if tcx.sess.opts.unstable_opts.validate_mir
+ && !tcx.sess.opts.unstable_opts.drop_tracking_mir
+ {
+ sanitize_witness(tcx, body, interior, upvars, &layout);
+ }
let can_return = can_return(tcx, body, tcx.param_env(body.source.def_id()));
@@ -1583,6 +1650,7 @@ impl<'tcx> Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> {
| StatementKind::AscribeUserType(..)
| StatementKind::Coverage(..)
| StatementKind::Intrinsic(..)
+ | StatementKind::ConstEvalCounter
| StatementKind::Nop => {}
}
}
@@ -1631,3 +1699,212 @@ impl<'tcx> Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> {
}
}
}
+
+fn check_suspend_tys<'tcx>(tcx: TyCtxt<'tcx>, layout: &GeneratorLayout<'tcx>, body: &Body<'tcx>) {
+ let mut linted_tys = FxHashSet::default();
+
+ // We want a user-facing param-env.
+ let param_env = tcx.param_env(body.source.def_id());
+
+ for (variant, yield_source_info) in
+ layout.variant_fields.iter().zip(&layout.variant_source_info)
+ {
+ debug!(?variant);
+ for &local in variant {
+ let decl = &layout.field_tys[local];
+ debug!(?decl);
+
+ if !decl.ignore_for_traits && linted_tys.insert(decl.ty) {
+ let Some(hir_id) = decl.source_info.scope.lint_root(&body.source_scopes) else { continue };
+
+ check_must_not_suspend_ty(
+ tcx,
+ decl.ty,
+ hir_id,
+ param_env,
+ SuspendCheckData {
+ source_span: decl.source_info.span,
+ yield_span: yield_source_info.span,
+ plural_len: 1,
+ ..Default::default()
+ },
+ );
+ }
+ }
+ }
+}
+
+#[derive(Default)]
+struct SuspendCheckData<'a> {
+ source_span: Span,
+ yield_span: Span,
+ descr_pre: &'a str,
+ descr_post: &'a str,
+ plural_len: usize,
+}
+
+// Returns whether it emitted a diagnostic or not
+// Note that this fn and the proceeding one are based on the code
+// for creating must_use diagnostics
+//
+// Note that this technique was chosen over things like a `Suspend` marker trait
+// as it is simpler and has precedent in the compiler
+fn check_must_not_suspend_ty<'tcx>(
+ tcx: TyCtxt<'tcx>,
+ ty: Ty<'tcx>,
+ hir_id: hir::HirId,
+ param_env: ty::ParamEnv<'tcx>,
+ data: SuspendCheckData<'_>,
+) -> bool {
+ if ty.is_unit() {
+ return false;
+ }
+
+ let plural_suffix = pluralize!(data.plural_len);
+
+ debug!("Checking must_not_suspend for {}", ty);
+
+ match *ty.kind() {
+ ty::Adt(..) if ty.is_box() => {
+ let boxed_ty = ty.boxed_ty();
+ let descr_pre = &format!("{}boxed ", data.descr_pre);
+ check_must_not_suspend_ty(
+ tcx,
+ boxed_ty,
+ hir_id,
+ param_env,
+ SuspendCheckData { descr_pre, ..data },
+ )
+ }
+ ty::Adt(def, _) => check_must_not_suspend_def(tcx, def.did(), hir_id, data),
+ // FIXME: support adding the attribute to TAITs
+ ty::Alias(ty::Opaque, ty::AliasTy { def_id: def, .. }) => {
+ let mut has_emitted = false;
+ for &(predicate, _) in tcx.explicit_item_bounds(def) {
+ // We only look at the `DefId`, so it is safe to skip the binder here.
+ if let ty::PredicateKind::Clause(ty::Clause::Trait(ref poly_trait_predicate)) =
+ predicate.kind().skip_binder()
+ {
+ let def_id = poly_trait_predicate.trait_ref.def_id;
+ let descr_pre = &format!("{}implementer{} of ", data.descr_pre, plural_suffix);
+ if check_must_not_suspend_def(
+ tcx,
+ def_id,
+ hir_id,
+ SuspendCheckData { descr_pre, ..data },
+ ) {
+ has_emitted = true;
+ break;
+ }
+ }
+ }
+ has_emitted
+ }
+ ty::Dynamic(binder, _, _) => {
+ let mut has_emitted = false;
+ for predicate in binder.iter() {
+ if let ty::ExistentialPredicate::Trait(ref trait_ref) = predicate.skip_binder() {
+ let def_id = trait_ref.def_id;
+ let descr_post = &format!(" trait object{}{}", plural_suffix, data.descr_post);
+ if check_must_not_suspend_def(
+ tcx,
+ def_id,
+ hir_id,
+ SuspendCheckData { descr_post, ..data },
+ ) {
+ has_emitted = true;
+ break;
+ }
+ }
+ }
+ has_emitted
+ }
+ ty::Tuple(fields) => {
+ let mut has_emitted = false;
+ for (i, ty) in fields.iter().enumerate() {
+ let descr_post = &format!(" in tuple element {i}");
+ if check_must_not_suspend_ty(
+ tcx,
+ ty,
+ hir_id,
+ param_env,
+ SuspendCheckData { descr_post, ..data },
+ ) {
+ has_emitted = true;
+ }
+ }
+ has_emitted
+ }
+ ty::Array(ty, len) => {
+ let descr_pre = &format!("{}array{} of ", data.descr_pre, plural_suffix);
+ check_must_not_suspend_ty(
+ tcx,
+ ty,
+ hir_id,
+ param_env,
+ SuspendCheckData {
+ descr_pre,
+ plural_len: len.try_eval_target_usize(tcx, param_env).unwrap_or(0) as usize + 1,
+ ..data
+ },
+ )
+ }
+ // If drop tracking is enabled, we want to look through references, since the referrent
+ // may not be considered live across the await point.
+ ty::Ref(_region, ty, _mutability) => {
+ let descr_pre = &format!("{}reference{} to ", data.descr_pre, plural_suffix);
+ check_must_not_suspend_ty(
+ tcx,
+ ty,
+ hir_id,
+ param_env,
+ SuspendCheckData { descr_pre, ..data },
+ )
+ }
+ _ => false,
+ }
+}
+
+fn check_must_not_suspend_def(
+ tcx: TyCtxt<'_>,
+ def_id: DefId,
+ hir_id: hir::HirId,
+ data: SuspendCheckData<'_>,
+) -> bool {
+ if let Some(attr) = tcx.get_attr(def_id, sym::must_not_suspend) {
+ let msg = format!(
+ "{}`{}`{} held across a suspend point, but should not be",
+ data.descr_pre,
+ tcx.def_path_str(def_id),
+ data.descr_post,
+ );
+ tcx.struct_span_lint_hir(
+ rustc_session::lint::builtin::MUST_NOT_SUSPEND,
+ hir_id,
+ data.source_span,
+ msg,
+ |lint| {
+ // add span pointing to the offending yield/await
+ lint.span_label(data.yield_span, "the value is held across this suspend point");
+
+ // Add optional reason note
+ if let Some(note) = attr.value_str() {
+ // FIXME(guswynn): consider formatting this better
+ lint.span_note(data.source_span, note.as_str());
+ }
+
+ // Add some quick suggestions on what to do
+ // FIXME: can `drop` work as a suggestion here as well?
+ lint.span_help(
+ data.source_span,
+ "consider using a block (`{ ... }`) \
+ to shrink the value's scope, ending before the suspend point",
+ )
+ },
+ );
+
+ true
+ } else {
+ false
+ }
+}
diff --git a/compiler/rustc_mir_transform/src/inline.rs b/compiler/rustc_mir_transform/src/inline.rs
index 28c9080d3..6e6d6566f 100644
--- a/compiler/rustc_mir_transform/src/inline.rs
+++ b/compiler/rustc_mir_transform/src/inline.rs
@@ -96,7 +96,7 @@ fn inline<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> bool {
history: Vec::new(),
changed: false,
};
- let blocks = BasicBlock::new(0)..body.basic_blocks.next_index();
+ let blocks = START_BLOCK..body.basic_blocks.next_index();
this.process_blocks(body, blocks);
this.changed
}
@@ -331,7 +331,7 @@ impl<'tcx> Inliner<'tcx> {
return None;
}
- let fn_sig = self.tcx.bound_fn_sig(def_id).subst(self.tcx, substs);
+ let fn_sig = self.tcx.fn_sig(def_id).subst(self.tcx, substs);
let source_info = SourceInfo { span: fn_span, ..terminator.source_info };
return Some(CallSite { callee, fn_sig, block: bb, target, source_info });
@@ -888,7 +888,7 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
location: Location,
) {
if let ProjectionElem::Field(f, ty) = elem {
- let parent = Place { local, projection: self.tcx.intern_place_elems(proj_base) };
+ let parent = Place { local, projection: self.tcx.mk_place_elems(proj_base) };
let parent_ty = parent.ty(&self.callee_body.local_decls, self.tcx);
let check_equal = |this: &mut Self, f_ty| {
if !util::is_equal_up_to_subtyping(this.tcx, this.param_env, ty, f_ty) {
@@ -900,7 +900,7 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
let kind = match parent_ty.ty.kind() {
&ty::Alias(ty::Opaque, ty::AliasTy { def_id, substs, .. }) => {
- self.tcx.bound_type_of(def_id).subst(self.tcx, substs).kind()
+ self.tcx.type_of(def_id).subst(self.tcx, substs).kind()
}
kind => kind,
};
@@ -947,12 +947,12 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
return;
};
- let Some(&f_ty) = layout.field_tys.get(local) else {
+ let Some(f_ty) = layout.field_tys.get(local) else {
self.validation = Err("malformed MIR");
return;
};
- f_ty
+ f_ty.ty
} else {
let Some(f_ty) = substs.as_generator().prefix_tys().nth(f.index()) else {
self.validation = Err("malformed MIR");
diff --git a/compiler/rustc_mir_transform/src/inline/cycle.rs b/compiler/rustc_mir_transform/src/inline/cycle.rs
index b027f9492..792457c80 100644
--- a/compiler/rustc_mir_transform/src/inline/cycle.rs
+++ b/compiler/rustc_mir_transform/src/inline/cycle.rs
@@ -2,7 +2,7 @@ use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexSet};
use rustc_data_structures::stack::ensure_sufficient_stack;
use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_middle::mir::TerminatorKind;
-use rustc_middle::ty::TypeVisitable;
+use rustc_middle::ty::TypeVisitableExt;
use rustc_middle::ty::{self, subst::SubstsRef, InstanceDef, TyCtxt};
use rustc_session::Limit;
diff --git a/compiler/rustc_mir_transform/src/instcombine.rs b/compiler/rustc_mir_transform/src/instcombine.rs
index 2f3c65869..4182da195 100644
--- a/compiler/rustc_mir_transform/src/instcombine.rs
+++ b/compiler/rustc_mir_transform/src/instcombine.rs
@@ -6,7 +6,9 @@ use rustc_middle::mir::{
BinOp, Body, Constant, ConstantKind, LocalDecls, Operand, Place, ProjectionElem, Rvalue,
SourceInfo, Statement, StatementKind, Terminator, TerminatorKind, UnOp,
};
-use rustc_middle::ty::{self, TyCtxt};
+use rustc_middle::ty::layout::ValidityRequirement;
+use rustc_middle::ty::{self, ParamEnv, SubstsRef, Ty, TyCtxt};
+use rustc_span::symbol::Symbol;
pub struct InstCombine;
@@ -16,7 +18,11 @@ impl<'tcx> MirPass<'tcx> for InstCombine {
}
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
- let ctx = InstCombineContext { tcx, local_decls: &body.local_decls };
+ let ctx = InstCombineContext {
+ tcx,
+ local_decls: &body.local_decls,
+ param_env: tcx.param_env_reveal_all_normalized(body.source.def_id()),
+ };
for block in body.basic_blocks.as_mut() {
for statement in block.statements.iter_mut() {
match statement.kind {
@@ -24,6 +30,7 @@ impl<'tcx> MirPass<'tcx> for InstCombine {
ctx.combine_bool_cmp(&statement.source_info, rvalue);
ctx.combine_ref_deref(&statement.source_info, rvalue);
ctx.combine_len(&statement.source_info, rvalue);
+ ctx.combine_cast(&statement.source_info, rvalue);
}
_ => {}
}
@@ -33,6 +40,10 @@ impl<'tcx> MirPass<'tcx> for InstCombine {
&mut block.terminator.as_mut().unwrap(),
&mut block.statements,
);
+ ctx.combine_intrinsic_assert(
+ &mut block.terminator.as_mut().unwrap(),
+ &mut block.statements,
+ );
}
}
}
@@ -40,6 +51,7 @@ impl<'tcx> MirPass<'tcx> for InstCombine {
struct InstCombineContext<'tcx, 'a> {
tcx: TyCtxt<'tcx>,
local_decls: &'a LocalDecls<'tcx>,
+ param_env: ParamEnv<'tcx>,
}
impl<'tcx> InstCombineContext<'tcx, '_> {
@@ -99,11 +111,7 @@ impl<'tcx> InstCombineContext<'tcx, '_> {
fn combine_ref_deref(&self, source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) {
if let Rvalue::Ref(_, _, place) = rvalue {
if let Some((base, ProjectionElem::Deref)) = place.as_ref().last_projection() {
- if let ty::Ref(_, _, Mutability::Not) =
- base.ty(self.local_decls, self.tcx).ty.kind()
- {
- // The dereferenced place must have type `&_`, so that we don't copy `&mut _`.
- } else {
+ if rvalue.ty(self.local_decls, self.tcx) != base.ty(self.local_decls, self.tcx).ty {
return;
}
@@ -113,7 +121,7 @@ impl<'tcx> InstCombineContext<'tcx, '_> {
*rvalue = Rvalue::Use(Operand::Copy(Place {
local: base.local,
- projection: self.tcx.intern_place_elems(base.projection),
+ projection: self.tcx.mk_place_elems(base.projection),
}));
}
}
@@ -135,6 +143,14 @@ impl<'tcx> InstCombineContext<'tcx, '_> {
}
}
+ fn combine_cast(&self, _source_info: &SourceInfo, rvalue: &mut Rvalue<'tcx>) {
+ if let Rvalue::Cast(_kind, operand, ty) = rvalue {
+ if operand.ty(self.local_decls, self.tcx) == *ty {
+ *rvalue = Rvalue::Use(operand.clone());
+ }
+ }
+ }
+
fn combine_primitive_clone(
&self,
terminator: &mut Terminator<'tcx>,
@@ -200,4 +216,58 @@ impl<'tcx> InstCombineContext<'tcx, '_> {
});
terminator.kind = TerminatorKind::Goto { target: destination_block };
}
+
+ fn combine_intrinsic_assert(
+ &self,
+ terminator: &mut Terminator<'tcx>,
+ _statements: &mut Vec<Statement<'tcx>>,
+ ) {
+ let TerminatorKind::Call { func, target, .. } = &mut terminator.kind else { return; };
+ let Some(target_block) = target else { return; };
+ let func_ty = func.ty(self.local_decls, self.tcx);
+ let Some((intrinsic_name, substs)) = resolve_rust_intrinsic(self.tcx, func_ty) else {
+ return;
+ };
+ // The intrinsics we are interested in have one generic parameter
+ if substs.is_empty() {
+ return;
+ }
+ let ty = substs.type_at(0);
+
+ let known_is_valid = intrinsic_assert_panics(self.tcx, self.param_env, ty, intrinsic_name);
+ match known_is_valid {
+ // We don't know the layout or it's not validity assertion at all, don't touch it
+ None => {}
+ Some(true) => {
+ // If we know the assert panics, indicate to later opts that the call diverges
+ *target = None;
+ }
+ Some(false) => {
+ // If we know the assert does not panic, turn the call into a Goto
+ terminator.kind = TerminatorKind::Goto { target: *target_block };
+ }
+ }
+ }
+}
+
+fn intrinsic_assert_panics<'tcx>(
+ tcx: TyCtxt<'tcx>,
+ param_env: ty::ParamEnv<'tcx>,
+ ty: Ty<'tcx>,
+ intrinsic_name: Symbol,
+) -> Option<bool> {
+ let requirement = ValidityRequirement::from_intrinsic(intrinsic_name)?;
+ Some(!tcx.check_validity_requirement((requirement, param_env.and(ty))).ok()?)
+}
+
+fn resolve_rust_intrinsic<'tcx>(
+ tcx: TyCtxt<'tcx>,
+ func_ty: Ty<'tcx>,
+) -> Option<(Symbol, SubstsRef<'tcx>)> {
+ if let ty::FnDef(def_id, substs) = *func_ty.kind() {
+ if tcx.is_intrinsic(def_id) {
+ return Some((tcx.item_name(def_id), substs));
+ }
+ }
+ None
}
diff --git a/compiler/rustc_mir_transform/src/large_enums.rs b/compiler/rustc_mir_transform/src/large_enums.rs
new file mode 100644
index 000000000..89e0a007d
--- /dev/null
+++ b/compiler/rustc_mir_transform/src/large_enums.rs
@@ -0,0 +1,294 @@
+use crate::rustc_middle::ty::util::IntTypeExt;
+use crate::MirPass;
+use rustc_data_structures::fx::FxHashMap;
+use rustc_middle::mir::interpret::AllocId;
+use rustc_middle::mir::*;
+use rustc_middle::ty::{self, AdtDef, ParamEnv, Ty, TyCtxt};
+use rustc_session::Session;
+use rustc_target::abi::{HasDataLayout, Size, TagEncoding, Variants};
+
+/// A pass that seeks to optimize unnecessary moves of large enum types, if there is a large
+/// enough discrepancy between them.
+///
+/// i.e. If there is are two variants:
+/// ```
+/// enum Example {
+/// Small,
+/// Large([u32; 1024]),
+/// }
+/// ```
+/// Instead of emitting moves of the large variant,
+/// Perform a memcpy instead.
+/// Based off of [this HackMD](https://hackmd.io/@ft4bxUsFT5CEUBmRKYHr7w/rJM8BBPzD).
+///
+/// In summary, what this does is at runtime determine which enum variant is active,
+/// and instead of copying all the bytes of the largest possible variant,
+/// copy only the bytes for the currently active variant.
+pub struct EnumSizeOpt {
+ pub(crate) discrepancy: u64,
+}
+
+impl<'tcx> MirPass<'tcx> for EnumSizeOpt {
+ fn is_enabled(&self, sess: &Session) -> bool {
+ sess.opts.unstable_opts.unsound_mir_opts || sess.mir_opt_level() >= 3
+ }
+ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+ // NOTE: This pass may produce different MIR based on the alignment of the target
+ // platform, but it will still be valid.
+ self.optim(tcx, body);
+ }
+}
+
+impl EnumSizeOpt {
+ fn candidate<'tcx>(
+ &self,
+ tcx: TyCtxt<'tcx>,
+ param_env: ParamEnv<'tcx>,
+ ty: Ty<'tcx>,
+ alloc_cache: &mut FxHashMap<Ty<'tcx>, AllocId>,
+ ) -> Option<(AdtDef<'tcx>, usize, AllocId)> {
+ let adt_def = match ty.kind() {
+ ty::Adt(adt_def, _substs) if adt_def.is_enum() => adt_def,
+ _ => return None,
+ };
+ let layout = tcx.layout_of(param_env.and(ty)).ok()?;
+ let variants = match &layout.variants {
+ Variants::Single { .. } => return None,
+ Variants::Multiple { tag_encoding, .. }
+ if matches!(tag_encoding, TagEncoding::Niche { .. }) =>
+ {
+ return None;
+ }
+ Variants::Multiple { variants, .. } if variants.len() <= 1 => return None,
+ Variants::Multiple { variants, .. } => variants,
+ };
+ let min = variants.iter().map(|v| v.size).min().unwrap();
+ let max = variants.iter().map(|v| v.size).max().unwrap();
+ if max.bytes() - min.bytes() < self.discrepancy {
+ return None;
+ }
+
+ let num_discrs = adt_def.discriminants(tcx).count();
+ if variants.iter_enumerated().any(|(var_idx, _)| {
+ let discr_for_var = adt_def.discriminant_for_variant(tcx, var_idx).val;
+ (discr_for_var > usize::MAX as u128) || (discr_for_var as usize >= num_discrs)
+ }) {
+ return None;
+ }
+ if let Some(alloc_id) = alloc_cache.get(&ty) {
+ return Some((*adt_def, num_discrs, *alloc_id));
+ }
+
+ let data_layout = tcx.data_layout();
+ let ptr_sized_int = data_layout.ptr_sized_integer();
+ let target_bytes = ptr_sized_int.size().bytes() as usize;
+ let mut data = vec![0; target_bytes * num_discrs];
+ macro_rules! encode_store {
+ ($curr_idx: expr, $endian: expr, $bytes: expr) => {
+ let bytes = match $endian {
+ rustc_target::abi::Endian::Little => $bytes.to_le_bytes(),
+ rustc_target::abi::Endian::Big => $bytes.to_be_bytes(),
+ };
+ for (i, b) in bytes.into_iter().enumerate() {
+ data[$curr_idx + i] = b;
+ }
+ };
+ }
+
+ for (var_idx, layout) in variants.iter_enumerated() {
+ let curr_idx =
+ target_bytes * adt_def.discriminant_for_variant(tcx, var_idx).val as usize;
+ let sz = layout.size;
+ match ptr_sized_int {
+ rustc_target::abi::Integer::I32 => {
+ encode_store!(curr_idx, data_layout.endian, sz.bytes() as u32);
+ }
+ rustc_target::abi::Integer::I64 => {
+ encode_store!(curr_idx, data_layout.endian, sz.bytes());
+ }
+ _ => unreachable!(),
+ };
+ }
+ let alloc = interpret::Allocation::from_bytes(
+ data,
+ tcx.data_layout.ptr_sized_integer().align(&tcx.data_layout).abi,
+ Mutability::Not,
+ );
+ let alloc = tcx.create_memory_alloc(tcx.mk_const_alloc(alloc));
+ Some((*adt_def, num_discrs, *alloc_cache.entry(ty).or_insert(alloc)))
+ }
+ fn optim<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+ let mut alloc_cache = FxHashMap::default();
+ let body_did = body.source.def_id();
+ let param_env = tcx.param_env(body_did);
+
+ let blocks = body.basic_blocks.as_mut();
+ let local_decls = &mut body.local_decls;
+
+ for bb in blocks {
+ bb.expand_statements(|st| {
+ if let StatementKind::Assign(box (
+ lhs,
+ Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)),
+ )) = &st.kind
+ {
+ let ty = lhs.ty(local_decls, tcx).ty;
+
+ let source_info = st.source_info;
+ let span = source_info.span;
+
+ let (adt_def, num_variants, alloc_id) =
+ self.candidate(tcx, param_env, ty, &mut alloc_cache)?;
+ let alloc = tcx.global_alloc(alloc_id).unwrap_memory();
+
+ let tmp_ty = tcx.mk_array(tcx.types.usize, num_variants as u64);
+
+ let size_array_local = local_decls.push(LocalDecl::new(tmp_ty, span));
+ let store_live = Statement {
+ source_info,
+ kind: StatementKind::StorageLive(size_array_local),
+ };
+
+ let place = Place::from(size_array_local);
+ let constant_vals = Constant {
+ span,
+ user_ty: None,
+ literal: ConstantKind::Val(
+ interpret::ConstValue::ByRef { alloc, offset: Size::ZERO },
+ tmp_ty,
+ ),
+ };
+ let rval = Rvalue::Use(Operand::Constant(box (constant_vals)));
+
+ let const_assign =
+ Statement { source_info, kind: StatementKind::Assign(box (place, rval)) };
+
+ let discr_place = Place::from(
+ local_decls
+ .push(LocalDecl::new(adt_def.repr().discr_type().to_ty(tcx), span)),
+ );
+
+ let store_discr = Statement {
+ source_info,
+ kind: StatementKind::Assign(box (discr_place, Rvalue::Discriminant(*rhs))),
+ };
+
+ let discr_cast_place =
+ Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span)));
+
+ let cast_discr = Statement {
+ source_info,
+ kind: StatementKind::Assign(box (
+ discr_cast_place,
+ Rvalue::Cast(
+ CastKind::IntToInt,
+ Operand::Copy(discr_place),
+ tcx.types.usize,
+ ),
+ )),
+ };
+
+ let size_place =
+ Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span)));
+
+ let store_size = Statement {
+ source_info,
+ kind: StatementKind::Assign(box (
+ size_place,
+ Rvalue::Use(Operand::Copy(Place {
+ local: size_array_local,
+ projection: tcx
+ .mk_place_elems(&[PlaceElem::Index(discr_cast_place.local)]),
+ })),
+ )),
+ };
+
+ let dst =
+ Place::from(local_decls.push(LocalDecl::new(tcx.mk_mut_ptr(ty), span)));
+
+ let dst_ptr = Statement {
+ source_info,
+ kind: StatementKind::Assign(box (
+ dst,
+ Rvalue::AddressOf(Mutability::Mut, *lhs),
+ )),
+ };
+
+ let dst_cast_ty = tcx.mk_mut_ptr(tcx.types.u8);
+ let dst_cast_place =
+ Place::from(local_decls.push(LocalDecl::new(dst_cast_ty, span)));
+
+ let dst_cast = Statement {
+ source_info,
+ kind: StatementKind::Assign(box (
+ dst_cast_place,
+ Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(dst), dst_cast_ty),
+ )),
+ };
+
+ let src =
+ Place::from(local_decls.push(LocalDecl::new(tcx.mk_imm_ptr(ty), span)));
+
+ let src_ptr = Statement {
+ source_info,
+ kind: StatementKind::Assign(box (
+ src,
+ Rvalue::AddressOf(Mutability::Not, *rhs),
+ )),
+ };
+
+ let src_cast_ty = tcx.mk_imm_ptr(tcx.types.u8);
+ let src_cast_place =
+ Place::from(local_decls.push(LocalDecl::new(src_cast_ty, span)));
+
+ let src_cast = Statement {
+ source_info,
+ kind: StatementKind::Assign(box (
+ src_cast_place,
+ Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(src), src_cast_ty),
+ )),
+ };
+
+ let deinit_old =
+ Statement { source_info, kind: StatementKind::Deinit(box dst) };
+
+ let copy_bytes = Statement {
+ source_info,
+ kind: StatementKind::Intrinsic(
+ box NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping {
+ src: Operand::Copy(src_cast_place),
+ dst: Operand::Copy(dst_cast_place),
+ count: Operand::Copy(size_place),
+ }),
+ ),
+ };
+
+ let store_dead = Statement {
+ source_info,
+ kind: StatementKind::StorageDead(size_array_local),
+ };
+ let iter = [
+ store_live,
+ const_assign,
+ store_discr,
+ cast_discr,
+ store_size,
+ dst_ptr,
+ dst_cast,
+ src_ptr,
+ src_cast,
+ deinit_old,
+ copy_bytes,
+ store_dead,
+ ]
+ .into_iter();
+
+ st.make_nop();
+ Some(iter)
+ } else {
+ None
+ }
+ });
+ }
+ }
+}
diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs
index 20b7fdcfe..cdd28ae0c 100644
--- a/compiler/rustc_mir_transform/src/lib.rs
+++ b/compiler/rustc_mir_transform/src/lib.rs
@@ -1,6 +1,7 @@
#![allow(rustc::potential_query_instability)]
#![feature(box_patterns)]
#![feature(drain_filter)]
+#![feature(box_syntax)]
#![feature(let_chains)]
#![feature(map_try_insert)]
#![feature(min_specialization)]
@@ -23,6 +24,7 @@ use rustc_const_eval::util;
use rustc_data_structures::fx::FxIndexSet;
use rustc_data_structures::steal::Steal;
use rustc_hir as hir;
+use rustc_hir::def::DefKind;
use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_hir::intravisit::{self, Visitor};
use rustc_index::vec::IndexVec;
@@ -33,7 +35,7 @@ use rustc_middle::mir::{
TerminatorKind,
};
use rustc_middle::ty::query::Providers;
-use rustc_middle::ty::{self, TyCtxt, TypeVisitable};
+use rustc_middle::ty::{self, TyCtxt, TypeVisitableExt};
use rustc_span::sym;
#[macro_use]
@@ -54,10 +56,11 @@ mod const_debuginfo;
mod const_goto;
mod const_prop;
mod const_prop_lint;
+mod copy_prop;
mod coverage;
+mod ctfe_limit;
mod dataflow_const_prop;
mod dead_store_elimination;
-mod deaggregator;
mod deduce_param_attrs;
mod deduplicate_blocks;
mod deref_separator;
@@ -71,6 +74,7 @@ mod function_item_references;
mod generator;
mod inline;
mod instcombine;
+mod large_enums;
mod lower_intrinsics;
mod lower_slice_len;
mod match_branches;
@@ -86,11 +90,11 @@ mod required_consts;
mod reveal_all;
mod separate_const_switch;
mod shim;
+mod ssa;
// This pass is public to allow external drivers to perform MIR cleanup
pub mod simplify;
mod simplify_branches;
mod simplify_comparison_integral;
-mod simplify_try;
mod sroa;
mod uninhabited_enum_branching;
mod unreachable_prop;
@@ -102,7 +106,6 @@ use rustc_mir_dataflow::rustc_peek;
pub fn provide(providers: &mut Providers) {
check_unsafety::provide(providers);
- check_packed_ref::provide(providers);
coverage::query::provide(providers);
ffi_unwind_calls::provide(providers);
shim::provide(providers);
@@ -124,6 +127,7 @@ pub fn provide(providers: &mut Providers) {
mir_drops_elaborated_and_const_checked,
mir_for_ctfe,
mir_for_ctfe_of_const_arg,
+ mir_generator_witnesses: generator::mir_generator_witnesses,
optimized_mir,
is_mir_available,
is_ctfe_mir_available: |tcx, did| is_mir_available(tcx, did),
@@ -188,7 +192,7 @@ fn remap_mir_for_const_eval_select<'tcx>(
let arguments = (0..num_args).map(|x| {
let mut place_elems = place_elems.to_vec();
place_elems.push(ProjectionElem::Field(x.into(), fields[x]));
- let projection = tcx.intern_place_elems(&place_elems);
+ let projection = tcx.mk_place_elems(&place_elems);
let place = Place {
local: place.local,
projection,
@@ -244,7 +248,7 @@ fn mir_const_qualif(tcx: TyCtxt<'_>, def: ty::WithOptConstParam<LocalDefId>) ->
// N.B., this `borrow()` is guaranteed to be valid (i.e., the value
// cannot yet be stolen), because `mir_promoted()`, which steals
- // from `mir_const(), forces this query to execute before
+ // from `mir_const()`, forces this query to execute before
// performing the steal.
let body = &tcx.mir_const(def).borrow();
@@ -410,6 +414,8 @@ fn inner_mir_for_ctfe(tcx: TyCtxt<'_>, def: ty::WithOptConstParam<LocalDefId>) -
}
}
+ pm::run_passes(tcx, &mut body, &[&ctfe_limit::CtfeLimit], None);
+
debug_assert!(!body.has_free_regions(), "Free regions in MIR for CTFE");
body
@@ -426,6 +432,11 @@ fn mir_drops_elaborated_and_const_checked(
return tcx.mir_drops_elaborated_and_const_checked(def);
}
+ if tcx.sess.opts.unstable_opts.drop_tracking_mir
+ && let DefKind::Generator = tcx.def_kind(def.did)
+ {
+ tcx.ensure().mir_generator_witnesses(def.did);
+ }
let mir_borrowck = tcx.mir_borrowck_opt_const_arg(def);
let is_fn_like = tcx.def_kind(def.did).is_fn_like();
@@ -513,9 +524,6 @@ fn run_runtime_lowering_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&elaborate_box_derefs::ElaborateBoxDerefs,
&generator::StateTransform,
&add_retag::AddRetag,
- // Deaggregator is necessary for const prop. We may want to consider implementing
- // CTFE support for aggregates.
- &deaggregator::Deaggregator,
&Lint(const_prop_lint::ConstProp),
];
pm::run_passes_no_validate(tcx, body, passes, Some(MirPhase::Runtime(RuntimePhase::Initial)));
@@ -541,13 +549,13 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&[
&reveal_all::RevealAll, // has to be done before inlining, since inlined code is in RevealAll mode.
&lower_slice_len::LowerSliceLenCalls, // has to be done before inlining, otherwise actual call will be almost always inlined. Also simple, so can just do first
- &normalize_array_len::NormalizeArrayLen, // has to run after `slice::len` lowering
&unreachable_prop::UnreachablePropagation,
&uninhabited_enum_branching::UninhabitedEnumBranching,
&o1(simplify::SimplifyCfg::new("after-uninhabited-enum-branching")),
&inline::Inline,
&remove_storage_markers::RemoveStorageMarkers,
&remove_zsts::RemoveZsts,
+ &normalize_array_len::NormalizeArrayLen, // has to run after `slice::len` lowering
&const_goto::ConstGoto,
&remove_unneeded_drops::RemoveUnneededDrops,
&sroa::ScalarReplacementOfAggregates,
@@ -557,6 +565,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&instcombine::InstCombine,
&separate_const_switch::SeparateConstSwitch,
&simplify::SimplifyLocals::new("before-const-prop"),
+ &copy_prop::CopyProp,
//
// FIXME(#70073): This pass is responsible for both optimization as well as some lints.
&const_prop::ConstProp,
@@ -567,8 +576,6 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&o1(simplify_branches::SimplifyConstCondition::new("after-const-prop")),
&early_otherwise_branch::EarlyOtherwiseBranch,
&simplify_comparison_integral::SimplifyComparisonIntegral,
- &simplify_try::SimplifyArmIdentity,
- &simplify_try::SimplifyBranchSame,
&dead_store_elimination::DeadStoreElimination,
&dest_prop::DestinationPropagation,
&o1(simplify_branches::SimplifyConstCondition::new("final")),
@@ -578,6 +585,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&simplify::SimplifyLocals::new("final"),
&multiple_return_terminators::MultipleReturnTerminators,
&deduplicate_blocks::DeduplicateBlocks,
+ &large_enums::EnumSizeOpt { discrepancy: 128 },
// Some cleanup necessary at least for LLVM and potentially other codegen backends.
&add_call_guards::CriticalCallEdges,
// Dump the end result for testing and debugging purposes.
diff --git a/compiler/rustc_mir_transform/src/lower_intrinsics.rs b/compiler/rustc_mir_transform/src/lower_intrinsics.rs
index 9892580e6..f596cc180 100644
--- a/compiler/rustc_mir_transform/src/lower_intrinsics.rs
+++ b/compiler/rustc_mir_transform/src/lower_intrinsics.rs
@@ -107,9 +107,29 @@ impl<'tcx> MirPass<'tcx> for LowerIntrinsics {
}
}
sym::add_with_overflow | sym::sub_with_overflow | sym::mul_with_overflow => {
- // The checked binary operations are not suitable target for lowering here,
- // since their semantics depend on the value of overflow-checks flag used
- // during codegen. Issue #35310.
+ if let Some(target) = *target {
+ let lhs;
+ let rhs;
+ {
+ let mut args = args.drain(..);
+ lhs = args.next().unwrap();
+ rhs = args.next().unwrap();
+ }
+ let bin_op = match intrinsic_name {
+ sym::add_with_overflow => BinOp::Add,
+ sym::sub_with_overflow => BinOp::Sub,
+ sym::mul_with_overflow => BinOp::Mul,
+ _ => bug!("unexpected intrinsic"),
+ };
+ block.statements.push(Statement {
+ source_info: terminator.source_info,
+ kind: StatementKind::Assign(Box::new((
+ *destination,
+ Rvalue::CheckedBinaryOp(bin_op, Box::new((lhs, rhs))),
+ ))),
+ });
+ terminator.kind = TerminatorKind::Goto { target };
+ }
}
sym::size_of | sym::min_align_of => {
if let Some(target) = *target {
diff --git a/compiler/rustc_mir_transform/src/lower_slice_len.rs b/compiler/rustc_mir_transform/src/lower_slice_len.rs
index 2f02d00ec..c6e7468aa 100644
--- a/compiler/rustc_mir_transform/src/lower_slice_len.rs
+++ b/compiler/rustc_mir_transform/src/lower_slice_len.rs
@@ -68,8 +68,11 @@ fn lower_slice_len_call<'tcx>(
ty::FnDef(fn_def_id, _) if fn_def_id == &slice_len_fn_item_def_id => {
// perform modifications
// from something like `_5 = core::slice::<impl [u8]>::len(move _6) -> bb1`
- // into `_5 = Len(*_6)
+ // into:
+ // ```
+ // _5 = Len(*_6)
// goto bb1
+ // ```
// make new RValue for Len
let deref_arg = tcx.mk_place_deref(arg);
diff --git a/compiler/rustc_mir_transform/src/normalize_array_len.rs b/compiler/rustc_mir_transform/src/normalize_array_len.rs
index 1708b287e..b36c8a0bd 100644
--- a/compiler/rustc_mir_transform/src/normalize_array_len.rs
+++ b/compiler/rustc_mir_transform/src/normalize_array_len.rs
@@ -1,288 +1,104 @@
//! This pass eliminates casting of arrays into slices when their length
//! is taken using `.len()` method. Handy to preserve information in MIR for const prop
+use crate::ssa::SsaLocals;
use crate::MirPass;
-use rustc_data_structures::fx::FxIndexMap;
-use rustc_data_structures::intern::Interned;
-use rustc_index::bit_set::BitSet;
use rustc_index::vec::IndexVec;
+use rustc_middle::mir::visit::*;
use rustc_middle::mir::*;
-use rustc_middle::ty::{self, ReErased, Region, TyCtxt};
-
-const MAX_NUM_BLOCKS: usize = 800;
-const MAX_NUM_LOCALS: usize = 3000;
+use rustc_middle::ty::{self, TyCtxt};
+use rustc_mir_dataflow::impls::borrowed_locals;
pub struct NormalizeArrayLen;
impl<'tcx> MirPass<'tcx> for NormalizeArrayLen {
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
- // See #105929
- sess.mir_opt_level() >= 4 && sess.opts.unstable_opts.unsound_mir_opts
+ sess.mir_opt_level() >= 3
}
+ #[instrument(level = "trace", skip(self, tcx, body))]
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
- // early returns for edge cases of highly unrolled functions
- if body.basic_blocks.len() > MAX_NUM_BLOCKS {
- return;
- }
- if body.local_decls.len() > MAX_NUM_LOCALS {
- return;
- }
+ debug!(def_id = ?body.source.def_id());
normalize_array_len_calls(tcx, body)
}
}
-pub fn normalize_array_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
- // We don't ever touch terminators, so no need to invalidate the CFG cache
- let basic_blocks = body.basic_blocks.as_mut_preserves_cfg();
- let local_decls = &mut body.local_decls;
+fn normalize_array_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+ let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id());
+ let borrowed_locals = borrowed_locals(body);
+ let ssa = SsaLocals::new(tcx, param_env, body, &borrowed_locals);
- // do a preliminary analysis to see if we ever have locals of type `[T;N]` or `&[T;N]`
- let mut interesting_locals = BitSet::new_empty(local_decls.len());
- for (local, decl) in local_decls.iter_enumerated() {
- match decl.ty.kind() {
- ty::Array(..) => {
- interesting_locals.insert(local);
- }
- ty::Ref(.., ty, Mutability::Not) => match ty.kind() {
- ty::Array(..) => {
- interesting_locals.insert(local);
- }
- _ => {}
- },
- _ => {}
- }
- }
- if interesting_locals.is_empty() {
- // we have found nothing to analyze
- return;
- }
- let num_intesting_locals = interesting_locals.count();
- let mut state = FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
- let mut patches_scratchpad =
- FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
- let mut replacements_scratchpad =
- FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
- for block in basic_blocks {
- // make length calls for arrays [T; N] not to decay into length calls for &[T]
- // that forbids constant propagation
- normalize_array_len_call(
- tcx,
- block,
- local_decls,
- &interesting_locals,
- &mut state,
- &mut patches_scratchpad,
- &mut replacements_scratchpad,
- );
- state.clear();
- patches_scratchpad.clear();
- replacements_scratchpad.clear();
- }
-}
+ let slice_lengths = compute_slice_length(tcx, &ssa, body);
+ debug!(?slice_lengths);
-struct Patcher<'a, 'tcx> {
- tcx: TyCtxt<'tcx>,
- patches_scratchpad: &'a FxIndexMap<usize, usize>,
- replacements_scratchpad: &'a mut FxIndexMap<usize, Local>,
- local_decls: &'a mut IndexVec<Local, LocalDecl<'tcx>>,
- statement_idx: usize,
+ Replacer { tcx, slice_lengths }.visit_body_preserves_cfg(body);
}
-impl<'tcx> Patcher<'_, 'tcx> {
- fn patch_expand_statement(
- &mut self,
- statement: &mut Statement<'tcx>,
- ) -> Option<std::vec::IntoIter<Statement<'tcx>>> {
- let idx = self.statement_idx;
- if let Some(len_statemnt_idx) = self.patches_scratchpad.get(&idx).copied() {
- let mut statements = Vec::with_capacity(2);
-
- // we are at statement that performs a cast. The only sound way is
- // to create another local that performs a similar copy without a cast and then
- // use this copy in the Len operation
-
- match &statement.kind {
- StatementKind::Assign(box (
- ..,
- Rvalue::Cast(
- CastKind::Pointer(ty::adjustment::PointerCast::Unsize),
- operand,
- _,
- ),
- )) => {
- match operand {
- Operand::Copy(place) | Operand::Move(place) => {
- // create new local
- let ty = operand.ty(self.local_decls, self.tcx);
- let local_decl = LocalDecl::with_source_info(ty, statement.source_info);
- let local = self.local_decls.push(local_decl);
- // make it live
- let mut make_live_statement = statement.clone();
- make_live_statement.kind = StatementKind::StorageLive(local);
- statements.push(make_live_statement);
- // copy into it
-
- let operand = Operand::Copy(*place);
- let mut make_copy_statement = statement.clone();
- let assign_to = Place::from(local);
- let rvalue = Rvalue::Use(operand);
- make_copy_statement.kind =
- StatementKind::Assign(Box::new((assign_to, rvalue)));
- statements.push(make_copy_statement);
-
- // to reorder we have to copy and make NOP
- statements.push(statement.clone());
- statement.make_nop();
-
- self.replacements_scratchpad.insert(len_statemnt_idx, local);
- }
- _ => {
- unreachable!("it's a bug in the implementation")
- }
- }
- }
- _ => {
- unreachable!("it's a bug in the implementation")
+fn compute_slice_length<'tcx>(
+ tcx: TyCtxt<'tcx>,
+ ssa: &SsaLocals,
+ body: &Body<'tcx>,
+) -> IndexVec<Local, Option<ty::Const<'tcx>>> {
+ let mut slice_lengths = IndexVec::from_elem(None, &body.local_decls);
+
+ for (local, rvalue) in ssa.assignments(body) {
+ match rvalue {
+ Rvalue::Cast(
+ CastKind::Pointer(ty::adjustment::PointerCast::Unsize),
+ operand,
+ cast_ty,
+ ) => {
+ let operand_ty = operand.ty(body, tcx);
+ debug!(?operand_ty);
+ if let Some(operand_ty) = operand_ty.builtin_deref(true)
+ && let ty::Array(_, len) = operand_ty.ty.kind()
+ && let Some(cast_ty) = cast_ty.builtin_deref(true)
+ && let ty::Slice(..) = cast_ty.ty.kind()
+ {
+ slice_lengths[local] = Some(*len);
}
}
-
- self.statement_idx += 1;
-
- Some(statements.into_iter())
- } else if let Some(local) = self.replacements_scratchpad.get(&idx).copied() {
- let mut statements = Vec::with_capacity(2);
-
- match &statement.kind {
- StatementKind::Assign(box (into, Rvalue::Len(place))) => {
- let add_deref = if let Some(..) = place.as_local() {
- false
- } else if let Some(..) = place.local_or_deref_local() {
- true
- } else {
- unreachable!("it's a bug in the implementation")
- };
- // replace len statement
- let mut len_statement = statement.clone();
- let mut place = Place::from(local);
- if add_deref {
- place = self.tcx.mk_place_deref(place);
- }
- len_statement.kind =
- StatementKind::Assign(Box::new((*into, Rvalue::Len(place))));
- statements.push(len_statement);
-
- // make temporary dead
- let mut make_dead_statement = statement.clone();
- make_dead_statement.kind = StatementKind::StorageDead(local);
- statements.push(make_dead_statement);
-
- // make original statement NOP
- statement.make_nop();
+ // The length information is stored in the fat pointer, so we treat `operand` as a value.
+ Rvalue::Use(operand) => {
+ if let Some(rhs) = operand.place() && let Some(rhs) = rhs.as_local() {
+ slice_lengths[local] = slice_lengths[rhs];
}
- _ => {
- unreachable!("it's a bug in the implementation")
+ }
+ // The length information is stored in the fat pointer.
+ // Reborrowing copies length information from one pointer to the other.
+ Rvalue::Ref(_, _, rhs) | Rvalue::AddressOf(_, rhs) => {
+ if let [PlaceElem::Deref] = rhs.projection[..] {
+ slice_lengths[local] = slice_lengths[rhs.local];
}
}
-
- self.statement_idx += 1;
-
- Some(statements.into_iter())
- } else {
- self.statement_idx += 1;
- None
+ _ => {}
}
}
+
+ slice_lengths
}
-fn normalize_array_len_call<'tcx>(
+struct Replacer<'tcx> {
tcx: TyCtxt<'tcx>,
- block: &mut BasicBlockData<'tcx>,
- local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
- interesting_locals: &BitSet<Local>,
- state: &mut FxIndexMap<Local, usize>,
- patches_scratchpad: &mut FxIndexMap<usize, usize>,
- replacements_scratchpad: &mut FxIndexMap<usize, Local>,
-) {
- for (statement_idx, statement) in block.statements.iter_mut().enumerate() {
- match &mut statement.kind {
- StatementKind::Assign(box (place, rvalue)) => {
- match rvalue {
- Rvalue::Cast(
- CastKind::Pointer(ty::adjustment::PointerCast::Unsize),
- operand,
- cast_ty,
- ) => {
- let Some(local) = place.as_local() else { return };
- match operand {
- Operand::Copy(place) | Operand::Move(place) => {
- let Some(operand_local) = place.local_or_deref_local() else { return; };
- if !interesting_locals.contains(operand_local) {
- return;
- }
- let operand_ty = local_decls[operand_local].ty;
- match (operand_ty.kind(), cast_ty.kind()) {
- (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => {
- if of_ty_src == of_ty_dst {
- // this is a cast from [T; N] into [T], so we are good
- state.insert(local, statement_idx);
- }
- }
- // current way of patching doesn't allow to work with `mut`
- (
- ty::Ref(
- Region(Interned(ReErased, _)),
- operand_ty,
- Mutability::Not,
- ),
- ty::Ref(
- Region(Interned(ReErased, _)),
- cast_ty,
- Mutability::Not,
- ),
- ) => {
- match (operand_ty.kind(), cast_ty.kind()) {
- // current way of patching doesn't allow to work with `mut`
- (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => {
- if of_ty_src == of_ty_dst {
- // this is a cast from [T; N] into [T], so we are good
- state.insert(local, statement_idx);
- }
- }
- _ => {}
- }
- }
- _ => {}
- }
- }
- _ => {}
- }
- }
- Rvalue::Len(place) => {
- let Some(local) = place.local_or_deref_local() else {
- return;
- };
- if let Some(cast_statement_idx) = state.get(&local).copied() {
- patches_scratchpad.insert(cast_statement_idx, statement_idx);
- }
- }
- _ => {
- // invalidate
- state.remove(&place.local);
- }
- }
- }
- _ => {}
- }
- }
+ slice_lengths: IndexVec<Local, Option<ty::Const<'tcx>>>,
+}
- let mut patcher = Patcher {
- tcx,
- patches_scratchpad: &*patches_scratchpad,
- replacements_scratchpad,
- local_decls,
- statement_idx: 0,
- };
+impl<'tcx> MutVisitor<'tcx> for Replacer<'tcx> {
+ fn tcx(&self) -> TyCtxt<'tcx> {
+ self.tcx
+ }
- block.expand_statements(|st| patcher.patch_expand_statement(st));
+ fn visit_rvalue(&mut self, rvalue: &mut Rvalue<'tcx>, loc: Location) {
+ if let Rvalue::Len(place) = rvalue
+ && let [PlaceElem::Deref] = &place.projection[..]
+ && let Some(len) = self.slice_lengths[place.local]
+ {
+ *rvalue = Rvalue::Use(Operand::Constant(Box::new(Constant {
+ span: rustc_span::DUMMY_SP,
+ user_ty: None,
+ literal: ConstantKind::from_const(len, self.tcx),
+ })));
+ }
+ self.super_rvalue(rvalue, loc);
+ }
}
diff --git a/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs b/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs
index f1bbf2ea7..e3a03aa08 100644
--- a/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs
+++ b/compiler/rustc_mir_transform/src/remove_noop_landing_pads.rs
@@ -35,6 +35,7 @@ impl RemoveNoopLandingPads {
| StatementKind::StorageDead(_)
| StatementKind::AscribeUserType(..)
| StatementKind::Coverage(..)
+ | StatementKind::ConstEvalCounter
| StatementKind::Nop => {
// These are all noops in a landing pad
}
diff --git a/compiler/rustc_mir_transform/src/remove_zsts.rs b/compiler/rustc_mir_transform/src/remove_zsts.rs
index 6cabef92d..1becfddb2 100644
--- a/compiler/rustc_mir_transform/src/remove_zsts.rs
+++ b/compiler/rustc_mir_transform/src/remove_zsts.rs
@@ -13,7 +13,7 @@ impl<'tcx> MirPass<'tcx> for RemoveZsts {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
// Avoid query cycles (generators require optimized MIR for layout).
- if tcx.type_of(body.source.def_id()).is_generator() {
+ if tcx.type_of(body.source.def_id()).subst_identity().is_generator() {
return;
}
let param_env = tcx.param_env(body.source.def_id());
diff --git a/compiler/rustc_mir_transform/src/separate_const_switch.rs b/compiler/rustc_mir_transform/src/separate_const_switch.rs
index 2f116aaa9..a24d2d34d 100644
--- a/compiler/rustc_mir_transform/src/separate_const_switch.rs
+++ b/compiler/rustc_mir_transform/src/separate_const_switch.rs
@@ -250,6 +250,7 @@ fn is_likely_const<'tcx>(mut tracked_place: Place<'tcx>, block: &BasicBlockData<
| StatementKind::Coverage(_)
| StatementKind::StorageDead(_)
| StatementKind::Intrinsic(_)
+ | StatementKind::ConstEvalCounter
| StatementKind::Nop => {}
}
}
@@ -318,6 +319,7 @@ fn find_determining_place<'tcx>(
| StatementKind::AscribeUserType(_, _)
| StatementKind::Coverage(_)
| StatementKind::Intrinsic(_)
+ | StatementKind::ConstEvalCounter
| StatementKind::Nop => {}
// If the discriminant is set, it is always set
diff --git a/compiler/rustc_mir_transform/src/shim.rs b/compiler/rustc_mir_transform/src/shim.rs
index dace540fa..ebe63d6cb 100644
--- a/compiler/rustc_mir_transform/src/shim.rs
+++ b/compiler/rustc_mir_transform/src/shim.rs
@@ -15,7 +15,6 @@ use rustc_target::spec::abi::Abi;
use std::fmt;
use std::iter;
-use crate::util::expand_aggregate;
use crate::{
abort_unwinding_calls, add_call_guards, add_moves_for_packed_drops, deref_separator,
pass_manager as pm, remove_noop_landing_pads, simplify,
@@ -148,11 +147,11 @@ fn build_drop_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, ty: Option<Ty<'tcx>>)
assert!(!matches!(ty, Some(ty) if ty.is_generator()));
let substs = if let Some(ty) = ty {
- tcx.intern_substs(&[ty.into()])
+ tcx.mk_substs(&[ty.into()])
} else {
InternalSubsts::identity_for_item(tcx, def_id)
};
- let sig = tcx.bound_fn_sig(def_id).subst(tcx, substs);
+ let sig = tcx.fn_sig(def_id).subst(tcx, substs);
let sig = tcx.erase_late_bound_regions(sig);
let span = tcx.def_span(def_id);
@@ -363,7 +362,7 @@ impl<'tcx> CloneShimBuilder<'tcx> {
// we must subst the self_ty because it's
// otherwise going to be TySelf and we can't index
// or access fields of a Place of type TySelf.
- let sig = tcx.bound_fn_sig(def_id).subst(tcx, &[self_ty.into()]);
+ let sig = tcx.fn_sig(def_id).subst(tcx, &[self_ty.into()]);
let sig = tcx.erase_late_bound_regions(sig);
let span = tcx.def_span(def_id);
@@ -427,7 +426,7 @@ impl<'tcx> CloneShimBuilder<'tcx> {
fn make_place(&mut self, mutability: Mutability, ty: Ty<'tcx>) -> Place<'tcx> {
let span = self.span;
let mut local = LocalDecl::new(ty, span);
- if mutability == Mutability::Not {
+ if mutability.is_not() {
local = local.immutable();
}
Place::from(self.local_decls.push(local))
@@ -598,7 +597,7 @@ fn build_call_shim<'tcx>(
let untuple_args = sig.inputs();
// Create substitutions for the `Self` and `Args` generic parameters of the shim body.
- let arg_tup = tcx.mk_tup(untuple_args.iter());
+ let arg_tup = tcx.mk_tup(untuple_args);
(Some([ty.into(), arg_tup.into()]), Some(untuple_args))
} else {
@@ -606,7 +605,7 @@ fn build_call_shim<'tcx>(
};
let def_id = instance.def_id();
- let sig = tcx.bound_fn_sig(def_id);
+ let sig = tcx.fn_sig(def_id);
let sig = sig.map_bound(|sig| tcx.erase_late_bound_regions(sig));
assert_eq!(sig_substs.is_some(), !instance.has_polymorphic_mir_body());
@@ -633,7 +632,7 @@ fn build_call_shim<'tcx>(
Adjustment::Deref => tcx.mk_imm_ptr(fnty),
Adjustment::RefMut => tcx.mk_mut_ptr(fnty),
};
- sig.inputs_and_output = tcx.intern_type_list(&inputs_and_output);
+ sig.inputs_and_output = tcx.mk_type_list(&inputs_and_output);
}
// FIXME(eddyb) avoid having this snippet both here and in
@@ -644,7 +643,7 @@ fn build_call_shim<'tcx>(
let self_arg = &mut inputs_and_output[0];
debug_assert!(tcx.generics_of(def_id).has_self && *self_arg == tcx.types.self_param);
*self_arg = tcx.mk_mut_ptr(*self_arg);
- sig.inputs_and_output = tcx.intern_type_list(&inputs_and_output);
+ sig.inputs_and_output = tcx.mk_type_list(&inputs_and_output);
}
let span = tcx.def_span(def_id);
@@ -693,7 +692,7 @@ fn build_call_shim<'tcx>(
// `FnDef` call with optional receiver.
CallKind::Direct(def_id) => {
- let ty = tcx.type_of(def_id);
+ let ty = tcx.type_of(def_id).subst_identity();
(
Operand::Constant(Box::new(Constant {
span,
@@ -798,7 +797,11 @@ pub fn build_adt_ctor(tcx: TyCtxt<'_>, ctor_id: DefId) -> Body<'_> {
let param_env = tcx.param_env(ctor_id);
// Normalize the sig.
- let sig = tcx.fn_sig(ctor_id).no_bound_vars().expect("LBR in ADT constructor signature");
+ let sig = tcx
+ .fn_sig(ctor_id)
+ .subst_identity()
+ .no_bound_vars()
+ .expect("LBR in ADT constructor signature");
let sig = tcx.normalize_erasing_regions(param_env, sig);
let ty::Adt(adt_def, substs) = sig.output().kind() else {
@@ -827,19 +830,23 @@ pub fn build_adt_ctor(tcx: TyCtxt<'_>, ctor_id: DefId) -> Body<'_> {
// return;
debug!("build_ctor: variant_index={:?}", variant_index);
- let statements = expand_aggregate(
- Place::return_place(),
- adt_def.variant(variant_index).fields.iter().enumerate().map(|(idx, field_def)| {
- (Operand::Move(Place::from(Local::new(idx + 1))), field_def.ty(tcx, substs))
- }),
- AggregateKind::Adt(adt_def.did(), variant_index, substs, None, None),
+ let kind = AggregateKind::Adt(adt_def.did(), variant_index, substs, None, None);
+ let variant = adt_def.variant(variant_index);
+ let statement = Statement {
+ kind: StatementKind::Assign(Box::new((
+ Place::return_place(),
+ Rvalue::Aggregate(
+ Box::new(kind),
+ (0..variant.fields.len())
+ .map(|idx| Operand::Move(Place::from(Local::new(idx + 1))))
+ .collect(),
+ ),
+ ))),
source_info,
- tcx,
- )
- .collect();
+ };
let start_block = BasicBlockData {
- statements,
+ statements: vec![statement],
terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }),
is_cleanup: false,
};
diff --git a/compiler/rustc_mir_transform/src/simplify.rs b/compiler/rustc_mir_transform/src/simplify.rs
index 8f6abe7a9..9ef55c558 100644
--- a/compiler/rustc_mir_transform/src/simplify.rs
+++ b/compiler/rustc_mir_transform/src/simplify.rs
@@ -404,6 +404,18 @@ impl<'tcx> MirPass<'tcx> for SimplifyLocals {
}
}
+pub fn remove_unused_definitions<'tcx>(body: &mut Body<'tcx>) {
+ // First, we're going to get a count of *actual* uses for every `Local`.
+ let mut used_locals = UsedLocals::new(body);
+
+ // Next, we're going to remove any `Local` with zero actual uses. When we remove those
+ // `Locals`, we're also going to subtract any uses of other `Locals` from the `used_locals`
+ // count. For example, if we removed `_2 = discriminant(_1)`, then we'll subtract one from
+ // `use_counts[_1]`. That in turn might make `_1` unused, so we loop until we hit a
+ // fixedpoint where there are no more unused locals.
+ remove_unused_definitions_helper(&mut used_locals, body);
+}
+
pub fn simplify_locals<'tcx>(body: &mut Body<'tcx>, tcx: TyCtxt<'tcx>) {
// First, we're going to get a count of *actual* uses for every `Local`.
let mut used_locals = UsedLocals::new(body);
@@ -413,7 +425,7 @@ pub fn simplify_locals<'tcx>(body: &mut Body<'tcx>, tcx: TyCtxt<'tcx>) {
// count. For example, if we removed `_2 = discriminant(_1)`, then we'll subtract one from
// `use_counts[_1]`. That in turn might make `_1` unused, so we loop until we hit a
// fixedpoint where there are no more unused locals.
- remove_unused_definitions(&mut used_locals, body);
+ remove_unused_definitions_helper(&mut used_locals, body);
// Finally, we'll actually do the work of shrinking `body.local_decls` and remapping the `Local`s.
let map = make_local_map(&mut body.local_decls, &used_locals);
@@ -484,7 +496,7 @@ impl UsedLocals {
self.increment = false;
// The location of the statement is irrelevant.
- let location = Location { block: START_BLOCK, statement_index: 0 };
+ let location = Location::START;
self.visit_statement(statement, location);
}
@@ -517,7 +529,7 @@ impl<'tcx> Visitor<'tcx> for UsedLocals {
self.super_statement(statement, location);
}
- StatementKind::Nop => {}
+ StatementKind::ConstEvalCounter | StatementKind::Nop => {}
StatementKind::StorageLive(_local) | StatementKind::StorageDead(_local) => {}
@@ -548,7 +560,7 @@ impl<'tcx> Visitor<'tcx> for UsedLocals {
}
/// Removes unused definitions. Updates the used locals to reflect the changes made.
-fn remove_unused_definitions(used_locals: &mut UsedLocals, body: &mut Body<'_>) {
+fn remove_unused_definitions_helper(used_locals: &mut UsedLocals, body: &mut Body<'_>) {
// The use counts are updated as we remove the statements. A local might become unused
// during the retain operation, leading to a temporary inconsistency (storage statements or
// definitions referencing the local might remain). For correctness it is crucial that this
diff --git a/compiler/rustc_mir_transform/src/simplify_try.rs b/compiler/rustc_mir_transform/src/simplify_try.rs
deleted file mode 100644
index e4f3ace9a..000000000
--- a/compiler/rustc_mir_transform/src/simplify_try.rs
+++ /dev/null
@@ -1,822 +0,0 @@
-//! The general point of the optimizations provided here is to simplify something like:
-//!
-//! ```rust
-//! # fn foo<T, E>(x: Result<T, E>) -> Result<T, E> {
-//! match x {
-//! Ok(x) => Ok(x),
-//! Err(x) => Err(x)
-//! }
-//! # }
-//! ```
-//!
-//! into just `x`.
-
-use crate::{simplify, MirPass};
-use itertools::Itertools as _;
-use rustc_index::{bit_set::BitSet, vec::IndexVec};
-use rustc_middle::mir::visit::{NonUseContext, PlaceContext, Visitor};
-use rustc_middle::mir::*;
-use rustc_middle::ty::{self, List, Ty, TyCtxt};
-use rustc_target::abi::VariantIdx;
-use std::iter::{once, Enumerate, Peekable};
-use std::slice::Iter;
-
-/// Simplifies arms of form `Variant(x) => Variant(x)` to just a move.
-///
-/// This is done by transforming basic blocks where the statements match:
-///
-/// ```ignore (MIR)
-/// _LOCAL_TMP = ((_LOCAL_1 as Variant ).FIELD: TY );
-/// _TMP_2 = _LOCAL_TMP;
-/// ((_LOCAL_0 as Variant).FIELD: TY) = move _TMP_2;
-/// discriminant(_LOCAL_0) = VAR_IDX;
-/// ```
-///
-/// into:
-///
-/// ```ignore (MIR)
-/// _LOCAL_0 = move _LOCAL_1
-/// ```
-pub struct SimplifyArmIdentity;
-
-#[derive(Debug)]
-struct ArmIdentityInfo<'tcx> {
- /// Storage location for the variant's field
- local_temp_0: Local,
- /// Storage location holding the variant being read from
- local_1: Local,
- /// The variant field being read from
- vf_s0: VarField<'tcx>,
- /// Index of the statement which loads the variant being read
- get_variant_field_stmt: usize,
-
- /// Tracks each assignment to a temporary of the variant's field
- field_tmp_assignments: Vec<(Local, Local)>,
-
- /// Storage location holding the variant's field that was read from
- local_tmp_s1: Local,
- /// Storage location holding the enum that we are writing to
- local_0: Local,
- /// The variant field being written to
- vf_s1: VarField<'tcx>,
-
- /// Storage location that the discriminant is being written to
- set_discr_local: Local,
- /// The variant being written
- set_discr_var_idx: VariantIdx,
-
- /// Index of the statement that should be overwritten as a move
- stmt_to_overwrite: usize,
- /// SourceInfo for the new move
- source_info: SourceInfo,
-
- /// Indices of matching Storage{Live,Dead} statements encountered.
- /// (StorageLive index,, StorageDead index, Local)
- storage_stmts: Vec<(usize, usize, Local)>,
-
- /// The statements that should be removed (turned into nops)
- stmts_to_remove: Vec<usize>,
-
- /// Indices of debug variables that need to be adjusted to point to
- // `{local_0}.{dbg_projection}`.
- dbg_info_to_adjust: Vec<usize>,
-
- /// The projection used to rewrite debug info.
- dbg_projection: &'tcx List<PlaceElem<'tcx>>,
-}
-
-fn get_arm_identity_info<'a, 'tcx>(
- stmts: &'a [Statement<'tcx>],
- locals_count: usize,
- debug_info: &'a [VarDebugInfo<'tcx>],
-) -> Option<ArmIdentityInfo<'tcx>> {
- // This can't possibly match unless there are at least 3 statements in the block
- // so fail fast on tiny blocks.
- if stmts.len() < 3 {
- return None;
- }
-
- let mut tmp_assigns = Vec::new();
- let mut nop_stmts = Vec::new();
- let mut storage_stmts = Vec::new();
- let mut storage_live_stmts = Vec::new();
- let mut storage_dead_stmts = Vec::new();
-
- type StmtIter<'a, 'tcx> = Peekable<Enumerate<Iter<'a, Statement<'tcx>>>>;
-
- fn is_storage_stmt(stmt: &Statement<'_>) -> bool {
- matches!(stmt.kind, StatementKind::StorageLive(_) | StatementKind::StorageDead(_))
- }
-
- /// Eats consecutive Statements which match `test`, performing the specified `action` for each.
- /// The iterator `stmt_iter` is not advanced if none were matched.
- fn try_eat<'a, 'tcx>(
- stmt_iter: &mut StmtIter<'a, 'tcx>,
- test: impl Fn(&'a Statement<'tcx>) -> bool,
- mut action: impl FnMut(usize, &'a Statement<'tcx>),
- ) {
- while stmt_iter.peek().map_or(false, |(_, stmt)| test(stmt)) {
- let (idx, stmt) = stmt_iter.next().unwrap();
-
- action(idx, stmt);
- }
- }
-
- /// Eats consecutive `StorageLive` and `StorageDead` Statements.
- /// The iterator `stmt_iter` is not advanced if none were found.
- fn try_eat_storage_stmts(
- stmt_iter: &mut StmtIter<'_, '_>,
- storage_live_stmts: &mut Vec<(usize, Local)>,
- storage_dead_stmts: &mut Vec<(usize, Local)>,
- ) {
- try_eat(stmt_iter, is_storage_stmt, |idx, stmt| {
- if let StatementKind::StorageLive(l) = stmt.kind {
- storage_live_stmts.push((idx, l));
- } else if let StatementKind::StorageDead(l) = stmt.kind {
- storage_dead_stmts.push((idx, l));
- }
- })
- }
-
- fn is_tmp_storage_stmt(stmt: &Statement<'_>) -> bool {
- use rustc_middle::mir::StatementKind::Assign;
- if let Assign(box (place, Rvalue::Use(Operand::Copy(p) | Operand::Move(p)))) = &stmt.kind {
- place.as_local().is_some() && p.as_local().is_some()
- } else {
- false
- }
- }
-
- /// Eats consecutive `Assign` Statements.
- // The iterator `stmt_iter` is not advanced if none were found.
- fn try_eat_assign_tmp_stmts(
- stmt_iter: &mut StmtIter<'_, '_>,
- tmp_assigns: &mut Vec<(Local, Local)>,
- nop_stmts: &mut Vec<usize>,
- ) {
- try_eat(stmt_iter, is_tmp_storage_stmt, |idx, stmt| {
- use rustc_middle::mir::StatementKind::Assign;
- if let Assign(box (place, Rvalue::Use(Operand::Copy(p) | Operand::Move(p)))) =
- &stmt.kind
- {
- tmp_assigns.push((place.as_local().unwrap(), p.as_local().unwrap()));
- nop_stmts.push(idx);
- }
- })
- }
-
- fn find_storage_live_dead_stmts_for_local(
- local: Local,
- stmts: &[Statement<'_>],
- ) -> Option<(usize, usize)> {
- trace!("looking for {:?}", local);
- let mut storage_live_stmt = None;
- let mut storage_dead_stmt = None;
- for (idx, stmt) in stmts.iter().enumerate() {
- if stmt.kind == StatementKind::StorageLive(local) {
- storage_live_stmt = Some(idx);
- } else if stmt.kind == StatementKind::StorageDead(local) {
- storage_dead_stmt = Some(idx);
- }
- }
-
- Some((storage_live_stmt?, storage_dead_stmt.unwrap_or(usize::MAX)))
- }
-
- // Try to match the expected MIR structure with the basic block we're processing.
- // We want to see something that looks like:
- // ```
- // (StorageLive(_) | StorageDead(_));*
- // _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY);
- // (StorageLive(_) | StorageDead(_));*
- // (tmp_n+1 = tmp_n);*
- // (StorageLive(_) | StorageDead(_));*
- // (tmp_n+1 = tmp_n);*
- // ((LOCAL_FROM as Variant).FIELD: TY) = move tmp;
- // discriminant(LOCAL_FROM) = VariantIdx;
- // (StorageLive(_) | StorageDead(_));*
- // ```
- let mut stmt_iter = stmts.iter().enumerate().peekable();
-
- try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);
-
- let (get_variant_field_stmt, stmt) = stmt_iter.next()?;
- let (local_tmp_s0, local_1, vf_s0, dbg_projection) = match_get_variant_field(stmt)?;
-
- try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);
-
- try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts);
-
- try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);
-
- try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts);
-
- let (idx, stmt) = stmt_iter.next()?;
- let (local_tmp_s1, local_0, vf_s1) = match_set_variant_field(stmt)?;
- nop_stmts.push(idx);
-
- let (idx, stmt) = stmt_iter.next()?;
- let (set_discr_local, set_discr_var_idx) = match_set_discr(stmt)?;
- let discr_stmt_source_info = stmt.source_info;
- nop_stmts.push(idx);
-
- try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);
-
- for (live_idx, live_local) in storage_live_stmts {
- if let Some(i) = storage_dead_stmts.iter().rposition(|(_, l)| *l == live_local) {
- let (dead_idx, _) = storage_dead_stmts.swap_remove(i);
- storage_stmts.push((live_idx, dead_idx, live_local));
-
- if live_local == local_tmp_s0 {
- nop_stmts.push(get_variant_field_stmt);
- }
- }
- }
- // We sort primitive usize here so we can use unstable sort
- nop_stmts.sort_unstable();
-
- // Use one of the statements we're going to discard between the point
- // where the storage location for the variant field becomes live and
- // is killed.
- let (live_idx, dead_idx) = find_storage_live_dead_stmts_for_local(local_tmp_s0, stmts)?;
- let stmt_to_overwrite =
- nop_stmts.iter().find(|stmt_idx| live_idx < **stmt_idx && **stmt_idx < dead_idx);
-
- let mut tmp_assigned_vars = BitSet::new_empty(locals_count);
- for (l, r) in &tmp_assigns {
- tmp_assigned_vars.insert(*l);
- tmp_assigned_vars.insert(*r);
- }
-
- let dbg_info_to_adjust: Vec<_> = debug_info
- .iter()
- .enumerate()
- .filter_map(|(i, var_info)| {
- if let VarDebugInfoContents::Place(p) = var_info.value {
- if tmp_assigned_vars.contains(p.local) {
- return Some(i);
- }
- }
-
- None
- })
- .collect();
-
- Some(ArmIdentityInfo {
- local_temp_0: local_tmp_s0,
- local_1,
- vf_s0,
- get_variant_field_stmt,
- field_tmp_assignments: tmp_assigns,
- local_tmp_s1,
- local_0,
- vf_s1,
- set_discr_local,
- set_discr_var_idx,
- stmt_to_overwrite: *stmt_to_overwrite?,
- source_info: discr_stmt_source_info,
- storage_stmts,
- stmts_to_remove: nop_stmts,
- dbg_info_to_adjust,
- dbg_projection,
- })
-}
-
-fn optimization_applies<'tcx>(
- opt_info: &ArmIdentityInfo<'tcx>,
- local_decls: &IndexVec<Local, LocalDecl<'tcx>>,
- local_uses: &IndexVec<Local, usize>,
- var_debug_info: &[VarDebugInfo<'tcx>],
-) -> bool {
- trace!("testing if optimization applies...");
-
- // FIXME(wesleywiser): possibly relax this restriction?
- if opt_info.local_0 == opt_info.local_1 {
- trace!("NO: moving into ourselves");
- return false;
- } else if opt_info.vf_s0 != opt_info.vf_s1 {
- trace!("NO: the field-and-variant information do not match");
- return false;
- } else if local_decls[opt_info.local_0].ty != local_decls[opt_info.local_1].ty {
- // FIXME(Centril,oli-obk): possibly relax to same layout?
- trace!("NO: source and target locals have different types");
- return false;
- } else if (opt_info.local_0, opt_info.vf_s0.var_idx)
- != (opt_info.set_discr_local, opt_info.set_discr_var_idx)
- {
- trace!("NO: the discriminants do not match");
- return false;
- }
-
- // Verify the assignment chain consists of the form b = a; c = b; d = c; etc...
- if opt_info.field_tmp_assignments.is_empty() {
- trace!("NO: no assignments found");
- return false;
- }
- let mut last_assigned_to = opt_info.field_tmp_assignments[0].1;
- let source_local = last_assigned_to;
- for (l, r) in &opt_info.field_tmp_assignments {
- if *r != last_assigned_to {
- trace!("NO: found unexpected assignment {:?} = {:?}", l, r);
- return false;
- }
-
- last_assigned_to = *l;
- }
-
- // Check that the first and last used locals are only used twice
- // since they are of the form:
- //
- // ```
- // _first = ((_x as Variant).n: ty);
- // _n = _first;
- // ...
- // ((_y as Variant).n: ty) = _n;
- // discriminant(_y) = z;
- // ```
- for (l, r) in &opt_info.field_tmp_assignments {
- if local_uses[*l] != 2 {
- warn!("NO: FAILED assignment chain local {:?} was used more than twice", l);
- return false;
- } else if local_uses[*r] != 2 {
- warn!("NO: FAILED assignment chain local {:?} was used more than twice", r);
- return false;
- }
- }
-
- // Check that debug info only points to full Locals and not projections.
- for dbg_idx in &opt_info.dbg_info_to_adjust {
- let dbg_info = &var_debug_info[*dbg_idx];
- if let VarDebugInfoContents::Place(p) = dbg_info.value {
- if !p.projection.is_empty() {
- trace!("NO: debug info for {:?} had a projection {:?}", dbg_info.name, p);
- return false;
- }
- }
- }
-
- if source_local != opt_info.local_temp_0 {
- trace!(
- "NO: start of assignment chain does not match enum variant temp: {:?} != {:?}",
- source_local,
- opt_info.local_temp_0
- );
- return false;
- } else if last_assigned_to != opt_info.local_tmp_s1 {
- trace!(
- "NO: end of assignment chain does not match written enum temp: {:?} != {:?}",
- last_assigned_to,
- opt_info.local_tmp_s1
- );
- return false;
- }
-
- trace!("SUCCESS: optimization applies!");
- true
-}
-
-impl<'tcx> MirPass<'tcx> for SimplifyArmIdentity {
- fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
- // FIXME(77359): This optimization can result in unsoundness.
- if !tcx.sess.opts.unstable_opts.unsound_mir_opts {
- return;
- }
-
- let source = body.source;
- trace!("running SimplifyArmIdentity on {:?}", source);
-
- let local_uses = LocalUseCounter::get_local_uses(body);
- for bb in body.basic_blocks.as_mut() {
- if let Some(opt_info) =
- get_arm_identity_info(&bb.statements, body.local_decls.len(), &body.var_debug_info)
- {
- trace!("got opt_info = {:#?}", opt_info);
- if !optimization_applies(
- &opt_info,
- &body.local_decls,
- &local_uses,
- &body.var_debug_info,
- ) {
- debug!("optimization skipped for {:?}", source);
- continue;
- }
-
- // Also remove unused Storage{Live,Dead} statements which correspond
- // to temps used previously.
- for (live_idx, dead_idx, local) in &opt_info.storage_stmts {
- // The temporary that we've read the variant field into is scoped to this block,
- // so we can remove the assignment.
- if *local == opt_info.local_temp_0 {
- bb.statements[opt_info.get_variant_field_stmt].make_nop();
- }
-
- for (left, right) in &opt_info.field_tmp_assignments {
- if local == left || local == right {
- bb.statements[*live_idx].make_nop();
- bb.statements[*dead_idx].make_nop();
- }
- }
- }
-
- // Right shape; transform
- for stmt_idx in opt_info.stmts_to_remove {
- bb.statements[stmt_idx].make_nop();
- }
-
- let stmt = &mut bb.statements[opt_info.stmt_to_overwrite];
- stmt.source_info = opt_info.source_info;
- stmt.kind = StatementKind::Assign(Box::new((
- opt_info.local_0.into(),
- Rvalue::Use(Operand::Move(opt_info.local_1.into())),
- )));
-
- bb.statements.retain(|stmt| stmt.kind != StatementKind::Nop);
-
- // Fix the debug info to point to the right local
- for dbg_index in opt_info.dbg_info_to_adjust {
- let dbg_info = &mut body.var_debug_info[dbg_index];
- assert!(
- matches!(dbg_info.value, VarDebugInfoContents::Place(_)),
- "value was not a Place"
- );
- if let VarDebugInfoContents::Place(p) = &mut dbg_info.value {
- assert!(p.projection.is_empty());
- p.local = opt_info.local_0;
- p.projection = opt_info.dbg_projection;
- }
- }
-
- trace!("block is now {:?}", bb.statements);
- }
- }
- }
-}
-
-struct LocalUseCounter {
- local_uses: IndexVec<Local, usize>,
-}
-
-impl LocalUseCounter {
- fn get_local_uses(body: &Body<'_>) -> IndexVec<Local, usize> {
- let mut counter = LocalUseCounter { local_uses: IndexVec::from_elem(0, &body.local_decls) };
- counter.visit_body(body);
- counter.local_uses
- }
-}
-
-impl Visitor<'_> for LocalUseCounter {
- fn visit_local(&mut self, local: Local, context: PlaceContext, _location: Location) {
- if context.is_storage_marker()
- || context == PlaceContext::NonUse(NonUseContext::VarDebugInfo)
- {
- return;
- }
-
- self.local_uses[local] += 1;
- }
-}
-
-/// Match on:
-/// ```ignore (MIR)
-/// _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY);
-/// ```
-fn match_get_variant_field<'tcx>(
- stmt: &Statement<'tcx>,
-) -> Option<(Local, Local, VarField<'tcx>, &'tcx List<PlaceElem<'tcx>>)> {
- match &stmt.kind {
- StatementKind::Assign(box (
- place_into,
- Rvalue::Use(Operand::Copy(pf) | Operand::Move(pf)),
- )) => {
- let local_into = place_into.as_local()?;
- let (local_from, vf) = match_variant_field_place(*pf)?;
- Some((local_into, local_from, vf, pf.projection))
- }
- _ => None,
- }
-}
-
-/// Match on:
-/// ```ignore (MIR)
-/// ((_LOCAL_FROM as Variant).FIELD: TY) = move _LOCAL_INTO;
-/// ```
-fn match_set_variant_field<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, Local, VarField<'tcx>)> {
- match &stmt.kind {
- StatementKind::Assign(box (place_from, Rvalue::Use(Operand::Move(place_into)))) => {
- let local_into = place_into.as_local()?;
- let (local_from, vf) = match_variant_field_place(*place_from)?;
- Some((local_into, local_from, vf))
- }
- _ => None,
- }
-}
-
-/// Match on:
-/// ```ignore (MIR)
-/// discriminant(_LOCAL_TO_SET) = VAR_IDX;
-/// ```
-fn match_set_discr(stmt: &Statement<'_>) -> Option<(Local, VariantIdx)> {
- match &stmt.kind {
- StatementKind::SetDiscriminant { place, variant_index } => {
- Some((place.as_local()?, *variant_index))
- }
- _ => None,
- }
-}
-
-#[derive(PartialEq, Debug)]
-struct VarField<'tcx> {
- field: Field,
- field_ty: Ty<'tcx>,
- var_idx: VariantIdx,
-}
-
-/// Match on `((_LOCAL as Variant).FIELD: TY)`.
-fn match_variant_field_place(place: Place<'_>) -> Option<(Local, VarField<'_>)> {
- match place.as_ref() {
- PlaceRef {
- local,
- projection: &[ProjectionElem::Downcast(_, var_idx), ProjectionElem::Field(field, ty)],
- } => Some((local, VarField { field, field_ty: ty, var_idx })),
- _ => None,
- }
-}
-
-/// Simplifies `SwitchInt(_) -> [targets]`,
-/// where all the `targets` have the same form,
-/// into `goto -> target_first`.
-pub struct SimplifyBranchSame;
-
-impl<'tcx> MirPass<'tcx> for SimplifyBranchSame {
- fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
- // This optimization is disabled by default for now due to
- // soundness concerns; see issue #89485 and PR #89489.
- if !tcx.sess.opts.unstable_opts.unsound_mir_opts {
- return;
- }
-
- trace!("Running SimplifyBranchSame on {:?}", body.source);
- let finder = SimplifyBranchSameOptimizationFinder { body, tcx };
- let opts = finder.find();
-
- let did_remove_blocks = opts.len() > 0;
- for opt in opts.iter() {
- trace!("SUCCESS: Applying optimization {:?}", opt);
- // Replace `SwitchInt(..) -> [bb_first, ..];` with a `goto -> bb_first;`.
- body.basic_blocks_mut()[opt.bb_to_opt_terminator].terminator_mut().kind =
- TerminatorKind::Goto { target: opt.bb_to_goto };
- }
-
- if did_remove_blocks {
- // We have dead blocks now, so remove those.
- simplify::remove_dead_blocks(tcx, body);
- }
- }
-}
-
-#[derive(Debug)]
-struct SimplifyBranchSameOptimization {
- /// All basic blocks are equal so go to this one
- bb_to_goto: BasicBlock,
- /// Basic block where the terminator can be simplified to a goto
- bb_to_opt_terminator: BasicBlock,
-}
-
-struct SwitchTargetAndValue {
- target: BasicBlock,
- // None in case of the `otherwise` case
- value: Option<u128>,
-}
-
-struct SimplifyBranchSameOptimizationFinder<'a, 'tcx> {
- body: &'a Body<'tcx>,
- tcx: TyCtxt<'tcx>,
-}
-
-impl<'tcx> SimplifyBranchSameOptimizationFinder<'_, 'tcx> {
- fn find(&self) -> Vec<SimplifyBranchSameOptimization> {
- self.body
- .basic_blocks
- .iter_enumerated()
- .filter_map(|(bb_idx, bb)| {
- let (discr_switched_on, targets_and_values) = match &bb.terminator().kind {
- TerminatorKind::SwitchInt { targets, discr, .. } => {
- let targets_and_values: Vec<_> = targets.iter()
- .map(|(val, target)| SwitchTargetAndValue { target, value: Some(val) })
- .chain(once(SwitchTargetAndValue { target: targets.otherwise(), value: None }))
- .collect();
- (discr, targets_and_values)
- },
- _ => return None,
- };
-
- // find the adt that has its discriminant read
- // assuming this must be the last statement of the block
- let adt_matched_on = match &bb.statements.last()?.kind {
- StatementKind::Assign(box (place, rhs))
- if Some(*place) == discr_switched_on.place() =>
- {
- match rhs {
- Rvalue::Discriminant(adt_place) if adt_place.ty(self.body, self.tcx).ty.is_enum() => adt_place,
- _ => {
- trace!("NO: expected a discriminant read of an enum instead of: {:?}", rhs);
- return None;
- }
- }
- }
- other => {
- trace!("NO: expected an assignment of a discriminant read to a place. Found: {:?}", other);
- return None
- },
- };
-
- let mut iter_bbs_reachable = targets_and_values
- .iter()
- .map(|target_and_value| (target_and_value, &self.body.basic_blocks[target_and_value.target]))
- .filter(|(_, bb)| {
- // Reaching `unreachable` is UB so assume it doesn't happen.
- bb.terminator().kind != TerminatorKind::Unreachable
- })
- .peekable();
-
- let bb_first = iter_bbs_reachable.peek().map_or(&targets_and_values[0], |(idx, _)| *idx);
- let mut all_successors_equivalent = StatementEquality::TrivialEqual;
-
- // All successor basic blocks must be equal or contain statements that are pairwise considered equal.
- for ((target_and_value_l,bb_l), (target_and_value_r,bb_r)) in iter_bbs_reachable.tuple_windows() {
- let trivial_checks = bb_l.is_cleanup == bb_r.is_cleanup
- && bb_l.terminator().kind == bb_r.terminator().kind
- && bb_l.statements.len() == bb_r.statements.len();
- let statement_check = || {
- bb_l.statements.iter().zip(&bb_r.statements).try_fold(StatementEquality::TrivialEqual, |acc,(l,r)| {
- let stmt_equality = self.statement_equality(*adt_matched_on, &l, target_and_value_l, &r, target_and_value_r);
- if matches!(stmt_equality, StatementEquality::NotEqual) {
- // short circuit
- None
- } else {
- Some(acc.combine(&stmt_equality))
- }
- })
- .unwrap_or(StatementEquality::NotEqual)
- };
- if !trivial_checks {
- all_successors_equivalent = StatementEquality::NotEqual;
- break;
- }
- all_successors_equivalent = all_successors_equivalent.combine(&statement_check());
- };
-
- match all_successors_equivalent{
- StatementEquality::TrivialEqual => {
- // statements are trivially equal, so just take first
- trace!("Statements are trivially equal");
- Some(SimplifyBranchSameOptimization {
- bb_to_goto: bb_first.target,
- bb_to_opt_terminator: bb_idx,
- })
- }
- StatementEquality::ConsideredEqual(bb_to_choose) => {
- trace!("Statements are considered equal");
- Some(SimplifyBranchSameOptimization {
- bb_to_goto: bb_to_choose,
- bb_to_opt_terminator: bb_idx,
- })
- }
- StatementEquality::NotEqual => {
- trace!("NO: not all successors of basic block {:?} were equivalent", bb_idx);
- None
- }
- }
- })
- .collect()
- }
-
- /// Tests if two statements can be considered equal
- ///
- /// Statements can be trivially equal if the kinds match.
- /// But they can also be considered equal in the following case A:
- /// ```ignore (MIR)
- /// discriminant(_0) = 0; // bb1
- /// _0 = move _1; // bb2
- /// ```
- /// In this case the two statements are equal iff
- /// - `_0` is an enum where the variant index 0 is fieldless, and
- /// - bb1 was targeted by a switch where the discriminant of `_1` was switched on
- fn statement_equality(
- &self,
- adt_matched_on: Place<'tcx>,
- x: &Statement<'tcx>,
- x_target_and_value: &SwitchTargetAndValue,
- y: &Statement<'tcx>,
- y_target_and_value: &SwitchTargetAndValue,
- ) -> StatementEquality {
- let helper = |rhs: &Rvalue<'tcx>,
- place: &Place<'tcx>,
- variant_index: VariantIdx,
- switch_value: u128,
- side_to_choose| {
- let place_type = place.ty(self.body, self.tcx).ty;
- let adt = match *place_type.kind() {
- ty::Adt(adt, _) if adt.is_enum() => adt,
- _ => return StatementEquality::NotEqual,
- };
- // We need to make sure that the switch value that targets the bb with
- // SetDiscriminant is the same as the variant discriminant.
- let variant_discr = adt.discriminant_for_variant(self.tcx, variant_index).val;
- if variant_discr != switch_value {
- trace!(
- "NO: variant discriminant {} does not equal switch value {}",
- variant_discr,
- switch_value
- );
- return StatementEquality::NotEqual;
- }
- let variant_is_fieldless = adt.variant(variant_index).fields.is_empty();
- if !variant_is_fieldless {
- trace!("NO: variant {:?} was not fieldless", variant_index);
- return StatementEquality::NotEqual;
- }
-
- match rhs {
- Rvalue::Use(operand) if operand.place() == Some(adt_matched_on) => {
- StatementEquality::ConsideredEqual(side_to_choose)
- }
- _ => {
- trace!(
- "NO: RHS of assignment was {:?}, but expected it to match the adt being matched on in the switch, which is {:?}",
- rhs,
- adt_matched_on
- );
- StatementEquality::NotEqual
- }
- }
- };
- match (&x.kind, &y.kind) {
- // trivial case
- (x, y) if x == y => StatementEquality::TrivialEqual,
-
- // check for case A
- (
- StatementKind::Assign(box (_, rhs)),
- &StatementKind::SetDiscriminant { ref place, variant_index },
- ) if y_target_and_value.value.is_some() => {
- // choose basic block of x, as that has the assign
- helper(
- rhs,
- place,
- variant_index,
- y_target_and_value.value.unwrap(),
- x_target_and_value.target,
- )
- }
- (
- &StatementKind::SetDiscriminant { ref place, variant_index },
- &StatementKind::Assign(box (_, ref rhs)),
- ) if x_target_and_value.value.is_some() => {
- // choose basic block of y, as that has the assign
- helper(
- rhs,
- place,
- variant_index,
- x_target_and_value.value.unwrap(),
- y_target_and_value.target,
- )
- }
- _ => {
- trace!("NO: statements `{:?}` and `{:?}` not considered equal", x, y);
- StatementEquality::NotEqual
- }
- }
- }
-}
-
-#[derive(Copy, Clone, Eq, PartialEq)]
-enum StatementEquality {
- /// The two statements are trivially equal; same kind
- TrivialEqual,
- /// The two statements are considered equal, but may be of different kinds. The BasicBlock field is the basic block to jump to when performing the branch-same optimization.
- /// For example, `_0 = _1` and `discriminant(_0) = discriminant(0)` are considered equal if 0 is a fieldless variant of an enum. But we don't want to jump to the basic block with the SetDiscriminant, as that is not legal if _1 is not the 0 variant index
- ConsideredEqual(BasicBlock),
- /// The two statements are not equal
- NotEqual,
-}
-
-impl StatementEquality {
- fn combine(&self, other: &StatementEquality) -> StatementEquality {
- use StatementEquality::*;
- match (self, other) {
- (TrivialEqual, TrivialEqual) => TrivialEqual,
- (TrivialEqual, ConsideredEqual(b)) | (ConsideredEqual(b), TrivialEqual) => {
- ConsideredEqual(*b)
- }
- (ConsideredEqual(b1), ConsideredEqual(b2)) => {
- if b1 == b2 {
- ConsideredEqual(*b1)
- } else {
- NotEqual
- }
- }
- (_, NotEqual) | (NotEqual, _) => NotEqual,
- }
- }
-}
diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs
index 42124f5a4..13168e9a2 100644
--- a/compiler/rustc_mir_transform/src/sroa.rs
+++ b/compiler/rustc_mir_transform/src/sroa.rs
@@ -1,10 +1,11 @@
use crate::MirPass;
-use rustc_data_structures::fx::{FxIndexMap, IndexEntry};
-use rustc_index::bit_set::BitSet;
+use rustc_index::bit_set::{BitSet, GrowableBitSet};
use rustc_index::vec::IndexVec;
+use rustc_middle::mir::patch::MirPatch;
use rustc_middle::mir::visit::*;
use rustc_middle::mir::*;
-use rustc_middle::ty::TyCtxt;
+use rustc_middle::ty::{Ty, TyCtxt};
+use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields};
pub struct ScalarReplacementOfAggregates;
@@ -13,27 +14,43 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
sess.mir_opt_level() >= 3
}
+ #[instrument(level = "debug", skip(self, tcx, body))]
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
- let escaping = escaping_locals(&*body);
- debug!(?escaping);
- let replacements = compute_flattening(tcx, body, escaping);
- debug!(?replacements);
- replace_flattened_locals(tcx, body, replacements);
+ debug!(def_id = ?body.source.def_id());
+ let mut excluded = excluded_locals(body);
+ loop {
+ debug!(?excluded);
+ let escaping = escaping_locals(&excluded, body);
+ debug!(?escaping);
+ let replacements = compute_flattening(tcx, body, escaping);
+ debug!(?replacements);
+ let all_dead_locals = replace_flattened_locals(tcx, body, replacements);
+ if !all_dead_locals.is_empty() {
+ excluded.union(&all_dead_locals);
+ excluded = {
+ let mut growable = GrowableBitSet::from(excluded);
+ growable.ensure(body.local_decls.len());
+ growable.into()
+ };
+ } else {
+ break;
+ }
+ }
}
}
/// Identify all locals that are not eligible for SROA.
///
/// There are 3 cases:
-/// - the aggegated local is used or passed to other code (function parameters and arguments);
+/// - the aggregated local is used or passed to other code (function parameters and arguments);
/// - the locals is a union or an enum;
/// - the local's address is taken, and thus the relative addresses of the fields are observable to
/// client code.
-fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
+fn escaping_locals(excluded: &BitSet<Local>, body: &Body<'_>) -> BitSet<Local> {
let mut set = BitSet::new_empty(body.local_decls.len());
set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
for (local, decl) in body.local_decls().iter_enumerated() {
- if decl.ty.is_union() || decl.ty.is_enum() {
+ if decl.ty.is_union() || decl.ty.is_enum() || excluded.contains(local) {
set.insert(local);
}
}
@@ -58,41 +75,33 @@ fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
self.super_place(place, context, location);
}
- fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
- if let Rvalue::AddressOf(.., place) | Rvalue::Ref(.., place) = rvalue {
- if !place.is_indirect() {
- // Raw pointers may be used to access anything inside the enclosing place.
- self.set.insert(place.local);
- return;
+ fn visit_assign(
+ &mut self,
+ lvalue: &Place<'tcx>,
+ rvalue: &Rvalue<'tcx>,
+ location: Location,
+ ) {
+ if lvalue.as_local().is_some() {
+ match rvalue {
+ // Aggregate assignments are expanded in run_pass.
+ Rvalue::Aggregate(..) | Rvalue::Use(..) => {
+ self.visit_rvalue(rvalue, location);
+ return;
+ }
+ _ => {}
}
}
- self.super_rvalue(rvalue, location)
+ self.super_assign(lvalue, rvalue, location)
}
fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
- if let StatementKind::StorageLive(..)
- | StatementKind::StorageDead(..)
- | StatementKind::Deinit(..) = statement.kind
- {
+ match statement.kind {
// Storage statements are expanded in run_pass.
- return;
+ StatementKind::StorageLive(..)
+ | StatementKind::StorageDead(..)
+ | StatementKind::Deinit(..) => return,
+ _ => self.super_statement(statement, location),
}
- self.super_statement(statement, location)
- }
-
- fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
- // Drop implicitly calls `drop_in_place`, which takes a `&mut`.
- // This implies that `Drop` implicitly takes the address of the place.
- if let TerminatorKind::Drop { place, .. }
- | TerminatorKind::DropAndReplace { place, .. } = terminator.kind
- {
- if !place.is_indirect() {
- // Raw pointers may be used to access anything inside the enclosing place.
- self.set.insert(place.local);
- return;
- }
- }
- self.super_terminator(terminator, location);
}
// We ignore anything that happens in debuginfo, since we expand it using
@@ -103,7 +112,30 @@ fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
#[derive(Default, Debug)]
struct ReplacementMap<'tcx> {
- fields: FxIndexMap<PlaceRef<'tcx>, Local>,
+ /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage
+ /// and deinit statement and debuginfo.
+ fragments: IndexVec<Local, Option<IndexVec<Field, Option<(Ty<'tcx>, Local)>>>>,
+}
+
+impl<'tcx> ReplacementMap<'tcx> {
+ fn replace_place(&self, tcx: TyCtxt<'tcx>, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
+ let &[PlaceElem::Field(f, _), ref rest @ ..] = place.projection else { return None; };
+ let fields = self.fragments[place.local].as_ref()?;
+ let (_, new_local) = fields[f]?;
+ Some(Place { local: new_local, projection: tcx.mk_place_elems(&rest) })
+ }
+
+ fn place_fragments(
+ &self,
+ place: Place<'tcx>,
+ ) -> Option<impl Iterator<Item = (Field, Ty<'tcx>, Local)> + '_> {
+ let local = place.as_local()?;
+ let fields = self.fragments[local].as_ref()?;
+ Some(fields.iter_enumerated().filter_map(|(field, &opt_ty_local)| {
+ let (ty, local) = opt_ty_local?;
+ Some((field, ty, local))
+ }))
+ }
}
/// Compute the replacement of flattened places into locals.
@@ -115,53 +147,25 @@ fn compute_flattening<'tcx>(
body: &mut Body<'tcx>,
escaping: BitSet<Local>,
) -> ReplacementMap<'tcx> {
- let mut visitor = PreFlattenVisitor {
- tcx,
- escaping,
- local_decls: &mut body.local_decls,
- map: Default::default(),
- };
- for (block, bbdata) in body.basic_blocks.iter_enumerated() {
- visitor.visit_basic_block_data(block, bbdata);
- }
- return visitor.map;
-
- struct PreFlattenVisitor<'tcx, 'll> {
- tcx: TyCtxt<'tcx>,
- local_decls: &'ll mut LocalDecls<'tcx>,
- escaping: BitSet<Local>,
- map: ReplacementMap<'tcx>,
- }
-
- impl<'tcx, 'll> PreFlattenVisitor<'tcx, 'll> {
- fn create_place(&mut self, place: PlaceRef<'tcx>) {
- if self.escaping.contains(place.local) {
- return;
- }
+ let mut fragments = IndexVec::from_elem(None, &body.local_decls);
- match self.map.fields.entry(place) {
- IndexEntry::Occupied(_) => {}
- IndexEntry::Vacant(v) => {
- let ty = place.ty(&*self.local_decls, self.tcx).ty;
- let local = self.local_decls.push(LocalDecl {
- ty,
- user_ty: None,
- ..self.local_decls[place.local].clone()
- });
- v.insert(local);
- }
- }
- }
- }
-
- impl<'tcx, 'll> Visitor<'tcx> for PreFlattenVisitor<'tcx, 'll> {
- fn visit_place(&mut self, place: &Place<'tcx>, _: PlaceContext, _: Location) {
- if let &[PlaceElem::Field(..), ..] = &place.projection[..] {
- let pr = PlaceRef { local: place.local, projection: &place.projection[..1] };
- self.create_place(pr)
- }
+ for local in body.local_decls.indices() {
+ if escaping.contains(local) {
+ continue;
}
+ let decl = body.local_decls[local].clone();
+ let ty = decl.ty;
+ iter_fields(ty, tcx, |variant, field, field_ty| {
+ if variant.is_some() {
+ // Downcasts are currently not supported.
+ return;
+ };
+ let new_local =
+ body.local_decls.push(LocalDecl { ty: field_ty, user_ty: None, ..decl.clone() });
+ fragments.get_or_insert_with(local, IndexVec::new).insert(field, (field_ty, new_local));
+ });
}
+ ReplacementMap { fragments }
}
/// Perform the replacement computed by `compute_flattening`.
@@ -169,29 +173,24 @@ fn replace_flattened_locals<'tcx>(
tcx: TyCtxt<'tcx>,
body: &mut Body<'tcx>,
replacements: ReplacementMap<'tcx>,
-) {
- let mut all_dead_locals = BitSet::new_empty(body.local_decls.len());
- for p in replacements.fields.keys() {
- all_dead_locals.insert(p.local);
+) -> BitSet<Local> {
+ let mut all_dead_locals = BitSet::new_empty(replacements.fragments.len());
+ for (local, replacements) in replacements.fragments.iter_enumerated() {
+ if replacements.is_some() {
+ all_dead_locals.insert(local);
+ }
}
debug!(?all_dead_locals);
if all_dead_locals.is_empty() {
- return;
+ return all_dead_locals;
}
- let mut fragments = IndexVec::new();
- for (k, v) in &replacements.fields {
- fragments.ensure_contains_elem(k.local, || Vec::new());
- fragments[k.local].push((k.projection, *v));
- }
- debug!(?fragments);
-
let mut visitor = ReplacementVisitor {
tcx,
local_decls: &body.local_decls,
- replacements,
+ replacements: &replacements,
all_dead_locals,
- fragments,
+ patch: MirPatch::new(body),
};
for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() {
visitor.visit_basic_block_data(bb, data);
@@ -205,6 +204,9 @@ fn replace_flattened_locals<'tcx>(
for var_debug_info in &mut body.var_debug_info {
visitor.visit_var_debug_info(var_debug_info);
}
+ let ReplacementVisitor { patch, all_dead_locals, .. } = visitor;
+ patch.apply(body);
+ all_dead_locals
}
struct ReplacementVisitor<'tcx, 'll> {
@@ -212,40 +214,23 @@ struct ReplacementVisitor<'tcx, 'll> {
/// This is only used to compute the type for `VarDebugInfoContents::Composite`.
local_decls: &'ll LocalDecls<'tcx>,
/// Work to do.
- replacements: ReplacementMap<'tcx>,
+ replacements: &'ll ReplacementMap<'tcx>,
/// This is used to check that we are not leaving references to replaced locals behind.
all_dead_locals: BitSet<Local>,
- /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage
- /// and deinit statement and debuginfo.
- fragments: IndexVec<Local, Vec<(&'tcx [PlaceElem<'tcx>], Local)>>,
+ patch: MirPatch<'tcx>,
}
-impl<'tcx, 'll> ReplacementVisitor<'tcx, 'll> {
- fn gather_debug_info_fragments(
- &self,
- place: PlaceRef<'tcx>,
- ) -> Vec<VarDebugInfoFragment<'tcx>> {
+impl<'tcx> ReplacementVisitor<'tcx, '_> {
+ fn gather_debug_info_fragments(&self, local: Local) -> Option<Vec<VarDebugInfoFragment<'tcx>>> {
let mut fragments = Vec::new();
- let parts = &self.fragments[place.local];
- for (proj, replacement_local) in parts {
- if proj.starts_with(place.projection) {
- fragments.push(VarDebugInfoFragment {
- projection: proj[place.projection.len()..].to_vec(),
- contents: Place::from(*replacement_local),
- });
- }
- }
- fragments
- }
-
- fn replace_place(&self, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
- if let &[PlaceElem::Field(..), ref rest @ ..] = place.projection {
- let pr = PlaceRef { local: place.local, projection: &place.projection[..1] };
- let local = self.replacements.fields.get(&pr)?;
- Some(Place { local: *local, projection: self.tcx.intern_place_elems(&rest) })
- } else {
- None
+ let parts = self.replacements.place_fragments(local.into())?;
+ for (field, ty, replacement_local) in parts {
+ fragments.push(VarDebugInfoFragment {
+ projection: vec![PlaceElem::Field(field, ty)],
+ contents: Place::from(replacement_local),
+ });
}
+ Some(fragments)
}
}
@@ -254,94 +239,188 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
self.tcx
}
- fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
- if let StatementKind::StorageLive(..)
- | StatementKind::StorageDead(..)
- | StatementKind::Deinit(..) = statement.kind
- {
- // Storage statements are expanded in run_pass.
- return;
- }
- self.super_statement(statement, location)
- }
-
fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
- if let Some(repl) = self.replace_place(place.as_ref()) {
+ if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
*place = repl
} else {
self.super_place(place, context, location)
}
}
+ #[instrument(level = "trace", skip(self))]
+ fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
+ match statement.kind {
+ // Duplicate storage and deinit statements, as they pretty much apply to all fields.
+ StatementKind::StorageLive(l) => {
+ if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
+ for (_, _, fl) in final_locals {
+ self.patch.add_statement(location, StatementKind::StorageLive(fl));
+ }
+ statement.make_nop();
+ }
+ return;
+ }
+ StatementKind::StorageDead(l) => {
+ if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
+ for (_, _, fl) in final_locals {
+ self.patch.add_statement(location, StatementKind::StorageDead(fl));
+ }
+ statement.make_nop();
+ }
+ return;
+ }
+ StatementKind::Deinit(box place) => {
+ if let Some(final_locals) = self.replacements.place_fragments(place) {
+ for (_, _, fl) in final_locals {
+ self.patch
+ .add_statement(location, StatementKind::Deinit(Box::new(fl.into())));
+ }
+ statement.make_nop();
+ return;
+ }
+ }
+
+ // We have `a = Struct { 0: x, 1: y, .. }`.
+ // We replace it by
+ // ```
+ // a_0 = x
+ // a_1 = y
+ // ...
+ // ```
+ StatementKind::Assign(box (place, Rvalue::Aggregate(_, ref mut operands))) => {
+ if let Some(local) = place.as_local()
+ && let Some(final_locals) = &self.replacements.fragments[local]
+ {
+ // This is ok as we delete the statement later.
+ let operands = std::mem::take(operands);
+ for (&opt_ty_local, mut operand) in final_locals.iter().zip(operands) {
+ if let Some((_, new_local)) = opt_ty_local {
+ // Replace mentions of SROA'd locals that appear in the operand.
+ self.visit_operand(&mut operand, location);
+
+ let rvalue = Rvalue::Use(operand);
+ self.patch.add_statement(
+ location,
+ StatementKind::Assign(Box::new((new_local.into(), rvalue))),
+ );
+ }
+ }
+ statement.make_nop();
+ return;
+ }
+ }
+
+ // We have `a = some constant`
+ // We add the projections.
+ // ```
+ // a_0 = a.0
+ // a_1 = a.1
+ // ...
+ // ```
+ // ConstProp will pick up the pieces and replace them by actual constants.
+ StatementKind::Assign(box (place, Rvalue::Use(Operand::Constant(_)))) => {
+ if let Some(final_locals) = self.replacements.place_fragments(place) {
+ // Put the deaggregated statements *after* the original one.
+ let location = location.successor_within_block();
+ for (field, ty, new_local) in final_locals {
+ let rplace = self.tcx.mk_place_field(place, field, ty);
+ let rvalue = Rvalue::Use(Operand::Move(rplace));
+ self.patch.add_statement(
+ location,
+ StatementKind::Assign(Box::new((new_local.into(), rvalue))),
+ );
+ }
+ // We still need `place.local` to exist, so don't make it nop.
+ return;
+ }
+ }
+
+ // We have `a = move? place`
+ // We replace it by
+ // ```
+ // a_0 = move? place.0
+ // a_1 = move? place.1
+ // ...
+ // ```
+ StatementKind::Assign(box (lhs, Rvalue::Use(ref op))) => {
+ let (rplace, copy) = match *op {
+ Operand::Copy(rplace) => (rplace, true),
+ Operand::Move(rplace) => (rplace, false),
+ Operand::Constant(_) => bug!(),
+ };
+ if let Some(final_locals) = self.replacements.place_fragments(lhs) {
+ for (field, ty, new_local) in final_locals {
+ let rplace = self.tcx.mk_place_field(rplace, field, ty);
+ debug!(?rplace);
+ let rplace = self
+ .replacements
+ .replace_place(self.tcx, rplace.as_ref())
+ .unwrap_or(rplace);
+ debug!(?rplace);
+ let rvalue = if copy {
+ Rvalue::Use(Operand::Copy(rplace))
+ } else {
+ Rvalue::Use(Operand::Move(rplace))
+ };
+ self.patch.add_statement(
+ location,
+ StatementKind::Assign(Box::new((new_local.into(), rvalue))),
+ );
+ }
+ statement.make_nop();
+ return;
+ }
+ }
+
+ _ => {}
+ }
+ self.super_statement(statement, location)
+ }
+
+ #[instrument(level = "trace", skip(self))]
fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) {
match &mut var_debug_info.value {
VarDebugInfoContents::Place(ref mut place) => {
- if let Some(repl) = self.replace_place(place.as_ref()) {
+ if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
*place = repl;
- } else if self.all_dead_locals.contains(place.local) {
+ } else if let Some(local) = place.as_local()
+ && let Some(fragments) = self.gather_debug_info_fragments(local)
+ {
let ty = place.ty(self.local_decls, self.tcx).ty;
- let fragments = self.gather_debug_info_fragments(place.as_ref());
var_debug_info.value = VarDebugInfoContents::Composite { ty, fragments };
}
}
VarDebugInfoContents::Composite { ty: _, ref mut fragments } => {
let mut new_fragments = Vec::new();
+ debug!(?fragments);
fragments
.drain_filter(|fragment| {
- if let Some(repl) = self.replace_place(fragment.contents.as_ref()) {
+ if let Some(repl) =
+ self.replacements.replace_place(self.tcx, fragment.contents.as_ref())
+ {
fragment.contents = repl;
- true
- } else if self.all_dead_locals.contains(fragment.contents.local) {
- let frg = self.gather_debug_info_fragments(fragment.contents.as_ref());
+ false
+ } else if let Some(local) = fragment.contents.as_local()
+ && let Some(frg) = self.gather_debug_info_fragments(local)
+ {
new_fragments.extend(frg.into_iter().map(|mut f| {
f.projection.splice(0..0, fragment.projection.iter().copied());
f
}));
- false
- } else {
true
+ } else {
+ false
}
})
.for_each(drop);
+ debug!(?fragments);
+ debug!(?new_fragments);
fragments.extend(new_fragments);
}
VarDebugInfoContents::Const(_) => {}
}
}
- fn visit_basic_block_data(&mut self, bb: BasicBlock, bbdata: &mut BasicBlockData<'tcx>) {
- self.super_basic_block_data(bb, bbdata);
-
- #[derive(Debug)]
- enum Stmt {
- StorageLive,
- StorageDead,
- Deinit,
- }
-
- bbdata.expand_statements(|stmt| {
- let source_info = stmt.source_info;
- let (stmt, origin_local) = match &stmt.kind {
- StatementKind::StorageLive(l) => (Stmt::StorageLive, *l),
- StatementKind::StorageDead(l) => (Stmt::StorageDead, *l),
- StatementKind::Deinit(p) if let Some(l) = p.as_local() => (Stmt::Deinit, l),
- _ => return None,
- };
- if !self.all_dead_locals.contains(origin_local) {
- return None;
- }
- let final_locals = self.fragments.get(origin_local)?;
- Some(final_locals.iter().map(move |&(_, l)| {
- let kind = match stmt {
- Stmt::StorageLive => StatementKind::StorageLive(l),
- Stmt::StorageDead => StatementKind::StorageDead(l),
- Stmt::Deinit => StatementKind::Deinit(Box::new(l.into())),
- };
- Statement { source_info, kind }
- }))
- });
- }
-
fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
assert!(!self.all_dead_locals.contains(*local));
}
diff --git a/compiler/rustc_mir_transform/src/ssa.rs b/compiler/rustc_mir_transform/src/ssa.rs
new file mode 100644
index 000000000..c1e7f62de
--- /dev/null
+++ b/compiler/rustc_mir_transform/src/ssa.rs
@@ -0,0 +1,257 @@
+use either::Either;
+use rustc_data_structures::graph::dominators::Dominators;
+use rustc_index::bit_set::BitSet;
+use rustc_index::vec::IndexVec;
+use rustc_middle::middle::resolve_bound_vars::Set1;
+use rustc_middle::mir::visit::*;
+use rustc_middle::mir::*;
+use rustc_middle::ty::{ParamEnv, TyCtxt};
+
+#[derive(Debug)]
+pub struct SsaLocals {
+ /// Assignments to each local. This defines whether the local is SSA.
+ assignments: IndexVec<Local, Set1<LocationExtended>>,
+ /// We visit the body in reverse postorder, to ensure each local is assigned before it is used.
+ /// We remember the order in which we saw the assignments to compute the SSA values in a single
+ /// pass.
+ assignment_order: Vec<Local>,
+ /// Copy equivalence classes between locals. See `copy_classes` for documentation.
+ copy_classes: IndexVec<Local, Local>,
+}
+
+/// We often encounter MIR bodies with 1 or 2 basic blocks. In those cases, it's unnecessary to
+/// actually compute dominators, we can just compare block indices because bb0 is always the first
+/// block, and in any body all other blocks are always always dominated by bb0.
+struct SmallDominators {
+ inner: Option<Dominators<BasicBlock>>,
+}
+
+trait DomExt {
+ fn dominates(self, _other: Self, dominators: &SmallDominators) -> bool;
+}
+
+impl DomExt for Location {
+ fn dominates(self, other: Location, dominators: &SmallDominators) -> bool {
+ if self.block == other.block {
+ self.statement_index <= other.statement_index
+ } else {
+ dominators.dominates(self.block, other.block)
+ }
+ }
+}
+
+impl SmallDominators {
+ fn dominates(&self, dom: BasicBlock, node: BasicBlock) -> bool {
+ if let Some(inner) = &self.inner { inner.dominates(dom, node) } else { dom < node }
+ }
+}
+
+impl SsaLocals {
+ pub fn new<'tcx>(
+ tcx: TyCtxt<'tcx>,
+ param_env: ParamEnv<'tcx>,
+ body: &Body<'tcx>,
+ borrowed_locals: &BitSet<Local>,
+ ) -> SsaLocals {
+ let assignment_order = Vec::new();
+
+ let assignments = IndexVec::from_elem(Set1::Empty, &body.local_decls);
+ let dominators =
+ if body.basic_blocks.len() > 2 { Some(body.basic_blocks.dominators()) } else { None };
+ let dominators = SmallDominators { inner: dominators };
+ let mut visitor = SsaVisitor { assignments, assignment_order, dominators };
+
+ for (local, decl) in body.local_decls.iter_enumerated() {
+ if matches!(body.local_kind(local), LocalKind::Arg) {
+ visitor.assignments[local] = Set1::One(LocationExtended::Arg);
+ }
+ if borrowed_locals.contains(local) && !decl.ty.is_freeze(tcx, param_env) {
+ visitor.assignments[local] = Set1::Many;
+ }
+ }
+
+ if body.basic_blocks.len() > 2 {
+ for (bb, data) in traversal::reverse_postorder(body) {
+ visitor.visit_basic_block_data(bb, data);
+ }
+ } else {
+ for (bb, data) in body.basic_blocks.iter_enumerated() {
+ visitor.visit_basic_block_data(bb, data);
+ }
+ }
+
+ for var_debug_info in &body.var_debug_info {
+ visitor.visit_var_debug_info(var_debug_info);
+ }
+
+ debug!(?visitor.assignments);
+
+ visitor
+ .assignment_order
+ .retain(|&local| matches!(visitor.assignments[local], Set1::One(_)));
+ debug!(?visitor.assignment_order);
+
+ let copy_classes = compute_copy_classes(&visitor, body);
+
+ SsaLocals {
+ assignments: visitor.assignments,
+ assignment_order: visitor.assignment_order,
+ copy_classes,
+ }
+ }
+
+ pub fn is_ssa(&self, local: Local) -> bool {
+ matches!(self.assignments[local], Set1::One(_))
+ }
+
+ pub fn assignments<'a, 'tcx>(
+ &'a self,
+ body: &'a Body<'tcx>,
+ ) -> impl Iterator<Item = (Local, &'a Rvalue<'tcx>)> + 'a {
+ self.assignment_order.iter().filter_map(|&local| {
+ if let Set1::One(LocationExtended::Plain(loc)) = self.assignments[local] {
+ // `loc` must point to a direct assignment to `local`.
+ let Either::Left(stmt) = body.stmt_at(loc) else { bug!() };
+ let Some((target, rvalue)) = stmt.kind.as_assign() else { bug!() };
+ assert_eq!(target.as_local(), Some(local));
+ Some((local, rvalue))
+ } else {
+ None
+ }
+ })
+ }
+
+ /// Compute the equivalence classes for locals, based on copy statements.
+ ///
+ /// The returned vector maps each local to the one it copies. In the following case:
+ /// _a = &mut _0
+ /// _b = move? _a
+ /// _c = move? _a
+ /// _d = move? _c
+ /// We return the mapping
+ /// _a => _a // not a copy so, represented by itself
+ /// _b => _a
+ /// _c => _a
+ /// _d => _a // transitively through _c
+ ///
+ /// Exception: we do not see through the return place, as it cannot be substituted.
+ pub fn copy_classes(&self) -> &IndexVec<Local, Local> {
+ &self.copy_classes
+ }
+
+ /// Make a property uniform on a copy equivalence class by removing elements.
+ pub fn meet_copy_equivalence(&self, property: &mut BitSet<Local>) {
+ // Consolidate to have a local iff all its copies are.
+ //
+ // `copy_classes` defines equivalence classes between locals. The `local`s that recursively
+ // move/copy the same local all have the same `head`.
+ for (local, &head) in self.copy_classes.iter_enumerated() {
+ // If any copy does not have `property`, then the head is not.
+ if !property.contains(local) {
+ property.remove(head);
+ }
+ }
+ for (local, &head) in self.copy_classes.iter_enumerated() {
+ // If any copy does not have `property`, then the head doesn't either,
+ // then no copy has `property`.
+ if !property.contains(head) {
+ property.remove(local);
+ }
+ }
+
+ // Verify that we correctly computed equivalence classes.
+ #[cfg(debug_assertions)]
+ for (local, &head) in self.copy_classes.iter_enumerated() {
+ assert_eq!(property.contains(local), property.contains(head));
+ }
+ }
+}
+
+#[derive(Copy, Clone, Debug, PartialEq, Eq)]
+enum LocationExtended {
+ Plain(Location),
+ Arg,
+}
+
+struct SsaVisitor {
+ dominators: SmallDominators,
+ assignments: IndexVec<Local, Set1<LocationExtended>>,
+ assignment_order: Vec<Local>,
+}
+
+impl<'tcx> Visitor<'tcx> for SsaVisitor {
+ fn visit_local(&mut self, local: Local, ctxt: PlaceContext, loc: Location) {
+ match ctxt {
+ PlaceContext::MutatingUse(MutatingUseContext::Store) => {
+ self.assignments[local].insert(LocationExtended::Plain(loc));
+ self.assignment_order.push(local);
+ }
+ // Anything can happen with raw pointers, so remove them.
+ PlaceContext::NonMutatingUse(NonMutatingUseContext::AddressOf)
+ | PlaceContext::MutatingUse(_) => self.assignments[local] = Set1::Many,
+ // Immutable borrows are taken into account in `SsaLocals::new` by
+ // removing non-freeze locals.
+ PlaceContext::NonMutatingUse(_) => {
+ let set = &mut self.assignments[local];
+ let assign_dominates = match *set {
+ Set1::Empty | Set1::Many => false,
+ Set1::One(LocationExtended::Arg) => true,
+ Set1::One(LocationExtended::Plain(assign)) => {
+ assign.successor_within_block().dominates(loc, &self.dominators)
+ }
+ };
+ // We are visiting a use that is not dominated by an assignment.
+ // Either there is a cycle involved, or we are reading for uninitialized local.
+ // Bail out.
+ if !assign_dominates {
+ *set = Set1::Many;
+ }
+ }
+ PlaceContext::NonUse(_) => {}
+ }
+ }
+}
+
+#[instrument(level = "trace", skip(ssa, body))]
+fn compute_copy_classes(ssa: &SsaVisitor, body: &Body<'_>) -> IndexVec<Local, Local> {
+ let mut copies = IndexVec::from_fn_n(|l| l, body.local_decls.len());
+
+ for &local in &ssa.assignment_order {
+ debug!(?local);
+
+ if local == RETURN_PLACE {
+ // `_0` is special, we cannot rename it.
+ continue;
+ }
+
+ // This is not SSA: mark that we don't know the value.
+ debug!(assignments = ?ssa.assignments[local]);
+ let Set1::One(LocationExtended::Plain(loc)) = ssa.assignments[local] else { continue };
+
+ // `loc` must point to a direct assignment to `local`.
+ let Either::Left(stmt) = body.stmt_at(loc) else { bug!() };
+ let Some((_target, rvalue)) = stmt.kind.as_assign() else { bug!() };
+ assert_eq!(_target.as_local(), Some(local));
+
+ let (Rvalue::Use(Operand::Copy(place) | Operand::Move(place)) | Rvalue::CopyForDeref(place))
+ = rvalue
+ else { continue };
+
+ let Some(rhs) = place.as_local() else { continue };
+ let Set1::One(_) = ssa.assignments[rhs] else { continue };
+
+ // We visit in `assignment_order`, ie. reverse post-order, so `rhs` has been
+ // visited before `local`, and we just have to copy the representing local.
+ copies[local] = copies[rhs];
+ }
+
+ debug!(?copies);
+
+ // Invariant: `copies` must point to the head of an equivalence class.
+ #[cfg(debug_assertions)]
+ for &head in copies.iter() {
+ assert_eq!(copies[head], head);
+ }
+
+ copies
+}