summaryrefslogtreecommitdiffstats
path: root/compiler/rustc_mir_transform/src/generator.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src/generator.rs')
-rw-r--r--compiler/rustc_mir_transform/src/generator.rs389
1 files changed, 333 insertions, 56 deletions
diff --git a/compiler/rustc_mir_transform/src/generator.rs b/compiler/rustc_mir_transform/src/generator.rs
index 39c61a34a..2e97312ee 100644
--- a/compiler/rustc_mir_transform/src/generator.rs
+++ b/compiler/rustc_mir_transform/src/generator.rs
@@ -52,9 +52,9 @@
use crate::deref_separator::deref_finder;
use crate::simplify;
-use crate::util::expand_aggregate;
use crate::MirPass;
-use rustc_data_structures::fx::FxHashMap;
+use rustc_data_structures::fx::{FxHashMap, FxHashSet};
+use rustc_errors::pluralize;
use rustc_hir as hir;
use rustc_hir::lang_items::LangItem;
use rustc_hir::GeneratorKind;
@@ -70,6 +70,9 @@ use rustc_mir_dataflow::impls::{
};
use rustc_mir_dataflow::storage::always_storage_live_locals;
use rustc_mir_dataflow::{self, Analysis};
+use rustc_span::def_id::DefId;
+use rustc_span::symbol::sym;
+use rustc_span::Span;
use rustc_target::abi::VariantIdx;
use rustc_target::spec::PanicStrategy;
use std::{iter, ops};
@@ -123,7 +126,7 @@ impl<'tcx> MutVisitor<'tcx> for DerefArgVisitor<'tcx> {
place,
Place {
local: SELF_ARG,
- projection: self.tcx().intern_place_elems(&[ProjectionElem::Deref]),
+ projection: self.tcx().mk_place_elems(&[ProjectionElem::Deref]),
},
self.tcx,
);
@@ -159,10 +162,9 @@ impl<'tcx> MutVisitor<'tcx> for PinArgVisitor<'tcx> {
place,
Place {
local: SELF_ARG,
- projection: self.tcx().intern_place_elems(&[ProjectionElem::Field(
- Field::new(0),
- self.ref_gen_ty,
- )]),
+ projection: self
+ .tcx()
+ .mk_place_elems(&[ProjectionElem::Field(Field::new(0), self.ref_gen_ty)]),
},
self.tcx,
);
@@ -184,7 +186,7 @@ fn replace_base<'tcx>(place: &mut Place<'tcx>, new_base: Place<'tcx>, tcx: TyCtx
let mut new_projection = new_base.projection.to_vec();
new_projection.append(&mut place.projection.to_vec());
- place.projection = tcx.intern_place_elems(&new_projection);
+ place.projection = tcx.mk_place_elems(&new_projection);
}
const SELF_ARG: Local = Local::from_u32(1);
@@ -268,31 +270,26 @@ impl<'tcx> TransformVisitor<'tcx> {
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
// FIXME(swatinem): assert that `val` is indeed unit?
- statements.extend(expand_aggregate(
- Place::return_place(),
- std::iter::empty(),
- kind,
+ statements.push(Statement {
+ kind: StatementKind::Assign(Box::new((
+ Place::return_place(),
+ Rvalue::Aggregate(Box::new(kind), vec![]),
+ ))),
source_info,
- self.tcx,
- ));
+ });
return;
}
// else: `Poll::Ready(x)`, `GeneratorState::Yielded(x)` or `GeneratorState::Complete(x)`
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1);
- let ty = self
- .tcx
- .bound_type_of(self.state_adt_ref.variant(idx).fields[0].did)
- .subst(self.tcx, self.state_substs);
-
- statements.extend(expand_aggregate(
- Place::return_place(),
- std::iter::once((val, ty)),
- kind,
+ statements.push(Statement {
+ kind: StatementKind::Assign(Box::new((
+ Place::return_place(),
+ Rvalue::Aggregate(Box::new(kind), vec![val]),
+ ))),
source_info,
- self.tcx,
- ));
+ });
}
// Create a Place referencing a generator struct field
@@ -302,7 +299,7 @@ impl<'tcx> TransformVisitor<'tcx> {
let mut projection = base.projection.to_vec();
projection.push(ProjectionElem::Field(Field::new(idx), ty));
- Place { local: base.local, projection: self.tcx.intern_place_elems(&projection) }
+ Place { local: base.local, projection: self.tcx.mk_place_elems(&projection) }
}
// Create a statement which changes the discriminant
@@ -429,7 +426,7 @@ fn make_generator_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
let pin_did = tcx.require_lang_item(LangItem::Pin, Some(body.span));
let pin_adt_ref = tcx.adt_def(pin_did);
- let substs = tcx.intern_substs(&[ref_gen_ty.into()]);
+ let substs = tcx.mk_substs(&[ref_gen_ty.into()]);
let pin_ref_gen_ty = tcx.mk_adt(pin_adt_ref, substs);
// Replace the by ref generator argument
@@ -489,7 +486,7 @@ fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, None);
- for bb in BasicBlock::new(0)..body.basic_blocks.next_index() {
+ for bb in START_BLOCK..body.basic_blocks.next_index() {
let bb_data = &body[bb];
if bb_data.is_cleanup {
continue;
@@ -854,7 +851,7 @@ fn sanitize_witness<'tcx>(
body: &Body<'tcx>,
witness: Ty<'tcx>,
upvars: Vec<Ty<'tcx>>,
- saved_locals: &GeneratorSavedLocals,
+ layout: &GeneratorLayout<'tcx>,
) {
let did = body.source.def_id();
let param_env = tcx.param_env(did);
@@ -873,31 +870,36 @@ fn sanitize_witness<'tcx>(
}
};
- for (local, decl) in body.local_decls.iter_enumerated() {
- // Ignore locals which are internal or not saved between yields.
- if !saved_locals.contains(local) || decl.internal {
+ let mut mismatches = Vec::new();
+ for fty in &layout.field_tys {
+ if fty.ignore_for_traits {
continue;
}
- let decl_ty = tcx.normalize_erasing_regions(param_env, decl.ty);
+ let decl_ty = tcx.normalize_erasing_regions(param_env, fty.ty);
// Sanity check that typeck knows about the type of locals which are
// live across a suspension point
if !allowed.contains(&decl_ty) && !allowed_upvars.contains(&decl_ty) {
- span_bug!(
- body.span,
- "Broken MIR: generator contains type {} in MIR, \
- but typeck only knows about {} and {:?}",
- decl_ty,
- allowed,
- allowed_upvars
- );
+ mismatches.push(decl_ty);
}
}
+
+ if !mismatches.is_empty() {
+ span_bug!(
+ body.span,
+ "Broken MIR: generator contains type {:?} in MIR, \
+ but typeck only knows about {} and {:?}",
+ mismatches,
+ allowed,
+ allowed_upvars
+ );
+ }
}
fn compute_layout<'tcx>(
+ tcx: TyCtxt<'tcx>,
liveness: LivenessInfo,
- body: &mut Body<'tcx>,
+ body: &Body<'tcx>,
) -> (
FxHashMap<Local, (Ty<'tcx>, VariantIdx, usize)>,
GeneratorLayout<'tcx>,
@@ -915,9 +917,33 @@ fn compute_layout<'tcx>(
let mut locals = IndexVec::<GeneratorSavedLocal, _>::new();
let mut tys = IndexVec::<GeneratorSavedLocal, _>::new();
for (saved_local, local) in saved_locals.iter_enumerated() {
- locals.push(local);
- tys.push(body.local_decls[local].ty);
debug!("generator saved local {:?} => {:?}", saved_local, local);
+
+ locals.push(local);
+ let decl = &body.local_decls[local];
+ debug!(?decl);
+
+ let ignore_for_traits = if tcx.sess.opts.unstable_opts.drop_tracking_mir {
+ match decl.local_info {
+ // Do not include raw pointers created from accessing `static` items, as those could
+ // well be re-created by another access to the same static.
+ Some(box LocalInfo::StaticRef { is_thread_local, .. }) => !is_thread_local,
+ // Fake borrows are only read by fake reads, so do not have any reality in
+ // post-analysis MIR.
+ Some(box LocalInfo::FakeBorrow) => true,
+ _ => false,
+ }
+ } else {
+ // FIXME(#105084) HIR-based drop tracking does not account for all the temporaries that
+ // MIR building may introduce. This leads to wrongly ignored types, but this is
+ // necessary for internal consistency and to avoid ICEs.
+ decl.internal
+ };
+ let decl =
+ GeneratorSavedTy { ty: decl.ty, source_info: decl.source_info, ignore_for_traits };
+ debug!(?decl);
+
+ tys.push(decl);
}
// Leave empty variants for the UNRESUMED, RETURNED, and POISONED states.
@@ -947,7 +973,7 @@ fn compute_layout<'tcx>(
// just use the first one here. That's fine; fields do not move
// around inside generators, so it doesn't matter which variant
// index we access them by.
- remap.entry(locals[saved_local]).or_insert((tys[saved_local], variant_index, idx));
+ remap.entry(locals[saved_local]).or_insert((tys[saved_local].ty, variant_index, idx));
}
variant_fields.push(fields);
variant_source_info.push(source_info_at_suspension_points[suspension_point_idx]);
@@ -957,6 +983,7 @@ fn compute_layout<'tcx>(
let layout =
GeneratorLayout { field_tys: tys, variant_fields, variant_source_info, storage_conflicts };
+ debug!(?layout);
(remap, layout, storage_liveness)
}
@@ -1227,7 +1254,7 @@ fn create_generator_resume_function<'tcx>(
use rustc_middle::mir::AssertKind::{ResumedAfterPanic, ResumedAfterReturn};
// Jump to the entry point on the unresumed
- cases.insert(0, (UNRESUMED, BasicBlock::new(0)));
+ cases.insert(0, (UNRESUMED, START_BLOCK));
// Panic when resumed on the returned or poisoned state
let generator_kind = body.generator_kind().unwrap();
@@ -1351,6 +1378,42 @@ fn create_cases<'tcx>(
.collect()
}
+#[instrument(level = "debug", skip(tcx), ret)]
+pub(crate) fn mir_generator_witnesses<'tcx>(
+ tcx: TyCtxt<'tcx>,
+ def_id: DefId,
+) -> GeneratorLayout<'tcx> {
+ assert!(tcx.sess.opts.unstable_opts.drop_tracking_mir);
+ let def_id = def_id.expect_local();
+
+ let (body, _) = tcx.mir_promoted(ty::WithOptConstParam::unknown(def_id));
+ let body = body.borrow();
+ let body = &*body;
+
+ // The first argument is the generator type passed by value
+ let gen_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
+
+ // Get the interior types and substs which typeck computed
+ let movable = match *gen_ty.kind() {
+ ty::Generator(_, _, movability) => movability == hir::Movability::Movable,
+ _ => span_bug!(body.span, "unexpected generator type {}", gen_ty),
+ };
+
+ // When first entering the generator, move the resume argument into its new local.
+ let always_live_locals = always_storage_live_locals(&body);
+
+ let liveness_info = locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);
+
+ // Extract locals which are live across suspension point into `layout`
+ // `remap` gives a mapping from local indices onto generator struct indices
+ // `storage_liveness` tells us which locals have live storage at suspension points
+ let (_, generator_layout, _) = compute_layout(tcx, liveness_info, body);
+
+ check_suspend_tys(tcx, &generator_layout, &body);
+
+ generator_layout
+}
+
impl<'tcx> MirPass<'tcx> for StateTransform {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let Some(yield_ty) = body.yield_ty() else {
@@ -1363,14 +1426,14 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
// The first argument is the generator type passed by value
let gen_ty = body.local_decls.raw[1].ty;
- // Get the interior types and substs which typeck computed
- let (upvars, interior, discr_ty, movable) = match *gen_ty.kind() {
+ // Get the discriminant type and substs which typeck computed
+ let (discr_ty, upvars, interior, movable) = match *gen_ty.kind() {
ty::Generator(_, substs, movability) => {
let substs = substs.as_generator();
(
- substs.upvar_tys().collect(),
- substs.witness(),
substs.discr_ty(tcx),
+ substs.upvar_tys().collect::<Vec<_>>(),
+ substs.witness(),
movability == hir::Movability::Movable,
)
}
@@ -1386,13 +1449,13 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
// Compute Poll<return_ty>
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
let poll_adt_ref = tcx.adt_def(poll_did);
- let poll_substs = tcx.intern_substs(&[body.return_ty().into()]);
+ let poll_substs = tcx.mk_substs(&[body.return_ty().into()]);
(poll_adt_ref, poll_substs)
} else {
// Compute GeneratorState<yield_ty, return_ty>
let state_did = tcx.require_lang_item(LangItem::GeneratorState, None);
let state_adt_ref = tcx.adt_def(state_did);
- let state_substs = tcx.intern_substs(&[yield_ty.into(), body.return_ty().into()]);
+ let state_substs = tcx.mk_substs(&[yield_ty.into(), body.return_ty().into()]);
(state_adt_ref, state_substs)
};
let ret_ty = tcx.mk_adt(state_adt_ref, state_substs);
@@ -1417,7 +1480,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
// When first entering the generator, move the resume argument into its new local.
let source_info = SourceInfo::outermost(body.span);
- let stmts = &mut body.basic_blocks_mut()[BasicBlock::new(0)].statements;
+ let stmts = &mut body.basic_blocks_mut()[START_BLOCK].statements;
stmts.insert(
0,
Statement {
@@ -1434,8 +1497,6 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
let liveness_info =
locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);
- sanitize_witness(tcx, body, interior, upvars, &liveness_info.saved_locals);
-
if tcx.sess.opts.unstable_opts.validate_mir {
let mut vis = EnsureGeneratorFieldAssignmentsNeverAlias {
assigned_local: None,
@@ -1449,7 +1510,13 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
// Extract locals which are live across suspension point into `layout`
// `remap` gives a mapping from local indices onto generator struct indices
// `storage_liveness` tells us which locals have live storage at suspension points
- let (remap, layout, storage_liveness) = compute_layout(liveness_info, body);
+ let (remap, layout, storage_liveness) = compute_layout(tcx, liveness_info, body);
+
+ if tcx.sess.opts.unstable_opts.validate_mir
+ && !tcx.sess.opts.unstable_opts.drop_tracking_mir
+ {
+ sanitize_witness(tcx, body, interior, upvars, &layout);
+ }
let can_return = can_return(tcx, body, tcx.param_env(body.source.def_id()));
@@ -1583,6 +1650,7 @@ impl<'tcx> Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> {
| StatementKind::AscribeUserType(..)
| StatementKind::Coverage(..)
| StatementKind::Intrinsic(..)
+ | StatementKind::ConstEvalCounter
| StatementKind::Nop => {}
}
}
@@ -1631,3 +1699,212 @@ impl<'tcx> Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> {
}
}
}
+
+fn check_suspend_tys<'tcx>(tcx: TyCtxt<'tcx>, layout: &GeneratorLayout<'tcx>, body: &Body<'tcx>) {
+ let mut linted_tys = FxHashSet::default();
+
+ // We want a user-facing param-env.
+ let param_env = tcx.param_env(body.source.def_id());
+
+ for (variant, yield_source_info) in
+ layout.variant_fields.iter().zip(&layout.variant_source_info)
+ {
+ debug!(?variant);
+ for &local in variant {
+ let decl = &layout.field_tys[local];
+ debug!(?decl);
+
+ if !decl.ignore_for_traits && linted_tys.insert(decl.ty) {
+ let Some(hir_id) = decl.source_info.scope.lint_root(&body.source_scopes) else { continue };
+
+ check_must_not_suspend_ty(
+ tcx,
+ decl.ty,
+ hir_id,
+ param_env,
+ SuspendCheckData {
+ source_span: decl.source_info.span,
+ yield_span: yield_source_info.span,
+ plural_len: 1,
+ ..Default::default()
+ },
+ );
+ }
+ }
+ }
+}
+
+#[derive(Default)]
+struct SuspendCheckData<'a> {
+ source_span: Span,
+ yield_span: Span,
+ descr_pre: &'a str,
+ descr_post: &'a str,
+ plural_len: usize,
+}
+
+// Returns whether it emitted a diagnostic or not
+// Note that this fn and the proceeding one are based on the code
+// for creating must_use diagnostics
+//
+// Note that this technique was chosen over things like a `Suspend` marker trait
+// as it is simpler and has precedent in the compiler
+fn check_must_not_suspend_ty<'tcx>(
+ tcx: TyCtxt<'tcx>,
+ ty: Ty<'tcx>,
+ hir_id: hir::HirId,
+ param_env: ty::ParamEnv<'tcx>,
+ data: SuspendCheckData<'_>,
+) -> bool {
+ if ty.is_unit() {
+ return false;
+ }
+
+ let plural_suffix = pluralize!(data.plural_len);
+
+ debug!("Checking must_not_suspend for {}", ty);
+
+ match *ty.kind() {
+ ty::Adt(..) if ty.is_box() => {
+ let boxed_ty = ty.boxed_ty();
+ let descr_pre = &format!("{}boxed ", data.descr_pre);
+ check_must_not_suspend_ty(
+ tcx,
+ boxed_ty,
+ hir_id,
+ param_env,
+ SuspendCheckData { descr_pre, ..data },
+ )
+ }
+ ty::Adt(def, _) => check_must_not_suspend_def(tcx, def.did(), hir_id, data),
+ // FIXME: support adding the attribute to TAITs
+ ty::Alias(ty::Opaque, ty::AliasTy { def_id: def, .. }) => {
+ let mut has_emitted = false;
+ for &(predicate, _) in tcx.explicit_item_bounds(def) {
+ // We only look at the `DefId`, so it is safe to skip the binder here.
+ if let ty::PredicateKind::Clause(ty::Clause::Trait(ref poly_trait_predicate)) =
+ predicate.kind().skip_binder()
+ {
+ let def_id = poly_trait_predicate.trait_ref.def_id;
+ let descr_pre = &format!("{}implementer{} of ", data.descr_pre, plural_suffix);
+ if check_must_not_suspend_def(
+ tcx,
+ def_id,
+ hir_id,
+ SuspendCheckData { descr_pre, ..data },
+ ) {
+ has_emitted = true;
+ break;
+ }
+ }
+ }
+ has_emitted
+ }
+ ty::Dynamic(binder, _, _) => {
+ let mut has_emitted = false;
+ for predicate in binder.iter() {
+ if let ty::ExistentialPredicate::Trait(ref trait_ref) = predicate.skip_binder() {
+ let def_id = trait_ref.def_id;
+ let descr_post = &format!(" trait object{}{}", plural_suffix, data.descr_post);
+ if check_must_not_suspend_def(
+ tcx,
+ def_id,
+ hir_id,
+ SuspendCheckData { descr_post, ..data },
+ ) {
+ has_emitted = true;
+ break;
+ }
+ }
+ }
+ has_emitted
+ }
+ ty::Tuple(fields) => {
+ let mut has_emitted = false;
+ for (i, ty) in fields.iter().enumerate() {
+ let descr_post = &format!(" in tuple element {i}");
+ if check_must_not_suspend_ty(
+ tcx,
+ ty,
+ hir_id,
+ param_env,
+ SuspendCheckData { descr_post, ..data },
+ ) {
+ has_emitted = true;
+ }
+ }
+ has_emitted
+ }
+ ty::Array(ty, len) => {
+ let descr_pre = &format!("{}array{} of ", data.descr_pre, plural_suffix);
+ check_must_not_suspend_ty(
+ tcx,
+ ty,
+ hir_id,
+ param_env,
+ SuspendCheckData {
+ descr_pre,
+ plural_len: len.try_eval_target_usize(tcx, param_env).unwrap_or(0) as usize + 1,
+ ..data
+ },
+ )
+ }
+ // If drop tracking is enabled, we want to look through references, since the referrent
+ // may not be considered live across the await point.
+ ty::Ref(_region, ty, _mutability) => {
+ let descr_pre = &format!("{}reference{} to ", data.descr_pre, plural_suffix);
+ check_must_not_suspend_ty(
+ tcx,
+ ty,
+ hir_id,
+ param_env,
+ SuspendCheckData { descr_pre, ..data },
+ )
+ }
+ _ => false,
+ }
+}
+
+fn check_must_not_suspend_def(
+ tcx: TyCtxt<'_>,
+ def_id: DefId,
+ hir_id: hir::HirId,
+ data: SuspendCheckData<'_>,
+) -> bool {
+ if let Some(attr) = tcx.get_attr(def_id, sym::must_not_suspend) {
+ let msg = format!(
+ "{}`{}`{} held across a suspend point, but should not be",
+ data.descr_pre,
+ tcx.def_path_str(def_id),
+ data.descr_post,
+ );
+ tcx.struct_span_lint_hir(
+ rustc_session::lint::builtin::MUST_NOT_SUSPEND,
+ hir_id,
+ data.source_span,
+ msg,
+ |lint| {
+ // add span pointing to the offending yield/await
+ lint.span_label(data.yield_span, "the value is held across this suspend point");
+
+ // Add optional reason note
+ if let Some(note) = attr.value_str() {
+ // FIXME(guswynn): consider formatting this better
+ lint.span_note(data.source_span, note.as_str());
+ }
+
+ // Add some quick suggestions on what to do
+ // FIXME: can `drop` work as a suggestion here as well?
+ lint.span_help(
+ data.source_span,
+ "consider using a block (`{ ... }`) \
+ to shrink the value's scope, ending before the suspend point",
+ )
+ },
+ );
+
+ true
+ } else {
+ false
+ }
+}