summaryrefslogtreecommitdiffstats
path: root/compiler/rustc_mir_transform/src/generator.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_mir_transform/src/generator.rs')
-rw-r--r--compiler/rustc_mir_transform/src/generator.rs126
1 files changed, 113 insertions, 13 deletions
diff --git a/compiler/rustc_mir_transform/src/generator.rs b/compiler/rustc_mir_transform/src/generator.rs
index 69f96fe48..39c61a34a 100644
--- a/compiler/rustc_mir_transform/src/generator.rs
+++ b/compiler/rustc_mir_transform/src/generator.rs
@@ -460,6 +460,104 @@ fn replace_local<'tcx>(
new_local
}
+/// Transforms the `body` of the generator applying the following transforms:
+///
+/// - Eliminates all the `get_context` calls that async lowering created.
+/// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`).
+///
+/// The `Local`s that have their types replaced are:
+/// - The `resume` argument itself.
+/// - The argument to `get_context`.
+/// - The yielded value of a `yield`.
+///
+/// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the
+/// `get_context` function is being used to convert that back to a `&mut Context<'_>`.
+///
+/// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection,
+/// but rather directly use `&mut Context<'_>`, however that would currently
+/// lead to higher-kinded lifetime errors.
+/// See <https://github.com/rust-lang/rust/issues/105501>.
+///
+/// The async lowering step and the type / lifetime inference / checking are
+/// still using the `ResumeTy` indirection for the time being, and that indirection
+/// is removed here. After this transform, the generator body only knows about `&mut Context<'_>`.
+fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+ let context_mut_ref = tcx.mk_task_context();
+
+ // replace the type of the `resume` argument
+ replace_resume_ty_local(tcx, body, Local::new(2), context_mut_ref);
+
+ let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, None);
+
+ for bb in BasicBlock::new(0)..body.basic_blocks.next_index() {
+ let bb_data = &body[bb];
+ if bb_data.is_cleanup {
+ continue;
+ }
+
+ match &bb_data.terminator().kind {
+ TerminatorKind::Call { func, .. } => {
+ let func_ty = func.ty(body, tcx);
+ if let ty::FnDef(def_id, _) = *func_ty.kind() {
+ if def_id == get_context_def_id {
+ let local = eliminate_get_context_call(&mut body[bb]);
+ replace_resume_ty_local(tcx, body, local, context_mut_ref);
+ }
+ } else {
+ continue;
+ }
+ }
+ TerminatorKind::Yield { resume_arg, .. } => {
+ replace_resume_ty_local(tcx, body, resume_arg.local, context_mut_ref);
+ }
+ _ => {}
+ }
+ }
+}
+
+fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local {
+ let terminator = bb_data.terminator.take().unwrap();
+ if let TerminatorKind::Call { mut args, destination, target, .. } = terminator.kind {
+ let arg = args.pop().unwrap();
+ let local = arg.place().unwrap().local;
+
+ let arg = Rvalue::Use(arg);
+ let assign = Statement {
+ source_info: terminator.source_info,
+ kind: StatementKind::Assign(Box::new((destination, arg))),
+ };
+ bb_data.statements.push(assign);
+ bb_data.terminator = Some(Terminator {
+ source_info: terminator.source_info,
+ kind: TerminatorKind::Goto { target: target.unwrap() },
+ });
+ local
+ } else {
+ bug!();
+ }
+}
+
+#[cfg_attr(not(debug_assertions), allow(unused))]
+fn replace_resume_ty_local<'tcx>(
+ tcx: TyCtxt<'tcx>,
+ body: &mut Body<'tcx>,
+ local: Local,
+ context_mut_ref: Ty<'tcx>,
+) {
+ let local_ty = std::mem::replace(&mut body.local_decls[local].ty, context_mut_ref);
+ // We have to replace the `ResumeTy` that is used for type and borrow checking
+ // with `&mut Context<'_>` in MIR.
+ #[cfg(debug_assertions)]
+ {
+ if let ty::Adt(resume_ty_adt, _) = local_ty.kind() {
+ let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None));
+ assert_eq!(*resume_ty_adt, expected_adt);
+ } else {
+ panic!("expected `ResumeTy`, found `{:?}`", local_ty);
+ };
+ }
+}
+
struct LivenessInfo {
/// Which locals are live across any suspension point.
saved_locals: GeneratorSavedLocals,
@@ -490,7 +588,7 @@ fn locals_live_across_suspend_points<'tcx>(
// Calculate when MIR locals have live storage. This gives us an upper bound of their
// lifetimes.
- let mut storage_live = MaybeStorageLive::new(always_live_locals.clone())
+ let mut storage_live = MaybeStorageLive::new(std::borrow::Cow::Borrowed(always_live_locals))
.into_engine(tcx, body_ref)
.iterate_to_fixpoint()
.into_results_cursor(body_ref);
@@ -877,11 +975,7 @@ fn insert_switch<'tcx>(
let (assign, discr) = transform.get_discr(body);
let switch_targets =
SwitchTargets::new(cases.iter().map(|(i, bb)| ((*i) as u128, *bb)), default_block);
- let switch = TerminatorKind::SwitchInt {
- discr: Operand::Move(discr),
- switch_ty: transform.discr_ty,
- targets: switch_targets,
- };
+ let switch = TerminatorKind::SwitchInt { discr: Operand::Move(discr), targets: switch_targets };
let source_info = SourceInfo::outermost(body.span);
body.basic_blocks_mut().raw.insert(
@@ -1287,13 +1381,13 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
}
};
- let is_async_kind = body.generator_kind().unwrap() != GeneratorKind::Gen;
+ let is_async_kind = matches!(body.generator_kind(), Some(GeneratorKind::Async(_)));
let (state_adt_ref, state_substs) = if is_async_kind {
// Compute Poll<return_ty>
- let state_did = tcx.require_lang_item(LangItem::Poll, None);
- let state_adt_ref = tcx.adt_def(state_did);
- let state_substs = tcx.intern_substs(&[body.return_ty().into()]);
- (state_adt_ref, state_substs)
+ let poll_did = tcx.require_lang_item(LangItem::Poll, None);
+ let poll_adt_ref = tcx.adt_def(poll_did);
+ let poll_substs = tcx.intern_substs(&[body.return_ty().into()]);
+ (poll_adt_ref, poll_substs)
} else {
// Compute GeneratorState<yield_ty, return_ty>
let state_did = tcx.require_lang_item(LangItem::GeneratorState, None);
@@ -1307,13 +1401,19 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
// RETURN_PLACE then is a fresh unused local with type ret_ty.
let new_ret_local = replace_local(RETURN_PLACE, ret_ty, body, tcx);
+ // Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
+ if is_async_kind {
+ transform_async_context(tcx, body);
+ }
+
// We also replace the resume argument and insert an `Assign`.
// This is needed because the resume argument `_2` might be live across a `yield`, in which
// case there is no `Assign` to it that the transform can turn into a store to the generator
// state. After the yield the slot in the generator state would then be uninitialized.
let resume_local = Local::new(2);
- let new_resume_local =
- replace_local(resume_local, body.local_decls[resume_local].ty, body, tcx);
+ let resume_ty =
+ if is_async_kind { tcx.mk_task_context() } else { body.local_decls[resume_local].ty };
+ let new_resume_local = replace_local(resume_local, resume_ty, body, tcx);
// When first entering the generator, move the resume argument into its new local.
let source_info = SourceInfo::outermost(body.span);