summaryrefslogtreecommitdiffstats
path: root/compiler/rustc_infer/src/infer/combine.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_infer/src/infer/combine.rs')
-rw-r--r--compiler/rustc_infer/src/infer/combine.rs233
1 files changed, 114 insertions, 119 deletions
diff --git a/compiler/rustc_infer/src/infer/combine.rs b/compiler/rustc_infer/src/infer/combine.rs
index 72676b718..33292e871 100644
--- a/compiler/rustc_infer/src/infer/combine.rs
+++ b/compiler/rustc_infer/src/infer/combine.rs
@@ -31,13 +31,18 @@ use super::{InferCtxt, MiscVariable, TypeTrace};
use crate::traits::{Obligation, PredicateObligations};
use rustc_data_structures::sso::SsoHashMap;
use rustc_hir::def_id::DefId;
+use rustc_middle::infer::canonical::OriginalQueryValues;
use rustc_middle::infer::unify_key::{ConstVarValue, ConstVariableValue};
use rustc_middle::infer::unify_key::{ConstVariableOrigin, ConstVariableOriginKind};
+use rustc_middle::traits::query::NoSolution;
use rustc_middle::traits::ObligationCause;
use rustc_middle::ty::error::{ExpectedFound, TypeError};
use rustc_middle::ty::relate::{self, Relate, RelateResult, TypeRelation};
use rustc_middle::ty::subst::SubstsRef;
-use rustc_middle::ty::{self, InferConst, Ty, TyCtxt, TypeVisitable};
+use rustc_middle::ty::{
+ self, AliasKind, FallibleTypeFolder, InferConst, ToPredicate, Ty, TyCtxt, TypeFoldable,
+ TypeSuperFoldable, TypeVisitableExt,
+};
use rustc_middle::ty::{IntType, UintType};
use rustc_span::{Span, DUMMY_SP};
@@ -71,7 +76,7 @@ impl<'tcx> InferCtxt<'tcx> {
b: Ty<'tcx>,
) -> RelateResult<'tcx, Ty<'tcx>>
where
- R: TypeRelation<'tcx>,
+ R: ObligationEmittingRelation<'tcx>,
{
let a_is_expected = relation.a_is_expected();
@@ -119,6 +124,15 @@ impl<'tcx> InferCtxt<'tcx> {
Err(TypeError::Sorts(ty::relate::expected_found(relation, a, b)))
}
+ (ty::Alias(AliasKind::Projection, _), _) if self.tcx.trait_solver_next() => {
+ relation.register_type_equate_obligation(a, b);
+ Ok(b)
+ }
+ (_, ty::Alias(AliasKind::Projection, _)) if self.tcx.trait_solver_next() => {
+ relation.register_type_equate_obligation(b, a);
+ Ok(a)
+ }
+
_ => ty::relate::super_relate_tys(relation, a, b),
}
}
@@ -130,7 +144,7 @@ impl<'tcx> InferCtxt<'tcx> {
b: ty::Const<'tcx>,
) -> RelateResult<'tcx, ty::Const<'tcx>>
where
- R: ConstEquateRelation<'tcx>,
+ R: ObligationEmittingRelation<'tcx>,
{
debug!("{}.consts({:?}, {:?})", relation.tag(), a, b);
if a == b {
@@ -140,7 +154,33 @@ impl<'tcx> InferCtxt<'tcx> {
let a = self.shallow_resolve(a);
let b = self.shallow_resolve(b);
- let a_is_expected = relation.a_is_expected();
+ // We should never have to relate the `ty` field on `Const` as it is checked elsewhere that consts have the
+ // correct type for the generic param they are an argument for. However there have been a number of cases
+ // historically where asserting that the types are equal has found bugs in the compiler so this is valuable
+ // to check even if it is a bit nasty impl wise :(
+ //
+ // This probe is probably not strictly necessary but it seems better to be safe and not accidentally find
+ // ourselves with a check to find bugs being required for code to compile because it made inference progress.
+ self.probe(|_| {
+ if a.ty() == b.ty() {
+ return;
+ }
+
+ // We don't have access to trait solving machinery in `rustc_infer` so the logic for determining if the
+ // two const param's types are able to be equal has to go through a canonical query with the actual logic
+ // in `rustc_trait_selection`.
+ let canonical = self.canonicalize_query(
+ (relation.param_env(), a.ty(), b.ty()),
+ &mut OriginalQueryValues::default(),
+ );
+
+ if let Err(NoSolution) = self.tcx.check_tys_might_be_eq(canonical) {
+ self.tcx.sess.delay_span_bug(
+ DUMMY_SP,
+ &format!("cannot relate consts of different types (a={:?}, b={:?})", a, b,),
+ );
+ }
+ });
match (a.kind(), b.kind()) {
(
@@ -158,17 +198,17 @@ impl<'tcx> InferCtxt<'tcx> {
}
(ty::ConstKind::Infer(InferConst::Var(vid)), _) => {
- return self.unify_const_variable(relation.param_env(), vid, b, a_is_expected);
+ return self.unify_const_variable(vid, b);
}
(_, ty::ConstKind::Infer(InferConst::Var(vid))) => {
- return self.unify_const_variable(relation.param_env(), vid, a, !a_is_expected);
+ return self.unify_const_variable(vid, a);
}
(ty::ConstKind::Unevaluated(..), _) if self.tcx.lazy_normalization() => {
// FIXME(#59490): Need to remove the leak check to accommodate
// escaping bound variables here.
if !a.has_escaping_bound_vars() && !b.has_escaping_bound_vars() {
- relation.const_equate_obligation(a, b);
+ relation.register_const_equate_obligation(a, b);
}
return Ok(b);
}
@@ -176,7 +216,7 @@ impl<'tcx> InferCtxt<'tcx> {
// FIXME(#59490): Need to remove the leak check to accommodate
// escaping bound variables here.
if !a.has_escaping_bound_vars() && !b.has_escaping_bound_vars() {
- relation.const_equate_obligation(a, b);
+ relation.register_const_equate_obligation(a, b);
}
return Ok(a);
}
@@ -223,10 +263,8 @@ impl<'tcx> InferCtxt<'tcx> {
#[instrument(level = "debug", skip(self))]
fn unify_const_variable(
&self,
- param_env: ty::ParamEnv<'tcx>,
target_vid: ty::ConstVid<'tcx>,
ct: ty::Const<'tcx>,
- vid_is_expected: bool,
) -> RelateResult<'tcx, ty::Const<'tcx>> {
let (for_universe, span) = {
let mut inner = self.inner.borrow_mut();
@@ -239,8 +277,12 @@ impl<'tcx> InferCtxt<'tcx> {
ConstVariableValue::Unknown { universe } => (universe, var_value.origin.span),
}
};
- let value = ConstInferUnifier { infcx: self, span, param_env, for_universe, target_vid }
- .relate(ct, ct)?;
+ let value = ct.try_fold_with(&mut ConstInferUnifier {
+ infcx: self,
+ span,
+ for_universe,
+ target_vid,
+ })?;
self.inner.borrow_mut().const_unification_table().union_value(
target_vid,
@@ -432,32 +474,18 @@ impl<'infcx, 'tcx> CombineFields<'infcx, 'tcx> {
Ok(Generalization { ty, needs_wf })
}
- pub fn add_const_equate_obligation(
- &mut self,
- a_is_expected: bool,
- a: ty::Const<'tcx>,
- b: ty::Const<'tcx>,
- ) {
- let predicate = if a_is_expected {
- ty::PredicateKind::ConstEquate(a, b)
- } else {
- ty::PredicateKind::ConstEquate(b, a)
- };
- self.obligations.push(Obligation::new(
- self.tcx(),
- self.trace.cause.clone(),
- self.param_env,
- ty::Binder::dummy(predicate),
- ));
+ pub fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
+ self.obligations.extend(obligations.into_iter());
+ }
+
+ pub fn register_predicates(&mut self, obligations: impl IntoIterator<Item: ToPredicate<'tcx>>) {
+ self.obligations.extend(obligations.into_iter().map(|to_pred| {
+ Obligation::new(self.infcx.tcx, self.trace.cause.clone(), self.param_env, to_pred)
+ }))
}
pub fn mark_ambiguous(&mut self) {
- self.obligations.push(Obligation::new(
- self.tcx(),
- self.trace.cause.clone(),
- self.param_env,
- ty::Binder::dummy(ty::PredicateKind::Ambiguous),
- ));
+ self.register_predicates([ty::Binder::dummy(ty::PredicateKind::Ambiguous)]);
}
}
@@ -702,6 +730,10 @@ impl<'tcx> TypeRelation<'tcx> for Generalizer<'_, 'tcx> {
return Ok(r);
}
+ ty::ReError(_) => {
+ return Ok(r);
+ }
+
ty::RePlaceholder(..)
| ty::ReVar(..)
| ty::ReStatic
@@ -772,11 +804,39 @@ impl<'tcx> TypeRelation<'tcx> for Generalizer<'_, 'tcx> {
}
}
-pub trait ConstEquateRelation<'tcx>: TypeRelation<'tcx> {
+pub trait ObligationEmittingRelation<'tcx>: TypeRelation<'tcx> {
+ /// Register obligations that must hold in order for this relation to hold
+ fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>);
+
+ /// Register predicates that must hold in order for this relation to hold. Uses
+ /// a default obligation cause, [`ObligationEmittingRelation::register_obligations`] should
+ /// be used if control over the obligaton causes is required.
+ fn register_predicates(&mut self, obligations: impl IntoIterator<Item: ToPredicate<'tcx>>);
+
/// Register an obligation that both constants must be equal to each other.
///
/// If they aren't equal then the relation doesn't hold.
- fn const_equate_obligation(&mut self, a: ty::Const<'tcx>, b: ty::Const<'tcx>);
+ fn register_const_equate_obligation(&mut self, a: ty::Const<'tcx>, b: ty::Const<'tcx>) {
+ let (a, b) = if self.a_is_expected() { (a, b) } else { (b, a) };
+
+ self.register_predicates([ty::Binder::dummy(if self.tcx().trait_solver_next() {
+ ty::PredicateKind::AliasEq(a.into(), b.into())
+ } else {
+ ty::PredicateKind::ConstEquate(a, b)
+ })]);
+ }
+
+ /// Register an obligation that both types must be equal to each other.
+ ///
+ /// If they aren't equal then the relation doesn't hold.
+ fn register_type_equate_obligation(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) {
+ let (a, b) = if self.a_is_expected() { (a, b) } else { (b, a) };
+
+ self.register_predicates([ty::Binder::dummy(ty::PredicateKind::AliasEq(
+ a.into(),
+ b.into(),
+ ))]);
+ }
}
fn int_unification_error<'tcx>(
@@ -800,8 +860,6 @@ struct ConstInferUnifier<'cx, 'tcx> {
span: Span,
- param_env: ty::ParamEnv<'tcx>,
-
for_universe: ty::UniverseIndex,
/// The vid of the const variable that is in the process of being
@@ -810,61 +868,15 @@ struct ConstInferUnifier<'cx, 'tcx> {
target_vid: ty::ConstVid<'tcx>,
}
-// We use `TypeRelation` here to propagate `RelateResult` upwards.
-//
-// Both inputs are expected to be the same.
-impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> {
- fn tcx(&self) -> TyCtxt<'tcx> {
- self.infcx.tcx
- }
-
- fn intercrate(&self) -> bool {
- assert!(!self.infcx.intercrate);
- false
- }
-
- fn param_env(&self) -> ty::ParamEnv<'tcx> {
- self.param_env
- }
-
- fn tag(&self) -> &'static str {
- "ConstInferUnifier"
- }
-
- fn a_is_expected(&self) -> bool {
- true
- }
-
- fn mark_ambiguous(&mut self) {
- bug!()
- }
+impl<'tcx> FallibleTypeFolder<TyCtxt<'tcx>> for ConstInferUnifier<'_, 'tcx> {
+ type Error = TypeError<'tcx>;
- fn relate_with_variance<T: Relate<'tcx>>(
- &mut self,
- _variance: ty::Variance,
- _info: ty::VarianceDiagInfo<'tcx>,
- a: T,
- b: T,
- ) -> RelateResult<'tcx, T> {
- // We don't care about variance here.
- self.relate(a, b)
- }
-
- fn binders<T>(
- &mut self,
- a: ty::Binder<'tcx, T>,
- b: ty::Binder<'tcx, T>,
- ) -> RelateResult<'tcx, ty::Binder<'tcx, T>>
- where
- T: Relate<'tcx>,
- {
- Ok(a.rebind(self.relate(a.skip_binder(), b.skip_binder())?))
+ fn interner(&self) -> TyCtxt<'tcx> {
+ self.infcx.tcx
}
#[instrument(level = "debug", skip(self), ret)]
- fn tys(&mut self, t: Ty<'tcx>, _t: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> {
- debug_assert_eq!(t, _t);
-
+ fn try_fold_ty(&mut self, t: Ty<'tcx>) -> Result<Ty<'tcx>, TypeError<'tcx>> {
match t.kind() {
&ty::Infer(ty::TyVar(vid)) => {
let vid = self.infcx.inner.borrow_mut().type_variables().root_var(vid);
@@ -872,7 +884,7 @@ impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> {
match probe {
TypeVariableValue::Known { value: u } => {
debug!("ConstOccursChecker: known value {:?}", u);
- self.tys(u, u)
+ u.try_fold_with(self)
}
TypeVariableValue::Unknown { universe } => {
if self.for_universe.can_name(universe) {
@@ -887,27 +899,26 @@ impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> {
.borrow_mut()
.type_variables()
.new_var(self.for_universe, origin);
- Ok(self.tcx().mk_ty_var(new_var_id))
+ Ok(self.interner().mk_ty_var(new_var_id))
}
}
}
ty::Infer(ty::IntVar(_) | ty::FloatVar(_)) => Ok(t),
- _ => relate::super_relate_tys(self, t, t),
+ _ => t.try_super_fold_with(self),
}
}
- fn regions(
+ #[instrument(level = "debug", skip(self), ret)]
+ fn try_fold_region(
&mut self,
r: ty::Region<'tcx>,
- _r: ty::Region<'tcx>,
- ) -> RelateResult<'tcx, ty::Region<'tcx>> {
- debug_assert_eq!(r, _r);
+ ) -> Result<ty::Region<'tcx>, TypeError<'tcx>> {
debug!("ConstInferUnifier: r={:?}", r);
match *r {
// Never make variables for regions bound within the type itself,
// nor for erased regions.
- ty::ReLateBound(..) | ty::ReErased => {
+ ty::ReLateBound(..) | ty::ReErased | ty::ReError(_) => {
return Ok(r);
}
@@ -930,14 +941,8 @@ impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> {
}
}
- #[instrument(level = "debug", skip(self))]
- fn consts(
- &mut self,
- c: ty::Const<'tcx>,
- _c: ty::Const<'tcx>,
- ) -> RelateResult<'tcx, ty::Const<'tcx>> {
- debug_assert_eq!(c, _c);
-
+ #[instrument(level = "debug", skip(self), ret)]
+ fn try_fold_const(&mut self, c: ty::Const<'tcx>) -> Result<ty::Const<'tcx>, TypeError<'tcx>> {
match c.kind() {
ty::ConstKind::Infer(InferConst::Var(vid)) => {
// Check if the current unification would end up
@@ -958,7 +963,7 @@ impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> {
let var_value =
self.infcx.inner.borrow_mut().const_unification_table().probe_value(vid);
match var_value.val {
- ConstVariableValue::Known { value: u } => self.consts(u, u),
+ ConstVariableValue::Known { value: u } => u.try_fold_with(self),
ConstVariableValue::Unknown { universe } => {
if self.for_universe.can_name(universe) {
Ok(c)
@@ -972,22 +977,12 @@ impl<'tcx> TypeRelation<'tcx> for ConstInferUnifier<'_, 'tcx> {
},
},
);
- Ok(self.tcx().mk_const(new_var_id, c.ty()))
+ Ok(self.interner().mk_const(new_var_id, c.ty()))
}
}
}
}
- ty::ConstKind::Unevaluated(ty::UnevaluatedConst { def, substs }) => {
- let substs = self.relate_with_variance(
- ty::Variance::Invariant,
- ty::VarianceDiagInfo::default(),
- substs,
- substs,
- )?;
-
- Ok(self.tcx().mk_const(ty::UnevaluatedConst { def, substs }, c.ty()))
- }
- _ => relate::super_relate_consts(self, c, c),
+ _ => c.try_super_fold_with(self),
}
}
}