summaryrefslogtreecommitdiffstats
path: root/intl/icu/source/common/lstmbe.cpp
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 01:47:29 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 01:47:29 +0000
commit0ebf5bdf043a27fd3dfb7f92e0cb63d88954c44d (patch)
treea31f07c9bcca9d56ce61e9a1ffd30ef350d513aa /intl/icu/source/common/lstmbe.cpp
parentInitial commit. (diff)
downloadfirefox-esr-0ebf5bdf043a27fd3dfb7f92e0cb63d88954c44d.tar.xz
firefox-esr-0ebf5bdf043a27fd3dfb7f92e0cb63d88954c44d.zip
Adding upstream version 115.8.0esr.upstream/115.8.0esr
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'intl/icu/source/common/lstmbe.cpp')
-rw-r--r--intl/icu/source/common/lstmbe.cpp856
1 files changed, 856 insertions, 0 deletions
diff --git a/intl/icu/source/common/lstmbe.cpp b/intl/icu/source/common/lstmbe.cpp
new file mode 100644
index 0000000000..fb8eb01761
--- /dev/null
+++ b/intl/icu/source/common/lstmbe.cpp
@@ -0,0 +1,856 @@
+// © 2021 and later: Unicode, Inc. and others.
+// License & terms of use: http://www.unicode.org/copyright.html
+
+#include <complex>
+#include <utility>
+
+#include "unicode/utypes.h"
+
+#if !UCONFIG_NO_BREAK_ITERATION
+
+#include "brkeng.h"
+#include "charstr.h"
+#include "cmemory.h"
+#include "lstmbe.h"
+#include "putilimp.h"
+#include "uassert.h"
+#include "ubrkimpl.h"
+#include "uresimp.h"
+#include "uvectr32.h"
+#include "uvector.h"
+
+#include "unicode/brkiter.h"
+#include "unicode/resbund.h"
+#include "unicode/ubrk.h"
+#include "unicode/uniset.h"
+#include "unicode/ustring.h"
+#include "unicode/utf.h"
+
+U_NAMESPACE_BEGIN
+
+// Uncomment the following #define to debug.
+// #define LSTM_DEBUG 1
+// #define LSTM_VECTORIZER_DEBUG 1
+
+/**
+ * Interface for reading 1D array.
+ */
+class ReadArray1D {
+public:
+ virtual ~ReadArray1D();
+ virtual int32_t d1() const = 0;
+ virtual float get(int32_t i) const = 0;
+
+#ifdef LSTM_DEBUG
+ void print() const {
+ printf("\n[");
+ for (int32_t i = 0; i < d1(); i++) {
+ printf("%0.8e ", get(i));
+ if (i % 4 == 3) printf("\n");
+ }
+ printf("]\n");
+ }
+#endif
+};
+
+ReadArray1D::~ReadArray1D()
+{
+}
+
+/**
+ * Interface for reading 2D array.
+ */
+class ReadArray2D {
+public:
+ virtual ~ReadArray2D();
+ virtual int32_t d1() const = 0;
+ virtual int32_t d2() const = 0;
+ virtual float get(int32_t i, int32_t j) const = 0;
+};
+
+ReadArray2D::~ReadArray2D()
+{
+}
+
+/**
+ * A class to index a float array as a 1D Array without owning the pointer or
+ * copy the data.
+ */
+class ConstArray1D : public ReadArray1D {
+public:
+ ConstArray1D() : data_(nullptr), d1_(0) {}
+
+ ConstArray1D(const float* data, int32_t d1) : data_(data), d1_(d1) {}
+
+ virtual ~ConstArray1D();
+
+ // Init the object, the object does not own the data nor copy.
+ // It is designed to directly use data from memory mapped resources.
+ void init(const int32_t* data, int32_t d1) {
+ U_ASSERT(IEEE_754 == 1);
+ data_ = reinterpret_cast<const float*>(data);
+ d1_ = d1;
+ }
+
+ // ReadArray1D methods.
+ virtual int32_t d1() const override { return d1_; }
+ virtual float get(int32_t i) const override {
+ U_ASSERT(i < d1_);
+ return data_[i];
+ }
+
+private:
+ const float* data_;
+ int32_t d1_;
+};
+
+ConstArray1D::~ConstArray1D()
+{
+}
+
+/**
+ * A class to index a float array as a 2D Array without owning the pointer or
+ * copy the data.
+ */
+class ConstArray2D : public ReadArray2D {
+public:
+ ConstArray2D() : data_(nullptr), d1_(0), d2_(0) {}
+
+ ConstArray2D(const float* data, int32_t d1, int32_t d2)
+ : data_(data), d1_(d1), d2_(d2) {}
+
+ virtual ~ConstArray2D();
+
+ // Init the object, the object does not own the data nor copy.
+ // It is designed to directly use data from memory mapped resources.
+ void init(const int32_t* data, int32_t d1, int32_t d2) {
+ U_ASSERT(IEEE_754 == 1);
+ data_ = reinterpret_cast<const float*>(data);
+ d1_ = d1;
+ d2_ = d2;
+ }
+
+ // ReadArray2D methods.
+ inline int32_t d1() const override { return d1_; }
+ inline int32_t d2() const override { return d2_; }
+ float get(int32_t i, int32_t j) const override {
+ U_ASSERT(i < d1_);
+ U_ASSERT(j < d2_);
+ return data_[i * d2_ + j];
+ }
+
+ // Expose the ith row as a ConstArray1D
+ inline ConstArray1D row(int32_t i) const {
+ U_ASSERT(i < d1_);
+ return ConstArray1D(data_ + i * d2_, d2_);
+ }
+
+private:
+ const float* data_;
+ int32_t d1_;
+ int32_t d2_;
+};
+
+ConstArray2D::~ConstArray2D()
+{
+}
+
+/**
+ * A class to allocate data as a writable 1D array.
+ * This is the main class implement matrix operation.
+ */
+class Array1D : public ReadArray1D {
+public:
+ Array1D() : memory_(nullptr), data_(nullptr), d1_(0) {}
+ Array1D(int32_t d1, UErrorCode &status)
+ : memory_(uprv_malloc(d1 * sizeof(float))),
+ data_((float*)memory_), d1_(d1) {
+ if (U_SUCCESS(status)) {
+ if (memory_ == nullptr) {
+ status = U_MEMORY_ALLOCATION_ERROR;
+ return;
+ }
+ clear();
+ }
+ }
+
+ virtual ~Array1D();
+
+ // A special constructor which does not own the memory but writeable
+ // as a slice of an array.
+ Array1D(float* data, int32_t d1)
+ : memory_(nullptr), data_(data), d1_(d1) {}
+
+ // ReadArray1D methods.
+ virtual int32_t d1() const override { return d1_; }
+ virtual float get(int32_t i) const override {
+ U_ASSERT(i < d1_);
+ return data_[i];
+ }
+
+ // Return the index which point to the max data in the array.
+ inline int32_t maxIndex() const {
+ int32_t index = 0;
+ float max = data_[0];
+ for (int32_t i = 1; i < d1_; i++) {
+ if (data_[i] > max) {
+ max = data_[i];
+ index = i;
+ }
+ }
+ return index;
+ }
+
+ // Slice part of the array to a new one.
+ inline Array1D slice(int32_t from, int32_t size) const {
+ U_ASSERT(from >= 0);
+ U_ASSERT(from < d1_);
+ U_ASSERT(from + size <= d1_);
+ return Array1D(data_ + from, size);
+ }
+
+ // Add dot product of a 1D array and a 2D array into this one.
+ inline Array1D& addDotProduct(const ReadArray1D& a, const ReadArray2D& b) {
+ U_ASSERT(a.d1() == b.d1());
+ U_ASSERT(b.d2() == d1());
+ for (int32_t i = 0; i < d1(); i++) {
+ for (int32_t j = 0; j < a.d1(); j++) {
+ data_[i] += a.get(j) * b.get(j, i);
+ }
+ }
+ return *this;
+ }
+
+ // Hadamard Product the values of another array of the same size into this one.
+ inline Array1D& hadamardProduct(const ReadArray1D& a) {
+ U_ASSERT(a.d1() == d1());
+ for (int32_t i = 0; i < d1(); i++) {
+ data_[i] *= a.get(i);
+ }
+ return *this;
+ }
+
+ // Add the Hadamard Product of two arrays of the same size into this one.
+ inline Array1D& addHadamardProduct(const ReadArray1D& a, const ReadArray1D& b) {
+ U_ASSERT(a.d1() == d1());
+ U_ASSERT(b.d1() == d1());
+ for (int32_t i = 0; i < d1(); i++) {
+ data_[i] += a.get(i) * b.get(i);
+ }
+ return *this;
+ }
+
+ // Add the values of another array of the same size into this one.
+ inline Array1D& add(const ReadArray1D& a) {
+ U_ASSERT(a.d1() == d1());
+ for (int32_t i = 0; i < d1(); i++) {
+ data_[i] += a.get(i);
+ }
+ return *this;
+ }
+
+ // Assign the values of another array of the same size into this one.
+ inline Array1D& assign(const ReadArray1D& a) {
+ U_ASSERT(a.d1() == d1());
+ for (int32_t i = 0; i < d1(); i++) {
+ data_[i] = a.get(i);
+ }
+ return *this;
+ }
+
+ // Apply tanh to all the elements in the array.
+ inline Array1D& tanh() {
+ return tanh(*this);
+ }
+
+ // Apply tanh of a and store into this array.
+ inline Array1D& tanh(const Array1D& a) {
+ U_ASSERT(a.d1() == d1());
+ for (int32_t i = 0; i < d1_; i++) {
+ data_[i] = std::tanh(a.get(i));
+ }
+ return *this;
+ }
+
+ // Apply sigmoid to all the elements in the array.
+ inline Array1D& sigmoid() {
+ for (int32_t i = 0; i < d1_; i++) {
+ data_[i] = 1.0f/(1.0f + expf(-data_[i]));
+ }
+ return *this;
+ }
+
+ inline Array1D& clear() {
+ uprv_memset(data_, 0, d1_ * sizeof(float));
+ return *this;
+ }
+
+private:
+ void* memory_;
+ float* data_;
+ int32_t d1_;
+};
+
+Array1D::~Array1D()
+{
+ uprv_free(memory_);
+}
+
+class Array2D : public ReadArray2D {
+public:
+ Array2D() : memory_(nullptr), data_(nullptr), d1_(0), d2_(0) {}
+ Array2D(int32_t d1, int32_t d2, UErrorCode &status)
+ : memory_(uprv_malloc(d1 * d2 * sizeof(float))),
+ data_((float*)memory_), d1_(d1), d2_(d2) {
+ if (U_SUCCESS(status)) {
+ if (memory_ == nullptr) {
+ status = U_MEMORY_ALLOCATION_ERROR;
+ return;
+ }
+ clear();
+ }
+ }
+ virtual ~Array2D();
+
+ // ReadArray2D methods.
+ virtual int32_t d1() const override { return d1_; }
+ virtual int32_t d2() const override { return d2_; }
+ virtual float get(int32_t i, int32_t j) const override {
+ U_ASSERT(i < d1_);
+ U_ASSERT(j < d2_);
+ return data_[i * d2_ + j];
+ }
+
+ inline Array1D row(int32_t i) const {
+ U_ASSERT(i < d1_);
+ return Array1D(data_ + i * d2_, d2_);
+ }
+
+ inline Array2D& clear() {
+ uprv_memset(data_, 0, d1_ * d2_ * sizeof(float));
+ return *this;
+ }
+
+private:
+ void* memory_;
+ float* data_;
+ int32_t d1_;
+ int32_t d2_;
+};
+
+Array2D::~Array2D()
+{
+ uprv_free(memory_);
+}
+
+typedef enum {
+ BEGIN,
+ INSIDE,
+ END,
+ SINGLE
+} LSTMClass;
+
+typedef enum {
+ UNKNOWN,
+ CODE_POINTS,
+ GRAPHEME_CLUSTER,
+} EmbeddingType;
+
+struct LSTMData : public UMemory {
+ LSTMData(UResourceBundle* rb, UErrorCode &status);
+ ~LSTMData();
+ UHashtable* fDict;
+ EmbeddingType fType;
+ const char16_t* fName;
+ ConstArray2D fEmbedding;
+ ConstArray2D fForwardW;
+ ConstArray2D fForwardU;
+ ConstArray1D fForwardB;
+ ConstArray2D fBackwardW;
+ ConstArray2D fBackwardU;
+ ConstArray1D fBackwardB;
+ ConstArray2D fOutputW;
+ ConstArray1D fOutputB;
+
+private:
+ UResourceBundle* fBundle;
+};
+
+LSTMData::LSTMData(UResourceBundle* rb, UErrorCode &status)
+ : fDict(nullptr), fType(UNKNOWN), fName(nullptr),
+ fBundle(rb)
+{
+ if (U_FAILURE(status)) {
+ return;
+ }
+ if (IEEE_754 != 1) {
+ status = U_UNSUPPORTED_ERROR;
+ return;
+ }
+ LocalUResourceBundlePointer embeddings_res(
+ ures_getByKey(rb, "embeddings", nullptr, &status));
+ int32_t embedding_size = ures_getInt(embeddings_res.getAlias(), &status);
+ LocalUResourceBundlePointer hunits_res(
+ ures_getByKey(rb, "hunits", nullptr, &status));
+ if (U_FAILURE(status)) return;
+ int32_t hunits = ures_getInt(hunits_res.getAlias(), &status);
+ const char16_t* type = ures_getStringByKey(rb, "type", nullptr, &status);
+ if (U_FAILURE(status)) return;
+ if (u_strCompare(type, -1, u"codepoints", -1, false) == 0) {
+ fType = CODE_POINTS;
+ } else if (u_strCompare(type, -1, u"graphclust", -1, false) == 0) {
+ fType = GRAPHEME_CLUSTER;
+ }
+ fName = ures_getStringByKey(rb, "model", nullptr, &status);
+ LocalUResourceBundlePointer dataRes(ures_getByKey(rb, "data", nullptr, &status));
+ if (U_FAILURE(status)) return;
+ int32_t data_len = 0;
+ const int32_t* data = ures_getIntVector(dataRes.getAlias(), &data_len, &status);
+ fDict = uhash_open(uhash_hashUChars, uhash_compareUChars, nullptr, &status);
+
+ StackUResourceBundle stackTempBundle;
+ ResourceDataValue value;
+ ures_getValueWithFallback(rb, "dict", stackTempBundle.getAlias(), value, status);
+ ResourceArray stringArray = value.getArray(status);
+ int32_t num_index = stringArray.getSize();
+ if (U_FAILURE(status)) { return; }
+
+ // put dict into hash
+ int32_t stringLength;
+ for (int32_t idx = 0; idx < num_index; idx++) {
+ stringArray.getValue(idx, value);
+ const char16_t* str = value.getString(stringLength, status);
+ uhash_putiAllowZero(fDict, (void*)str, idx, &status);
+ if (U_FAILURE(status)) return;
+#ifdef LSTM_VECTORIZER_DEBUG
+ printf("Assign [");
+ while (*str != 0x0000) {
+ printf("U+%04x ", *str);
+ str++;
+ }
+ printf("] map to %d\n", idx-1);
+#endif
+ }
+ int32_t mat1_size = (num_index + 1) * embedding_size;
+ int32_t mat2_size = embedding_size * 4 * hunits;
+ int32_t mat3_size = hunits * 4 * hunits;
+ int32_t mat4_size = 4 * hunits;
+ int32_t mat5_size = mat2_size;
+ int32_t mat6_size = mat3_size;
+ int32_t mat7_size = mat4_size;
+ int32_t mat8_size = 2 * hunits * 4;
+#if U_DEBUG
+ int32_t mat9_size = 4;
+ U_ASSERT(data_len == mat1_size + mat2_size + mat3_size + mat4_size + mat5_size +
+ mat6_size + mat7_size + mat8_size + mat9_size);
+#endif
+
+ fEmbedding.init(data, (num_index + 1), embedding_size);
+ data += mat1_size;
+ fForwardW.init(data, embedding_size, 4 * hunits);
+ data += mat2_size;
+ fForwardU.init(data, hunits, 4 * hunits);
+ data += mat3_size;
+ fForwardB.init(data, 4 * hunits);
+ data += mat4_size;
+ fBackwardW.init(data, embedding_size, 4 * hunits);
+ data += mat5_size;
+ fBackwardU.init(data, hunits, 4 * hunits);
+ data += mat6_size;
+ fBackwardB.init(data, 4 * hunits);
+ data += mat7_size;
+ fOutputW.init(data, 2 * hunits, 4);
+ data += mat8_size;
+ fOutputB.init(data, 4);
+}
+
+LSTMData::~LSTMData() {
+ uhash_close(fDict);
+ ures_close(fBundle);
+}
+
+class Vectorizer : public UMemory {
+public:
+ Vectorizer(UHashtable* dict) : fDict(dict) {}
+ virtual ~Vectorizer();
+ virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
+ UVector32 &offsets, UVector32 &indices,
+ UErrorCode &status) const = 0;
+protected:
+ int32_t stringToIndex(const char16_t* str) const {
+ UBool found = false;
+ int32_t ret = uhash_getiAndFound(fDict, (const void*)str, &found);
+ if (!found) {
+ ret = fDict->count;
+ }
+#ifdef LSTM_VECTORIZER_DEBUG
+ printf("[");
+ while (*str != 0x0000) {
+ printf("U+%04x ", *str);
+ str++;
+ }
+ printf("] map to %d\n", ret);
+#endif
+ return ret;
+ }
+
+private:
+ UHashtable* fDict;
+};
+
+Vectorizer::~Vectorizer()
+{
+}
+
+class CodePointsVectorizer : public Vectorizer {
+public:
+ CodePointsVectorizer(UHashtable* dict) : Vectorizer(dict) {}
+ virtual ~CodePointsVectorizer();
+ virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
+ UVector32 &offsets, UVector32 &indices,
+ UErrorCode &status) const override;
+};
+
+CodePointsVectorizer::~CodePointsVectorizer()
+{
+}
+
+void CodePointsVectorizer::vectorize(
+ UText *text, int32_t startPos, int32_t endPos,
+ UVector32 &offsets, UVector32 &indices, UErrorCode &status) const
+{
+ if (offsets.ensureCapacity(endPos - startPos, status) &&
+ indices.ensureCapacity(endPos - startPos, status)) {
+ if (U_FAILURE(status)) return;
+ utext_setNativeIndex(text, startPos);
+ int32_t current;
+ char16_t str[2] = {0, 0};
+ while (U_SUCCESS(status) &&
+ (current = (int32_t)utext_getNativeIndex(text)) < endPos) {
+ // Since the LSTMBreakEngine is currently only accept chars in BMP,
+ // we can ignore the possibility of hitting supplementary code
+ // point.
+ str[0] = (char16_t) utext_next32(text);
+ U_ASSERT(!U_IS_SURROGATE(str[0]));
+ offsets.addElement(current, status);
+ indices.addElement(stringToIndex(str), status);
+ }
+ }
+}
+
+class GraphemeClusterVectorizer : public Vectorizer {
+public:
+ GraphemeClusterVectorizer(UHashtable* dict)
+ : Vectorizer(dict)
+ {
+ }
+ virtual ~GraphemeClusterVectorizer();
+ virtual void vectorize(UText *text, int32_t startPos, int32_t endPos,
+ UVector32 &offsets, UVector32 &indices,
+ UErrorCode &status) const override;
+};
+
+GraphemeClusterVectorizer::~GraphemeClusterVectorizer()
+{
+}
+
+constexpr int32_t MAX_GRAPHEME_CLSTER_LENGTH = 10;
+
+void GraphemeClusterVectorizer::vectorize(
+ UText *text, int32_t startPos, int32_t endPos,
+ UVector32 &offsets, UVector32 &indices, UErrorCode &status) const
+{
+ if (U_FAILURE(status)) return;
+ if (!offsets.ensureCapacity(endPos - startPos, status) ||
+ !indices.ensureCapacity(endPos - startPos, status)) {
+ return;
+ }
+ if (U_FAILURE(status)) return;
+ LocalPointer<BreakIterator> graphemeIter(BreakIterator::createCharacterInstance(Locale(), status));
+ if (U_FAILURE(status)) return;
+ graphemeIter->setText(text, status);
+ if (U_FAILURE(status)) return;
+
+ if (startPos != 0) {
+ graphemeIter->preceding(startPos);
+ }
+ int32_t last = startPos;
+ int32_t current = startPos;
+ char16_t str[MAX_GRAPHEME_CLSTER_LENGTH];
+ while ((current = graphemeIter->next()) != BreakIterator::DONE) {
+ if (current >= endPos) {
+ break;
+ }
+ if (current > startPos) {
+ utext_extract(text, last, current, str, MAX_GRAPHEME_CLSTER_LENGTH, &status);
+ if (U_FAILURE(status)) return;
+ offsets.addElement(last, status);
+ indices.addElement(stringToIndex(str), status);
+ if (U_FAILURE(status)) return;
+ }
+ last = current;
+ }
+ if (U_FAILURE(status) || last >= endPos) {
+ return;
+ }
+ utext_extract(text, last, endPos, str, MAX_GRAPHEME_CLSTER_LENGTH, &status);
+ if (U_SUCCESS(status)) {
+ offsets.addElement(last, status);
+ indices.addElement(stringToIndex(str), status);
+ }
+}
+
+// Computing LSTM as stated in
+// https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate
+// ifco is temp array allocate outside which does not need to be
+// input/output value but could avoid unnecessary memory alloc/free if passing
+// in.
+void compute(
+ int32_t hunits,
+ const ReadArray2D& W, const ReadArray2D& U, const ReadArray1D& b,
+ const ReadArray1D& x, Array1D& h, Array1D& c,
+ Array1D& ifco)
+{
+ // ifco = x * W + h * U + b
+ ifco.assign(b)
+ .addDotProduct(x, W)
+ .addDotProduct(h, U);
+
+ ifco.slice(0*hunits, hunits).sigmoid(); // i: sigmod
+ ifco.slice(1*hunits, hunits).sigmoid(); // f: sigmoid
+ ifco.slice(2*hunits, hunits).tanh(); // c_: tanh
+ ifco.slice(3*hunits, hunits).sigmoid(); // o: sigmod
+
+ c.hadamardProduct(ifco.slice(hunits, hunits))
+ .addHadamardProduct(ifco.slice(0, hunits), ifco.slice(2*hunits, hunits));
+
+ h.tanh(c)
+ .hadamardProduct(ifco.slice(3*hunits, hunits));
+}
+
+// Minimum word size
+static const int32_t MIN_WORD = 2;
+
+// Minimum number of characters for two words
+static const int32_t MIN_WORD_SPAN = MIN_WORD * 2;
+
+int32_t
+LSTMBreakEngine::divideUpDictionaryRange( UText *text,
+ int32_t startPos,
+ int32_t endPos,
+ UVector32 &foundBreaks,
+ UBool /* isPhraseBreaking */,
+ UErrorCode& status) const {
+ if (U_FAILURE(status)) return 0;
+ int32_t beginFoundBreakSize = foundBreaks.size();
+ utext_setNativeIndex(text, startPos);
+ utext_moveIndex32(text, MIN_WORD_SPAN);
+ if (utext_getNativeIndex(text) >= endPos) {
+ return 0; // Not enough characters for two words
+ }
+ utext_setNativeIndex(text, startPos);
+
+ UVector32 offsets(status);
+ UVector32 indices(status);
+ if (U_FAILURE(status)) return 0;
+ fVectorizer->vectorize(text, startPos, endPos, offsets, indices, status);
+ if (U_FAILURE(status)) return 0;
+ int32_t* offsetsBuf = offsets.getBuffer();
+ int32_t* indicesBuf = indices.getBuffer();
+
+ int32_t input_seq_len = indices.size();
+ int32_t hunits = fData->fForwardU.d1();
+
+ // ----- Begin of all the Array memory allocation needed for this function
+ // Allocate temp array used inside compute()
+ Array1D ifco(4 * hunits, status);
+
+ Array1D c(hunits, status);
+ Array1D logp(4, status);
+
+ // TODO: limit size of hBackward. If input_seq_len is too big, we could
+ // run out of memory.
+ // Backward LSTM
+ Array2D hBackward(input_seq_len, hunits, status);
+
+ // Allocate fbRow and slice the internal array in two.
+ Array1D fbRow(2 * hunits, status);
+
+ // ----- End of all the Array memory allocation needed for this function
+ if (U_FAILURE(status)) return 0;
+
+ // To save the needed memory usage, the following is different from the
+ // Python or ICU4X implementation. We first perform the Backward LSTM
+ // and then merge the iteration of the forward LSTM and the output layer
+ // together because we only neetdto remember the h[t-1] for Forward LSTM.
+ for (int32_t i = input_seq_len - 1; i >= 0; i--) {
+ Array1D hRow = hBackward.row(i);
+ if (i != input_seq_len - 1) {
+ hRow.assign(hBackward.row(i+1));
+ }
+#ifdef LSTM_DEBUG
+ printf("hRow %d\n", i);
+ hRow.print();
+ printf("indicesBuf[%d] = %d\n", i, indicesBuf[i]);
+ printf("fData->fEmbedding.row(indicesBuf[%d]):\n", i);
+ fData->fEmbedding.row(indicesBuf[i]).print();
+#endif // LSTM_DEBUG
+ compute(hunits,
+ fData->fBackwardW, fData->fBackwardU, fData->fBackwardB,
+ fData->fEmbedding.row(indicesBuf[i]),
+ hRow, c, ifco);
+ }
+
+
+ Array1D forwardRow = fbRow.slice(0, hunits); // point to first half of data in fbRow.
+ Array1D backwardRow = fbRow.slice(hunits, hunits); // point to second half of data n fbRow.
+
+ // The following iteration merge the forward LSTM and the output layer
+ // together.
+ c.clear(); // reuse c since it is the same size.
+ for (int32_t i = 0; i < input_seq_len; i++) {
+#ifdef LSTM_DEBUG
+ printf("forwardRow %d\n", i);
+ forwardRow.print();
+#endif // LSTM_DEBUG
+ // Forward LSTM
+ // Calculate the result into forwardRow, which point to the data in the first half
+ // of fbRow.
+ compute(hunits,
+ fData->fForwardW, fData->fForwardU, fData->fForwardB,
+ fData->fEmbedding.row(indicesBuf[i]),
+ forwardRow, c, ifco);
+
+ // assign the data from hBackward.row(i) to second half of fbRowa.
+ backwardRow.assign(hBackward.row(i));
+
+ logp.assign(fData->fOutputB).addDotProduct(fbRow, fData->fOutputW);
+#ifdef LSTM_DEBUG
+ printf("backwardRow %d\n", i);
+ backwardRow.print();
+ printf("logp %d\n", i);
+ logp.print();
+#endif // LSTM_DEBUG
+
+ // current = argmax(logp)
+ LSTMClass current = (LSTMClass)logp.maxIndex();
+ // BIES logic.
+ if (current == BEGIN || current == SINGLE) {
+ if (i != 0) {
+ foundBreaks.addElement(offsetsBuf[i], status);
+ if (U_FAILURE(status)) return 0;
+ }
+ }
+ }
+ return foundBreaks.size() - beginFoundBreakSize;
+}
+
+Vectorizer* createVectorizer(const LSTMData* data, UErrorCode &status) {
+ if (U_FAILURE(status)) {
+ return nullptr;
+ }
+ switch (data->fType) {
+ case CODE_POINTS:
+ return new CodePointsVectorizer(data->fDict);
+ break;
+ case GRAPHEME_CLUSTER:
+ return new GraphemeClusterVectorizer(data->fDict);
+ break;
+ default:
+ break;
+ }
+ UPRV_UNREACHABLE_EXIT;
+}
+
+LSTMBreakEngine::LSTMBreakEngine(const LSTMData* data, const UnicodeSet& set, UErrorCode &status)
+ : DictionaryBreakEngine(), fData(data), fVectorizer(createVectorizer(fData, status))
+{
+ if (U_FAILURE(status)) {
+ fData = nullptr; // If failure, we should not delete fData in destructor because the caller will do so.
+ return;
+ }
+ setCharacters(set);
+}
+
+LSTMBreakEngine::~LSTMBreakEngine() {
+ delete fData;
+ delete fVectorizer;
+}
+
+const char16_t* LSTMBreakEngine::name() const {
+ return fData->fName;
+}
+
+UnicodeString defaultLSTM(UScriptCode script, UErrorCode& status) {
+ // open root from brkitr tree.
+ UResourceBundle *b = ures_open(U_ICUDATA_BRKITR, "", &status);
+ b = ures_getByKeyWithFallback(b, "lstm", b, &status);
+ UnicodeString result = ures_getUnicodeStringByKey(b, uscript_getShortName(script), &status);
+ ures_close(b);
+ return result;
+}
+
+U_CAPI const LSTMData* U_EXPORT2 CreateLSTMDataForScript(UScriptCode script, UErrorCode& status)
+{
+ if (script != USCRIPT_KHMER && script != USCRIPT_LAO && script != USCRIPT_MYANMAR && script != USCRIPT_THAI) {
+ return nullptr;
+ }
+ UnicodeString name = defaultLSTM(script, status);
+ if (U_FAILURE(status)) return nullptr;
+ CharString namebuf;
+ namebuf.appendInvariantChars(name, status).truncate(namebuf.lastIndexOf('.'));
+
+ LocalUResourceBundlePointer rb(
+ ures_openDirect(U_ICUDATA_BRKITR, namebuf.data(), &status));
+ if (U_FAILURE(status)) return nullptr;
+
+ return CreateLSTMData(rb.orphan(), status);
+}
+
+U_CAPI const LSTMData* U_EXPORT2 CreateLSTMData(UResourceBundle* rb, UErrorCode& status)
+{
+ return new LSTMData(rb, status);
+}
+
+U_CAPI const LanguageBreakEngine* U_EXPORT2
+CreateLSTMBreakEngine(UScriptCode script, const LSTMData* data, UErrorCode& status)
+{
+ UnicodeString unicodeSetString;
+ switch(script) {
+ case USCRIPT_THAI:
+ unicodeSetString = UnicodeString(u"[[:Thai:]&[:LineBreak=SA:]]");
+ break;
+ case USCRIPT_MYANMAR:
+ unicodeSetString = UnicodeString(u"[[:Mymr:]&[:LineBreak=SA:]]");
+ break;
+ default:
+ delete data;
+ return nullptr;
+ }
+ UnicodeSet unicodeSet;
+ unicodeSet.applyPattern(unicodeSetString, status);
+ const LanguageBreakEngine* engine = new LSTMBreakEngine(data, unicodeSet, status);
+ if (U_FAILURE(status) || engine == nullptr) {
+ if (engine != nullptr) {
+ delete engine;
+ } else {
+ status = U_MEMORY_ALLOCATION_ERROR;
+ }
+ return nullptr;
+ }
+ return engine;
+}
+
+U_CAPI void U_EXPORT2 DeleteLSTMData(const LSTMData* data)
+{
+ delete data;
+}
+
+U_CAPI const char16_t* U_EXPORT2 LSTMDataName(const LSTMData* data)
+{
+ return data->fName;
+}
+
+U_NAMESPACE_END
+
+#endif /* #if !UCONFIG_NO_BREAK_ITERATION */