use std::iter::Step; use std::marker::PhantomData; use std::ops::RangeBounds; use std::ops::{Bound, Range}; use smallvec::SmallVec; use crate::idx::Idx; use crate::vec::IndexVec; #[cfg(test)] mod tests; /// Stores a set of intervals on the indices. /// /// The elements in `map` are sorted and non-adjacent, which means /// the second value of the previous element is *greater* than the /// first value of the following element. #[derive(Debug, Clone)] pub struct IntervalSet { // Start, end map: SmallVec<[(u32, u32); 4]>, domain: usize, _data: PhantomData, } #[inline] fn inclusive_start(range: impl RangeBounds) -> u32 { match range.start_bound() { Bound::Included(start) => start.index() as u32, Bound::Excluded(start) => start.index() as u32 + 1, Bound::Unbounded => 0, } } #[inline] fn inclusive_end(domain: usize, range: impl RangeBounds) -> Option { let end = match range.end_bound() { Bound::Included(end) => end.index() as u32, Bound::Excluded(end) => end.index().checked_sub(1)? as u32, Bound::Unbounded => domain.checked_sub(1)? as u32, }; Some(end) } impl IntervalSet { pub fn new(domain: usize) -> IntervalSet { IntervalSet { map: SmallVec::new(), domain, _data: PhantomData } } pub fn clear(&mut self) { self.map.clear(); } pub fn iter(&self) -> impl Iterator + '_ where I: Step, { self.iter_intervals().flatten() } /// Iterates through intervals stored in the set, in order. pub fn iter_intervals(&self) -> impl Iterator> + '_ where I: Step, { self.map.iter().map(|&(start, end)| I::new(start as usize)..I::new(end as usize + 1)) } /// Returns true if we increased the number of elements present. pub fn insert(&mut self, point: I) -> bool { self.insert_range(point..=point) } /// Returns true if we increased the number of elements present. pub fn insert_range(&mut self, range: impl RangeBounds + Clone) -> bool { let start = inclusive_start(range.clone()); let Some(end) = inclusive_end(self.domain, range) else { // empty range return false; }; if start > end { return false; } // This condition looks a bit weird, but actually makes sense. // // if r.0 == end + 1, then we're actually adjacent, so we want to // continue to the next range. We're looking here for the first // range which starts *non-adjacently* to our end. let next = self.map.partition_point(|r| r.0 <= end + 1); let result = if let Some(right) = next.checked_sub(1) { let (prev_start, prev_end) = self.map[right]; if prev_end + 1 >= start { // If the start for the inserted range is adjacent to the // end of the previous, we can extend the previous range. if start < prev_start { // The first range which ends *non-adjacently* to our start. // And we can ensure that left <= right. let left = self.map.partition_point(|l| l.1 + 1 < start); let min = std::cmp::min(self.map[left].0, start); let max = std::cmp::max(prev_end, end); self.map[right] = (min, max); if left != right { self.map.drain(left..right); } true } else { // We overlap with the previous range, increase it to // include us. // // Make sure we're actually going to *increase* it though -- // it may be that end is just inside the previously existing // set. if end > prev_end { self.map[right].1 = end; true } else { false } } } else { // Otherwise, we don't overlap, so just insert self.map.insert(right + 1, (start, end)); true } } else { if self.map.is_empty() { // Quite common in practice, and expensive to call memcpy // with length zero. self.map.push((start, end)); } else { self.map.insert(next, (start, end)); } true }; debug_assert!( self.check_invariants(), "wrong intervals after insert {start:?}..={end:?} to {self:?}" ); result } pub fn contains(&self, needle: I) -> bool { let needle = needle.index() as u32; let Some(last) = self.map.partition_point(|r| r.0 <= needle).checked_sub(1) else { // All ranges in the map start after the new range's end return false; }; let (_, prev_end) = &self.map[last]; needle <= *prev_end } pub fn superset(&self, other: &IntervalSet) -> bool where I: Step, { let mut sup_iter = self.iter_intervals(); let mut current = None; let contains = |sup: Range, sub: Range, current: &mut Option>| { if sup.end < sub.start { // if `sup.end == sub.start`, the next sup doesn't contain `sub.start` None // continue to the next sup } else if sup.end >= sub.end && sup.start <= sub.start { *current = Some(sup); // save the current sup Some(true) } else { Some(false) } }; other.iter_intervals().all(|sub| { current .take() .and_then(|sup| contains(sup, sub.clone(), &mut current)) .or_else(|| sup_iter.find_map(|sup| contains(sup, sub.clone(), &mut current))) .unwrap_or(false) }) } pub fn is_empty(&self) -> bool { self.map.is_empty() } /// Equivalent to `range.iter().find(|i| !self.contains(i))`. pub fn first_unset_in(&self, range: impl RangeBounds + Clone) -> Option { let start = inclusive_start(range.clone()); let Some(end) = inclusive_end(self.domain, range) else { // empty range return None; }; if start > end { return None; } let Some(last) = self.map.partition_point(|r| r.0 <= start).checked_sub(1) else { // All ranges in the map start after the new range's end return Some(I::new(start as usize)); }; let (_, prev_end) = self.map[last]; if start > prev_end { Some(I::new(start as usize)) } else if prev_end < end { Some(I::new(prev_end as usize + 1)) } else { None } } /// Returns the maximum (last) element present in the set from `range`. pub fn last_set_in(&self, range: impl RangeBounds + Clone) -> Option { let start = inclusive_start(range.clone()); let Some(end) = inclusive_end(self.domain, range) else { // empty range return None; }; if start > end { return None; } let Some(last) = self.map.partition_point(|r| r.0 <= end).checked_sub(1) else { // All ranges in the map start after the new range's end return None; }; let (_, prev_end) = &self.map[last]; if start <= *prev_end { Some(I::new(std::cmp::min(*prev_end, end) as usize)) } else { None } } pub fn insert_all(&mut self) { self.clear(); if let Some(end) = self.domain.checked_sub(1) { self.map.push((0, end.try_into().unwrap())); } debug_assert!(self.check_invariants()); } pub fn union(&mut self, other: &IntervalSet) -> bool where I: Step, { assert_eq!(self.domain, other.domain); let mut did_insert = false; for range in other.iter_intervals() { did_insert |= self.insert_range(range); } debug_assert!(self.check_invariants()); did_insert } // Check the intervals are valid, sorted and non-adjacent fn check_invariants(&self) -> bool { let mut current: Option = None; for (start, end) in &self.map { if start > end || current.is_some_and(|x| x + 1 >= *start) { return false; } current = Some(*end); } current.map_or(true, |x| x < self.domain as u32) } } /// This data structure optimizes for cases where the stored bits in each row /// are expected to be highly contiguous (long ranges of 1s or 0s), in contrast /// to BitMatrix and SparseBitMatrix which are optimized for /// "random"/non-contiguous bits and cheap(er) point queries at the expense of /// memory usage. #[derive(Clone)] pub struct SparseIntervalMatrix where R: Idx, C: Idx, { rows: IndexVec>, column_size: usize, } impl SparseIntervalMatrix { pub fn new(column_size: usize) -> SparseIntervalMatrix { SparseIntervalMatrix { rows: IndexVec::new(), column_size } } pub fn rows(&self) -> impl Iterator { self.rows.indices() } pub fn row(&self, row: R) -> Option<&IntervalSet> { self.rows.get(row) } fn ensure_row(&mut self, row: R) -> &mut IntervalSet { self.rows.ensure_contains_elem(row, || IntervalSet::new(self.column_size)) } pub fn union_row(&mut self, row: R, from: &IntervalSet) -> bool where C: Step, { self.ensure_row(row).union(from) } pub fn union_rows(&mut self, read: R, write: R) -> bool where C: Step, { if read == write || self.rows.get(read).is_none() { return false; } self.ensure_row(write); let (read_row, write_row) = self.rows.pick2_mut(read, write); write_row.union(read_row) } pub fn insert_all_into_row(&mut self, row: R) { self.ensure_row(row).insert_all(); } pub fn insert_range(&mut self, row: R, range: impl RangeBounds + Clone) { self.ensure_row(row).insert_range(range); } pub fn insert(&mut self, row: R, point: C) -> bool { self.ensure_row(row).insert(point) } pub fn contains(&self, row: R, point: C) -> bool { self.row(row).is_some_and(|r| r.contains(point)) } }