summaryrefslogtreecommitdiffstats
path: root/compiler/rustc_typeck/src/variance/solve.rs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/rustc_typeck/src/variance/solve.rs')
-rw-r--r--compiler/rustc_typeck/src/variance/solve.rs135
1 files changed, 135 insertions, 0 deletions
diff --git a/compiler/rustc_typeck/src/variance/solve.rs b/compiler/rustc_typeck/src/variance/solve.rs
new file mode 100644
index 000000000..97aca621a
--- /dev/null
+++ b/compiler/rustc_typeck/src/variance/solve.rs
@@ -0,0 +1,135 @@
+//! Constraint solving
+//!
+//! The final phase iterates over the constraints, refining the variance
+//! for each inferred until a fixed point is reached. This will be the
+//! optimal solution to the constraints. The final variance for each
+//! inferred is then written into the `variance_map` in the tcx.
+
+use rustc_data_structures::fx::FxHashMap;
+use rustc_hir::def_id::DefId;
+use rustc_middle::ty;
+
+use super::constraints::*;
+use super::terms::VarianceTerm::*;
+use super::terms::*;
+use super::xform::*;
+
+struct SolveContext<'a, 'tcx> {
+ terms_cx: TermsContext<'a, 'tcx>,
+ constraints: Vec<Constraint<'a>>,
+
+ // Maps from an InferredIndex to the inferred value for that variable.
+ solutions: Vec<ty::Variance>,
+}
+
+pub fn solve_constraints<'tcx>(
+ constraints_cx: ConstraintContext<'_, 'tcx>,
+) -> ty::CrateVariancesMap<'tcx> {
+ let ConstraintContext { terms_cx, constraints, .. } = constraints_cx;
+
+ let mut solutions = vec![ty::Bivariant; terms_cx.inferred_terms.len()];
+ for &(id, ref variances) in &terms_cx.lang_items {
+ let InferredIndex(start) = terms_cx.inferred_starts[&id];
+ for (i, &variance) in variances.iter().enumerate() {
+ solutions[start + i] = variance;
+ }
+ }
+
+ let mut solutions_cx = SolveContext { terms_cx, constraints, solutions };
+ solutions_cx.solve();
+ let variances = solutions_cx.create_map();
+
+ ty::CrateVariancesMap { variances }
+}
+
+impl<'a, 'tcx> SolveContext<'a, 'tcx> {
+ fn solve(&mut self) {
+ // Propagate constraints until a fixed point is reached. Note
+ // that the maximum number of iterations is 2C where C is the
+ // number of constraints (each variable can change values at most
+ // twice). Since number of constraints is linear in size of the
+ // input, so is the inference process.
+ let mut changed = true;
+ while changed {
+ changed = false;
+
+ for constraint in &self.constraints {
+ let Constraint { inferred, variance: term } = *constraint;
+ let InferredIndex(inferred) = inferred;
+ let variance = self.evaluate(term);
+ let old_value = self.solutions[inferred];
+ let new_value = glb(variance, old_value);
+ if old_value != new_value {
+ debug!(
+ "updating inferred {} \
+ from {:?} to {:?} due to {:?}",
+ inferred, old_value, new_value, term
+ );
+
+ self.solutions[inferred] = new_value;
+ changed = true;
+ }
+ }
+ }
+ }
+
+ fn enforce_const_invariance(&self, generics: &ty::Generics, variances: &mut [ty::Variance]) {
+ let tcx = self.terms_cx.tcx;
+
+ // Make all const parameters invariant.
+ for param in generics.params.iter() {
+ if let ty::GenericParamDefKind::Const { .. } = param.kind {
+ variances[param.index as usize] = ty::Invariant;
+ }
+ }
+
+ // Make all the const parameters in the parent invariant (recursively).
+ if let Some(def_id) = generics.parent {
+ self.enforce_const_invariance(tcx.generics_of(def_id), variances);
+ }
+ }
+
+ fn create_map(&self) -> FxHashMap<DefId, &'tcx [ty::Variance]> {
+ let tcx = self.terms_cx.tcx;
+
+ let solutions = &self.solutions;
+ self.terms_cx
+ .inferred_starts
+ .iter()
+ .map(|(&def_id, &InferredIndex(start))| {
+ let generics = tcx.generics_of(def_id);
+ let count = generics.count();
+
+ let variances = tcx.arena.alloc_slice(&solutions[start..(start + count)]);
+
+ // Const parameters are always invariant.
+ self.enforce_const_invariance(generics, variances);
+
+ // Functions are permitted to have unused generic parameters: make those invariant.
+ if let ty::FnDef(..) = tcx.type_of(def_id).kind() {
+ for variance in variances.iter_mut() {
+ if *variance == ty::Bivariant {
+ *variance = ty::Invariant;
+ }
+ }
+ }
+
+ (def_id.to_def_id(), &*variances)
+ })
+ .collect()
+ }
+
+ fn evaluate(&self, term: VarianceTermPtr<'a>) -> ty::Variance {
+ match *term {
+ ConstantTerm(v) => v,
+
+ TransformTerm(t1, t2) => {
+ let v1 = self.evaluate(t1);
+ let v2 = self.evaluate(t2);
+ v1.xform(v2)
+ }
+
+ InferredTerm(InferredIndex(index)) => self.solutions[index],
+ }
+ }
+}