diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-15 03:34:42 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-15 03:34:42 +0000 |
commit | da4c7e7ed675c3bf405668739c3012d140856109 (patch) | |
tree | cdd868dba063fecba609a1d819de271f0d51b23e /toolkit/components/ml/tests | |
parent | Adding upstream version 125.0.3. (diff) | |
download | firefox-da4c7e7ed675c3bf405668739c3012d140856109.tar.xz firefox-da4c7e7ed675c3bf405668739c3012d140856109.zip |
Adding upstream version 126.0.upstream/126.0
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'toolkit/components/ml/tests')
6 files changed, 416 insertions, 0 deletions
diff --git a/toolkit/components/ml/tests/browser/browser.toml b/toolkit/components/ml/tests/browser/browser.toml index 9ccda0beaa..57637c8bda 100644 --- a/toolkit/components/ml/tests/browser/browser.toml +++ b/toolkit/components/ml/tests/browser/browser.toml @@ -1,5 +1,9 @@ [DEFAULT] support-files = [ "head.js", + "data/**/*.*" ] + +["browser_ml_cache.js"] + ["browser_ml_engine.js"] diff --git a/toolkit/components/ml/tests/browser/browser_ml_cache.js b/toolkit/components/ml/tests/browser/browser_ml_cache.js new file mode 100644 index 0000000000..d8725368bd --- /dev/null +++ b/toolkit/components/ml/tests/browser/browser_ml_cache.js @@ -0,0 +1,361 @@ +/* Any copyright is dedicated to the Public Domain. +http://creativecommons.org/publicdomain/zero/1.0/ */ +"use strict"; + +const { sinon } = ChromeUtils.importESModule( + "resource://testing-common/Sinon.sys.mjs" +); + +// Root URL of the fake hub, see the `data` dir in the tests. +const FAKE_HUB = + "chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data"; + +const FAKE_MODEL_ARGS = { + organization: "acme", + modelName: "bert", + modelVersion: "main", + file: "config.json", +}; + +const FAKE_ONNX_MODEL_ARGS = { + organization: "acme", + modelName: "bert", + modelVersion: "main", + file: "onnx/config.json", +}; + +const badHubs = [ + "https://my.cool.hub", + "https://sub.localhost/myhub", // Subdomain of allowed domain + "https://model-hub.mozilla.org.evil.com", // Manipulating path to mimic domain + "httpsz://localhost/myhub", // Similar-looking scheme + "https://localhost.", // Trailing dot in domain + "resource://user@localhost", // User info in URL + "ftp://localhost/myhub", // Disallowed scheme with allowed host + "https://model-hub.mozilla.org.hack", // Domain that contains allowed domain +]; + +add_task(async function test_bad_hubs() { + for (const badHub of badHubs) { + Assert.throws( + () => new ModelHub({ rootUrl: badHub }), + new RegExp(`Error: Invalid model hub root url: ${badHub}`), + `Should throw with ${badHub}` + ); + } +}); + +let goodHubs = [ + "https:///localhost/myhub", // Triple slashes, see https://stackoverflow.com/a/22775589 + "https://localhost:8080/myhub", + "http://localhost/myhub", + "https://model-hub.mozilla.org", + "chrome://gre/somewhere/in/the/code/base", +]; + +add_task(async function test_allowed_hub() { + goodHubs.forEach(url => new ModelHub({ rootUrl: url })); +}); + +const badInputs = [ + [ + { + organization: "ac me", + modelName: "bert", + modelVersion: "main", + file: "config.json", + }, + "Org can only contain letters, numbers, and hyphens", + ], + [ + { + organization: "1111", + modelName: "bert", + modelVersion: "main", + file: "config.json", + }, + "Org cannot contain only numbers", + ], + [ + { + organization: "-acme", + modelName: "bert", + modelVersion: "main", + file: "config.json", + }, + "Org start or end with a hyphen, or use consecutive hyphens", + ], + [ + { + organization: "a-c-m-e", + modelName: "#bert", + modelVersion: "main", + file: "config.json", + }, + "Models can only contain letters, numbers, and hyphens, underscord, periods", + ], + [ + { + organization: "a-c-m-e", + modelName: "b$ert", + modelVersion: "main", + file: "config.json", + }, + "Models cannot contain spaces or control characters", + ], + [ + { + organization: "a-c-m-e", + modelName: "b$ert", + modelVersion: "main", + file: ".filename", + }, + "File", + ], +]; + +add_task(async function test_bad_inputs() { + const hub = new ModelHub({ rootUrl: FAKE_HUB }); + + for (const badInput of badInputs) { + const params = badInput[0]; + const errorMsg = badInput[1]; + try { + await hub.getModelFileAsArrayBuffer(params); + } catch (error) { + continue; + } + throw new Error(errorMsg); + } +}); + +add_task(async function test_getting_file() { + const hub = new ModelHub({ rootUrl: FAKE_HUB }); + + let [array, headers] = await hub.getModelFileAsArrayBuffer(FAKE_MODEL_ARGS); + + Assert.equal(headers["Content-Type"], "application/json"); + + // check the content of the file. + let jsonData = JSON.parse( + String.fromCharCode.apply(null, new Uint8Array(array)) + ); + + Assert.equal(jsonData.hidden_size, 768); +}); + +add_task(async function test_getting_file_in_subdir() { + const hub = new ModelHub({ rootUrl: FAKE_HUB }); + + let [array, metadata] = await hub.getModelFileAsArrayBuffer( + FAKE_ONNX_MODEL_ARGS + ); + + Assert.equal(metadata["Content-Type"], "application/json"); + + // check the content of the file. + let jsonData = JSON.parse( + String.fromCharCode.apply(null, new Uint8Array(array)) + ); + + Assert.equal(jsonData.hidden_size, 768); +}); + +add_task(async function test_getting_file_custom_path() { + const hub = new ModelHub({ + rootUrl: FAKE_HUB, + urlTemplate: "${organization}/${modelName}/resolve/${modelVersion}/${file}", + }); + + let res = await hub.getModelFileAsArrayBuffer(FAKE_MODEL_ARGS); + + Assert.equal(res[1]["Content-Type"], "application/json"); +}); + +add_task(async function test_getting_file_custom_path_rogue() { + const urlTemplate = + "${organization}/${modelName}/resolve/${modelVersion}/${file}?some_id=bedqwdw"; + Assert.throws( + () => new ModelHub({ rootUrl: FAKE_HUB, urlTemplate }), + /Invalid URL template/, + `Should throw with ${urlTemplate}` + ); +}); + +add_task(async function test_getting_file_as_response() { + const hub = new ModelHub({ rootUrl: FAKE_HUB }); + + let response = await hub.getModelFileAsResponse(FAKE_MODEL_ARGS); + + // check the content of the file. + let jsonData = await response.json(); + Assert.equal(jsonData.hidden_size, 768); +}); + +add_task(async function test_getting_file_from_cache() { + const hub = new ModelHub({ rootUrl: FAKE_HUB }); + let array = await hub.getModelFileAsArrayBuffer(FAKE_MODEL_ARGS); + + // stub to verify that the data was retrieved from IndexDB + let matchMethod = hub.cache._testGetData; + + sinon.stub(hub.cache, "_testGetData").callsFake(function () { + return matchMethod.apply(this, arguments).then(result => { + Assert.notEqual(result, null); + return result; + }); + }); + + // exercises the cache + let array2 = await hub.getModelFileAsArrayBuffer(FAKE_MODEL_ARGS); + hub.cache._testGetData.restore(); + + Assert.deepEqual(array, array2); +}); + +// IndexedDB tests + +/** + * Helper function to initialize the cache + */ +async function initializeCache() { + const randomSuffix = Math.floor(Math.random() * 10000); + return await IndexedDBCache.init(`modelFiles-${randomSuffix}`); +} + +/** + * Helper function to delete the cache database + */ +async function deleteCache(cache) { + await cache.dispose(); + indexedDB.deleteDatabase(cache.dbName); +} + +/** + * Test the initialization and creation of the IndexedDBCache instance. + */ +add_task(async function test_Init() { + const cache = await initializeCache(); + Assert.ok( + cache instanceof IndexedDBCache, + "The cache instance should be created successfully." + ); + Assert.ok( + IDBDatabase.isInstance(cache.db), + `The cache should have an IDBDatabase instance. Found ${cache.db}` + ); + await deleteCache(cache); +}); + +/** + * Test adding data to the cache and retrieving it. + */ +add_task(async function test_PutAndGet() { + const cache = await initializeCache(); + const testData = new ArrayBuffer(8); // Example data + await cache.put("org", "model", "v1", "file.txt", testData, { + ETag: "ETAG123", + }); + + const [retrievedData, headers] = await cache.getFile( + "org", + "model", + "v1", + "file.txt" + ); + Assert.deepEqual( + retrievedData, + testData, + "The retrieved data should match the stored data." + ); + Assert.equal( + headers.ETag, + "ETAG123", + "The retrieved ETag should match the stored ETag." + ); + + await deleteCache(cache); +}); + +/** + * Test retrieving the headers for a cache entry. + */ +add_task(async function test_GetHeaders() { + const cache = await initializeCache(); + const testData = new ArrayBuffer(8); + const headers = { + ETag: "ETAG123", + status: 200, + extra: "extra", + }; + + await cache.put("org", "model", "v1", "file.txt", testData, headers); + + const storedHeaders = await cache.getHeaders( + "org", + "model", + "v1", + "file.txt" + ); + + // The `extra` field should be removed from the stored headers because + // it's not part of the allowed keys. + // The content-type one is added when not present + Assert.deepEqual( + { + ETag: "ETAG123", + status: 200, + "Content-Type": "application/octet-stream", + }, + storedHeaders, + "The retrieved headers should match the stored headers." + ); + await deleteCache(cache); +}); + +/** + * Test listing all models stored in the cache. + */ +add_task(async function test_ListModels() { + const cache = await initializeCache(); + await cache.put( + "org1", + "modelA", + "v1", + "file1.txt", + new ArrayBuffer(8), + null + ); + await cache.put( + "org2", + "modelB", + "v1", + "file2.txt", + new ArrayBuffer(8), + null + ); + + const models = await cache.listModels(); + Assert.ok( + models.includes("org1/modelA/v1") && models.includes("org2/modelB/v1"), + "All models should be listed." + ); + await deleteCache(cache); +}); + +/** + * Test deleting a model and its data from the cache. + */ +add_task(async function test_DeleteModel() { + const cache = await initializeCache(); + await cache.put("org", "model", "v1", "file.txt", new ArrayBuffer(8), null); + await cache.deleteModel("org", "model", "v1"); + + const dataAfterDelete = await cache.getFile("org", "model", "v1", "file.txt"); + Assert.equal( + dataAfterDelete, + null, + "The data for the deleted model should not exist." + ); + await deleteCache(cache); +}); diff --git a/toolkit/components/ml/tests/browser/data/README.md b/toolkit/components/ml/tests/browser/data/README.md new file mode 100644 index 0000000000..d826cf7ee6 --- /dev/null +++ b/toolkit/components/ml/tests/browser/data/README.md @@ -0,0 +1,5 @@ +# fake hub + +This directory is a fake hub that is served via chrome://global/content/ml/tests. + +All files in this directory are included with a wildcard in the component `jar.md`. diff --git a/toolkit/components/ml/tests/browser/data/acme/bert/resolve/main/config.json b/toolkit/components/ml/tests/browser/data/acme/bert/resolve/main/config.json new file mode 100644 index 0000000000..50dbb760bb --- /dev/null +++ b/toolkit/components/ml/tests/browser/data/acme/bert/resolve/main/config.json @@ -0,0 +1,21 @@ +{ + "architectures": ["BertForMaskedLM"], + "attention_probs_dropout_prob": 0.1, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.6.0.dev0", + "type_vocab_size": 2, + "use_cache": true, + "vocab_size": 30522 +} diff --git a/toolkit/components/ml/tests/browser/data/acme/bert/resolve/main/onnx/config.json b/toolkit/components/ml/tests/browser/data/acme/bert/resolve/main/onnx/config.json new file mode 100644 index 0000000000..50dbb760bb --- /dev/null +++ b/toolkit/components/ml/tests/browser/data/acme/bert/resolve/main/onnx/config.json @@ -0,0 +1,21 @@ +{ + "architectures": ["BertForMaskedLM"], + "attention_probs_dropout_prob": 0.1, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.6.0.dev0", + "type_vocab_size": 2, + "use_cache": true, + "vocab_size": 30522 +} diff --git a/toolkit/components/ml/tests/browser/head.js b/toolkit/components/ml/tests/browser/head.js index 99d27ce18a..9fc20c0e84 100644 --- a/toolkit/components/ml/tests/browser/head.js +++ b/toolkit/components/ml/tests/browser/head.js @@ -19,6 +19,10 @@ const { MLEngineParent } = ChromeUtils.importESModule( "resource://gre/actors/MLEngineParent.sys.mjs" ); +const { ModelHub, IndexedDBCache } = ChromeUtils.importESModule( + "chrome://global/content/ml/ModelHub.sys.mjs" +); + // This test suite shares some utility functions with translations as they work in a very // similar fashion. Eventually, the plan is to unify these two components. Services.scriptloader.loadSubScript( |