From 4f9fe856a25ab29345b90e7725509e9ee38a37be Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 17 Apr 2024 14:19:41 +0200 Subject: Adding upstream version 1.69.0+dfsg1. Signed-off-by: Daniel Baumann --- compiler/rustc_mir_transform/src/generator.rs | 389 ++++++++++++++++++++++---- 1 file changed, 333 insertions(+), 56 deletions(-) (limited to 'compiler/rustc_mir_transform/src/generator.rs') diff --git a/compiler/rustc_mir_transform/src/generator.rs b/compiler/rustc_mir_transform/src/generator.rs index 39c61a34a..2e97312ee 100644 --- a/compiler/rustc_mir_transform/src/generator.rs +++ b/compiler/rustc_mir_transform/src/generator.rs @@ -52,9 +52,9 @@ use crate::deref_separator::deref_finder; use crate::simplify; -use crate::util::expand_aggregate; use crate::MirPass; -use rustc_data_structures::fx::FxHashMap; +use rustc_data_structures::fx::{FxHashMap, FxHashSet}; +use rustc_errors::pluralize; use rustc_hir as hir; use rustc_hir::lang_items::LangItem; use rustc_hir::GeneratorKind; @@ -70,6 +70,9 @@ use rustc_mir_dataflow::impls::{ }; use rustc_mir_dataflow::storage::always_storage_live_locals; use rustc_mir_dataflow::{self, Analysis}; +use rustc_span::def_id::DefId; +use rustc_span::symbol::sym; +use rustc_span::Span; use rustc_target::abi::VariantIdx; use rustc_target::spec::PanicStrategy; use std::{iter, ops}; @@ -123,7 +126,7 @@ impl<'tcx> MutVisitor<'tcx> for DerefArgVisitor<'tcx> { place, Place { local: SELF_ARG, - projection: self.tcx().intern_place_elems(&[ProjectionElem::Deref]), + projection: self.tcx().mk_place_elems(&[ProjectionElem::Deref]), }, self.tcx, ); @@ -159,10 +162,9 @@ impl<'tcx> MutVisitor<'tcx> for PinArgVisitor<'tcx> { place, Place { local: SELF_ARG, - projection: self.tcx().intern_place_elems(&[ProjectionElem::Field( - Field::new(0), - self.ref_gen_ty, - )]), + projection: self + .tcx() + .mk_place_elems(&[ProjectionElem::Field(Field::new(0), self.ref_gen_ty)]), }, self.tcx, ); @@ -184,7 +186,7 @@ fn replace_base<'tcx>(place: &mut Place<'tcx>, new_base: Place<'tcx>, tcx: TyCtx let mut new_projection = new_base.projection.to_vec(); new_projection.append(&mut place.projection.to_vec()); - place.projection = tcx.intern_place_elems(&new_projection); + place.projection = tcx.mk_place_elems(&new_projection); } const SELF_ARG: Local = Local::from_u32(1); @@ -268,31 +270,26 @@ impl<'tcx> TransformVisitor<'tcx> { assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0); // FIXME(swatinem): assert that `val` is indeed unit? - statements.extend(expand_aggregate( - Place::return_place(), - std::iter::empty(), - kind, + statements.push(Statement { + kind: StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::Aggregate(Box::new(kind), vec![]), + ))), source_info, - self.tcx, - )); + }); return; } // else: `Poll::Ready(x)`, `GeneratorState::Yielded(x)` or `GeneratorState::Complete(x)` assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1); - let ty = self - .tcx - .bound_type_of(self.state_adt_ref.variant(idx).fields[0].did) - .subst(self.tcx, self.state_substs); - - statements.extend(expand_aggregate( - Place::return_place(), - std::iter::once((val, ty)), - kind, + statements.push(Statement { + kind: StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::Aggregate(Box::new(kind), vec![val]), + ))), source_info, - self.tcx, - )); + }); } // Create a Place referencing a generator struct field @@ -302,7 +299,7 @@ impl<'tcx> TransformVisitor<'tcx> { let mut projection = base.projection.to_vec(); projection.push(ProjectionElem::Field(Field::new(idx), ty)); - Place { local: base.local, projection: self.tcx.intern_place_elems(&projection) } + Place { local: base.local, projection: self.tcx.mk_place_elems(&projection) } } // Create a statement which changes the discriminant @@ -429,7 +426,7 @@ fn make_generator_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body let pin_did = tcx.require_lang_item(LangItem::Pin, Some(body.span)); let pin_adt_ref = tcx.adt_def(pin_did); - let substs = tcx.intern_substs(&[ref_gen_ty.into()]); + let substs = tcx.mk_substs(&[ref_gen_ty.into()]); let pin_ref_gen_ty = tcx.mk_adt(pin_adt_ref, substs); // Replace the by ref generator argument @@ -489,7 +486,7 @@ fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, None); - for bb in BasicBlock::new(0)..body.basic_blocks.next_index() { + for bb in START_BLOCK..body.basic_blocks.next_index() { let bb_data = &body[bb]; if bb_data.is_cleanup { continue; @@ -854,7 +851,7 @@ fn sanitize_witness<'tcx>( body: &Body<'tcx>, witness: Ty<'tcx>, upvars: Vec>, - saved_locals: &GeneratorSavedLocals, + layout: &GeneratorLayout<'tcx>, ) { let did = body.source.def_id(); let param_env = tcx.param_env(did); @@ -873,31 +870,36 @@ fn sanitize_witness<'tcx>( } }; - for (local, decl) in body.local_decls.iter_enumerated() { - // Ignore locals which are internal or not saved between yields. - if !saved_locals.contains(local) || decl.internal { + let mut mismatches = Vec::new(); + for fty in &layout.field_tys { + if fty.ignore_for_traits { continue; } - let decl_ty = tcx.normalize_erasing_regions(param_env, decl.ty); + let decl_ty = tcx.normalize_erasing_regions(param_env, fty.ty); // Sanity check that typeck knows about the type of locals which are // live across a suspension point if !allowed.contains(&decl_ty) && !allowed_upvars.contains(&decl_ty) { - span_bug!( - body.span, - "Broken MIR: generator contains type {} in MIR, \ - but typeck only knows about {} and {:?}", - decl_ty, - allowed, - allowed_upvars - ); + mismatches.push(decl_ty); } } + + if !mismatches.is_empty() { + span_bug!( + body.span, + "Broken MIR: generator contains type {:?} in MIR, \ + but typeck only knows about {} and {:?}", + mismatches, + allowed, + allowed_upvars + ); + } } fn compute_layout<'tcx>( + tcx: TyCtxt<'tcx>, liveness: LivenessInfo, - body: &mut Body<'tcx>, + body: &Body<'tcx>, ) -> ( FxHashMap, VariantIdx, usize)>, GeneratorLayout<'tcx>, @@ -915,9 +917,33 @@ fn compute_layout<'tcx>( let mut locals = IndexVec::::new(); let mut tys = IndexVec::::new(); for (saved_local, local) in saved_locals.iter_enumerated() { - locals.push(local); - tys.push(body.local_decls[local].ty); debug!("generator saved local {:?} => {:?}", saved_local, local); + + locals.push(local); + let decl = &body.local_decls[local]; + debug!(?decl); + + let ignore_for_traits = if tcx.sess.opts.unstable_opts.drop_tracking_mir { + match decl.local_info { + // Do not include raw pointers created from accessing `static` items, as those could + // well be re-created by another access to the same static. + Some(box LocalInfo::StaticRef { is_thread_local, .. }) => !is_thread_local, + // Fake borrows are only read by fake reads, so do not have any reality in + // post-analysis MIR. + Some(box LocalInfo::FakeBorrow) => true, + _ => false, + } + } else { + // FIXME(#105084) HIR-based drop tracking does not account for all the temporaries that + // MIR building may introduce. This leads to wrongly ignored types, but this is + // necessary for internal consistency and to avoid ICEs. + decl.internal + }; + let decl = + GeneratorSavedTy { ty: decl.ty, source_info: decl.source_info, ignore_for_traits }; + debug!(?decl); + + tys.push(decl); } // Leave empty variants for the UNRESUMED, RETURNED, and POISONED states. @@ -947,7 +973,7 @@ fn compute_layout<'tcx>( // just use the first one here. That's fine; fields do not move // around inside generators, so it doesn't matter which variant // index we access them by. - remap.entry(locals[saved_local]).or_insert((tys[saved_local], variant_index, idx)); + remap.entry(locals[saved_local]).or_insert((tys[saved_local].ty, variant_index, idx)); } variant_fields.push(fields); variant_source_info.push(source_info_at_suspension_points[suspension_point_idx]); @@ -957,6 +983,7 @@ fn compute_layout<'tcx>( let layout = GeneratorLayout { field_tys: tys, variant_fields, variant_source_info, storage_conflicts }; + debug!(?layout); (remap, layout, storage_liveness) } @@ -1227,7 +1254,7 @@ fn create_generator_resume_function<'tcx>( use rustc_middle::mir::AssertKind::{ResumedAfterPanic, ResumedAfterReturn}; // Jump to the entry point on the unresumed - cases.insert(0, (UNRESUMED, BasicBlock::new(0))); + cases.insert(0, (UNRESUMED, START_BLOCK)); // Panic when resumed on the returned or poisoned state let generator_kind = body.generator_kind().unwrap(); @@ -1351,6 +1378,42 @@ fn create_cases<'tcx>( .collect() } +#[instrument(level = "debug", skip(tcx), ret)] +pub(crate) fn mir_generator_witnesses<'tcx>( + tcx: TyCtxt<'tcx>, + def_id: DefId, +) -> GeneratorLayout<'tcx> { + assert!(tcx.sess.opts.unstable_opts.drop_tracking_mir); + let def_id = def_id.expect_local(); + + let (body, _) = tcx.mir_promoted(ty::WithOptConstParam::unknown(def_id)); + let body = body.borrow(); + let body = &*body; + + // The first argument is the generator type passed by value + let gen_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty; + + // Get the interior types and substs which typeck computed + let movable = match *gen_ty.kind() { + ty::Generator(_, _, movability) => movability == hir::Movability::Movable, + _ => span_bug!(body.span, "unexpected generator type {}", gen_ty), + }; + + // When first entering the generator, move the resume argument into its new local. + let always_live_locals = always_storage_live_locals(&body); + + let liveness_info = locals_live_across_suspend_points(tcx, body, &always_live_locals, movable); + + // Extract locals which are live across suspension point into `layout` + // `remap` gives a mapping from local indices onto generator struct indices + // `storage_liveness` tells us which locals have live storage at suspension points + let (_, generator_layout, _) = compute_layout(tcx, liveness_info, body); + + check_suspend_tys(tcx, &generator_layout, &body); + + generator_layout +} + impl<'tcx> MirPass<'tcx> for StateTransform { fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let Some(yield_ty) = body.yield_ty() else { @@ -1363,14 +1426,14 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // The first argument is the generator type passed by value let gen_ty = body.local_decls.raw[1].ty; - // Get the interior types and substs which typeck computed - let (upvars, interior, discr_ty, movable) = match *gen_ty.kind() { + // Get the discriminant type and substs which typeck computed + let (discr_ty, upvars, interior, movable) = match *gen_ty.kind() { ty::Generator(_, substs, movability) => { let substs = substs.as_generator(); ( - substs.upvar_tys().collect(), - substs.witness(), substs.discr_ty(tcx), + substs.upvar_tys().collect::>(), + substs.witness(), movability == hir::Movability::Movable, ) } @@ -1386,13 +1449,13 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // Compute Poll let poll_did = tcx.require_lang_item(LangItem::Poll, None); let poll_adt_ref = tcx.adt_def(poll_did); - let poll_substs = tcx.intern_substs(&[body.return_ty().into()]); + let poll_substs = tcx.mk_substs(&[body.return_ty().into()]); (poll_adt_ref, poll_substs) } else { // Compute GeneratorState let state_did = tcx.require_lang_item(LangItem::GeneratorState, None); let state_adt_ref = tcx.adt_def(state_did); - let state_substs = tcx.intern_substs(&[yield_ty.into(), body.return_ty().into()]); + let state_substs = tcx.mk_substs(&[yield_ty.into(), body.return_ty().into()]); (state_adt_ref, state_substs) }; let ret_ty = tcx.mk_adt(state_adt_ref, state_substs); @@ -1417,7 +1480,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // When first entering the generator, move the resume argument into its new local. let source_info = SourceInfo::outermost(body.span); - let stmts = &mut body.basic_blocks_mut()[BasicBlock::new(0)].statements; + let stmts = &mut body.basic_blocks_mut()[START_BLOCK].statements; stmts.insert( 0, Statement { @@ -1434,8 +1497,6 @@ impl<'tcx> MirPass<'tcx> for StateTransform { let liveness_info = locals_live_across_suspend_points(tcx, body, &always_live_locals, movable); - sanitize_witness(tcx, body, interior, upvars, &liveness_info.saved_locals); - if tcx.sess.opts.unstable_opts.validate_mir { let mut vis = EnsureGeneratorFieldAssignmentsNeverAlias { assigned_local: None, @@ -1449,7 +1510,13 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // Extract locals which are live across suspension point into `layout` // `remap` gives a mapping from local indices onto generator struct indices // `storage_liveness` tells us which locals have live storage at suspension points - let (remap, layout, storage_liveness) = compute_layout(liveness_info, body); + let (remap, layout, storage_liveness) = compute_layout(tcx, liveness_info, body); + + if tcx.sess.opts.unstable_opts.validate_mir + && !tcx.sess.opts.unstable_opts.drop_tracking_mir + { + sanitize_witness(tcx, body, interior, upvars, &layout); + } let can_return = can_return(tcx, body, tcx.param_env(body.source.def_id())); @@ -1583,6 +1650,7 @@ impl<'tcx> Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> { | StatementKind::AscribeUserType(..) | StatementKind::Coverage(..) | StatementKind::Intrinsic(..) + | StatementKind::ConstEvalCounter | StatementKind::Nop => {} } } @@ -1631,3 +1699,212 @@ impl<'tcx> Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> { } } } + +fn check_suspend_tys<'tcx>(tcx: TyCtxt<'tcx>, layout: &GeneratorLayout<'tcx>, body: &Body<'tcx>) { + let mut linted_tys = FxHashSet::default(); + + // We want a user-facing param-env. + let param_env = tcx.param_env(body.source.def_id()); + + for (variant, yield_source_info) in + layout.variant_fields.iter().zip(&layout.variant_source_info) + { + debug!(?variant); + for &local in variant { + let decl = &layout.field_tys[local]; + debug!(?decl); + + if !decl.ignore_for_traits && linted_tys.insert(decl.ty) { + let Some(hir_id) = decl.source_info.scope.lint_root(&body.source_scopes) else { continue }; + + check_must_not_suspend_ty( + tcx, + decl.ty, + hir_id, + param_env, + SuspendCheckData { + source_span: decl.source_info.span, + yield_span: yield_source_info.span, + plural_len: 1, + ..Default::default() + }, + ); + } + } + } +} + +#[derive(Default)] +struct SuspendCheckData<'a> { + source_span: Span, + yield_span: Span, + descr_pre: &'a str, + descr_post: &'a str, + plural_len: usize, +} + +// Returns whether it emitted a diagnostic or not +// Note that this fn and the proceeding one are based on the code +// for creating must_use diagnostics +// +// Note that this technique was chosen over things like a `Suspend` marker trait +// as it is simpler and has precedent in the compiler +fn check_must_not_suspend_ty<'tcx>( + tcx: TyCtxt<'tcx>, + ty: Ty<'tcx>, + hir_id: hir::HirId, + param_env: ty::ParamEnv<'tcx>, + data: SuspendCheckData<'_>, +) -> bool { + if ty.is_unit() { + return false; + } + + let plural_suffix = pluralize!(data.plural_len); + + debug!("Checking must_not_suspend for {}", ty); + + match *ty.kind() { + ty::Adt(..) if ty.is_box() => { + let boxed_ty = ty.boxed_ty(); + let descr_pre = &format!("{}boxed ", data.descr_pre); + check_must_not_suspend_ty( + tcx, + boxed_ty, + hir_id, + param_env, + SuspendCheckData { descr_pre, ..data }, + ) + } + ty::Adt(def, _) => check_must_not_suspend_def(tcx, def.did(), hir_id, data), + // FIXME: support adding the attribute to TAITs + ty::Alias(ty::Opaque, ty::AliasTy { def_id: def, .. }) => { + let mut has_emitted = false; + for &(predicate, _) in tcx.explicit_item_bounds(def) { + // We only look at the `DefId`, so it is safe to skip the binder here. + if let ty::PredicateKind::Clause(ty::Clause::Trait(ref poly_trait_predicate)) = + predicate.kind().skip_binder() + { + let def_id = poly_trait_predicate.trait_ref.def_id; + let descr_pre = &format!("{}implementer{} of ", data.descr_pre, plural_suffix); + if check_must_not_suspend_def( + tcx, + def_id, + hir_id, + SuspendCheckData { descr_pre, ..data }, + ) { + has_emitted = true; + break; + } + } + } + has_emitted + } + ty::Dynamic(binder, _, _) => { + let mut has_emitted = false; + for predicate in binder.iter() { + if let ty::ExistentialPredicate::Trait(ref trait_ref) = predicate.skip_binder() { + let def_id = trait_ref.def_id; + let descr_post = &format!(" trait object{}{}", plural_suffix, data.descr_post); + if check_must_not_suspend_def( + tcx, + def_id, + hir_id, + SuspendCheckData { descr_post, ..data }, + ) { + has_emitted = true; + break; + } + } + } + has_emitted + } + ty::Tuple(fields) => { + let mut has_emitted = false; + for (i, ty) in fields.iter().enumerate() { + let descr_post = &format!(" in tuple element {i}"); + if check_must_not_suspend_ty( + tcx, + ty, + hir_id, + param_env, + SuspendCheckData { descr_post, ..data }, + ) { + has_emitted = true; + } + } + has_emitted + } + ty::Array(ty, len) => { + let descr_pre = &format!("{}array{} of ", data.descr_pre, plural_suffix); + check_must_not_suspend_ty( + tcx, + ty, + hir_id, + param_env, + SuspendCheckData { + descr_pre, + plural_len: len.try_eval_target_usize(tcx, param_env).unwrap_or(0) as usize + 1, + ..data + }, + ) + } + // If drop tracking is enabled, we want to look through references, since the referrent + // may not be considered live across the await point. + ty::Ref(_region, ty, _mutability) => { + let descr_pre = &format!("{}reference{} to ", data.descr_pre, plural_suffix); + check_must_not_suspend_ty( + tcx, + ty, + hir_id, + param_env, + SuspendCheckData { descr_pre, ..data }, + ) + } + _ => false, + } +} + +fn check_must_not_suspend_def( + tcx: TyCtxt<'_>, + def_id: DefId, + hir_id: hir::HirId, + data: SuspendCheckData<'_>, +) -> bool { + if let Some(attr) = tcx.get_attr(def_id, sym::must_not_suspend) { + let msg = format!( + "{}`{}`{} held across a suspend point, but should not be", + data.descr_pre, + tcx.def_path_str(def_id), + data.descr_post, + ); + tcx.struct_span_lint_hir( + rustc_session::lint::builtin::MUST_NOT_SUSPEND, + hir_id, + data.source_span, + msg, + |lint| { + // add span pointing to the offending yield/await + lint.span_label(data.yield_span, "the value is held across this suspend point"); + + // Add optional reason note + if let Some(note) = attr.value_str() { + // FIXME(guswynn): consider formatting this better + lint.span_note(data.source_span, note.as_str()); + } + + // Add some quick suggestions on what to do + // FIXME: can `drop` work as a suggestion here as well? + lint.span_help( + data.source_span, + "consider using a block (`{ ... }`) \ + to shrink the value's scope, ending before the suspend point", + ) + }, + ); + + true + } else { + false + } +} -- cgit v1.2.3