diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
commit | 26a029d407be480d791972afb5975cf62c9360a6 (patch) | |
tree | f435a8308119effd964b339f76abb83a57c29483 /third_party/rust/icu_segmenter/src | |
parent | Initial commit. (diff) | |
download | firefox-26a029d407be480d791972afb5975cf62c9360a6.tar.xz firefox-26a029d407be480d791972afb5975cf62c9360a6.zip |
Adding upstream version 124.0.1.upstream/124.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/icu_segmenter/src')
17 files changed, 5946 insertions, 0 deletions
diff --git a/third_party/rust/icu_segmenter/src/complex/dictionary.rs b/third_party/rust/icu_segmenter/src/complex/dictionary.rs new file mode 100644 index 0000000000..90360ee2b0 --- /dev/null +++ b/third_party/rust/icu_segmenter/src/complex/dictionary.rs @@ -0,0 +1,268 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +use crate::grapheme::*; +use crate::indices::Utf16Indices; +use crate::provider::*; +use core::str::CharIndices; +use icu_collections::char16trie::{Char16Trie, TrieResult}; + +/// A trait for dictionary based iterator +trait DictionaryType<'l, 's> { + /// The iterator over characters. + type IterAttr: Iterator<Item = (usize, Self::CharType)> + Clone; + + /// The character type. + type CharType: Copy + Into<u32>; + + fn to_char(c: Self::CharType) -> char; + fn char_len(c: Self::CharType) -> usize; +} + +struct DictionaryBreakIterator< + 'l, + 's, + Y: DictionaryType<'l, 's> + ?Sized, + X: Iterator<Item = usize> + ?Sized, +> { + trie: Char16Trie<'l>, + iter: Y::IterAttr, + len: usize, + grapheme_iter: X, + // TODO transform value for byte trie +} + +/// Implement the [`Iterator`] trait over the segmenter break opportunities of the given string. +/// Please see the [module-level documentation](crate) for its usages. +/// +/// Lifetimes: +/// - `'l` = lifetime of the segmenter object from which this iterator was created +/// - `'s` = lifetime of the string being segmented +/// +/// [`Iterator`]: core::iter::Iterator +impl<'l, 's, Y: DictionaryType<'l, 's> + ?Sized, X: Iterator<Item = usize> + ?Sized> Iterator + for DictionaryBreakIterator<'l, 's, Y, X> +{ + type Item = usize; + + fn next(&mut self) -> Option<Self::Item> { + let mut trie_iter = self.trie.iter(); + let mut intermediate_length = 0; + let mut not_match = false; + let mut previous_match = None; + let mut last_grapheme_offset = 0; + + while let Some(next) = self.iter.next() { + let ch = Y::to_char(next.1); + match trie_iter.next(ch) { + TrieResult::FinalValue(_) => { + return Some(next.0 + Y::char_len(next.1)); + } + TrieResult::Intermediate(_) => { + // Dictionary has to match with grapheme cluster segment. + // If not, we ignore it. + while last_grapheme_offset < next.0 + Y::char_len(next.1) { + if let Some(offset) = self.grapheme_iter.next() { + last_grapheme_offset = offset; + continue; + } + last_grapheme_offset = self.len; + break; + } + if last_grapheme_offset != next.0 + Y::char_len(next.1) { + continue; + } + + intermediate_length = next.0 + Y::char_len(next.1); + previous_match = Some(self.iter.clone()); + } + TrieResult::NoMatch => { + if intermediate_length > 0 { + if let Some(previous_match) = previous_match { + // Rewind previous match point + self.iter = previous_match; + } + return Some(intermediate_length); + } + // Not found + return Some(next.0 + Y::char_len(next.1)); + } + TrieResult::NoValue => { + // Prefix string is matched + not_match = true; + } + } + } + + if intermediate_length > 0 { + Some(intermediate_length) + } else if not_match { + // no match by scanning text + Some(self.len) + } else { + None + } + } +} + +impl<'l, 's> DictionaryType<'l, 's> for u32 { + type IterAttr = Utf16Indices<'s>; + type CharType = u32; + + fn to_char(c: u32) -> char { + char::from_u32(c).unwrap_or(char::REPLACEMENT_CHARACTER) + } + + fn char_len(c: u32) -> usize { + if c >= 0x10000 { + 2 + } else { + 1 + } + } +} + +impl<'l, 's> DictionaryType<'l, 's> for char { + type IterAttr = CharIndices<'s>; + type CharType = char; + + fn to_char(c: char) -> char { + c + } + + fn char_len(c: char) -> usize { + c.len_utf8() + } +} + +pub(super) struct DictionarySegmenter<'l> { + dict: &'l UCharDictionaryBreakDataV1<'l>, + grapheme: &'l RuleBreakDataV1<'l>, +} + +impl<'l> DictionarySegmenter<'l> { + pub(super) fn new( + dict: &'l UCharDictionaryBreakDataV1<'l>, + grapheme: &'l RuleBreakDataV1<'l>, + ) -> Self { + // TODO: no way to verify trie data + Self { dict, grapheme } + } + + /// Create a dictionary based break iterator for an `str` (a UTF-8 string). + pub(super) fn segment_str(&'l self, input: &'l str) -> impl Iterator<Item = usize> + 'l { + let grapheme_iter = GraphemeClusterSegmenter::new_and_segment_str(input, self.grapheme); + DictionaryBreakIterator::<char, GraphemeClusterBreakIteratorUtf8> { + trie: Char16Trie::new(self.dict.trie_data.clone()), + iter: input.char_indices(), + len: input.len(), + grapheme_iter, + } + } + + /// Create a dictionary based break iterator for a UTF-16 string. + pub(super) fn segment_utf16(&'l self, input: &'l [u16]) -> impl Iterator<Item = usize> + 'l { + let grapheme_iter = GraphemeClusterSegmenter::new_and_segment_utf16(input, self.grapheme); + DictionaryBreakIterator::<u32, GraphemeClusterBreakIteratorUtf16> { + trie: Char16Trie::new(self.dict.trie_data.clone()), + iter: Utf16Indices::new(input), + len: input.len(), + grapheme_iter, + } + } +} + +#[cfg(test)] +#[cfg(feature = "serde")] +mod tests { + use super::*; + use crate::{LineSegmenter, WordSegmenter}; + use icu_provider::prelude::*; + + #[test] + fn burmese_dictionary_test() { + let segmenter = LineSegmenter::new_dictionary(); + // From css/css-text/word-break/word-break-normal-my-000.html + let s = "မြန်မာစာမြန်မာစာမြန်မာစာ"; + let result: Vec<usize> = segmenter.segment_str(s).collect(); + assert_eq!(result, vec![0, 18, 24, 42, 48, 66, 72]); + + let s_utf16: Vec<u16> = s.encode_utf16().collect(); + let result: Vec<usize> = segmenter.segment_utf16(&s_utf16).collect(); + assert_eq!(result, vec![0, 6, 8, 14, 16, 22, 24]); + } + + #[test] + fn cj_dictionary_test() { + let dict_payload: DataPayload<DictionaryForWordOnlyAutoV1Marker> = crate::provider::Baked + .load(DataRequest { + locale: &icu_locid::locale!("ja").into(), + metadata: Default::default(), + }) + .unwrap() + .take_payload() + .unwrap(); + let word_segmenter = WordSegmenter::new_dictionary(); + let dict_segmenter = DictionarySegmenter::new( + dict_payload.get(), + crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1, + ); + + // Match case + let s = "龟山岛龟山岛"; + let result: Vec<usize> = dict_segmenter.segment_str(s).collect(); + assert_eq!(result, vec![9, 18]); + + let result: Vec<usize> = word_segmenter.segment_str(s).collect(); + assert_eq!(result, vec![0, 9, 18]); + + let s_utf16: Vec<u16> = s.encode_utf16().collect(); + let result: Vec<usize> = dict_segmenter.segment_utf16(&s_utf16).collect(); + assert_eq!(result, vec![3, 6]); + + let result: Vec<usize> = word_segmenter.segment_utf16(&s_utf16).collect(); + assert_eq!(result, vec![0, 3, 6]); + + // Match case, then no match case + let s = "エディターエディ"; + let result: Vec<usize> = dict_segmenter.segment_str(s).collect(); + assert_eq!(result, vec![15, 24]); + + // TODO(#3236): Why is WordSegmenter not returning the middle segment? + let result: Vec<usize> = word_segmenter.segment_str(s).collect(); + assert_eq!(result, vec![0, 24]); + + let s_utf16: Vec<u16> = s.encode_utf16().collect(); + let result: Vec<usize> = dict_segmenter.segment_utf16(&s_utf16).collect(); + assert_eq!(result, vec![5, 8]); + + // TODO(#3236): Why is WordSegmenter not returning the middle segment? + let result: Vec<usize> = word_segmenter.segment_utf16(&s_utf16).collect(); + assert_eq!(result, vec![0, 8]); + } + + #[test] + fn khmer_dictionary_test() { + let segmenter = LineSegmenter::new_dictionary(); + let s = "ភាសាខ្មែរភាសាខ្មែរភាសាខ្មែរ"; + let result: Vec<usize> = segmenter.segment_str(s).collect(); + assert_eq!(result, vec![0, 27, 54, 81]); + + let s_utf16: Vec<u16> = s.encode_utf16().collect(); + let result: Vec<usize> = segmenter.segment_utf16(&s_utf16).collect(); + assert_eq!(result, vec![0, 9, 18, 27]); + } + + #[test] + fn lao_dictionary_test() { + let segmenter = LineSegmenter::new_dictionary(); + let s = "ພາສາລາວພາສາລາວພາສາລາວ"; + let r: Vec<usize> = segmenter.segment_str(s).collect(); + assert_eq!(r, vec![0, 12, 21, 33, 42, 54, 63]); + + let s_utf16: Vec<u16> = s.encode_utf16().collect(); + let r: Vec<usize> = segmenter.segment_utf16(&s_utf16).collect(); + assert_eq!(r, vec![0, 4, 7, 11, 14, 18, 21]); + } +} diff --git a/third_party/rust/icu_segmenter/src/complex/language.rs b/third_party/rust/icu_segmenter/src/complex/language.rs new file mode 100644 index 0000000000..327eea5e20 --- /dev/null +++ b/third_party/rust/icu_segmenter/src/complex/language.rs @@ -0,0 +1,161 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +#[derive(PartialEq, Debug, Copy, Clone)] +pub(super) enum Language { + Burmese, + ChineseOrJapanese, + Khmer, + Lao, + Thai, + Unknown, +} + +// TODO: Use data provider +fn get_language(codepoint: u32) -> Language { + match codepoint { + 0xe01..=0xe7f => Language::Thai, + 0x0E80..=0x0EFF => Language::Lao, + 0x1000..=0x109f => Language::Burmese, + 0x1780..=0x17FF => Language::Khmer, + 0x19E0..=0x19FF => Language::Khmer, + 0x2E80..=0x2EFF => Language::ChineseOrJapanese, + 0x2F00..=0x2FDF => Language::ChineseOrJapanese, + 0x3040..=0x30FF => Language::ChineseOrJapanese, + 0x31F0..=0x31FF => Language::ChineseOrJapanese, + 0x32D0..=0x32FE => Language::ChineseOrJapanese, + 0x3400..=0x4DBF => Language::ChineseOrJapanese, + 0x4E00..=0x9FFF => Language::ChineseOrJapanese, + 0xa9e0..=0xa9ff => Language::Burmese, + 0xaa60..=0xaa7f => Language::Burmese, + 0xF900..=0xFAFF => Language::ChineseOrJapanese, + 0xFF66..=0xFF9D => Language::ChineseOrJapanese, + 0x16FE2..=0x16FE3 => Language::ChineseOrJapanese, + 0x16FF0..=0x16FF1 => Language::ChineseOrJapanese, + 0x1AFF0..=0x1B16F => Language::ChineseOrJapanese, + 0x1F200 => Language::ChineseOrJapanese, + 0x20000..=0x2FA1F => Language::ChineseOrJapanese, + 0x30000..=0x3134F => Language::ChineseOrJapanese, + _ => Language::Unknown, + } +} + +/// This struct is an iterator that returns the string per language from the +/// given string. +pub(super) struct LanguageIterator<'s> { + rest: &'s str, +} + +impl<'s> LanguageIterator<'s> { + pub(super) fn new(input: &'s str) -> Self { + Self { rest: input } + } +} + +impl<'s> Iterator for LanguageIterator<'s> { + type Item = (&'s str, Language); + + fn next(&mut self) -> Option<Self::Item> { + let mut indices = self.rest.char_indices(); + let lang = get_language(indices.next()?.1 as u32); + match indices.find(|&(_, ch)| get_language(ch as u32) != lang) { + Some((i, _)) => { + let (result, rest) = self.rest.split_at(i); + self.rest = rest; + Some((result, lang)) + } + None => Some((core::mem::take(&mut self.rest), lang)), + } + } +} + +pub(super) struct LanguageIteratorUtf16<'s> { + rest: &'s [u16], +} + +impl<'s> LanguageIteratorUtf16<'s> { + pub(super) fn new(input: &'s [u16]) -> Self { + Self { rest: input } + } +} + +impl<'s> Iterator for LanguageIteratorUtf16<'s> { + type Item = (&'s [u16], Language); + + fn next(&mut self) -> Option<Self::Item> { + let lang = get_language(*self.rest.first()? as u32); + match self + .rest + .iter() + .position(|&ch| get_language(ch as u32) != lang) + { + Some(i) => { + let (result, rest) = self.rest.split_at(i); + self.rest = rest; + Some((result, lang)) + } + None => Some((core::mem::take(&mut self.rest), lang)), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_thai_only() { + let s = "ภาษาไทยภาษาไทย"; + let utf16: Vec<u16> = s.encode_utf16().collect(); + let mut iter = LanguageIteratorUtf16::new(&utf16); + assert_eq!( + iter.next(), + Some((utf16.as_slice(), Language::Thai)), + "Thai language only with UTF-16" + ); + let mut iter = LanguageIterator::new(s); + assert_eq!( + iter.next(), + Some((s, Language::Thai)), + "Thai language only with UTF-8" + ); + assert_eq!(iter.next(), None, "Iterator for UTF-8 is finished"); + } + + #[test] + fn test_combine() { + const TEST_STR_THAI: &str = "ภาษาไทยภาษาไทย"; + const TEST_STR_BURMESE: &str = "ဗမာနွယ်ဘာသာစကားမျာ"; + let s = format!("{TEST_STR_THAI}{TEST_STR_BURMESE}"); + let utf16: Vec<u16> = s.encode_utf16().collect(); + let thai_utf16: Vec<u16> = TEST_STR_THAI.encode_utf16().collect(); + let burmese_utf16: Vec<u16> = TEST_STR_BURMESE.encode_utf16().collect(); + + let mut iter = LanguageIteratorUtf16::new(&utf16); + assert_eq!( + iter.next(), + Some((thai_utf16.as_slice(), Language::Thai)), + "Thai language with UTF-16 at first" + ); + assert_eq!( + iter.next(), + Some((burmese_utf16.as_slice(), Language::Burmese)), + "Burmese language with UTF-16 at second" + ); + assert_eq!(iter.next(), None, "Iterator for UTF-16 is finished"); + + let mut iter = LanguageIterator::new(&s); + assert_eq!( + iter.next(), + Some((TEST_STR_THAI, Language::Thai)), + "Thai language with UTF-8 at first" + ); + assert_eq!( + iter.next(), + Some((TEST_STR_BURMESE, Language::Burmese)), + "Burmese language with UTF-8 at second" + ); + assert_eq!(iter.next(), None, "Iterator for UTF-8 is finished"); + } +} diff --git a/third_party/rust/icu_segmenter/src/complex/lstm/matrix.rs b/third_party/rust/icu_segmenter/src/complex/lstm/matrix.rs new file mode 100644 index 0000000000..3cf5ce2e3c --- /dev/null +++ b/third_party/rust/icu_segmenter/src/complex/lstm/matrix.rs @@ -0,0 +1,540 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +use alloc::vec; +use alloc::vec::Vec; +use core::ops::Range; +#[allow(unused_imports)] +use core_maths::*; +use zerovec::ule::AsULE; +use zerovec::ZeroSlice; + +/// A `D`-dimensional, heap-allocated matrix. +/// +/// This matrix implementation supports slicing matrices into tightly-packed +/// submatrices. For example, indexing into a matrix of size 5x4x3 returns a +/// matrix of size 4x3. For more information, see [`MatrixOwned::submatrix`]. +#[derive(Debug, Clone)] +pub(super) struct MatrixOwned<const D: usize> { + data: Vec<f32>, + dims: [usize; D], +} + +impl<const D: usize> MatrixOwned<D> { + pub(super) fn as_borrowed(&self) -> MatrixBorrowed<D> { + MatrixBorrowed { + data: &self.data, + dims: self.dims, + } + } + + pub(super) fn new_zero(dims: [usize; D]) -> Self { + let total_len = dims.iter().product::<usize>(); + MatrixOwned { + data: vec![0.0; total_len], + dims, + } + } + + /// Returns the tighly packed submatrix at _index_, or `None` if _index_ is out of range. + /// + /// For example, if the matrix is 5x4x3, this function returns a matrix sized 4x3. If the + /// matrix is 4x3, then this function returns a linear matrix of length 3. + /// + /// The type parameter `M` should be `D - 1`. + #[inline] + pub(super) fn submatrix<const M: usize>(&self, index: usize) -> Option<MatrixBorrowed<M>> { + // This assertion is based on const generics; it should always succeed and be elided. + assert_eq!(M, D - 1); + let (range, dims) = self.as_borrowed().submatrix_range(index); + let data = &self.data.get(range)?; + Some(MatrixBorrowed { data, dims }) + } + + pub(super) fn as_mut(&mut self) -> MatrixBorrowedMut<D> { + MatrixBorrowedMut { + data: &mut self.data, + dims: self.dims, + } + } + + /// A mutable version of [`Self::submatrix`]. + #[inline] + pub(super) fn submatrix_mut<const M: usize>( + &mut self, + index: usize, + ) -> Option<MatrixBorrowedMut<M>> { + // This assertion is based on const generics; it should always succeed and be elided. + assert_eq!(M, D - 1); + let (range, dims) = self.as_borrowed().submatrix_range(index); + let data = self.data.get_mut(range)?; + Some(MatrixBorrowedMut { data, dims }) + } +} + +/// A `D`-dimensional, borrowed matrix. +#[derive(Debug, Clone, Copy)] +pub(super) struct MatrixBorrowed<'a, const D: usize> { + data: &'a [f32], + dims: [usize; D], +} + +impl<'a, const D: usize> MatrixBorrowed<'a, D> { + #[cfg(debug_assertions)] + pub(super) fn debug_assert_dims(&self, dims: [usize; D]) { + debug_assert_eq!(dims, self.dims); + let expected_len = dims.iter().product::<usize>(); + debug_assert_eq!(expected_len, self.data.len()); + } + + pub(super) fn as_slice(&self) -> &'a [f32] { + self.data + } + + /// See [`MatrixOwned::submatrix`]. + #[inline] + pub(super) fn submatrix<const M: usize>(&self, index: usize) -> Option<MatrixBorrowed<'a, M>> { + // This assertion is based on const generics; it should always succeed and be elided. + assert_eq!(M, D - 1); + let (range, dims) = self.submatrix_range(index); + let data = &self.data.get(range)?; + Some(MatrixBorrowed { data, dims }) + } + + #[inline] + fn submatrix_range<const M: usize>(&self, index: usize) -> (Range<usize>, [usize; M]) { + // This assertion is based on const generics; it should always succeed and be elided. + assert_eq!(M, D - 1); + // The above assertion guarantees that the following line will succeed + #[allow(clippy::indexing_slicing, clippy::unwrap_used)] + let sub_dims: [usize; M] = self.dims[1..].try_into().unwrap(); + let n = sub_dims.iter().product::<usize>(); + (n * index..n * (index + 1), sub_dims) + } +} + +macro_rules! impl_basic_dim { + ($t1:path, $t2:path, $t3:path) => { + impl<'a> $t1 { + #[allow(dead_code)] + pub(super) fn dim(&self) -> usize { + let [dim] = self.dims; + dim + } + } + impl<'a> $t2 { + #[allow(dead_code)] + pub(super) fn dim(&self) -> (usize, usize) { + let [d0, d1] = self.dims; + (d0, d1) + } + } + impl<'a> $t3 { + #[allow(dead_code)] + pub(super) fn dim(&self) -> (usize, usize, usize) { + let [d0, d1, d2] = self.dims; + (d0, d1, d2) + } + } + }; +} + +impl_basic_dim!(MatrixOwned<1>, MatrixOwned<2>, MatrixOwned<3>); +impl_basic_dim!( + MatrixBorrowed<'a, 1>, + MatrixBorrowed<'a, 2>, + MatrixBorrowed<'a, 3> +); +impl_basic_dim!( + MatrixBorrowedMut<'a, 1>, + MatrixBorrowedMut<'a, 2>, + MatrixBorrowedMut<'a, 3> +); +impl_basic_dim!(MatrixZero<'a, 1>, MatrixZero<'a, 2>, MatrixZero<'a, 3>); + +/// A `D`-dimensional, mutably borrowed matrix. +pub(super) struct MatrixBorrowedMut<'a, const D: usize> { + pub(super) data: &'a mut [f32], + pub(super) dims: [usize; D], +} + +impl<'a, const D: usize> MatrixBorrowedMut<'a, D> { + pub(super) fn as_borrowed(&self) -> MatrixBorrowed<D> { + MatrixBorrowed { + data: self.data, + dims: self.dims, + } + } + + pub(super) fn as_mut_slice(&mut self) -> &mut [f32] { + self.data + } + + pub(super) fn copy_submatrix<const M: usize>(&mut self, from: usize, to: usize) { + let (range_from, _) = self.as_borrowed().submatrix_range::<M>(from); + let (range_to, _) = self.as_borrowed().submatrix_range::<M>(to); + if let (Some(_), Some(_)) = ( + self.data.get(range_from.clone()), + self.data.get(range_to.clone()), + ) { + // This function is panicky, but we just validated the ranges + self.data.copy_within(range_from, range_to.start); + } + } + + #[must_use] + pub(super) fn add(&mut self, other: MatrixZero<'_, D>) -> Option<()> { + debug_assert_eq!(self.dims, other.dims); + // TODO: Vectorize? + for i in 0..self.data.len() { + *self.data.get_mut(i)? += other.data.get(i)?; + } + Some(()) + } + + #[allow(dead_code)] // maybe needed for more complicated bies calculations + /// Mutates this matrix by applying a softmax transformation. + pub(super) fn softmax_transform(&mut self) { + for v in self.data.iter_mut() { + *v = v.exp(); + } + let sm = 1.0 / self.data.iter().sum::<f32>(); + for v in self.data.iter_mut() { + *v *= sm; + } + } + + pub(super) fn sigmoid_transform(&mut self) { + for x in &mut self.data.iter_mut() { + *x = 1.0 / (1.0 + (-*x).exp()); + } + } + + pub(super) fn tanh_transform(&mut self) { + for x in &mut self.data.iter_mut() { + *x = x.tanh(); + } + } + + pub(super) fn convolve( + &mut self, + i: MatrixBorrowed<'_, D>, + c: MatrixBorrowed<'_, D>, + f: MatrixBorrowed<'_, D>, + ) { + let i = i.as_slice(); + let c = c.as_slice(); + let f = f.as_slice(); + let len = self.data.len(); + if len != i.len() || len != c.len() || len != f.len() { + debug_assert!(false, "LSTM matrices not the correct dimensions"); + return; + } + for idx in 0..len { + // Safety: The lengths are all the same (checked above) + unsafe { + *self.data.get_unchecked_mut(idx) = i.get_unchecked(idx) * c.get_unchecked(idx) + + self.data.get_unchecked(idx) * f.get_unchecked(idx) + } + } + } + + pub(super) fn mul_tanh(&mut self, o: MatrixBorrowed<'_, D>, c: MatrixBorrowed<'_, D>) { + let o = o.as_slice(); + let c = c.as_slice(); + let len = self.data.len(); + if len != o.len() || len != c.len() { + debug_assert!(false, "LSTM matrices not the correct dimensions"); + return; + } + for idx in 0..len { + // Safety: The lengths are all the same (checked above) + unsafe { + *self.data.get_unchecked_mut(idx) = + o.get_unchecked(idx) * c.get_unchecked(idx).tanh(); + } + } + } +} + +impl<'a> MatrixBorrowed<'a, 1> { + #[allow(dead_code)] // could be useful + pub(super) fn dot_1d(&self, other: MatrixZero<1>) -> f32 { + debug_assert_eq!(self.dims, other.dims); + unrolled_dot_1(self.data, other.data) + } +} + +impl<'a> MatrixBorrowedMut<'a, 1> { + /// Calculate the dot product of a and b, adding the result to self. + /// + /// Note: For better dot product efficiency, if `b` is MxN, then `a` should be N; + /// this is the opposite of standard practice. + pub(super) fn add_dot_2d(&mut self, a: MatrixBorrowed<1>, b: MatrixZero<2>) { + let m = a.dim(); + let n = self.as_borrowed().dim(); + debug_assert_eq!( + m, + b.dim().1, + "dims: {:?}/{:?}/{:?}", + self.as_borrowed().dim(), + a.dim(), + b.dim() + ); + debug_assert_eq!( + n, + b.dim().0, + "dims: {:?}/{:?}/{:?}", + self.as_borrowed().dim(), + a.dim(), + b.dim() + ); + for i in 0..n { + if let (Some(dest), Some(b_sub)) = (self.as_mut_slice().get_mut(i), b.submatrix::<1>(i)) + { + *dest += unrolled_dot_1(a.data, b_sub.data); + } else { + debug_assert!(false, "unreachable: dims checked above"); + } + } + } +} + +impl<'a> MatrixBorrowedMut<'a, 2> { + /// Calculate the dot product of a and b, adding the result to self. + /// + /// Self should be _MxN_; `a`, _O_; and `b`, _MxNxO_. + pub(super) fn add_dot_3d_1(&mut self, a: MatrixBorrowed<1>, b: MatrixZero<3>) { + let m = a.dim(); + let n = self.as_borrowed().dim().0 * self.as_borrowed().dim().1; + debug_assert_eq!( + m, + b.dim().2, + "dims: {:?}/{:?}/{:?}", + self.as_borrowed().dim(), + a.dim(), + b.dim() + ); + debug_assert_eq!( + n, + b.dim().0 * b.dim().1, + "dims: {:?}/{:?}/{:?}", + self.as_borrowed().dim(), + a.dim(), + b.dim() + ); + // Note: The following two loops are equivalent, but the second has more opportunity for + // vectorization since it allows the vectorization to span submatrices. + // for i in 0..b.dim().0 { + // self.submatrix_mut::<1>(i).add_dot_2d(a, b.submatrix(i)); + // } + let lhs = a.as_slice(); + for i in 0..n { + if let (Some(dest), Some(rhs)) = ( + self.as_mut_slice().get_mut(i), + b.as_slice().get_subslice(i * m..(i + 1) * m), + ) { + *dest += unrolled_dot_1(lhs, rhs); + } else { + debug_assert!(false, "unreachable: dims checked above"); + } + } + } + + /// Calculate the dot product of a and b, adding the result to self. + /// + /// Self should be _MxN_; `a`, _O_; and `b`, _MxNxO_. + pub(super) fn add_dot_3d_2(&mut self, a: MatrixZero<1>, b: MatrixZero<3>) { + let m = a.dim(); + let n = self.as_borrowed().dim().0 * self.as_borrowed().dim().1; + debug_assert_eq!( + m, + b.dim().2, + "dims: {:?}/{:?}/{:?}", + self.as_borrowed().dim(), + a.dim(), + b.dim() + ); + debug_assert_eq!( + n, + b.dim().0 * b.dim().1, + "dims: {:?}/{:?}/{:?}", + self.as_borrowed().dim(), + a.dim(), + b.dim() + ); + // Note: The following two loops are equivalent, but the second has more opportunity for + // vectorization since it allows the vectorization to span submatrices. + // for i in 0..b.dim().0 { + // self.submatrix_mut::<1>(i).add_dot_2d(a, b.submatrix(i)); + // } + let lhs = a.as_slice(); + for i in 0..n { + if let (Some(dest), Some(rhs)) = ( + self.as_mut_slice().get_mut(i), + b.as_slice().get_subslice(i * m..(i + 1) * m), + ) { + *dest += unrolled_dot_2(lhs, rhs); + } else { + debug_assert!(false, "unreachable: dims checked above"); + } + } + } +} + +/// A `D`-dimensional matrix borrowed from a [`ZeroSlice`]. +#[derive(Debug, Clone, Copy)] +pub(super) struct MatrixZero<'a, const D: usize> { + data: &'a ZeroSlice<f32>, + dims: [usize; D], +} + +impl<'a> From<&'a crate::provider::LstmMatrix1<'a>> for MatrixZero<'a, 1> { + fn from(other: &'a crate::provider::LstmMatrix1<'a>) -> Self { + Self { + data: &other.data, + dims: other.dims.map(|x| x as usize), + } + } +} + +impl<'a> From<&'a crate::provider::LstmMatrix2<'a>> for MatrixZero<'a, 2> { + fn from(other: &'a crate::provider::LstmMatrix2<'a>) -> Self { + Self { + data: &other.data, + dims: other.dims.map(|x| x as usize), + } + } +} + +impl<'a> From<&'a crate::provider::LstmMatrix3<'a>> for MatrixZero<'a, 3> { + fn from(other: &'a crate::provider::LstmMatrix3<'a>) -> Self { + Self { + data: &other.data, + dims: other.dims.map(|x| x as usize), + } + } +} + +impl<'a, const D: usize> MatrixZero<'a, D> { + #[allow(clippy::wrong_self_convention)] // same convention as slice::to_vec + pub(super) fn to_owned(&self) -> MatrixOwned<D> { + MatrixOwned { + data: self.data.iter().collect(), + dims: self.dims, + } + } + + pub(super) fn as_slice(&self) -> &ZeroSlice<f32> { + self.data + } + + #[cfg(debug_assertions)] + pub(super) fn debug_assert_dims(&self, dims: [usize; D]) { + debug_assert_eq!(dims, self.dims); + let expected_len = dims.iter().product::<usize>(); + debug_assert_eq!(expected_len, self.data.len()); + } + + /// See [`MatrixOwned::submatrix`]. + #[inline] + pub(super) fn submatrix<const M: usize>(&self, index: usize) -> Option<MatrixZero<'a, M>> { + // This assertion is based on const generics; it should always succeed and be elided. + assert_eq!(M, D - 1); + let (range, dims) = self.submatrix_range(index); + let data = &self.data.get_subslice(range)?; + Some(MatrixZero { data, dims }) + } + + #[inline] + fn submatrix_range<const M: usize>(&self, index: usize) -> (Range<usize>, [usize; M]) { + // This assertion is based on const generics; it should always succeed and be elided. + assert_eq!(M, D - 1); + // The above assertion guarantees that the following line will succeed + #[allow(clippy::indexing_slicing, clippy::unwrap_used)] + let sub_dims: [usize; M] = self.dims[1..].try_into().unwrap(); + let n = sub_dims.iter().product::<usize>(); + (n * index..n * (index + 1), sub_dims) + } +} + +macro_rules! f32c { + ($ule:expr) => { + f32::from_unaligned($ule) + }; +} + +/// Compute the dot product of an aligned and an unaligned f32 slice. +/// +/// `xs` and `ys` must be the same length +/// +/// (Based on ndarray 0.15.6) +fn unrolled_dot_1(xs: &[f32], ys: &ZeroSlice<f32>) -> f32 { + debug_assert_eq!(xs.len(), ys.len()); + // eightfold unrolled so that floating point can be vectorized + // (even with strict floating point accuracy semantics) + let mut p = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0); + let xit = xs.chunks_exact(8); + let yit = ys.as_ule_slice().chunks_exact(8); + let sum = xit + .remainder() + .iter() + .zip(yit.remainder().iter()) + .map(|(x, y)| x * f32c!(*y)) + .sum::<f32>(); + for (xx, yy) in xit.zip(yit) { + // TODO: Use array_chunks once stable to avoid the unwrap. + // <https://github.com/rust-lang/rust/issues/74985> + #[allow(clippy::unwrap_used)] + let [x0, x1, x2, x3, x4, x5, x6, x7] = *<&[f32; 8]>::try_from(xx).unwrap(); + #[allow(clippy::unwrap_used)] + let [y0, y1, y2, y3, y4, y5, y6, y7] = *<&[<f32 as AsULE>::ULE; 8]>::try_from(yy).unwrap(); + p.0 += x0 * f32c!(y0); + p.1 += x1 * f32c!(y1); + p.2 += x2 * f32c!(y2); + p.3 += x3 * f32c!(y3); + p.4 += x4 * f32c!(y4); + p.5 += x5 * f32c!(y5); + p.6 += x6 * f32c!(y6); + p.7 += x7 * f32c!(y7); + } + sum + (p.0 + p.4) + (p.1 + p.5) + (p.2 + p.6) + (p.3 + p.7) +} + +/// Compute the dot product of two unaligned f32 slices. +/// +/// `xs` and `ys` must be the same length +/// +/// (Based on ndarray 0.15.6) +fn unrolled_dot_2(xs: &ZeroSlice<f32>, ys: &ZeroSlice<f32>) -> f32 { + debug_assert_eq!(xs.len(), ys.len()); + // eightfold unrolled so that floating point can be vectorized + // (even with strict floating point accuracy semantics) + let mut p = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0); + let xit = xs.as_ule_slice().chunks_exact(8); + let yit = ys.as_ule_slice().chunks_exact(8); + let sum = xit + .remainder() + .iter() + .zip(yit.remainder().iter()) + .map(|(x, y)| f32c!(*x) * f32c!(*y)) + .sum::<f32>(); + for (xx, yy) in xit.zip(yit) { + // TODO: Use array_chunks once stable to avoid the unwrap. + // <https://github.com/rust-lang/rust/issues/74985> + #[allow(clippy::unwrap_used)] + let [x0, x1, x2, x3, x4, x5, x6, x7] = *<&[<f32 as AsULE>::ULE; 8]>::try_from(xx).unwrap(); + #[allow(clippy::unwrap_used)] + let [y0, y1, y2, y3, y4, y5, y6, y7] = *<&[<f32 as AsULE>::ULE; 8]>::try_from(yy).unwrap(); + p.0 += f32c!(x0) * f32c!(y0); + p.1 += f32c!(x1) * f32c!(y1); + p.2 += f32c!(x2) * f32c!(y2); + p.3 += f32c!(x3) * f32c!(y3); + p.4 += f32c!(x4) * f32c!(y4); + p.5 += f32c!(x5) * f32c!(y5); + p.6 += f32c!(x6) * f32c!(y6); + p.7 += f32c!(x7) * f32c!(y7); + } + sum + (p.0 + p.4) + (p.1 + p.5) + (p.2 + p.6) + (p.3 + p.7) +} diff --git a/third_party/rust/icu_segmenter/src/complex/lstm/mod.rs b/third_party/rust/icu_segmenter/src/complex/lstm/mod.rs new file mode 100644 index 0000000000..8718cbd3da --- /dev/null +++ b/third_party/rust/icu_segmenter/src/complex/lstm/mod.rs @@ -0,0 +1,402 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +use crate::grapheme::GraphemeClusterSegmenter; +use crate::provider::*; +use alloc::vec::Vec; +use core::char::{decode_utf16, REPLACEMENT_CHARACTER}; +use zerovec::{maps::ZeroMapBorrowed, ule::UnvalidatedStr}; + +mod matrix; +use matrix::*; + +// A word break iterator using LSTM model. Input string have to be same language. + +struct LstmSegmenterIterator<'s> { + input: &'s str, + pos_utf8: usize, + bies: BiesIterator<'s>, +} + +impl Iterator for LstmSegmenterIterator<'_> { + type Item = usize; + + fn next(&mut self) -> Option<Self::Item> { + #[allow(clippy::indexing_slicing)] // pos_utf8 in range + loop { + let is_e = self.bies.next()?; + self.pos_utf8 += self.input[self.pos_utf8..].chars().next()?.len_utf8(); + if is_e || self.bies.len() == 0 { + return Some(self.pos_utf8); + } + } + } +} + +struct LstmSegmenterIteratorUtf16<'s> { + bies: BiesIterator<'s>, + pos: usize, +} + +impl Iterator for LstmSegmenterIteratorUtf16<'_> { + type Item = usize; + + fn next(&mut self) -> Option<Self::Item> { + loop { + self.pos += 1; + if self.bies.next()? || self.bies.len() == 0 { + return Some(self.pos); + } + } + } +} + +pub(super) struct LstmSegmenter<'l> { + dic: ZeroMapBorrowed<'l, UnvalidatedStr, u16>, + embedding: MatrixZero<'l, 2>, + fw_w: MatrixZero<'l, 3>, + fw_u: MatrixZero<'l, 3>, + fw_b: MatrixZero<'l, 2>, + bw_w: MatrixZero<'l, 3>, + bw_u: MatrixZero<'l, 3>, + bw_b: MatrixZero<'l, 2>, + timew_fw: MatrixZero<'l, 2>, + timew_bw: MatrixZero<'l, 2>, + time_b: MatrixZero<'l, 1>, + grapheme: Option<&'l RuleBreakDataV1<'l>>, +} + +impl<'l> LstmSegmenter<'l> { + /// Returns `Err` if grapheme data is required but not present + pub(super) fn new(lstm: &'l LstmDataV1<'l>, grapheme: &'l RuleBreakDataV1<'l>) -> Self { + let LstmDataV1::Float32(lstm) = lstm; + let time_w = MatrixZero::from(&lstm.time_w); + #[allow(clippy::unwrap_used)] // shape (2, 4, hunits) + let timew_fw = time_w.submatrix(0).unwrap(); + #[allow(clippy::unwrap_used)] // shape (2, 4, hunits) + let timew_bw = time_w.submatrix(1).unwrap(); + Self { + dic: lstm.dic.as_borrowed(), + embedding: MatrixZero::from(&lstm.embedding), + fw_w: MatrixZero::from(&lstm.fw_w), + fw_u: MatrixZero::from(&lstm.fw_u), + fw_b: MatrixZero::from(&lstm.fw_b), + bw_w: MatrixZero::from(&lstm.bw_w), + bw_u: MatrixZero::from(&lstm.bw_u), + bw_b: MatrixZero::from(&lstm.bw_b), + timew_fw, + timew_bw, + time_b: MatrixZero::from(&lstm.time_b), + grapheme: (lstm.model == ModelType::GraphemeClusters).then_some(grapheme), + } + } + + /// Create an LSTM based break iterator for an `str` (a UTF-8 string). + pub(super) fn segment_str(&'l self, input: &'l str) -> impl Iterator<Item = usize> + 'l { + self.segment_str_p(input) + } + + // For unit testing as we cannot inspect the opaque type's bies + fn segment_str_p(&'l self, input: &'l str) -> LstmSegmenterIterator<'l> { + let input_seq = if let Some(grapheme) = self.grapheme { + GraphemeClusterSegmenter::new_and_segment_str(input, grapheme) + .collect::<Vec<usize>>() + .windows(2) + .map(|chunk| { + let range = if let [first, second, ..] = chunk { + *first..*second + } else { + unreachable!() + }; + let grapheme_cluster = if let Some(grapheme_cluster) = input.get(range) { + grapheme_cluster + } else { + return self.dic.len() as u16; + }; + + self.dic + .get_copied(UnvalidatedStr::from_str(grapheme_cluster)) + .unwrap_or_else(|| self.dic.len() as u16) + }) + .collect() + } else { + input + .chars() + .map(|c| { + self.dic + .get_copied(UnvalidatedStr::from_str(c.encode_utf8(&mut [0; 4]))) + .unwrap_or_else(|| self.dic.len() as u16) + }) + .collect() + }; + LstmSegmenterIterator { + input, + pos_utf8: 0, + bies: BiesIterator::new(self, input_seq), + } + } + + /// Create an LSTM based break iterator for a UTF-16 string. + pub(super) fn segment_utf16(&'l self, input: &[u16]) -> impl Iterator<Item = usize> + 'l { + let input_seq = if let Some(grapheme) = self.grapheme { + GraphemeClusterSegmenter::new_and_segment_utf16(input, grapheme) + .collect::<Vec<usize>>() + .windows(2) + .map(|chunk| { + let range = if let [first, second, ..] = chunk { + *first..*second + } else { + unreachable!() + }; + let grapheme_cluster = if let Some(grapheme_cluster) = input.get(range) { + grapheme_cluster + } else { + return self.dic.len() as u16; + }; + + self.dic + .get_copied_by(|key| { + key.as_bytes().iter().copied().cmp( + decode_utf16(grapheme_cluster.iter().copied()).flat_map(|c| { + let mut buf = [0; 4]; + let len = c + .unwrap_or(REPLACEMENT_CHARACTER) + .encode_utf8(&mut buf) + .len(); + buf.into_iter().take(len) + }), + ) + }) + .unwrap_or_else(|| self.dic.len() as u16) + }) + .collect() + } else { + decode_utf16(input.iter().copied()) + .map(|c| c.unwrap_or(REPLACEMENT_CHARACTER)) + .map(|c| { + self.dic + .get_copied(UnvalidatedStr::from_str(c.encode_utf8(&mut [0; 4]))) + .unwrap_or_else(|| self.dic.len() as u16) + }) + .collect() + }; + LstmSegmenterIteratorUtf16 { + bies: BiesIterator::new(self, input_seq), + pos: 0, + } + } +} + +struct BiesIterator<'l> { + segmenter: &'l LstmSegmenter<'l>, + input_seq: core::iter::Enumerate<alloc::vec::IntoIter<u16>>, + h_bw: MatrixOwned<2>, + curr_fw: MatrixOwned<1>, + c_fw: MatrixOwned<1>, +} + +impl<'l> BiesIterator<'l> { + // input_seq is a sequence of id numbers that represents grapheme clusters or code points in the input line. These ids are used later + // in the embedding layer of the model. + fn new(segmenter: &'l LstmSegmenter<'l>, input_seq: Vec<u16>) -> Self { + let hunits = segmenter.fw_u.dim().1; + + // Backward LSTM + let mut c_bw = MatrixOwned::<1>::new_zero([hunits]); + let mut h_bw = MatrixOwned::<2>::new_zero([input_seq.len(), hunits]); + for (i, &g_id) in input_seq.iter().enumerate().rev() { + if i + 1 < input_seq.len() { + h_bw.as_mut().copy_submatrix::<1>(i + 1, i); + } + #[allow(clippy::unwrap_used)] + compute_hc( + segmenter.embedding.submatrix::<1>(g_id as usize).unwrap(), /* shape (dict.len() + 1, hunit), g_id is at most dict.len() */ + h_bw.submatrix_mut(i).unwrap(), // shape (input_seq.len(), hunits) + c_bw.as_mut(), + segmenter.bw_w, + segmenter.bw_u, + segmenter.bw_b, + ); + } + + Self { + input_seq: input_seq.into_iter().enumerate(), + h_bw, + c_fw: MatrixOwned::<1>::new_zero([hunits]), + curr_fw: MatrixOwned::<1>::new_zero([hunits]), + segmenter, + } + } +} + +impl ExactSizeIterator for BiesIterator<'_> { + fn len(&self) -> usize { + self.input_seq.len() + } +} + +impl Iterator for BiesIterator<'_> { + type Item = bool; + + fn next(&mut self) -> Option<Self::Item> { + let (i, g_id) = self.input_seq.next()?; + + #[allow(clippy::unwrap_used)] + compute_hc( + self.segmenter + .embedding + .submatrix::<1>(g_id as usize) + .unwrap(), // shape (dict.len() + 1, hunit), g_id is at most dict.len() + self.curr_fw.as_mut(), + self.c_fw.as_mut(), + self.segmenter.fw_w, + self.segmenter.fw_u, + self.segmenter.fw_b, + ); + + #[allow(clippy::unwrap_used)] // shape (input_seq.len(), hunits) + let curr_bw = self.h_bw.submatrix::<1>(i).unwrap(); + let mut weights = [0.0; 4]; + let mut curr_est = MatrixBorrowedMut { + data: &mut weights, + dims: [4], + }; + curr_est.add_dot_2d(self.curr_fw.as_borrowed(), self.segmenter.timew_fw); + curr_est.add_dot_2d(curr_bw, self.segmenter.timew_bw); + #[allow(clippy::unwrap_used)] // both shape (4) + curr_est.add(self.segmenter.time_b).unwrap(); + // For correct BIES weight calculation we'd now have to apply softmax, however + // we're only doing a naive argmax, so a monotonic function doesn't make a difference. + + Some(weights[2] > weights[0] && weights[2] > weights[1] && weights[2] > weights[3]) + } +} + +/// `compute_hc1` implemens the evaluation of one LSTM layer. +fn compute_hc<'a>( + x_t: MatrixZero<'a, 1>, + mut h_tm1: MatrixBorrowedMut<'a, 1>, + mut c_tm1: MatrixBorrowedMut<'a, 1>, + w: MatrixZero<'a, 3>, + u: MatrixZero<'a, 3>, + b: MatrixZero<'a, 2>, +) { + #[cfg(debug_assertions)] + { + let hunits = h_tm1.dim(); + let embedd_dim = x_t.dim(); + c_tm1.as_borrowed().debug_assert_dims([hunits]); + w.debug_assert_dims([4, hunits, embedd_dim]); + u.debug_assert_dims([4, hunits, hunits]); + b.debug_assert_dims([4, hunits]); + } + + let mut s_t = b.to_owned(); + + s_t.as_mut().add_dot_3d_2(x_t, w); + s_t.as_mut().add_dot_3d_1(h_tm1.as_borrowed(), u); + + #[allow(clippy::unwrap_used)] // first dimension is 4 + s_t.submatrix_mut::<1>(0).unwrap().sigmoid_transform(); + #[allow(clippy::unwrap_used)] // first dimension is 4 + s_t.submatrix_mut::<1>(1).unwrap().sigmoid_transform(); + #[allow(clippy::unwrap_used)] // first dimension is 4 + s_t.submatrix_mut::<1>(2).unwrap().tanh_transform(); + #[allow(clippy::unwrap_used)] // first dimension is 4 + s_t.submatrix_mut::<1>(3).unwrap().sigmoid_transform(); + + #[allow(clippy::unwrap_used)] // first dimension is 4 + c_tm1.convolve( + s_t.as_borrowed().submatrix(0).unwrap(), + s_t.as_borrowed().submatrix(2).unwrap(), + s_t.as_borrowed().submatrix(1).unwrap(), + ); + + #[allow(clippy::unwrap_used)] // first dimension is 4 + h_tm1.mul_tanh(s_t.as_borrowed().submatrix(3).unwrap(), c_tm1.as_borrowed()); +} + +#[cfg(test)] +mod tests { + use super::*; + use icu_locid::locale; + use icu_provider::prelude::*; + use serde::Deserialize; + use std::fs::File; + use std::io::BufReader; + + /// `TestCase` is a struct used to store a single test case. + /// Each test case has two attributs: `unseg` which denots the unsegmented line, and `true_bies` which indicates the Bies + /// sequence representing the true segmentation. + #[derive(PartialEq, Debug, Deserialize)] + struct TestCase { + unseg: String, + expected_bies: String, + true_bies: String, + } + + /// `TestTextData` is a struct to store a vector of `TestCase` that represents a test text. + #[derive(PartialEq, Debug, Deserialize)] + struct TestTextData { + testcases: Vec<TestCase>, + } + + #[derive(Debug)] + struct TestText { + data: TestTextData, + } + + fn load_test_text(filename: &str) -> TestTextData { + let file = File::open(filename).expect("File should be present"); + let reader = BufReader::new(file); + serde_json::from_reader(reader).expect("JSON syntax error") + } + + #[test] + fn segment_file_by_lstm() { + let lstm: DataPayload<LstmForWordLineAutoV1Marker> = crate::provider::Baked + .load(DataRequest { + locale: &locale!("th").into(), + metadata: Default::default(), + }) + .unwrap() + .take_payload() + .unwrap(); + let lstm = LstmSegmenter::new( + lstm.get(), + crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1, + ); + + // Importing the test data + let test_text_data = load_test_text(&format!( + "tests/testdata/test_text_{}.json", + if lstm.grapheme.is_some() { + "grapheme" + } else { + "codepoints" + } + )); + let test_text = TestText { + data: test_text_data, + }; + + // Testing + for test_case in &test_text.data.testcases { + let lstm_output = lstm + .segment_str_p(&test_case.unseg) + .bies + .map(|is_e| if is_e { 'e' } else { '?' }) + .collect::<String>(); + println!("Test case : {}", test_case.unseg); + println!("Expected bies : {}", test_case.expected_bies); + println!("Estimated bies : {lstm_output}"); + println!("True bies : {}", test_case.true_bies); + println!("****************************************************"); + assert_eq!( + test_case.expected_bies.replace(['b', 'i', 's'], "?"), + lstm_output + ); + } + } +} diff --git a/third_party/rust/icu_segmenter/src/complex/mod.rs b/third_party/rust/icu_segmenter/src/complex/mod.rs new file mode 100644 index 0000000000..65f49a92f0 --- /dev/null +++ b/third_party/rust/icu_segmenter/src/complex/mod.rs @@ -0,0 +1,440 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +use crate::provider::*; +use alloc::vec::Vec; +use icu_locid::{locale, Locale}; +use icu_provider::prelude::*; + +mod dictionary; +use dictionary::*; +mod language; +use language::*; +#[cfg(feature = "lstm")] +mod lstm; +#[cfg(feature = "lstm")] +use lstm::*; + +#[cfg(not(feature = "lstm"))] +type DictOrLstm = Result<DataPayload<UCharDictionaryBreakDataV1Marker>, core::convert::Infallible>; +#[cfg(not(feature = "lstm"))] +type DictOrLstmBorrowed<'a> = + Result<&'a DataPayload<UCharDictionaryBreakDataV1Marker>, &'a core::convert::Infallible>; + +#[cfg(feature = "lstm")] +type DictOrLstm = + Result<DataPayload<UCharDictionaryBreakDataV1Marker>, DataPayload<LstmDataV1Marker>>; +#[cfg(feature = "lstm")] +type DictOrLstmBorrowed<'a> = + Result<&'a DataPayload<UCharDictionaryBreakDataV1Marker>, &'a DataPayload<LstmDataV1Marker>>; + +#[derive(Debug)] +pub(crate) struct ComplexPayloads { + grapheme: DataPayload<GraphemeClusterBreakDataV1Marker>, + my: Option<DictOrLstm>, + km: Option<DictOrLstm>, + lo: Option<DictOrLstm>, + th: Option<DictOrLstm>, + ja: Option<DataPayload<UCharDictionaryBreakDataV1Marker>>, +} + +impl ComplexPayloads { + fn select(&self, language: Language) -> Option<DictOrLstmBorrowed> { + const ERR: DataError = DataError::custom("No segmentation model for language"); + match language { + Language::Burmese => self.my.as_ref().map(Result::as_ref).or_else(|| { + ERR.with_display_context("my"); + None + }), + Language::Khmer => self.km.as_ref().map(Result::as_ref).or_else(|| { + ERR.with_display_context("km"); + None + }), + Language::Lao => self.lo.as_ref().map(Result::as_ref).or_else(|| { + ERR.with_display_context("lo"); + None + }), + Language::Thai => self.th.as_ref().map(Result::as_ref).or_else(|| { + ERR.with_display_context("th"); + None + }), + Language::ChineseOrJapanese => self.ja.as_ref().map(Ok).or_else(|| { + ERR.with_display_context("ja"); + None + }), + Language::Unknown => None, + } + } + + #[cfg(feature = "lstm")] + #[cfg(feature = "compiled_data")] + pub(crate) fn new_lstm() -> Self { + #[allow(clippy::unwrap_used)] + // try_load is infallible if the provider only returns `MissingLocale`. + Self { + grapheme: DataPayload::from_static_ref( + crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1, + ), + my: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("my")) + .unwrap() + .map(DataPayload::cast) + .map(Err), + km: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("km")) + .unwrap() + .map(DataPayload::cast) + .map(Err), + lo: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("lo")) + .unwrap() + .map(DataPayload::cast) + .map(Err), + th: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("th")) + .unwrap() + .map(DataPayload::cast) + .map(Err), + ja: None, + } + } + + #[cfg(feature = "lstm")] + pub(crate) fn try_new_lstm<D>(provider: &D) -> Result<Self, DataError> + where + D: DataProvider<GraphemeClusterBreakDataV1Marker> + + DataProvider<LstmForWordLineAutoV1Marker> + + ?Sized, + { + Ok(Self { + grapheme: provider.load(Default::default())?.take_payload()?, + my: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("my"))? + .map(DataPayload::cast) + .map(Err), + km: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("km"))? + .map(DataPayload::cast) + .map(Err), + lo: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("lo"))? + .map(DataPayload::cast) + .map(Err), + th: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("th"))? + .map(DataPayload::cast) + .map(Err), + ja: None, + }) + } + + #[cfg(feature = "compiled_data")] + pub(crate) fn new_dict() -> Self { + #[allow(clippy::unwrap_used)] + // try_load is infallible if the provider only returns `MissingLocale`. + Self { + grapheme: DataPayload::from_static_ref( + crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1, + ), + my: try_load::<DictionaryForWordLineExtendedV1Marker, _>( + &crate::provider::Baked, + locale!("my"), + ) + .unwrap() + .map(DataPayload::cast) + .map(Ok), + km: try_load::<DictionaryForWordLineExtendedV1Marker, _>( + &crate::provider::Baked, + locale!("km"), + ) + .unwrap() + .map(DataPayload::cast) + .map(Ok), + lo: try_load::<DictionaryForWordLineExtendedV1Marker, _>( + &crate::provider::Baked, + locale!("lo"), + ) + .unwrap() + .map(DataPayload::cast) + .map(Ok), + th: try_load::<DictionaryForWordLineExtendedV1Marker, _>( + &crate::provider::Baked, + locale!("th"), + ) + .unwrap() + .map(DataPayload::cast) + .map(Ok), + ja: try_load::<DictionaryForWordOnlyAutoV1Marker, _>( + &crate::provider::Baked, + locale!("ja"), + ) + .unwrap() + .map(DataPayload::cast), + } + } + + pub(crate) fn try_new_dict<D>(provider: &D) -> Result<Self, DataError> + where + D: DataProvider<GraphemeClusterBreakDataV1Marker> + + DataProvider<DictionaryForWordLineExtendedV1Marker> + + DataProvider<DictionaryForWordOnlyAutoV1Marker> + + ?Sized, + { + Ok(Self { + grapheme: provider.load(Default::default())?.take_payload()?, + my: try_load::<DictionaryForWordLineExtendedV1Marker, D>(provider, locale!("my"))? + .map(DataPayload::cast) + .map(Ok), + km: try_load::<DictionaryForWordLineExtendedV1Marker, D>(provider, locale!("km"))? + .map(DataPayload::cast) + .map(Ok), + lo: try_load::<DictionaryForWordLineExtendedV1Marker, D>(provider, locale!("lo"))? + .map(DataPayload::cast) + .map(Ok), + th: try_load::<DictionaryForWordLineExtendedV1Marker, D>(provider, locale!("th"))? + .map(DataPayload::cast) + .map(Ok), + ja: try_load::<DictionaryForWordOnlyAutoV1Marker, D>(provider, locale!("ja"))? + .map(DataPayload::cast), + }) + } + + #[cfg(feature = "auto")] // Use by WordSegmenter with "auto" enabled. + #[cfg(feature = "compiled_data")] + pub(crate) fn new_auto() -> Self { + #[allow(clippy::unwrap_used)] + // try_load is infallible if the provider only returns `MissingLocale`. + Self { + grapheme: DataPayload::from_static_ref( + crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1, + ), + my: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("my")) + .unwrap() + .map(DataPayload::cast) + .map(Err), + km: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("km")) + .unwrap() + .map(DataPayload::cast) + .map(Err), + lo: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("lo")) + .unwrap() + .map(DataPayload::cast) + .map(Err), + th: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("th")) + .unwrap() + .map(DataPayload::cast) + .map(Err), + ja: try_load::<DictionaryForWordOnlyAutoV1Marker, _>( + &crate::provider::Baked, + locale!("ja"), + ) + .unwrap() + .map(DataPayload::cast), + } + } + + #[cfg(feature = "auto")] // Use by WordSegmenter with "auto" enabled. + pub(crate) fn try_new_auto<D>(provider: &D) -> Result<Self, DataError> + where + D: DataProvider<GraphemeClusterBreakDataV1Marker> + + DataProvider<LstmForWordLineAutoV1Marker> + + DataProvider<DictionaryForWordOnlyAutoV1Marker> + + ?Sized, + { + Ok(Self { + grapheme: provider.load(Default::default())?.take_payload()?, + my: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("my"))? + .map(DataPayload::cast) + .map(Err), + km: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("km"))? + .map(DataPayload::cast) + .map(Err), + lo: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("lo"))? + .map(DataPayload::cast) + .map(Err), + th: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("th"))? + .map(DataPayload::cast) + .map(Err), + ja: try_load::<DictionaryForWordOnlyAutoV1Marker, D>(provider, locale!("ja"))? + .map(DataPayload::cast), + }) + } + + #[cfg(feature = "compiled_data")] + pub(crate) fn new_southeast_asian() -> Self { + #[allow(clippy::unwrap_used)] + // try_load is infallible if the provider only returns `MissingLocale`. + Self { + grapheme: DataPayload::from_static_ref( + crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1, + ), + my: try_load::<DictionaryForWordLineExtendedV1Marker, _>( + &crate::provider::Baked, + locale!("my"), + ) + .unwrap() + .map(DataPayload::cast) + .map(Ok), + km: try_load::<DictionaryForWordLineExtendedV1Marker, _>( + &crate::provider::Baked, + locale!("km"), + ) + .unwrap() + .map(DataPayload::cast) + .map(Ok), + lo: try_load::<DictionaryForWordLineExtendedV1Marker, _>( + &crate::provider::Baked, + locale!("lo"), + ) + .unwrap() + .map(DataPayload::cast) + .map(Ok), + th: try_load::<DictionaryForWordLineExtendedV1Marker, _>( + &crate::provider::Baked, + locale!("th"), + ) + .unwrap() + .map(DataPayload::cast) + .map(Ok), + ja: None, + } + } + + pub(crate) fn try_new_southeast_asian<D>(provider: &D) -> Result<Self, DataError> + where + D: DataProvider<DictionaryForWordLineExtendedV1Marker> + + DataProvider<GraphemeClusterBreakDataV1Marker> + + ?Sized, + { + Ok(Self { + grapheme: provider.load(Default::default())?.take_payload()?, + my: try_load::<DictionaryForWordLineExtendedV1Marker, _>(provider, locale!("my"))? + .map(DataPayload::cast) + .map(Ok), + km: try_load::<DictionaryForWordLineExtendedV1Marker, _>(provider, locale!("km"))? + .map(DataPayload::cast) + .map(Ok), + lo: try_load::<DictionaryForWordLineExtendedV1Marker, _>(provider, locale!("lo"))? + .map(DataPayload::cast) + .map(Ok), + th: try_load::<DictionaryForWordLineExtendedV1Marker, _>(provider, locale!("th"))? + .map(DataPayload::cast) + .map(Ok), + ja: None, + }) + } +} + +fn try_load<M: KeyedDataMarker, P: DataProvider<M> + ?Sized>( + provider: &P, + locale: Locale, +) -> Result<Option<DataPayload<M>>, DataError> { + match provider.load(DataRequest { + locale: &DataLocale::from(locale), + metadata: { + let mut m = DataRequestMetadata::default(); + m.silent = true; + m + }, + }) { + Ok(response) => Ok(Some(response.take_payload()?)), + Err(DataError { + kind: DataErrorKind::MissingLocale, + .. + }) => Ok(None), + Err(e) => Err(e), + } +} + +/// Return UTF-16 segment offset array using dictionary or lstm segmenter. +pub(crate) fn complex_language_segment_utf16( + payloads: &ComplexPayloads, + input: &[u16], +) -> Vec<usize> { + let mut result = Vec::new(); + let mut offset = 0; + for (slice, lang) in LanguageIteratorUtf16::new(input) { + match payloads.select(lang) { + Some(Ok(dict)) => { + result.extend( + DictionarySegmenter::new(dict.get(), payloads.grapheme.get()) + .segment_utf16(slice) + .map(|n| offset + n), + ); + } + #[cfg(feature = "lstm")] + Some(Err(lstm)) => { + result.extend( + LstmSegmenter::new(lstm.get(), payloads.grapheme.get()) + .segment_utf16(slice) + .map(|n| offset + n), + ); + } + #[cfg(not(feature = "lstm"))] + Some(Err(_infallible)) => {} // should be refutable + None => { + result.push(offset + slice.len()); + } + } + offset += slice.len(); + } + result +} + +/// Return UTF-8 segment offset array using dictionary or lstm segmenter. +pub(crate) fn complex_language_segment_str(payloads: &ComplexPayloads, input: &str) -> Vec<usize> { + let mut result = Vec::new(); + let mut offset = 0; + for (slice, lang) in LanguageIterator::new(input) { + match payloads.select(lang) { + Some(Ok(dict)) => { + result.extend( + DictionarySegmenter::new(dict.get(), payloads.grapheme.get()) + .segment_str(slice) + .map(|n| offset + n), + ); + } + #[cfg(feature = "lstm")] + Some(Err(lstm)) => { + result.extend( + LstmSegmenter::new(lstm.get(), payloads.grapheme.get()) + .segment_str(slice) + .map(|n| offset + n), + ); + } + #[cfg(not(feature = "lstm"))] + Some(Err(_infallible)) => {} // should be refutable + None => { + result.push(offset + slice.len()); + } + } + offset += slice.len(); + } + result +} + +#[cfg(test)] +#[cfg(feature = "serde")] +mod tests { + use super::*; + + #[test] + fn thai_word_break() { + const TEST_STR: &str = "ภาษาไทยภาษาไทย"; + let utf16: Vec<u16> = TEST_STR.encode_utf16().collect(); + + let lstm = ComplexPayloads::new_lstm(); + let dict = ComplexPayloads::new_dict(); + + assert_eq!( + complex_language_segment_str(&lstm, TEST_STR), + [12, 21, 33, 42] + ); + assert_eq!( + complex_language_segment_utf16(&lstm, &utf16), + [4, 7, 11, 14] + ); + + assert_eq!( + complex_language_segment_str(&dict, TEST_STR), + [12, 21, 33, 42] + ); + assert_eq!( + complex_language_segment_utf16(&dict, &utf16), + [4, 7, 11, 14] + ); + } +} diff --git a/third_party/rust/icu_segmenter/src/error.rs b/third_party/rust/icu_segmenter/src/error.rs new file mode 100644 index 0000000000..b0f79ec85f --- /dev/null +++ b/third_party/rust/icu_segmenter/src/error.rs @@ -0,0 +1,27 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +use core::fmt::Debug; +use displaydoc::Display; +use icu_provider::DataError; + +#[cfg(feature = "std")] +impl std::error::Error for SegmenterError {} + +/// A list of error outcomes for various operations in this module. +/// +/// Re-exported as [`Error`](crate::Error). +#[derive(Display, Debug, Copy, Clone, PartialEq)] +#[non_exhaustive] +pub enum SegmenterError { + /// An error originating inside of the [data provider](icu_provider). + #[displaydoc("{0}")] + Data(DataError), +} + +impl From<DataError> for SegmenterError { + fn from(e: DataError) -> Self { + Self::Data(e) + } +} diff --git a/third_party/rust/icu_segmenter/src/grapheme.rs b/third_party/rust/icu_segmenter/src/grapheme.rs new file mode 100644 index 0000000000..9cfe0349bc --- /dev/null +++ b/third_party/rust/icu_segmenter/src/grapheme.rs @@ -0,0 +1,270 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +use alloc::vec::Vec; +use icu_provider::prelude::*; + +use crate::indices::{Latin1Indices, Utf16Indices}; +use crate::iterator_helpers::derive_usize_iterator_with_type; +use crate::rule_segmenter::*; +use crate::{provider::*, SegmenterError}; +use utf8_iter::Utf8CharIndices; + +/// Implements the [`Iterator`] trait over the grapheme cluster boundaries of the given string. +/// +/// Lifetimes: +/// +/// - `'l` = lifetime of the segmenter object from which this iterator was created +/// - `'s` = lifetime of the string being segmented +/// +/// The [`Iterator::Item`] is an [`usize`] representing index of a code unit +/// _after_ the boundary (for a boundary at the end of text, this index is the length +/// of the [`str`] or array of code units). +/// +/// For examples of use, see [`GraphemeClusterSegmenter`]. +#[derive(Debug)] +pub struct GraphemeClusterBreakIterator<'l, 's, Y: RuleBreakType<'l, 's> + ?Sized>( + RuleBreakIterator<'l, 's, Y>, +); + +derive_usize_iterator_with_type!(GraphemeClusterBreakIterator); + +/// Grapheme cluster break iterator for an `str` (a UTF-8 string). +/// +/// For examples of use, see [`GraphemeClusterSegmenter`]. +pub type GraphemeClusterBreakIteratorUtf8<'l, 's> = + GraphemeClusterBreakIterator<'l, 's, RuleBreakTypeUtf8>; + +/// Grapheme cluster break iterator for a potentially invalid UTF-8 string. +/// +/// For examples of use, see [`GraphemeClusterSegmenter`]. +pub type GraphemeClusterBreakIteratorPotentiallyIllFormedUtf8<'l, 's> = + GraphemeClusterBreakIterator<'l, 's, RuleBreakTypePotentiallyIllFormedUtf8>; + +/// Grapheme cluster break iterator for a Latin-1 (8-bit) string. +/// +/// For examples of use, see [`GraphemeClusterSegmenter`]. +pub type GraphemeClusterBreakIteratorLatin1<'l, 's> = + GraphemeClusterBreakIterator<'l, 's, RuleBreakTypeLatin1>; + +/// Grapheme cluster break iterator for a UTF-16 string. +/// +/// For examples of use, see [`GraphemeClusterSegmenter`]. +pub type GraphemeClusterBreakIteratorUtf16<'l, 's> = + GraphemeClusterBreakIterator<'l, 's, RuleBreakTypeUtf16>; + +/// Segments a string into grapheme clusters. +/// +/// Supports loading grapheme cluster break data, and creating grapheme cluster break iterators for +/// different string encodings. +/// +/// # Examples +/// +/// Segment a string: +/// +/// ```rust +/// use icu_segmenter::GraphemeClusterSegmenter; +/// let segmenter = GraphemeClusterSegmenter::new(); +/// +/// let breakpoints: Vec<usize> = segmenter.segment_str("Hello 🗺").collect(); +/// // World Map (U+1F5FA) is encoded in four bytes in UTF-8. +/// assert_eq!(&breakpoints, &[0, 1, 2, 3, 4, 5, 6, 10]); +/// ``` +/// +/// Segment a Latin1 byte string: +/// +/// ```rust +/// use icu_segmenter::GraphemeClusterSegmenter; +/// let segmenter = GraphemeClusterSegmenter::new(); +/// +/// let breakpoints: Vec<usize> = +/// segmenter.segment_latin1(b"Hello World").collect(); +/// assert_eq!(&breakpoints, &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]); +/// ``` +/// +/// Successive boundaries can be used to retrieve the grapheme clusters. +/// In particular, the first boundary is always 0, and the last one is the +/// length of the segmented text in code units. +/// +/// ```rust +/// # use icu_segmenter::GraphemeClusterSegmenter; +/// # let segmenter = +/// # GraphemeClusterSegmenter::new(); +/// use itertools::Itertools; +/// let text = "मांजर"; +/// let grapheme_clusters: Vec<&str> = segmenter +/// .segment_str(text) +/// .tuple_windows() +/// .map(|(i, j)| &text[i..j]) +/// .collect(); +/// assert_eq!(&grapheme_clusters, &["मां", "ज", "र"]); +/// ``` +/// +/// This segmenter applies all rules provided to the constructor. +/// Thus, if the data supplied by the provider comprises all +/// [grapheme cluster boundary rules][Rules] from Unicode Standard Annex #29, +/// _Unicode Text Segmentation_, which is the case of default data +/// (both test data and data produced by `icu_datagen`), the `segment_*` +/// functions return extended grapheme cluster boundaries, as opposed to +/// legacy grapheme cluster boundaries. See [_Section 3, Grapheme Cluster +/// Boundaries_][GC], and [_Table 1a, Sample Grapheme Clusters_][Sample_GC], +/// in Unicode Standard Annex #29, _Unicode Text Segmentation_. +/// +/// [Rules]: https://www.unicode.org/reports/tr29/#Grapheme_Cluster_Boundary_Rules +/// [GC]: https://www.unicode.org/reports/tr29/#Grapheme_Cluster_Boundaries +/// [Sample_GC]: https://www.unicode.org/reports/tr29/#Table_Sample_Grapheme_Clusters +/// +/// ```rust +/// use icu_segmenter::GraphemeClusterSegmenter; +/// let segmenter = +/// GraphemeClusterSegmenter::new(); +/// +/// // நி (TAMIL LETTER NA, TAMIL VOWEL SIGN I) is an extended grapheme cluster, +/// // but not a legacy grapheme cluster. +/// let ni = "நி"; +/// let egc_boundaries: Vec<usize> = segmenter.segment_str(ni).collect(); +/// assert_eq!(&egc_boundaries, &[0, ni.len()]); +/// ``` +#[derive(Debug)] +pub struct GraphemeClusterSegmenter { + payload: DataPayload<GraphemeClusterBreakDataV1Marker>, +} + +#[cfg(feature = "compiled_data")] +impl Default for GraphemeClusterSegmenter { + fn default() -> Self { + Self::new() + } +} + +impl GraphemeClusterSegmenter { + /// Constructs a [`GraphemeClusterSegmenter`] with an invariant locale from compiled data. + /// + /// ✨ *Enabled with the `compiled_data` Cargo feature.* + /// + /// [📚 Help choosing a constructor](icu_provider::constructors) + #[cfg(feature = "compiled_data")] + pub fn new() -> Self { + Self { + payload: DataPayload::from_static_ref( + crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1, + ), + } + } + + icu_provider::gen_any_buffer_data_constructors!(locale: skip, options: skip, error: SegmenterError, + #[cfg(skip)] + functions: [ + new, + try_new_with_any_provider, + try_new_with_buffer_provider, + try_new_unstable, + Self, + ]); + + #[doc = icu_provider::gen_any_buffer_unstable_docs!(UNSTABLE, Self::new)] + pub fn try_new_unstable<D>(provider: &D) -> Result<Self, SegmenterError> + where + D: DataProvider<GraphemeClusterBreakDataV1Marker> + ?Sized, + { + let payload = provider.load(Default::default())?.take_payload()?; + Ok(Self { payload }) + } + + /// Creates a grapheme cluster break iterator for an `str` (a UTF-8 string). + pub fn segment_str<'l, 's>( + &'l self, + input: &'s str, + ) -> GraphemeClusterBreakIteratorUtf8<'l, 's> { + GraphemeClusterSegmenter::new_and_segment_str(input, self.payload.get()) + } + + /// Creates a grapheme cluster break iterator from grapheme cluster rule payload. + /// + /// There are always breakpoints at 0 and the string length, or only at 0 for the empty string. + pub(crate) fn new_and_segment_str<'l, 's>( + input: &'s str, + payload: &'l RuleBreakDataV1<'l>, + ) -> GraphemeClusterBreakIteratorUtf8<'l, 's> { + GraphemeClusterBreakIterator(RuleBreakIterator { + iter: input.char_indices(), + len: input.len(), + current_pos_data: None, + result_cache: Vec::new(), + data: payload, + complex: None, + boundary_property: 0, + }) + } + + /// Creates a grapheme cluster break iterator for a potentially ill-formed UTF8 string + /// + /// Invalid characters are treated as REPLACEMENT CHARACTER + /// + /// There are always breakpoints at 0 and the string length, or only at 0 for the empty string. + pub fn segment_utf8<'l, 's>( + &'l self, + input: &'s [u8], + ) -> GraphemeClusterBreakIteratorPotentiallyIllFormedUtf8<'l, 's> { + GraphemeClusterBreakIterator(RuleBreakIterator { + iter: Utf8CharIndices::new(input), + len: input.len(), + current_pos_data: None, + result_cache: Vec::new(), + data: self.payload.get(), + complex: None, + boundary_property: 0, + }) + } + /// Creates a grapheme cluster break iterator for a Latin-1 (8-bit) string. + /// + /// There are always breakpoints at 0 and the string length, or only at 0 for the empty string. + pub fn segment_latin1<'l, 's>( + &'l self, + input: &'s [u8], + ) -> GraphemeClusterBreakIteratorLatin1<'l, 's> { + GraphemeClusterBreakIterator(RuleBreakIterator { + iter: Latin1Indices::new(input), + len: input.len(), + current_pos_data: None, + result_cache: Vec::new(), + data: self.payload.get(), + complex: None, + boundary_property: 0, + }) + } + + /// Creates a grapheme cluster break iterator for a UTF-16 string. + /// + /// There are always breakpoints at 0 and the string length, or only at 0 for the empty string. + pub fn segment_utf16<'l, 's>( + &'l self, + input: &'s [u16], + ) -> GraphemeClusterBreakIteratorUtf16<'l, 's> { + GraphemeClusterSegmenter::new_and_segment_utf16(input, self.payload.get()) + } + + /// Creates a grapheme cluster break iterator from grapheme cluster rule payload. + pub(crate) fn new_and_segment_utf16<'l, 's>( + input: &'s [u16], + payload: &'l RuleBreakDataV1<'l>, + ) -> GraphemeClusterBreakIteratorUtf16<'l, 's> { + GraphemeClusterBreakIterator(RuleBreakIterator { + iter: Utf16Indices::new(input), + len: input.len(), + current_pos_data: None, + result_cache: Vec::new(), + data: payload, + complex: None, + boundary_property: 0, + }) + } +} + +#[test] +fn empty_string() { + let segmenter = GraphemeClusterSegmenter::new(); + let breaks: Vec<usize> = segmenter.segment_str("").collect(); + assert_eq!(breaks, [0]); +} diff --git a/third_party/rust/icu_segmenter/src/indices.rs b/third_party/rust/icu_segmenter/src/indices.rs new file mode 100644 index 0000000000..2ea6b81fc6 --- /dev/null +++ b/third_party/rust/icu_segmenter/src/indices.rs @@ -0,0 +1,129 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +/// Similar to [`core::str::CharIndices`] for Latin-1 strings, represented as `[u8]`. +/// +/// Contrary to [`core::str::CharIndices`], the second element of the +/// [`Iterator::Item`] is a [`u8`], representing a Unicode scalar value in the +/// range U+0000–U+00FF. +#[derive(Clone, Debug)] +pub struct Latin1Indices<'a> { + front_offset: usize, + iter: &'a [u8], +} + +impl<'a> Latin1Indices<'a> { + pub fn new(input: &'a [u8]) -> Self { + Self { + front_offset: 0, + iter: input, + } + } +} + +impl<'a> Iterator for Latin1Indices<'a> { + type Item = (usize, u8); + + #[inline] + fn next(&mut self) -> Option<(usize, u8)> { + self.iter.get(self.front_offset).map(|ch| { + self.front_offset += 1; + (self.front_offset - 1, *ch) + }) + } +} + +/// Similar to [`core::str::CharIndices`] for UTF-16 strings, represented as `[u16]`. +/// +/// Contrary to [`core::str::CharIndices`], the second element of the +/// [`Iterator::Item`] is a Unicode code point represented by a [`u32`], +/// rather than a Unicode scalar value represented by a [`char`], because this +/// iterator preserves unpaired surrogates. +#[derive(Clone, Debug)] +pub struct Utf16Indices<'a> { + front_offset: usize, + iter: &'a [u16], +} + +impl<'a> Utf16Indices<'a> { + pub fn new(input: &'a [u16]) -> Self { + Self { + front_offset: 0, + iter: input, + } + } +} + +impl<'a> Iterator for Utf16Indices<'a> { + type Item = (usize, u32); + + #[inline] + fn next(&mut self) -> Option<(usize, u32)> { + let (index, ch) = self.iter.get(self.front_offset).map(|ch| { + self.front_offset += 1; + (self.front_offset - 1, *ch) + })?; + + let mut ch = ch as u32; + if (ch & 0xfc00) != 0xd800 { + return Some((index, ch)); + } + + if let Some(next) = self.iter.get(self.front_offset) { + let next = *next as u32; + if (next & 0xfc00) == 0xdc00 { + // Combine low and high surrogates to UTF-32 code point. + ch = ((ch & 0x3ff) << 10) + (next & 0x3ff) + 0x10000; + self.front_offset += 1; + } + } + Some((index, ch)) + } +} + +#[cfg(test)] +mod tests { + use crate::indices::*; + + #[test] + fn latin1_indices() { + let latin1 = [0x30, 0x31, 0x32]; + let mut indices = Latin1Indices::new(&latin1); + let n = indices.next().unwrap(); + assert_eq!(n.0, 0); + assert_eq!(n.1, 0x30); + let n = indices.next().unwrap(); + assert_eq!(n.0, 1); + assert_eq!(n.1, 0x31); + let n = indices.next().unwrap(); + assert_eq!(n.0, 2); + assert_eq!(n.1, 0x32); + let n = indices.next(); + assert_eq!(n, None); + } + + #[test] + fn utf16_indices() { + let utf16 = [0xd83d, 0xde03, 0x0020, 0xd83c, 0xdf00, 0xd800, 0x0020]; + let mut indices = Utf16Indices::new(&utf16); + let n = indices.next().unwrap(); + assert_eq!(n.0, 0); + assert_eq!(n.1, 0x1f603); + let n = indices.next().unwrap(); + assert_eq!(n.0, 2); + assert_eq!(n.1, 0x20); + let n = indices.next().unwrap(); + assert_eq!(n.0, 3); + assert_eq!(n.1, 0x1f300); + // This is invalid surrogate pair. + let n = indices.next().unwrap(); + assert_eq!(n.0, 5); + assert_eq!(n.1, 0xd800); + let n = indices.next().unwrap(); + assert_eq!(n.0, 6); + assert_eq!(n.1, 0x0020); + let n = indices.next(); + assert_eq!(n, None); + } +} diff --git a/third_party/rust/icu_segmenter/src/iterator_helpers.rs b/third_party/rust/icu_segmenter/src/iterator_helpers.rs new file mode 100644 index 0000000000..593a4702ca --- /dev/null +++ b/third_party/rust/icu_segmenter/src/iterator_helpers.rs @@ -0,0 +1,19 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +//! Macros and utilities to help implement the various iterator types. + +macro_rules! derive_usize_iterator_with_type { + ($ty:tt) => { + impl<'l, 's, Y: RuleBreakType<'l, 's> + ?Sized> Iterator for $ty<'l, 's, Y> { + type Item = usize; + #[inline] + fn next(&mut self) -> Option<Self::Item> { + self.0.next() + } + } + }; +} + +pub(crate) use derive_usize_iterator_with_type; diff --git a/third_party/rust/icu_segmenter/src/lib.rs b/third_party/rust/icu_segmenter/src/lib.rs new file mode 100644 index 0000000000..b286c4e312 --- /dev/null +++ b/third_party/rust/icu_segmenter/src/lib.rs @@ -0,0 +1,174 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +//! Segment strings by lines, graphemes, words, and sentences. +//! +//! This module is published as its own crate ([`icu_segmenter`](https://docs.rs/icu_segmenter/latest/icu_segmenter/)) +//! and as part of the [`icu`](https://docs.rs/icu/latest/icu/) crate. See the latter for more details on the ICU4X project. +//! +//! This module contains segmenter implementation for the following rules. +//! +//! - Line segmenter that is compatible with [Unicode Standard Annex #14][UAX14], _Unicode Line +//! Breaking Algorithm_, with options to tailor line-breaking behavior for CSS [`line-break`] and +//! [`word-break`] properties. +//! - Grapheme cluster segmenter, word segmenter, and sentence segmenter that are compatible with +//! [Unicode Standard Annex #29][UAX29], _Unicode Text Segmentation_. +//! +//! [UAX14]: https://www.unicode.org/reports/tr14/ +//! [UAX29]: https://www.unicode.org/reports/tr29/ +//! [`line-break`]: https://drafts.csswg.org/css-text-3/#line-break-property +//! [`word-break`]: https://drafts.csswg.org/css-text-3/#word-break-property +//! +//! # Examples +//! +//! ## Line Break +//! +//! Find line break opportunities: +//! +//!```rust +//! use icu::segmenter::LineSegmenter; +//! +//! let segmenter = LineSegmenter::new_auto(); +//! +//! let breakpoints: Vec<usize> = segmenter +//! .segment_str("Hello World. Xin chào thế giới!") +//! .collect(); +//! assert_eq!(&breakpoints, &[0, 6, 13, 17, 23, 29, 36]); +//! ``` +//! +//! See [`LineSegmenter`] for more examples. +//! +//! ## Grapheme Cluster Break +//! +//! Find all grapheme cluster boundaries: +//! +//!```rust +//! use icu::segmenter::GraphemeClusterSegmenter; +//! +//! let segmenter = GraphemeClusterSegmenter::new(); +//! +//! let breakpoints: Vec<usize> = segmenter +//! .segment_str("Hello World. Xin chào thế giới!") +//! .collect(); +//! assert_eq!( +//! &breakpoints, +//! &[ +//! 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, +//! 19, 21, 22, 23, 24, 25, 28, 29, 30, 31, 34, 35, 36 +//! ] +//! ); +//! ``` +//! +//! See [`GraphemeClusterSegmenter`] for more examples. +//! +//! ## Word Break +//! +//! Find all word boundaries: +//! +//!```rust +//! use icu::segmenter::WordSegmenter; +//! +//! let segmenter = WordSegmenter::new_auto(); +//! +//! let breakpoints: Vec<usize> = segmenter +//! .segment_str("Hello World. Xin chào thế giới!") +//! .collect(); +//! assert_eq!( +//! &breakpoints, +//! &[0, 5, 6, 11, 12, 13, 16, 17, 22, 23, 28, 29, 35, 36] +//! ); +//! ``` +//! +//! See [`WordSegmenter`] for more examples. +//! +//! ## Sentence Break +//! +//! Segment the string into sentences: +//! +//!```rust +//! use icu::segmenter::SentenceSegmenter; +//! +//! let segmenter = SentenceSegmenter::new(); +//! +//! let breakpoints: Vec<usize> = segmenter +//! .segment_str("Hello World. Xin chào thế giới!") +//! .collect(); +//! assert_eq!(&breakpoints, &[0, 13, 36]); +//! ``` +//! +//! See [`SentenceSegmenter`] for more examples. + +// https://github.com/unicode-org/icu4x/blob/main/docs/process/boilerplate.md#library-annotations +#![cfg_attr(not(any(test, feature = "std")), no_std)] +#![cfg_attr( + not(test), + deny( + clippy::indexing_slicing, + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::exhaustive_structs, + clippy::exhaustive_enums, + missing_debug_implementations, + ) +)] +#![warn(missing_docs)] + +extern crate alloc; + +mod complex; +mod error; +mod indices; +mod iterator_helpers; +mod rule_segmenter; + +mod grapheme; +mod line; +mod sentence; +mod word; + +pub mod provider; + +// icu_datagen uses symbols, but we don't want to expose this implementation detail to the users. +#[doc(hidden)] +pub mod symbols; + +// Main Segmenter and BreakIterator public types +pub use crate::grapheme::GraphemeClusterBreakIterator; +pub use crate::grapheme::GraphemeClusterSegmenter; +pub use crate::line::LineBreakIterator; +pub use crate::line::LineSegmenter; +pub use crate::sentence::SentenceBreakIterator; +pub use crate::sentence::SentenceSegmenter; +pub use crate::word::WordBreakIterator; +pub use crate::word::WordSegmenter; + +// Options structs and enums +pub use crate::line::LineBreakOptions; +pub use crate::line::LineBreakStrictness; +pub use crate::line::LineBreakWordOption; +pub use crate::word::WordType; + +// Typedefs +pub use crate::grapheme::GraphemeClusterBreakIteratorLatin1; +pub use crate::grapheme::GraphemeClusterBreakIteratorPotentiallyIllFormedUtf8; +pub use crate::grapheme::GraphemeClusterBreakIteratorUtf16; +pub use crate::grapheme::GraphemeClusterBreakIteratorUtf8; +pub use crate::line::LineBreakIteratorLatin1; +pub use crate::line::LineBreakIteratorPotentiallyIllFormedUtf8; +pub use crate::line::LineBreakIteratorUtf16; +pub use crate::line::LineBreakIteratorUtf8; +pub use crate::sentence::SentenceBreakIteratorLatin1; +pub use crate::sentence::SentenceBreakIteratorPotentiallyIllFormedUtf8; +pub use crate::sentence::SentenceBreakIteratorUtf16; +pub use crate::sentence::SentenceBreakIteratorUtf8; +pub use crate::word::WordBreakIteratorLatin1; +pub use crate::word::WordBreakIteratorPotentiallyIllFormedUtf8; +pub use crate::word::WordBreakIteratorUtf16; +pub use crate::word::WordBreakIteratorUtf8; + +pub use error::SegmenterError; + +#[doc(no_inline)] +pub use SegmenterError as Error; diff --git a/third_party/rust/icu_segmenter/src/line.rs b/third_party/rust/icu_segmenter/src/line.rs new file mode 100644 index 0000000000..f93e31b13d --- /dev/null +++ b/third_party/rust/icu_segmenter/src/line.rs @@ -0,0 +1,1641 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +use crate::complex::*; +use crate::indices::*; +use crate::provider::*; +use crate::symbols::*; +use crate::SegmenterError; +use alloc::string::String; +use alloc::vec; +use alloc::vec::Vec; +use core::char; +use core::str::CharIndices; +use icu_provider::prelude::*; +use utf8_iter::Utf8CharIndices; + +/// An enum specifies the strictness of line-breaking rules. It can be passed as +/// an argument when creating a line segmenter. +/// +/// Each enum value has the same meaning with respect to the `line-break` +/// property values in the CSS Text spec. See the details in +/// <https://drafts.csswg.org/css-text-3/#line-break-property>. +#[non_exhaustive] +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub enum LineBreakStrictness { + /// Breaks text using the least restrictive set of line-breaking rules. + /// Typically used for short lines, such as in newspapers. + /// <https://drafts.csswg.org/css-text-3/#valdef-line-break-loose> + Loose, + + /// Breaks text using the most common set of line-breaking rules. + /// <https://drafts.csswg.org/css-text-3/#valdef-line-break-normal> + Normal, + + /// Breaks text using the most stringent set of line-breaking rules. + /// <https://drafts.csswg.org/css-text-3/#valdef-line-break-strict> + /// + /// This is the default behaviour of the Unicode Line Breaking Algorithm, + /// resolving class [CJ](https://www.unicode.org/reports/tr14/#CJ) to + /// [NS](https://www.unicode.org/reports/tr14/#NS); + /// see rule [LB1](https://www.unicode.org/reports/tr14/#LB1). + Strict, + + /// Breaks text assuming there is a soft wrap opportunity around every + /// typographic character unit, disregarding any prohibition against line + /// breaks. See more details in + /// <https://drafts.csswg.org/css-text-3/#valdef-line-break-anywhere>. + Anywhere, +} + +/// An enum specifies the line break opportunities between letters. It can be +/// passed as an argument when creating a line segmenter. +/// +/// Each enum value has the same meaning with respect to the `word-break` +/// property values in the CSS Text spec. See the details in +/// <https://drafts.csswg.org/css-text-3/#word-break-property> +#[non_exhaustive] +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub enum LineBreakWordOption { + /// Words break according to their customary rules. See the details in + /// <https://drafts.csswg.org/css-text-3/#valdef-word-break-normal>. + Normal, + + /// Breaking is allowed within "words". + /// <https://drafts.csswg.org/css-text-3/#valdef-word-break-break-all> + BreakAll, + + /// Breaking is forbidden within "word". + /// <https://drafts.csswg.org/css-text-3/#valdef-word-break-keep-all> + KeepAll, +} + +/// Options to tailor line-breaking behavior. +#[non_exhaustive] +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct LineBreakOptions { + /// Strictness of line-breaking rules. See [`LineBreakStrictness`]. + pub strictness: LineBreakStrictness, + + /// Line break opportunities between letters. See [`LineBreakWordOption`]. + pub word_option: LineBreakWordOption, + + /// Use `true` as a hint to the line segmenter that the writing + /// system is Chinese or Japanese. This allows more break opportunities when + /// `LineBreakStrictness` is `Normal` or `Loose`. See + /// <https://drafts.csswg.org/css-text-3/#line-break-property> for details. + /// + /// This option has no effect in Latin-1 mode. + pub ja_zh: bool, +} + +impl Default for LineBreakOptions { + fn default() -> Self { + Self { + strictness: LineBreakStrictness::Strict, + word_option: LineBreakWordOption::Normal, + ja_zh: false, + } + } +} + +/// Line break iterator for an `str` (a UTF-8 string). +/// +/// For examples of use, see [`LineSegmenter`]. +pub type LineBreakIteratorUtf8<'l, 's> = LineBreakIterator<'l, 's, LineBreakTypeUtf8>; + +/// Line break iterator for a potentially invalid UTF-8 string. +/// +/// For examples of use, see [`LineSegmenter`]. +pub type LineBreakIteratorPotentiallyIllFormedUtf8<'l, 's> = + LineBreakIterator<'l, 's, LineBreakTypePotentiallyIllFormedUtf8>; + +/// Line break iterator for a Latin-1 (8-bit) string. +/// +/// For examples of use, see [`LineSegmenter`]. +pub type LineBreakIteratorLatin1<'l, 's> = LineBreakIterator<'l, 's, LineBreakTypeLatin1>; + +/// Line break iterator for a UTF-16 string. +/// +/// For examples of use, see [`LineSegmenter`]. +pub type LineBreakIteratorUtf16<'l, 's> = LineBreakIterator<'l, 's, LineBreakTypeUtf16>; + +/// Supports loading line break data, and creating line break iterators for different string +/// encodings. +/// +/// The segmenter returns mandatory breaks (as defined by [definition LD7][LD7] of +/// Unicode Standard Annex #14, _Unicode Line Breaking Algorithm_) as well as +/// line break opportunities ([definition LD3][LD3]). +/// It does not distinguish them. Callers requiring that distinction can check +/// the Line_Break property of the code point preceding the break against those +/// listed in rules [LB4][LB4] and [LB5][LB5], special-casing the end of text +/// according to [LB3][LB3]. +/// +/// For consistency with the grapheme, word, and sentence segmenters, there is +/// always a breakpoint returned at index 0, but this breakpoint is not a +/// meaningful line break opportunity. +/// +/// [LD3]: https://www.unicode.org/reports/tr14/#LD3 +/// [LD7]: https://www.unicode.org/reports/tr14/#LD7 +/// [LB3]: https://www.unicode.org/reports/tr14/#LB3 +/// [LB4]: https://www.unicode.org/reports/tr14/#LB4 +/// [LB5]: https://www.unicode.org/reports/tr14/#LB5 +/// +/// ```rust +/// # use icu_segmenter::LineSegmenter; +/// # +/// # let segmenter = LineSegmenter::new_auto(); +/// # +/// let text = "Summary\r\nThis annex…"; +/// let breakpoints: Vec<usize> = segmenter.segment_str(text).collect(); +/// // 9 and 22 are mandatory breaks, 14 is a line break opportunity. +/// assert_eq!(&breakpoints, &[0, 9, 14, 22]); +/// ``` +/// +/// # Examples +/// +/// Segment a string with default options: +/// +/// ```rust +/// use icu_segmenter::LineSegmenter; +/// +/// let segmenter = LineSegmenter::new_auto(); +/// +/// let breakpoints: Vec<usize> = +/// segmenter.segment_str("Hello World").collect(); +/// assert_eq!(&breakpoints, &[0, 6, 11]); +/// ``` +/// +/// Segment a string with CSS option overrides: +/// +/// ```rust +/// use icu_segmenter::{ +/// LineBreakOptions, LineBreakStrictness, LineBreakWordOption, +/// LineSegmenter, +/// }; +/// +/// let mut options = LineBreakOptions::default(); +/// options.strictness = LineBreakStrictness::Strict; +/// options.word_option = LineBreakWordOption::BreakAll; +/// options.ja_zh = false; +/// let segmenter = LineSegmenter::new_auto_with_options(options); +/// +/// let breakpoints: Vec<usize> = +/// segmenter.segment_str("Hello World").collect(); +/// assert_eq!(&breakpoints, &[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11]); +/// ``` +/// +/// Segment a Latin1 byte string: +/// +/// ```rust +/// use icu_segmenter::LineSegmenter; +/// +/// let segmenter = LineSegmenter::new_auto(); +/// +/// let breakpoints: Vec<usize> = +/// segmenter.segment_latin1(b"Hello World").collect(); +/// assert_eq!(&breakpoints, &[0, 6, 11]); +/// ``` +/// +/// Separate mandatory breaks from the break opportunities: +/// +/// ```rust +/// use icu::properties::{maps, LineBreak}; +/// use icu_segmenter::LineSegmenter; +/// +/// # let segmenter = LineSegmenter::new_auto(); +/// # +/// let text = "Summary\r\nThis annex…"; +/// +/// let mandatory_breaks: Vec<usize> = segmenter +/// .segment_str(text) +/// .into_iter() +/// .filter(|&i| { +/// text[..i].chars().next_back().map_or(false, |c| { +/// matches!( +/// maps::line_break().get(c), +/// LineBreak::MandatoryBreak +/// | LineBreak::CarriageReturn +/// | LineBreak::LineFeed +/// | LineBreak::NextLine +/// ) || i == text.len() +/// }) +/// }) +/// .collect(); +/// assert_eq!(&mandatory_breaks, &[9, 22]); +/// ``` +#[derive(Debug)] +pub struct LineSegmenter { + options: LineBreakOptions, + payload: DataPayload<LineBreakDataV1Marker>, + complex: ComplexPayloads, +} + +impl LineSegmenter { + /// Constructs a [`LineSegmenter`] with an invariant locale and the best available compiled data for + /// complex scripts (Khmer, Lao, Myanmar, and Thai). + /// + /// The current behavior, which is subject to change, is to use the LSTM model when available. + /// + /// See also [`Self::new_auto_with_options`]. + /// + /// ✨ *Enabled with the `compiled_data` and `auto` Cargo features.* + /// + /// [📚 Help choosing a constructor](icu_provider::constructors) + #[cfg(feature = "compiled_data")] + #[cfg(feature = "auto")] + pub fn new_auto() -> Self { + Self::new_auto_with_options(Default::default()) + } + + #[cfg(feature = "auto")] + icu_provider::gen_any_buffer_data_constructors!( + locale: skip, + options: skip, + error: SegmenterError, + #[cfg(skip)] + functions: [ + new_auto, + try_new_auto_with_any_provider, + try_new_auto_with_buffer_provider, + try_new_auto_unstable, + Self, + ] + ); + + #[doc = icu_provider::gen_any_buffer_unstable_docs!(UNSTABLE, Self::new_auto)] + #[cfg(feature = "auto")] + pub fn try_new_auto_unstable<D>(provider: &D) -> Result<Self, SegmenterError> + where + D: DataProvider<LineBreakDataV1Marker> + + DataProvider<LstmForWordLineAutoV1Marker> + + DataProvider<GraphemeClusterBreakDataV1Marker> + + ?Sized, + { + Self::try_new_auto_with_options_unstable(provider, Default::default()) + } + + /// Constructs a [`LineSegmenter`] with an invariant locale and compiled LSTM data for + /// complex scripts (Khmer, Lao, Myanmar, and Thai). + /// + /// The LSTM, or Long Term Short Memory, is a machine learning model. It is smaller than + /// the full dictionary but more expensive during segmentation (inference). + /// + /// See also [`Self::new_lstm_with_options`]. + /// + /// ✨ *Enabled with the `compiled_data` and `lstm` Cargo features.* + /// + /// [📚 Help choosing a constructor](icu_provider::constructors) + #[cfg(feature = "compiled_data")] + #[cfg(feature = "lstm")] + pub fn new_lstm() -> Self { + Self::new_lstm_with_options(Default::default()) + } + + #[cfg(feature = "lstm")] + icu_provider::gen_any_buffer_data_constructors!( + locale: skip, + options: skip, + error: SegmenterError, + #[cfg(skip)] + functions: [ + new_lstm, + try_new_lstm_with_any_provider, + try_new_lstm_with_buffer_provider, + try_new_lstm_unstable, + Self, + ] + ); + + #[cfg(feature = "lstm")] + #[doc = icu_provider::gen_any_buffer_unstable_docs!(UNSTABLE, Self::new_lstm)] + pub fn try_new_lstm_unstable<D>(provider: &D) -> Result<Self, SegmenterError> + where + D: DataProvider<LineBreakDataV1Marker> + + DataProvider<LstmForWordLineAutoV1Marker> + + DataProvider<GraphemeClusterBreakDataV1Marker> + + ?Sized, + { + Self::try_new_lstm_with_options_unstable(provider, Default::default()) + } + + /// Constructs a [`LineSegmenter`] with an invariant locale and compiled dictionary data for + /// complex scripts (Khmer, Lao, Myanmar, and Thai). + /// + /// The dictionary model uses a list of words to determine appropriate breakpoints. It is + /// faster than the LSTM model but requires more data. + /// + /// See also [`Self::new_dictionary_with_options`]. + /// + /// ✨ *Enabled with the `compiled_data` Cargo feature.* + /// + /// [📚 Help choosing a constructor](icu_provider::constructors) + #[cfg(feature = "compiled_data")] + pub fn new_dictionary() -> Self { + Self::new_dictionary_with_options(Default::default()) + } + + icu_provider::gen_any_buffer_data_constructors!( + locale: skip, + options: skip, + error: SegmenterError, + #[cfg(skip)] + functions: [ + new_dictionary, + try_new_dictionary_with_any_provider, + try_new_dictionary_with_buffer_provider, + try_new_dictionary_unstable, + Self, + ] + ); + + #[doc = icu_provider::gen_any_buffer_unstable_docs!(UNSTABLE, Self::new_dictionary)] + pub fn try_new_dictionary_unstable<D>(provider: &D) -> Result<Self, SegmenterError> + where + D: DataProvider<LineBreakDataV1Marker> + + DataProvider<DictionaryForWordLineExtendedV1Marker> + + DataProvider<GraphemeClusterBreakDataV1Marker> + + ?Sized, + { + Self::try_new_dictionary_with_options_unstable(provider, Default::default()) + } + + /// Constructs a [`LineSegmenter`] with an invariant locale, custom [`LineBreakOptions`], and + /// the best available compiled data for complex scripts (Khmer, Lao, Myanmar, and Thai). + /// + /// The current behavior, which is subject to change, is to use the LSTM model when available. + /// + /// See also [`Self::new_auto`]. + /// + /// ✨ *Enabled with the `compiled_data` and `auto` Cargo features.* + /// + /// [📚 Help choosing a constructor](icu_provider::constructors) + #[cfg(feature = "auto")] + #[cfg(feature = "compiled_data")] + pub fn new_auto_with_options(options: LineBreakOptions) -> Self { + Self::new_lstm_with_options(options) + } + + #[cfg(feature = "auto")] + icu_provider::gen_any_buffer_data_constructors!( + locale: skip, + options: LineBreakOptions, + error: SegmenterError, + #[cfg(skip)] + functions: [ + new_auto_with_options, + try_new_auto_with_options_with_any_provider, + try_new_auto_with_options_with_buffer_provider, + try_new_auto_with_options_unstable, + Self, + ] + ); + + #[cfg(feature = "auto")] + #[doc = icu_provider::gen_any_buffer_unstable_docs!(UNSTABLE, Self::new_auto_with_options)] + pub fn try_new_auto_with_options_unstable<D>( + provider: &D, + options: LineBreakOptions, + ) -> Result<Self, SegmenterError> + where + D: DataProvider<LineBreakDataV1Marker> + + DataProvider<LstmForWordLineAutoV1Marker> + + DataProvider<GraphemeClusterBreakDataV1Marker> + + ?Sized, + { + Self::try_new_lstm_with_options_unstable(provider, options) + } + + /// Constructs a [`LineSegmenter`] with an invariant locale, custom [`LineBreakOptions`], and + /// compiled LSTM data for complex scripts (Khmer, Lao, Myanmar, and Thai). + /// + /// The LSTM, or Long Term Short Memory, is a machine learning model. It is smaller than + /// the full dictionary but more expensive during segmentation (inference). + /// + /// See also [`Self::new_dictionary`]. + /// + /// ✨ *Enabled with the `compiled_data` and `lstm` Cargo features.* + /// + /// [📚 Help choosing a constructor](icu_provider::constructors) + #[cfg(feature = "lstm")] + #[cfg(feature = "compiled_data")] + pub fn new_lstm_with_options(options: LineBreakOptions) -> Self { + Self { + options, + payload: DataPayload::from_static_ref( + crate::provider::Baked::SINGLETON_SEGMENTER_LINE_V1, + ), + complex: ComplexPayloads::new_lstm(), + } + } + + #[cfg(feature = "lstm")] + icu_provider::gen_any_buffer_data_constructors!( + locale: skip, + options: LineBreakOptions, + error: SegmenterError, + #[cfg(skip)] + functions: [ + try_new_lstm_with_options, + try_new_lstm_with_options_with_any_provider, + try_new_lstm_with_options_with_buffer_provider, + try_new_lstm_with_options_unstable, + Self, + ] + ); + + #[cfg(feature = "lstm")] + #[doc = icu_provider::gen_any_buffer_unstable_docs!(UNSTABLE, Self::new_lstm_with_options)] + pub fn try_new_lstm_with_options_unstable<D>( + provider: &D, + options: LineBreakOptions, + ) -> Result<Self, SegmenterError> + where + D: DataProvider<LineBreakDataV1Marker> + + DataProvider<LstmForWordLineAutoV1Marker> + + DataProvider<GraphemeClusterBreakDataV1Marker> + + ?Sized, + { + Ok(Self { + options, + payload: provider.load(Default::default())?.take_payload()?, + complex: ComplexPayloads::try_new_lstm(provider)?, + }) + } + + /// Constructs a [`LineSegmenter`] with an invariant locale, custom [`LineBreakOptions`], and + /// compiled dictionary data for complex scripts (Khmer, Lao, Myanmar, and Thai). + /// + /// The dictionary model uses a list of words to determine appropriate breakpoints. It is + /// faster than the LSTM model but requires more data. + /// + /// See also [`Self::new_dictionary`]. + /// + /// ✨ *Enabled with the `compiled_data` Cargo feature.* + /// + /// [📚 Help choosing a constructor](icu_provider::constructors) + #[cfg(feature = "compiled_data")] + pub fn new_dictionary_with_options(options: LineBreakOptions) -> Self { + Self { + options, + payload: DataPayload::from_static_ref( + crate::provider::Baked::SINGLETON_SEGMENTER_LINE_V1, + ), + // Line segmenter doesn't need to load CJ dictionary because UAX 14 rules handles CJK + // characters [1]. Southeast Asian languages however require complex context analysis + // [2]. + // + // [1]: https://www.unicode.org/reports/tr14/#ID + // [2]: https://www.unicode.org/reports/tr14/#SA + complex: ComplexPayloads::new_southeast_asian(), + } + } + + icu_provider::gen_any_buffer_data_constructors!( + locale: skip, + options: LineBreakOptions, + error: SegmenterError, + #[cfg(skip)] + functions: [ + new_dictionary_with_options, + try_new_dictionary_with_options_with_any_provider, + try_new_dictionary_with_options_with_buffer_provider, + try_new_dictionary_with_options_unstable, + Self, + ] + ); + + #[doc = icu_provider::gen_any_buffer_unstable_docs!(UNSTABLE, Self::new_dictionary_with_options)] + pub fn try_new_dictionary_with_options_unstable<D>( + provider: &D, + options: LineBreakOptions, + ) -> Result<Self, SegmenterError> + where + D: DataProvider<LineBreakDataV1Marker> + + DataProvider<DictionaryForWordLineExtendedV1Marker> + + DataProvider<GraphemeClusterBreakDataV1Marker> + + ?Sized, + { + Ok(Self { + options, + payload: provider.load(Default::default())?.take_payload()?, + // Line segmenter doesn't need to load CJ dictionary because UAX 14 rules handles CJK + // characters [1]. Southeast Asian languages however require complex context analysis + // [2]. + // + // [1]: https://www.unicode.org/reports/tr14/#ID + // [2]: https://www.unicode.org/reports/tr14/#SA + complex: ComplexPayloads::try_new_southeast_asian(provider)?, + }) + } + + /// Creates a line break iterator for an `str` (a UTF-8 string). + /// + /// There are always breakpoints at 0 and the string length, or only at 0 for the empty string. + pub fn segment_str<'l, 's>(&'l self, input: &'s str) -> LineBreakIteratorUtf8<'l, 's> { + LineBreakIterator { + iter: input.char_indices(), + len: input.len(), + current_pos_data: None, + result_cache: Vec::new(), + data: self.payload.get(), + options: &self.options, + complex: &self.complex, + } + } + /// Creates a line break iterator for a potentially ill-formed UTF8 string + /// + /// Invalid characters are treated as REPLACEMENT CHARACTER + /// + /// There are always breakpoints at 0 and the string length, or only at 0 for the empty string. + pub fn segment_utf8<'l, 's>( + &'l self, + input: &'s [u8], + ) -> LineBreakIteratorPotentiallyIllFormedUtf8<'l, 's> { + LineBreakIterator { + iter: Utf8CharIndices::new(input), + len: input.len(), + current_pos_data: None, + result_cache: Vec::new(), + data: self.payload.get(), + options: &self.options, + complex: &self.complex, + } + } + /// Creates a line break iterator for a Latin-1 (8-bit) string. + /// + /// There are always breakpoints at 0 and the string length, or only at 0 for the empty string. + pub fn segment_latin1<'l, 's>(&'l self, input: &'s [u8]) -> LineBreakIteratorLatin1<'l, 's> { + LineBreakIterator { + iter: Latin1Indices::new(input), + len: input.len(), + current_pos_data: None, + result_cache: Vec::new(), + data: self.payload.get(), + options: &self.options, + complex: &self.complex, + } + } + + /// Creates a line break iterator for a UTF-16 string. + /// + /// There are always breakpoints at 0 and the string length, or only at 0 for the empty string. + pub fn segment_utf16<'l, 's>(&'l self, input: &'s [u16]) -> LineBreakIteratorUtf16<'l, 's> { + LineBreakIterator { + iter: Utf16Indices::new(input), + len: input.len(), + current_pos_data: None, + result_cache: Vec::new(), + data: self.payload.get(), + options: &self.options, + complex: &self.complex, + } + } +} + +fn get_linebreak_property_utf32_with_rule( + property_table: &RuleBreakPropertyTable<'_>, + codepoint: u32, + strictness: LineBreakStrictness, + word_option: LineBreakWordOption, +) -> u8 { + // Note: Default value is 0 == UNKNOWN + let prop = property_table.0.get32(codepoint); + + if word_option == LineBreakWordOption::BreakAll + || strictness == LineBreakStrictness::Loose + || strictness == LineBreakStrictness::Normal + { + return match prop { + CJ => ID, // All CJ's General_Category is Other_Letter (Lo). + _ => prop, + }; + } + + // CJ is treated as NS by default, yielding strict line breaking. + // https://www.unicode.org/reports/tr14/#CJ + prop +} + +#[inline] +fn get_linebreak_property_latin1(property_table: &RuleBreakPropertyTable<'_>, codepoint: u8) -> u8 { + // Note: Default value is 0 == UNKNOWN + property_table.0.get32(codepoint as u32) +} + +#[inline] +fn get_linebreak_property_with_rule( + property_table: &RuleBreakPropertyTable<'_>, + codepoint: char, + linebreak_rule: LineBreakStrictness, + wordbreak_rule: LineBreakWordOption, +) -> u8 { + get_linebreak_property_utf32_with_rule( + property_table, + codepoint as u32, + linebreak_rule, + wordbreak_rule, + ) +} + +#[inline] +fn is_break_utf32_by_normal(codepoint: u32, ja_zh: bool) -> bool { + match codepoint { + 0x301C => ja_zh, + 0x30A0 => ja_zh, + _ => false, + } +} + +#[inline] +fn is_break_utf32_by_loose( + right_codepoint: u32, + left_prop: u8, + right_prop: u8, + ja_zh: bool, +) -> Option<bool> { + // breaks before hyphens + if right_prop == BA { + if left_prop == ID && (right_codepoint == 0x2010 || right_codepoint == 0x2013) { + return Some(true); + } + } else if right_prop == NS { + // breaks before certain CJK hyphen-like characters + if right_codepoint == 0x301C || right_codepoint == 0x30A0 { + return Some(ja_zh); + } + + // breaks before iteration marks + if right_codepoint == 0x3005 + || right_codepoint == 0x303B + || right_codepoint == 0x309D + || right_codepoint == 0x309E + || right_codepoint == 0x30FD + || right_codepoint == 0x30FE + { + return Some(true); + } + + // breaks before certain centered punctuation marks: + if right_codepoint == 0x30FB + || right_codepoint == 0xFF1A + || right_codepoint == 0xFF1B + || right_codepoint == 0xFF65 + || right_codepoint == 0x203C + || (0x2047..=0x2049).contains(&right_codepoint) + { + return Some(ja_zh); + } + } else if right_prop == IN { + // breaks between inseparable characters such as U+2025, U+2026 i.e. characters with the Unicode Line Break property IN + return Some(true); + } else if right_prop == EX { + // breaks before certain centered punctuation marks: + if right_codepoint == 0xFF01 || right_codepoint == 0xFF1F { + return Some(ja_zh); + } + } + + // breaks before suffixes: + // Characters with the Unicode Line Break property PO and the East Asian Width property + if right_prop == PO_EAW { + return Some(ja_zh); + } + // breaks after prefixes: + // Characters with the Unicode Line Break property PR and the East Asian Width property + if left_prop == PR_EAW { + return Some(ja_zh); + } + None +} + +#[inline] +fn is_break_from_table( + break_state_table: &RuleBreakStateTable<'_>, + property_count: u8, + left: u8, + right: u8, +) -> bool { + let rule = get_break_state_from_table(break_state_table, property_count, left, right); + if rule == KEEP_RULE { + return false; + } + if rule >= 0 { + // need additional next characters to get break rule. + return false; + } + true +} + +#[inline] +fn is_non_break_by_keepall(left: u8, right: u8) -> bool { + // typographic letter units shouldn't be break + (left == AI + || left == AL + || left == ID + || left == NU + || left == HY + || left == H2 + || left == H3 + || left == JL + || left == JV + || left == JT + || left == CJ) + && (right == AI + || right == AL + || right == ID + || right == NU + || right == HY + || right == H2 + || right == H3 + || right == JL + || right == JV + || right == JT + || right == CJ) +} + +#[inline] +fn get_break_state_from_table( + break_state_table: &RuleBreakStateTable<'_>, + property_count: u8, + left: u8, + right: u8, +) -> i8 { + let idx = (left as usize) * (property_count as usize) + (right as usize); + // We use unwrap_or to fall back to the base case and prevent panics on bad data. + break_state_table.0.get(idx).unwrap_or(KEEP_RULE) +} + +#[inline] +fn use_complex_breaking_utf32(property_table: &RuleBreakPropertyTable<'_>, codepoint: u32) -> bool { + let line_break_property = get_linebreak_property_utf32_with_rule( + property_table, + codepoint, + LineBreakStrictness::Strict, + LineBreakWordOption::Normal, + ); + + line_break_property == SA +} + +/* +#[inline] +fn use_complex_breaking_utf32(codepoint: u32) -> bool { + // Thai, Lao and Khmer + (codepoint >= 0xe01 && codepoint <= 0xeff) || (codepoint >= 0x1780 && codepoint <= 0x17ff) +} +*/ + +/// A trait allowing for LineBreakIterator to be generalized to multiple string iteration methods. +/// +/// This is implemented by ICU4X for several common string types. +pub trait LineBreakType<'l, 's> { + /// The iterator over characters. + type IterAttr: Iterator<Item = (usize, Self::CharType)> + Clone; + + /// The character type. + type CharType: Copy + Into<u32>; + + fn use_complex_breaking(iterator: &LineBreakIterator<'l, 's, Self>, c: Self::CharType) -> bool; + + fn get_linebreak_property_with_rule( + iterator: &LineBreakIterator<'l, 's, Self>, + c: Self::CharType, + ) -> u8; + + fn get_current_position_character_len(iterator: &LineBreakIterator<'l, 's, Self>) -> usize; + + fn handle_complex_language( + iterator: &mut LineBreakIterator<'l, 's, Self>, + left_codepoint: Self::CharType, + ) -> Option<usize>; +} + +/// Implements the [`Iterator`] trait over the line break opportunities of the given string. +/// +/// Lifetimes: +/// +/// - `'l` = lifetime of the [`LineSegmenter`] object from which this iterator was created +/// - `'s` = lifetime of the string being segmented +/// +/// The [`Iterator::Item`] is an [`usize`] representing index of a code unit +/// _after_ the break (for a break at the end of text, this index is the length +/// of the [`str`] or array of code units). +/// +/// For examples of use, see [`LineSegmenter`]. +#[derive(Debug)] +pub struct LineBreakIterator<'l, 's, Y: LineBreakType<'l, 's> + ?Sized> { + iter: Y::IterAttr, + len: usize, + current_pos_data: Option<(usize, Y::CharType)>, + result_cache: Vec<usize>, + data: &'l RuleBreakDataV1<'l>, + options: &'l LineBreakOptions, + complex: &'l ComplexPayloads, +} + +impl<'l, 's, Y: LineBreakType<'l, 's>> Iterator for LineBreakIterator<'l, 's, Y> { + type Item = usize; + + fn next(&mut self) -> Option<Self::Item> { + match self.check_eof() { + StringBoundaryPosType::Start => return Some(0), + StringBoundaryPosType::End => return None, + _ => (), + } + + // If we have break point cache by previous run, return this result + if let Some(&first_pos) = self.result_cache.first() { + let mut i = 0; + loop { + if i == first_pos { + self.result_cache = self.result_cache.iter().skip(1).map(|r| r - i).collect(); + return self.get_current_position(); + } + i += Y::get_current_position_character_len(self); + self.advance_iter(); + if self.is_eof() { + self.result_cache.clear(); + return Some(self.len); + } + } + } + + loop { + debug_assert!(!self.is_eof()); + let left_codepoint = self.get_current_codepoint()?; + let mut left_prop = self.get_linebreak_property(left_codepoint); + self.advance_iter(); + + let Some(right_codepoint) = self.get_current_codepoint() else { + return Some(self.len); + }; + let right_prop = self.get_linebreak_property(right_codepoint); + + // CSS word-break property handling + match self.options.word_option { + LineBreakWordOption::BreakAll => { + left_prop = match left_prop { + AL => ID, + NU => ID, + SA => ID, + _ => left_prop, + }; + } + LineBreakWordOption::KeepAll => { + if is_non_break_by_keepall(left_prop, right_prop) { + continue; + } + } + _ => (), + } + + // CSS line-break property handling + match self.options.strictness { + LineBreakStrictness::Normal => { + if self.is_break_by_normal(right_codepoint) { + return self.get_current_position(); + } + } + LineBreakStrictness::Loose => { + if let Some(breakable) = is_break_utf32_by_loose( + right_codepoint.into(), + left_prop, + right_prop, + self.options.ja_zh, + ) { + if breakable { + return self.get_current_position(); + } + continue; + } + } + LineBreakStrictness::Anywhere => { + return self.get_current_position(); + } + _ => (), + }; + + // UAX14 doesn't have Thai etc, so use another way. + if self.options.word_option != LineBreakWordOption::BreakAll + && Y::use_complex_breaking(self, left_codepoint) + && Y::use_complex_breaking(self, right_codepoint) + { + let result = Y::handle_complex_language(self, left_codepoint); + if result.is_some() { + return result; + } + // I may have to fetch text until non-SA character?. + } + + // If break_state is equals or grater than 0, it is alias of property. + let mut break_state = self.get_break_state_from_table(left_prop, right_prop); + if break_state >= 0_i8 { + let mut previous_iter = self.iter.clone(); + let mut previous_pos_data = self.current_pos_data; + + loop { + self.advance_iter(); + + let Some(prop) = self.get_current_linebreak_property() else { + // Reached EOF. But we are analyzing multiple characters now, so next break may be previous point. + let break_state = self + .get_break_state_from_table(break_state as u8, self.data.eot_property); + if break_state == NOT_MATCH_RULE { + self.iter = previous_iter; + self.current_pos_data = previous_pos_data; + return self.get_current_position(); + } + // EOF + return Some(self.len); + }; + + break_state = self.get_break_state_from_table(break_state as u8, prop); + if break_state < 0 { + break; + } + + previous_iter = self.iter.clone(); + previous_pos_data = self.current_pos_data; + } + if break_state == KEEP_RULE { + continue; + } + if break_state == NOT_MATCH_RULE { + self.iter = previous_iter; + self.current_pos_data = previous_pos_data; + return self.get_current_position(); + } + return self.get_current_position(); + } + + if self.is_break_from_table(left_prop, right_prop) { + return self.get_current_position(); + } + } + } +} + +enum StringBoundaryPosType { + Start, + Middle, + End, +} + +impl<'l, 's, Y: LineBreakType<'l, 's>> LineBreakIterator<'l, 's, Y> { + fn advance_iter(&mut self) { + self.current_pos_data = self.iter.next(); + } + + fn is_eof(&self) -> bool { + self.current_pos_data.is_none() + } + + #[inline] + fn check_eof(&mut self) -> StringBoundaryPosType { + if self.is_eof() { + self.advance_iter(); + if self.is_eof() { + if self.len == 0 { + // Empty string. Since `self.current_pos_data` is always going to be empty, + // we never read `self.len` except for here, so we can use it to mark that + // we have already returned the single empty-string breakpoint. + self.len = 1; + StringBoundaryPosType::Start + } else { + StringBoundaryPosType::End + } + } else { + StringBoundaryPosType::Start + } + } else { + StringBoundaryPosType::Middle + } + } + + fn get_current_position(&self) -> Option<usize> { + self.current_pos_data.map(|(pos, _)| pos) + } + + fn get_current_codepoint(&self) -> Option<Y::CharType> { + self.current_pos_data.map(|(_, codepoint)| codepoint) + } + + fn get_linebreak_property(&self, codepoint: Y::CharType) -> u8 { + Y::get_linebreak_property_with_rule(self, codepoint) + } + + fn get_current_linebreak_property(&self) -> Option<u8> { + self.get_current_codepoint() + .map(|c| self.get_linebreak_property(c)) + } + + fn is_break_by_normal(&self, codepoint: Y::CharType) -> bool { + is_break_utf32_by_normal(codepoint.into(), self.options.ja_zh) + } + + fn get_break_state_from_table(&self, left: u8, right: u8) -> i8 { + get_break_state_from_table( + &self.data.break_state_table, + self.data.property_count, + left, + right, + ) + } + + fn is_break_from_table(&self, left: u8, right: u8) -> bool { + is_break_from_table( + &self.data.break_state_table, + self.data.property_count, + left, + right, + ) + } +} + +#[derive(Debug)] +pub struct LineBreakTypeUtf8; + +impl<'l, 's> LineBreakType<'l, 's> for LineBreakTypeUtf8 { + type IterAttr = CharIndices<'s>; + type CharType = char; + + fn get_linebreak_property_with_rule(iterator: &LineBreakIterator<Self>, c: char) -> u8 { + get_linebreak_property_with_rule( + &iterator.data.property_table, + c, + iterator.options.strictness, + iterator.options.word_option, + ) + } + + #[inline] + fn use_complex_breaking(iterator: &LineBreakIterator<Self>, c: char) -> bool { + use_complex_breaking_utf32(&iterator.data.property_table, c as u32) + } + + fn get_current_position_character_len(iterator: &LineBreakIterator<Self>) -> usize { + iterator.get_current_codepoint().map_or(0, |c| c.len_utf8()) + } + + fn handle_complex_language( + iter: &mut LineBreakIterator<'l, 's, Self>, + left_codepoint: char, + ) -> Option<usize> { + handle_complex_language_utf8(iter, left_codepoint) + } +} + +#[derive(Debug)] +pub struct LineBreakTypePotentiallyIllFormedUtf8; + +impl<'l, 's> LineBreakType<'l, 's> for LineBreakTypePotentiallyIllFormedUtf8 { + type IterAttr = Utf8CharIndices<'s>; + type CharType = char; + + fn get_linebreak_property_with_rule(iterator: &LineBreakIterator<Self>, c: char) -> u8 { + get_linebreak_property_with_rule( + &iterator.data.property_table, + c, + iterator.options.strictness, + iterator.options.word_option, + ) + } + + #[inline] + fn use_complex_breaking(iterator: &LineBreakIterator<Self>, c: char) -> bool { + use_complex_breaking_utf32(&iterator.data.property_table, c as u32) + } + + fn get_current_position_character_len(iterator: &LineBreakIterator<Self>) -> usize { + iterator.get_current_codepoint().map_or(0, |c| c.len_utf8()) + } + + fn handle_complex_language( + iter: &mut LineBreakIterator<'l, 's, Self>, + left_codepoint: char, + ) -> Option<usize> { + handle_complex_language_utf8(iter, left_codepoint) + } +} +/// handle_complex_language impl for UTF8 iterators +fn handle_complex_language_utf8<'l, 's, T>( + iter: &mut LineBreakIterator<'l, 's, T>, + left_codepoint: char, +) -> Option<usize> +where + T: LineBreakType<'l, 's, CharType = char>, +{ + // word segmenter doesn't define break rules for some languages such as Thai. + let start_iter = iter.iter.clone(); + let start_point = iter.current_pos_data; + let mut s = String::new(); + s.push(left_codepoint); + loop { + debug_assert!(!iter.is_eof()); + s.push(iter.get_current_codepoint()?); + iter.advance_iter(); + if let Some(current_codepoint) = iter.get_current_codepoint() { + if !T::use_complex_breaking(iter, current_codepoint) { + break; + } + } else { + // EOF + break; + } + } + + // Restore iterator to move to head of complex string + iter.iter = start_iter; + iter.current_pos_data = start_point; + let breaks = complex_language_segment_str(iter.complex, &s); + iter.result_cache = breaks; + let first_pos = *iter.result_cache.first()?; + let mut i = left_codepoint.len_utf8(); + loop { + if i == first_pos { + // Re-calculate breaking offset + iter.result_cache = iter.result_cache.iter().skip(1).map(|r| r - i).collect(); + return iter.get_current_position(); + } + debug_assert!( + i < first_pos, + "we should always arrive at first_pos: near index {:?}", + iter.get_current_position() + ); + i += T::get_current_position_character_len(iter); + iter.advance_iter(); + if iter.is_eof() { + iter.result_cache.clear(); + return Some(iter.len); + } + } +} + +#[derive(Debug)] +pub struct LineBreakTypeLatin1; + +impl<'l, 's> LineBreakType<'l, 's> for LineBreakTypeLatin1 { + type IterAttr = Latin1Indices<'s>; + type CharType = u8; + + fn get_linebreak_property_with_rule(iterator: &LineBreakIterator<Self>, c: u8) -> u8 { + // No CJ on Latin1 + get_linebreak_property_latin1(&iterator.data.property_table, c) + } + + #[inline] + fn use_complex_breaking(_iterator: &LineBreakIterator<Self>, _c: u8) -> bool { + false + } + + fn get_current_position_character_len(_: &LineBreakIterator<Self>) -> usize { + unreachable!() + } + + fn handle_complex_language( + _: &mut LineBreakIterator<Self>, + _: Self::CharType, + ) -> Option<usize> { + unreachable!() + } +} + +#[derive(Debug)] +pub struct LineBreakTypeUtf16; + +impl<'l, 's> LineBreakType<'l, 's> for LineBreakTypeUtf16 { + type IterAttr = Utf16Indices<'s>; + type CharType = u32; + + fn get_linebreak_property_with_rule(iterator: &LineBreakIterator<Self>, c: u32) -> u8 { + get_linebreak_property_utf32_with_rule( + &iterator.data.property_table, + c, + iterator.options.strictness, + iterator.options.word_option, + ) + } + + #[inline] + fn use_complex_breaking(iterator: &LineBreakIterator<Self>, c: u32) -> bool { + use_complex_breaking_utf32(&iterator.data.property_table, c) + } + + fn get_current_position_character_len(iterator: &LineBreakIterator<Self>) -> usize { + match iterator.get_current_codepoint() { + None => 0, + Some(ch) if ch >= 0x10000 => 2, + _ => 1, + } + } + + fn handle_complex_language( + iterator: &mut LineBreakIterator<Self>, + left_codepoint: Self::CharType, + ) -> Option<usize> { + // word segmenter doesn't define break rules for some languages such as Thai. + let start_iter = iterator.iter.clone(); + let start_point = iterator.current_pos_data; + let mut s = vec![left_codepoint as u16]; + loop { + debug_assert!(!iterator.is_eof()); + s.push(iterator.get_current_codepoint()? as u16); + iterator.advance_iter(); + if let Some(current_codepoint) = iterator.get_current_codepoint() { + if !Self::use_complex_breaking(iterator, current_codepoint) { + break; + } + } else { + // EOF + break; + } + } + + // Restore iterator to move to head of complex string + iterator.iter = start_iter; + iterator.current_pos_data = start_point; + let breaks = complex_language_segment_utf16(iterator.complex, &s); + iterator.result_cache = breaks; + // result_cache vector is utf-16 index that is in BMP. + let first_pos = *iterator.result_cache.first()?; + let mut i = 1; + loop { + if i == first_pos { + // Re-calculate breaking offset + iterator.result_cache = iterator + .result_cache + .iter() + .skip(1) + .map(|r| r - i) + .collect(); + return iterator.get_current_position(); + } + debug_assert!( + i < first_pos, + "we should always arrive at first_pos: near index {:?}", + iterator.get_current_position() + ); + i += 1; + iterator.advance_iter(); + if iterator.is_eof() { + iterator.result_cache.clear(); + return Some(iterator.len); + } + } + } +} + +#[cfg(test)] +#[cfg(feature = "serde")] +mod tests { + use super::*; + use crate::LineSegmenter; + + #[test] + fn linebreak_property() { + let payload = DataProvider::<LineBreakDataV1Marker>::load( + &crate::provider::Baked, + Default::default(), + ) + .expect("Loading should succeed!") + .take_payload() + .expect("Data should be present!"); + + let get_linebreak_property = |codepoint| { + get_linebreak_property_with_rule( + &payload.get().property_table, + codepoint, + LineBreakStrictness::Strict, + LineBreakWordOption::Normal, + ) + }; + + assert_eq!(get_linebreak_property('\u{0020}'), SP); + assert_eq!(get_linebreak_property('\u{0022}'), QU); + assert_eq!(get_linebreak_property('('), OP_OP30); + assert_eq!(get_linebreak_property('\u{0030}'), NU); + assert_eq!(get_linebreak_property('['), OP_OP30); + assert_eq!(get_linebreak_property('\u{1f3fb}'), EM); + assert_eq!(get_linebreak_property('\u{20000}'), ID); + assert_eq!(get_linebreak_property('\u{e0020}'), CM); + assert_eq!(get_linebreak_property('\u{3041}'), CJ); + assert_eq!(get_linebreak_property('\u{0025}'), PO); + assert_eq!(get_linebreak_property('\u{00A7}'), AI); + assert_eq!(get_linebreak_property('\u{50005}'), XX); + assert_eq!(get_linebreak_property('\u{17D6}'), NS); + assert_eq!(get_linebreak_property('\u{2014}'), B2); + } + + #[test] + #[allow(clippy::bool_assert_comparison)] // clearer when we're testing bools directly + fn break_rule() { + let payload = DataProvider::<LineBreakDataV1Marker>::load( + &crate::provider::Baked, + Default::default(), + ) + .expect("Loading should succeed!") + .take_payload() + .expect("Data should be present!"); + let lb_data: &RuleBreakDataV1 = payload.get(); + + let is_break = |left, right| { + is_break_from_table( + &lb_data.break_state_table, + lb_data.property_count, + left, + right, + ) + }; + + // LB4 + assert_eq!(is_break(BK, AL), true); + // LB5 + assert_eq!(is_break(CR, LF), false); + assert_eq!(is_break(CR, AL), true); + assert_eq!(is_break(LF, AL), true); + assert_eq!(is_break(NL, AL), true); + // LB6 + assert_eq!(is_break(AL, BK), false); + assert_eq!(is_break(AL, CR), false); + assert_eq!(is_break(AL, LF), false); + assert_eq!(is_break(AL, NL), false); + // LB7 + assert_eq!(is_break(AL, SP), false); + assert_eq!(is_break(AL, ZW), false); + // LB8 + // LB8a + assert_eq!(is_break(ZWJ, AL), false); + // LB9 + assert_eq!(is_break(AL, ZWJ), false); + assert_eq!(is_break(AL, CM), false); + assert_eq!(is_break(ID, ZWJ), false); + // LB10 + assert_eq!(is_break(ZWJ, SP), false); + assert_eq!(is_break(SP, CM), true); + // LB11 + assert_eq!(is_break(AL, WJ), false); + assert_eq!(is_break(WJ, AL), false); + // LB12 + assert_eq!(is_break(GL, AL), false); + // LB12a + assert_eq!(is_break(AL, GL), false); + assert_eq!(is_break(SP, GL), true); + // LB13 + assert_eq!(is_break(AL, CL), false); + assert_eq!(is_break(AL, CP), false); + assert_eq!(is_break(AL, EX), false); + assert_eq!(is_break(AL, IS), false); + assert_eq!(is_break(AL, SY), false); + // LB18 + assert_eq!(is_break(SP, AL), true); + // LB19 + assert_eq!(is_break(AL, QU), false); + assert_eq!(is_break(QU, AL), false); + // LB20 + assert_eq!(is_break(AL, CB), true); + assert_eq!(is_break(CB, AL), true); + // LB20 + assert_eq!(is_break(AL, BA), false); + assert_eq!(is_break(AL, HY), false); + assert_eq!(is_break(AL, NS), false); + // LB21 + assert_eq!(is_break(AL, BA), false); + assert_eq!(is_break(BB, AL), false); + assert_eq!(is_break(ID, BA), false); + assert_eq!(is_break(ID, NS), false); + // LB21a + // LB21b + assert_eq!(is_break(SY, HL), false); + // LB22 + assert_eq!(is_break(AL, IN), false); + // LB 23 + assert_eq!(is_break(AL, NU), false); + assert_eq!(is_break(HL, NU), false); + // LB 23a + assert_eq!(is_break(PR, ID), false); + assert_eq!(is_break(PR, EB), false); + assert_eq!(is_break(PR, EM), false); + assert_eq!(is_break(ID, PO), false); + assert_eq!(is_break(EB, PO), false); + assert_eq!(is_break(EM, PO), false); + // LB26 + assert_eq!(is_break(JL, JL), false); + assert_eq!(is_break(JL, JV), false); + assert_eq!(is_break(JL, H2), false); + // LB27 + assert_eq!(is_break(JL, IN), false); + assert_eq!(is_break(JL, PO), false); + assert_eq!(is_break(PR, JL), false); + // LB28 + assert_eq!(is_break(AL, AL), false); + assert_eq!(is_break(HL, AL), false); + // LB29 + assert_eq!(is_break(IS, AL), false); + assert_eq!(is_break(IS, HL), false); + // LB30b + assert_eq!(is_break(EB, EM), false); + // LB31 + assert_eq!(is_break(ID, ID), true); + } + + #[test] + fn linebreak() { + let segmenter = LineSegmenter::try_new_dictionary_unstable(&crate::provider::Baked) + .expect("Data exists"); + + let mut iter = segmenter.segment_str("hello world"); + assert_eq!(Some(0), iter.next()); + assert_eq!(Some(6), iter.next()); + assert_eq!(Some(11), iter.next()); + assert_eq!(None, iter.next()); + + iter = segmenter.segment_str("$10 $10"); + assert_eq!(Some(0), iter.next()); + assert_eq!(Some(4), iter.next()); + assert_eq!(Some(7), iter.next()); + assert_eq!(None, iter.next()); + + // LB10 + + // LB14 + iter = segmenter.segment_str("[ abc def"); + assert_eq!(Some(0), iter.next()); + assert_eq!(Some(7), iter.next()); + assert_eq!(Some(10), iter.next()); + assert_eq!(None, iter.next()); + + let input: [u8; 10] = [0x5B, 0x20, 0x20, 0x61, 0x62, 0x63, 0x20, 0x64, 0x65, 0x66]; + let mut iter_u8 = segmenter.segment_latin1(&input); + assert_eq!(Some(0), iter_u8.next()); + assert_eq!(Some(7), iter_u8.next()); + assert_eq!(Some(10), iter_u8.next()); + assert_eq!(None, iter_u8.next()); + + let input: [u16; 10] = [0x5B, 0x20, 0x20, 0x61, 0x62, 0x63, 0x20, 0x64, 0x65, 0x66]; + let mut iter_u16 = segmenter.segment_utf16(&input); + assert_eq!(Some(0), iter_u16.next()); + assert_eq!(Some(7), iter_u16.next()); + assert_eq!(Some(10), iter_u16.next()); + assert_eq!(None, iter_u16.next()); + + // LB15 + iter = segmenter.segment_str("abc\u{0022} (def"); + assert_eq!(Some(0), iter.next()); + assert_eq!(Some(10), iter.next()); + assert_eq!(None, iter.next()); + + let input: [u8; 10] = [0x61, 0x62, 0x63, 0x22, 0x20, 0x20, 0x28, 0x64, 0x65, 0x66]; + let mut iter_u8 = segmenter.segment_latin1(&input); + assert_eq!(Some(0), iter_u8.next()); + assert_eq!(Some(10), iter_u8.next()); + assert_eq!(None, iter_u8.next()); + + let input: [u16; 10] = [0x61, 0x62, 0x63, 0x22, 0x20, 0x20, 0x28, 0x64, 0x65, 0x66]; + let mut iter_u16 = segmenter.segment_utf16(&input); + assert_eq!(Some(0), iter_u16.next()); + assert_eq!(Some(10), iter_u16.next()); + assert_eq!(None, iter_u16.next()); + + // LB16 + iter = segmenter.segment_str("\u{0029}\u{203C}"); + assert_eq!(Some(0), iter.next()); + assert_eq!(Some(4), iter.next()); + assert_eq!(None, iter.next()); + iter = segmenter.segment_str("\u{0029} \u{203C}"); + assert_eq!(Some(0), iter.next()); + assert_eq!(Some(6), iter.next()); + assert_eq!(None, iter.next()); + + let input: [u16; 4] = [0x29, 0x20, 0x20, 0x203c]; + let mut iter_u16 = segmenter.segment_utf16(&input); + assert_eq!(Some(0), iter_u16.next()); + assert_eq!(Some(4), iter_u16.next()); + assert_eq!(None, iter_u16.next()); + + // LB17 + iter = segmenter.segment_str("\u{2014}\u{2014}aa"); + assert_eq!(Some(0), iter.next()); + assert_eq!(Some(6), iter.next()); + assert_eq!(Some(8), iter.next()); + assert_eq!(None, iter.next()); + iter = segmenter.segment_str("\u{2014} \u{2014}aa"); + assert_eq!(Some(0), iter.next()); + assert_eq!(Some(8), iter.next()); + assert_eq!(Some(10), iter.next()); + assert_eq!(None, iter.next()); + + iter = segmenter.segment_str("\u{2014}\u{2014} \u{2014}\u{2014}123 abc"); + assert_eq!(Some(0), iter.next()); + assert_eq!(Some(14), iter.next()); + assert_eq!(Some(18), iter.next()); + assert_eq!(Some(21), iter.next()); + assert_eq!(None, iter.next()); + + // LB25 + let mut iter = segmenter.segment_str("(0,1)+(2,3)"); + assert_eq!(Some(0), iter.next()); + assert_eq!(Some(11), iter.next()); + assert_eq!(None, iter.next()); + let input: [u16; 11] = [ + 0x28, 0x30, 0x2C, 0x31, 0x29, 0x2B, 0x28, 0x32, 0x2C, 0x33, 0x29, + ]; + let mut iter_u16 = segmenter.segment_utf16(&input); + assert_eq!(Some(0), iter_u16.next()); + assert_eq!(Some(11), iter_u16.next()); + assert_eq!(None, iter_u16.next()); + + let input: [u16; 13] = [ + 0x2014, 0x2014, 0x20, 0x20, 0x2014, 0x2014, 0x31, 0x32, 0x33, 0x20, 0x61, 0x62, 0x63, + ]; + let mut iter_u16 = segmenter.segment_utf16(&input); + assert_eq!(Some(0), iter_u16.next()); + assert_eq!(Some(6), iter_u16.next()); + assert_eq!(Some(10), iter_u16.next()); + assert_eq!(Some(13), iter_u16.next()); + assert_eq!(None, iter_u16.next()); + + iter = segmenter.segment_str("\u{1F3FB} \u{1F3FB}"); + assert_eq!(Some(0), iter.next()); + assert_eq!(Some(5), iter.next()); + assert_eq!(Some(9), iter.next()); + assert_eq!(None, iter.next()); + } + + #[test] + #[cfg(feature = "lstm")] + fn thai_line_break() { + const TEST_STR: &str = "ภาษาไทยภาษาไทย"; + + let segmenter = LineSegmenter::new_lstm(); + let breaks: Vec<usize> = segmenter.segment_str(TEST_STR).collect(); + assert_eq!(breaks, [0, 12, 21, 33, TEST_STR.len()], "Thai test"); + + let utf16: Vec<u16> = TEST_STR.encode_utf16().collect(); + let breaks: Vec<usize> = segmenter.segment_utf16(&utf16).collect(); + assert_eq!(breaks, [0, 4, 7, 11, utf16.len()], "Thai test"); + + let utf16: [u16; 4] = [0x0e20, 0x0e32, 0x0e29, 0x0e32]; + let breaks: Vec<usize> = segmenter.segment_utf16(&utf16).collect(); + assert_eq!(breaks, [0, 4], "Thai test"); + } + + #[test] + #[cfg(feature = "lstm")] + fn burmese_line_break() { + // "Burmese Language" in Burmese + const TEST_STR: &str = "မြန်မာဘာသာစကား"; + + let segmenter = LineSegmenter::new_lstm(); + let breaks: Vec<usize> = segmenter.segment_str(TEST_STR).collect(); + // LSTM model breaks more characters, but it is better to return [30]. + assert_eq!(breaks, [0, 12, 18, 30, TEST_STR.len()], "Burmese test"); + + let utf16: Vec<u16> = TEST_STR.encode_utf16().collect(); + let breaks: Vec<usize> = segmenter.segment_utf16(&utf16).collect(); + // LSTM model breaks more characters, but it is better to return [10]. + assert_eq!(breaks, [0, 4, 6, 10, utf16.len()], "Burmese utf-16 test"); + } + + #[test] + #[cfg(feature = "lstm")] + fn khmer_line_break() { + const TEST_STR: &str = "សេចក្ដីប្រកាសជាសកលស្ដីពីសិទ្ធិមនុស្ស"; + + let segmenter = LineSegmenter::new_lstm(); + let breaks: Vec<usize> = segmenter.segment_str(TEST_STR).collect(); + // Note: This small sample matches the ICU dictionary segmenter + assert_eq!(breaks, [0, 39, 48, 54, 72, TEST_STR.len()], "Khmer test"); + + let utf16: Vec<u16> = TEST_STR.encode_utf16().collect(); + let breaks: Vec<usize> = segmenter.segment_utf16(&utf16).collect(); + assert_eq!( + breaks, + [0, 13, 16, 18, 24, utf16.len()], + "Khmer utf-16 test" + ); + } + + #[test] + #[cfg(feature = "lstm")] + fn lao_line_break() { + const TEST_STR: &str = "ກ່ຽວກັບສິດຂອງມະນຸດ"; + + let segmenter = LineSegmenter::new_lstm(); + let breaks: Vec<usize> = segmenter.segment_str(TEST_STR).collect(); + // Note: LSTM finds a break at '12' that the dictionary does not find + assert_eq!(breaks, [0, 12, 21, 30, 39, TEST_STR.len()], "Lao test"); + + let utf16: Vec<u16> = TEST_STR.encode_utf16().collect(); + let breaks: Vec<usize> = segmenter.segment_utf16(&utf16).collect(); + assert_eq!(breaks, [0, 4, 7, 10, 13, utf16.len()], "Lao utf-16 test"); + } + + #[test] + fn empty_string() { + let segmenter = LineSegmenter::new_auto(); + let breaks: Vec<usize> = segmenter.segment_str("").collect(); + assert_eq!(breaks, [0]); + } +} diff --git a/third_party/rust/icu_segmenter/src/provider/lstm.rs b/third_party/rust/icu_segmenter/src/provider/lstm.rs new file mode 100644 index 0000000000..6a85680e4c --- /dev/null +++ b/third_party/rust/icu_segmenter/src/provider/lstm.rs @@ -0,0 +1,358 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +//! Data provider struct definitions for the lstm + +// Provider structs must be stable +#![allow(clippy::exhaustive_structs, clippy::exhaustive_enums)] + +use icu_provider::prelude::*; +use zerovec::{ule::UnvalidatedStr, ZeroMap, ZeroVec}; + +// We do this instead of const generics because ZeroFrom and Yokeable derives, as well as serde +// don't support them +macro_rules! lstm_matrix { + ($name:ident, $generic:literal) => { + /// The struct that stores a LSTM's matrix. + /// + /// <div class="stab unstable"> + /// 🚧 This code is considered unstable; it may change at any time, in breaking or non-breaking ways, + /// including in SemVer minor releases. While the serde representation of data structs is guaranteed + /// to be stable, their Rust representation might not be. Use with caution. + /// </div> + #[derive(PartialEq, Debug, Clone, zerofrom::ZeroFrom, yoke::Yokeable)] + #[cfg_attr(feature = "datagen", derive(serde::Serialize))] + pub struct $name<'data> { + // Invariant: dims.product() == data.len() + #[allow(missing_docs)] + pub(crate) dims: [u16; $generic], + #[allow(missing_docs)] + pub(crate) data: ZeroVec<'data, f32>, + } + + impl<'data> $name<'data> { + #[cfg(any(feature = "serde", feature = "datagen"))] + /// Creates a LstmMatrix with the given dimensions. Fails if the dimensions don't match the data. + pub fn from_parts( + dims: [u16; $generic], + data: ZeroVec<'data, f32>, + ) -> Result<Self, DataError> { + if dims.iter().map(|&i| i as usize).product::<usize>() != data.len() { + Err(DataError::custom("Dimension mismatch")) + } else { + Ok(Self { dims, data }) + } + } + + #[doc(hidden)] // databake + pub const fn from_parts_unchecked( + dims: [u16; $generic], + data: ZeroVec<'data, f32>, + ) -> Self { + Self { dims, data } + } + } + + #[cfg(feature = "serde")] + impl<'de: 'data, 'data> serde::Deserialize<'de> for $name<'data> { + fn deserialize<S>(deserializer: S) -> Result<Self, S::Error> + where + S: serde::de::Deserializer<'de>, + { + #[derive(serde::Deserialize)] + struct Raw<'data> { + dims: [u16; $generic], + #[serde(borrow)] + data: ZeroVec<'data, f32>, + } + + let raw = Raw::deserialize(deserializer)?; + + use serde::de::Error; + Self::from_parts(raw.dims, raw.data) + .map_err(|_| S::Error::custom("Dimension mismatch")) + } + } + + #[cfg(feature = "datagen")] + impl databake::Bake for $name<'_> { + fn bake(&self, env: &databake::CrateEnv) -> databake::TokenStream { + let dims = self.dims.bake(env); + let data = self.data.bake(env); + databake::quote! { + icu_segmenter::provider::$name::from_parts_unchecked(#dims, #data) + } + } + } + }; +} + +lstm_matrix!(LstmMatrix1, 1); +lstm_matrix!(LstmMatrix2, 2); +lstm_matrix!(LstmMatrix3, 3); + +#[derive(PartialEq, Debug, Clone, Copy)] +#[cfg_attr( + feature = "datagen", + derive(serde::Serialize,databake::Bake), + databake(path = icu_segmenter::provider), +)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize))] +/// The type of LSTM model +/// +/// <div class="stab unstable"> +/// 🚧 This code is considered unstable; it may change at any time, in breaking or non-breaking ways, +/// including in SemVer minor releases. While the serde representation of data structs is guaranteed +/// to be stable, their Rust representation might not be. Use with caution. +/// </div> +pub enum ModelType { + /// A model working on code points + Codepoints, + /// A model working on grapheme clusters + GraphemeClusters, +} + +/// The struct that stores a LSTM model. +/// +/// <div class="stab unstable"> +/// 🚧 This code is considered unstable; it may change at any time, in breaking or non-breaking ways, +/// including in SemVer minor releases. While the serde representation of data structs is guaranteed +/// to be stable, their Rust representation might not be. Use with caution. +/// </div> +#[derive(PartialEq, Debug, Clone, yoke::Yokeable, zerofrom::ZeroFrom)] +#[cfg_attr(feature = "datagen", derive(serde::Serialize))] +#[yoke(prove_covariance_manually)] +pub struct LstmDataFloat32<'data> { + /// Type of the model + pub(crate) model: ModelType, + /// The grapheme cluster dictionary used to train the model + pub(crate) dic: ZeroMap<'data, UnvalidatedStr, u16>, + /// The embedding layer. Shape (dic.len + 1, e) + pub(crate) embedding: LstmMatrix2<'data>, + /// The forward layer's first matrix. Shape (h, 4, e) + pub(crate) fw_w: LstmMatrix3<'data>, + /// The forward layer's second matrix. Shape (h, 4, h) + pub(crate) fw_u: LstmMatrix3<'data>, + /// The forward layer's bias. Shape (h, 4) + pub(crate) fw_b: LstmMatrix2<'data>, + /// The backward layer's first matrix. Shape (h, 4, e) + pub(crate) bw_w: LstmMatrix3<'data>, + /// The backward layer's second matrix. Shape (h, 4, h) + pub(crate) bw_u: LstmMatrix3<'data>, + /// The backward layer's bias. Shape (h, 4) + pub(crate) bw_b: LstmMatrix2<'data>, + /// The output layer's weights. Shape (2, 4, h) + pub(crate) time_w: LstmMatrix3<'data>, + /// The output layer's bias. Shape (4) + pub(crate) time_b: LstmMatrix1<'data>, +} + +impl<'data> LstmDataFloat32<'data> { + #[doc(hidden)] // databake + #[allow(clippy::too_many_arguments)] // constructor + pub const fn from_parts_unchecked( + model: ModelType, + dic: ZeroMap<'data, UnvalidatedStr, u16>, + embedding: LstmMatrix2<'data>, + fw_w: LstmMatrix3<'data>, + fw_u: LstmMatrix3<'data>, + fw_b: LstmMatrix2<'data>, + bw_w: LstmMatrix3<'data>, + bw_u: LstmMatrix3<'data>, + bw_b: LstmMatrix2<'data>, + time_w: LstmMatrix3<'data>, + time_b: LstmMatrix1<'data>, + ) -> Self { + Self { + model, + dic, + embedding, + fw_w, + fw_u, + fw_b, + bw_w, + bw_u, + bw_b, + time_w, + time_b, + } + } + + #[cfg(any(feature = "serde", feature = "datagen"))] + /// Creates a LstmDataFloat32 with the given data. Fails if the matrix dimensions are inconsisent. + #[allow(clippy::too_many_arguments)] // constructor + pub fn try_from_parts( + model: ModelType, + dic: ZeroMap<'data, UnvalidatedStr, u16>, + embedding: LstmMatrix2<'data>, + fw_w: LstmMatrix3<'data>, + fw_u: LstmMatrix3<'data>, + fw_b: LstmMatrix2<'data>, + bw_w: LstmMatrix3<'data>, + bw_u: LstmMatrix3<'data>, + bw_b: LstmMatrix2<'data>, + time_w: LstmMatrix3<'data>, + time_b: LstmMatrix1<'data>, + ) -> Result<Self, DataError> { + let dic_len = u16::try_from(dic.len()) + .map_err(|_| DataError::custom("Dictionary does not fit in u16"))?; + + let num_classes = embedding.dims[0]; + let embedd_dim = embedding.dims[1]; + let hunits = fw_u.dims[2]; + if num_classes - 1 != dic_len + || fw_w.dims != [4, hunits, embedd_dim] + || fw_u.dims != [4, hunits, hunits] + || fw_b.dims != [4, hunits] + || bw_w.dims != [4, hunits, embedd_dim] + || bw_u.dims != [4, hunits, hunits] + || bw_b.dims != [4, hunits] + || time_w.dims != [2, 4, hunits] + || time_b.dims != [4] + { + return Err(DataError::custom("LSTM dimension mismatch")); + } + + #[cfg(debug_assertions)] + if !dic.iter_copied_values().all(|(_, g)| g < dic_len) { + return Err(DataError::custom("Invalid cluster id")); + } + + Ok(Self { + model, + dic, + embedding, + fw_w, + fw_u, + fw_b, + bw_w, + bw_u, + bw_b, + time_w, + time_b, + }) + } +} + +#[cfg(feature = "serde")] +impl<'de: 'data, 'data> serde::Deserialize<'de> for LstmDataFloat32<'data> { + fn deserialize<S>(deserializer: S) -> Result<Self, S::Error> + where + S: serde::de::Deserializer<'de>, + { + #[derive(serde::Deserialize)] + struct Raw<'data> { + model: ModelType, + #[cfg_attr(feature = "serde", serde(borrow))] + dic: ZeroMap<'data, UnvalidatedStr, u16>, + #[cfg_attr(feature = "serde", serde(borrow))] + embedding: LstmMatrix2<'data>, + #[cfg_attr(feature = "serde", serde(borrow))] + fw_w: LstmMatrix3<'data>, + #[cfg_attr(feature = "serde", serde(borrow))] + fw_u: LstmMatrix3<'data>, + #[cfg_attr(feature = "serde", serde(borrow))] + fw_b: LstmMatrix2<'data>, + #[cfg_attr(feature = "serde", serde(borrow))] + bw_w: LstmMatrix3<'data>, + #[cfg_attr(feature = "serde", serde(borrow))] + bw_u: LstmMatrix3<'data>, + #[cfg_attr(feature = "serde", serde(borrow))] + bw_b: LstmMatrix2<'data>, + #[cfg_attr(feature = "serde", serde(borrow))] + time_w: LstmMatrix3<'data>, + #[cfg_attr(feature = "serde", serde(borrow))] + time_b: LstmMatrix1<'data>, + } + + let raw = Raw::deserialize(deserializer)?; + + use serde::de::Error; + Self::try_from_parts( + raw.model, + raw.dic, + raw.embedding, + raw.fw_w, + raw.fw_u, + raw.fw_b, + raw.bw_w, + raw.bw_u, + raw.bw_b, + raw.time_w, + raw.time_b, + ) + .map_err(|_| S::Error::custom("Invalid dimensions")) + } +} + +#[cfg(feature = "datagen")] +impl databake::Bake for LstmDataFloat32<'_> { + fn bake(&self, env: &databake::CrateEnv) -> databake::TokenStream { + let model = self.model.bake(env); + let dic = self.dic.bake(env); + let embedding = self.embedding.bake(env); + let fw_w = self.fw_w.bake(env); + let fw_u = self.fw_u.bake(env); + let fw_b = self.fw_b.bake(env); + let bw_w = self.bw_w.bake(env); + let bw_u = self.bw_u.bake(env); + let bw_b = self.bw_b.bake(env); + let time_w = self.time_w.bake(env); + let time_b = self.time_b.bake(env); + databake::quote! { + icu_segmenter::provider::LstmDataFloat32::from_parts_unchecked( + #model, + #dic, + #embedding, + #fw_w, + #fw_u, + #fw_b, + #bw_w, + #bw_u, + #bw_b, + #time_w, + #time_b, + ) + } + } +} + +/// The data to power the LSTM segmentation model. +/// +/// This data enum is extensible: more backends may be added in the future. +/// Old data can be used with newer code but not vice versa. +/// +/// Examples of possible future extensions: +/// +/// 1. Variant to store data in 16 instead of 32 bits +/// 2. Minor changes to the LSTM model, such as different forward/backward matrix sizes +/// +/// <div class="stab unstable"> +/// 🚧 This code is considered unstable; it may change at any time, in breaking or non-breaking ways, +/// including in SemVer minor releases. While the serde representation of data structs is guaranteed +/// to be stable, their Rust representation might not be. Use with caution. +/// </div> +#[icu_provider::data_struct(LstmForWordLineAutoV1Marker = "segmenter/lstm/wl_auto@1")] +#[derive(Debug, PartialEq, Clone)] +#[cfg_attr( + feature = "datagen", + derive(serde::Serialize, databake::Bake), + databake(path = icu_segmenter::provider), +)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize))] +#[yoke(prove_covariance_manually)] +#[non_exhaustive] +pub enum LstmDataV1<'data> { + /// The data as matrices of zerovec f32 values. + Float32(#[cfg_attr(feature = "serde", serde(borrow))] LstmDataFloat32<'data>), + // new variants should go BELOW existing ones + // Serde serializes based on variant name and index in the enum + // https://docs.rs/serde/latest/serde/trait.Serializer.html#tymethod.serialize_unit_variant +} + +pub(crate) struct LstmDataV1Marker; + +impl DataMarker for LstmDataV1Marker { + type Yokeable = LstmDataV1<'static>; +} diff --git a/third_party/rust/icu_segmenter/src/provider/mod.rs b/third_party/rust/icu_segmenter/src/provider/mod.rs new file mode 100644 index 0000000000..75f0d4d1e7 --- /dev/null +++ b/third_party/rust/icu_segmenter/src/provider/mod.rs @@ -0,0 +1,202 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +//! 🚧 \[Unstable\] Data provider struct definitions for this ICU4X component. +//! +//! <div class="stab unstable"> +//! 🚧 This code is considered unstable; it may change at any time, in breaking or non-breaking ways, +//! including in SemVer minor releases. While the serde representation of data structs is guaranteed +//! to be stable, their Rust representation might not be. Use with caution. +//! </div> +//! +//! Read more about data providers: [`icu_provider`] + +// Provider structs must be stable +#![allow(clippy::exhaustive_structs, clippy::exhaustive_enums)] + +mod lstm; +pub use lstm::*; + +// Re-export this from the provider module because it is needed by datagen +#[cfg(feature = "datagen")] +pub use crate::rule_segmenter::RuleStatusType; + +use icu_collections::codepointtrie::CodePointTrie; +use icu_provider::prelude::*; +use zerovec::ZeroVec; + +#[cfg(feature = "compiled_data")] +#[derive(Debug)] +/// Baked data +/// +/// <div class="stab unstable"> +/// 🚧 This code is considered unstable; it may change at any time, in breaking or non-breaking ways, +/// including in SemVer minor releases. In particular, the `DataProvider` implementations are only +/// guaranteed to match with this version's `*_unstable` providers. Use with caution. +/// </div> +pub struct Baked; + +#[cfg(feature = "compiled_data")] +const _: () = { + pub mod icu { + pub use crate as segmenter; + pub use icu_collections as collections; + } + icu_segmenter_data::make_provider!(Baked); + icu_segmenter_data::impl_segmenter_dictionary_w_auto_v1!(Baked); + icu_segmenter_data::impl_segmenter_dictionary_wl_ext_v1!(Baked); + icu_segmenter_data::impl_segmenter_grapheme_v1!(Baked); + icu_segmenter_data::impl_segmenter_line_v1!(Baked); + #[cfg(feature = "lstm")] + icu_segmenter_data::impl_segmenter_lstm_wl_auto_v1!(Baked); + icu_segmenter_data::impl_segmenter_sentence_v1!(Baked); + icu_segmenter_data::impl_segmenter_word_v1!(Baked); +}; + +#[cfg(feature = "datagen")] +/// The latest minimum set of keys required by this component. +pub const KEYS: &[DataKey] = &[ + DictionaryForWordLineExtendedV1Marker::KEY, + DictionaryForWordOnlyAutoV1Marker::KEY, + GraphemeClusterBreakDataV1Marker::KEY, + LineBreakDataV1Marker::KEY, + LstmForWordLineAutoV1Marker::KEY, + SentenceBreakDataV1Marker::KEY, + WordBreakDataV1Marker::KEY, +]; + +/// Pre-processed Unicode data in the form of tables to be used for rule-based breaking. +/// +/// <div class="stab unstable"> +/// 🚧 This code is considered unstable; it may change at any time, in breaking or non-breaking ways, +/// including in SemVer minor releases. While the serde representation of data structs is guaranteed +/// to be stable, their Rust representation might not be. Use with caution. +/// </div> +#[icu_provider::data_struct( + marker(LineBreakDataV1Marker, "segmenter/line@1", singleton), + marker(WordBreakDataV1Marker, "segmenter/word@1", singleton), + marker(GraphemeClusterBreakDataV1Marker, "segmenter/grapheme@1", singleton), + marker(SentenceBreakDataV1Marker, "segmenter/sentence@1", singleton) +)] +#[derive(Debug, PartialEq, Clone)] +#[cfg_attr( + feature = "datagen", + derive(serde::Serialize,databake::Bake), + databake(path = icu_segmenter::provider), +)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize))] +pub struct RuleBreakDataV1<'data> { + /// Property table for rule-based breaking. + #[cfg_attr(feature = "serde", serde(borrow))] + pub property_table: RuleBreakPropertyTable<'data>, + + /// Break state table for rule-based breaking. + #[cfg_attr(feature = "serde", serde(borrow))] + pub break_state_table: RuleBreakStateTable<'data>, + + /// Rule status table for rule-based breaking. + #[cfg_attr(feature = "serde", serde(borrow))] + pub rule_status_table: RuleStatusTable<'data>, + + /// Number of properties; should be the square root of the length of [`Self::break_state_table`]. + pub property_count: u8, + + /// The index of the last simple state for [`Self::break_state_table`]. (A simple state has no + /// `left` nor `right` in SegmenterProperty). + pub last_codepoint_property: i8, + + /// The index of SOT (start of text) state for [`Self::break_state_table`]. + pub sot_property: u8, + + /// The index of EOT (end of text) state [`Self::break_state_table`]. + pub eot_property: u8, + + /// The index of "SA" state (or 127 if the complex language isn't handled) for + /// [`Self::break_state_table`]. + pub complex_property: u8, +} + +/// Property table for rule-based breaking. +/// +/// <div class="stab unstable"> +/// 🚧 This code is considered unstable; it may change at any time, in breaking or non-breaking ways, +/// including in SemVer minor releases. While the serde representation of data structs is guaranteed +/// to be stable, their Rust representation might not be. Use with caution. +/// </div> +#[derive(Debug, PartialEq, Clone, yoke::Yokeable, zerofrom::ZeroFrom)] +#[cfg_attr( + feature = "datagen", + derive(serde::Serialize,databake::Bake), + databake(path = icu_segmenter::provider), +)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize))] +pub struct RuleBreakPropertyTable<'data>( + #[cfg_attr(feature = "serde", serde(borrow))] pub CodePointTrie<'data, u8>, +); + +/// Break state table for rule-based breaking. +/// +/// <div class="stab unstable"> +/// 🚧 This code is considered unstable; it may change at any time, in breaking or non-breaking ways, +/// including in SemVer minor releases. While the serde representation of data structs is guaranteed +/// to be stable, their Rust representation might not be. Use with caution. +/// </div> +#[derive(Debug, PartialEq, Clone, yoke::Yokeable, zerofrom::ZeroFrom)] +#[cfg_attr( + feature = "datagen", + derive(serde::Serialize,databake::Bake), + databake(path = icu_segmenter::provider), +)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize))] +pub struct RuleBreakStateTable<'data>( + #[cfg_attr(feature = "serde", serde(borrow))] pub ZeroVec<'data, i8>, +); + +/// Rules status data for rule_status and is_word_like of word segmenter. +/// +/// <div class="stab unstable"> +/// 🚧 This code is considered unstable; it may change at any time, in breaking or non-breaking ways, +/// including in SemVer minor releases. While the serde representation of data structs is guaranteed +/// to be stable, their Rust representation might not be. Use with caution. +/// </div> +#[derive(Debug, PartialEq, Clone, yoke::Yokeable, zerofrom::ZeroFrom)] +#[cfg_attr( + feature = "datagen", + derive(serde::Serialize,databake::Bake), + databake(path = icu_segmenter::provider), +)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize))] +pub struct RuleStatusTable<'data>( + #[cfg_attr(feature = "serde", serde(borrow))] pub ZeroVec<'data, u8>, +); + +/// char16trie data for dictionary break +/// +/// <div class="stab unstable"> +/// 🚧 This code is considered unstable; it may change at any time, in breaking or non-breaking ways, +/// including in SemVer minor releases. While the serde representation of data structs is guaranteed +/// to be stable, their Rust representation might not be. Use with caution. +/// </div> +#[icu_provider::data_struct( + DictionaryForWordOnlyAutoV1Marker = "segmenter/dictionary/w_auto@1", + DictionaryForWordLineExtendedV1Marker = "segmenter/dictionary/wl_ext@1" +)] +#[derive(Debug, PartialEq, Clone)] +#[cfg_attr( + feature = "datagen", + derive(serde::Serialize,databake::Bake), + databake(path = icu_segmenter::provider), +)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize))] +pub struct UCharDictionaryBreakDataV1<'data> { + /// Dictionary data of char16trie. + #[cfg_attr(feature = "serde", serde(borrow))] + pub trie_data: ZeroVec<'data, u16>, +} + +pub(crate) struct UCharDictionaryBreakDataV1Marker; + +impl DataMarker for UCharDictionaryBreakDataV1Marker { + type Yokeable = UCharDictionaryBreakDataV1<'static>; +} diff --git a/third_party/rust/icu_segmenter/src/rule_segmenter.rs b/third_party/rust/icu_segmenter/src/rule_segmenter.rs new file mode 100644 index 0000000000..740138e4ca --- /dev/null +++ b/third_party/rust/icu_segmenter/src/rule_segmenter.rs @@ -0,0 +1,349 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +use crate::complex::ComplexPayloads; +use crate::indices::{Latin1Indices, Utf16Indices}; +use crate::provider::RuleBreakDataV1; +use crate::symbols::*; +use core::str::CharIndices; +use utf8_iter::Utf8CharIndices; + +/// The category tag that is returned by +/// [`WordBreakIterator::word_type()`][crate::WordBreakIterator::word_type()]. +#[non_exhaustive] +#[derive(Copy, Clone, PartialEq, Debug)] +#[repr(u8)] +pub enum RuleStatusType { + /// No category tag + None = 0, + /// Number category tag + Number = 1, + /// Letter category tag, including CJK. + Letter = 2, +} + +/// A trait allowing for RuleBreakIterator to be generalized to multiple string +/// encoding methods and granularity such as grapheme cluster, word, etc. +pub trait RuleBreakType<'l, 's> { + /// The iterator over characters. + type IterAttr: Iterator<Item = (usize, Self::CharType)> + Clone + core::fmt::Debug; + + /// The character type. + type CharType: Copy + Into<u32> + core::fmt::Debug; + + fn get_current_position_character_len(iter: &RuleBreakIterator<'l, 's, Self>) -> usize; + + fn handle_complex_language( + iter: &mut RuleBreakIterator<'l, 's, Self>, + left_codepoint: Self::CharType, + ) -> Option<usize>; +} + +/// Implements the [`Iterator`] trait over the segmenter boundaries of the given string. +/// +/// Lifetimes: +/// +/// - `'l` = lifetime of the segmenter object from which this iterator was created +/// - `'s` = lifetime of the string being segmented +/// +/// The [`Iterator::Item`] is an [`usize`] representing index of a code unit +/// _after_ the boundary (for a boundary at the end of text, this index is the length +/// of the [`str`] or array of code units). +#[derive(Debug)] +pub struct RuleBreakIterator<'l, 's, Y: RuleBreakType<'l, 's> + ?Sized> { + pub(crate) iter: Y::IterAttr, + pub(crate) len: usize, + pub(crate) current_pos_data: Option<(usize, Y::CharType)>, + pub(crate) result_cache: alloc::vec::Vec<usize>, + pub(crate) data: &'l RuleBreakDataV1<'l>, + pub(crate) complex: Option<&'l ComplexPayloads>, + pub(crate) boundary_property: u8, +} + +impl<'l, 's, Y: RuleBreakType<'l, 's> + ?Sized> Iterator for RuleBreakIterator<'l, 's, Y> { + type Item = usize; + + fn next(&mut self) -> Option<Self::Item> { + // If we have break point cache by previous run, return this result + if let Some(&first_result) = self.result_cache.first() { + let mut i = 0; + loop { + if i == first_result { + self.result_cache = self.result_cache.iter().skip(1).map(|r| r - i).collect(); + return self.get_current_position(); + } + i += Y::get_current_position_character_len(self); + self.advance_iter(); + if self.is_eof() { + self.result_cache.clear(); + return Some(self.len); + } + } + } + + if self.is_eof() { + self.advance_iter(); + if self.is_eof() && self.len == 0 { + // Empty string. Since `self.current_pos_data` is always going to be empty, + // we never read `self.len` except for here, so we can use it to mark that + // we have already returned the single empty-string breakpoint. + self.len = 1; + return Some(0); + } + // SOT x anything + let right_prop = self.get_current_break_property()?; + if self.is_break_from_table(self.data.sot_property, right_prop) { + self.boundary_property = 0; // SOT is special type + return self.get_current_position(); + } + } + + loop { + debug_assert!(!self.is_eof()); + let left_codepoint = self.get_current_codepoint()?; + let left_prop = self.get_break_property(left_codepoint); + self.advance_iter(); + + let Some(right_prop) = self.get_current_break_property() else { + self.boundary_property = left_prop; + return Some(self.len); + }; + + // Some segmenter rules doesn't have language-specific rules, we have to use LSTM (or dictionary) segmenter. + // If property is marked as SA, use it + if right_prop == self.data.complex_property { + if left_prop != self.data.complex_property { + // break before SA + self.boundary_property = left_prop; + return self.get_current_position(); + } + let break_offset = Y::handle_complex_language(self, left_codepoint); + if break_offset.is_some() { + return break_offset; + } + } + + // If break_state is equals or grater than 0, it is alias of property. + let mut break_state = self.get_break_state_from_table(left_prop, right_prop); + + if break_state >= 0 { + // This isn't simple rule set. We need marker to restore iterator to previous position. + let mut previous_iter = self.iter.clone(); + let mut previous_pos_data = self.current_pos_data; + let mut previous_left_prop = left_prop; + + break_state &= !INTERMEDIATE_MATCH_RULE; + loop { + self.advance_iter(); + + let Some(prop) = self.get_current_break_property() else { + // Reached EOF. But we are analyzing multiple characters now, so next break may be previous point. + self.boundary_property = break_state as u8; + if self + .get_break_state_from_table(break_state as u8, self.data.eot_property) + == NOT_MATCH_RULE + { + self.boundary_property = previous_left_prop; + self.iter = previous_iter; + self.current_pos_data = previous_pos_data; + return self.get_current_position(); + } + // EOF + return Some(self.len); + }; + + let previous_break_state = break_state; + break_state = self.get_break_state_from_table(break_state as u8, prop); + if break_state < 0 { + break; + } + if previous_break_state >= 0 + && previous_break_state <= self.data.last_codepoint_property + { + // Move marker + previous_iter = self.iter.clone(); + previous_pos_data = self.current_pos_data; + previous_left_prop = break_state as u8; + } + if (break_state & INTERMEDIATE_MATCH_RULE) != 0 { + break_state -= INTERMEDIATE_MATCH_RULE; + previous_iter = self.iter.clone(); + previous_pos_data = self.current_pos_data; + previous_left_prop = break_state as u8; + } + } + if break_state == KEEP_RULE { + continue; + } + if break_state == NOT_MATCH_RULE { + self.boundary_property = previous_left_prop; + self.iter = previous_iter; + self.current_pos_data = previous_pos_data; + return self.get_current_position(); + } + return self.get_current_position(); + } + + if self.is_break_from_table(left_prop, right_prop) { + self.boundary_property = left_prop; + return self.get_current_position(); + } + } + } +} + +impl<'l, 's, Y: RuleBreakType<'l, 's> + ?Sized> RuleBreakIterator<'l, 's, Y> { + pub(crate) fn advance_iter(&mut self) { + self.current_pos_data = self.iter.next(); + } + + pub(crate) fn is_eof(&self) -> bool { + self.current_pos_data.is_none() + } + + pub(crate) fn get_current_break_property(&self) -> Option<u8> { + self.get_current_codepoint() + .map(|c| self.get_break_property(c)) + } + + pub(crate) fn get_current_position(&self) -> Option<usize> { + self.current_pos_data.map(|(pos, _)| pos) + } + + pub(crate) fn get_current_codepoint(&self) -> Option<Y::CharType> { + self.current_pos_data.map(|(_, codepoint)| codepoint) + } + + fn get_break_property(&self, codepoint: Y::CharType) -> u8 { + // Note: Default value is 0 == UNKNOWN + self.data.property_table.0.get32(codepoint.into()) + } + + fn get_break_state_from_table(&self, left: u8, right: u8) -> i8 { + let idx = left as usize * self.data.property_count as usize + right as usize; + // We use unwrap_or to fall back to the base case and prevent panics on bad data. + self.data.break_state_table.0.get(idx).unwrap_or(KEEP_RULE) + } + + fn is_break_from_table(&self, left: u8, right: u8) -> bool { + let rule = self.get_break_state_from_table(left, right); + if rule == KEEP_RULE { + return false; + } + if rule >= 0 { + // need additional next characters to get break rule. + return false; + } + true + } + + /// Return the status value of break boundary. + /// If segmenter isn't word, always return RuleStatusType::None + pub fn rule_status(&self) -> RuleStatusType { + if self.result_cache.first().is_some() { + // Dictionary type (CJ and East Asian) is letter. + return RuleStatusType::Letter; + } + if self.boundary_property == 0 { + // break position is SOT / Any + return RuleStatusType::None; + } + match self + .data + .rule_status_table + .0 + .get((self.boundary_property - 1) as usize) + { + Some(1) => RuleStatusType::Number, + Some(2) => RuleStatusType::Letter, + _ => RuleStatusType::None, + } + } + + /// Return true when break boundary is word-like such as letter/number/CJK + /// If segmenter isn't word, return false + pub fn is_word_like(&self) -> bool { + self.rule_status() != RuleStatusType::None + } +} + +#[derive(Debug)] +pub struct RuleBreakTypeUtf8; + +impl<'l, 's> RuleBreakType<'l, 's> for RuleBreakTypeUtf8 { + type IterAttr = CharIndices<'s>; + type CharType = char; + + fn get_current_position_character_len(iter: &RuleBreakIterator<Self>) -> usize { + iter.get_current_codepoint().map_or(0, |c| c.len_utf8()) + } + + fn handle_complex_language( + _: &mut RuleBreakIterator<Self>, + _: Self::CharType, + ) -> Option<usize> { + unreachable!() + } +} + +#[derive(Debug)] +pub struct RuleBreakTypePotentiallyIllFormedUtf8; + +impl<'l, 's> RuleBreakType<'l, 's> for RuleBreakTypePotentiallyIllFormedUtf8 { + type IterAttr = Utf8CharIndices<'s>; + type CharType = char; + + fn get_current_position_character_len(iter: &RuleBreakIterator<Self>) -> usize { + iter.get_current_codepoint().map_or(0, |c| c.len_utf8()) + } + + fn handle_complex_language( + _: &mut RuleBreakIterator<Self>, + _: Self::CharType, + ) -> Option<usize> { + unreachable!() + } +} + +#[derive(Debug)] +pub struct RuleBreakTypeLatin1; + +impl<'l, 's> RuleBreakType<'l, 's> for RuleBreakTypeLatin1 { + type IterAttr = Latin1Indices<'s>; + type CharType = u8; + + fn get_current_position_character_len(_: &RuleBreakIterator<Self>) -> usize { + unreachable!() + } + + fn handle_complex_language( + _: &mut RuleBreakIterator<Self>, + _: Self::CharType, + ) -> Option<usize> { + unreachable!() + } +} + +#[derive(Debug)] +pub struct RuleBreakTypeUtf16; + +impl<'l, 's> RuleBreakType<'l, 's> for RuleBreakTypeUtf16 { + type IterAttr = Utf16Indices<'s>; + type CharType = u32; + + fn get_current_position_character_len(iter: &RuleBreakIterator<Self>) -> usize { + match iter.get_current_codepoint() { + None => 0, + Some(ch) if ch >= 0x10000 => 2, + _ => 1, + } + } + + fn handle_complex_language( + _: &mut RuleBreakIterator<Self>, + _: Self::CharType, + ) -> Option<usize> { + unreachable!() + } +} diff --git a/third_party/rust/icu_segmenter/src/sentence.rs b/third_party/rust/icu_segmenter/src/sentence.rs new file mode 100644 index 0000000000..05173f9eb5 --- /dev/null +++ b/third_party/rust/icu_segmenter/src/sentence.rs @@ -0,0 +1,220 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +use alloc::vec::Vec; +use icu_provider::prelude::*; + +use crate::indices::{Latin1Indices, Utf16Indices}; +use crate::iterator_helpers::derive_usize_iterator_with_type; +use crate::rule_segmenter::*; +use crate::{provider::*, SegmenterError}; +use utf8_iter::Utf8CharIndices; + +/// Implements the [`Iterator`] trait over the sentence boundaries of the given string. +/// +/// Lifetimes: +/// +/// - `'l` = lifetime of the segmenter object from which this iterator was created +/// - `'s` = lifetime of the string being segmented +/// +/// The [`Iterator::Item`] is an [`usize`] representing index of a code unit +/// _after_ the boundary (for a boundary at the end of text, this index is the length +/// of the [`str`] or array of code units). +/// +/// For examples of use, see [`SentenceSegmenter`]. +#[derive(Debug)] +pub struct SentenceBreakIterator<'l, 's, Y: RuleBreakType<'l, 's> + ?Sized>( + RuleBreakIterator<'l, 's, Y>, +); + +derive_usize_iterator_with_type!(SentenceBreakIterator); + +/// Sentence break iterator for an `str` (a UTF-8 string). +/// +/// For examples of use, see [`SentenceSegmenter`]. +pub type SentenceBreakIteratorUtf8<'l, 's> = SentenceBreakIterator<'l, 's, RuleBreakTypeUtf8>; + +/// Sentence break iterator for a potentially invalid UTF-8 string. +/// +/// For examples of use, see [`SentenceSegmenter`]. +pub type SentenceBreakIteratorPotentiallyIllFormedUtf8<'l, 's> = + SentenceBreakIterator<'l, 's, RuleBreakTypePotentiallyIllFormedUtf8>; + +/// Sentence break iterator for a Latin-1 (8-bit) string. +/// +/// For examples of use, see [`SentenceSegmenter`]. +pub type SentenceBreakIteratorLatin1<'l, 's> = SentenceBreakIterator<'l, 's, RuleBreakTypeLatin1>; + +/// Sentence break iterator for a UTF-16 string. +/// +/// For examples of use, see [`SentenceSegmenter`]. +pub type SentenceBreakIteratorUtf16<'l, 's> = SentenceBreakIterator<'l, 's, RuleBreakTypeUtf16>; + +/// Supports loading sentence break data, and creating sentence break iterators for different string +/// encodings. +/// +/// # Examples +/// +/// Segment a string: +/// +/// ```rust +/// use icu_segmenter::SentenceSegmenter; +/// let segmenter = SentenceSegmenter::new(); +/// +/// let breakpoints: Vec<usize> = +/// segmenter.segment_str("Hello World").collect(); +/// assert_eq!(&breakpoints, &[0, 11]); +/// ``` +/// +/// Segment a Latin1 byte string: +/// +/// ```rust +/// use icu_segmenter::SentenceSegmenter; +/// let segmenter = SentenceSegmenter::new(); +/// +/// let breakpoints: Vec<usize> = +/// segmenter.segment_latin1(b"Hello World").collect(); +/// assert_eq!(&breakpoints, &[0, 11]); +/// ``` +/// +/// Successive boundaries can be used to retrieve the sentences. +/// In particular, the first boundary is always 0, and the last one is the +/// length of the segmented text in code units. +/// +/// ```rust +/// # use icu_segmenter::SentenceSegmenter; +/// # let segmenter = SentenceSegmenter::new(); +/// use itertools::Itertools; +/// let text = "Ceci tuera cela. Le livre tuera l’édifice."; +/// let sentences: Vec<&str> = segmenter +/// .segment_str(text) +/// .tuple_windows() +/// .map(|(i, j)| &text[i..j]) +/// .collect(); +/// assert_eq!( +/// &sentences, +/// &["Ceci tuera cela. ", "Le livre tuera l’édifice."] +/// ); +/// ``` +#[derive(Debug)] +pub struct SentenceSegmenter { + payload: DataPayload<SentenceBreakDataV1Marker>, +} + +#[cfg(feature = "compiled_data")] +impl Default for SentenceSegmenter { + fn default() -> Self { + Self::new() + } +} + +impl SentenceSegmenter { + /// Constructs a [`SentenceSegmenter`] with an invariant locale and compiled data. + /// + /// ✨ *Enabled with the `compiled_data` Cargo feature.* + /// + /// [📚 Help choosing a constructor](icu_provider::constructors) + #[cfg(feature = "compiled_data")] + pub fn new() -> Self { + Self { + payload: DataPayload::from_static_ref( + crate::provider::Baked::SINGLETON_SEGMENTER_SENTENCE_V1, + ), + } + } + + icu_provider::gen_any_buffer_data_constructors!(locale: skip, options: skip, error: SegmenterError, + #[cfg(skip)] + functions: [ + new, + try_new_with_any_provider, + try_new_with_buffer_provider, + try_new_unstable, + Self, + ] + ); + + #[doc = icu_provider::gen_any_buffer_unstable_docs!(UNSTABLE, Self::new)] + pub fn try_new_unstable<D>(provider: &D) -> Result<Self, SegmenterError> + where + D: DataProvider<SentenceBreakDataV1Marker> + ?Sized, + { + let payload = provider.load(Default::default())?.take_payload()?; + Ok(Self { payload }) + } + + /// Creates a sentence break iterator for an `str` (a UTF-8 string). + /// + /// There are always breakpoints at 0 and the string length, or only at 0 for the empty string. + pub fn segment_str<'l, 's>(&'l self, input: &'s str) -> SentenceBreakIteratorUtf8<'l, 's> { + SentenceBreakIterator(RuleBreakIterator { + iter: input.char_indices(), + len: input.len(), + current_pos_data: None, + result_cache: Vec::new(), + data: self.payload.get(), + complex: None, + boundary_property: 0, + }) + } + /// Creates a sentence break iterator for a potentially ill-formed UTF8 string + /// + /// Invalid characters are treated as REPLACEMENT CHARACTER + /// + /// There are always breakpoints at 0 and the string length, or only at 0 for the empty string. + pub fn segment_utf8<'l, 's>( + &'l self, + input: &'s [u8], + ) -> SentenceBreakIteratorPotentiallyIllFormedUtf8<'l, 's> { + SentenceBreakIterator(RuleBreakIterator { + iter: Utf8CharIndices::new(input), + len: input.len(), + current_pos_data: None, + result_cache: Vec::new(), + data: self.payload.get(), + complex: None, + boundary_property: 0, + }) + } + /// Creates a sentence break iterator for a Latin-1 (8-bit) string. + /// + /// There are always breakpoints at 0 and the string length, or only at 0 for the empty string. + pub fn segment_latin1<'l, 's>( + &'l self, + input: &'s [u8], + ) -> SentenceBreakIteratorLatin1<'l, 's> { + SentenceBreakIterator(RuleBreakIterator { + iter: Latin1Indices::new(input), + len: input.len(), + current_pos_data: None, + result_cache: Vec::new(), + data: self.payload.get(), + complex: None, + boundary_property: 0, + }) + } + + /// Creates a sentence break iterator for a UTF-16 string. + /// + /// There are always breakpoints at 0 and the string length, or only at 0 for the empty string. + pub fn segment_utf16<'l, 's>(&'l self, input: &'s [u16]) -> SentenceBreakIteratorUtf16<'l, 's> { + SentenceBreakIterator(RuleBreakIterator { + iter: Utf16Indices::new(input), + len: input.len(), + current_pos_data: None, + result_cache: Vec::new(), + data: self.payload.get(), + complex: None, + boundary_property: 0, + }) + } +} + +#[cfg(all(test, feature = "serde"))] +#[test] +fn empty_string() { + let segmenter = SentenceSegmenter::new(); + let breaks: Vec<usize> = segmenter.segment_str("").collect(); + assert_eq!(breaks, [0]); +} diff --git a/third_party/rust/icu_segmenter/src/symbols.rs b/third_party/rust/icu_segmenter/src/symbols.rs new file mode 100644 index 0000000000..b2c9a2450f --- /dev/null +++ b/third_party/rust/icu_segmenter/src/symbols.rs @@ -0,0 +1,141 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +// TODO(#1637): The numeric values of these symbols are generated by the old transformation code +// (aka build.rs). We should move these symbols into RuleBreakDataV1, and remove this file. + +// Used by line.rs. +#[allow(dead_code)] +pub const UNKNOWN: u8 = 0; +#[allow(dead_code)] +pub const AI: u8 = 1; +#[allow(dead_code)] +pub const AL: u8 = 2; +#[allow(dead_code)] +pub const B2: u8 = 3; +#[allow(dead_code)] +pub const BA: u8 = 4; +#[allow(dead_code)] +pub const BB: u8 = 5; +#[allow(dead_code)] +pub const BK: u8 = 6; +#[allow(dead_code)] +pub const CB: u8 = 7; +#[allow(dead_code)] +pub const CJ: u8 = 8; +#[allow(dead_code)] +pub const CL: u8 = 9; +#[allow(dead_code)] +pub const CM: u8 = 10; +#[allow(dead_code)] +pub const CP: u8 = 11; +#[allow(dead_code)] +pub const CR: u8 = 12; +#[allow(dead_code)] +pub const EB: u8 = 13; +#[allow(dead_code)] +pub const EM: u8 = 14; +#[allow(dead_code)] +pub const EX: u8 = 15; +#[allow(dead_code)] +pub const GL: u8 = 16; +#[allow(dead_code)] +pub const H2: u8 = 17; +#[allow(dead_code)] +pub const H3: u8 = 18; +#[allow(dead_code)] +pub const HL: u8 = 19; +#[allow(dead_code)] +pub const HY: u8 = 20; +#[allow(dead_code)] +pub const ID: u8 = 21; +#[allow(dead_code)] +pub const ID_CN: u8 = 22; +#[allow(dead_code)] +pub const IN: u8 = 23; +#[allow(dead_code)] +pub const IS: u8 = 24; +#[allow(dead_code)] +pub const JL: u8 = 25; +#[allow(dead_code)] +pub const JT: u8 = 26; +#[allow(dead_code)] +pub const JV: u8 = 27; +#[allow(dead_code)] +pub const LF: u8 = 28; +#[allow(dead_code)] +pub const NL: u8 = 29; +#[allow(dead_code)] +pub const NS: u8 = 30; +#[allow(dead_code)] +pub const NU: u8 = 31; +#[allow(dead_code)] +pub const OP_EA: u8 = 32; +#[allow(dead_code)] +pub const OP_OP30: u8 = 33; +#[allow(dead_code)] +pub const PO: u8 = 34; +#[allow(dead_code)] +pub const PO_EAW: u8 = 35; +#[allow(dead_code)] +pub const PR: u8 = 36; +#[allow(dead_code)] +pub const PR_EAW: u8 = 37; +#[allow(dead_code)] +pub const QU: u8 = 38; +#[allow(dead_code)] +pub const RI: u8 = 39; +#[allow(dead_code)] +pub const SA: u8 = 40; +#[allow(dead_code)] +pub const SG: u8 = 41; +#[allow(dead_code)] +pub const SP: u8 = 42; +#[allow(dead_code)] +pub const SY: u8 = 43; +#[allow(dead_code)] +pub const WJ: u8 = 44; +#[allow(dead_code)] +pub const XX: u8 = 45; +#[allow(dead_code)] +pub const ZW: u8 = 46; +#[allow(dead_code)] +pub const ZWJ: u8 = 47; +#[allow(dead_code)] +pub const OP_SP: u8 = 48; +#[allow(dead_code)] +pub const QU_SP: u8 = 49; +#[allow(dead_code)] +pub const CL_CP_SP: u8 = 50; +#[allow(dead_code)] +pub const B2_SP: u8 = 51; +#[allow(dead_code)] +pub const HL_HY: u8 = 52; +#[allow(dead_code)] +pub const LB25_HY: u8 = 53; +#[allow(dead_code)] +pub const LB25_OP: u8 = 54; +#[allow(dead_code)] +pub const LB25_NU_IS: u8 = 55; +#[allow(dead_code)] +pub const LB25_NU_SY: u8 = 56; +#[allow(dead_code)] +pub const LB25_NU_CL: u8 = 57; +#[allow(dead_code)] +pub const LB25_NU_CP: u8 = 58; +#[allow(dead_code)] +pub const RI_RI: u8 = 59; +#[allow(dead_code)] +pub const SOT: u8 = 60; +#[allow(dead_code)] +pub const EOT: u8 = 61; + +// Used by all segmenters. +pub const BREAK_RULE: i8 = -128; +pub const UNKNOWN_RULE: i8 = -127; +pub const NOT_MATCH_RULE: i8 = -2; +pub const KEEP_RULE: i8 = -1; +// This is a mask bit chosen sufficiently large than all other concrete states. +// If a break state contains this bit, we have to look ahead one more character. +pub const INTERMEDIATE_MATCH_RULE: i8 = 64; diff --git a/third_party/rust/icu_segmenter/src/word.rs b/third_party/rust/icu_segmenter/src/word.rs new file mode 100644 index 0000000000..de4af16543 --- /dev/null +++ b/third_party/rust/icu_segmenter/src/word.rs @@ -0,0 +1,605 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +use crate::complex::*; +use crate::indices::{Latin1Indices, Utf16Indices}; +use crate::iterator_helpers::derive_usize_iterator_with_type; +use crate::provider::*; +use crate::rule_segmenter::*; +use crate::SegmenterError; +use alloc::string::String; +use alloc::vec; +use alloc::vec::Vec; +use core::str::CharIndices; +use icu_provider::prelude::*; +use utf8_iter::Utf8CharIndices; + +/// Implements the [`Iterator`] trait over the word boundaries of the given string. +/// +/// Lifetimes: +/// +/// - `'l` = lifetime of the segmenter object from which this iterator was created +/// - `'s` = lifetime of the string being segmented +/// +/// The [`Iterator::Item`] is an [`usize`] representing index of a code unit +/// _after_ the boundary (for a boundary at the end of text, this index is the length +/// of the [`str`] or array of code units). +/// +/// For examples of use, see [`WordSegmenter`]. +#[derive(Debug)] +pub struct WordBreakIterator<'l, 's, Y: RuleBreakType<'l, 's> + ?Sized>( + RuleBreakIterator<'l, 's, Y>, +); + +derive_usize_iterator_with_type!(WordBreakIterator); + +/// The word type tag that is returned by [`WordBreakIterator::word_type()`]. +#[non_exhaustive] +#[derive(Copy, Clone, PartialEq, Debug)] +#[repr(u8)] +pub enum WordType { + /// No category tag. + None = 0, + /// Number category tag. + Number = 1, + /// Letter category tag, including CJK. + Letter = 2, +} + +impl<'l, 's, Y: RuleBreakType<'l, 's> + ?Sized> WordBreakIterator<'l, 's, Y> { + /// Returns the word type of the segment preceding the current boundary. + #[inline] + pub fn word_type(&self) -> WordType { + match self.0.rule_status() { + RuleStatusType::None => WordType::None, + RuleStatusType::Number => WordType::Number, + RuleStatusType::Letter => WordType::Letter, + } + } + /// Returns `true` when the segment preceding the current boundary is word-like, + /// such as letter, number, or CJK. + #[inline] + pub fn is_word_like(&self) -> bool { + self.0.is_word_like() + } +} + +/// Word break iterator for an `str` (a UTF-8 string). +/// +/// For examples of use, see [`WordSegmenter`]. +pub type WordBreakIteratorUtf8<'l, 's> = WordBreakIterator<'l, 's, WordBreakTypeUtf8>; + +/// Word break iterator for a potentially invalid UTF-8 string. +/// +/// For examples of use, see [`WordSegmenter`]. +pub type WordBreakIteratorPotentiallyIllFormedUtf8<'l, 's> = + WordBreakIterator<'l, 's, WordBreakTypePotentiallyIllFormedUtf8>; + +/// Word break iterator for a Latin-1 (8-bit) string. +/// +/// For examples of use, see [`WordSegmenter`]. +pub type WordBreakIteratorLatin1<'l, 's> = WordBreakIterator<'l, 's, RuleBreakTypeLatin1>; + +/// Word break iterator for a UTF-16 string. +/// +/// For examples of use, see [`WordSegmenter`]. +pub type WordBreakIteratorUtf16<'l, 's> = WordBreakIterator<'l, 's, WordBreakTypeUtf16>; + +/// Supports loading word break data, and creating word break iterators for different string +/// encodings. +/// +/// # Examples +/// +/// Segment a string: +/// +/// ```rust +/// use icu_segmenter::WordSegmenter; +/// let segmenter = WordSegmenter::new_auto(); +/// +/// let breakpoints: Vec<usize> = +/// segmenter.segment_str("Hello World").collect(); +/// assert_eq!(&breakpoints, &[0, 5, 6, 11]); +/// ``` +/// +/// Segment a Latin1 byte string: +/// +/// ```rust +/// use icu_segmenter::WordSegmenter; +/// let segmenter = WordSegmenter::new_auto(); +/// +/// let breakpoints: Vec<usize> = +/// segmenter.segment_latin1(b"Hello World").collect(); +/// assert_eq!(&breakpoints, &[0, 5, 6, 11]); +/// ``` +/// +/// Successive boundaries can be used to retrieve the segments. +/// In particular, the first boundary is always 0, and the last one is the +/// length of the segmented text in code units. +/// +/// ```rust +/// # use icu_segmenter::WordSegmenter; +/// # let segmenter = WordSegmenter::new_auto(); +/// use itertools::Itertools; +/// let text = "Mark’d ye his words?"; +/// let segments: Vec<&str> = segmenter +/// .segment_str(text) +/// .tuple_windows() +/// .map(|(i, j)| &text[i..j]) +/// .collect(); +/// assert_eq!( +/// &segments, +/// &["Mark’d", " ", "ye", " ", "his", " ", "words", "?"] +/// ); +/// ``` +/// +/// Not all segments delimited by word boundaries are words; some are interword +/// segments such as spaces and punctuation. +/// The [`WordBreakIterator::word_type()`] of a boundary can be used to +/// classify the preceding segment. +/// ```rust +/// # use itertools::Itertools; +/// # use icu_segmenter::{WordType, WordSegmenter}; +/// # let segmenter = WordSegmenter::new_auto(); +/// # let text = "Mark’d ye his words?"; +/// let words: Vec<&str> = { +/// let mut it = segmenter.segment_str(text); +/// std::iter::from_fn(move || it.next().map(|i| (i, it.word_type()))) +/// .tuple_windows() +/// .filter(|(_, (_, status))| *status == WordType::Letter) +/// .map(|((i, _), (j, _))| &text[i..j]) +/// .collect() +/// }; +/// assert_eq!(&words, &["Mark’d", "ye", "his", "words"]); +/// ``` +#[derive(Debug)] +pub struct WordSegmenter { + payload: DataPayload<WordBreakDataV1Marker>, + complex: ComplexPayloads, +} + +impl WordSegmenter { + /// Constructs a [`WordSegmenter`] with an invariant locale and the best available compiled data for + /// complex scripts (Chinese, Japanese, Khmer, Lao, Myanmar, and Thai). + /// + /// The current behavior, which is subject to change, is to use the LSTM model when available + /// and the dictionary model for Chinese and Japanese. + /// + /// ✨ *Enabled with the `compiled_data` and `auto` Cargo features.* + /// + /// [📚 Help choosing a constructor](icu_provider::constructors) + /// + /// # Examples + /// + /// Behavior with complex scripts: + /// + /// ``` + /// use icu::segmenter::WordSegmenter; + /// + /// let th_str = "ทุกสองสัปดาห์"; + /// let ja_str = "こんにちは世界"; + /// + /// let segmenter = WordSegmenter::new_auto(); + /// + /// let th_bps = segmenter.segment_str(th_str).collect::<Vec<_>>(); + /// let ja_bps = segmenter.segment_str(ja_str).collect::<Vec<_>>(); + /// + /// assert_eq!(th_bps, [0, 9, 18, 39]); + /// assert_eq!(ja_bps, [0, 15, 21]); + /// ``` + #[cfg(feature = "compiled_data")] + #[cfg(feature = "auto")] + pub fn new_auto() -> Self { + Self { + payload: DataPayload::from_static_ref( + crate::provider::Baked::SINGLETON_SEGMENTER_WORD_V1, + ), + complex: ComplexPayloads::new_auto(), + } + } + + #[cfg(feature = "auto")] + icu_provider::gen_any_buffer_data_constructors!( + locale: skip, + options: skip, + error: SegmenterError, + #[cfg(skip)] + functions: [ + try_new_auto, + try_new_auto_with_any_provider, + try_new_auto_with_buffer_provider, + try_new_auto_unstable, + Self + ] + ); + + #[cfg(feature = "auto")] + #[doc = icu_provider::gen_any_buffer_unstable_docs!(UNSTABLE, Self::new_auto)] + pub fn try_new_auto_unstable<D>(provider: &D) -> Result<Self, SegmenterError> + where + D: DataProvider<WordBreakDataV1Marker> + + DataProvider<DictionaryForWordOnlyAutoV1Marker> + + DataProvider<LstmForWordLineAutoV1Marker> + + DataProvider<GraphemeClusterBreakDataV1Marker> + + ?Sized, + { + Ok(Self { + payload: provider.load(Default::default())?.take_payload()?, + complex: ComplexPayloads::try_new_auto(provider)?, + }) + } + + /// Constructs a [`WordSegmenter`] with an invariant locale and compiled LSTM data for + /// complex scripts (Burmese, Khmer, Lao, and Thai). + /// + /// The LSTM, or Long Term Short Memory, is a machine learning model. It is smaller than + /// the full dictionary but more expensive during segmentation (inference). + /// + /// Warning: there is not currently an LSTM model for Chinese or Japanese, so the [`WordSegmenter`] + /// created by this function will have unexpected behavior in spans of those scripts. + /// + /// ✨ *Enabled with the `compiled_data` and `lstm` Cargo features.* + /// + /// [📚 Help choosing a constructor](icu_provider::constructors) + /// + /// # Examples + /// + /// Behavior with complex scripts: + /// + /// ``` + /// use icu::segmenter::WordSegmenter; + /// + /// let th_str = "ทุกสองสัปดาห์"; + /// let ja_str = "こんにちは世界"; + /// + /// let segmenter = WordSegmenter::new_lstm(); + /// + /// let th_bps = segmenter.segment_str(th_str).collect::<Vec<_>>(); + /// let ja_bps = segmenter.segment_str(ja_str).collect::<Vec<_>>(); + /// + /// assert_eq!(th_bps, [0, 9, 18, 39]); + /// + /// // Note: We aren't able to find a suitable breakpoint in Chinese/Japanese. + /// assert_eq!(ja_bps, [0, 21]); + /// ``` + #[cfg(feature = "compiled_data")] + #[cfg(feature = "lstm")] + pub fn new_lstm() -> Self { + Self { + payload: DataPayload::from_static_ref( + crate::provider::Baked::SINGLETON_SEGMENTER_WORD_V1, + ), + complex: ComplexPayloads::new_lstm(), + } + } + + #[cfg(feature = "lstm")] + icu_provider::gen_any_buffer_data_constructors!( + locale: skip, + options: skip, + error: SegmenterError, + #[cfg(skip)] + functions: [ + new_lstm, + try_new_lstm_with_any_provider, + try_new_lstm_with_buffer_provider, + try_new_lstm_unstable, + Self + ] + ); + + #[cfg(feature = "lstm")] + #[doc = icu_provider::gen_any_buffer_unstable_docs!(UNSTABLE, Self::new_lstm)] + pub fn try_new_lstm_unstable<D>(provider: &D) -> Result<Self, SegmenterError> + where + D: DataProvider<WordBreakDataV1Marker> + + DataProvider<LstmForWordLineAutoV1Marker> + + DataProvider<GraphemeClusterBreakDataV1Marker> + + ?Sized, + { + Ok(Self { + payload: provider.load(Default::default())?.take_payload()?, + complex: ComplexPayloads::try_new_lstm(provider)?, + }) + } + + /// Construct a [`WordSegmenter`] with an invariant locale and compiled dictionary data for + /// complex scripts (Chinese, Japanese, Khmer, Lao, Myanmar, and Thai). + /// + /// The dictionary model uses a list of words to determine appropriate breakpoints. It is + /// faster than the LSTM model but requires more data. + /// + /// ✨ *Enabled with the `compiled_data` Cargo feature.* + /// + /// [📚 Help choosing a constructor](icu_provider::constructors) + /// + /// # Examples + /// + /// Behavior with complex scripts: + /// + /// ``` + /// use icu::segmenter::WordSegmenter; + /// + /// let th_str = "ทุกสองสัปดาห์"; + /// let ja_str = "こんにちは世界"; + /// + /// let segmenter = WordSegmenter::new_dictionary(); + /// + /// let th_bps = segmenter.segment_str(th_str).collect::<Vec<_>>(); + /// let ja_bps = segmenter.segment_str(ja_str).collect::<Vec<_>>(); + /// + /// assert_eq!(th_bps, [0, 9, 18, 39]); + /// assert_eq!(ja_bps, [0, 15, 21]); + /// ``` + #[cfg(feature = "compiled_data")] + pub fn new_dictionary() -> Self { + Self { + payload: DataPayload::from_static_ref( + crate::provider::Baked::SINGLETON_SEGMENTER_WORD_V1, + ), + complex: ComplexPayloads::new_dict(), + } + } + + icu_provider::gen_any_buffer_data_constructors!( + locale: skip, + options: skip, + error: SegmenterError, + #[cfg(skip)] + functions: [ + new_dictionary, + try_new_dictionary_with_any_provider, + try_new_dictionary_with_buffer_provider, + try_new_dictionary_unstable, + Self + ] + ); + + #[doc = icu_provider::gen_any_buffer_unstable_docs!(UNSTABLE, Self::new_dictionary)] + pub fn try_new_dictionary_unstable<D>(provider: &D) -> Result<Self, SegmenterError> + where + D: DataProvider<WordBreakDataV1Marker> + + DataProvider<DictionaryForWordOnlyAutoV1Marker> + + DataProvider<DictionaryForWordLineExtendedV1Marker> + + DataProvider<GraphemeClusterBreakDataV1Marker> + + ?Sized, + { + Ok(Self { + payload: provider.load(Default::default())?.take_payload()?, + complex: ComplexPayloads::try_new_dict(provider)?, + }) + } + + /// Creates a word break iterator for an `str` (a UTF-8 string). + /// + /// There are always breakpoints at 0 and the string length, or only at 0 for the empty string. + pub fn segment_str<'l, 's>(&'l self, input: &'s str) -> WordBreakIteratorUtf8<'l, 's> { + WordBreakIterator(RuleBreakIterator { + iter: input.char_indices(), + len: input.len(), + current_pos_data: None, + result_cache: Vec::new(), + data: self.payload.get(), + complex: Some(&self.complex), + boundary_property: 0, + }) + } + + /// Creates a word break iterator for a potentially ill-formed UTF8 string + /// + /// Invalid characters are treated as REPLACEMENT CHARACTER + /// + /// There are always breakpoints at 0 and the string length, or only at 0 for the empty string. + pub fn segment_utf8<'l, 's>( + &'l self, + input: &'s [u8], + ) -> WordBreakIteratorPotentiallyIllFormedUtf8<'l, 's> { + WordBreakIterator(RuleBreakIterator { + iter: Utf8CharIndices::new(input), + len: input.len(), + current_pos_data: None, + result_cache: Vec::new(), + data: self.payload.get(), + complex: Some(&self.complex), + boundary_property: 0, + }) + } + + /// Creates a word break iterator for a Latin-1 (8-bit) string. + /// + /// There are always breakpoints at 0 and the string length, or only at 0 for the empty string. + pub fn segment_latin1<'l, 's>(&'l self, input: &'s [u8]) -> WordBreakIteratorLatin1<'l, 's> { + WordBreakIterator(RuleBreakIterator { + iter: Latin1Indices::new(input), + len: input.len(), + current_pos_data: None, + result_cache: Vec::new(), + data: self.payload.get(), + complex: Some(&self.complex), + boundary_property: 0, + }) + } + + /// Creates a word break iterator for a UTF-16 string. + /// + /// There are always breakpoints at 0 and the string length, or only at 0 for the empty string. + pub fn segment_utf16<'l, 's>(&'l self, input: &'s [u16]) -> WordBreakIteratorUtf16<'l, 's> { + WordBreakIterator(RuleBreakIterator { + iter: Utf16Indices::new(input), + len: input.len(), + current_pos_data: None, + result_cache: Vec::new(), + data: self.payload.get(), + complex: Some(&self.complex), + boundary_property: 0, + }) + } +} + +#[derive(Debug)] +pub struct WordBreakTypeUtf8; + +impl<'l, 's> RuleBreakType<'l, 's> for WordBreakTypeUtf8 { + type IterAttr = CharIndices<'s>; + type CharType = char; + + fn get_current_position_character_len(iter: &RuleBreakIterator<Self>) -> usize { + iter.get_current_codepoint().map_or(0, |c| c.len_utf8()) + } + + fn handle_complex_language( + iter: &mut RuleBreakIterator<'l, 's, Self>, + left_codepoint: Self::CharType, + ) -> Option<usize> { + handle_complex_language_utf8(iter, left_codepoint) + } +} + +#[derive(Debug)] +pub struct WordBreakTypePotentiallyIllFormedUtf8; + +impl<'l, 's> RuleBreakType<'l, 's> for WordBreakTypePotentiallyIllFormedUtf8 { + type IterAttr = Utf8CharIndices<'s>; + type CharType = char; + + fn get_current_position_character_len(iter: &RuleBreakIterator<Self>) -> usize { + iter.get_current_codepoint().map_or(0, |c| c.len_utf8()) + } + + fn handle_complex_language( + iter: &mut RuleBreakIterator<'l, 's, Self>, + left_codepoint: Self::CharType, + ) -> Option<usize> { + handle_complex_language_utf8(iter, left_codepoint) + } +} + +/// handle_complex_language impl for UTF8 iterators +fn handle_complex_language_utf8<'l, 's, T>( + iter: &mut RuleBreakIterator<'l, 's, T>, + left_codepoint: T::CharType, +) -> Option<usize> +where + T: RuleBreakType<'l, 's, CharType = char>, +{ + // word segmenter doesn't define break rules for some languages such as Thai. + let start_iter = iter.iter.clone(); + let start_point = iter.current_pos_data; + let mut s = String::new(); + s.push(left_codepoint); + loop { + debug_assert!(!iter.is_eof()); + s.push(iter.get_current_codepoint()?); + iter.advance_iter(); + if let Some(current_break_property) = iter.get_current_break_property() { + if current_break_property != iter.data.complex_property { + break; + } + } else { + // EOF + break; + } + } + + // Restore iterator to move to head of complex string + iter.iter = start_iter; + iter.current_pos_data = start_point; + #[allow(clippy::unwrap_used)] // iter.complex present for word segmenter + let breaks = complex_language_segment_str(iter.complex.unwrap(), &s); + iter.result_cache = breaks; + let first_pos = *iter.result_cache.first()?; + let mut i = left_codepoint.len_utf8(); + loop { + if i == first_pos { + // Re-calculate breaking offset + iter.result_cache = iter.result_cache.iter().skip(1).map(|r| r - i).collect(); + return iter.get_current_position(); + } + debug_assert!( + i < first_pos, + "we should always arrive at first_pos: near index {:?}", + iter.get_current_position() + ); + i += T::get_current_position_character_len(iter); + iter.advance_iter(); + if iter.is_eof() { + iter.result_cache.clear(); + return Some(iter.len); + } + } +} + +#[derive(Debug)] +pub struct WordBreakTypeUtf16; + +impl<'l, 's> RuleBreakType<'l, 's> for WordBreakTypeUtf16 { + type IterAttr = Utf16Indices<'s>; + type CharType = u32; + + fn get_current_position_character_len(iter: &RuleBreakIterator<Self>) -> usize { + match iter.get_current_codepoint() { + None => 0, + Some(ch) if ch >= 0x10000 => 2, + _ => 1, + } + } + + fn handle_complex_language( + iter: &mut RuleBreakIterator<Self>, + left_codepoint: Self::CharType, + ) -> Option<usize> { + // word segmenter doesn't define break rules for some languages such as Thai. + let start_iter = iter.iter.clone(); + let start_point = iter.current_pos_data; + let mut s = vec![left_codepoint as u16]; + loop { + debug_assert!(!iter.is_eof()); + s.push(iter.get_current_codepoint()? as u16); + iter.advance_iter(); + if let Some(current_break_property) = iter.get_current_break_property() { + if current_break_property != iter.data.complex_property { + break; + } + } else { + // EOF + break; + } + } + + // Restore iterator to move to head of complex string + iter.iter = start_iter; + iter.current_pos_data = start_point; + #[allow(clippy::unwrap_used)] // iter.complex present for word segmenter + let breaks = complex_language_segment_utf16(iter.complex.unwrap(), &s); + iter.result_cache = breaks; + // result_cache vector is utf-16 index that is in BMP. + let first_pos = *iter.result_cache.first()?; + let mut i = 1; + loop { + if i == first_pos { + // Re-calculate breaking offset + iter.result_cache = iter.result_cache.iter().skip(1).map(|r| r - i).collect(); + return iter.get_current_position(); + } + debug_assert!( + i < first_pos, + "we should always arrive at first_pos: near index {:?}", + iter.get_current_position() + ); + i += 1; + iter.advance_iter(); + if iter.is_eof() { + iter.result_cache.clear(); + return Some(iter.len); + } + } + } +} + +#[cfg(all(test, feature = "serde"))] +#[test] +fn empty_string() { + let segmenter = WordSegmenter::new_auto(); + let breaks: Vec<usize> = segmenter.segment_str("").collect(); + assert_eq!(breaks, [0]); +} |