summaryrefslogtreecommitdiffstats
path: root/compiler/rustc_trait_selection/src/traits/engine.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_trait_selection/src/traits/engine.rs')
-rw-r--r--compiler/rustc_trait_selection/src/traits/engine.rs108
1 files changed, 89 insertions, 19 deletions
diff --git a/compiler/rustc_trait_selection/src/traits/engine.rs b/compiler/rustc_trait_selection/src/traits/engine.rs
index e0c8deec9..c028e89e4 100644
--- a/compiler/rustc_trait_selection/src/traits/engine.rs
+++ b/compiler/rustc_trait_selection/src/traits/engine.rs
@@ -1,14 +1,21 @@
use std::cell::RefCell;
+use std::fmt::Debug;
use super::TraitEngine;
use super::{ChalkFulfillmentContext, FulfillmentContext};
-use crate::infer::InferCtxtExt;
-use rustc_data_structures::fx::FxHashSet;
+use crate::traits::NormalizeExt;
+use rustc_data_structures::fx::FxIndexSet;
use rustc_hir::def_id::{DefId, LocalDefId};
+use rustc_infer::infer::at::ToTrace;
+use rustc_infer::infer::canonical::{
+ Canonical, CanonicalVarValues, CanonicalizedQueryResponse, QueryResponse,
+};
use rustc_infer::infer::{InferCtxt, InferOk};
+use rustc_infer::traits::query::Fallible;
use rustc_infer::traits::{
FulfillmentError, Obligation, ObligationCause, PredicateObligation, TraitEngineExt as _,
};
+use rustc_middle::arena::ArenaAllocatable;
use rustc_middle::ty::error::TypeError;
use rustc_middle::ty::ToPredicate;
use rustc_middle::ty::TypeFoldable;
@@ -31,7 +38,7 @@ impl<'tcx> TraitEngineExt<'tcx> for dyn TraitEngine<'tcx> {
fn new_in_snapshot(tcx: TyCtxt<'tcx>) -> Box<Self> {
if tcx.sess.opts.unstable_opts.chalk {
- Box::new(ChalkFulfillmentContext::new())
+ Box::new(ChalkFulfillmentContext::new_in_snapshot())
} else {
Box::new(FulfillmentContext::new_in_snapshot())
}
@@ -86,7 +93,7 @@ impl<'a, 'tcx> ObligationCtxt<'a, 'tcx> {
def_id: DefId,
) {
let tcx = self.infcx.tcx;
- let trait_ref = ty::TraitRef { def_id, substs: tcx.mk_substs_trait(ty, &[]) };
+ let trait_ref = tcx.mk_trait_ref(def_id, [ty]);
self.register_obligation(Obligation {
cause,
recursion_depth: 0,
@@ -97,28 +104,75 @@ impl<'a, 'tcx> ObligationCtxt<'a, 'tcx> {
pub fn normalize<T: TypeFoldable<'tcx>>(
&self,
- cause: ObligationCause<'tcx>,
+ cause: &ObligationCause<'tcx>,
param_env: ty::ParamEnv<'tcx>,
value: T,
) -> T {
- let infer_ok = self.infcx.partially_normalize_associated_types_in(cause, param_env, value);
+ let infer_ok = self.infcx.at(&cause, param_env).normalize(value);
self.register_infer_ok_obligations(infer_ok)
}
- pub fn equate_types(
+ /// Makes `expected <: actual`.
+ pub fn eq_exp<T>(
+ &self,
+ cause: &ObligationCause<'tcx>,
+ param_env: ty::ParamEnv<'tcx>,
+ a_is_expected: bool,
+ a: T,
+ b: T,
+ ) -> Result<(), TypeError<'tcx>>
+ where
+ T: ToTrace<'tcx>,
+ {
+ self.infcx
+ .at(cause, param_env)
+ .eq_exp(a_is_expected, a, b)
+ .map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
+ }
+
+ pub fn eq<T: ToTrace<'tcx>>(
&self,
cause: &ObligationCause<'tcx>,
param_env: ty::ParamEnv<'tcx>,
- expected: Ty<'tcx>,
- actual: Ty<'tcx>,
+ expected: T,
+ actual: T,
) -> Result<(), TypeError<'tcx>> {
- match self.infcx.at(cause, param_env).eq(expected, actual) {
- Ok(InferOk { obligations, value: () }) => {
- self.register_obligations(obligations);
- Ok(())
- }
- Err(e) => Err(e),
- }
+ self.infcx
+ .at(cause, param_env)
+ .eq(expected, actual)
+ .map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
+ }
+
+ /// Checks whether `expected` is a subtype of `actual`: `expected <: actual`.
+ pub fn sub<T: ToTrace<'tcx>>(
+ &self,
+ cause: &ObligationCause<'tcx>,
+ param_env: ty::ParamEnv<'tcx>,
+ expected: T,
+ actual: T,
+ ) -> Result<(), TypeError<'tcx>> {
+ self.infcx
+ .at(cause, param_env)
+ .sup(expected, actual)
+ .map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
+ }
+
+ /// Checks whether `expected` is a supertype of `actual`: `expected :> actual`.
+ pub fn sup<T: ToTrace<'tcx>>(
+ &self,
+ cause: &ObligationCause<'tcx>,
+ param_env: ty::ParamEnv<'tcx>,
+ expected: T,
+ actual: T,
+ ) -> Result<(), TypeError<'tcx>> {
+ self.infcx
+ .at(cause, param_env)
+ .sup(expected, actual)
+ .map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
+ }
+
+ pub fn select_where_possible(&self) -> Vec<FulfillmentError<'tcx>> {
+ self.engine.borrow_mut().select_where_possible(self.infcx)
}
pub fn select_all_or_error(&self) -> Vec<FulfillmentError<'tcx>> {
@@ -130,10 +184,10 @@ impl<'a, 'tcx> ObligationCtxt<'a, 'tcx> {
param_env: ty::ParamEnv<'tcx>,
span: Span,
def_id: LocalDefId,
- ) -> FxHashSet<Ty<'tcx>> {
+ ) -> FxIndexSet<Ty<'tcx>> {
let tcx = self.infcx.tcx;
let assumed_wf_types = tcx.assumed_wf_types(def_id);
- let mut implied_bounds = FxHashSet::default();
+ let mut implied_bounds = FxIndexSet::default();
let hir_id = tcx.hir().local_def_id_to_hir_id(def_id);
let cause = ObligationCause::misc(span, hir_id);
for ty in assumed_wf_types {
@@ -149,9 +203,25 @@ impl<'a, 'tcx> ObligationCtxt<'a, 'tcx> {
// sound and then uncomment this line again.
// implied_bounds.insert(ty);
- let normalized = self.normalize(cause.clone(), param_env, ty);
+ let normalized = self.normalize(&cause, param_env, ty);
implied_bounds.insert(normalized);
}
implied_bounds
}
+
+ pub fn make_canonicalized_query_response<T>(
+ &self,
+ inference_vars: CanonicalVarValues<'tcx>,
+ answer: T,
+ ) -> Fallible<CanonicalizedQueryResponse<'tcx, T>>
+ where
+ T: Debug + TypeFoldable<'tcx>,
+ Canonical<'tcx, QueryResponse<'tcx, T>>: ArenaAllocatable<'tcx>,
+ {
+ self.infcx.make_canonicalized_query_response(
+ inference_vars,
+ answer,
+ &mut **self.engine.borrow_mut(),
+ )
+ }
}