summaryrefslogtreecommitdiffstats
path: root/compiler/rustc_mir_transform/src/large_enums.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src/large_enums.rs')
-rw-r--r--compiler/rustc_mir_transform/src/large_enums.rs294
1 files changed, 294 insertions, 0 deletions
diff --git a/compiler/rustc_mir_transform/src/large_enums.rs b/compiler/rustc_mir_transform/src/large_enums.rs
new file mode 100644
index 000000000..89e0a007d
--- /dev/null
+++ b/compiler/rustc_mir_transform/src/large_enums.rs
@@ -0,0 +1,294 @@
+use crate::rustc_middle::ty::util::IntTypeExt;
+use crate::MirPass;
+use rustc_data_structures::fx::FxHashMap;
+use rustc_middle::mir::interpret::AllocId;
+use rustc_middle::mir::*;
+use rustc_middle::ty::{self, AdtDef, ParamEnv, Ty, TyCtxt};
+use rustc_session::Session;
+use rustc_target::abi::{HasDataLayout, Size, TagEncoding, Variants};
+
+/// A pass that seeks to optimize unnecessary moves of large enum types, if there is a large
+/// enough discrepancy between them.
+///
+/// i.e. If there is are two variants:
+/// ```
+/// enum Example {
+/// Small,
+/// Large([u32; 1024]),
+/// }
+/// ```
+/// Instead of emitting moves of the large variant,
+/// Perform a memcpy instead.
+/// Based off of [this HackMD](https://hackmd.io/@ft4bxUsFT5CEUBmRKYHr7w/rJM8BBPzD).
+///
+/// In summary, what this does is at runtime determine which enum variant is active,
+/// and instead of copying all the bytes of the largest possible variant,
+/// copy only the bytes for the currently active variant.
+pub struct EnumSizeOpt {
+ pub(crate) discrepancy: u64,
+}
+
+impl<'tcx> MirPass<'tcx> for EnumSizeOpt {
+ fn is_enabled(&self, sess: &Session) -> bool {
+ sess.opts.unstable_opts.unsound_mir_opts || sess.mir_opt_level() >= 3
+ }
+ fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+ // NOTE: This pass may produce different MIR based on the alignment of the target
+ // platform, but it will still be valid.
+ self.optim(tcx, body);
+ }
+}
+
+impl EnumSizeOpt {
+ fn candidate<'tcx>(
+ &self,
+ tcx: TyCtxt<'tcx>,
+ param_env: ParamEnv<'tcx>,
+ ty: Ty<'tcx>,
+ alloc_cache: &mut FxHashMap<Ty<'tcx>, AllocId>,
+ ) -> Option<(AdtDef<'tcx>, usize, AllocId)> {
+ let adt_def = match ty.kind() {
+ ty::Adt(adt_def, _substs) if adt_def.is_enum() => adt_def,
+ _ => return None,
+ };
+ let layout = tcx.layout_of(param_env.and(ty)).ok()?;
+ let variants = match &layout.variants {
+ Variants::Single { .. } => return None,
+ Variants::Multiple { tag_encoding, .. }
+ if matches!(tag_encoding, TagEncoding::Niche { .. }) =>
+ {
+ return None;
+ }
+ Variants::Multiple { variants, .. } if variants.len() <= 1 => return None,
+ Variants::Multiple { variants, .. } => variants,
+ };
+ let min = variants.iter().map(|v| v.size).min().unwrap();
+ let max = variants.iter().map(|v| v.size).max().unwrap();
+ if max.bytes() - min.bytes() < self.discrepancy {
+ return None;
+ }
+
+ let num_discrs = adt_def.discriminants(tcx).count();
+ if variants.iter_enumerated().any(|(var_idx, _)| {
+ let discr_for_var = adt_def.discriminant_for_variant(tcx, var_idx).val;
+ (discr_for_var > usize::MAX as u128) || (discr_for_var as usize >= num_discrs)
+ }) {
+ return None;
+ }
+ if let Some(alloc_id) = alloc_cache.get(&ty) {
+ return Some((*adt_def, num_discrs, *alloc_id));
+ }
+
+ let data_layout = tcx.data_layout();
+ let ptr_sized_int = data_layout.ptr_sized_integer();
+ let target_bytes = ptr_sized_int.size().bytes() as usize;
+ let mut data = vec![0; target_bytes * num_discrs];
+ macro_rules! encode_store {
+ ($curr_idx: expr, $endian: expr, $bytes: expr) => {
+ let bytes = match $endian {
+ rustc_target::abi::Endian::Little => $bytes.to_le_bytes(),
+ rustc_target::abi::Endian::Big => $bytes.to_be_bytes(),
+ };
+ for (i, b) in bytes.into_iter().enumerate() {
+ data[$curr_idx + i] = b;
+ }
+ };
+ }
+
+ for (var_idx, layout) in variants.iter_enumerated() {
+ let curr_idx =
+ target_bytes * adt_def.discriminant_for_variant(tcx, var_idx).val as usize;
+ let sz = layout.size;
+ match ptr_sized_int {
+ rustc_target::abi::Integer::I32 => {
+ encode_store!(curr_idx, data_layout.endian, sz.bytes() as u32);
+ }
+ rustc_target::abi::Integer::I64 => {
+ encode_store!(curr_idx, data_layout.endian, sz.bytes());
+ }
+ _ => unreachable!(),
+ };
+ }
+ let alloc = interpret::Allocation::from_bytes(
+ data,
+ tcx.data_layout.ptr_sized_integer().align(&tcx.data_layout).abi,
+ Mutability::Not,
+ );
+ let alloc = tcx.create_memory_alloc(tcx.mk_const_alloc(alloc));
+ Some((*adt_def, num_discrs, *alloc_cache.entry(ty).or_insert(alloc)))
+ }
+ fn optim<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+ let mut alloc_cache = FxHashMap::default();
+ let body_did = body.source.def_id();
+ let param_env = tcx.param_env(body_did);
+
+ let blocks = body.basic_blocks.as_mut();
+ let local_decls = &mut body.local_decls;
+
+ for bb in blocks {
+ bb.expand_statements(|st| {
+ if let StatementKind::Assign(box (
+ lhs,
+ Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)),
+ )) = &st.kind
+ {
+ let ty = lhs.ty(local_decls, tcx).ty;
+
+ let source_info = st.source_info;
+ let span = source_info.span;
+
+ let (adt_def, num_variants, alloc_id) =
+ self.candidate(tcx, param_env, ty, &mut alloc_cache)?;
+ let alloc = tcx.global_alloc(alloc_id).unwrap_memory();
+
+ let tmp_ty = tcx.mk_array(tcx.types.usize, num_variants as u64);
+
+ let size_array_local = local_decls.push(LocalDecl::new(tmp_ty, span));
+ let store_live = Statement {
+ source_info,
+ kind: StatementKind::StorageLive(size_array_local),
+ };
+
+ let place = Place::from(size_array_local);
+ let constant_vals = Constant {
+ span,
+ user_ty: None,
+ literal: ConstantKind::Val(
+ interpret::ConstValue::ByRef { alloc, offset: Size::ZERO },
+ tmp_ty,
+ ),
+ };
+ let rval = Rvalue::Use(Operand::Constant(box (constant_vals)));
+
+ let const_assign =
+ Statement { source_info, kind: StatementKind::Assign(box (place, rval)) };
+
+ let discr_place = Place::from(
+ local_decls
+ .push(LocalDecl::new(adt_def.repr().discr_type().to_ty(tcx), span)),
+ );
+
+ let store_discr = Statement {
+ source_info,
+ kind: StatementKind::Assign(box (discr_place, Rvalue::Discriminant(*rhs))),
+ };
+
+ let discr_cast_place =
+ Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span)));
+
+ let cast_discr = Statement {
+ source_info,
+ kind: StatementKind::Assign(box (
+ discr_cast_place,
+ Rvalue::Cast(
+ CastKind::IntToInt,
+ Operand::Copy(discr_place),
+ tcx.types.usize,
+ ),
+ )),
+ };
+
+ let size_place =
+ Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span)));
+
+ let store_size = Statement {
+ source_info,
+ kind: StatementKind::Assign(box (
+ size_place,
+ Rvalue::Use(Operand::Copy(Place {
+ local: size_array_local,
+ projection: tcx
+ .mk_place_elems(&[PlaceElem::Index(discr_cast_place.local)]),
+ })),
+ )),
+ };
+
+ let dst =
+ Place::from(local_decls.push(LocalDecl::new(tcx.mk_mut_ptr(ty), span)));
+
+ let dst_ptr = Statement {
+ source_info,
+ kind: StatementKind::Assign(box (
+ dst,
+ Rvalue::AddressOf(Mutability::Mut, *lhs),
+ )),
+ };
+
+ let dst_cast_ty = tcx.mk_mut_ptr(tcx.types.u8);
+ let dst_cast_place =
+ Place::from(local_decls.push(LocalDecl::new(dst_cast_ty, span)));
+
+ let dst_cast = Statement {
+ source_info,
+ kind: StatementKind::Assign(box (
+ dst_cast_place,
+ Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(dst), dst_cast_ty),
+ )),
+ };
+
+ let src =
+ Place::from(local_decls.push(LocalDecl::new(tcx.mk_imm_ptr(ty), span)));
+
+ let src_ptr = Statement {
+ source_info,
+ kind: StatementKind::Assign(box (
+ src,
+ Rvalue::AddressOf(Mutability::Not, *rhs),
+ )),
+ };
+
+ let src_cast_ty = tcx.mk_imm_ptr(tcx.types.u8);
+ let src_cast_place =
+ Place::from(local_decls.push(LocalDecl::new(src_cast_ty, span)));
+
+ let src_cast = Statement {
+ source_info,
+ kind: StatementKind::Assign(box (
+ src_cast_place,
+ Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(src), src_cast_ty),
+ )),
+ };
+
+ let deinit_old =
+ Statement { source_info, kind: StatementKind::Deinit(box dst) };
+
+ let copy_bytes = Statement {
+ source_info,
+ kind: StatementKind::Intrinsic(
+ box NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping {
+ src: Operand::Copy(src_cast_place),
+ dst: Operand::Copy(dst_cast_place),
+ count: Operand::Copy(size_place),
+ }),
+ ),
+ };
+
+ let store_dead = Statement {
+ source_info,
+ kind: StatementKind::StorageDead(size_array_local),
+ };
+ let iter = [
+ store_live,
+ const_assign,
+ store_discr,
+ cast_discr,
+ store_size,
+ dst_ptr,
+ dst_cast,
+ src_ptr,
+ src_cast,
+ deinit_old,
+ copy_bytes,
+ store_dead,
+ ]
+ .into_iter();
+
+ st.make_nop();
+ Some(iter)
+ } else {
+ None
+ }
+ });
+ }
+ }
+}