summaryrefslogtreecommitdiffstats
path: root/compiler/rustc_mir_transform/src/sroa.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src/sroa.rs')
-rw-r--r--compiler/rustc_mir_transform/src/sroa.rs439
1 files changed, 259 insertions, 180 deletions
diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs
index 42124f5a4..13168e9a2 100644
--- a/compiler/rustc_mir_transform/src/sroa.rs
+++ b/compiler/rustc_mir_transform/src/sroa.rs
@@ -1,10 +1,11 @@
use crate::MirPass;
-use rustc_data_structures::fx::{FxIndexMap, IndexEntry};
-use rustc_index::bit_set::BitSet;
+use rustc_index::bit_set::{BitSet, GrowableBitSet};
use rustc_index::vec::IndexVec;
+use rustc_middle::mir::patch::MirPatch;
use rustc_middle::mir::visit::*;
use rustc_middle::mir::*;
-use rustc_middle::ty::TyCtxt;
+use rustc_middle::ty::{Ty, TyCtxt};
+use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields};
pub struct ScalarReplacementOfAggregates;
@@ -13,27 +14,43 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
sess.mir_opt_level() >= 3
}
+ #[instrument(level = "debug", skip(self, tcx, body))]
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
- let escaping = escaping_locals(&*body);
- debug!(?escaping);
- let replacements = compute_flattening(tcx, body, escaping);
- debug!(?replacements);
- replace_flattened_locals(tcx, body, replacements);
+ debug!(def_id = ?body.source.def_id());
+ let mut excluded = excluded_locals(body);
+ loop {
+ debug!(?excluded);
+ let escaping = escaping_locals(&excluded, body);
+ debug!(?escaping);
+ let replacements = compute_flattening(tcx, body, escaping);
+ debug!(?replacements);
+ let all_dead_locals = replace_flattened_locals(tcx, body, replacements);
+ if !all_dead_locals.is_empty() {
+ excluded.union(&all_dead_locals);
+ excluded = {
+ let mut growable = GrowableBitSet::from(excluded);
+ growable.ensure(body.local_decls.len());
+ growable.into()
+ };
+ } else {
+ break;
+ }
+ }
}
}
/// Identify all locals that are not eligible for SROA.
///
/// There are 3 cases:
-/// - the aggegated local is used or passed to other code (function parameters and arguments);
+/// - the aggregated local is used or passed to other code (function parameters and arguments);
/// - the locals is a union or an enum;
/// - the local's address is taken, and thus the relative addresses of the fields are observable to
/// client code.
-fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
+fn escaping_locals(excluded: &BitSet<Local>, body: &Body<'_>) -> BitSet<Local> {
let mut set = BitSet::new_empty(body.local_decls.len());
set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
for (local, decl) in body.local_decls().iter_enumerated() {
- if decl.ty.is_union() || decl.ty.is_enum() {
+ if decl.ty.is_union() || decl.ty.is_enum() || excluded.contains(local) {
set.insert(local);
}
}
@@ -58,41 +75,33 @@ fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
self.super_place(place, context, location);
}
- fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
- if let Rvalue::AddressOf(.., place) | Rvalue::Ref(.., place) = rvalue {
- if !place.is_indirect() {
- // Raw pointers may be used to access anything inside the enclosing place.
- self.set.insert(place.local);
- return;
+ fn visit_assign(
+ &mut self,
+ lvalue: &Place<'tcx>,
+ rvalue: &Rvalue<'tcx>,
+ location: Location,
+ ) {
+ if lvalue.as_local().is_some() {
+ match rvalue {
+ // Aggregate assignments are expanded in run_pass.
+ Rvalue::Aggregate(..) | Rvalue::Use(..) => {
+ self.visit_rvalue(rvalue, location);
+ return;
+ }
+ _ => {}
}
}
- self.super_rvalue(rvalue, location)
+ self.super_assign(lvalue, rvalue, location)
}
fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
- if let StatementKind::StorageLive(..)
- | StatementKind::StorageDead(..)
- | StatementKind::Deinit(..) = statement.kind
- {
+ match statement.kind {
// Storage statements are expanded in run_pass.
- return;
+ StatementKind::StorageLive(..)
+ | StatementKind::StorageDead(..)
+ | StatementKind::Deinit(..) => return,
+ _ => self.super_statement(statement, location),
}
- self.super_statement(statement, location)
- }
-
- fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
- // Drop implicitly calls `drop_in_place`, which takes a `&mut`.
- // This implies that `Drop` implicitly takes the address of the place.
- if let TerminatorKind::Drop { place, .. }
- | TerminatorKind::DropAndReplace { place, .. } = terminator.kind
- {
- if !place.is_indirect() {
- // Raw pointers may be used to access anything inside the enclosing place.
- self.set.insert(place.local);
- return;
- }
- }
- self.super_terminator(terminator, location);
}
// We ignore anything that happens in debuginfo, since we expand it using
@@ -103,7 +112,30 @@ fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
#[derive(Default, Debug)]
struct ReplacementMap<'tcx> {
- fields: FxIndexMap<PlaceRef<'tcx>, Local>,
+ /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage
+ /// and deinit statement and debuginfo.
+ fragments: IndexVec<Local, Option<IndexVec<Field, Option<(Ty<'tcx>, Local)>>>>,
+}
+
+impl<'tcx> ReplacementMap<'tcx> {
+ fn replace_place(&self, tcx: TyCtxt<'tcx>, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
+ let &[PlaceElem::Field(f, _), ref rest @ ..] = place.projection else { return None; };
+ let fields = self.fragments[place.local].as_ref()?;
+ let (_, new_local) = fields[f]?;
+ Some(Place { local: new_local, projection: tcx.mk_place_elems(&rest) })
+ }
+
+ fn place_fragments(
+ &self,
+ place: Place<'tcx>,
+ ) -> Option<impl Iterator<Item = (Field, Ty<'tcx>, Local)> + '_> {
+ let local = place.as_local()?;
+ let fields = self.fragments[local].as_ref()?;
+ Some(fields.iter_enumerated().filter_map(|(field, &opt_ty_local)| {
+ let (ty, local) = opt_ty_local?;
+ Some((field, ty, local))
+ }))
+ }
}
/// Compute the replacement of flattened places into locals.
@@ -115,53 +147,25 @@ fn compute_flattening<'tcx>(
body: &mut Body<'tcx>,
escaping: BitSet<Local>,
) -> ReplacementMap<'tcx> {
- let mut visitor = PreFlattenVisitor {
- tcx,
- escaping,
- local_decls: &mut body.local_decls,
- map: Default::default(),
- };
- for (block, bbdata) in body.basic_blocks.iter_enumerated() {
- visitor.visit_basic_block_data(block, bbdata);
- }
- return visitor.map;
-
- struct PreFlattenVisitor<'tcx, 'll> {
- tcx: TyCtxt<'tcx>,
- local_decls: &'ll mut LocalDecls<'tcx>,
- escaping: BitSet<Local>,
- map: ReplacementMap<'tcx>,
- }
-
- impl<'tcx, 'll> PreFlattenVisitor<'tcx, 'll> {
- fn create_place(&mut self, place: PlaceRef<'tcx>) {
- if self.escaping.contains(place.local) {
- return;
- }
+ let mut fragments = IndexVec::from_elem(None, &body.local_decls);
- match self.map.fields.entry(place) {
- IndexEntry::Occupied(_) => {}
- IndexEntry::Vacant(v) => {
- let ty = place.ty(&*self.local_decls, self.tcx).ty;
- let local = self.local_decls.push(LocalDecl {
- ty,
- user_ty: None,
- ..self.local_decls[place.local].clone()
- });
- v.insert(local);
- }
- }
- }
- }
-
- impl<'tcx, 'll> Visitor<'tcx> for PreFlattenVisitor<'tcx, 'll> {
- fn visit_place(&mut self, place: &Place<'tcx>, _: PlaceContext, _: Location) {
- if let &[PlaceElem::Field(..), ..] = &place.projection[..] {
- let pr = PlaceRef { local: place.local, projection: &place.projection[..1] };
- self.create_place(pr)
- }
+ for local in body.local_decls.indices() {
+ if escaping.contains(local) {
+ continue;
}
+ let decl = body.local_decls[local].clone();
+ let ty = decl.ty;
+ iter_fields(ty, tcx, |variant, field, field_ty| {
+ if variant.is_some() {
+ // Downcasts are currently not supported.
+ return;
+ };
+ let new_local =
+ body.local_decls.push(LocalDecl { ty: field_ty, user_ty: None, ..decl.clone() });
+ fragments.get_or_insert_with(local, IndexVec::new).insert(field, (field_ty, new_local));
+ });
}
+ ReplacementMap { fragments }
}
/// Perform the replacement computed by `compute_flattening`.
@@ -169,29 +173,24 @@ fn replace_flattened_locals<'tcx>(
tcx: TyCtxt<'tcx>,
body: &mut Body<'tcx>,
replacements: ReplacementMap<'tcx>,
-) {
- let mut all_dead_locals = BitSet::new_empty(body.local_decls.len());
- for p in replacements.fields.keys() {
- all_dead_locals.insert(p.local);
+) -> BitSet<Local> {
+ let mut all_dead_locals = BitSet::new_empty(replacements.fragments.len());
+ for (local, replacements) in replacements.fragments.iter_enumerated() {
+ if replacements.is_some() {
+ all_dead_locals.insert(local);
+ }
}
debug!(?all_dead_locals);
if all_dead_locals.is_empty() {
- return;
+ return all_dead_locals;
}
- let mut fragments = IndexVec::new();
- for (k, v) in &replacements.fields {
- fragments.ensure_contains_elem(k.local, || Vec::new());
- fragments[k.local].push((k.projection, *v));
- }
- debug!(?fragments);
-
let mut visitor = ReplacementVisitor {
tcx,
local_decls: &body.local_decls,
- replacements,
+ replacements: &replacements,
all_dead_locals,
- fragments,
+ patch: MirPatch::new(body),
};
for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() {
visitor.visit_basic_block_data(bb, data);
@@ -205,6 +204,9 @@ fn replace_flattened_locals<'tcx>(
for var_debug_info in &mut body.var_debug_info {
visitor.visit_var_debug_info(var_debug_info);
}
+ let ReplacementVisitor { patch, all_dead_locals, .. } = visitor;
+ patch.apply(body);
+ all_dead_locals
}
struct ReplacementVisitor<'tcx, 'll> {
@@ -212,40 +214,23 @@ struct ReplacementVisitor<'tcx, 'll> {
/// This is only used to compute the type for `VarDebugInfoContents::Composite`.
local_decls: &'ll LocalDecls<'tcx>,
/// Work to do.
- replacements: ReplacementMap<'tcx>,
+ replacements: &'ll ReplacementMap<'tcx>,
/// This is used to check that we are not leaving references to replaced locals behind.
all_dead_locals: BitSet<Local>,
- /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage
- /// and deinit statement and debuginfo.
- fragments: IndexVec<Local, Vec<(&'tcx [PlaceElem<'tcx>], Local)>>,
+ patch: MirPatch<'tcx>,
}
-impl<'tcx, 'll> ReplacementVisitor<'tcx, 'll> {
- fn gather_debug_info_fragments(
- &self,
- place: PlaceRef<'tcx>,
- ) -> Vec<VarDebugInfoFragment<'tcx>> {
+impl<'tcx> ReplacementVisitor<'tcx, '_> {
+ fn gather_debug_info_fragments(&self, local: Local) -> Option<Vec<VarDebugInfoFragment<'tcx>>> {
let mut fragments = Vec::new();
- let parts = &self.fragments[place.local];
- for (proj, replacement_local) in parts {
- if proj.starts_with(place.projection) {
- fragments.push(VarDebugInfoFragment {
- projection: proj[place.projection.len()..].to_vec(),
- contents: Place::from(*replacement_local),
- });
- }
- }
- fragments
- }
-
- fn replace_place(&self, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
- if let &[PlaceElem::Field(..), ref rest @ ..] = place.projection {
- let pr = PlaceRef { local: place.local, projection: &place.projection[..1] };
- let local = self.replacements.fields.get(&pr)?;
- Some(Place { local: *local, projection: self.tcx.intern_place_elems(&rest) })
- } else {
- None
+ let parts = self.replacements.place_fragments(local.into())?;
+ for (field, ty, replacement_local) in parts {
+ fragments.push(VarDebugInfoFragment {
+ projection: vec![PlaceElem::Field(field, ty)],
+ contents: Place::from(replacement_local),
+ });
}
+ Some(fragments)
}
}
@@ -254,94 +239,188 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
self.tcx
}
- fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
- if let StatementKind::StorageLive(..)
- | StatementKind::StorageDead(..)
- | StatementKind::Deinit(..) = statement.kind
- {
- // Storage statements are expanded in run_pass.
- return;
- }
- self.super_statement(statement, location)
- }
-
fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
- if let Some(repl) = self.replace_place(place.as_ref()) {
+ if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
*place = repl
} else {
self.super_place(place, context, location)
}
}
+ #[instrument(level = "trace", skip(self))]
+ fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
+ match statement.kind {
+ // Duplicate storage and deinit statements, as they pretty much apply to all fields.
+ StatementKind::StorageLive(l) => {
+ if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
+ for (_, _, fl) in final_locals {
+ self.patch.add_statement(location, StatementKind::StorageLive(fl));
+ }
+ statement.make_nop();
+ }
+ return;
+ }
+ StatementKind::StorageDead(l) => {
+ if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
+ for (_, _, fl) in final_locals {
+ self.patch.add_statement(location, StatementKind::StorageDead(fl));
+ }
+ statement.make_nop();
+ }
+ return;
+ }
+ StatementKind::Deinit(box place) => {
+ if let Some(final_locals) = self.replacements.place_fragments(place) {
+ for (_, _, fl) in final_locals {
+ self.patch
+ .add_statement(location, StatementKind::Deinit(Box::new(fl.into())));
+ }
+ statement.make_nop();
+ return;
+ }
+ }
+
+ // We have `a = Struct { 0: x, 1: y, .. }`.
+ // We replace it by
+ // ```
+ // a_0 = x
+ // a_1 = y
+ // ...
+ // ```
+ StatementKind::Assign(box (place, Rvalue::Aggregate(_, ref mut operands))) => {
+ if let Some(local) = place.as_local()
+ && let Some(final_locals) = &self.replacements.fragments[local]
+ {
+ // This is ok as we delete the statement later.
+ let operands = std::mem::take(operands);
+ for (&opt_ty_local, mut operand) in final_locals.iter().zip(operands) {
+ if let Some((_, new_local)) = opt_ty_local {
+ // Replace mentions of SROA'd locals that appear in the operand.
+ self.visit_operand(&mut operand, location);
+
+ let rvalue = Rvalue::Use(operand);
+ self.patch.add_statement(
+ location,
+ StatementKind::Assign(Box::new((new_local.into(), rvalue))),
+ );
+ }
+ }
+ statement.make_nop();
+ return;
+ }
+ }
+
+ // We have `a = some constant`
+ // We add the projections.
+ // ```
+ // a_0 = a.0
+ // a_1 = a.1
+ // ...
+ // ```
+ // ConstProp will pick up the pieces and replace them by actual constants.
+ StatementKind::Assign(box (place, Rvalue::Use(Operand::Constant(_)))) => {
+ if let Some(final_locals) = self.replacements.place_fragments(place) {
+ // Put the deaggregated statements *after* the original one.
+ let location = location.successor_within_block();
+ for (field, ty, new_local) in final_locals {
+ let rplace = self.tcx.mk_place_field(place, field, ty);
+ let rvalue = Rvalue::Use(Operand::Move(rplace));
+ self.patch.add_statement(
+ location,
+ StatementKind::Assign(Box::new((new_local.into(), rvalue))),
+ );
+ }
+ // We still need `place.local` to exist, so don't make it nop.
+ return;
+ }
+ }
+
+ // We have `a = move? place`
+ // We replace it by
+ // ```
+ // a_0 = move? place.0
+ // a_1 = move? place.1
+ // ...
+ // ```
+ StatementKind::Assign(box (lhs, Rvalue::Use(ref op))) => {
+ let (rplace, copy) = match *op {
+ Operand::Copy(rplace) => (rplace, true),
+ Operand::Move(rplace) => (rplace, false),
+ Operand::Constant(_) => bug!(),
+ };
+ if let Some(final_locals) = self.replacements.place_fragments(lhs) {
+ for (field, ty, new_local) in final_locals {
+ let rplace = self.tcx.mk_place_field(rplace, field, ty);
+ debug!(?rplace);
+ let rplace = self
+ .replacements
+ .replace_place(self.tcx, rplace.as_ref())
+ .unwrap_or(rplace);
+ debug!(?rplace);
+ let rvalue = if copy {
+ Rvalue::Use(Operand::Copy(rplace))
+ } else {
+ Rvalue::Use(Operand::Move(rplace))
+ };
+ self.patch.add_statement(
+ location,
+ StatementKind::Assign(Box::new((new_local.into(), rvalue))),
+ );
+ }
+ statement.make_nop();
+ return;
+ }
+ }
+
+ _ => {}
+ }
+ self.super_statement(statement, location)
+ }
+
+ #[instrument(level = "trace", skip(self))]
fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) {
match &mut var_debug_info.value {
VarDebugInfoContents::Place(ref mut place) => {
- if let Some(repl) = self.replace_place(place.as_ref()) {
+ if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
*place = repl;
- } else if self.all_dead_locals.contains(place.local) {
+ } else if let Some(local) = place.as_local()
+ && let Some(fragments) = self.gather_debug_info_fragments(local)
+ {
let ty = place.ty(self.local_decls, self.tcx).ty;
- let fragments = self.gather_debug_info_fragments(place.as_ref());
var_debug_info.value = VarDebugInfoContents::Composite { ty, fragments };
}
}
VarDebugInfoContents::Composite { ty: _, ref mut fragments } => {
let mut new_fragments = Vec::new();
+ debug!(?fragments);
fragments
.drain_filter(|fragment| {
- if let Some(repl) = self.replace_place(fragment.contents.as_ref()) {
+ if let Some(repl) =
+ self.replacements.replace_place(self.tcx, fragment.contents.as_ref())
+ {
fragment.contents = repl;
- true
- } else if self.all_dead_locals.contains(fragment.contents.local) {
- let frg = self.gather_debug_info_fragments(fragment.contents.as_ref());
+ false
+ } else if let Some(local) = fragment.contents.as_local()
+ && let Some(frg) = self.gather_debug_info_fragments(local)
+ {
new_fragments.extend(frg.into_iter().map(|mut f| {
f.projection.splice(0..0, fragment.projection.iter().copied());
f
}));
- false
- } else {
true
+ } else {
+ false
}
})
.for_each(drop);
+ debug!(?fragments);
+ debug!(?new_fragments);
fragments.extend(new_fragments);
}
VarDebugInfoContents::Const(_) => {}
}
}
- fn visit_basic_block_data(&mut self, bb: BasicBlock, bbdata: &mut BasicBlockData<'tcx>) {
- self.super_basic_block_data(bb, bbdata);
-
- #[derive(Debug)]
- enum Stmt {
- StorageLive,
- StorageDead,
- Deinit,
- }
-
- bbdata.expand_statements(|stmt| {
- let source_info = stmt.source_info;
- let (stmt, origin_local) = match &stmt.kind {
- StatementKind::StorageLive(l) => (Stmt::StorageLive, *l),
- StatementKind::StorageDead(l) => (Stmt::StorageDead, *l),
- StatementKind::Deinit(p) if let Some(l) = p.as_local() => (Stmt::Deinit, l),
- _ => return None,
- };
- if !self.all_dead_locals.contains(origin_local) {
- return None;
- }
- let final_locals = self.fragments.get(origin_local)?;
- Some(final_locals.iter().map(move |&(_, l)| {
- let kind = match stmt {
- Stmt::StorageLive => StatementKind::StorageLive(l),
- Stmt::StorageDead => StatementKind::StorageDead(l),
- Stmt::Deinit => StatementKind::Deinit(Box::new(l.into())),
- };
- Statement { source_info, kind }
- }))
- });
- }
-
fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
assert!(!self.all_dead_locals.contains(*local));
}