diff options
Diffstat (limited to 'compiler/rustc_index/src/interval.rs')
-rw-r--r-- | compiler/rustc_index/src/interval.rs | 305 |
1 files changed, 305 insertions, 0 deletions
diff --git a/compiler/rustc_index/src/interval.rs b/compiler/rustc_index/src/interval.rs new file mode 100644 index 000000000..3592fb330 --- /dev/null +++ b/compiler/rustc_index/src/interval.rs @@ -0,0 +1,305 @@ +use std::iter::Step; +use std::marker::PhantomData; +use std::ops::RangeBounds; +use std::ops::{Bound, Range}; + +use crate::vec::Idx; +use crate::vec::IndexVec; +use smallvec::SmallVec; + +#[cfg(test)] +mod tests; + +/// Stores a set of intervals on the indices. +/// +/// The elements in `map` are sorted and non-adjacent, which means +/// the second value of the previous element is *greater* than the +/// first value of the following element. +#[derive(Debug, Clone)] +pub struct IntervalSet<I> { + // Start, end + map: SmallVec<[(u32, u32); 4]>, + domain: usize, + _data: PhantomData<I>, +} + +#[inline] +fn inclusive_start<T: Idx>(range: impl RangeBounds<T>) -> u32 { + match range.start_bound() { + Bound::Included(start) => start.index() as u32, + Bound::Excluded(start) => start.index() as u32 + 1, + Bound::Unbounded => 0, + } +} + +#[inline] +fn inclusive_end<T: Idx>(domain: usize, range: impl RangeBounds<T>) -> Option<u32> { + let end = match range.end_bound() { + Bound::Included(end) => end.index() as u32, + Bound::Excluded(end) => end.index().checked_sub(1)? as u32, + Bound::Unbounded => domain.checked_sub(1)? as u32, + }; + Some(end) +} + +impl<I: Idx> IntervalSet<I> { + pub fn new(domain: usize) -> IntervalSet<I> { + IntervalSet { map: SmallVec::new(), domain, _data: PhantomData } + } + + pub fn clear(&mut self) { + self.map.clear(); + } + + pub fn iter(&self) -> impl Iterator<Item = I> + '_ + where + I: Step, + { + self.iter_intervals().flatten() + } + + /// Iterates through intervals stored in the set, in order. + pub fn iter_intervals(&self) -> impl Iterator<Item = std::ops::Range<I>> + '_ + where + I: Step, + { + self.map.iter().map(|&(start, end)| I::new(start as usize)..I::new(end as usize + 1)) + } + + /// Returns true if we increased the number of elements present. + pub fn insert(&mut self, point: I) -> bool { + self.insert_range(point..=point) + } + + /// Returns true if we increased the number of elements present. + pub fn insert_range(&mut self, range: impl RangeBounds<I> + Clone) -> bool { + let start = inclusive_start(range.clone()); + let Some(end) = inclusive_end(self.domain, range) else { + // empty range + return false; + }; + if start > end { + return false; + } + + // This condition looks a bit weird, but actually makes sense. + // + // if r.0 == end + 1, then we're actually adjacent, so we want to + // continue to the next range. We're looking here for the first + // range which starts *non-adjacently* to our end. + let next = self.map.partition_point(|r| r.0 <= end + 1); + let result = if let Some(right) = next.checked_sub(1) { + let (prev_start, prev_end) = self.map[right]; + if prev_end + 1 >= start { + // If the start for the inserted range is adjacent to the + // end of the previous, we can extend the previous range. + if start < prev_start { + // The first range which ends *non-adjacently* to our start. + // And we can ensure that left <= right. + let left = self.map.partition_point(|l| l.1 + 1 < start); + let min = std::cmp::min(self.map[left].0, start); + let max = std::cmp::max(prev_end, end); + self.map[right] = (min, max); + if left != right { + self.map.drain(left..right); + } + true + } else { + // We overlap with the previous range, increase it to + // include us. + // + // Make sure we're actually going to *increase* it though -- + // it may be that end is just inside the previously existing + // set. + if end > prev_end { + self.map[right].1 = end; + true + } else { + false + } + } + } else { + // Otherwise, we don't overlap, so just insert + self.map.insert(right + 1, (start, end)); + true + } + } else { + if self.map.is_empty() { + // Quite common in practice, and expensive to call memcpy + // with length zero. + self.map.push((start, end)); + } else { + self.map.insert(next, (start, end)); + } + true + }; + debug_assert!( + self.check_invariants(), + "wrong intervals after insert {:?}..={:?} to {:?}", + start, + end, + self + ); + result + } + + pub fn contains(&self, needle: I) -> bool { + let needle = needle.index() as u32; + let Some(last) = self.map.partition_point(|r| r.0 <= needle).checked_sub(1) else { + // All ranges in the map start after the new range's end + return false; + }; + let (_, prev_end) = &self.map[last]; + needle <= *prev_end + } + + pub fn superset(&self, other: &IntervalSet<I>) -> bool + where + I: Step, + { + let mut sup_iter = self.iter_intervals(); + let mut current = None; + let contains = |sup: Range<I>, sub: Range<I>, current: &mut Option<Range<I>>| { + if sup.end < sub.start { + // if `sup.end == sub.start`, the next sup doesn't contain `sub.start` + None // continue to the next sup + } else if sup.end >= sub.end && sup.start <= sub.start { + *current = Some(sup); // save the current sup + Some(true) + } else { + Some(false) + } + }; + other.iter_intervals().all(|sub| { + current + .take() + .and_then(|sup| contains(sup, sub.clone(), &mut current)) + .or_else(|| sup_iter.find_map(|sup| contains(sup, sub.clone(), &mut current))) + .unwrap_or(false) + }) + } + + pub fn is_empty(&self) -> bool { + self.map.is_empty() + } + + /// Returns the maximum (last) element present in the set from `range`. + pub fn last_set_in(&self, range: impl RangeBounds<I> + Clone) -> Option<I> { + let start = inclusive_start(range.clone()); + let Some(end) = inclusive_end(self.domain, range) else { + // empty range + return None; + }; + if start > end { + return None; + } + let Some(last) = self.map.partition_point(|r| r.0 <= end).checked_sub(1) else { + // All ranges in the map start after the new range's end + return None; + }; + let (_, prev_end) = &self.map[last]; + if start <= *prev_end { Some(I::new(std::cmp::min(*prev_end, end) as usize)) } else { None } + } + + pub fn insert_all(&mut self) { + self.clear(); + if let Some(end) = self.domain.checked_sub(1) { + self.map.push((0, end.try_into().unwrap())); + } + debug_assert!(self.check_invariants()); + } + + pub fn union(&mut self, other: &IntervalSet<I>) -> bool + where + I: Step, + { + assert_eq!(self.domain, other.domain); + let mut did_insert = false; + for range in other.iter_intervals() { + did_insert |= self.insert_range(range); + } + debug_assert!(self.check_invariants()); + did_insert + } + + // Check the intervals are valid, sorted and non-adjacent + fn check_invariants(&self) -> bool { + let mut current: Option<u32> = None; + for (start, end) in &self.map { + if start > end || current.map_or(false, |x| x + 1 >= *start) { + return false; + } + current = Some(*end); + } + current.map_or(true, |x| x < self.domain as u32) + } +} + +/// This data structure optimizes for cases where the stored bits in each row +/// are expected to be highly contiguous (long ranges of 1s or 0s), in contrast +/// to BitMatrix and SparseBitMatrix which are optimized for +/// "random"/non-contiguous bits and cheap(er) point queries at the expense of +/// memory usage. +#[derive(Clone)] +pub struct SparseIntervalMatrix<R, C> +where + R: Idx, + C: Idx, +{ + rows: IndexVec<R, IntervalSet<C>>, + column_size: usize, +} + +impl<R: Idx, C: Step + Idx> SparseIntervalMatrix<R, C> { + pub fn new(column_size: usize) -> SparseIntervalMatrix<R, C> { + SparseIntervalMatrix { rows: IndexVec::new(), column_size } + } + + pub fn rows(&self) -> impl Iterator<Item = R> { + self.rows.indices() + } + + pub fn row(&self, row: R) -> Option<&IntervalSet<C>> { + self.rows.get(row) + } + + fn ensure_row(&mut self, row: R) -> &mut IntervalSet<C> { + self.rows.ensure_contains_elem(row, || IntervalSet::new(self.column_size)); + &mut self.rows[row] + } + + pub fn union_row(&mut self, row: R, from: &IntervalSet<C>) -> bool + where + C: Step, + { + self.ensure_row(row).union(from) + } + + pub fn union_rows(&mut self, read: R, write: R) -> bool + where + C: Step, + { + if read == write || self.rows.get(read).is_none() { + return false; + } + self.ensure_row(write); + let (read_row, write_row) = self.rows.pick2_mut(read, write); + write_row.union(read_row) + } + + pub fn insert_all_into_row(&mut self, row: R) { + self.ensure_row(row).insert_all(); + } + + pub fn insert_range(&mut self, row: R, range: impl RangeBounds<C> + Clone) { + self.ensure_row(row).insert_range(range); + } + + pub fn insert(&mut self, row: R, point: C) -> bool { + self.ensure_row(row).insert(point) + } + + pub fn contains(&self, row: R, point: C) -> bool { + self.row(row).map_or(false, |r| r.contains(point)) + } +} |