summaryrefslogtreecommitdiffstats
path: root/compiler/rustc_index/src/interval
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--compiler/rustc_index/src/interval.rs305
-rw-r--r--compiler/rustc_index/src/interval/tests.rs199
2 files changed, 504 insertions, 0 deletions
diff --git a/compiler/rustc_index/src/interval.rs b/compiler/rustc_index/src/interval.rs
new file mode 100644
index 000000000..3592fb330
--- /dev/null
+++ b/compiler/rustc_index/src/interval.rs
@@ -0,0 +1,305 @@
+use std::iter::Step;
+use std::marker::PhantomData;
+use std::ops::RangeBounds;
+use std::ops::{Bound, Range};
+
+use crate::vec::Idx;
+use crate::vec::IndexVec;
+use smallvec::SmallVec;
+
+#[cfg(test)]
+mod tests;
+
+/// Stores a set of intervals on the indices.
+///
+/// The elements in `map` are sorted and non-adjacent, which means
+/// the second value of the previous element is *greater* than the
+/// first value of the following element.
+#[derive(Debug, Clone)]
+pub struct IntervalSet<I> {
+ // Start, end
+ map: SmallVec<[(u32, u32); 4]>,
+ domain: usize,
+ _data: PhantomData<I>,
+}
+
+#[inline]
+fn inclusive_start<T: Idx>(range: impl RangeBounds<T>) -> u32 {
+ match range.start_bound() {
+ Bound::Included(start) => start.index() as u32,
+ Bound::Excluded(start) => start.index() as u32 + 1,
+ Bound::Unbounded => 0,
+ }
+}
+
+#[inline]
+fn inclusive_end<T: Idx>(domain: usize, range: impl RangeBounds<T>) -> Option<u32> {
+ let end = match range.end_bound() {
+ Bound::Included(end) => end.index() as u32,
+ Bound::Excluded(end) => end.index().checked_sub(1)? as u32,
+ Bound::Unbounded => domain.checked_sub(1)? as u32,
+ };
+ Some(end)
+}
+
+impl<I: Idx> IntervalSet<I> {
+ pub fn new(domain: usize) -> IntervalSet<I> {
+ IntervalSet { map: SmallVec::new(), domain, _data: PhantomData }
+ }
+
+ pub fn clear(&mut self) {
+ self.map.clear();
+ }
+
+ pub fn iter(&self) -> impl Iterator<Item = I> + '_
+ where
+ I: Step,
+ {
+ self.iter_intervals().flatten()
+ }
+
+ /// Iterates through intervals stored in the set, in order.
+ pub fn iter_intervals(&self) -> impl Iterator<Item = std::ops::Range<I>> + '_
+ where
+ I: Step,
+ {
+ self.map.iter().map(|&(start, end)| I::new(start as usize)..I::new(end as usize + 1))
+ }
+
+ /// Returns true if we increased the number of elements present.
+ pub fn insert(&mut self, point: I) -> bool {
+ self.insert_range(point..=point)
+ }
+
+ /// Returns true if we increased the number of elements present.
+ pub fn insert_range(&mut self, range: impl RangeBounds<I> + Clone) -> bool {
+ let start = inclusive_start(range.clone());
+ let Some(end) = inclusive_end(self.domain, range) else {
+ // empty range
+ return false;
+ };
+ if start > end {
+ return false;
+ }
+
+ // This condition looks a bit weird, but actually makes sense.
+ //
+ // if r.0 == end + 1, then we're actually adjacent, so we want to
+ // continue to the next range. We're looking here for the first
+ // range which starts *non-adjacently* to our end.
+ let next = self.map.partition_point(|r| r.0 <= end + 1);
+ let result = if let Some(right) = next.checked_sub(1) {
+ let (prev_start, prev_end) = self.map[right];
+ if prev_end + 1 >= start {
+ // If the start for the inserted range is adjacent to the
+ // end of the previous, we can extend the previous range.
+ if start < prev_start {
+ // The first range which ends *non-adjacently* to our start.
+ // And we can ensure that left <= right.
+ let left = self.map.partition_point(|l| l.1 + 1 < start);
+ let min = std::cmp::min(self.map[left].0, start);
+ let max = std::cmp::max(prev_end, end);
+ self.map[right] = (min, max);
+ if left != right {
+ self.map.drain(left..right);
+ }
+ true
+ } else {
+ // We overlap with the previous range, increase it to
+ // include us.
+ //
+ // Make sure we're actually going to *increase* it though --
+ // it may be that end is just inside the previously existing
+ // set.
+ if end > prev_end {
+ self.map[right].1 = end;
+ true
+ } else {
+ false
+ }
+ }
+ } else {
+ // Otherwise, we don't overlap, so just insert
+ self.map.insert(right + 1, (start, end));
+ true
+ }
+ } else {
+ if self.map.is_empty() {
+ // Quite common in practice, and expensive to call memcpy
+ // with length zero.
+ self.map.push((start, end));
+ } else {
+ self.map.insert(next, (start, end));
+ }
+ true
+ };
+ debug_assert!(
+ self.check_invariants(),
+ "wrong intervals after insert {:?}..={:?} to {:?}",
+ start,
+ end,
+ self
+ );
+ result
+ }
+
+ pub fn contains(&self, needle: I) -> bool {
+ let needle = needle.index() as u32;
+ let Some(last) = self.map.partition_point(|r| r.0 <= needle).checked_sub(1) else {
+ // All ranges in the map start after the new range's end
+ return false;
+ };
+ let (_, prev_end) = &self.map[last];
+ needle <= *prev_end
+ }
+
+ pub fn superset(&self, other: &IntervalSet<I>) -> bool
+ where
+ I: Step,
+ {
+ let mut sup_iter = self.iter_intervals();
+ let mut current = None;
+ let contains = |sup: Range<I>, sub: Range<I>, current: &mut Option<Range<I>>| {
+ if sup.end < sub.start {
+ // if `sup.end == sub.start`, the next sup doesn't contain `sub.start`
+ None // continue to the next sup
+ } else if sup.end >= sub.end && sup.start <= sub.start {
+ *current = Some(sup); // save the current sup
+ Some(true)
+ } else {
+ Some(false)
+ }
+ };
+ other.iter_intervals().all(|sub| {
+ current
+ .take()
+ .and_then(|sup| contains(sup, sub.clone(), &mut current))
+ .or_else(|| sup_iter.find_map(|sup| contains(sup, sub.clone(), &mut current)))
+ .unwrap_or(false)
+ })
+ }
+
+ pub fn is_empty(&self) -> bool {
+ self.map.is_empty()
+ }
+
+ /// Returns the maximum (last) element present in the set from `range`.
+ pub fn last_set_in(&self, range: impl RangeBounds<I> + Clone) -> Option<I> {
+ let start = inclusive_start(range.clone());
+ let Some(end) = inclusive_end(self.domain, range) else {
+ // empty range
+ return None;
+ };
+ if start > end {
+ return None;
+ }
+ let Some(last) = self.map.partition_point(|r| r.0 <= end).checked_sub(1) else {
+ // All ranges in the map start after the new range's end
+ return None;
+ };
+ let (_, prev_end) = &self.map[last];
+ if start <= *prev_end { Some(I::new(std::cmp::min(*prev_end, end) as usize)) } else { None }
+ }
+
+ pub fn insert_all(&mut self) {
+ self.clear();
+ if let Some(end) = self.domain.checked_sub(1) {
+ self.map.push((0, end.try_into().unwrap()));
+ }
+ debug_assert!(self.check_invariants());
+ }
+
+ pub fn union(&mut self, other: &IntervalSet<I>) -> bool
+ where
+ I: Step,
+ {
+ assert_eq!(self.domain, other.domain);
+ let mut did_insert = false;
+ for range in other.iter_intervals() {
+ did_insert |= self.insert_range(range);
+ }
+ debug_assert!(self.check_invariants());
+ did_insert
+ }
+
+ // Check the intervals are valid, sorted and non-adjacent
+ fn check_invariants(&self) -> bool {
+ let mut current: Option<u32> = None;
+ for (start, end) in &self.map {
+ if start > end || current.map_or(false, |x| x + 1 >= *start) {
+ return false;
+ }
+ current = Some(*end);
+ }
+ current.map_or(true, |x| x < self.domain as u32)
+ }
+}
+
+/// This data structure optimizes for cases where the stored bits in each row
+/// are expected to be highly contiguous (long ranges of 1s or 0s), in contrast
+/// to BitMatrix and SparseBitMatrix which are optimized for
+/// "random"/non-contiguous bits and cheap(er) point queries at the expense of
+/// memory usage.
+#[derive(Clone)]
+pub struct SparseIntervalMatrix<R, C>
+where
+ R: Idx,
+ C: Idx,
+{
+ rows: IndexVec<R, IntervalSet<C>>,
+ column_size: usize,
+}
+
+impl<R: Idx, C: Step + Idx> SparseIntervalMatrix<R, C> {
+ pub fn new(column_size: usize) -> SparseIntervalMatrix<R, C> {
+ SparseIntervalMatrix { rows: IndexVec::new(), column_size }
+ }
+
+ pub fn rows(&self) -> impl Iterator<Item = R> {
+ self.rows.indices()
+ }
+
+ pub fn row(&self, row: R) -> Option<&IntervalSet<C>> {
+ self.rows.get(row)
+ }
+
+ fn ensure_row(&mut self, row: R) -> &mut IntervalSet<C> {
+ self.rows.ensure_contains_elem(row, || IntervalSet::new(self.column_size));
+ &mut self.rows[row]
+ }
+
+ pub fn union_row(&mut self, row: R, from: &IntervalSet<C>) -> bool
+ where
+ C: Step,
+ {
+ self.ensure_row(row).union(from)
+ }
+
+ pub fn union_rows(&mut self, read: R, write: R) -> bool
+ where
+ C: Step,
+ {
+ if read == write || self.rows.get(read).is_none() {
+ return false;
+ }
+ self.ensure_row(write);
+ let (read_row, write_row) = self.rows.pick2_mut(read, write);
+ write_row.union(read_row)
+ }
+
+ pub fn insert_all_into_row(&mut self, row: R) {
+ self.ensure_row(row).insert_all();
+ }
+
+ pub fn insert_range(&mut self, row: R, range: impl RangeBounds<C> + Clone) {
+ self.ensure_row(row).insert_range(range);
+ }
+
+ pub fn insert(&mut self, row: R, point: C) -> bool {
+ self.ensure_row(row).insert(point)
+ }
+
+ pub fn contains(&self, row: R, point: C) -> bool {
+ self.row(row).map_or(false, |r| r.contains(point))
+ }
+}
diff --git a/compiler/rustc_index/src/interval/tests.rs b/compiler/rustc_index/src/interval/tests.rs
new file mode 100644
index 000000000..375af60f6
--- /dev/null
+++ b/compiler/rustc_index/src/interval/tests.rs
@@ -0,0 +1,199 @@
+use super::*;
+
+#[test]
+fn insert_collapses() {
+ let mut set = IntervalSet::<u32>::new(10000);
+ set.insert_range(9831..=9837);
+ set.insert_range(43..=9830);
+ assert_eq!(set.iter_intervals().collect::<Vec<_>>(), [43..9838]);
+}
+
+#[test]
+fn contains() {
+ let mut set = IntervalSet::new(300);
+ set.insert(0u32);
+ assert!(set.contains(0));
+ set.insert_range(0..10);
+ assert!(set.contains(9));
+ assert!(!set.contains(10));
+ set.insert_range(10..11);
+ assert!(set.contains(10));
+}
+
+#[test]
+fn insert() {
+ for i in 0..30usize {
+ let mut set = IntervalSet::new(300);
+ for j in i..30usize {
+ set.insert(j);
+ for k in i..j {
+ assert!(set.contains(k));
+ }
+ }
+ }
+
+ let mut set = IntervalSet::new(300);
+ set.insert_range(0..1u32);
+ assert!(set.contains(0), "{:?}", set.map);
+ assert!(!set.contains(1));
+ set.insert_range(1..1);
+ assert!(set.contains(0));
+ assert!(!set.contains(1));
+
+ let mut set = IntervalSet::new(300);
+ set.insert_range(4..5u32);
+ set.insert_range(5..10);
+ assert_eq!(set.iter().collect::<Vec<_>>(), [4, 5, 6, 7, 8, 9]);
+ set.insert_range(3..7);
+ assert_eq!(set.iter().collect::<Vec<_>>(), [3, 4, 5, 6, 7, 8, 9]);
+
+ let mut set = IntervalSet::new(300);
+ set.insert_range(0..10u32);
+ set.insert_range(3..5);
+ assert_eq!(set.iter().collect::<Vec<_>>(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
+
+ let mut set = IntervalSet::new(300);
+ set.insert_range(0..10u32);
+ set.insert_range(0..3);
+ assert_eq!(set.iter().collect::<Vec<_>>(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
+
+ let mut set = IntervalSet::new(300);
+ set.insert_range(0..10u32);
+ set.insert_range(0..10);
+ assert_eq!(set.iter().collect::<Vec<_>>(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
+
+ let mut set = IntervalSet::new(300);
+ set.insert_range(0..10u32);
+ set.insert_range(5..10);
+ assert_eq!(set.iter().collect::<Vec<_>>(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
+
+ let mut set = IntervalSet::new(300);
+ set.insert_range(0..10u32);
+ set.insert_range(5..13);
+ assert_eq!(set.iter().collect::<Vec<_>>(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
+}
+
+#[test]
+fn insert_range() {
+ #[track_caller]
+ fn check<R>(range: R)
+ where
+ R: RangeBounds<usize> + Clone + IntoIterator<Item = usize> + std::fmt::Debug,
+ {
+ let mut set = IntervalSet::new(300);
+ set.insert_range(range.clone());
+ for i in set.iter() {
+ assert!(range.contains(&i));
+ }
+ for i in range.clone() {
+ assert!(set.contains(i), "A: {} in {:?}, inserted {:?}", i, set, range);
+ }
+ set.insert_range(range.clone());
+ for i in set.iter() {
+ assert!(range.contains(&i), "{} in {:?}", i, set);
+ }
+ for i in range.clone() {
+ assert!(set.contains(i), "B: {} in {:?}, inserted {:?}", i, set, range);
+ }
+ }
+ check(10..10);
+ check(10..100);
+ check(10..30);
+ check(0..5);
+ check(0..250);
+ check(200..250);
+
+ check(10..=10);
+ check(10..=100);
+ check(10..=30);
+ check(0..=5);
+ check(0..=250);
+ check(200..=250);
+
+ for i in 0..30 {
+ for j in i..30 {
+ check(i..j);
+ check(i..=j);
+ }
+ }
+}
+
+#[test]
+fn insert_range_dual() {
+ let mut set = IntervalSet::<u32>::new(300);
+ set.insert_range(0..3);
+ assert_eq!(set.iter().collect::<Vec<_>>(), [0, 1, 2]);
+ set.insert_range(5..7);
+ assert_eq!(set.iter().collect::<Vec<_>>(), [0, 1, 2, 5, 6]);
+ set.insert_range(3..4);
+ assert_eq!(set.iter().collect::<Vec<_>>(), [0, 1, 2, 3, 5, 6]);
+ set.insert_range(3..5);
+ assert_eq!(set.iter().collect::<Vec<_>>(), [0, 1, 2, 3, 4, 5, 6]);
+}
+
+#[test]
+fn last_set_before_adjacent() {
+ let mut set = IntervalSet::<u32>::new(300);
+ set.insert_range(0..3);
+ set.insert_range(3..5);
+ assert_eq!(set.last_set_in(0..3), Some(2));
+ assert_eq!(set.last_set_in(0..5), Some(4));
+ assert_eq!(set.last_set_in(3..5), Some(4));
+ set.insert_range(2..5);
+ assert_eq!(set.last_set_in(0..3), Some(2));
+ assert_eq!(set.last_set_in(0..5), Some(4));
+ assert_eq!(set.last_set_in(3..5), Some(4));
+}
+
+#[test]
+fn last_set_in() {
+ fn easy(set: &IntervalSet<usize>, needle: impl RangeBounds<usize>) -> Option<usize> {
+ let mut last_leq = None;
+ for e in set.iter() {
+ if needle.contains(&e) {
+ last_leq = Some(e);
+ }
+ }
+ last_leq
+ }
+
+ #[track_caller]
+ fn cmp(set: &IntervalSet<usize>, needle: impl RangeBounds<usize> + Clone + std::fmt::Debug) {
+ assert_eq!(
+ set.last_set_in(needle.clone()),
+ easy(set, needle.clone()),
+ "{:?} in {:?}",
+ needle,
+ set
+ );
+ }
+ let mut set = IntervalSet::new(300);
+ cmp(&set, 50..=50);
+ set.insert(64);
+ cmp(&set, 64..=64);
+ set.insert(64 - 1);
+ cmp(&set, 0..=64 - 1);
+ cmp(&set, 0..=5);
+ cmp(&set, 10..100);
+ set.insert(100);
+ cmp(&set, 100..110);
+ cmp(&set, 99..100);
+ cmp(&set, 99..=100);
+
+ for i in 0..=30 {
+ for j in i..=30 {
+ for k in 0..30 {
+ let mut set = IntervalSet::new(100);
+ cmp(&set, ..j);
+ cmp(&set, i..);
+ cmp(&set, i..j);
+ cmp(&set, i..=j);
+ set.insert(k);
+ cmp(&set, ..j);
+ cmp(&set, i..);
+ cmp(&set, i..j);
+ cmp(&set, i..=j);
+ }
+ }
+ }
+}