summaryrefslogtreecommitdiffstats
path: root/third_party/rust/icu_segmenter/src/complex
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 00:47:55 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 00:47:55 +0000
commit26a029d407be480d791972afb5975cf62c9360a6 (patch)
treef435a8308119effd964b339f76abb83a57c29483 /third_party/rust/icu_segmenter/src/complex
parentInitial commit. (diff)
downloadfirefox-26a029d407be480d791972afb5975cf62c9360a6.tar.xz
firefox-26a029d407be480d791972afb5975cf62c9360a6.zip
Adding upstream version 124.0.1.upstream/124.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/icu_segmenter/src/complex')
-rw-r--r--third_party/rust/icu_segmenter/src/complex/dictionary.rs268
-rw-r--r--third_party/rust/icu_segmenter/src/complex/language.rs161
-rw-r--r--third_party/rust/icu_segmenter/src/complex/lstm/matrix.rs540
-rw-r--r--third_party/rust/icu_segmenter/src/complex/lstm/mod.rs402
-rw-r--r--third_party/rust/icu_segmenter/src/complex/mod.rs440
5 files changed, 1811 insertions, 0 deletions
diff --git a/third_party/rust/icu_segmenter/src/complex/dictionary.rs b/third_party/rust/icu_segmenter/src/complex/dictionary.rs
new file mode 100644
index 0000000000..90360ee2b0
--- /dev/null
+++ b/third_party/rust/icu_segmenter/src/complex/dictionary.rs
@@ -0,0 +1,268 @@
+// This file is part of ICU4X. For terms of use, please see the file
+// called LICENSE at the top level of the ICU4X source tree
+// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
+
+use crate::grapheme::*;
+use crate::indices::Utf16Indices;
+use crate::provider::*;
+use core::str::CharIndices;
+use icu_collections::char16trie::{Char16Trie, TrieResult};
+
+/// A trait for dictionary based iterator
+trait DictionaryType<'l, 's> {
+ /// The iterator over characters.
+ type IterAttr: Iterator<Item = (usize, Self::CharType)> + Clone;
+
+ /// The character type.
+ type CharType: Copy + Into<u32>;
+
+ fn to_char(c: Self::CharType) -> char;
+ fn char_len(c: Self::CharType) -> usize;
+}
+
+struct DictionaryBreakIterator<
+ 'l,
+ 's,
+ Y: DictionaryType<'l, 's> + ?Sized,
+ X: Iterator<Item = usize> + ?Sized,
+> {
+ trie: Char16Trie<'l>,
+ iter: Y::IterAttr,
+ len: usize,
+ grapheme_iter: X,
+ // TODO transform value for byte trie
+}
+
+/// Implement the [`Iterator`] trait over the segmenter break opportunities of the given string.
+/// Please see the [module-level documentation](crate) for its usages.
+///
+/// Lifetimes:
+/// - `'l` = lifetime of the segmenter object from which this iterator was created
+/// - `'s` = lifetime of the string being segmented
+///
+/// [`Iterator`]: core::iter::Iterator
+impl<'l, 's, Y: DictionaryType<'l, 's> + ?Sized, X: Iterator<Item = usize> + ?Sized> Iterator
+ for DictionaryBreakIterator<'l, 's, Y, X>
+{
+ type Item = usize;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ let mut trie_iter = self.trie.iter();
+ let mut intermediate_length = 0;
+ let mut not_match = false;
+ let mut previous_match = None;
+ let mut last_grapheme_offset = 0;
+
+ while let Some(next) = self.iter.next() {
+ let ch = Y::to_char(next.1);
+ match trie_iter.next(ch) {
+ TrieResult::FinalValue(_) => {
+ return Some(next.0 + Y::char_len(next.1));
+ }
+ TrieResult::Intermediate(_) => {
+ // Dictionary has to match with grapheme cluster segment.
+ // If not, we ignore it.
+ while last_grapheme_offset < next.0 + Y::char_len(next.1) {
+ if let Some(offset) = self.grapheme_iter.next() {
+ last_grapheme_offset = offset;
+ continue;
+ }
+ last_grapheme_offset = self.len;
+ break;
+ }
+ if last_grapheme_offset != next.0 + Y::char_len(next.1) {
+ continue;
+ }
+
+ intermediate_length = next.0 + Y::char_len(next.1);
+ previous_match = Some(self.iter.clone());
+ }
+ TrieResult::NoMatch => {
+ if intermediate_length > 0 {
+ if let Some(previous_match) = previous_match {
+ // Rewind previous match point
+ self.iter = previous_match;
+ }
+ return Some(intermediate_length);
+ }
+ // Not found
+ return Some(next.0 + Y::char_len(next.1));
+ }
+ TrieResult::NoValue => {
+ // Prefix string is matched
+ not_match = true;
+ }
+ }
+ }
+
+ if intermediate_length > 0 {
+ Some(intermediate_length)
+ } else if not_match {
+ // no match by scanning text
+ Some(self.len)
+ } else {
+ None
+ }
+ }
+}
+
+impl<'l, 's> DictionaryType<'l, 's> for u32 {
+ type IterAttr = Utf16Indices<'s>;
+ type CharType = u32;
+
+ fn to_char(c: u32) -> char {
+ char::from_u32(c).unwrap_or(char::REPLACEMENT_CHARACTER)
+ }
+
+ fn char_len(c: u32) -> usize {
+ if c >= 0x10000 {
+ 2
+ } else {
+ 1
+ }
+ }
+}
+
+impl<'l, 's> DictionaryType<'l, 's> for char {
+ type IterAttr = CharIndices<'s>;
+ type CharType = char;
+
+ fn to_char(c: char) -> char {
+ c
+ }
+
+ fn char_len(c: char) -> usize {
+ c.len_utf8()
+ }
+}
+
+pub(super) struct DictionarySegmenter<'l> {
+ dict: &'l UCharDictionaryBreakDataV1<'l>,
+ grapheme: &'l RuleBreakDataV1<'l>,
+}
+
+impl<'l> DictionarySegmenter<'l> {
+ pub(super) fn new(
+ dict: &'l UCharDictionaryBreakDataV1<'l>,
+ grapheme: &'l RuleBreakDataV1<'l>,
+ ) -> Self {
+ // TODO: no way to verify trie data
+ Self { dict, grapheme }
+ }
+
+ /// Create a dictionary based break iterator for an `str` (a UTF-8 string).
+ pub(super) fn segment_str(&'l self, input: &'l str) -> impl Iterator<Item = usize> + 'l {
+ let grapheme_iter = GraphemeClusterSegmenter::new_and_segment_str(input, self.grapheme);
+ DictionaryBreakIterator::<char, GraphemeClusterBreakIteratorUtf8> {
+ trie: Char16Trie::new(self.dict.trie_data.clone()),
+ iter: input.char_indices(),
+ len: input.len(),
+ grapheme_iter,
+ }
+ }
+
+ /// Create a dictionary based break iterator for a UTF-16 string.
+ pub(super) fn segment_utf16(&'l self, input: &'l [u16]) -> impl Iterator<Item = usize> + 'l {
+ let grapheme_iter = GraphemeClusterSegmenter::new_and_segment_utf16(input, self.grapheme);
+ DictionaryBreakIterator::<u32, GraphemeClusterBreakIteratorUtf16> {
+ trie: Char16Trie::new(self.dict.trie_data.clone()),
+ iter: Utf16Indices::new(input),
+ len: input.len(),
+ grapheme_iter,
+ }
+ }
+}
+
+#[cfg(test)]
+#[cfg(feature = "serde")]
+mod tests {
+ use super::*;
+ use crate::{LineSegmenter, WordSegmenter};
+ use icu_provider::prelude::*;
+
+ #[test]
+ fn burmese_dictionary_test() {
+ let segmenter = LineSegmenter::new_dictionary();
+ // From css/css-text/word-break/word-break-normal-my-000.html
+ let s = "မြန်မာစာမြန်မာစာမြန်မာစာ";
+ let result: Vec<usize> = segmenter.segment_str(s).collect();
+ assert_eq!(result, vec![0, 18, 24, 42, 48, 66, 72]);
+
+ let s_utf16: Vec<u16> = s.encode_utf16().collect();
+ let result: Vec<usize> = segmenter.segment_utf16(&s_utf16).collect();
+ assert_eq!(result, vec![0, 6, 8, 14, 16, 22, 24]);
+ }
+
+ #[test]
+ fn cj_dictionary_test() {
+ let dict_payload: DataPayload<DictionaryForWordOnlyAutoV1Marker> = crate::provider::Baked
+ .load(DataRequest {
+ locale: &icu_locid::locale!("ja").into(),
+ metadata: Default::default(),
+ })
+ .unwrap()
+ .take_payload()
+ .unwrap();
+ let word_segmenter = WordSegmenter::new_dictionary();
+ let dict_segmenter = DictionarySegmenter::new(
+ dict_payload.get(),
+ crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1,
+ );
+
+ // Match case
+ let s = "龟山岛龟山岛";
+ let result: Vec<usize> = dict_segmenter.segment_str(s).collect();
+ assert_eq!(result, vec![9, 18]);
+
+ let result: Vec<usize> = word_segmenter.segment_str(s).collect();
+ assert_eq!(result, vec![0, 9, 18]);
+
+ let s_utf16: Vec<u16> = s.encode_utf16().collect();
+ let result: Vec<usize> = dict_segmenter.segment_utf16(&s_utf16).collect();
+ assert_eq!(result, vec![3, 6]);
+
+ let result: Vec<usize> = word_segmenter.segment_utf16(&s_utf16).collect();
+ assert_eq!(result, vec![0, 3, 6]);
+
+ // Match case, then no match case
+ let s = "エディターエディ";
+ let result: Vec<usize> = dict_segmenter.segment_str(s).collect();
+ assert_eq!(result, vec![15, 24]);
+
+ // TODO(#3236): Why is WordSegmenter not returning the middle segment?
+ let result: Vec<usize> = word_segmenter.segment_str(s).collect();
+ assert_eq!(result, vec![0, 24]);
+
+ let s_utf16: Vec<u16> = s.encode_utf16().collect();
+ let result: Vec<usize> = dict_segmenter.segment_utf16(&s_utf16).collect();
+ assert_eq!(result, vec![5, 8]);
+
+ // TODO(#3236): Why is WordSegmenter not returning the middle segment?
+ let result: Vec<usize> = word_segmenter.segment_utf16(&s_utf16).collect();
+ assert_eq!(result, vec![0, 8]);
+ }
+
+ #[test]
+ fn khmer_dictionary_test() {
+ let segmenter = LineSegmenter::new_dictionary();
+ let s = "ភាសាខ្មែរភាសាខ្មែរភាសាខ្មែរ";
+ let result: Vec<usize> = segmenter.segment_str(s).collect();
+ assert_eq!(result, vec![0, 27, 54, 81]);
+
+ let s_utf16: Vec<u16> = s.encode_utf16().collect();
+ let result: Vec<usize> = segmenter.segment_utf16(&s_utf16).collect();
+ assert_eq!(result, vec![0, 9, 18, 27]);
+ }
+
+ #[test]
+ fn lao_dictionary_test() {
+ let segmenter = LineSegmenter::new_dictionary();
+ let s = "ພາສາລາວພາສາລາວພາສາລາວ";
+ let r: Vec<usize> = segmenter.segment_str(s).collect();
+ assert_eq!(r, vec![0, 12, 21, 33, 42, 54, 63]);
+
+ let s_utf16: Vec<u16> = s.encode_utf16().collect();
+ let r: Vec<usize> = segmenter.segment_utf16(&s_utf16).collect();
+ assert_eq!(r, vec![0, 4, 7, 11, 14, 18, 21]);
+ }
+}
diff --git a/third_party/rust/icu_segmenter/src/complex/language.rs b/third_party/rust/icu_segmenter/src/complex/language.rs
new file mode 100644
index 0000000000..327eea5e20
--- /dev/null
+++ b/third_party/rust/icu_segmenter/src/complex/language.rs
@@ -0,0 +1,161 @@
+// This file is part of ICU4X. For terms of use, please see the file
+// called LICENSE at the top level of the ICU4X source tree
+// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
+
+#[derive(PartialEq, Debug, Copy, Clone)]
+pub(super) enum Language {
+ Burmese,
+ ChineseOrJapanese,
+ Khmer,
+ Lao,
+ Thai,
+ Unknown,
+}
+
+// TODO: Use data provider
+fn get_language(codepoint: u32) -> Language {
+ match codepoint {
+ 0xe01..=0xe7f => Language::Thai,
+ 0x0E80..=0x0EFF => Language::Lao,
+ 0x1000..=0x109f => Language::Burmese,
+ 0x1780..=0x17FF => Language::Khmer,
+ 0x19E0..=0x19FF => Language::Khmer,
+ 0x2E80..=0x2EFF => Language::ChineseOrJapanese,
+ 0x2F00..=0x2FDF => Language::ChineseOrJapanese,
+ 0x3040..=0x30FF => Language::ChineseOrJapanese,
+ 0x31F0..=0x31FF => Language::ChineseOrJapanese,
+ 0x32D0..=0x32FE => Language::ChineseOrJapanese,
+ 0x3400..=0x4DBF => Language::ChineseOrJapanese,
+ 0x4E00..=0x9FFF => Language::ChineseOrJapanese,
+ 0xa9e0..=0xa9ff => Language::Burmese,
+ 0xaa60..=0xaa7f => Language::Burmese,
+ 0xF900..=0xFAFF => Language::ChineseOrJapanese,
+ 0xFF66..=0xFF9D => Language::ChineseOrJapanese,
+ 0x16FE2..=0x16FE3 => Language::ChineseOrJapanese,
+ 0x16FF0..=0x16FF1 => Language::ChineseOrJapanese,
+ 0x1AFF0..=0x1B16F => Language::ChineseOrJapanese,
+ 0x1F200 => Language::ChineseOrJapanese,
+ 0x20000..=0x2FA1F => Language::ChineseOrJapanese,
+ 0x30000..=0x3134F => Language::ChineseOrJapanese,
+ _ => Language::Unknown,
+ }
+}
+
+/// This struct is an iterator that returns the string per language from the
+/// given string.
+pub(super) struct LanguageIterator<'s> {
+ rest: &'s str,
+}
+
+impl<'s> LanguageIterator<'s> {
+ pub(super) fn new(input: &'s str) -> Self {
+ Self { rest: input }
+ }
+}
+
+impl<'s> Iterator for LanguageIterator<'s> {
+ type Item = (&'s str, Language);
+
+ fn next(&mut self) -> Option<Self::Item> {
+ let mut indices = self.rest.char_indices();
+ let lang = get_language(indices.next()?.1 as u32);
+ match indices.find(|&(_, ch)| get_language(ch as u32) != lang) {
+ Some((i, _)) => {
+ let (result, rest) = self.rest.split_at(i);
+ self.rest = rest;
+ Some((result, lang))
+ }
+ None => Some((core::mem::take(&mut self.rest), lang)),
+ }
+ }
+}
+
+pub(super) struct LanguageIteratorUtf16<'s> {
+ rest: &'s [u16],
+}
+
+impl<'s> LanguageIteratorUtf16<'s> {
+ pub(super) fn new(input: &'s [u16]) -> Self {
+ Self { rest: input }
+ }
+}
+
+impl<'s> Iterator for LanguageIteratorUtf16<'s> {
+ type Item = (&'s [u16], Language);
+
+ fn next(&mut self) -> Option<Self::Item> {
+ let lang = get_language(*self.rest.first()? as u32);
+ match self
+ .rest
+ .iter()
+ .position(|&ch| get_language(ch as u32) != lang)
+ {
+ Some(i) => {
+ let (result, rest) = self.rest.split_at(i);
+ self.rest = rest;
+ Some((result, lang))
+ }
+ None => Some((core::mem::take(&mut self.rest), lang)),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_thai_only() {
+ let s = "ภาษาไทยภาษาไทย";
+ let utf16: Vec<u16> = s.encode_utf16().collect();
+ let mut iter = LanguageIteratorUtf16::new(&utf16);
+ assert_eq!(
+ iter.next(),
+ Some((utf16.as_slice(), Language::Thai)),
+ "Thai language only with UTF-16"
+ );
+ let mut iter = LanguageIterator::new(s);
+ assert_eq!(
+ iter.next(),
+ Some((s, Language::Thai)),
+ "Thai language only with UTF-8"
+ );
+ assert_eq!(iter.next(), None, "Iterator for UTF-8 is finished");
+ }
+
+ #[test]
+ fn test_combine() {
+ const TEST_STR_THAI: &str = "ภาษาไทยภาษาไทย";
+ const TEST_STR_BURMESE: &str = "ဗမာနွယ်ဘာသာစကားမျာ";
+ let s = format!("{TEST_STR_THAI}{TEST_STR_BURMESE}");
+ let utf16: Vec<u16> = s.encode_utf16().collect();
+ let thai_utf16: Vec<u16> = TEST_STR_THAI.encode_utf16().collect();
+ let burmese_utf16: Vec<u16> = TEST_STR_BURMESE.encode_utf16().collect();
+
+ let mut iter = LanguageIteratorUtf16::new(&utf16);
+ assert_eq!(
+ iter.next(),
+ Some((thai_utf16.as_slice(), Language::Thai)),
+ "Thai language with UTF-16 at first"
+ );
+ assert_eq!(
+ iter.next(),
+ Some((burmese_utf16.as_slice(), Language::Burmese)),
+ "Burmese language with UTF-16 at second"
+ );
+ assert_eq!(iter.next(), None, "Iterator for UTF-16 is finished");
+
+ let mut iter = LanguageIterator::new(&s);
+ assert_eq!(
+ iter.next(),
+ Some((TEST_STR_THAI, Language::Thai)),
+ "Thai language with UTF-8 at first"
+ );
+ assert_eq!(
+ iter.next(),
+ Some((TEST_STR_BURMESE, Language::Burmese)),
+ "Burmese language with UTF-8 at second"
+ );
+ assert_eq!(iter.next(), None, "Iterator for UTF-8 is finished");
+ }
+}
diff --git a/third_party/rust/icu_segmenter/src/complex/lstm/matrix.rs b/third_party/rust/icu_segmenter/src/complex/lstm/matrix.rs
new file mode 100644
index 0000000000..3cf5ce2e3c
--- /dev/null
+++ b/third_party/rust/icu_segmenter/src/complex/lstm/matrix.rs
@@ -0,0 +1,540 @@
+// This file is part of ICU4X. For terms of use, please see the file
+// called LICENSE at the top level of the ICU4X source tree
+// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
+
+use alloc::vec;
+use alloc::vec::Vec;
+use core::ops::Range;
+#[allow(unused_imports)]
+use core_maths::*;
+use zerovec::ule::AsULE;
+use zerovec::ZeroSlice;
+
+/// A `D`-dimensional, heap-allocated matrix.
+///
+/// This matrix implementation supports slicing matrices into tightly-packed
+/// submatrices. For example, indexing into a matrix of size 5x4x3 returns a
+/// matrix of size 4x3. For more information, see [`MatrixOwned::submatrix`].
+#[derive(Debug, Clone)]
+pub(super) struct MatrixOwned<const D: usize> {
+ data: Vec<f32>,
+ dims: [usize; D],
+}
+
+impl<const D: usize> MatrixOwned<D> {
+ pub(super) fn as_borrowed(&self) -> MatrixBorrowed<D> {
+ MatrixBorrowed {
+ data: &self.data,
+ dims: self.dims,
+ }
+ }
+
+ pub(super) fn new_zero(dims: [usize; D]) -> Self {
+ let total_len = dims.iter().product::<usize>();
+ MatrixOwned {
+ data: vec![0.0; total_len],
+ dims,
+ }
+ }
+
+ /// Returns the tighly packed submatrix at _index_, or `None` if _index_ is out of range.
+ ///
+ /// For example, if the matrix is 5x4x3, this function returns a matrix sized 4x3. If the
+ /// matrix is 4x3, then this function returns a linear matrix of length 3.
+ ///
+ /// The type parameter `M` should be `D - 1`.
+ #[inline]
+ pub(super) fn submatrix<const M: usize>(&self, index: usize) -> Option<MatrixBorrowed<M>> {
+ // This assertion is based on const generics; it should always succeed and be elided.
+ assert_eq!(M, D - 1);
+ let (range, dims) = self.as_borrowed().submatrix_range(index);
+ let data = &self.data.get(range)?;
+ Some(MatrixBorrowed { data, dims })
+ }
+
+ pub(super) fn as_mut(&mut self) -> MatrixBorrowedMut<D> {
+ MatrixBorrowedMut {
+ data: &mut self.data,
+ dims: self.dims,
+ }
+ }
+
+ /// A mutable version of [`Self::submatrix`].
+ #[inline]
+ pub(super) fn submatrix_mut<const M: usize>(
+ &mut self,
+ index: usize,
+ ) -> Option<MatrixBorrowedMut<M>> {
+ // This assertion is based on const generics; it should always succeed and be elided.
+ assert_eq!(M, D - 1);
+ let (range, dims) = self.as_borrowed().submatrix_range(index);
+ let data = self.data.get_mut(range)?;
+ Some(MatrixBorrowedMut { data, dims })
+ }
+}
+
+/// A `D`-dimensional, borrowed matrix.
+#[derive(Debug, Clone, Copy)]
+pub(super) struct MatrixBorrowed<'a, const D: usize> {
+ data: &'a [f32],
+ dims: [usize; D],
+}
+
+impl<'a, const D: usize> MatrixBorrowed<'a, D> {
+ #[cfg(debug_assertions)]
+ pub(super) fn debug_assert_dims(&self, dims: [usize; D]) {
+ debug_assert_eq!(dims, self.dims);
+ let expected_len = dims.iter().product::<usize>();
+ debug_assert_eq!(expected_len, self.data.len());
+ }
+
+ pub(super) fn as_slice(&self) -> &'a [f32] {
+ self.data
+ }
+
+ /// See [`MatrixOwned::submatrix`].
+ #[inline]
+ pub(super) fn submatrix<const M: usize>(&self, index: usize) -> Option<MatrixBorrowed<'a, M>> {
+ // This assertion is based on const generics; it should always succeed and be elided.
+ assert_eq!(M, D - 1);
+ let (range, dims) = self.submatrix_range(index);
+ let data = &self.data.get(range)?;
+ Some(MatrixBorrowed { data, dims })
+ }
+
+ #[inline]
+ fn submatrix_range<const M: usize>(&self, index: usize) -> (Range<usize>, [usize; M]) {
+ // This assertion is based on const generics; it should always succeed and be elided.
+ assert_eq!(M, D - 1);
+ // The above assertion guarantees that the following line will succeed
+ #[allow(clippy::indexing_slicing, clippy::unwrap_used)]
+ let sub_dims: [usize; M] = self.dims[1..].try_into().unwrap();
+ let n = sub_dims.iter().product::<usize>();
+ (n * index..n * (index + 1), sub_dims)
+ }
+}
+
+macro_rules! impl_basic_dim {
+ ($t1:path, $t2:path, $t3:path) => {
+ impl<'a> $t1 {
+ #[allow(dead_code)]
+ pub(super) fn dim(&self) -> usize {
+ let [dim] = self.dims;
+ dim
+ }
+ }
+ impl<'a> $t2 {
+ #[allow(dead_code)]
+ pub(super) fn dim(&self) -> (usize, usize) {
+ let [d0, d1] = self.dims;
+ (d0, d1)
+ }
+ }
+ impl<'a> $t3 {
+ #[allow(dead_code)]
+ pub(super) fn dim(&self) -> (usize, usize, usize) {
+ let [d0, d1, d2] = self.dims;
+ (d0, d1, d2)
+ }
+ }
+ };
+}
+
+impl_basic_dim!(MatrixOwned<1>, MatrixOwned<2>, MatrixOwned<3>);
+impl_basic_dim!(
+ MatrixBorrowed<'a, 1>,
+ MatrixBorrowed<'a, 2>,
+ MatrixBorrowed<'a, 3>
+);
+impl_basic_dim!(
+ MatrixBorrowedMut<'a, 1>,
+ MatrixBorrowedMut<'a, 2>,
+ MatrixBorrowedMut<'a, 3>
+);
+impl_basic_dim!(MatrixZero<'a, 1>, MatrixZero<'a, 2>, MatrixZero<'a, 3>);
+
+/// A `D`-dimensional, mutably borrowed matrix.
+pub(super) struct MatrixBorrowedMut<'a, const D: usize> {
+ pub(super) data: &'a mut [f32],
+ pub(super) dims: [usize; D],
+}
+
+impl<'a, const D: usize> MatrixBorrowedMut<'a, D> {
+ pub(super) fn as_borrowed(&self) -> MatrixBorrowed<D> {
+ MatrixBorrowed {
+ data: self.data,
+ dims: self.dims,
+ }
+ }
+
+ pub(super) fn as_mut_slice(&mut self) -> &mut [f32] {
+ self.data
+ }
+
+ pub(super) fn copy_submatrix<const M: usize>(&mut self, from: usize, to: usize) {
+ let (range_from, _) = self.as_borrowed().submatrix_range::<M>(from);
+ let (range_to, _) = self.as_borrowed().submatrix_range::<M>(to);
+ if let (Some(_), Some(_)) = (
+ self.data.get(range_from.clone()),
+ self.data.get(range_to.clone()),
+ ) {
+ // This function is panicky, but we just validated the ranges
+ self.data.copy_within(range_from, range_to.start);
+ }
+ }
+
+ #[must_use]
+ pub(super) fn add(&mut self, other: MatrixZero<'_, D>) -> Option<()> {
+ debug_assert_eq!(self.dims, other.dims);
+ // TODO: Vectorize?
+ for i in 0..self.data.len() {
+ *self.data.get_mut(i)? += other.data.get(i)?;
+ }
+ Some(())
+ }
+
+ #[allow(dead_code)] // maybe needed for more complicated bies calculations
+ /// Mutates this matrix by applying a softmax transformation.
+ pub(super) fn softmax_transform(&mut self) {
+ for v in self.data.iter_mut() {
+ *v = v.exp();
+ }
+ let sm = 1.0 / self.data.iter().sum::<f32>();
+ for v in self.data.iter_mut() {
+ *v *= sm;
+ }
+ }
+
+ pub(super) fn sigmoid_transform(&mut self) {
+ for x in &mut self.data.iter_mut() {
+ *x = 1.0 / (1.0 + (-*x).exp());
+ }
+ }
+
+ pub(super) fn tanh_transform(&mut self) {
+ for x in &mut self.data.iter_mut() {
+ *x = x.tanh();
+ }
+ }
+
+ pub(super) fn convolve(
+ &mut self,
+ i: MatrixBorrowed<'_, D>,
+ c: MatrixBorrowed<'_, D>,
+ f: MatrixBorrowed<'_, D>,
+ ) {
+ let i = i.as_slice();
+ let c = c.as_slice();
+ let f = f.as_slice();
+ let len = self.data.len();
+ if len != i.len() || len != c.len() || len != f.len() {
+ debug_assert!(false, "LSTM matrices not the correct dimensions");
+ return;
+ }
+ for idx in 0..len {
+ // Safety: The lengths are all the same (checked above)
+ unsafe {
+ *self.data.get_unchecked_mut(idx) = i.get_unchecked(idx) * c.get_unchecked(idx)
+ + self.data.get_unchecked(idx) * f.get_unchecked(idx)
+ }
+ }
+ }
+
+ pub(super) fn mul_tanh(&mut self, o: MatrixBorrowed<'_, D>, c: MatrixBorrowed<'_, D>) {
+ let o = o.as_slice();
+ let c = c.as_slice();
+ let len = self.data.len();
+ if len != o.len() || len != c.len() {
+ debug_assert!(false, "LSTM matrices not the correct dimensions");
+ return;
+ }
+ for idx in 0..len {
+ // Safety: The lengths are all the same (checked above)
+ unsafe {
+ *self.data.get_unchecked_mut(idx) =
+ o.get_unchecked(idx) * c.get_unchecked(idx).tanh();
+ }
+ }
+ }
+}
+
+impl<'a> MatrixBorrowed<'a, 1> {
+ #[allow(dead_code)] // could be useful
+ pub(super) fn dot_1d(&self, other: MatrixZero<1>) -> f32 {
+ debug_assert_eq!(self.dims, other.dims);
+ unrolled_dot_1(self.data, other.data)
+ }
+}
+
+impl<'a> MatrixBorrowedMut<'a, 1> {
+ /// Calculate the dot product of a and b, adding the result to self.
+ ///
+ /// Note: For better dot product efficiency, if `b` is MxN, then `a` should be N;
+ /// this is the opposite of standard practice.
+ pub(super) fn add_dot_2d(&mut self, a: MatrixBorrowed<1>, b: MatrixZero<2>) {
+ let m = a.dim();
+ let n = self.as_borrowed().dim();
+ debug_assert_eq!(
+ m,
+ b.dim().1,
+ "dims: {:?}/{:?}/{:?}",
+ self.as_borrowed().dim(),
+ a.dim(),
+ b.dim()
+ );
+ debug_assert_eq!(
+ n,
+ b.dim().0,
+ "dims: {:?}/{:?}/{:?}",
+ self.as_borrowed().dim(),
+ a.dim(),
+ b.dim()
+ );
+ for i in 0..n {
+ if let (Some(dest), Some(b_sub)) = (self.as_mut_slice().get_mut(i), b.submatrix::<1>(i))
+ {
+ *dest += unrolled_dot_1(a.data, b_sub.data);
+ } else {
+ debug_assert!(false, "unreachable: dims checked above");
+ }
+ }
+ }
+}
+
+impl<'a> MatrixBorrowedMut<'a, 2> {
+ /// Calculate the dot product of a and b, adding the result to self.
+ ///
+ /// Self should be _MxN_; `a`, _O_; and `b`, _MxNxO_.
+ pub(super) fn add_dot_3d_1(&mut self, a: MatrixBorrowed<1>, b: MatrixZero<3>) {
+ let m = a.dim();
+ let n = self.as_borrowed().dim().0 * self.as_borrowed().dim().1;
+ debug_assert_eq!(
+ m,
+ b.dim().2,
+ "dims: {:?}/{:?}/{:?}",
+ self.as_borrowed().dim(),
+ a.dim(),
+ b.dim()
+ );
+ debug_assert_eq!(
+ n,
+ b.dim().0 * b.dim().1,
+ "dims: {:?}/{:?}/{:?}",
+ self.as_borrowed().dim(),
+ a.dim(),
+ b.dim()
+ );
+ // Note: The following two loops are equivalent, but the second has more opportunity for
+ // vectorization since it allows the vectorization to span submatrices.
+ // for i in 0..b.dim().0 {
+ // self.submatrix_mut::<1>(i).add_dot_2d(a, b.submatrix(i));
+ // }
+ let lhs = a.as_slice();
+ for i in 0..n {
+ if let (Some(dest), Some(rhs)) = (
+ self.as_mut_slice().get_mut(i),
+ b.as_slice().get_subslice(i * m..(i + 1) * m),
+ ) {
+ *dest += unrolled_dot_1(lhs, rhs);
+ } else {
+ debug_assert!(false, "unreachable: dims checked above");
+ }
+ }
+ }
+
+ /// Calculate the dot product of a and b, adding the result to self.
+ ///
+ /// Self should be _MxN_; `a`, _O_; and `b`, _MxNxO_.
+ pub(super) fn add_dot_3d_2(&mut self, a: MatrixZero<1>, b: MatrixZero<3>) {
+ let m = a.dim();
+ let n = self.as_borrowed().dim().0 * self.as_borrowed().dim().1;
+ debug_assert_eq!(
+ m,
+ b.dim().2,
+ "dims: {:?}/{:?}/{:?}",
+ self.as_borrowed().dim(),
+ a.dim(),
+ b.dim()
+ );
+ debug_assert_eq!(
+ n,
+ b.dim().0 * b.dim().1,
+ "dims: {:?}/{:?}/{:?}",
+ self.as_borrowed().dim(),
+ a.dim(),
+ b.dim()
+ );
+ // Note: The following two loops are equivalent, but the second has more opportunity for
+ // vectorization since it allows the vectorization to span submatrices.
+ // for i in 0..b.dim().0 {
+ // self.submatrix_mut::<1>(i).add_dot_2d(a, b.submatrix(i));
+ // }
+ let lhs = a.as_slice();
+ for i in 0..n {
+ if let (Some(dest), Some(rhs)) = (
+ self.as_mut_slice().get_mut(i),
+ b.as_slice().get_subslice(i * m..(i + 1) * m),
+ ) {
+ *dest += unrolled_dot_2(lhs, rhs);
+ } else {
+ debug_assert!(false, "unreachable: dims checked above");
+ }
+ }
+ }
+}
+
+/// A `D`-dimensional matrix borrowed from a [`ZeroSlice`].
+#[derive(Debug, Clone, Copy)]
+pub(super) struct MatrixZero<'a, const D: usize> {
+ data: &'a ZeroSlice<f32>,
+ dims: [usize; D],
+}
+
+impl<'a> From<&'a crate::provider::LstmMatrix1<'a>> for MatrixZero<'a, 1> {
+ fn from(other: &'a crate::provider::LstmMatrix1<'a>) -> Self {
+ Self {
+ data: &other.data,
+ dims: other.dims.map(|x| x as usize),
+ }
+ }
+}
+
+impl<'a> From<&'a crate::provider::LstmMatrix2<'a>> for MatrixZero<'a, 2> {
+ fn from(other: &'a crate::provider::LstmMatrix2<'a>) -> Self {
+ Self {
+ data: &other.data,
+ dims: other.dims.map(|x| x as usize),
+ }
+ }
+}
+
+impl<'a> From<&'a crate::provider::LstmMatrix3<'a>> for MatrixZero<'a, 3> {
+ fn from(other: &'a crate::provider::LstmMatrix3<'a>) -> Self {
+ Self {
+ data: &other.data,
+ dims: other.dims.map(|x| x as usize),
+ }
+ }
+}
+
+impl<'a, const D: usize> MatrixZero<'a, D> {
+ #[allow(clippy::wrong_self_convention)] // same convention as slice::to_vec
+ pub(super) fn to_owned(&self) -> MatrixOwned<D> {
+ MatrixOwned {
+ data: self.data.iter().collect(),
+ dims: self.dims,
+ }
+ }
+
+ pub(super) fn as_slice(&self) -> &ZeroSlice<f32> {
+ self.data
+ }
+
+ #[cfg(debug_assertions)]
+ pub(super) fn debug_assert_dims(&self, dims: [usize; D]) {
+ debug_assert_eq!(dims, self.dims);
+ let expected_len = dims.iter().product::<usize>();
+ debug_assert_eq!(expected_len, self.data.len());
+ }
+
+ /// See [`MatrixOwned::submatrix`].
+ #[inline]
+ pub(super) fn submatrix<const M: usize>(&self, index: usize) -> Option<MatrixZero<'a, M>> {
+ // This assertion is based on const generics; it should always succeed and be elided.
+ assert_eq!(M, D - 1);
+ let (range, dims) = self.submatrix_range(index);
+ let data = &self.data.get_subslice(range)?;
+ Some(MatrixZero { data, dims })
+ }
+
+ #[inline]
+ fn submatrix_range<const M: usize>(&self, index: usize) -> (Range<usize>, [usize; M]) {
+ // This assertion is based on const generics; it should always succeed and be elided.
+ assert_eq!(M, D - 1);
+ // The above assertion guarantees that the following line will succeed
+ #[allow(clippy::indexing_slicing, clippy::unwrap_used)]
+ let sub_dims: [usize; M] = self.dims[1..].try_into().unwrap();
+ let n = sub_dims.iter().product::<usize>();
+ (n * index..n * (index + 1), sub_dims)
+ }
+}
+
+macro_rules! f32c {
+ ($ule:expr) => {
+ f32::from_unaligned($ule)
+ };
+}
+
+/// Compute the dot product of an aligned and an unaligned f32 slice.
+///
+/// `xs` and `ys` must be the same length
+///
+/// (Based on ndarray 0.15.6)
+fn unrolled_dot_1(xs: &[f32], ys: &ZeroSlice<f32>) -> f32 {
+ debug_assert_eq!(xs.len(), ys.len());
+ // eightfold unrolled so that floating point can be vectorized
+ // (even with strict floating point accuracy semantics)
+ let mut p = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0);
+ let xit = xs.chunks_exact(8);
+ let yit = ys.as_ule_slice().chunks_exact(8);
+ let sum = xit
+ .remainder()
+ .iter()
+ .zip(yit.remainder().iter())
+ .map(|(x, y)| x * f32c!(*y))
+ .sum::<f32>();
+ for (xx, yy) in xit.zip(yit) {
+ // TODO: Use array_chunks once stable to avoid the unwrap.
+ // <https://github.com/rust-lang/rust/issues/74985>
+ #[allow(clippy::unwrap_used)]
+ let [x0, x1, x2, x3, x4, x5, x6, x7] = *<&[f32; 8]>::try_from(xx).unwrap();
+ #[allow(clippy::unwrap_used)]
+ let [y0, y1, y2, y3, y4, y5, y6, y7] = *<&[<f32 as AsULE>::ULE; 8]>::try_from(yy).unwrap();
+ p.0 += x0 * f32c!(y0);
+ p.1 += x1 * f32c!(y1);
+ p.2 += x2 * f32c!(y2);
+ p.3 += x3 * f32c!(y3);
+ p.4 += x4 * f32c!(y4);
+ p.5 += x5 * f32c!(y5);
+ p.6 += x6 * f32c!(y6);
+ p.7 += x7 * f32c!(y7);
+ }
+ sum + (p.0 + p.4) + (p.1 + p.5) + (p.2 + p.6) + (p.3 + p.7)
+}
+
+/// Compute the dot product of two unaligned f32 slices.
+///
+/// `xs` and `ys` must be the same length
+///
+/// (Based on ndarray 0.15.6)
+fn unrolled_dot_2(xs: &ZeroSlice<f32>, ys: &ZeroSlice<f32>) -> f32 {
+ debug_assert_eq!(xs.len(), ys.len());
+ // eightfold unrolled so that floating point can be vectorized
+ // (even with strict floating point accuracy semantics)
+ let mut p = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0);
+ let xit = xs.as_ule_slice().chunks_exact(8);
+ let yit = ys.as_ule_slice().chunks_exact(8);
+ let sum = xit
+ .remainder()
+ .iter()
+ .zip(yit.remainder().iter())
+ .map(|(x, y)| f32c!(*x) * f32c!(*y))
+ .sum::<f32>();
+ for (xx, yy) in xit.zip(yit) {
+ // TODO: Use array_chunks once stable to avoid the unwrap.
+ // <https://github.com/rust-lang/rust/issues/74985>
+ #[allow(clippy::unwrap_used)]
+ let [x0, x1, x2, x3, x4, x5, x6, x7] = *<&[<f32 as AsULE>::ULE; 8]>::try_from(xx).unwrap();
+ #[allow(clippy::unwrap_used)]
+ let [y0, y1, y2, y3, y4, y5, y6, y7] = *<&[<f32 as AsULE>::ULE; 8]>::try_from(yy).unwrap();
+ p.0 += f32c!(x0) * f32c!(y0);
+ p.1 += f32c!(x1) * f32c!(y1);
+ p.2 += f32c!(x2) * f32c!(y2);
+ p.3 += f32c!(x3) * f32c!(y3);
+ p.4 += f32c!(x4) * f32c!(y4);
+ p.5 += f32c!(x5) * f32c!(y5);
+ p.6 += f32c!(x6) * f32c!(y6);
+ p.7 += f32c!(x7) * f32c!(y7);
+ }
+ sum + (p.0 + p.4) + (p.1 + p.5) + (p.2 + p.6) + (p.3 + p.7)
+}
diff --git a/third_party/rust/icu_segmenter/src/complex/lstm/mod.rs b/third_party/rust/icu_segmenter/src/complex/lstm/mod.rs
new file mode 100644
index 0000000000..8718cbd3da
--- /dev/null
+++ b/third_party/rust/icu_segmenter/src/complex/lstm/mod.rs
@@ -0,0 +1,402 @@
+// This file is part of ICU4X. For terms of use, please see the file
+// called LICENSE at the top level of the ICU4X source tree
+// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
+
+use crate::grapheme::GraphemeClusterSegmenter;
+use crate::provider::*;
+use alloc::vec::Vec;
+use core::char::{decode_utf16, REPLACEMENT_CHARACTER};
+use zerovec::{maps::ZeroMapBorrowed, ule::UnvalidatedStr};
+
+mod matrix;
+use matrix::*;
+
+// A word break iterator using LSTM model. Input string have to be same language.
+
+struct LstmSegmenterIterator<'s> {
+ input: &'s str,
+ pos_utf8: usize,
+ bies: BiesIterator<'s>,
+}
+
+impl Iterator for LstmSegmenterIterator<'_> {
+ type Item = usize;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ #[allow(clippy::indexing_slicing)] // pos_utf8 in range
+ loop {
+ let is_e = self.bies.next()?;
+ self.pos_utf8 += self.input[self.pos_utf8..].chars().next()?.len_utf8();
+ if is_e || self.bies.len() == 0 {
+ return Some(self.pos_utf8);
+ }
+ }
+ }
+}
+
+struct LstmSegmenterIteratorUtf16<'s> {
+ bies: BiesIterator<'s>,
+ pos: usize,
+}
+
+impl Iterator for LstmSegmenterIteratorUtf16<'_> {
+ type Item = usize;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ loop {
+ self.pos += 1;
+ if self.bies.next()? || self.bies.len() == 0 {
+ return Some(self.pos);
+ }
+ }
+ }
+}
+
+pub(super) struct LstmSegmenter<'l> {
+ dic: ZeroMapBorrowed<'l, UnvalidatedStr, u16>,
+ embedding: MatrixZero<'l, 2>,
+ fw_w: MatrixZero<'l, 3>,
+ fw_u: MatrixZero<'l, 3>,
+ fw_b: MatrixZero<'l, 2>,
+ bw_w: MatrixZero<'l, 3>,
+ bw_u: MatrixZero<'l, 3>,
+ bw_b: MatrixZero<'l, 2>,
+ timew_fw: MatrixZero<'l, 2>,
+ timew_bw: MatrixZero<'l, 2>,
+ time_b: MatrixZero<'l, 1>,
+ grapheme: Option<&'l RuleBreakDataV1<'l>>,
+}
+
+impl<'l> LstmSegmenter<'l> {
+ /// Returns `Err` if grapheme data is required but not present
+ pub(super) fn new(lstm: &'l LstmDataV1<'l>, grapheme: &'l RuleBreakDataV1<'l>) -> Self {
+ let LstmDataV1::Float32(lstm) = lstm;
+ let time_w = MatrixZero::from(&lstm.time_w);
+ #[allow(clippy::unwrap_used)] // shape (2, 4, hunits)
+ let timew_fw = time_w.submatrix(0).unwrap();
+ #[allow(clippy::unwrap_used)] // shape (2, 4, hunits)
+ let timew_bw = time_w.submatrix(1).unwrap();
+ Self {
+ dic: lstm.dic.as_borrowed(),
+ embedding: MatrixZero::from(&lstm.embedding),
+ fw_w: MatrixZero::from(&lstm.fw_w),
+ fw_u: MatrixZero::from(&lstm.fw_u),
+ fw_b: MatrixZero::from(&lstm.fw_b),
+ bw_w: MatrixZero::from(&lstm.bw_w),
+ bw_u: MatrixZero::from(&lstm.bw_u),
+ bw_b: MatrixZero::from(&lstm.bw_b),
+ timew_fw,
+ timew_bw,
+ time_b: MatrixZero::from(&lstm.time_b),
+ grapheme: (lstm.model == ModelType::GraphemeClusters).then_some(grapheme),
+ }
+ }
+
+ /// Create an LSTM based break iterator for an `str` (a UTF-8 string).
+ pub(super) fn segment_str(&'l self, input: &'l str) -> impl Iterator<Item = usize> + 'l {
+ self.segment_str_p(input)
+ }
+
+ // For unit testing as we cannot inspect the opaque type's bies
+ fn segment_str_p(&'l self, input: &'l str) -> LstmSegmenterIterator<'l> {
+ let input_seq = if let Some(grapheme) = self.grapheme {
+ GraphemeClusterSegmenter::new_and_segment_str(input, grapheme)
+ .collect::<Vec<usize>>()
+ .windows(2)
+ .map(|chunk| {
+ let range = if let [first, second, ..] = chunk {
+ *first..*second
+ } else {
+ unreachable!()
+ };
+ let grapheme_cluster = if let Some(grapheme_cluster) = input.get(range) {
+ grapheme_cluster
+ } else {
+ return self.dic.len() as u16;
+ };
+
+ self.dic
+ .get_copied(UnvalidatedStr::from_str(grapheme_cluster))
+ .unwrap_or_else(|| self.dic.len() as u16)
+ })
+ .collect()
+ } else {
+ input
+ .chars()
+ .map(|c| {
+ self.dic
+ .get_copied(UnvalidatedStr::from_str(c.encode_utf8(&mut [0; 4])))
+ .unwrap_or_else(|| self.dic.len() as u16)
+ })
+ .collect()
+ };
+ LstmSegmenterIterator {
+ input,
+ pos_utf8: 0,
+ bies: BiesIterator::new(self, input_seq),
+ }
+ }
+
+ /// Create an LSTM based break iterator for a UTF-16 string.
+ pub(super) fn segment_utf16(&'l self, input: &[u16]) -> impl Iterator<Item = usize> + 'l {
+ let input_seq = if let Some(grapheme) = self.grapheme {
+ GraphemeClusterSegmenter::new_and_segment_utf16(input, grapheme)
+ .collect::<Vec<usize>>()
+ .windows(2)
+ .map(|chunk| {
+ let range = if let [first, second, ..] = chunk {
+ *first..*second
+ } else {
+ unreachable!()
+ };
+ let grapheme_cluster = if let Some(grapheme_cluster) = input.get(range) {
+ grapheme_cluster
+ } else {
+ return self.dic.len() as u16;
+ };
+
+ self.dic
+ .get_copied_by(|key| {
+ key.as_bytes().iter().copied().cmp(
+ decode_utf16(grapheme_cluster.iter().copied()).flat_map(|c| {
+ let mut buf = [0; 4];
+ let len = c
+ .unwrap_or(REPLACEMENT_CHARACTER)
+ .encode_utf8(&mut buf)
+ .len();
+ buf.into_iter().take(len)
+ }),
+ )
+ })
+ .unwrap_or_else(|| self.dic.len() as u16)
+ })
+ .collect()
+ } else {
+ decode_utf16(input.iter().copied())
+ .map(|c| c.unwrap_or(REPLACEMENT_CHARACTER))
+ .map(|c| {
+ self.dic
+ .get_copied(UnvalidatedStr::from_str(c.encode_utf8(&mut [0; 4])))
+ .unwrap_or_else(|| self.dic.len() as u16)
+ })
+ .collect()
+ };
+ LstmSegmenterIteratorUtf16 {
+ bies: BiesIterator::new(self, input_seq),
+ pos: 0,
+ }
+ }
+}
+
+struct BiesIterator<'l> {
+ segmenter: &'l LstmSegmenter<'l>,
+ input_seq: core::iter::Enumerate<alloc::vec::IntoIter<u16>>,
+ h_bw: MatrixOwned<2>,
+ curr_fw: MatrixOwned<1>,
+ c_fw: MatrixOwned<1>,
+}
+
+impl<'l> BiesIterator<'l> {
+ // input_seq is a sequence of id numbers that represents grapheme clusters or code points in the input line. These ids are used later
+ // in the embedding layer of the model.
+ fn new(segmenter: &'l LstmSegmenter<'l>, input_seq: Vec<u16>) -> Self {
+ let hunits = segmenter.fw_u.dim().1;
+
+ // Backward LSTM
+ let mut c_bw = MatrixOwned::<1>::new_zero([hunits]);
+ let mut h_bw = MatrixOwned::<2>::new_zero([input_seq.len(), hunits]);
+ for (i, &g_id) in input_seq.iter().enumerate().rev() {
+ if i + 1 < input_seq.len() {
+ h_bw.as_mut().copy_submatrix::<1>(i + 1, i);
+ }
+ #[allow(clippy::unwrap_used)]
+ compute_hc(
+ segmenter.embedding.submatrix::<1>(g_id as usize).unwrap(), /* shape (dict.len() + 1, hunit), g_id is at most dict.len() */
+ h_bw.submatrix_mut(i).unwrap(), // shape (input_seq.len(), hunits)
+ c_bw.as_mut(),
+ segmenter.bw_w,
+ segmenter.bw_u,
+ segmenter.bw_b,
+ );
+ }
+
+ Self {
+ input_seq: input_seq.into_iter().enumerate(),
+ h_bw,
+ c_fw: MatrixOwned::<1>::new_zero([hunits]),
+ curr_fw: MatrixOwned::<1>::new_zero([hunits]),
+ segmenter,
+ }
+ }
+}
+
+impl ExactSizeIterator for BiesIterator<'_> {
+ fn len(&self) -> usize {
+ self.input_seq.len()
+ }
+}
+
+impl Iterator for BiesIterator<'_> {
+ type Item = bool;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ let (i, g_id) = self.input_seq.next()?;
+
+ #[allow(clippy::unwrap_used)]
+ compute_hc(
+ self.segmenter
+ .embedding
+ .submatrix::<1>(g_id as usize)
+ .unwrap(), // shape (dict.len() + 1, hunit), g_id is at most dict.len()
+ self.curr_fw.as_mut(),
+ self.c_fw.as_mut(),
+ self.segmenter.fw_w,
+ self.segmenter.fw_u,
+ self.segmenter.fw_b,
+ );
+
+ #[allow(clippy::unwrap_used)] // shape (input_seq.len(), hunits)
+ let curr_bw = self.h_bw.submatrix::<1>(i).unwrap();
+ let mut weights = [0.0; 4];
+ let mut curr_est = MatrixBorrowedMut {
+ data: &mut weights,
+ dims: [4],
+ };
+ curr_est.add_dot_2d(self.curr_fw.as_borrowed(), self.segmenter.timew_fw);
+ curr_est.add_dot_2d(curr_bw, self.segmenter.timew_bw);
+ #[allow(clippy::unwrap_used)] // both shape (4)
+ curr_est.add(self.segmenter.time_b).unwrap();
+ // For correct BIES weight calculation we'd now have to apply softmax, however
+ // we're only doing a naive argmax, so a monotonic function doesn't make a difference.
+
+ Some(weights[2] > weights[0] && weights[2] > weights[1] && weights[2] > weights[3])
+ }
+}
+
+/// `compute_hc1` implemens the evaluation of one LSTM layer.
+fn compute_hc<'a>(
+ x_t: MatrixZero<'a, 1>,
+ mut h_tm1: MatrixBorrowedMut<'a, 1>,
+ mut c_tm1: MatrixBorrowedMut<'a, 1>,
+ w: MatrixZero<'a, 3>,
+ u: MatrixZero<'a, 3>,
+ b: MatrixZero<'a, 2>,
+) {
+ #[cfg(debug_assertions)]
+ {
+ let hunits = h_tm1.dim();
+ let embedd_dim = x_t.dim();
+ c_tm1.as_borrowed().debug_assert_dims([hunits]);
+ w.debug_assert_dims([4, hunits, embedd_dim]);
+ u.debug_assert_dims([4, hunits, hunits]);
+ b.debug_assert_dims([4, hunits]);
+ }
+
+ let mut s_t = b.to_owned();
+
+ s_t.as_mut().add_dot_3d_2(x_t, w);
+ s_t.as_mut().add_dot_3d_1(h_tm1.as_borrowed(), u);
+
+ #[allow(clippy::unwrap_used)] // first dimension is 4
+ s_t.submatrix_mut::<1>(0).unwrap().sigmoid_transform();
+ #[allow(clippy::unwrap_used)] // first dimension is 4
+ s_t.submatrix_mut::<1>(1).unwrap().sigmoid_transform();
+ #[allow(clippy::unwrap_used)] // first dimension is 4
+ s_t.submatrix_mut::<1>(2).unwrap().tanh_transform();
+ #[allow(clippy::unwrap_used)] // first dimension is 4
+ s_t.submatrix_mut::<1>(3).unwrap().sigmoid_transform();
+
+ #[allow(clippy::unwrap_used)] // first dimension is 4
+ c_tm1.convolve(
+ s_t.as_borrowed().submatrix(0).unwrap(),
+ s_t.as_borrowed().submatrix(2).unwrap(),
+ s_t.as_borrowed().submatrix(1).unwrap(),
+ );
+
+ #[allow(clippy::unwrap_used)] // first dimension is 4
+ h_tm1.mul_tanh(s_t.as_borrowed().submatrix(3).unwrap(), c_tm1.as_borrowed());
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use icu_locid::locale;
+ use icu_provider::prelude::*;
+ use serde::Deserialize;
+ use std::fs::File;
+ use std::io::BufReader;
+
+ /// `TestCase` is a struct used to store a single test case.
+ /// Each test case has two attributs: `unseg` which denots the unsegmented line, and `true_bies` which indicates the Bies
+ /// sequence representing the true segmentation.
+ #[derive(PartialEq, Debug, Deserialize)]
+ struct TestCase {
+ unseg: String,
+ expected_bies: String,
+ true_bies: String,
+ }
+
+ /// `TestTextData` is a struct to store a vector of `TestCase` that represents a test text.
+ #[derive(PartialEq, Debug, Deserialize)]
+ struct TestTextData {
+ testcases: Vec<TestCase>,
+ }
+
+ #[derive(Debug)]
+ struct TestText {
+ data: TestTextData,
+ }
+
+ fn load_test_text(filename: &str) -> TestTextData {
+ let file = File::open(filename).expect("File should be present");
+ let reader = BufReader::new(file);
+ serde_json::from_reader(reader).expect("JSON syntax error")
+ }
+
+ #[test]
+ fn segment_file_by_lstm() {
+ let lstm: DataPayload<LstmForWordLineAutoV1Marker> = crate::provider::Baked
+ .load(DataRequest {
+ locale: &locale!("th").into(),
+ metadata: Default::default(),
+ })
+ .unwrap()
+ .take_payload()
+ .unwrap();
+ let lstm = LstmSegmenter::new(
+ lstm.get(),
+ crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1,
+ );
+
+ // Importing the test data
+ let test_text_data = load_test_text(&format!(
+ "tests/testdata/test_text_{}.json",
+ if lstm.grapheme.is_some() {
+ "grapheme"
+ } else {
+ "codepoints"
+ }
+ ));
+ let test_text = TestText {
+ data: test_text_data,
+ };
+
+ // Testing
+ for test_case in &test_text.data.testcases {
+ let lstm_output = lstm
+ .segment_str_p(&test_case.unseg)
+ .bies
+ .map(|is_e| if is_e { 'e' } else { '?' })
+ .collect::<String>();
+ println!("Test case : {}", test_case.unseg);
+ println!("Expected bies : {}", test_case.expected_bies);
+ println!("Estimated bies : {lstm_output}");
+ println!("True bies : {}", test_case.true_bies);
+ println!("****************************************************");
+ assert_eq!(
+ test_case.expected_bies.replace(['b', 'i', 's'], "?"),
+ lstm_output
+ );
+ }
+ }
+}
diff --git a/third_party/rust/icu_segmenter/src/complex/mod.rs b/third_party/rust/icu_segmenter/src/complex/mod.rs
new file mode 100644
index 0000000000..65f49a92f0
--- /dev/null
+++ b/third_party/rust/icu_segmenter/src/complex/mod.rs
@@ -0,0 +1,440 @@
+// This file is part of ICU4X. For terms of use, please see the file
+// called LICENSE at the top level of the ICU4X source tree
+// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
+
+use crate::provider::*;
+use alloc::vec::Vec;
+use icu_locid::{locale, Locale};
+use icu_provider::prelude::*;
+
+mod dictionary;
+use dictionary::*;
+mod language;
+use language::*;
+#[cfg(feature = "lstm")]
+mod lstm;
+#[cfg(feature = "lstm")]
+use lstm::*;
+
+#[cfg(not(feature = "lstm"))]
+type DictOrLstm = Result<DataPayload<UCharDictionaryBreakDataV1Marker>, core::convert::Infallible>;
+#[cfg(not(feature = "lstm"))]
+type DictOrLstmBorrowed<'a> =
+ Result<&'a DataPayload<UCharDictionaryBreakDataV1Marker>, &'a core::convert::Infallible>;
+
+#[cfg(feature = "lstm")]
+type DictOrLstm =
+ Result<DataPayload<UCharDictionaryBreakDataV1Marker>, DataPayload<LstmDataV1Marker>>;
+#[cfg(feature = "lstm")]
+type DictOrLstmBorrowed<'a> =
+ Result<&'a DataPayload<UCharDictionaryBreakDataV1Marker>, &'a DataPayload<LstmDataV1Marker>>;
+
+#[derive(Debug)]
+pub(crate) struct ComplexPayloads {
+ grapheme: DataPayload<GraphemeClusterBreakDataV1Marker>,
+ my: Option<DictOrLstm>,
+ km: Option<DictOrLstm>,
+ lo: Option<DictOrLstm>,
+ th: Option<DictOrLstm>,
+ ja: Option<DataPayload<UCharDictionaryBreakDataV1Marker>>,
+}
+
+impl ComplexPayloads {
+ fn select(&self, language: Language) -> Option<DictOrLstmBorrowed> {
+ const ERR: DataError = DataError::custom("No segmentation model for language");
+ match language {
+ Language::Burmese => self.my.as_ref().map(Result::as_ref).or_else(|| {
+ ERR.with_display_context("my");
+ None
+ }),
+ Language::Khmer => self.km.as_ref().map(Result::as_ref).or_else(|| {
+ ERR.with_display_context("km");
+ None
+ }),
+ Language::Lao => self.lo.as_ref().map(Result::as_ref).or_else(|| {
+ ERR.with_display_context("lo");
+ None
+ }),
+ Language::Thai => self.th.as_ref().map(Result::as_ref).or_else(|| {
+ ERR.with_display_context("th");
+ None
+ }),
+ Language::ChineseOrJapanese => self.ja.as_ref().map(Ok).or_else(|| {
+ ERR.with_display_context("ja");
+ None
+ }),
+ Language::Unknown => None,
+ }
+ }
+
+ #[cfg(feature = "lstm")]
+ #[cfg(feature = "compiled_data")]
+ pub(crate) fn new_lstm() -> Self {
+ #[allow(clippy::unwrap_used)]
+ // try_load is infallible if the provider only returns `MissingLocale`.
+ Self {
+ grapheme: DataPayload::from_static_ref(
+ crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1,
+ ),
+ my: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("my"))
+ .unwrap()
+ .map(DataPayload::cast)
+ .map(Err),
+ km: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("km"))
+ .unwrap()
+ .map(DataPayload::cast)
+ .map(Err),
+ lo: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("lo"))
+ .unwrap()
+ .map(DataPayload::cast)
+ .map(Err),
+ th: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("th"))
+ .unwrap()
+ .map(DataPayload::cast)
+ .map(Err),
+ ja: None,
+ }
+ }
+
+ #[cfg(feature = "lstm")]
+ pub(crate) fn try_new_lstm<D>(provider: &D) -> Result<Self, DataError>
+ where
+ D: DataProvider<GraphemeClusterBreakDataV1Marker>
+ + DataProvider<LstmForWordLineAutoV1Marker>
+ + ?Sized,
+ {
+ Ok(Self {
+ grapheme: provider.load(Default::default())?.take_payload()?,
+ my: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("my"))?
+ .map(DataPayload::cast)
+ .map(Err),
+ km: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("km"))?
+ .map(DataPayload::cast)
+ .map(Err),
+ lo: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("lo"))?
+ .map(DataPayload::cast)
+ .map(Err),
+ th: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("th"))?
+ .map(DataPayload::cast)
+ .map(Err),
+ ja: None,
+ })
+ }
+
+ #[cfg(feature = "compiled_data")]
+ pub(crate) fn new_dict() -> Self {
+ #[allow(clippy::unwrap_used)]
+ // try_load is infallible if the provider only returns `MissingLocale`.
+ Self {
+ grapheme: DataPayload::from_static_ref(
+ crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1,
+ ),
+ my: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
+ &crate::provider::Baked,
+ locale!("my"),
+ )
+ .unwrap()
+ .map(DataPayload::cast)
+ .map(Ok),
+ km: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
+ &crate::provider::Baked,
+ locale!("km"),
+ )
+ .unwrap()
+ .map(DataPayload::cast)
+ .map(Ok),
+ lo: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
+ &crate::provider::Baked,
+ locale!("lo"),
+ )
+ .unwrap()
+ .map(DataPayload::cast)
+ .map(Ok),
+ th: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
+ &crate::provider::Baked,
+ locale!("th"),
+ )
+ .unwrap()
+ .map(DataPayload::cast)
+ .map(Ok),
+ ja: try_load::<DictionaryForWordOnlyAutoV1Marker, _>(
+ &crate::provider::Baked,
+ locale!("ja"),
+ )
+ .unwrap()
+ .map(DataPayload::cast),
+ }
+ }
+
+ pub(crate) fn try_new_dict<D>(provider: &D) -> Result<Self, DataError>
+ where
+ D: DataProvider<GraphemeClusterBreakDataV1Marker>
+ + DataProvider<DictionaryForWordLineExtendedV1Marker>
+ + DataProvider<DictionaryForWordOnlyAutoV1Marker>
+ + ?Sized,
+ {
+ Ok(Self {
+ grapheme: provider.load(Default::default())?.take_payload()?,
+ my: try_load::<DictionaryForWordLineExtendedV1Marker, D>(provider, locale!("my"))?
+ .map(DataPayload::cast)
+ .map(Ok),
+ km: try_load::<DictionaryForWordLineExtendedV1Marker, D>(provider, locale!("km"))?
+ .map(DataPayload::cast)
+ .map(Ok),
+ lo: try_load::<DictionaryForWordLineExtendedV1Marker, D>(provider, locale!("lo"))?
+ .map(DataPayload::cast)
+ .map(Ok),
+ th: try_load::<DictionaryForWordLineExtendedV1Marker, D>(provider, locale!("th"))?
+ .map(DataPayload::cast)
+ .map(Ok),
+ ja: try_load::<DictionaryForWordOnlyAutoV1Marker, D>(provider, locale!("ja"))?
+ .map(DataPayload::cast),
+ })
+ }
+
+ #[cfg(feature = "auto")] // Use by WordSegmenter with "auto" enabled.
+ #[cfg(feature = "compiled_data")]
+ pub(crate) fn new_auto() -> Self {
+ #[allow(clippy::unwrap_used)]
+ // try_load is infallible if the provider only returns `MissingLocale`.
+ Self {
+ grapheme: DataPayload::from_static_ref(
+ crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1,
+ ),
+ my: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("my"))
+ .unwrap()
+ .map(DataPayload::cast)
+ .map(Err),
+ km: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("km"))
+ .unwrap()
+ .map(DataPayload::cast)
+ .map(Err),
+ lo: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("lo"))
+ .unwrap()
+ .map(DataPayload::cast)
+ .map(Err),
+ th: try_load::<LstmForWordLineAutoV1Marker, _>(&crate::provider::Baked, locale!("th"))
+ .unwrap()
+ .map(DataPayload::cast)
+ .map(Err),
+ ja: try_load::<DictionaryForWordOnlyAutoV1Marker, _>(
+ &crate::provider::Baked,
+ locale!("ja"),
+ )
+ .unwrap()
+ .map(DataPayload::cast),
+ }
+ }
+
+ #[cfg(feature = "auto")] // Use by WordSegmenter with "auto" enabled.
+ pub(crate) fn try_new_auto<D>(provider: &D) -> Result<Self, DataError>
+ where
+ D: DataProvider<GraphemeClusterBreakDataV1Marker>
+ + DataProvider<LstmForWordLineAutoV1Marker>
+ + DataProvider<DictionaryForWordOnlyAutoV1Marker>
+ + ?Sized,
+ {
+ Ok(Self {
+ grapheme: provider.load(Default::default())?.take_payload()?,
+ my: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("my"))?
+ .map(DataPayload::cast)
+ .map(Err),
+ km: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("km"))?
+ .map(DataPayload::cast)
+ .map(Err),
+ lo: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("lo"))?
+ .map(DataPayload::cast)
+ .map(Err),
+ th: try_load::<LstmForWordLineAutoV1Marker, D>(provider, locale!("th"))?
+ .map(DataPayload::cast)
+ .map(Err),
+ ja: try_load::<DictionaryForWordOnlyAutoV1Marker, D>(provider, locale!("ja"))?
+ .map(DataPayload::cast),
+ })
+ }
+
+ #[cfg(feature = "compiled_data")]
+ pub(crate) fn new_southeast_asian() -> Self {
+ #[allow(clippy::unwrap_used)]
+ // try_load is infallible if the provider only returns `MissingLocale`.
+ Self {
+ grapheme: DataPayload::from_static_ref(
+ crate::provider::Baked::SINGLETON_SEGMENTER_GRAPHEME_V1,
+ ),
+ my: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
+ &crate::provider::Baked,
+ locale!("my"),
+ )
+ .unwrap()
+ .map(DataPayload::cast)
+ .map(Ok),
+ km: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
+ &crate::provider::Baked,
+ locale!("km"),
+ )
+ .unwrap()
+ .map(DataPayload::cast)
+ .map(Ok),
+ lo: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
+ &crate::provider::Baked,
+ locale!("lo"),
+ )
+ .unwrap()
+ .map(DataPayload::cast)
+ .map(Ok),
+ th: try_load::<DictionaryForWordLineExtendedV1Marker, _>(
+ &crate::provider::Baked,
+ locale!("th"),
+ )
+ .unwrap()
+ .map(DataPayload::cast)
+ .map(Ok),
+ ja: None,
+ }
+ }
+
+ pub(crate) fn try_new_southeast_asian<D>(provider: &D) -> Result<Self, DataError>
+ where
+ D: DataProvider<DictionaryForWordLineExtendedV1Marker>
+ + DataProvider<GraphemeClusterBreakDataV1Marker>
+ + ?Sized,
+ {
+ Ok(Self {
+ grapheme: provider.load(Default::default())?.take_payload()?,
+ my: try_load::<DictionaryForWordLineExtendedV1Marker, _>(provider, locale!("my"))?
+ .map(DataPayload::cast)
+ .map(Ok),
+ km: try_load::<DictionaryForWordLineExtendedV1Marker, _>(provider, locale!("km"))?
+ .map(DataPayload::cast)
+ .map(Ok),
+ lo: try_load::<DictionaryForWordLineExtendedV1Marker, _>(provider, locale!("lo"))?
+ .map(DataPayload::cast)
+ .map(Ok),
+ th: try_load::<DictionaryForWordLineExtendedV1Marker, _>(provider, locale!("th"))?
+ .map(DataPayload::cast)
+ .map(Ok),
+ ja: None,
+ })
+ }
+}
+
+fn try_load<M: KeyedDataMarker, P: DataProvider<M> + ?Sized>(
+ provider: &P,
+ locale: Locale,
+) -> Result<Option<DataPayload<M>>, DataError> {
+ match provider.load(DataRequest {
+ locale: &DataLocale::from(locale),
+ metadata: {
+ let mut m = DataRequestMetadata::default();
+ m.silent = true;
+ m
+ },
+ }) {
+ Ok(response) => Ok(Some(response.take_payload()?)),
+ Err(DataError {
+ kind: DataErrorKind::MissingLocale,
+ ..
+ }) => Ok(None),
+ Err(e) => Err(e),
+ }
+}
+
+/// Return UTF-16 segment offset array using dictionary or lstm segmenter.
+pub(crate) fn complex_language_segment_utf16(
+ payloads: &ComplexPayloads,
+ input: &[u16],
+) -> Vec<usize> {
+ let mut result = Vec::new();
+ let mut offset = 0;
+ for (slice, lang) in LanguageIteratorUtf16::new(input) {
+ match payloads.select(lang) {
+ Some(Ok(dict)) => {
+ result.extend(
+ DictionarySegmenter::new(dict.get(), payloads.grapheme.get())
+ .segment_utf16(slice)
+ .map(|n| offset + n),
+ );
+ }
+ #[cfg(feature = "lstm")]
+ Some(Err(lstm)) => {
+ result.extend(
+ LstmSegmenter::new(lstm.get(), payloads.grapheme.get())
+ .segment_utf16(slice)
+ .map(|n| offset + n),
+ );
+ }
+ #[cfg(not(feature = "lstm"))]
+ Some(Err(_infallible)) => {} // should be refutable
+ None => {
+ result.push(offset + slice.len());
+ }
+ }
+ offset += slice.len();
+ }
+ result
+}
+
+/// Return UTF-8 segment offset array using dictionary or lstm segmenter.
+pub(crate) fn complex_language_segment_str(payloads: &ComplexPayloads, input: &str) -> Vec<usize> {
+ let mut result = Vec::new();
+ let mut offset = 0;
+ for (slice, lang) in LanguageIterator::new(input) {
+ match payloads.select(lang) {
+ Some(Ok(dict)) => {
+ result.extend(
+ DictionarySegmenter::new(dict.get(), payloads.grapheme.get())
+ .segment_str(slice)
+ .map(|n| offset + n),
+ );
+ }
+ #[cfg(feature = "lstm")]
+ Some(Err(lstm)) => {
+ result.extend(
+ LstmSegmenter::new(lstm.get(), payloads.grapheme.get())
+ .segment_str(slice)
+ .map(|n| offset + n),
+ );
+ }
+ #[cfg(not(feature = "lstm"))]
+ Some(Err(_infallible)) => {} // should be refutable
+ None => {
+ result.push(offset + slice.len());
+ }
+ }
+ offset += slice.len();
+ }
+ result
+}
+
+#[cfg(test)]
+#[cfg(feature = "serde")]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn thai_word_break() {
+ const TEST_STR: &str = "ภาษาไทยภาษาไทย";
+ let utf16: Vec<u16> = TEST_STR.encode_utf16().collect();
+
+ let lstm = ComplexPayloads::new_lstm();
+ let dict = ComplexPayloads::new_dict();
+
+ assert_eq!(
+ complex_language_segment_str(&lstm, TEST_STR),
+ [12, 21, 33, 42]
+ );
+ assert_eq!(
+ complex_language_segment_utf16(&lstm, &utf16),
+ [4, 7, 11, 14]
+ );
+
+ assert_eq!(
+ complex_language_segment_str(&dict, TEST_STR),
+ [12, 21, 33, 42]
+ );
+ assert_eq!(
+ complex_language_segment_utf16(&dict, &utf16),
+ [4, 7, 11, 14]
+ );
+ }
+}