summaryrefslogtreecommitdiffstats
path: root/compiler/rustc_mir_transform/src/deduplicate_blocks.rs
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-17 12:02:58 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-17 12:02:58 +0000
commit698f8c2f01ea549d77d7dc3338a12e04c11057b9 (patch)
tree173a775858bd501c378080a10dca74132f05bc50 /compiler/rustc_mir_transform/src/deduplicate_blocks.rs
parentInitial commit. (diff)
downloadrustc-698f8c2f01ea549d77d7dc3338a12e04c11057b9.tar.xz
rustc-698f8c2f01ea549d77d7dc3338a12e04c11057b9.zip
Adding upstream version 1.64.0+dfsg1.upstream/1.64.0+dfsg1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'compiler/rustc_mir_transform/src/deduplicate_blocks.rs')
-rw-r--r--compiler/rustc_mir_transform/src/deduplicate_blocks.rs190
1 files changed, 190 insertions, 0 deletions
diff --git a/compiler/rustc_mir_transform/src/deduplicate_blocks.rs b/compiler/rustc_mir_transform/src/deduplicate_blocks.rs
new file mode 100644
index 000000000..d1977ed49
--- /dev/null
+++ b/compiler/rustc_mir_transform/src/deduplicate_blocks.rs
@@ -0,0 +1,190 @@
+//! This pass finds basic blocks that are completely equal,
+//! and replaces all uses with just one of them.
+
+use std::{collections::hash_map::Entry, hash::Hash, hash::Hasher, iter};
+
+use crate::MirPass;
+
+use rustc_data_structures::fx::FxHashMap;
+use rustc_middle::mir::visit::MutVisitor;
+use rustc_middle::mir::*;
+use rustc_middle::ty::TyCtxt;
+
+use super::simplify::simplify_cfg;
+
+pub struct DeduplicateBlocks;
+
+impl<'tcx> MirPass<'tcx> for DeduplicateBlocks {
+ fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
+ sess.mir_opt_level() >= 4
+ }
+
+ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+ debug!("Running DeduplicateBlocks on `{:?}`", body.source);
+ let duplicates = find_duplicates(body);
+ let has_opts_to_apply = !duplicates.is_empty();
+
+ if has_opts_to_apply {
+ let mut opt_applier = OptApplier { tcx, duplicates };
+ opt_applier.visit_body(body);
+ simplify_cfg(tcx, body);
+ }
+ }
+}
+
+struct OptApplier<'tcx> {
+ tcx: TyCtxt<'tcx>,
+ duplicates: FxHashMap<BasicBlock, BasicBlock>,
+}
+
+impl<'tcx> MutVisitor<'tcx> for OptApplier<'tcx> {
+ fn tcx(&self) -> TyCtxt<'tcx> {
+ self.tcx
+ }
+
+ fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) {
+ for target in terminator.successors_mut() {
+ if let Some(replacement) = self.duplicates.get(target) {
+ debug!("SUCCESS: Replacing: `{:?}` with `{:?}`", target, replacement);
+ *target = *replacement;
+ }
+ }
+
+ self.super_terminator(terminator, location);
+ }
+}
+
+fn find_duplicates(body: &Body<'_>) -> FxHashMap<BasicBlock, BasicBlock> {
+ let mut duplicates = FxHashMap::default();
+
+ let bbs_to_go_through =
+ body.basic_blocks().iter_enumerated().filter(|(_, bbd)| !bbd.is_cleanup).count();
+
+ let mut same_hashes =
+ FxHashMap::with_capacity_and_hasher(bbs_to_go_through, Default::default());
+
+ // Go through the basic blocks backwards. This means that in case of duplicates,
+ // we can use the basic block with the highest index as the replacement for all lower ones.
+ // For example, if bb1, bb2 and bb3 are duplicates, we will first insert bb3 in same_hashes.
+ // Then we will see that bb2 is a duplicate of bb3,
+ // and insert bb2 with the replacement bb3 in the duplicates list.
+ // When we see bb1, we see that it is a duplicate of bb3, and therefore insert it in the duplicates list
+ // with replacement bb3.
+ // When the duplicates are removed, we will end up with only bb3.
+ for (bb, bbd) in body.basic_blocks().iter_enumerated().rev().filter(|(_, bbd)| !bbd.is_cleanup)
+ {
+ // Basic blocks can get really big, so to avoid checking for duplicates in basic blocks
+ // that are unlikely to have duplicates, we stop early. The early bail number has been
+ // found experimentally by eprintln while compiling the crates in the rustc-perf suite.
+ if bbd.statements.len() > 10 {
+ continue;
+ }
+
+ let to_hash = BasicBlockHashable { basic_block_data: bbd };
+ let entry = same_hashes.entry(to_hash);
+ match entry {
+ Entry::Occupied(occupied) => {
+ // The basic block was already in the hashmap, which means we have a duplicate
+ let value = *occupied.get();
+ debug!("Inserting {:?} -> {:?}", bb, value);
+ duplicates.try_insert(bb, value).expect("key was already inserted");
+ }
+ Entry::Vacant(vacant) => {
+ vacant.insert(bb);
+ }
+ }
+ }
+
+ duplicates
+}
+
+struct BasicBlockHashable<'tcx, 'a> {
+ basic_block_data: &'a BasicBlockData<'tcx>,
+}
+
+impl Hash for BasicBlockHashable<'_, '_> {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ hash_statements(state, self.basic_block_data.statements.iter());
+ // Note that since we only hash the kind, we lose span information if we deduplicate the blocks
+ self.basic_block_data.terminator().kind.hash(state);
+ }
+}
+
+impl Eq for BasicBlockHashable<'_, '_> {}
+
+impl PartialEq for BasicBlockHashable<'_, '_> {
+ fn eq(&self, other: &Self) -> bool {
+ self.basic_block_data.statements.len() == other.basic_block_data.statements.len()
+ && &self.basic_block_data.terminator().kind == &other.basic_block_data.terminator().kind
+ && iter::zip(&self.basic_block_data.statements, &other.basic_block_data.statements)
+ .all(|(x, y)| statement_eq(&x.kind, &y.kind))
+ }
+}
+
+fn hash_statements<'a, 'tcx, H: Hasher>(
+ hasher: &mut H,
+ iter: impl Iterator<Item = &'a Statement<'tcx>>,
+) where
+ 'tcx: 'a,
+{
+ for stmt in iter {
+ statement_hash(hasher, &stmt.kind);
+ }
+}
+
+fn statement_hash<H: Hasher>(hasher: &mut H, stmt: &StatementKind<'_>) {
+ match stmt {
+ StatementKind::Assign(box (place, rvalue)) => {
+ place.hash(hasher);
+ rvalue_hash(hasher, rvalue)
+ }
+ x => x.hash(hasher),
+ };
+}
+
+fn rvalue_hash<H: Hasher>(hasher: &mut H, rvalue: &Rvalue<'_>) {
+ match rvalue {
+ Rvalue::Use(op) => operand_hash(hasher, op),
+ x => x.hash(hasher),
+ };
+}
+
+fn operand_hash<H: Hasher>(hasher: &mut H, operand: &Operand<'_>) {
+ match operand {
+ Operand::Constant(box Constant { user_ty: _, literal, span: _ }) => literal.hash(hasher),
+ x => x.hash(hasher),
+ };
+}
+
+fn statement_eq<'tcx>(lhs: &StatementKind<'tcx>, rhs: &StatementKind<'tcx>) -> bool {
+ let res = match (lhs, rhs) {
+ (
+ StatementKind::Assign(box (place, rvalue)),
+ StatementKind::Assign(box (place2, rvalue2)),
+ ) => place == place2 && rvalue_eq(rvalue, rvalue2),
+ (x, y) => x == y,
+ };
+ debug!("statement_eq lhs: `{:?}` rhs: `{:?}` result: {:?}", lhs, rhs, res);
+ res
+}
+
+fn rvalue_eq<'tcx>(lhs: &Rvalue<'tcx>, rhs: &Rvalue<'tcx>) -> bool {
+ let res = match (lhs, rhs) {
+ (Rvalue::Use(op1), Rvalue::Use(op2)) => operand_eq(op1, op2),
+ (x, y) => x == y,
+ };
+ debug!("rvalue_eq lhs: `{:?}` rhs: `{:?}` result: {:?}", lhs, rhs, res);
+ res
+}
+
+fn operand_eq<'tcx>(lhs: &Operand<'tcx>, rhs: &Operand<'tcx>) -> bool {
+ let res = match (lhs, rhs) {
+ (
+ Operand::Constant(box Constant { user_ty: _, literal, span: _ }),
+ Operand::Constant(box Constant { user_ty: _, literal: literal2, span: _ }),
+ ) => literal == literal2,
+ (x, y) => x == y,
+ };
+ debug!("operand_eq lhs: `{:?}` rhs: `{:?}` result: {:?}", lhs, rhs, res);
+ res
+}