summaryrefslogtreecommitdiffstats
path: root/compiler/rustc_const_eval/src/util
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_const_eval/src/util')
-rw-r--r--compiler/rustc_const_eval/src/util/aggregate.rs77
-rw-r--r--compiler/rustc_const_eval/src/util/alignment.rs63
-rw-r--r--compiler/rustc_const_eval/src/util/call_kind.rs146
-rw-r--r--compiler/rustc_const_eval/src/util/collect_writes.rs36
-rw-r--r--compiler/rustc_const_eval/src/util/find_self_call.rs36
-rw-r--r--compiler/rustc_const_eval/src/util/mod.rs10
6 files changed, 368 insertions, 0 deletions
diff --git a/compiler/rustc_const_eval/src/util/aggregate.rs b/compiler/rustc_const_eval/src/util/aggregate.rs
new file mode 100644
index 000000000..180a40043
--- /dev/null
+++ b/compiler/rustc_const_eval/src/util/aggregate.rs
@@ -0,0 +1,77 @@
+use rustc_index::vec::Idx;
+use rustc_middle::mir::*;
+use rustc_middle::ty::{Ty, TyCtxt};
+use rustc_target::abi::VariantIdx;
+
+use std::convert::TryFrom;
+use std::iter::TrustedLen;
+
+/// Expand `lhs = Rvalue::Aggregate(kind, operands)` into assignments to the fields.
+///
+/// Produces something like
+///
+/// (lhs as Variant).field0 = arg0; // We only have a downcast if this is an enum
+/// (lhs as Variant).field1 = arg1;
+/// discriminant(lhs) = variant_index; // If lhs is an enum or generator.
+pub fn expand_aggregate<'tcx>(
+ orig_lhs: Place<'tcx>,
+ operands: impl Iterator<Item = (Operand<'tcx>, Ty<'tcx>)> + TrustedLen,
+ kind: AggregateKind<'tcx>,
+ source_info: SourceInfo,
+ tcx: TyCtxt<'tcx>,
+) -> impl Iterator<Item = Statement<'tcx>> + TrustedLen {
+ let mut lhs = orig_lhs;
+ let mut set_discriminant = None;
+ let active_field_index = match kind {
+ AggregateKind::Adt(adt_did, variant_index, _, _, active_field_index) => {
+ let adt_def = tcx.adt_def(adt_did);
+ if adt_def.is_enum() {
+ set_discriminant = Some(Statement {
+ kind: StatementKind::SetDiscriminant {
+ place: Box::new(orig_lhs),
+ variant_index,
+ },
+ source_info,
+ });
+ lhs = tcx.mk_place_downcast(orig_lhs, adt_def, variant_index);
+ }
+ active_field_index
+ }
+ AggregateKind::Generator(..) => {
+ // Right now we only support initializing generators to
+ // variant 0 (Unresumed).
+ let variant_index = VariantIdx::new(0);
+ set_discriminant = Some(Statement {
+ kind: StatementKind::SetDiscriminant { place: Box::new(orig_lhs), variant_index },
+ source_info,
+ });
+
+ // Operands are upvars stored on the base place, so no
+ // downcast is necessary.
+
+ None
+ }
+ _ => None,
+ };
+
+ let operands = operands.enumerate().map(move |(i, (op, ty))| {
+ let lhs_field = if let AggregateKind::Array(_) = kind {
+ let offset = u64::try_from(i).unwrap();
+ tcx.mk_place_elem(
+ lhs,
+ ProjectionElem::ConstantIndex { offset, min_length: offset + 1, from_end: false },
+ )
+ } else {
+ let field = Field::new(active_field_index.unwrap_or(i));
+ tcx.mk_place_field(lhs, field, ty)
+ };
+ Statement {
+ source_info,
+ kind: StatementKind::Assign(Box::new((lhs_field, Rvalue::Use(op)))),
+ }
+ });
+ [Statement { source_info, kind: StatementKind::Deinit(Box::new(orig_lhs)) }]
+ .into_iter()
+ .chain(operands)
+ .chain(set_discriminant)
+}
diff --git a/compiler/rustc_const_eval/src/util/alignment.rs b/compiler/rustc_const_eval/src/util/alignment.rs
new file mode 100644
index 000000000..4f39dad20
--- /dev/null
+++ b/compiler/rustc_const_eval/src/util/alignment.rs
@@ -0,0 +1,63 @@
+use rustc_middle::mir::*;
+use rustc_middle::ty::{self, TyCtxt};
+use rustc_target::abi::Align;
+
+/// Returns `true` if this place is allowed to be less aligned
+/// than its containing struct (because it is within a packed
+/// struct).
+pub fn is_disaligned<'tcx, L>(
+ tcx: TyCtxt<'tcx>,
+ local_decls: &L,
+ param_env: ty::ParamEnv<'tcx>,
+ place: Place<'tcx>,
+) -> bool
+where
+ L: HasLocalDecls<'tcx>,
+{
+ debug!("is_disaligned({:?})", place);
+ let Some(pack) = is_within_packed(tcx, local_decls, place) else {
+ debug!("is_disaligned({:?}) - not within packed", place);
+ return false;
+ };
+
+ let ty = place.ty(local_decls, tcx).ty;
+ match tcx.layout_of(param_env.and(ty)) {
+ Ok(layout) if layout.align.abi <= pack => {
+ // If the packed alignment is greater or equal to the field alignment, the type won't be
+ // further disaligned.
+ debug!(
+ "is_disaligned({:?}) - align = {}, packed = {}; not disaligned",
+ place,
+ layout.align.abi.bytes(),
+ pack.bytes()
+ );
+ false
+ }
+ _ => {
+ debug!("is_disaligned({:?}) - true", place);
+ true
+ }
+ }
+}
+
+fn is_within_packed<'tcx, L>(
+ tcx: TyCtxt<'tcx>,
+ local_decls: &L,
+ place: Place<'tcx>,
+) -> Option<Align>
+where
+ L: HasLocalDecls<'tcx>,
+{
+ place
+ .iter_projections()
+ .rev()
+ // Stop at `Deref`; standard ABI alignment applies there.
+ .take_while(|(_base, elem)| !matches!(elem, ProjectionElem::Deref))
+ // Consider the packed alignments at play here...
+ .filter_map(|(base, _elem)| {
+ base.ty(local_decls, tcx).ty.ty_adt_def().and_then(|adt| adt.repr().pack)
+ })
+ // ... and compute their minimum.
+ // The overall smallest alignment is what matters.
+ .min()
+}
diff --git a/compiler/rustc_const_eval/src/util/call_kind.rs b/compiler/rustc_const_eval/src/util/call_kind.rs
new file mode 100644
index 000000000..af9d83f06
--- /dev/null
+++ b/compiler/rustc_const_eval/src/util/call_kind.rs
@@ -0,0 +1,146 @@
+//! Common logic for borrowck use-after-move errors when moved into a `fn(self)`,
+//! as well as errors when attempting to call a non-const function in a const
+//! context.
+
+use rustc_hir::def_id::DefId;
+use rustc_hir::lang_items::LangItemGroup;
+use rustc_middle::ty::subst::SubstsRef;
+use rustc_middle::ty::{self, AssocItemContainer, DefIdTree, Instance, ParamEnv, Ty, TyCtxt};
+use rustc_span::symbol::Ident;
+use rustc_span::{sym, DesugaringKind, Span};
+
+#[derive(Clone, Copy, PartialEq, Eq, Debug)]
+pub enum CallDesugaringKind {
+ /// for _ in x {} calls x.into_iter()
+ ForLoopIntoIter,
+ /// x? calls x.branch()
+ QuestionBranch,
+ /// x? calls type_of(x)::from_residual()
+ QuestionFromResidual,
+ /// try { ..; x } calls type_of(x)::from_output(x)
+ TryBlockFromOutput,
+}
+
+impl CallDesugaringKind {
+ pub fn trait_def_id(self, tcx: TyCtxt<'_>) -> DefId {
+ match self {
+ Self::ForLoopIntoIter => tcx.get_diagnostic_item(sym::IntoIterator).unwrap(),
+ Self::QuestionBranch | Self::TryBlockFromOutput => {
+ tcx.lang_items().try_trait().unwrap()
+ }
+ Self::QuestionFromResidual => tcx.get_diagnostic_item(sym::FromResidual).unwrap(),
+ }
+ }
+}
+
+#[derive(Clone, Copy, PartialEq, Eq, Debug)]
+pub enum CallKind<'tcx> {
+ /// A normal method call of the form `receiver.foo(a, b, c)`
+ Normal {
+ self_arg: Option<Ident>,
+ desugaring: Option<(CallDesugaringKind, Ty<'tcx>)>,
+ /// Whether the self type of the method call has an `.as_ref()` method.
+ /// Used for better diagnostics.
+ is_option_or_result: bool,
+ },
+ /// A call to `Fn(..)::call(..)`, desugared from `my_closure(a, b, c)`
+ FnCall { fn_trait_id: DefId, self_ty: Ty<'tcx> },
+ /// A call to an operator trait, desugared from operator syntax (e.g. `a << b`)
+ Operator { self_arg: Option<Ident>, trait_id: DefId, self_ty: Ty<'tcx> },
+ DerefCoercion {
+ /// The `Span` of the `Target` associated type
+ /// in the `Deref` impl we are using.
+ deref_target: Span,
+ /// The type `T::Deref` we are dereferencing to
+ deref_target_ty: Ty<'tcx>,
+ self_ty: Ty<'tcx>,
+ },
+}
+
+pub fn call_kind<'tcx>(
+ tcx: TyCtxt<'tcx>,
+ param_env: ParamEnv<'tcx>,
+ method_did: DefId,
+ method_substs: SubstsRef<'tcx>,
+ fn_call_span: Span,
+ from_hir_call: bool,
+ self_arg: Option<Ident>,
+) -> CallKind<'tcx> {
+ let parent = tcx.opt_associated_item(method_did).and_then(|assoc| {
+ let container_id = assoc.container_id(tcx);
+ match assoc.container {
+ AssocItemContainer::ImplContainer => tcx.trait_id_of_impl(container_id),
+ AssocItemContainer::TraitContainer => Some(container_id),
+ }
+ });
+
+ let fn_call = parent
+ .and_then(|p| tcx.lang_items().group(LangItemGroup::Fn).iter().find(|did| **did == p));
+
+ let operator = (!from_hir_call)
+ .then(|| parent)
+ .flatten()
+ .and_then(|p| tcx.lang_items().group(LangItemGroup::Op).iter().find(|did| **did == p));
+
+ let is_deref = !from_hir_call && tcx.is_diagnostic_item(sym::deref_method, method_did);
+
+ // Check for a 'special' use of 'self' -
+ // an FnOnce call, an operator (e.g. `<<`), or a
+ // deref coercion.
+ let kind = if let Some(&trait_id) = fn_call {
+ Some(CallKind::FnCall { fn_trait_id: trait_id, self_ty: method_substs.type_at(0) })
+ } else if let Some(&trait_id) = operator {
+ Some(CallKind::Operator { self_arg, trait_id, self_ty: method_substs.type_at(0) })
+ } else if is_deref {
+ let deref_target = tcx.get_diagnostic_item(sym::deref_target).and_then(|deref_target| {
+ Instance::resolve(tcx, param_env, deref_target, method_substs).transpose()
+ });
+ if let Some(Ok(instance)) = deref_target {
+ let deref_target_ty = instance.ty(tcx, param_env);
+ Some(CallKind::DerefCoercion {
+ deref_target: tcx.def_span(instance.def_id()),
+ deref_target_ty,
+ self_ty: method_substs.type_at(0),
+ })
+ } else {
+ None
+ }
+ } else {
+ None
+ };
+
+ kind.unwrap_or_else(|| {
+ // This isn't a 'special' use of `self`
+ debug!(?method_did, ?fn_call_span);
+ let desugaring = if Some(method_did) == tcx.lang_items().into_iter_fn()
+ && fn_call_span.desugaring_kind() == Some(DesugaringKind::ForLoop)
+ {
+ Some((CallDesugaringKind::ForLoopIntoIter, method_substs.type_at(0)))
+ } else if fn_call_span.desugaring_kind() == Some(DesugaringKind::QuestionMark) {
+ if Some(method_did) == tcx.lang_items().branch_fn() {
+ Some((CallDesugaringKind::QuestionBranch, method_substs.type_at(0)))
+ } else if Some(method_did) == tcx.lang_items().from_residual_fn() {
+ Some((CallDesugaringKind::QuestionFromResidual, method_substs.type_at(0)))
+ } else {
+ None
+ }
+ } else if Some(method_did) == tcx.lang_items().from_output_fn()
+ && fn_call_span.desugaring_kind() == Some(DesugaringKind::TryBlock)
+ {
+ Some((CallDesugaringKind::TryBlockFromOutput, method_substs.type_at(0)))
+ } else {
+ None
+ };
+ let parent_did = tcx.parent(method_did);
+ let parent_self_ty = (tcx.def_kind(parent_did) == rustc_hir::def::DefKind::Impl)
+ .then_some(parent_did)
+ .and_then(|did| match tcx.type_of(did).kind() {
+ ty::Adt(def, ..) => Some(def.did()),
+ _ => None,
+ });
+ let is_option_or_result = parent_self_ty.map_or(false, |def_id| {
+ matches!(tcx.get_diagnostic_name(def_id), Some(sym::Option | sym::Result))
+ });
+ CallKind::Normal { self_arg, desugaring, is_option_or_result }
+ })
+}
diff --git a/compiler/rustc_const_eval/src/util/collect_writes.rs b/compiler/rustc_const_eval/src/util/collect_writes.rs
new file mode 100644
index 000000000..8d92bb359
--- /dev/null
+++ b/compiler/rustc_const_eval/src/util/collect_writes.rs
@@ -0,0 +1,36 @@
+use rustc_middle::mir::visit::PlaceContext;
+use rustc_middle::mir::visit::Visitor;
+use rustc_middle::mir::{Body, Local, Location};
+
+pub trait FindAssignments {
+ // Finds all statements that assign directly to local (i.e., X = ...)
+ // and returns their locations.
+ fn find_assignments(&self, local: Local) -> Vec<Location>;
+}
+
+impl<'tcx> FindAssignments for Body<'tcx> {
+ fn find_assignments(&self, local: Local) -> Vec<Location> {
+ let mut visitor = FindLocalAssignmentVisitor { needle: local, locations: vec![] };
+ visitor.visit_body(self);
+ visitor.locations
+ }
+}
+
+// The Visitor walks the MIR to return the assignment statements corresponding
+// to a Local.
+struct FindLocalAssignmentVisitor {
+ needle: Local,
+ locations: Vec<Location>,
+}
+
+impl<'tcx> Visitor<'tcx> for FindLocalAssignmentVisitor {
+ fn visit_local(&mut self, local: Local, place_context: PlaceContext, location: Location) {
+ if self.needle != local {
+ return;
+ }
+
+ if place_context.is_place_assignment() {
+ self.locations.push(location);
+ }
+ }
+}
diff --git a/compiler/rustc_const_eval/src/util/find_self_call.rs b/compiler/rustc_const_eval/src/util/find_self_call.rs
new file mode 100644
index 000000000..33ad128ee
--- /dev/null
+++ b/compiler/rustc_const_eval/src/util/find_self_call.rs
@@ -0,0 +1,36 @@
+use rustc_middle::mir::*;
+use rustc_middle::ty::subst::SubstsRef;
+use rustc_middle::ty::{self, TyCtxt};
+use rustc_span::def_id::DefId;
+
+/// Checks if the specified `local` is used as the `self` parameter of a method call
+/// in the provided `BasicBlock`. If it is, then the `DefId` of the called method is
+/// returned.
+pub fn find_self_call<'tcx>(
+ tcx: TyCtxt<'tcx>,
+ body: &Body<'tcx>,
+ local: Local,
+ block: BasicBlock,
+) -> Option<(DefId, SubstsRef<'tcx>)> {
+ debug!("find_self_call(local={:?}): terminator={:?}", local, &body[block].terminator);
+ if let Some(Terminator { kind: TerminatorKind::Call { func, args, .. }, .. }) =
+ &body[block].terminator
+ {
+ debug!("find_self_call: func={:?}", func);
+ if let Operand::Constant(box Constant { literal, .. }) = func {
+ if let ty::FnDef(def_id, substs) = *literal.ty().kind() {
+ if let Some(ty::AssocItem { fn_has_self_parameter: true, .. }) =
+ tcx.opt_associated_item(def_id)
+ {
+ debug!("find_self_call: args={:?}", args);
+ if let [Operand::Move(self_place) | Operand::Copy(self_place), ..] = **args {
+ if self_place.as_local() == Some(local) {
+ return Some((def_id, substs));
+ }
+ }
+ }
+ }
+ }
+ }
+ None
+}
diff --git a/compiler/rustc_const_eval/src/util/mod.rs b/compiler/rustc_const_eval/src/util/mod.rs
new file mode 100644
index 000000000..a1876bed8
--- /dev/null
+++ b/compiler/rustc_const_eval/src/util/mod.rs
@@ -0,0 +1,10 @@
+pub mod aggregate;
+mod alignment;
+mod call_kind;
+pub mod collect_writes;
+mod find_self_call;
+
+pub use self::aggregate::expand_aggregate;
+pub use self::alignment::is_disaligned;
+pub use self::call_kind::{call_kind, CallDesugaringKind, CallKind};
+pub use self::find_self_call::find_self_call;