From 4f9fe856a25ab29345b90e7725509e9ee38a37be Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 17 Apr 2024 14:19:41 +0200 Subject: Adding upstream version 1.69.0+dfsg1. Signed-off-by: Daniel Baumann --- compiler/rustc_mir_transform/src/sroa.rs | 439 ++++++++++++++++++------------- 1 file changed, 259 insertions(+), 180 deletions(-) (limited to 'compiler/rustc_mir_transform/src/sroa.rs') diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs index 42124f5a4..13168e9a2 100644 --- a/compiler/rustc_mir_transform/src/sroa.rs +++ b/compiler/rustc_mir_transform/src/sroa.rs @@ -1,10 +1,11 @@ use crate::MirPass; -use rustc_data_structures::fx::{FxIndexMap, IndexEntry}; -use rustc_index::bit_set::BitSet; +use rustc_index::bit_set::{BitSet, GrowableBitSet}; use rustc_index::vec::IndexVec; +use rustc_middle::mir::patch::MirPatch; use rustc_middle::mir::visit::*; use rustc_middle::mir::*; -use rustc_middle::ty::TyCtxt; +use rustc_middle::ty::{Ty, TyCtxt}; +use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields}; pub struct ScalarReplacementOfAggregates; @@ -13,27 +14,43 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates { sess.mir_opt_level() >= 3 } + #[instrument(level = "debug", skip(self, tcx, body))] fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { - let escaping = escaping_locals(&*body); - debug!(?escaping); - let replacements = compute_flattening(tcx, body, escaping); - debug!(?replacements); - replace_flattened_locals(tcx, body, replacements); + debug!(def_id = ?body.source.def_id()); + let mut excluded = excluded_locals(body); + loop { + debug!(?excluded); + let escaping = escaping_locals(&excluded, body); + debug!(?escaping); + let replacements = compute_flattening(tcx, body, escaping); + debug!(?replacements); + let all_dead_locals = replace_flattened_locals(tcx, body, replacements); + if !all_dead_locals.is_empty() { + excluded.union(&all_dead_locals); + excluded = { + let mut growable = GrowableBitSet::from(excluded); + growable.ensure(body.local_decls.len()); + growable.into() + }; + } else { + break; + } + } } } /// Identify all locals that are not eligible for SROA. /// /// There are 3 cases: -/// - the aggegated local is used or passed to other code (function parameters and arguments); +/// - the aggregated local is used or passed to other code (function parameters and arguments); /// - the locals is a union or an enum; /// - the local's address is taken, and thus the relative addresses of the fields are observable to /// client code. -fn escaping_locals(body: &Body<'_>) -> BitSet { +fn escaping_locals(excluded: &BitSet, body: &Body<'_>) -> BitSet { let mut set = BitSet::new_empty(body.local_decls.len()); set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count)); for (local, decl) in body.local_decls().iter_enumerated() { - if decl.ty.is_union() || decl.ty.is_enum() { + if decl.ty.is_union() || decl.ty.is_enum() || excluded.contains(local) { set.insert(local); } } @@ -58,41 +75,33 @@ fn escaping_locals(body: &Body<'_>) -> BitSet { self.super_place(place, context, location); } - fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) { - if let Rvalue::AddressOf(.., place) | Rvalue::Ref(.., place) = rvalue { - if !place.is_indirect() { - // Raw pointers may be used to access anything inside the enclosing place. - self.set.insert(place.local); - return; + fn visit_assign( + &mut self, + lvalue: &Place<'tcx>, + rvalue: &Rvalue<'tcx>, + location: Location, + ) { + if lvalue.as_local().is_some() { + match rvalue { + // Aggregate assignments are expanded in run_pass. + Rvalue::Aggregate(..) | Rvalue::Use(..) => { + self.visit_rvalue(rvalue, location); + return; + } + _ => {} } } - self.super_rvalue(rvalue, location) + self.super_assign(lvalue, rvalue, location) } fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) { - if let StatementKind::StorageLive(..) - | StatementKind::StorageDead(..) - | StatementKind::Deinit(..) = statement.kind - { + match statement.kind { // Storage statements are expanded in run_pass. - return; + StatementKind::StorageLive(..) + | StatementKind::StorageDead(..) + | StatementKind::Deinit(..) => return, + _ => self.super_statement(statement, location), } - self.super_statement(statement, location) - } - - fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { - // Drop implicitly calls `drop_in_place`, which takes a `&mut`. - // This implies that `Drop` implicitly takes the address of the place. - if let TerminatorKind::Drop { place, .. } - | TerminatorKind::DropAndReplace { place, .. } = terminator.kind - { - if !place.is_indirect() { - // Raw pointers may be used to access anything inside the enclosing place. - self.set.insert(place.local); - return; - } - } - self.super_terminator(terminator, location); } // We ignore anything that happens in debuginfo, since we expand it using @@ -103,7 +112,30 @@ fn escaping_locals(body: &Body<'_>) -> BitSet { #[derive(Default, Debug)] struct ReplacementMap<'tcx> { - fields: FxIndexMap, Local>, + /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage + /// and deinit statement and debuginfo. + fragments: IndexVec, Local)>>>>, +} + +impl<'tcx> ReplacementMap<'tcx> { + fn replace_place(&self, tcx: TyCtxt<'tcx>, place: PlaceRef<'tcx>) -> Option> { + let &[PlaceElem::Field(f, _), ref rest @ ..] = place.projection else { return None; }; + let fields = self.fragments[place.local].as_ref()?; + let (_, new_local) = fields[f]?; + Some(Place { local: new_local, projection: tcx.mk_place_elems(&rest) }) + } + + fn place_fragments( + &self, + place: Place<'tcx>, + ) -> Option, Local)> + '_> { + let local = place.as_local()?; + let fields = self.fragments[local].as_ref()?; + Some(fields.iter_enumerated().filter_map(|(field, &opt_ty_local)| { + let (ty, local) = opt_ty_local?; + Some((field, ty, local)) + })) + } } /// Compute the replacement of flattened places into locals. @@ -115,53 +147,25 @@ fn compute_flattening<'tcx>( body: &mut Body<'tcx>, escaping: BitSet, ) -> ReplacementMap<'tcx> { - let mut visitor = PreFlattenVisitor { - tcx, - escaping, - local_decls: &mut body.local_decls, - map: Default::default(), - }; - for (block, bbdata) in body.basic_blocks.iter_enumerated() { - visitor.visit_basic_block_data(block, bbdata); - } - return visitor.map; - - struct PreFlattenVisitor<'tcx, 'll> { - tcx: TyCtxt<'tcx>, - local_decls: &'ll mut LocalDecls<'tcx>, - escaping: BitSet, - map: ReplacementMap<'tcx>, - } - - impl<'tcx, 'll> PreFlattenVisitor<'tcx, 'll> { - fn create_place(&mut self, place: PlaceRef<'tcx>) { - if self.escaping.contains(place.local) { - return; - } + let mut fragments = IndexVec::from_elem(None, &body.local_decls); - match self.map.fields.entry(place) { - IndexEntry::Occupied(_) => {} - IndexEntry::Vacant(v) => { - let ty = place.ty(&*self.local_decls, self.tcx).ty; - let local = self.local_decls.push(LocalDecl { - ty, - user_ty: None, - ..self.local_decls[place.local].clone() - }); - v.insert(local); - } - } - } - } - - impl<'tcx, 'll> Visitor<'tcx> for PreFlattenVisitor<'tcx, 'll> { - fn visit_place(&mut self, place: &Place<'tcx>, _: PlaceContext, _: Location) { - if let &[PlaceElem::Field(..), ..] = &place.projection[..] { - let pr = PlaceRef { local: place.local, projection: &place.projection[..1] }; - self.create_place(pr) - } + for local in body.local_decls.indices() { + if escaping.contains(local) { + continue; } + let decl = body.local_decls[local].clone(); + let ty = decl.ty; + iter_fields(ty, tcx, |variant, field, field_ty| { + if variant.is_some() { + // Downcasts are currently not supported. + return; + }; + let new_local = + body.local_decls.push(LocalDecl { ty: field_ty, user_ty: None, ..decl.clone() }); + fragments.get_or_insert_with(local, IndexVec::new).insert(field, (field_ty, new_local)); + }); } + ReplacementMap { fragments } } /// Perform the replacement computed by `compute_flattening`. @@ -169,29 +173,24 @@ fn replace_flattened_locals<'tcx>( tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, replacements: ReplacementMap<'tcx>, -) { - let mut all_dead_locals = BitSet::new_empty(body.local_decls.len()); - for p in replacements.fields.keys() { - all_dead_locals.insert(p.local); +) -> BitSet { + let mut all_dead_locals = BitSet::new_empty(replacements.fragments.len()); + for (local, replacements) in replacements.fragments.iter_enumerated() { + if replacements.is_some() { + all_dead_locals.insert(local); + } } debug!(?all_dead_locals); if all_dead_locals.is_empty() { - return; + return all_dead_locals; } - let mut fragments = IndexVec::new(); - for (k, v) in &replacements.fields { - fragments.ensure_contains_elem(k.local, || Vec::new()); - fragments[k.local].push((k.projection, *v)); - } - debug!(?fragments); - let mut visitor = ReplacementVisitor { tcx, local_decls: &body.local_decls, - replacements, + replacements: &replacements, all_dead_locals, - fragments, + patch: MirPatch::new(body), }; for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() { visitor.visit_basic_block_data(bb, data); @@ -205,6 +204,9 @@ fn replace_flattened_locals<'tcx>( for var_debug_info in &mut body.var_debug_info { visitor.visit_var_debug_info(var_debug_info); } + let ReplacementVisitor { patch, all_dead_locals, .. } = visitor; + patch.apply(body); + all_dead_locals } struct ReplacementVisitor<'tcx, 'll> { @@ -212,40 +214,23 @@ struct ReplacementVisitor<'tcx, 'll> { /// This is only used to compute the type for `VarDebugInfoContents::Composite`. local_decls: &'ll LocalDecls<'tcx>, /// Work to do. - replacements: ReplacementMap<'tcx>, + replacements: &'ll ReplacementMap<'tcx>, /// This is used to check that we are not leaving references to replaced locals behind. all_dead_locals: BitSet, - /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage - /// and deinit statement and debuginfo. - fragments: IndexVec], Local)>>, + patch: MirPatch<'tcx>, } -impl<'tcx, 'll> ReplacementVisitor<'tcx, 'll> { - fn gather_debug_info_fragments( - &self, - place: PlaceRef<'tcx>, - ) -> Vec> { +impl<'tcx> ReplacementVisitor<'tcx, '_> { + fn gather_debug_info_fragments(&self, local: Local) -> Option>> { let mut fragments = Vec::new(); - let parts = &self.fragments[place.local]; - for (proj, replacement_local) in parts { - if proj.starts_with(place.projection) { - fragments.push(VarDebugInfoFragment { - projection: proj[place.projection.len()..].to_vec(), - contents: Place::from(*replacement_local), - }); - } - } - fragments - } - - fn replace_place(&self, place: PlaceRef<'tcx>) -> Option> { - if let &[PlaceElem::Field(..), ref rest @ ..] = place.projection { - let pr = PlaceRef { local: place.local, projection: &place.projection[..1] }; - let local = self.replacements.fields.get(&pr)?; - Some(Place { local: *local, projection: self.tcx.intern_place_elems(&rest) }) - } else { - None + let parts = self.replacements.place_fragments(local.into())?; + for (field, ty, replacement_local) in parts { + fragments.push(VarDebugInfoFragment { + projection: vec![PlaceElem::Field(field, ty)], + contents: Place::from(replacement_local), + }); } + Some(fragments) } } @@ -254,94 +239,188 @@ impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> { self.tcx } - fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { - if let StatementKind::StorageLive(..) - | StatementKind::StorageDead(..) - | StatementKind::Deinit(..) = statement.kind - { - // Storage statements are expanded in run_pass. - return; - } - self.super_statement(statement, location) - } - fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) { - if let Some(repl) = self.replace_place(place.as_ref()) { + if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) { *place = repl } else { self.super_place(place, context, location) } } + #[instrument(level = "trace", skip(self))] + fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { + match statement.kind { + // Duplicate storage and deinit statements, as they pretty much apply to all fields. + StatementKind::StorageLive(l) => { + if let Some(final_locals) = self.replacements.place_fragments(l.into()) { + for (_, _, fl) in final_locals { + self.patch.add_statement(location, StatementKind::StorageLive(fl)); + } + statement.make_nop(); + } + return; + } + StatementKind::StorageDead(l) => { + if let Some(final_locals) = self.replacements.place_fragments(l.into()) { + for (_, _, fl) in final_locals { + self.patch.add_statement(location, StatementKind::StorageDead(fl)); + } + statement.make_nop(); + } + return; + } + StatementKind::Deinit(box place) => { + if let Some(final_locals) = self.replacements.place_fragments(place) { + for (_, _, fl) in final_locals { + self.patch + .add_statement(location, StatementKind::Deinit(Box::new(fl.into()))); + } + statement.make_nop(); + return; + } + } + + // We have `a = Struct { 0: x, 1: y, .. }`. + // We replace it by + // ``` + // a_0 = x + // a_1 = y + // ... + // ``` + StatementKind::Assign(box (place, Rvalue::Aggregate(_, ref mut operands))) => { + if let Some(local) = place.as_local() + && let Some(final_locals) = &self.replacements.fragments[local] + { + // This is ok as we delete the statement later. + let operands = std::mem::take(operands); + for (&opt_ty_local, mut operand) in final_locals.iter().zip(operands) { + if let Some((_, new_local)) = opt_ty_local { + // Replace mentions of SROA'd locals that appear in the operand. + self.visit_operand(&mut operand, location); + + let rvalue = Rvalue::Use(operand); + self.patch.add_statement( + location, + StatementKind::Assign(Box::new((new_local.into(), rvalue))), + ); + } + } + statement.make_nop(); + return; + } + } + + // We have `a = some constant` + // We add the projections. + // ``` + // a_0 = a.0 + // a_1 = a.1 + // ... + // ``` + // ConstProp will pick up the pieces and replace them by actual constants. + StatementKind::Assign(box (place, Rvalue::Use(Operand::Constant(_)))) => { + if let Some(final_locals) = self.replacements.place_fragments(place) { + // Put the deaggregated statements *after* the original one. + let location = location.successor_within_block(); + for (field, ty, new_local) in final_locals { + let rplace = self.tcx.mk_place_field(place, field, ty); + let rvalue = Rvalue::Use(Operand::Move(rplace)); + self.patch.add_statement( + location, + StatementKind::Assign(Box::new((new_local.into(), rvalue))), + ); + } + // We still need `place.local` to exist, so don't make it nop. + return; + } + } + + // We have `a = move? place` + // We replace it by + // ``` + // a_0 = move? place.0 + // a_1 = move? place.1 + // ... + // ``` + StatementKind::Assign(box (lhs, Rvalue::Use(ref op))) => { + let (rplace, copy) = match *op { + Operand::Copy(rplace) => (rplace, true), + Operand::Move(rplace) => (rplace, false), + Operand::Constant(_) => bug!(), + }; + if let Some(final_locals) = self.replacements.place_fragments(lhs) { + for (field, ty, new_local) in final_locals { + let rplace = self.tcx.mk_place_field(rplace, field, ty); + debug!(?rplace); + let rplace = self + .replacements + .replace_place(self.tcx, rplace.as_ref()) + .unwrap_or(rplace); + debug!(?rplace); + let rvalue = if copy { + Rvalue::Use(Operand::Copy(rplace)) + } else { + Rvalue::Use(Operand::Move(rplace)) + }; + self.patch.add_statement( + location, + StatementKind::Assign(Box::new((new_local.into(), rvalue))), + ); + } + statement.make_nop(); + return; + } + } + + _ => {} + } + self.super_statement(statement, location) + } + + #[instrument(level = "trace", skip(self))] fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) { match &mut var_debug_info.value { VarDebugInfoContents::Place(ref mut place) => { - if let Some(repl) = self.replace_place(place.as_ref()) { + if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) { *place = repl; - } else if self.all_dead_locals.contains(place.local) { + } else if let Some(local) = place.as_local() + && let Some(fragments) = self.gather_debug_info_fragments(local) + { let ty = place.ty(self.local_decls, self.tcx).ty; - let fragments = self.gather_debug_info_fragments(place.as_ref()); var_debug_info.value = VarDebugInfoContents::Composite { ty, fragments }; } } VarDebugInfoContents::Composite { ty: _, ref mut fragments } => { let mut new_fragments = Vec::new(); + debug!(?fragments); fragments .drain_filter(|fragment| { - if let Some(repl) = self.replace_place(fragment.contents.as_ref()) { + if let Some(repl) = + self.replacements.replace_place(self.tcx, fragment.contents.as_ref()) + { fragment.contents = repl; - true - } else if self.all_dead_locals.contains(fragment.contents.local) { - let frg = self.gather_debug_info_fragments(fragment.contents.as_ref()); + false + } else if let Some(local) = fragment.contents.as_local() + && let Some(frg) = self.gather_debug_info_fragments(local) + { new_fragments.extend(frg.into_iter().map(|mut f| { f.projection.splice(0..0, fragment.projection.iter().copied()); f })); - false - } else { true + } else { + false } }) .for_each(drop); + debug!(?fragments); + debug!(?new_fragments); fragments.extend(new_fragments); } VarDebugInfoContents::Const(_) => {} } } - fn visit_basic_block_data(&mut self, bb: BasicBlock, bbdata: &mut BasicBlockData<'tcx>) { - self.super_basic_block_data(bb, bbdata); - - #[derive(Debug)] - enum Stmt { - StorageLive, - StorageDead, - Deinit, - } - - bbdata.expand_statements(|stmt| { - let source_info = stmt.source_info; - let (stmt, origin_local) = match &stmt.kind { - StatementKind::StorageLive(l) => (Stmt::StorageLive, *l), - StatementKind::StorageDead(l) => (Stmt::StorageDead, *l), - StatementKind::Deinit(p) if let Some(l) = p.as_local() => (Stmt::Deinit, l), - _ => return None, - }; - if !self.all_dead_locals.contains(origin_local) { - return None; - } - let final_locals = self.fragments.get(origin_local)?; - Some(final_locals.iter().map(move |&(_, l)| { - let kind = match stmt { - Stmt::StorageLive => StatementKind::StorageLive(l), - Stmt::StorageDead => StatementKind::StorageDead(l), - Stmt::Deinit => StatementKind::Deinit(Box::new(l.into())), - }; - Statement { source_info, kind } - })) - }); - } - fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) { assert!(!self.all_dead_locals.contains(*local)); } -- cgit v1.2.3