diff options
Diffstat (limited to 'toolkit/components/translations/fasttext/fasttext.js')
-rw-r--r-- | toolkit/components/translations/fasttext/fasttext.js | 536 |
1 files changed, 536 insertions, 0 deletions
diff --git a/toolkit/components/translations/fasttext/fasttext.js b/toolkit/components/translations/fasttext/fasttext.js new file mode 100644 index 0000000000..a79dfeffa0 --- /dev/null +++ b/toolkit/components/translations/fasttext/fasttext.js @@ -0,0 +1,536 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +let fastTextModule; + +const _initFastTextModule = async function (wasmModule) { + try { + fastTextModule = await loadFastTextModule(wasmModule); + } catch(e) { + console.error(e); + } + return true +} + +let postRunFunc = null; +const addOnPostRun = function (func) { + postRunFunc = func; +}; + + +const loadFastText = (wasmModule) => { + _initFastTextModule(wasmModule).then((res) => { + if (postRunFunc) { + postRunFunc(); + } + }) +} +const thisModule = this; +const trainFileInWasmFs = 'train.txt'; +const testFileInWasmFs = 'test.txt'; +const modelFileInWasmFs = 'model.bin'; + +const getFloat32ArrayFromHeap = (len) => { + const dataBytes = len * Float32Array.BYTES_PER_ELEMENT; + const dataPtr = fastTextModule._malloc(dataBytes); + const dataHeap = new Uint8Array(fastTextModule.HEAPU8.buffer, + dataPtr, + dataBytes); + return { + 'ptr':dataHeap.byteOffset, + 'size':len, + 'buffer':dataHeap.buffer + }; +}; + +const heapToFloat32 = (r) => new Float32Array(r.buffer, r.ptr, r.size); + +class FastText { + constructor(fastTextModule) { + this.f = new fastTextModule.FastText(); + } + + /** + * loadModel + * + * Loads the model file from the specified url, and returns the + * corresponding `FastTextModel` object. + * + * @param {string} url + * the url of the model file. + * + * @return {Promise} promise object that resolves to a `FastTextModel` + * + */ + loadModel(url) { + const fetchFunc = (thisModule && thisModule.fetch) || fetch; + + const fastTextNative = this.f; + return new Promise(function(resolve, reject) { + fetchFunc(url).then(response => { + return response.arrayBuffer(); + }).then(bytes => { + const byteArray = new Uint8Array(bytes); + const FS = fastTextModule.FS; + FS.writeFile(modelFileInWasmFs, byteArray); + }).then(() => { + fastTextNative.loadModel(modelFileInWasmFs); + resolve(new FastTextModel(fastTextNative)); + }).catch(error => { + reject(error); + }); + }); + } + + loadModelBinary(buffer) { + const fastTextNative = this.f; + const byteArray = new Uint8Array(buffer); + const FS = fastTextModule.FS; + FS.writeFile(modelFileInWasmFs, byteArray); + fastTextNative.loadModel(modelFileInWasmFs); + return new FastTextModel(fastTextNative); + } + + _train(url, modelName, kwargs = {}, callback = null) { + const fetchFunc = (thisModule && thisModule.fetch) || fetch; + const fastTextNative = this.f; + + return new Promise(function(resolve, reject) { + fetchFunc(url).then(response => { + return response.arrayBuffer(); + }).then(bytes => { + const byteArray = new Uint8Array(bytes); + const FS = fastTextModule.FS; + FS.writeFile(trainFileInWasmFs, byteArray); + }).then(() => { + const argsList = ['lr', 'lrUpdateRate', 'dim', 'ws', 'epoch', + 'minCount', 'minCountLabel', 'neg', 'wordNgrams', 'loss', + 'model', 'bucket', 'minn', 'maxn', 't', 'label', 'verbose', + 'pretrainedVectors', 'saveOutput', 'seed', 'qout', 'retrain', + 'qnorm', 'cutoff', 'dsub', 'qnorm', 'autotuneValidationFile', + 'autotuneMetric', 'autotunePredictions', 'autotuneDuration', + 'autotuneModelSize']; + const args = new fastTextModule.Args(); + argsList.forEach(k => { + if (k in kwargs) { + args[k] = kwargs[k]; + } + }); + args.model = fastTextModule.ModelName[modelName]; + args.loss = ('loss' in kwargs) ? + fastTextModule.LossName[kwargs['loss']] : 'hs'; + args.thread = 1; + args.input = trainFileInWasmFs; + + fastTextNative.train(args, callback); + + resolve(new FastTextModel(fastTextNative)); + }).catch(error => { + reject(error); + }); + }); + } + + /** + * trainSupervised + * + * Downloads the input file from the specified url, trains a supervised + * model and returns a `FastTextModel` object. + * + * @param {string} url + * the url of the input file. + * The input file must must contain at least one label per line. For an + * example consult the example datasets which are part of the fastText + * repository such as the dataset pulled by classification-example.sh. + * + * @param {dict} kwargs + * train parameters. + * For example {'lr': 0.5, 'epoch': 5} + * + * @param {function} callback + * train callback function + * `callback` function is called regularly from the train loop: + * `callback(progress, loss, wordsPerSec, learningRate, eta)` + * + * @return {Promise} promise object that resolves to a `FastTextModel` + * + */ + trainSupervised(url, kwargs = {}, callback) { + const self = this; + return new Promise(function(resolve, reject) { + self._train(url, 'supervised', kwargs, callback).then(model => { + resolve(model); + }).catch(error => { + reject(error); + }); + }); + } + + /** + * trainUnsupervised + * + * Downloads the input file from the specified url, trains an unsupervised + * model and returns a `FastTextModel` object. + * + * @param {string} url + * the url of the input file. + * The input file must not contain any labels or use the specified label + * prefixunless it is ok for those words to be ignored. For an example + * consult the dataset pulled by the example script word-vector-example.sh + * which is part of the fastText repository. + * + * @param {string} modelName + * Model to be used for unsupervised learning. `cbow` or `skipgram`. + * + * @param {dict} kwargs + * train parameters. + * For example {'lr': 0.5, 'epoch': 5} + * + * @param {function} callback + * train callback function + * `callback` function is called regularly from the train loop: + * `callback(progress, loss, wordsPerSec, learningRate, eta)` + * + * @return {Promise} promise object that resolves to a `FastTextModel` + * + */ + trainUnsupervised(url, modelName, kwargs = {}, callback) { + const self = this; + return new Promise(function(resolve, reject) { + self._train(url, modelName, kwargs, callback).then(model => { + resolve(model); + }).catch(error => { + reject(error); + }); + }); + } + +} + + +class FastTextModel { + /** + * `FastTextModel` represents a trained model. + * + * @constructor + * + * @param {object} fastTextNative + * webassembly object that makes the bridge between js and C++ + */ + constructor(fastTextNative) { + this.f = fastTextNative; + } + + /** + * isQuant + * + * @return {bool} true if the model is quantized + * + */ + isQuant() { + return this.f.isQuant; + } + + /** + * getDimension + * + * @return {int} the dimension (size) of a lookup vector (hidden layer) + * + */ + getDimension() { + return this.f.args.dim; + } + + /** + * getWordVector + * + * @param {string} word + * + * @return {Float32Array} the vector representation of `word`. + * + */ + getWordVector(word) { + const b = getFloat32ArrayFromHeap(this.getDimension()); + this.f.getWordVector(b, word); + + return heapToFloat32(b); + } + + /** + * getSentenceVector + * + * @param {string} text + * + * @return {Float32Array} the vector representation of `text`. + * + */ + getSentenceVector(text) { + if (text.indexOf('\n') != -1) { + "sentence vector processes one line at a time (remove '\\n')"; + } + text += '\n'; + const b = getFloat32ArrayFromHeap(this.getDimension()); + this.f.getSentenceVector(b, text); + + return heapToFloat32(b); + } + + /** + * getNearestNeighbors + * + * returns the nearest `k` neighbors of `word`. + * + * @param {string} word + * @param {int} k + * + * @return {Array.<Pair.<number, string>>} + * words and their corresponding cosine similarities. + * + */ + getNearestNeighbors(word, k = 10) { + return this.f.getNN(word, k); + } + + /** + * getAnalogies + * + * returns the nearest `k` neighbors of the operation + * `wordA - wordB + wordC`. + * + * @param {string} wordA + * @param {string} wordB + * @param {string} wordC + * @param {int} k + * + * @return {Array.<Pair.<number, string>>} + * words and their corresponding cosine similarities + * + */ + getAnalogies(wordA, wordB, wordC, k) { + return this.f.getAnalogies(k, wordA, wordB, wordC); + } + + /** + * getWordId + * + * Given a word, get the word id within the dictionary. + * Returns -1 if word is not in the dictionary. + * + * @return {int} word id + * + */ + getWordId(word) { + return this.f.getWordId(word); + } + + /** + * getSubwordId + * + * Given a subword, return the index (within input matrix) it hashes to. + * + * @return {int} subword id + * + */ + getSubwordId(subword) { + return this.f.getSubwordId(subword); + } + + /** + * getSubwords + * + * returns the subwords and their indicies. + * + * @param {string} word + * + * @return {Pair.<Array.<string>, Array.<int>>} + * words and their corresponding indicies + * + */ + getSubwords(word) { + return this.f.getSubwords(word); + } + + /** + * getInputVector + * + * Given an index, get the corresponding vector of the Input Matrix. + * + * @param {int} ind + * + * @return {Float32Array} the vector of the `ind`'th index + * + */ + getInputVector(ind) { + const b = getFloat32ArrayFromHeap(this.getDimension()); + this.f.getInputVector(b, ind); + + return heapToFloat32(b); + } + + /** + * predict + * + * Given a string, get a list of labels and a list of corresponding + * probabilities. k controls the number of returned labels. + * + * @param {string} text + * @param {int} k, the number of predictions to be returned + * @param {number} probability threshold + * + * @return {Array.<Pair.<number, string>>} + * labels and their probabilities + * + */ + predict(text, k = 1, threshold = 0.0) { + return this.f.predict(text, k, threshold); + } + + /** + * getInputMatrix + * + * Get a reference to the full input matrix of a Model. This only + * works if the model is not quantized. + * + * @return {DenseMatrix} + * densematrix with functions: `rows`, `cols`, `at(i,j)` + * + * example: + * let inputMatrix = model.getInputMatrix(); + * let value = inputMatrix.at(1, 2); + */ + getInputMatrix() { + if (this.isQuant()) { + throw new Error("Can't get quantized Matrix"); + } + return this.f.getInputMatrix(); + } + + /** + * getOutputMatrix + * + * Get a reference to the full input matrix of a Model. This only + * works if the model is not quantized. + * + * @return {DenseMatrix} + * densematrix with functions: `rows`, `cols`, `at(i,j)` + * + * example: + * let outputMatrix = model.getOutputMatrix(); + * let value = outputMatrix.at(1, 2); + */ + getOutputMatrix() { + if (this.isQuant()) { + throw new Error("Can't get quantized Matrix"); + } + return this.f.getOutputMatrix(); + } + + /** + * getWords + * + * Get the entire list of words of the dictionary including the frequency + * of the individual words. This does not include any subwords. For that + * please consult the function get_subwords. + * + * @return {Pair.<Array.<string>, Array.<int>>} + * words and their corresponding frequencies + * + */ + getWords() { + return this.f.getWords(); + } + + /** + * getLabels + * + * Get the entire list of labels of the dictionary including the frequency + * of the individual labels. + * + * @return {Pair.<Array.<string>, Array.<int>>} + * labels and their corresponding frequencies + * + */ + getLabels() { + return this.f.getLabels(); + } + + /** + * getLine + * + * Split a line of text into words and labels. Labels must start with + * the prefix used to create the model (__label__ by default). + * + * @param {string} text + * + * @return {Pair.<Array.<string>, Array.<string>>} + * words and labels + * + */ + getLine(text) { + return this.f.getLine(text); + } + + /** + * saveModel + * + * Saves the model file in web assembly in-memory FS and returns a blob + * + * @return {Blob} blob data of the file saved in web assembly FS + * + */ + saveModel() { + this.f.saveModel(modelFileInWasmFs); + const content = fastTextModule.FS.readFile(modelFileInWasmFs, + { encoding: 'binary' }); + return new Blob( + [new Uint8Array(content, content.byteOffset, content.length)], + { type: ' application/octet-stream' } + ); + } + + /** + * test + * + * Downloads the test file from the specified url, evaluates the supervised + * model with it. + * + * @param {string} url + * @param {int} k, the number of predictions to be returned + * @param {number} probability threshold + * + * @return {Promise} promise object that resolves to a `Meter` object + * + * example: + * model.test("/absolute/url/to/test.txt", 1, 0.0).then((meter) => { + * console.log(meter.precision); + * console.log(meter.recall); + * console.log(meter.f1Score); + * console.log(meter.nexamples()); + * }); + * + */ + test(url, k, threshold) { + const fetchFunc = (thisModule && thisModule.fetch) || fetch; + const fastTextNative = this.f; + + return new Promise(function(resolve, reject) { + fetchFunc(url).then(response => { + return response.arrayBuffer(); + }).then(bytes => { + const byteArray = new Uint8Array(bytes); + const FS = fastTextModule.FS; + FS.writeFile(testFileInWasmFs, byteArray); + }).then(() => { + const meter = fastTextNative.test(testFileInWasmFs, k, threshold); + resolve(meter); + }).catch(error => { + reject(error); + }); + }); + } +} |