diff options
Diffstat (limited to 'vendor/regex-automata-0.2.0/src/util/alphabet.rs')
-rw-r--r-- | vendor/regex-automata-0.2.0/src/util/alphabet.rs | 790 |
1 files changed, 790 insertions, 0 deletions
diff --git a/vendor/regex-automata-0.2.0/src/util/alphabet.rs b/vendor/regex-automata-0.2.0/src/util/alphabet.rs new file mode 100644 index 000000000..0bc1ece58 --- /dev/null +++ b/vendor/regex-automata-0.2.0/src/util/alphabet.rs @@ -0,0 +1,790 @@ +use core::convert::TryFrom; + +use crate::util::{ + bytes::{DeserializeError, SerializeError}, + DebugByte, +}; + +/// Unit represents a single unit of input for DFA based regex engines. +/// +/// **NOTE:** It is not expected for consumers of this crate to need to use +/// this type unless they are implementing their own DFA. And even then, it's +/// not required: implementors may use other techniques to handle input. +/// +/// Typically, a single unit of input for a DFA would be a single byte. +/// However, for the DFAs in this crate, matches are delayed by a single byte +/// in order to handle look-ahead assertions (`\b`, `$` and `\z`). Thus, once +/// we have consumed the haystack, we must run the DFA through one additional +/// transition using an input that indicates the haystack has ended. +/// +/// Since there is no way to represent a sentinel with a `u8` since all +/// possible values *may* be valid inputs to a DFA, this type explicitly adds +/// room for a sentinel value. +/// +/// The sentinel EOI value is always its own equivalence class and is +/// ultimately represented by adding 1 to the maximum equivalence class value. +/// So for example, the regex `^[a-z]+$` might be split into the following +/// equivalence classes: +/// +/// ```text +/// 0 => [\x00-`] +/// 1 => [a-z] +/// 2 => [{-\xFF] +/// 3 => [EOI] +/// ``` +/// +/// Where EOI is the special sentinel value that is always in its own +/// singleton equivalence class. +#[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord)] +pub enum Unit { + U8(u8), + EOI(u16), +} + +impl Unit { + /// Create a new input unit from a byte value. + /// + /// All possible byte values are legal. However, when creating an input + /// unit for a specific DFA, one should be careful to only construct input + /// units that are in that DFA's alphabet. Namely, one way to compact a + /// DFA's in-memory representation is to collapse its transitions to a set + /// of equivalence classes into a set of all possible byte values. If a + /// DFA uses equivalence classes instead of byte values, then the byte + /// given here should be the equivalence class. + pub fn u8(byte: u8) -> Unit { + Unit::U8(byte) + } + + pub fn eoi(num_byte_equiv_classes: usize) -> Unit { + assert!( + num_byte_equiv_classes <= 256, + "max number of byte-based equivalent classes is 256, but got {}", + num_byte_equiv_classes, + ); + Unit::EOI(u16::try_from(num_byte_equiv_classes).unwrap()) + } + + pub fn as_u8(self) -> Option<u8> { + match self { + Unit::U8(b) => Some(b), + Unit::EOI(_) => None, + } + } + + #[cfg(feature = "alloc")] + pub fn as_eoi(self) -> Option<usize> { + match self { + Unit::U8(_) => None, + Unit::EOI(eoi) => Some(eoi as usize), + } + } + + pub fn as_usize(self) -> usize { + match self { + Unit::U8(b) => b as usize, + Unit::EOI(eoi) => eoi as usize, + } + } + + pub fn is_eoi(&self) -> bool { + match *self { + Unit::EOI(_) => true, + _ => false, + } + } + + #[cfg(feature = "alloc")] + pub fn is_word_byte(&self) -> bool { + self.as_u8().map_or(false, crate::util::is_word_byte) + } +} + +impl core::fmt::Debug for Unit { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + match *self { + Unit::U8(b) => write!(f, "{:?}", DebugByte(b)), + Unit::EOI(_) => write!(f, "EOI"), + } + } +} + +/// A representation of byte oriented equivalence classes. +/// +/// This is used in a DFA to reduce the size of the transition table. This can +/// have a particularly large impact not only on the total size of a dense DFA, +/// but also on compile times. +#[derive(Clone, Copy)] +pub struct ByteClasses([u8; 256]); + +impl ByteClasses { + /// Creates a new set of equivalence classes where all bytes are mapped to + /// the same class. + pub fn empty() -> ByteClasses { + ByteClasses([0; 256]) + } + + /// Creates a new set of equivalence classes where each byte belongs to + /// its own equivalence class. + #[cfg(feature = "alloc")] + pub fn singletons() -> ByteClasses { + let mut classes = ByteClasses::empty(); + for i in 0..256 { + classes.set(i as u8, i as u8); + } + classes + } + + /// Deserializes a byte class map from the given slice. If the slice is of + /// insufficient length or otherwise contains an impossible mapping, then + /// an error is returned. Upon success, the number of bytes read along with + /// the map are returned. The number of bytes read is always a multiple of + /// 8. + pub fn from_bytes( + slice: &[u8], + ) -> Result<(ByteClasses, usize), DeserializeError> { + if slice.len() < 256 { + return Err(DeserializeError::buffer_too_small("byte class map")); + } + let mut classes = ByteClasses::empty(); + for (b, &class) in slice[..256].iter().enumerate() { + classes.set(b as u8, class); + } + for b in classes.iter() { + if b.as_usize() >= classes.alphabet_len() { + return Err(DeserializeError::generic( + "found equivalence class greater than alphabet len", + )); + } + } + Ok((classes, 256)) + } + + /// Writes this byte class map to the given byte buffer. if the given + /// buffer is too small, then an error is returned. Upon success, the total + /// number of bytes written is returned. The number of bytes written is + /// guaranteed to be a multiple of 8. + pub fn write_to( + &self, + mut dst: &mut [u8], + ) -> Result<usize, SerializeError> { + let nwrite = self.write_to_len(); + if dst.len() < nwrite { + return Err(SerializeError::buffer_too_small("byte class map")); + } + for b in 0..=255 { + dst[0] = self.get(b); + dst = &mut dst[1..]; + } + Ok(nwrite) + } + + /// Returns the total number of bytes written by `write_to`. + pub fn write_to_len(&self) -> usize { + 256 + } + + /// Set the equivalence class for the given byte. + #[inline] + pub fn set(&mut self, byte: u8, class: u8) { + self.0[byte as usize] = class; + } + + /// Get the equivalence class for the given byte. + #[inline] + pub fn get(&self, byte: u8) -> u8 { + self.0[byte as usize] + } + + /// Get the equivalence class for the given byte while forcefully + /// eliding bounds checks. + #[inline] + pub unsafe fn get_unchecked(&self, byte: u8) -> u8 { + *self.0.get_unchecked(byte as usize) + } + + /// Get the equivalence class for the given input unit and return the + /// class as a `usize`. + #[inline] + pub fn get_by_unit(&self, unit: Unit) -> usize { + match unit { + Unit::U8(b) => usize::try_from(self.get(b)).unwrap(), + Unit::EOI(b) => usize::try_from(b).unwrap(), + } + } + + #[inline] + pub fn eoi(&self) -> Unit { + Unit::eoi(self.alphabet_len().checked_sub(1).unwrap()) + } + + /// Return the total number of elements in the alphabet represented by + /// these equivalence classes. Equivalently, this returns the total number + /// of equivalence classes. + #[inline] + pub fn alphabet_len(&self) -> usize { + // Add one since the number of equivalence classes is one bigger than + // the last one. But add another to account for the final EOI class + // that isn't explicitly represented. + self.0[255] as usize + 1 + 1 + } + + /// Returns the stride, as a base-2 exponent, required for these + /// equivalence classes. + /// + /// The stride is always the smallest power of 2 that is greater than or + /// equal to the alphabet length. This is done so that converting between + /// state IDs and indices can be done with shifts alone, which is much + /// faster than integer division. + #[cfg(feature = "alloc")] + pub fn stride2(&self) -> usize { + self.alphabet_len().next_power_of_two().trailing_zeros() as usize + } + + /// Returns true if and only if every byte in this class maps to its own + /// equivalence class. Equivalently, there are 257 equivalence classes + /// and each class contains exactly one byte (plus the special EOI class). + #[inline] + pub fn is_singleton(&self) -> bool { + self.alphabet_len() == 257 + } + + /// Returns an iterator over all equivalence classes in this set. + pub fn iter(&self) -> ByteClassIter<'_> { + ByteClassIter { classes: self, i: 0 } + } + + /// Returns an iterator over a sequence of representative bytes from each + /// equivalence class. Namely, this yields exactly N items, where N is + /// equivalent to the number of equivalence classes. Each item is an + /// arbitrary byte drawn from each equivalence class. + /// + /// This is useful when one is determinizing an NFA and the NFA's alphabet + /// hasn't been converted to equivalence classes yet. Picking an arbitrary + /// byte from each equivalence class then permits a full exploration of + /// the NFA instead of using every possible byte value. + #[cfg(feature = "alloc")] + pub fn representatives(&self) -> ByteClassRepresentatives<'_> { + ByteClassRepresentatives { classes: self, byte: 0, last_class: None } + } + + /// Returns an iterator of the bytes in the given equivalence class. + pub fn elements(&self, class: Unit) -> ByteClassElements { + ByteClassElements { classes: self, class, byte: 0 } + } + + /// Returns an iterator of byte ranges in the given equivalence class. + /// + /// That is, a sequence of contiguous ranges are returned. Typically, every + /// class maps to a single contiguous range. + fn element_ranges(&self, class: Unit) -> ByteClassElementRanges { + ByteClassElementRanges { elements: self.elements(class), range: None } + } +} + +impl core::fmt::Debug for ByteClasses { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + if self.is_singleton() { + write!(f, "ByteClasses({{singletons}})") + } else { + write!(f, "ByteClasses(")?; + for (i, class) in self.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{:?} => [", class.as_usize())?; + for (start, end) in self.element_ranges(class) { + if start == end { + write!(f, "{:?}", start)?; + } else { + write!(f, "{:?}-{:?}", start, end)?; + } + } + write!(f, "]")?; + } + write!(f, ")") + } + } +} + +/// An iterator over each equivalence class. +#[derive(Debug)] +pub struct ByteClassIter<'a> { + classes: &'a ByteClasses, + i: usize, +} + +impl<'a> Iterator for ByteClassIter<'a> { + type Item = Unit; + + fn next(&mut self) -> Option<Unit> { + if self.i + 1 == self.classes.alphabet_len() { + self.i += 1; + Some(self.classes.eoi()) + } else if self.i < self.classes.alphabet_len() { + let class = self.i as u8; + self.i += 1; + Some(Unit::u8(class)) + } else { + None + } + } +} + +/// An iterator over representative bytes from each equivalence class. +#[cfg(feature = "alloc")] +#[derive(Debug)] +pub struct ByteClassRepresentatives<'a> { + classes: &'a ByteClasses, + byte: usize, + last_class: Option<u8>, +} + +#[cfg(feature = "alloc")] +impl<'a> Iterator for ByteClassRepresentatives<'a> { + type Item = Unit; + + fn next(&mut self) -> Option<Unit> { + while self.byte < 256 { + let byte = self.byte as u8; + let class = self.classes.get(byte); + self.byte += 1; + + if self.last_class != Some(class) { + self.last_class = Some(class); + return Some(Unit::u8(byte)); + } + } + if self.byte == 256 { + self.byte += 1; + return Some(self.classes.eoi()); + } + None + } +} + +/// An iterator over all elements in an equivalence class. +#[derive(Debug)] +pub struct ByteClassElements<'a> { + classes: &'a ByteClasses, + class: Unit, + byte: usize, +} + +impl<'a> Iterator for ByteClassElements<'a> { + type Item = Unit; + + fn next(&mut self) -> Option<Unit> { + while self.byte < 256 { + let byte = self.byte as u8; + self.byte += 1; + if self.class.as_u8() == Some(self.classes.get(byte)) { + return Some(Unit::u8(byte)); + } + } + if self.byte < 257 { + self.byte += 1; + if self.class.is_eoi() { + return Some(Unit::eoi(256)); + } + } + None + } +} + +/// An iterator over all elements in an equivalence class expressed as a +/// sequence of contiguous ranges. +#[derive(Debug)] +pub struct ByteClassElementRanges<'a> { + elements: ByteClassElements<'a>, + range: Option<(Unit, Unit)>, +} + +impl<'a> Iterator for ByteClassElementRanges<'a> { + type Item = (Unit, Unit); + + fn next(&mut self) -> Option<(Unit, Unit)> { + loop { + let element = match self.elements.next() { + None => return self.range.take(), + Some(element) => element, + }; + match self.range.take() { + None => { + self.range = Some((element, element)); + } + Some((start, end)) => { + if end.as_usize() + 1 != element.as_usize() + || element.is_eoi() + { + self.range = Some((element, element)); + return Some((start, end)); + } + self.range = Some((start, element)); + } + } + } + } +} + +/// A byte class set keeps track of an *approximation* of equivalence classes +/// of bytes during NFA construction. That is, every byte in an equivalence +/// class cannot discriminate between a match and a non-match. +/// +/// For example, in the regex `[ab]+`, the bytes `a` and `b` would be in the +/// same equivalence class because it never matters whether an `a` or a `b` is +/// seen, and no combination of `a`s and `b`s in the text can discriminate a +/// match. +/// +/// Note though that this does not compute the minimal set of equivalence +/// classes. For example, in the regex `[ac]+`, both `a` and `c` are in the +/// same equivalence class for the same reason that `a` and `b` are in the +/// same equivalence class in the aforementioned regex. However, in this +/// implementation, `a` and `c` are put into distinct equivalence classes. The +/// reason for this is implementation complexity. In the future, we should +/// endeavor to compute the minimal equivalence classes since they can have a +/// rather large impact on the size of the DFA. (Doing this will likely require +/// rethinking how equivalence classes are computed, including changing the +/// representation here, which is only able to group contiguous bytes into the +/// same equivalence class.) +#[derive(Clone, Debug)] +pub struct ByteClassSet(ByteSet); + +impl ByteClassSet { + /// Create a new set of byte classes where all bytes are part of the same + /// equivalence class. + #[cfg(feature = "alloc")] + pub fn empty() -> Self { + ByteClassSet(ByteSet::empty()) + } + + /// Indicate the the range of byte given (inclusive) can discriminate a + /// match between it and all other bytes outside of the range. + #[cfg(feature = "alloc")] + pub fn set_range(&mut self, start: u8, end: u8) { + debug_assert!(start <= end); + if start > 0 { + self.0.add(start - 1); + } + self.0.add(end); + } + + /// Add the contiguous ranges in the set given to this byte class set. + #[cfg(feature = "alloc")] + pub fn add_set(&mut self, set: &ByteSet) { + for (start, end) in set.iter_ranges() { + self.set_range(start, end); + } + } + + /// Convert this boolean set to a map that maps all byte values to their + /// corresponding equivalence class. The last mapping indicates the largest + /// equivalence class identifier (which is never bigger than 255). + #[cfg(feature = "alloc")] + pub fn byte_classes(&self) -> ByteClasses { + let mut classes = ByteClasses::empty(); + let mut class = 0u8; + let mut b = 0u8; + loop { + classes.set(b, class); + if b == 255 { + break; + } + if self.0.contains(b) { + class = class.checked_add(1).unwrap(); + } + b = b.checked_add(1).unwrap(); + } + classes + } +} + +/// A simple set of bytes that is reasonably cheap to copy and allocation free. +#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] +pub struct ByteSet { + bits: BitSet, +} + +/// The representation of a byte set. Split out so that we can define a +/// convenient Debug impl for it while keeping "ByteSet" in the output. +#[derive(Clone, Copy, Default, Eq, PartialEq)] +struct BitSet([u128; 2]); + +impl ByteSet { + /// Create an empty set of bytes. + #[cfg(feature = "alloc")] + pub fn empty() -> ByteSet { + ByteSet { bits: BitSet([0; 2]) } + } + + /// Add a byte to this set. + /// + /// If the given byte already belongs to this set, then this is a no-op. + #[cfg(feature = "alloc")] + pub fn add(&mut self, byte: u8) { + let bucket = byte / 128; + let bit = byte % 128; + self.bits.0[bucket as usize] |= 1 << bit; + } + + /// Add an inclusive range of bytes. + #[cfg(feature = "alloc")] + pub fn add_all(&mut self, start: u8, end: u8) { + for b in start..=end { + self.add(b); + } + } + + /// Remove a byte from this set. + /// + /// If the given byte is not in this set, then this is a no-op. + #[cfg(feature = "alloc")] + pub fn remove(&mut self, byte: u8) { + let bucket = byte / 128; + let bit = byte % 128; + self.bits.0[bucket as usize] &= !(1 << bit); + } + + /// Remove an inclusive range of bytes. + #[cfg(feature = "alloc")] + pub fn remove_all(&mut self, start: u8, end: u8) { + for b in start..=end { + self.remove(b); + } + } + + /// Return true if and only if the given byte is in this set. + pub fn contains(&self, byte: u8) -> bool { + let bucket = byte / 128; + let bit = byte % 128; + self.bits.0[bucket as usize] & (1 << bit) > 0 + } + + /// Return true if and only if the given inclusive range of bytes is in + /// this set. + #[cfg(feature = "alloc")] + pub fn contains_range(&self, start: u8, end: u8) -> bool { + (start..=end).all(|b| self.contains(b)) + } + + /// Returns an iterator over all bytes in this set. + #[cfg(feature = "alloc")] + pub fn iter(&self) -> ByteSetIter { + ByteSetIter { set: self, b: 0 } + } + + /// Returns an iterator over all contiguous ranges of bytes in this set. + #[cfg(feature = "alloc")] + pub fn iter_ranges(&self) -> ByteSetRangeIter { + ByteSetRangeIter { set: self, b: 0 } + } + + /// Return the number of bytes in this set. + #[cfg(feature = "alloc")] + pub fn len(&self) -> usize { + (self.bits.0[0].count_ones() + self.bits.0[1].count_ones()) as usize + } + + /// Return true if and only if this set is empty. + #[cfg(feature = "alloc")] + pub fn is_empty(&self) -> bool { + self.bits.0 == [0, 0] + } +} + +impl core::fmt::Debug for BitSet { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + let mut fmtd = f.debug_set(); + for b in (0..256).map(|b| b as u8) { + if (ByteSet { bits: *self }).contains(b) { + fmtd.entry(&b); + } + } + fmtd.finish() + } +} + +#[derive(Debug)] +pub struct ByteSetIter<'a> { + set: &'a ByteSet, + b: usize, +} + +impl<'a> Iterator for ByteSetIter<'a> { + type Item = u8; + + fn next(&mut self) -> Option<u8> { + while self.b <= 255 { + let b = self.b as u8; + self.b += 1; + if self.set.contains(b) { + return Some(b); + } + } + None + } +} + +#[derive(Debug)] +pub struct ByteSetRangeIter<'a> { + set: &'a ByteSet, + b: usize, +} + +impl<'a> Iterator for ByteSetRangeIter<'a> { + type Item = (u8, u8); + + fn next(&mut self) -> Option<(u8, u8)> { + while self.b <= 255 { + let start = self.b as u8; + self.b += 1; + if !self.set.contains(start) { + continue; + } + + let mut end = start; + while self.b <= 255 && self.set.contains(self.b as u8) { + end = self.b as u8; + self.b += 1; + } + return Some((start, end)); + } + None + } +} + +#[cfg(test)] +#[cfg(feature = "alloc")] +mod tests { + use alloc::{vec, vec::Vec}; + + use super::*; + + #[test] + fn byte_classes() { + let mut set = ByteClassSet::empty(); + set.set_range(b'a', b'z'); + + let classes = set.byte_classes(); + assert_eq!(classes.get(0), 0); + assert_eq!(classes.get(1), 0); + assert_eq!(classes.get(2), 0); + assert_eq!(classes.get(b'a' - 1), 0); + assert_eq!(classes.get(b'a'), 1); + assert_eq!(classes.get(b'm'), 1); + assert_eq!(classes.get(b'z'), 1); + assert_eq!(classes.get(b'z' + 1), 2); + assert_eq!(classes.get(254), 2); + assert_eq!(classes.get(255), 2); + + let mut set = ByteClassSet::empty(); + set.set_range(0, 2); + set.set_range(4, 6); + let classes = set.byte_classes(); + assert_eq!(classes.get(0), 0); + assert_eq!(classes.get(1), 0); + assert_eq!(classes.get(2), 0); + assert_eq!(classes.get(3), 1); + assert_eq!(classes.get(4), 2); + assert_eq!(classes.get(5), 2); + assert_eq!(classes.get(6), 2); + assert_eq!(classes.get(7), 3); + assert_eq!(classes.get(255), 3); + } + + #[test] + fn full_byte_classes() { + let mut set = ByteClassSet::empty(); + for i in 0..256u16 { + set.set_range(i as u8, i as u8); + } + assert_eq!(set.byte_classes().alphabet_len(), 257); + } + + #[test] + fn elements_typical() { + let mut set = ByteClassSet::empty(); + set.set_range(b'b', b'd'); + set.set_range(b'g', b'm'); + set.set_range(b'z', b'z'); + let classes = set.byte_classes(); + // class 0: \x00-a + // class 1: b-d + // class 2: e-f + // class 3: g-m + // class 4: n-y + // class 5: z-z + // class 6: \x7B-\xFF + // class 7: EOI + assert_eq!(classes.alphabet_len(), 8); + + let elements = classes.elements(Unit::u8(0)).collect::<Vec<_>>(); + assert_eq!(elements.len(), 98); + assert_eq!(elements[0], Unit::u8(b'\x00')); + assert_eq!(elements[97], Unit::u8(b'a')); + + let elements = classes.elements(Unit::u8(1)).collect::<Vec<_>>(); + assert_eq!( + elements, + vec![Unit::u8(b'b'), Unit::u8(b'c'), Unit::u8(b'd')], + ); + + let elements = classes.elements(Unit::u8(2)).collect::<Vec<_>>(); + assert_eq!(elements, vec![Unit::u8(b'e'), Unit::u8(b'f')],); + + let elements = classes.elements(Unit::u8(3)).collect::<Vec<_>>(); + assert_eq!( + elements, + vec![ + Unit::u8(b'g'), + Unit::u8(b'h'), + Unit::u8(b'i'), + Unit::u8(b'j'), + Unit::u8(b'k'), + Unit::u8(b'l'), + Unit::u8(b'm'), + ], + ); + + let elements = classes.elements(Unit::u8(4)).collect::<Vec<_>>(); + assert_eq!(elements.len(), 12); + assert_eq!(elements[0], Unit::u8(b'n')); + assert_eq!(elements[11], Unit::u8(b'y')); + + let elements = classes.elements(Unit::u8(5)).collect::<Vec<_>>(); + assert_eq!(elements, vec![Unit::u8(b'z')]); + + let elements = classes.elements(Unit::u8(6)).collect::<Vec<_>>(); + assert_eq!(elements.len(), 133); + assert_eq!(elements[0], Unit::u8(b'\x7B')); + assert_eq!(elements[132], Unit::u8(b'\xFF')); + + let elements = classes.elements(Unit::eoi(7)).collect::<Vec<_>>(); + assert_eq!(elements, vec![Unit::eoi(256)]); + } + + #[test] + fn elements_singletons() { + let classes = ByteClasses::singletons(); + assert_eq!(classes.alphabet_len(), 257); + + let elements = classes.elements(Unit::u8(b'a')).collect::<Vec<_>>(); + assert_eq!(elements, vec![Unit::u8(b'a')]); + + let elements = classes.elements(Unit::eoi(5)).collect::<Vec<_>>(); + assert_eq!(elements, vec![Unit::eoi(256)]); + } + + #[test] + fn elements_empty() { + let classes = ByteClasses::empty(); + assert_eq!(classes.alphabet_len(), 2); + + let elements = classes.elements(Unit::u8(0)).collect::<Vec<_>>(); + assert_eq!(elements.len(), 256); + assert_eq!(elements[0], Unit::u8(b'\x00')); + assert_eq!(elements[255], Unit::u8(b'\xFF')); + + let elements = classes.elements(Unit::eoi(1)).collect::<Vec<_>>(); + assert_eq!(elements, vec![Unit::eoi(256)]); + } +} |