diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
commit | 26a029d407be480d791972afb5975cf62c9360a6 (patch) | |
tree | f435a8308119effd964b339f76abb83a57c29483 /third_party/rust/icu_segmenter/src/complex | |
parent | Initial commit. (diff) | |
download | firefox-26a029d407be480d791972afb5975cf62c9360a6.tar.xz firefox-26a029d407be480d791972afb5975cf62c9360a6.zip |
Adding upstream version 124.0.1.upstream/124.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/icu_segmenter/src/complex')
-rw-r--r-- | third_party/rust/icu_segmenter/src/complex/dictionary.rs | 268 | ||||
-rw-r--r-- | third_party/rust/icu_segmenter/src/complex/language.rs | 161 | ||||
-rw-r--r-- | third_party/rust/icu_segmenter/src/complex/lstm/matrix.rs | 540 | ||||
-rw-r--r-- | third_party/rust/icu_segmenter/src/complex/lstm/mod.rs | 402 | ||||
-rw-r--r-- | third_party/rust/icu_segmenter/src/complex/mod.rs | 440 |
5 files changed, 1811 insertions, 0 deletions
diff --git a/third_party/rust/icu_segmenter/src/complex/dictionary.rs b/third_party/rust/icu_segmenter/src/complex/dictionary.rs new file mode 100644 index 0000000000..90360ee2b0 --- /dev/null +++ b/third_party/rust/icu_segmenter/src/complex/dictionary.rs @@ -0,0 +1,268 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +use crate::grapheme::*; +use crate::indices::Utf16Indices; +use crate::provider::*; +use core::str::CharIndices; +use icu_collections::char16trie::{Char16Trie, TrieResult}; + +/// A trait for dictionary based iterator +trait DictionaryType<'l, 's> { + /// The iterator over characters. + type IterAttr: Iterator<Item = (usize, Self::CharType)> + Clone; + + /// The character type. + type CharType: Copy + Into<u32>; + + fn to_char(c: Self::CharType) -> char; + fn char_len(c: Self::CharType) -> usize; +} + +struct DictionaryBreakIterator< + 'l, + 's, + Y: DictionaryType<'l, 's> + ?Sized, + X: Iterator<Item = usize> + ?Sized, +> { + trie: Char16Trie<'l>, + iter: Y::IterAttr, + len: usize, + grapheme_iter: X, + // TODO transform value for byte trie +} + +/// Implement the [`Iterator`] trait over the segmenter break opportunities of the given string. +/// Please see the [module-level documentation](crate) for its usages. +/// +/// Lifetimes: +/// - `'l` = lifetime of the segmenter object from which this iterator was created +/// - `'s` = lifetime of the string being segmented +/// +/// [`Iterator`]: core::iter::Iterator +impl<'l, 's, Y: DictionaryType<'l, 's> + ?Sized, X: Iterator<Item = usize> + ?Sized> Iterator + for DictionaryBreakIterator<'l, 's, Y, X> +{ + type Item = usize; + + fn next(&mut self) -> Option<Self::Item> { + let mut trie_iter = self.trie.iter(); + let mut intermediate_length = 0; + let mut not_match = false; + let mut previous_match = None; + let mut last_grapheme_offset = 0; + + while let Some(next) = self.iter.next() { + let ch = Y::to_char(next.1); + match trie_iter.next(ch) { + TrieResult::FinalValue(_) => { + return Some(next.0 + Y::char_len(next.1)); + } + TrieResult::Intermediate(_) => { + // Dictionary has to match with grapheme cluster segment. + // If not, we ignore it. + while last_grapheme_offset < next.0 + Y::char_len(next.1) { + if let Some(offset) = self.grapheme_iter.next() { + last_grapheme_offset = offset; + continue; + } + last_grapheme_offset = self.len; + break; + } + if last_grapheme_offset != next.0 + Y::char_len(next.1) { + continue; + } + + intermediate_length = next.0 + Y::char_len(next.1); + previous_match = Some(self.iter.clone()); + } + TrieResult::NoMatch => { + if intermediate_length > 0 { + if let Some(previous_match) = previous_match { + // Rewind previous match point + self.iter = previous_match; + } + return Some(intermediate_length); + } + // Not found + return Some(next.0 + Y::char_len(next.1)); + } + TrieResult::NoValue => { + // Prefix string is matched + not_match = true; + } + } + } + + if intermediate_length > 0 { + Some(intermediate_length) + } else if not_match { + // no match by scanning text + Some(self.len) + } else { + None + } + } +} + +impl<'l, 's> DictionaryType<'l, 's> for u32 { + type IterAttr = Utf16Indices<'s>; + type CharType = u32; + + fn to_char(c: u32) -> char { + char::from_u32(c).unwrap_or(char::REPLACEMENT_CHARACTER) + } + + fn char_len(c: u32) -> usize { + if c >= 0x10000 { + 2 + } else { + 1 + } + } +} + +impl<'l, 's> DictionaryType<'l, 's> for char { + type IterAttr = CharIndices<'s>; + type CharType = char; + + fn to_char(c: char) -> char { + c + } + + fn char_len(c: char) -> usize { + c.len_utf8() + } +} + +pub(super) struct DictionarySegmenter<'l> { + dict: &'l UCharDictionaryBreakDataV1<'l>, + grapheme: &'l RuleBreakDataV1<'l>, +} + +impl<'l> DictionarySegmenter<'l> { + pub(super) fn new( + dict: &'l UCharDictionaryBreakDataV1<'l>, + grapheme: &'l RuleBreakDataV1<'l>, + ) -> Self { + // TODO: no way to verify trie data + Self { dict, grapheme } + } + + /// Create a dictionary based break iterator for an `str` (a UTF-8 string). + pub(super) fn segment_str(&'l self, input: &'l str) -> impl Iterator<Item = usize> + 'l { + let grapheme_iter = GraphemeClusterSegmenter::new_and_segment_str(input, self.grapheme); + DictionaryBreakIterator::<char, GraphemeClusterBreakIteratorUtf8> { + trie: Char16Trie::new(self.dict.trie_data.clone()), + iter: input.char_indices(), + len: input.len(), + grapheme_iter, + } + } + + /// Create a dictionary based break iterator for a UTF-16 string. + pub(super) fn segment_utf16(&'l self, input: &'l [u16]) -> impl Iterator<Item = usize> + 'l { + let grapheme_iter = GraphemeClusterSegmenter::new_and_segment_utf16(input, self.grapheme); + DictionaryBreakIterator::<u32, GraphemeClusterBreakIteratorUtf16> { + trie: Char16Trie::new(self.dict.trie_data.clone()), + iter: Utf16Indices::new(input), + len: input.len(), + grapheme_iter, + } + } +} + +#[cfg(test)] +#[cfg(feature = "serde")] +mod tests { + use super::*; + use crate::{LineSegmenter, WordSegmenter}; + use icu_provider::prelude::*; + + #[test] + fn burmese_dictionary_test() { + let segmenter = LineSegmenter::new_dictionary(); + // From css/css-text/word-break/word-break-normal-my-000.html + let s = "မြန်မာစာမြန်မာစာမြန်မာစာ"; + let result: Vec<usize> = segmenter.segment_str(s).collect(); + assert_eq!(result, vec![0, 18, 24, 42, 48, 66, 72]); + + let s_utf16: Vec<u16> = s.encode_utf16().collect(); + let result: Vec<usize> = segmenter.segment_utf16(&s_utf16).collect(); + assert_eq!(result, vec![0, 6, 8, 14, 16, 22, 24]); + } + + #[test] + fn cj_dictionary_test() { + let dict_payload: DataPayload<DictionaryForWordOnlyAutoV1Marker> = crate::provider::Baked + .load(DataRequest { + locale: &icu_locid::locale!("ja").into(), + metadata: Default::default(), + }) + .unwrap() + .take_payload() + .unwrap(); + let word_segmenter = WordSegmenter::new_dictionary(); + let dict_segmenter = DictionarySegmenter::new( + dict_payload.get(), + crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1, + ); + + // Match case + let s = "龟山岛龟山岛"; + let result: Vec<usize> = dict_segmenter.segment_str(s).collect(); + assert_eq!(result, vec![9, 18]); + + let result: Vec<usize> = word_segmenter.segment_str(s).collect(); + assert_eq!(result, vec![0, 9, 18]); + + let s_utf16: Vec<u16> = s.encode_utf16().collect(); + let result: Vec<usize> = dict_segmenter.segment_utf16(&s_utf16).collect(); + assert_eq!(result, vec![3, 6]); + + let result: Vec<usize> = word_segmenter.segment_utf16(&s_utf16).collect(); + assert_eq!(result, vec![0, 3, 6]); + + // Match case, then no match case + let s = "エディターエディ"; + let result: Vec<usize> = dict_segmenter.segment_str(s).collect(); + assert_eq!(result, vec![15, 24]); + + // TODO(#3236): Why is WordSegmenter not returning the middle segment? + let result: Vec<usize> = word_segmenter.segment_str(s).collect(); + assert_eq!(result, vec![0, 24]); + + let s_utf16: Vec<u16> = s.encode_utf16().collect(); + let result: Vec<usize> = dict_segmenter.segment_utf16(&s_utf16).collect(); + assert_eq!(result, vec![5, 8]); + + // TODO(#3236): Why is WordSegmenter not returning the middle segment? + let result: Vec<usize> = word_segmenter.segment_utf16(&s_utf16).collect(); + assert_eq!(result, vec![0, 8]); + } + + #[test] + fn khmer_dictionary_test() { + let segmenter = LineSegmenter::new_dictionary(); + let s = "ភាសាខ្មែរភាសាខ្មែរភាសាខ្មែរ"; + let result: Vec<usize> = segmenter.segment_str(s).collect(); + assert_eq!(result, vec![0, 27, 54, 81]); + + let s_utf16: Vec<u16> = s.encode_utf16().collect(); + let result: Vec<usize> = segmenter.segment_utf16(&s_utf16).collect(); + assert_eq!(result, vec![0, 9, 18, 27]); + } + + #[test] + fn lao_dictionary_test() { + let segmenter = LineSegmenter::new_dictionary(); + let s = "ພາສາລາວພາສາລາວພາສາລາວ"; + let r: Vec<usize> = segmenter.segment_str(s).collect(); + assert_eq!(r, vec![0, 12, 21, 33, 42, 54, 63]); + + let s_utf16: Vec<u16> = s.encode_utf16().collect(); + let r: Vec<usize> = segmenter.segment_utf16(&s_utf16).collect(); + assert_eq!(r, vec![0, 4, 7, 11, 14, 18, 21]); + } +} diff --git a/third_party/rust/icu_segmenter/src/complex/language.rs b/third_party/rust/icu_segmenter/src/complex/language.rs new file mode 100644 index 0000000000..327eea5e20 --- /dev/null +++ b/third_party/rust/icu_segmenter/src/complex/language.rs @@ -0,0 +1,161 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +#[derive(PartialEq, Debug, Copy, Clone)] +pub(super) enum Language { + Burmese, + ChineseOrJapanese, + Khmer, + Lao, + Thai, + Unknown, +} + +// TODO: Use data provider +fn get_language(codepoint: u32) -> Language { + match codepoint { + 0xe01..=0xe7f => Language::Thai, + 0x0E80..=0x0EFF => Language::Lao, + 0x1000..=0x109f => Language::Burmese, + 0x1780..=0x17FF => Language::Khmer, + 0x19E0..=0x19FF => Language::Khmer, + 0x2E80..=0x2EFF => Language::ChineseOrJapanese, + 0x2F00..=0x2FDF => Language::ChineseOrJapanese, + 0x3040..=0x30FF => Language::ChineseOrJapanese, + 0x31F0..=0x31FF => Language::ChineseOrJapanese, + 0x32D0..=0x32FE => Language::ChineseOrJapanese, + 0x3400..=0x4DBF => Language::ChineseOrJapanese, + 0x4E00..=0x9FFF => Language::ChineseOrJapanese, + 0xa9e0..=0xa9ff => Language::Burmese, + 0xaa60..=0xaa7f => Language::Burmese, + 0xF900..=0xFAFF => Language::ChineseOrJapanese, + 0xFF66..=0xFF9D => Language::ChineseOrJapanese, + 0x16FE2..=0x16FE3 => Language::ChineseOrJapanese, + 0x16FF0..=0x16FF1 => Language::ChineseOrJapanese, + 0x1AFF0..=0x1B16F => Language::ChineseOrJapanese, + 0x1F200 => Language::ChineseOrJapanese, + 0x20000..=0x2FA1F => Language::ChineseOrJapanese, + 0x30000..=0x3134F => Language::ChineseOrJapanese, + _ => Language::Unknown, + } +} + +/// This struct is an iterator that returns the string per language from the +/// given string. +pub(super) struct LanguageIterator<'s> { + rest: &'s str, +} + +impl<'s> LanguageIterator<'s> { + pub(super) fn new(input: &'s str) -> Self { + Self { rest: input } + } +} + +impl<'s> Iterator for LanguageIterator<'s> { + type Item = (&'s str, Language); + + fn next(&mut self) -> Option<Self::Item> { + let mut indices = self.rest.char_indices(); + let lang = get_language(indices.next()?.1 as u32); + match indices.find(|&(_, ch)| get_language(ch as u32) != lang) { + Some((i, _)) => { + let (result, rest) = self.rest.split_at(i); + self.rest = rest; + Some((result, lang)) + } + None => Some((core::mem::take(&mut self.rest), lang)), + } + } +} + +pub(super) struct LanguageIteratorUtf16<'s> { + rest: &'s [u16], +} + +impl<'s> LanguageIteratorUtf16<'s> { + pub(super) fn new(input: &'s [u16]) -> Self { + Self { rest: input } + } +} + +impl<'s> Iterator for LanguageIteratorUtf16<'s> { + type Item = (&'s [u16], Language); + + fn next(&mut self) -> Option<Self::Item> { + let lang = get_language(*self.rest.first()? as u32); + match self + .rest + .iter() + .position(|&ch| get_language(ch as u32) != lang) + { + Some(i) => { + let (result, rest) = self.rest.split_at(i); + self.rest = rest; + Some((result, lang)) + } + None => Some((core::mem::take(&mut self.rest), lang)), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_thai_only() { + let s = "ภาษาไทยภาษาไทย"; + let utf16: Vec<u16> = s.encode_utf16().collect(); + let mut iter = LanguageIteratorUtf16::new(&utf16); + assert_eq!( + iter.next(), + Some((utf16.as_slice(), Language::Thai)), + "Thai language only with UTF-16" + ); + let mut iter = LanguageIterator::new(s); + assert_eq!( + iter.next(), + Some((s, Language::Thai)), + "Thai language only with UTF-8" + ); + assert_eq!(iter.next(), None, "Iterator for UTF-8 is finished"); + } + + #[test] + fn test_combine() { + const TEST_STR_THAI: &str = "ภาษาไทยภาษาไทย"; + const TEST_STR_BURMESE: &str = "ဗမာနွယ်ဘာသာစကားမျာ"; + let s = format!("{TEST_STR_THAI}{TEST_STR_BURMESE}"); + let utf16: Vec<u16> = s.encode_utf16().collect(); + let thai_utf16: Vec<u16> = TEST_STR_THAI.encode_utf16().collect(); + let burmese_utf16: Vec<u16> = TEST_STR_BURMESE.encode_utf16().collect(); + + let mut iter = LanguageIteratorUtf16::new(&utf16); + assert_eq!( + iter.next(), + Some((thai_utf16.as_slice(), Language::Thai)), + "Thai language with UTF-16 at first" + ); + assert_eq!( + iter.next(), + Some((burmese_utf16.as_slice(), Language::Burmese)), + "Burmese language with UTF-16 at second" + ); + assert_eq!(iter.next(), None, "Iterator for UTF-16 is finished"); + + let mut iter = LanguageIterator::new(&s); + assert_eq!( + iter.next(), + Some((TEST_STR_THAI, Language::Thai)), + "Thai language with UTF-8 at first" + ); + assert_eq!( + iter.next(), + Some((TEST_STR_BURMESE, Language::Burmese)), + "Burmese language with UTF-8 at second" + ); + assert_eq!(iter.next(), None, "Iterator for UTF-8 is finished"); + } +} diff --git a/third_party/rust/icu_segmenter/src/complex/lstm/matrix.rs b/third_party/rust/icu_segmenter/src/complex/lstm/matrix.rs new file mode 100644 index 0000000000..3cf5ce2e3c --- /dev/null +++ b/third_party/rust/icu_segmenter/src/complex/lstm/matrix.rs @@ -0,0 +1,540 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +use alloc::vec; +use alloc::vec::Vec; +use core::ops::Range; +#[allow(unused_imports)] +use core_maths::*; +use zerovec::ule::AsULE; +use zerovec::ZeroSlice; + +/// A `D`-dimensional, heap-allocated matrix. +/// +/// This matrix implementation supports slicing matrices into tightly-packed +/// submatrices. For example, indexing into a matrix of size 5x4x3 returns a +/// matrix of size 4x3. For more information, see [`MatrixOwned::submatrix`]. +#[derive(Debug, Clone)] +pub(super) struct MatrixOwned<const D: usize> { + data: Vec<f32>, + dims: [usize; D], +} + +impl<const D: usize> MatrixOwned<D> { + pub(super) fn as_borrowed(&self) -> MatrixBorrowed<D> { + MatrixBorrowed { + data: &self.data, + dims: self.dims, + } + } + + pub(super) fn new_zero(dims: [usize; D]) -> Self { + let total_len = dims.iter().product::<usize>(); + MatrixOwned { + data: vec![0.0; total_len], + dims, + } + } + + /// Returns the tighly packed submatrix at _index_, or `None` if _index_ is out of range. + /// + /// For example, if the matrix is 5x4x3, this function returns a matrix sized 4x3. If the + /// matrix is 4x3, then this function returns a linear matrix of length 3. + /// + /// The type parameter `M` should be `D - 1`. + #[inline] + pub(super) fn submatrix<const M: usize>(&self, index: usize) -> Option<MatrixBorrowed<M>> { + // This assertion is based on const generics; it should always succeed and be elided. + assert_eq!(M, D - 1); + let (range, dims) = self.as_borrowed().submatrix_range(index); + let data = &self.data.get(range)?; + Some(MatrixBorrowed { data, dims }) + } + + pub(super) fn as_mut(&mut self) -> MatrixBorrowedMut<D> { + MatrixBorrowedMut { + data: &mut self.data, + dims: self.dims, + } + } + + /// A mutable version of [`Self::submatrix`]. + #[inline] + pub(super) fn submatrix_mut<const M: usize>( + &mut self, + index: usize, + ) -> Option<MatrixBorrowedMut<M>> { + // This assertion is based on const generics; it should always succeed and be elided. + assert_eq!(M, D - 1); + let (range, dims) = self.as_borrowed().submatrix_range(index); + let data = self.data.get_mut(range)?; + Some(MatrixBorrowedMut { data, dims }) + } +} + +/// A `D`-dimensional, borrowed matrix. +#[derive(Debug, Clone, Copy)] +pub(super) struct MatrixBorrowed<'a, const D: usize> { + data: &'a [f32], + dims: [usize; D], +} + +impl<'a, const D: usize> MatrixBorrowed<'a, D> { + #[cfg(debug_assertions)] + pub(super) fn debug_assert_dims(&self, dims: [usize; D]) { + debug_assert_eq!(dims, self.dims); + let expected_len = dims.iter().product::<usize>(); + debug_assert_eq!(expected_len, self.data.len()); + } + + pub(super) fn as_slice(&self) -> &'a [f32] { + self.data + } + + /// See [`MatrixOwned::submatrix`]. + #[inline] + pub(super) fn submatrix<const M: usize>(&self, index: usize) -> Option<MatrixBorrowed<'a, M>> { + // This assertion is based on const generics; it should always succeed and be elided. + assert_eq!(M, D - 1); + let (range, dims) = self.submatrix_range(index); + let data = &self.data.get(range)?; + Some(MatrixBorrowed { data, dims }) + } + + #[inline] + fn submatrix_range<const M: usize>(&self, index: usize) -> (Range<usize>, [usize; M]) { + // This assertion is based on const generics; it should always succeed and be elided. + assert_eq!(M, D - 1); + // The above assertion guarantees that the following line will succeed + #[allow(clippy::indexing_slicing, clippy::unwrap_used)] + let sub_dims: [usize; M] = self.dims[1..].try_into().unwrap(); + let n = sub_dims.iter().product::<usize>(); + (n * index..n * (index + 1), sub_dims) + } +} + +macro_rules! impl_basic_dim { + ($t1:path, $t2:path, $t3:path) => { + impl<'a> $t1 { + #[allow(dead_code)] + pub(super) fn dim(&self) -> usize { + let [dim] = self.dims; + dim + } + } + impl<'a> $t2 { + #[allow(dead_code)] + pub(super) fn dim(&self) -> (usize, usize) { + let [d0, d1] = self.dims; + (d0, d1) + } + } + impl<'a> $t3 { + #[allow(dead_code)] + pub(super) fn dim(&self) -> (usize, usize, usize) { + let [d0, d1, d2] = self.dims; + (d0, d1, d2) + } + } + }; +} + +impl_basic_dim!(MatrixOwned<1>, MatrixOwned<2>, MatrixOwned<3>); +impl_basic_dim!( + MatrixBorrowed<'a, 1>, + MatrixBorrowed<'a, 2>, + MatrixBorrowed<'a, 3> +); +impl_basic_dim!( + MatrixBorrowedMut<'a, 1>, + MatrixBorrowedMut<'a, 2>, + MatrixBorrowedMut<'a, 3> +); +impl_basic_dim!(MatrixZero<'a, 1>, MatrixZero<'a, 2>, MatrixZero<'a, 3>); + +/// A `D`-dimensional, mutably borrowed matrix. +pub(super) struct MatrixBorrowedMut<'a, const D: usize> { + pub(super) data: &'a mut [f32], + pub(super) dims: [usize; D], +} + +impl<'a, const D: usize> MatrixBorrowedMut<'a, D> { + pub(super) fn as_borrowed(&self) -> MatrixBorrowed<D> { + MatrixBorrowed { + data: self.data, + dims: self.dims, + } + } + + pub(super) fn as_mut_slice(&mut self) -> &mut [f32] { + self.data + } + + pub(super) fn copy_submatrix<const M: usize>(&mut self, from: usize, to: usize) { + let (range_from, _) = self.as_borrowed().submatrix_range::<M>(from); + let (range_to, _) = self.as_borrowed().submatrix_range::<M>(to); + if let (Some(_), Some(_)) = ( + self.data.get(range_from.clone()), + self.data.get(range_to.clone()), + ) { + // This function is panicky, but we just validated the ranges + self.data.copy_within(range_from, range_to.start); + } + } + + #[must_use] + pub(super) fn add(&mut self, other: MatrixZero<'_, D>) -> Option<()> { + debug_assert_eq!(self.dims, other.dims); + // TODO: Vectorize? + for i in 0..self.data.len() { + *self.data.get_mut(i)? += other.data.get(i)?; + } + Some(()) + } + + #[allow(dead_code)] // maybe needed for more complicated bies calculations + /// Mutates this matrix by applying a softmax transformation. + pub(super) fn softmax_transform(&mut self) { + for v in self.data.iter_mut() { + *v = v.exp(); + } + let sm = 1.0 / self.data.iter().sum::<f32>(); + for v in self.data.iter_mut() { + *v *= sm; + } + } + + pub(super) fn sigmoid_transform(&mut self) { + for x in &mut self.data.iter_mut() { + *x = 1.0 / (1.0 + (-*x).exp()); + } + } + + pub(super) fn tanh_transform(&mut self) { + for x in &mut self.data.iter_mut() { + *x = x.tanh(); + } + } + + pub(super) fn convolve( + &mut self, + i: MatrixBorrowed<'_, D>, + c: MatrixBorrowed<'_, D>, + f: MatrixBorrowed<'_, D>, + ) { + let i = i.as_slice(); + let c = c.as_slice(); + let f = f.as_slice(); + let len = self.data.len(); + if len != i.len() || len != c.len() || len != f.len() { + debug_assert!(false, "LSTM matrices not the correct dimensions"); + return; + } + for idx in 0..len { + // Safety: The lengths are all the same (checked above) + unsafe { + *self.data.get_unchecked_mut(idx) = i.get_unchecked(idx) * c.get_unchecked(idx) + + self.data.get_unchecked(idx) * f.get_unchecked(idx) + } + } + } + + pub(super) fn mul_tanh(&mut self, o: MatrixBorrowed<'_, D>, c: MatrixBorrowed<'_, D>) { + let o = o.as_slice(); + let c = c.as_slice(); + let len = self.data.len(); + if len != o.len() || len != c.len() { + debug_assert!(false, "LSTM matrices not the correct dimensions"); + return; + } + for idx in 0..len { + // Safety: The lengths are all the same (checked above) + unsafe { + *self.data.get_unchecked_mut(idx) = + o.get_unchecked(idx) * c.get_unchecked(idx).tanh(); + } + } + } +} + +impl<'a> MatrixBorrowed<'a, 1> { + #[allow(dead_code)] // could be useful + pub(super) fn dot_1d(&self, other: MatrixZero<1>) -> f32 { + debug_assert_eq!(self.dims, other.dims); + unrolled_dot_1(self.data, other.data) + } +} + +impl<'a> MatrixBorrowedMut<'a, 1> { + /// Calculate the dot product of a and b, adding the result to self. + /// + /// Note: For better dot product efficiency, if `b` is MxN, then `a` should be N; + /// this is the opposite of standard practice. + pub(super) fn add_dot_2d(&mut self, a: MatrixBorrowed<1>, b: MatrixZero<2>) { + let m = a.dim(); + let n = self.as_borrowed().dim(); + debug_assert_eq!( + m, + b.dim().1, + "dims: {:?}/{:?}/{:?}", + self.as_borrowed().dim(), + a.dim(), + b.dim() + ); + debug_assert_eq!( + n, + b.dim().0, + "dims: {:?}/{:?}/{:?}", + self.as_borrowed().dim(), + a.dim(), + b.dim() + ); + for i in 0..n { + if let (Some(dest), Some(b_sub)) = (self.as_mut_slice().get_mut(i), b.submatrix::<1>(i)) + { + *dest += unrolled_dot_1(a.data, b_sub.data); + } else { + debug_assert!(false, "unreachable: dims checked above"); + } + } + } +} + +impl<'a> MatrixBorrowedMut<'a, 2> { + /// Calculate the dot product of a and b, adding the result to self. + /// + /// Self should be _MxN_; `a`, _O_; and `b`, _MxNxO_. + pub(super) fn add_dot_3d_1(&mut self, a: MatrixBorrowed<1>, b: MatrixZero<3>) { + let m = a.dim(); + let n = self.as_borrowed().dim().0 * self.as_borrowed().dim().1; + debug_assert_eq!( + m, + b.dim().2, + "dims: {:?}/{:?}/{:?}", + self.as_borrowed().dim(), + a.dim(), + b.dim() + ); + debug_assert_eq!( + n, + b.dim().0 * b.dim().1, + "dims: {:?}/{:?}/{:?}", + self.as_borrowed().dim(), + a.dim(), + b.dim() + ); + // Note: The following two loops are equivalent, but the second has more opportunity for + // vectorization since it allows the vectorization to span submatrices. + // for i in 0..b.dim().0 { + // self.submatrix_mut::<1>(i).add_dot_2d(a, b.submatrix(i)); + // } + let lhs = a.as_slice(); + for i in 0..n { + if let (Some(dest), Some(rhs)) = ( + self.as_mut_slice().get_mut(i), + b.as_slice().get_subslice(i * m..(i + 1) * m), + ) { + *dest += unrolled_dot_1(lhs, rhs); + } else { + debug_assert!(false, "unreachable: dims checked above"); + } + } + } + + /// Calculate the dot product of a and b, adding the result to self. + /// + /// Self should be _MxN_; `a`, _O_; and `b`, _MxNxO_. + pub(super) fn add_dot_3d_2(&mut self, a: MatrixZero<1>, b: MatrixZero<3>) { + let m = a.dim(); + let n = self.as_borrowed().dim().0 * self.as_borrowed().dim().1; + debug_assert_eq!( + m, + b.dim().2, + "dims: {:?}/{:?}/{:?}", + self.as_borrowed().dim(), + a.dim(), + b.dim() + ); + debug_assert_eq!( + n, + b.dim().0 * b.dim().1, + "dims: {:?}/{:?}/{:?}", + self.as_borrowed().dim(), + a.dim(), + b.dim() + ); + // Note: The following two loops are equivalent, but the second has more opportunity for + // vectorization since it allows the vectorization to span submatrices. + // for i in 0..b.dim().0 { + // self.submatrix_mut::<1>(i).add_dot_2d(a, b.submatrix(i)); + // } + let lhs = a.as_slice(); + for i in 0..n { + if let (Some(dest), Some(rhs)) = ( + self.as_mut_slice().get_mut(i), + b.as_slice().get_subslice(i * m..(i + 1) * m), + ) { + *dest += unrolled_dot_2(lhs, rhs); + } else { + debug_assert!(false, "unreachable: dims checked above"); + } + } + } +} + +/// A `D`-dimensional matrix borrowed from a [`ZeroSlice`]. +#[derive(Debug, Clone, Copy)] +pub(super) struct MatrixZero<'a, const D: usize> { + data: &'a ZeroSlice<f32>, + dims: [usize; D], +} + +impl<'a> From<&'a crate::provider::LstmMatrix1<'a>> for MatrixZero<'a, 1> { + fn from(other: &'a crate::provider::LstmMatrix1<'a>) -> Self { + Self { + data: &other.data, + dims: other.dims.map(|x| x as usize), + } + } +} + +impl<'a> From<&'a crate::provider::LstmMatrix2<'a>> for MatrixZero<'a, 2> { + fn from(other: &'a crate::provider::LstmMatrix2<'a>) -> Self { + Self { + data: &other.data, + dims: other.dims.map(|x| x as usize), + } + } +} + +impl<'a> From<&'a crate::provider::LstmMatrix3<'a>> for MatrixZero<'a, 3> { + fn from(other: &'a crate::provider::LstmMatrix3<'a>) -> Self { + Self { + data: &other.data, + dims: other.dims.map(|x| x as usize), + } + } +} + +impl<'a, const D: usize> MatrixZero<'a, D> { + #[allow(clippy::wrong_self_convention)] // same convention as slice::to_vec + pub(super) fn to_owned(&self) -> MatrixOwned<D> { + MatrixOwned { + data: self.data.iter().collect(), + dims: self.dims, + } + } + + pub(super) fn as_slice(&self) -> &ZeroSlice<f32> { + self.data + } + + #[cfg(debug_assertions)] + pub(super) fn debug_assert_dims(&self, dims: [usize; D]) { + debug_assert_eq!(dims, self.dims); + let expected_len = dims.iter().product::<usize>(); + debug_assert_eq!(expected_len, self.data.len()); + } + + /// See [`MatrixOwned::submatrix`]. + #[inline] + pub(super) fn submatrix<const M: usize>(&self, index: usize) -> Option<MatrixZero<'a, M>> { + // This assertion is based on const generics; it should always succeed and be elided. + assert_eq!(M, D - 1); + let (range, dims) = self.submatrix_range(index); + let data = &self.data.get_subslice(range)?; + Some(MatrixZero { data, dims }) + } + + #[inline] + fn submatrix_range<const M: usize>(&self, index: usize) -> (Range<usize>, [usize; M]) { + // This assertion is based on const generics; it should always succeed and be elided. + assert_eq!(M, D - 1); + // The above assertion guarantees that the following line will succeed + #[allow(clippy::indexing_slicing, clippy::unwrap_used)] + let sub_dims: [usize; M] = self.dims[1..].try_into().unwrap(); + let n = sub_dims.iter().product::<usize>(); + (n * index..n * (index + 1), sub_dims) + } +} + +macro_rules! f32c { + ($ule:expr) => { + f32::from_unaligned($ule) + }; +} + +/// Compute the dot product of an aligned and an unaligned f32 slice. +/// +/// `xs` and `ys` must be the same length +/// +/// (Based on ndarray 0.15.6) +fn unrolled_dot_1(xs: &[f32], ys: &ZeroSlice<f32>) -> f32 { + debug_assert_eq!(xs.len(), ys.len()); + // eightfold unrolled so that floating point can be vectorized + // (even with strict floating point accuracy semantics) + let mut p = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0); + let xit = xs.chunks_exact(8); + let yit = ys.as_ule_slice().chunks_exact(8); + let sum = xit + .remainder() + .iter() + .zip(yit.remainder().iter()) + .map(|(x, y)| x * f32c!(*y)) + .sum::<f32>(); + for (xx, yy) in xit.zip(yit) { + // TODO: Use array_chunks once stable to avoid the unwrap. + // <https://github.com/rust-lang/rust/issues/74985> + #[allow(clippy::unwrap_used)] + let [x0, x1, x2, x3, x4, x5, x6, x7] = *<&[f32; 8]>::try_from(xx).unwrap(); + #[allow(clippy::unwrap_used)] + let [y0, y1, y2, y3, y4, y5, y6, y7] = *<&[<f32 as AsULE>::ULE; 8]>::try_from(yy).unwrap(); + p.0 += x0 * f32c!(y0); + p.1 += x1 * f32c!(y1); + p.2 += x2 * f32c!(y2); + p.3 += x3 * f32c!(y3); + p.4 += x4 * f32c!(y4); + p.5 += x5 * f32c!(y5); + p.6 += x6 * f32c!(y6); + p.7 += x7 * f32c!(y7); + } + sum + (p.0 + p.4) + (p.1 + p.5) + (p.2 + p.6) + (p.3 + p.7) +} + +/// Compute the dot product of two unaligned f32 slices. +/// +/// `xs` and `ys` must be the same length +/// +/// (Based on ndarray 0.15.6) +fn unrolled_dot_2(xs: &ZeroSlice<f32>, ys: &ZeroSlice<f32>) -> f32 { + debug_assert_eq!(xs.len(), ys.len()); + // eightfold unrolled so that floating point can be vectorized + // (even with strict floating point accuracy semantics) + let mut p = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0); + let xit = xs.as_ule_slice().chunks_exact(8); + let yit = ys.as_ule_slice().chunks_exact(8); + let sum = xit + .remainder() + .iter() + .zip(yit.remainder().iter()) + .map(|(x, y)| f32c!(*x) * f32c!(*y)) + .sum::<f32>(); + for (xx, yy) in xit.zip(yit) { + // TODO: Use array_chunks once stable to avoid the unwrap. + // <https://github.com/rust-lang/rust/issues/74985> + #[allow(clippy::unwrap_used)] + let [x0, x1, x2, x3, x4, x5, x6, x7] = *<&[<f32 as AsULE>::ULE; 8]>::try_from(xx).unwrap(); + #[allow(clippy::unwrap_used)] + let [y0, y1, y2, y3, y4, y5, y6, y7] = *<&[<f32 as AsULE>::ULE; 8]>::try_from(yy).unwrap(); + p.0 += f32c!(x0) * f32c!(y0); + p.1 += f32c!(x1) * f32c!(y1); + p.2 += f32c!(x2) * f32c!(y2); + p.3 += f32c!(x3) * f32c!(y3); + p.4 += f32c!(x4) * f32c!(y4); + p.5 += f32c!(x5) * f32c!(y5); + p.6 += f32c!(x6) * f32c!(y6); + p.7 += f32c!(x7) * f32c!(y7); + } + sum + (p.0 + p.4) + (p.1 + p.5) + (p.2 + p.6) + (p.3 + p.7) +} diff --git a/third_party/rust/icu_segmenter/src/complex/lstm/mod.rs b/third_party/rust/icu_segmenter/src/complex/lstm/mod.rs new file mode 100644 index 0000000000..8718cbd3da --- /dev/null +++ b/third_party/rust/icu_segmenter/src/complex/lstm/mod.rs @@ -0,0 +1,402 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +use crate::grapheme::GraphemeClusterSegmenter; +use crate::provider::*; +use alloc::vec::Vec; +use core::char::{decode_utf16, REPLACEMENT_CHARACTER}; +use zerovec::{maps::ZeroMapBorrowed, ule::UnvalidatedStr}; + +mod matrix; +use matrix::*; + +// A word break iterator using LSTM model. Input string have to be same language. + +struct LstmSegmenterIterator<'s> { + input: &'s str, + pos_utf8: usize, + bies: BiesIterator<'s>, +} + +impl Iterator for LstmSegmenterIterator<'_> { + type Item = usize; + + fn next(&mut self) -> Option<Self::Item> { + #[allow(clippy::indexing_slicing)] // pos_utf8 in range + loop { + let is_e = self.bies.next()?; + self.pos_utf8 += self.input[self.pos_utf8..].chars().next()?.len_utf8(); + if is_e || self.bies.len() == 0 { + return Some(self.pos_utf8); + } + } + } +} + +struct LstmSegmenterIteratorUtf16<'s> { + bies: BiesIterator<'s>, + pos: usize, +} + +impl Iterator for LstmSegmenterIteratorUtf16<'_> { + type Item = usize; + + fn next(&mut self) -> Option<Self::Item> { + loop { + self.pos += 1; + if self.bies.next()? || self.bies.len() == 0 { + return Some(self.pos); + } + } + } +} + +pub(super) struct LstmSegmenter<'l> { + dic: ZeroMapBorrowed<'l, UnvalidatedStr, u16>, + embedding: MatrixZero<'l, 2>, + fw_w: MatrixZero<'l, 3>, + fw_u: MatrixZero<'l, 3>, + fw_b: MatrixZero<'l, 2>, + bw_w: MatrixZero<'l, 3>, + bw_u: MatrixZero<'l, 3>, + bw_b: MatrixZero<'l, 2>, + timew_fw: MatrixZero<'l, 2>, + timew_bw: MatrixZero<'l, 2>, + time_b: MatrixZero<'l, 1>, + grapheme: Option<&'l RuleBreakDataV1<'l>>, +} + +impl<'l> LstmSegmenter<'l> { + /// Returns `Err` if grapheme data is required but not present + pub(super) fn new(lstm: &'l LstmDataV1<'l>, grapheme: &'l RuleBreakDataV1<'l>) -> Self { + let LstmDataV1::Float32(lstm) = lstm; + let time_w = MatrixZero::from(&lstm.time_w); + #[allow(clippy::unwrap_used)] // shape (2, 4, hunits) + let timew_fw = time_w.submatrix(0).unwrap(); + #[allow(clippy::unwrap_used)] // shape (2, 4, hunits) + let timew_bw = time_w.submatrix(1).unwrap(); + Self { + dic: lstm.dic.as_borrowed(), + embedding: MatrixZero::from(&lstm.embedding), + fw_w: MatrixZero::from(&lstm.fw_w), + fw_u: MatrixZero::from(&lstm.fw_u), + fw_b: MatrixZero::from(&lstm.fw_b), + bw_w: MatrixZero::from(&lstm.bw_w), + bw_u: MatrixZero::from(&lstm.bw_u), + bw_b: MatrixZero::from(&lstm.bw_b), + timew_fw, + timew_bw, + time_b: MatrixZero::from(&lstm.time_b), + grapheme: (lstm.model == ModelType::GraphemeClusters).then_some(grapheme), + } + } + + /// Create an LSTM based break iterator for an `str` (a UTF-8 string). + pub(super) fn segment_str(&'l self, input: &'l str) -> impl Iterator<Item = usize> + 'l { + self.segment_str_p(input) + } + + // For unit testing as we cannot inspect the opaque type's bies + fn segment_str_p(&'l self, input: &'l str) -> LstmSegmenterIterator<'l> { + let input_seq = if let Some(grapheme) = self.grapheme { + GraphemeClusterSegmenter::new_and_segment_str(input, grapheme) + .collect::<Vec<usize>>() + .windows(2) + .map(|chunk| { + let range = if let [first, second, ..] = chunk { + *first..*second + } else { + unreachable!() + }; + let grapheme_cluster = if let Some(grapheme_cluster) = input.get(range) { + grapheme_cluster + } else { + return self.dic.len() as u16; + }; + + self.dic + .get_copied(UnvalidatedStr::from_str(grapheme_cluster)) + .unwrap_or_else(|| self.dic.len() as u16) + }) + .collect() + } else { + input + .chars() + .map(|c| { + self.dic + .get_copied(UnvalidatedStr::from_str(c.encode_utf8(&mut [0; 4]))) + .unwrap_or_else(|| self.dic.len() as u16) + }) + .collect() + }; + LstmSegmenterIterator { + input, + pos_utf8: 0, + bies: BiesIterator::new(self, input_seq), + } + } + + /// Create an LSTM based break iterator for a UTF-16 string. + pub(super) fn segment_utf16(&'l self, input: &[u16]) -> impl Iterator<Item = usize> + 'l { + let input_seq = if let Some(grapheme) = self.grapheme { + GraphemeClusterSegmenter::new_and_segment_utf16(input, grapheme) + .collect::<Vec<usize>>() + .windows(2) + .map(|chunk| { + let range = if let [first, second, ..] = chunk { + *first..*second + } else { + unreachable!() + }; + let grapheme_cluster = if let Some(grapheme_cluster) = input.get(range) { + grapheme_cluster + } else { + return self.dic.len() as u16; + }; + + self.dic + .get_copied_by(|key| { + key.as_bytes().iter().copied().cmp( + decode_utf16(grapheme_cluster.iter().copied()).flat_map(|c| { + let mut buf = [0; 4]; + let len = c + .unwrap_or(REPLACEMENT_CHARACTER) + .encode_utf8(&mut buf) + .len(); + buf.into_iter().take(len) + }), + ) + }) + .unwrap_or_else(|| self.dic.len() as u16) + }) + .collect() + } else { + decode_utf16(input.iter().copied()) + .map(|c| c.unwrap_or(REPLACEMENT_CHARACTER)) + .map(|c| { + self.dic + .get_copied(UnvalidatedStr::from_str(c.encode_utf8(&mut [0; 4]))) + .unwrap_or_else(|| self.dic.len() as u16) + }) + .collect() + }; + LstmSegmenterIteratorUtf16 { + bies: BiesIterator::new(self, input_seq), + pos: 0, + } + } +} + +struct BiesIterator<'l> { + segmenter: &'l LstmSegmenter<'l>, + input_seq: core::iter::Enumerate<alloc::vec::IntoIter<u16>>, + h_bw: MatrixOwned<2>, + curr_fw: MatrixOwned<1>, + c_fw: MatrixOwned<1>, +} + +impl<'l> BiesIterator<'l> { + // input_seq is a sequence of id numbers that represents grapheme clusters or code points in the input line. These ids are used later + // in the embedding layer of the model. + fn new(segmenter: &'l LstmSegmenter<'l>, input_seq: Vec<u16>) -> Self { + let hunits = segmenter.fw_u.dim().1; + + // Backward LSTM + let mut c_bw = MatrixOwned::<1>::new_zero([hunits]); + let mut h_bw = MatrixOwned::<2>::new_zero([input_seq.len(), hunits]); + for (i, &g_id) in input_seq.iter().enumerate().rev() { + if i + 1 < input_seq.len() { + h_bw.as_mut().copy_submatrix::<1>(i + 1, i); + } + #[allow(clippy::unwrap_used)] + compute_hc( + segmenter.embedding.submatrix::<1>(g_id as usize).unwrap(), /* shape (dict.len() + 1, hunit), g_id is at most dict.len() */ + h_bw.submatrix_mut(i).unwrap(), // shape (input_seq.len(), hunits) + c_bw.as_mut(), + segmenter.bw_w, + segmenter.bw_u, + segmenter.bw_b, + ); + } + + Self { + input_seq: input_seq.into_iter().enumerate(), + h_bw, + c_fw: MatrixOwned::<1>::new_zero([hunits]), + curr_fw: MatrixOwned::<1>::new_zero([hunits]), + segmenter, + } + } +} + +impl ExactSizeIterator for BiesIterator<'_> { + fn len(&self) -> usize { + self.input_seq.len() + } +} + +impl Iterator for BiesIterator<'_> { + type Item = bool; + + fn next(&mut self) -> Option<Self::Item> { + let (i, g_id) = self.input_seq.next()?; + + #[allow(clippy::unwrap_used)] + compute_hc( + self.segmenter + .embedding + .submatrix::<1>(g_id as usize) + .unwrap(), // shape (dict.len() + 1, hunit), g_id is at most dict.len() + self.curr_fw.as_mut(), + self.c_fw.as_mut(), + self.segmenter.fw_w, + self.segmenter.fw_u, + self.segmenter.fw_b, + ); + + #[allow(clippy::unwrap_used)] // shape (input_seq.len(), hunits) + let curr_bw = self.h_bw.submatrix::<1>(i).unwrap(); + let mut weights = [0.0; 4]; + let mut curr_est = MatrixBorrowedMut { + data: &mut weights, + dims: [4], + }; + curr_est.add_dot_2d(self.curr_fw.as_borrowed(), self.segmenter.timew_fw); + curr_est.add_dot_2d(curr_bw, self.segmenter.timew_bw); + #[allow(clippy::unwrap_used)] // both shape (4) + curr_est.add(self.segmenter.time_b).unwrap(); + // For correct BIES weight calculation we'd now have to apply softmax, however + // we're only doing a naive argmax, so a monotonic function doesn't make a difference. + + Some(weights[2] > weights[0] && weights[2] > weights[1] && weights[2] > weights[3]) + } +} + +/// `compute_hc1` implemens the evaluation of one LSTM layer. +fn compute_hc<'a>( + x_t: MatrixZero<'a, 1>, + mut h_tm1: MatrixBorrowedMut<'a, 1>, + mut c_tm1: MatrixBorrowedMut<'a, 1>, + w: MatrixZero<'a, 3>, + u: MatrixZero<'a, 3>, + b: MatrixZero<'a, 2>, +) { + #[cfg(debug_assertions)] + { + let hunits = h_tm1.dim(); + let embedd_dim = x_t.dim(); + c_tm1.as_borrowed().debug_assert_dims([hunits]); + w.debug_assert_dims([4, hunits, embedd_dim]); + u.debug_assert_dims([4, hunits, hunits]); + b.debug_assert_dims([4, hunits]); + } + + let mut s_t = b.to_owned(); + + s_t.as_mut().add_dot_3d_2(x_t, w); + s_t.as_mut().add_dot_3d_1(h_tm1.as_borrowed(), u); + + #[allow(clippy::unwrap_used)] // first dimension is 4 + s_t.submatrix_mut::<1>(0).unwrap().sigmoid_transform(); + #[allow(clippy::unwrap_used)] // first dimension is 4 + s_t.submatrix_mut::<1>(1).unwrap().sigmoid_transform(); + #[allow(clippy::unwrap_used)] // first dimension is 4 + s_t.submatrix_mut::<1>(2).unwrap().tanh_transform(); + #[allow(clippy::unwrap_used)] // first dimension is 4 + s_t.submatrix_mut::<1>(3).unwrap().sigmoid_transform(); + + #[allow(clippy::unwrap_used)] // first dimension is 4 + c_tm1.convolve( + s_t.as_borrowed().submatrix(0).unwrap(), + s_t.as_borrowed().submatrix(2).unwrap(), + s_t.as_borrowed().submatrix(1).unwrap(), + ); + + #[allow(clippy::unwrap_used)] // first dimension is 4 + h_tm1.mul_tanh(s_t.as_borrowed().submatrix(3).unwrap(), c_tm1.as_borrowed()); +} + +#[cfg(test)] +mod tests { + use super::*; + use icu_locid::locale; + use icu_provider::prelude::*; + use serde::Deserialize; + use std::fs::File; + use std::io::BufReader; + + /// `TestCase` is a struct used to store a single test case. + /// Each test case has two attributs: `unseg` which denots the unsegmented line, and `true_bies` which indicates the Bies + /// sequence representing the true segmentation. + #[derive(PartialEq, Debug, Deserialize)] + struct TestCase { + unseg: String, + expected_bies: String, + true_bies: String, + } + + /// `TestTextData` is a struct to store a vector of `TestCase` that represents a test text. + #[derive(PartialEq, Debug, Deserialize)] + struct TestTextData { + testcases: Vec<TestCase>, + } + + #[derive(Debug)] + struct TestText { + data: TestTextData, + } + + fn load_test_text(filename: &str) -> TestTextData { + let file = File::open(filename).expect("File should be present"); + let reader = BufReader::new(file); + serde_json::from_reader(reader).expect("JSON syntax error") + } + + #[test] + fn segment_file_by_lstm() { + let lstm: DataPayload<LstmForWordLineAutoV1Marker> = crate::provider::Baked + .load(DataRequest { + locale: &locale!("th").into(), + metadata: Default::default(), + }) + .unwrap() + .take_payload() + .unwrap(); + let lstm = LstmSegmenter::new( + lstm.get(), + crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1, + ); + + // Importing the test data + let test_text_data = load_test_text(&format!( + "tests/testdata/test_text_{}.json", + if lstm.grapheme.is_some() { + "grapheme" + } else { + "codepoints" + } + )); + let test_text = TestText { + data: test_text_data, + }; + + // Testing + for test_case in &test_text.data.testcases { + let lstm_output = lstm + .segment_str_p(&test_case.unseg) + .bies + .map(|is_e| if is_e { 'e' } else { '?' }) + .collect::<String>(); + println!("Test case : {}", test_case.unseg); + println!("Expected bies : {}", test_case.expected_bies); + println!("Estimated bies : {lstm_output}"); + println!("True bies : {}", test_case.true_bies); + println!("****************************************************"); + assert_eq!( + test_case.expected_bies.replace(['b', 'i', 's'], "?"), + lstm_output + ); + } + } +} diff --git a/third_party/rust/icu_segmenter/src/complex/mod.rs b/third_party/rust/icu_segmenter/src/complex/mod.rs new file mode 100644 index 0000000000..65f49a92f0 --- /dev/null +++ b/third_party/rust/icu_segmenter/src/complex/mod.rs @@ -0,0 +1,440 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +use crate::provider::*; +use alloc::vec::Vec; +use icu_locid::{locale, Locale}; +use icu_provider::prelude::*; + +mod dictionary; +use dictionary::*; +mod language; +use language::*; +#[cfg(feature = "lstm")] +mod lstm; +#[cfg(feature = "lstm")] +use lstm::*; + +#[cfg(not(feature = "lstm"))] +type DictOrLstm = Result<DataPayload<UCharDictionaryBreakDataV1Marker>, core::convert::Infallible>; +#[cfg(not(feature = "lstm"))] +type DictOrLstmBorrowed<'a> = + Result<&'a DataPayload<UCharDictionaryBreakDataV1Marker>, &'a core::convert::Infallible>; + +#[cfg(feature = "lstm")] +type DictOrLstm = + Result<DataPayload<UCharDictionaryBreakDataV1Marker>, DataPayload<LstmDataV1Marker>>; +#[cfg(feature = "lstm")] +type DictOrLstmBorrowed<'a> = + Result<&'a DataPayload<UCharDictionaryBreakDataV1Marker>, &'a DataPayload<LstmDataV1Marker>>; + +#[derive(Debug)] +pub(crate) struct ComplexPayloads { + grapheme: DataPayload<GraphemeClusterBreakDataV1Marker>, + my: Option<DictOrLstm>, + km: Option<DictOrLstm>, + lo: Option<DictOrLstm>, + th: Option<DictOrLstm>, + ja: Option<DataPayload<UCharDictionaryBreakDataV1Marker>>, +} + +impl ComplexPayloads { + fn select(&self, language: Language) -> Option<DictOrLstmBorrowed> { + const ERR: DataError = DataError::custom("No segmentation model for language"); + match language { + Language::Burmese => self.my.as_ref().map(Result::as_ref).or_else(|| { + ERR.with_display_context("my"); + None + }), + Language::Khmer => self.km.as_ref().map(Result::as_ref).or_else(|| { + ERR.with_display_context("km"); + None + }), + Language::Lao => self.lo.as_ref().map(Result::as_ref).or_else(|| { + ERR.with_display_context("lo"); + None + }), + Language::Thai => self.th.as_ref().map(Result::as_ref).or_else(|| { + ERR.with_display_context("th"); + None + }), + Language::ChineseOrJapanese => self.ja.as_ref().map(Ok).or_else(|| { + ERR.with_display_context("ja"); + None + }), + Language::Unknown => None, + } + } + + #[cfg(feature = "lstm")] + #[cfg(feature = "compiled_data")] + pub(crate) fn new_lstm() -> Self { + #[allow(clippy::unwrap_used)] + // try_load is infallible if the provider only returns `MissingLocale`. + Self { + grapheme: DataPayload::from_static_ref( + crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1, + ), + my: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("my")) + .unwrap() + .map(DataPayload::cast) + .map(Err), + km: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("km")) + .unwrap() + .map(DataPayload::cast) + .map(Err), + lo: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("lo")) + .unwrap() + .map(DataPayload::cast) + .map(Err), + th: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("th")) + .unwrap() + .map(DataPayload::cast) + .map(Err), + ja: None, + } + } + + #[cfg(feature = "lstm")] + pub(crate) fn try_new_lstm<D>(provider: &D) -> Result<Self, DataError> + where + D: DataProvider<GraphemeClusterBreakDataV1Marker> + + DataProvider<LstmForWordLineAutoV1Marker> + + ?Sized, + { + Ok(Self { + grapheme: provider.load(Default::default())?.take_payload()?, + my: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("my"))? + .map(DataPayload::cast) + .map(Err), + km: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("km"))? + .map(DataPayload::cast) + .map(Err), + lo: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("lo"))? + .map(DataPayload::cast) + .map(Err), + th: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("th"))? + .map(DataPayload::cast) + .map(Err), + ja: None, + }) + } + + #[cfg(feature = "compiled_data")] + pub(crate) fn new_dict() -> Self { + #[allow(clippy::unwrap_used)] + // try_load is infallible if the provider only returns `MissingLocale`. + Self { + grapheme: DataPayload::from_static_ref( + crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1, + ), + my: try_load::<DictionaryForWordLineExtendedV1Marker, _>( + &crate::provider::Baked, + locale!("my"), + ) + .unwrap() + .map(DataPayload::cast) + .map(Ok), + km: try_load::<DictionaryForWordLineExtendedV1Marker, _>( + &crate::provider::Baked, + locale!("km"), + ) + .unwrap() + .map(DataPayload::cast) + .map(Ok), + lo: try_load::<DictionaryForWordLineExtendedV1Marker, _>( + &crate::provider::Baked, + locale!("lo"), + ) + .unwrap() + .map(DataPayload::cast) + .map(Ok), + th: try_load::<DictionaryForWordLineExtendedV1Marker, _>( + &crate::provider::Baked, + locale!("th"), + ) + .unwrap() + .map(DataPayload::cast) + .map(Ok), + ja: try_load::<DictionaryForWordOnlyAutoV1Marker, _>( + &crate::provider::Baked, + locale!("ja"), + ) + .unwrap() + .map(DataPayload::cast), + } + } + + pub(crate) fn try_new_dict<D>(provider: &D) -> Result<Self, DataError> + where + D: DataProvider<GraphemeClusterBreakDataV1Marker> + + DataProvider<DictionaryForWordLineExtendedV1Marker> + + DataProvider<DictionaryForWordOnlyAutoV1Marker> + + ?Sized, + { + Ok(Self { + grapheme: provider.load(Default::default())?.take_payload()?, + my: try_load::<DictionaryForWordLineExtendedV1Marker, D>(provider, locale!("my"))? + .map(DataPayload::cast) + .map(Ok), + km: try_load::<DictionaryForWordLineExtendedV1Marker, D>(provider, locale!("km"))? + .map(DataPayload::cast) + .map(Ok), + lo: try_load::<DictionaryForWordLineExtendedV1Marker, D>(provider, locale!("lo"))? + .map(DataPayload::cast) + .map(Ok), + th: try_load::<DictionaryForWordLineExtendedV1Marker, D>(provider, locale!("th"))? + .map(DataPayload::cast) + .map(Ok), + ja: try_load::<DictionaryForWordOnlyAutoV1Marker, D>(provider, locale!("ja"))? + .map(DataPayload::cast), + }) + } + + #[cfg(feature = "auto")] // Use by WordSegmenter with "auto" enabled. + #[cfg(feature = "compiled_data")] + pub(crate) fn new_auto() -> Self { + #[allow(clippy::unwrap_used)] + // try_load is infallible if the provider only returns `MissingLocale`. + Self { + grapheme: DataPayload::from_static_ref( + crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1, + ), + my: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("my")) + .unwrap() + .map(DataPayload::cast) + .map(Err), + km: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("km")) + .unwrap() + .map(DataPayload::cast) + .map(Err), + lo: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("lo")) + .unwrap() + .map(DataPayload::cast) + .map(Err), + th: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("th")) + .unwrap() + .map(DataPayload::cast) + .map(Err), + ja: try_load::<DictionaryForWordOnlyAutoV1Marker, _>( + &crate::provider::Baked, + locale!("ja"), + ) + .unwrap() + .map(DataPayload::cast), + } + } + + #[cfg(feature = "auto")] // Use by WordSegmenter with "auto" enabled. + pub(crate) fn try_new_auto<D>(provider: &D) -> Result<Self, DataError> + where + D: DataProvider<GraphemeClusterBreakDataV1Marker> + + DataProvider<LstmForWordLineAutoV1Marker> + + DataProvider<DictionaryForWordOnlyAutoV1Marker> + + ?Sized, + { + Ok(Self { + grapheme: provider.load(Default::default())?.take_payload()?, + my: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("my"))? + .map(DataPayload::cast) + .map(Err), + km: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("km"))? + .map(DataPayload::cast) + .map(Err), + lo: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("lo"))? + .map(DataPayload::cast) + .map(Err), + th: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("th"))? + .map(DataPayload::cast) + .map(Err), + ja: try_load::<DictionaryForWordOnlyAutoV1Marker, D>(provider, locale!("ja"))? + .map(DataPayload::cast), + }) + } + + #[cfg(feature = "compiled_data")] + pub(crate) fn new_southeast_asian() -> Self { + #[allow(clippy::unwrap_used)] + // try_load is infallible if the provider only returns `MissingLocale`. + Self { + grapheme: DataPayload::from_static_ref( + crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1, + ), + my: try_load::<DictionaryForWordLineExtendedV1Marker, _>( + &crate::provider::Baked, + locale!("my"), + ) + .unwrap() + .map(DataPayload::cast) + .map(Ok), + km: try_load::<DictionaryForWordLineExtendedV1Marker, _>( + &crate::provider::Baked, + locale!("km"), + ) + .unwrap() + .map(DataPayload::cast) + .map(Ok), + lo: try_load::<DictionaryForWordLineExtendedV1Marker, _>( + &crate::provider::Baked, + locale!("lo"), + ) + .unwrap() + .map(DataPayload::cast) + .map(Ok), + th: try_load::<DictionaryForWordLineExtendedV1Marker, _>( + &crate::provider::Baked, + locale!("th"), + ) + .unwrap() + .map(DataPayload::cast) + .map(Ok), + ja: None, + } + } + + pub(crate) fn try_new_southeast_asian<D>(provider: &D) -> Result<Self, DataError> + where + D: DataProvider<DictionaryForWordLineExtendedV1Marker> + + DataProvider<GraphemeClusterBreakDataV1Marker> + + ?Sized, + { + Ok(Self { + grapheme: provider.load(Default::default())?.take_payload()?, + my: try_load::<DictionaryForWordLineExtendedV1Marker, _>(provider, locale!("my"))? + .map(DataPayload::cast) + .map(Ok), + km: try_load::<DictionaryForWordLineExtendedV1Marker, _>(provider, locale!("km"))? + .map(DataPayload::cast) + .map(Ok), + lo: try_load::<DictionaryForWordLineExtendedV1Marker, _>(provider, locale!("lo"))? + .map(DataPayload::cast) + .map(Ok), + th: try_load::<DictionaryForWordLineExtendedV1Marker, _>(provider, locale!("th"))? + .map(DataPayload::cast) + .map(Ok), + ja: None, + }) + } +} + +fn try_load<M: KeyedDataMarker, P: DataProvider<M> + ?Sized>( + provider: &P, + locale: Locale, +) -> Result<Option<DataPayload<M>>, DataError> { + match provider.load(DataRequest { + locale: &DataLocale::from(locale), + metadata: { + let mut m = DataRequestMetadata::default(); + m.silent = true; + m + }, + }) { + Ok(response) => Ok(Some(response.take_payload()?)), + Err(DataError { + kind: DataErrorKind::MissingLocale, + .. + }) => Ok(None), + Err(e) => Err(e), + } +} + +/// Return UTF-16 segment offset array using dictionary or lstm segmenter. +pub(crate) fn complex_language_segment_utf16( + payloads: &ComplexPayloads, + input: &[u16], +) -> Vec<usize> { + let mut result = Vec::new(); + let mut offset = 0; + for (slice, lang) in LanguageIteratorUtf16::new(input) { + match payloads.select(lang) { + Some(Ok(dict)) => { + result.extend( + DictionarySegmenter::new(dict.get(), payloads.grapheme.get()) + .segment_utf16(slice) + .map(|n| offset + n), + ); + } + #[cfg(feature = "lstm")] + Some(Err(lstm)) => { + result.extend( + LstmSegmenter::new(lstm.get(), payloads.grapheme.get()) + .segment_utf16(slice) + .map(|n| offset + n), + ); + } + #[cfg(not(feature = "lstm"))] + Some(Err(_infallible)) => {} // should be refutable + None => { + result.push(offset + slice.len()); + } + } + offset += slice.len(); + } + result +} + +/// Return UTF-8 segment offset array using dictionary or lstm segmenter. +pub(crate) fn complex_language_segment_str(payloads: &ComplexPayloads, input: &str) -> Vec<usize> { + let mut result = Vec::new(); + let mut offset = 0; + for (slice, lang) in LanguageIterator::new(input) { + match payloads.select(lang) { + Some(Ok(dict)) => { + result.extend( + DictionarySegmenter::new(dict.get(), payloads.grapheme.get()) + .segment_str(slice) + .map(|n| offset + n), + ); + } + #[cfg(feature = "lstm")] + Some(Err(lstm)) => { + result.extend( + LstmSegmenter::new(lstm.get(), payloads.grapheme.get()) + .segment_str(slice) + .map(|n| offset + n), + ); + } + #[cfg(not(feature = "lstm"))] + Some(Err(_infallible)) => {} // should be refutable + None => { + result.push(offset + slice.len()); + } + } + offset += slice.len(); + } + result +} + +#[cfg(test)] +#[cfg(feature = "serde")] +mod tests { + use super::*; + + #[test] + fn thai_word_break() { + const TEST_STR: &str = "ภาษาไทยภาษาไทย"; + let utf16: Vec<u16> = TEST_STR.encode_utf16().collect(); + + let lstm = ComplexPayloads::new_lstm(); + let dict = ComplexPayloads::new_dict(); + + assert_eq!( + complex_language_segment_str(&lstm, TEST_STR), + [12, 21, 33, 42] + ); + assert_eq!( + complex_language_segment_utf16(&lstm, &utf16), + [4, 7, 11, 14] + ); + + assert_eq!( + complex_language_segment_str(&dict, TEST_STR), + [12, 21, 33, 42] + ); + assert_eq!( + complex_language_segment_utf16(&dict, &utf16), + [4, 7, 11, 14] + ); + } +} |